Commit e5034150 authored by peng xu's avatar peng xu
Browse files

fix pipeline parallel for beam search

parent da11c982
......@@ -300,10 +300,12 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
forward_step = ForwardStep(model, beam_size, final_sequence_length)
beam_hyp = BeamHypotheses(beam_size, length_penalty)
done = False
best_batches = None
done = torch.zeros(1, dtype=torch.uint8, device=torch.cuda.current_device())
scores = torch.zeros(beam_size,
dtype=torch.float32,
device=torch.cuda.current_device()).unsqueeze(1)
scores_size_tensor, tokens_size_tensor = None, None
# =============
# Run infernece
# =============
......@@ -321,6 +323,10 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
# logits will be meanigful only in the last pipeline stage.
logits = forward_step(tokens2use, positions2use, attention_mask2use)
# if mpu.is_pipeline_first_stage():
# print('-' * 40)
# print(tokens[:, context_length-5:context_length+5])
# print(context_length)
if mpu.is_pipeline_last_stage():
vocab_size = logits.size(2)
......@@ -335,6 +341,10 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
best_beam_ids = torch.div(indices[: 2 * beam_size], vocab_size).trunc().long()
best_words = indices[:2 * beam_size] % vocab_size
best_scores = sorted_scores[: 2 * beam_size]
# print('*' * 40)
# print(best_beam_ids)
# print(best_words)
# print(context_length)
next_beams = []
for beam_token_rank, (token_id, beam_score, beam_id) in enumerate(
......@@ -358,40 +368,53 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
break
if beam_hyp.is_done(best_scores.max().item(), context_length + 1 - prompt_length):
done = True
break
done = torch.ones(1, dtype=torch.uint8, device=torch.cuda.current_device())
print("find all hyp exiting")
best_batches = tokens.new([item[2] for item in next_beams])
tokens = tokens[best_batches,:]
tokens[:, context_length] = tokens.new([item[0] for item in next_beams])
scores = scores.new([item[1] for item in next_beams]).unsqueeze(1)
# set inference key values to make it consistent with best beam index
forward_step.inference_params.swap_key_value_dict(best_batches)
# torch.distributed.barrier()
done = broadcast_from_last_pipeline_stage(1, torch.uint8, done)
if done:
print("break for loop")
break
# Update the tokens on the first stage so the next input to
# the network is correct.
copy_from_last_to_first_pipeline_stage(batch_size, torch.int64,
tokens[:, context_length])
copy_from_last_to_first_pipeline_stage(tokens.size(), torch.int64,
tokens)
# set inference key values to make it consistent with best beam index
best_batches = broadcast_from_last_pipeline_stage(beam_size, torch.int64, best_batches)
forward_step.inference_params.swap_key_value_dict(best_batches)
# Update the context length for the next token generation.
prev_context_length = context_length
copy_from_last_to_first_pipeline_stage(scores.size(0), torch.float32,
scores[:,0])
# if cannot find stop token, add open beams to hyps
if not done:
for beam_id in range(beam_size):
beam_hyp.add(tokens[beam_id].clone(), scores[beam_id], context_length + 1 - prompt_length)
# rank based on scores
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0], reverse=True)
num_return_gen = min(num_return_gen, len(sorted_hyps))
scores = [sorted_hyps[i][0] for i in range(num_return_gen)]
tokens = [sorted_hyps[i][1] for i in range(num_return_gen)]
scores = torch.stack(scores, dim=0)
tokens = torch.stack(tokens, dim=0)
if mpu.is_pipeline_last_stage():
# if cannot find stop token, add open beams to hyps
if not done:
for beam_id in range(beam_size):
beam_hyp.add(tokens[beam_id].clone(), scores[beam_id], context_length + 1 - prompt_length)
# rank based on scores
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0], reverse=True)
num_return_gen = min(num_return_gen, len(sorted_hyps))
scores = [sorted_hyps[i][0] for i in range(num_return_gen)]
tokens = [sorted_hyps[i][1] for i in range(num_return_gen)]
scores = torch.stack(scores, dim=0)
tokens = torch.stack(tokens, dim=0)
scores_size_tensor = torch.tensor(scores.shape, dtype=torch.int64, device=torch.cuda.current_device())
tokens_size_tensor = torch.tensor(tokens.shape, dtype=torch.int64, device=torch.cuda.current_device())
scores_size_tensor = broadcast_from_last_pipeline_stage(1, torch.int64, scores_size_tensor)
tokens_size_tensor = broadcast_from_last_pipeline_stage(2, torch.int64, tokens_size_tensor)
scores = broadcast_from_last_to_first_pipeline_stage(tuple(scores_size_tensor), torch.float32, scores)
tokens = broadcast_from_last_to_first_pipeline_stage(tuple(tokens_size_tensor), torch.int64, tokens)
return tokens, scores
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment