Unverified Commit ff846e9b authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[FlaxGenerate] Fix bug in decoder_start_token_id (#17035)

parent eb877f1f
......@@ -264,7 +264,7 @@ class FlaxGenerationMixin:
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id else self.config.decoder_start_token_id
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
)
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment