step3_vl.py 41 KB
Newer Older
Song's avatar
Song committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable, Mapping, Sequence
from itertools import product
from math import ceil, sqrt
from typing import Any, Literal, Optional, TypedDict, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from transformers import BatchFeature, PretrainedConfig, TensorType

18
from vllm.attention.layer import MultiHeadAttention
Song's avatar
Song committed
19
20
21
22
23
24
25
26
27
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
28
                                    MultiModalKwargsItems, NestedTensors)
Song's avatar
Song committed
29
30
31
32
33
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        BaseProcessingInfo, PromptReplacement,
                                        PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
34
from vllm.multimodal.utils import run_dp_sharded_vision_model
Song's avatar
Song committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import Step3VisionEncoderConfig
from vllm.transformers_utils.tokenizer import AnyTokenizer

from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
                    init_vllm_registered_model, maybe_prefix,
                    merge_multimodal_embeddings)


class Step3VLImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    pixel_values: torch.Tensor
    patch_pixel_values: Optional[torch.Tensor]
    num_patches: list[int]


class Step3VLImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    image_embeds: torch.Tensor


Step3VLImageInputs = Union[Step3VLImagePixelInputs,
                           Step3VLImageEmbeddingInputs]

ImageWithPatches = tuple[Image.Image, list[Image.Image], list[int] | None]

MAX_IMAGE_SIZE: int = 3024


class Step3VisionProcessor:

    def __init__(self, size, interpolation_mode="bicubic", patch_size=None):
        mean = [0.48145466, 0.4578275, 0.40821073]
        std = [0.26862954, 0.26130258, 0.27577711]
        patch_size = patch_size if patch_size is not None else size

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
            transforms.Resize(
                (size, size),
                interpolation=InterpolationMode.BICUBIC if interpolation_mode
                == "bicubic" else InterpolationMode.BILINEAR,
                antialias=True),
        ])

        self.patch_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
            transforms.Resize(
                (patch_size, patch_size),
                interpolation=InterpolationMode.BICUBIC if interpolation_mode
                == "bicubic" else InterpolationMode.BILINEAR,
                antialias=True),
        ]) if patch_size is not None else None

    def __call__(self, image, is_patch=False):
        if is_patch:
            return {"pixel_values": self.patch_transform(image).unsqueeze(0)}
        else:
            return {"pixel_values": self.transform(image).unsqueeze(0)}


class ImagePatcher:

    def determine_window_size(self, long: int, short: int) -> int:
        if long <= 728:
            return short if long / short > 1.5 else 0
        return min(short, 504) if long / short > 4 else 504

    def slide_window(
        self,
        width: int,
        height: int,
        sizes: list[tuple[int, int]],
        steps: list[tuple[int, int]],
        img_rate_thr: float = 0.6,
    ) -> tuple[list[tuple[int, int, int, int]], tuple[int, int]]:
        assert 1 >= img_rate_thr >= 0, "The `in_rate_thr` should lie in 0~1"
        windows = []
        # Sliding windows.
        for size, step in zip(sizes, steps):
            size_w, size_h = size
            step_w, step_h = step

            x_num = 1 if width <= size_w else ceil((width - size_w) / step_w +
                                                   1)
            x_start = [step_w * i for i in range(x_num)]
            if len(x_start) > 1 and x_start[-1] + size_w > width:
                x_start[-1] = width - size_w

            y_num = 1 if height <= size_h else ceil((height - size_h) /
                                                    step_h + 1)
            y_start = [step_h * i for i in range(y_num)]
            if len(y_start) > 1 and y_start[-1] + size_h > height:
                y_start[-1] = height - size_h

            start = np.array(list(product(y_start, x_start)), dtype=int)
            start[:, [0, 1]] = start[:, [1, 0]]
            windows.append(np.concatenate([start, start + size], axis=1))
        windows = np.concatenate(windows, axis=0)

        return [(int(box[0]), int(box[1]), int(box[2] - box[0]),
                 int(box[3] - box[1])) for box in windows], (x_num, y_num)

    def square_pad(self, img: Image.Image) -> Image.Image:
        w, h = img.size
        if w == h:
            return img
        size = max(w, h)
        padded = Image.new(img.mode, (size, size), 0)
        padded.paste(img, (0, 0))
        return padded

    def get_image_size_for_padding(self, img_width: int,
                                   img_height: int) -> tuple[int, int]:
        ratio = img_width / img_height
        if min(img_height, img_width) < 32 and (ratio > 4 or ratio < 1 / 4):
            new_size = max(img_height, img_width)
            return new_size, new_size
        return img_width, img_height

    def get_image_size_for_preprocess(self, img_width: int,
                                      img_height: int) -> tuple[int, int]:

        if max(img_height, img_width) > MAX_IMAGE_SIZE:
            scale_factor = MAX_IMAGE_SIZE / max(img_height, img_width)
            img_width = int(img_width * scale_factor)
            img_height = int(img_height * scale_factor)
        return img_width, img_height

    def get_image_size_for_crop(self, img_width: int, img_height: int,
                                window_size: int):
        w_ratio = img_width / window_size
        h_ratio = img_height / window_size

        if w_ratio < 1:
            width_new = img_width
        else:
            decimal_w = w_ratio - img_width // window_size
            w_ratio = int(w_ratio) + 1 if decimal_w > 0.2 else int(w_ratio)
            width_new = window_size * w_ratio
        if h_ratio < 1:
            height_new = img_height
        else:
            decimal_h = h_ratio - img_height // window_size
            h_ratio = int(h_ratio) + 1 if decimal_h > 0.2 else int(h_ratio)
            height_new = window_size * h_ratio
        return int(width_new), int(height_new)

    def patch_crop(self, img: Image.Image, i: int, j: int, th: int, tw: int):
        target = img.crop((j, i, j + tw, i + th))
        return target

    def get_num_patches(self, img_width: int,
                        img_height: int) -> tuple[int, int]:
        img_width, img_height = self.get_image_size_for_padding(
            img_width, img_height)
        img_width, img_height = self.get_image_size_for_preprocess(
            img_width, img_height)
        window_size = self.determine_window_size(max(img_height, img_width),
                                                 min(img_height, img_width))
        if window_size == 0:
            return 0, 0
        else:
            img_width, img_height = self.get_image_size_for_crop(
                img_width, img_height, window_size)
            center_list, (x_num, y_num) = self.slide_window(
                img_width, img_height, [(window_size, window_size)],
                [(window_size, window_size)])
            full_rows = (len(center_list) - 1) // x_num + 1
            if len(center_list) > 0 and len(center_list) % x_num == 0:
                full_rows -= 1
            return len(center_list), full_rows

    def __call__(
        self, img: Image.Image
    ) -> tuple[Image.Image, list[Image.Image], list[bool] | None]:
        img_width, img_height = img.size
        new_img_width, new_img_height = self.get_image_size_for_padding(
            img_width, img_height)
        if new_img_width != img_width or new_img_height != img_height:
            img = self.square_pad(img)
            img_width, img_height = img.size

        new_img_width, new_img_height = self.get_image_size_for_preprocess(
            img_width, img_height)
        img = img.resize((new_img_width, new_img_height),
                         Image.Resampling.BILINEAR)
        window_size = self.determine_window_size(
            max(new_img_height, new_img_width),
            min(new_img_height, new_img_width))

        if window_size == 0:
            return img, [], None
        else:
            new_img_width, new_img_height = self.get_image_size_for_crop(
                new_img_width, new_img_height, window_size)
            if (new_img_width, new_img_height) != (img_width, img_height):
                img_for_crop = img.resize((new_img_width, new_img_height),
                                          Image.Resampling.BILINEAR)
            else:
                img_for_crop = img

            patches = []
            newlines = []
            center_list, (x_num, y_num) = self.slide_window(
                new_img_width, new_img_height, [(window_size, window_size)],
                [(window_size, window_size)])
            for patch_id, center_lf_point in enumerate(center_list):
                x, y, patch_w, patch_h = center_lf_point
                big_patch = self.patch_crop(img_for_crop, y, x, patch_h,
                                            patch_w)
                patches.append(big_patch)
                if (patch_id + 1) % x_num == 0:
                    newlines.append(patch_id)

            if newlines and newlines[-1] == len(patches) - 1:
                newlines.pop()

            return img, patches, [i in newlines for i in range(len(patches))
                                  ] if len(patches) > 0 else None


class Step3VLProcessor:

    def __init__(
        self,
        config: PretrainedConfig,
        tokenizer: AnyTokenizer,
    ) -> None:
        super().__init__()

        self.config = config
        self.tokenizer = tokenizer

        self.image_size = 728
        self.patch_size = 504
        self.image_preprocessor = Step3VisionProcessor(self.image_size,
                                                       "bilinear",
                                                       self.patch_size)

        self.num_image_feature_size = 169
        self.num_patch_feature_size = 81
        self.image_token = "<im_patch>"
        self.image_feature_placeholder = (self.image_token *
                                          self.num_image_feature_size)
        self.patch_feature_placeholder = (self.image_token *
                                          self.num_patch_feature_size)

        self.patcher = ImagePatcher()

    @property
    def image_token_id(self) -> int:
        return self.tokenizer.get_vocab()[self.image_token]

    def get_num_image_tokens(self, img_width: int, img_height: int) -> int:
        num_patches, num_newlines = self.patcher.get_num_patches(
            img_width, img_height)

        return num_patches * (
            self.num_patch_feature_size +
            2) + self.num_image_feature_size + 2 + num_newlines

    def _split_images(self,
                      images: list[Image.Image]) -> list[ImageWithPatches]:
        result = []
        for img in images:
            result.append(self.patcher(img))
        return result

    def _convert_images_to_pixel_values(
        self,
        images: list[Image.Image],
        is_patch: bool = False,
    ) -> list[torch.Tensor]:
        return [
            self.image_preprocessor(img, is_patch=is_patch)["pixel_values"]
            for img in images
        ]

    def _get_patch_repl(
        self,
        num_patches: int,
        patch_newline_mask: list[bool] | None,
    ) -> tuple[str, list[int]]:
        text = ""
        token_ids = []
        for i in range(num_patches):
            assert len(patch_newline_mask) == num_patches
            text += f"<patch_start>{self.patch_feature_placeholder}<patch_end>"
            token_ids.extend(
                [self.tokenizer.convert_tokens_to_ids("<patch_start>")] +
                [self.image_token_id] * self.num_patch_feature_size +
                [self.tokenizer.convert_tokens_to_ids("<patch_end>")])
            if patch_newline_mask and patch_newline_mask[i]:
                text += "<patch_newline>"
                token_ids.append(
                    self.tokenizer.convert_tokens_to_ids("<patch_newline>"))
        return text, token_ids

    def _get_image_repl(
        self,
        num_images: int,
    ) -> tuple[str, list[int]]:
        text = f"<im_start>{self.image_feature_placeholder}<im_end>"
        token_ids = [
            self.tokenizer.convert_tokens_to_ids("<im_start>")
        ] + [self.image_token_id] * self.num_image_feature_size + [
            self.tokenizer.convert_tokens_to_ids("<im_end>")
        ]
        return text * num_images, token_ids * num_images

    def _get_image_repl_features(
        self,
        num_images: int,
        num_patches: int,
        patch_new_line_idx: Optional[list[bool]],
    ) -> tuple[str, list[int]]:
        if num_patches > 0:
            patch_repl, patch_repl_ids = self._get_patch_repl(
                num_patches, patch_new_line_idx)
        else:
            patch_repl = ""
            patch_repl_ids = []
        image_repl, image_repl_ids = self._get_image_repl(num_images)
        return patch_repl + image_repl, patch_repl_ids + image_repl_ids

    def replace_placeholder(self, text: str, placeholder: str,
                            repls: list[str]) -> str:
        parts = text.split(placeholder)

        if len(parts) - 1 != len(repls):
            raise ValueError(
                "The number of placeholders does not match the number of replacements."  # noqa: E501
            )

        result = [parts[0]]
        for i, repl in enumerate(repls):
            result.append(repl)
            result.append(parts[i + 1])

        return "".join(result)

    def __call__(
        self,
        text: Optional[Union[str, list[str]]] = None,
        images: Optional[Union[Image.Image, list[Image.Image]]] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
    ) -> BatchFeature:
        if text is None:
            text = []
        if not isinstance(text, list):
            text = [text]
        if images is None:
            images = []
        if not isinstance(images, list):
            images = [images]

        if len(images) == 0:
            image_inputs = {}
            text_inputs = self.tokenizer(text)
        else:
            splitted_images_data = self._split_images(images)
            pixel_values_lst = []
            patch_pixel_values_lst = []
            patch_newline_mask_lst = []
            image_repl_str_lst = []
            image_repl_ids_lst = []
            num_patches = []
            for raw_img, img_patches, patch_newline_mask in splitted_images_data:  # noqa: E501
                pixel_values_lst.extend(
                    self._convert_images_to_pixel_values([raw_img]))

                if len(img_patches) > 0:
                    patch_pixel_values_lst.extend(
                        self._convert_images_to_pixel_values(img_patches,
                                                             is_patch=True))
                num_patches.append(len(img_patches))

                image_repl_str, image_repl_ids = self._get_image_repl_features(
                    1, len(img_patches), patch_newline_mask)
                image_repl_str_lst.append(image_repl_str)
                image_repl_ids_lst.extend(image_repl_ids)

                if patch_newline_mask is not None:
                    patch_newline_mask_lst.extend(patch_newline_mask)

            image_inputs = {
                "pixel_values": torch.cat(pixel_values_lst),
                "num_patches": num_patches,
            }
            if patch_pixel_values_lst:
                image_inputs["patch_pixel_values"] = torch.cat(
                    patch_pixel_values_lst)
            if patch_newline_mask_lst:
                image_inputs["patch_newline_mask"] = torch.tensor(
                    patch_newline_mask_lst, dtype=torch.bool)

            text = [
                self.replace_placeholder(t, self.image_token,
                                         image_repl_str_lst) for t in text
            ]
            text_inputs = self.tokenizer(text)

        return BatchFeature(
            {
                **text_inputs,
                **image_inputs,
            },
            tensor_type=return_tensors,
        )


class Step3VLProcessingInfo(BaseProcessingInfo):

    def get_hf_processor(self) -> Step3VLProcessor:
        return Step3VLProcessor(
            self.get_hf_config(),
            self.get_tokenizer(),
        )

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}

    def get_max_image_tokens(self) -> int:
        hf_processor = self.get_hf_processor()
        return hf_processor.get_num_image_tokens(
            self.get_image_size_with_most_features().width,
            self.get_image_size_with_most_features().height)

    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
        return {"image": self.get_max_image_tokens()}

    def get_image_size_with_most_features(self) -> ImageSize:
        return ImageSize(3024, 3024)

    def get_num_mm_tokens(self, mm_data: MultiModalDataDict) -> int:
        if len(mm_data) != 1 or "image" not in mm_data:
            raise ValueError(
                "mm_data could only contain one key 'image' for steo1o")

        image_data = mm_data["image"]
        if not isinstance(image_data, (list, tuple)):
            image_data = [image_data]

        return sum(self.get_hf_processor().get_num_image_tokens(
            img.width, img.height) for img in image_data)


class Step3VLDummyInputsBuilder(BaseDummyInputsBuilder[Step3VLProcessingInfo]):

    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        return "<im_patch>" * num_images

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> MultiModalDataDict:
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
        num_images = mm_counts.get("image", 0)

        return {
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }


class Step3VLMultiModalProcessor(BaseMultiModalProcessor[Step3VLProcessingInfo]
                                 ):

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
520
        out_mm_kwargs: MultiModalKwargsItems,
Song's avatar
Song committed
521
522
523
524
525
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        image_placeholder_token_id = hf_processor.image_token_id

        def get_replacement_step1o(item_idx: int):
526
527
            out_item = out_mm_kwargs["image"][item_idx]
            num_patches = int(out_item["num_patches"].data)
Song's avatar
Song committed
528
            if num_patches > 0:
529
                patch_newline_mask = out_item["patch_newline_mask"].data
Song's avatar
Song committed
530
                image_repl_ids = hf_processor._get_image_repl_features(
531
                    1, num_patches, patch_newline_mask.tolist())[1]
Song's avatar
Song committed
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
            else:
                image_repl_ids = hf_processor._get_image_repl_features(
                    1, 0, None)[1]
            return PromptUpdateDetails.select_token_id(
                seq=image_repl_ids,
                embed_token_id=image_placeholder_token_id,
            )

        return [
            PromptReplacement(
                modality="image",
                target=[image_placeholder_token_id],
                replacement=get_replacement_step1o,
            )
        ]

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        num_patches = hf_inputs.get("num_patches", torch.empty(0))

        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            patch_pixel_values=MultiModalFieldConfig.flat_from_sizes(
                "image", num_patches),
            num_patches=MultiModalFieldConfig.batched("image"),
            patch_newline_mask=MultiModalFieldConfig.flat_from_sizes(
                "image", num_patches),
        )


def get_abs_pos(abs_pos, tgt_size):
    dim = abs_pos.size(-1)
    abs_pos_new = abs_pos.squeeze(0)
    cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:]

    src_size = int(math.sqrt(abs_pos_new.shape[0] - 1))
    tgt_size = int(math.sqrt(tgt_size))
    dtype = abs_pos.dtype

    if src_size != tgt_size:
        old_pos_embed = old_pos_embed.view(1, src_size, src_size,
                                           dim).permute(0, 3, 1,
                                                        2).contiguous()
        old_pos_embed = old_pos_embed.to(torch.float32)
        new_pos_embed = F.interpolate(
            old_pos_embed,
            size=(tgt_size, tgt_size),
            mode='bicubic',
            antialias=True,
            align_corners=False,
        ).to(dtype)
        new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
        new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim)
        vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0)
        vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1,
                                                 dim)
        return vision_pos_embed
    else:
        return abs_pos


class Step3VisionEmbeddings(nn.Module):

    def __init__(self, config: Step3VisionEncoderConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.class_embedding = nn.Parameter(torch.randn(1, self.embed_dim))

        self.patch_embedding = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            bias=True,
        )

        self.num_patches = (self.image_size // self.patch_size)**2
        self.pad_tp_size = 4  # hard code for padding
        # To load the pretrained weights, we still use P+1 as the seqlen
        self.position_embedding = torch.nn.Embedding(self.num_patches + 1,
                                                     self.embed_dim)
        self.register_buffer("position_ids",
                             torch.arange(self.num_patches + 1).expand(
                                 (1, -1)),
                             persistent=False)

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        batch_size = pixel_values.shape[0]
        patch_embeds = self.patch_embedding(
            pixel_values)  # shape = [*, width, grid, grid]
        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

        # pad
        class_embeds = self.class_embedding.expand(batch_size, 1, -1)
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
        embeddings = embeddings + get_abs_pos(
            self.position_embedding(self.position_ids), patch_embeds.size(1))
        embeddings = torch.cat([
            embeddings[:, 0, :].unsqueeze(1).repeat(1, self.pad_tp_size - 1,
                                                    1), embeddings
        ],
                               dim=1)
        return embeddings


class Step3VisionAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self,
                 config,
                 quant_config: Optional[QuantizationConfig] = None,
650
651
                 prefix: str = "",
                 use_data_parallel: bool = False):
Song's avatar
Song committed
652
653
654
655
656
657
658
659
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.total_num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.total_num_heads

        self.scale = self.head_dim**-0.5

660
661
        tp_size = (1 if use_data_parallel else
                   get_tensor_model_parallel_world_size())
Song's avatar
Song committed
662
663
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
664
665
666

        self.q_size = self.num_heads * self.head_dim

667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
        self.qkv_proj = QKVParallelLinear(
            self.embed_dim,
            self.head_dim,
            self.total_num_heads,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
            disable_tp=use_data_parallel,
        )
        self.out_proj = RowParallelLinear(self.embed_dim,
                                          self.embed_dim,
                                          bias=True,
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.out_proj",
                                          disable_tp=use_data_parallel)
Song's avatar
Song committed
682

683
684
685
        # Use unified MultiHeadAttention with automatic backend selection
        self.attn = MultiHeadAttention(self.num_heads, self.head_dim,
                                       self.scale)
Song's avatar
Song committed
686
687
688
689
690
691
692
693
694
695
696

    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        """Input shape: Batch x Time x Channel"""
        bsz, tgt_len, _ = hidden_states.size()

        # get query proj
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
697
698
699

        # Use unified MultiHeadAttention with automatic backend selection
        attn_output = self.attn(q, k, v)
Song's avatar
Song committed
700
701
702
703
704
705
706
707
708
709
710

        attn_output, _ = self.out_proj(attn_output)

        return attn_output


class Step3VisionMLP(nn.Module):

    def __init__(self,
                 config,
                 quant_config: Optional[QuantizationConfig] = None,
711
712
                 prefix: str = "",
                 use_data_parallel: bool = False):
Song's avatar
Song committed
713
714
715
        super().__init__()
        self.config = config
        self.activation_fn = get_act_fn(config.hidden_act)
716
717
718
719
720
721
722
723
724
725
726
727
        self.fc1 = ColumnParallelLinear(config.hidden_size,
                                        config.intermediate_size,
                                        bias=True,
                                        quant_config=quant_config,
                                        prefix=f"{prefix}.fc1",
                                        disable_tp=use_data_parallel)
        self.fc2 = RowParallelLinear(config.intermediate_size,
                                     config.hidden_size,
                                     bias=True,
                                     quant_config=quant_config,
                                     prefix=f"{prefix}.fc2",
                                     disable_tp=use_data_parallel)
Song's avatar
Song committed
728
729
730
731
732
733
734
735
736
737
738
739
740

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


class Step3VisionEncoderLayer(nn.Module):

    def __init__(self,
                 config: Step3VisionEncoderConfig,
                 quant_config: Optional[QuantizationConfig] = None,
741
742
                 prefix: str = "",
                 use_data_parallel: bool = False):
Song's avatar
Song committed
743
        super().__init__()
744
        self.use_data_parallel = use_data_parallel
Song's avatar
Song committed
745
        self.embed_dim = config.hidden_size
746
747
748
749
750
        self.self_attn = Step3VisionAttention(
            config,
            quant_config,
            prefix=f"{prefix}.self_attn",
            use_data_parallel=self.use_data_parallel)
Song's avatar
Song committed
751
752
        self.layer_norm1 = nn.LayerNorm(self.embed_dim,
                                        eps=config.layer_norm_eps)
753
754
755
756
        self.mlp = Step3VisionMLP(config,
                                  quant_config,
                                  prefix=f"{prefix}.mlp",
                                  use_data_parallel=self.use_data_parallel)
Song's avatar
Song committed
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
        self.layer_norm2 = nn.LayerNorm(self.embed_dim,
                                        eps=config.layer_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.FloatTensor:
        hidden_states = hidden_states + self.layer_norm1(
            self.self_attn(hidden_states))
        hidden_states = hidden_states + self.layer_norm2(
            self.mlp(hidden_states))
        return hidden_states


class Step3VisionEncoder(nn.Module):

    def __init__(self,
                 config: Step3VisionEncoderConfig,
                 quant_config: Optional[QuantizationConfig] = None,
776
777
                 prefix: str = "",
                 use_data_parallel: bool = False):
Song's avatar
Song committed
778
779
        super().__init__()
        self.config = config
780
        self.use_data_parallel = use_data_parallel
Song's avatar
Song committed
781
782
783
        self.layers = nn.ModuleList([
            Step3VisionEncoderLayer(config,
                                    quant_config,
784
785
                                    prefix=f"{prefix}.layers.{i}",
                                    use_data_parallel=self.use_data_parallel)
Song's avatar
Song committed
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
            for i in range(config.num_hidden_layers)
        ])

    def forward(
        self,
        inputs_embeds,
    ):
        hidden_states = inputs_embeds
        for encoder_layer in self.layers:
            hidden_states = encoder_layer(hidden_states)
        return hidden_states


class Step3VisionTransformer(nn.Module):

    def __init__(self,
                 config: Step3VisionEncoderConfig,
                 quant_config: Optional[QuantizationConfig] = None,
804
805
                 prefix: str = "",
                 use_data_parallel: bool = False):
Song's avatar
Song committed
806
807
        super().__init__()
        self.config = config
808
        self.use_data_parallel = use_data_parallel
Song's avatar
Song committed
809
810
        self.image_size = config.image_size
        self.embeddings = Step3VisionEmbeddings(config)
811
812
813
814
815
        self.transformer = Step3VisionEncoder(
            config,
            quant_config,
            prefix=f"{prefix}.transformer",
            use_data_parallel=self.use_data_parallel)
Song's avatar
Song committed
816
817
818
819
820
821

    def forward(
        self,
        pixel_values: torch.Tensor,
    ):
        hidden_states = self.embeddings(pixel_values)
822
823
824
825
826
        if self.use_data_parallel:
            hidden_states = run_dp_sharded_vision_model(
                hidden_states, self.transformer)
        else:
            hidden_states = self.transformer(inputs_embeds=hidden_states)
Song's avatar
Song committed
827
828
829
830
831
832
833
834
835
836
837
838
839
840
        return hidden_states


@MULTIMODAL_REGISTRY.register_processor(Step3VLMultiModalProcessor,
                                        info=Step3VLProcessingInfo,
                                        dummy_inputs=Step3VLDummyInputsBuilder)
class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
                                      SupportsPP):

    hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
        "model.": "language_model.model.",
        "lm_head.": "language_model.lm_head.",
    })

841
842
    supports_encoder_tp_data = True

Song's avatar
Song committed
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<im_patch>"

        raise ValueError("Only image modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__()

        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
        self.multimodal_config = multimodal_config
858
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
Song's avatar
Song committed
859

860
        if multimodal_config.get_limit_per_prompt("image"):
861
862
863
864
865
            self.vision_model = Step3VisionTransformer(
                config.vision_config,
                None,
                prefix=maybe_prefix(prefix, "vision_model"),
                use_data_parallel=self.use_data_parallel)
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
            self.vit_downsampler = nn.Conv2d(
                config.vision_config.hidden_size,
                config.vision_config.output_hidden_size,
                kernel_size=2,
                stride=config.understand_projector_stride)
            self.vit_downsampler2 = nn.Conv2d(
                config.vision_config.output_hidden_size,
                config.vision_config.output_hidden_size * 2,
                kernel_size=3,
                stride=2,
                padding=1,
            )
            self.vit_large_projector = nn.Linear(
                config.vision_config.output_hidden_size * 2,
                config.hidden_size,
                bias=config.projector_bias,
            )
        else:
            self.vision_model = None
            self.vit_downsampler = None
            self.vit_downsampler2 = None
            self.vit_large_projector = None

Song's avatar
Song committed
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"))

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @property
    def device(self):
        return next(self.parameters()).device

    @property
    def dtype(self):
        return next(self.parameters()).dtype

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[Step3VLImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
        patch_pixel_values = kwargs.pop("patch_pixel_values", None)
        num_patches = kwargs.pop("num_patches", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            pixel_values = flatten_bn(pixel_values, concat=True)
            if pixel_values.dim() >= 3:
                pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:])
            if patch_pixel_values is not None:
                patch_pixel_values = flatten_bn(patch_pixel_values,
                                                concat=True)
                patch_pixel_values = patch_pixel_values.view(
                    -1, *patch_pixel_values.shape[-3:])
                # Handle empty patch_pixel_values by setting to None
                if patch_pixel_values.shape[0] == 0:
                    patch_pixel_values = None
            num_patches = flatten_bn(num_patches, concat=True).tolist()

            return Step3VLImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values.to(self.dtype).to(self.device),
                patch_pixel_values=patch_pixel_values.to(self.dtype).to(
                    self.device) if patch_pixel_values is not None else None,
                num_patches=num_patches,
            )

        if image_embeds is not None:
            if image_embeds.dim() == 2 or image_embeds.dim() >= 3:
                image_embeds = image_embeds.view(-1, image_embeds.shape[-1])
            else:
                raise ValueError(
                    f"Unexpected shape for image_embeds: {image_embeds.shape}")

            return Step3VLImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds.to(self.dtype).to(self.device),
            )
        return None

    def _process_image_features(self,
                                image_features: torch.Tensor) -> torch.Tensor:
        B, P = image_features.shape[:2]
        HW = int(sqrt(P))
        image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW)
        image_features = self.vit_downsampler(image_features)
        image_features = self.vit_downsampler2(image_features)
        n_dim = image_features.size(1)
        image_features = image_features.view(B, n_dim, -1).permute(0, 2, 1)
        image_features = self.vit_large_projector(image_features)
        return image_features

    def _get_vision_model_output(self,
                                 input_tensor: torch.Tensor) -> torch.Tensor:
        return self.vision_model(input_tensor)[:, 4:]

    def _process_image_input(
            self, image_input: Step3VLImageInputs) -> tuple[torch.Tensor, ...]:

        if image_input["type"] == "image_embeds":
            image_features = image_input["image_embeds"]
        else:
            image_features = self._get_vision_model_output(
                image_input["pixel_values"])
            patch_image_features = self._get_vision_model_output(
                image_input["patch_pixel_values"]
            ) if image_input["patch_pixel_values"] is not None else None
            num_patches = image_input["num_patches"]

        image_features = self._process_image_features(image_features)
        patch_image_features = self._process_image_features(
            patch_image_features) if patch_image_features is not None else None

        merged_image_features = []
        cur_patch_idx = 0
        for i, num_patch in enumerate(num_patches):
            cur_feature = []
            if num_patch > 0:
                patch_slice = patch_image_features[
                    cur_patch_idx:cur_patch_idx + num_patch]
                cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1]))
            cur_feature.append(image_features[i].view(
                -1, image_features.shape[-1]))
            cur_patch_idx += num_patch
            merged_image_features.append(
                torch.cat(cur_feature) if len(cur_feature) >
                1 else cur_feature[0])
        return merged_image_features

    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
    ) -> torch.Tensor:
        if multimodal_embeddings is None:
            inputs_embeds = self.language_model.model.get_input_embeddings(
                input_ids)
        else:
            is_text = input_ids != self.config.image_token_id
            text_ids = input_ids[is_text]
            text_embeds = self.language_model.model.get_input_embeddings(
                text_ids)
            inputs_embeds = torch.empty(input_ids.shape[0],
                                        text_embeds.shape[-1],
                                        dtype=text_embeds.dtype,
                                        device=text_embeds.device)
            inputs_embeds[is_text] = text_embeds
            inputs_embeds = merge_multimodal_embeddings(
                input_ids, inputs_embeds, multimodal_embeddings,
                self.config.image_token_id)
        return inputs_embeds

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if intermediate_tensors is not None:
            inputs_embeds = None
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            # always pass the input via `inputs_embeds`
            # to make sure the computation graph is consistent
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None

        hidden_states = self.language_model(input_ids,
                                            positions,
                                            intermediate_tensors,
                                            inputs_embeds=inputs_embeds)

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
1058
        return self.language_model.compute_logits(hidden_states)
Song's avatar
Song committed
1059
1060

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
1061
1062
1063
1064
1065
1066
1067
1068
1069

        skip_prefixes = []
        if self.vision_model is None and self.vit_large_projector is None:
            skip_prefixes = [
                "vision_model.", "vit_downsampler.", "vit_downsampler2.",
                "vit_large_projector."
            ]

        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
Song's avatar
Song committed
1070
1071
1072
        loaded_weights = loader.load_weights(weights,
                                             mapper=self.hf_to_vllm_mapper)
        return loaded_weights