Commit 8d405805 authored by rprenger's avatar rprenger
Browse files

Changing defaults and query sanitation to keep it from crashing on reasonable queries

parent 0dd5cc75
......@@ -19,7 +19,6 @@ import threading
from flask import Flask, request, jsonify, current_app
from flask_restful import Resource, Api
from megatron import get_args
from megatron import mpu
from megatron.text_generation import generate_and_post_process
......@@ -68,7 +67,7 @@ class MegatronGenerate(Resource):
if not isinstance(logprobs, bool):
return "logprobs must be a boolean value"
temperature = args.temperature
temperature = 1.0
if "temperature" in request.get_json():
temperature = request.get_json()["temperature"]
if not (type(temperature) == int or type(temperature) == float):
......@@ -76,7 +75,7 @@ class MegatronGenerate(Resource):
if not (0.0 < temperature <= 100.0):
return "temperature must be a positive number less than or equal to 100.0"
top_k = args.top_k
top_k = 0.0
if "top_k" in request.get_json():
top_k = request.get_json()["top_k"]
if not (type(top_k) == int):
......@@ -84,11 +83,13 @@ class MegatronGenerate(Resource):
if not (0 < top_k <= 1000):
return "top_k must be equal to or greater than 0 and less than or equal to 1000"
top_p = args.top_p
top_p = 0.0
if "top_p" in request.get_json():
top_p = request.get_json()["top_p"]
if not (type(top_p) == float):
return "top_p must be a positive float less than or equal to 1.0"
if top_p > 0.0 and top_k > 0.0:
return "cannot set both top-k and top-p samplings."
if not (0 < top_p <= 1.0):
return "top_p must be less than or equal to 1.0"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment