Unverified Commit fad15fba authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Llava: generate without images (#32183)

* llava w/o images

* tests
parent 4ab33c2d
...@@ -104,14 +104,14 @@ class LlavaProcessor(ProcessorMixin): ...@@ -104,14 +104,14 @@ class LlavaProcessor(ProcessorMixin):
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
""" """
if images is not None: if images is not None:
pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"] image_inputs = self.image_processor(images, return_tensors=return_tensors)
else: else:
pixel_values = None image_inputs = {}
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
) )
return BatchFeature(data={**text_inputs, "pixel_values": pixel_values}) return BatchFeature(data={**text_inputs, **image_inputs})
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs): def batch_decode(self, *args, **kwargs):
......
...@@ -458,3 +458,16 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase): ...@@ -458,3 +458,16 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
EXPECTED_OUTPUT = ['<|im_start|>', 'system', '\n', 'Answer', '▁the', '▁questions', '.', '<|im_end|>', '<|im_start|>', 'user', '\n', '<image>', '\n', 'What', '▁is', '▁shown', '▁in', '▁this', '▁image', '?', '<|im_end|>', '<|im_start|>', 'ass', 'istant', '\n'] # fmt: skip EXPECTED_OUTPUT = ['<|im_start|>', 'system', '\n', 'Answer', '▁the', '▁questions', '.', '<|im_end|>', '<|im_start|>', 'user', '\n', '<image>', '\n', 'What', '▁is', '▁shown', '▁in', '▁this', '▁image', '?', '<|im_end|>', '<|im_start|>', 'ass', 'istant', '\n'] # fmt: skip
self.assertEqual(slow_tokenizer.tokenize(prompt), EXPECTED_OUTPUT) self.assertEqual(slow_tokenizer.tokenize(prompt), EXPECTED_OUTPUT)
self.assertEqual(fast_tokenizer.tokenize(prompt), EXPECTED_OUTPUT) self.assertEqual(fast_tokenizer.tokenize(prompt), EXPECTED_OUTPUT)
@slow
@require_bitsandbytes
def test_generation_no_images(self):
model_id = "llava-hf/llava-1.5-7b-hf"
model = LlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)
processor = AutoProcessor.from_pretrained(model_id)
# Prepare inputs with no images
inputs = processor("Hello, I am", return_tensors="pt").to(torch_device)
# Make sure that `generate` works
_ = model.generate(**inputs, max_new_tokens=20)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment