Commit 554bb262 authored by rprenger's avatar rprenger
Browse files

Code that keeps it from dying when the input prompts are too long

parent a3770921
...@@ -113,10 +113,11 @@ def generate(model, ...@@ -113,10 +113,11 @@ def generate(model,
context_tokens_tensor, context_length_tensor = tokenize_prompts( context_tokens_tensor, context_length_tensor = tokenize_prompts(
prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS) prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS)
if just_score: if just_score:
return score_and_return_on_first_stage( return score_and_return_on_first_stage(
model, context_tokens_tensor, context_length_tensor) model, context_tokens_tensor, context_length_tensor)
# Main inference function. # Main inference function.
# Note that the outputs are available on the first stage. # Note that the outputs are available on the first stage.
return generate_tokens_probs_and_return_on_first_stage( return generate_tokens_probs_and_return_on_first_stage(
......
...@@ -130,6 +130,10 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -130,6 +130,10 @@ def generate_tokens_probs_and_return_on_first_stage(
min_prompt_length = lengths.min().item() min_prompt_length = lengths.min().item()
max_sequence_length = tokens.size(1) max_sequence_length = tokens.size(1)
max_sequence_length = min(max_sequence_length, args.max_position_embeddings) max_sequence_length = min(max_sequence_length, args.max_position_embeddings)
# If the context is too big, this happens
if min_prompt_length >= max_sequence_length:
raise ValueError
# forward step. # forward step.
forward_step = ForwardStep(model, batch_size, max_sequence_length) forward_step = ForwardStep(model, batch_size, max_sequence_length)
......
...@@ -36,9 +36,6 @@ class MegatronGenerate(Resource): ...@@ -36,9 +36,6 @@ class MegatronGenerate(Resource):
def put(self): def put(self):
args = get_args() args = get_args()
print("request IP: " + str(request.remote_addr))
print(json.dumps(request.get_json()),flush=True)
print("current time: ", datetime.datetime.now())
if not "prompts" in request.get_json(): if not "prompts" in request.get_json():
return "prompts argument required", 400 return "prompts argument required", 400
...@@ -106,19 +103,26 @@ class MegatronGenerate(Resource): ...@@ -106,19 +103,26 @@ class MegatronGenerate(Resource):
return "add_BOS must be a boolean value" return "add_BOS must be a boolean value"
with lock: # Need to get lock to keep multiple threads from hitting code with lock: # Need to get lock to keep multiple threads from hitting code
print("request IP: " + str(request.remote_addr))
print(json.dumps(request.get_json()),flush=True)
print("start time: ", datetime.datetime.now())
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
response, response_seg, response_logprobs, _ = \ try:
generate_and_post_process( response, response_seg, response_logprobs, _ = \
self.model, generate_and_post_process(
prompts=prompts, self.model,
tokens_to_generate=tokens_to_generate, prompts=prompts,
return_output_log_probs=logprobs, tokens_to_generate=tokens_to_generate,
top_k_sampling=top_k, return_output_log_probs=logprobs,
top_p_sampling=top_p, top_k_sampling=top_k,
temperature=temperature, top_p_sampling=top_p,
add_BOS=add_BOS, temperature=temperature,
use_eod_token_for_early_termination=True, add_BOS=add_BOS,
just_score=just_score) use_eod_token_for_early_termination=True,
just_score=just_score)
except ValueError as ve:
return "Length of prompt + tokens_to_generate longer than allowed"
print("end time: ", datetime.datetime.now())
return jsonify({"text": response, return jsonify({"text": response,
"segments": response_seg, "segments": response_seg,
......
...@@ -78,4 +78,7 @@ if __name__ == "__main__": ...@@ -78,4 +78,7 @@ if __name__ == "__main__":
choice = torch.cuda.LongTensor(1) choice = torch.cuda.LongTensor(1)
torch.distributed.broadcast(choice, 0) torch.distributed.broadcast(choice, 0)
if choice[0].item() == 0: if choice[0].item() == 0:
generate_and_post_process(model) try:
generate_and_post_process(model)
except ValueError as ve:
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment