"examples/others/tensorize_vllm_model.py" did not exist on "e656f638de122095bab21a5f5f9c21d2e4974b07"
pixtral.py 48.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import math
5
from collections.abc import Iterable, Mapping, Sequence
Patrick von Platen's avatar
Patrick von Platen committed
6
from dataclasses import dataclass, fields
7
from functools import cached_property
8
from typing import Annotated, Literal, Optional, Union
Patrick von Platen's avatar
Patrick von Platen committed
9
10
11
12

import torch
import torch.nn as nn
import torch.nn.functional as F
13
14
15
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
                                                       UserMessage)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
16
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
Patrick von Platen's avatar
Patrick von Platen committed
17
from PIL import Image
18
19
from transformers import PixtralVisionConfig, TensorType
from transformers.image_utils import ImageInput
20
from transformers.models.pixtral.image_processing_pixtral import (
21
    _num_image_tokens as _get_pixtral_hf_num_image_tokens)
22
from transformers.models.pixtral.modeling_pixtral import (
23
    PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid)
24
from transformers.tokenization_utils_base import TextInput
Patrick von Platen's avatar
Patrick von Platen committed
25

26
from vllm.config import VllmConfig
27
from vllm.distributed import divide, get_tensor_model_parallel_world_size
28
from vllm.model_executor.layers.activation import get_act_and_mul_fn
Patrick von Platen's avatar
Patrick von Platen committed
29
from vllm.model_executor.layers.layernorm import RMSNorm
30
31
32
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
Patrick von Platen's avatar
Patrick von Platen committed
33
34
35
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
36
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems
37
38
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
                                    NestedTensors)
39
40
41
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
                                   MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
42
43
                                        BaseProcessingInfo,
                                        MultiModalProcessingInfo,
44
45
                                        PromptReplacement, PromptUpdate,
                                        PromptUpdateDetails)
46
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
47
from vllm.platforms import current_platform
48
49
50
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
                                               cached_tokenizer_from_config)
51
from vllm.utils.tensor_schema import TensorSchema, TensorShape
Patrick von Platen's avatar
Patrick von Platen committed
52

53
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
54
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
55
                    merge_multimodal_embeddings)
56
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
Patrick von Platen's avatar
Patrick von Platen committed
57

58
59
try:
    from xformers import ops as xops
60
61
62
63
64
65
    if (current_platform.is_cuda()
            and current_platform.has_device_capability(100)):
        # Xformers FA is not compatible with B200
        USE_XFORMERS_OPS = False
    else:
        USE_XFORMERS_OPS = True
66
67
68
except ImportError:
    USE_XFORMERS_OPS = False

Patrick von Platen's avatar
Patrick von Platen committed
69
70
PATCH_MERGE = "patch_merge"

Patrick von Platen's avatar
Patrick von Platen committed
71

72
class PixtralImagePixelInputs(TensorSchema):
73
    """
74
75
76
77
78
79
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height of each image
        - w: Width of each image
    
80
    The result of stacking `ImageEncoding.tokens` from each prompt.
81
    """
82
83
84
85
    type: Literal["pixel_values"] = "pixel_values"

    images: Annotated[Union[torch.Tensor, list[torch.Tensor]],
                      TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"})]
Patrick von Platen's avatar
Patrick von Platen committed
86
87


88
89
90
class PixtralProcessorAdapter:
    """
    Provide a HF-compatible interface for
91
    `mistral_common.tokens.tokenizers.multimodal.ImageEncoder`.
92
    """
Patrick von Platen's avatar
Patrick von Platen committed
93

94
95
    def __init__(self, tokenizer: MistralTokenizer) -> None:
        super().__init__()
Patrick von Platen's avatar
Patrick von Platen committed
96

97
        self.tokenizer = tokenizer
Patrick von Platen's avatar
Patrick von Platen committed
98

99
100
101
102
103
    @property
    def image_processor(self) -> ImageEncoder:
        image_encoder = self.tokenizer.instruct.mm_encoder
        assert isinstance(image_encoder, ImageEncoder)
        return image_encoder
104

105
106
107
    @cached_property
    def image_break_id(self) -> int:
        return self.image_processor.special_ids.img_break
Patrick von Platen's avatar
Patrick von Platen committed
108

109
110
111
    @cached_property
    def image_token_id(self) -> int:
        return self.image_processor.special_ids.img
Patrick von Platen's avatar
Patrick von Platen committed
112

113
114
115
    @cached_property
    def image_end_id(self) -> int:
        return self.image_processor.special_ids.img_end
Patrick von Platen's avatar
Patrick von Platen committed
116

117
118
119
    @cached_property
    def image_size(self) -> int:
        return self.image_processor.mm_config.max_image_size
Patrick von Platen's avatar
Patrick von Platen committed
120

121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    @cached_property
    def patch_size(self) -> int:
        return self.image_processor.mm_config.image_patch_size

    def __call__(
        self,
        text: Optional[Union[TextInput, list[TextInput]]] = None,
        images: Optional[Union[ImageInput, list[ImageInput]]] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        **kwargs,
    ) -> Mapping[str, NestedTensors]:
        if text is None:
            text = []
        if not isinstance(text, list):
            text = [text]
        if images is None:
            images = []
        if not isinstance(images, list):
            images = [images]

        if not images:
            input_ids = self.tokenizer(text).input_ids

            return {"input_ids": torch.tensor(input_ids)}

        # Allow dummy text, which is used for profiling as well as token inputs
        if any(len(t) > 0 for t in text):
            raise ValueError(
                "You've passed text inputs instead of token inputs. "
                "Make sure to process your input via `mistral_common`'s "
                "tokenizer or pass a chat completion request. "
                "For more info, see: "
                "https://github.com/vllm-project/vllm/issues/8411.")

        images_processed = list[torch.Tensor]()
        images_tokens = list[torch.Tensor]()

        for image in images:
            image_inputs = self.image_processor(ImageChunk(image=image))
            image_processed = torch.tensor(image_inputs.image)
            image_tokens = torch.tensor(image_inputs.tokens)

            images_processed.append(image_processed)
            images_tokens.append(image_tokens)

        return {
            "input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
            "images": images_processed,
        }


class PixtralProcessingInfo(BaseProcessingInfo):

    def get_tokenizer(self) -> MistralTokenizer:
        tokenizer = cached_tokenizer_from_config(self.ctx.model_config)
        if not isinstance(tokenizer, MistralTokenizer):
            raise ValueError("This model requires `--tokenizer-mode mistral`")

        return tokenizer

    def get_hf_processor(self) -> PixtralProcessorAdapter:
        return PixtralProcessorAdapter(self.get_tokenizer())

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}

    def get_vision_config(
        self,
        processor: Optional[PixtralProcessorAdapter] = None,
    ):
        if processor is None:
            processor = self.get_hf_processor()

        return PixtralVisionConfig(
            image_size=processor.image_size,
            patch_size=processor.patch_size,
        )

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        processor: Optional[PixtralProcessorAdapter] = None,
    ) -> int:
        if processor is None:
            processor = self.get_hf_processor()

        ncols, nrows = processor.image_processor._image_to_num_tokens(
            Image.new("RGB", (image_width, image_height)))

212
        return ncols * nrows
213
214
215
216
217
218
219
220
221
222

    def get_image_size_with_most_features(self) -> ImageSize:
        image_processor = self.get_hf_processor().image_processor
        max_image_size = image_processor.mm_config.max_image_size

        return ImageSize(width=max_image_size, height=max_image_size)


class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):

223
224
225
226
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
227
228
229
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
230
    ) -> MultiModalDataDict:
231
232
233
234
235
        num_images = mm_counts.get("image", 0)

        target_width, target_height = \
            self.info.get_image_size_with_most_features()

236
        return {
237
238
239
240
241
242
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }

243
244
245
246
247
248
249
250
251
252
    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
        tokenizer = self.info.get_tokenizer()

        dummy_text = self.get_dummy_text(mm_counts)
        dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
        dummy_images = dummy_mm_data.get("image", [])
253
        tokenization_kwargs = {"truncation": False}
254
255
256
257
258
259
260
261
262
263

        request = ChatCompletionRequest(messages=[
            UserMessage(content=[
                TextChunk(text=dummy_text),
                *(ImageChunk(image=image) for image in dummy_images),
            ]),
        ])
        res = tokenizer.mistral.encode_chat_completion(request)
        dummy_tokens = res.tokens

264
265
266
        return ProcessorInputs(prompt=dummy_tokens,
                               mm_data=dummy_mm_data,
                               tokenization_kwargs=tokenization_kwargs)
267

268
269
270

class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
                                 ):
Patrick von Platen's avatar
Patrick von Platen committed
271

272
273
274
275
276
    def _get_mm_fields_config(
        self,
        hf_inputs: Mapping[str, NestedTensors],
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
277
        return dict(images=MultiModalFieldConfig.batched("image"))
278
279
280
281
282

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
283
        out_mm_kwargs: MultiModalKwargsItems,
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
    ) -> Sequence[PromptUpdate]:
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

        image_break_id = processor.image_break_id
        image_token_id = processor.image_token_id
        image_end_id = processor.image_end_id

        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)

            ncols, nrows = processor.image_processor._image_to_num_tokens(
                Image.new("RGB", (image_size.width, image_size.height)))

            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id

301
            return PromptUpdateDetails.select_token_id(tokens, image_token_id)
302
303
304
305
306
307
308
309
310
311
312
313
314
315

        return [
            PromptReplacement(
                modality="image",
                target="",  # Never match the prompt (see below note)
                replacement=get_replacement,
            ),
        ]

    def _cached_apply_hf_processor(
        self,
        prompt: Union[str, list[int]],
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
316
        tokenization_kwargs: Mapping[str, object],
317
        mm_hash_overrides: Optional[dict[str, list[str]]] = None,
318
319
    ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
        prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
320
321
322
            prompt=prompt,
            mm_data_items=mm_data_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
323
            tokenization_kwargs=tokenization_kwargs,
324
            mm_hash_overrides=mm_hash_overrides,
325
326
327
        )

        # NOTE: The tokens are already inserted by the chat template
328
        return prompt_ids, mm_info, True
Patrick von Platen's avatar
Patrick von Platen committed
329

330
331
332
333

@MULTIMODAL_REGISTRY.register_processor(PixtralMultiModalProcessor,
                                        info=PixtralProcessingInfo,
                                        dummy_inputs=PixtralDummyInputsBuilder)
334
335
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
                                      SupportsPP):
Patrick von Platen's avatar
Patrick von Platen committed
336

337
338
339
340
341
342
343
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

344
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Patrick von Platen's avatar
Patrick von Platen committed
345
        super().__init__()
346
347
        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config
Patrick von Platen's avatar
Patrick von Platen committed
348
349
350
351
352
353
354
355
356
357
358
359
360
361
        self.config = config
        self.multimodal_config = multimodal_config

        dataclass_fields = {field.name for field in fields(VisionEncoderArgs)}
        vision_args = {
            key: value
            for key, value in self.config.vision_config.to_dict().items()
            if key in dataclass_fields
        }

        self.vision_args = VisionEncoderArgs(**vision_args)

        # init MistralForCausalLM
        self.language_model = init_vllm_registered_model(
362
            vllm_config=vllm_config,
363
364
365
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
Patrick von Platen's avatar
Patrick von Platen committed
366
367

        self.vision_encoder = VisionTransformer(self.vision_args)
Patrick von Platen's avatar
Patrick von Platen committed
368
369
370
371
372
373
374
375
376
377
378
379

        if self.vision_args.add_pre_mm_projector_layer_norm:
            self.pre_mm_projector_norm = RMSNorm(self.vision_args.hidden_size,
                                                 eps=1e-5)

        if self.vision_args.mm_projector_id == PATCH_MERGE:
            self.patch_merger = PatchMerger(
                vision_encoder_dim=self.vision_args.hidden_size,
                spatial_merge_size=self.vision_args.spatial_merge_size,
                use_mlp_bias=False,
            )

Patrick von Platen's avatar
Patrick von Platen committed
380
381
382
        self.vision_language_adapter = VisionLanguageAdapter(
            self.vision_args, dim=config.text_config.hidden_size)

383
384
385
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[PixtralImagePixelInputs]:
        images = kwargs.pop("images", None)
        if images is None:
            return None

        return PixtralImagePixelInputs(
            type="pixel_values",
            images=flatten_bn(images),
        )

    def _process_image_input(
        self,
        image_input: PixtralImagePixelInputs,
    ) -> tuple[torch.Tensor, ...]:
        images = image_input["images"]
        image_features = self.vision_encoder(images)
        feature_sizes = [
            image_feature.shape[0] for image_feature in image_features
        ]
Patrick von Platen's avatar
Patrick von Platen committed
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
        image_features = torch.cat(image_features)
        if self.vision_args.add_pre_mm_projector_layer_norm:
            image_features = self.pre_mm_projector_norm(image_features)
        if self.vision_args.mm_projector_id == PATCH_MERGE:
            patch_size = self.vision_args.patch_size
            spatial_merge_size_square = self.vision_args.spatial_merge_size**2
            img_patch_dims = [(img.shape[1] // patch_size,
                               img.shape[2] // patch_size) for img in images]
            feature_sizes = [
                feature_size // spatial_merge_size_square
                for feature_size in feature_sizes
            ]
            image_features = self.patch_merger(image_features,
                                               image_sizes=img_patch_dims)
        image_embeds = self.vision_language_adapter(image_features)
421
422
423
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

424
425
426
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

427
428
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
429
        image_input = self._parse_and_validate_image_input(**kwargs)
430
        if image_input is None:
431
            return []
432

433
        return self._process_image_input(image_input)
434
435
436
437

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
438
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
439
440
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
441
442
        if multimodal_embeddings is not None \
            and len(multimodal_embeddings) != 0:
443
            inputs_embeds = merge_multimodal_embeddings(
444
445
                input_ids,
                inputs_embeds,
446
                multimodal_embeddings,
447
448
                self.vision_args.image_token_id,
            )
449
450
        return inputs_embeds

Patrick von Platen's avatar
Patrick von Platen committed
451
452
453
454
455
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
456
        inputs_embeds: Optional[torch.Tensor] = None,
Patrick von Platen's avatar
Patrick von Platen committed
457
        **kwargs: object,
458
    ) -> Union[torch.Tensor, IntermediateTensors]:
459
        """Run forward pass for pixtral."""
460
461
        if intermediate_tensors is not None:
            inputs_embeds = None
Patrick von Platen's avatar
Patrick von Platen committed
462

463
464
465
466
467
468
469
        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
Patrick von Platen's avatar
Patrick von Platen committed
470
471
472

        hidden_states = self.language_model.model(input_ids,
                                                  positions,
473
                                                  intermediate_tensors,
Patrick von Platen's avatar
Patrick von Platen committed
474
475
476
477
478
479
480
481
482
483
484
485
                                                  inputs_embeds=inputs_embeds)

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)

486
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
Patrick von Platen's avatar
Patrick von Platen committed
487

488
        def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]):
Patrick von Platen's avatar
Patrick von Platen committed
489
490
            return weight[0].startswith("vision_encoder")

491
        def is_vision_lang_adapter_weights(weight: tuple[str, torch.Tensor]):
Patrick von Platen's avatar
Patrick von Platen committed
492
493
            return weight[0].startswith("vision_language_adapter")

494
        def is_patch_merger(weight: tuple[str, torch.Tensor]):
Patrick von Platen's avatar
Patrick von Platen committed
495
496
            return weight[0].startswith("patch_merger")

497
        def is_pre_mm_projector_norm(weight: tuple[str, torch.Tensor]):
Patrick von Platen's avatar
Patrick von Platen committed
498
499
            return weight[0].startswith("pre_mm_projector_norm")

500
        # Get references to parameters for direct loading
Patrick von Platen's avatar
Patrick von Platen committed
501
        vision_encoder_dict = dict(self.vision_encoder.named_parameters())
Patrick von Platen's avatar
Patrick von Platen committed
502
503
504
505
506
        patch_merger_dict = dict(self.patch_merger.named_parameters(
        )) if self.vision_args.mm_projector_id == PATCH_MERGE else dict()
        pre_mm_projector_norm_dict = dict(
            self.pre_mm_projector_norm.named_parameters(
            )) if self.vision_args.add_pre_mm_projector_layer_norm else dict()
507
        vision_lang_adapter_dict = dict(
Patrick von Platen's avatar
Patrick von Platen committed
508
            self.vision_language_adapter.named_parameters())
509
510
511
512
513
514
515
516
517
518

        def llm_weights_generator():
            # Single pass over weights
            for name, w in weights:
                if is_vision_encoder_weights((name, w)):
                    # Load vision encoder weights directly
                    trimmed_name = '.'.join(name.split(".")[1:])
                    param = vision_encoder_dict[trimmed_name]
                    with torch.no_grad():
                        default_weight_loader(param, w)
Patrick von Platen's avatar
Patrick von Platen committed
519
520
521
522
523
524
525
526
527
528
529
530
                elif is_patch_merger((name, w)):
                    # Load vision patch merger weights directly
                    trimmed_name = '.'.join(name.split(".")[1:])
                    param = patch_merger_dict[trimmed_name]
                    with torch.no_grad():
                        default_weight_loader(param, w)
                elif is_pre_mm_projector_norm((name, w)):
                    # Load vision pre_mm_projector_norm weights directly
                    trimmed_name = '.'.join(name.split(".")[1:])
                    param = pre_mm_projector_norm_dict[trimmed_name]
                    with torch.no_grad():
                        default_weight_loader(param, w)
531
532
533
534
535
536
537
538
539
540
541
542
543
                elif is_vision_lang_adapter_weights((name, w)):
                    # Load vision-language adapter weights directly
                    trimmed_name = '.'.join(name.split(".")[1:])
                    param = vision_lang_adapter_dict[trimmed_name]
                    with torch.no_grad():
                        default_weight_loader(param, w)
                else:
                    # LLM weights: yield them to be loaded
                    # by language_model.load_weights
                    yield (name, w)

        # Now we call the language model load with the generator
        self.language_model.load_weights(llm_weights_generator())
Patrick von Platen's avatar
Patrick von Platen committed
544
545
546
547
548
549
550
551
552
553
554
555
556
557


# Vision encoder
@dataclass
class VisionEncoderArgs:
    hidden_size: int
    num_channels: int
    image_size: int
    patch_size: int
    intermediate_size: int
    num_hidden_layers: int
    num_attention_heads: int
    rope_theta: float  # for rope-2D
    image_token_id: int
558
    adapter_bias: bool = True
Patrick von Platen's avatar
Patrick von Platen committed
559
560
561
    spatial_merge_size: int = 1
    add_pre_mm_projector_layer_norm: bool = False
    mm_projector_id: str = ""
Patrick von Platen's avatar
Patrick von Platen committed
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613


def _reshape_for_broadcast(freqs_cis: torch.Tensor,
                           x: torch.Tensor) -> torch.Tensor:
    """
    freqs_cis: complex - (seq_len, head_dim / 2)
    x: complex - (bsz, seq_len, head_dim / 2)
    """
    ndim = x.ndim
    assert ndim > 1
    assert freqs_cis.shape == (x.shape[1], x.shape[-1]), (
        freqs_cis.shape,
        (x.shape[1], x.shape[-1]),
    )
    shape = [
        d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)
    ]
    return freqs_cis.view(*shape)


def precompute_freqs_cis_2d(
    dim: int,
    height: int,
    width: int,
    theta: float,
) -> torch.Tensor:
    """
    freqs_cis: 2D complex tensor of shape (height, width, dim // 2)
        to be indexed by (height, width) position tuples
    """
    # (dim / 2) frequency bases
    freqs = 1.0 / (theta**(torch.arange(0, dim, 2).float() / dim))

    h = torch.arange(height, device=freqs.device)
    w = torch.arange(width, device=freqs.device)

    freqs_h = torch.outer(h, freqs[::2]).float()
    freqs_w = torch.outer(w, freqs[1::2]).float()
    freqs_2d = torch.cat(
        [
            freqs_h[:, None, :].repeat(1, width, 1),
            freqs_w[None, :, :].repeat(height, 1, 1),
        ],
        dim=-1,
    )
    return torch.polar(torch.ones_like(freqs_2d), freqs_2d)


def apply_rotary_emb_vit(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
614
) -> tuple[torch.Tensor, torch.Tensor]:
Patrick von Platen's avatar
Patrick von Platen committed
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    assert freqs_cis.dtype == torch.complex64
    freqs_cis = _reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)


class FeedForward(nn.Module):

    def __init__(self, args: VisionEncoderArgs):
        super().__init__()
        assert args.intermediate_size is not None
        self.w1 = nn.Linear(args.hidden_size,
                            args.intermediate_size,
                            bias=False)
        self.w2 = nn.Linear(args.intermediate_size,
                            args.hidden_size,
                            bias=False)
        self.w3 = nn.Linear(args.hidden_size,
                            args.intermediate_size,
                            bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


class Attention(nn.Module):

    def __init__(self, args: VisionEncoderArgs):
        super().__init__()
        self.args = args
        assert not args.hidden_size % args.num_attention_heads
        self.n_heads = args.num_attention_heads
        self.head_dim = args.hidden_size // args.num_attention_heads

        self.wq = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
        self.wk = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
        self.wv = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
        self.wo = nn.Linear(args.hidden_size, args.hidden_size, bias=False)

    def forward(
        self,
        x: torch.Tensor,
660
        mask: torch.Tensor,
Patrick von Platen's avatar
Patrick von Platen committed
661
662
663
664
665
666
667
668
669
670
        freqs_cis: torch.Tensor,
    ) -> torch.Tensor:
        batch, patches, _ = x.shape

        q, k, v = self.wq(x), self.wk(x), self.wv(x)
        q = q.reshape(batch, patches, self.n_heads, self.head_dim)
        k = k.reshape(batch, patches, self.n_heads, self.head_dim)
        v = v.reshape(batch, patches, self.n_heads, self.head_dim)

        q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
671
672
673
674
675
676
677
678
679
680
681
682
683

        if USE_XFORMERS_OPS:
            out = xops.memory_efficient_attention(q, k, v, attn_bias=mask)
        else:
            q = q.transpose(1, 2)
            k = k.transpose(1, 2)
            v = v.transpose(1, 2)
            out = nn.functional.scaled_dot_product_attention(q,
                                                             k,
                                                             v,
                                                             attn_mask=mask)
            out = out.transpose(1, 2)

Patrick von Platen's avatar
Patrick von Platen committed
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
        out = out.reshape(batch, patches, self.n_heads * self.head_dim)
        return self.wo(out)


class TransformerBlock(nn.Module):

    def __init__(self, args: VisionEncoderArgs):
        super().__init__()
        self.attention = Attention(args)
        self.feed_forward = FeedForward(args)
        self.attention_norm = RMSNorm(args.hidden_size, eps=1e-5)
        self.ffn_norm = RMSNorm(args.hidden_size, eps=1e-5)

    def forward(
        self,
        x: torch.Tensor,
700
        mask: torch.Tensor,
Patrick von Platen's avatar
Patrick von Platen committed
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
        freqs_cis: torch.Tensor,
    ) -> torch.Tensor:
        r = self.attention.forward(self.attention_norm(x),
                                   mask=mask,
                                   freqs_cis=freqs_cis)
        h = x + r
        r = self.feed_forward.forward(self.ffn_norm(h))
        out = h + r
        return out


class Transformer(nn.Module):

    def __init__(self, args: VisionEncoderArgs):
        super().__init__()
        self.layers = torch.nn.ModuleList()
        for _ in range(args.num_hidden_layers):
            self.layers.append(TransformerBlock(args))

    def forward(
        self,
        x: torch.Tensor,
723
        mask: torch.Tensor,
Patrick von Platen's avatar
Patrick von Platen committed
724
725
726
727
728
729
730
        freqs_cis: Optional[torch.Tensor],
    ) -> torch.Tensor:
        for layer in self.layers:
            x = layer(x, mask=mask, freqs_cis=freqs_cis)
        return x


731
def position_meshgrid(patch_embeds_list: list[torch.Tensor], ) -> torch.Tensor:
Patrick von Platen's avatar
Patrick von Platen committed
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
    positions = torch.cat([
        torch.stack(
            torch.meshgrid(
                torch.arange(p.shape[-2]),
                torch.arange(p.shape[-1]),
                indexing="ij",
            ),
            dim=-1,
        ).reshape(-1, 2) for p in patch_embeds_list
    ])
    return positions


class VisionTransformer(nn.Module):

    def __init__(self, args: VisionEncoderArgs):
        super().__init__()
        self.args = args
        self.patch_conv = nn.Conv2d(
            in_channels=args.num_channels,
            out_channels=args.hidden_size,
            kernel_size=args.patch_size,
            stride=args.patch_size,
            bias=False,
        )
        self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5)
        self.transformer = Transformer(args)

        head_dim = self.args.hidden_size // self.args.num_attention_heads
        assert head_dim % 2 == 0, "ROPE requires even head_dim"
        self._freqs_cis: Optional[torch.Tensor] = None

    @property
    def max_patches_per_side(self) -> int:
        return self.args.image_size // self.args.patch_size

    @property
769
    def device(self) -> torch.types.Device:
Patrick von Platen's avatar
Patrick von Platen committed
770
771
772
        return next(self.parameters()).device

    @property
773
    def dtype(self) -> torch.dtype:
Patrick von Platen's avatar
Patrick von Platen committed
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
        return next(self.parameters()).dtype

    @property
    def freqs_cis(self) -> torch.Tensor:
        if self._freqs_cis is None:
            self._freqs_cis = precompute_freqs_cis_2d(
                dim=self.args.hidden_size // self.args.num_attention_heads,
                height=self.max_patches_per_side,
                width=self.max_patches_per_side,
                theta=self.args.rope_theta,
            )

        if self._freqs_cis.device != self.device:
            self._freqs_cis = self._freqs_cis.to(device=self.device)

        return self._freqs_cis

    def forward(
        self,
793
        images: list[torch.Tensor],
Patrick von Platen's avatar
Patrick von Platen committed
794
795
796
    ) -> torch.Tensor:
        """
        Args:
797
            images: list of N_img images of variable sizes,
Patrick von Platen's avatar
Patrick von Platen committed
798
799
                each of shape (C, H, W)
        Returns:
800
            image_features: tensor of token features for
Patrick von Platen's avatar
Patrick von Platen committed
801
802
803
804
805
806
807
                all tokens of all images of shape (N_toks, D)
        """
        # pass images through initial convolution independently
        patch_embeds_list = [
            self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images
        ]

808
809
810
811
812
        patch_embeds = [
            p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list
        ]
        embed_sizes = [p.shape[1] for p in patch_embeds]

Patrick von Platen's avatar
Patrick von Platen committed
813
        # flatten to a single sequence
814
        patch_embeds = torch.cat(patch_embeds, dim=1)
Patrick von Platen's avatar
Patrick von Platen committed
815
816
817
818
819
820
821
        patch_embeds = self.ln_pre(patch_embeds)

        # positional embeddings
        positions = position_meshgrid(patch_embeds_list).to(self.device)
        freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]]

        # pass through Transformer with a block diagonal mask delimiting images
822
823
824
825
        if USE_XFORMERS_OPS:
            mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
                [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
        else:
826
827
828
829
830
            from transformers.models.pixtral.modeling_pixtral import (
                generate_block_attention_mask)
            mask = generate_block_attention_mask(
                [p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
                patch_embeds)
Patrick von Platen's avatar
Patrick von Platen committed
831
832
        out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)

833
834
        # squeeze dim 0 and split into separate tensors for each image
        return torch.split(out.squeeze(0), embed_sizes)
Patrick von Platen's avatar
Patrick von Platen committed
835
836
837
838
839
840
841
842
843
844


class VisionLanguageAdapter(nn.Module):

    def __init__(self, args: VisionEncoderArgs, dim: int):
        super().__init__()
        assert isinstance(args, VisionEncoderArgs)
        self.w_in = nn.Linear(
            args.hidden_size,
            dim,
845
            bias=args.adapter_bias,
Patrick von Platen's avatar
Patrick von Platen committed
846
847
        )
        self.gelu = nn.GELU()
848
        self.w_out = nn.Linear(dim, dim, bias=args.adapter_bias)
Patrick von Platen's avatar
Patrick von Platen committed
849
850
851

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w_out(self.gelu(self.w_in(x)))
852
853


Patrick von Platen's avatar
Patrick von Platen committed
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
class PatchMerger(nn.Module):
    """
    Learned merging of spatial_merge_size ** 2 patches
    """

    def __init__(
        self,
        vision_encoder_dim: int,
        spatial_merge_size: int,
        use_mlp_bias: bool = False,
    ) -> None:
        super().__init__()

        mlp_input_dim = vision_encoder_dim * (spatial_merge_size**2)

        self.spatial_merge_size = spatial_merge_size
        self.mlp_input_dim = mlp_input_dim

        self.merging_layer = nn.Linear(
            mlp_input_dim,
            vision_encoder_dim,
            bias=use_mlp_bias,
        )

    def forward(self, x: torch.Tensor,
                image_sizes: list[tuple[int, int]]) -> torch.Tensor:
        # image_sizes specified in tokens
        assert sum([h * w for h, w in image_sizes]) == len(x)

        # x is (N, vision_encoder_dim)
        x = self.permute(x, image_sizes)

886
887
        # x is (N / spatial_merge_size ** 2,
        #       vision_encoder_dim * spatial_merge_size ** 2)
Patrick von Platen's avatar
Patrick von Platen committed
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
        x = self.merging_layer(x)

        # x is (N / spatial_merge_size ** 2, vision_encoder_dim)
        return x

    def permute(
        self,
        x: torch.Tensor,
        image_sizes: list[tuple[int, int]],
    ) -> torch.Tensor:
        """
        Args:
            x: (N, D) where N is flattened and concatenated patch tokens
                for all images
            image_sizes: list of tuple of (height, width) in tokens for
                each image
        Returns:
            image_features: reorders patch tokens so each grid of
                (spatial_merge_size, spatial_merge_size) is contiguous.
                now (N / spatial_merge_size ** 2, D * spatial_merge_size ** 2)
        """

        sub_grids = get_sub_grids(
            x=x,
            image_sizes=image_sizes,
            spatial_merge_size=self.spatial_merge_size
        )  # list of [d x sub_grid_size x sub_grid_size x n_patches]
        permuted_tensor: list[torch.Tensor] = []
        for grid in sub_grids:
            n_patches = grid.shape[-1]
            permuted_tensor.append(grid.view(-1, n_patches).t(
            ))  # n_patches x d * sub_grid_size * sub_grid_size
        return torch.cat(
            permuted_tensor, dim=0
        )  # (N / spatial_merge_size ** 2, d * spatial_merge_size ** 2)


def get_sub_grids(
    x: torch.Tensor,
    image_sizes: list[tuple[int, int]],
    spatial_merge_size: int,
) -> list[torch.Tensor]:
    # image_sizes specified in tokens
    tokens_per_image = [h * w for h, w in image_sizes]
    d = x.shape[-1]
    all_img_sub_grids: list[torch.Tensor] = []
    sub_grid_size = spatial_merge_size

    for image_index, image_tokens in enumerate(x.split(tokens_per_image)):
        # Reshape image_tokens into a 2D grid
        h, w = image_sizes[image_index]
        image_grid = image_tokens.view(h, w, d).permute(
            2, 0, 1)[None, :, :, :]  # 1 x d x h x w
        sub_grids = torch.nn.functional.unfold(image_grid,
                                               kernel_size=sub_grid_size,
                                               stride=sub_grid_size)
        sub_grids = sub_grids.view(
            1, d, sub_grid_size, sub_grid_size,
            -1)  # 1 x d x sub_grid_size x sub_grid_size x n_patches

        all_img_sub_grids.append(sub_grids[0])

    return all_img_sub_grids


953
954
955
956
957
958
959
960
#### HF Transformers version of Pixtral ####
# Based off https://github.com/huggingface/transformers/blob/d7950bff82b18c823193d17d72188c5e46d06c83/src/transformers/models/pixtral/modeling_pixtral.py
# This model follows the Llava family, meaning image embeddings are placed
# instead of the `[IMG]` token placeholders.
# The model uses [`PixtralVisionModel`] for its vision encoder,
# and [`MistralForCausalLM`] for its language decoder.


961
962
963
964
965
966
967
968
class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
969
970
971
        ncols, nrows = self.get_patch_grid_size(
            image_width=image_width,
            image_height=image_height,
972
        )
973
        return ncols * nrows
974

975
976
977
978
    def get_image_size(self) -> int:
        return self.vision_config.image_size

    def get_patch_size(self) -> int:
979
980
981
        # spatial_merge_size is needed for Mistral3
        spatial_merge_size = getattr(self.hf_config, "spatial_merge_size", 1)
        return self.vision_config.patch_size * spatial_merge_size
982
983

    def get_patch_grid_length(self) -> int:
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
        image_size, patch_size = self.get_image_size(), self.get_patch_size()

        # Since interpolation is applied, the image size need not be divisible
        # assert image_size % patch_size == 0
        return image_size // patch_size

    # Adapted from: https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/pixtral/image_processing_pixtral.py#L99
    def get_patch_grid_size(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> tuple[int, int]:
        max_width = max_height = self.get_image_size()
        patch_width = patch_height = self.get_patch_size()

        ratio = max(image_width / max_width, image_height / max_height)

        if ratio > 1:
1003
1004
            image_width = int(math.floor(image_width / ratio))
            image_height = int(math.floor(image_height / ratio))
1005
1006
1007
1008
1009
1010
1011

        nrows, ncols = _get_pixtral_hf_num_image_tokens(
            (image_height, image_width),
            (patch_height, patch_width),
        )  # type: ignore

        return ncols, nrows
1012
1013
1014
1015


class PixtralHFMLP(nn.Module):

1016
1017
1018
1019
1020
1021
1022
    def __init__(
        self,
        config: PixtralVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        prefix: str = "",
    ) -> None:
1023
        super().__init__()
1024

1025
        assert config.intermediate_size is not None
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
        self.gate_up_proj = MergedColumnParallelLinear(
            input_size=config.hidden_size,
            output_sizes=[config.intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj")
        self.down_proj = RowParallelLinear(input_size=config.intermediate_size,
                                           output_size=config.hidden_size,
                                           bias=False,
                                           quant_config=quant_config,
                                           prefix=f"{prefix}.down_proj")
        self.act_and_mul = get_act_and_mul_fn(config.hidden_act)
1038
1039

    def forward(self, x: torch.Tensor) -> torch.Tensor:
1040
1041
1042
1043
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_and_mul(gate_up)
        x, _ = self.down_proj(x)
        return x
1044
1045
1046
1047


class PixtralHFAttention(nn.Module):

1048
1049
1050
1051
1052
1053
1054
    def __init__(
        self,
        config: PixtralVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        prefix: str = "",
    ) -> None:
1055
        super().__init__()
1056

1057
1058
        self.config = config
        assert not config.hidden_size % config.num_attention_heads
1059
1060
1061
        self.total_num_heads = config.num_attention_heads
        tp_size = get_tensor_model_parallel_world_size()
        self.n_heads = divide(config.num_attention_heads, tp_size)
1062
1063
        self.head_dim = config.hidden_size // config.num_attention_heads

1064
1065
1066
        self.qkv_proj = QKVParallelLinear(
            hidden_size=config.hidden_size,
            head_size=self.head_dim,
1067
            total_num_heads=self.total_num_heads,
1068
1069
1070
1071
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
1072
        assert self.total_num_heads * self.head_dim == config.hidden_size
1073
1074
1075
1076
1077
1078
1079
        self.o_proj = RowParallelLinear(
            input_size=config.hidden_size,
            output_size=config.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
1080
1081
1082
1083

    def forward(
        self,
        hidden_states: torch.Tensor,
1084
        attention_mask: torch.Tensor,
1085
        position_embeddings: torch.Tensor,
1086
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
1087
        batch, patches, _ = hidden_states.size()
1088

1089
1090
        qkv_states, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv_states.chunk(3, dim=-1)
1091

1092
1093
1094
        # Transpose q and k to apply HF's Rotary Position Embedding
        q = q.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
1095
        v = v.view(batch, patches, self.n_heads, self.head_dim)
1096
        cos, sin = position_embeddings
1097
        q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0)
1098

1099
1100
1101
1102
1103
1104
1105
1106
1107
        if USE_XFORMERS_OPS:
            # Transpose q and k back for attention
            q = q.transpose(1, 2).contiguous()
            k = k.transpose(1, 2).contiguous()
            out = xops.memory_efficient_attention(q,
                                                  k,
                                                  v,
                                                  attn_bias=attention_mask)
        else:
1108
            v = v.transpose(1, 2)
1109
1110
1111
            out = nn.functional.scaled_dot_product_attention(
                q, k, v, attn_mask=attention_mask)
            out = out.transpose(1, 2)
1112

1113
1114
        out = out.view(batch, patches, self.n_heads * self.head_dim)
        attn_output, _ = self.o_proj(out)
1115

1116
        return attn_output, None
1117
1118
1119
1120


class PixtralHFTransformerBlock(nn.Module):

1121
1122
1123
1124
1125
1126
1127
    def __init__(
        self,
        config: PixtralVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        prefix: str = "",
    ) -> None:
1128
        super().__init__()
1129

1130
        self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
1131
1132
1133
1134
1135
1136
        self.attention = PixtralHFAttention(config,
                                            quant_config=quant_config,
                                            prefix=f"{prefix}.attention")
        self.feed_forward = PixtralHFMLP(config,
                                         quant_config=quant_config,
                                         prefix=f"{prefix}.feed_forward")
1137
1138
1139
1140
1141
        self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)

    def forward(
        self,
        hidden_states: torch.Tensor,
1142
        attention_mask: torch.Tensor,
1143
1144
        position_embeddings: torch.Tensor,
    ) -> torch.Tensor:
1145
1146
1147
        r, _ = self.attention.forward(self.attention_norm(hidden_states),
                                      attention_mask=attention_mask,
                                      position_embeddings=position_embeddings)
1148
1149
1150
1151
1152
1153
1154
1155
        h = hidden_states + r
        r = self.feed_forward.forward(self.ffn_norm(h))
        out = h + r
        return out


class PixtralHFTransformer(nn.Module):

1156
1157
1158
1159
1160
1161
1162
1163
    def __init__(
        self,
        config: PixtralVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_hidden_layers_override: Optional[int] = None,
        prefix: str = "",
    ) -> None:
1164
        super().__init__()
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176

        if num_hidden_layers_override is None:
            num_hidden_layers = config.num_hidden_layers
        else:
            num_hidden_layers = num_hidden_layers_override

        self.layers = nn.ModuleList([
            PixtralHFTransformerBlock(config=config,
                                      quant_config=quant_config,
                                      prefix=f"{prefix}.layers.{layer_idx}")
            for layer_idx in range(num_hidden_layers)
        ])
1177
1178
1179
1180

    def forward(
        self,
        x: torch.Tensor,
1181
        attention_mask: torch.Tensor,
1182
        position_embeddings: torch.Tensor,
1183
        return_all_hidden_states: bool,
1184
    ) -> torch.Tensor:
1185
        hidden_states_pool = [x]
1186

1187
1188
        for layer in self.layers:
            x = layer(x, attention_mask, position_embeddings)
1189
1190
1191
1192
1193
1194
            if return_all_hidden_states:
                hidden_states_pool.append(x)
        # If we have multiple feature sample layers, we return all hidden
        # states in order and grab the ones we need by index.
        if return_all_hidden_states:
            return hidden_states_pool
1195
1196
1197
1198
1199
        return x


class PixtralHFVisionModel(nn.Module):

1200
1201
1202
1203
1204
1205
1206
1207
1208
    def __init__(
        self,
        config: PixtralVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_hidden_layers_override: Optional[int] = None,
        require_post_norm: Optional[bool] = None,
        prefix: str = "",
    ) -> None:
1209
1210
1211
        super().__init__()

        self.config = config
1212

1213
1214
1215
1216
1217
1218
1219
1220
        self.patch_conv = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=config.hidden_size,
            kernel_size=config.patch_size,
            stride=config.patch_size,
            bias=False,
        )
        self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
        self.transformer = PixtralHFTransformer(
            config,
            quant_config,
            num_hidden_layers_override=num_hidden_layers_override,
            prefix=f"{prefix}.transformer",
        )

        num_hidden_layers = config.num_hidden_layers
        if len(self.transformer.layers) > config.num_hidden_layers:
            raise ValueError(
                f"The original encoder only has {num_hidden_layers} "
                f"layers, but you requested {len(self.transformer.layers)} "
                "layers.")

        if require_post_norm is True:
            msg = "PixtralHFVisionModel does not have post-layernorm"
            raise ValueError(msg)

1239
1240
1241
1242
1243
1244
1245
        self.dtype = next(self.parameters()).dtype
        self.device = next(self.parameters()).device
        self.patch_positional_embedding = PixtralRotaryEmbedding(
            config, self.device)

    def forward(
        self,
1246
        pixel_values: list[torch.Tensor],
1247
        feature_sample_layers: Optional[list[int]] = None,
1248
    ) -> tuple[torch.Tensor, ...]:
1249
1250
        """
        Args:
1251
1252
1253
1254
            pixel_values: Each image to be processed will be a separate tensor
                in pixel_values. This means it will be a list of tensors
                because multiple requests batched can have multiple images,
                each with their own shape potentially
1255
1256
1257
            feature_sample_layers: Layer indices whose features should be
                concatenated and used as the visual encoder output. If none
                are provided, the last layer is used.
1258

1259
1260
1261
1262
1263
1264
        Returns:
            image_features: tensor of token features for
                all tokens of all images of shape (N_toks, D)
        """
        # pass images through initial convolution independently
        patch_embeds_list = [
1265
            self.patch_conv(img.unsqueeze(0).to(self.dtype))
1266
1267
1268
            for img in pixel_values
        ]

1269
1270
1271
1272
1273
        patch_embeds = [
            p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list
        ]
        embed_sizes = [p.shape[1] for p in patch_embeds]

1274
        # flatten to a single sequence
1275
        patch_embeds = torch.cat(patch_embeds, dim=1)
1276
1277
1278
1279
1280
1281
1282
1283
1284
        patch_embeds = self.ln_pre(patch_embeds)

        # positional embeddings
        position_ids = position_ids_in_meshgrid(
            patch_embeds_list,
            max_width=self.config.image_size // self.config.patch_size).to(
                self.device)
        position_embedding = self.patch_positional_embedding(
            patch_embeds, position_ids)
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295

        if USE_XFORMERS_OPS:
            attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
                [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
        else:
            from transformers.models.pixtral.modeling_pixtral import (
                generate_block_attention_mask)
            attention_mask = generate_block_attention_mask(
                [p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
                patch_embeds)

1296
1297
1298
1299
1300
1301
1302
1303
1304
        return_all_hidden_states = feature_sample_layers is not None
        out = self.transformer(
            patch_embeds,
            attention_mask,
            position_embedding,
            return_all_hidden_states=return_all_hidden_states)

        out = resolve_visual_encoder_outputs(out, feature_sample_layers, None,
                                             self.config.num_hidden_layers)
1305

1306
        # squeeze dim 0 and split into separate tensors for each image
1307
        return torch.split(out.squeeze(0), embed_sizes)
1308
1309
1310

    # (TODO) Add prefix argument for filtering out weights to be loaded
    #        ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
1311
1312
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
1313
1314
1315
1316
1317
1318
1319
1320
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
1321
        params_dict = dict(self.named_parameters())
1322
        loaded_params: set[str] = set()
1323
        layer_count = len(self.transformer.layers)
1324
1325

        for name, loaded_weight in weights:
1326
1327
1328
1329
1330
1331
            # omit layers when num_hidden_layers_override is set
            if name.startswith("transformer.layers"):
                layer_idx = int(name.split(".")[2])
                if layer_idx >= layer_count:
                    continue

1332
1333
1334
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
1335
1336
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
1337
1338
1339
1340
1341
1342
1343
1344
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
1345
1346
            loaded_params.add(name)
        return loaded_params