"vscode:/vscode.git/clone" did not exist on "384f0eb2f9d42e44094dbfd0917ccf4e6ddb462a"
Unverified Commit 408453b4 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Add inputs embeds in generation (#30269)

* Add inputs embeds in generation

* always scale embeds

* fix-copies

* fix failing test

* fix copies once more

* remove embeds for models with scaling

* second try to revert

* codestyle
parent 6c1295a0
...@@ -594,7 +594,7 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel): ...@@ -594,7 +594,7 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None) token_type_ids = kwargs.get("token_type_ids", None)
# Omit tokens covered by past_key_values # Omit tokens covered by past_key_values
if past_key_values: if past_key_values:
...@@ -621,14 +621,22 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel): ...@@ -621,14 +621,22 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
if past_key_values: if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
return { # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
"input_ids": input_ids, if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()}
model_inputs.update(
{
"past_key_values": past_key_values, "past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"), "use_cache": kwargs.get("use_cache"),
"position_ids": position_ids, "position_ids": position_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"token_type_ids": token_type_ids, "token_type_ids": token_type_ids,
} }
)
return model_inputs
@add_start_docstrings_to_model_forward(CODEGEN_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(CODEGEN_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
......
...@@ -1212,6 +1212,7 @@ class FalconForCausalLM(FalconPreTrainedModel): ...@@ -1212,6 +1212,7 @@ class FalconForCausalLM(FalconPreTrainedModel):
past_key_values: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> dict: ) -> dict:
if past_key_values is not None: if past_key_values is not None:
...@@ -1234,13 +1235,20 @@ class FalconForCausalLM(FalconPreTrainedModel): ...@@ -1234,13 +1235,20 @@ class FalconForCausalLM(FalconPreTrainedModel):
if past_key_values: if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
return { if inputs_embeds is not None and past_key_values is None:
"input_ids": input_ids, model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids, "position_ids": position_ids,
"past_key_values": past_key_values, "past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"), "use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask, "attention_mask": attention_mask,
} }
)
return model_inputs
@add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
......
...@@ -1430,7 +1430,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -1430,7 +1430,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None) token_type_ids = kwargs.get("token_type_ids", None)
# Omit tokens covered by past_key_values # Omit tokens covered by past_key_values
if past_key_values: if past_key_values:
...@@ -1459,14 +1459,22 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -1459,14 +1459,22 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
else: else:
position_ids = None position_ids = None
return { # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
"input_ids": input_ids, if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()}
model_inputs.update(
{
"past_key_values": past_key_values, "past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"), "use_cache": kwargs.get("use_cache"),
"position_ids": position_ids, "position_ids": position_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"token_type_ids": token_type_ids, "token_type_ids": token_type_ids,
} }
)
return model_inputs
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
......
...@@ -1658,7 +1658,7 @@ class PegasusForCausalLM(PegasusPreTrainedModel): ...@@ -1658,7 +1658,7 @@ class PegasusForCausalLM(PegasusPreTrainedModel):
) )
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, use_cache=None, **kwargs
): ):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None: if attention_mask is None:
...@@ -1676,12 +1676,19 @@ class PegasusForCausalLM(PegasusPreTrainedModel): ...@@ -1676,12 +1676,19 @@ class PegasusForCausalLM(PegasusPreTrainedModel):
input_ids = input_ids[:, remove_prefix_length:] input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty # first step, decoder_cached_states are empty
return { if inputs_embeds is not None and past_key_values is None:
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()}
model_inputs.update(
{
"attention_mask": attention_mask, "attention_mask": attention_mask,
"past_key_values": past_key_values, "past_key_values": past_key_values,
"use_cache": use_cache, "use_cache": use_cache,
} }
)
return model_inputs
@staticmethod @staticmethod
def _reorder_cache(past_key_values, beam_idx): def _reorder_cache(past_key_values, beam_idx):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment