Commit d5878167 authored by mashun1's avatar mashun1
Browse files

llava-next

parents
Pipeline #2589 failed with stages
in 0 seconds
icon.png

53.8 KB

# LLaVA-NeXT: Stronger LLMs Supercharge Multimodal Capabilities in the Wild
## 论文
`LLaVA-NeXT: Stronger LLMs Supercharge Multimodal Capabilities in the Wild`
* https://llava-vl.github.io/blog/2024-05-10-llava-next-stronger-llms/
## 模型结构
参考[README.md](../README.md)
## 算法原理
参考[README.md](../README.md)
## 数据集
## 训练
## 推理
```bash
python image.py
```
注意:在运行前需修改文件中的模型路径。
### 评估
```bash
export HF_ENDPOINT=https://hf-mirror.com
huggingface-cli login --token $HUGGINGFACE_TOKEN --add-to-git-credential
```
注意:此命令为自动下载评估数据所需。
```bash
accelerate launch --num_processes=8 \
-m lmms_eval \
--model llava \
--model_args pretrained=/path/to/llama3-llava-next-8b,conv_template=llava_llama_3 \
--tasks ai2d,chartqa,docvqa_val,mme,mmbench_en_dev \
--batch_size 1 \
--log_samples \
--log_samples_suffix llava_next \
--output_path ./logs/
```
```bash
accelerate launch --num_processes=1 \
-m lmms_eval \
--model llava \
--model_args pretrained=/path/to/llava-next-72b,conv_template=qwen_1_5,model_name=llava_qwen,device_map=auto \
--tasks ai2d,chartqa,docvqa_val,mme,mmbench_en_dev \
--batch_size 1 \
--log_samples \
--log_samples_suffix llava_next \
--output_path ./logs/
```
## result
![alt text](readme_imgs/multimodal-8b.png)
### 精度
## 应用场景
参考[README.md](../README.md)
## 预训练权重
|model|url|
|:---:|:---:|
|llama3-llava-next-8b|[hf](https://huggingface.co/lmms-lab/llama3-llava-next-8b) \| [SCNet]() |
|llava-next-qwen-32b|[hf](https://huggingface.co/lmms-lab/llava-next-qwen-32b) \| [SCNet]() |
模型下载后保存至`ckpts`(需自行创建).
## 源码仓库及问题反馈
参考[README.md](../README.md)
## 参考资料
* https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/docs/LLaVA-NeXT.md
* https://llava-vl.github.io/blog/2024-05-10-llava-next-stronger-llms/
\ No newline at end of file
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
from llava.conversation import conv_templates, SeparatorStyle
from PIL import Image
import requests
import copy
import torch
from pathlib import Path
import os
current_dir = str(Path(__file__).resolve().parent)
pretrained = os.path.join(current_dir, "ckpts", "llama3-llava-next-8b")
model_name = "llava_llama3"
device = "cuda"
device_map = "auto"
tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map) # Add any other thing you want to pass in llava_model_args
model.eval()
model.tie_weights()
image = Image.open("./examples/llava_v1_5_radar.jpg")
image_tensor = process_images([image], image_processor, model.config)
image_tensor = [_image.to(dtype=torch.float16, device=device) for _image in image_tensor]
conv_template = "llava_llama_3" # Make sure you use correct chat template for different models
question = DEFAULT_IMAGE_TOKEN + "\nWhat is shown in this image?"
conv = copy.deepcopy(conv_templates[conv_template])
conv.tokenizer = tokenizer
conv.append_message(conv.roles[0], question)
conv.append_message(conv.roles[1], None)
prompt_question = conv.get_prompt()
input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
image_sizes = [image.size]
cont = model.generate(
input_ids,
images=image_tensor,
image_sizes=image_sizes,
do_sample=False,
temperature=0,
max_new_tokens=256,
)
text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)
print(text_outputs)
\ No newline at end of file
import requests
from PIL import Image
import torch
from transformers import AutoProcessor, LlavaNextForConditionalGeneration
from pathlib import Path
import os
current_dir = str(Path(__file__).resolve().parent)
pretrained = os.path.join(current_dir, "ckpts", "llava-v1.6-mistral-7b-hf")
# Load the model in half-precision
model = LlavaNextForConditionalGeneration.from_pretrained(pretrained, torch_dtype=torch.float16, device_map="auto")
processor = AutoProcessor.from_pretrained(pretrained)
# Get three different images
# url = "https://www.ilankelman.org/stopsigns/australia.jpg"
# image_stop = Image.open(requests.get(url, stream=True).raw)
image_stop = Image.open("./examples/image.png")
# url = "http://images.cocodataset.org/val2017/000000039769.jpg"
# image_cats = Image.open(requests.get(url, stream=True).raw)
image_cats = Image.open("./examples/cat.jpg")
# url = "https://hugging-face.cn/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
# image_snowman = Image.open(requests.get(url, stream=True).raw)
image_snowman = Image.open("./examples/snowman.jpg")
# Prepare a batch of two prompts, where the first one is a multi-turn conversation and the second is not
conversation_1 = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "There is a lake in the image."},
],
},
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What about this image? How many cats do you see?"},
],
},
]
conversation_2 = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
prompt_1 = processor.apply_chat_template(conversation_1, add_generation_prompt=True)
prompt_2 = processor.apply_chat_template(conversation_2, add_generation_prompt=True)
prompts = [prompt_1, prompt_2]
# We can simply feed images in the order they have to be used in the text prompt
# Each "<image>" token uses one image leaving the next for the subsequent "<image>" tokens
inputs = processor(images=[image_stop, image_cats, image_snowman], text=prompts, padding=True, return_tensors="pt").to(model.device)
# Generate
generate_ids = model.generate(**inputs, max_new_tokens=30)
print(processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
\ No newline at end of file
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
import torch
from PIL import Image
import requests
from pathlib import Path
import os
current_dir = str(Path(__file__).resolve().parent)
pretrained = os.path.join(current_dir, "ckpts", "llava-v1.6-mistral-7b-hf")
processor = LlavaNextProcessor.from_pretrained(pretrained)
model = LlavaNextForConditionalGeneration.from_pretrained(pretrained, torch_dtype=torch.float16, low_cpu_mem_usage=True)
model.to("cuda")
# prepare image and text prompt, using the appropriate prompt template
# url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
# image = Image.open(requests.get(url, stream=True).raw)
image = Image.open("./examples/image.png")
# Define a chat history and use `apply_chat_template` to get correctly formatted prompt
# Each value in "content" has to be a list of dicts with types ("text", "image")
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
{"type": "image"},
],
},
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(images=image, text=prompt, return_tensors="pt").to("cuda")
# autoregressively complete prompt
output = model.generate(**inputs, max_new_tokens=100)
print(processor.decode(output[0], skip_special_tokens=True))
# LLaVA-NeXT: Tackling Multi-image, Video, and 3D in Large Multimodal Models
## 论文
`LLaVA-NeXT-Interleave: Tackling Multi-image, Video, and 3D
in Large Multimodal Models`
* https://arxiv.org/pdf/2407.07895
## 模型结构
参考[README.md](../README.md)
## 算法原理
[README.md](../README.md)基础上,LLaVA-NeXT-Interleave的核心是通过统一的数据格式和联合训练策略,实现多模态任务的泛化与迁移。
## 数据集
## 训练
## 推理
在 inference分支中
### 原生
```bash
python ../playground/demo/interleave_demo.py --model_path path/to/ckpt
```
### hf
```bash
python inference_hf.py
```
注意:运行前需要修改脚本中相应路径。
## result
![alt text](readme_imgs/result.png)
### 精度
## 应用场景
参考[README.md](../README.md)
## 预训练权重
|model|url|
|:---:|:---:|
|llava-next-interleave-qwen-7b|[hf](https://huggingface.co/lmms-lab/llava-next-interleave-qwen-7b) \| [SCNet]() |
|llava-next-interleave-qwen-0.5b|[hf](https://hf-mirror.com/lmms-lab/llava-next-interleave-qwen-0.5b) \| [SCNet]() |
|llava-interleave-qwen-0.5b-hf|[hf](https://huggingface.co/llava-hf/llava-interleave-qwen-0.5b-hf) \| [SCNet]() |
|llava-interleave-qwen-7b-hf|[hf](https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf) \| [SCNet]() |
|llava-interleave-qwen-7b-dpo-hf|[hf](https://huggingface.co/llava-hf/llava-interleave-qwen-7b-dpo-hf) \| [SCNet]() |
模型下载后保存至`ckpts`(需自行创建).
## 源码仓库及问题反馈
参考[README.md](../README.md)
## 参考资料
* https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/docs/LLaVA-NeXT-Interleave.md
* https://llava-vl.github.io/blog/2024-06-16-llava-next-interleave/
import requests
from PIL import Image
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
from pathlib import Path
import os
current_dir = str(Path(__file__).resolve().parent)
model_id = os.path.join(current_dir, "ckpts", "llava-interleave-qwen-0.5b-hf")
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(0)
processor = AutoProcessor.from_pretrained(model_id)
# Define a chat history and use `apply_chat_template` to get correctly formatted prompt
# Each value in "content" has to be a list of dicts with types ("text", "image")
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "What are these?"},
{"type": "image"},
],
},
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
# image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
# raw_image = Image.open(requests.get(image_file, stream=True).raw)
raw_image = Image.open("./examples/cat.jpg")
inputs = processor(images=raw_image, text=prompt, return_tensors='pt').to(0, torch.float16)
output = model.generate(**inputs, max_new_tokens=200, do_sample=False)
print(processor.decode(output[0][2:], skip_special_tokens=True))
from .model import LlavaLlamaForCausalLM
CONTROLLER_HEART_BEAT_EXPIRATION = 30
WORKER_HEART_BEAT_INTERVAL = 15
LOGDIR = "."
# Model Constants
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
This diff is collapsed.
import re
from rouge import Rouge
import argparse
import os
import json
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
spot_the_diff = ["Spot-the-Diff", "Birds-to-Words", "CLEVR-Change"]
image_edit_instruct = ["IEdit", "HQ-Edit", "MagicBrush"]
visual_story_telling = ["AESOP", "FlintstonesSV", "PororoSV", "VIST"]
visual_cloze = ["COMICS_Dialogue", "RecipeQA_VisualCloze"]
text_rich_vqa = ["WebQA", "TQA", "OCR-VQA", "DocVQA"]
multi_image_vqa = ["MIT-States_StateCoherence", "MIT-States_PropertyCoherence", "VISION", "RecipeQA_ImageCoherence"]
puzzle = ["RAVEN"]
nlrv2 = ["NLVR2_Mantis"]
qbench = ["QBench"]
class Eval:
def __init__(self):
self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
self.commaStrip = re.compile("(\d)(\,)(\d)")
self.punct = [
";",
r"/",
"[",
"]",
'"',
"{",
"}",
"(",
")",
"=",
"+",
"\\",
"_",
"-",
">",
"<",
"@",
"`",
",",
"?",
"!",
]
def processPunctuation(self, inText):
outText = inText
for p in self.punct:
if (p + " " in inText or " " + p in inText) or (
re.search(self.commaStrip, inText) != None
):
outText = outText.replace(p, "")
else:
outText = outText.replace(p, " ")
outText = self.periodStrip.sub("", outText, re.UNICODE)
return outText
def process(self, answer):
answer = answer.replace("\n", " ")
answer = answer.replace("\t", " ")
answer = answer.strip()
answer = self.processPunctuation(answer)
answer = answer.strip('\'')
answer = answer.strip('\"')
answer = answer.strip(')')
answer = answer.strip('(')
answer = answer.strip().lower()
return answer
def evaluate_rouge(self,preds):
rouge = Rouge()
acc = {'f': []}
eval_list = []
for i, res in enumerate(preds):
sample_id = res['sample_id']
# print(sample_id)
gt_ans = self.process(res["gt_response"])
pred_ans = self.process(res["pred_response"])
# assert gt_ans != ''
if gt_ans == '':
continue
if pred_ans == '':
s = 0
else:
if len(pred_ans) > 512:
pred_ans = pred_ans[0: 512]
s = rouge.get_scores(pred_ans, gt_ans)[0]['rouge-l']['f']
acc['f'].append(s)
eval_list.append({'id':str(sample_id),'score':str(round(s,3))})
results = {'Rouge-L f': np.mean(acc['f'])}
return results,eval_list
def judge_multi_choice(self,sample):
sample_id = sample['sample_id']
gt_ans = sample["gt_response"]
pred_ans = sample["pred_response"]
if ":" in pred_ans:
a_list = pred_ans.split(":")
a_list = [a.strip() for a in a_list ]
for a in a_list:
if len(a) == 1 and a[-1] in ["a", "b", "c", "d", "e", "f", "g", "h"]:
pred_ans = a
if pred_ans == gt_ans:
return 1
else:
return 0
def process_sample(self,sample):
sample["gt_response"] = self.process(sample["gt_response"])
sample["pred_response"] = self.process(sample["pred_response"])
def evaluate_multichoice(self, preditions):
correct = 0
eval_list = []
for i, sample in enumerate(preditions):
self.process_sample(sample)
score = self.judge_multi_choice(sample)
sample_id = sample['sample_id']
sample['result'] = score
eval_list.append({'id':str(sample_id),'score':str(score)})
correct+=score
return {'Accuracy':correct/len(preditions)},eval_list
def evaluate_multi_choice_image(self,preditions):
correct = 0
eval_list = []
for i,sample in enumerate(preditions):
gt_ans = self.process(sample["gt_response"])
pred_ans = self.process(sample["pred_response"])
sample_id = sample['sample_id']
if ":" in pred_ans:
a_list = pred_ans.split(":")
a_list = [a.strip() for a in a_list ]
for a in a_list:
if len(a) == 1 and a[-1] in ["a", "b", "c", "d", "e", "f", "g", "h"]:
pred_ans = a
if gt_ans == pred_ans:
score = 1
else:
score = 0
sample_id = sample['sample_id']
sample['result'] = score
eval_list.append({'id':str(sample_id),'score':str(score)})
correct+=score
return {'Accuracy':correct/len(preditions)},eval_list
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--result-dir', type=str, required=True)
args = parser.parse_args()
result_file = os.path.join(args.result_dir, "result.jsonl")
if not os.path.exists(result_file):
print('No prediction file found')
exit(0)
with open(result_file, 'r') as f:
preds_all = [json.loads(line) for line in f]
preds_all_dict = dict()
for pred in preds_all:
if pred["dataset"] not in preds_all_dict:
preds_all_dict[pred["dataset"]] = list()
preds_all_dict[pred["dataset"]].append(pred)
image_choice_dataset_list = ["recipeqa-RecipeQA_VisualCloze", "RecipeQA_ImageCoherence", "COMICS_Panel"]
E = Eval()
eval_result_list = dict()
eval_result_list_detail = dict()
for dataset in preds_all_dict:
preds = preds_all_dict[dataset]
question_type = preds[0]["question_type"]
if question_type == 'open-ended':
eval_result, eval_list = E.evaluate_rouge(preds)
elif question_type == 'multi-choice' or dataset == 'nlrv2':
if dataset in image_choice_dataset_list:
eval_result, eval_list = E.evaluate_multi_choice_image(preds)
else:
eval_result, eval_list = E.evaluate_multichoice(preds)
else:
eval_result = 'Dataset not supported'
print('Dataset not supported')
exit(0)
print(dataset, end = ': ')
print(eval_result)
eval_result_list[dataset] = eval_result
eval_result_list_detail[dataset] = eval_list
os.makedirs(args.result_dir, exist_ok=True)
with open(os.path.join(args.result_dir, 'eval_dataset.json'), 'w') as f:
json.dump(eval_result_list, f, indent=4)
with open(os.path.join(args.result_dir,'eval_dataset_details.json'), 'w') as f:
json.dump(eval_result_list_detail, f, indent=4)
eval_cat_list = dict()
print()
# spot_the_diff
score = 0
count = 0
for dataset in eval_result_list:
if dataset in spot_the_diff:
count += 1
score += list(eval_result_list[dataset].values())[0]
if count > 0:
score /= count
eval_cat_list["spot_the_diff"] = score
print("spot_the_diff", end = ': ')
print('{:.2f}'.format(100 * score))
# image_edit_instruct
score = 0
count = 0
for dataset in eval_result_list:
if dataset in image_edit_instruct:
count += 1
score += list(eval_result_list[dataset].values())[0]
if count > 0:
score /= count
eval_cat_list["image_edit_instruct"] = score
print("image_edit_instruct", end = ': ')
print('{:.2f}'.format(100 * score))
# visual_story_telling
score = 0
count = 0
for dataset in eval_result_list:
if dataset in visual_story_telling:
count += 1
score += list(eval_result_list[dataset].values())[0]
if count > 0:
score /= count
eval_cat_list["visual_story_telling"] = score
print("visual_story_telling", end = ': ')
print('{:.2f}'.format(100 * score))
# visual_cloze
score = 0
count = 0
for dataset in eval_result_list:
if dataset in visual_cloze:
count += 1
score += list(eval_result_list[dataset].values())[0]
if count > 0:
score /= count
eval_cat_list["visual_cloze"] = score
print("visual_cloze", end = ': ')
print('{:.2f}'.format(100 * score))
# text_rich_vqa
score = 0
count = 0
for dataset in eval_result_list:
if dataset in text_rich_vqa:
count += 1
score += list(eval_result_list[dataset].values())[0]
if count > 0:
score /= count
eval_cat_list["text_rich_vqa"] = score
print("text_rich_vqa", end = ': ')
print('{:.2f}'.format(100 * score))
# multi_image_vqa
score = 0
count = 0
for dataset in eval_result_list:
if dataset in multi_image_vqa:
count += 1
score += list(eval_result_list[dataset].values())[0]
if count > 0:
score /= count
eval_cat_list["multi_image_vqa"] = score
print("multi_image_vqa", end = ': ')
print('{:.2f}'.format(100 * score))
# puzzle
score = 0
count = 0
for dataset in eval_result_list:
if dataset in puzzle:
count += 1
score += list(eval_result_list[dataset].values())[0]
if count > 0:
score /= count
eval_cat_list["puzzle"] = score
print("puzzle", end = ': ')
print('{:.2f}'.format(100 * score))
# nlrv2
score = 0
count = 0
for dataset in eval_result_list:
if dataset in nlrv2:
count += 1
score += list(eval_result_list[dataset].values())[0]
if count > 0:
score /= count
eval_cat_list["nlrv2"] = score
print("nlrv2", end = ': ')
print('{:.2f}'.format(100 * score))
# qbench
score = 0
count = 0
for dataset in eval_result_list:
if dataset in qbench:
count += 1
score += list(eval_result_list[dataset].values())[0]
if count > 0:
score /= count
eval_cat_list["qbench"] = score
print("qbench", end = ': ')
print('{:.2f}'.format(100 * score))
with open(os.path.join(args.result_dir,'eval_cat.json'), 'w') as f:
json.dump(eval_cat_list, f, indent=4)
\ No newline at end of file
import argparse
import torch
import os
import json
from tqdm import tqdm
import shortuuid
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from llava.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_INDEX
from typing import Dict, Optional, Sequence, List
import transformers
import re
from PIL import Image
import math
def split_list(lst, n):
"""Split a list into n (roughly) equal-sized chunks"""
chunk_size = math.ceil(len(lst) / n) # integer division
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
def get_chunk(lst, n, k):
chunks = split_list(lst, n)
return chunks[k]
def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict:
roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"}
im_start, im_end = tokenizer.additional_special_tokens_ids
nl_tokens = tokenizer("\n").input_ids
_system = tokenizer("system").input_ids + nl_tokens
_user = tokenizer("user").input_ids + nl_tokens
_assistant = tokenizer("assistant").input_ids + nl_tokens
# Apply prompt templates
input_ids, targets = [], []
source = sources
if roles[source[0]["from"]] != roles["human"]:
source = source[1:]
input_id, target = [], []
system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens
input_id += system
target += [im_start] + [IGNORE_INDEX] * (len(system) - 3) + [im_end] + nl_tokens
assert len(input_id) == len(target)
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
if has_image and sentence["value"] is not None and "<image>" in sentence["value"]:
num_image = len(re.findall(DEFAULT_IMAGE_TOKEN, sentence["value"]))
texts = sentence["value"].split('<image>')
_input_id = tokenizer(role).input_ids + nl_tokens
for i,text in enumerate(texts):
_input_id += tokenizer(text).input_ids
if i<len(texts)-1:
_input_id += [IMAGE_TOKEN_INDEX] + nl_tokens
_input_id += [im_end] + nl_tokens
assert sum([i==IMAGE_TOKEN_INDEX for i in _input_id])==num_image
else:
if sentence["value"] is None:
_input_id = tokenizer(role).input_ids + nl_tokens
else:
_input_id = tokenizer(role).input_ids + nl_tokens + tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens
input_id += _input_id
if role == "<|im_start|>user":
_target = [im_start] + [IGNORE_INDEX] * (len(_input_id) - 3) + [im_end] + nl_tokens
elif role == "<|im_start|>assistant":
_target = [im_start] + [IGNORE_INDEX] * len(tokenizer(role).input_ids) + _input_id[len(tokenizer(role).input_ids) + 1 : -2] + [im_end] + nl_tokens
else:
raise NotImplementedError
target += _target
input_ids.append(input_id)
targets.append(target)
input_ids = torch.tensor(input_ids, dtype=torch.long)
targets = torch.tensor(targets, dtype=torch.long)
return input_ids
def eval_model(args):
# Model
disable_torch_init()
model_path = os.path.expanduser(args.model_path)
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
# Data
with open(os.path.expanduser(args.question_file)) as f:
questions = json.load(f)
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
answers_file = os.path.expanduser(args.answers_file)
os.makedirs(os.path.dirname(answers_file), exist_ok=True)
ans_file = open(answers_file, "w")
for line in tqdm(questions):
idx = line["sample_id"]
question_type = line["metadata"]["question_type"]
dataset_name = line["metadata"]["dataset"]
gt = line["conversations"][1]["value"]
image_files = line["image"]
qs = line["conversations"][0]["value"]
cur_prompt = args.extra_prompt + qs
args.conv_mode = "qwen_1_5"
conv = conv_templates[args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = preprocess_qwen([line["conversations"][0],{'from': 'gpt','value': None}], tokenizer, has_image=True).cuda()
img_num = list(input_ids.squeeze()).count(IMAGE_TOKEN_INDEX)
image_tensors = []
for image_file in image_files:
image = Image.open(os.path.join(args.image_folder, image_file))
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values']
image_tensors.append(image_tensor.half().cuda())
# image_tensors = torch.cat(image_tensors, dim=0)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensors,
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
top_p=args.top_p,
num_beams=args.num_beams,
# no_repeat_ngram_size=3,
max_new_tokens=1024,
use_cache=True)
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
ans_id = shortuuid.uuid()
ans_file.write(json.dumps({
"dataset": dataset_name,
"sample_id": idx,
"prompt": cur_prompt,
"pred_response": outputs,
"gt_response": gt,
"shortuuid": ans_id,
"model_id": model_name,
"question_type": question_type,
}) + "\n")
ans_file.flush()
if len(line["conversations"]) > 2:
for i in range(2, len(line["conversations"]), 2):
input_ids = torch.cat((input_ids, output_ids), dim=1)
gt = line["conversations"][i + 1]["value"]
qs = line["conversations"][i]["value"]
cur_prompt = args.extra_prompt + qs
args.conv_mode = "qwen_1_5"
conv = conv_templates[args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids_new = preprocess_qwen([line["conversations"][i],{'from': 'gpt','value': None}], tokenizer, has_image=True).cuda()
input_ids = torch.cat((input_ids, input_ids_new), dim=1)
img_num = list(input_ids_new.squeeze()).count(IMAGE_TOKEN_INDEX)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensors,
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
top_p=args.top_p,
num_beams=args.num_beams,
# no_repeat_ngram_size=3,
max_new_tokens=1024,
use_cache=True)
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
ans_id = shortuuid.uuid()
ans_file.write(json.dumps({
"dataset": dataset_name,
"sample_id": idx,
"prompt": cur_prompt,
"pred_response": outputs,
"gt_response": gt,
"shortuuid": ans_id,
"model_id": model_name,
"question_type": question_type,
}) + "\n")
ans_file.flush()
ans_file.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--image-folder", type=str, default="")
parser.add_argument("--extra-prompt", type=str, default="")
parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
parser.add_argument("--answers-file", type=str, default="answer.jsonl")
parser.add_argument("--conv-mode", type=str, default="llava_v1")
parser.add_argument("--num-chunks", type=int, default=1)
parser.add_argument("--chunk-idx", type=int, default=0)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--top_p", type=float, default=None)
parser.add_argument("--num_beams", type=int, default=1)
parser.add_argument("--test_size", type=int, default=10000000)
args = parser.parse_args()
eval_model(args)
\ No newline at end of file
from PIL import Image
from io import BytesIO
import base64
import math
import ast
import re
import torch
from transformers import StoppingCriteria
from llava.constants import IMAGE_TOKEN_INDEX
def resize_and_center_crop(image, shortest_edge_length):
# Calculate new dimensions and resize
aspect_ratio = float(image.width) / float(image.height)
if aspect_ratio > 1:
new_width = int(shortest_edge_length * aspect_ratio)
new_height = shortest_edge_length
else:
new_width = shortest_edge_length
new_height = int(shortest_edge_length / aspect_ratio)
resized_image = image.resize((new_width, new_height), Image.ANTIALIAS)
# Calculate the position and perform the center crop
left = (new_width - shortest_edge_length) / 2
top = (new_height - shortest_edge_length) / 2
right = (new_width + shortest_edge_length) / 2
bottom = (new_height + shortest_edge_length) / 2
cropped_image = resized_image.crop((left, top, right, bottom))
return cropped_image
def auto_pad_images(image, grid_params):
assert isinstance(image, Image.Image), "Input should be a Pillow Image"
assert len(grid_params) > 0, "Grid parameters should not be empty"
# Step 1: Calculate and find the closest aspect ratio
input_width, input_height = image.size
input_aspect_ratio = input_width / input_height
candidate_resolutions = [(w / h, w, h) for w in grid_params for h in grid_params]
closest_aspect_ratio = min(candidate_resolutions, key=lambda x: abs(input_aspect_ratio - x[0]))
candidate_resolutions = [(x[1], x[2]) for x in candidate_resolutions if abs(x[0] - closest_aspect_ratio[0]) < 1e-3]
target_resolution = min(candidate_resolutions, key=lambda res: abs(max(input_width, input_height) / max(res) - 1))
resize_width, resize_height = target_resolution
if input_width > input_height:
resize_height = int(resize_width / input_aspect_ratio)
else:
resize_width = int(resize_height * input_aspect_ratio)
resized_image = image.resize((resize_width, resize_height), Image.ANTIALIAS)
# Step 5: Pad the resized image if necessary to match the target resolution
pad_width = target_resolution[0] - resize_width
pad_height = target_resolution[1] - resize_height
padded_image = Image.new("RGB", target_resolution, color=(0, 0, 0))
padded_image.paste(resized_image, (pad_width // 2, pad_height // 2))
return padded_image
def extract_patches(image, patch_size, overlap_ratio):
assert isinstance(image, Image.Image), "Input should be a Pillow Image"
assert patch_size > 0, "Patch size should be greater than 0"
assert 0 <= overlap_ratio < 1, "Overlap ratio should be between 0 and 1"
W, H = image.size
patches = []
stride = int(patch_size * (1 - overlap_ratio))
num_patches_y = (H - patch_size) // stride + 1
num_patches_x = (W - patch_size) // stride + 1
y_start = (H - (num_patches_y - 1) * stride - patch_size) // 2
x_start = (W - (num_patches_x - 1) * stride - patch_size) // 2
for y in range(y_start, y_start + num_patches_y * stride, stride):
for x in range(x_start, x_start + num_patches_x * stride, stride):
patch = image.crop((x, y, x + patch_size, y + patch_size))
patches.append(patch)
return patches
def process_highres_image_crop_split(image, data_args, processor=None):
crop_resolution = data_args.image_crop_resolution
split_resolution = data_args.image_split_resolution
if processor is None:
processor = data_args.image_processor
image_crop = resize_and_center_crop(image, crop_resolution)
image_patches = extract_patches(image_crop, patch_size=split_resolution, overlap_ratio=0)
image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
return torch.stack(image_patches, dim=0)
def process_highres_image(image, processor, grid_pinpoints):
grid_params = [int(x) for x in grid_pinpoints.split(",")]
width_height = max(image.size)
fit_grid_params = [x for x in grid_params if x >= width_height]
if len(fit_grid_params) == 0:
select_size = max(grid_params)
else:
select_size = min(fit_grid_params)
# FIXME: always select the 448
select_size = max(grid_params)
image_padded = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
# FIXME: this seems to be a bug that it always resizes instead of padding
image_original_resize = image.resize((processor.size["shortest_edge"], processor.size["shortest_edge"]))
image_padded = image_padded.resize((select_size, select_size))
image_patches = extract_patches(image_padded, patch_size=processor.size["shortest_edge"], overlap_ratio=0)
image_patches = [image_original_resize] + image_patches
image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
return torch.stack(image_patches, dim=0)
def select_best_resolution(original_size, possible_resolutions):
"""
Selects the best resolution from a list of possible resolutions based on the original size.
Args:
original_size (tuple): The original size of the image in the format (width, height).
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
Returns:
tuple: The best fit resolution in the format (width, height).
"""
original_width, original_height = original_size
best_fit = None
max_effective_resolution = 0
min_wasted_resolution = float("inf")
for width, height in possible_resolutions:
# Calculate the downscaled size to keep the aspect ratio
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
# Calculate effective and wasted resolutions
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
wasted_resolution = (width * height) - effective_resolution
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (width, height)
return best_fit
def resize_and_pad_image(image, target_resolution):
"""
Resize and pad an image to a target resolution while maintaining aspect ratio.
Args:
image (PIL.Image.Image): The input image.
target_resolution (tuple): The target resolution (width, height) of the image.
Returns:
PIL.Image.Image: The resized and padded image.
"""
original_width, original_height = image.size
target_width, target_height = target_resolution
# Determine which dimension (width or height) to fill
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
# Width will be filled completely
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
# Height will be filled completely
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
# Resize the image
resized_image = image.resize((new_width, new_height))
# Create a new image with the target size and paste the resized image onto it
new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
new_image.paste(resized_image, (paste_x, paste_y))
return new_image
def divide_to_patches(image, patch_size):
"""
Divides an image into patches of a specified size.
Args:
image (PIL.Image.Image): The input image.
patch_size (int): The size of each patch.
Returns:
list: A list of PIL.Image.Image objects representing the patches.
"""
patches = []
width, height = image.size
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
box = (j, i, j + patch_size, i + patch_size)
patch = image.crop(box)
patches.append(patch)
return patches
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (tuple): The size of the input image in the format (width, height).
grid_pinpoints (str): A string representation of a list of possible resolutions.
patch_size (int): The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (width, height).
"""
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
# Use regex to extract the range from the input string
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)]
# Multiply all elements by patch_size
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints)
width, height = select_best_resolution(image_size, possible_resolutions)
return width // patch_size, height // patch_size
def process_anyres_image(image, processor, grid_pinpoints):
"""
Process an image with variable resolutions.
Args:
image (PIL.Image.Image): The input image to be processed.
processor: The image processor object.
grid_pinpoints (str): A string representation of a list of possible resolutions.
Returns:
torch.Tensor: A tensor containing the processed image patches.
"""
# Convert grid_pinpoints from string to list
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
try:
patch_size = processor.size[0]
except Exception as e:
patch_size = processor.size["shortest_edge"]
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
# Use regex to extract the range from the input string
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)]
# Multiply all elements by patch_size
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints)
best_resolution = select_best_resolution(image.size, possible_resolutions)
image_padded = resize_and_pad_image(image, best_resolution)
patches = divide_to_patches(image_padded, processor.crop_size["height"])
# FIXME: this seems to be a bug that it resizes instead of pad.
# but to keep it consistent with previous, i will keep it as it is
# TODO: uncomment below to ablate with the padding
if isinstance(processor.size, dict):
shortest_edge = processor.size["shortest_edge"]
else:
shortest_edge = min(processor.size)
image_original_resize = image.resize((shortest_edge, shortest_edge))
# image_padded_square = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
# image_original_resize = image_padded_square.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
image_patches = [image_original_resize] + patches
image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
return torch.stack(image_patches, dim=0)
def load_image_from_base64(image):
return Image.open(BytesIO(base64.b64decode(image)))
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def process_images(images, image_processor, model_cfg):
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
new_images = []
if image_aspect_ratio == "highres":
for image in images:
image = process_highres_image(image, image_processor, model_cfg.image_grid_pinpoints)
new_images.append(image)
elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
for image in images:
image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
new_images.append(image)
elif image_aspect_ratio == "crop_split":
for image in images:
image = process_highres_image_crop_split(image, model_cfg, image_processor)
new_images.append(image)
elif image_aspect_ratio == "pad":
for image in images:
image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
new_images.append(image)
else:
return image_processor.preprocess(images, return_tensors="pt")["pixel_values"]
if all(x.shape == new_images[0].shape for x in new_images):
new_images = torch.stack(new_images, dim=0)
return new_images
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
def insert_separator(X, sep):
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
input_ids = []
offset = 0
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
offset = 1
input_ids.append(prompt_chunks[0][0])
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
input_ids.extend(x[offset:])
if return_tensors is not None:
if return_tensors == "pt":
return torch.tensor(input_ids, dtype=torch.long)
raise ValueError(f"Unsupported tensor type: {return_tensors}")
return input_ids
def get_model_name_from_path(model_path):
model_path = model_path.strip("/")
model_paths = model_path.split("/")
if model_paths[-1].startswith("checkpoint-"):
return model_paths[-2] + "_" + model_paths[-1]
else:
return model_paths[-1]
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.keyword_ids = []
for keyword in keywords:
cur_keyword_ids = tokenizer(keyword).input_ids
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
cur_keyword_ids = cur_keyword_ids[1:]
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
self.tokenizer = tokenizer
self.start_len = input_ids.shape[1]
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
offset = min(output_ids.shape[1] - self.start_len, 3)
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
for keyword_id in self.keyword_ids:
if output_ids[0, -keyword_id.shape[0] :] == keyword_id:
return True
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment