model.py 22.3 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
import math
import re
from typing import Iterable, List, Optional, Tuple

import numpy as np
import torch
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.mm_utils import (
    get_anyres_image_grid_shape,  # unpad_image, unpad_image_shape
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.srt.utils import add_prefix
from torch import nn
from transformers import (
    CLIPVisionConfig,
    CLIPVisionModel,
    SiglipVisionConfig,
    SiglipVisionModel,
)

from ..vlm_hf_model.configuration_mineru2 import Mineru2QwenConfig
from ..vlm_hf_model.modeling_mineru2 import build_vision_projector
from ...utils.models_download_utils import auto_download_and_get_model_root_path


def flatten_nested_list(nested_list):
    if isinstance(nested_list, list):
        return [item for sublist in nested_list for item in flatten_nested_list(sublist)]
    else:
        return [nested_list]


def downgrade_modality(modality):
    modality_str = str(modality)
    if "MULTI_IMAGES" in modality_str:
        return "multi-images"
    if "IMAGE" in modality_str:
        return "image"
    if "VIDEO" in modality_str:
        return "video"
    if "AUDIO" in modality_str:
        return "audio"
    raise ValueError(f"Unexpected modality: {modality_str}")


class Mineru2QwenForCausalLM(nn.Module):
    def __init__(
        self,
        config: Mineru2QwenConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.config = config

        if getattr(self.config, "projector_hidden_act", None) is None:
            self.config.projector_hidden_act = "gelu"
        if getattr(self.config, "image_token_index", None) is None:
            self.config.image_token_index = 151646

        # load vision tower
        mm_vision_tower = self.config.mm_vision_tower
        model_root_path = auto_download_and_get_model_root_path(mm_vision_tower, "vlm")
        mm_vision_tower = f"{model_root_path}/{mm_vision_tower}"

        if "clip" in mm_vision_tower:
            vision_config = CLIPVisionConfig.from_pretrained(mm_vision_tower)
            self.vision_tower = CLIPVisionModel(vision_config)  # type: ignore
        elif "siglip" in mm_vision_tower:
            vision_config = SiglipVisionConfig.from_pretrained(mm_vision_tower)
            self.vision_tower = SiglipVisionModel(vision_config)  # type: ignore
            # Siglip needs all feature tokens
            self.config.mm_vision_select_feature = "full"
        else:
            raise ValueError(f"Unexpected mm_vision_tower: {mm_vision_tower}")

        ### EDIT: change projector
        # the name `projector` contains `proj` which is often used in attention layers, which can cause bugs in quantization.
        self.multi_modal_mlp = build_vision_projector(config)

        self.language_model = Qwen2ForCausalLM(
            config,
            quant_config=quant_config,
            prefix=add_prefix("language_model", prefix),
        )

        if "unpad" in getattr(config, "mm_patch_merge_type", ""):
            self.language_model.model.image_newline = nn.Parameter(torch.empty(config.hidden_size))

        language_model_device = next(self.language_model.parameters()).device
        self.vision_tower = self.vision_tower.to(language_model_device)
        self.vision_tower.eval()

        self.vision_feature_layer = self.config.mm_vision_select_layer
        self.vision_feature_select_strategy = self.config.mm_vision_select_feature
        self.image_size = self.vision_tower.config.image_size
        self.patch_size = self.vision_tower.config.patch_size

        self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
        self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
        self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)

        self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
        if self.vision_feature_select_strategy in ("patch", "full"):
            pass
        elif self.vision_feature_select_strategy == "cls_patch":
            self.image_feature_len += 1
        else:
            raise ValueError(f"Unexpected select feature: {self.select_feature}")

    def pad_input_ids(self, input_ids: List[int], image_inputs):
        if hasattr(image_inputs, "mm_items"):  # MultimodalInputs
            # sglang==0.4.5.post3
            image_sizes = flatten_nested_list([item.image_sizes for item in image_inputs.mm_items])
            pad_values = [item.pad_value for item in image_inputs.mm_items]
        else:  # ImageInputs
            # sglang==0.4.4.post1
            image_sizes = image_inputs.image_sizes
            pad_values = image_inputs.pad_values

        # hardcode for spatial_unpad + anyres
        # if image_inputs.modalities is not None and (
        #     "multi-images" in image_inputs.modalities or "video" in image_inputs.modalities
        # ):
        #     image_aspect_ratio = "pad"
        # else:
        #     image_aspect_ratio = "anyres"

        offset_list = []
        image_inputs.image_pad_len = []
        for image_idx, image_s in enumerate(image_sizes):
            if len(image_sizes) > 16:
                # 2x2 pooling with stride 2
                new_image_feature_len = math.ceil(self.image_size / self.patch_size / 2) ** 2
            else:
                new_image_feature_len = self.image_feature_len  # multiimage

            height = width = self.num_patches_per_side
            if "anyres" in self.config.image_aspect_ratio:
                num_patch_width, num_patch_height = get_anyres_image_grid_shape(
                    image_s,
                    self.image_grid_pinpoints,
                    self.vision_tower.config.image_size,
                )
                h = num_patch_height * height
                w = num_patch_width * width

                ### EDIT: remove `unpad_image_shape`
                # new_h, new_w = unpad_image_shape(h, w, image_s)
                new_h, new_w = h, w

                if "anyres_max" in self.config.image_aspect_ratio:
                    matched_anyres_max_num_patches = re.match(r".*anyres_max_(\d+)", self.config.image_aspect_ratio)
                    if matched_anyres_max_num_patches:
                        max_num_patches = int(matched_anyres_max_num_patches.group(1))
                        times = math.sqrt(new_h * new_w / (max_num_patches * self.image_feature_len))
                        if times > 1.1:
                            new_h = int(new_h // times)
                            new_w = int(new_w // times)
                new_image_feature_len += new_h * (new_w + 1)

            try:
                offset = input_ids.index(self.config.image_token_index)
            except ValueError:
                offset = 0
            # old_len + pad_len - 1, because we need to remove image_token_id
            input_ids = input_ids[:offset] + [pad_values[image_idx]] * new_image_feature_len + input_ids[offset + 1 :]
            offset_list.append(offset)
            image_inputs.image_pad_len.append(new_image_feature_len)

        image_inputs.image_offsets = offset_list
        return input_ids

    def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
        pixel_values = pixel_values.to(device=self.vision_tower.device, dtype=self.vision_tower.dtype)
        image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
        # NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.

        selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
        if self.vision_feature_select_strategy in ["default", "patch"]:
            selected_image_feature = selected_image_feature[:, 1:]
        elif self.vision_feature_select_strategy == "full":
            selected_image_feature = selected_image_feature
        else:
            raise ValueError(f"Unexpected select feature strategy: {self.vision_feature_select_strategy}")

        image_features = self.multi_modal_mlp(selected_image_feature)
        return image_features

    @torch.no_grad()
    def forward(
        self,
        input_ids: torch.LongTensor,
        positions: torch.Tensor,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:
        if hasattr(forward_batch, "mm_inputs"):
            # sglang==0.4.5.post3
            image_inputs = forward_batch.mm_inputs
            is_sglang_mm_inputs = True
        else:
            # sglang==0.4.4.post1
            image_inputs = forward_batch.image_inputs
            is_sglang_mm_inputs = False

        if image_inputs is None:
            image_inputs = []

        if forward_batch.forward_mode.is_extend():
            # Clamp input ids. This is because the input_ids for the image tokens are
            # filled with the hash values of the image for the prefix matching in the radix attention.
            # There values are useless because their embeddings will be replaced by vision embeddings anyway.
            input_ids.clamp_(min=0, max=self.config.vocab_size - 1)

            # Embed text inputs
            input_embeds = self.language_model.model.embed_tokens(input_ids)

            # Got List[List[str]] extend it to List[str]
            # The length of the List should be equal to batch size
            modalities_list = []
            max_image_offset = []
            for im in image_inputs:
                if im:
                    if hasattr(im, "mm_items"):
                        # sglang==0.4.5.post3
                        modalities_list.extend([downgrade_modality(item.modality) for item in im.mm_items])
                    elif im.modalities is not None:
                        # sglang==0.4.4.post1
                        modalities_list.extend(im.modalities)
                if im and im.image_offsets:
                    max_image_offset.append(np.max(np.array(im.image_offsets) + np.array(im.image_pad_len)))
                else:
                    max_image_offset.append(-1)

            start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
            need_vision = start_positions <= np.array(max_image_offset)

            if need_vision.any():
                bs = forward_batch.batch_size

                if is_sglang_mm_inputs:
                    # sglang==0.4.5.post3
                    pixel_values = flatten_nested_list(
                        [[item.pixel_values for item in image_inputs[i].mm_items] for i in range(bs) if need_vision[i]]
                    )  # image_inputs[batch_idx].mm_items[item_idx].pixel_values is Tensor
                    image_sizes = [
                        flatten_nested_list([item.image_sizes for item in image_inputs[i].mm_items])
                        for i in range(bs)
                        if need_vision[i]
                    ]  # image_inputs[batch_idx].mm_items[item_idx].image_sizes should be tuple, but is list of tuple for now.
                else:
                    # sglang==0.4.4.post1
                    pixel_values = [image_inputs[i].pixel_values for i in range(bs) if need_vision[i]]
                    image_sizes = [image_inputs[i].image_sizes for i in range(bs) if need_vision[i]]

                ########## Encode Image ########

                if pixel_values[0].ndim == 4:
                    # llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
                    np.concatenate(pixel_values, axis=0)
                    # ndim=4
                    concat_images = torch.tensor(
                        np.concatenate(pixel_values, axis=0),
                        device=self.vision_tower.device,
                    )
                    image_features = self.encode_images(concat_images)
                    split_sizes = [image.shape[0] for image in pixel_values]
                    image_features = torch.split(image_features, split_sizes, dim=0)
                    # hd image_features: BS, num_patch, 576, 4096
                else:
                    # normal pixel: BS, C=3, H=336, W=336
                    pixel_values = torch.tensor(np.array(pixel_values), device=self.vision_tower.device)
                    image_features = self.encode_images(pixel_values)
                    # image_features: BS, 576, 4096

                if self.mm_patch_merge_type.startswith("spatial"):
                    new_image_features = []
                    height = width = self.num_patches_per_side
                    for image_idx, image_feature in enumerate(image_features):
                        if modalities_list[image_idx] == "image":
                            image_aspect_ratio = self.config.image_aspect_ratio  # single image
                        elif modalities_list[image_idx] == "multi-images" or modalities_list[image_idx] == "video":
                            image_aspect_ratio = "pad"  # multi image
                        # image_aspect_ratio = (
                        #     "anyres" if len(image_sizes[image_idx]) == 1 else "pad"
                        # )
                        if (
                            image_feature.shape[0] > 1
                            and "anyres" in image_aspect_ratio
                            and modalities_list[image_idx] == "image"
                        ):
                            base_image_feature = image_feature[0]
                            image_feature = image_feature[1:]
                            assert height * width == base_image_feature.shape[0]

                            if "anyres_max" in image_aspect_ratio:
                                matched_anyres_max_num_patches = re.match(r".*anyres_max_(\d+)", image_aspect_ratio)
                                if matched_anyres_max_num_patches:
                                    max_num_patches = int(matched_anyres_max_num_patches.group(1))

                            if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
                                vision_tower_image_size = self.image_size
                                try:
                                    num_patch_width, num_patch_height = get_anyres_image_grid_shape(
                                        image_sizes[image_idx][0],
                                        self.config.image_grid_pinpoints,
                                        vision_tower_image_size,
                                    )
                                except Exception as e:
                                    print(f"Error: {e}")
                                    num_patch_width, num_patch_height = 2, 2
                                image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
                            else:
                                image_feature = image_feature.view(2, 2, height, width, -1)

                            if "unpad" in self.mm_patch_merge_type:
                                unit = image_feature.shape[2]
                                image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
                                image_feature = image_feature.flatten(1, 2).flatten(2, 3)

                                ### EDIT: remove `unpad_image`
                                # image_feature = unpad_image(image_feature, image_sizes[image_idx][0])

                                if "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches:
                                    c, h, w = image_feature.shape
                                    times = math.sqrt(h * w / (max_num_patches * unit**2))
                                    if times > 1.1:
                                        image_feature = image_feature[None]
                                        image_feature = nn.functional.interpolate(
                                            image_feature,
                                            [int(h // times), int(w // times)],
                                            mode="bilinear",
                                        )[0]
                                image_feature = torch.cat(
                                    (
                                        image_feature,
                                        self.language_model.model.image_newline[:, None, None].expand(
                                            *image_feature.shape[:-1], 1
                                        ),
                                    ),
                                    dim=-1,
                                )
                                image_feature = image_feature.flatten(1, 2).transpose(0, 1)
                            else:
                                image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
                                image_feature = image_feature.flatten(0, 3)
                            image_feature = torch.cat((base_image_feature, image_feature), dim=0)
                            image_feature = image_feature.unsqueeze(0)
                        else:
                            if modalities_list[image_idx] == "video":  # video
                                # 2x2 pooling
                                num_of_frames = image_feature.shape[0]
                                image_feature = image_feature.view(num_of_frames, height, width, -1)
                                image_feature = image_feature.permute(0, 3, 1, 2).contiguous()  # N, C, H, W
                                height, weight = image_feature.shape[2:]
                                scaled_shape = [
                                    math.ceil(height / 2),
                                    math.ceil(weight / 2),
                                ]
                                image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode="bilinear")
                                image_feature = image_feature.flatten(2).transpose(1, 2).contiguous()  # N, C, H*W
                            if "unpad" in self.mm_patch_merge_type:
                                image_feature = torch.cat(
                                    (
                                        image_feature,
                                        # Expand to (bs, 1, hidden_dim) and concat at the end of the image tokens
                                        self.language_model.model.image_newline[None, None].expand(
                                            image_feature.shape[0],
                                            1,
                                            image_feature.shape[-1],
                                        ),
                                    ),
                                    dim=1,
                                )

                        new_image_features.append(image_feature)
                    image_features = new_image_features

                # Fill in the placeholder for the image
                extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
                extend_seq_lens = forward_batch.extend_seq_lens.cpu().numpy()
                prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
                pt = 0
                for i in range(bs):
                    if not need_vision[i]:
                        continue

                    start_idx = extend_start_loc_cpu[i]
                    seq_len = extend_seq_lens[i]
                    prefix_len = prefix_lens_cpu[i]

                    # Multiple images
                    for image_idx, image_offset in enumerate(image_inputs[i].image_offsets):
                        if image_offset + image_inputs[i].image_pad_len[image_idx] <= prefix_len:
                            continue
                        if image_offset >= prefix_len + seq_len:
                            break

                        tmp_image_feature = image_features[pt][image_idx]
                        pad_len = tmp_image_feature.shape[0]

                        input_offset = image_offset - prefix_len
                        left_idx = start_idx + input_offset
                        right_idx = left_idx + pad_len
                        assert right_idx > start_idx
                        if input_offset < 0:
                            left_idx = start_idx
                            tmp_image_feature = tmp_image_feature[-input_offset:]
                        if right_idx > start_idx + seq_len:
                            tmp_image_feature = tmp_image_feature[: start_idx + seq_len - right_idx]
                            right_idx = start_idx + seq_len
                        try:
                            input_embeds[left_idx:right_idx] = tmp_image_feature
                        except RuntimeError as e:
                            print(f"RuntimeError in image encoding: {e}")
                            print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}")
                            print(f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}")
                    pt += 1

            return self.language_model(input_ids, positions, forward_batch, input_embeds=input_embeds)
        elif forward_batch.forward_mode.is_decode():
            return self.language_model(input_ids, positions, forward_batch)
        else:
            raise ValueError(f"Unexpected forward mode: {forward_batch.forward_mode}")

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        projector_weights = {
            "model.mm_projector": "multi_modal_mlp",
            "model.vision_tower.vision_tower": "vision_tower",
            # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
            "model.image_newline": "language_model.model.image_newline",
        }
        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if "projector" in name or "vision_tower" in name or "image_newline" in name:
                for weight_name, param_name in projector_weights.items():
                    if weight_name in name:
                        name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)
            else:
                self.language_model.load_weights([(name, loaded_weight)])

    @property
    def num_patches_per_side(self):
        return self.image_size // self.patch_size


EntryClass = [Mineru2QwenForCausalLM]