deepseek_vl2.py 21.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py
"""Inference-only Deepseek-VL2 model compatible with HuggingFace weights."""
6

7
import math
8
from collections.abc import Iterable, Mapping, Sequence
9
from typing import Annotated, Literal, TypeAlias
10
11
12
13
14

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
15
from transformers import BatchFeature
16
17

from vllm.config import VllmConfig
18
from vllm.config.multimodal import BaseDummyOptions
19
from vllm.distributed import get_tensor_model_parallel_world_size
20
from vllm.model_executor.layers.quantization import QuantizationConfig
21
from vllm.model_executor.models.transformers.utils import replace_linear_class
22
from vllm.multimodal import MULTIMODAL_REGISTRY
23
24
25
26
27
28
29
30
31
32
33
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
)
34
35
from vllm.multimodal.processing import BaseDummyInputsBuilder
from vllm.multimodal.processing.processor import (
36
37
38
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    MultiModalProcessingInfo,
39
    ProcessorInputs,
40
41
    PromptReplacement,
    PromptUpdate,
42
    TimingContext,
43
)
44
from vllm.sequence import IntermediateTensors
45
from vllm.tokenizers import cached_tokenizer_from_config
46
47
48
49
50
51
from vllm.transformers_utils.configs.deepseek_vl2 import (
    DeepseekVLV2Config,
    MlpProjectorConfig,
    VisionEncoderConfig,
)
from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor
52
from vllm.utils.tensor_schema import TensorSchema, TensorShape
53
from vllm.utils.torch_utils import set_default_torch_dtype
54

55
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
56
57
58
59
60
61
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
62
63
64
65
66

# The image token id may be various
_IMAGE_TOKEN = "<image>"


67
class DeepseekVL2ImagePixelInputs(TensorSchema):
68
    """
69
    Dimensions:
70
        - bnp: Batch size * number of images * number of patches
71
        - p: Number of patches
72
73
74
        - c: Number of channels (3)
        - h: Height of each image
        - w: Width of each image
75
    """
76

77
    type: Literal["pixel_values"]
78
    data: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w", dynamic_dims={"bnp"})]
79
    images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)]
80
81


82
class DeepseekVL2VImageEmbeddingInputs(TensorSchema):
83
    """
84
85
86
87
88
    Dimensions:
        - bn: Batch size * number of images
        - f: Image feature size
        - h: Hidden size (must match language model backbone)
    """
89

90
    type: Literal["image_embeds"]
91
    data: Annotated[torch.Tensor | list[torch.Tensor], TensorShape("bn", "f", "h")]
92
93


94
95
96
DeepseekVL2ImageInputs: TypeAlias = (
    DeepseekVL2ImagePixelInputs | DeepseekVL2VImageEmbeddingInputs
)
97
98
99
100
101
102
103


class MlpProjector(nn.Module):
    def __init__(self, cfg: MlpProjectorConfig):
        super().__init__()

        self.cfg = cfg
104
        self.projector_type = cfg.projector_type
105
        assert not cfg.token_pooling, "Token pooling is not supported currently."
106

107
        if self.projector_type == "downsample_mlp_gelu":
108
109
110
111
            mlp_depth = cfg.depth
            mlp_ratio = cfg.mlp_ratio
            modules = [
                nn.Linear(
112
113
114
                    cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio,
                    cfg.n_embed * mlp_ratio,
                )
115
116
117
118
            ]
            for _ in range(1, mlp_depth - 1):
                modules.append(nn.GELU())
                modules.append(
119
120
                    nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio)
                )
121
122
123
            modules.append(nn.GELU())
            modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
            modules = nn.Sequential(*modules)
124
125
        elif self.projector_type == "linear":
            modules = nn.Linear(cfg.input_dim, cfg.n_embed)
126
127
        else:
            raise NotImplementedError(
128
129
                f"Unsupported projector type: {cfg.projector_type}"
            )
130
131
132
133
134

        self.layers = modules

    def forward(self, x):
        bs, hw, input_dim = x.shape
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        if self.projector_type == "downsample_mlp_gelu":
            h = w = int((hw) ** 0.5)
            """compute padding"""
            if h % self.cfg.downsample_ratio:
                pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio
            else:
                pad = 0
            x = x.reshape(bs, h, w, input_dim)
            if pad > 0:
                x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
            """4 to 1 concat"""
            x = x.permute(0, 3, 1, 2)  # B, C, H, W
            x = F.unfold(
                x,
                kernel_size=self.cfg.downsample_ratio,
                stride=self.cfg.downsample_ratio,
                padding=0,
            )  # B, C*4, HW // 4
            x = x.permute(0, 2, 1)
154
155
156
157
158
159
160
161

        return self.layers(x)


class DeepseekVL2ProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
        return self.ctx.get_hf_config(DeepseekVLV2Config)

162
163
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(DeepseekVLV2Processor, **kwargs)
164

165
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
166
167
        return {"image": None}

168
169
170
    def get_num_image_tokens(
        self, *, image_width: int, image_height: int, cropping: bool = True
    ) -> int:
171
172
173
174
175
        hf_processor = self.get_hf_processor()
        image_size = hf_processor.image_size
        patch_size = hf_processor.patch_size
        downsample_ratio = hf_processor.downsample_ratio

176
177
        if cropping:
            best_width, best_height = hf_processor.select_best_resolution(
178
179
180
181
182
183
                (image_width, image_height)
            )
            num_width_tiles, num_height_tiles = (
                best_width // image_size,
                best_height // image_size,
            )
184
185
        else:
            num_width_tiles = num_height_tiles = 1
186
187
188
189
190
191
192
193
194
195

        h = w = math.ceil((image_size // patch_size) / downsample_ratio)

        global_views_tokens = h * (w + 1)
        local_views_tokens = (num_height_tiles * h) * (num_width_tiles * w + 1)
        return global_views_tokens + local_views_tokens + 1

    def get_image_size_with_most_features(self) -> ImageSize:
        hf_config = self.get_hf_config()
        candidate_resolutions = hf_config.candidate_resolutions
196
197
198
199
200
201
        height, width = max(
            candidate_resolutions,
            key=lambda x: self.get_num_image_tokens(
                image_width=x[1], image_height=x[0]
            ),
        )
202
203
204
        return ImageSize(width=width, height=height)


205
class DeepseekVL2DummyInputsBuilder(BaseDummyInputsBuilder[DeepseekVL2ProcessingInfo]):
206
207
208
209
210
211
212
213
214
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images

    def get_dummy_mm_data(
215
216
217
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
218
        mm_options: Mapping[str, BaseDummyOptions],
219
    ) -> MultiModalDataDict:
220
221
222
223
        num_images = mm_counts.get("image", 0)

        max_image_size = self.info.get_image_size_with_most_features()

224
        image_overrides = mm_options.get("image")
225

226
        return {
227
228
229
230
231
232
            "image": self._get_dummy_images(
                width=max_image_size.width,
                height=max_image_size.height,
                num_images=num_images,
                overrides=image_overrides,
            )
233
234
235
236
        }


class DeepseekVL2MultiModalProcessor(
237
238
    BaseMultiModalProcessor[DeepseekVL2ProcessingInfo]
):
239
240
241
242
243
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
244
        tok_kwargs: Mapping[str, object],
245
    ) -> BatchFeature:
246
        if not mm_data:
247
            tokenizer = self.info.get_tokenizer()
248
            return tokenizer(prompt, add_special_tokens=True, return_tensors="pt")
249
250
251
252
253
254
255
256

        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
            tok_kwargs=tok_kwargs,
        )

257
        processed_outputs["num_patches"] = (
258
259
            processed_outputs["images_spatial_crop"].prod(-1) + 1
        )
260
261
262
263
264
265
266
267

        return processed_outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
268
269
        num_patches = hf_inputs.get("num_patches", torch.empty(0))

270
        return dict(
271
            pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches),
272
273
274
275
            images_spatial_crop=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

276
    def _get_prompt_updates(
277
278
279
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
280
        out_mm_kwargs: MultiModalKwargsItems,
281
    ) -> Sequence[PromptUpdate]:
282
283
284
285
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

        image_token_id = hf_processor.image_token_id
        assert isinstance(image_token_id, int)
286
287
288

        def get_replacement_deepseek_vl2(item_idx: int):
            images = mm_items.get_items(
289
290
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
291
292
293
294
295
296
297
298
299

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)

                num_image_tokens = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
300
                    cropping=len(images) <= 2,
301
302
303
304
305
306
307
308
309
310
311
                )
            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement_deepseek_vl2,
            )
        ]

312
313
    def _cached_apply_hf_processor(
        self,
314
315
        inputs: ProcessorInputs,
        timing_ctx: TimingContext,
316
    ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
317
318
319
320
        # The processor logic is different for len(images) <= 2 vs > 2
        # Since the processing cache assumes that the processor output is
        # invariant of how many images are passed per prompt, we only
        # perform caching for the most common case
321
322
        if inputs.mm_data_items.get_count("image", strict=False) > 2:
            return self._apply_hf_processor(inputs, timing_ctx)
323

324
        return super()._cached_apply_hf_processor(inputs, timing_ctx)
325

326
327
328
329

@MULTIMODAL_REGISTRY.register_processor(
    DeepseekVL2MultiModalProcessor,
    info=DeepseekVL2ProcessingInfo,
330
331
    dummy_inputs=DeepseekVL2DummyInputsBuilder,
)
332
class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
333
334
335
336
337
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "language.": "language_model.",
        }
    )
338

339
    @classmethod
340
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
341
342
343
344
345
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

346
347
348
349
350
351
352
353
354
355
356
357
358
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config: DeepseekVLV2Config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
        self.multimodal_config = multimodal_config

        self.vision_config = config.vision_config
        self.projector_config = config.projector_config
        self.text_config = config.text_config

359
360
        model_config = vllm_config.model_config
        tokenizer = cached_tokenizer_from_config(model_config)
361
        self.image_token_id: int = tokenizer.vocab[_IMAGE_TOKEN]
362

363
364
365
366
        with self._mark_tower_model(vllm_config, "image"):
            self.vision = self._init_vision_module(
                self.vision_config, quant_config, maybe_prefix(prefix, "vision")
            )
367

368
369
370
            self.projector = MlpProjector(self.projector_config)
            self.tile_tag = config.tile_tag
            self.global_view_pos = config.global_view_pos
371

372
373
374
            # special token for image token sequence format
            embed_std = 1 / torch.sqrt(
                torch.tensor(self.projector_config.n_embed, dtype=torch.float32)
375
            )
376
377
378
379
380
381
382
383
384
385
386
387
388
            if self.tile_tag == "2D":
                # <|view_seperator|>, <|\n|>
                self.image_newline = nn.Parameter(
                    torch.randn(self.projector_config.n_embed) * embed_std
                )
                # This is a typo in original implementation
                self.view_seperator = nn.Parameter(
                    torch.randn(self.projector_config.n_embed) * embed_std
                )
            else:
                raise ValueError(
                    f"Only 2D tile_tag is supported currently, got: {self.tile_tag}"
                )
389

390
391
392
393
394
395
        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=self.text_config,
                prefix=maybe_prefix(prefix, "language"),
            )
396
397

        self.make_empty_intermediate_tensors = (
398
399
            self.language_model.make_empty_intermediate_tensors
        )
400

401
402
    def _get_parent_and_attr(self, root: torch.nn.Module, dotted_name: str):
        """Return (parent_module, final_attr_name) for a dotted module path."""
403
        names = dotted_name.split(".")
404
405
406
407
408
        parent = root
        for n in names[:-1]:
            parent = getattr(parent, n)
        return parent, names[-1]

409
410
    # patch for timm ViT instance to support tensor parallel
    def patch_vit_for_tp(self, vit: torch.nn.Module, quant_config: QuantizationConfig):
411
412
413
414
415
416
417
418
419
        try:
            import timm
        except ImportError as e:
            raise ImportError("Please install timm") from e

        for name, module in vit.named_modules():
            if isinstance(module, nn.Linear):
                parent, attr_name = self._get_parent_and_attr(vit, name)
                if isinstance(parent, timm.layers.Mlp) and attr_name == "fc1":
420
421
422
                    new_linear = replace_linear_class(
                        module, "colwise", quant_config, prefix=name
                    )
423
                    setattr(parent, attr_name, new_linear)
424
425
426
427
                elif isinstance(parent, timm.layers.Mlp) and attr_name == "fc2":
                    new_linear = replace_linear_class(
                        module, "rowwise", quant_config, prefix=name
                    )
428
429
430
431
                    setattr(parent, attr_name, new_linear)

        return vit

432
433
434
    def _init_vision_module(
        self,
        vision_config: VisionEncoderConfig,
435
        quant_config: QuantizationConfig | None,
436
437
438
439
440
        prefix: str = "",
    ) -> nn.Module:
        # TODO: refactor vision model through timm wrapper from transformers
        try:
            import timm
441
442
        except ImportError as e:
            raise ImportError("Please install timm") from e
443
444
445
446
447
448
449
450
451
452

        with set_default_torch_dtype(torch.float16):
            model = timm.create_model(
                "vit_so400m_patch14_siglip_384.webli",
                pretrained=False,
                num_classes=0,
                dynamic_img_size=True,
                dynamic_img_pad=True,
            )

453
454
455
        if get_tensor_model_parallel_world_size() > 1:
            model = self.patch_vit_for_tp(model, quant_config)

456
457
458
459
        model = model.to(dtype=torch.get_default_dtype())
        return model

    def _parse_and_validate_image_input(
460
        self, **kwargs: object
461
    ) -> DeepseekVL2ImageInputs | None:
462
463
464
465
466
467
468
469
        pixel_values = kwargs.pop("pixel_values", None)
        images_spatial_crop = kwargs.pop("images_spatial_crop", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
470
            expected_h = expected_w = self.vision_config.image_size
471
472
473
474
475
476
477
            return DeepseekVL2ImagePixelInputs(
                type="pixel_values",
                data=pixel_values,
                images_spatial_crop=images_spatial_crop,
                resolve_bindings={
                    "h": expected_h,
                    "w": expected_w,
478
479
                },
            )
480
481
482
483

        if image_embeds is not None:
            return DeepseekVL2VImageEmbeddingInputs(
                type="image_embeds",
484
                data=image_embeds,
485
486
487
488
489
490
            )

        raise AssertionError("This line should be unreachable.")

    def _pixel_values_to_embedding(
        self,
491
        pixel_values: torch.Tensor,
492
        images_spatial_crop: torch.Tensor,
493
    ) -> list[torch.Tensor]:
494
        # [batch_all_tiles, vit_seq_len, c]
495
        images_feature = self.vision.forward_features(pixel_values)
496
497
498
499
500
501
502

        # [batch_all_tiles, hw, D]
        images_embeds = self.projector(images_feature)

        _, hw, n_dim = images_embeds.shape
        h = w = int(hw**0.5)

503
        # fill image token based on self.tile_tag & self.global_view_pos
504
505
506
507
508
509
510
511
512
513
514
515
516
        tile_index = 0
        vision_embeddings = []
        for jdx in range(images_spatial_crop.size(0)):
            # extra global & local features
            num_width_tiles, num_height_tiles = images_spatial_crop[jdx]
            if num_width_tiles == 0 or num_height_tiles == 0:
                break
            num_tiles_in_image = num_width_tiles * num_height_tiles

            # [hw, D]
            global_features = images_embeds[tile_index]

            # [num_height_tiles * num_width_tiles, hw, D]
517
518
519
            local_features = images_embeds[
                tile_index + 1 : tile_index + 1 + num_tiles_in_image
            ]
520
521
522
523
524
525
526
527
528
529
530
            tile_index += num_tiles_in_image + 1

            # format global and local features
            # ----------------- global view add newline -----------------
            # [hw, D] -> [h, w, D]
            global_features = global_features.view(h, w, n_dim)

            # [D]     -> [h, 1, D]
            new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)

            # cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
531
            global_features = torch.cat([global_features, new_lines_in_global], dim=1)
532
533
534
535
536
537
538

            # [h, w + 1, D] -> [h * (w + 1), D]
            global_features = global_features.view(-1, n_dim)

            # ----------------- local view add newline -----------------
            # [num_height_tiles * num_width_tiles, h * w, D] ->
            # [num_height_tiles * h, num_width_tiles * w, D]
539
540
541
542
543
544
545
546
            local_features = rearrange(
                local_features,
                "(th tw) (h w) d -> (th h) (tw w) d",
                th=num_height_tiles,
                tw=num_width_tiles,
                h=h,
                w=w,
            )
547
548

            # [D] -> [num_height_tiles * h, 1, D]
549
550
551
            new_lines_in_local = repeat(
                self.image_newline, "d -> (th h) 1 d", th=num_height_tiles, h=h
            )
552
553

            # [num_height_tiles * h, num_width_tiles * w + 1, D]
554
            local_features = torch.cat([local_features, new_lines_in_local], dim=1)
555
556
557
558
559
560
561

            # [num_height_tiles * h, num_width_tiles * w + 1, D]
            #   --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
            local_features = local_features.view(-1, n_dim)

            # merge global and local tiles
            if self.global_view_pos == "head":
562
563
564
565
566
567
568
                global_local_features = torch.cat(
                    [
                        global_features,
                        self.view_seperator[None, :],
                        local_features,
                    ]
                )
569
            else:
570
571
572
573
574
575
576
                global_local_features = torch.cat(
                    [
                        local_features,
                        self.view_seperator[None, :],
                        global_features,
                    ]
                )
577
578
579
580
581

            vision_embeddings.append(global_local_features)
        return vision_embeddings

    def _process_image_input(
582
        self, image_input: DeepseekVL2ImageInputs
583
    ) -> torch.Tensor | list[torch.Tensor]:
584
        if image_input["type"] == "image_embeds":
585
            return image_input["data"]
586
587
588
589
590

        pixel_values = image_input["data"]
        images_spatial_crop = image_input["images_spatial_crop"]

        return self._pixel_values_to_embedding(
591
592
            pixel_values=pixel_values, images_spatial_crop=images_spatial_crop
        )
593

594
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
595
596
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
597
            return []
598
599
600
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

601
602
    def forward(
        self,
603
        input_ids: torch.Tensor | None,
604
        positions: torch.Tensor,
605
606
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
607
608
        **kwargs: object,
    ):
609
610
611
        if intermediate_tensors is not None:
            inputs_embeds = None

612
613
614
        hidden_states = self.language_model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
615
616
617
618
619
620

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
621
    ) -> torch.Tensor | None:
622
        return self.language_model.compute_logits(hidden_states)
623

624
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
625
        loader = AutoWeightsLoader(self)
626
        autoloaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
627
        return autoloaded_weights