You need to sign in or sign up before continuing.
Commit 8d405805 authored by rprenger's avatar rprenger
Browse files

Changing defaults and query sanitation to keep it from crashing on reasonable queries

parent 0dd5cc75
...@@ -19,7 +19,6 @@ import threading ...@@ -19,7 +19,6 @@ import threading
from flask import Flask, request, jsonify, current_app from flask import Flask, request, jsonify, current_app
from flask_restful import Resource, Api from flask_restful import Resource, Api
from megatron import get_args from megatron import get_args
from megatron import mpu
from megatron.text_generation import generate_and_post_process from megatron.text_generation import generate_and_post_process
...@@ -68,7 +67,7 @@ class MegatronGenerate(Resource): ...@@ -68,7 +67,7 @@ class MegatronGenerate(Resource):
if not isinstance(logprobs, bool): if not isinstance(logprobs, bool):
return "logprobs must be a boolean value" return "logprobs must be a boolean value"
temperature = args.temperature temperature = 1.0
if "temperature" in request.get_json(): if "temperature" in request.get_json():
temperature = request.get_json()["temperature"] temperature = request.get_json()["temperature"]
if not (type(temperature) == int or type(temperature) == float): if not (type(temperature) == int or type(temperature) == float):
...@@ -76,7 +75,7 @@ class MegatronGenerate(Resource): ...@@ -76,7 +75,7 @@ class MegatronGenerate(Resource):
if not (0.0 < temperature <= 100.0): if not (0.0 < temperature <= 100.0):
return "temperature must be a positive number less than or equal to 100.0" return "temperature must be a positive number less than or equal to 100.0"
top_k = args.top_k top_k = 0.0
if "top_k" in request.get_json(): if "top_k" in request.get_json():
top_k = request.get_json()["top_k"] top_k = request.get_json()["top_k"]
if not (type(top_k) == int): if not (type(top_k) == int):
...@@ -84,11 +83,13 @@ class MegatronGenerate(Resource): ...@@ -84,11 +83,13 @@ class MegatronGenerate(Resource):
if not (0 < top_k <= 1000): if not (0 < top_k <= 1000):
return "top_k must be equal to or greater than 0 and less than or equal to 1000" return "top_k must be equal to or greater than 0 and less than or equal to 1000"
top_p = args.top_p top_p = 0.0
if "top_p" in request.get_json(): if "top_p" in request.get_json():
top_p = request.get_json()["top_p"] top_p = request.get_json()["top_p"]
if not (type(top_p) == float): if not (type(top_p) == float):
return "top_p must be a positive float less than or equal to 1.0" return "top_p must be a positive float less than or equal to 1.0"
if top_p > 0.0 and top_k > 0.0:
return "cannot set both top-k and top-p samplings."
if not (0 < top_p <= 1.0): if not (0 < top_p <= 1.0):
return "top_p must be less than or equal to 1.0" return "top_p must be less than or equal to 1.0"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment