pixtral.py 41.9 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
from dataclasses import dataclass, fields
2
from functools import cached_property
Patrick von Platen's avatar
Patrick von Platen committed
3
from itertools import tee
4
from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union
Patrick von Platen's avatar
Patrick von Platen committed
5

6
import numpy
Patrick von Platen's avatar
Patrick von Platen committed
7
8
9
10
11
import torch
import torch.nn as nn
import torch.nn.functional as F
from mistral_common.protocol.instruct.messages import ImageChunk
from PIL import Image
12
from transformers import PixtralVisionConfig
13
14
15
from transformers.models.pixtral.image_processing_pixtral import (
    _num_image_tokens)
from transformers.models.pixtral.modeling_pixtral import (
16
    PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid)
Patrick von Platen's avatar
Patrick von Platen committed
17
18

from vllm.attention import AttentionMetadata
19
from vllm.config import ModelConfig, VllmConfig
20
from vllm.distributed import divide, get_tensor_model_parallel_world_size
21
22
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
                         InputContext, token_inputs)
23
from vllm.model_executor.layers.activation import get_act_and_mul_fn
Patrick von Platen's avatar
Patrick von Platen committed
24
from vllm.model_executor.layers.layernorm import RMSNorm
25
26
27
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
Patrick von Platen's avatar
Patrick von Platen committed
28
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
29
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
Patrick von Platen's avatar
Patrick von Platen committed
30
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
31
from vllm.model_executor.models.utils import merge_multimodal_embeddings
Patrick von Platen's avatar
Patrick von Platen committed
32
from vllm.model_executor.sampling_metadata import SamplingMetadata
33
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
34
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
35
from vllm.multimodal.utils import (cached_get_tokenizer,
36
37
                                   consecutive_placeholder_ranges,
                                   resolve_visual_encoder_outputs)
38
from vllm.sequence import IntermediateTensors, SequenceData
39
40
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils import is_list_of
Patrick von Platen's avatar
Patrick von Platen committed
41

42
from .interfaces import SupportsMultiModal, SupportsPP
43
from .utils import init_vllm_registered_model, maybe_prefix
Patrick von Platen's avatar
Patrick von Platen committed
44

45
46
47
48
49
50
try:
    from xformers import ops as xops
    USE_XFORMERS_OPS = True
except ImportError:
    USE_XFORMERS_OPS = False

Patrick von Platen's avatar
Patrick von Platen committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69

def get_max_pixtral_image_tokens(ctx: InputContext):
    tokenizer = cached_get_tokenizer(
        ctx.model_config.tokenizer,
        tokenizer_mode=ctx.model_config.tokenizer_mode)
    mm_encoder = tokenizer.instruct.mm_encoder

    max_image_size = mm_encoder.mm_config.max_image_size
    image_patch_size = mm_encoder.mm_config.image_patch_size

    return ((max_image_size // image_patch_size)**2)


def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
                           mm_counts: Mapping[str, int]):
    tokenizer = cached_get_tokenizer(
        ctx.model_config.tokenizer,
        tokenizer_mode=ctx.model_config.tokenizer_mode)

70
71
72
    mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
    patch_size = mm_encoder.mm_config.image_patch_size
    image_token_id = mm_encoder.special_ids.img
Patrick von Platen's avatar
Patrick von Platen committed
73

74
75
    mm_config = ctx.model_config.multimodal_config
    num_images = mm_config.limit_per_prompt.get("image", 1)
Patrick von Platen's avatar
Patrick von Platen committed
76

77
78
    # dummy size
    size = 256
Patrick von Platen's avatar
Patrick von Platen committed
79
80
    image = Image.new("RGB", (size, size), color=0)

81
82
83
    image_feature_size = (size**2) // (patch_size**2)

    num_image_tokens = image_feature_size * num_images
84
    seq_data = SequenceData.from_prompt_token_counts(
85
86
87
        (image_token_id, num_image_tokens),
        (0, seq_len - num_image_tokens),
    )
88
89

    mm_data = {"image": num_images * [image]}
90
91
92
93
94
95
    mm_placeholders = {
        "image":
        consecutive_placeholder_ranges(num_items=num_images,
                                       item_size=image_feature_size)
    }
    return DummyData(seq_data, mm_data, mm_placeholders)
Patrick von Platen's avatar
Patrick von Platen committed
96
97
98


def input_mapper_for_pixtral(ctx: InputContext,
99
100
                             data: object) -> MultiModalKwargs:
    """Maps the input data to its MultiModalKwargs (if any).
Patrick von Platen's avatar
Patrick von Platen committed
101
102
103
104
105
106
107

    Args:
        ctx: Context of the loaded model.
        data: data potentially containing image/image embeddings to be mapped
            to pixel_values in .forward() for a visual QWenLMHeadModel model.

    Returns:
108
        MultiModalKwargs containing the stacked normalized images tensor or
Patrick von Platen's avatar
Patrick von Platen committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        image embeddings.
    """
    # Early exit if we have provided an image to a language only Qwen model
    model_config = ctx.model_config
    tokenizer = cached_get_tokenizer(
        model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode)

    data_list = data if isinstance(data, list) else [data]

    images = []
    for image_data in data_list:
        image = ImageChunk(image=image_data)
        encoding = tokenizer.instruct.mm_encoder(image)
        image = torch.from_numpy(encoding.image).to(device="cuda",
                                                    dtype=torch.float16)
        images.append(image)

126
    return MultiModalKwargs({"images": images})
Patrick von Platen's avatar
Patrick von Platen committed
127
128


129
130
def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs):
    multi_modal_data = inputs.get("multi_modal_data")
131
132
133
134
    if multi_modal_data is not None and "image" in multi_modal_data:
        tokenizer = cached_get_tokenizer(
            ctx.model_config.tokenizer,
            tokenizer_mode=ctx.model_config.tokenizer_mode)
Patrick von Platen's avatar
Patrick von Platen committed
135

136
137
        mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
        image_token_id = mm_encoder.special_ids.img
Patrick von Platen's avatar
Patrick von Platen committed
138

139
        if image_token_id not in inputs['prompt_token_ids']:
140
            raise ValueError(
141
142
143
144
145
                f"You've passed {inputs=} without {image_token_id=}"
                " Make sure to process your input via mistral_common's"
                " tokenizer or pass a chat completion request. For more"
                " For more info, see: "
                "https://github.com/vllm-project/vllm/issues/8411.")
Patrick von Platen's avatar
Patrick von Platen committed
146

147
    return inputs
Patrick von Platen's avatar
Patrick von Platen committed
148
149
150
151
152


@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral)
153
@INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral)
154
155
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
                                      SupportsPP):
Patrick von Platen's avatar
Patrick von Platen committed
156

157
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Patrick von Platen's avatar
Patrick von Platen committed
158
        super().__init__()
159
160
        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config
Patrick von Platen's avatar
Patrick von Platen committed
161
162
163
164
165
166
167
168
169
170
171
172
173
174
        self.config = config
        self.multimodal_config = multimodal_config

        dataclass_fields = {field.name for field in fields(VisionEncoderArgs)}
        vision_args = {
            key: value
            for key, value in self.config.vision_config.to_dict().items()
            if key in dataclass_fields
        }

        self.vision_args = VisionEncoderArgs(**vision_args)

        # init MistralForCausalLM
        self.language_model = init_vllm_registered_model(
175
            vllm_config=vllm_config,
176
177
178
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
Patrick von Platen's avatar
Patrick von Platen committed
179
180
181
182
183

        self.vision_encoder = VisionTransformer(self.vision_args)
        self.vision_language_adapter = VisionLanguageAdapter(
            self.vision_args, dim=config.text_config.hidden_size)

184
185
186
187
188
189
190
191
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
192
        return get_sampler()
193

194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[NestedTensors] = None,
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None:
            inputs_embeds = merge_multimodal_embeddings(
                input_ids, inputs_embeds, multimodal_embeddings,
                self.vision_args.image_token_id)
        return inputs_embeds

Patrick von Platen's avatar
Patrick von Platen committed
213
214
215
216
217
218
219
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
220
        inputs_embeds: Optional[torch.Tensor] = None,
Patrick von Platen's avatar
Patrick von Platen committed
221
        **kwargs: object,
222
    ) -> Union[torch.Tensor, IntermediateTensors]:
Patrick von Platen's avatar
Patrick von Platen committed
223
224
        """Run forward pass for pixtral.
        """
225
226
        if intermediate_tensors is not None:
            inputs_embeds = None
Patrick von Platen's avatar
Patrick von Platen committed
227

228
229
230
231
232
233
234
        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
Patrick von Platen's avatar
Patrick von Platen committed
235
236
237
238
239

        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
240
                                                  intermediate_tensors,
Patrick von Platen's avatar
Patrick von Platen committed
241
242
243
244
245
246
247
248
249
250
251
252
253
                                                  inputs_embeds=inputs_embeds)

        return hidden_states

    def _parse_and_validate_image_input(
        self,
        images: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor],
                               torch.Tensor]] = None
    ) -> Optional[List[torch.Tensor]]:
        if images is None:
            return None

        if isinstance(images, torch.Tensor):
254
255
256
257
            # if passed as batch take all images
            N, B, C, W, H = images.shape
            images = images.reshape(N * B, C, W, H)
            images = [images[i] for i in range(images.size(0))]
Patrick von Platen's avatar
Patrick von Platen committed
258
        elif isinstance(images, list):
259
260
261
262
263
264
265
266
267
268
            # if passed as list flatten lists of tensors
            flatten_images = []
            for imgs_per_req in images:
                imgs_per_req = [
                    imgs_per_req[i] for i in range(imgs_per_req.size(0))
                ] if isinstance(imgs_per_req, torch.Tensor) else imgs_per_req

                flatten_images.extend(imgs_per_req)

            images = flatten_images
Patrick von Platen's avatar
Patrick von Platen committed
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344

        return images

    def _process_image_input(self,
                             image_input: List[torch.Tensor]) -> torch.Tensor:
        return self.vision_language_adapter(self.vision_encoder(image_input))

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        return self.language_model.sample(logits, sampling_metadata)

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

        def is_vision_encoder_weights(weight: Tuple[str, torch.Tensor]):
            return weight[0].startswith("vision_encoder")

        def is_vision_lang_adapter_weights(weight: Tuple[str, torch.Tensor]):
            return weight[0].startswith("vision_language_adapter")

        def is_vision_weights(weight: Tuple[str, torch.Tensor]):
            return is_vision_encoder_weights(
                weight) or is_vision_lang_adapter_weights(weight)

        llm_weights, vision_encoder_weights, vision_lang_adapter_weights = tee(
            weights, 3)

        # llm
        llm_weights = filter(lambda x: not is_vision_weights(x), llm_weights)
        self.language_model.load_weights(llm_weights)

        # vision encoder
        vision_encoder_weights = filter(is_vision_encoder_weights,
                                        vision_encoder_weights)
        vision_encoder_dict = dict(self.vision_encoder.named_parameters())
        for name, loaded_weight in vision_encoder_weights:
            # cut 'vision_encoder.'
            name = '.'.join(name.split(".")[1:])
            param = vision_encoder_dict[name]

            default_weight_loader(param, loaded_weight)

        # adapter
        vision_lang_adapter_weights = filter(is_vision_lang_adapter_weights,
                                             vision_lang_adapter_weights)
        vision_lang_adpter_dict = dict(
            self.vision_language_adapter.named_parameters())
        for name, loaded_weight in vision_lang_adapter_weights:
            # cut 'vision_language_adapter.'
            name = '.'.join(name.split(".")[1:])
            param = vision_lang_adpter_dict[name]
            default_weight_loader(param, loaded_weight)


# Vision encoder
@dataclass
class VisionEncoderArgs:
    hidden_size: int
    num_channels: int
    image_size: int
    patch_size: int
    intermediate_size: int
    num_hidden_layers: int
    num_attention_heads: int
    rope_theta: float  # for rope-2D
    image_token_id: int
345
    adapter_bias: bool = True
Patrick von Platen's avatar
Patrick von Platen committed
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443


def _reshape_for_broadcast(freqs_cis: torch.Tensor,
                           x: torch.Tensor) -> torch.Tensor:
    """
    freqs_cis: complex - (seq_len, head_dim / 2)
    x: complex - (bsz, seq_len, head_dim / 2)
    """
    ndim = x.ndim
    assert ndim > 1
    assert freqs_cis.shape == (x.shape[1], x.shape[-1]), (
        freqs_cis.shape,
        (x.shape[1], x.shape[-1]),
    )
    shape = [
        d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)
    ]
    return freqs_cis.view(*shape)


def precompute_freqs_cis_2d(
    dim: int,
    height: int,
    width: int,
    theta: float,
) -> torch.Tensor:
    """
    freqs_cis: 2D complex tensor of shape (height, width, dim // 2)
        to be indexed by (height, width) position tuples
    """
    # (dim / 2) frequency bases
    freqs = 1.0 / (theta**(torch.arange(0, dim, 2).float() / dim))

    h = torch.arange(height, device=freqs.device)
    w = torch.arange(width, device=freqs.device)

    freqs_h = torch.outer(h, freqs[::2]).float()
    freqs_w = torch.outer(w, freqs[1::2]).float()
    freqs_2d = torch.cat(
        [
            freqs_h[:, None, :].repeat(1, width, 1),
            freqs_w[None, :, :].repeat(height, 1, 1),
        ],
        dim=-1,
    )
    return torch.polar(torch.ones_like(freqs_2d), freqs_2d)


def apply_rotary_emb_vit(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    assert freqs_cis.dtype == torch.complex64
    freqs_cis = _reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)


class FeedForward(nn.Module):

    def __init__(self, args: VisionEncoderArgs):
        super().__init__()
        assert args.intermediate_size is not None
        self.w1 = nn.Linear(args.hidden_size,
                            args.intermediate_size,
                            bias=False)
        self.w2 = nn.Linear(args.intermediate_size,
                            args.hidden_size,
                            bias=False)
        self.w3 = nn.Linear(args.hidden_size,
                            args.intermediate_size,
                            bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


class Attention(nn.Module):

    def __init__(self, args: VisionEncoderArgs):
        super().__init__()
        self.args = args
        assert not args.hidden_size % args.num_attention_heads
        self.n_heads = args.num_attention_heads
        self.head_dim = args.hidden_size // args.num_attention_heads

        self.wq = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
        self.wk = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
        self.wv = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
        self.wo = nn.Linear(args.hidden_size, args.hidden_size, bias=False)

    def forward(
        self,
        x: torch.Tensor,
444
        mask: torch.Tensor,
Patrick von Platen's avatar
Patrick von Platen committed
445
446
447
448
449
450
451
452
453
454
        freqs_cis: torch.Tensor,
    ) -> torch.Tensor:
        batch, patches, _ = x.shape

        q, k, v = self.wq(x), self.wk(x), self.wv(x)
        q = q.reshape(batch, patches, self.n_heads, self.head_dim)
        k = k.reshape(batch, patches, self.n_heads, self.head_dim)
        v = v.reshape(batch, patches, self.n_heads, self.head_dim)

        q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
455
        out = xops.memory_efficient_attention(q, k, v, attn_bias=mask)
Patrick von Platen's avatar
Patrick von Platen committed
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
        out = out.reshape(batch, patches, self.n_heads * self.head_dim)
        return self.wo(out)


class TransformerBlock(nn.Module):

    def __init__(self, args: VisionEncoderArgs):
        super().__init__()
        self.attention = Attention(args)
        self.feed_forward = FeedForward(args)
        self.attention_norm = RMSNorm(args.hidden_size, eps=1e-5)
        self.ffn_norm = RMSNorm(args.hidden_size, eps=1e-5)

    def forward(
        self,
        x: torch.Tensor,
472
        mask: torch.Tensor,
Patrick von Platen's avatar
Patrick von Platen committed
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
        freqs_cis: torch.Tensor,
    ) -> torch.Tensor:
        r = self.attention.forward(self.attention_norm(x),
                                   mask=mask,
                                   freqs_cis=freqs_cis)
        h = x + r
        r = self.feed_forward.forward(self.ffn_norm(h))
        out = h + r
        return out


class Transformer(nn.Module):

    def __init__(self, args: VisionEncoderArgs):
        super().__init__()
        self.layers = torch.nn.ModuleList()
        for _ in range(args.num_hidden_layers):
            self.layers.append(TransformerBlock(args))

    def forward(
        self,
        x: torch.Tensor,
495
        mask: torch.Tensor,
Patrick von Platen's avatar
Patrick von Platen committed
496
497
498
499
500
501
502
        freqs_cis: Optional[torch.Tensor],
    ) -> torch.Tensor:
        for layer in self.layers:
            x = layer(x, mask=mask, freqs_cis=freqs_cis)
        return x


503
def position_meshgrid(patch_embeds_list: List[torch.Tensor], ) -> torch.Tensor:
Patrick von Platen's avatar
Patrick von Platen committed
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
    positions = torch.cat([
        torch.stack(
            torch.meshgrid(
                torch.arange(p.shape[-2]),
                torch.arange(p.shape[-1]),
                indexing="ij",
            ),
            dim=-1,
        ).reshape(-1, 2) for p in patch_embeds_list
    ])
    return positions


class VisionTransformer(nn.Module):

    def __init__(self, args: VisionEncoderArgs):
        super().__init__()
        self.args = args
        self.patch_conv = nn.Conv2d(
            in_channels=args.num_channels,
            out_channels=args.hidden_size,
            kernel_size=args.patch_size,
            stride=args.patch_size,
            bias=False,
        )
        self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5)
        self.transformer = Transformer(args)

        head_dim = self.args.hidden_size // self.args.num_attention_heads
        assert head_dim % 2 == 0, "ROPE requires even head_dim"
        self._freqs_cis: Optional[torch.Tensor] = None

    @property
    def max_patches_per_side(self) -> int:
        return self.args.image_size // self.args.patch_size

    @property
    def device(self) -> torch.device:
        return next(self.parameters()).device

    @property
    def dtype(self) -> torch.device:
        return next(self.parameters()).dtype

    @property
    def freqs_cis(self) -> torch.Tensor:
        if self._freqs_cis is None:
            self._freqs_cis = precompute_freqs_cis_2d(
                dim=self.args.hidden_size // self.args.num_attention_heads,
                height=self.max_patches_per_side,
                width=self.max_patches_per_side,
                theta=self.args.rope_theta,
            )

        if self._freqs_cis.device != self.device:
            self._freqs_cis = self._freqs_cis.to(device=self.device)

        return self._freqs_cis

    def forward(
        self,
        images: List[torch.Tensor],
    ) -> torch.Tensor:
        """
        Args:
            images: list of N_img images of variable sizes, 
                each of shape (C, H, W)
        Returns:
            image_features: tensor of token features for 
                all tokens of all images of shape (N_toks, D)
        """
        # pass images through initial convolution independently
        patch_embeds_list = [
            self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images
        ]

        # flatten to a single sequence
        patch_embeds = torch.cat(
            [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
        patch_embeds = self.ln_pre(patch_embeds)

        # positional embeddings
        positions = position_meshgrid(patch_embeds_list).to(self.device)
        freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]]

        # pass through Transformer with a block diagonal mask delimiting images
590
591
592
593
594
595
        if USE_XFORMERS_OPS:
            mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
                [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
        else:
            raise ImportError("Xformers is required for Pixtral inference "
                              "with the Mistral format")
Patrick von Platen's avatar
Patrick von Platen committed
596
597
598
599
600
601
602
603
604
605
606
607
608
609
        out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)

        # remove batch dimension of the single sequence
        return out.squeeze(0)


class VisionLanguageAdapter(nn.Module):

    def __init__(self, args: VisionEncoderArgs, dim: int):
        super().__init__()
        assert isinstance(args, VisionEncoderArgs)
        self.w_in = nn.Linear(
            args.hidden_size,
            dim,
610
            bias=args.adapter_bias,
Patrick von Platen's avatar
Patrick von Platen committed
611
612
        )
        self.gelu = nn.GELU()
613
        self.w_out = nn.Linear(dim, dim, bias=args.adapter_bias)
Patrick von Platen's avatar
Patrick von Platen committed
614
615
616

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w_out(self.gelu(self.w_in(x)))
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650


#### HF Transformers version of Pixtral ####
# Based off https://github.com/huggingface/transformers/blob/d7950bff82b18c823193d17d72188c5e46d06c83/src/transformers/models/pixtral/modeling_pixtral.py
# This model follows the Llava family, meaning image embeddings are placed
# instead of the `[IMG]` token placeholders.
# The model uses [`PixtralVisionModel`] for its vision encoder,
# and [`MistralForCausalLM`] for its language decoder.


def get_pixtral_hf_patch_grid_length(*, image_size: int,
                                     patch_size: int) -> int:
    # Since interpolation is applied, the image size need not be divisible
    # assert image_size % patch_size == 0
    return image_size // patch_size


def get_pixtral_hf_num_patches(*, image_size: int, patch_size: int) -> int:
    grid_length = get_pixtral_hf_patch_grid_length(image_size=image_size,
                                                   patch_size=patch_size)
    return grid_length * grid_length


def get_max_pixtral_hf_image_feature_size(
        hf_config: PixtralVisionConfig) -> int:
    return get_pixtral_hf_num_patches(image_size=hf_config.image_size,
                                      patch_size=hf_config.patch_size)


def get_max_pixtral_hf_image_tokens(hf_config: PixtralVisionConfig) -> int:
    return get_max_pixtral_hf_image_feature_size(hf_config)


def dummy_seq_data_for_pixtral_hf(
651
652
653
654
655
656
657
        hf_config: PixtralVisionConfig,
        seq_len: int,
        num_images: int,
        *,
        image_token_id: int,
        image_feature_size_override: Optional[int] = None,
        mm_key: str = "image"):
658
659
660
661
662
663
664
665
    if image_feature_size_override is None:
        image_feature_size = get_max_pixtral_hf_image_feature_size(hf_config)
    else:
        image_feature_size = image_feature_size_override

    return SequenceData.from_prompt_token_counts(
        (image_token_id, image_feature_size * num_images),
        (0, seq_len - image_feature_size * num_images),
666
667
668
669
670
    ), {
        mm_key:
        consecutive_placeholder_ranges(num_items=num_images,
                                       item_size=image_feature_size)
    }
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735


def dummy_image_for_pixtral_hf(
    hf_config: PixtralVisionConfig,
    num_images: int,
    *,
    image_width_override: Optional[int] = None,
    image_height_override: Optional[int] = None,
):
    width = height = hf_config.image_size
    if image_width_override is not None:
        width = image_width_override
    if image_height_override is not None:
        height = image_height_override

    image = Image.new("RGB", (width, height), color=0)
    return {"image": image if num_images == 1 else [image] * num_images}


def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig,
                                      image_width: int,
                                      image_height: int) -> Tuple[int, int]:
    # Adapted from transformers.models.pixtral.image_processing_pixtral.get_resize_output_image_size # noqa: E501
    # https://github.com/huggingface/transformers/blob/2bd4d5897dc73e8b172832070a6f9e567a0df017/src/transformers/models/pixtral/image_processing_pixtral.py#L180 # noqa: E501
    max_width, max_height = hf_config.image_size, hf_config.image_size
    patch_width, patch_height = hf_config.patch_size, hf_config.patch_size

    ratio = max(image_width / max_width, image_height / max_height)

    if ratio > 1:
        image_width = int(numpy.ceil(image_width / ratio))
        image_height = int(numpy.ceil(image_height / ratio))

    num_height_tokens, num_width_tokens = _num_image_tokens(
        (image_height, image_width), (patch_height, patch_width))

    return num_width_tokens, num_height_tokens


def input_processor_for_pixtral_hf(
    model_config: ModelConfig,
    hf_config: PixtralVisionConfig,
    inputs: DecoderOnlyInputs,
    *,
    image_token_id: int,
    image_feature_size_override: Optional[Union[int, List[int]]] = None,
) -> DecoderOnlyInputs:
    assert image_feature_size_override is None, (
        "image_feature_size_override is not supported for Pixtral")

    multi_modal_data = inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        return inputs

    processor = cached_get_processor(model_config.model)

    image_data = multi_modal_data["image"]
    if isinstance(image_data, Image.Image):
        image_data = [image_data]
    elif not is_list_of(image_data, Image.Image):
        raise TypeError(f"Invalid image type: {type(image_data)}")

    new_prompt = inputs.get("prompt")
    new_token_ids = inputs["prompt_token_ids"]

736
737
738
739
    image_token = processor.image_token
    image_break_token = processor.image_break_token
    image_end_token = processor.image_end_token

740
741
    # Update new_prompt if present
    if new_prompt:
742
743
744
        parts = new_prompt.split(image_token)
        assert len(parts) - 1 == len(image_data)
        new_parts = [parts[0]]  # Start with the part before any image tokens
745

746
747
        for image, next_part in zip(image_data, parts[1:]):
            w, h = image.size
748
749
750
751
            (num_width_tokens,
             num_height_tokens) = get_pixtral_hf_image_feature_size(
                 hf_config, image_width=w, image_height=h)

752
753
            replace_tokens = [image_token] * num_width_tokens + [
                image_break_token
754
            ]
755
756
            replace_tokens = replace_tokens * num_height_tokens
            replace_tokens[-1] = image_end_token
757

758
759
760
761
            new_parts.append("".join(replace_tokens))
            new_parts.append(next_part)

        new_prompt = "".join(new_parts)
762
763

    # Update new_token_ids
764
765
766
767
    convert_tokens_to_ids = processor.tokenizer.convert_tokens_to_ids
    image_token_id = convert_tokens_to_ids(image_token)
    image_break_id = convert_tokens_to_ids(image_break_token)
    image_end_id = convert_tokens_to_ids(image_end_token)
768
    placeholder_token_id = -999
769
770
771
772
773
774
    # Find all image token indices at once
    placeholder_indices = [
        idx for idx, token_id in enumerate(new_token_ids)
        if token_id == image_token_id
    ]
    assert len(placeholder_indices) == len(image_data)
775
    replace_tokens_list = []
776
777
    for placeholder_idx, image in zip(placeholder_indices, image_data):
        new_token_ids[placeholder_idx] = placeholder_token_id
778

779
780
781
782
783
        w, h = image.size
        (num_width_tokens,
         num_height_tokens) = get_pixtral_hf_image_feature_size(hf_config,
                                                                image_width=w,
                                                                image_height=h)
784

785
786
        replace_tokens = [image_token_id] * num_width_tokens + [image_break_id]
        replace_tokens = replace_tokens * num_height_tokens
787
788
        replace_tokens[-1] = image_end_id
        replace_tokens_list.append(replace_tokens)
789

790
    reverse_offsets: List[int] = []
791
792
793
    # Backward iteration for replacement without affecting known indices
    for placeholder_idx, replace_tokens in zip(reversed(placeholder_indices),
                                               reversed(replace_tokens_list)):
794
795
        reverse_offsets.append(
            len(new_token_ids) - placeholder_idx + len(replace_tokens))
796
        new_token_ids[placeholder_idx:placeholder_idx + 1] = replace_tokens
797

798
799
800
801
802
803
804
805
806
    placeholder_ranges: List[PlaceholderRange] = []
    for reverse_offset, replace_tokens in zip(reversed(reverse_offsets),
                                              replace_tokens_list):
        placeholder_ranges.append(
            PlaceholderRange(
                offset=len(new_token_ids) - reverse_offset,
                length=len(replace_tokens),
            ))

807
808
809
    # NOTE: Create a defensive copy of the original inputs
    return token_inputs(prompt_token_ids=new_token_ids,
                        prompt=new_prompt,
810
811
                        multi_modal_data=multi_modal_data,
                        multi_modal_placeholders={"image": placeholder_ranges})
812
813
814
815


class PixtralHFMLP(nn.Module):

816
817
818
819
820
821
822
    def __init__(
        self,
        config: PixtralVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        prefix: str = "",
    ) -> None:
823
        super().__init__()
824

825
        assert config.intermediate_size is not None
826
827
828
829
830
831
832
833
834
835
836
837
        self.gate_up_proj = MergedColumnParallelLinear(
            input_size=config.hidden_size,
            output_sizes=[config.intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj")
        self.down_proj = RowParallelLinear(input_size=config.intermediate_size,
                                           output_size=config.hidden_size,
                                           bias=False,
                                           quant_config=quant_config,
                                           prefix=f"{prefix}.down_proj")
        self.act_and_mul = get_act_and_mul_fn(config.hidden_act)
838
839

    def forward(self, x: torch.Tensor) -> torch.Tensor:
840
841
842
843
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_and_mul(gate_up)
        x, _ = self.down_proj(x)
        return x
844
845
846
847


class PixtralHFAttention(nn.Module):

848
849
850
851
852
853
854
    def __init__(
        self,
        config: PixtralVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        prefix: str = "",
    ) -> None:
855
        super().__init__()
856

857
858
        self.config = config
        assert not config.hidden_size % config.num_attention_heads
859
860
861
        self.total_num_heads = config.num_attention_heads
        tp_size = get_tensor_model_parallel_world_size()
        self.n_heads = divide(config.num_attention_heads, tp_size)
862
863
        self.head_dim = config.hidden_size // config.num_attention_heads

864
865
866
        self.qkv_proj = QKVParallelLinear(
            hidden_size=config.hidden_size,
            head_size=self.head_dim,
867
            total_num_heads=self.total_num_heads,
868
869
870
871
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
872
        assert self.total_num_heads * self.head_dim == config.hidden_size
873
874
875
876
877
878
879
        self.o_proj = RowParallelLinear(
            input_size=config.hidden_size,
            output_size=config.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
880
881
882
883

    def forward(
        self,
        hidden_states: torch.Tensor,
884
        attention_mask: torch.Tensor,
885
886
        position_embeddings: torch.Tensor,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
887
        batch, patches, _ = hidden_states.size()
888

889
890
        qkv_states, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv_states.chunk(3, dim=-1)
891

892
893
894
        # Transpose q and k to apply HF's Rotary Position Embedding
        q = q.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
895
        v = v.view(batch, patches, self.n_heads, self.head_dim)
896
        cos, sin = position_embeddings
897
        q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0)
898

899
900
901
902
903
904
905
906
907
908
        if USE_XFORMERS_OPS:
            # Transpose q and k back for attention
            q = q.transpose(1, 2).contiguous()
            k = k.transpose(1, 2).contiguous()

            out = xops.memory_efficient_attention(q,
                                                  k,
                                                  v,
                                                  attn_bias=attention_mask)
        else:
909
            v = v.transpose(1, 2)
910
911
912
            out = nn.functional.scaled_dot_product_attention(
                q, k, v, attn_mask=attention_mask)
            out = out.transpose(1, 2)
913

914
915
        out = out.view(batch, patches, self.n_heads * self.head_dim)
        attn_output, _ = self.o_proj(out)
916

917
        return attn_output, None
918
919
920
921


class PixtralHFTransformerBlock(nn.Module):

922
923
924
925
926
927
928
    def __init__(
        self,
        config: PixtralVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        prefix: str = "",
    ) -> None:
929
        super().__init__()
930

931
        self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
932
933
934
935
936
937
        self.attention = PixtralHFAttention(config,
                                            quant_config=quant_config,
                                            prefix=f"{prefix}.attention")
        self.feed_forward = PixtralHFMLP(config,
                                         quant_config=quant_config,
                                         prefix=f"{prefix}.feed_forward")
938
939
940
941
942
        self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)

    def forward(
        self,
        hidden_states: torch.Tensor,
943
        attention_mask: torch.Tensor,
944
945
        position_embeddings: torch.Tensor,
    ) -> torch.Tensor:
946
947
948
        r, _ = self.attention.forward(self.attention_norm(hidden_states),
                                      attention_mask=attention_mask,
                                      position_embeddings=position_embeddings)
949
950
951
952
953
954
955
956
        h = hidden_states + r
        r = self.feed_forward.forward(self.ffn_norm(h))
        out = h + r
        return out


class PixtralHFTransformer(nn.Module):

957
958
959
960
961
962
963
964
    def __init__(
        self,
        config: PixtralVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_hidden_layers_override: Optional[int] = None,
        prefix: str = "",
    ) -> None:
965
        super().__init__()
966
967
968
969
970
971
972
973
974
975
976
977

        if num_hidden_layers_override is None:
            num_hidden_layers = config.num_hidden_layers
        else:
            num_hidden_layers = num_hidden_layers_override

        self.layers = nn.ModuleList([
            PixtralHFTransformerBlock(config=config,
                                      quant_config=quant_config,
                                      prefix=f"{prefix}.layers.{layer_idx}")
            for layer_idx in range(num_hidden_layers)
        ])
978
979
980
981

    def forward(
        self,
        x: torch.Tensor,
982
        attention_mask: torch.Tensor,
983
        position_embeddings: torch.Tensor,
984
        return_all_hidden_states: bool,
985
    ) -> torch.Tensor:
986
987
        hidden_states_pool = []

988
989
        for layer in self.layers:
            x = layer(x, attention_mask, position_embeddings)
990
991
992
993
994
995
            if return_all_hidden_states:
                hidden_states_pool.append(x)
        # If we have multiple feature sample layers, we return all hidden
        # states in order and grab the ones we need by index.
        if return_all_hidden_states:
            return hidden_states_pool
996
997
998
999
1000
        return x


class PixtralHFVisionModel(nn.Module):

1001
1002
1003
1004
1005
1006
1007
1008
1009
    def __init__(
        self,
        config: PixtralVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_hidden_layers_override: Optional[int] = None,
        require_post_norm: Optional[bool] = None,
        prefix: str = "",
    ) -> None:
1010
1011
1012
        super().__init__()

        self.config = config
1013

1014
1015
1016
1017
1018
1019
1020
1021
        self.patch_conv = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=config.hidden_size,
            kernel_size=config.patch_size,
            stride=config.patch_size,
            bias=False,
        )
        self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
        self.transformer = PixtralHFTransformer(
            config,
            quant_config,
            num_hidden_layers_override=num_hidden_layers_override,
            prefix=f"{prefix}.transformer",
        )

        num_hidden_layers = config.num_hidden_layers
        if len(self.transformer.layers) > config.num_hidden_layers:
            raise ValueError(
                f"The original encoder only has {num_hidden_layers} "
                f"layers, but you requested {len(self.transformer.layers)} "
                "layers.")

        if require_post_norm is True:
            msg = "PixtralHFVisionModel does not have post-layernorm"
            raise ValueError(msg)

1040
1041
1042
1043
1044
1045
1046
1047
        self.dtype = next(self.parameters()).dtype
        self.device = next(self.parameters()).device
        self.patch_positional_embedding = PixtralRotaryEmbedding(
            config, self.device)

    def forward(
        self,
        pixel_values: List[torch.Tensor],
1048
        feature_sample_layers: Optional[list[int]] = None,
1049
1050
1051
    ) -> torch.Tensor:
        """
        Args:
1052
1053
1054
1055
            pixel_values: Each image to be processed will be a separate tensor
                in pixel_values. This means it will be a list of tensors
                because multiple requests batched can have multiple images,
                each with their own shape potentially
1056
1057
1058
            feature_sample_layers: Layer indices whose features should be
                concatenated and used as the visual encoder output. If none
                are provided, the last layer is used.
1059

1060
1061
1062
1063
1064
1065
        Returns:
            image_features: tensor of token features for
                all tokens of all images of shape (N_toks, D)
        """
        # pass images through initial convolution independently
        patch_embeds_list = [
1066
            self.patch_conv(img.unsqueeze(0).to(self.dtype))
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
            for img in pixel_values
        ]

        # flatten to a single sequence
        patch_embeds = torch.cat(
            [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
        patch_embeds = self.ln_pre(patch_embeds)

        # positional embeddings
        position_ids = position_ids_in_meshgrid(
            patch_embeds_list,
            max_width=self.config.image_size // self.config.patch_size).to(
                self.device)
        position_embedding = self.patch_positional_embedding(
            patch_embeds, position_ids)
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092

        if USE_XFORMERS_OPS:
            attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
                [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
        else:
            from transformers.models.pixtral.modeling_pixtral import (
                generate_block_attention_mask)
            attention_mask = generate_block_attention_mask(
                [p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
                patch_embeds)

1093
1094
1095
1096
1097
1098
1099
1100
1101
        return_all_hidden_states = feature_sample_layers is not None
        out = self.transformer(
            patch_embeds,
            attention_mask,
            position_embedding,
            return_all_hidden_states=return_all_hidden_states)

        out = resolve_visual_encoder_outputs(out, feature_sample_layers, None,
                                             self.config.num_hidden_layers)
1102
1103
1104
1105
1106

        return out

    # (TODO) Add prefix argument for filtering out weights to be loaded
    #        ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
1107
1108
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
1109
1110
1111
1112
1113
1114
1115
1116
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
1117
        params_dict = dict(self.named_parameters())
1118
        loaded_params: Set[str] = set()
1119
        layer_count = len(self.transformer.layers)
1120
1121

        for name, loaded_weight in weights:
1122
1123
1124
1125
1126
1127
            # omit layers when num_hidden_layers_override is set
            if name.startswith("transformer.layers"):
                layer_idx = int(name.split(".")[2])
                if layer_idx >= layer_count:
                    continue

1128
1129
1130
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
1131
1132
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
1133
1134
1135
1136
1137
1138
1139
1140
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
1141
1142
            loaded_params.add(name)
        return loaded_params