Unverified Commit ed4df855 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

fix beam search bug in tf as well (#4745)

parent 1b5820a5
......@@ -1218,7 +1218,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
continue
# test that beam scores match previously calculated scores if not eos and batch_idx not done
if eos_token_id is not None and all(
(token_id % vocab_size).numpy().item() is not eos_token_id for token_id in next_tokens[batch_idx]
(token_id % vocab_size).numpy().item() != eos_token_id for token_id in next_tokens[batch_idx]
):
assert tf.reduce_all(
next_scores[batch_idx, :num_beams] == tf.reshape(beam_scores, (batch_size, num_beams))[batch_idx]
......
......@@ -1528,7 +1528,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# test that beam scores match previously calculated scores if not eos and batch_idx not done
if eos_token_id is not None and all(
(token_id % vocab_size).item() is not eos_token_id for token_id in next_tokens[batch_idx]
(token_id % vocab_size).item() != eos_token_id for token_id in next_tokens[batch_idx]
):
assert torch.all(
next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment