import torch from transformers import AutoModelForCausalLM, AutoTokenizer import warnings import copy as cp from .base import BaseModel from ..smp import isimg, listinstr from ..dataset import DATASET_TYPE class QwenVL(BaseModel): INSTALL_REQ = False INTERLEAVE = True def __init__(self, model_path='Qwen/Qwen-VL', **kwargs): assert model_path is not None self.model_path = model_path tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer.padding_side = 'left' tokenizer.pad_token_id = tokenizer.eod_id self.tokenizer = tokenizer self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map='cuda', trust_remote_code=True).eval() default_kwargs = dict( do_sample=False, num_beams=1, max_new_tokens=512, min_new_tokens=1, num_return_sequences=1, use_cache=True, output_hidden_states=True, pad_token_id=tokenizer.eod_id, eos_token_id=tokenizer.eod_id) default_kwargs.update(kwargs) self.kwargs = default_kwargs warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ') torch.cuda.empty_cache() def adjust_kwargs(self, dataset): kwargs = cp.deepcopy(self.kwargs) if DATASET_TYPE(dataset) in ['MCQ', 'Y/N']: kwargs['max_new_tokens'] = 32 elif DATASET_TYPE(dataset) == 'Caption' and 'COCO' in dataset: kwargs['max_new_tokens'] = 32 elif DATASET_TYPE(dataset) == 'VQA': if listinstr(['OCRVQA', 'ChartQA', 'DocVQA'], dataset): kwargs['max_new_tokens'] = 100 elif listinstr(['TextVQA'], dataset): kwargs['max_new_tokens'] = 10 return kwargs def generate_inner(self, message, dataset=None): if dataset is not None: kwargs = self.adjust_kwargs(dataset) else: kwargs = self.kwargs prompt = '' for s in message: if s['type'] == 'image': prompt += f'{s["value"]}' elif s['type'] == 'text': prompt += s['value'] if dataset is not None and DATASET_TYPE(dataset) == 'VQA': prompt += ' Answer:' encoded = self.tokenizer([prompt], return_tensors='pt', padding='longest') input_ids = encoded.input_ids.to('cuda') attention_mask = encoded.attention_mask.to('cuda') pred = self.model.generate( input_ids=input_ids, attention_mask=attention_mask, **kwargs) answer = self.tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip() return answer class QwenVLChat(BaseModel): INSTALL_REQ = False INTERLEAVE = True def __init__(self, model_path='Qwen/Qwen-VL-Chat', **kwargs): assert model_path is not None self.model_path = model_path self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map='cuda', trust_remote_code=True).eval() torch.cuda.empty_cache() self.kwargs = kwargs warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ') def build_history(self, message): def concat_tilist(tilist): image_cnt = 1 prompt = '' for item in tilist: if item['type'] == 'text': prompt += item['value'] elif item['type'] == 'image': prompt += f"Picture {image_cnt}: {item['value']}\n" image_cnt += 1 return prompt assert len(message) % 2 == 0 hist = [] for i in range(len(message) // 2): m1, m2 = message[2 * i], message[2 * i + 1] assert m1['role'] == 'user' and m2['role'] == 'assistant' hist.append((concat_tilist(m1['content']), concat_tilist(m2['content']))) return hist def generate_inner(self, message, dataset=None): vl_list = [{'image': s['value']} if s['type'] == 'image' else {'text': s['value']} for s in message] query = self.tokenizer.from_list_format(vl_list) response, _ = self.model.chat(self.tokenizer, query=query, history=None, **self.kwargs) return response def chat_inner(self, message, dataset=None): assert len(message) % 2 == 1 and message[-1]['role'] == 'user' history = self.build_history(message[:-1]) vl_list = [ {'image': s['value']} if s['type'] == 'image' else {'text': s['value']} for s in message[-1]['content'] ] query = self.tokenizer.from_list_format(vl_list) response, _ = self.model.chat(self.tokenizer, query=query, history=history, **self.kwargs) return response