"vscode:/vscode.git/clone" did not exist on "322a0be6ba963f8f376f00a5daa55848de0ea596"
molmo.py 50.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import math
5
from collections.abc import Iterable, Mapping, Sequence
6
from dataclasses import dataclass
7
from functools import cached_property, partial
8
from itertools import islice
9
from typing import Annotated
10

11
import numpy as np
12
import torch
13
14
import torch.nn as nn
import torch.nn.functional as F
15
from einops import rearrange
16
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin, TensorType
17
18
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput
19

20
from vllm.attention.layer import Attention, MultiHeadAttention
21
from vllm.compilation.decorators import support_torch_compile
22
from vllm.config import CacheConfig, VllmConfig
23
from vllm.config.multimodal import BaseDummyOptions
24
25
26
27
28
29
30
31
from vllm.distributed import (
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    split_tensor_along_last_dim,
    tensor_model_parallel_all_gather,
)
from vllm.model_executor.layers.activation import MulAndSilu, QuickGELU, SiluAndMul
32
from vllm.model_executor.layers.layernorm import RMSNorm
33
34
35
36
37
38
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
39
from vllm.model_executor.layers.logits_processor import LogitsProcessor
40
from vllm.model_executor.layers.quantization import QuantizationConfig
41
42
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
43
44
45
    ParallelLMHead,
    VocabParallelEmbedding,
)
46
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
47
from vllm.model_executor.models.module_mapping import MultiModelKeys
48
from vllm.multimodal import MULTIMODAL_REGISTRY
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptIndexTargets,
    PromptInsertion,
    PromptUpdate,
    PromptUpdateDetails,
)
63
from vllm.multimodal.profiling import BaseDummyInputsBuilder
64
from vllm.sequence import IntermediateTensors
65
from vllm.utils.tensor_schema import TensorSchema, TensorShape
66

67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
    SupportsQuant,
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
82
83
84
85
86

# TODO: hard-coded for now. Consider making it configurable.
VIT_LAYERS = [-2, -9]
NUM_PREFIX_TOKENS = 1
ADDITIONAL_VOCAB_SIZE = 128
87
88
89
90
91
IMAGE_PATCH_TOKEN = "<im_patch>"
IM_COL_TOKEN = "<im_col>"
IM_START_TOKEN = "<im_start>"
IM_END_TOKEN = "<im_end>"
POOLING_SIZE = 2
92
93


94
class MolmoImageInputs(TensorSchema):
95
    """
96
97
    Dimensions:
        - bn: Batch size * number of images
98
        - bnc: Batch size * number of images * number of crops (dynamic)
99
        - np: Number of patches
100
        - tp: Token sequence positions
101
        - pd: Patch dimension
102
    """
103

104
    images: Annotated[torch.Tensor, TensorShape("bnc", "np", "pd")]
105

106
107
108
109
    image_masks: Annotated[torch.Tensor | None, TensorShape("bnc", "np")]

    image_input_idx: Annotated[torch.Tensor, TensorShape("bnc", "tp")]
    """An index tensor that maps image features to their corresponding patch tokens."""
110
111

    num_crops: Annotated[torch.Tensor, TensorShape("bn")]
112

113
114
115

@dataclass
class VisionBackboneConfig:
116
    image_default_input_size: tuple[int, int] = (336, 336)
117
118
119
120
121
122
123
124
125
126
127
128
    image_patch_size: int = 14
    image_pos_patch_size: int = 14
    image_emb_dim: int = 1024
    image_num_heads: int = 16
    image_num_key_value_heads: int = 16
    image_num_layers: int = 23
    image_mlp_dim: int = 4096
    image_mlp_activations: str = "quick_gelu"
    image_num_pos: int = 577
    image_norm_eps: float = 1e-5

    def __post_init__(self):
129
        self.image_default_input_size = tuple(self.image_default_input_size)  # type: ignore[assignment]
130
131
132
133
134
135
136
137
138
139
140
141
142

    @property
    def image_num_patch(self):
        h, w = self.image_default_input_size
        return h // self.image_patch_size, w // self.image_patch_size


class ViTMLP(nn.Module):
    """MLP used in Vision Transformer."""

    def __init__(
        self,
        config: VisionBackboneConfig,
143
        quant_config: QuantizationConfig | None = None,
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    ):
        super().__init__()
        self.w1 = ColumnParallelLinear(
            config.image_emb_dim,
            config.image_mlp_dim,
            bias=True,
            quant_config=quant_config,
        )
        # Activation function.
        assert config.image_mlp_activations == "quick_gelu"
        self.act = QuickGELU()
        self.w2 = RowParallelLinear(
            config.image_mlp_dim,
            config.image_emb_dim,
            bias=True,
            quant_config=quant_config,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.w1(x)
        x = self.act(x)
        x, _ = self.w2(x)
        return x


class MultiHeadDotProductAttention(nn.Module):
    """Multi-head attention used in Vision Transformer."""

    def __init__(
        self,
        config: VisionBackboneConfig,
        use_bias: bool = True,
        nlayers: int = 1,
177
        quant_config: QuantizationConfig | None = None,
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    ):
        super().__init__()

        self.hidden_size = config.image_emb_dim
        self.total_num_heads = config.image_num_heads
        tp_size = get_tensor_model_parallel_world_size()

        assert self.hidden_size % self.total_num_heads == 0
        assert self.total_num_heads % tp_size == 0

        self.num_heads = self.total_num_heads // tp_size
        self.head_dim = self.hidden_size // self.total_num_heads

        self.total_num_kv_heads = config.image_num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            assert self.total_num_kv_heads % tp_size == 0
        else:
            assert tp_size % self.total_num_kv_heads == 0

        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)

        self.wq = ColumnParallelLinear(
            nlayers * self.hidden_size,
            self.total_num_heads * self.head_dim,
            bias=use_bias,
            quant_config=quant_config,
        )
        self.wk = ColumnParallelLinear(
            nlayers * self.hidden_size,
            self.total_num_kv_heads * self.head_dim,
            bias=use_bias,
            quant_config=quant_config,
        )
        self.wv = ColumnParallelLinear(
            nlayers * self.hidden_size,
            self.total_num_kv_heads * self.head_dim,
            bias=use_bias,
            quant_config=quant_config,
        )
        self.wo = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=use_bias,
            quant_config=quant_config,
        )

224
        self.scale = self.head_dim**-0.5
225
226
227
        self.attn = MultiHeadAttention(
            self.num_heads, self.head_dim, self.scale, num_kv_heads=self.num_kv_heads
        )
228

229
    def forward(
230
        self, inputs_q: torch.Tensor, inputs_kv: torch.Tensor | None = None
231
    ) -> torch.Tensor:
232
233
234
235
236
237
238
239
240
241
        if inputs_kv is not None:
            inputs_k = inputs_kv
            inputs_v = inputs_kv
        else:
            inputs_k = inputs_q
            inputs_v = inputs_q

        xq, _ = self.wq(inputs_q)
        xk, _ = self.wk(inputs_k)
        xv, _ = self.wv(inputs_v)
242
243

        output = self.attn(xq, xk, xv)
244
245
246
247
248
249
250
251
252
253
254
        output, _ = self.wo(output)

        return output


class ResidualAttentionBlock(nn.Module):
    """Residual attention block used in Vision Transformer."""

    def __init__(
        self,
        config: VisionBackboneConfig,
255
        quant_config: QuantizationConfig | None = None,
256
257
    ):
        super().__init__()
258
        self.attention = MultiHeadDotProductAttention(config, quant_config=quant_config)
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
        self.feed_forward = ViTMLP(config, quant_config)
        self.attention_norm = nn.LayerNorm(
            config.image_emb_dim,
            eps=config.image_norm_eps,
        )
        self.ffn_norm = nn.LayerNorm(
            config.image_emb_dim,
            eps=config.image_norm_eps,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attention(self.attention_norm(x))
        x = x + self.feed_forward(self.ffn_norm(x))
        return x


class BlockCollection(nn.Module):
    """Collection of residual attention blocks used in Vision Transformer."""

    def __init__(
        self,
        config: VisionBackboneConfig,
281
        quant_config: QuantizationConfig | None = None,
282
283
    ):
        super().__init__()
284
285
286
287
288
289
        self.resblocks = nn.ModuleList(
            [
                ResidualAttentionBlock(config, quant_config)
                for _ in range(config.image_num_layers)
            ]
        )
290

291
    def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
        hidden_states = []
        for r in self.resblocks:
            x = r(x)
            hidden_states.append(x)
        return hidden_states


def _expand_token(token: torch.Tensor, batch_size: int) -> torch.Tensor:
    return token.view(1, 1, -1).expand(batch_size, -1, -1)


class VisionTransformer(nn.Module):
    """Vision Transformer used in Vision Backbone."""

    def __init__(
        self,
        config: VisionBackboneConfig,
309
        quant_config: QuantizationConfig | None = None,
310
311
312
313
    ):
        super().__init__()
        scale = config.image_emb_dim**-0.5
        self.patch_num = config.image_num_patch
314
        self.class_embedding = nn.Parameter(torch.randn(config.image_emb_dim) * scale)
315
316
        self.num_prefix_tokens: int = NUM_PREFIX_TOKENS
        self.positional_embedding = nn.Parameter(
317
318
            torch.randn(config.image_num_pos, config.image_emb_dim) * scale
        )
319
320
321
322
323
324
        image_patch_size = config.image_patch_size
        self.patch_embedding = nn.Linear(
            image_patch_size * image_patch_size * 3,
            config.image_emb_dim,
            bias=False,
        )
325
        self.pre_ln = nn.LayerNorm(config.image_emb_dim, eps=config.image_norm_eps)
326
327
328
329
330
331
332
        self.transformer = BlockCollection(config, quant_config)

    def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor:
        cls_emb = self.positional_embedding[0:1]
        pos_emb = self.positional_embedding[1:]

        pos_emb = pos_emb.reshape(
333
334
335
336
337
338
            (
                int(math.sqrt(pos_emb.shape[0])),
                int(math.sqrt(pos_emb.shape[0])),
                pos_emb.shape[1],
            )
        )
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354

        (patch_num_0, patch_num_1) = patch_num

        if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1:
            # from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
            pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2)
            pos_emb = F.interpolate(
                pos_emb,
                size=(patch_num_0, patch_num_1),
                mode="bicubic",
                align_corners=False,
                antialias=True,
            )
            pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0)

        pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1])
355
        x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]], dim=1).to(x.dtype)
356
357
        return x

358
    def forward(
359
        self, x: torch.Tensor, patch_num: int | None = None
360
    ) -> list[torch.Tensor]:
361
362
363
364
365
366
367
368
369
370
371
        """
        : param x: (batch_size, num_patch, n_pixels)
        """
        if patch_num is None:
            patch_num = self.patch_num
        B, N, D = x.shape

        x = self.patch_embedding(x)

        # class embeddings and positional embeddings
        x = torch.cat(
372
373
            [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1
        )
374
375
376
377
378
379
380
381
382
383
384
385
386
387
        x = self.add_pos_emb(x, patch_num)

        x = self.pre_ln(x)

        hidden_states = self.transformer(x)
        return hidden_states


class MolmoAttention(nn.Module):
    """Molmo's LLM attention."""

    def __init__(
        self,
        config: PretrainedConfig,
388
389
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
390
        prefix: str = "",
391
392
393
394
395
396
397
398
399
400
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        self.tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads

        assert self.hidden_size % self.total_num_heads == 0
        assert self.total_num_heads % self.tp_size == 0

        self.num_heads = self.total_num_heads // self.tp_size
401
        self.total_num_kv_heads = config.num_key_value_heads or self.total_num_heads
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
        if self.total_num_kv_heads >= self.tp_size:
            assert self.total_num_kv_heads % self.tp_size == 0
        else:
            assert self.tp_size % self.total_num_kv_heads == 0

        self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
        self.head_dim = self.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.max_position_embeddings = config.max_position_embeddings

        # Attention input projection. Projects x -> (q, k, v)
        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=config.qkv_bias,
            quant_config=quant_config,
        )

423
424
425
        self.tp_rank: int | None = None
        self.k_norm: nn.Module | None = None
        self.q_norm: nn.Module | None = None
426
427
        if config.attention_layer_norm:
            self.tp_rank = get_tensor_model_parallel_rank()
428
429
430
431
            self.k_norm = RMSNorm(
                self.total_num_kv_heads * self.head_dim, eps=config.layer_norm_eps
            )
            self.q_norm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
432
433
434
435
436
437

        # Rotary embeddings.
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
438
            rope_parameters=config.rope_parameters,
439
440
        )
        self.scaling = self.head_dim**-0.5
441
442
443
444
445
446
447
448
449
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
450
451
452
453
454
455
456
457
458

        # Attention output projection.
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
        )

459
460
461
    def _apply_qk_norm(
        self, q: torch.Tensor, k: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
462
463
464
        if self.tp_size > 1:
            q = tensor_model_parallel_all_gather(q.contiguous())
            k = tensor_model_parallel_all_gather(k.contiguous())
465
466
        q = self.q_norm(q)
        k = self.k_norm(k)
467
        if self.tp_size > 1:
468
            splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size)
469
470
471
472
473
474
475
476
477
478
479
480
481
482
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
        return q, k

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        if self.q_norm is not None and self.k_norm is not None:
            q, k = self._apply_qk_norm(q, k)
        q, k = self.rotary_emb(positions, q, k)
483
        attn_output = self.attn(q, k, v)
484
485
486
487
        output, _ = self.o_proj(attn_output)
        return output


488
class LanguageModelMLP(nn.Module):
489
490
    """Molmo's LLM mlp."""

491
492
493
    def __init__(
        self,
        config: PretrainedConfig,
494
495
        input_dim: int | None = None,
        quant_config: QuantizationConfig | None = None,
496
    ) -> None:
497
498
499
500
        super().__init__()
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size // 2

501
502
503
504
505
506
507
        self.gate_up_proj = MergedColumnParallelLinear(
            input_dim or self.hidden_size,
            [self.intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
        )
        # Activation function.
508
        self.act_fn = MulAndSilu()
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
        # Feed-forward output projection.
        self.down_proj = RowParallelLinear(
            self.intermediate_size,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
        )

    def forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class ImageProjectorMLP(nn.Module):
    """Molmo's image_projector mlp."""

    def __init__(
        self,
        config: PretrainedConfig,
533
534
        input_dim: int | None = None,
        quant_config: QuantizationConfig | None = None,
535
536
537
538
539
540
541
542
543
544
545
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size // 2

        self.merged_linear = MergedColumnParallelLinear(
            input_dim or self.hidden_size,
            [self.intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
        )
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
        # Activation function.
        self.act_fn = SiluAndMul()

        # Feed-forward output projection.
        self.down_proj = RowParallelLinear(
            self.intermediate_size,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
        )

    def forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
561
        gate_up, _ = self.merged_linear(x)
562
563
564
565
566
567
568
569
570
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class MolmoDecoderLayer(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
571
572
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
573
        prefix: str = "",
574
575
576
    ) -> None:
        super().__init__()
        # Attention block.
577
578
579
        self.self_attn = MolmoAttention(
            config, cache_config, quant_config, prefix=f"{prefix}.self_attn"
        )
580
581

        # MLP block.
582
        self.mlp = LanguageModelMLP(config, quant_config=quant_config)
583
584
585

        # LayerNorm
        assert config.layer_norm_type == "rms"
586
587
588
589
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.layer_norm_eps
        )
590
591
592
593
594

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
595
596
        residual: torch.Tensor | None,
    ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
597
598
599
600
601
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
602
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
603
604
605
606
607
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

608
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
609
610
611
612
613
614
615
616
617
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
618
619
        residual: torch.Tensor | None,
    ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
        # Self Attention
        residual = hidden_states
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = hidden_states + residual
        residual = hidden_states

        hidden_states = self.mlp(hidden_states)
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = hidden_states + residual
        residual = None
        return hidden_states, residual


638
639
class MolmoVisionBackbone(nn.Module, SupportsQuant):
    packed_modules_mapping = {"merged_linear": ["gate_proj", "up_proj"]}
640
641
642
643
644

    def __init__(
        self,
        config: PretrainedConfig,
        vision_config: VisionBackboneConfig,
645
        quant_config: QuantizationConfig | None = None,
646
647
648
649
650
    ) -> None:
        super().__init__()
        self.vit_layers = VIT_LAYERS
        self.image_num_patch = vision_config.image_num_patch
        self.llm_patches_per_crop = (
651
652
            (self.image_num_patch[0] + 1) // POOLING_SIZE,
            (self.image_num_patch[1] + 1) // POOLING_SIZE,
653
        )
654
        self.image_vit = VisionTransformer(vision_config, quant_config=quant_config)
655
        self.num_prefix_tokens = self.image_vit.num_prefix_tokens
656
657
658
        assert self.num_prefix_tokens in {0, 1}, (
            "Only 0 or 1 prefix tokens are supported"
        )
659
        self.image_pooling_2d = MultiHeadDotProductAttention(
660
661
            vision_config, nlayers=len(self.vit_layers), quant_config=quant_config
        )
662
        self.image_projector = ImageProjectorMLP(
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
            config,
            input_dim=vision_config.image_emb_dim,
            quant_config=quant_config,
        )

        image_dim = vision_config.image_emb_dim * len(self.vit_layers)
        self.pad_embed = nn.Parameter(torch.zeros((2, image_dim)))

    @property
    def dtype(self) -> torch.dtype:
        return self.image_vit.patch_embedding.weight.dtype

    @property
    def device(self) -> torch.device:
        return self.image_vit.patch_embedding.weight.device

    def encode_image(self, images: torch.Tensor) -> torch.Tensor:
        """
        : param images: (batch_size, num_crops, num_patch, n_pixels)
        """
        B, T, N, D = images.shape

685
        mask = ~torch.all(images.view(B * T, N, D) == -1, dim=(1, 2), keepdim=True)
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706

        images = images.view(B * T, N, D)
        image_features = self.image_vit(images)

        if self.vit_layers is not None:
            features = []
            for layer in self.vit_layers:
                features.append(image_features[layer])
            image_features = torch.cat(features, dim=-1)
        else:
            image_features = image_features[-1]

        if self.num_prefix_tokens > 0:
            image_features = image_features[:, 1:]

        image_features = image_features * mask
        image_features = image_features.view(B, T, N, -1)

        return image_features

    def forward(
707
708
709
710
        self,
        images: torch.Tensor,
        image_masks: torch.Tensor,
    ) -> torch.Tensor:
711
712
713
714
715
716
717
718
719
        # image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) # noqa: E501
        batch_size, num_image = images.shape[:2]
        images = images.to(device=self.device, dtype=self.dtype)
        image_features = self.encode_image(images)

        og_dtype = image_features.dtype
        assert image_masks is not None
        pad_embed = self.pad_embed[:, None, None, None, :]
        all_pad = image_masks == 0
720
721
722
        partial_pad = torch.logical_and(image_masks < 1, torch.logical_not(all_pad)).to(
            dtype=torch.float32
        )
723
        all_pad = all_pad.to(dtype=torch.float32)
724
        image_features = image_features + pad_embed[0] * torch.unsqueeze(all_pad, -1)
725
        image_features = image_features + pad_embed[1] * torch.unsqueeze(
726
727
            partial_pad, -1
        )
728
729
730
731

        image_features = image_features.to(og_dtype)

        image_features = image_features.reshape(
732
733
            (batch_size, num_image) + self.image_num_patch + (-1,),
        )
734

735
        if missing_w := self.image_num_patch[0] % POOLING_SIZE:
736
            # Padding for image pooling (see below)
737
738
            image_features = F.pad(
                image_features,
739
                (0, 0, 0, missing_w, 0, missing_w, 0, 0, 0, 0),
740
741
742
743
744
            )

        # image pooling
        image_features = rearrange(
            image_features,
745
            "b n (h dh) (w dw) c -> (b n h w) (dh dw) c",
746
747
            dh=POOLING_SIZE,
            dw=POOLING_SIZE,
748
749
750
751
752
753
754
755
756
757
758
759
760
        )

        query = image_features.mean(-2, keepdim=True)
        image_features = self.image_pooling_2d(query, image_features)

        h, w = self.llm_patches_per_crop
        image_features = image_features.view(batch_size, num_image, h * w, -1)

        image_features = self.image_projector(image_features)

        # image_features: (batch_size, num_image, num_patch, d_model)
        return image_features

761
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
762
763
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
764
765
            ("merged_linear", "gate_proj", 0),
            ("merged_linear", "up_proj", 1),
766
767
        ]
        params_dict = dict(self.named_parameters())
768
        loaded_params: set[str] = set()
769
770

        for name, loaded_weight in weights:
771
            for param_name, weight_name, shard_id in stacked_params_mapping:
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
790
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
791
792
793
794
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

795

796
@support_torch_compile
797
class MolmoModel(nn.Module, SupportsQuant):
798
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
799
        super().__init__()
800
801
802
803
804

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

805
806
807
808
809
810
811
812
813
814
        self.config = config

        self.embedding_size = config.embedding_size or config.vocab_size
        self.embedding_size += ADDITIONAL_VOCAB_SIZE
        self.embed_tokens = VocabParallelEmbedding(
            self.embedding_size,
            config.hidden_size,
            quant_config=quant_config,
        )

815
816
817
        decoder_layer = (
            MolmoDecoderNormAfterLayer if config.norm_after else MolmoDecoderLayer
        )
818
819
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
820
            lambda prefix: decoder_layer(
821
822
                config, cache_config, quant_config, prefix=prefix
            ),
823
824
825
826
827
828
            prefix=f"{prefix}.layers",
        )

        assert config.layer_norm_type == "rms"
        self.norm = RMSNorm(config.hidden_size, config.layer_norm_eps)

829
830
831
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
832

833
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
834
835
        return self.embed_tokens(input_ids)

836
837
838
839
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
840
841
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
842
843
844
845
846
847
848
849
850
851
852
853
854
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.embed_tokens(input_ids)
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        # Apply blocks one-by-one.
855
        for layer in islice(self.layers, self.start_layer, self.end_layer):
856
857
858
859
860
861
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
        if not get_pp_group().is_last_rank:
862
863
864
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
865
866
867
868
869
870
        if residual is not None:
            hidden_states, _ = self.norm(hidden_states, residual)
        else:
            hidden_states = self.norm(hidden_states)
        return hidden_states

871
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
872
        params_dict = dict(self.named_parameters())
873
        loaded_params: set[str] = set()
874
875
876
877
878
879
880
881

        for name, loaded_weight in weights:
            if name.endswith(".bias") and name not in params_dict:
                continue
            if is_pp_missing_parameter(name, self):
                continue

            param = params_dict[name]
882
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
883
884
885
886
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

887

888
889
890
def _lowest_multiple(x: int, k: int) -> int:
    return (x // k) * k

891

892
893
894
895
896
897
898
899
900
901
def get_num_patches(
    num_tiles: int,
    *,
    crop_patches: int,
    left_margin: int,
    right_margin: int,
    pooling_size: int,
) -> int:
    if num_tiles == 1:
        return _lowest_multiple(crop_patches + pooling_size - 1, pooling_size)
902
903

    crop_window_patches = crop_patches - (left_margin + right_margin)
904
905
906
907
908
909
910
911
912
913
914
915

    left_num = _lowest_multiple(
        crop_window_patches + left_margin + pooling_size - 1,
        pooling_size,
    )
    middle_num = _lowest_multiple(
        crop_window_patches + pooling_size - 1,
        pooling_size,
    )
    right_num = _lowest_multiple(
        crop_window_patches + right_margin + pooling_size - 1,
        pooling_size,
916
917
    )

918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
    return left_num + (num_tiles - 2) * middle_num + right_num


def get_patches_grid_size(
    *,
    tiling_h: int,
    tiling_w: int,
    crop_patches: int,
    left_margin: int,
    right_margin: int,
    pooling_size: int,
) -> tuple[int, int]:
    nrows = get_num_patches(
        tiling_h,
        crop_patches=crop_patches,
        left_margin=left_margin,
        right_margin=right_margin,
        pooling_size=pooling_size,
    )
    ncols = get_num_patches(
        tiling_w,
        crop_patches=crop_patches,
        left_margin=left_margin,
        right_margin=right_margin,
        pooling_size=pooling_size,
    )
944

945
946
947
948
    return nrows, ncols


def get_candidate_tilings(max_num: int) -> list[tuple[int, int]]:
949
950
951
952
953
954
    tilings = [
        (i, j)
        for i in range(1, max_num + 1)
        for j in range(1, max_num + 1)
        if i * j <= max_num
    ]
955
956
957
958
959
960
961
962
963
    return sorted(tilings, key=lambda x: x[0] * x[1])


def select_tiling(
    *,
    height: int,
    width: int,
    patch_size: int,
    max_num_patches: int,
964
):
965
966
967
968
969
970
971
972
973
974
    tilings = get_candidate_tilings(max_num_patches)
    candidate_tilings = np.array(tilings, dtype=np.int32)
    candidate_resolutions = candidate_tilings * patch_size

    original_size = np.array([height, width], dtype=np.float32)
    required_scale_d = candidate_resolutions.astype(np.float32) / original_size
    required_scale = required_scale_d.min(axis=-1, keepdims=True)

    if (required_scale < 1).all():
        ix = required_scale.argmax()
975
    else:
976
977
978
979
980
981
982
        ix = np.where(required_scale < 1.0, 10e9, required_scale).argmin()

    return candidate_tilings[ix]


class MolmoProcessorWrapper:
    """
983
    Wraps `MolmoProcessor` so that it can be called directly.
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054

    The original definition can be found here:
    https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/preprocessing_molmo.py
    """

    def __init__(self, processor: ProcessorMixin):
        super().__init__()

        self.processor = processor

    @cached_property
    def vocab(self) -> dict[str, int]:
        return self.processor.tokenizer.vocab  # type: ignore

    @cached_property
    def max_crops(self) -> int:
        image_processor = self.processor.image_processor  # type: ignore

        max_crops = image_processor.max_crops
        assert isinstance(max_crops, int)

        return max_crops

    @cached_property
    def base_image_input_size(self) -> tuple[int, int]:
        image_processor = self.processor.image_processor  # type: ignore

        base_image_input_size = image_processor.base_image_input_size
        if isinstance(base_image_input_size, int):
            return base_image_input_size, base_image_input_size

        return tuple(base_image_input_size)

    @cached_property
    def image_patch_size(self) -> int:
        image_processor = self.processor.image_processor  # type: ignore

        image_patch_size = image_processor.image_patch_size
        assert isinstance(image_patch_size, int)

        return image_patch_size

    @cached_property
    def overlap_margins(self) -> tuple[int, int]:
        image_processor = self.processor.image_processor  # type: ignore

        left_margin, right_margin = image_processor.overlap_margins
        assert isinstance(left_margin, int)
        assert isinstance(right_margin, int)

        return left_margin, right_margin

    @cached_property
    def image_token_length_w(self) -> int:
        image_processor = self.processor.image_processor  # type: ignore

        image_token_length_w = image_processor.image_token_length_w
        assert isinstance(image_token_length_w, int)

        return image_token_length_w

    @cached_property
    def image_token_length_h(self) -> int:
        image_processor = self.processor.image_processor  # type: ignore

        image_token_length_h = image_processor.image_token_length_h
        assert isinstance(image_token_length_h, int)

        return image_token_length_h

    @property
1055
    def message_format(self) -> str | None:
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
        return "role"

    @property
    def always_start_with_space(self) -> bool:
        return True

    @cached_property
    def image_patch_id(self) -> int:
        return self.vocab[IMAGE_PATCH_TOKEN]

    @cached_property
    def im_col_id(self) -> int:
        return self.vocab[IM_COL_TOKEN]

    @cached_property
    def im_start_id(self) -> int:
        return self.vocab[IM_START_TOKEN]

    @cached_property
    def im_end_id(self) -> int:
        return self.vocab[IM_END_TOKEN]

    @property
    def pooling_size(self) -> int:
        return POOLING_SIZE

    def select_tiling(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> tuple[int, int]:
        max_crops = self.max_crops
        left_margin, right_margin = self.overlap_margins
        base_image_input_size = self.base_image_input_size
        base_image_input_d = self.image_patch_size

        total_margin_pixels = base_image_input_d * (right_margin + left_margin)
        crop_patches = base_image_input_size[0] // base_image_input_d
        crop_window_patches = crop_patches - (right_margin + left_margin)
        crop_window_size = crop_window_patches * base_image_input_d
        tiling_h, tiling_w = select_tiling(
            height=image_height - total_margin_pixels,
            width=image_width - total_margin_pixels,
            patch_size=crop_window_size,
            max_num_patches=max_crops,
1102
1103
        )

1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
        return tiling_w, tiling_h

    def get_patches_grid_size(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> tuple[int, int]:
        left_margin, right_margin = self.overlap_margins
        base_image_input_size = self.base_image_input_size
        base_image_input_d = self.image_patch_size
        pooling_size = self.pooling_size

        crop_patches = base_image_input_size[0] // base_image_input_d
        tiling_w, tiling_h = self.select_tiling(
            image_height=image_height,
            image_width=image_width,
        )

        nrows, ncols = get_patches_grid_size(
            tiling_h=tiling_h,
            tiling_w=tiling_w,
            crop_patches=crop_patches,
            left_margin=left_margin,
            right_margin=right_margin,
            pooling_size=pooling_size,
        )

        return ncols, nrows

    def __call__(
        self,
1136
1137
1138
        text: TextInput | list[TextInput] | None = None,
        images: ImageInput | list[ImageInput] | None = None,
        return_tensors: str | TensorType | None = None,
1139
1140
1141
        **kwargs,
    ) -> BatchFeature:
        outputs = self.processor.process(  # type: ignore
1142
1143
            text, images, **kwargs
        )
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154

        if images is None:
            images = []
        if not isinstance(images, list):
            images = [images]

        input_ids: torch.Tensor = outputs.pop("input_ids")
        outputs["input_ids"] = input_ids.unsqueeze(0)

        image_input_idx = outputs.pop("image_input_idx", None)
        if image_input_idx is not None:
1155
            feat_is_patch = image_input_idx >= 0
1156
1157
1158
1159
1160

            tilings = [
                self.select_tiling(
                    image_width=image.size[0],
                    image_height=image.size[1],
1161
1162
                )
                for image in images
1163
1164
1165
1166
            ]
            # For each image: tiling_h * tiling_w + extra
            num_crops = torch.tensor(tilings).prod(-1) + 1
            assert num_crops.sum() == len(feat_is_patch)
1167

1168
            outputs["image_input_idx"] = image_input_idx
1169
1170
1171
            outputs["num_crops"] = num_crops
            outputs["img_patch_id"] = self.image_patch_id

1172
        return BatchFeature(outputs)
1173
1174
1175


class MolmoProcessingInfo(BaseProcessingInfo):
1176
1177
    def get_hf_processor(self, **kwargs: object) -> MolmoProcessorWrapper:
        processor = self.ctx.get_hf_processor(**kwargs)
1178
1179
        return MolmoProcessorWrapper(processor)

1180
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
1181
        return {"image": None}
1182
1183
1184
1185
1186
1187

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
1188
        processor: MolmoProcessorWrapper | None,
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
    ) -> int:
        if processor is None:
            processor = self.get_hf_processor()

        ncols, nrows = processor.get_patches_grid_size(
            image_width=image_width,
            image_height=image_height,
        )
        pooling_size = processor.pooling_size

1199
1200
        image_token_length_w = processor.image_token_length_w
        image_token_length_h = processor.image_token_length_h
1201

1202
1203
1204
        # Calculate total tokens: 2 for start/end + (w+1)*h for column separators
        extra = 2 + (image_token_length_w + 1) * image_token_length_h
        joint = 2 + ((ncols + 1) // pooling_size + 1) * ((nrows + 1) // pooling_size)
1205

1206
        return extra + joint
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224

    def get_image_size_with_most_features(self) -> ImageSize:
        processor = self.get_hf_processor()

        tilings = get_candidate_tilings(processor.max_crops)
        base_h, base_w = processor.base_image_input_size

        largest_feature_size, largest_feature_pinpoint = 0, None
        for wr, hr in tilings:
            width, height = base_w * wr, base_h * hr

            feat_size = self.get_num_image_tokens(
                image_width=width,
                image_height=height,
                processor=processor,
            )
            if feat_size > largest_feature_size:
                largest_feature_size = feat_size
1225
                largest_feature_pinpoint = ImageSize(width=width, height=height)
1226
1227
1228
1229
1230
1231
1232
1233

        if largest_feature_size == 0 or largest_feature_pinpoint is None:
            raise ValueError("Cannot have a largest feature size of 0!")

        return largest_feature_pinpoint


class MolmoDummyInputsBuilder(BaseDummyInputsBuilder[MolmoProcessingInfo]):
1234
1235
1236
1237
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
1238
1239
1240
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1241
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
1242
    ) -> MultiModalDataDict:
1243
        target_width, target_height = self.info.get_image_size_with_most_features()
1244
1245
        num_images = mm_counts.get("image", 0)

1246
1247
        image_overrides = mm_options.get("image") if mm_options else None

1248
        return {
1249
1250
1251
1252
1253
1254
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
        }


class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        processor = self.info.get_hf_processor()

1265
1266
1267
        # The chat template is already applied to the prompt tokens
        # Use message_format="none" to avoid applying it again
        # Prepend an empty space if `always_start_with_space` is True
1268
1269
        tokens = processor.processor.get_tokens_input(  # type: ignore
            self.info.get_tokenizer().decode(prompt_tokens),
1270
            message_format="none",
1271
1272
1273
            always_start_with_space=processor.always_start_with_space,
        )

1274
        # Prepend a BOS token id to the tokens
1275
1276
1277
1278
        processed_data = self.info.ctx.call_hf_processor(
            processor,  # type: ignore
            dict(tokens=tokens),
        )
1279
        (prompt_ids,) = processed_data.pop("input_ids").tolist()
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292

        return prompt_ids

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        num_crops = hf_inputs.get("num_crops", torch.empty(0))
        num_images = len(num_crops)

        return dict(
            images=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
1293
            image_masks=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
1294
            image_input_idx=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
1295
1296
1297
1298
            num_crops=MultiModalFieldConfig.batched("image"),
            img_patch_id=MultiModalFieldConfig.shared("image", num_images),
        )

1299
    def _get_prompt_updates(
1300
1301
1302
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1303
        out_mm_kwargs: MultiModalKwargsItems,
1304
    ) -> Sequence[PromptUpdate]:
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

        image_token_length_w = processor.image_token_length_w
        image_token_length_h = processor.image_token_length_h
        pooling_size = processor.pooling_size

        img_patch_id = processor.image_patch_id
        img_col_id = processor.im_col_id
        img_start_id = processor.im_start_id
        img_end_id = processor.im_end_id

        extra_row = [img_patch_id] * image_token_length_w + [img_col_id]
1317
        extra_joint = [img_start_id] + extra_row * image_token_length_h + [img_end_id]
1318

1319
        def get_insertion_molmo(item_idx: int):
1320
1321
1322
1323
1324
1325
1326
1327
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)

            ncols, nrows = processor.get_patches_grid_size(
                image_width=image_size.width,
                image_height=image_size.height,
            )

1328
1329
1330
1331
1332
1333
            joint_row = [img_patch_id] * ((ncols + 1) // pooling_size) + [img_col_id]
            joint = (
                [img_start_id]
                + joint_row * ((nrows + 1) // pooling_size)
                + [img_end_id]
            )
1334

1335
1336
1337
1338
            return PromptUpdateDetails.select_token_id(
                extra_joint + joint,
                embed_token_id=img_patch_id,
            )
1339
1340

        return [
1341
            PromptInsertion(
1342
                modality="image",
1343
                target=PromptIndexTargets.prefix("<|endoftext|>"),
1344
                insertion=get_insertion_molmo,
1345
1346
1347
1348
            )
        ]


1349
1350
1351
1352
1353
1354
1355
1356
@MULTIMODAL_REGISTRY.register_processor(
    MolmoMultiModalProcessor,
    info=MolmoProcessingInfo,
    dummy_inputs=MolmoDummyInputsBuilder,
)
class MolmoForCausalLM(
    nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsQuant
):
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            # vision backbone mapping
            "image_projector.w1.": "image_projector.gate_proj.",
            "image_projector.w3.": "image_projector.up_proj.",
            "image_projector.w2.": "image_projector.down_proj.",
            # language backbone mapping
            "att_proj": "self_attn.qkv_proj",
            "attn_out": "self_attn.o_proj",
            "q_norm": "self_attn.q_norm",
            "k_norm": "self_attn.k_norm",
            "ff_proj": "mlp.gate_up_proj",
            "ff_out": "mlp.down_proj",
            "attn_norm": "input_layernorm",
            "ff_norm": "post_attention_layernorm",
        },
        orig_to_new_prefix={
            # vision backbone mapping
            "model.vision_backbone.": "vision_backbone.",
            # language backbone mapping
            "model.transformer.blocks.": "model.layers.",
            "model.transformer.ln_f.": "model.norm.",
            # lm_head is renamed to model.transformer.mlp.down_proj firstly,
            # we need to run a second renaming for it
            "model.transformer.mlp.down_proj.": "lm_head.",
        },
    )

1385
1386
1387
    packed_modules_mapping = {
        "qkv_proj": ["qkv_proj"],
        "gate_up_proj": ["gate_up_proj"],  # language model
1388
        "merged_linear": ["gate_proj", "up_proj"],  # image_projector
1389
1390
    }

1391
    @classmethod
1392
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
1393
1394
1395
1396
1397
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

1398
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1399
        super().__init__()
1400
1401
1402
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
1403

1404
1405
1406
1407
        self.config = config
        self.multimodal_config = multimodal_config

        vision_config = VisionBackboneConfig()
1408
1409
1410
1411
        self.vision_backbone = MolmoVisionBackbone(config, vision_config, quant_config)
        self.model = MolmoModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
1412
        self.img_patch_id = None
1413
1414
1415
1416
1417
1418
1419
1420

        if self.config.weight_tying:
            self.lm_head = self.model.transformer.wte
        else:
            self.lm_head = ParallelLMHead(
                config.embedding_size or config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
1421
                prefix=maybe_prefix(prefix, "lm_head"),
1422
1423
            )

1424
1425
1426
        self.logits_processor = LogitsProcessor(
            config.embedding_size or config.vocab_size
        )
1427

1428
        self.make_empty_intermediate_tensors = (
1429
1430
            self.model.make_empty_intermediate_tensors
        )
1431

1432
1433
1434
    def _parse_and_validate_image_input(
        self,
        **kwargs: object,
1435
    ) -> MolmoImageInputs | None:
1436
        images = kwargs.pop("images", None)
1437
        image_masks = kwargs.pop("image_masks", None)
1438
        image_input_idx = kwargs.pop("image_input_idx", None)
1439
        num_crops = kwargs.pop("num_crops", None)
1440
1441
1442
1443

        if images is None:
            return None

1444
        img_patch_id = kwargs.pop("img_patch_id", None)
1445
1446
1447
1448
1449
        if isinstance(img_patch_id, torch.Tensor):
            img_patch_id = img_patch_id.item()

        assert isinstance(img_patch_id, int)
        self.img_patch_id = img_patch_id
1450
1451
1452
1453

        return MolmoImageInputs(
            images=images,
            image_masks=image_masks,
1454
            image_input_idx=image_input_idx,
1455
            num_crops=num_crops,
1456
1457
1458
1459
1460
        )

    def _process_image_input(
        self,
        image_input: MolmoImageInputs,
1461
1462
1463
    ) -> list[torch.Tensor]:
        images = image_input["images"]
        image_masks = image_input["image_masks"]
1464
        image_input_idx = image_input["image_input_idx"]
1465
1466
        num_crops = image_input["num_crops"]

1467
        # Call the vision backbone on the whole batch at once
1468
1469
1470
        image_features = self.vision_backbone(
            images=images.unsqueeze(0),
            image_masks=None if image_masks is None else image_masks.unsqueeze(0),
1471
        ).squeeze(0)
1472

1473
        # Only the features corresponding to patch tokens are relevant
1474
1475
1476
1477
        # Re-order the features using the image_input_idx tensor
        results = []
        num_crops_list = num_crops.tolist()
        for feats, img_idx in zip(
1478
1479
            image_features.split(num_crops_list),
            image_input_idx.split(num_crops_list),
1480
1481
1482
1483
1484
1485
        ):
            is_valid = img_idx >= 0
            valid_img_idx = img_idx[is_valid]
            order = torch.argsort(valid_img_idx)
            results.append(feats[is_valid][order])
        return results
1486

1487
1488
1489
    def get_language_model(self) -> torch.nn.Module:
        return self.model

1490
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
1491
1492
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
1493
            return []
1494

1495
        return self._process_image_input(image_input)
1496
1497
1498
1499
1500

    def forward(
        self,
        input_ids: torch.LongTensor,
        positions: torch.LongTensor,
1501
1502
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1503
        **kwargs: object,
1504
    ) -> torch.Tensor:
1505
1506
        if intermediate_tensors is not None:
            inputs_embeds = None
1507

1508
1509
1510
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
1511
1512
1513

        return hidden_states

1514
1515
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
1516
1517
        return logits

1518
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
1519
1520
        loader = AutoWeightsLoader(self)
        weights = _get_weights_with_merged_embedding(weights)
1521
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
1522

1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="model",
            connector="vision_backbone.image_projector",
            tower_model="vision_backbone",
        )

1533
1534

def _get_weights_with_merged_embedding(
1535
    weights: Iterable[tuple[str, torch.Tensor]],
1536
) -> Iterable[tuple[str, torch.Tensor]]:
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
    embedding_weights = {}
    for name, weight in weights:
        if "wte.embedding" in name:
            embedding_weights["embedding"] = weight
        elif "wte.new_embedding" in name:
            embedding_weights["new_embedding"] = weight
        else:
            yield (name, weight)
    # this is compatible with most of quantization,
    # because they won't quantize embed_tokens
    embedding_weights = torch.cat(
        [embedding_weights["embedding"], embedding_weights["new_embedding"]],
        dim=0,
    )
    yield ("model.embed_tokens.weight", embedding_weights)