import os import torch from PIL import Image from abc import abstractproperty from .base import BaseModel from ..dataset import DATASET_TYPE from ..smp import * class Parrot(BaseModel): INSTALL_REQ = False INTERLEAVE = False def __init__(self, model_path='AIDC-AI/Parrot-7B', **kwargs): try: from parrot.model.parrot_arch import ParrotMetaForCausalLM from parrot.utils.constants import DEFAULT_IMAGE_TOKEN, BEGIN_LINE, END_LINE from parrot.model.conversation_formatter import ConversationFormatter from parrot.utils.mm_utils import process_images except Exception as e: logging.critical('Please install Parrot before using Parrot') logging.critical('Please install Parrot from https://github.com/AIDC-AI/Parrot') logging.critical('Using `pip install -e . --no-deps` in the Parrot directory') logging.critical('Recommend to install transformers==4.39.0') raise e self.process_images = process_images self.ConversationFormatter = ConversationFormatter self.DEFAULT_IMAGE_TOKEN = DEFAULT_IMAGE_TOKEN self.BEGIN_LINE = BEGIN_LINE self.END_LINE = END_LINE try: model_name = 'parrot_qwen2' model, tokenizer, conversation_formatter = ParrotMetaForCausalLM.build( model_name, model_path, mm_vision_tower='openai/clip-vit-large-patch14-336' ) self.model = model.cuda() self.vision_tower = self.model.get_vision_tower() self.tokenizer = tokenizer self.conversation_formatter = conversation_formatter self.image_processor = self.model.get_vision_tower().image_processor except Exception as e: logging.critical('Error when loading Parrot model:') raise e self.kwargs = dict( do_sample=False, num_beams=1, max_new_tokens=512, repetition_penalty=None, use_cache=True, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id ) if int(os.environ.get('LOCAL_RANK', '0')) == 0: print(f'Following kwargs {self.kwargs} will be used as generation config.') self.count = 0 def use_custom_prompt(self, dataset): if DATASET_TYPE(dataset) == 'Y/N' or DATASET_TYPE(dataset) == 'MCQ': return True return False def build_prompt(self, line, dataset=None): assert self.use_custom_prompt(dataset) assert isinstance(dataset, str) tgt_path = self.dump_image(line, dataset) if DATASET_TYPE(dataset) == 'Y/N': prompt = self.built_yorn_prompt(line, dataset) elif DATASET_TYPE(dataset) == 'MCQ': prompt = self.build_multi_choice_prompt(line, dataset) else: raise ValueError(f'Invalid dataset type: {DATASET_TYPE(dataset)}') message = [dict(type='text', value=prompt)] message.extend([dict(type='image', value=p) for p in tgt_path]) return message def built_yorn_prompt(self, line, dataset=None): prompt = line['question'] previous_suffixs = [' Please answer yes or no.', ' Yes or No', ' Answer in one sentence.'] for previous_suffix in previous_suffixs: if prompt.endswith(previous_suffix): prompt = prompt[:-len(previous_suffix)] break prompt += '\n请直接回答Yes或No。请用单个词或短语回答问题。' if cn_string( prompt) else '\nPlease strictly answer Yes or No. Answer the question using a single word or phrase.' return prompt def build_multi_choice_prompt(self, line, dataset=None): question = line['question'] hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None if hint is not None: question = hint + '\n' + question options = { cand: line[cand] for cand in string.ascii_uppercase if cand in line and not pd.isna(line[cand]) } for key, item in options.items(): question += f'\n{key}. {item}' prompt = question if len(options): default_prompt = "\nAnswer with the option's letter from the given choices directly." if dataset[-3:] == '_cn' or cn_string(prompt): default_prompt = '\n请直接用给定选项中的选项字母回答。' elif dataset[-3:] == '_pt': default_prompt = '\nResponda diretamente com a letra da opção das escolhas dadas.' elif dataset[-3:] == '_ar': default_prompt = '\nأجب مباشرةً بحرف الخيار من الاختيارات المعطاة.' elif dataset[-3:] == '_ru': default_prompt = '\nОтветьте буквой варианта из предложенных вариантов напрямую.' elif dataset[-3:] == '_tr': default_prompt = '\nVerilen seçeneklerden doğrudan seçeneğin harfi ile cevap verin.' prompt += default_prompt # prompt += ( # '\n请直接回答选项字母。' if cn_string(prompt) else # "\nAnswer with the option's letter from the given choices directly." # ) else: prompt += '\n请用单个词或短语回答问题。' if cn_string( prompt) else '\nAnswer the question using a single word or phrase.' return prompt def process_answer_prefix(self, answer, prefixes): for prefix in prefixes: if prefix in answer.lower(): return answer[answer.lower().find(prefix) + len(prefix):] return answer def generate_inner(self, message, dataset=None): query, image_paths = self.prepare_inputs(message) images_list = [Image.open(image_path).convert('RGB') for image_path in image_paths] args = abstractproperty() args.image_aspect_ratio = 'pad' image_tensors = self.process_images(images_list, self.image_processor, args).cuda() prompt, input_ids = self.conversation_formatter.format_query(query) input_ids = input_ids.unsqueeze(0).cuda() with torch.inference_mode(): kwargs = dict( images=image_tensors, ) kwargs.update(self.kwargs) output_ids = self.model.generate(input_ids, **kwargs) input_token_len = input_ids.shape[1] n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() if n_diff_input_output > 0: print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') response = self.tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0].strip(string.whitespace) answer = response if query.endswith("Answer with the option's letter from the given choices directly.") or query.endswith( '请直接回答选项字母。'): qtype = 'multiple-choice' while True: answer = answer.strip(string.punctuation + string.whitespace) if len(answer) > 1: if answer[0] in string.ascii_uppercase and answer[1] in string.whitespace + string.punctuation: answer = answer[0] break elif answer[-1] in string.ascii_uppercase and answer[-2] in string.whitespace + string.punctuation: answer = answer[-1] break elif listinstr(['answer is', 'answer:'], answer.lower()): answer = self.process_answer_prefix(answer, ['answer is', 'answer:']) answer = self.process_answer_prefix(answer, ['option']) else: break else: break else: qtype = 'open' if self.count % 50 == 0 and int(os.environ.get('LOCAL_RANK', '0')) == 0: print(f'\n{self.BEGIN_LINE}') print(f'image_paths: {image_paths}\n') print(f'prompt: {prompt}\n') print(f'qtype: {qtype}\n') print(f'output: {response}\n') print(f'answer: {answer}\n') print(f'{self.END_LINE}\n', flush=True) self.count += 1 return answer def prepare_inputs(self, message): prompt = '' image_paths = [] image_count = 0 text_count = 0 pure_text = '' for msg in message: if msg['type'] == 'text': text_count += 1 prompt += msg['value'] pure_text += msg['value'] elif msg['type'] == 'image': image_count += 1 prompt += self.DEFAULT_IMAGE_TOKEN image_paths.append(msg['value']) if image_count == 1 and text_count == 1: prompt = self.DEFAULT_IMAGE_TOKEN + '\n' + pure_text return prompt, image_paths