Commit 8b686288 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'fixing_safety' into 'main'

Memory safety checks were incorrect for the tokens_to_generate=0 case

See merge request ADLR/megatron-lm!447
parents d63c2541 e38d41ca
...@@ -47,10 +47,15 @@ def score_and_return_on_first_stage(model, tokens, lengths): ...@@ -47,10 +47,15 @@ def score_and_return_on_first_stage(model, tokens, lengths):
batch_size = tokens.size(0) batch_size = tokens.size(0)
max_prompt_length = lengths.max().item() max_prompt_length = lengths.max().item()
assert max_prompt_length == tokens.size(1) assert max_prompt_length == tokens.size(1)
max_sequence_length = min(max_prompt_length, args.max_position_embeddings)
if max_prompt_length > args.max_position_embeddings:
raise ValueError("Length of prompt + tokens_to_generate longer than allowed")
if max_prompt_length * batch_size >= MAX_TOKENS_TO_OOM:
raise ValueError("Too many tokens. " + str(max_prompt_length*batch_size)+ " is greater than "+str(MAX_TOKENS_TO_OOM))
# forward step. # forward step.
forward_step = ForwardStep(model, batch_size, max_sequence_length) forward_step = ForwardStep(model, batch_size, max_prompt_length)
# =================== # ===================
# Pre-allocate memory # Pre-allocate memory
...@@ -58,7 +63,7 @@ def score_and_return_on_first_stage(model, tokens, lengths): ...@@ -58,7 +63,7 @@ def score_and_return_on_first_stage(model, tokens, lengths):
# Log probability of the sequence (prompt + generated tokens). # Log probability of the sequence (prompt + generated tokens).
output_log_probs = None output_log_probs = None
output_log_probs_size = (batch_size, max_sequence_length - 1) output_log_probs_size = (batch_size, max_prompt_length - 1)
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
output_log_probs = torch.empty(output_log_probs_size, output_log_probs = torch.empty(output_log_probs_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment