from PIL import Image import torch from .base import BaseModel from ..smp import * class XGenMM(BaseModel): INSTALL_REQ = False INTERLEAVE = True def __init__(self, model_path='Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5', **kwargs): try: from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoImageProcessor except Exception as err: logging.critical('Please install the latest version transformers.') raise err model = AutoModelForVision2Seq.from_pretrained( model_path, device_map='cuda', trust_remote_code=True, torch_dtype='auto' ).eval() tokenizer = AutoTokenizer.from_pretrained( model_path, trust_remote_code=True, use_fast=False, legacy=False ) tokenizer = model.update_special_tokens(tokenizer) tokenizer.eos_token = '<|end|>' tokenizer.padding_side = 'left' image_processor = AutoImageProcessor.from_pretrained(model_path, trust_remote_code=True) self.model = model self.image_processor = image_processor self.tokenizer = tokenizer self.kwargs = kwargs def apply_prompt_template(self, query): s = ( '<|system|>\nA chat between a curious user and an artificial intelligence assistant. ' "The assistant gives helpful, detailed, and polite answers to the user's questions.<|end|>\n" f'<|user|>\n{query}<|end|>\n<|assistant|>\n' ) return s def generate_inner(self, message, dataset=None): content, images, image_sizes = '', [], [] for msg in message: if msg['type'] == 'text': content += msg['value'] elif msg['type'] == 'image': image = Image.open(msg['value']).convert('RGB') images.append(self.image_processor([image], image_aspect_ratio='anyres')['pixel_values'].to('cuda')) image_sizes.append(image.size) content += ' ' inputs = {'pixel_values': [images]} prompt = self.apply_prompt_template(content) language_inputs = self.tokenizer([prompt], return_tensors='pt').to('cuda') inputs.update(language_inputs) generation_args = { 'max_new_tokens': 1024, 'temperature': 0.0, 'do_sample': False, 'top_p': None, 'num_beams': 1 } generation_args.update(self.kwargs) generate_ids = self.model.generate( **inputs, image_size=[image_sizes], pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, **generation_args ) # remove input tokens response = self.tokenizer.decode(generate_ids[0], skip_special_tokens=True).split('<|end|>')[0] return response