Commit bc5ebf0f authored by luopl's avatar luopl
Browse files

Initial commit

parents
Pipeline #2167 canceled with stages
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import warnings
import os.path as osp
from PIL import Image
from .base import BaseModel
from ..smp import *
from ..dataset import DATASET_TYPE
class QH_360VL(BaseModel):
INSTALL_REQ = False
INTERLEAVE = False
def __init__(self, model_path='qihoo360/360VL-70B', **kwargs):
assert model_path is not None
self.model_path = model_path
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(model_path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map='auto',
trust_remote_code=True).eval()
vision_tower = self.model.get_vision_tower()
vision_tower.load_model()
vision_tower.to(device='cuda', dtype=torch.float16)
self.image_processor = vision_tower.image_processor
self.tokenizer.pad_token = self.tokenizer.eos_token
self.kwargs = kwargs
warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
torch.cuda.empty_cache()
def generate(self, message, dataset=None):
prompt, image_path = self.message_to_promptimg(message, dataset=dataset)
print(prompt)
image = Image.open(image_path).convert('RGB')
terminators = [
self.tokenizer.convert_tokens_to_ids('<|eot_id|>',)
]
inputs = self.model.build_conversation_input_ids(self.tokenizer,
query=prompt,
image=image,
image_processor=self.image_processor)
input_ids = inputs['input_ids'].to(device='cuda', non_blocking=True)
images = inputs['image'].to(dtype=torch.float16, device='cuda', non_blocking=True)
output_ids = self.model.generate(input_ids=input_ids,
images=images,
do_sample=False,
num_beams=1,
max_new_tokens=512,
eos_token_id=terminators,
use_cache=True)
input_token_len = input_ids.shape[1]
outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
response = outputs.strip()
return response
from .model import Qwen2VLChat
from .prompt import Qwen2VLPromptMixin
from __future__ import annotations
import os
import sys
import warnings
import math
import logging
import torch
from ..base import BaseModel
from .prompt import Qwen2VLPromptMixin
from ...smp import get_rank_and_world_size, get_gpu_memory, auto_split_flag
def ensure_image_url(image: str) -> str:
prefixes = ['http://', 'https://', 'file://', 'data:image;']
if any(image.startswith(prefix) for prefix in prefixes):
return image
if os.path.exists(image):
return 'file://' + image
raise ValueError(f'Invalid image: {image}')
def ensure_video_url(video: str) -> str:
prefixes = ['http://', 'https://', 'file://', 'data:video;']
if any(video.startswith(prefix) for prefix in prefixes):
return video
if os.path.exists(video):
return 'file://' + video
raise ValueError(f'Invalid video: {video}')
def split_model():
device_map = {}
total_gpus = torch.cuda.device_count()
rank, world_size = get_rank_and_world_size()
num_gpus = total_gpus // world_size
# + 8 is virtual layers for the memory of visual
num_layers = 80 + 8
num_layers_per_gpu = math.ceil(num_layers / num_gpus)
num_layers_per_gpu = [num_layers_per_gpu] * num_gpus
num_layers_per_gpu[0] -= 6
num_layers_per_gpu[-1] -= 2
layer_cnt = 0
for i, num_layer in enumerate(num_layers_per_gpu):
for j in range(num_layer):
device_map[f'model.layers.{layer_cnt}'] = rank + i * world_size
layer_cnt += 1
last_gpu = rank + (num_gpus - 1) * world_size
device_map['visual'] = rank
device_map['model.embed_tokens'] = rank
device_map['model.norm'] = last_gpu
device_map['model.rotary_emb'] = last_gpu
device_map['lm_head'] = last_gpu
return device_map
class Qwen2VLChat(Qwen2VLPromptMixin, BaseModel):
INSTALL_REQ = False
INTERLEAVE = True
VIDEO_LLM = True
def __init__(
self,
model_path: str,
min_pixels: int | None = None,
max_pixels: int | None = None,
max_new_tokens=2048,
top_p=0.001,
top_k=1,
temperature=0.01,
repetition_penalty=1.0,
use_custom_prompt: bool = True,
system_prompt: str | None = None,
verbose: bool = False,
):
super().__init__(use_custom_prompt=use_custom_prompt)
self.min_pixels = min_pixels
self.max_pixels = max_pixels
self.generate_kwargs = dict(
max_new_tokens=max_new_tokens,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=repetition_penalty,
)
self.system_prompt = system_prompt
self.verbose = verbose
self.fps = 2.0
self.nframe = 64
self.FRAME_FACTOR = 2
from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor
rank, world_size = get_rank_and_world_size()
assert model_path is not None
self.model_path = model_path
self.processor = Qwen2VLProcessor.from_pretrained(model_path)
gpu_mems = get_gpu_memory()
max_gpu_mem = max(gpu_mems) if gpu_mems != [] else -1
assert max_gpu_mem > 0
# If only one process and GPU memory is less than 40GB
if auto_split_flag():
assert world_size == 1, 'Only support world_size == 1 when AUTO_SPLIT is set for non-72B Qwen2-VL'
# Will Use All GPUs to run one model
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path, torch_dtype='auto', device_map='auto', attn_implementation='flash_attention_2'
)
elif '72b' not in self.model_path.lower():
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path, torch_dtype='auto', device_map='cpu', attn_implementation='flash_attention_2'
)
self.model.cuda().eval()
else:
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path, torch_dtype='auto', device_map=split_model(), attn_implementation='flash_attention_2'
)
self.model.eval()
torch.cuda.empty_cache()
def _prepare_content(self, inputs: list[dict[str, str]], dataset: str | None = None) -> list[dict[str, str]]:
"""
inputs list[dict[str, str]], each dict has keys: ['type', 'value']
"""
content = []
for s in inputs:
if s['type'] == 'image':
item = {'type': 'image', 'image': ensure_image_url(s['value'])}
if dataset == 'OCRBench':
item['min_pixels'] = 10 * 10 * 28 * 28
warnings.warn(f"OCRBench dataset uses custom min_pixels={item['min_pixels']}")
if self.max_pixels is not None:
item['max_pixels'] = self.max_pixels
else:
if self.min_pixels is not None:
item['min_pixels'] = self.min_pixels
if self.max_pixels is not None:
item['max_pixels'] = self.max_pixels
elif s['type'] == 'video':
item = {'type': 'video', 'video': ensure_video_url(s['value'])}
if self.fps is not None:
item['fps'] = self.fps
elif self.nframe is not None:
import cv2
video = cv2.VideoCapture(s['value'])
frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
video.release()
if frame_count < self.nframe:
new_frame_count = frame_count // self.FRAME_FACTOR * self.FRAME_FACTOR
print(f"use {new_frame_count} for {s['value']}")
item['nframes'] = new_frame_count
else:
item['nframes'] = self.nframe
elif s['type'] == 'text':
item = {'type': 'text', 'text': s['value']}
else:
raise ValueError(f"Invalid message type: {s['type']}, {s}")
content.append(item)
return content
def generate_inner(self, message, dataset=None):
try:
from qwen_vl_utils import process_vision_info
except Exception as err:
logging.critical("qwen_vl_utils not found, please install it via 'pip install qwen-vl-utils'")
raise err
messages = []
if self.system_prompt is not None:
messages.append({'role': 'system', 'content': self.system_prompt})
messages.append({'role': 'user', 'content': self._prepare_content(message, dataset=dataset)})
if self.verbose:
print(f'\033[31m{messages}\033[0m')
text = self.processor.apply_chat_template([messages], tokenize=False, add_generation_prompt=True)
images, videos = process_vision_info([messages])
inputs = self.processor(text=text, images=images, videos=videos, padding=True, return_tensors='pt')
inputs = inputs.to('cuda')
generated_ids = self.model.generate(
**inputs,
**self.generate_kwargs,
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
]
out = self.processor.tokenizer.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
response = out[0]
if self.verbose:
print(f'\033[32m{response}\033[0m')
return response
from __future__ import annotations
class Qwen2VLPromptMixin:
"""
Mixin class for Qwen2VLChat to build custom prompt for different datasets.
Requires the following methods to be implemented in the subclass:
- dump_image(line, dataset: str) -> str | list[str]
Implements the following methods:
- use_custom_prompt(dataset: str) -> bool
- build_prompt(line, dataset: str) -> list[dict[str, str]]
"""
def __init__(self, *args, use_custom_prompt: bool = True, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._use_custom_prompt = use_custom_prompt
def set_dump_image(self, dump_image_func):
self.dump_image_func = dump_image_func
def dump_image(self, line, dataset):
return self.dump_image_func(line)
def use_custom_prompt(self, dataset: str) -> bool:
from vlmeval.dataset import DATASET_TYPE
dataset_type = DATASET_TYPE(dataset, default=None)
if not self._use_custom_prompt:
return False
if dataset in {'MMMU_DEV_VAL', 'MMMU_TEST'}:
return True
if dataset_type == 'MCQ':
return True
if dataset_type == 'Y/N' and dataset in {'HallusionBench', 'POPE'}: # MME has it's own prompt
return True
if dataset_type == 'VQA' and dataset not in {'MMVet'}: # MMVet VQA has it's own prompt
return True
return False
def build_prompt(self, line, dataset: str) -> list[dict[str, str]]:
from vlmeval.dataset import DATASET_TYPE
if dataset in {'MMMU_DEV_VAL', 'MMMU_TEST'}:
return self._build_mmmu_prompt(line, dataset)
dataset_type = DATASET_TYPE(dataset, default=None)
if dataset_type == 'MCQ':
return self._build_mcq_prompt(line, dataset)
if dataset_type == 'Y/N':
return self._build_yorn_prompt(line, dataset)
if dataset_type == 'VQA':
return self._build_vqa_prompt(line, dataset)
raise ValueError(f'Unsupported dataset: {dataset}')
def _build_mmmu_prompt(self, line, dataset: str) -> list[dict[str, str]]:
"""change the prompt for MMMU dataset: keep all images at beginning."""
import string
import pandas as pd
tgt_path = self.dump_image(line, dataset)
question = line['question']
options = {cand: line[cand] for cand in string.ascii_uppercase if cand in line and not pd.isna(line[cand])}
options_prompt = 'Options:\n'
for key, item in options.items():
options_prompt += f'{key}. {item}\n'
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
prompt = ''
if hint is not None:
prompt += f'Hint: {hint}\n'
prompt += f'Question: {question}\n'
if len(options):
prompt += options_prompt
prompt += 'Please select the correct answer from the options above. \n'
prompt = prompt.rstrip()
msgs = []
if isinstance(tgt_path, list):
msgs.extend([dict(type='image', value=p) for p in tgt_path])
else:
msgs = [dict(type='image', value=tgt_path)]
msgs.append(dict(type='text', value=prompt))
return msgs
def _build_mcq_prompt(self, line, dataset: str) -> list[dict[str, str]]:
"""change the prompt for MCQ dataset: use chinese prompt if the question contains chinese characters."""
MCQ_CN_PROMPT = '请直接回答选项字母。'
MCQ_EN_PROMPT = 'Please select the correct answer from the options above.'
import string
import pandas as pd
def cn_string(s):
import re
if re.search('[\u4e00-\u9fff]', s):
return True
return False
tgt_path = self.dump_image(line, dataset)
question = line['question']
options = {cand: line[cand] for cand in string.ascii_uppercase if cand in line and not pd.isna(line[cand])}
options_prompt = 'Options:\n'
for key, item in options.items():
options_prompt += f'{key}. {item}\n'
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
prompt = ''
if hint is not None:
prompt += f'Hint: {hint}\n'
prompt += f'Question: {question}\n'
if len(options):
prompt += options_prompt
prompt += MCQ_CN_PROMPT if cn_string(prompt) else MCQ_EN_PROMPT
prompt = prompt.rstrip()
msgs = []
if isinstance(tgt_path, list):
msgs.extend([dict(type='image', value=p) for p in tgt_path])
else:
msgs = [dict(type='image', value=tgt_path)]
msgs.append(dict(type='text', value=prompt))
return msgs
def _build_yorn_prompt(self, line, dataset: str) -> list[dict[str, str]]:
"""change the prompt for YORN dataset:"""
YORN_PROMPT = ' Please answer yes or no.'
tgt_path = self.dump_image(line, dataset)
question = line['question']
msgs = []
if isinstance(tgt_path, list):
msgs.extend([dict(type='image', value=p) for p in tgt_path])
else:
msgs = [dict(type='image', value=tgt_path)]
msgs.append(dict(type='text', value=question))
assert msgs[-1]['type'] == 'text'
msgs[-1]['value'] += YORN_PROMPT
return msgs
def _build_vqa_prompt(self, line, dataset: str) -> list[dict[str, str]]:
"""change the prompt for VQA dataset:"""
VQA_PROMPT = '\nPlease try to answer the question with short words or phrases if possible.'
tgt_path = self.dump_image(line, dataset)
question = line['question']
msgs = []
if isinstance(tgt_path, list):
msgs.extend([dict(type='image', value=p) for p in tgt_path])
else:
msgs = [dict(type='image', value=tgt_path)]
msgs.append(dict(type='text', value=question))
assert msgs[-1]['type'] == 'text'
msgs[-1]['value'] += VQA_PROMPT
return msgs
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import warnings
import copy as cp
from .base import BaseModel
from ..smp import isimg, listinstr
from ..dataset import DATASET_TYPE
class QwenVL(BaseModel):
INSTALL_REQ = False
INTERLEAVE = True
def __init__(self, model_path='Qwen/Qwen-VL', **kwargs):
assert model_path is not None
self.model_path = model_path
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer.padding_side = 'left'
tokenizer.pad_token_id = tokenizer.eod_id
self.tokenizer = tokenizer
self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map='cuda', trust_remote_code=True).eval()
default_kwargs = dict(
do_sample=False,
num_beams=1,
max_new_tokens=512,
min_new_tokens=1,
num_return_sequences=1,
use_cache=True,
output_hidden_states=True,
pad_token_id=tokenizer.eod_id,
eos_token_id=tokenizer.eod_id)
default_kwargs.update(kwargs)
self.kwargs = default_kwargs
warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
torch.cuda.empty_cache()
def adjust_kwargs(self, dataset):
kwargs = cp.deepcopy(self.kwargs)
if DATASET_TYPE(dataset) in ['MCQ', 'Y/N']:
kwargs['max_new_tokens'] = 32
elif DATASET_TYPE(dataset) == 'Caption' and 'COCO' in dataset:
kwargs['max_new_tokens'] = 32
elif DATASET_TYPE(dataset) == 'VQA':
if listinstr(['OCRVQA', 'ChartQA', 'DocVQA'], dataset):
kwargs['max_new_tokens'] = 100
elif listinstr(['TextVQA'], dataset):
kwargs['max_new_tokens'] = 10
return kwargs
def generate_inner(self, message, dataset=None):
if dataset is not None:
kwargs = self.adjust_kwargs(dataset)
else:
kwargs = self.kwargs
prompt = ''
for s in message:
if s['type'] == 'image':
prompt += f'<img>{s["value"]}</img>'
elif s['type'] == 'text':
prompt += s['value']
if dataset is not None and DATASET_TYPE(dataset) == 'VQA':
prompt += ' Answer:'
encoded = self.tokenizer([prompt], return_tensors='pt', padding='longest')
input_ids = encoded.input_ids.to('cuda')
attention_mask = encoded.attention_mask.to('cuda')
pred = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
**kwargs)
answer = self.tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
return answer
class QwenVLChat(BaseModel):
INSTALL_REQ = False
INTERLEAVE = True
def __init__(self, model_path='Qwen/Qwen-VL-Chat', **kwargs):
assert model_path is not None
self.model_path = model_path
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map='cuda', trust_remote_code=True).eval()
torch.cuda.empty_cache()
self.kwargs = kwargs
warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
def build_history(self, message):
def concat_tilist(tilist):
image_cnt = 1
prompt = ''
for item in tilist:
if item['type'] == 'text':
prompt += item['value']
elif item['type'] == 'image':
prompt += f"Picture {image_cnt}: <img>{item['value']}</img>\n"
image_cnt += 1
return prompt
assert len(message) % 2 == 0
hist = []
for i in range(len(message) // 2):
m1, m2 = message[2 * i], message[2 * i + 1]
assert m1['role'] == 'user' and m2['role'] == 'assistant'
hist.append((concat_tilist(m1['content']), concat_tilist(m2['content'])))
return hist
def generate_inner(self, message, dataset=None):
vl_list = [{'image': s['value']} if s['type'] == 'image' else {'text': s['value']} for s in message]
query = self.tokenizer.from_list_format(vl_list)
response, _ = self.model.chat(self.tokenizer, query=query, history=None, **self.kwargs)
return response
def chat_inner(self, message, dataset=None):
assert len(message) % 2 == 1 and message[-1]['role'] == 'user'
history = self.build_history(message[:-1])
vl_list = [
{'image': s['value']} if s['type'] == 'image' else {'text': s['value']}
for s in message[-1]['content']
]
query = self.tokenizer.from_list_format(vl_list)
response, _ = self.model.chat(self.tokenizer, query=query, history=history, **self.kwargs)
return response
import sys
import torch
import os.path as osp
import os
import warnings
from .base import BaseModel
from ..dataset import DATASET_TYPE
from ..smp import *
from PIL import Image
'''
Please follow the instructions to download ckpt.
https://github.com/RBDash-Team/RBDash?tab=readme-ov-file#pretrained-weights
'''
class RBDash(BaseModel):
INSTALL_REQ = True
INTERLEAVE = False
def __init__(self, model_path, root=None, conv_mode='qwen', **kwargs):
from huggingface_hub import snapshot_download
if root is None:
raise ValueError('Please set `root` to RBDash code directory, \
which is cloned from here: "https://github.com/RBDash-Team/RBDash?tab=readme-ov-file" ')
warnings.warn('Please follow the instructions of RBDash to put the ckpt file in the right place, \
which can be found at https://github.com/RBDash-Team/RBDash?tab=readme-ov-file#structure')
assert model_path == 'RBDash-Team/RBDash-v1.5', 'We only support RBDash-v1.5 for now'
sys.path.append(root)
try:
from rbdash.model.builder import load_pretrained_model
from rbdash.mm_utils import get_model_name_from_path
except Exception as err:
logging.critical(
'Please first install RBdash and set the root path to use RBdash, '
'which is cloned from here: "https://github.com/RBDash-Team/RBDash?tab=readme-ov-file" '
)
raise err
VLMEvalKit_path = os.getcwd()
os.chdir(root)
warnings.warn('Please set `root` to RBdash code directory, \
which is cloned from here: "https://github.com/RBDash-Team/RBDash?tab=readme-ov-file" ')
try:
model_name = get_model_name_from_path(model_path)
except Exception as err:
logging.critical(
'Please follow the instructions of RBdash to put the ckpt file in the right place, '
'which can be found at https://github.com/RBDash-Team/RBDash?tab=readme-ov-file#structure'
)
raise err
download_model_path = snapshot_download(model_path)
internvit_local_dir = './model_zoo/OpenGVLab/InternViT-6B-448px-V1-5'
os.makedirs(internvit_local_dir, exist_ok=True)
snapshot_download('OpenGVLab/InternViT-6B-448px-V1-5', local_dir=internvit_local_dir)
convnext_local_dir = './model_zoo/OpenAI/openclip-convnext-large-d-320-laion2B-s29B-b131K-ft-soup'
os.makedirs(convnext_local_dir, exist_ok=True)
snapshot_download('laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup', local_dir=convnext_local_dir)
preprocessor_url = 'https://huggingface.co/openai/clip-vit-large-patch14-336/blob/main/preprocessor_config.json'
download_file_path = osp.join(convnext_local_dir, 'preprocessor_config.json')
if not osp.exists(download_file_path):
print(f'download preprocessor to {download_file_path}')
download_file(preprocessor_url, download_file_path)
tokenizer, model, image_processor, image_processor_aux, context_len = load_pretrained_model(
download_model_path, None, model_name, device_map='auto'
)
os.chdir(VLMEvalKit_path)
self.model = model
self.tokenizer = tokenizer
self.image_processor = image_processor
self.image_processor_aux = image_processor_aux
self.conv_mode = conv_mode
if tokenizer.unk_token is None:
tokenizer.unk_token = '<|endoftext|>'
tokenizer.pad_token = tokenizer.unk_token
kwargs_default = dict(temperature=float(0.2), num_beams=1, top_p=None, max_new_tokens=128, use_cache=True)
kwargs_default.update(kwargs)
self.kwargs = kwargs_default
def generate_inner(self, message, dataset=None):
try:
from rbdash.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, \
DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from rbdash.conversation import conv_templates
from rbdash.mm_utils import tokenizer_image_token, process_images
except Exception as err:
logging.critical(
'Please first install RBdash and set the root path to use RBdash, '
'which is cloned from here: "https://github.com/RBDash-Team/RBDash?tab=readme-ov-file" '
)
raise err
prompt, image = self.message_to_promptimg(message, dataset=dataset)
image = Image.open(image).convert('RGB')
if self.model.config.mm_use_im_start_end:
prompt = (
DEFAULT_IM_START_TOKEN
+ DEFAULT_IMAGE_TOKEN
+ DEFAULT_IM_END_TOKEN
+ '\n'
+ prompt
)
else:
prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt
conv = conv_templates[self.conv_mode].copy()
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
input_ids = input_ids.unsqueeze(0).cuda()
if hasattr(self.model.config, 'image_size_aux'):
if not hasattr(self.image_processor, 'image_size_raw'):
self.image_processor.image_size_raw = self.image_processor.crop_size.copy()
self.image_processor.crop_size['height'] = self.model.config.image_size_aux
self.image_processor.crop_size['width'] = self.model.config.image_size_aux
self.image_processor.size['shortest_edge'] = self.model.config.image_size_aux
self.image_processor_aux.crop_size['height'] = self.model.config.image_size_aux
self.image_processor_aux.crop_size['width'] = self.model.config.image_size_aux
self.image_processor_aux.size[
'shortest_edge'
] = self.model.config.image_size_aux
image_tensor = process_images([image], self.image_processor, self.model.config)[0]
image_grid = getattr(self.model.config, 'image_grid', 1)
if hasattr(self.model.config, 'image_size_aux'):
raw_shape = [
self.image_processor.image_size_raw['height'] * image_grid,
self.image_processor.image_size_raw['width'] * image_grid
]
if self.image_processor is not self.image_processor_aux:
image_tensor_aux = process_images([image], self.image_processor_aux, self.model.config)[
0
]
else:
image_tensor_aux = image_tensor
image_tensor = torch.nn.functional.interpolate(
image_tensor[None],
size=raw_shape,
mode='bilinear',
align_corners=False
)[0]
else:
image_tensor_aux = []
if image_grid >= 2:
raw_image = image_tensor.reshape(
3, image_grid, self.image_processor.image_size_raw['height'],
image_grid, self.image_processor.image_size_raw['width']
)
raw_image = raw_image.permute(1, 3, 0, 2, 4)
raw_image = raw_image.reshape(
-1, 3, self.image_processor.image_size_raw['height'], self.image_processor.image_size_raw['width']
)
if getattr(self.model.config, 'image_global', False):
global_image = image_tensor
if len(global_image.shape) == 3:
global_image = global_image[None]
global_image = torch.nn.functional.interpolate(
global_image,
size=[
self.image_processor.image_size_raw['height'],
self.image_processor.image_size_raw['width']
],
mode='bilinear',
align_corners=False
)
raw_image = torch.cat([raw_image, global_image], dim=0)
image_tensor = raw_image.contiguous()
images = image_tensor[None].to(dtype=self.model.dtype, device='cuda', non_blocking=True)
if len(image_tensor_aux) > 0:
images_aux = image_tensor_aux[None].to(dtype=self.model.dtype, device='cuda', non_blocking=True)
else:
images_aux = None
with torch.inference_mode():
output_ids = self.model.generate(
input_ids,
max_new_tokens=512,
images=images,
images_aux=images_aux,
do_sample=True if self.kwargs['temperature'] > 0 else False,
temperature=self.kwargs['temperature'],
top_p=self.kwargs['top_p'],
num_beams=self.kwargs['num_beams']
)
outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
return outputs
def use_custom_prompt(self, dataset):
assert dataset is not None
if listinstr(['MMDU', 'MME-RealWorld', 'MME-RealWorld-CN'], dataset):
# For Multi-Turn we don't have custom prompt
return False
if 'mme' in dataset.lower():
return True
elif 'hallusionbench' in dataset.lower():
return True
elif 'mmmu' in dataset.lower():
return True
elif 'mmbench' in dataset.lower():
return True
return False
def build_mme(self, line):
question = line['question']
prompt = question + 'Answer the question using a single word or phrase.'
return prompt
def build_hallusionbench(self, line):
question = line['question']
prompt = question + '\nAnswer the question using a single word or phrase.'
return prompt
def build_mmbench(self, line):
question = line['question']
options = {
cand: line[cand]
for cand in string.ascii_uppercase
if cand in line and not pd.isna(line[cand])
}
options_prompt = 'Options:\n'
for key, item in options.items():
options_prompt += f'{key}. {item}\n'
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
prompt = ''
if hint is not None:
prompt += f'Hint: {hint}\n'
prompt += f'Question: {question}\n'
if len(options):
prompt += options_prompt
prompt += "Answer with the option's letter from the given choices directly."
else:
prompt += 'Answer the question using a single word or phrase.'
return prompt
def build_mmmu(self, line):
question = line['question']
options = {
cand: line[cand]
for cand in string.ascii_uppercase
if cand in line and not pd.isna(line[cand])
}
options_prompt = 'Options:\n'
for key, item in options.items():
options_prompt += f'({key}) {item}\n'
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
prompt = ''
if hint is not None:
prompt += f'Hint: {hint}\n'
prompt += f'Question: {question}\n'
if len(options):
prompt += options_prompt
prompt += "Answer with the option's letter from the given choices directly."
else:
prompt += 'Answer the question using a single word or phrase.'
return prompt
def build_prompt(self, line, dataset=None):
assert dataset is None or isinstance(dataset, str)
assert self.use_custom_prompt(dataset)
tgt_path = self.dump_image(line, dataset)
if 'mme' in dataset.lower():
prompt = self.build_mme(line)
elif 'hallusionbench' in dataset.lower():
prompt = self.build_hallusionbench(line)
elif 'mmmu' in dataset.lower():
prompt = self.build_mmmu(line)
elif 'mmbench' in dataset.lower():
prompt = self.build_mmbench(line)
ret = [dict(type='text', value=prompt)]
ret.extend([dict(type='image', value=s) for s in tgt_path])
return ret
import math
import pandas as pd
import random
import re
import string
import torch
import torch.distributed as dist
import torchvision.transforms as T
import transformers
import warnings
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoTokenizer, AutoConfig, AutoModel, CLIPImageProcessor
from .base import BaseModel
from ..dataset import DATASET_TYPE, DATASET_MODALITY
from ..smp import *
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform
def process_response(response, dataset_name):
if dataset_name is None:
return response
if listinstr(['ChartQA', 'OCRVQA'], dataset_name):
if len(response) >= 1 and response[-1] == '.':
response = response[:-1]
return response
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
def load_image(image_file, input_size=448, max_num=6, upscale=False):
image = Image.open(image_file).convert('RGB')
if upscale:
image = image.resize((image.width * 2, image.height * 2), Image.BILINEAR)
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
def get_local_rank_and_local_world_size():
if not dist.is_available():
return 0, 1
if not dist.is_initialized():
return 0, 1
if 'SLURM_LOCALID' in os.environ:
local_rank = int(os.environ['SLURM_LOCALID'])
local_world_size = int(os.environ['SLURM_NTASKS_PER_NODE'])
return local_rank, local_world_size
if 'LOCAL_RANK' in os.environ and 'LOCAL_WORLD_SIZE' in os.environ:
return int(os.environ['LOCAL_RANK']), int(os.environ['LOCAL_WORLD_SIZE'])
raise NotImplementedError(
"Fail to get local_rank and local_world_size! "
"Please ensure that you set the environment variable "
"`LOCAL_RANK` and `LOCAL_WORLD_SIZE`"
)
def build_multi_choice_prompt(line, dataset=None):
question = line['question']
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
if hint is not None:
question = hint + '\n' + question
options = {
cand: line[cand]
for cand in string.ascii_uppercase
if cand in line and not pd.isna(line[cand])
}
for key, item in options.items():
question += f'\n{key}. {item}'
prompt = question
if len(options):
prompt += '\n请直接回答选项字母。' if cn_string(
prompt) else "\nAnswer with the option's letter from the given choices directly."
else:
prompt += '\n请直接回答问题。' if cn_string(prompt) else '\nAnswer the question directly.'
return prompt
def build_video_prompt(prompt, dataset=None, max_frames=64):
for start in range(0, max_frames, 8):
images_to_remove = ''.join([f'<Image-{i}>' for i in range(start + 1, start + 9)])
prompt = prompt.replace(images_to_remove, '')
for i in range(max_frames):
prompt = prompt.replace(f'Image-{i + 1}', f'Frame-{i + 1}')
if listinstr(['MMBench-Video'], dataset):
prompt = prompt.replace('\nAnswer:', '')
elif listinstr(['Video-MME'], dataset):
prompt = prompt.replace('\nAnswer:', '')
prompt += "\nAnswer with the option's letter from the given choices directly."
elif listinstr(['MVBench'], dataset):
prompt = prompt.replace('Best option:(', '')
return prompt
def reorganize_prompt(message, image_num, dataset=None):
if dataset is not None and listinstr(['MUIRBench'], dataset):
prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
images_to_remove = ' '.join(['<image>'] * image_num)
prompt = prompt.replace(images_to_remove, '')
for i in range(image_num):
prompt = prompt.replace('<image>', f'<Image-{i + 1}>', 1)
prompt = ''.join([f'Image-{i + 1}: <image>\n' for i in range(image_num)]) + prompt
elif image_num == 1:
prompt = '<image>\n' + '\n'.join([x['value'] for x in message if x['type'] == 'text'])
else:
prompt, image_idx = '', 1
for x in message:
if x['type'] == 'text':
prompt += x['value']
elif x['type'] == 'image':
prompt += f'<Image-{image_idx}>'
image_idx += 1
prompt = ''.join([f'Image-{i + 1}: <image>\n' for i in range(image_num)]) + prompt
images_to_remove = ''.join([f'<Image-{i + 1}>' for i in range(image_num)])
prompt = prompt.replace(images_to_remove, '')
return prompt
class SailVL(BaseModel):
INSTALL_REQ = False
INTERLEAVE = True
def __init__(self,
model_path='BytedanceDouyinContent/SAIL-VL-2B',
load_in_8bit=False,
**kwargs):
assert model_path is not None
assert version_cmp(transformers.__version__, '4.36.2', 'ge')
self.model_path = model_path
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
# Regular expression to match the pattern 'Image' followed by a number, e.g. Image1
self.pattern = r'Image(\d+)'
# Replacement pattern to insert a hyphen between 'Image' and the number, e.g. Image-1
self.replacement = r'Image-\1'
# Convert InternVL2 response to dataset format
# e.g. Image1 -> Image-1
# Regular expression to match the pattern 'Image-' followed by a number
self.reverse_pattern = r'Image-(\d+)'
# Replacement pattern to remove the hyphen (Image-1 -> Image1)
self.reverse_replacement = r'Image\1'
self.model = AutoModel.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
load_in_8bit=load_in_8bit,
trust_remote_code=True,
low_cpu_mem_usage=True).eval().cuda()
self.device = 'cuda'
self.image_size = self.model.config.vision_config.image_size
kwargs_default = dict(do_sample=False, max_new_tokens=4096, top_p=None)
kwargs_default.update(kwargs)
self.kwargs = kwargs_default
warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
def use_custom_prompt(self, dataset):
assert dataset is not None
if listinstr(['MMDU', 'MME-RealWorld', 'MME-RealWorld-CN'], dataset):
# For Multi-Turn we don't have custom prompt
return False
if DATASET_MODALITY(dataset) == 'VIDEO':
# For Video benchmarks we don't have custom prompt at here
return False
else:
return True
def build_prompt(self, line, dataset=None):
assert self.use_custom_prompt(dataset)
assert dataset is None or isinstance(dataset, str)
tgt_path = self.dump_image(line, dataset)
if dataset is not None and DATASET_TYPE(dataset) == 'Y/N':
question = line['question']
if listinstr(['MME'], dataset):
prompt = question + ' Answer the question using a single word or phrase.'
elif listinstr(['HallusionBench', 'AMBER'], dataset):
prompt = question + ' Please answer yes or no. Answer the question using a single word or phrase.'
else:
prompt = question
elif dataset is not None and DATASET_TYPE(dataset) == 'MCQ':
prompt = build_multi_choice_prompt(line, dataset)
elif dataset is not None and DATASET_TYPE(dataset) == 'VQA':
question = line['question']
if listinstr(['LLaVABench', 'WildVision'], dataset):
prompt = question + '\nAnswer this question in detail.'
elif listinstr(['OCRVQA', 'TextVQA', 'ChartQA', 'DocVQA', 'InfoVQA', 'OCRBench',
'DUDE', 'SLIDEVQA', 'GQA', 'MMLongBench_DOC'], dataset):
prompt = question + '\nAnswer the question using a single word or phrase.'
elif listinstr(['MathVista', 'MathVision', 'VCR', 'MTVQA', 'MMVet', 'MathVerse',
'MMDU', 'CRPE', 'MIA-Bench', 'MM-Math', 'DynaMath', 'QSpatial'], dataset):
prompt = question
else:
prompt = question + '\nAnswer the question using a single word or phrase.'
else:
# VQA_ex_prompt: OlympiadBench, VizWiz
prompt = line['question']
message = [dict(type='text', value=prompt)]
message.extend([dict(type='image', value=s) for s in tgt_path])
return message
def set_max_num(self, dataset):
# The total limit on the number of images processed, set to avoid Out-of-Memory issues.
self.total_max_num = 64
if dataset is None:
self.max_num = 6
return None
res_12_datasets = ['ChartQA_TEST', 'MMMU_DEV_VAL', 'MMMU_TEST', 'MME-RealWorld',
'VCR_EN', 'VCR_ZH', 'OCRVQA']
res_18_datasets = ['DocVQA_VAL', 'DocVQA_TEST', 'DUDE', 'MMLongBench_DOC', 'SLIDEVQA']
res_24_datasets = ['InfoVQA_VAL', 'InfoVQA_TEST', 'OCRBench', 'HRBench4K', 'HRBench8K']
if DATASET_MODALITY(dataset) == 'VIDEO':
self.max_num = 1
elif listinstr(res_12_datasets, dataset):
self.max_num = 12
elif listinstr(res_18_datasets, dataset):
self.max_num = 18
elif listinstr(res_24_datasets, dataset):
self.max_num = 24
else:
self.max_num = 6
def generate_inner(self, message, dataset=None):
self.set_max_num(dataset)
image_num = len([x for x in message if x['type'] == 'image'])
max_num = max(1, min(self.max_num, self.total_max_num // image_num))
prompt = reorganize_prompt(message, image_num, dataset=dataset)
if dataset is not None and DATASET_MODALITY(dataset) == 'VIDEO':
prompt = build_video_prompt(prompt, dataset)
if image_num > 1:
image_path = [x['value'] for x in message if x['type'] == 'image']
num_patches_list, pixel_values_list = [], []
for image_idx, file_name in enumerate(image_path):
upscale_flag = image_idx == 0 and dataset is not None and listinstr(['MMMU'], dataset)
curr_pixel_values = load_image(
file_name, max_num=max_num, upscale=upscale_flag).to(self.device).to(torch.bfloat16)
num_patches_list.append(curr_pixel_values.size(0))
pixel_values_list.append(curr_pixel_values)
pixel_values = torch.cat(pixel_values_list, dim=0)
elif image_num == 1:
image_path = [x['value'] for x in message if x['type'] == 'image'][0]
upscale_flag = dataset is not None and listinstr(['MMMU'], dataset)
pixel_values = load_image(
image_path, max_num=max_num, upscale=upscale_flag).to(self.device).to(torch.bfloat16)
num_patches_list = [pixel_values.size(0)]
else:
pixel_values = None
num_patches_list = []
with torch.no_grad():
response = self.model.chat(
self.tokenizer,
pixel_values=pixel_values,
num_patches_list=num_patches_list,
question=prompt,
generation_config=self.kwargs,
verbose=True
)
response = process_response(response, dataset_name=dataset)
return response
def build_history(self, message):
# Global Variables
image_path = []
image_cnt = 0
def concat_tilist(tilist):
nonlocal image_cnt # Declare image_cnt as nonlocal to modify it
prompt = ''
for item in tilist:
# Substitute the pattern in the text
if item['type'] == 'text':
prompt += re.sub(self.pattern, self.replacement, item['value'])
elif item['type'] == 'image':
image_cnt += 1
prompt += '<image>\n'
image_path.append(item['value'])
return prompt
# Only previous messages
assert len(message) % 2 == 0
history = []
for i in range(len(message) // 2):
m1, m2 = message[2 * i], message[2 * i + 1]
assert m1['role'] == 'user' and m2['role'] == 'assistant'
history.append((concat_tilist(m1['content']), concat_tilist(m2['content'])))
return history, image_path, image_cnt
def chat_inner(self, message, dataset=None):
self.set_max_num(dataset)
kwargs_default = dict(do_sample=False, max_new_tokens=512, top_p=None, num_beams=1)
self.kwargs = kwargs_default
if len(message) > 1:
history, image_path, image_cnt = self.build_history(message[:-1])
else:
history, image_path, image_cnt = None, [], 1
current_msg = message[-1]
question = ''
# If message is just text in the conversation
if len(current_msg['content']) == 1 and current_msg['content'][0]['type'] == 'text':
question = current_msg['content'][0]['value']
question = re.sub(self.pattern, self.replacement, question) # Fix pattern as per InternVL
else:
for msg in current_msg['content']:
if msg['type'] == 'text':
question += re.sub(self.pattern, self.replacement, msg['value'])
elif msg['type'] == 'image':
image_cnt += 1
question += '<image>\n'
image_path.append(msg['value'])
if image_cnt > 1:
num_patches_list = []
pixel_values_list = []
for image_idx, file_name in enumerate(image_path):
upscale_flag = image_idx == 0 and dataset is not None and listinstr(['MMMU_DEV_VAL'], dataset)
curr_pixel_values = load_image(
file_name, max_num=self.max_num, upscale=upscale_flag).to(self.device).to(torch.bfloat16)
num_patches_list.append(curr_pixel_values.size(0))
pixel_values_list.append(curr_pixel_values)
pixel_values = torch.cat(pixel_values_list, dim=0)
elif image_cnt == 1:
upscale_flag = listinstr(['MMMU_DEV_VAL'], dataset)
pixel_values = load_image(
image_path, max_num=self.max_num, upscale=upscale_flag).to(self.device).to(torch.bfloat16)
num_patches_list = [pixel_values.size(0)]
else:
pixel_values = None
num_patches_list = []
response, history = self.model.chat(
self.tokenizer,
pixel_values=pixel_values,
num_patches_list=num_patches_list,
question=question,
generation_config=self.kwargs,
history=history,
return_history=True
)
response = re.sub(self.reverse_pattern, self.reverse_replacement, response)
return response
import torch
from PIL import Image
from abc import abstractproperty
import sys
import os.path as osp
from .base import BaseModel
from ..smp import *
from ..dataset import DATASET_TYPE
import copy
class SliME(BaseModel):
INSTALL_REQ = True
INTERLEAVE = True
DEFAULT_IMAGE_TOKEN = '<image>'
IMAGE_TOKEN_INDEX = -200
def __init__(self, model_path='yifanzhang114/SliME-Llama3-8B', **kwargs):
assert model_path is not None
try:
from llava.model.builder import load_pretrained_model
from llava.conversation import conv_templates
from llava.mm_utils import get_model_name_from_path, tokenizer_image_token
except Exception as err:
logging.critical('Please install requirements on https://github.com/yfzhang114/SliME before using SliME')
raise err
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, _ = load_pretrained_model(model_path, None, model_name, device_map=None)
model.cuda().eval()
model.tie_weights()
if 'llama3' in model_path.lower():
conv_mode = 'llama3'
elif 'vicuna' in model_path.lower():
conv_mode = 'v1'
self.conv_template = conv_mode
self.conv_templates = conv_templates
self.tokenizer = tokenizer
self.model = model
self.image_processor = image_processor
self.tokenizer_image_token = tokenizer_image_token
def generate_inner(self, message, dataset=None):
content, images = '', []
for msg in message:
if msg['type'] == 'text':
content += msg['value']
else:
images.append(Image.open(msg['value']).convert('RGB'))
content += (self.DEFAULT_IMAGE_TOKEN + '\n')
preprocess = self.image_processor.preprocess
image_tokenizer = self.tokenizer_image_token
image_tensor = [
preprocess(f, return_tensors='pt')['pixel_values'][0].half().cuda() for f in images
]
image_tensor = torch.stack(image_tensor)
conv = copy.deepcopy(self.conv_templates[self.conv_template])
conv.messages = list(conv.messages)
conv.append_message(conv.roles[0], content)
conv.append_message(conv.roles[1], None)
prompt_question = conv.get_prompt()
input_ids = image_tokenizer(prompt_question, self.tokenizer, self.IMAGE_TOKEN_INDEX, return_tensors='pt')
input_ids = input_ids.unsqueeze(0).cuda()
cont = self.model.generate(
input_ids,
images=image_tensor,
do_sample=False,
temperature=0,
max_new_tokens=512,
)
text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0]
return text_outputs
import torch
import os.path as osp
import warnings
from .base import BaseModel
from ..smp import splitlen
from PIL import Image
import os
import math
class SmolVLM(BaseModel):
INSTALL_REQ = True
INTERLEAVE = True
def __init__(self, model_path='HuggingFaceTB/SmolVLM-Instruct', **kwargs):
from transformers import AutoProcessor, Idefics3ForConditionalGeneration
assert osp.exists(model_path) or splitlen(model_path) == 2
self.processor = AutoProcessor.from_pretrained(model_path)
self.model = Idefics3ForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.float32,
device_map='cuda'
)
kwargs_default = {'max_new_tokens': 512,
'use_cache': True}
kwargs_default.update(kwargs)
self.kwargs = kwargs_default
warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config.')
torch.cuda.empty_cache()
def generate_inner(self, message, dataset=None):
if dataset in ['MMBench_DEV_EN', 'MMBench_TEST_EN', 'MMBench_DEV_CN', 'MMBench_TEST_CN', 'MMBench',
'MMBench_CN', 'MMBench_DEV_EN_V11', 'MMBench_DEV_CN_V11', 'MMBench_TEST_EN_V11',
'MMBench_TEST_CN_V11', 'MMBench_V11', 'MMBench_CN_V11', 'CCBench']:
formatted_messages, formatted_images = self.build_prompt_mmbench(message)
elif dataset in ['MMMU_DEV_VAL', 'MMMU_TEST']:
formatted_messages, formatted_images = self.build_prompt_mmmu(message)
elif dataset in ['MathVista_MINI']:
formatted_messages, formatted_images = self.build_prompt_mathvista(message)
elif dataset in ['MME', 'MMVet', 'OCRVQA_TEST', 'OCRVQA_TESTCORE', 'TextVQA_VAL',
'ChartQA_TEST', 'DocVQA_VAL', 'DocVQA_TEST', 'InfoVQA_VAL', 'InfoVQA_TEST']:
formatted_messages, formatted_images = self.build_prompt_default(message, add_brief=True)
elif dataset == 'HallusionBench':
formatted_messages, formatted_images = self.build_prompt_default(message, add_yes_or_no=True)
elif dataset in ['MMStar', 'SEEDBench_IMG', 'AI2D_TEST', 'ScienceQA_VAL', 'ScienceQA_TEST']:
formatted_messages, formatted_images = self.build_prompt_puremcq(message)
else:
formatted_messages, formatted_images = self.build_prompt_default(message)
images = [formatted_images] if isinstance(formatted_images, Image.Image) else formatted_images
inputs = self.processor(text=formatted_messages, images=images, return_tensors="pt")
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
generated_ids = self.model.generate(**inputs, **self.kwargs)
generated_text = self.processor.batch_decode(
generated_ids[:, inputs['input_ids'].size(1):],
skip_special_tokens=True
)[0]
return generated_text.strip()
def build_prompt_default(self, message, add_brief=False, add_yes_or_no=False):
from transformers.image_utils import load_image
prompt, images = 'User:', []
for msg in message:
if msg['type'] == 'image':
img = load_image(msg['value'])
images.append(img)
prompt += '<image>'
elif msg['type'] == 'text':
prompt += msg['value'].strip()
if add_brief:
prompt += '\nGive a very brief answer.'
if add_yes_or_no:
prompt += '\nAnswer yes or no.'
prompt += '<end_of_utterance>\nAssistant:'
return prompt, images
def build_prompt_puremcq(self, message):
from transformers.image_utils import load_image
replace_mapping = {
'\nOptions:': '\nChoices:',
'Please select the correct answer from the options above.': 'Answer with the letter.',
}
prompt, images = 'User:', []
for msg in message:
if msg['type'] == 'image':
img = load_image(msg['value'])
images.append(img)
prompt += '<image>'
elif msg['type'] == 'text':
instruction = msg['value'].strip()
for k, v in replace_mapping.items():
instruction = instruction.replace(k, v)
prompt += instruction
prompt += '<end_of_utterance>\nAssistant: Answer:'
return prompt, images
def build_prompt_mt(self, message):
from transformers.image_utils import load_image
prompt, images = '', []
for msg in message:
if msg['role'] == 'user':
prompt += 'User: '
elif msg['role'] == 'assistant':
prompt += 'Assistant: '
for item in msg['content']:
if item['type'] == 'image':
img = load_image(item['value'])
images.append(img)
elif item['type'] == 'text':
prompt += item['value'].strip()
prompt += '<end_of_utterance>\n'
return prompt + 'Assistant: '
def build_prompt_mmbench(self, message):
from transformers.image_utils import load_image
replace_mapping = {
'\nOptions:': '\nChoices:',
'Please select the correct answer from the options above.': 'Answer with a letter.',
}
prompt, images = 'User:', []
for msg in message:
if msg['type'] == 'image':
img = load_image(msg['value'])
images.append(img)
prompt += '<image>'
elif msg['type'] == 'text':
instruction = msg['value'].strip()
for k, v in replace_mapping.items():
instruction = instruction.replace(k, v)
# Swap hint and question
if instruction.startswith('Hint:'):
hint, question = instruction.split('\nQuestion:')
question, choices = question.split('\nChoices:')
instruction = (
'Question:' + question + '\n' + hint + '\nChoices:' + choices
)
prompt += instruction
prompt += '<end_of_utterance>\nAssistant: Answer:'
return prompt, images
def build_prompt_mmmu(self, message):
from transformers.image_utils import load_image
replace_mapping = {
'Question:': '',
'Please select the correct answer from the options above.': 'Answer with the letter.',
'\nOptions:': '\nChoices:',
}
prompt, images, img_counter = 'User: Question: ', [], 1
for msg in message:
if msg['type'] == 'image':
prompt += f'<image {img_counter}>:<image>\n'
img_counter += 1
img_counter = 1
for msg in message:
if msg['type'] == 'image':
img = load_image(msg['value'])
images.append(img)
prompt += f' <image {img_counter}> '
img_counter += 1
elif msg['type'] == 'text':
instruction = msg['value'].strip()
for k, v in replace_mapping.items():
instruction = instruction.replace(k, v)
prompt += instruction.strip()
prompt += '<end_of_utterance>\nAssistant:'
if 'A.' in prompt and 'B.' in prompt:
prompt += ' Answer:'
return prompt, images
def build_prompt_mathvista(self, message):
from transformers.image_utils import load_image
replace_mapping = {
'(A) ': 'A. ',
'(B) ': 'B. ',
'(C) ': 'C. ',
'(D) ': 'D. ',
'(E) ': 'E. ',
'(F) ': 'F. ',
'(G) ': 'G. ',
'(H) ': 'H. ',
'\nOptions:': '\nChoices:',
'Hint: ': '',
}
prompt, images = 'User:', []
for msg in message:
if msg['type'] == 'image':
img = load_image(msg['value'])
images.append(img)
prompt += '<image>'
elif msg['type'] == 'text':
instruction = msg['value'].strip()
for k, v in replace_mapping.items():
instruction = instruction.replace(k, v)
prompt += instruction.strip()
prompt += '<end_of_utterance>\nAssistant:'
if 'A.' in prompt and 'B.' in prompt:
prompt += ' Answer:'
return prompt, images
def chat_inner(self, message, dataset=None):
formatted_messages, formatted_images = self.build_prompt_mt(message)
images = [formatted_images] if isinstance(formatted_images, Image.Image) else formatted_images
resulting_messages = [{"role": "user", "content": [{"type": "image"}]
+ [{"type": "text", "text": formatted_messages}]}]
prompt = self.processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
inputs = self.processor(text=prompt, images=images, return_tensors="pt")
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
generated_ids = self.model.generate(**inputs, **self.kwargs)
generated_text = self.processor.batch_decode(
generated_ids[:, inputs['input_ids'].size(1):],
skip_special_tokens=True
)[0]
return generated_text.strip()
import sys
import torch
from abc import abstractproperty
from .base import BaseModel
from ..smp import *
from ..dataset import DATASET_TYPE
from transformers import AutoTokenizer, BitsAndBytesConfig
class TransCoreM(BaseModel):
INSTALL_REQ = True
INTERLEAVE = False
def load_pretrained_model(self, model_path, load_8bit=False, load_4bit=False, revision='main'):
from transcorem.model import TransCoreMQWenForCausalLM
from transcorem.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
import transcorem.config_param as config_param
kwargs = {'revision': revision}
if load_8bit:
kwargs['load_in_8bit'] = True
elif load_4bit:
kwargs['load_in_4bit'] = True
kwargs['quantization_config'] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4'
)
else:
kwargs['torch_dtype'] = torch.float16
config_param.model_path = model_path
tokenizer = AutoTokenizer.from_pretrained(
model_path, use_fast=False, revision=revision, trust_remote_code=True)
model = TransCoreMQWenForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
image_processor = None
mm_use_im_start_end = getattr(model.config, 'mm_use_im_start_end', False)
mm_use_im_patch_token = getattr(model.config, 'mm_use_im_patch_token', True)
if mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
model.resize_token_embeddings(len(tokenizer))
vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
vision_tower.load_model()
vision_tower.to(device='cpu', dtype=torch.float16)
image_processor = vision_tower.image_processor
if hasattr(model.config, 'max_sequence_length'):
context_len = model.config.max_sequence_length
else:
context_len = 2048
return tokenizer, model, image_processor, context_len
def __init__(self,
root=None,
revision='main',
**kwargs):
self.root = root
self.revision = revision
sys.path.append(root)
model_path = 'PCIResearch/TransCore-M'
assert osp.exists(model_path) or splitlen(model_path) == 2
self.tokenizer, self.model, self.image_processor, self.context_len = self.load_pretrained_model(
model_path=model_path, revision=revision)
self.model = self.model.cuda()
print('==============conv_mode: transcorem_v1')
self.conv_mode = 'transcorem_v1'
kwargs_default = dict(do_sample=False, temperature=0.0, max_new_tokens=512, top_p=None, num_beams=1)
kwargs_default.update(kwargs)
self.kwargs = kwargs_default
warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
def use_custom_prompt(self, dataset):
assert dataset is not None
if DATASET_TYPE(dataset) == 'MCQ':
return True
return False
def build_prompt(self, line, dataset=None):
assert dataset is None or isinstance(dataset, str)
assert self.use_custom_prompt(dataset)
tgt_path = self.dump_image(line, dataset)
question = line['question']
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
if hint is not None:
question = hint + '\n' + question
options = {
cand: line[cand]
for cand in string.ascii_uppercase
if cand in line and not pd.isna(line[cand])
}
for key, item in options.items():
question += f'\n{key}. {item}'
prompt = question
if len(options):
prompt += (
'\n请直接回答选项字母。' if cn_string(prompt) else
"\nAnswer with the option's letter from the given choices directly."
)
else:
prompt += '\n请直接回答问题。' if cn_string(prompt) else '\nAnswer the question directly.'
message = [dict(type='text', value=prompt)]
message.extend([dict(type='image', value=f) for f in tgt_path])
return message
def generate_inner(self, message, dataset=None):
from transcorem.mm_utils import highres_process_images, tokenizer_image_token, KeywordsStoppingCriteria
from transcorem.constants import (
IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN)
from transcorem.conversation import conv_templates, SeparatorStyle
prompt, image_path = self.message_to_promptimg(message, dataset=dataset)
image = Image.open(image_path).convert('RGB')
args = abstractproperty()
args.image_aspect_ratio = 'pad'
image_patches = highres_process_images(image, self.image_processor, args, base_reso=336)
image_patches = [patch.unsqueeze(0).to('cuda', dtype=torch.float16) for patch in image_patches]
if self.model.config.mm_use_im_start_end:
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + prompt
else:
inp = DEFAULT_IMAGE_TOKEN + '\n' + prompt
conv = conv_templates[self.conv_mode].copy()
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt_conv = conv.get_prompt()
input_ids = tokenizer_image_token(prompt_conv, self.tokenizer, IMAGE_TOKEN_INDEX,
return_tensors='pt').unsqueeze(0).cuda()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
with torch.inference_mode():
output_ids = self.model.generate(
input_ids,
images=image_patches,
use_cache=True,
stopping_criteria=[stopping_criteria],
**self.kwargs)
input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
return outputs
from .valley_eagle_chat import ValleyEagleChat
accelerate==0.34.2
bert-score==0.3.13
byted-wandb==0.13.72
datasets==2.21.0
decord==0.6.0
deepspeed==0.9.5
einops==0.8.0
evaluate==0.4.3
fastapi==0.115.0
flash_attn
ftfy==6.2.3
markdown2==2.5.0
ninja==1.11.1.1
nltk==3.9.1
numpy==1.26.4
omegaconf==2.3.0
openai==0.28
opencv-python-headless==4.10.0.84
packaging==24.1
pandas==2.2.2
peft==0.5.0
prettytable==3.11.0
protobuf==3.20.3
pyarrow==15.0.0
pydantic==1.10.14
qwen_vl_utils
requests==2.32.3
rouge-score==0.1.2
scikit-image==0.24.0
scikit-learn==1.5.2
sentencepiece==0.1.97
timm==0.6.7
tokenizers>=0.13.3
torchmetrics
transformers==4.45.2
uvicorn==0.30.6
CONTROLLER_HEART_BEAT_EXPIRATION = 30
WORKER_HEART_BEAT_INTERVAL = 15
CONV_SEP = "###"
LOGDIR = "./valley/serve/serve_logs"
SERVE_IMAGE = "./valley/serve/serve_logs"
SHELL_UI_HEADER = '''
██╗ ██╗ █████╗ ██╗ ██╗ ███████╗██╗ ██╗ ██████╗██╗ ██╗ █████╗ ████████╗
██║ ██║██╔══██╗██║ ██║ ██╔════╝╚██╗ ██╔╝ ██╔════╝██║ ██║██╔══██╗╚══██╔══╝
██║ ██║███████║██║ ██║ █████╗ ╚████╔╝ ██║ ███████║███████║ ██║
╚██╗ ██╔╝██╔══██║██║ ██║ ██╔══╝ ╚██╔╝ ██║ ██╔══██║██╔══██║ ██║
╚████╔╝ ██║ ██║███████╗███████╗███████╗ ██║ ╚██████╗██║ ██║██║ ██║ ██║
╚═══╝ ╚═╝ ╚═╝╚══════╝╚══════╝╚══════╝ ╚═╝ ╚═════╝╚═╝ ╚═╝╚═╝ ╚═╝ ╚═╝
'''
import dataclasses
from enum import auto, Enum
from typing import List, Tuple
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
MPT = auto()
PLAIN = auto()
LLAMA_2 = auto()
MISTRAL = auto()
QWEN2 = auto()
GEMMA2 = auto()
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: Tuple[str] = ("USER", "ASSISTANT")
messages: Tuple[List[str]] = ()
offset: int = 0
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
version: str = "Unknown"
skip_next: bool = False
def get_prompt(self):
messages = self.messages
if len(messages) > 0 and type(messages[0][1]) is tuple:
messages = self.messages.copy()
init_role, init_msg = messages[0].copy()
init_msg = init_msg[0].replace("<image>", "").strip()
if 'mmtag' in self.version:
messages[0] = (init_role, init_msg)
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
messages.insert(1, (self.roles[1], "Received."))
else:
messages[0] = (init_role, "<image>\n" + init_msg)
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system + self.sep if self.system is not None else ''
for role, message in messages:
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + self.sep # keep space after ":"
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0] if self.system is not None else ''
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + seps[i % 2] # keep space after ":"
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.MPT:
ret = self.system + self.sep if self.system is not None else ''
for role, message in messages:
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + message + self.sep
else:
ret += role
elif self.sep_style == SeparatorStyle.LLAMA_2:
def wrap_sys(msg):
return f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
def wrap_inst(msg):
return f"[INST] {msg} [/INST]"
ret = ""
for i, (role, message) in enumerate(messages):
if i == 0:
assert message, "first message should not be none"
assert role == self.roles[0], "first message should come from user"
if message:
if type(message) is tuple:
message, _, _ = message
if i == 0:
if self.system is not None:
message = wrap_sys(self.system) + message
else:
message = message
if i % 2 == 0:
message = wrap_inst(message)
ret += self.sep + message
else:
ret += " " + message + " " + self.sep2
else:
ret += ""
ret = ret.lstrip(self.sep)
elif self.sep_style == SeparatorStyle.MISTRAL:
'''text = "[INST] What is your favourite condiment? [/INST]"
"Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount \
of zesty flavour to whatever I'm cooking up in the kitchen!</s> "
"[INST] Do you have mayonnaise recipes? [/INST]"'''
def wrap_sys(msg):
return f"[INST] {msg} [/INST]Sure!"
def wrap_inst(msg):
return f"[INST] {msg} [/INST]"
ret = ""
for i, (role, message) in enumerate(messages):
if i == 0:
assert message, "first message should not be none"
assert role == self.roles[0], "first message should come from user"
if message:
if type(message) is tuple:
message, _, _ = message
if i == 0:
if self.system is not None:
ret = wrap_sys(self.system) + self.sep2
else:
ret = ""
if i % 2 == 0:
message = wrap_inst(message)
ret += message
else:
ret += message + self.sep2
else:
ret += ""
elif self.sep_style == SeparatorStyle.PLAIN:
seps = [self.sep, self.sep2]
ret = self.system if self.system is not None else ''
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += message + seps[i % 2]
else:
ret += ""
elif self.sep_style == SeparatorStyle.QWEN2:
pass
elif self.sep_style == SeparatorStyle.GEMMA2:
pass
else:
raise ValueError(f"Invalid style: {self.sep_style}")
return ret
def append_message(self, role, message):
self.messages.append([role, message])
def get_images(self, return_pil=False):
images = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
import base64
from io import BytesIO
from PIL import Image
msg, image, image_process_mode = msg
if image_process_mode == "Pad":
def expand2square(pil_img, background_color=(122, 116, 104)):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image)
elif image_process_mode in ["Default", "Crop"]:
pass
elif image_process_mode == "Resize":
image = image.resize((336, 336))
else:
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if longest_edge != max(image.size):
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
if return_pil:
images.append(image)
else:
buffered = BytesIO()
image.save(buffered, format="PNG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
images.append(img_b64_str)
return images
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
import base64
from io import BytesIO
msg, image, image_process_mode = msg
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
msg = img_str + msg.replace('<image>', '').strip()
ret.append([msg, None])
else:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
version=self.version)
def dict(self):
if len(self.get_images()) > 0:
return {
"system": self.system,
"roles": self.roles,
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
conv_vicuna_v0 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
version="v0",
messages=(
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
("Assistant",
"Renewable energy sources are those that can be replenished naturally in a relatively "
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
"renewable and non-renewable energy sources:\n"
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
"energy sources are finite and will eventually run out.\n"
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
"and other negative effects.\n"
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
"have lower operational costs than non-renewable sources.\n"
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
"locations than non-renewable sources.\n"
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
),
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_vicuna_v1 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
conv_llama_2 = Conversation(
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being \
safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. \
Please ensure that your responses are socially unbiased and positive in nature. \
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
correct. If you don't know the answer to a question, please don't share false information.""",
roles=("USER", "ASSISTANT"),
version="llama_2",
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA_2,
sep="<s>",
sep2="</s>",
)
conv_mistral = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("USER", "ASSISTANT"),
version="mistral",
messages=(),
offset=0,
sep_style=SeparatorStyle.MISTRAL,
sep="<s>",
sep2="</s>",
)
conv_qwen2 = Conversation(
system="You are a helpful assistant.",
roles=("user", "assistant"),
version="qwen2",
messages=(),
offset=0,
sep_style=SeparatorStyle.QWEN2,
sep="<|im_start|>",
sep2="<|im_end|>\n",
)
conv_gemma2 = Conversation(
system="You are a helpful assistant.",
roles=("user", "model"),
version="gemma2",
messages=(),
offset=0,
sep_style=SeparatorStyle.GEMMA2,
sep="<start_of_turn>",
sep2="<end_of_turn>\n",
)
conv_valley_v1 = Conversation(
system="You are Valley, a large language and vision assistant trained by ByteDance."
"You are able to understand the visual content or video that the user provides, \
and assist the user with a variety of tasks using natural language."
"Follow the instructions carefully and explain your answers in detail.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
conv_valley_mfe_v1 = Conversation(
system="You are Valley, a large-scale language and vision assistant designed to aid in \
the detection of misleading functionality and effect in input visual content. The currently \
imported video are mainly designed for the e-commerce live streaming field."
"You have the ability to understand multiple languages."
"You can understand videos and help people determine whether there are misleading \
functionality and effect in input visual content. Misleading functional effects refer to exaggerating \
before-and-after comparisons in videos, falsely describing curative effects, and violating objective \
scientific laws. Examples of misleading functional effects include unrealistic before-after comparisons, \
unrealistic promises, false medical promises, or violations of science.' "
"Follow the instructions carefully and explain the reason.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
conv_valley_multilabel = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
conv_llava_v0_mmtag = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant is able to understand the visual content that the user provides, \
and assist the user with a variety of tasks using natural language."
"The visual content will be provided with the following format: <Image>visual content</Image>.",
roles=("Human", "Assistant"),
messages=(
),
offset=0,
sep_style=SeparatorStyle.SINGLE,
sep="###",
version="v0_mmtag",
)
conv_mistral_instruct = Conversation(
system="",
roles=("USER", "ASSISTANT"),
version="llama_2",
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA_2,
sep="",
sep2="</s>",
)
conv_llava_v1 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
conv_belle = Conversation(
system="你是字节跳动训练的大型语言视觉助手 Chinese-Valley。"
"你能够理解用户提供的视觉内容或视频,并使用自然语言协助用户完成各种任务。"
"请仔细按照人类的指令进行回答,并详细解释你的答案。",
roles=("Human", "Assistant"),
messages=[],
offset=0,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_no_system = Conversation(
system=None
)
conv_void_system = Conversation(
system=''
)
default_conversation = conv_vicuna_v0
conv_templates = {
"default": conv_vicuna_v0,
"v0": conv_vicuna_v0,
"v1": conv_vicuna_v1,
"vicuna_v1": conv_vicuna_v1,
"llama_2": conv_llama_2,
"v0_mmtag": conv_llava_v0_mmtag,
"llava_v1": conv_llava_v1,
"belle": conv_belle,
'valley_v1': conv_valley_v1,
'mistral': conv_mistral,
"valley_v1": conv_valley_v1,
"mistral_instruct": conv_mistral_instruct,
"qwen2": conv_qwen2,
"gemma2": conv_gemma2,
}
prompt_templates = {
"mfe_v1": conv_valley_mfe_v1,
"mfe_v1": conv_valley_mfe_v1,
"multilabel_v1": conv_valley_multilabel,
"no": conv_no_system,
'void': conv_void_system
}
if __name__ == "__main__":
print(default_conversation.get_prompt())
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers import AutoConfig, AutoModelForCausalLM, Qwen2Config, Qwen2ForCausalLM, Qwen2Model
from transformers.modeling_outputs import CausalLMOutputWithPast
from ..valley_arch import ValleyMetaModel, ValleyMetaForCausalLM
class ValleyConfig(Qwen2Config):
model_type = "valley"
class ValleyQwen2Model(ValleyMetaModel, Qwen2Model):
config_class = ValleyConfig
def __init__(self, config: Qwen2Config):
super(ValleyQwen2Model, self).__init__(config)
class ValleyQwen2ForCausalLM(Qwen2ForCausalLM, ValleyMetaForCausalLM):
config_class = ValleyConfig
def __init__(self, config):
super(Qwen2ForCausalLM, self).__init__(config)
self.model = ValleyQwen2Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
image_sizes: Optional[List[List[int]]] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
pack_ids=None
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels
) = self.prepare_inputs_labels_for_multimodal(
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
images,
image_sizes,
pixel_values,
pixel_values_videos,
image_grid_thw,
video_grid_thw,
pack_ids
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss(reduction='mean')
bs = shift_labels.shape[0]
shift_labels = shift_labels.to(shift_logits.device)
loss = torch.stack([loss_fct(shift_logits[i], shift_labels[i]) for i in range(bs)])
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values:
input_ids = input_ids[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"images": kwargs.get("images", None),
"image_sizes": kwargs.get("image_sizes", None),
"pixel_values": kwargs.get("pixel_values", None),
"pixel_values_videos": kwargs.get("pixel_values_videos", None),
"image_grid_thw": kwargs.get("image_grid_thw", None),
"video_grid_thw": kwargs.get("video_grid_thw", None),
"pack_ids": kwargs.get("pack_ids", None),
}
)
return model_inputs
AutoConfig.register("valley", ValleyConfig)
AutoModelForCausalLM.register(ValleyConfig, ValleyQwen2ForCausalLM)
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers import Qwen2VLConfig, Qwen2VLModel, Qwen2VLForConditionalGeneration
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast
from transformers.models.qwen2_vl.modeling_qwen2_vl import _prepare_4d_causal_attention_mask_with_cache_position
from transformers.cache_utils import StaticCache
from ..valley_arch import ValleyMetaModel, ValleyMetaForCausalLM
class ValleyConfig(Qwen2VLConfig):
model_type = "valley"
class ValleyQwen2VLModel(ValleyMetaModel, Qwen2VLModel):
config_class = ValleyConfig
def __init__(self, config: Qwen2VLConfig):
super(ValleyQwen2VLModel, self).__init__(config)
class ValleyQwen2VLForCausalLM(Qwen2VLForConditionalGeneration, ValleyMetaForCausalLM):
config_class = ValleyConfig
def __init__(self, config):
super(ValleyQwen2VLForCausalLM, self).__init__(config)
self.model = ValleyQwen2VLModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
self.visual.requires_grad_(False)
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
**kwargs
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
>>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
>>> messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"The image shows a street scene with a red stop sign in the foreground."
"In the background, there is a large red gate with Chinese characters ..."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.get_dtype())
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).to(inputs_embeds.device)
image_mask = input_ids == self.config.image_token_id
if self.training:
inputs_embeds = inputs_embeds.clone()
inputs_embeds[image_mask] = image_embeds
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw).to(inputs_embeds.device)
video_mask = input_ids == self.config.video_token_id
inputs_embeds[video_mask] = video_embeds
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
outputs = self.model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return Qwen2VLCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=rope_deltas,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
pixel_values=None,
pixel_values_videos=None,
image_grid_thw=None,
video_grid_thw=None,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0]:]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
rope_deltas = kwargs.get("rope_deltas", None)
if attention_mask is not None and position_ids is None:
if cache_position is None or (cache_position is not None and cache_position[0] == 0):
position_ids, rope_deltas = self.get_rope_index(
input_ids, image_grid_thw, video_grid_thw, attention_mask
)
else:
batch_size, seq_length = input_ids.shape
delta = (
cache_position[0] + rope_deltas if cache_position is not None and rope_deltas is not None else 0
)
position_ids = torch.arange(seq_length, device=input_ids.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
if cache_position[0] != 0:
pixel_values = None
pixel_values_videos = None
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"pixel_values_videos": pixel_values_videos,
"image_grid_thw": image_grid_thw,
"video_grid_thw": video_grid_thw,
"rope_deltas": rope_deltas,
}
)
return model_inputs
import torch
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
from ...util.vision_encoder_config import qwen2vl_vit_config
def build_vision_tower(vision_tower_cfg, **kwargs):
vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None))
if getattr(vision_tower_cfg, "language", None) is None:
vision_tower_cfg.language = "chinese" if "chinese" in vision_tower else "english"
print(f"language: {vision_tower_cfg.language}, vision_tower: {vision_tower}")
if "siglip-so400m-patch14-384" in vision_tower:
from .siglip_encoder import SigLipVisionTower
qwen2vl_vision_tower = Qwen2VisionTransformerPretrainedModel._from_config(qwen2vl_vit_config)
qwen2vl_vision_tower.requires_grad_(False)
return SigLipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs), qwen2vl_vision_tower
else:
raise ValueError(f"Unknown vision tower: {vision_tower}")
import torch
import torch.nn as nn
from ...util.vision_encoder_config import siglip_config
class SigLipVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False, cache_dir="./cache_dir"):
super().__init__()
self.is_loaded = False
self.image_tower_name = vision_tower
self.select_layer = args.mm_vision_select_layer
self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
self.cache_dir = cache_dir
if not delay_load:
self.load_model()
else:
from transformers import SiglipVisionConfig, SiglipVisionModel
self.cfg_only = SiglipVisionConfig.from_pretrained(self.image_tower_name, cache_dir=self.cache_dir)
self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name) # dummy-load
def load_model(self):
from transformers import SiglipImageProcessor, SiglipVisionModel
self.image_processor = SiglipImageProcessor.from_pretrained(self.image_tower_name)
self.vision_tower = SiglipVisionModel._from_config(siglip_config)
self.vision_tower.requires_grad_(False)
self.image_processor.crop_size = self.image_processor.size["height"]
self.is_loaded = True
def feature_select(self, image_forward_outs):
assert self.select_feature == "cls_patch"
image_features = torch.cat([image_forward_outs[:, :1, :], image_forward_outs], dim=1)
return image_features
@torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(
image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
output_hidden_states=True,
return_dict=True,
)
image_feature = self.feature_select(image_forward_out.last_hidden_state).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(
images.to(device=self.device, dtype=self.dtype),
output_hidden_states=True,
return_dict=True,
)
image_features = self.feature_select(image_forward_outs.last_hidden_state).to(images.dtype)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
import torch
import torch.nn as nn
import re
import math
class IdentityMap(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
@property
def config(self):
return {"mm_projector_type": 'identity'}
class IdentityPatchMap(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
"""
It is used to remove the first token (cls token) in the image feature.
Args:
x (torch.Tensor): image features
shape (F, v, D)
Returns:
shape (F, n, D) where n is token_num that has been reduced (n = v - 1)
"""
return x[:, 1:, :]
@property
def config(self):
return {"mm_projector_type": 'identity_patch'}
class SimpleResBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.pre_norm = nn.LayerNorm(channels)
self.proj = nn.Sequential(
nn.Linear(channels, channels),
nn.GELU(),
nn.Linear(channels, channels)
)
def forward(self, x):
x = self.pre_norm(x)
return x + self.proj(x)
class SpatialPyramidPooling(nn.Module):
def __init__(self, pool_sizes=[2, 3, 5, 8], pool_mode='max'):
super(SpatialPyramidPooling, self).__init__()
self.pool_sizes = pool_sizes
self.pooling_method = {'max': nn.AdaptiveMaxPool2d, 'mean': nn.AdaptiveAvgPool2d}[pool_mode]
self.layers = [self.pooling_method(i) for i in pool_sizes]
def forward(self, x):
b, c, h, W = x.size()
pooled = []
for p in self.layers:
pooled.append(p(x).view(b, c, -1))
return torch.cat(pooled, -1)
class LinearAdapter(nn.Linear):
def __init__(self, mm_hidden_size, hidden_size):
super(LinearAdapter, self).__init__(mm_hidden_size, hidden_size)
self.mm_projector_type = 'linear'
class ConvAdapter(nn.Module):
def __init__(self, dim_in, dim_out, mlp_hidden_dim=None):
super().__init__()
self.mm_projector_type = 'conv_adapter'
if mlp_hidden_dim is None:
self.mlp = nn.Sequential(
nn.Linear(dim_in, dim_out),
nn.GELU(),
nn.Linear(dim_out, dim_out)
)
else:
self.mlp = nn.Sequential(
nn.Linear(dim_in, mlp_hidden_dim),
nn.GELU(),
nn.Linear(mlp_hidden_dim, dim_out)
)
self.conv = nn.Conv2d(dim_out, dim_out, kernel_size=(3, 3), stride=(2, 2), padding=1)
def forward(self, x):
"""
Args:
x (torch.Tensor): image features
shape (F, v, D)
Returns:
shape (F, n, D) where n is token_num that has been reduced
"""
x = self.mlp(x)
f, v, d = x.shape
s = int(math.sqrt(v - 1))
x = x[:, 1:, :] # remove cls_token
x = x.reshape(f, s, s, d).permute([0, 3, 1, 2])
x = self.conv(x)
x = x.permute([0, 2, 3, 1]).reshape(f, -1, d)
return x
class PoolAdapter(nn.Module):
def __init__(self, dim_in, dim_out, pool_out_size=4):
super().__init__()
self.mm_projector_type = 'pool_adapter'
self.pool_h, self.pool_w = pool_out_size, pool_out_size
self.mlp = nn.Sequential(
nn.Linear(dim_in, dim_out),
nn.GELU(),
nn.Linear(dim_out, dim_out)
)
def forward(self, x):
"""
Args:
x (torch.Tensor): image features
shape (F, v, D)
Returns:
shape (F, n, D) where n is token_num that has been reduced
"""
f, v, d = x.shape
# print(x.shape) # torch.Size([image_num, vit_token_num, dim_in]) [8, 257, 1024] f, v, d = x.shape
s = int(math.sqrt(v - 1))
x = x[:, 1:, :] # remove cls_token
x = x.reshape(f, s, s, d)
if s % self.pool_h == 0 and s % self.pool_w == 0:
x = x.reshape(f, self.pool_h, s // self.pool_h, self.pool_w, s // self.pool_w, d)
x = x.permute([0, 1, 3, 5, 2, 4]).reshape(f, self.pool_h * self.pool_w, d, -1).mean(-1)
x = self.mlp(x) # [f, h*w, d]
else:
raise ValueError()
return x
class PoolAdapterCLS(nn.Module):
def __init__(self, dim_in, dim_out, pool_out_size=4):
super().__init__()
self.mm_projector_type = 'pool_adapter_w_cls'
self.pool_h, self.pool_w = pool_out_size, pool_out_size
self.mlp = nn.Sequential(
nn.Linear(dim_in, dim_out),
nn.GELU(),
nn.Linear(dim_out, dim_out)
)
def forward(self, x):
"""
Args:
x (torch.Tensor): image features
shape (F, v, D)
Returns:
shape (F, n, D) where n is token_num that has been reduced
"""
f, v, d = x.shape
# print(x.shape) # torch.Size([image_num, vit_token_num, dim_in]) [8, 257, 1024] f, v, d = x.shape
s = int(math.sqrt(v - 1))
cls = x[:, 0, :]
feature = x[:, 1:, :] # remove cls_token
feature = feature.reshape(f, s, s, d)
if s % self.pool_h == 0 and s % self.pool_w == 0:
feature = feature.reshape(f, self.pool_h, s // self.pool_h, self.pool_w, s // self.pool_w, d)
feature = feature.permute([0, 1, 3, 5, 2, 4]).reshape(f, self.pool_h * self.pool_w, d, -1).mean(-1)
feature = torch.concat([cls.unsqueeze(1), feature], dim=1)
feature = self.mlp(feature) # [f, h*w, d]
else:
raise ValueError()
return feature
class AdaptPooler(nn.Module):
def __init__(self, dim_in, dim_out, pool_out_size=4):
super().__init__()
self.mm_projector_type = 'adapt_pooler'
self.pool_h, self.pool_w = pool_out_size, pool_out_size
self.mlp = nn.Sequential(
nn.Linear(dim_in, dim_out),
nn.GELU(),
nn.Linear(dim_out, dim_out)
)
def forward(self, x):
"""
Args:
x (torch.Tensor): image features
shape (F, v, D)
Returns:
shape (F, n, D) where n is token_num that has been reduced
"""
x = self.mlp(x)
f, v, d = x.shape
s = int(math.sqrt(v - 1))
x = x[:, 1:, :] # remove cls_token
x = x.reshape(f, s, s, d)
x = x.reshape(f, self.pool_h, s // self.pool_h, self.pool_w, s // self.pool_w, d)
x = x.permute([0, 1, 3, 5, 2, 4]).reshape(f, self.pool_h * self.pool_w, d, -1).mean(-1)
return x
class AdaptPoolerCLS(nn.Module):
def __init__(self, dim_in, dim_out, pool_out_size=4):
super().__init__()
self.mm_projector_type = 'adapt_pooler_w_cls'
self.pool_h, self.pool_w = pool_out_size, pool_out_size
self.mlp = nn.Sequential(
nn.Linear(dim_in, dim_out),
nn.GELU(),
nn.Linear(dim_out, dim_out)
)
def forward(self, x):
"""
Args:
x (torch.Tensor): image features
shape (F, v, D)
Returns:
shape (F, n, D) where n is token_num that has been reduced
"""
x = self.mlp(x)
f, v, d = x.shape
s = int(math.sqrt(v - 1))
cls = x[:, 0, :]
feature = x[:, 1:, :] # remove cls_token
feature = feature.reshape(f, s, s, d)
feature = feature.reshape(f, self.pool_h, s // self.pool_h, self.pool_w, s // self.pool_w, d)
feature = feature.permute([0, 1, 3, 5, 2, 4]).reshape(f, self.pool_h * self.pool_w, d, -1).mean(-1)
return torch.concat([cls.unsqueeze(1), feature], dim=1)
class AdaptPyraPooler(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.mm_projector_type = 'adapt_pyrapooler'
self.mlp = nn.Sequential(
nn.Linear(dim_in, dim_out),
nn.GELU(),
nn.Linear(dim_out, dim_out)
)
self.pool = SpatialPyramidPooling([2, 3, 5, 8], pool_mode='max')
def forward(self, x):
"""
Args:
x (torch.Tensor): image features
shape (F, v, D)
Returns:
shape (F, n, D) where n is token_num that has been reduced
"""
x = self.mlp(x)
f, v, d = x.shape
s = int(math.sqrt(v - 1))
x = x[:, 1:, :] # remove cls_token
x = x.reshape(f, s, s, d).permute([0, 3, 1, 2])
x = self.pool(x).permute([0, 2, 1])
return x
class MlpPixelShuffle(nn.Module):
def __init__(self, dim_in, dim_out, pixelshuffle_downsample_ratio, mlp_hidden_dim=None):
super().__init__()
self.mm_projector_type = 'mlp_pixel_shuffle'
if mlp_hidden_dim is None:
self.mlp = nn.Sequential(
nn.Linear(int(dim_in * (pixelshuffle_downsample_ratio ** 2)), dim_out),
nn.GELU(),
nn.Linear(dim_out, dim_out)
)
else:
self.mlp = nn.Sequential(
nn.Linear(int(dim_in * (pixelshuffle_downsample_ratio ** 2)), mlp_hidden_dim),
nn.GELU(),
nn.Linear(mlp_hidden_dim, dim_out)
)
self.scale_factor = pixelshuffle_downsample_ratio
def pixel_shuffle(self, x, scale_factor=2):
# change scale_factor from float to int
n, w, h, c = x.size()
# N, W, H, C --> N, W, H / scale, C * scale
x = x.view(n, w, int(h / scale_factor), int(c * scale_factor))
# N, W, H / scale, C * scale --> N, H / scale, W, C * scale
x = x.permute(0, 2, 1, 3).contiguous()
# N, H / scale, W, C * scale --> N, H / scale, W / scale, C * (scale ** 2)
x = x.view(n, int(h / scale_factor), int(w / scale_factor),
int(c * (scale_factor * scale_factor)))
x = x.permute(0, 2, 1, 3).contiguous()
return x
def forward(self, x):
"""
Args:
x (torch.Tensor): image features
shape (F, v, D)
Returns:
shape (F, n, D) where n is token_num that has been reduced
"""
x = x[:, 1:, :] # remove cls_token
h = w = int(x.shape[1] ** 0.5)
x = x.view(x.shape[0], h, w, -1)
x = self.pixel_shuffle(x, self.scale_factor)
x = self.mlp(x)
x = x.view(x.shape[0],-1,x.shape[-1])
return x
class OvisConvAdapter(nn.Module):
def __init__(self, dim_in, dim_out, vocab_size, tokenize_function="softmax"):
super().__init__()
self.mm_projector_type = 'ovis_conv_adapter'
self.conv = nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), padding=1)
self.mlp = torch.nn.Sequential(
torch.nn.Linear(dim_in, vocab_size, bias=False),
torch.nn.LayerNorm(vocab_size)
)
self.embedding = torch.nn.Embedding(vocab_size, dim_out)
self.tokenize_function = tokenize_function
def tokenize(self, logits):
def st_argmax(y_soft, dim): # straight-through softmax
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(y_soft, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
ret = y_hard - y_soft.detach() + y_soft
return ret
if self.tokenize_function == 'softmax':
tokens = torch.nn.functional.softmax(logits, dim=-1)
elif self.tokenize_function == 'gumbel_argmax':
tokens = torch.nn.functional.gumbel_softmax(logits, tau=self.config.tau, hard=True)
elif self.tokenize_function == 'st_argmax':
tokens = st_argmax(logits, dim=-1)
else:
raise ValueError(
'Invalid `max_type`, expected softmax or gumbel_argmax or st_argmax,'
f' but got {self.config.tokenize_function}'
)
return tokens
def forward(self, x):
"""
Args:
x (torch.Tensor): image features
shape (F, v, D)
Returns:
shape (F, n, D) where n is token_num that has been reduced
"""
# conv
f, v, d = x.shape
s = int(math.sqrt(v - 1))
x = x[:, 1:, :] # remove cls_token
x = x.reshape(f, s, s, d).permute([0, 3, 1, 2])
x = self.conv(x)
x = x.permute([0, 2, 3, 1]).reshape(f, -1, d)
# tokenize
logits = self.mlp(x)
visual_tokens = self.tokenize(logits)
# get embeddings
out = torch.matmul(visual_tokens, self.embedding.weight)
return out
def build_vision_projector(config, delay_load=False, **kwargs):
projector_type = getattr(config, 'mm_projector_type', 'linear')
if projector_type == 'linear':
return LinearAdapter(config.mm_hidden_size, config.hidden_size)
elif projector_type == 'pool_adapter':
return PoolAdapter(config.mm_hidden_size, config.hidden_size, config.pool_out_size)
elif projector_type == 'adapt_pooler':
return AdaptPooler(config.mm_hidden_size, config.hidden_size, config.pool_out_size)
elif projector_type == 'adapt_pyrapooler':
return AdaptPyraPooler(config.mm_hidden_size, config.hidden_size)
elif projector_type == 'adapt_pooler_w_cls':
return AdaptPoolerCLS(config.mm_hidden_size, config.hidden_size, config.pool_out_size)
elif projector_type == 'pool_adapter_w_cls':
return PoolAdapterCLS(config.mm_hidden_size, config.hidden_size, config.pool_out_size)
elif projector_type == 'conv_adapter':
return ConvAdapter(config.mm_hidden_size, config.hidden_size, getattr(config, "mlp_hidden_dim", None))
elif projector_type == 'mlp_pixel_shuffle':
return MlpPixelShuffle(config.mm_hidden_size, config.hidden_size,
config.pixelshuffle_downsample_ratio, getattr(config, "mlp_hidden_dim", None))
elif projector_type == 'ovis_conv_adapter':
return OvisConvAdapter(config.mm_hidden_size, config.hidden_size, getattr(config, "mlp_hidden_dim", 32000),
getattr(config, "tokenize_function", "softmax"))
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
mm_projector = nn.Sequential(*modules)
# this line is for fixing bug in valley/model/valley_arch.py line 72.
# If the projector is 2 layer mlp, projector has no attr named mm_projector_type.
mm_projector.mm_projector_type = projector_type
return mm_projector
if projector_type == 'identity':
return IdentityMap()
if projector_type == 'identity_patch':
return IdentityPatchMap()
raise ValueError(f'Unknown projector type: {projector_type}')
from torch import nn
class AvgPoolTokenCompressor(nn.Module):
"""
A PyTorch module for compressing tokens using average pooling.
This module performs average pooling on the input tensor to reduce its spatial dimensions
by a specified scale factor.
Attributes:
scale (int): The scale factor for downsampling.
Example:
>>> compressor = AvgPoolTokenCompressor(scale=2)
>>> input_tensor = torch.randn(1, 256, 4096) # Shape: [B, N, dim]
>>> output_tensor = compressor(input_tensor)
>>> print(output_tensor.shape) # Expected shape: [1, 64, 4096]
"""
def __init__(self, scale) -> None:
super(AvgPoolTokenCompressor, self).__init__()
self.scale = scale
def _inner_forward(self, x):
scale = self.scale
B, N, dim = x.shape
H = W = int(N ** 0.5)
x = x.view(B, H, W, dim)
return x.view(B, H // scale, scale, W // scale, scale, dim) \
.permute(0, 1, 3, 5, 2, 4) \
.reshape(B, H // scale, W // scale, dim, scale * scale) \
.mean(dim=-1) \
.squeeze(dim=-1) \
.reshape(B, -1, dim)
def forward(self, x):
if type(x) is list:
x = [self._inner_forward(item.unsqueeze(0)).squeeze(0) for item in x]
else:
x = self._inner_forward(x)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment