molmo.py 51.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import math
5
from collections.abc import Iterable, Mapping, Sequence
6
from dataclasses import dataclass
7
from functools import cached_property, partial
8
from itertools import islice
9
from typing import Annotated, Optional, Union
10

11
import numpy as np
12
import torch
13
14
import torch.nn as nn
import torch.nn.functional as F
15
from einops import rearrange
16
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin, TensorType
17
18
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput
19

20
from vllm.attention import Attention
21
from vllm.attention.layer import MultiHeadAttention
22
from vllm.compilation.decorators import support_torch_compile
23
from vllm.config import CacheConfig, VllmConfig
24
from vllm.config.multimodal import BaseDummyOptions
25
26
27
28
29
30
31
32
from vllm.distributed import (
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    split_tensor_along_last_dim,
    tensor_model_parallel_all_gather,
)
from vllm.model_executor.layers.activation import MulAndSilu, QuickGELU, SiluAndMul
33
from vllm.model_executor.layers.layernorm import RMSNorm
34
35
36
37
38
39
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
40
from vllm.model_executor.layers.logits_processor import LogitsProcessor
41
from vllm.model_executor.layers.quantization import QuantizationConfig
42
43
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
44
45
46
    ParallelLMHead,
    VocabParallelEmbedding,
)
47
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
48
from vllm.model_executor.models.module_mapping import MultiModelKeys
49
from vllm.multimodal import MULTIMODAL_REGISTRY
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptIndexTargets,
    PromptInsertion,
    PromptUpdate,
    PromptUpdateDetails,
)
64
from vllm.multimodal.profiling import BaseDummyInputsBuilder
65
from vllm.sequence import IntermediateTensors
66
from vllm.utils.tensor_schema import TensorSchema, TensorShape
67

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
    SupportsQuant,
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    flatten_bn,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
84
85
86
87
88

# TODO: hard-coded for now. Consider making it configurable.
VIT_LAYERS = [-2, -9]
NUM_PREFIX_TOKENS = 1
ADDITIONAL_VOCAB_SIZE = 128
89
90
91
92
93
IMAGE_PATCH_TOKEN = "<im_patch>"
IM_COL_TOKEN = "<im_col>"
IM_START_TOKEN = "<im_start>"
IM_END_TOKEN = "<im_end>"
POOLING_SIZE = 2
94
95


96
class MolmoImageInputs(TensorSchema):
97
    """
98
99
    Dimensions:
        - bn: Batch size * number of images
100
        - nc: Number of crops (dynamic)
101
        - np: Number of patches
102
        - tp: Token sequence positions
103
        - pd: Patch dimension
104
    """
105
106
107
108
109

    images: Annotated[
        Union[torch.Tensor, list[torch.Tensor]],
        TensorShape("bn", "nc", "np", "pd", dynamic_dims={"nc"}),
    ]
110
    # Number of crops may vary per batch and image, so pass it as a list.
111

112
113
114
115
    image_masks: Annotated[
        Optional[Union[torch.Tensor, list[torch.Tensor]]],
        TensorShape("bn", "nc", "np", dynamic_dims={"nc"}),
    ]
116

117
118
    feat_is_patch: Annotated[
        Union[torch.Tensor, list[torch.Tensor]],
119
120
        TensorShape("bn", "nc", "tp", dynamic_dims={"nc"}),
    ]
121
122
    # A boolean mask indicating which image features correspond to patch tokens.
    num_crops: Annotated[torch.Tensor, TensorShape("bn")]
123

124
125
126

@dataclass
class VisionBackboneConfig:
127
    image_default_input_size: tuple[int, int] = (336, 336)
128
129
130
131
132
133
134
135
136
137
138
139
    image_patch_size: int = 14
    image_pos_patch_size: int = 14
    image_emb_dim: int = 1024
    image_num_heads: int = 16
    image_num_key_value_heads: int = 16
    image_num_layers: int = 23
    image_mlp_dim: int = 4096
    image_mlp_activations: str = "quick_gelu"
    image_num_pos: int = 577
    image_norm_eps: float = 1e-5

    def __post_init__(self):
140
        self.image_default_input_size = tuple(self.image_default_input_size)  # type: ignore[assignment]
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234

    @property
    def image_num_patch(self):
        h, w = self.image_default_input_size
        return h // self.image_patch_size, w // self.image_patch_size


class ViTMLP(nn.Module):
    """MLP used in Vision Transformer."""

    def __init__(
        self,
        config: VisionBackboneConfig,
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()
        self.w1 = ColumnParallelLinear(
            config.image_emb_dim,
            config.image_mlp_dim,
            bias=True,
            quant_config=quant_config,
        )
        # Activation function.
        assert config.image_mlp_activations == "quick_gelu"
        self.act = QuickGELU()
        self.w2 = RowParallelLinear(
            config.image_mlp_dim,
            config.image_emb_dim,
            bias=True,
            quant_config=quant_config,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.w1(x)
        x = self.act(x)
        x, _ = self.w2(x)
        return x


class MultiHeadDotProductAttention(nn.Module):
    """Multi-head attention used in Vision Transformer."""

    def __init__(
        self,
        config: VisionBackboneConfig,
        use_bias: bool = True,
        nlayers: int = 1,
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()

        self.hidden_size = config.image_emb_dim
        self.total_num_heads = config.image_num_heads
        tp_size = get_tensor_model_parallel_world_size()

        assert self.hidden_size % self.total_num_heads == 0
        assert self.total_num_heads % tp_size == 0

        self.num_heads = self.total_num_heads // tp_size
        self.head_dim = self.hidden_size // self.total_num_heads

        self.total_num_kv_heads = config.image_num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            assert self.total_num_kv_heads % tp_size == 0
        else:
            assert tp_size % self.total_num_kv_heads == 0

        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)

        self.wq = ColumnParallelLinear(
            nlayers * self.hidden_size,
            self.total_num_heads * self.head_dim,
            bias=use_bias,
            quant_config=quant_config,
        )
        self.wk = ColumnParallelLinear(
            nlayers * self.hidden_size,
            self.total_num_kv_heads * self.head_dim,
            bias=use_bias,
            quant_config=quant_config,
        )
        self.wv = ColumnParallelLinear(
            nlayers * self.hidden_size,
            self.total_num_kv_heads * self.head_dim,
            bias=use_bias,
            quant_config=quant_config,
        )
        self.wo = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=use_bias,
            quant_config=quant_config,
        )

235
        self.scale = self.head_dim**-0.5
236
237
238
        self.attn = MultiHeadAttention(
            self.num_heads, self.head_dim, self.scale, num_kv_heads=self.num_kv_heads
        )
239

240
241
242
    def forward(
        self, inputs_q: torch.Tensor, inputs_kv: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
243
244
245
246
247
248
249
250
251
252
        if inputs_kv is not None:
            inputs_k = inputs_kv
            inputs_v = inputs_kv
        else:
            inputs_k = inputs_q
            inputs_v = inputs_q

        xq, _ = self.wq(inputs_q)
        xk, _ = self.wk(inputs_k)
        xv, _ = self.wv(inputs_v)
253
254

        output = self.attn(xq, xk, xv)
255
256
257
258
259
260
261
262
263
264
265
266
267
268
        output, _ = self.wo(output)

        return output


class ResidualAttentionBlock(nn.Module):
    """Residual attention block used in Vision Transformer."""

    def __init__(
        self,
        config: VisionBackboneConfig,
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()
269
        self.attention = MultiHeadDotProductAttention(config, quant_config=quant_config)
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
        self.feed_forward = ViTMLP(config, quant_config)
        self.attention_norm = nn.LayerNorm(
            config.image_emb_dim,
            eps=config.image_norm_eps,
        )
        self.ffn_norm = nn.LayerNorm(
            config.image_emb_dim,
            eps=config.image_norm_eps,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attention(self.attention_norm(x))
        x = x + self.feed_forward(self.ffn_norm(x))
        return x


class BlockCollection(nn.Module):
    """Collection of residual attention blocks used in Vision Transformer."""

    def __init__(
        self,
        config: VisionBackboneConfig,
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()
295
296
297
298
299
300
        self.resblocks = nn.ModuleList(
            [
                ResidualAttentionBlock(config, quant_config)
                for _ in range(config.image_num_layers)
            ]
        )
301

302
    def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
        hidden_states = []
        for r in self.resblocks:
            x = r(x)
            hidden_states.append(x)
        return hidden_states


def _expand_token(token: torch.Tensor, batch_size: int) -> torch.Tensor:
    return token.view(1, 1, -1).expand(batch_size, -1, -1)


class VisionTransformer(nn.Module):
    """Vision Transformer used in Vision Backbone."""

    def __init__(
        self,
        config: VisionBackboneConfig,
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()
        scale = config.image_emb_dim**-0.5
        self.patch_num = config.image_num_patch
325
        self.class_embedding = nn.Parameter(torch.randn(config.image_emb_dim) * scale)
326
327
        self.num_prefix_tokens: int = NUM_PREFIX_TOKENS
        self.positional_embedding = nn.Parameter(
328
329
            torch.randn(config.image_num_pos, config.image_emb_dim) * scale
        )
330
331
332
333
334
335
        image_patch_size = config.image_patch_size
        self.patch_embedding = nn.Linear(
            image_patch_size * image_patch_size * 3,
            config.image_emb_dim,
            bias=False,
        )
336
        self.pre_ln = nn.LayerNorm(config.image_emb_dim, eps=config.image_norm_eps)
337
338
339
340
341
342
343
        self.transformer = BlockCollection(config, quant_config)

    def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor:
        cls_emb = self.positional_embedding[0:1]
        pos_emb = self.positional_embedding[1:]

        pos_emb = pos_emb.reshape(
344
345
346
347
348
349
            (
                int(math.sqrt(pos_emb.shape[0])),
                int(math.sqrt(pos_emb.shape[0])),
                pos_emb.shape[1],
            )
        )
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365

        (patch_num_0, patch_num_1) = patch_num

        if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1:
            # from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
            pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2)
            pos_emb = F.interpolate(
                pos_emb,
                size=(patch_num_0, patch_num_1),
                mode="bicubic",
                align_corners=False,
                antialias=True,
            )
            pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0)

        pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1])
366
        x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]], dim=1).to(x.dtype)
367
368
        return x

369
370
371
    def forward(
        self, x: torch.Tensor, patch_num: Optional[int] = None
    ) -> list[torch.Tensor]:
372
373
374
375
376
377
378
379
380
381
382
        """
        : param x: (batch_size, num_patch, n_pixels)
        """
        if patch_num is None:
            patch_num = self.patch_num
        B, N, D = x.shape

        x = self.patch_embedding(x)

        # class embeddings and positional embeddings
        x = torch.cat(
383
384
            [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1
        )
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
        x = self.add_pos_emb(x, patch_num)

        x = self.pre_ln(x)

        hidden_states = self.transformer(x)
        return hidden_states


class MolmoAttention(nn.Module):
    """Molmo's LLM attention."""

    def __init__(
        self,
        config: PretrainedConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
401
        prefix: str = "",
402
403
404
405
406
407
408
409
410
411
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        self.tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads

        assert self.hidden_size % self.total_num_heads == 0
        assert self.total_num_heads % self.tp_size == 0

        self.num_heads = self.total_num_heads // self.tp_size
412
        self.total_num_kv_heads = config.num_key_value_heads or self.total_num_heads
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
        if self.total_num_kv_heads >= self.tp_size:
            assert self.total_num_kv_heads % self.tp_size == 0
        else:
            assert self.tp_size % self.total_num_kv_heads == 0

        self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
        self.head_dim = self.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta

        # Attention input projection. Projects x -> (q, k, v)
        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=config.qkv_bias,
            quant_config=quant_config,
        )

        self.tp_rank: Optional[int] = None
        self.k_norm: Optional[nn.Module] = None
        self.q_norm: Optional[nn.Module] = None
        if config.attention_layer_norm:
            self.tp_rank = get_tensor_model_parallel_rank()
440
441
442
443
            self.k_norm = RMSNorm(
                self.total_num_kv_heads * self.head_dim, eps=config.layer_norm_eps
            )
            self.q_norm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
444
445
446
447
448
449
450
451
452

        # Rotary embeddings.
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=self.rope_theta,
        )
        self.scaling = self.head_dim**-0.5
453
454
455
456
457
458
459
460
461
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
462
463
464
465
466
467
468
469
470

        # Attention output projection.
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
        )

471
472
473
    def _apply_qk_norm(
        self, q: torch.Tensor, k: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
474
475
476
        if self.tp_size > 1:
            q = tensor_model_parallel_all_gather(q.contiguous())
            k = tensor_model_parallel_all_gather(k.contiguous())
477
478
        q = self.q_norm(q)
        k = self.k_norm(k)
479
        if self.tp_size > 1:
480
            splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size)
481
482
483
484
485
486
487
488
489
490
491
492
493
494
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
        return q, k

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        if self.q_norm is not None and self.k_norm is not None:
            q, k = self._apply_qk_norm(q, k)
        q, k = self.rotary_emb(positions, q, k)
495
        attn_output = self.attn(q, k, v)
496
497
498
499
        output, _ = self.o_proj(attn_output)
        return output


500
class LanguageModelMLP(nn.Module):
501
502
    """Molmo's LLM mlp."""

503
504
505
506
507
508
    def __init__(
        self,
        config: PretrainedConfig,
        input_dim: Optional[int] = None,
        quant_config: Optional[QuantizationConfig] = None,
    ) -> None:
509
510
511
512
        super().__init__()
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size // 2

513
514
515
516
517
518
519
        self.gate_up_proj = MergedColumnParallelLinear(
            input_dim or self.hidden_size,
            [self.intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
        )
        # Activation function.
520
        self.act_fn = MulAndSilu()
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
        # Feed-forward output projection.
        self.down_proj = RowParallelLinear(
            self.intermediate_size,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
        )

    def forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class ImageProjectorMLP(nn.Module):
    """Molmo's image_projector mlp."""

    def __init__(
        self,
        config: PretrainedConfig,
        input_dim: Optional[int] = None,
        quant_config: Optional[QuantizationConfig] = None,
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size // 2

        self.merged_linear = MergedColumnParallelLinear(
            input_dim or self.hidden_size,
            [self.intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
        )
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
        # Activation function.
        self.act_fn = SiluAndMul()

        # Feed-forward output projection.
        self.down_proj = RowParallelLinear(
            self.intermediate_size,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
        )

    def forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
573
        gate_up, _ = self.merged_linear(x)
574
575
576
577
578
579
580
581
582
583
584
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class MolmoDecoderLayer(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
585
        prefix: str = "",
586
587
588
    ) -> None:
        super().__init__()
        # Attention block.
589
590
591
        self.self_attn = MolmoAttention(
            config, cache_config, quant_config, prefix=f"{prefix}.self_attn"
        )
592
593

        # MLP block.
594
        self.mlp = LanguageModelMLP(config, quant_config=quant_config)
595
596
597

        # LayerNorm
        assert config.layer_norm_type == "rms"
598
599
600
601
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.layer_norm_eps
        )
602
603
604
605
606
607

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
608
    ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
609
610
611
612
613
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
614
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
615
616
617
618
619
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

620
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
621
622
623
624
625
626
627
628
629
630
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
631
    ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
        # Self Attention
        residual = hidden_states
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = hidden_states + residual
        residual = hidden_states

        hidden_states = self.mlp(hidden_states)
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = hidden_states + residual
        residual = None
        return hidden_states, residual


650
651
class MolmoVisionBackbone(nn.Module, SupportsQuant):
    packed_modules_mapping = {"merged_linear": ["gate_proj", "up_proj"]}
652
653
654
655
656
657
658
659
660
661
662

    def __init__(
        self,
        config: PretrainedConfig,
        vision_config: VisionBackboneConfig,
        quant_config: Optional[QuantizationConfig] = None,
    ) -> None:
        super().__init__()
        self.vit_layers = VIT_LAYERS
        self.image_num_patch = vision_config.image_num_patch
        self.llm_patches_per_crop = (
663
664
            (self.image_num_patch[0] + 1) // POOLING_SIZE,
            (self.image_num_patch[1] + 1) // POOLING_SIZE,
665
        )
666
        self.image_vit = VisionTransformer(vision_config, quant_config=quant_config)
667
        self.num_prefix_tokens = self.image_vit.num_prefix_tokens
668
669
670
        assert self.num_prefix_tokens in {0, 1}, (
            "Only 0 or 1 prefix tokens are supported"
        )
671
        self.image_pooling_2d = MultiHeadDotProductAttention(
672
673
            vision_config, nlayers=len(self.vit_layers), quant_config=quant_config
        )
674
        self.image_projector = ImageProjectorMLP(
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
            config,
            input_dim=vision_config.image_emb_dim,
            quant_config=quant_config,
        )

        image_dim = vision_config.image_emb_dim * len(self.vit_layers)
        self.pad_embed = nn.Parameter(torch.zeros((2, image_dim)))

    @property
    def dtype(self) -> torch.dtype:
        return self.image_vit.patch_embedding.weight.dtype

    @property
    def device(self) -> torch.device:
        return self.image_vit.patch_embedding.weight.device

    def encode_image(self, images: torch.Tensor) -> torch.Tensor:
        """
        : param images: (batch_size, num_crops, num_patch, n_pixels)
        """
        B, T, N, D = images.shape

697
        mask = ~torch.all(images.view(B * T, N, D) == -1, dim=(1, 2), keepdim=True)
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718

        images = images.view(B * T, N, D)
        image_features = self.image_vit(images)

        if self.vit_layers is not None:
            features = []
            for layer in self.vit_layers:
                features.append(image_features[layer])
            image_features = torch.cat(features, dim=-1)
        else:
            image_features = image_features[-1]

        if self.num_prefix_tokens > 0:
            image_features = image_features[:, 1:]

        image_features = image_features * mask
        image_features = image_features.view(B, T, N, -1)

        return image_features

    def forward(
719
720
721
722
        self,
        images: torch.Tensor,
        image_masks: torch.Tensor,
    ) -> torch.Tensor:
723
724
725
726
727
728
729
730
731
        # image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) # noqa: E501
        batch_size, num_image = images.shape[:2]
        images = images.to(device=self.device, dtype=self.dtype)
        image_features = self.encode_image(images)

        og_dtype = image_features.dtype
        assert image_masks is not None
        pad_embed = self.pad_embed[:, None, None, None, :]
        all_pad = image_masks == 0
732
733
734
        partial_pad = torch.logical_and(image_masks < 1, torch.logical_not(all_pad)).to(
            dtype=torch.float32
        )
735
        all_pad = all_pad.to(dtype=torch.float32)
736
        image_features = image_features + pad_embed[0] * torch.unsqueeze(all_pad, -1)
737
        image_features = image_features + pad_embed[1] * torch.unsqueeze(
738
739
            partial_pad, -1
        )
740
741
742
743

        image_features = image_features.to(og_dtype)

        image_features = image_features.reshape(
744
745
            (batch_size, num_image) + self.image_num_patch + (-1,),
        )
746

747
        if missing_w := self.image_num_patch[0] % POOLING_SIZE:
748
            # Padding for image pooling (see below)
749
750
            image_features = F.pad(
                image_features,
751
                (0, 0, 0, missing_w, 0, missing_w, 0, 0, 0, 0),
752
753
754
755
756
            )

        # image pooling
        image_features = rearrange(
            image_features,
757
            "b n (h dh) (w dw) c -> (b n h w) (dh dw) c",
758
759
            dh=POOLING_SIZE,
            dw=POOLING_SIZE,
760
761
762
763
764
765
766
767
768
769
770
771
772
        )

        query = image_features.mean(-2, keepdim=True)
        image_features = self.image_pooling_2d(query, image_features)

        h, w = self.llm_patches_per_crop
        image_features = image_features.view(batch_size, num_image, h * w, -1)

        image_features = self.image_projector(image_features)

        # image_features: (batch_size, num_image, num_patch, d_model)
        return image_features

773
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
774
775
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
776
777
            ("merged_linear", "gate_proj", 0),
            ("merged_linear", "up_proj", 1),
778
779
        ]
        params_dict = dict(self.named_parameters())
780
        loaded_params: set[str] = set()
781
782

        for name, loaded_weight in weights:
783
            for param_name, weight_name, shard_id in stacked_params_mapping:
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
802
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
803
804
805
806
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

807

808
@support_torch_compile
809
class MolmoModel(nn.Module, SupportsQuant):
810
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
811
        super().__init__()
812
813
814
815
816

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

817
818
819
820
821
822
823
824
825
826
        self.config = config

        self.embedding_size = config.embedding_size or config.vocab_size
        self.embedding_size += ADDITIONAL_VOCAB_SIZE
        self.embed_tokens = VocabParallelEmbedding(
            self.embedding_size,
            config.hidden_size,
            quant_config=quant_config,
        )

827
828
829
        decoder_layer = (
            MolmoDecoderNormAfterLayer if config.norm_after else MolmoDecoderLayer
        )
830
831
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
832
            lambda prefix: decoder_layer(
833
834
                config, cache_config, quant_config, prefix=prefix
            ),
835
836
837
838
839
840
            prefix=f"{prefix}.layers",
        )

        assert config.layer_norm_type == "rms"
        self.norm = RMSNorm(config.hidden_size, config.layer_norm_eps)

841
842
843
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
844

845
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
846
847
        return self.embed_tokens(input_ids)

848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.embed_tokens(input_ids)
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        # Apply blocks one-by-one.
867
        for layer in islice(self.layers, self.start_layer, self.end_layer):
868
869
870
871
872
873
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
        if not get_pp_group().is_last_rank:
874
875
876
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
877
878
879
880
881
882
        if residual is not None:
            hidden_states, _ = self.norm(hidden_states, residual)
        else:
            hidden_states = self.norm(hidden_states)
        return hidden_states

883
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
884
        params_dict = dict(self.named_parameters())
885
        loaded_params: set[str] = set()
886
887
888
889
890
891
892
893

        for name, loaded_weight in weights:
            if name.endswith(".bias") and name not in params_dict:
                continue
            if is_pp_missing_parameter(name, self):
                continue

            param = params_dict[name]
894
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
895
896
897
898
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

899

900
901
902
def _lowest_multiple(x: int, k: int) -> int:
    return (x // k) * k

903

904
905
906
907
908
909
910
911
912
913
def get_num_patches(
    num_tiles: int,
    *,
    crop_patches: int,
    left_margin: int,
    right_margin: int,
    pooling_size: int,
) -> int:
    if num_tiles == 1:
        return _lowest_multiple(crop_patches + pooling_size - 1, pooling_size)
914
915

    crop_window_patches = crop_patches - (left_margin + right_margin)
916
917
918
919
920
921
922
923
924
925
926
927

    left_num = _lowest_multiple(
        crop_window_patches + left_margin + pooling_size - 1,
        pooling_size,
    )
    middle_num = _lowest_multiple(
        crop_window_patches + pooling_size - 1,
        pooling_size,
    )
    right_num = _lowest_multiple(
        crop_window_patches + right_margin + pooling_size - 1,
        pooling_size,
928
929
    )

930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
    return left_num + (num_tiles - 2) * middle_num + right_num


def get_patches_grid_size(
    *,
    tiling_h: int,
    tiling_w: int,
    crop_patches: int,
    left_margin: int,
    right_margin: int,
    pooling_size: int,
) -> tuple[int, int]:
    nrows = get_num_patches(
        tiling_h,
        crop_patches=crop_patches,
        left_margin=left_margin,
        right_margin=right_margin,
        pooling_size=pooling_size,
    )
    ncols = get_num_patches(
        tiling_w,
        crop_patches=crop_patches,
        left_margin=left_margin,
        right_margin=right_margin,
        pooling_size=pooling_size,
    )
956

957
958
959
960
    return nrows, ncols


def get_candidate_tilings(max_num: int) -> list[tuple[int, int]]:
961
962
963
964
965
966
    tilings = [
        (i, j)
        for i in range(1, max_num + 1)
        for j in range(1, max_num + 1)
        if i * j <= max_num
    ]
967
968
969
970
971
972
973
974
975
    return sorted(tilings, key=lambda x: x[0] * x[1])


def select_tiling(
    *,
    height: int,
    width: int,
    patch_size: int,
    max_num_patches: int,
976
):
977
978
979
980
981
982
983
984
985
986
    tilings = get_candidate_tilings(max_num_patches)
    candidate_tilings = np.array(tilings, dtype=np.int32)
    candidate_resolutions = candidate_tilings * patch_size

    original_size = np.array([height, width], dtype=np.float32)
    required_scale_d = candidate_resolutions.astype(np.float32) / original_size
    required_scale = required_scale_d.min(axis=-1, keepdims=True)

    if (required_scale < 1).all():
        ix = required_scale.argmax()
987
    else:
988
989
990
991
992
993
994
        ix = np.where(required_scale < 1.0, 10e9, required_scale).argmin()

    return candidate_tilings[ix]


class MolmoProcessorWrapper:
    """
995
    Wraps `MolmoProcessor` so that it can be called directly.
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113

    The original definition can be found here:
    https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/preprocessing_molmo.py
    """

    def __init__(self, processor: ProcessorMixin):
        super().__init__()

        self.processor = processor

    @cached_property
    def vocab(self) -> dict[str, int]:
        return self.processor.tokenizer.vocab  # type: ignore

    @cached_property
    def max_crops(self) -> int:
        image_processor = self.processor.image_processor  # type: ignore

        max_crops = image_processor.max_crops
        assert isinstance(max_crops, int)

        return max_crops

    @cached_property
    def base_image_input_size(self) -> tuple[int, int]:
        image_processor = self.processor.image_processor  # type: ignore

        base_image_input_size = image_processor.base_image_input_size
        if isinstance(base_image_input_size, int):
            return base_image_input_size, base_image_input_size

        return tuple(base_image_input_size)

    @cached_property
    def image_patch_size(self) -> int:
        image_processor = self.processor.image_processor  # type: ignore

        image_patch_size = image_processor.image_patch_size
        assert isinstance(image_patch_size, int)

        return image_patch_size

    @cached_property
    def overlap_margins(self) -> tuple[int, int]:
        image_processor = self.processor.image_processor  # type: ignore

        left_margin, right_margin = image_processor.overlap_margins
        assert isinstance(left_margin, int)
        assert isinstance(right_margin, int)

        return left_margin, right_margin

    @cached_property
    def image_token_length_w(self) -> int:
        image_processor = self.processor.image_processor  # type: ignore

        image_token_length_w = image_processor.image_token_length_w
        assert isinstance(image_token_length_w, int)

        return image_token_length_w

    @cached_property
    def image_token_length_h(self) -> int:
        image_processor = self.processor.image_processor  # type: ignore

        image_token_length_h = image_processor.image_token_length_h
        assert isinstance(image_token_length_h, int)

        return image_token_length_h

    @property
    def message_format(self) -> Optional[str]:
        return "role"

    @property
    def always_start_with_space(self) -> bool:
        return True

    @cached_property
    def image_patch_id(self) -> int:
        return self.vocab[IMAGE_PATCH_TOKEN]

    @cached_property
    def im_col_id(self) -> int:
        return self.vocab[IM_COL_TOKEN]

    @cached_property
    def im_start_id(self) -> int:
        return self.vocab[IM_START_TOKEN]

    @cached_property
    def im_end_id(self) -> int:
        return self.vocab[IM_END_TOKEN]

    @property
    def pooling_size(self) -> int:
        return POOLING_SIZE

    def select_tiling(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> tuple[int, int]:
        max_crops = self.max_crops
        left_margin, right_margin = self.overlap_margins
        base_image_input_size = self.base_image_input_size
        base_image_input_d = self.image_patch_size

        total_margin_pixels = base_image_input_d * (right_margin + left_margin)
        crop_patches = base_image_input_size[0] // base_image_input_d
        crop_window_patches = crop_patches - (right_margin + left_margin)
        crop_window_size = crop_window_patches * base_image_input_d
        tiling_h, tiling_w = select_tiling(
            height=image_height - total_margin_pixels,
            width=image_width - total_margin_pixels,
            patch_size=crop_window_size,
            max_num_patches=max_crops,
1114
1115
        )

1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
        return tiling_w, tiling_h

    def get_patches_grid_size(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> tuple[int, int]:
        left_margin, right_margin = self.overlap_margins
        base_image_input_size = self.base_image_input_size
        base_image_input_d = self.image_patch_size
        pooling_size = self.pooling_size

        crop_patches = base_image_input_size[0] // base_image_input_d
        tiling_w, tiling_h = self.select_tiling(
            image_height=image_height,
            image_width=image_width,
        )

        nrows, ncols = get_patches_grid_size(
            tiling_h=tiling_h,
            tiling_w=tiling_w,
            crop_patches=crop_patches,
            left_margin=left_margin,
            right_margin=right_margin,
            pooling_size=pooling_size,
        )

        return ncols, nrows

    def __call__(
        self,
        text: Optional[Union[TextInput, list[TextInput]]] = None,
        images: Optional[Union[ImageInput, list[ImageInput]]] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        **kwargs,
    ) -> BatchFeature:
        outputs = self.processor.process(  # type: ignore
1154
1155
            text, images, **kwargs
        )
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166

        if images is None:
            images = []
        if not isinstance(images, list):
            images = [images]

        input_ids: torch.Tensor = outputs.pop("input_ids")
        outputs["input_ids"] = input_ids.unsqueeze(0)

        image_input_idx = outputs.pop("image_input_idx", None)
        if image_input_idx is not None:
1167
            feat_is_patch = image_input_idx >= 0
1168
1169
1170
1171
1172

            tilings = [
                self.select_tiling(
                    image_width=image.size[0],
                    image_height=image.size[1],
1173
1174
                )
                for image in images
1175
1176
1177
1178
            ]
            # For each image: tiling_h * tiling_w + extra
            num_crops = torch.tensor(tilings).prod(-1) + 1
            assert num_crops.sum() == len(feat_is_patch)
1179

1180
1181
1182
1183
            outputs["feat_is_patch"] = feat_is_patch
            outputs["num_crops"] = num_crops
            outputs["img_patch_id"] = self.image_patch_id

1184
        return BatchFeature(outputs)
1185
1186
1187


class MolmoProcessingInfo(BaseProcessingInfo):
1188
1189
    def get_hf_processor(self, **kwargs: object) -> MolmoProcessorWrapper:
        processor = self.ctx.get_hf_processor(**kwargs)
1190
1191
1192
        return MolmoProcessorWrapper(processor)

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
1193
        return {"image": None}
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        processor: Optional[MolmoProcessorWrapper],
    ) -> int:
        if processor is None:
            processor = self.get_hf_processor()

        ncols, nrows = processor.get_patches_grid_size(
            image_width=image_width,
            image_height=image_height,
        )
        pooling_size = processor.pooling_size

1211
1212
        image_token_length_w = processor.image_token_length_w
        image_token_length_h = processor.image_token_length_h
1213

1214
1215
        extra = image_token_length_w * image_token_length_h
        joint = ((ncols + 1) // pooling_size) * ((nrows + 1) // pooling_size)
1216

1217
        return extra + joint
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235

    def get_image_size_with_most_features(self) -> ImageSize:
        processor = self.get_hf_processor()

        tilings = get_candidate_tilings(processor.max_crops)
        base_h, base_w = processor.base_image_input_size

        largest_feature_size, largest_feature_pinpoint = 0, None
        for wr, hr in tilings:
            width, height = base_w * wr, base_h * hr

            feat_size = self.get_num_image_tokens(
                image_width=width,
                image_height=height,
                processor=processor,
            )
            if feat_size > largest_feature_size:
                largest_feature_size = feat_size
1236
                largest_feature_pinpoint = ImageSize(width=width, height=height)
1237
1238
1239
1240
1241
1242
1243
1244

        if largest_feature_size == 0 or largest_feature_pinpoint is None:
            raise ValueError("Cannot have a largest feature size of 0!")

        return largest_feature_pinpoint


class MolmoDummyInputsBuilder(BaseDummyInputsBuilder[MolmoProcessingInfo]):
1245
1246
1247
1248
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
1249
1250
1251
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1252
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
1253
    ) -> MultiModalDataDict:
1254
        target_width, target_height = self.info.get_image_size_with_most_features()
1255
1256
        num_images = mm_counts.get("image", 0)

1257
1258
        image_overrides = mm_options.get("image") if mm_options else None

1259
        return {
1260
1261
1262
1263
1264
1265
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
        }


class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        processor = self.info.get_hf_processor()

        # Apply the chat template to the tokens
        tokens = processor.processor.get_tokens_input(  # type: ignore
            self.info.get_tokenizer().decode(prompt_tokens),
            message_format=processor.message_format,
            always_start_with_space=processor.always_start_with_space,
        )

        processed_data = self.info.ctx.call_hf_processor(
            processor,  # type: ignore
            dict(tokens=tokens),
        )
1287
        (prompt_ids,) = processed_data.pop("input_ids").tolist()
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300

        return prompt_ids

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        num_crops = hf_inputs.get("num_crops", torch.empty(0))
        num_images = len(num_crops)

        return dict(
            images=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
1301
1302
            image_masks=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
            feat_is_patch=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
1303
1304
1305
1306
            num_crops=MultiModalFieldConfig.batched("image"),
            img_patch_id=MultiModalFieldConfig.shared("image", num_images),
        )

1307
    def _get_prompt_updates(
1308
1309
1310
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1311
        out_mm_kwargs: MultiModalKwargsItems,
1312
    ) -> Sequence[PromptUpdate]:
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

        image_token_length_w = processor.image_token_length_w
        image_token_length_h = processor.image_token_length_h
        pooling_size = processor.pooling_size

        img_patch_id = processor.image_patch_id
        img_col_id = processor.im_col_id
        img_start_id = processor.im_start_id
        img_end_id = processor.im_end_id

        extra_row = [img_patch_id] * image_token_length_w + [img_col_id]
1325
        extra_joint = [img_start_id] + extra_row * image_token_length_h + [img_end_id]
1326

1327
        def get_insertion_molmo(item_idx: int):
1328
1329
1330
1331
1332
1333
1334
1335
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)

            ncols, nrows = processor.get_patches_grid_size(
                image_width=image_size.width,
                image_height=image_size.height,
            )

1336
1337
1338
1339
1340
1341
            joint_row = [img_patch_id] * ((ncols + 1) // pooling_size) + [img_col_id]
            joint = (
                [img_start_id]
                + joint_row * ((nrows + 1) // pooling_size)
                + [img_end_id]
            )
1342

1343
1344
1345
1346
            return PromptUpdateDetails.select_token_id(
                extra_joint + joint,
                embed_token_id=img_patch_id,
            )
1347
1348

        return [
1349
            PromptInsertion(
1350
                modality="image",
1351
                target=PromptIndexTargets.prefix("<|endoftext|>"),
1352
                insertion=get_insertion_molmo,
1353
1354
1355
1356
            )
        ]


1357
1358
1359
1360
1361
1362
1363
1364
@MULTIMODAL_REGISTRY.register_processor(
    MolmoMultiModalProcessor,
    info=MolmoProcessingInfo,
    dummy_inputs=MolmoDummyInputsBuilder,
)
class MolmoForCausalLM(
    nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsQuant
):
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            # vision backbone mapping
            "image_projector.w1.": "image_projector.gate_proj.",
            "image_projector.w3.": "image_projector.up_proj.",
            "image_projector.w2.": "image_projector.down_proj.",
            # language backbone mapping
            "att_proj": "self_attn.qkv_proj",
            "attn_out": "self_attn.o_proj",
            "q_norm": "self_attn.q_norm",
            "k_norm": "self_attn.k_norm",
            "ff_proj": "mlp.gate_up_proj",
            "ff_out": "mlp.down_proj",
            "attn_norm": "input_layernorm",
            "ff_norm": "post_attention_layernorm",
        },
        orig_to_new_prefix={
            # vision backbone mapping
            "model.vision_backbone.": "vision_backbone.",
            # language backbone mapping
            "model.transformer.blocks.": "model.layers.",
            "model.transformer.ln_f.": "model.norm.",
            # lm_head is renamed to model.transformer.mlp.down_proj firstly,
            # we need to run a second renaming for it
            "model.transformer.mlp.down_proj.": "lm_head.",
        },
    )

1393
1394
1395
    packed_modules_mapping = {
        "qkv_proj": ["qkv_proj"],
        "gate_up_proj": ["gate_up_proj"],  # language model
1396
        "merged_linear": ["gate_proj", "up_proj"],  # image_projector
1397
1398
    }

1399
1400
1401
1402
1403
1404
1405
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

1406
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1407
        super().__init__()
1408
1409
1410
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
1411
        lora_config = vllm_config.lora_config
1412
1413
        self.config = config
        self.multimodal_config = multimodal_config
1414
        self.lora_config = lora_config
1415
1416

        vision_config = VisionBackboneConfig()
1417
1418
1419
1420
        self.vision_backbone = MolmoVisionBackbone(config, vision_config, quant_config)
        self.model = MolmoModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
1421
        self.img_patch_id = None
1422
1423
1424
1425
1426
1427
1428
1429

        if self.config.weight_tying:
            self.lm_head = self.model.transformer.wte
        else:
            self.lm_head = ParallelLMHead(
                config.embedding_size or config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
1430
                prefix=maybe_prefix(prefix, "lm_head"),
1431
1432
            )

1433
1434
1435
        self.logits_processor = LogitsProcessor(
            config.embedding_size or config.vocab_size
        )
1436

1437
        self.make_empty_intermediate_tensors = (
1438
1439
            self.model.make_empty_intermediate_tensors
        )
1440

1441
1442
1443
1444
1445
    def _parse_and_validate_image_input(
        self,
        **kwargs: object,
    ) -> Optional[MolmoImageInputs]:
        images = kwargs.pop("images", None)
1446
1447
1448
        image_masks = kwargs.pop("image_masks", None)
        feat_is_patch = kwargs.pop("feat_is_patch", None)
        num_crops = kwargs.pop("num_crops", None)
1449
1450
1451
1452

        if images is None:
            return None

1453
        if not isinstance(num_crops, (torch.Tensor, list)):
1454
1455
1456
            raise ValueError(
                f"Incorrect type of num_crops. Got type: {type(num_crops)}"
            )
1457
        num_crops = flatten_bn(num_crops, concat=True)
1458
1459
1460

        img_patch_id = kwargs.pop("img_patch_id", None)
        if not isinstance(img_patch_id, torch.Tensor):
1461
1462
1463
            raise ValueError(
                f"Incorrect type of img_patch_id. Got type: {type(img_patch_id)}"
            )
1464
        self.img_patch_id = img_patch_id.flatten().unique().item()
1465
1466
1467
1468

        return MolmoImageInputs(
            images=images,
            image_masks=image_masks,
1469
1470
            feat_is_patch=feat_is_patch,
            num_crops=num_crops,
1471
1472
1473
1474
1475
        )

    def _process_image_input(
        self,
        image_input: MolmoImageInputs,
1476
1477
1478
1479
1480
1481
    ) -> list[torch.Tensor]:
        images = image_input["images"]
        image_masks = image_input["image_masks"]
        feat_is_patch = image_input["feat_is_patch"]
        num_crops = image_input["num_crops"]

1482
1483
        # Call the vision backbone on the whole batch at once
        images_flat = flatten_bn(images, concat=True)
1484
1485
1486
        image_masks_flat = (
            None if image_masks is None else flatten_bn(image_masks, concat=True)
        )
1487
1488
1489
1490
        feat_is_patch_flat = flatten_bn(feat_is_patch, concat=True)

        image_features_flat = self.vision_backbone(
            images=images_flat.unsqueeze(0),
1491
1492
1493
            image_masks=(
                None if image_masks_flat is None else image_masks_flat.unsqueeze(0)
            ),
1494
        ).squeeze(0)
1495

1496
1497
        # Only the features corresponding to patch tokens are relevant
        return [
1498
1499
            feats[f_is_patch]
            for feats, f_is_patch in zip(
1500
1501
1502
                image_features_flat.split(num_crops.tolist()),
                feat_is_patch_flat.split(num_crops.tolist()),
            )
1503
        ]
1504

1505
1506
1507
    def get_language_model(self) -> torch.nn.Module:
        return self.model

1508
    def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
1509
1510
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
1511
            return []
1512

1513
        return self._process_image_input(image_input)
1514
1515
1516
1517
1518

    def forward(
        self,
        input_ids: torch.LongTensor,
        positions: torch.LongTensor,
1519
        intermediate_tensors: Optional[IntermediateTensors] = None,
1520
        inputs_embeds: Optional[torch.Tensor] = None,
1521
        **kwargs: object,
1522
    ) -> torch.Tensor:
1523
1524
        if intermediate_tensors is not None:
            inputs_embeds = None
1525

1526
1527
1528
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
1529
1530
1531

        return hidden_states

1532
1533
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
1534
1535
        return logits

1536
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
1537
1538
        loader = AutoWeightsLoader(self)
        weights = _get_weights_with_merged_embedding(weights)
1539
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
1540

1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="model",
            connector="vision_backbone.image_projector",
            tower_model="vision_backbone",
        )

1551
1552

def _get_weights_with_merged_embedding(
1553
    weights: Iterable[tuple[str, torch.Tensor]],
1554
) -> Iterable[tuple[str, torch.Tensor]]:
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
    embedding_weights = {}
    for name, weight in weights:
        if "wte.embedding" in name:
            embedding_weights["embedding"] = weight
        elif "wte.new_embedding" in name:
            embedding_weights["new_embedding"] = weight
        else:
            yield (name, weight)
    # this is compatible with most of quantization,
    # because they won't quantize embed_tokens
    embedding_weights = torch.cat(
        [embedding_weights["embedding"], embedding_weights["new_embedding"]],
        dim=0,
    )
    yield ("model.embed_tokens.weight", embedding_weights)