Unverified Commit 735968b6 authored by Boris Dayma's avatar Boris Dayma Committed by GitHub
Browse files

fix: sampling in flax keeps EOS (#28378)

parent 7e0ddf89
...@@ -716,8 +716,8 @@ class FlaxGenerationMixin: ...@@ -716,8 +716,8 @@ class FlaxGenerationMixin:
next_token = jax.random.categorical(prng_key, logits, axis=-1) next_token = jax.random.categorical(prng_key, logits, axis=-1)
next_token = next_token * ~state.is_sent_finished + pad_token_id * state.is_sent_finished
next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id) next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished
next_token = next_token[:, None] next_token = next_token[:, None]
next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len)) next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment