Unverified Commit 9fdf158a authored by Lei Li's avatar Lei Li Committed by GitHub
Browse files

Add inputs_embeds functionality when generating with GPT-Neox (#22916)



* support gpt neox generate with inputs embeds

* Update src/transformers/models/gpt_neox/modeling_gpt_neox.py

great thx for the suggestion!
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

---------
Co-authored-by: default avatarLei Li <tobiaslee@qq.com>
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
parent ec93b895
......@@ -697,7 +697,9 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
input_shape = input_ids.shape
# cut decoder_input_ids if past is used
......@@ -716,12 +718,21 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
return {
"input_ids": input_ids,
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"attention_mask": attention_mask,
"position_ids": position_ids,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
)
return model_inputs
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment