Unverified Commit 10056d89 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

OPT: BLIP2-ready `prepare_inputs_for_generation` (#21477)

parent baf4bacb
...@@ -965,21 +965,25 @@ class OPTForCausalLM(OPTPreTrainedModel): ...@@ -965,21 +965,25 @@ class OPTForCausalLM(OPTPreTrainedModel):
) )
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
): ):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past_key_values: if past_key_values:
input_ids = input_ids[:, -1:] input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return { # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed if inputs_embeds is not None and past_key_values is None:
"attention_mask": attention_mask, model_inputs = {"inputs_embeds": inputs_embeds}
"past_key_values": past_key_values, else:
"use_cache": use_cache, model_inputs = {"input_ids": input_ids}
}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
@staticmethod @staticmethod
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment