mistral3.py 19.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

from collections.abc import Iterable, Mapping, Sequence
5
from typing import Annotated, Literal
6
7
8

import torch
import torch.nn as nn
9
from transformers import BatchFeature, Mistral3Config, PixtralVisionConfig
10
11
12
from transformers.models.pixtral import PixtralProcessor

from vllm.config import VllmConfig
13
from vllm.config.multimodal import BaseDummyOptions
14
from vllm.inputs import MultiModalDataDict
15
16
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
17
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
18
from vllm.model_executor.layers.quantization import QuantizationConfig
19
from vllm.model_executor.models.module_mapping import MultiModelKeys
20
from vllm.multimodal import MULTIMODAL_REGISTRY
21
22
23
24
25
26
from vllm.multimodal.inputs import (
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
27
    BaseDummyInputsBuilder,
28
29
30
31
32
33
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
34
from vllm.sequence import IntermediateTensors
35
from vllm.utils.tensor_schema import TensorSchema, TensorShape
36

37
38
from .interfaces import (
    MultiModalEmbeddings,
39
    SupportsEagle,
40
    SupportsEagle3,
41
42
43
44
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
45
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
46
47
48
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
49
    get_layer_index,
50
51
52
    init_vllm_registered_model,
    maybe_prefix,
)
53
from .vision import get_vision_encoder_info
54
55


56
class Mistral3ImagePixelInputs(TensorSchema):
57
    """
58
59
60
61
62
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height of each image
        - w: Width of each image
63
64
    """

65
66
67
68
69
    type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral"

    # Note that `height` or `width` may be different per batch and image,
    # in which case the data is passed as a list instead of a batched tensor.
    pixel_values: Annotated[
70
        torch.Tensor | list[torch.Tensor],
71
72
73
        TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"}),
    ]

74
75
76
77
78
79

class Mistral3PatchMerger(nn.Module):
    """
    Learned merging of spatial_merge_size ** 2 patches
    """

80
81
82
    def __init__(
        self, vision_hidden_size: int, spatial_merge_size: int, patch_size: int
    ):
83
84
85
86
87
        super().__init__()

        self.vision_hidden_size = vision_hidden_size
        self.spatial_merge_size = spatial_merge_size
        self.patch_size = patch_size
88
89
90
91
92
        self.merging_layer = nn.Linear(
            vision_hidden_size * self.spatial_merge_size**2,
            vision_hidden_size,
            bias=False,
        )
93

94
95
96
97
98
99
100
    def forward(
        self, image_features: torch.Tensor, image_sizes: torch.Tensor
    ) -> torch.Tensor:
        image_sizes = [
            (image_size[0] // self.patch_size, image_size[1] // self.patch_size)
            for image_size in image_sizes
        ]
101
102
103
104
105
106

        tokens_per_image = [h * w for h, w in image_sizes]
        d = image_features.shape[-1]

        permuted_tensor = []
        for image_index, image_tokens in enumerate(
107
108
            image_features.split(tokens_per_image)
        ):
109
110
            # Reshape image_tokens into a 2D grid
            h, w = image_sizes[image_index]
111
            image_grid = image_tokens.view(h, w, d).permute(2, 0, 1).unsqueeze(0)
112
113
114
            grid = torch.nn.functional.unfold(
                image_grid,
                kernel_size=self.spatial_merge_size,
115
116
                stride=self.spatial_merge_size,
            )
117
118
119
120
121
122
123
124
125
            grid = grid.view(d * self.spatial_merge_size**2, -1).t()
            permuted_tensor.append(grid)

        image_features = torch.cat(permuted_tensor, dim=0)
        image_features = self.merging_layer(image_features)
        return image_features


class Mistral3MultiModalProjector(nn.Module):
126
127
128
129
130
131
132
133
    def __init__(
        self,
        vision_hidden_size: int,
        text_hidden_size: int,
        spatial_merge_size: int,
        patch_size: int,
        projector_hidden_act: str,
        multimodal_projector_bias: bool,
134
        quant_config: QuantizationConfig | None = None,
135
136
        prefix: str = "",
    ):
137
138
139
140
141
142
        super().__init__()

        self.norm = RMSNorm(vision_hidden_size, eps=1e-5)
        self.patch_merger = Mistral3PatchMerger(
            vision_hidden_size=vision_hidden_size,
            spatial_merge_size=spatial_merge_size,
143
144
            patch_size=patch_size,
        )
145

146
147
148
149
150
151
152
        self.linear_1 = ColumnParallelLinear(
            vision_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_1",
        )
153
        self.act = get_act_fn(projector_hidden_act)
154
155
156
157
158
159
160
161
162
163
164
        self.linear_2 = RowParallelLinear(
            text_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_2",
        )

    def forward(
        self, image_features: torch.Tensor, image_sizes: torch.Tensor
    ) -> torch.Tensor:
165
166
167
168
169
170
171
172
        image_features = self.norm(image_features)
        image_features = self.patch_merger(image_features, image_sizes)
        hidden_states, _ = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.linear_2(hidden_states)
        return hidden_states


173
174
class Mistral3ProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> Mistral3Config:
175
176
177
178
179
        return self.ctx.get_hf_config(Mistral3Config)

    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())

180
181
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
182

183
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
        return {"image": None}

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        vision_encoder_info = self.get_vision_encoder_info()
        return vision_encoder_info.get_num_image_tokens(
            image_width=image_width,
            image_height=image_height,
        )

    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)


204
class Mistral3DummyInputsBuilder(BaseDummyInputsBuilder[Mistral3ProcessingInfo]):
205
206
207
208
209
210
211
212
213
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images

    def get_dummy_mm_data(
214
215
216
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
217
        mm_options: Mapping[str, BaseDummyOptions],
218
    ) -> MultiModalDataDict:
219
220
        num_images = mm_counts.get("image", 0)

221
        target_width, target_height = self.info.get_image_size_with_most_features()
222

223
        image_overrides = mm_options.get("image")
224

225
        return {
226
227
228
229
230
231
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
232
233
234
        }


235
class Mistral3MultiModalProcessor(BaseMultiModalProcessor[Mistral3ProcessingInfo]):
236
237
238
239
240
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
241
        tok_kwargs: Mapping[str, object],
242
243
244
245
246
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
247
            tok_kwargs=tok_kwargs,
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
        )

        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
            # Avoid padding since we need the output for each image to be
            # independent of other images for the cache to work correctly
            image_sizes = processed_outputs["image_sizes"]
            assert len(pixel_values) == len(image_sizes)

            processed_outputs["pixel_values"] = [
                p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
            ]

        return processed_outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
277
        out_mm_kwargs: MultiModalKwargsItems,
278
279
280
281
282
283
284
285
286
287
    ) -> Sequence[PromptUpdate]:
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        hf_config = self.info.get_hf_config()
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()

        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]

288
289
        assert isinstance(hf_config.vision_config, PixtralVisionConfig)
        encoder_info = PixtralHFEncoderInfo(hf_config)
290
291
292
293
294
295
296
297
298
299
300
301
302

        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)

            ncols, nrows = encoder_info.get_patch_grid_size(
                image_width=image_size.width,
                image_height=image_size.height,
            )

            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id

303
            return PromptUpdateDetails.select_token_id(tokens, image_token_id)
304
305
306
307
308
309
310
311
312
313

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


314
def _get_num_hidden_layers(hf_config: Mistral3Config) -> int:
315
316
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
317

318
319
320
321
322
323
324
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
325
        return get_layer_index(feature_layers, num_hidden_layers)
326
327
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
328
        return max(get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
329
330
331
    raise TypeError(
        f"vision_layer_feature type: {type(feature_layers)} is not supported"
    )
332
333


334
335
def init_vision_tower_for_mistral3(
    hf_config: Mistral3Config,
336
    quant_config: QuantizationConfig | None,
337
    *,
338
    require_post_norm: bool | None = None,
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
    prefix: str = "",
) -> PixtralHFVisionModel:
    vision_config = hf_config.vision_config

    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)

    assert isinstance(vision_config, PixtralVisionConfig)

    return PixtralHFVisionModel(
        vision_config,
        quant_config=quant_config,
        num_hidden_layers_override=num_hidden_layers,
        require_post_norm=require_post_norm,
        prefix=prefix,
    )


@MULTIMODAL_REGISTRY.register_processor(
358
359
    Mistral3MultiModalProcessor,
    info=Mistral3ProcessingInfo,
360
361
362
    dummy_inputs=Mistral3DummyInputsBuilder,
)
class Mistral3ForConditionalGeneration(
363
364
365
366
367
368
    nn.Module,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
    SupportsEagle,
    SupportsEagle3,
369
):
370
371
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
372
        "gate_up_proj": ["gate_proj", "up_proj"],
373
374
    }

375
376
377
378
379
380
381
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
382
383
384
            # Some PEFT LoRAs are trained against the text submodule directly
            # and produce names like `base_model.model.model.layers.*`.
            "model.": "language_model.model.",
385
386
        }
    )
387

388
    @classmethod
389
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
390
391
392
393
394
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

395
396
397
398
399
400
401
402
403
404
405
406
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__()

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
        self.multimodal_config = multimodal_config

        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
407
408
409
410
        if (
            config.text_config.architectures is None
            and config.text_config.model_type == "mistral"
        ):
411
            config.text_config.architectures = ["MistralForCausalLM"]
412
413
414
415
        if (
            config.projector_hidden_act is None
            and config.vision_config.hidden_act == "gelu"
        ):
416
417
            config.projector_hidden_act = "gelu"

418
        with self._mark_tower_model(vllm_config, "image"):
419
            self.vision_tower = init_vision_tower_for_mistral3(
420
                config,
421
                quant_config=quant_config,
422
                require_post_norm=False,
423
424
                prefix=maybe_prefix(prefix, "vision_tower"),
            )
425
426
427
428
429
430
431
432
            self.multi_modal_projector = Mistral3MultiModalProjector(
                vision_hidden_size=config.vision_config.hidden_size,
                text_hidden_size=config.text_config.hidden_size,
                projector_hidden_act=config.projector_hidden_act,
                spatial_merge_size=config.spatial_merge_size,
                patch_size=config.vision_config.patch_size,
                multimodal_projector_bias=config.multimodal_projector_bias,
                quant_config=quant_config,
433
434
                prefix=maybe_prefix(prefix, "multi_modal_projector"),
            )
435

436
437
438
439
440
441
        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )
442
443

        self.make_empty_intermediate_tensors = (
444
445
            self.language_model.make_empty_intermediate_tensors
        )
446
447

    def _parse_and_validate_image_input(
448
        self, **kwargs: object
449
    ) -> Mistral3ImagePixelInputs | None:
450
451
452
453
454
455
456
457
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values is None and image_embeds is None:
            return None

        return Mistral3ImagePixelInputs(
            type="pixel_values_pixtral",
458
            pixel_values=pixel_values,
459
460
461
462
463
        )

    def _process_image_input(
        self,
        image_input: Mistral3ImagePixelInputs,
464
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
465
466
467
        if image_input["type"] == "image_embeds":
            return image_input["data"]

468
469
470
        image_sizes = [
            (img.shape[-2], img.shape[-1]) for img in image_input["pixel_values"]
        ]
471
472
473
474
475
476
477
478
479
480
481

        image_features = self.vision_tower(image_input["pixel_values"])

        if isinstance(image_features, torch.Tensor):
            return self.multi_modal_projector(image_features, image_sizes)

        feature_sizes = [
            image_feature.shape[0] // self.config.spatial_merge_size**2
            for image_feature in image_features
        ]

482
483
484
        image_embeds = self.multi_modal_projector(
            torch.cat(image_features), image_sizes
        )
485
486
487
        if len(feature_sizes) > 1:
            image_embeds = torch.split(image_embeds, feature_sizes)
        else:
488
            image_embeds = (image_embeds,)
489
490
        return image_embeds

491
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
492
493
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
494
            return []
495
496
497

        vision_embeddings = self._process_image_input(image_input)

498
        return vision_embeddings
499
500
501

    def forward(
        self,
502
        input_ids: torch.Tensor | None,
503
        positions: torch.Tensor,
504
505
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
506
        **kwargs: object,
507
    ) -> torch.Tensor | IntermediateTensors:
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
        """Run forward pass for Mistral3.

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.

        Concretely, consider a text prompt:
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

        Tokenizer outputs:
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
        before they are inputted to the model, so the input processor prepends
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
538
539
540
            positions: Position indices for the input tokens.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
541

542
        Info:
543
            [`Mistral3ImagePixelInputs`][vllm.model_executor.models.mistral3.Mistral3ImagePixelInputs]
544
545
546
547
        """
        if intermediate_tensors is not None:
            inputs_embeds = None

548
549
550
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
551
552
553
554
555
556

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
557
    ) -> torch.Tensor | None:
558
        return self.language_model.compute_logits(hidden_states)
559

560
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
561
        loader = AutoWeightsLoader(self)
562
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
563
564
565
566
567
568
569
570

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="multi_modal_projector",
571
572
            tower_model="vision_tower",
        )