Commit 0bb597b4 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'oom' into 'main'

Adding checks for total number of tokens to keep server from crashing

See merge request ADLR/megatron-lm!428
parents 2bc3c1a4 cfe2c2be
...@@ -28,6 +28,8 @@ from .forward_step import ForwardStep ...@@ -28,6 +28,8 @@ from .forward_step import ForwardStep
from .sampling import sample from .sampling import sample
from .beam_utils import BeamHypotheses from .beam_utils import BeamHypotheses
MAX_TOKENS_TO_OOM = 12000 # (rprenger) Perfect value depends on hardware and network
def score_and_return_on_first_stage(model, tokens, lengths): def score_and_return_on_first_stage(model, tokens, lengths):
"""Function for just scoring. """Function for just scoring.
Arguments: Arguments:
...@@ -133,11 +135,12 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -133,11 +135,12 @@ def generate_tokens_probs_and_return_on_first_stage(
batch_size = tokens.size(0) batch_size = tokens.size(0)
min_prompt_length = lengths.min().item() min_prompt_length = lengths.min().item()
max_sequence_length = tokens.size(1) max_sequence_length = tokens.size(1)
max_sequence_length = min(max_sequence_length, args.max_position_embeddings)
if max_sequence_length > args.max_position_embeddings:
raise ValueError("Length of prompt + tokens_to_generate longer than allowed")
# If the context is too big, this happens if max_sequence_length * batch_size >= MAX_TOKENS_TO_OOM:
if min_prompt_length >= max_sequence_length: raise ValueError("Too many tokens. " + str(max_sequence_length*batch_size)+ " is greater than "+str(MAX_TOKENS_TO_OOM))
raise ValueError("context length + tokens_to_generate too large")
# forward step. # forward step.
forward_step = ForwardStep(model, batch_size, max_sequence_length) forward_step = ForwardStep(model, batch_size, max_sequence_length)
......
...@@ -226,7 +226,7 @@ class MegatronGenerate(Resource): ...@@ -226,7 +226,7 @@ class MegatronGenerate(Resource):
"logprobs": response_logprobs}) "logprobs": response_logprobs})
except ValueError as ve: except ValueError as ve:
return "Length of prompt + tokens_to_generate longer than allowed" return ve.args[0]
print("end time: ", datetime.datetime.now()) print("end time: ", datetime.datetime.now())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment