mistral3.py 20.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5

from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence
6
from typing import Annotated, Final, Literal, Protocol, TypeVar
7
8
9

import torch
import torch.nn as nn
10
11
12
13
14
15
from transformers import (
    BatchFeature,
    Mistral3Config,
    PixtralVisionConfig,
    PretrainedConfig,
)
16
17
18
from transformers.models.pixtral import PixtralProcessor

from vllm.config import VllmConfig
19
from vllm.config.multimodal import BaseDummyOptions
20
21
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
22
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
23
from vllm.model_executor.layers.quantization import QuantizationConfig
24
from vllm.model_executor.models.module_mapping import MultiModelKeys
25
from vllm.multimodal import MULTIMODAL_REGISTRY
26
from vllm.multimodal.cache import BaseMultiModalProcessorCache
27
28
29
30
31
32
33
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
34
    BaseDummyInputsBuilder,
35
36
37
38
39
40
41
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    InputProcessingContext,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
42
from vllm.sequence import IntermediateTensors
43
from vllm.utils.tensor_schema import TensorSchema, TensorShape
44

45
46
from .interfaces import (
    MultiModalEmbeddings,
47
    SupportsEagle,
48
    SupportsEagle3,
49
50
51
52
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
53
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
54
55
56
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
57
    get_layer_index,
58
59
60
    init_vllm_registered_model,
    maybe_prefix,
)
61
from .vision import get_vision_encoder_info
62
63


64
class Mistral3ImagePixelInputs(TensorSchema):
65
    """
66
67
68
69
70
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height of each image
        - w: Width of each image
71
72
    """

73
74
75
76
77
    type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral"

    # Note that `height` or `width` may be different per batch and image,
    # in which case the data is passed as a list instead of a batched tensor.
    pixel_values: Annotated[
78
        torch.Tensor | list[torch.Tensor],
79
80
81
        TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"}),
    ]

82
83
84
85
86
87

class Mistral3PatchMerger(nn.Module):
    """
    Learned merging of spatial_merge_size ** 2 patches
    """

88
89
90
    def __init__(
        self, vision_hidden_size: int, spatial_merge_size: int, patch_size: int
    ):
91
92
93
94
95
        super().__init__()

        self.vision_hidden_size = vision_hidden_size
        self.spatial_merge_size = spatial_merge_size
        self.patch_size = patch_size
96
97
98
99
100
        self.merging_layer = nn.Linear(
            vision_hidden_size * self.spatial_merge_size**2,
            vision_hidden_size,
            bias=False,
        )
101

102
103
104
105
106
107
108
    def forward(
        self, image_features: torch.Tensor, image_sizes: torch.Tensor
    ) -> torch.Tensor:
        image_sizes = [
            (image_size[0] // self.patch_size, image_size[1] // self.patch_size)
            for image_size in image_sizes
        ]
109
110
111
112
113
114

        tokens_per_image = [h * w for h, w in image_sizes]
        d = image_features.shape[-1]

        permuted_tensor = []
        for image_index, image_tokens in enumerate(
115
116
            image_features.split(tokens_per_image)
        ):
117
118
            # Reshape image_tokens into a 2D grid
            h, w = image_sizes[image_index]
119
            image_grid = image_tokens.view(h, w, d).permute(2, 0, 1).unsqueeze(0)
120
121
122
            grid = torch.nn.functional.unfold(
                image_grid,
                kernel_size=self.spatial_merge_size,
123
124
                stride=self.spatial_merge_size,
            )
125
126
127
128
129
130
131
132
133
            grid = grid.view(d * self.spatial_merge_size**2, -1).t()
            permuted_tensor.append(grid)

        image_features = torch.cat(permuted_tensor, dim=0)
        image_features = self.merging_layer(image_features)
        return image_features


class Mistral3MultiModalProjector(nn.Module):
134
135
136
137
138
139
140
141
    def __init__(
        self,
        vision_hidden_size: int,
        text_hidden_size: int,
        spatial_merge_size: int,
        patch_size: int,
        projector_hidden_act: str,
        multimodal_projector_bias: bool,
142
        quant_config: QuantizationConfig | None = None,
143
144
        prefix: str = "",
    ):
145
146
147
148
149
150
        super().__init__()

        self.norm = RMSNorm(vision_hidden_size, eps=1e-5)
        self.patch_merger = Mistral3PatchMerger(
            vision_hidden_size=vision_hidden_size,
            spatial_merge_size=spatial_merge_size,
151
152
            patch_size=patch_size,
        )
153

154
155
156
157
158
159
160
        self.linear_1 = ColumnParallelLinear(
            vision_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_1",
        )
161
        self.act = get_act_fn(projector_hidden_act)
162
163
164
165
166
167
168
169
170
171
172
        self.linear_2 = RowParallelLinear(
            text_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_2",
        )

    def forward(
        self, image_features: torch.Tensor, image_sizes: torch.Tensor
    ) -> torch.Tensor:
173
174
175
176
177
178
179
180
181
182
183
184
        image_features = self.norm(image_features)
        image_features = self.patch_merger(image_features, image_sizes)
        hidden_states, _ = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.linear_2(hidden_states)
        return hidden_states


class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
    image_token_index: Final[int]
    vision_feature_select_strategy: Final[str]
185
    vision_feature_layer: Final[int | list[int]]
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202


class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


class BaseLlavaProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> LlavaLikeConfig:
        return self.ctx.get_hf_config(Mistral3Config)

    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())

    @abstractmethod
    def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
        raise NotImplementedError

203
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
        return {"image": None}

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        vision_encoder_info = self.get_vision_encoder_info()
        return vision_encoder_info.get_num_image_tokens(
            image_width=image_width,
            image_height=image_height,
        )

    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)


_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class Mistral3DummyInputsBuilder(BaseDummyInputsBuilder[_I]):
228
229
230
231
232
233
234
235
236
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images

    def get_dummy_mm_data(
237
238
239
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
240
        mm_options: Mapping[str, BaseDummyOptions],
241
    ) -> MultiModalDataDict:
242
243
        num_images = mm_counts.get("image", 0)

244
        target_width, target_height = self.info.get_image_size_with_most_features()
245

246
        image_overrides = mm_options.get("image")
247

248
        return {
249
250
251
252
253
254
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
255
256
257
258
259
260
261
262
        }


class Mistral3ProcessingInfo(BaseLlavaProcessingInfo):
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)


263
class Mistral3MultiModalProcessor(BaseMultiModalProcessor[Mistral3ProcessingInfo]):
264
265
266
267
268
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
269
        tok_kwargs: Mapping[str, object],
270
271
272
273
274
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
275
            tok_kwargs=tok_kwargs,
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
        )

        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
            # Avoid padding since we need the output for each image to be
            # independent of other images for the cache to work correctly
            image_sizes = processed_outputs["image_sizes"]
            assert len(pixel_values) == len(image_sizes)

            processed_outputs["pixel_values"] = [
                p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
            ]

        return processed_outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
305
        out_mm_kwargs: MultiModalKwargsItems,
306
307
308
309
310
311
312
313
314
315
    ) -> Sequence[PromptUpdate]:
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        hf_config = self.info.get_hf_config()
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()

        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]

316
317
        assert isinstance(hf_config.vision_config, PixtralVisionConfig)
        encoder_info = PixtralHFEncoderInfo(hf_config)
318
319
320
321
322
323
324
325
326
327
328
329
330

        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)

            ncols, nrows = encoder_info.get_patch_grid_size(
                image_width=image_size.width,
                image_height=image_size.height,
            )

            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id

331
            return PromptUpdateDetails.select_token_id(tokens, image_token_id)
332
333
334
335
336
337
338
339
340
341
342

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


def _build_mistral3_info(
343
344
    ctx: InputProcessingContext,
) -> BaseLlavaProcessingInfo:
345
346
347
348
349
350
351
352
353
    hf_config = ctx.get_hf_config(Mistral3Config)
    assert isinstance(hf_config.vision_config, PixtralVisionConfig)
    return Mistral3ProcessingInfo(ctx)


def _build_mistral3_processor(
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
    *,
354
    cache: BaseMultiModalProcessorCache | None = None,
355
356
357
358
359
360
361
362
363
364
365
366
) -> BaseMultiModalProcessor:
    assert isinstance(info, Mistral3ProcessingInfo)
    return Mistral3MultiModalProcessor(
        info,
        dummy_inputs,  # type: ignore
        cache=cache,
    )


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
367

368
369
370
371
372
373
374
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
375
        return get_layer_index(feature_layers, num_hidden_layers)
376
377
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
378
        return max(get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
379
380
381
    raise TypeError(
        f"vision_layer_feature type: {type(feature_layers)} is not supported"
    )
382
383
384
385


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
386
    quant_config: QuantizationConfig | None,
387
    *,
388
    require_post_norm: bool | None = None,
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
    prefix: str = "",
) -> PixtralHFVisionModel:
    vision_config = hf_config.vision_config

    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)

    assert isinstance(vision_config, PixtralVisionConfig)

    return PixtralHFVisionModel(
        vision_config,
        quant_config=quant_config,
        num_hidden_layers_override=num_hidden_layers,
        require_post_norm=require_post_norm,
        prefix=prefix,
    )


@MULTIMODAL_REGISTRY.register_processor(
    _build_mistral3_processor,
    info=_build_mistral3_info,
410
411
412
    dummy_inputs=Mistral3DummyInputsBuilder,
)
class Mistral3ForConditionalGeneration(
413
414
415
416
417
418
    nn.Module,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
    SupportsEagle,
    SupportsEagle3,
419
):
420
421
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
422
        "gate_up_proj": ["gate_proj", "up_proj"],
423
424
    }

425
426
427
428
429
430
431
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
432
433
434
            # Some PEFT LoRAs are trained against the text submodule directly
            # and produce names like `base_model.model.model.layers.*`.
            "model.": "language_model.model.",
435
436
        }
    )
437

438
    @classmethod
439
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
440
441
442
443
444
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

445
446
447
448
449
450
451
452
453
454
455
456
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__()

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
        self.multimodal_config = multimodal_config

        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
457
458
459
460
        if (
            config.text_config.architectures is None
            and config.text_config.model_type == "mistral"
        ):
461
            config.text_config.architectures = ["MistralForCausalLM"]
462
463
464
465
        if (
            config.projector_hidden_act is None
            and config.vision_config.hidden_act == "gelu"
        ):
466
467
            config.projector_hidden_act = "gelu"

468
        with self._mark_tower_model(vllm_config, "image"):
469
470
            self.vision_tower = init_vision_tower_for_llava(
                config,
471
                quant_config=quant_config,
472
                require_post_norm=False,
473
474
                prefix=maybe_prefix(prefix, "vision_tower"),
            )
475
476
477
478
479
480
481
482
            self.multi_modal_projector = Mistral3MultiModalProjector(
                vision_hidden_size=config.vision_config.hidden_size,
                text_hidden_size=config.text_config.hidden_size,
                projector_hidden_act=config.projector_hidden_act,
                spatial_merge_size=config.spatial_merge_size,
                patch_size=config.vision_config.patch_size,
                multimodal_projector_bias=config.multimodal_projector_bias,
                quant_config=quant_config,
483
484
                prefix=maybe_prefix(prefix, "multi_modal_projector"),
            )
485

486
487
488
489
490
491
        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )
492
493

        self.make_empty_intermediate_tensors = (
494
495
            self.language_model.make_empty_intermediate_tensors
        )
496
497

    def _parse_and_validate_image_input(
498
        self, **kwargs: object
499
    ) -> Mistral3ImagePixelInputs | None:
500
501
502
503
504
505
506
507
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values is None and image_embeds is None:
            return None

        return Mistral3ImagePixelInputs(
            type="pixel_values_pixtral",
508
            pixel_values=pixel_values,
509
510
511
512
513
        )

    def _process_image_input(
        self,
        image_input: Mistral3ImagePixelInputs,
514
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
515
516
517
        if image_input["type"] == "image_embeds":
            return image_input["data"]

518
519
520
        image_sizes = [
            (img.shape[-2], img.shape[-1]) for img in image_input["pixel_values"]
        ]
521
522
523
524
525
526
527
528
529
530
531

        image_features = self.vision_tower(image_input["pixel_values"])

        if isinstance(image_features, torch.Tensor):
            return self.multi_modal_projector(image_features, image_sizes)

        feature_sizes = [
            image_feature.shape[0] // self.config.spatial_merge_size**2
            for image_feature in image_features
        ]

532
533
534
        image_embeds = self.multi_modal_projector(
            torch.cat(image_features), image_sizes
        )
535
536
537
        if len(feature_sizes) > 1:
            image_embeds = torch.split(image_embeds, feature_sizes)
        else:
538
            image_embeds = (image_embeds,)
539
540
        return image_embeds

541
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
542
543
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
544
            return []
545
546
547

        vision_embeddings = self._process_image_input(image_input)

548
        return vision_embeddings
549
550
551

    def forward(
        self,
552
        input_ids: torch.Tensor | None,
553
        positions: torch.Tensor,
554
555
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
556
        **kwargs: object,
557
    ) -> torch.Tensor | IntermediateTensors:
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
        """Run forward pass for Mistral3.

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.

        Concretely, consider a text prompt:
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

        Tokenizer outputs:
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
        before they are inputted to the model, so the input processor prepends
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
588
589
590
            positions: Position indices for the input tokens.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
591

592
        Info:
593
            [`Mistral3ImagePixelInputs`][vllm.model_executor.models.mistral3.Mistral3ImagePixelInputs]
594
595
596
597
        """
        if intermediate_tensors is not None:
            inputs_embeds = None

598
599
600
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
601
602
603
604
605
606

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
607
    ) -> torch.Tensor | None:
608
        return self.language_model.compute_logits(hidden_states)
609

610
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
611
        loader = AutoWeightsLoader(self)
612
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
613
614
615
616
617
618
619
620

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="multi_modal_projector",
621
622
            tower_model="vision_tower",
        )