Unverified Commit 5e6cd51b authored by Andy Ehrenberg's avatar Andy Ehrenberg Committed by GitHub
Browse files

Flax beam search fix (#21857)

parent b599b192
...@@ -448,10 +448,11 @@ class FlaxGenerationMixin: ...@@ -448,10 +448,11 @@ class FlaxGenerationMixin:
model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=generation_config.num_beams model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=generation_config.num_beams
) )
if "attention_mask" in model_kwargs: for kwarg in ["attention_mask", "decoder_attention_mask"]:
model_kwargs["attention_mask"] = self._expand_to_num_beams( if kwarg in model_kwargs:
model_kwargs["attention_mask"], num_beams=generation_config.num_beams model_kwargs[kwarg] = self._expand_to_num_beams(
) model_kwargs[kwarg], num_beams=generation_config.num_beams
)
return self._beam_search( return self._beam_search(
input_ids, input_ids,
...@@ -821,8 +822,9 @@ class FlaxGenerationMixin: ...@@ -821,8 +822,9 @@ class FlaxGenerationMixin:
model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim( model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim(
model_kwargs["encoder_outputs"]["last_hidden_state"] model_kwargs["encoder_outputs"]["last_hidden_state"]
) )
if "attention_mask" in model_kwargs: for kwarg in ["attention_mask", "decoder_attention_mask"]:
model_kwargs["attention_mask"] = flatten_beam_dim(model_kwargs["attention_mask"]) if kwarg in model_kwargs:
model_kwargs[kwarg] = flatten_beam_dim(model_kwargs[kwarg])
# initialize model specific kwargs # initialize model specific kwargs
model_kwargs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), max_length, **model_kwargs) model_kwargs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), max_length, **model_kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment