Commit 554bb262 authored by rprenger's avatar rprenger
Browse files

Code that keeps it from dying when the input prompts are too long

parent a3770921
......@@ -113,6 +113,7 @@ def generate(model,
context_tokens_tensor, context_length_tensor = tokenize_prompts(
prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS)
if just_score:
return score_and_return_on_first_stage(
model, context_tokens_tensor, context_length_tensor)
......
......@@ -131,6 +131,10 @@ def generate_tokens_probs_and_return_on_first_stage(
max_sequence_length = tokens.size(1)
max_sequence_length = min(max_sequence_length, args.max_position_embeddings)
# If the context is too big, this happens
if min_prompt_length >= max_sequence_length:
raise ValueError
# forward step.
forward_step = ForwardStep(model, batch_size, max_sequence_length)
......
......@@ -36,9 +36,6 @@ class MegatronGenerate(Resource):
def put(self):
args = get_args()
print("request IP: " + str(request.remote_addr))
print(json.dumps(request.get_json()),flush=True)
print("current time: ", datetime.datetime.now())
if not "prompts" in request.get_json():
return "prompts argument required", 400
......@@ -106,7 +103,11 @@ class MegatronGenerate(Resource):
return "add_BOS must be a boolean value"
with lock: # Need to get lock to keep multiple threads from hitting code
print("request IP: " + str(request.remote_addr))
print(json.dumps(request.get_json()),flush=True)
print("start time: ", datetime.datetime.now())
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
try:
response, response_seg, response_logprobs, _ = \
generate_and_post_process(
self.model,
......@@ -119,6 +120,9 @@ class MegatronGenerate(Resource):
add_BOS=add_BOS,
use_eod_token_for_early_termination=True,
just_score=just_score)
except ValueError as ve:
return "Length of prompt + tokens_to_generate longer than allowed"
print("end time: ", datetime.datetime.now())
return jsonify({"text": response,
"segments": response_seg,
......
......@@ -78,4 +78,7 @@ if __name__ == "__main__":
choice = torch.cuda.LongTensor(1)
torch.distributed.broadcast(choice, 0)
if choice[0].item() == 0:
try:
generate_and_post_process(model)
except ValueError as ve:
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment