Unverified Commit fa35cda9 authored by ymfa's avatar ymfa Committed by GitHub
Browse files

Pass encoder outputs into GenerationMixin (#10599)

* Pass encoder_outputs into generate()

* Remove an if-statement

* Reformat

* Minimize changes to generate()

* Comment on input_ids
parent 00cad2e5
......@@ -376,7 +376,14 @@ class GenerationMixin:
"""
return logits
def _prepare_input_ids_for_generation(self, bos_token_id: int) -> torch.LongTensor:
def _prepare_input_ids_for_generation(
self, bos_token_id: Optional[int], encoder_outputs: Optional[ModelOutput]
) -> torch.LongTensor:
if self.config.is_encoder_decoder and encoder_outputs is not None:
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
shape = encoder_outputs.last_hidden_state.size()[:-1]
return torch.ones(shape, dtype=torch.long, device=self.device) * -100
if bos_token_id is None:
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id
......@@ -395,6 +402,7 @@ class GenerationMixin:
def _prepare_encoder_decoder_kwargs_for_generation(
self, input_ids: torch.LongTensor, model_kwargs
) -> Dict[str, Any]:
if "encoder_outputs" not in model_kwargs:
# retrieve encoder hidden states
encoder = self.get_encoder()
encoder_kwargs = {
......@@ -887,7 +895,7 @@ class GenerationMixin:
if input_ids is None:
# init `input_ids` with bos_token_id
input_ids = self._prepare_input_ids_for_generation(bos_token_id)
input_ids = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))
if model_kwargs.get("attention_mask", None) is None:
# init `attention_mask` depending on `pad_token_id`
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment