Commit d5878167 authored by mashun1's avatar mashun1
Browse files

llava-next

parents
Pipeline #2589 failed with stages
in 0 seconds
icon.png

53.8 KB

# LLaVA-NeXT: Stronger LLMs Supercharge Multimodal Capabilities in the Wild
## 论文
`LLaVA-NeXT: Stronger LLMs Supercharge Multimodal Capabilities in the Wild`
* https://llava-vl.github.io/blog/2024-05-10-llava-next-stronger-llms/
## 模型结构
参考[README.md](../README.md)
## 算法原理
参考[README.md](../README.md)
## 数据集
## 训练
## 推理
```bash
python image.py
```
注意:在运行前需修改文件中的模型路径。
### 评估
```bash
export HF_ENDPOINT=https://hf-mirror.com
huggingface-cli login --token $HUGGINGFACE_TOKEN --add-to-git-credential
```
注意:此命令为自动下载评估数据所需。
```bash
accelerate launch --num_processes=8 \
-m lmms_eval \
--model llava \
--model_args pretrained=/path/to/llama3-llava-next-8b,conv_template=llava_llama_3 \
--tasks ai2d,chartqa,docvqa_val,mme,mmbench_en_dev \
--batch_size 1 \
--log_samples \
--log_samples_suffix llava_next \
--output_path ./logs/
```
```bash
accelerate launch --num_processes=1 \
-m lmms_eval \
--model llava \
--model_args pretrained=/path/to/llava-next-72b,conv_template=qwen_1_5,model_name=llava_qwen,device_map=auto \
--tasks ai2d,chartqa,docvqa_val,mme,mmbench_en_dev \
--batch_size 1 \
--log_samples \
--log_samples_suffix llava_next \
--output_path ./logs/
```
## result
![alt text](readme_imgs/multimodal-8b.png)
### 精度
## 应用场景
参考[README.md](../README.md)
## 预训练权重
|model|url|
|:---:|:---:|
|llama3-llava-next-8b|[hf](https://huggingface.co/lmms-lab/llama3-llava-next-8b) \| [SCNet]() |
|llava-next-qwen-32b|[hf](https://huggingface.co/lmms-lab/llava-next-qwen-32b) \| [SCNet]() |
模型下载后保存至`ckpts`(需自行创建).
## 源码仓库及问题反馈
参考[README.md](../README.md)
## 参考资料
* https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/docs/LLaVA-NeXT.md
* https://llava-vl.github.io/blog/2024-05-10-llava-next-stronger-llms/
\ No newline at end of file
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
from llava.conversation import conv_templates, SeparatorStyle
from PIL import Image
import requests
import copy
import torch
from pathlib import Path
import os
current_dir = str(Path(__file__).resolve().parent)
pretrained = os.path.join(current_dir, "ckpts", "llama3-llava-next-8b")
model_name = "llava_llama3"
device = "cuda"
device_map = "auto"
tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map) # Add any other thing you want to pass in llava_model_args
model.eval()
model.tie_weights()
image = Image.open("./examples/llava_v1_5_radar.jpg")
image_tensor = process_images([image], image_processor, model.config)
image_tensor = [_image.to(dtype=torch.float16, device=device) for _image in image_tensor]
conv_template = "llava_llama_3" # Make sure you use correct chat template for different models
question = DEFAULT_IMAGE_TOKEN + "\nWhat is shown in this image?"
conv = copy.deepcopy(conv_templates[conv_template])
conv.tokenizer = tokenizer
conv.append_message(conv.roles[0], question)
conv.append_message(conv.roles[1], None)
prompt_question = conv.get_prompt()
input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
image_sizes = [image.size]
cont = model.generate(
input_ids,
images=image_tensor,
image_sizes=image_sizes,
do_sample=False,
temperature=0,
max_new_tokens=256,
)
text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)
print(text_outputs)
\ No newline at end of file
import requests
from PIL import Image
import torch
from transformers import AutoProcessor, LlavaNextForConditionalGeneration
from pathlib import Path
import os
current_dir = str(Path(__file__).resolve().parent)
pretrained = os.path.join(current_dir, "ckpts", "llava-v1.6-mistral-7b-hf")
# Load the model in half-precision
model = LlavaNextForConditionalGeneration.from_pretrained(pretrained, torch_dtype=torch.float16, device_map="auto")
processor = AutoProcessor.from_pretrained(pretrained)
# Get three different images
# url = "https://www.ilankelman.org/stopsigns/australia.jpg"
# image_stop = Image.open(requests.get(url, stream=True).raw)
image_stop = Image.open("./examples/image.png")
# url = "http://images.cocodataset.org/val2017/000000039769.jpg"
# image_cats = Image.open(requests.get(url, stream=True).raw)
image_cats = Image.open("./examples/cat.jpg")
# url = "https://hugging-face.cn/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
# image_snowman = Image.open(requests.get(url, stream=True).raw)
image_snowman = Image.open("./examples/snowman.jpg")
# Prepare a batch of two prompts, where the first one is a multi-turn conversation and the second is not
conversation_1 = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "There is a lake in the image."},
],
},
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What about this image? How many cats do you see?"},
],
},
]
conversation_2 = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
prompt_1 = processor.apply_chat_template(conversation_1, add_generation_prompt=True)
prompt_2 = processor.apply_chat_template(conversation_2, add_generation_prompt=True)
prompts = [prompt_1, prompt_2]
# We can simply feed images in the order they have to be used in the text prompt
# Each "<image>" token uses one image leaving the next for the subsequent "<image>" tokens
inputs = processor(images=[image_stop, image_cats, image_snowman], text=prompts, padding=True, return_tensors="pt").to(model.device)
# Generate
generate_ids = model.generate(**inputs, max_new_tokens=30)
print(processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
\ No newline at end of file
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
import torch
from PIL import Image
import requests
from pathlib import Path
import os
current_dir = str(Path(__file__).resolve().parent)
pretrained = os.path.join(current_dir, "ckpts", "llava-v1.6-mistral-7b-hf")
processor = LlavaNextProcessor.from_pretrained(pretrained)
model = LlavaNextForConditionalGeneration.from_pretrained(pretrained, torch_dtype=torch.float16, low_cpu_mem_usage=True)
model.to("cuda")
# prepare image and text prompt, using the appropriate prompt template
# url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
# image = Image.open(requests.get(url, stream=True).raw)
image = Image.open("./examples/image.png")
# Define a chat history and use `apply_chat_template` to get correctly formatted prompt
# Each value in "content" has to be a list of dicts with types ("text", "image")
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
{"type": "image"},
],
},
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(images=image, text=prompt, return_tensors="pt").to("cuda")
# autoregressively complete prompt
output = model.generate(**inputs, max_new_tokens=100)
print(processor.decode(output[0], skip_special_tokens=True))
# LLaVA-NeXT: Tackling Multi-image, Video, and 3D in Large Multimodal Models
## 论文
`LLaVA-NeXT-Interleave: Tackling Multi-image, Video, and 3D
in Large Multimodal Models`
* https://arxiv.org/pdf/2407.07895
## 模型结构
参考[README.md](../README.md)
## 算法原理
[README.md](../README.md)基础上,LLaVA-NeXT-Interleave的核心是通过统一的数据格式和联合训练策略,实现多模态任务的泛化与迁移。
## 数据集
## 训练
## 推理
在 inference分支中
### 原生
```bash
python ../playground/demo/interleave_demo.py --model_path path/to/ckpt
```
### hf
```bash
python inference_hf.py
```
注意:运行前需要修改脚本中相应路径。
## result
![alt text](readme_imgs/result.png)
### 精度
## 应用场景
参考[README.md](../README.md)
## 预训练权重
|model|url|
|:---:|:---:|
|llava-next-interleave-qwen-7b|[hf](https://huggingface.co/lmms-lab/llava-next-interleave-qwen-7b) \| [SCNet]() |
|llava-next-interleave-qwen-0.5b|[hf](https://hf-mirror.com/lmms-lab/llava-next-interleave-qwen-0.5b) \| [SCNet]() |
|llava-interleave-qwen-0.5b-hf|[hf](https://huggingface.co/llava-hf/llava-interleave-qwen-0.5b-hf) \| [SCNet]() |
|llava-interleave-qwen-7b-hf|[hf](https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf) \| [SCNet]() |
|llava-interleave-qwen-7b-dpo-hf|[hf](https://huggingface.co/llava-hf/llava-interleave-qwen-7b-dpo-hf) \| [SCNet]() |
模型下载后保存至`ckpts`(需自行创建).
## 源码仓库及问题反馈
参考[README.md](../README.md)
## 参考资料
* https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/docs/LLaVA-NeXT-Interleave.md
* https://llava-vl.github.io/blog/2024-06-16-llava-next-interleave/
import requests
from PIL import Image
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
from pathlib import Path
import os
current_dir = str(Path(__file__).resolve().parent)
model_id = os.path.join(current_dir, "ckpts", "llava-interleave-qwen-0.5b-hf")
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(0)
processor = AutoProcessor.from_pretrained(model_id)
# Define a chat history and use `apply_chat_template` to get correctly formatted prompt
# Each value in "content" has to be a list of dicts with types ("text", "image")
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "What are these?"},
{"type": "image"},
],
},
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
# image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
# raw_image = Image.open(requests.get(image_file, stream=True).raw)
raw_image = Image.open("./examples/cat.jpg")
inputs = processor(images=raw_image, text=prompt, return_tensors='pt').to(0, torch.float16)
output = model.generate(**inputs, max_new_tokens=200, do_sample=False)
print(processor.decode(output[0][2:], skip_special_tokens=True))
from .model import LlavaLlamaForCausalLM
CONTROLLER_HEART_BEAT_EXPIRATION = 30
WORKER_HEART_BEAT_INTERVAL = 15
LOGDIR = "."
# Model Constants
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
import dataclasses
from enum import auto, Enum
from typing import List, Any, Dict, Union, Tuple
import re
import base64
from io import BytesIO
from PIL import Image
from transformers import AutoTokenizer
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
MPT = auto()
PLAIN = auto()
CHATML = auto()
LLAMA_2 = auto()
LLAMA_3 = auto()
QWEN = auto()
GEMMA = auto()
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
version: str = "Unknown"
tokenizer_id: str = ""
tokenizer: Any = None
# Stop criteria (the default one is EOS token)
stop_str: Union[str, List[str]] = None
# Stops generation if meeting any token in this list
stop_token_ids: List[int] = None
skip_next: bool = False
def get_prompt(self):
messages = self.messages
if len(messages) > 0 and type(messages[0][1]) is tuple:
messages = self.messages.copy()
init_role, init_msg = messages[0].copy()
init_msg = init_msg[0]
if "mmtag" in self.version:
init_msg = init_msg.replace("<image>", "").strip()
messages[0] = (init_role, init_msg)
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
messages.insert(1, (self.roles[1], "Received."))
elif not init_msg.startswith("<image>"):
init_msg = init_msg.replace("<image>", "").strip()
messages[0] = (init_role, "<image>\n" + init_msg)
else:
messages[0] = (init_role, init_msg)
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system + self.sep
for role, message in messages:
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + self.sep
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.CHATML:
ret = "" if self.system == "" else self.system + self.sep + "\n"
for role, message in messages:
if message:
if type(message) is tuple:
message, images, _ = message
message = "<image>" * len(images) + message
ret += role + "\n" + message + self.sep + "\n"
else:
ret += role + "\n"
return ret
elif self.sep_style == SeparatorStyle.LLAMA_3:
if self.tokenizer is None:
raise ValueError("Llama 3 tokenizer is not available. Make sure you have the necessary permissions.")
chat_template_messages = [{"role": "system", "content": self.system}]
for role, message in messages:
if message:
if type(message) is tuple:
message, images = message
message = "<image>" * len(images) + message
chat_template_messages.append({"role": role, "content": message})
# print(chat_template_messages)
return self.tokenizer.apply_chat_template(chat_template_messages, tokenize=False, add_generation_prompt=True)
# ret = "" if self.system == "" else self.system + self.sep + "\n"
# for role, message in messages:
# if message:
# if type(message) is tuple:
# message, images = message
# message = "<image>" * len(images) + message
# ret += role + "\n" + message + self.sep + "\n"
# else:
# ret += role + "\n"
# return ret
elif self.sep_style == SeparatorStyle.MPT:
ret = self.system + self.sep
for role, message in messages:
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + message + self.sep
else:
ret += role
elif self.sep_style == SeparatorStyle.GEMMA:
ret = ""
for i, (role, message) in enumerate(messages):
assert role == self.roles[i % 2], "Conversation should alternate user/assistant/user/assistant/..."
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + message + self.sep
else:
ret += role
elif self.sep_style == SeparatorStyle.LLAMA_2:
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
ret = ""
for i, (role, message) in enumerate(messages):
if i == 0:
assert message, "first message should not be none"
assert role == self.roles[0], "first message should come from user"
if message:
if type(message) is tuple:
message, _, _ = message
if i == 0:
message = wrap_sys(self.system) + message
if i % 2 == 0:
message = wrap_inst(message)
ret += self.sep + message
else:
ret += " " + message + " " + self.sep2
else:
ret += ""
ret = ret.lstrip(self.sep)
elif self.sep_style == SeparatorStyle.PLAIN:
seps = [self.sep, self.sep2]
ret = self.system
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += message + seps[i % 2]
else:
ret += ""
else:
raise ValueError(f"Invalid style: {self.sep_style}")
return ret
def append_message(self, role, message):
self.messages.append([role, message])
def process_image(self, image, image_process_mode, return_pil=False, image_format="PNG"):
if image_process_mode == "Pad":
def expand2square(pil_img, background_color=(122, 116, 104)):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image)
elif image_process_mode in ["Default", "Crop"]:
pass
elif image_process_mode == "Resize":
image = image.resize((336, 336))
else:
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
if type(image) is not Image.Image:
image = Image.open(image).convert("RGB")
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 672, 448
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
if return_pil:
return image
else:
buffered = BytesIO()
image.save(buffered, format=image_format)
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
return img_b64_str
def get_images(self, return_pil=False, return_path=False):
images = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
if type(msg) is tuple:
msg, image, image_process_mode = msg
if type(image) != list:
image = [image]
for img in image:
if not return_path and self.is_image_file(img):
img = self.process_image(img, image_process_mode, return_pil=return_pil)
else:
images.append(img)
return images
def is_image_file(self, filename):
image_extensions = [".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".webp"]
return any(filename.lower().endswith(ext) for ext in image_extensions)
def is_video_file(self, filename):
video_extensions = [".mp4", ".mov", ".avi", ".mkv", ".wmv", ".flv", ".mpeg", ".mpg"]
return any(filename.lower().endswith(ext) for ext in video_extensions)
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
if type(msg) is tuple:
msg, image, image_process_mode = msg
if type(image) != list:
image = [image]
if len(image) == 1:
msg = "<image>\n" + msg.replace("<image>", "").strip()
else:
msg = re.sub(r"(<image>)\n(?=<image>)", r"\1 ", msg)
img_str_list = []
for img in image:
if self.is_image_file(img):
img_b64_str = self.process_image(img, "Default", return_pil=False, image_format="JPEG")
img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" style="max-width: 256px; max-height: 256px; width: auto; height: auto; object-fit: contain;"/>'
img_str_list.append(img_str)
elif self.is_video_file(img):
ret.append(((img,), None))
msg = msg.strip()
img_place_holder = ""
for img_str in img_str_list:
img_place_holder += f"{img_str}\n\n"
if len(img_str_list) > 0:
msg = f"{img_place_holder}\n\n{msg}"
if len(msg) > 0:
ret.append([msg, None])
else:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def copy(self):
return Conversation(system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2, version=self.version)
def dict(self):
if len(self.get_images()) > 0:
return {
"system": self.system,
"roles": self.roles,
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
conv_vicuna_v0 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=[
["Human", "What are the key differences between renewable and non-renewable energy sources?"],
[
"Assistant",
"Renewable energy sources are those that can be replenished naturally in a relatively "
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
"renewable and non-renewable energy sources:\n"
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
"energy sources are finite and will eventually run out.\n"
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
"and other negative effects.\n"
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
"have lower operational costs than non-renewable sources.\n"
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
"locations than non-renewable sources.\n"
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n",
],
],
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_vicuna_v1 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=[],
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
conv_llama_2 = Conversation(
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
roles=("USER", "ASSISTANT"),
version="llama_v2",
messages=[],
offset=0,
sep_style=SeparatorStyle.LLAMA_2,
sep="<s>",
sep2="</s>",
)
conv_llava_llama_2 = Conversation(
system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.",
roles=("USER", "ASSISTANT"),
version="llama_v2",
messages=[],
offset=0,
sep_style=SeparatorStyle.LLAMA_2,
sep="<s>",
sep2="</s>",
)
def safe_load_tokenizer(tokenizer_id):
try:
return AutoTokenizer.from_pretrained(tokenizer_id)
except Exception:
return None
conv_llava_llama_3 = Conversation(
system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.",
roles=("user", "assistant"),
version="llama_v3",
messages=[],
offset=0,
sep="<|eot_id|>",
sep_style=SeparatorStyle.LLAMA_3,
tokenizer_id="meta-llama/Meta-Llama-3-8B-Instruct",
tokenizer=safe_load_tokenizer("meta-llama/Meta-Llama-3-8B-Instruct"),
stop_token_ids=[128009],
)
conv_mistral_instruct = Conversation(
system="",
roles=("USER", "ASSISTANT"),
version="llama_v2",
messages=[],
offset=0,
sep_style=SeparatorStyle.LLAMA_2,
sep="",
sep2="</s>",
)
conv_llava_llama_2_simple = Conversation(
system="Answer the questions about the visual content that the user provides.",
roles=("USER", "ASSISTANT"),
version="llama_v2",
messages=[],
offset=0,
sep_style=SeparatorStyle.LLAMA_2,
sep="<s>",
sep2="</s>",
)
conv_llava_llama_2_mmtag = Conversation(
system="Answer the questions about the visual content that the user provides." "The visual content will be provided with the following format: <Image>visual content</Image>.",
roles=("USER", "ASSISTANT"),
version="llama_v2_mmtag",
messages=[],
offset=0,
sep_style=SeparatorStyle.LLAMA_2,
sep="<s>",
sep2="</s>",
)
conv_mpt = Conversation(
system="""<|im_start|>system
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
version="mpt",
messages=[],
offset=0,
sep_style=SeparatorStyle.MPT,
sep="<|im_end|>",
)
conv_qwen = Conversation(
system="""<|im_start|>system
You are a helpful assistant.""",
roles=("<|im_start|>user", "<|im_start|>assistant"),
version="qwen",
messages=[],
offset=0,
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>",
)
conv_gemma_instruct = Conversation(system="", roles=("<start_of_turn>user\n", "<start_of_turn>model\n"), version="gemma", messages=[], offset=0, sep_style=SeparatorStyle.GEMMA, sep="<end_of_turn>\n")
conv_llava_plain = Conversation(
system="",
roles=("", ""),
messages=[],
offset=0,
sep_style=SeparatorStyle.PLAIN,
sep="\n",
)
conv_llava_v0 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=[],
offset=0,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_llava_v0_mmtag = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
"The visual content will be provided with the following format: <Image>visual content</Image>.",
roles=("Human", "Assistant"),
messages=[],
offset=0,
sep_style=SeparatorStyle.SINGLE,
sep="###",
version="v0_mmtag",
)
conv_llava_v1 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=[],
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
conv_llava_v1_mmtag = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
"The visual content will be provided with the following format: <Image>visual content</Image>.",
roles=("USER", "ASSISTANT"),
messages=[],
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
version="v1_mmtag",
)
conv_mistral_orca = Conversation(
system="""<|im_start|>system
You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!""",
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
version="mpt",
messages=[],
offset=0,
sep_style=SeparatorStyle.MPT,
sep="<|im_end|>",
)
conv_mistral_zephyr = Conversation(
system="""<|system|>
You are a helpful AI assistant.""",
roles=("<|user|>\n", "<|assistant|>\n"),
version="mpt",
messages=[],
offset=0,
sep_style=SeparatorStyle.MPT,
sep="</s>",
)
conv_mistral_direct = Conversation(
system="""<|im_start|>system
Answer the questions.""",
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
version="mpt",
messages=[],
offset=0,
sep_style=SeparatorStyle.MPT,
sep="<|im_end|>",
)
conv_chatml_direct = Conversation(
system="""<|im_start|>system
Answer the questions.""",
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
version="mpt",
messages=[],
offset=0,
sep_style=SeparatorStyle.MPT,
sep="<|im_end|>",
)
default_conversation = conv_vicuna_v0
conv_templates = {
"default": conv_vicuna_v0,
"v0": conv_vicuna_v0,
"v1": conv_vicuna_v1,
"vicuna_v1": conv_vicuna_v1,
"llama_2": conv_llama_2,
"mistral_instruct": conv_mistral_instruct,
"mistral_orca": conv_mistral_orca,
"mistral_zephyr": conv_mistral_zephyr,
"mistral_direct": conv_mistral_direct,
"plain": conv_llava_plain,
"v0_plain": conv_llava_plain,
"chatml_direct": conv_chatml_direct,
"llava_v0": conv_llava_v0,
"llava_v0_mmtag": conv_llava_v0_mmtag,
"llava_v1": conv_llava_v1,
"llava_v1_mmtag": conv_llava_v1_mmtag,
"llava_llama_2": conv_llava_llama_2,
"llava_llama_3": conv_llava_llama_3,
"llava_llama_2_simple": conv_llava_llama_2_simple,
"llava_llama_2_mmtag": conv_llava_llama_2_mmtag,
"llava_mistral_instruct": conv_mistral_instruct,
"mpt": conv_mpt,
"qwen_1_5": conv_qwen,
"qwen_2": conv_qwen,
"gemma_instruct": conv_gemma_instruct,
}
if __name__ == "__main__":
print(default_conversation.get_prompt())
import re
from rouge import Rouge
import argparse
import os
import json
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
spot_the_diff = ["Spot-the-Diff", "Birds-to-Words", "CLEVR-Change"]
image_edit_instruct = ["IEdit", "HQ-Edit", "MagicBrush"]
visual_story_telling = ["AESOP", "FlintstonesSV", "PororoSV", "VIST"]
visual_cloze = ["COMICS_Dialogue", "RecipeQA_VisualCloze"]
text_rich_vqa = ["WebQA", "TQA", "OCR-VQA", "DocVQA"]
multi_image_vqa = ["MIT-States_StateCoherence", "MIT-States_PropertyCoherence", "VISION", "RecipeQA_ImageCoherence"]
puzzle = ["RAVEN"]
nlrv2 = ["NLVR2_Mantis"]
qbench = ["QBench"]
class Eval:
def __init__(self):
self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
self.commaStrip = re.compile("(\d)(\,)(\d)")
self.punct = [
";",
r"/",
"[",
"]",
'"',
"{",
"}",
"(",
")",
"=",
"+",
"\\",
"_",
"-",
">",
"<",
"@",
"`",
",",
"?",
"!",
]
def processPunctuation(self, inText):
outText = inText
for p in self.punct:
if (p + " " in inText or " " + p in inText) or (
re.search(self.commaStrip, inText) != None
):
outText = outText.replace(p, "")
else:
outText = outText.replace(p, " ")
outText = self.periodStrip.sub("", outText, re.UNICODE)
return outText
def process(self, answer):
answer = answer.replace("\n", " ")
answer = answer.replace("\t", " ")
answer = answer.strip()
answer = self.processPunctuation(answer)
answer = answer.strip('\'')
answer = answer.strip('\"')
answer = answer.strip(')')
answer = answer.strip('(')
answer = answer.strip().lower()
return answer
def evaluate_rouge(self,preds):
rouge = Rouge()
acc = {'f': []}
eval_list = []
for i, res in enumerate(preds):
sample_id = res['sample_id']
# print(sample_id)
gt_ans = self.process(res["gt_response"])
pred_ans = self.process(res["pred_response"])
# assert gt_ans != ''
if gt_ans == '':
continue
if pred_ans == '':
s = 0
else:
if len(pred_ans) > 512:
pred_ans = pred_ans[0: 512]
s = rouge.get_scores(pred_ans, gt_ans)[0]['rouge-l']['f']
acc['f'].append(s)
eval_list.append({'id':str(sample_id),'score':str(round(s,3))})
results = {'Rouge-L f': np.mean(acc['f'])}
return results,eval_list
def judge_multi_choice(self,sample):
sample_id = sample['sample_id']
gt_ans = sample["gt_response"]
pred_ans = sample["pred_response"]
if ":" in pred_ans:
a_list = pred_ans.split(":")
a_list = [a.strip() for a in a_list ]
for a in a_list:
if len(a) == 1 and a[-1] in ["a", "b", "c", "d", "e", "f", "g", "h"]:
pred_ans = a
if pred_ans == gt_ans:
return 1
else:
return 0
def process_sample(self,sample):
sample["gt_response"] = self.process(sample["gt_response"])
sample["pred_response"] = self.process(sample["pred_response"])
def evaluate_multichoice(self, preditions):
correct = 0
eval_list = []
for i, sample in enumerate(preditions):
self.process_sample(sample)
score = self.judge_multi_choice(sample)
sample_id = sample['sample_id']
sample['result'] = score
eval_list.append({'id':str(sample_id),'score':str(score)})
correct+=score
return {'Accuracy':correct/len(preditions)},eval_list
def evaluate_multi_choice_image(self,preditions):
correct = 0
eval_list = []
for i,sample in enumerate(preditions):
gt_ans = self.process(sample["gt_response"])
pred_ans = self.process(sample["pred_response"])
sample_id = sample['sample_id']
if ":" in pred_ans:
a_list = pred_ans.split(":")
a_list = [a.strip() for a in a_list ]
for a in a_list:
if len(a) == 1 and a[-1] in ["a", "b", "c", "d", "e", "f", "g", "h"]:
pred_ans = a
if gt_ans == pred_ans:
score = 1
else:
score = 0
sample_id = sample['sample_id']
sample['result'] = score
eval_list.append({'id':str(sample_id),'score':str(score)})
correct+=score
return {'Accuracy':correct/len(preditions)},eval_list
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--result-dir', type=str, required=True)
args = parser.parse_args()
result_file = os.path.join(args.result_dir, "result.jsonl")
if not os.path.exists(result_file):
print('No prediction file found')
exit(0)
with open(result_file, 'r') as f:
preds_all = [json.loads(line) for line in f]
preds_all_dict = dict()
for pred in preds_all:
if pred["dataset"] not in preds_all_dict:
preds_all_dict[pred["dataset"]] = list()
preds_all_dict[pred["dataset"]].append(pred)
image_choice_dataset_list = ["recipeqa-RecipeQA_VisualCloze", "RecipeQA_ImageCoherence", "COMICS_Panel"]
E = Eval()
eval_result_list = dict()
eval_result_list_detail = dict()
for dataset in preds_all_dict:
preds = preds_all_dict[dataset]
question_type = preds[0]["question_type"]
if question_type == 'open-ended':
eval_result, eval_list = E.evaluate_rouge(preds)
elif question_type == 'multi-choice' or dataset == 'nlrv2':
if dataset in image_choice_dataset_list:
eval_result, eval_list = E.evaluate_multi_choice_image(preds)
else:
eval_result, eval_list = E.evaluate_multichoice(preds)
else:
eval_result = 'Dataset not supported'
print('Dataset not supported')
exit(0)
print(dataset, end = ': ')
print(eval_result)
eval_result_list[dataset] = eval_result
eval_result_list_detail[dataset] = eval_list
os.makedirs(args.result_dir, exist_ok=True)
with open(os.path.join(args.result_dir, 'eval_dataset.json'), 'w') as f:
json.dump(eval_result_list, f, indent=4)
with open(os.path.join(args.result_dir,'eval_dataset_details.json'), 'w') as f:
json.dump(eval_result_list_detail, f, indent=4)
eval_cat_list = dict()
print()
# spot_the_diff
score = 0
count = 0
for dataset in eval_result_list:
if dataset in spot_the_diff:
count += 1
score += list(eval_result_list[dataset].values())[0]
if count > 0:
score /= count
eval_cat_list["spot_the_diff"] = score
print("spot_the_diff", end = ': ')
print('{:.2f}'.format(100 * score))
# image_edit_instruct
score = 0
count = 0
for dataset in eval_result_list:
if dataset in image_edit_instruct:
count += 1
score += list(eval_result_list[dataset].values())[0]
if count > 0:
score /= count
eval_cat_list["image_edit_instruct"] = score
print("image_edit_instruct", end = ': ')
print('{:.2f}'.format(100 * score))
# visual_story_telling
score = 0
count = 0
for dataset in eval_result_list:
if dataset in visual_story_telling:
count += 1
score += list(eval_result_list[dataset].values())[0]
if count > 0:
score /= count
eval_cat_list["visual_story_telling"] = score
print("visual_story_telling", end = ': ')
print('{:.2f}'.format(100 * score))
# visual_cloze
score = 0
count = 0
for dataset in eval_result_list:
if dataset in visual_cloze:
count += 1
score += list(eval_result_list[dataset].values())[0]
if count > 0:
score /= count
eval_cat_list["visual_cloze"] = score
print("visual_cloze", end = ': ')
print('{:.2f}'.format(100 * score))
# text_rich_vqa
score = 0
count = 0
for dataset in eval_result_list:
if dataset in text_rich_vqa:
count += 1
score += list(eval_result_list[dataset].values())[0]
if count > 0:
score /= count
eval_cat_list["text_rich_vqa"] = score
print("text_rich_vqa", end = ': ')
print('{:.2f}'.format(100 * score))
# multi_image_vqa
score = 0
count = 0
for dataset in eval_result_list:
if dataset in multi_image_vqa:
count += 1
score += list(eval_result_list[dataset].values())[0]
if count > 0:
score /= count
eval_cat_list["multi_image_vqa"] = score
print("multi_image_vqa", end = ': ')
print('{:.2f}'.format(100 * score))
# puzzle
score = 0
count = 0
for dataset in eval_result_list:
if dataset in puzzle:
count += 1
score += list(eval_result_list[dataset].values())[0]
if count > 0:
score /= count
eval_cat_list["puzzle"] = score
print("puzzle", end = ': ')
print('{:.2f}'.format(100 * score))
# nlrv2
score = 0
count = 0
for dataset in eval_result_list:
if dataset in nlrv2:
count += 1
score += list(eval_result_list[dataset].values())[0]
if count > 0:
score /= count
eval_cat_list["nlrv2"] = score
print("nlrv2", end = ': ')
print('{:.2f}'.format(100 * score))
# qbench
score = 0
count = 0
for dataset in eval_result_list:
if dataset in qbench:
count += 1
score += list(eval_result_list[dataset].values())[0]
if count > 0:
score /= count
eval_cat_list["qbench"] = score
print("qbench", end = ': ')
print('{:.2f}'.format(100 * score))
with open(os.path.join(args.result_dir,'eval_cat.json'), 'w') as f:
json.dump(eval_cat_list, f, indent=4)
\ No newline at end of file
import argparse
import torch
import os
import json
from tqdm import tqdm
import shortuuid
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from llava.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_INDEX
from typing import Dict, Optional, Sequence, List
import transformers
import re
from PIL import Image
import math
def split_list(lst, n):
"""Split a list into n (roughly) equal-sized chunks"""
chunk_size = math.ceil(len(lst) / n) # integer division
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
def get_chunk(lst, n, k):
chunks = split_list(lst, n)
return chunks[k]
def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict:
roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"}
im_start, im_end = tokenizer.additional_special_tokens_ids
nl_tokens = tokenizer("\n").input_ids
_system = tokenizer("system").input_ids + nl_tokens
_user = tokenizer("user").input_ids + nl_tokens
_assistant = tokenizer("assistant").input_ids + nl_tokens
# Apply prompt templates
input_ids, targets = [], []
source = sources
if roles[source[0]["from"]] != roles["human"]:
source = source[1:]
input_id, target = [], []
system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens
input_id += system
target += [im_start] + [IGNORE_INDEX] * (len(system) - 3) + [im_end] + nl_tokens
assert len(input_id) == len(target)
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
if has_image and sentence["value"] is not None and "<image>" in sentence["value"]:
num_image = len(re.findall(DEFAULT_IMAGE_TOKEN, sentence["value"]))
texts = sentence["value"].split('<image>')
_input_id = tokenizer(role).input_ids + nl_tokens
for i,text in enumerate(texts):
_input_id += tokenizer(text).input_ids
if i<len(texts)-1:
_input_id += [IMAGE_TOKEN_INDEX] + nl_tokens
_input_id += [im_end] + nl_tokens
assert sum([i==IMAGE_TOKEN_INDEX for i in _input_id])==num_image
else:
if sentence["value"] is None:
_input_id = tokenizer(role).input_ids + nl_tokens
else:
_input_id = tokenizer(role).input_ids + nl_tokens + tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens
input_id += _input_id
if role == "<|im_start|>user":
_target = [im_start] + [IGNORE_INDEX] * (len(_input_id) - 3) + [im_end] + nl_tokens
elif role == "<|im_start|>assistant":
_target = [im_start] + [IGNORE_INDEX] * len(tokenizer(role).input_ids) + _input_id[len(tokenizer(role).input_ids) + 1 : -2] + [im_end] + nl_tokens
else:
raise NotImplementedError
target += _target
input_ids.append(input_id)
targets.append(target)
input_ids = torch.tensor(input_ids, dtype=torch.long)
targets = torch.tensor(targets, dtype=torch.long)
return input_ids
def eval_model(args):
# Model
disable_torch_init()
model_path = os.path.expanduser(args.model_path)
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
# Data
with open(os.path.expanduser(args.question_file)) as f:
questions = json.load(f)
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
answers_file = os.path.expanduser(args.answers_file)
os.makedirs(os.path.dirname(answers_file), exist_ok=True)
ans_file = open(answers_file, "w")
for line in tqdm(questions):
idx = line["sample_id"]
question_type = line["metadata"]["question_type"]
dataset_name = line["metadata"]["dataset"]
gt = line["conversations"][1]["value"]
image_files = line["image"]
qs = line["conversations"][0]["value"]
cur_prompt = args.extra_prompt + qs
args.conv_mode = "qwen_1_5"
conv = conv_templates[args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = preprocess_qwen([line["conversations"][0],{'from': 'gpt','value': None}], tokenizer, has_image=True).cuda()
img_num = list(input_ids.squeeze()).count(IMAGE_TOKEN_INDEX)
image_tensors = []
for image_file in image_files:
image = Image.open(os.path.join(args.image_folder, image_file))
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values']
image_tensors.append(image_tensor.half().cuda())
# image_tensors = torch.cat(image_tensors, dim=0)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensors,
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
top_p=args.top_p,
num_beams=args.num_beams,
# no_repeat_ngram_size=3,
max_new_tokens=1024,
use_cache=True)
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
ans_id = shortuuid.uuid()
ans_file.write(json.dumps({
"dataset": dataset_name,
"sample_id": idx,
"prompt": cur_prompt,
"pred_response": outputs,
"gt_response": gt,
"shortuuid": ans_id,
"model_id": model_name,
"question_type": question_type,
}) + "\n")
ans_file.flush()
if len(line["conversations"]) > 2:
for i in range(2, len(line["conversations"]), 2):
input_ids = torch.cat((input_ids, output_ids), dim=1)
gt = line["conversations"][i + 1]["value"]
qs = line["conversations"][i]["value"]
cur_prompt = args.extra_prompt + qs
args.conv_mode = "qwen_1_5"
conv = conv_templates[args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids_new = preprocess_qwen([line["conversations"][i],{'from': 'gpt','value': None}], tokenizer, has_image=True).cuda()
input_ids = torch.cat((input_ids, input_ids_new), dim=1)
img_num = list(input_ids_new.squeeze()).count(IMAGE_TOKEN_INDEX)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensors,
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
top_p=args.top_p,
num_beams=args.num_beams,
# no_repeat_ngram_size=3,
max_new_tokens=1024,
use_cache=True)
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
ans_id = shortuuid.uuid()
ans_file.write(json.dumps({
"dataset": dataset_name,
"sample_id": idx,
"prompt": cur_prompt,
"pred_response": outputs,
"gt_response": gt,
"shortuuid": ans_id,
"model_id": model_name,
"question_type": question_type,
}) + "\n")
ans_file.flush()
ans_file.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--image-folder", type=str, default="")
parser.add_argument("--extra-prompt", type=str, default="")
parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
parser.add_argument("--answers-file", type=str, default="answer.jsonl")
parser.add_argument("--conv-mode", type=str, default="llava_v1")
parser.add_argument("--num-chunks", type=int, default=1)
parser.add_argument("--chunk-idx", type=int, default=0)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--top_p", type=float, default=None)
parser.add_argument("--num_beams", type=int, default=1)
parser.add_argument("--test_size", type=int, default=10000000)
args = parser.parse_args()
eval_model(args)
\ No newline at end of file
from PIL import Image
from io import BytesIO
import base64
import math
import ast
import re
import torch
from transformers import StoppingCriteria
from llava.constants import IMAGE_TOKEN_INDEX
def resize_and_center_crop(image, shortest_edge_length):
# Calculate new dimensions and resize
aspect_ratio = float(image.width) / float(image.height)
if aspect_ratio > 1:
new_width = int(shortest_edge_length * aspect_ratio)
new_height = shortest_edge_length
else:
new_width = shortest_edge_length
new_height = int(shortest_edge_length / aspect_ratio)
resized_image = image.resize((new_width, new_height), Image.ANTIALIAS)
# Calculate the position and perform the center crop
left = (new_width - shortest_edge_length) / 2
top = (new_height - shortest_edge_length) / 2
right = (new_width + shortest_edge_length) / 2
bottom = (new_height + shortest_edge_length) / 2
cropped_image = resized_image.crop((left, top, right, bottom))
return cropped_image
def auto_pad_images(image, grid_params):
assert isinstance(image, Image.Image), "Input should be a Pillow Image"
assert len(grid_params) > 0, "Grid parameters should not be empty"
# Step 1: Calculate and find the closest aspect ratio
input_width, input_height = image.size
input_aspect_ratio = input_width / input_height
candidate_resolutions = [(w / h, w, h) for w in grid_params for h in grid_params]
closest_aspect_ratio = min(candidate_resolutions, key=lambda x: abs(input_aspect_ratio - x[0]))
candidate_resolutions = [(x[1], x[2]) for x in candidate_resolutions if abs(x[0] - closest_aspect_ratio[0]) < 1e-3]
target_resolution = min(candidate_resolutions, key=lambda res: abs(max(input_width, input_height) / max(res) - 1))
resize_width, resize_height = target_resolution
if input_width > input_height:
resize_height = int(resize_width / input_aspect_ratio)
else:
resize_width = int(resize_height * input_aspect_ratio)
resized_image = image.resize((resize_width, resize_height), Image.ANTIALIAS)
# Step 5: Pad the resized image if necessary to match the target resolution
pad_width = target_resolution[0] - resize_width
pad_height = target_resolution[1] - resize_height
padded_image = Image.new("RGB", target_resolution, color=(0, 0, 0))
padded_image.paste(resized_image, (pad_width // 2, pad_height // 2))
return padded_image
def extract_patches(image, patch_size, overlap_ratio):
assert isinstance(image, Image.Image), "Input should be a Pillow Image"
assert patch_size > 0, "Patch size should be greater than 0"
assert 0 <= overlap_ratio < 1, "Overlap ratio should be between 0 and 1"
W, H = image.size
patches = []
stride = int(patch_size * (1 - overlap_ratio))
num_patches_y = (H - patch_size) // stride + 1
num_patches_x = (W - patch_size) // stride + 1
y_start = (H - (num_patches_y - 1) * stride - patch_size) // 2
x_start = (W - (num_patches_x - 1) * stride - patch_size) // 2
for y in range(y_start, y_start + num_patches_y * stride, stride):
for x in range(x_start, x_start + num_patches_x * stride, stride):
patch = image.crop((x, y, x + patch_size, y + patch_size))
patches.append(patch)
return patches
def process_highres_image_crop_split(image, data_args, processor=None):
crop_resolution = data_args.image_crop_resolution
split_resolution = data_args.image_split_resolution
if processor is None:
processor = data_args.image_processor
image_crop = resize_and_center_crop(image, crop_resolution)
image_patches = extract_patches(image_crop, patch_size=split_resolution, overlap_ratio=0)
image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
return torch.stack(image_patches, dim=0)
def process_highres_image(image, processor, grid_pinpoints):
grid_params = [int(x) for x in grid_pinpoints.split(",")]
width_height = max(image.size)
fit_grid_params = [x for x in grid_params if x >= width_height]
if len(fit_grid_params) == 0:
select_size = max(grid_params)
else:
select_size = min(fit_grid_params)
# FIXME: always select the 448
select_size = max(grid_params)
image_padded = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
# FIXME: this seems to be a bug that it always resizes instead of padding
image_original_resize = image.resize((processor.size["shortest_edge"], processor.size["shortest_edge"]))
image_padded = image_padded.resize((select_size, select_size))
image_patches = extract_patches(image_padded, patch_size=processor.size["shortest_edge"], overlap_ratio=0)
image_patches = [image_original_resize] + image_patches
image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
return torch.stack(image_patches, dim=0)
def select_best_resolution(original_size, possible_resolutions):
"""
Selects the best resolution from a list of possible resolutions based on the original size.
Args:
original_size (tuple): The original size of the image in the format (width, height).
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
Returns:
tuple: The best fit resolution in the format (width, height).
"""
original_width, original_height = original_size
best_fit = None
max_effective_resolution = 0
min_wasted_resolution = float("inf")
for width, height in possible_resolutions:
# Calculate the downscaled size to keep the aspect ratio
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
# Calculate effective and wasted resolutions
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
wasted_resolution = (width * height) - effective_resolution
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (width, height)
return best_fit
def resize_and_pad_image(image, target_resolution):
"""
Resize and pad an image to a target resolution while maintaining aspect ratio.
Args:
image (PIL.Image.Image): The input image.
target_resolution (tuple): The target resolution (width, height) of the image.
Returns:
PIL.Image.Image: The resized and padded image.
"""
original_width, original_height = image.size
target_width, target_height = target_resolution
# Determine which dimension (width or height) to fill
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
# Width will be filled completely
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
# Height will be filled completely
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
# Resize the image
resized_image = image.resize((new_width, new_height))
# Create a new image with the target size and paste the resized image onto it
new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
new_image.paste(resized_image, (paste_x, paste_y))
return new_image
def divide_to_patches(image, patch_size):
"""
Divides an image into patches of a specified size.
Args:
image (PIL.Image.Image): The input image.
patch_size (int): The size of each patch.
Returns:
list: A list of PIL.Image.Image objects representing the patches.
"""
patches = []
width, height = image.size
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
box = (j, i, j + patch_size, i + patch_size)
patch = image.crop(box)
patches.append(patch)
return patches
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (tuple): The size of the input image in the format (width, height).
grid_pinpoints (str): A string representation of a list of possible resolutions.
patch_size (int): The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (width, height).
"""
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
# Use regex to extract the range from the input string
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)]
# Multiply all elements by patch_size
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints)
width, height = select_best_resolution(image_size, possible_resolutions)
return width // patch_size, height // patch_size
def process_anyres_image(image, processor, grid_pinpoints):
"""
Process an image with variable resolutions.
Args:
image (PIL.Image.Image): The input image to be processed.
processor: The image processor object.
grid_pinpoints (str): A string representation of a list of possible resolutions.
Returns:
torch.Tensor: A tensor containing the processed image patches.
"""
# Convert grid_pinpoints from string to list
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
try:
patch_size = processor.size[0]
except Exception as e:
patch_size = processor.size["shortest_edge"]
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
# Use regex to extract the range from the input string
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)]
# Multiply all elements by patch_size
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints)
best_resolution = select_best_resolution(image.size, possible_resolutions)
image_padded = resize_and_pad_image(image, best_resolution)
patches = divide_to_patches(image_padded, processor.crop_size["height"])
# FIXME: this seems to be a bug that it resizes instead of pad.
# but to keep it consistent with previous, i will keep it as it is
# TODO: uncomment below to ablate with the padding
if isinstance(processor.size, dict):
shortest_edge = processor.size["shortest_edge"]
else:
shortest_edge = min(processor.size)
image_original_resize = image.resize((shortest_edge, shortest_edge))
# image_padded_square = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
# image_original_resize = image_padded_square.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
image_patches = [image_original_resize] + patches
image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
return torch.stack(image_patches, dim=0)
def load_image_from_base64(image):
return Image.open(BytesIO(base64.b64decode(image)))
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def process_images(images, image_processor, model_cfg):
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
new_images = []
if image_aspect_ratio == "highres":
for image in images:
image = process_highres_image(image, image_processor, model_cfg.image_grid_pinpoints)
new_images.append(image)
elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
for image in images:
image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
new_images.append(image)
elif image_aspect_ratio == "crop_split":
for image in images:
image = process_highres_image_crop_split(image, model_cfg, image_processor)
new_images.append(image)
elif image_aspect_ratio == "pad":
for image in images:
image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
new_images.append(image)
else:
return image_processor.preprocess(images, return_tensors="pt")["pixel_values"]
if all(x.shape == new_images[0].shape for x in new_images):
new_images = torch.stack(new_images, dim=0)
return new_images
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
def insert_separator(X, sep):
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
input_ids = []
offset = 0
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
offset = 1
input_ids.append(prompt_chunks[0][0])
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
input_ids.extend(x[offset:])
if return_tensors is not None:
if return_tensors == "pt":
return torch.tensor(input_ids, dtype=torch.long)
raise ValueError(f"Unsupported tensor type: {return_tensors}")
return input_ids
def get_model_name_from_path(model_path):
model_path = model_path.strip("/")
model_paths = model_path.split("/")
if model_paths[-1].startswith("checkpoint-"):
return model_paths[-2] + "_" + model_paths[-1]
else:
return model_paths[-1]
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.keyword_ids = []
for keyword in keywords:
cur_keyword_ids = tokenizer(keyword).input_ids
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
cur_keyword_ids = cur_keyword_ids[1:]
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
self.tokenizer = tokenizer
self.start_len = input_ids.shape[1]
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
offset = min(output_ids.shape[1] - self.start_len, 3)
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
for keyword_id in self.keyword_ids:
if output_ids[0, -keyword_id.shape[0] :] == keyword_id:
return True
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment