Unverified Commit b6e06c84 authored by Boris Dayma's avatar Boris Dayma Committed by GitHub
Browse files

fix(flax): generate with logits processor/warper (#16231)

parent 1c1e377e
...@@ -560,10 +560,10 @@ class FlaxGenerationMixin: ...@@ -560,10 +560,10 @@ class FlaxGenerationMixin:
# apply min_length, ... # apply min_length, ...
logits = logits_processor(state.sequences, logits, state.cur_len) logits = logits_processor(state.sequences, logits, state.cur_len)
# apply top_k, top_k, temperature # apply top_p, top_k, temperature
logits = logits_warper(logits, logits, state.cur_len) logits = logits_warper(logits, logits, state.cur_len)
next_token = jax.random.categorical(prng_key, model_outputs.logits[:, -1], axis=-1) next_token = jax.random.categorical(prng_key, logits, axis=-1)
next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id) next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment