Commit b7b2d6a9 authored by Peng Xu's avatar Peng Xu
Browse files

fix a bug for size mismatch

parent 2e65bee5
...@@ -388,7 +388,7 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto ...@@ -388,7 +388,7 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
# if cannot find stop token, add open beams to hyps # if cannot find stop token, add open beams to hyps
if not done: if not done:
for beam_id in range(beam_size): for beam_id in range(beam_size):
beam_hyp.add(tokens[beam_id].clone(), scores[beam_id], context_length + 1 - prompt_length) beam_hyp.add(tokens[beam_id].clone(), scores[beam_id].squeeze(), context_length + 1 - prompt_length)
# rank based on scores # rank based on scores
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0], reverse=True) sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0], reverse=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment