import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import warnings
import copy as cp
from .base import BaseModel
from ..smp import isimg, listinstr
from ..dataset import DATASET_TYPE
class QwenVL(BaseModel):
INSTALL_REQ = False
INTERLEAVE = True
def __init__(self, model_path='Qwen/Qwen-VL', **kwargs):
assert model_path is not None
self.model_path = model_path
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer.padding_side = 'left'
tokenizer.pad_token_id = tokenizer.eod_id
self.tokenizer = tokenizer
self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map='cuda', trust_remote_code=True).eval()
default_kwargs = dict(
do_sample=False,
num_beams=1,
max_new_tokens=512,
min_new_tokens=1,
num_return_sequences=1,
use_cache=True,
output_hidden_states=True,
pad_token_id=tokenizer.eod_id,
eos_token_id=tokenizer.eod_id)
default_kwargs.update(kwargs)
self.kwargs = default_kwargs
warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
torch.cuda.empty_cache()
def adjust_kwargs(self, dataset):
kwargs = cp.deepcopy(self.kwargs)
if DATASET_TYPE(dataset) in ['MCQ', 'Y/N']:
kwargs['max_new_tokens'] = 32
elif DATASET_TYPE(dataset) == 'Caption' and 'COCO' in dataset:
kwargs['max_new_tokens'] = 32
elif DATASET_TYPE(dataset) == 'VQA':
if listinstr(['OCRVQA', 'ChartQA', 'DocVQA'], dataset):
kwargs['max_new_tokens'] = 100
elif listinstr(['TextVQA'], dataset):
kwargs['max_new_tokens'] = 10
return kwargs
def generate_inner(self, message, dataset=None):
if dataset is not None:
kwargs = self.adjust_kwargs(dataset)
else:
kwargs = self.kwargs
prompt = ''
for s in message:
if s['type'] == 'image':
prompt += f'
{s["value"]}'
elif s['type'] == 'text':
prompt += s['value']
if dataset is not None and DATASET_TYPE(dataset) == 'VQA':
prompt += ' Answer:'
encoded = self.tokenizer([prompt], return_tensors='pt', padding='longest')
input_ids = encoded.input_ids.to('cuda')
attention_mask = encoded.attention_mask.to('cuda')
pred = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
**kwargs)
answer = self.tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
return answer
class QwenVLChat(BaseModel):
INSTALL_REQ = False
INTERLEAVE = True
def __init__(self, model_path='Qwen/Qwen-VL-Chat', **kwargs):
assert model_path is not None
self.model_path = model_path
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map='cuda', trust_remote_code=True).eval()
torch.cuda.empty_cache()
self.kwargs = kwargs
warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
def build_history(self, message):
def concat_tilist(tilist):
image_cnt = 1
prompt = ''
for item in tilist:
if item['type'] == 'text':
prompt += item['value']
elif item['type'] == 'image':
prompt += f"Picture {image_cnt}:
{item['value']}\n"
image_cnt += 1
return prompt
assert len(message) % 2 == 0
hist = []
for i in range(len(message) // 2):
m1, m2 = message[2 * i], message[2 * i + 1]
assert m1['role'] == 'user' and m2['role'] == 'assistant'
hist.append((concat_tilist(m1['content']), concat_tilist(m2['content'])))
return hist
def generate_inner(self, message, dataset=None):
vl_list = [{'image': s['value']} if s['type'] == 'image' else {'text': s['value']} for s in message]
query = self.tokenizer.from_list_format(vl_list)
response, _ = self.model.chat(self.tokenizer, query=query, history=None, **self.kwargs)
return response
def chat_inner(self, message, dataset=None):
assert len(message) % 2 == 1 and message[-1]['role'] == 'user'
history = self.build_history(message[:-1])
vl_list = [
{'image': s['value']} if s['type'] == 'image' else {'text': s['value']}
for s in message[-1]['content']
]
query = self.tokenizer.from_list_format(vl_list)
response, _ = self.model.chat(self.tokenizer, query=query, history=history, **self.kwargs)
return response