Commit 0bb597b4 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'oom' into 'main'

Adding checks for total number of tokens to keep server from crashing

See merge request ADLR/megatron-lm!428
parents 2bc3c1a4 cfe2c2be
......@@ -28,6 +28,8 @@ from .forward_step import ForwardStep
from .sampling import sample
from .beam_utils import BeamHypotheses
MAX_TOKENS_TO_OOM = 12000 # (rprenger) Perfect value depends on hardware and network
def score_and_return_on_first_stage(model, tokens, lengths):
"""Function for just scoring.
Arguments:
......@@ -133,11 +135,12 @@ def generate_tokens_probs_and_return_on_first_stage(
batch_size = tokens.size(0)
min_prompt_length = lengths.min().item()
max_sequence_length = tokens.size(1)
max_sequence_length = min(max_sequence_length, args.max_position_embeddings)
if max_sequence_length > args.max_position_embeddings:
raise ValueError("Length of prompt + tokens_to_generate longer than allowed")
# If the context is too big, this happens
if min_prompt_length >= max_sequence_length:
raise ValueError("context length + tokens_to_generate too large")
if max_sequence_length * batch_size >= MAX_TOKENS_TO_OOM:
raise ValueError("Too many tokens. " + str(max_sequence_length*batch_size)+ " is greater than "+str(MAX_TOKENS_TO_OOM))
# forward step.
forward_step = ForwardStep(model, batch_size, max_sequence_length)
......
......@@ -226,7 +226,7 @@ class MegatronGenerate(Resource):
"logprobs": response_logprobs})
except ValueError as ve:
return "Length of prompt + tokens_to_generate longer than allowed"
return ve.args[0]
print("end time: ", datetime.datetime.now())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment