molmo.py 50 KB
Newer Older
1
2
3
4
5
import math
import re
from array import array
from dataclasses import dataclass
from functools import lru_cache, partial
6
from typing import Iterable, List, Mapping, Optional, Set, Tuple, TypedDict
7
8
9
10
11
12
13
14
15

import torch
from einops import rearrange
from PIL import Image
from torch import nn
from torch.nn import functional as F
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionMetadata
16
from vllm.attention.layer import MultiHeadAttention
17
from vllm.compilation.decorators import support_torch_compile
18
from vllm.config import CacheConfig, VllmConfig
19
20
21
22
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather)
23
24
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
                         InputContext, token_inputs)
25
26
27
28
29
30
31
32
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
33
from vllm.model_executor.layers.quantization import QuantizationConfig
34
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
35
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
36
37
38
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
40
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
41
from vllm.multimodal.utils import cached_get_tokenizer
42
43
44
45
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
                           SequenceData)
from vllm.transformers_utils.processor import get_processor

46
from .interfaces import SupportsMultiModal, SupportsPP
47
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
48
                    make_empty_intermediate_tensors_factory, make_layers,
49
                    maybe_prefix, merge_multimodal_embeddings)
50
51
52
53
54

# TODO: hard-coded for now. Consider making it configurable.
VIT_LAYERS = [-2, -9]
NUM_PREFIX_TOKENS = 1
ADDITIONAL_VOCAB_SIZE = 128
55
56
57
58
DEFAULT_IMAGE_PATCH_TOKEN_ID = 152066
DEFAULT_IM_START_TOKEN_ID = 152067
DEFAULT_IM_END_TOKEN_ID = 152064
DEFAULT_IM_COL_TOKEN_ID = 152065
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81


class MolmoImageInputs(TypedDict):
    images: torch.Tensor
    """Shape:
    `(batch_size, num_crops, num_patch, patch_dim)`
    """

    image_input_idx: torch.Tensor
    """Shape:
    `(batch_size, num_crops, num_patch)`
    """

    seq_len: torch.Tensor
    """Shape:
    `(batch_size, )`
    """

    image_masks: Optional[torch.Tensor]
    """Shape:
    `(batch_size, num_crops, num_patch)`
    """

82
83
84
85
86
    image_start_end: Tuple[int, int]
    """Starting and ending index of placeholder 
    tokens
    """

87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198

@dataclass
class VisionBackboneConfig:
    image_default_input_size: Tuple[int, int] = (336, 336)
    image_patch_size: int = 14
    image_pos_patch_size: int = 14
    image_emb_dim: int = 1024
    image_num_heads: int = 16
    image_num_key_value_heads: int = 16
    image_num_layers: int = 23
    image_mlp_dim: int = 4096
    image_mlp_activations: str = "quick_gelu"
    image_num_pos: int = 577
    image_norm_eps: float = 1e-5

    def __post_init__(self):
        self.image_default_input_size = tuple(
            self.image_default_input_size)  # type: ignore[assignment]

    @property
    def image_num_patch(self):
        h, w = self.image_default_input_size
        return h // self.image_patch_size, w // self.image_patch_size


class ViTMLP(nn.Module):
    """MLP used in Vision Transformer."""

    def __init__(
        self,
        config: VisionBackboneConfig,
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()
        self.w1 = ColumnParallelLinear(
            config.image_emb_dim,
            config.image_mlp_dim,
            bias=True,
            quant_config=quant_config,
        )
        # Activation function.
        assert config.image_mlp_activations == "quick_gelu"
        self.act = QuickGELU()
        self.w2 = RowParallelLinear(
            config.image_mlp_dim,
            config.image_emb_dim,
            bias=True,
            quant_config=quant_config,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.w1(x)
        x = self.act(x)
        x, _ = self.w2(x)
        return x


class MultiHeadDotProductAttention(nn.Module):
    """Multi-head attention used in Vision Transformer."""

    def __init__(
        self,
        config: VisionBackboneConfig,
        use_bias: bool = True,
        nlayers: int = 1,
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()

        self.hidden_size = config.image_emb_dim
        self.total_num_heads = config.image_num_heads
        tp_size = get_tensor_model_parallel_world_size()

        assert self.hidden_size % self.total_num_heads == 0
        assert self.total_num_heads % tp_size == 0

        self.num_heads = self.total_num_heads // tp_size
        self.head_dim = self.hidden_size // self.total_num_heads

        self.total_num_kv_heads = config.image_num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            assert self.total_num_kv_heads % tp_size == 0
        else:
            assert tp_size % self.total_num_kv_heads == 0

        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)

        self.wq = ColumnParallelLinear(
            nlayers * self.hidden_size,
            self.total_num_heads * self.head_dim,
            bias=use_bias,
            quant_config=quant_config,
        )
        self.wk = ColumnParallelLinear(
            nlayers * self.hidden_size,
            self.total_num_kv_heads * self.head_dim,
            bias=use_bias,
            quant_config=quant_config,
        )
        self.wv = ColumnParallelLinear(
            nlayers * self.hidden_size,
            self.total_num_kv_heads * self.head_dim,
            bias=use_bias,
            quant_config=quant_config,
        )
        self.wo = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=use_bias,
            quant_config=quant_config,
        )

199
200
201
202
203
        self.scale = self.head_dim**-0.5
        self.attn = MultiHeadAttention(self.num_heads,
                                       self.head_dim,
                                       self.scale,
                                       num_kv_heads=self.num_kv_heads)
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218

    def forward(self,
                inputs_q: torch.Tensor,
                inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor:

        if inputs_kv is not None:
            inputs_k = inputs_kv
            inputs_v = inputs_kv
        else:
            inputs_k = inputs_q
            inputs_v = inputs_q

        xq, _ = self.wq(inputs_q)
        xk, _ = self.wk(inputs_k)
        xv, _ = self.wv(inputs_v)
219
220

        output = self.attn(xq, xk, xv)
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
        output, _ = self.wo(output)

        return output


class ResidualAttentionBlock(nn.Module):
    """Residual attention block used in Vision Transformer."""

    def __init__(
        self,
        config: VisionBackboneConfig,
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()
        self.attention = MultiHeadDotProductAttention(
            config, quant_config=quant_config)
        self.feed_forward = ViTMLP(config, quant_config)
        self.attention_norm = nn.LayerNorm(
            config.image_emb_dim,
            eps=config.image_norm_eps,
        )
        self.ffn_norm = nn.LayerNorm(
            config.image_emb_dim,
            eps=config.image_norm_eps,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attention(self.attention_norm(x))
        x = x + self.feed_forward(self.ffn_norm(x))
        return x


class BlockCollection(nn.Module):
    """Collection of residual attention blocks used in Vision Transformer."""

    def __init__(
        self,
        config: VisionBackboneConfig,
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()
        self.resblocks = nn.ModuleList([
            ResidualAttentionBlock(config, quant_config)
            for _ in range(config.image_num_layers)
        ])

    def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
        hidden_states = []
        for r in self.resblocks:
            x = r(x)
            hidden_states.append(x)
        return hidden_states


def _expand_token(token: torch.Tensor, batch_size: int) -> torch.Tensor:
    return token.view(1, 1, -1).expand(batch_size, -1, -1)


class VisionTransformer(nn.Module):
    """Vision Transformer used in Vision Backbone."""

    def __init__(
        self,
        config: VisionBackboneConfig,
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()
        scale = config.image_emb_dim**-0.5
        self.patch_num = config.image_num_patch
        self.class_embedding = nn.Parameter(
            torch.randn(config.image_emb_dim) * scale)
        self.num_prefix_tokens: int = NUM_PREFIX_TOKENS
        self.positional_embedding = nn.Parameter(
            torch.randn(config.image_num_pos, config.image_emb_dim) * scale)
        image_patch_size = config.image_patch_size
        self.patch_embedding = nn.Linear(
            image_patch_size * image_patch_size * 3,
            config.image_emb_dim,
            bias=False,
        )
        self.pre_ln = nn.LayerNorm(config.image_emb_dim,
                                   eps=config.image_norm_eps)
        self.transformer = BlockCollection(config, quant_config)

    def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor:
        cls_emb = self.positional_embedding[0:1]
        pos_emb = self.positional_embedding[1:]

        pos_emb = pos_emb.reshape(
            (int(math.sqrt(pos_emb.shape[0])),
             int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1]))

        (patch_num_0, patch_num_1) = patch_num

        if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1:
            # from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
            pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2)
            pos_emb = F.interpolate(
                pos_emb,
                size=(patch_num_0, patch_num_1),
                mode="bicubic",
                align_corners=False,
                antialias=True,
            )
            pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0)

        pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1])
        x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]],
                          dim=1).to(x.dtype)
        return x

    def forward(self,
                x: torch.Tensor,
                patch_num: int = None) -> List[torch.Tensor]:
        """
        : param x: (batch_size, num_patch, n_pixels)
        """
        if patch_num is None:
            patch_num = self.patch_num
        B, N, D = x.shape

        x = self.patch_embedding(x)

        # class embeddings and positional embeddings
        x = torch.cat(
            [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x],
            dim=1)
        x = self.add_pos_emb(x, patch_num)

        x = self.pre_ln(x)

        hidden_states = self.transformer(x)
        return hidden_states


class MolmoAttention(nn.Module):
    """Molmo's LLM attention."""

    def __init__(
        self,
        config: PretrainedConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
364
        prefix: str = "",
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        self.tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads

        assert self.hidden_size % self.total_num_heads == 0
        assert self.total_num_heads % self.tp_size == 0

        self.num_heads = self.total_num_heads // self.tp_size
        self.total_num_kv_heads = config.num_key_value_heads \
            or self.total_num_heads
        if self.total_num_kv_heads >= self.tp_size:
            assert self.total_num_kv_heads % self.tp_size == 0
        else:
            assert self.tp_size % self.total_num_kv_heads == 0

        self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
        self.head_dim = self.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta

        # Attention input projection. Projects x -> (q, k, v)
        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=config.qkv_bias,
            quant_config=quant_config,
        )

        self.tp_rank: Optional[int] = None
        self.k_norm: Optional[nn.Module] = None
        self.q_norm: Optional[nn.Module] = None
        if config.attention_layer_norm:
            self.tp_rank = get_tensor_model_parallel_rank()
            self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim,
                                  eps=config.layer_norm_eps)
            self.q_norm = RMSNorm(config.hidden_size,
                                  eps=config.layer_norm_eps)

        # Rotary embeddings.
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=self.rope_theta,
        )
        self.scaling = self.head_dim**-0.5
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
422
423
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463

        # Attention output projection.
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
        )

    def _apply_qk_norm(self, q: torch.Tensor,
                       k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.tp_size > 1:
            q = tensor_model_parallel_all_gather(q.contiguous())
            k = tensor_model_parallel_all_gather(k.contiguous())
        q = self.q_norm.forward_native(q)
        k = self.k_norm.forward_native(k)
        if self.tp_size > 1:
            splitter = partial(split_tensor_along_last_dim,
                               num_partitions=self.tp_size)
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
        return q, k

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        if self.q_norm is not None and self.k_norm is not None:
            q, k = self._apply_qk_norm(q, k)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
        output, _ = self.o_proj(attn_output)
        return output


464
465
466
467
468
469
470
471
472
473
class SwiGLU(nn.Module):

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, gate = x.chunk(2, dim=-1)
        # Note that the order is reversed compared to
        # SiluAndMul.
        return x * F.silu(gate)


class LanuageModelMLP(nn.Module):
474
475
    """Molmo's LLM mlp."""

476
477
478
    def __init__(self,
                 config: PretrainedConfig,
                 input_dim: Optional[int] = None,
479
                 quant_config: Optional[QuantizationConfig] = None) -> None:
480
481
482
483
        super().__init__()
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size // 2

484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
        self.gate_up_proj = MergedColumnParallelLinear(
            input_dim or self.hidden_size,
            [self.intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
        )
        # Activation function.
        self.act_fn = SwiGLU()
        # Feed-forward output projection.
        self.down_proj = RowParallelLinear(
            self.intermediate_size,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
        )

    def forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class ImageProjectorMLP(nn.Module):
    """Molmo's image_projector mlp."""

    def __init__(
        self,
        config: PretrainedConfig,
        input_dim: Optional[int] = None,
        quant_config: Optional[QuantizationConfig] = None,
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size // 2

        self.merged_linear = MergedColumnParallelLinear(
            input_dim or self.hidden_size,
            [self.intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
        )
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
        # Activation function.
        self.act_fn = SiluAndMul()

        # Feed-forward output projection.
        self.down_proj = RowParallelLinear(
            self.intermediate_size,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
        )

    def forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
544
        gate_up, _ = self.merged_linear(x)
545
546
547
548
549
550
551
552
553
554
555
556
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class MolmoDecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
557
        prefix: str = "",
558
559
560
    ) -> None:
        super().__init__()
        # Attention block.
561
562
563
564
        self.self_attn = MolmoAttention(config,
                                        cache_config,
                                        quant_config,
                                        prefix=f"{prefix}.self_attn")
565
566

        # MLP block.
567
        self.mlp = LanuageModelMLP(config, quant_config=quant_config)
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658

        # LayerNorm
        assert config.layer_norm_type == "rms"
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.layer_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.layer_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            attn_metadata=attn_metadata,
        )

        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
        # Self Attention
        residual = hidden_states
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            attn_metadata=attn_metadata,
        )

        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = hidden_states + residual
        residual = hidden_states

        hidden_states = self.mlp(hidden_states)
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = hidden_states + residual
        residual = None
        return hidden_states, residual


class MolmoVisionBackbone(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        vision_config: VisionBackboneConfig,
        quant_config: Optional[QuantizationConfig] = None,
    ) -> None:
        super().__init__()
        self.vit_layers = VIT_LAYERS
        self.image_num_patch = vision_config.image_num_patch
        self.llm_patches_per_crop = (
            (self.image_num_patch[0] + 1) // 2,
            (self.image_num_patch[1] + 1) // 2,
        )
        self.image_vit = VisionTransformer(vision_config,
                                           quant_config=quant_config)
        self.num_prefix_tokens = self.image_vit.num_prefix_tokens
        assert self.num_prefix_tokens in {
            0, 1
        }, "Only 0 or 1 prefix tokens are supported"
        self.image_pooling_2d = MultiHeadDotProductAttention(
            vision_config,
            nlayers=len(self.vit_layers),
            quant_config=quant_config)
659
        self.image_projector = ImageProjectorMLP(
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
            config,
            input_dim=vision_config.image_emb_dim,
            quant_config=quant_config,
        )

        image_dim = vision_config.image_emb_dim * len(self.vit_layers)
        self.pad_embed = nn.Parameter(torch.zeros((2, image_dim)))

    @property
    def dtype(self) -> torch.dtype:
        return self.image_vit.patch_embedding.weight.dtype

    @property
    def device(self) -> torch.device:
        return self.image_vit.patch_embedding.weight.device

    def encode_image(self, images: torch.Tensor) -> torch.Tensor:
        """
        : param images: (batch_size, num_crops, num_patch, n_pixels)
        """
        B, T, N, D = images.shape

        mask = ~torch.all(
            images.view(B * T, N, D) == -1, dim=(1, 2), keepdim=True)

        images = images.view(B * T, N, D)
        image_features = self.image_vit(images)

        if self.vit_layers is not None:
            features = []
            for layer in self.vit_layers:
                features.append(image_features[layer])
            image_features = torch.cat(features, dim=-1)
        else:
            image_features = image_features[-1]

        if self.num_prefix_tokens > 0:
            image_features = image_features[:, 1:]

        image_features = image_features * mask
        image_features = image_features.view(B, T, N, -1)

        return image_features

    def forward(
        self, images: torch.Tensor, image_masks: torch.Tensor
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:

        # image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) # noqa: E501
        batch_size, num_image = images.shape[:2]
        images = images.to(device=self.device, dtype=self.dtype)
        image_features = self.encode_image(images)

        og_dtype = image_features.dtype
        assert image_masks is not None
        pad_embed = self.pad_embed[:, None, None, None, :]
        all_pad = image_masks == 0
        partial_pad = torch.logical_and(
            image_masks < 1,
            torch.logical_not(all_pad)).to(dtype=torch.float32)
        all_pad = all_pad.to(dtype=torch.float32)
        image_features = image_features + pad_embed[0] * torch.unsqueeze(
            all_pad, -1)
        image_features = image_features + pad_embed[1] * torch.unsqueeze(
            partial_pad, -1)

        image_features = image_features.to(og_dtype)

        image_features = image_features.reshape(
            (batch_size, num_image) + self.image_num_patch + (-1, ), )

        if self.image_num_patch[0] % 2 == 1:
            # Pad so we can still pool 2x2 patches
            image_features = F.pad(
                image_features,
                (0, 0, 0, 1, 0, 1, 0, 0, 0, 0),
            )

        # image pooling
        image_features = rearrange(
            image_features,
            'b n (h dh) (w dw) c -> (b n h w) (dh dw) c',
            dh=2,
            dw=2,
        )

        query = image_features.mean(-2, keepdim=True)
        image_features = self.image_pooling_2d(query, image_features)

        h, w = self.llm_patches_per_crop
        image_features = image_features.view(batch_size, num_image, h * w, -1)

        image_features = self.image_projector(image_features)

        # image_features: (batch_size, num_image, num_patch, d_model)
        return image_features

757
758
759
760
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
761
762
            ("merged_linear", "gate_proj", 0),
            ("merged_linear", "up_proj", 1),
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: Set[str] = set()

        for name, loaded_weight in weights:
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

793

794
@support_torch_compile
795
796
class MolmoModel(nn.Module):

797
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
798
        super().__init__()
799
800
801
802
803

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

804
805
806
807
808
809
810
811
812
813
814
815
816
817
        self.config = config

        self.embedding_size = config.embedding_size or config.vocab_size
        self.embedding_size += ADDITIONAL_VOCAB_SIZE
        self.embed_tokens = VocabParallelEmbedding(
            self.embedding_size,
            config.hidden_size,
            quant_config=quant_config,
        )

        decoder_layer = MolmoDecoderNormAfterLayer if config.norm_after \
            else MolmoDecoderLayer
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
818
819
            lambda prefix: decoder_layer(
                config, cache_config, quant_config, prefix=prefix),
820
821
822
823
824
825
            prefix=f"{prefix}.layers",
        )

        assert config.layer_norm_type == "rms"
        self.norm = RMSNorm(config.hidden_size, config.layer_norm_eps)

826
827
828
829
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))

830
831
832
833
834
835
    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
    ) -> torch.Tensor:
        return self.embed_tokens(input_ids)

836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.embed_tokens(input_ids)
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        # Apply blocks one-by-one.
        for i in range(self.start_layer, self.end_layer):
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
                kv_caches[i - self.start_layer],
                attn_metadata,
                residual,
            )
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
        if residual is not None:
            hidden_states, _ = self.norm(hidden_states, residual)
        else:
            hidden_states = self.norm(hidden_states)
        return hidden_states

877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
        params_dict = dict(self.named_parameters())
        loaded_params: Set[str] = set()

        for name, loaded_weight in weights:
            if name.endswith(".bias") and name not in params_dict:
                continue
            if is_pp_missing_parameter(name, self):
                continue

            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947

cached_get_processor = lru_cache(get_processor)


def get_num_patches(num_tiles: int, crop_patches: int, left_margin: int,
                    right_margin: int, pooling_size: int) -> int:
    crop_window_patches = crop_patches - (left_margin + right_margin)
    if num_tiles > 1:
        left_crop_window_patches = (crop_window_patches + left_margin +
                                    pooling_size -
                                    1) // pooling_size * pooling_size
        middle_crop_window_patches = (crop_window_patches + pooling_size -
                                      1) // pooling_size * pooling_size
        right_crop_window_patches = (crop_window_patches + right_margin +
                                     pooling_size -
                                     1) // pooling_size * pooling_size
        return left_crop_window_patches + (
            num_tiles -
            2) * middle_crop_window_patches + right_crop_window_patches
    else:
        single_crop_window_patches = (crop_patches + pooling_size -
                                      1) // pooling_size * pooling_size
        return single_crop_window_patches


