"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "362a90f8bfffe62d5802925944f540ed16b2731e"
Commit cfe2c2be authored by rprenger's avatar rprenger
Browse files

Adding checks for total number of tokens to keep server from crashing

parent 5f694372
...@@ -28,6 +28,8 @@ from .forward_step import ForwardStep ...@@ -28,6 +28,8 @@ from .forward_step import ForwardStep
from .sampling import sample from .sampling import sample
from .beam_utils import BeamHypotheses from .beam_utils import BeamHypotheses
MAX_TOKENS_TO_OOM = 12000 # (rprenger) Perfect value depends on hardware and network
def score_and_return_on_first_stage(model, tokens, lengths): def score_and_return_on_first_stage(model, tokens, lengths):
"""Function for just scoring. """Function for just scoring.
Arguments: Arguments:
...@@ -133,11 +135,12 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -133,11 +135,12 @@ def generate_tokens_probs_and_return_on_first_stage(
batch_size = tokens.size(0) batch_size = tokens.size(0)
min_prompt_length = lengths.min().item() min_prompt_length = lengths.min().item()
max_sequence_length = tokens.size(1) max_sequence_length = tokens.size(1)
max_sequence_length = min(max_sequence_length, args.max_position_embeddings)
if max_sequence_length > args.max_position_embeddings:
raise ValueError("Length of prompt + tokens_to_generate longer than allowed")
# If the context is too big, this happens if max_sequence_length * batch_size >= MAX_TOKENS_TO_OOM:
if min_prompt_length >= max_sequence_length: raise ValueError("Too many tokens. " + str(max_sequence_length*batch_size)+ " is greater than "+str(MAX_TOKENS_TO_OOM))
raise ValueError("context length + tokens_to_generate too large")
# forward step. # forward step.
forward_step = ForwardStep(model, batch_size, max_sequence_length) forward_step = ForwardStep(model, batch_size, max_sequence_length)
......
...@@ -204,7 +204,7 @@ class MegatronGenerate(Resource): ...@@ -204,7 +204,7 @@ class MegatronGenerate(Resource):
"logprobs": response_logprobs}) "logprobs": response_logprobs})
except ValueError as ve: except ValueError as ve:
return "Length of prompt + tokens_to_generate longer than allowed" return ve.args[0]
print("end time: ", datetime.datetime.now()) print("end time: ", datetime.datetime.now())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment