Commit 8d405805 authored by rprenger's avatar rprenger
Browse files

Changing defaults and query sanitation to keep it from crashing on reasonable queries

parent 0dd5cc75
...@@ -19,7 +19,6 @@ import threading ...@@ -19,7 +19,6 @@ import threading
from flask import Flask, request, jsonify, current_app from flask import Flask, request, jsonify, current_app
from flask_restful import Resource, Api from flask_restful import Resource, Api
from megatron import get_args from megatron import get_args
from megatron import mpu
from megatron.text_generation import generate_and_post_process from megatron.text_generation import generate_and_post_process
...@@ -68,7 +67,7 @@ class MegatronGenerate(Resource): ...@@ -68,7 +67,7 @@ class MegatronGenerate(Resource):
if not isinstance(logprobs, bool): if not isinstance(logprobs, bool):
return "logprobs must be a boolean value" return "logprobs must be a boolean value"
temperature = args.temperature temperature = 1.0
if "temperature" in request.get_json(): if "temperature" in request.get_json():
temperature = request.get_json()["temperature"] temperature = request.get_json()["temperature"]
if not (type(temperature) == int or type(temperature) == float): if not (type(temperature) == int or type(temperature) == float):
...@@ -76,7 +75,7 @@ class MegatronGenerate(Resource): ...@@ -76,7 +75,7 @@ class MegatronGenerate(Resource):
if not (0.0 < temperature <= 100.0): if not (0.0 < temperature <= 100.0):
return "temperature must be a positive number less than or equal to 100.0" return "temperature must be a positive number less than or equal to 100.0"
top_k = args.top_k top_k = 0.0
if "top_k" in request.get_json(): if "top_k" in request.get_json():
top_k = request.get_json()["top_k"] top_k = request.get_json()["top_k"]
if not (type(top_k) == int): if not (type(top_k) == int):
...@@ -84,11 +83,13 @@ class MegatronGenerate(Resource): ...@@ -84,11 +83,13 @@ class MegatronGenerate(Resource):
if not (0 < top_k <= 1000): if not (0 < top_k <= 1000):
return "top_k must be equal to or greater than 0 and less than or equal to 1000" return "top_k must be equal to or greater than 0 and less than or equal to 1000"
top_p = args.top_p top_p = 0.0
if "top_p" in request.get_json(): if "top_p" in request.get_json():
top_p = request.get_json()["top_p"] top_p = request.get_json()["top_p"]
if not (type(top_p) == float): if not (type(top_p) == float):
return "top_p must be a positive float less than or equal to 1.0" return "top_p must be a positive float less than or equal to 1.0"
if top_p > 0.0 and top_k > 0.0:
return "cannot set both top-k and top-p samplings."
if not (0 < top_p <= 1.0): if not (0 < top_p <= 1.0):
return "top_p must be less than or equal to 1.0" return "top_p must be less than or equal to 1.0"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment