Commit bdbb30fb authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'beam_search' into 'main'

Fix bugs for beam search when using pipeline parallelization

See merge request ADLR/megatron-lm!426
parents 3f4e71df 2e65bee5
......@@ -300,10 +300,12 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
forward_step = ForwardStep(model, beam_size, final_sequence_length)
beam_hyp = BeamHypotheses(beam_size, length_penalty)
done = False
best_batches = None
done = torch.zeros(1, dtype=torch.uint8, device=torch.cuda.current_device())
scores = torch.zeros(beam_size,
dtype=torch.float32,
device=torch.cuda.current_device()).unsqueeze(1)
scores_size_tensor, tokens_size_tensor = None, None
# =============
# Run infernece
# =============
......@@ -358,28 +360,31 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
break
if beam_hyp.is_done(best_scores.max().item(), context_length + 1 - prompt_length):
done = True
break
done = torch.ones(1, dtype=torch.uint8, device=torch.cuda.current_device())
best_batches = tokens.new([item[2] for item in next_beams])
tokens = tokens[best_batches,:]
tokens[:, context_length] = tokens.new([item[0] for item in next_beams])
scores = scores.new([item[1] for item in next_beams]).unsqueeze(1)
# set inference key values to make it consistent with best beam index
forward_step.inference_params.swap_key_value_dict(best_batches)
# torch.distributed.barrier()
done = broadcast_from_last_pipeline_stage(1, torch.uint8, done)
if done:
break
# Update the tokens on the first stage so the next input to
# the network is correct.
copy_from_last_to_first_pipeline_stage(batch_size, torch.int64,
tokens[:, context_length])
copy_from_last_to_first_pipeline_stage(tokens.size(), torch.int64,
tokens)
# set inference key values to make it consistent with best beam index
best_batches = broadcast_from_last_pipeline_stage(beam_size, torch.int64, best_batches)
forward_step.inference_params.swap_key_value_dict(best_batches)
# Update the context length for the next token generation.
prev_context_length = context_length
copy_from_last_to_first_pipeline_stage(scores.size(0), torch.float32,
scores[:,0])
if mpu.is_pipeline_last_stage():
# if cannot find stop token, add open beams to hyps
if not done:
for beam_id in range(beam_size):
......@@ -392,6 +397,14 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
tokens = [sorted_hyps[i][1] for i in range(num_return_gen)]
scores = torch.stack(scores, dim=0)
tokens = torch.stack(tokens, dim=0)
scores_size_tensor = torch.tensor(scores.shape, dtype=torch.int64, device=torch.cuda.current_device())
tokens_size_tensor = torch.tensor(tokens.shape, dtype=torch.int64, device=torch.cuda.current_device())
scores_size_tensor = broadcast_from_last_pipeline_stage(1, torch.int64, scores_size_tensor)
tokens_size_tensor = broadcast_from_last_pipeline_stage(2, torch.int64, tokens_size_tensor)
scores = broadcast_from_last_to_first_pipeline_stage(tuple(scores_size_tensor), torch.float32, scores)
tokens = broadcast_from_last_to_first_pipeline_stage(tuple(tokens_size_tensor), torch.int64, tokens)
return tokens, scores
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment