Commit 83d78670 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'beam_search' into 'main'

fix a bug for size mismatch

See merge request ADLR/megatron-lm!438
parents 189e72a7 b7b2d6a9
...@@ -395,7 +395,7 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto ...@@ -395,7 +395,7 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
# if cannot find stop token, add open beams to hyps # if cannot find stop token, add open beams to hyps
if not done: if not done:
for beam_id in range(beam_size): for beam_id in range(beam_size):
beam_hyp.add(tokens[beam_id].clone(), scores[beam_id], context_length + 1 - prompt_length) beam_hyp.add(tokens[beam_id].clone(), scores[beam_id].squeeze(), context_length + 1 - prompt_length)
# rank based on scores # rank based on scores
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0], reverse=True) sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0], reverse=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment