import torch import transformers from transformers import AutoModelForCausalLM, AutoTokenizer from PIL import Image import warnings import re from .base import BaseModel from ..smp import * from ..dataset import DATASET_TYPE class BunnyLLama3(BaseModel): INSTALL_REQ = False INTERLEAVE = False def __init__(self, model_path='BAAI/Bunny-v1_1-Llama-3-8B-V', **kwargs): assert model_path is not None transformers.logging.set_verbosity_error() transformers.logging.disable_progress_bar() warnings.filterwarnings('ignore') self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map='auto', trust_remote_code=True) self.kwargs = kwargs def use_custom_prompt(self, dataset): if listinstr(['MCQ', 'Y/N'], DATASET_TYPE(dataset)) or listinstr(['mathvista'], dataset.lower()): return True else: return False def build_prompt(self, line, dataset): if dataset is None: dataset = self.dataset if isinstance(line, int): line = self.data.iloc[line] tgt_path = self.dump_image(line, dataset) prompt = line['question'] if DATASET_TYPE(dataset) == 'MCQ': if listinstr(['mmmu'], dataset.lower()): hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None assert hint is None question = line['question'] question = re.sub(r'', lambda x: x.group(0)[1:-1], question) options = { cand: line[cand] for cand in string.ascii_uppercase if cand in line and not pd.isna(line[cand]) } options_prompt = '\n' for key, item in options.items(): options_prompt += f'({key}) {item}\n' prompt = question if len(options): prompt += options_prompt prompt += "\nAnswer with the option's letter from the given choices directly." else: prompt += '\nAnswer the question using a single word or phrase.' else: hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None prompt = '' if hint is not None: prompt += f'{hint}\n' question = line['question'] options = { cand: line[cand] for cand in string.ascii_uppercase if cand in line and not pd.isna(line[cand]) } options_prompt = '\n' for key, item in options.items(): options_prompt += f'{key}. {item}\n' prompt += question + options_prompt if listinstr(['cn', 'ccbench'], dataset.lower()): prompt += '请直接回答选项字母。' else: prompt += "Answer with the option's letter from the given choices directly." elif DATASET_TYPE(dataset) == 'Y/N': if listinstr(['mme'], dataset.lower()): if not listinstr( ['code_reasoning', 'commonsense_reasoning', 'numerical_calculation', 'text_translation'], line['category']): prompt = prompt.replace(' Please answer yes or no.', '\nAnswer the question using a single word or phrase.') elif listinstr(['pope'], dataset.lower()): prompt = prompt.replace(' Please answer yes or no.', '\nAnswer the question using a single word or phrase.') elif listinstr(['mathvista'], dataset.lower()): match = re.search(r'Hint: (.*?)\nQuestion: (.*?)\n(Choices:\n(.*))?', prompt + '\n', re.DOTALL) prompt = match.group(2) if match.group(4) is not None: prompt += '\n' + match.group(4).rstrip('\n') prompt += '\n' + match.group(1) else: raise ValueError( f"Bunny doesn't implement a custom prompt for {dataset}. It should use the default prompt, but didn't.") msgs = [] if isinstance(tgt_path, list): msgs.extend([dict(type='image', value=p) for p in tgt_path]) else: msgs = [dict(type='image', value=tgt_path)] msgs.append(dict(type='text', value=prompt)) return msgs def generate_inner(self, message, dataset=None): prompt, image_path = self.message_to_promptimg(message, dataset=dataset) text = (f'A chat between a curious user and an artificial intelligence assistant. ' f"The assistant gives helpful, detailed, and polite answers to the user's questions. " f'USER: \n{prompt} ASSISTANT:') text_chunks = [self.tokenizer(chunk).input_ids for chunk in text.split('')] input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1][1:], dtype=torch.long).unsqueeze(0) image = Image.open(image_path).convert('RGB') image_tensor = self.model.process_images([image], self.model.config).to(dtype=self.model.dtype) output_ids = self.model.generate(input_ids, images=image_tensor, max_new_tokens=128, use_cache=True)[0] response = self.tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True) return response