import torch from PIL import Image from abc import abstractproperty from .base import BaseModel from ..smp import * from ..dataset import DATASET_TYPE import warnings class Mantis(BaseModel): """ Mantis Model This implementation is adpated from the Llava model from llava.py and the Idefics model from idefics.py """ INSTALL_REQ = True INTERLEAVE = True DEFAULT_IMAGE_TOKEN = '' IMAGE_TOKEN_INDEX = -200 def __init__(self, model_path='TIGER-Lab/Mantis-8B-siglip-llama3', **kwargs): assert model_path is not None try: from mantis.models.mllava import LlavaForConditionalGeneration, MLlavaProcessor from mantis.models.mfuyu import MFuyuForCausalLM, MFuyuProcessor from mantis.models.conversation import conv_mllava_v1 as default_conv, conv_templates except Exception as e: logging.critical( "Mantis is not installed. Please install Mantis to use this model.Please use 'pip install " "git+https://github.com/TIGER-AI-Lab/Mantis.git' to install" ) raise e try: from transformers import AutoModelForVision2Seq, AutoProcessor except Exception as e: logging.critical(f'{type(e)}: {e}') logging.critical("Upgrade transformers to use Mantis's idefics model.\nError: %s" % e) # inference implementation for attention, can be "sdpa", "eager", "flash_attention_2". # Seems FA2 is not effective during inference: # https://discuss.huggingface.co/t/flash-attention-has-no-effect-on-inference/73453/5 # if is_flash_attn_2_available: # best_fit_attn_implementation = "flash_attention_2" # flash_attn has a bug that says: ERROR Error query and key must have the same dtype in generating try: import flash_attn best_fit_attn_implementation = 'flash_attention_2' except ImportError: best_fit_attn_implementation = 'eager' self.model_path = model_path attn_implementation = best_fit_attn_implementation self._is_idefics = 'idefics' in model_path.lower() # Here load the "non-idefics" Mantis model. if not self._is_idefics: if 'fuyu' in model_path.lower(): self.processor = MFuyuProcessor.from_pretrained(self.model_path) model = MFuyuForCausalLM.from_pretrained( self.model_path, device_map='cuda', attn_implementation=attn_implementation, torch_dtype=torch.float16 ) else: self.processor = MLlavaProcessor.from_pretrained(self.model_path) model = LlavaForConditionalGeneration.from_pretrained( self.model_path, device_map='cuda', attn_implementation=attn_implementation, torch_dtype=torch.float16 ) else: self.processor = AutoProcessor.from_pretrained(self.model_path) model = AutoModelForVision2Seq.from_pretrained( self.model_path, device_map='cuda', torch_dtype=torch.float16 ) model = model.eval() self.model = model.cuda() kwargs_default = dict(do_sample=False, temperature=0, max_new_tokens=1024, top_p=None, num_beams=1) kwargs_default.update(kwargs) self.kwargs = kwargs_default warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ') self.tokenizer = self.processor.tokenizer self.default_conv = default_conv self.conv_templates = conv_templates def use_custom_prompt(self, dataset): assert dataset is not None if DATASET_TYPE(dataset) == 'MCQ': return True return False def build_prompt(self, line, dataset=None): assert self.use_custom_prompt(dataset) assert dataset is None or isinstance(dataset, str) tgt_path = self.dump_image(line, dataset) question = line['question'] hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None if hint is not None: question = hint + '\n' + question options = { cand: line[cand] for cand in string.ascii_uppercase if cand in line and not pd.isna(line[cand]) } for key, item in options.items(): question += f'\n{key}. {item}' prompt = question if len(options): prompt += ( '\n请直接回答选项字母。' if cn_string(prompt) else "\nAnswer with the option's letter from the given choices directly." ) else: prompt += '\n请直接回答问题。' if cn_string(prompt) else '\nAnswer the question directly.' message = [dict(type='image', value=s) for s in tgt_path] message.append(dict(type='text', value=prompt)) return message def output_process(self, answer): if '' in answer: answer = answer.replace('', '').strip() if '[/INST]' in answer: answer = answer.split('[/INST]')[1].strip() elif 'ASSISTANT:' in answer: answer = answer.split('ASSISTANT:')[1].strip() elif 'assistant\n' in answer: answer = answer.split('assistant\n')[1].strip() elif '<|end_header_id|>\n\n' in answer: answer = answer.split('<|end_header_id|>\n\n')[2].strip() if '' in answer: answer = answer.split('')[0].strip() elif '<|im_end|>' in answer: answer = answer.split('<|im_end|>')[0].strip() elif '<|eot_id|>' in answer: answer = answer.split('<|eot_id|>')[0].strip() elif '' in answer: answer = answer.split('')[0].strip() elif '|ENDOFTEXT|' in answer: answer = answer.split('|ENDOFTEXT|')[0].strip() return answer def generate_inner(self, message, dataset=None): content, images = '', [] ide_content, question = [], '' for msg in message: if msg['type'] == 'text': content += msg['value'] question += msg['value'] else: images.append(Image.open(msg['value']).convert('RGB')) content += (self.DEFAULT_IMAGE_TOKEN + '\n') ide_content.append({'type': 'image'}) if self._is_idefics: # Follow the idefics implementation: ide_content.append({'type': 'text', 'text': question}) prompt = [{'role': 'user', 'content': ide_content}] prompt = self.processor.apply_chat_template(prompt, add_generation_prompt=True) else: # Follow the Mantis code base to make sure they are consistent: # https://github.com/TIGER-AI-Lab/Mantis/blob/main/mantis/models/mllava/utils.py#L33 # Users don't need to define chat template as it is done here if 'llama-3' in self.model.language_model.name_or_path.lower(): conv = self.conv_templates['llama_3'] terminators = [ self.processor.tokenizer.eos_token_id, self.processor.tokenizer.convert_tokens_to_ids('<|eot_id|>') ] else: conv = self.default_conv terminators = [self.processor.tokenizer.eos_token_id] # Using EOT because end of *text* is more accurate for what we're doing than end of *sentence* if 'eos_token_id' not in self.kwargs: self.kwargs['eos_token_id'] = terminators conv = conv.copy() conv.append_message(conv.roles[0], content) conv.append_message(conv.roles[1], '') assert conv.messages[-1][0] == conv.roles[1] and conv.messages[-1][1] == '', 'Format check' prompt = conv.get_prompt() inputs = self.processor(prompt, images, return_tensors='pt', truncation=True) # FIXME: Fuyu model would return a list instead of a pytorch tensor. This weird behavior needs fixing. if 'image_patches' in inputs.keys(): inputs['image_patches'] = inputs['image_patches'][0] inputs = {k: v.to(self.model.device) for k, v in inputs.items()} output = self.model.generate(**inputs, **self.kwargs) output = output[0] generated_ids = output[inputs['input_ids'].shape[-1]:] answer = self.processor.decode(generated_ids, skip_special_token=True) answer = self.output_process(answer) return answer