Unverified Commit b1dbdf22 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

pass params to encode (#14370)

parent e92190c0
...@@ -132,13 +132,13 @@ class FlaxGenerationMixin: ...@@ -132,13 +132,13 @@ class FlaxGenerationMixin:
state = body_fn(state) state = body_fn(state)
return state return state
def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids, model_kwargs): def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids, params, model_kwargs):
encoder_kwargs = { encoder_kwargs = {
argument: value argument: value
for argument, value in model_kwargs.items() for argument, value in model_kwargs.items()
if not (argument.startswith("decoder_") or argument.startswith("cross_attn")) if not (argument.startswith("decoder_") or argument.startswith("cross_attn"))
} }
model_kwargs["encoder_outputs"] = self.encode(input_ids, return_dict=True, **encoder_kwargs) model_kwargs["encoder_outputs"] = self.encode(input_ids, params=params, return_dict=True, **encoder_kwargs)
return model_kwargs return model_kwargs
@staticmethod @staticmethod
...@@ -251,7 +251,7 @@ class FlaxGenerationMixin: ...@@ -251,7 +251,7 @@ class FlaxGenerationMixin:
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
# add encoder_outputs to model_kwargs # add encoder_outputs to model_kwargs
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs) model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs)
# prepare decoder_input_ids for generation # prepare decoder_input_ids for generation
input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment