Unverified Commit 977b2f05 authored by Gabriel Asher's avatar Gabriel Asher Committed by GitHub
Browse files

Add input_embeds functionality to gpt_neo Causal LM (#25659)

* Updated gpt_neo causalLM to support using input embeddings for generation

* added indentation

* Did make fixup
parent 908f8536
...@@ -680,7 +680,7 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel): ...@@ -680,7 +680,7 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None) token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs # only last token for inputs_ids if past is defined in kwargs
if past_key_values: if past_key_values:
...@@ -698,7 +698,14 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel): ...@@ -698,7 +698,14 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
if past_key_values: if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1) position_ids = position_ids[:, -1].unsqueeze(-1)
return { # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"input_ids": input_ids, "input_ids": input_ids,
"past_key_values": past_key_values, "past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"), "use_cache": kwargs.get("use_cache"),
...@@ -706,6 +713,9 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel): ...@@ -706,6 +713,9 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
"attention_mask": attention_mask, "attention_mask": attention_mask,
"token_type_ids": token_type_ids, "token_type_ids": token_type_ids,
} }
)
return model_inputs
@add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment