Commit bdbb30fb authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'beam_search' into 'main'

Fix bugs for beam search when using pipeline parallelization

See merge request ADLR/megatron-lm!426
parents 3f4e71df 2e65bee5
...@@ -300,10 +300,12 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto ...@@ -300,10 +300,12 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
forward_step = ForwardStep(model, beam_size, final_sequence_length) forward_step = ForwardStep(model, beam_size, final_sequence_length)
beam_hyp = BeamHypotheses(beam_size, length_penalty) beam_hyp = BeamHypotheses(beam_size, length_penalty)
done = False best_batches = None
done = torch.zeros(1, dtype=torch.uint8, device=torch.cuda.current_device())
scores = torch.zeros(beam_size, scores = torch.zeros(beam_size,
dtype=torch.float32, dtype=torch.float32,
device=torch.cuda.current_device()).unsqueeze(1) device=torch.cuda.current_device()).unsqueeze(1)
scores_size_tensor, tokens_size_tensor = None, None
# ============= # =============
# Run infernece # Run infernece
# ============= # =============
...@@ -358,40 +360,51 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto ...@@ -358,40 +360,51 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
break break
if beam_hyp.is_done(best_scores.max().item(), context_length + 1 - prompt_length): if beam_hyp.is_done(best_scores.max().item(), context_length + 1 - prompt_length):
done = True done = torch.ones(1, dtype=torch.uint8, device=torch.cuda.current_device())
break
best_batches = tokens.new([item[2] for item in next_beams]) best_batches = tokens.new([item[2] for item in next_beams])
tokens = tokens[best_batches,:] tokens = tokens[best_batches,:]
tokens[:, context_length] = tokens.new([item[0] for item in next_beams]) tokens[:, context_length] = tokens.new([item[0] for item in next_beams])
scores = scores.new([item[1] for item in next_beams]).unsqueeze(1) scores = scores.new([item[1] for item in next_beams]).unsqueeze(1)
# set inference key values to make it consistent with best beam index # torch.distributed.barrier()
forward_step.inference_params.swap_key_value_dict(best_batches) done = broadcast_from_last_pipeline_stage(1, torch.uint8, done)
if done:
break
# Update the tokens on the first stage so the next input to # Update the tokens on the first stage so the next input to
# the network is correct. # the network is correct.
copy_from_last_to_first_pipeline_stage(batch_size, torch.int64, copy_from_last_to_first_pipeline_stage(tokens.size(), torch.int64,
tokens[:, context_length]) tokens)
# set inference key values to make it consistent with best beam index
best_batches = broadcast_from_last_pipeline_stage(beam_size, torch.int64, best_batches)
forward_step.inference_params.swap_key_value_dict(best_batches)
# Update the context length for the next token generation. # Update the context length for the next token generation.
prev_context_length = context_length prev_context_length = context_length
copy_from_last_to_first_pipeline_stage(scores.size(0), torch.float32, if mpu.is_pipeline_last_stage():
scores[:,0]) # if cannot find stop token, add open beams to hyps
if not done:
# if cannot find stop token, add open beams to hyps for beam_id in range(beam_size):
if not done: beam_hyp.add(tokens[beam_id].clone(), scores[beam_id], context_length + 1 - prompt_length)
for beam_id in range(beam_size):
beam_hyp.add(tokens[beam_id].clone(), scores[beam_id], context_length + 1 - prompt_length) # rank based on scores
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0], reverse=True)
# rank based on scores num_return_gen = min(num_return_gen, len(sorted_hyps))
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0], reverse=True) scores = [sorted_hyps[i][0] for i in range(num_return_gen)]
num_return_gen = min(num_return_gen, len(sorted_hyps)) tokens = [sorted_hyps[i][1] for i in range(num_return_gen)]
scores = [sorted_hyps[i][0] for i in range(num_return_gen)] scores = torch.stack(scores, dim=0)
tokens = [sorted_hyps[i][1] for i in range(num_return_gen)] tokens = torch.stack(tokens, dim=0)
scores = torch.stack(scores, dim=0) scores_size_tensor = torch.tensor(scores.shape, dtype=torch.int64, device=torch.cuda.current_device())
tokens = torch.stack(tokens, dim=0) tokens_size_tensor = torch.tensor(tokens.shape, dtype=torch.int64, device=torch.cuda.current_device())
scores_size_tensor = broadcast_from_last_pipeline_stage(1, torch.int64, scores_size_tensor)
tokens_size_tensor = broadcast_from_last_pipeline_stage(2, torch.int64, tokens_size_tensor)
scores = broadcast_from_last_to_first_pipeline_stage(tuple(scores_size_tensor), torch.float32, scores)
tokens = broadcast_from_last_to_first_pipeline_stage(tuple(tokens_size_tensor), torch.int64, tokens)
return tokens, scores return tokens, scores
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment