"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "5ad960f1f4f77f436ddf3de3692d09949a27c2df"
Unverified Commit edbb37f7 authored by Sid Kiblawi's avatar Sid Kiblawi Committed by GitHub
Browse files

Add `inputs_embeds` functionality when generating with BioGPT (#21889)

* initial commit to add inputs_embeds to generation

* formatting
parent 3412f597
......@@ -703,17 +703,27 @@ class BioGptForCausalLM(BioGptPreTrainedModel):
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, attention_mask, past_key_values=None, **kwargs):
def prepare_inputs_for_generation(
self, input_ids, attention_mask, inputs_embeds=None, past_key_values=None, **kwargs
):
# only last token for inputs_ids if past is defined in kwargs
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
}
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
}
)
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment