import re import numpy as np import torch import torchvision.transforms as transforms from PIL import Image, ImageDraw, ImageFont from transformers import AutoModel, AutoTokenizer from ...dataset import DATASET_TYPE from ...smp import * from ..base import BaseModel pattern = re.compile(r'[A-Z]') conv_pattern = '\\[UNUSED_TOKEN_146\\]user\\\n|\\[UNUSED_TOKEN_146\\]assistant\\\n|\\[UNUSED_TOKEN_145\\]' def get_font(): try: truetype_url = "http://opencompass.openxlab.space/utils/Fonts/SimHei.ttf" ff = urlopen(truetype_url) # ff = '/fs-computility/mllm/shared/dongxiaoyi/share_data/SimHei.ttf' font = ImageFont.truetype(ff, size=40) except Exception as e: logging.warning(f'{type(e)}: {e}') logging.warning("Fail to download the font. Use the default one.") font = ImageFont.load_default(size=40) return font def padding_560(b): width, height = b.size tar = int(np.ceil(height / 560) * 560) top_padding = int((tar - height) / 2) bottom_padding = tar - height - top_padding left_padding = 0 right_padding = 0 b = transforms.functional.pad( b, [left_padding, top_padding, right_padding, bottom_padding], fill=[255, 255, 255]) return b def Identity_transform(img, hd_num=25): width, height = img.size trans = False if width < height: img = img.transpose(Image.TRANSPOSE) trans = True width, height = img.size ratio = (width / height) scale = 1 new_h = int(scale * 560) new_w = int(new_h * ratio) # print (new_h, new_w) img = transforms.functional.resize(img, [new_h, new_w],) img = img.transpose(Image.TRANSPOSE) img = padding_560(img) width, height = img.size if not trans: img = img.transpose(Image.TRANSPOSE) return img def HD_transform(img, im_num=36, id_scale=1.5): width, height = img.size trans = False if width < height: img = img.transpose(Image.TRANSPOSE) trans = True width, height = img.size ratio = (width / height) scale = 1 while scale * np.ceil(scale / ratio) <= im_num: scale += 1 scale -= 1 scale = min(np.ceil(width * id_scale / 560), scale) new_w = int(scale * 560) new_h = int(new_w / ratio) img = transforms.functional.resize(img, [new_h, new_w],) img = padding_560(img) width, height = img.size assert width * height <= im_num * 560 * 560 if trans: img = img.transpose(Image.TRANSPOSE) return img def img_process(imgs): new_imgs = [] for img in imgs: w, h = img.size scale = w / h if w > h: new_w = 560 * 2 new_h = int(560 * 2 / scale) else: new_w = int(560 * 2 * scale) new_h = 560 * 2 img = transforms.functional.resize(img, [new_h, new_w],) new_imgs.append(img) imgs = new_imgs new_w = 0 new_h = 0 pad = 40 if w > h: for im in imgs: w,h = im.size new_w = max(new_w, w) new_h += h + 10 + pad font = get_font() new_img = Image.new('RGB', (new_w, new_h), 'white') draw = ImageDraw.Draw(new_img) curr_h = 0 for idx, im in enumerate(imgs): w,h = im.size new_img.paste(im, (0, pad + curr_h)) draw.text((0, curr_h), f'', font=font, fill='black') if idx + 1 < len(imgs): draw.line([(0, pad + curr_h + h + 5), (new_w, pad + curr_h + h + 5)], fill='black', width=2) curr_h += h + 10 + pad # print (new_w, new_h) else: for im in imgs: w,h = im.size new_w += w + 10 new_h = max(new_h, h) new_h += pad font = get_font() new_img = Image.new('RGB', (new_w, new_h), 'white') draw = ImageDraw.Draw(new_img) curr_w = 0 for idx, im in enumerate(imgs): w,h = im.size new_img.paste(im, (curr_w, pad)) draw.text((curr_w, 0), f'', font=font, fill='black') if idx + 1 < len(imgs): draw.line([(curr_w + w + 5, 0), (curr_w + w + 5, new_h)], fill='black', width=2) curr_w += w + 10 return new_img meta_instruction = """You are an AI assistant whose name is InternLM (书生·浦语).\n" + "- InternLM (书生·浦语) \ is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"+"- InternLM (书生·浦语) \ can understand and communicate fluently in the language chosen by the user such as English and 中文.""" def model_gen(model, text, images, need_bos=True, padding=False, beams=3, max_token=500, video_input=False): embeds = [] im_mask = [] # print(text) im_idx = 0 sub_q = text.split('') add_im = len(sub_q) - 1 for subtext in sub_q: if need_bos or len(subtext) > 0: text_embeds = model.encode_text( subtext, add_special_tokens=need_bos) embeds.append(text_embeds) im_mask.append(torch.zeros(text_embeds.shape[:2]).to(model.device)) need_bos = False if im_idx < len(images) and add_im: image = images[im_idx] if video_input: image = Identity_transform(image) else: if len(images) > 1: image = HD_transform(image, im_num=model.hd_num // len(images), id_scale=model.id_scale) else: image = HD_transform( image, im_num=model.hd_num, id_scale=model.id_scale) # print(image.size) image = model.vis_processor(image).unsqueeze(0).to(model.device) image_embeds = model.encode_img(image) im_idx += 1 add_im -= 1 embeds.append(image_embeds) im_mask.append(torch.ones( image_embeds.shape[:2], dtype=torch.long).to(model.device)) embeds = torch.cat(embeds, dim=1) im_mask = torch.cat(im_mask, dim=1) im_mask = im_mask.bool() outputs = model.generate(inputs_embeds=embeds, im_mask=im_mask, temperature=1.0, max_new_tokens=max_token, num_beams=beams, do_sample=False, repetition_penalty=1.0) output_token = outputs[0] if output_token[0] == 0 or output_token[0] == 1: output_token = output_token[1:] output_text = model.tokenizer.decode(output_token, add_special_tokens=False) output_text = output_text.split('[UNUSED_TOKEN_145]')[0].strip().split('<|im_end|>')[0].strip().split('The answer is')[-1].strip() # noqa # print(output_text) return output_text class XComposer2d5(BaseModel): INSTALL_REQ = False INTERLEAVE = True def __init__(self, model_path='internlm/internlm-xcomposer2d5-7b', id_scale=1.5, beam=3, **kwargs): assert model_path is not None self.model_path = model_path self.id_scale = id_scale self.beam = beam model = AutoModel.from_pretrained( self.model_path, device_map='cpu', trust_remote_code=True, local_files_only=True).cuda().eval() model.half() tokenizer = AutoTokenizer.from_pretrained( self.model_path, trust_remote_code=True) model.tokenizer = tokenizer self.model = model self.device = self.model.model.tok_embeddings.weight.device self.model.hd_num = 36 self.model.id_scale = self.id_scale def message_to_promptimg(self, message, dataset=None, video_input=False): num_images = len([x for x in message if x['type'] == 'image']) if num_images == 0: prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text']) image = None else: image = [Image.open(x['value']).convert('RGB') for x in message if x['type'] == 'image'] if video_input: im_prompt = 'Here are some frames of a video.' if len(image) > 64: step = len(image) / 64 image = [image[int(i * step)] for i in range(64)] image = [img_process(image)] else: if len(image) > 1: im_prompt = ' '.join([ f'Image{im_idx+1}: ;' for im_idx in range(len(image))]) else: im_prompt = '' prompt = '' for x in message: if x['type'] == 'text' and x.get('role', '') != 'system': prompt += x['value'] sp = [i for i in re.split(conv_pattern, prompt) if i != '' and i != '\n'] assert len(sp) <= 2 q = sp[0] prompt = f'[UNUSED_TOKEN_146]user\n{im_prompt}{q}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n' for idx in range(10): idx = chr(65 + idx) prompt = prompt.replace(f'({idx})', f'{idx}.') return prompt, image def generate_mme(self, image_path, text): text = text.split('Please answer')[0].strip() text = f'{text} Answer this question briefly' text = f'[UNUSED_TOKEN_146]user\n{text}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n' return model_gen(self.model, text, image_path, need_bos=True, padding=True, beams=self.beam) def generate_multichoice(self, image_path, text, dataset): out = model_gen(self.model, text, image_path, need_bos=True, padding=False, beams=self.beam, max_token=5) if 'mmmu' in dataset.lower(): return out res = pattern.findall(out) if len(res) == 0: print('Error:', out) res = 'Z' return res[0] def generate_vqa(self, image_path, text): out = model_gen(self.model, text, image_path, beams=self.beam, need_bos=True, max_token=100) return out def generate_vanilla(self, image_path, text): out = model_gen(self.model, text, image_path, beams=self.beam, need_bos=True, max_token=500) return out def generate_brief(self, image_path, text): text = '[UNUSED_TOKEN_146]user\nAnswer the question using a single word or phrase.{}\ [UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'.format(text) out = model_gen(self.model, text, image_path, beams=self.beam, need_bos=True, max_token=10) return out def generate_video(self, image_path, text): out = model_gen( self.model, text, image_path, beams=1, # self.beam, need_bos=True, max_token=100, video_input=True) return out def set_max_num(self, dataset): if dataset is not None and listinstr(['MME-RealWorld', 'MME-RealWorld-CN'], dataset): self.model.hd_num = 25 def generate_inner(self, message, dataset=None): self.set_max_num(dataset) with torch.cuda.amp.autocast(): if dataset is None: prompt, image_path = self.message_to_promptimg(message, dataset=dataset) return self.generate_vanilla(image_path, prompt) assert isinstance(dataset, str) if listinstr(['video', 'mvbench'], dataset.lower()): prompt, image_path = self.message_to_promptimg(message, dataset=dataset, video_input=True) return self.generate_video(image_path, prompt) else: prompt, image_path = self.message_to_promptimg(message, dataset=dataset) if dataset == 'MME': return self.generate_mme(image_path, prompt) elif listinstr(['hallu', 'pope'], dataset.lower()): return self.generate_brief(image_path, prompt) elif listinstr(['llava', 'mmvet'], dataset.lower()): return self.generate_vanilla(image_path, prompt) elif dataset is not None and DATASET_TYPE(dataset) == 'MCQ': return self.generate_multichoice(image_path, prompt, dataset) elif listinstr(['MME-RealWorld', 'MME-RealWorld-CN'], dataset): return self.generate_multichoice(image_path, prompt, dataset) elif dataset is not None and DATASET_TYPE(dataset) == 'VQA': return self.generate_vqa(image_path, prompt) else: return self.generate_vanilla(image_path, prompt) def use_custom_prompt(self, dataset): assert dataset is not None if DATASET_TYPE(dataset) == 'MCQ' or DATASET_TYPE(dataset) == 'VQA': return True return False def build_mcqa(self, line): question = line['question'] options = { cand: line[cand] for cand in string.ascii_uppercase if cand in line and not pd.isna(line[cand]) } img_prompt = '[UNUSED_TOKEN_146]user\n' if len(options): options_prompt = '' for key, item in options.items(): options_prompt += f'{key}. {item} ' options_prompt = options_prompt.strip() hint = line['hint'] if ( 'hint' in line and not pd.isna(line['hint'])) else None context = 'N/A' if hint is None else hint mid_prompt = 'Question: ' + question + '\nContext: ' + \ context + '\nOptions: ' + options_prompt ans_prompt = '[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\nThe answer is' prompt = img_prompt + mid_prompt + ans_prompt else: mid_prompt = f'Answer the question using a single word or phrase.{question}' ans_prompt = '[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n' prompt = img_prompt + mid_prompt + ans_prompt return prompt def build_prompt(self, line, dataset=None): assert dataset is None or isinstance(dataset, str) assert self.use_custom_prompt(dataset) tgt_path = self.dump_image(line, dataset) if DATASET_TYPE(dataset) == 'MCQ': prompt = self.build_mcqa(line) elif DATASET_TYPE(dataset) == 'VQA': if 'mathvista' in dataset.lower(): q = line['question'] prompt = f'[UNUSED_TOKEN_146]user\n{q}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n' elif listinstr(['llava', 'mmvet'], dataset.lower()): q = line['question'] prompt = '[UNUSED_TOKEN_146]system\n{}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]user\n{}\ Answer this question in detail.[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]\ assistant\n'.format(meta_instruction, q) elif listinstr(['mmlongbench_doc', 'dude', 'slidevqa'], dataset.lower()): q = line['question'] prompt = f'[UNUSED_TOKEN_146]user\n{q}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n' else: q = line['question'] prefix = 'Answer the question using a single word or phrase.' prompt = f'[UNUSED_TOKEN_146]user\n{prefix}{q}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n' ret = [dict(type='text', value=prompt)] ret.extend([dict(type='image', value=s) for s in tgt_path]) return ret