import torch from PIL import Image from abc import abstractproperty import sys import os.path as osp from .base import BaseModel from ..smp import * from ..dataset import DATASET_TYPE import copy class VILA(BaseModel): INSTALL_REQ = True INTERLEAVE = True def __init__(self, model_path='Efficient-Large-Model/Llama-3-VILA1.5-8b', **kwargs): try: from llava.model.builder import load_pretrained_model from llava.mm_utils import get_model_name_from_path from llava.mm_utils import process_images, tokenizer_image_token, KeywordsStoppingCriteria from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN # noqa E501 from llava.conversation import conv_templates, SeparatorStyle except Exception as err: logging.critical('Please install VILA before using VILA') logging.critical('Please install VILA from https://github.com/NVlabs/VILA') logging.critical('Please install VLMEvalKit after installing VILA') logging.critical('VILA is supported only with transformers==4.36.2') raise err warnings.warn('Please install the latest version of VILA from GitHub before you evaluate the VILA model.') assert osp.exists(model_path) or len(model_path.split('/')) == 2 model_name = get_model_name_from_path(model_path) try: self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model( model_path=model_path, model_base=None, model_name=model_name, device='cpu', device_map='cpu' ) except Exception as err: logging.critical('Error loading VILA model: ') raise err self.model = self.model.cuda() if '3b' in model_path: self.conv_mode = 'vicuna_v1' if '8b' in model_path: self.conv_mode = 'llama_3' elif '13b' in model_path: self.conv_mode = 'vicuna_v1' elif '40b' in model_path: self.conv_mode = 'hermes-2' kwargs_default = dict(do_sample=False, temperature=0, max_new_tokens=512, top_p=None, num_beams=1, use_cache=True) # noqa E501 kwargs_default.update(kwargs) self.kwargs = kwargs_default warnings.warn(f'Using the following kwargs for generation config: {self.kwargs}') self.conv_templates = conv_templates self.process_images = process_images self.tokenizer_image_token = tokenizer_image_token self. DEFAULT_IMAGE_TOKEN = DEFAULT_IMAGE_TOKEN self.SeparatorStyle = SeparatorStyle self.IMAGE_TOKEN_INDEX = IMAGE_TOKEN_INDEX self.KeywordsStoppingCriteria = KeywordsStoppingCriteria def use_custom_prompt(self, dataset): assert dataset is not None # TODO see if custom prompt needed return False def generate_inner(self, message, dataset=None): content, images = '', [] for msg in message: if msg['type'] == 'text': content += msg['value'] elif msg['type'] == 'image': image = Image.open(msg['value']).convert('RGB') images.append(image) content += (self.DEFAULT_IMAGE_TOKEN + '\n') image_tensor = self.process_images( images, self.image_processor, self.model.config).to(self.model.device, dtype=torch.float16) # Support interleave text and image conv = self.conv_templates[self.conv_mode].copy() conv.append_message(conv.roles[0], content) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = self.tokenizer_image_token(prompt, self.tokenizer, self.IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() stop_str = conv.sep if conv.sep_style != self.SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] stopping_criteria = self.KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids) with torch.inference_mode(): output_ids = self.model.generate( input_ids, images=image_tensor, stopping_criteria=[stopping_criteria], **self.kwargs) output = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() return output