Unverified Commit fd734be1 authored by théo gigant's avatar théo gigant Committed by GitHub
Browse files

fix issue with logit processor during beam search in Flax (#29636)

fix issue with logit processor in beam search in Flax
parent 691c3d73
...@@ -911,7 +911,7 @@ class FlaxGenerationMixin: ...@@ -911,7 +911,7 @@ class FlaxGenerationMixin:
# add new logprobs to existing running logprobs scores. # add new logprobs to existing running logprobs scores.
log_probs = jax.nn.log_softmax(logits) log_probs = jax.nn.log_softmax(logits)
log_probs = logits_processor( log_probs = logits_processor(
flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len flatten_beam_dim(state.running_sequences), flatten_beam_dim(log_probs), state.cur_len
) )
log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams) log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2) log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment