Unverified Commit edbb37f7 authored by Sid Kiblawi's avatar Sid Kiblawi Committed by GitHub
Browse files

Add `inputs_embeds` functionality when generating with BioGPT (#21889)

* initial commit to add inputs_embeds to generation

* formatting
parent 3412f597
...@@ -703,17 +703,27 @@ class BioGptForCausalLM(BioGptPreTrainedModel): ...@@ -703,17 +703,27 @@ class BioGptForCausalLM(BioGptPreTrainedModel):
cross_attentions=outputs.cross_attentions, cross_attentions=outputs.cross_attentions,
) )
def prepare_inputs_for_generation(self, input_ids, attention_mask, past_key_values=None, **kwargs): def prepare_inputs_for_generation(
self, input_ids, attention_mask, inputs_embeds=None, past_key_values=None, **kwargs
):
# only last token for inputs_ids if past is defined in kwargs # only last token for inputs_ids if past is defined in kwargs
if past_key_values: if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1) input_ids = input_ids[:, -1].unsqueeze(-1)
return { if inputs_embeds is not None and past_key_values is None:
"input_ids": input_ids, model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"attention_mask": attention_mask, "attention_mask": attention_mask,
"past_key_values": past_key_values, "past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"), "use_cache": kwargs.get("use_cache"),
} }
)
return model_inputs
@staticmethod @staticmethod
def _reorder_cache(past_key_values, beam_idx): def _reorder_cache(past_key_values, beam_idx):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment