pixtral.py 47.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import math
5
from collections.abc import Iterable, Mapping, Sequence
Patrick von Platen's avatar
Patrick von Platen committed
6
from dataclasses import dataclass, fields
7
from functools import cached_property
8
from typing import Annotated, Literal
Patrick von Platen's avatar
Patrick von Platen committed
9
10
11
12

import torch
import torch.nn as nn
import torch.nn.functional as F
13
14
from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
15
from mistral_common.protocol.instruct.request import ChatCompletionRequest
16
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
Patrick von Platen's avatar
Patrick von Platen committed
17
from PIL import Image
18
from transformers import BatchFeature, PixtralVisionConfig, TensorType
19
from transformers.image_utils import ImageInput
20
from transformers.models.pixtral.image_processing_pixtral import (
21
22
    _num_image_tokens as _get_pixtral_hf_num_image_tokens,
)
23
from transformers.models.pixtral.modeling_pixtral import (
24
25
26
27
    PixtralRotaryEmbedding,
    apply_rotary_pos_emb,
    position_ids_in_meshgrid,
)
28
from transformers.tokenization_utils_base import TextInput
Patrick von Platen's avatar
Patrick von Platen committed
29

30
from vllm.config import VllmConfig
31
from vllm.config.multimodal import BaseDummyOptions
32
from vllm.distributed import divide, get_tensor_model_parallel_world_size
33
from vllm.model_executor.layers.activation import get_act_and_mul_fn
34
from vllm.model_executor.layers.conv import Conv2dLayer
Patrick von Platen's avatar
Patrick von Platen committed
35
from vllm.model_executor.layers.layernorm import RMSNorm
36
37
38
39
40
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
Patrick von Platen's avatar
Patrick von Platen committed
41
42
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
43
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalUUIDDict,
    NestedTensors,
)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    MultiModalProcessingInfo,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
59
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
60
from vllm.platforms import current_platform
61
from vllm.sequence import IntermediateTensors
62
from vllm.tokenizers import MistralTokenizer, cached_tokenizer_from_config
63
from vllm.utils.tensor_schema import TensorSchema, TensorShape
Patrick von Platen's avatar
Patrick von Platen committed
64

65
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
66
from .utils import init_vllm_registered_model, maybe_prefix
67
68
69
70
71
from .vision import (
    VisionEncoderInfo,
    VisionFeatureSelectStrategy,
    resolve_visual_encoder_outputs,
)
Patrick von Platen's avatar
Patrick von Platen committed
72

73
try:
74
    # Note: vLLM does not install xformers by default.
75
    from xformers import ops as xops
76
77

    if current_platform.is_cuda() and current_platform.has_device_capability(100):
78
79
80
81
        # Xformers FA is not compatible with B200
        USE_XFORMERS_OPS = False
    else:
        USE_XFORMERS_OPS = True
82
83
84
except ImportError:
    USE_XFORMERS_OPS = False

Patrick von Platen's avatar
Patrick von Platen committed
85
86
PATCH_MERGE = "patch_merge"

Patrick von Platen's avatar
Patrick von Platen committed
87

88
class PixtralImagePixelInputs(TensorSchema):
89
    """
90
91
92
93
94
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height of each image
        - w: Width of each image
95

96
    The result of stacking `ImageEncoding.tokens` from each prompt.
97
    """
98

99
100
    type: Literal["pixel_values"] = "pixel_values"

101
    images: Annotated[
102
        torch.Tensor | list[torch.Tensor],
103
104
        TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"}),
    ]
Patrick von Platen's avatar
Patrick von Platen committed
105
106


107
108
109
class PixtralProcessorAdapter:
    """
    Provide a HF-compatible interface for
110
    `mistral_common.tokens.tokenizers.multimodal.ImageEncoder`.
111
    """
Patrick von Platen's avatar
Patrick von Platen committed
112

113
114
    def __init__(self, tokenizer: MistralTokenizer) -> None:
        super().__init__()
Patrick von Platen's avatar
Patrick von Platen committed
115

116
        self.tokenizer = tokenizer
Patrick von Platen's avatar
Patrick von Platen committed
117

118
119
120
121
122
    @property
    def image_processor(self) -> ImageEncoder:
        image_encoder = self.tokenizer.instruct.mm_encoder
        assert isinstance(image_encoder, ImageEncoder)
        return image_encoder
123

124
125
126
    @cached_property
    def image_break_id(self) -> int:
        return self.image_processor.special_ids.img_break
Patrick von Platen's avatar
Patrick von Platen committed
127

128
129
130
    @cached_property
    def image_token_id(self) -> int:
        return self.image_processor.special_ids.img
Patrick von Platen's avatar
Patrick von Platen committed
131

132
133
134
    @cached_property
    def image_end_id(self) -> int:
        return self.image_processor.special_ids.img_end
Patrick von Platen's avatar
Patrick von Platen committed
135

136
137
138
    @cached_property
    def image_size(self) -> int:
        return self.image_processor.mm_config.max_image_size
Patrick von Platen's avatar
Patrick von Platen committed
139

140
141
142
143
144
145
    @cached_property
    def patch_size(self) -> int:
        return self.image_processor.mm_config.image_patch_size

    def __call__(
        self,
146
147
148
        text: TextInput | list[TextInput] | None = None,
        images: ImageInput | list[ImageInput] | None = None,
        return_tensors: str | TensorType | None = None,
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        **kwargs,
    ) -> Mapping[str, NestedTensors]:
        if text is None:
            text = []
        if not isinstance(text, list):
            text = [text]
        if images is None:
            images = []
        if not isinstance(images, list):
            images = [images]

        if not images:
            input_ids = self.tokenizer(text).input_ids

            return {"input_ids": torch.tensor(input_ids)}

        # Allow dummy text, which is used for profiling as well as token inputs
        if any(len(t) > 0 for t in text):
            raise ValueError(
                "You've passed text inputs instead of token inputs. "
                "Make sure to process your input via `mistral_common`'s "
                "tokenizer or pass a chat completion request. "
                "For more info, see: "
172
173
                "https://github.com/vllm-project/vllm/issues/8411."
            )
174
175
176
177
178
179
180
181
182
183
184
185

        images_processed = list[torch.Tensor]()
        images_tokens = list[torch.Tensor]()

        for image in images:
            image_inputs = self.image_processor(ImageChunk(image=image))
            image_processed = torch.tensor(image_inputs.image)
            image_tokens = torch.tensor(image_inputs.tokens)

            images_processed.append(image_processed)
            images_tokens.append(image_tokens)

186
187
188
189
190
191
        return BatchFeature(
            {
                "input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
                "images": images_processed,
            }
        )
192
193
194
195
196
197
198
199
200
201
202
203
204


class PixtralProcessingInfo(BaseProcessingInfo):
    def get_tokenizer(self) -> MistralTokenizer:
        tokenizer = cached_tokenizer_from_config(self.ctx.model_config)
        if not isinstance(tokenizer, MistralTokenizer):
            raise ValueError("This model requires `--tokenizer-mode mistral`")

        return tokenizer

    def get_hf_processor(self) -> PixtralProcessorAdapter:
        return PixtralProcessorAdapter(self.get_tokenizer())

205
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
206
207
208
209
        return {"image": None}

    def get_vision_config(
        self,
210
        processor: PixtralProcessorAdapter | None = None,
211
212
213
214
215
216
217
218
219
220
221
222
223
224
    ):
        if processor is None:
            processor = self.get_hf_processor()

        return PixtralVisionConfig(
            image_size=processor.image_size,
            patch_size=processor.patch_size,
        )

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
225
        processor: PixtralProcessorAdapter | None = None,
226
227
228
229
230
    ) -> int:
        if processor is None:
            processor = self.get_hf_processor()

        ncols, nrows = processor.image_processor._image_to_num_tokens(
231
232
            Image.new("RGB", (image_width, image_height))
        )
233

234
        return ncols * nrows
235
236
237
238
239
240
241
242
243

    def get_image_size_with_most_features(self) -> ImageSize:
        image_processor = self.get_hf_processor().image_processor
        max_image_size = image_processor.mm_config.max_image_size

        return ImageSize(width=max_image_size, height=max_image_size)


class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
244
245
246
247
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
248
249
250
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
251
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
252
    ) -> MultiModalDataDict:
253
254
        num_images = mm_counts.get("image", 0)

255
        target_width, target_height = self.info.get_image_size_with_most_features()
256

257
258
        image_overrides = mm_options.get("image") if mm_options else None

259
        return {
260
261
262
263
264
265
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
266
267
        }

268
269
270
271
    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
272
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
273
274
275
276
    ) -> ProcessorInputs:
        tokenizer = self.info.get_tokenizer()

        dummy_text = self.get_dummy_text(mm_counts)
277
        dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
278
        dummy_images = dummy_mm_data.get("image", [])
279
        tokenization_kwargs = {"truncation": False}
280

281
282
283
284
285
286
287
288
289
290
        request = ChatCompletionRequest(
            messages=[
                UserMessage(
                    content=[
                        TextChunk(text=dummy_text),
                        *(ImageChunk(image=image) for image in dummy_images),
                    ]
                ),
            ]
        )
291
292
293
        res = tokenizer.mistral.encode_chat_completion(request)
        dummy_tokens = res.tokens

294
295
296
297
298
        return ProcessorInputs(
            prompt=dummy_tokens,
            mm_data=dummy_mm_data,
            tokenization_kwargs=tokenization_kwargs,
        )
299

Patrick von Platen's avatar
Patrick von Platen committed
300

301
class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]):
302
303
304
305
306
    def _get_mm_fields_config(
        self,
        hf_inputs: Mapping[str, NestedTensors],
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
307
        return dict(images=MultiModalFieldConfig.batched("image"))
308
309
310
311
312

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
313
        out_mm_kwargs: MultiModalKwargsItems,
314
315
316
317
318
319
320
321
322
323
324
325
    ) -> Sequence[PromptUpdate]:
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

        image_break_id = processor.image_break_id
        image_token_id = processor.image_token_id
        image_end_id = processor.image_end_id

        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)

            ncols, nrows = processor.image_processor._image_to_num_tokens(
326
327
                Image.new("RGB", (image_size.width, image_size.height))
            )
328
329
330
331

            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id

332
            return PromptUpdateDetails.select_token_id(tokens, image_token_id)
333
334
335
336
337
338
339
340
341
342
343

        return [
            PromptReplacement(
                modality="image",
                target="",  # Never match the prompt (see below note)
                replacement=get_replacement,
            ),
        ]

    def _cached_apply_hf_processor(
        self,
344
        prompt: str | list[int],
345
346
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
347
        tokenization_kwargs: Mapping[str, object],
348
        mm_uuids: MultiModalUUIDDict | None = None,
349
350
    ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
        prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
351
352
353
            prompt=prompt,
            mm_data_items=mm_data_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
354
            tokenization_kwargs=tokenization_kwargs,
355
            mm_uuids=mm_uuids,
356
357
358
        )

        # NOTE: The tokens are already inserted by the chat template
359
        return prompt_ids, mm_info, True
Patrick von Platen's avatar
Patrick von Platen committed
360

361

362
363
364
365
366
367
@MULTIMODAL_REGISTRY.register_processor(
    PixtralMultiModalProcessor,
    info=PixtralProcessingInfo,
    dummy_inputs=PixtralDummyInputsBuilder,
)
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
368
369
    merge_by_field_config = True

370
    @classmethod
371
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
372
373
374
375
376
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

377
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Patrick von Platen's avatar
Patrick von Platen committed
378
        super().__init__()
379
380
        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config
Patrick von Platen's avatar
Patrick von Platen committed
381
382
383
384
385
386
387
388
389
390
391
392
393
394
        self.config = config
        self.multimodal_config = multimodal_config

        dataclass_fields = {field.name for field in fields(VisionEncoderArgs)}
        vision_args = {
            key: value
            for key, value in self.config.vision_config.to_dict().items()
            if key in dataclass_fields
        }

        self.vision_args = VisionEncoderArgs(**vision_args)

        # init MistralForCausalLM
        self.language_model = init_vllm_registered_model(
395
            vllm_config=vllm_config,
396
397
398
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
Patrick von Platen's avatar
Patrick von Platen committed
399

400
401
402
403
404
405
        if multimodal_config.get_limit_per_prompt("image"):
            self.vision_encoder = VisionTransformer(self.vision_args)
            self.pre_mm_projector_norm = (
                RMSNorm(self.vision_args.hidden_size, eps=1e-5)
                if self.vision_args.add_pre_mm_projector_layer_norm
                else None
Patrick von Platen's avatar
Patrick von Platen committed
406
            )
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
            self.patch_merger = (
                PatchMerger(
                    vision_encoder_dim=self.vision_args.hidden_size,
                    spatial_merge_size=self.vision_args.spatial_merge_size,
                    use_mlp_bias=False,
                )
                if self.vision_args.mm_projector_id == PATCH_MERGE
                else None
            )
            self.vision_language_adapter = VisionLanguageAdapter(
                self.vision_args, dim=config.text_config.hidden_size
            )
        else:
            self.vision_encoder = None
            self.pre_mm_projector_norm = None
            self.patch_merger = None
            self.vision_language_adapter = None
Patrick von Platen's avatar
Patrick von Platen committed
424

425
        self.make_empty_intermediate_tensors = (
426
427
            self.language_model.make_empty_intermediate_tensors
        )
428

429
    def _parse_and_validate_image_input(
430
        self, **kwargs: object
431
    ) -> PixtralImagePixelInputs | None:
432
433
434
435
436
437
        images = kwargs.pop("images", None)
        if images is None:
            return None

        return PixtralImagePixelInputs(
            type="pixel_values",
438
            images=images,
439
440
441
442
443
444
        )

    def _process_image_input(
        self,
        image_input: PixtralImagePixelInputs,
    ) -> tuple[torch.Tensor, ...]:
445
446
447
448
        assert (
            self.vision_encoder is not None and self.vision_language_adapter is not None
        )

449
450
        images = image_input["images"]
        image_features = self.vision_encoder(images)
451
        feature_sizes = [image_feature.shape[0] for image_feature in image_features]
Patrick von Platen's avatar
Patrick von Platen committed
452
        image_features = torch.cat(image_features)
453
        if self.pre_mm_projector_norm is not None:
Patrick von Platen's avatar
Patrick von Platen committed
454
            image_features = self.pre_mm_projector_norm(image_features)
455
        if self.patch_merger is not None:
Patrick von Platen's avatar
Patrick von Platen committed
456
457
            patch_size = self.vision_args.patch_size
            spatial_merge_size_square = self.vision_args.spatial_merge_size**2
458
459
460
461
            img_patch_dims = [
                (img.shape[1] // patch_size, img.shape[2] // patch_size)
                for img in images
            ]
Patrick von Platen's avatar
Patrick von Platen committed
462
463
464
465
            feature_sizes = [
                feature_size // spatial_merge_size_square
                for feature_size in feature_sizes
            ]
466
467
468
            image_features = self.patch_merger(
                image_features, image_sizes=img_patch_dims
            )
Patrick von Platen's avatar
Patrick von Platen committed
469
        image_embeds = self.vision_language_adapter(image_features)
470
471
472
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

473
474
475
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

476
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
477
        image_input = self._parse_and_validate_image_input(**kwargs)
478
        if image_input is None:
479
            return []
480

481
        return self._process_image_input(image_input)
482

Patrick von Platen's avatar
Patrick von Platen committed
483
484
485
486
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
487
488
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
Patrick von Platen's avatar
Patrick von Platen committed
489
        **kwargs: object,
490
    ) -> torch.Tensor | IntermediateTensors:
491
        """Run forward pass for pixtral."""
492
493
        if intermediate_tensors is not None:
            inputs_embeds = None
Patrick von Platen's avatar
Patrick von Platen committed
494

495
496
497
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
Patrick von Platen's avatar
Patrick von Platen committed
498
499
500
501
502
503

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
504
    ) -> torch.Tensor | None:
505
        return self.language_model.compute_logits(hidden_states)
Patrick von Platen's avatar
Patrick von Platen committed
506

507
508
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
        def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]):
Patrick von Platen's avatar
Patrick von Platen committed
509
510
            return weight[0].startswith("vision_encoder")

511
        def is_vision_lang_adapter_weights(weight: tuple[str, torch.Tensor]):
Patrick von Platen's avatar
Patrick von Platen committed
512
513
            return weight[0].startswith("vision_language_adapter")

514
        def is_patch_merger(weight: tuple[str, torch.Tensor]):
Patrick von Platen's avatar
Patrick von Platen committed
515
516
            return weight[0].startswith("patch_merger")

517
        def is_pre_mm_projector_norm(weight: tuple[str, torch.Tensor]):
Patrick von Platen's avatar
Patrick von Platen committed
518
519
            return weight[0].startswith("pre_mm_projector_norm")

520
        # Get references to parameters for direct loading
521
522
523
524
525
        vision_encoder_dict = (
            dict(self.vision_encoder.named_parameters())
            if self.vision_encoder is not None
            else {}
        )
526
527
        patch_merger_dict = (
            dict(self.patch_merger.named_parameters())
528
529
            if self.patch_merger is not None
            else {}
530
531
532
        )
        pre_mm_projector_norm_dict = (
            dict(self.pre_mm_projector_norm.named_parameters())
533
534
535
536
537
538
539
            if self.pre_mm_projector_norm is not None
            else {}
        )
        vision_lang_adapter_dict = (
            dict(self.vision_language_adapter.named_parameters())
            if self.vision_language_adapter is not None
            else {}
540
        )
541
542
543
544
545

        def llm_weights_generator():
            # Single pass over weights
            for name, w in weights:
                if is_vision_encoder_weights((name, w)):
546
547
                    if self.vision_encoder is None:
                        continue
548
                    # Load vision encoder weights directly
549
                    trimmed_name = ".".join(name.split(".")[1:])
550
551
552
                    param = vision_encoder_dict[trimmed_name]
                    with torch.no_grad():
                        default_weight_loader(param, w)
Patrick von Platen's avatar
Patrick von Platen committed
553
                elif is_patch_merger((name, w)):
554
555
                    if self.patch_merger is None:
                        continue
Patrick von Platen's avatar
Patrick von Platen committed
556
                    # Load vision patch merger weights directly
557
                    trimmed_name = ".".join(name.split(".")[1:])
Patrick von Platen's avatar
Patrick von Platen committed
558
559
560
561
                    param = patch_merger_dict[trimmed_name]
                    with torch.no_grad():
                        default_weight_loader(param, w)
                elif is_pre_mm_projector_norm((name, w)):
562
563
                    if self.pre_mm_projector_norm is None:
                        continue
Patrick von Platen's avatar
Patrick von Platen committed
564
                    # Load vision pre_mm_projector_norm weights directly
565
                    trimmed_name = ".".join(name.split(".")[1:])
Patrick von Platen's avatar
Patrick von Platen committed
566
567
568
                    param = pre_mm_projector_norm_dict[trimmed_name]
                    with torch.no_grad():
                        default_weight_loader(param, w)
569
                elif is_vision_lang_adapter_weights((name, w)):
570
571
                    if self.vision_language_adapter is None:
                        continue
572
                    # Load vision-language adapter weights directly
573
                    trimmed_name = ".".join(name.split(".")[1:])
574
575
576
577
578
579
580
581
582
583
                    param = vision_lang_adapter_dict[trimmed_name]
                    with torch.no_grad():
                        default_weight_loader(param, w)
                else:
                    # LLM weights: yield them to be loaded
                    # by language_model.load_weights
                    yield (name, w)

        # Now we call the language model load with the generator
        self.language_model.load_weights(llm_weights_generator())
Patrick von Platen's avatar
Patrick von Platen committed
584
585
586
587
588
589
590
591
592
593
594
595
596
597


# Vision encoder
@dataclass
class VisionEncoderArgs:
    hidden_size: int
    num_channels: int
    image_size: int
    patch_size: int
    intermediate_size: int
    num_hidden_layers: int
    num_attention_heads: int
    rope_theta: float  # for rope-2D
    image_token_id: int
598
    adapter_bias: bool = True
Patrick von Platen's avatar
Patrick von Platen committed
599
600
601
    spatial_merge_size: int = 1
    add_pre_mm_projector_layer_norm: bool = False
    mm_projector_id: str = ""
Patrick von Platen's avatar
Patrick von Platen committed
602
603


604
def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
Patrick von Platen's avatar
Patrick von Platen committed
605
606
607
608
609
610
611
612
613
614
    """
    freqs_cis: complex - (seq_len, head_dim / 2)
    x: complex - (bsz, seq_len, head_dim / 2)
    """
    ndim = x.ndim
    assert ndim > 1
    assert freqs_cis.shape == (x.shape[1], x.shape[-1]), (
        freqs_cis.shape,
        (x.shape[1], x.shape[-1]),
    )
615
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
Patrick von Platen's avatar
Patrick von Platen committed
616
617
618
619
620
621
622
623
624
625
626
627
628
629
    return freqs_cis.view(*shape)


def precompute_freqs_cis_2d(
    dim: int,
    height: int,
    width: int,
    theta: float,
) -> torch.Tensor:
    """
    freqs_cis: 2D complex tensor of shape (height, width, dim // 2)
        to be indexed by (height, width) position tuples
    """
    # (dim / 2) frequency bases
630
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
Patrick von Platen's avatar
Patrick von Platen committed
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650

    h = torch.arange(height, device=freqs.device)
    w = torch.arange(width, device=freqs.device)

    freqs_h = torch.outer(h, freqs[::2]).float()
    freqs_w = torch.outer(w, freqs[1::2]).float()
    freqs_2d = torch.cat(
        [
            freqs_h[:, None, :].repeat(1, width, 1),
            freqs_w[None, :, :].repeat(height, 1, 1),
        ],
        dim=-1,
    )
    return torch.polar(torch.ones_like(freqs_2d), freqs_2d)


def apply_rotary_emb_vit(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
651
) -> tuple[torch.Tensor, torch.Tensor]:
Patrick von Platen's avatar
Patrick von Platen committed
652
653
654
655
656
657
658
659
660
661
662
663
664
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    assert freqs_cis.dtype == torch.complex64
    freqs_cis = _reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)


class FeedForward(nn.Module):
    def __init__(self, args: VisionEncoderArgs):
        super().__init__()
        assert args.intermediate_size is not None
665
666
667
        self.w1 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
        self.w2 = nn.Linear(args.intermediate_size, args.hidden_size, bias=False)
        self.w3 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
Patrick von Platen's avatar
Patrick von Platen committed
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


class Attention(nn.Module):
    def __init__(self, args: VisionEncoderArgs):
        super().__init__()
        self.args = args
        assert not args.hidden_size % args.num_attention_heads
        self.n_heads = args.num_attention_heads
        self.head_dim = args.hidden_size // args.num_attention_heads

        self.wq = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
        self.wk = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
        self.wv = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
        self.wo = nn.Linear(args.hidden_size, args.hidden_size, bias=False)

    def forward(
        self,
        x: torch.Tensor,
689
        mask: torch.Tensor,
Patrick von Platen's avatar
Patrick von Platen committed
690
691
692
693
694
695
696
697
698
699
        freqs_cis: torch.Tensor,
    ) -> torch.Tensor:
        batch, patches, _ = x.shape

        q, k, v = self.wq(x), self.wk(x), self.wv(x)
        q = q.reshape(batch, patches, self.n_heads, self.head_dim)
        k = k.reshape(batch, patches, self.n_heads, self.head_dim)
        v = v.reshape(batch, patches, self.n_heads, self.head_dim)

        q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
700
701
702
703
704
705
706

        if USE_XFORMERS_OPS:
            out = xops.memory_efficient_attention(q, k, v, attn_bias=mask)
        else:
            q = q.transpose(1, 2)
            k = k.transpose(1, 2)
            v = v.transpose(1, 2)
707
            out = nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
708
709
            out = out.transpose(1, 2)

Patrick von Platen's avatar
Patrick von Platen committed
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
        out = out.reshape(batch, patches, self.n_heads * self.head_dim)
        return self.wo(out)


class TransformerBlock(nn.Module):
    def __init__(self, args: VisionEncoderArgs):
        super().__init__()
        self.attention = Attention(args)
        self.feed_forward = FeedForward(args)
        self.attention_norm = RMSNorm(args.hidden_size, eps=1e-5)
        self.ffn_norm = RMSNorm(args.hidden_size, eps=1e-5)

    def forward(
        self,
        x: torch.Tensor,
725
        mask: torch.Tensor,
Patrick von Platen's avatar
Patrick von Platen committed
726
727
        freqs_cis: torch.Tensor,
    ) -> torch.Tensor:
728
729
730
        r = self.attention.forward(
            self.attention_norm(x), mask=mask, freqs_cis=freqs_cis
        )
Patrick von Platen's avatar
Patrick von Platen committed
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
        h = x + r
        r = self.feed_forward.forward(self.ffn_norm(h))
        out = h + r
        return out


class Transformer(nn.Module):
    def __init__(self, args: VisionEncoderArgs):
        super().__init__()
        self.layers = torch.nn.ModuleList()
        for _ in range(args.num_hidden_layers):
            self.layers.append(TransformerBlock(args))

    def forward(
        self,
        x: torch.Tensor,
747
        mask: torch.Tensor,
748
        freqs_cis: torch.Tensor | None,
Patrick von Platen's avatar
Patrick von Platen committed
749
750
751
752
753
754
    ) -> torch.Tensor:
        for layer in self.layers:
            x = layer(x, mask=mask, freqs_cis=freqs_cis)
        return x


755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
def position_meshgrid(
    patch_embeds_list: list[torch.Tensor],
) -> torch.Tensor:
    positions = torch.cat(
        [
            torch.stack(
                torch.meshgrid(
                    torch.arange(p.shape[-2]),
                    torch.arange(p.shape[-1]),
                    indexing="ij",
                ),
                dim=-1,
            ).reshape(-1, 2)
            for p in patch_embeds_list
        ]
    )
Patrick von Platen's avatar
Patrick von Platen committed
771
772
773
774
775
776
777
    return positions


class VisionTransformer(nn.Module):
    def __init__(self, args: VisionEncoderArgs):
        super().__init__()
        self.args = args
778
        self.patch_conv = Conv2dLayer(
Patrick von Platen's avatar
Patrick von Platen committed
779
780
781
782
783
784
785
786
787
788
789
            in_channels=args.num_channels,
            out_channels=args.hidden_size,
            kernel_size=args.patch_size,
            stride=args.patch_size,
            bias=False,
        )
        self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5)
        self.transformer = Transformer(args)

        head_dim = self.args.hidden_size // self.args.num_attention_heads
        assert head_dim % 2 == 0, "ROPE requires even head_dim"
790
        self._freqs_cis: torch.Tensor | None = None
Patrick von Platen's avatar
Patrick von Platen committed
791
792
793
794
795
796

    @property
    def max_patches_per_side(self) -> int:
        return self.args.image_size // self.args.patch_size

    @property
797
    def device(self) -> torch.types.Device:
Patrick von Platen's avatar
Patrick von Platen committed
798
799
800
        return next(self.parameters()).device

    @property
801
    def dtype(self) -> torch.dtype:
Patrick von Platen's avatar
Patrick von Platen committed
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
        return next(self.parameters()).dtype

    @property
    def freqs_cis(self) -> torch.Tensor:
        if self._freqs_cis is None:
            self._freqs_cis = precompute_freqs_cis_2d(
                dim=self.args.hidden_size // self.args.num_attention_heads,
                height=self.max_patches_per_side,
                width=self.max_patches_per_side,
                theta=self.args.rope_theta,
            )

        if self._freqs_cis.device != self.device:
            self._freqs_cis = self._freqs_cis.to(device=self.device)

        return self._freqs_cis

    def forward(
        self,
821
        images: list[torch.Tensor],
Patrick von Platen's avatar
Patrick von Platen committed
822
823
824
    ) -> torch.Tensor:
        """
        Args:
825
            images: list of N_img images of variable sizes,
Patrick von Platen's avatar
Patrick von Platen committed
826
827
                each of shape (C, H, W)
        Returns:
828
            image_features: tensor of token features for
Patrick von Platen's avatar
Patrick von Platen committed
829
830
831
832
833
834
835
                all tokens of all images of shape (N_toks, D)
        """
        # pass images through initial convolution independently
        patch_embeds_list = [
            self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images
        ]

836
        patch_embeds = [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list]
837
838
        embed_sizes = [p.shape[1] for p in patch_embeds]

Patrick von Platen's avatar
Patrick von Platen committed
839
        # flatten to a single sequence
840
        patch_embeds = torch.cat(patch_embeds, dim=1)
Patrick von Platen's avatar
Patrick von Platen committed
841
842
843
844
845
846
847
        patch_embeds = self.ln_pre(patch_embeds)

        # positional embeddings
        positions = position_meshgrid(patch_embeds_list).to(self.device)
        freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]]

        # pass through Transformer with a block diagonal mask delimiting images
848
849
        if USE_XFORMERS_OPS:
            mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
850
851
                [p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
            )
852
        else:
853
            from transformers.models.pixtral.modeling_pixtral import (
854
855
856
                generate_block_attention_mask,
            )

857
            mask = generate_block_attention_mask(
858
859
                [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds
            )
Patrick von Platen's avatar
Patrick von Platen committed
860
861
        out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)

862
863
        # squeeze dim 0 and split into separate tensors for each image
        return torch.split(out.squeeze(0), embed_sizes)
Patrick von Platen's avatar
Patrick von Platen committed
864
865
866
867
868
869
870
871
872


class VisionLanguageAdapter(nn.Module):
    def __init__(self, args: VisionEncoderArgs, dim: int):
        super().__init__()
        assert isinstance(args, VisionEncoderArgs)
        self.w_in = nn.Linear(
            args.hidden_size,
            dim,
873
            bias=args.adapter_bias,
Patrick von Platen's avatar
Patrick von Platen committed
874
875
        )
        self.gelu = nn.GELU()
876
        self.w_out = nn.Linear(dim, dim, bias=args.adapter_bias)
Patrick von Platen's avatar
Patrick von Platen committed
877
878
879

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w_out(self.gelu(self.w_in(x)))
880
881


Patrick von Platen's avatar
Patrick von Platen committed
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
class PatchMerger(nn.Module):
    """
    Learned merging of spatial_merge_size ** 2 patches
    """

    def __init__(
        self,
        vision_encoder_dim: int,
        spatial_merge_size: int,
        use_mlp_bias: bool = False,
    ) -> None:
        super().__init__()

        mlp_input_dim = vision_encoder_dim * (spatial_merge_size**2)

        self.spatial_merge_size = spatial_merge_size
        self.mlp_input_dim = mlp_input_dim

        self.merging_layer = nn.Linear(
            mlp_input_dim,
            vision_encoder_dim,
            bias=use_mlp_bias,
        )

906
907
908
    def forward(
        self, x: torch.Tensor, image_sizes: list[tuple[int, int]]
    ) -> torch.Tensor:
Patrick von Platen's avatar
Patrick von Platen committed
909
910
911
912
913
914
        # image_sizes specified in tokens
        assert sum([h * w for h, w in image_sizes]) == len(x)

        # x is (N, vision_encoder_dim)
        x = self.permute(x, image_sizes)

915
916
        # x is (N / spatial_merge_size ** 2,
        #       vision_encoder_dim * spatial_merge_size ** 2)
Patrick von Platen's avatar
Patrick von Platen committed
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
        x = self.merging_layer(x)

        # x is (N / spatial_merge_size ** 2, vision_encoder_dim)
        return x

    def permute(
        self,
        x: torch.Tensor,
        image_sizes: list[tuple[int, int]],
    ) -> torch.Tensor:
        """
        Args:
            x: (N, D) where N is flattened and concatenated patch tokens
                for all images
            image_sizes: list of tuple of (height, width) in tokens for
                each image
        Returns:
            image_features: reorders patch tokens so each grid of
                (spatial_merge_size, spatial_merge_size) is contiguous.
                now (N / spatial_merge_size ** 2, D * spatial_merge_size ** 2)
        """

        sub_grids = get_sub_grids(
940
            x=x, image_sizes=image_sizes, spatial_merge_size=self.spatial_merge_size
Patrick von Platen's avatar
Patrick von Platen committed
941
942
943
944
        )  # list of [d x sub_grid_size x sub_grid_size x n_patches]
        permuted_tensor: list[torch.Tensor] = []
        for grid in sub_grids:
            n_patches = grid.shape[-1]
945
946
947
            permuted_tensor.append(
                grid.view(-1, n_patches).t()
            )  # n_patches x d * sub_grid_size * sub_grid_size
Patrick von Platen's avatar
Patrick von Platen committed
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
        return torch.cat(
            permuted_tensor, dim=0
        )  # (N / spatial_merge_size ** 2, d * spatial_merge_size ** 2)


def get_sub_grids(
    x: torch.Tensor,
    image_sizes: list[tuple[int, int]],
    spatial_merge_size: int,
) -> list[torch.Tensor]:
    # image_sizes specified in tokens
    tokens_per_image = [h * w for h, w in image_sizes]
    d = x.shape[-1]
    all_img_sub_grids: list[torch.Tensor] = []
    sub_grid_size = spatial_merge_size

    for image_index, image_tokens in enumerate(x.split(tokens_per_image)):
        # Reshape image_tokens into a 2D grid
        h, w = image_sizes[image_index]
967
968
969
970
971
972
        image_grid = image_tokens.view(h, w, d).permute(2, 0, 1)[
            None, :, :, :
        ]  # 1 x d x h x w
        sub_grids = torch.nn.functional.unfold(
            image_grid, kernel_size=sub_grid_size, stride=sub_grid_size
        )
Patrick von Platen's avatar
Patrick von Platen committed
973
        sub_grids = sub_grids.view(
974
975
            1, d, sub_grid_size, sub_grid_size, -1
        )  # 1 x d x sub_grid_size x sub_grid_size x n_patches
Patrick von Platen's avatar
Patrick von Platen committed
976
977
978
979
980
981

        all_img_sub_grids.append(sub_grids[0])

    return all_img_sub_grids


982
983
984
985
986
987
988
989
#### HF Transformers version of Pixtral ####
# Based off https://github.com/huggingface/transformers/blob/d7950bff82b18c823193d17d72188c5e46d06c83/src/transformers/models/pixtral/modeling_pixtral.py
# This model follows the Llava family, meaning image embeddings are placed
# instead of the `[IMG]` token placeholders.
# The model uses [`PixtralVisionModel`] for its vision encoder,
# and [`MistralForCausalLM`] for its language decoder.


990
991
992
993
994
995
996
class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
997
998
999
        ncols, nrows = self.get_patch_grid_size(
            image_width=image_width,
            image_height=image_height,
1000
        )
1001
        return ncols * nrows
1002

1003
1004
1005
1006
    def get_image_size(self) -> int:
        return self.vision_config.image_size

    def get_patch_size(self) -> int:
1007
1008
1009
        # spatial_merge_size is needed for Mistral3
        spatial_merge_size = getattr(self.hf_config, "spatial_merge_size", 1)
        return self.vision_config.patch_size * spatial_merge_size
1010
1011

    def get_patch_grid_length(self) -> int:
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
        image_size, patch_size = self.get_image_size(), self.get_patch_size()

        # Since interpolation is applied, the image size need not be divisible
        # assert image_size % patch_size == 0
        return image_size // patch_size

    # Adapted from: https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/pixtral/image_processing_pixtral.py#L99
    def get_patch_grid_size(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> tuple[int, int]:
        max_width = max_height = self.get_image_size()
        patch_width = patch_height = self.get_patch_size()

        ratio = max(image_width / max_width, image_height / max_height)

        if ratio > 1:
1031
1032
            image_width = int(math.floor(image_width / ratio))
            image_height = int(math.floor(image_height / ratio))
1033
1034
1035
1036
1037
1038
1039

        nrows, ncols = _get_pixtral_hf_num_image_tokens(
            (image_height, image_width),
            (patch_height, patch_width),
        )  # type: ignore

        return ncols, nrows
1040
1041
1042


class PixtralHFMLP(nn.Module):
1043
1044
1045
    def __init__(
        self,
        config: PixtralVisionConfig,
1046
        quant_config: QuantizationConfig | None = None,
1047
1048
1049
        *,
        prefix: str = "",
    ) -> None:
1050
        super().__init__()
1051

1052
        assert config.intermediate_size is not None
1053
1054
1055
1056
1057
        self.gate_up_proj = MergedColumnParallelLinear(
            input_size=config.hidden_size,
            output_sizes=[config.intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
1058
1059
1060
1061
1062
1063
1064
1065
1066
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            input_size=config.intermediate_size,
            output_size=config.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.down_proj",
        )
1067
        self.act_and_mul = get_act_and_mul_fn(config.hidden_act)
1068
1069

    def forward(self, x: torch.Tensor) -> torch.Tensor:
1070
1071
1072
1073
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_and_mul(gate_up)
        x, _ = self.down_proj(x)
        return x
1074
1075
1076


class PixtralHFAttention(nn.Module):
1077
1078
1079
    def __init__(
        self,
        config: PixtralVisionConfig,
1080
        quant_config: QuantizationConfig | None = None,
1081
1082
1083
        *,
        prefix: str = "",
    ) -> None:
1084
        super().__init__()
1085

1086
1087
        self.config = config
        assert not config.hidden_size % config.num_attention_heads
1088
1089
1090
        self.total_num_heads = config.num_attention_heads
        tp_size = get_tensor_model_parallel_world_size()
        self.n_heads = divide(config.num_attention_heads, tp_size)
1091
1092
        self.head_dim = config.hidden_size // config.num_attention_heads

1093
1094
1095
        self.qkv_proj = QKVParallelLinear(
            hidden_size=config.hidden_size,
            head_size=self.head_dim,
1096
            total_num_heads=self.total_num_heads,
1097
1098
1099
1100
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
1101
        assert self.total_num_heads * self.head_dim == config.hidden_size
1102
1103
1104
1105
1106
1107
1108
        self.o_proj = RowParallelLinear(
            input_size=config.hidden_size,
            output_size=config.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
1109
1110
1111
1112

    def forward(
        self,
        hidden_states: torch.Tensor,
1113
        attention_mask: torch.Tensor,
1114
        position_embeddings: torch.Tensor,
1115
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
1116
        batch, patches, _ = hidden_states.size()
1117

1118
1119
        qkv_states, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv_states.chunk(3, dim=-1)
1120

1121
1122
1123
        # Transpose q and k to apply HF's Rotary Position Embedding
        q = q.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
1124
        v = v.view(batch, patches, self.n_heads, self.head_dim)
1125
        cos, sin = position_embeddings
1126
        q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0)
1127

1128
1129
1130
1131
        if USE_XFORMERS_OPS:
            # Transpose q and k back for attention
            q = q.transpose(1, 2).contiguous()
            k = k.transpose(1, 2).contiguous()
1132
            out = xops.memory_efficient_attention(q, k, v, attn_bias=attention_mask)
1133
        else:
1134
            v = v.transpose(1, 2)
1135
            out = nn.functional.scaled_dot_product_attention(
1136
1137
                q, k, v, attn_mask=attention_mask
            )
1138
            out = out.transpose(1, 2)
1139

1140
        out = out.reshape(batch, patches, self.n_heads * self.head_dim)
1141
        attn_output, _ = self.o_proj(out)
1142

1143
        return attn_output, None
1144
1145
1146


class PixtralHFTransformerBlock(nn.Module):
1147
1148
1149
    def __init__(
        self,
        config: PixtralVisionConfig,
1150
        quant_config: QuantizationConfig | None = None,
1151
1152
1153
        *,
        prefix: str = "",
    ) -> None:
1154
        super().__init__()
1155

1156
        self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
1157
1158
1159
1160
1161
1162
        self.attention = PixtralHFAttention(
            config, quant_config=quant_config, prefix=f"{prefix}.attention"
        )
        self.feed_forward = PixtralHFMLP(
            config, quant_config=quant_config, prefix=f"{prefix}.feed_forward"
        )
1163
1164
1165
1166
1167
        self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)

    def forward(
        self,
        hidden_states: torch.Tensor,
1168
        attention_mask: torch.Tensor,
1169
1170
        position_embeddings: torch.Tensor,
    ) -> torch.Tensor:
1171
1172
1173
1174
1175
        r, _ = self.attention.forward(
            self.attention_norm(hidden_states),
            attention_mask=attention_mask,
            position_embeddings=position_embeddings,
        )
1176
1177
1178
1179
1180
1181
1182
        h = hidden_states + r
        r = self.feed_forward.forward(self.ffn_norm(h))
        out = h + r
        return out


class PixtralHFTransformer(nn.Module):
1183
1184
1185
    def __init__(
        self,
        config: PixtralVisionConfig,
1186
        quant_config: QuantizationConfig | None = None,
1187
        *,
1188
        num_hidden_layers_override: int | None = None,
1189
1190
        prefix: str = "",
    ) -> None:
1191
        super().__init__()
1192
1193
1194
1195
1196
1197

        if num_hidden_layers_override is None:
            num_hidden_layers = config.num_hidden_layers
        else:
            num_hidden_layers = num_hidden_layers_override

1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
        self.layers = nn.ModuleList(
            [
                PixtralHFTransformerBlock(
                    config=config,
                    quant_config=quant_config,
                    prefix=f"{prefix}.layers.{layer_idx}",
                )
                for layer_idx in range(num_hidden_layers)
            ]
        )
1208
1209
1210
1211

    def forward(
        self,
        x: torch.Tensor,
1212
        attention_mask: torch.Tensor,
1213
        position_embeddings: torch.Tensor,
1214
        return_all_hidden_states: bool,
1215
    ) -> torch.Tensor:
1216
        hidden_states_pool = [x]
1217

1218
1219
        for layer in self.layers:
            x = layer(x, attention_mask, position_embeddings)
1220
1221
1222
1223
1224
1225
            if return_all_hidden_states:
                hidden_states_pool.append(x)
        # If we have multiple feature sample layers, we return all hidden
        # states in order and grab the ones we need by index.
        if return_all_hidden_states:
            return hidden_states_pool
1226
1227
1228
1229
        return x


class PixtralHFVisionModel(nn.Module):
1230
1231
1232
    def __init__(
        self,
        config: PixtralVisionConfig,
1233
        quant_config: QuantizationConfig | None = None,
1234
        *,
1235
1236
        num_hidden_layers_override: int | None = None,
        require_post_norm: bool | None = None,
1237
1238
        prefix: str = "",
    ) -> None:
1239
1240
1241
        super().__init__()

        self.config = config
1242

1243
        self.patch_conv = Conv2dLayer(
1244
1245
1246
1247
1248
1249
1250
            in_channels=config.num_channels,
            out_channels=config.hidden_size,
            kernel_size=config.patch_size,
            stride=config.patch_size,
            bias=False,
        )
        self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
        self.transformer = PixtralHFTransformer(
            config,
            quant_config,
            num_hidden_layers_override=num_hidden_layers_override,
            prefix=f"{prefix}.transformer",
        )

        num_hidden_layers = config.num_hidden_layers
        if len(self.transformer.layers) > config.num_hidden_layers:
            raise ValueError(
                f"The original encoder only has {num_hidden_layers} "
                f"layers, but you requested {len(self.transformer.layers)} "
1263
1264
                "layers."
            )
1265
1266
1267
1268
1269

        if require_post_norm is True:
            msg = "PixtralHFVisionModel does not have post-layernorm"
            raise ValueError(msg)

1270
1271
        self.dtype = next(self.parameters()).dtype
        self.device = next(self.parameters()).device
1272
        self.patch_positional_embedding = PixtralRotaryEmbedding(config, self.device)
1273
1274
1275

    def forward(
        self,
1276
        pixel_values: list[torch.Tensor],
1277
        *,
1278
1279
        select_layers: list[int] | None = None,
        feature_select_strategy: VisionFeatureSelectStrategy | None = None,
1280
    ) -> tuple[torch.Tensor, ...]:
1281
1282
        """
        Args:
1283
1284
1285
1286
            pixel_values: Each image to be processed will be a separate tensor
                in pixel_values. This means it will be a list of tensors
                because multiple requests batched can have multiple images,
                each with their own shape potentially
1287
            select_layers: Layer indices whose features should be
1288
1289
                concatenated and used as the visual encoder output. If none
                are provided, the last layer is used.
1290

1291
1292
1293
1294
1295
1296
        Returns:
            image_features: tensor of token features for
                all tokens of all images of shape (N_toks, D)
        """
        # pass images through initial convolution independently
        patch_embeds_list = [
1297
            self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in pixel_values
1298
1299
        ]

1300
        patch_embeds = [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list]
1301
1302
        embed_sizes = [p.shape[1] for p in patch_embeds]

1303
        # flatten to a single sequence
1304
        patch_embeds = torch.cat(patch_embeds, dim=1)
1305
1306
1307
1308
1309
        patch_embeds = self.ln_pre(patch_embeds)

        # positional embeddings
        position_ids = position_ids_in_meshgrid(
            patch_embeds_list,
1310
1311
1312
            max_width=self.config.image_size // self.config.patch_size,
        ).to(self.device)
        position_embedding = self.patch_positional_embedding(patch_embeds, position_ids)
1313
1314
1315

        if USE_XFORMERS_OPS:
            attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
1316
1317
                [p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
            )
1318
1319
        else:
            from transformers.models.pixtral.modeling_pixtral import (
1320
1321
1322
                generate_block_attention_mask,
            )

1323
            attention_mask = generate_block_attention_mask(
1324
1325
                [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds
            )
1326

1327
1328
1329
1330
        out = self.transformer(
            patch_embeds,
            attention_mask,
            position_embedding,
1331
1332
            return_all_hidden_states=select_layers is not None,
        )
1333

1334
1335
1336
1337
1338
1339
1340
        out = resolve_visual_encoder_outputs(
            out,
            None,
            select_layers=select_layers,
            max_possible_layers=self.config.num_hidden_layers,
            feature_select_strategy=feature_select_strategy,
        )
1341

1342
        # squeeze dim 0 and split into separate tensors for each image
1343
        return torch.split(out.squeeze(0), embed_sizes)
1344
1345
1346

    # (TODO) Add prefix argument for filtering out weights to be loaded
    #        ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
1347
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1348
1349
1350
1351
1352
1353
1354
1355
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
1356
        params_dict = dict(self.named_parameters())
1357
        loaded_params: set[str] = set()
1358
        layer_count = len(self.transformer.layers)
1359
1360

        for name, loaded_weight in weights:
1361
1362
1363
1364
1365
1366
            # omit layers when num_hidden_layers_override is set
            if name.startswith("transformer.layers"):
                layer_idx = int(name.split(".")[2])
                if layer_idx >= layer_count:
                    continue

1367
            for param_name, weight_name, shard_id in stacked_params_mapping:
1368
1369
                if weight_name not in name:
                    continue
1370
1371
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
1372
1373
1374
1375
1376
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
1377
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
1378
                weight_loader(param, loaded_weight)
1379
1380
            loaded_params.add(name)
        return loaded_params