def get_tokens(tiling_h: int, tiling_w: int, crop_patches: int,
               left_margin: int, right_margin: int, pooling_size: int) -> int:
    h = get_num_patches(tiling_h, crop_patches, left_margin, right_margin,
                        pooling_size)
    w = get_num_patches(tiling_w, crop_patches, left_margin, right_margin,
                        pooling_size)
    per_row = w // pooling_size + 1
    joint = per_row * (h // pooling_size) + 2
    image_token_length = (crop_patches + pooling_size - 1) // pooling_size
    resize = (image_token_length + 1) * image_token_length + 2
    return resize + joint


def get_max_tokens(max_crops: int, crop_patches: int, left_margin: int,
                   right_margin: int, pooling_size: int) -> int:
    tilings = []
    for i in range(1, max_crops + 1):
        for j in range(1, max_crops + 1):
            if i * j <= max_crops:
                tilings.append((i, j))
    tokens = [
        get_tokens(tilings[i][0], tilings[i][1], crop_patches, left_margin,
                   right_margin, pooling_size) for i in range(len(tilings))
    ]
    return max(tokens)


def get_max_molmo_image_tokens(ctx: InputContext) -> int:
948
949
950
951
    processor = cached_get_processor(
        ctx.model_config.model,
        trust_remote_code=ctx.model_config.trust_remote_code,
        revision=ctx.model_config.code_revision)
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
    image_processor = processor.image_processor
    max_llm_image_tokens = get_max_tokens(
        image_processor.max_crops,
        image_processor.base_image_input_size[0] //
        image_processor.image_patch_size,
        image_processor.overlap_margins[0],
        image_processor.overlap_margins[1],
        2,
    )
    return max_llm_image_tokens


# NOTE: preprocessing for the image data has been included in the
# 'input_processor_for_molmo' function
def image_input_mapper_for_molmo(
    ctx: InputContext,
    data: object,
):
970
    if isinstance(data, list):
971
        assert len(data) == 1, "Molmo supports only one image per prompt."
972
        data = data[0]
973
974
975

    # Remove unused dummy PIL image
    data.pop('raw_mm_data', None)
976
    return MultiModalKwargs(data)
977
978
979
980


def dummy_data_for_molmo(ctx: InputContext, seq_len: int,
                         mm_counts: Mapping[str, int]):
981
982
983
984
    processor = cached_get_processor(
        ctx.model_config.model,
        trust_remote_code=ctx.model_config.trust_remote_code,
        revision=ctx.model_config.code_revision)
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
    image_processor = processor.image_processor

    base_image_input_d = image_processor.image_patch_size
    left_margin, right_margin = image_processor.overlap_margins
    max_crops = image_processor.max_crops

    # Assume: prompt_token_ids always starts with bos_token_id followed image tokens # noqa: E501
    max_llm_image_tokens = get_max_molmo_image_tokens(ctx)
    if seq_len - max_llm_image_tokens - 1 < 0:
        raise RuntimeError(
            f"Molmo cannot process {max_crops} crops in a prompt, "
            "please increase max_model_len or reduce number of crops")

    # The vertical image has the maximum number of image tokens due to column tokens. # noqa: E501
    tiling = (max_crops, 1)
    total_margin_pixels = base_image_input_d * (right_margin + left_margin)
    crop_patches = image_processor.base_image_input_size[
        0] // base_image_input_d
    crop_window_patches = crop_patches - (right_margin + left_margin)
    crop_window_size = crop_window_patches * base_image_input_d

    h = crop_window_size * tiling[0] + total_margin_pixels
    w = crop_window_size * tiling[1] + total_margin_pixels

    dummy_image = Image.new("RGB", (w, h), color="red")

    out = processor.process("dummy prompt", dummy_image)

    token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                      out["input_ids"][:1 + max_llm_image_tokens])
    token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
                       [0]) * (seq_len - max_llm_image_tokens - 1)
    dummy_seqdata = SequenceData(token_ids)
    dummy_imgdata = {
        "images": out["images"],
        "image_input_idx": out["image_input_idx"],
1021
        "raw_mm_data": dummy_image,
1022
1023
1024
1025
    }
    if "image_masks" in out:
        dummy_imgdata["image_masks"] = out["image_masks"]
    dummy_imgdata["seq_len"] = torch.tensor(seq_len, dtype=torch.long)
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
    size = 0
    offset = -1
    for i in range(len(token_ids)):
        if token_ids[i] in (DEFAULT_IMAGE_PATCH_TOKEN_ID,
                            DEFAULT_IM_START_TOKEN_ID, DEFAULT_IM_END_TOKEN_ID,
                            DEFAULT_IM_COL_TOKEN_ID):
            if offset < 0:
                offset = i
            size += 1
    dummy_imgdata["image_start_end"] = (offset, offset + size)
    return DummyData(seq_data=dummy_seqdata,
                     multi_modal_data={"image": dummy_imgdata},
                     multi_modal_placeholders={
                         "image":
                         [PlaceholderRange(offset=offset, length=size)]
                     })
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057


def pad_images(
    max_total_crops: int,
    images: torch.Tensor,
    image_input_idx: torch.Tensor,
    image_masks: Optional[torch.Tensor] = None,
):
    n = max_total_crops - images.shape[0]
    images = F.pad(images, (0, 0, 0, 0, 0, n), value=-1)
    image_input_idx = F.pad(image_input_idx, (0, 0, 0, n), value=-1)
    if image_masks is not None:
        image_masks = F.pad(image_masks, (0, 0, 0, n), value=-1)
    return images, image_input_idx, image_masks


1058
def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
1059
1060
1061
1062
1063
    prompt = inputs.get("prompt")
    multi_modal_data = inputs.get("multi_modal_data")
    image = None if multi_modal_data is None else multi_modal_data.get("image")

    model_config = ctx.model_config
1064
1065
1066
1067
    processor = cached_get_processor(
        ctx.model_config.model,
        trust_remote_code=model_config.trust_remote_code,
        revision=ctx.model_config.code_revision)
1068
1069
1070
1071
    tokenizer = cached_get_tokenizer(
        model_config.tokenizer,
        trust_remote_code=model_config.trust_remote_code)

1072
1073
1074
1075
1076
1077
1078
1079
1080
    # NOTE: message formatting for raw text prompt is only applied for
    # offline inference; for online inference, the prompt is always in
    # instruction format and tokenized.
    if prompt is not None and re.match(r"^User:[\s\S]*?(Assistant:)*$",
                                       prompt):
        out = processor.process(prompt, image, message_format="none")
    elif prompt is not None:
        out = processor.process(prompt, image)
    else:
1081
        out = processor.process(None, image, tokens=inputs["prompt_token_ids"])
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128

    image_processor = processor.image_processor
    max_total_crops = 1 + image_processor.max_crops
    if image is not None:
        images, image_input_idx, image_masks = pad_images(
            max_total_crops,
            out["images"],
            out["image_input_idx"],
            out.get("image_masks"),
        )
    else:
        base_image_input_size = image_processor.base_image_input_size
        image_patch_size = image_processor.image_patch_size
        image_num_patch = (
            base_image_input_size[0] // image_patch_size,
            base_image_input_size[1] // image_patch_size,
        )
        n_pixels = image_patch_size * image_patch_size * 3
        n_patches = image_num_patch[0] * image_num_patch[1]

        image_length_w = image_processor.image_token_length_w
        image_length_h = image_processor.image_token_length_h
        tokens_per_image = image_length_w * image_length_h
        images = torch.full(
            (max_total_crops, n_patches, n_pixels),
            -1,
            dtype=torch.float32,
        )
        image_input_idx = torch.full(
            (max_total_crops, tokens_per_image),
            -1,
            dtype=torch.int32,
        )
        if image_processor.image_padding_mask:
            image_masks = torch.full(
                (max_total_crops, n_patches),
                -1,
                dtype=torch.float32,
            )

    image_data = dict(
        images=images,
        image_input_idx=image_input_idx,
    )
    if image_masks is not None:
        image_data["image_masks"] = image_masks

1129
1130
    new_prompt_token_ids = out["input_ids"].tolist()
    image_data["seq_len"] = torch.tensor(len(new_prompt_token_ids),
1131
1132
1133
                                         dtype=torch.long)

    multi_modal_data = dict(image=image_data)
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
    size = 0
    offset = -1
    for i in range(len(new_prompt_token_ids)):
        if new_prompt_token_ids[i] in (DEFAULT_IMAGE_PATCH_TOKEN_ID,
                                       DEFAULT_IM_START_TOKEN_ID,
                                       DEFAULT_IM_END_TOKEN_ID,
                                       DEFAULT_IM_COL_TOKEN_ID):
            if offset < 0:
                offset = i
            size += 1
    image_data["image_start_end"] = (offset, offset + size)
1145

1146
1147
    prompt = inputs.get("prompt")
    if prompt is None:
1148
        prompt = tokenizer.decode(new_prompt_token_ids)
1149

1150
    return token_inputs(
1151
        prompt_token_ids=new_prompt_token_ids,
1152
        prompt=prompt,
1153
        multi_modal_data=multi_modal_data,
1154
1155
1156
        multi_modal_placeholders={
            "image": [PlaceholderRange(offset=offset, length=size)]
        },
1157
1158
1159
1160
1161
1162
1163
    )


@MULTIMODAL_REGISTRY.register_image_input_mapper(image_input_mapper_for_molmo)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_molmo_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo)
@INPUT_REGISTRY.register_input_processor(input_processor_for_molmo)
1164
class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
1165

1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            # vision backbone mapping
            "image_projector.w1.": "image_projector.gate_proj.",
            "image_projector.w3.": "image_projector.up_proj.",
            "image_projector.w2.": "image_projector.down_proj.",
            # language backbone mapping
            "att_proj": "self_attn.qkv_proj",
            "attn_out": "self_attn.o_proj",
            "q_norm": "self_attn.q_norm",
            "k_norm": "self_attn.k_norm",
            "ff_proj": "mlp.gate_up_proj",
            "ff_out": "mlp.down_proj",
            "attn_norm": "input_layernorm",
            "ff_norm": "post_attention_layernorm",
        },
        orig_to_new_prefix={
            # vision backbone mapping
            "model.vision_backbone.": "vision_backbone.",
            # language backbone mapping
            "model.transformer.blocks.": "model.layers.",
            "model.transformer.ln_f.": "model.norm.",
            # lm_head is renamed to model.transformer.mlp.down_proj firstly,
            # we need to run a second renaming for it
            "model.transformer.mlp.down_proj.": "lm_head.",
        },
    )

1194
1195
1196
1197
1198
1199
    # BitandBytes specific attributes
    bitsandbytes_stacked_params_mapping = {
        "gate_proj": ("merged_linear", 0),
        "up_proj": ("merged_linear", 1),
    }

1200
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1201
        super().__init__()
1202
1203
1204
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
1205
1206
1207
1208
1209
1210
        self.config = config
        self.multimodal_config = multimodal_config

        vision_config = VisionBackboneConfig()
        self.vision_backbone = MolmoVisionBackbone(config, vision_config,
                                                   quant_config)
1211
1212
        self.model = MolmoModel(vllm_config=vllm_config,
                                prefix=maybe_prefix(prefix, "model"))
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224

        if self.config.weight_tying:
            self.lm_head = self.model.transformer.wte
        else:
            self.lm_head = ParallelLMHead(
                config.embedding_size or config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
            )

        self.logits_processor = LogitsProcessor(config.embedding_size
                                                or config.vocab_size)
Joe Runde's avatar
Joe Runde committed
1225
        self.sampler = get_sampler()
1226

1227
1228
1229
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)

1230
1231
1232
1233
1234
1235
    def _parse_and_validate_image_input(
        self,
        **kwargs: object,
    ) -> Optional[MolmoImageInputs]:
        images = kwargs.pop("images", None)
        image_masks = kwargs.pop("image_masks", None)
1236
        image_start_end = kwargs.pop("image_start_end", None)
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
        if images is None:
            return None

        image_input_idx = kwargs.pop("image_input_idx", None)
        seq_len = kwargs.pop("seq_len", None)
        if image_input_idx is None:
            raise ValueError("image_input_idx is required for Molmo model.")
        if seq_len is None:
            raise ValueError("seq_len is required for Molmo model.")
        if not isinstance(seq_len, torch.Tensor):
            seq_len = torch.tensor(seq_len)

        return MolmoImageInputs(
            images=images,
            image_input_idx=image_input_idx,
            seq_len=seq_len,
            image_masks=image_masks,
1254
            image_start_end=image_start_end,
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
        )

    def _process_image_input(
        self,
        image_input: MolmoImageInputs,
    ) -> torch.Tensor:

        image_features = self.vision_backbone(
            images=image_input["images"],
            image_masks=image_input["image_masks"],
        )

        return image_features

1269
1270
1271
1272
1273
1274
1275
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        image_features = self._process_image_input(image_input)
        image_input_idx = image_input["image_input_idx"]
        seq_len = image_input["seq_len"]
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
        batch_size, num_image, num_patch = image_features.shape[:3]
        assert image_input_idx.shape == (batch_size, num_image, num_patch)

        # insert the image feature into the embedding.
        image_features = image_features.view(batch_size, num_image * num_patch,
                                             -1)
        image_input_idx = image_input_idx.view(batch_size,
                                               num_image * num_patch)

        valid = image_input_idx >= 0
        image_features = image_features * valid[:, :, None].to(
            image_features.dtype)
        image_features = image_features.view(
            batch_size * num_image * num_patch, -1).contiguous()

        image_input_idx = image_input_idx * valid.to(image_input_idx.dtype)
1292
1293
1294
        offset = torch.cat([seq_len.new_zeros(1),
                            seq_len.cumsum(dim=0)[:-1]],
                           dim=0)[:, None]
1295
1296
1297
        image_input_idx = image_input_idx + offset.to(image_input_idx.dtype)
        image_input_idx = image_input_idx.flatten()[:, None]
        mat = image_input_idx == torch.arange(
1298
            seq_len.sum().item(), device=image_features.device)[None, :]
1299
1300
        mat = mat.to(image_features.dtype)

1301
1302
        # Note: In this original implementation from AI2, the final
        # vision_embeddings will be always be the same length
1303
        # of input embeddings.
1304
        vision_embeddings = torch.einsum('nd,nm->md', image_features, mat)
1305
1306
1307
1308
1309
1310
1311
1312

        # Split by the sizes of the input sequences. For each full embedding,
        # extract the actual vision embeddings to be merged.
        vision_embeddings = list(vision_embeddings.split(seq_len.tolist()))
        for i in range(len(vision_embeddings)):
            start, end = image_input['image_start_end'][i]
            vision_embeddings[i] = vision_embeddings[i][start:end]

1313
        return vision_embeddings
1314

1315
1316
1317
1318
1319
1320
1321
    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[NestedTensors] = None,
    ) -> torch.Tensor:
        inputs_embeds = self.model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None:
1322
1323
1324
1325
1326
            inputs_embeds = merge_multimodal_embeddings(
                input_ids, inputs_embeds, multimodal_embeddings, [
                    DEFAULT_IMAGE_PATCH_TOKEN_ID, DEFAULT_IM_START_TOKEN_ID,
                    DEFAULT_IM_END_TOKEN_ID, DEFAULT_IM_COL_TOKEN_ID
                ])
1327
1328
1329
1330
1331
1332
1333
1334
        return inputs_embeds

    def forward(
        self,
        input_ids: torch.LongTensor,
        positions: torch.LongTensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
1335
        intermediate_tensors: Optional[IntermediateTensors] = None,
1336
        inputs_embeds: Optional[torch.Tensor] = None,
1337
1338
        **kwargs: object,
    ) -> SamplerOutput:
1339

1340
1341
        if intermediate_tensors is not None:
            inputs_embeds = None
1342

1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None

        hidden_states = self.model(input_ids,
                                   positions,
                                   kv_caches,
                                   attn_metadata,
                                   intermediate_tensors,
                                   inputs_embeds=inputs_embeds)
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374

        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
1375

1376
1377
        loader = AutoWeightsLoader(self)
        weights = _get_weights_with_merged_embedding(weights)
1378
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398


def _get_weights_with_merged_embedding(
    weights: Iterable[Tuple[str, torch.Tensor]]
) -> Iterable[Tuple[str, torch.Tensor]]:
    embedding_weights = {}
    for name, weight in weights:
        if "wte.embedding" in name:
            embedding_weights["embedding"] = weight
        elif "wte.new_embedding" in name:
            embedding_weights["new_embedding"] = weight
        else:
            yield (name, weight)
    # this is compatible with most of quantization,
    # because they won't quantize embed_tokens
    embedding_weights = torch.cat(
        [embedding_weights["embedding"], embedding_weights["new_embedding"]],
        dim=0,
    )
    yield ("model.embed_tokens.weight", embedding_weights)