import sys import torch import os.path as osp import os import warnings from .base import BaseModel from ..smp import * from PIL import Image ''' Please follow the instructions to download ckpt. https://github.com/dvlab-research/MGM?tab=readme-ov-file#pretrained-weights ''' class Mini_Gemini(BaseModel): INSTALL_REQ = True INTERLEAVE = False def __init__(self, model_path, root=None, conv_mode='llava_v1', **kwargs): if root is None: warnings.warn('Please set `root` to Mini_Gemini code directory, \ which is cloned from here: "https://github.com/dvlab-research/MGM?tab=readme-ov-file" ') raise ValueError warnings.warn('Please follow the instructions of Mini_Gemini to put the ckpt file in the right place, \ which can be found at https://github.com/dvlab-research/MGM?tab=readme-ov-file#structure') assert model_path == 'YanweiLi/MGM-7B-HD', 'We only support MGM-7B-HD for now' self.model_path = model_path sys.path.append(root) try: from mgm.model.builder import load_pretrained_model from mgm.mm_utils import get_model_name_from_path except Exception as e: logging.critical( 'Please first install Mini_Gemini and set the root path to use Mini_Gemini, ' 'which is cloned from here: "https://github.com/dvlab-research/MGM?tab=readme-ov-file" ' ) raise e VLMEvalKit_path = os.getcwd() os.chdir(root) warnings.warn('Please set `root` to Mini_Gemini code directory, \ which is cloned from here: "https://github.com/dvlab-research/MGM?tab=readme-ov-file" ') model_path = osp.join(root, 'work_dirs', 'MGM', 'MGM-7B-HD') try: model_name = get_model_name_from_path(model_path) except Exception as e: logging.critical( 'Please follow the instructions of Mini_Gemini to put the ckpt file in the right place, ' 'which can be found at https://github.com/dvlab-research/MGM?tab=readme-ov-file#structure' ) raise e tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name) os.chdir(VLMEvalKit_path) self.model = model self.tokenizer = tokenizer self.image_processor = image_processor self.conv_mode = conv_mode kwargs_default = dict(temperature=float(0), num_beams=1, top_p=None, max_new_tokens=1024, use_cache=True) kwargs_default.update(kwargs) do_sample = kwargs_default['temperature'] > 0 kwargs_default.update({'do_sample': do_sample}) self.kwargs = kwargs_default def generate_inner(self, message, dataset=None): try: from mgm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, \ DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from mgm.conversation import conv_templates from mgm.mm_utils import tokenizer_image_token, process_images except Exception as e: logging.critical( 'Please first install Mini_Gemini and set the root path to use Mini_Gemini, ' 'which is cloned from here: "https://github.com/dvlab-research/MGM?tab=readme-ov-file" ' ) raise e prompt, image = self.message_to_promptimg(message, dataset=dataset) image = Image.open(image) prompt = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + prompt conv = conv_templates[self.conv_mode].copy() conv.append_message(conv.roles[0], prompt) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') input_ids = input_ids.unsqueeze(0).cuda() if hasattr(self.model.config, 'image_size_aux'): if not hasattr(self.image_processor, 'image_size_raw'): self.image_processor.image_size_raw = self.image_processor.crop_size.copy() self.image_processor.crop_size['height'] = self.model.config.image_size_aux self.image_processor.crop_size['width'] = self.model.config.image_size_aux self.image_processor.size['shortest_edge'] = self.model.config.image_size_aux image_tensor = process_images([image], self.image_processor, self.model.config)[0] image_grid = getattr(self.model.config, 'image_grid', 1) if hasattr(self.model.config, 'image_size_aux'): raw_shape = [ self.image_processor.image_size_raw['height'] * image_grid, self.image_processor.image_size_raw['width'] * image_grid ] image_tensor_aux = image_tensor image_tensor = torch.nn.functional.interpolate( image_tensor[None], size=raw_shape, mode='bilinear', align_corners=False )[0] else: image_tensor_aux = [] if image_grid >= 2: raw_image = image_tensor.reshape( 3, image_grid, self.image_processor.image_size_raw['height'], image_grid, self.image_processor.image_size_raw['width'] ) raw_image = raw_image.permute(1, 3, 0, 2, 4) raw_image = raw_image.reshape( -1, 3, self.image_processor.image_size_raw['height'], self.image_processor.image_size_raw['width'] ) if getattr(self.model.config, 'image_global', False): global_image = image_tensor if len(global_image.shape) == 3: global_image = global_image[None] global_image = torch.nn.functional.interpolate( global_image, size=[ self.image_processor.image_size_raw['height'], self.image_processor.image_size_raw['width'] ], mode='bilinear', align_corners=False ) # [image_crops, image_global] raw_image = torch.cat([raw_image, global_image], dim=0) image_tensor = raw_image.contiguous() images = image_tensor[None].to(dtype=self.model.dtype, device='cuda', non_blocking=True) if len(image_tensor_aux) > 0: images_aux = image_tensor_aux[None].to(dtype=self.model.dtype, device='cuda', non_blocking=True) else: images_aux = None with torch.inference_mode(): output_ids = self.model.generate( input_ids, images=images, images_aux=images_aux, # no_repeat_ngram_size=3, bos_token_id=self.tokenizer.bos_token_id, # Begin of sequence token eos_token_id=self.tokenizer.eos_token_id, # End of sequence token pad_token_id=self.tokenizer.pad_token_id, # Pad token **self.kwargs ) outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() return outputs