import requests from PIL import Image import torch from transformers import AutoProcessor, LlavaNextForConditionalGeneration from pathlib import Path import os current_dir = str(Path(__file__).resolve().parent) pretrained = os.path.join(current_dir, "ckpts", "llava-v1.6-mistral-7b-hf") # Load the model in half-precision model = LlavaNextForConditionalGeneration.from_pretrained(pretrained, torch_dtype=torch.float16, device_map="auto") processor = AutoProcessor.from_pretrained(pretrained) # Get three different images # url = "https://www.ilankelman.org/stopsigns/australia.jpg" # image_stop = Image.open(requests.get(url, stream=True).raw) image_stop = Image.open("./examples/image.png") # url = "http://images.cocodataset.org/val2017/000000039769.jpg" # image_cats = Image.open(requests.get(url, stream=True).raw) image_cats = Image.open("./examples/cat.jpg") # url = "https://hugging-face.cn/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg" # image_snowman = Image.open(requests.get(url, stream=True).raw) image_snowman = Image.open("./examples/snowman.jpg") # Prepare a batch of two prompts, where the first one is a multi-turn conversation and the second is not conversation_1 = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": "What is shown in this image?"}, ], }, { "role": "assistant", "content": [ {"type": "text", "text": "There is a lake in the image."}, ], }, { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": "What about this image? How many cats do you see?"}, ], }, ] conversation_2 = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": "What is shown in this image?"}, ], }, ] prompt_1 = processor.apply_chat_template(conversation_1, add_generation_prompt=True) prompt_2 = processor.apply_chat_template(conversation_2, add_generation_prompt=True) prompts = [prompt_1, prompt_2] # We can simply feed images in the order they have to be used in the text prompt # Each "" token uses one image leaving the next for the subsequent "" tokens inputs = processor(images=[image_stop, image_cats, image_snowman], text=prompts, padding=True, return_tensors="pt").to(model.device) # Generate generate_ids = model.generate(**inputs, max_new_tokens=30) print(processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))