mgm.py 7.12 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import sys
import torch
import os.path as osp
import os
import warnings
from .base import BaseModel
from ..smp import *
from PIL import Image

'''
    Please follow the instructions to download ckpt.
    https://github.com/dvlab-research/MGM?tab=readme-ov-file#pretrained-weights
'''


class Mini_Gemini(BaseModel):
    INSTALL_REQ = True
    INTERLEAVE = False

    def __init__(self, model_path, root=None, conv_mode='llava_v1', **kwargs):
        if root is None:
            warnings.warn('Please set `root` to Mini_Gemini code directory, \
                which is cloned from here: "https://github.com/dvlab-research/MGM?tab=readme-ov-file" ')
            raise ValueError
        warnings.warn('Please follow the instructions of Mini_Gemini to put the ckpt file in the right place, \
            which can be found at https://github.com/dvlab-research/MGM?tab=readme-ov-file#structure')
        assert model_path == 'YanweiLi/MGM-7B-HD', 'We only support MGM-7B-HD for now'
        self.model_path = model_path
        sys.path.append(root)
        try:
            from mgm.model.builder import load_pretrained_model
            from mgm.mm_utils import get_model_name_from_path
        except Exception as e:
            logging.critical(
                'Please first install Mini_Gemini and set the root path to use Mini_Gemini, '
                'which is cloned from here: "https://github.com/dvlab-research/MGM?tab=readme-ov-file" '
            )
            raise e

        VLMEvalKit_path = os.getcwd()
        os.chdir(root)
        warnings.warn('Please set `root` to Mini_Gemini code directory, \
            which is cloned from here: "https://github.com/dvlab-research/MGM?tab=readme-ov-file" ')
        model_path = osp.join(root, 'work_dirs', 'MGM', 'MGM-7B-HD')
        try:
            model_name = get_model_name_from_path(model_path)
        except Exception as e:
            logging.critical(
                'Please follow the instructions of Mini_Gemini to put the ckpt file in the right place, '
                'which can be found at https://github.com/dvlab-research/MGM?tab=readme-ov-file#structure'
            )
            raise e

        tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name)
        os.chdir(VLMEvalKit_path)
        self.model = model
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.conv_mode = conv_mode

        kwargs_default = dict(temperature=float(0), num_beams=1, top_p=None, max_new_tokens=1024, use_cache=True)
        kwargs_default.update(kwargs)
        do_sample = kwargs_default['temperature'] > 0
        kwargs_default.update({'do_sample': do_sample})
        self.kwargs = kwargs_default

    def generate_inner(self, message, dataset=None):
        try:
            from mgm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, \
                DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
            from mgm.conversation import conv_templates
            from mgm.mm_utils import tokenizer_image_token, process_images
        except Exception as e:
            logging.critical(
                'Please first install Mini_Gemini and set the root path to use Mini_Gemini, '
                'which is cloned from here: "https://github.com/dvlab-research/MGM?tab=readme-ov-file" '
            )
            raise e

        prompt, image = self.message_to_promptimg(message, dataset=dataset)
        image = Image.open(image)
        prompt = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + prompt
        conv = conv_templates[self.conv_mode].copy()
        conv.append_message(conv.roles[0], prompt)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
        input_ids = input_ids.unsqueeze(0).cuda()
        if hasattr(self.model.config, 'image_size_aux'):
            if not hasattr(self.image_processor, 'image_size_raw'):
                self.image_processor.image_size_raw = self.image_processor.crop_size.copy()
            self.image_processor.crop_size['height'] = self.model.config.image_size_aux
            self.image_processor.crop_size['width'] = self.model.config.image_size_aux
            self.image_processor.size['shortest_edge'] = self.model.config.image_size_aux
        image_tensor = process_images([image], self.image_processor, self.model.config)[0]
        image_grid = getattr(self.model.config, 'image_grid', 1)
        if hasattr(self.model.config, 'image_size_aux'):
            raw_shape = [
                self.image_processor.image_size_raw['height'] * image_grid,
                self.image_processor.image_size_raw['width'] * image_grid
            ]
            image_tensor_aux = image_tensor
            image_tensor = torch.nn.functional.interpolate(
                image_tensor[None],
                size=raw_shape,
                mode='bilinear',
                align_corners=False
            )[0]
        else:
            image_tensor_aux = []
        if image_grid >= 2:
            raw_image = image_tensor.reshape(
                3, image_grid, self.image_processor.image_size_raw['height'],
                image_grid, self.image_processor.image_size_raw['width']
            )
            raw_image = raw_image.permute(1, 3, 0, 2, 4)
            raw_image = raw_image.reshape(
                -1, 3, self.image_processor.image_size_raw['height'], self.image_processor.image_size_raw['width']
            )

            if getattr(self.model.config, 'image_global', False):
                global_image = image_tensor
                if len(global_image.shape) == 3:
                    global_image = global_image[None]
                global_image = torch.nn.functional.interpolate(
                    global_image,
                    size=[
                        self.image_processor.image_size_raw['height'],
                        self.image_processor.image_size_raw['width']
                    ],
                    mode='bilinear',
                    align_corners=False
                )
                # [image_crops, image_global]
                raw_image = torch.cat([raw_image, global_image], dim=0)
            image_tensor = raw_image.contiguous()

        images = image_tensor[None].to(dtype=self.model.dtype, device='cuda', non_blocking=True)
        if len(image_tensor_aux) > 0:
            images_aux = image_tensor_aux[None].to(dtype=self.model.dtype, device='cuda', non_blocking=True)
        else:
            images_aux = None

        with torch.inference_mode():
            output_ids = self.model.generate(
                input_ids,
                images=images,
                images_aux=images_aux,
                # no_repeat_ngram_size=3,
                bos_token_id=self.tokenizer.bos_token_id,  # Begin of sequence token
                eos_token_id=self.tokenizer.eos_token_id,  # End of sequence token
                pad_token_id=self.tokenizer.pad_token_id,  # Pad token
                **self.kwargs
            )

        outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
        return outputs