"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "826496580beae08289452da0eda914bdc40a95bb"
Unverified Commit 9fdf158a authored by Lei Li's avatar Lei Li Committed by GitHub
Browse files

Add inputs_embeds functionality when generating with GPT-Neox (#22916)



* support gpt neox generate with inputs embeds

* Update src/transformers/models/gpt_neox/modeling_gpt_neox.py

great thx for the suggestion!
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

---------
Co-authored-by: default avatarLei Li <tobiaslee@qq.com>
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
parent ec93b895
...@@ -697,7 +697,9 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel): ...@@ -697,7 +697,9 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
attentions=outputs.attentions, attentions=outputs.attentions,
) )
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs): def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
input_shape = input_ids.shape input_shape = input_ids.shape
# cut decoder_input_ids if past is used # cut decoder_input_ids if past is used
...@@ -716,12 +718,21 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel): ...@@ -716,12 +718,21 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape) attention_mask = input_ids.new_ones(input_shape)
return { # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
"input_ids": input_ids, if inputs_embeds is not None and past_key_values is None:
"attention_mask": attention_mask, model_inputs = {"inputs_embeds": inputs_embeds}
"position_ids": position_ids, else:
"past_key_values": past_key_values, model_inputs = {"input_ids": input_ids}
}
model_inputs.update(
{
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
)
return model_inputs
def _reorder_cache(self, past_key_values, beam_idx): def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = () reordered_past = ()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment