Unverified Commit abd25310 authored by Shauray Singh's avatar Shauray Singh Committed by GitHub
Browse files

Fix padding for IDEFICS (#26396)

* fix

* fixup

* tests

* fixup
parent 408b2b3c
...@@ -280,7 +280,7 @@ class IdeficsProcessor(ProcessorMixin): ...@@ -280,7 +280,7 @@ class IdeficsProcessor(ProcessorMixin):
else: else:
return fake_token + image_token + fake_token return fake_token + image_token + fake_token
all_texts = [] all_prompts = []
all_images = [] all_images = []
for sample in prompts: for sample in prompts:
# the model was trained on samples starting with <s> # the model was trained on samples starting with <s>
...@@ -321,17 +321,18 @@ class IdeficsProcessor(ProcessorMixin): ...@@ -321,17 +321,18 @@ class IdeficsProcessor(ProcessorMixin):
image_objects = self.image_processor(image_objects, transform=transform) image_objects = self.image_processor(image_objects, transform=transform)
text_encoding = self.tokenizer( all_prompts.append(full_text)
text=full_text,
add_special_tokens=False,
padding=padding,
truncation=truncation,
max_length=max_length,
)
all_texts.append(text_encoding["input_ids"])
all_images.append(image_objects) all_images.append(image_objects)
text_encoding = self.tokenizer(
text=all_prompts,
add_special_tokens=False,
padding=padding,
truncation=truncation,
max_length=max_length,
)
all_texts = text_encoding["input_ids"]
max_seq_len = max(len(x) for x in all_texts) max_seq_len = max(len(x) for x in all_texts)
# max_num_images has to be at least 1 even when there are no images # max_num_images has to be at least 1 even when there are no images
......
...@@ -141,6 +141,25 @@ class IdeficsProcessorTest(TestCasePlus): ...@@ -141,6 +141,25 @@ class IdeficsProcessorTest(TestCasePlus):
self.assertListEqual(decoded_tok, decoded_processor) self.assertListEqual(decoded_tok, decoded_processor)
def test_tokenizer_padding(self):
image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer(padding_side="right")
processor = IdeficsProcessor(tokenizer=tokenizer, image_processor=image_processor)
predicted_tokens = [
"<s>Describe this image.\nAssistant:<unk><unk><unk><unk><unk><unk><unk><unk><unk>",
"<s>Describe this image.\nAssistant:<unk><unk><unk><unk><unk><unk><unk><unk><unk><unk>",
]
prompts = [[prompt] for prompt in self.prepare_prompts()[2]]
max_length = processor(prompts, padding="max_length", truncation=True, max_length=20)
longest = processor(prompts, padding="longest", truncation=True, max_length=30)
decoded_max_length = processor.tokenizer.decode(max_length["input_ids"][-1])
decoded_longest = processor.tokenizer.decode(longest["input_ids"][-1])
self.assertEqual(decoded_max_length, predicted_tokens[1])
self.assertEqual(decoded_longest, predicted_tokens[0])
def test_model_input_names(self): def test_model_input_names(self):
image_processor = self.get_image_processor() image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment