Unverified Commit 077c00c0 authored by Boris Dayma's avatar Boris Dayma Committed by GitHub
Browse files

feat(flax): allow encoder_outputs in generate (#15554)

* feat(flax): allow encoder_outputs in generate

* doc(flax): encoder_outputs in generate

* fix: style

* fix: style
parent 8406fa6d
...@@ -217,7 +217,9 @@ class FlaxGenerationMixin: ...@@ -217,7 +217,9 @@ class FlaxGenerationMixin:
params (`Dict[str, jnp.ndarray]`, *optional*): params (`Dict[str, jnp.ndarray]`, *optional*):
Optionally the model parameters can be passed. Can be useful for parallelized generation. Optionally the model parameters can be passed. Can be useful for parallelized generation.
model_kwargs: model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model
is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs
should be prefixed with *decoder_*. Also accepts `encoder_outputs` to skip encoder part.
Return: Return:
[`~file_utils.ModelOutput`]. [`~file_utils.ModelOutput`].
...@@ -251,7 +253,8 @@ class FlaxGenerationMixin: ...@@ -251,7 +253,8 @@ class FlaxGenerationMixin:
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
# add encoder_outputs to model_kwargs # add encoder_outputs to model_kwargs
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs) if model_kwargs.get("encoder_outputs") is None:
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs)
# prepare decoder_input_ids for generation # prepare decoder_input_ids for generation
input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment