import sys import torch from PIL import Image from .base import BaseModel from ..smp import * from ..dataset import DATASET_TYPE class mPLUG_Owl2(BaseModel): INSTALL_REQ = True INTERLEAVE = False def __init__(self, model_path='MAGAer13/mplug-owl2-llama2-7b', **kwargs): try: from mplug_owl2.model.builder import load_pretrained_model from mplug_owl2.mm_utils import get_model_name_from_path except Exception as e: logging.critical('Please install mPLUG_Owl2 before using mPLUG_Owl2. ') raise e model_name = get_model_name_from_path(model_path) tokenizer, model, image_processor, context_len = load_pretrained_model( model_path, None, model_name, load_8bit=False, load_4bit=False, device='cpu') self.model = model.cuda() self.device = self.model.device self.image_processor = image_processor tokenizer.padding_side = 'left' tokenizer.pad_token_id = tokenizer.eos_token_id self.tokenizer = tokenizer self.context_len = context_len kwargs_default = dict( max_new_tokens=512, do_sample=False, num_beams=1, min_new_tokens=1, length_penalty=1, num_return_sequences=1) kwargs_default.update(kwargs) self.kwargs = kwargs_default warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ') def use_custom_prompt(self, dataset): assert dataset is not None if listinstr(['MMMU'], dataset): return False if DATASET_TYPE(dataset) == 'MCQ' or dataset == 'MMVet': return True return False def build_prompt(self, line, dataset=None): assert dataset is None or isinstance(dataset, str) assert self.use_custom_prompt(dataset) tgt_path = self.dump_image(line, dataset) question = line['question'] if dataset == 'MMVet': prompt = question + '\nAnswer the question directly. ' elif DATASET_TYPE(dataset) == 'MCQ': options = { cand: line[cand] for cand in string.ascii_uppercase if cand in line and not pd.isna(line[cand]) } options_prompt = '' for key, item in options.items(): options_prompt += f'{key}. {item}\n' hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None prompt = f'Hint: {hint}\n' if hint is not None else '' prompt += f'{question}\n' prompt += ( f'{options_prompt}\nAnswer with the option’s letter from the given choices directly. ' if len(options) else 'Answer the question directly. ' ) else: raise NotImplementedError message = [dict(type='text', value=prompt)] message.extend([dict(type='image', value=s) for s in tgt_path]) return message def generate_inner(self, message, dataset=None): from mplug_owl2.constants import IMAGE_TOKEN_INDEX from mplug_owl2.mm_utils import process_images, tokenizer_image_token kwargs = cp.deepcopy(self.kwargs) if dataset in ['MMVet', 'LLaVABench']: kwargs['length_penalty'] = 0 elif dataset is not None and DATASET_TYPE(dataset) == 'VQA': kwargs['length_penalty'] = 0 elif dataset is not None and DATASET_TYPE(dataset) == 'MCQ': kwargs['max_new_tokens'] = 10 num_images = len([x for x in message if x['type'] == 'image']) assert num_images >= 0 prompt_full = 'USER: ' images = [] if num_images == 1: prompt, image = self.message_to_promptimg(message, dataset=dataset) prompt_full += f'<|image|>{prompt} \nASSISTANT: ' images.append(image) else: for msg in message: if msg['type'] == 'image': images.append(msg['value']) prompt_full += '<|image|>' elif msg['type'] == 'text': prompt_full += msg['value'] prompt_full += '\nASSISTANT: ' def preproc_image(fname): image = Image.open(fname).convert('RGB') max_edge = max(image.size) image = image.resize((max_edge, max_edge)) return image images = [preproc_image(fname) for fname in images] image_tensor = process_images(images, self.image_processor) image_tensor = image_tensor.to(self.device, dtype=torch.float16) input_ids = tokenizer_image_token( prompt_full, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device) with torch.inference_mode(): output_ids = self.model.generate( input_ids=input_ids, images=image_tensor, output_hidden_states=True, use_cache=True, **kwargs) answer = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() return answer.split('')[0]