Unverified Commit e3ecbaa4 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

Patch-past-refactor (#21050)

* small patches, forgot a line

* refactor PT

* the actual fix
parent 48d4e147
......@@ -729,13 +729,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
)
def prepare_inputs_for_generation(
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
past_key_values = decoder_inputs.get("past_key_values")
if past_key_values is None:
past_key_values = decoder_inputs.get("past") # e.g. on TF GPT2
input_dict = {
"pixel_values": None, # needs to be passed to make Keras.layer.__call__ happy
"attention_mask": attention_mask,
......
......@@ -649,9 +649,9 @@ class VisionEncoderDecoderModel(PreTrainedModel):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def prepare_inputs_for_generation(
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
input_dict = {
"attention_mask": attention_mask,
......
......@@ -3333,7 +3333,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
if past_key_values:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
......
......@@ -55,7 +55,6 @@ class ImageToTextPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta
)
@require_tf
@unittest.skip("Arthur will fix me!")
def test_small_model_tf(self):
pipe = pipeline("image-to-text", model="hf-internal-testing/tiny-random-vit-gpt2", framework="tf")
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment