aria.py 24.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Iterable, Mapping, Sequence
4
from typing import Annotated, Literal, Optional, Union
5
6
7

import torch
import torch.nn as nn
8
9
10
from transformers import AriaConfig, AriaTextConfig, BatchFeature
from transformers.models.aria.modeling_aria import AriaCrossAttention
from transformers.models.aria.processing_aria import AriaProcessor
11

12
from vllm.config import VllmConfig
13
from vllm.config.multimodal import BaseDummyOptions
14
15
16
17
18
19
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
20
from vllm.model_executor.layers.quantization import QuantizationConfig
21
22
23
24
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, maybe_remap_kv_scale_name)
from vllm.multimodal import MULTIMODAL_REGISTRY
25
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
26
                                    MultiModalKwargsItems)
27
from vllm.multimodal.parse import MultiModalDataItems
28
from vllm.multimodal.processing import (BaseMultiModalProcessor,
29
30
                                        BaseProcessingInfo, PromptReplacement,
                                        PromptUpdate)
31
from vllm.multimodal.profiling import BaseDummyInputsBuilder
32
from vllm.sequence import IntermediateTensors
33
from vllm.utils.tensor_schema import TensorSchema, TensorShape
34

35
# yapf: disable
36
from .idefics2_vision_model import Idefics2VisionConfig
37
38
39
from .idefics2_vision_model import (
    Idefics2VisionTransformer as Idefics3VisionTransformer)
# yapf: enable
40
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant
41
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
42
43
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
                    maybe_prefix)
44
45


46
class AriaImagePixelInputs(TensorSchema):
47
    """
48
49
50
51
52
53
    Dimensions:
        - b: Batch size
        - n: Number of images
        - c: Number of channels
        - h: Height of each image
        - w: Width of each image
54
55
    """

56
57
    type: Literal["pixel_values"]

58
59
60
61
62
63
64
65
66
67
    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("bn", 3, "h", "w"),
    ]

    pixel_mask: Annotated[
        Optional[torch.Tensor],
        TensorShape("bn", "h", "w"),
    ]

68

69
70
class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant):
    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
71
72
73
74
75
76
77

    def __init__(
        self,
        config: Idefics2VisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
78
        super().__init__(config, quant_config=quant_config, prefix=prefix)
79
80
81
82
83
        # Unlike Idefics3VisionTransformer which uses LayerNorm after the
        # final layer, Aria omits this normalization, so we replace it with an
        # Identity layer
        self.post_layernorm = nn.Identity()

84
85
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
86
87
88
89
90
91
92
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
93
        loaded_params: set[str] = set()
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        for name, loaded_weight in weights:

            # NOTE: post_layernorm is not used in Aria
            if "post_layernorm" in name:
                continue

            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


117
class AriaProjectorMLP(nn.Module):
118
119
120

    def __init__(
        self,
121
122
123
        in_features: int,
        hidden_features: int,
        output_dim: int,
124
125
126
    ) -> None:
        super().__init__()

127
128
129
130
131
132
        self.linear_in = ColumnParallelLinear(in_features,
                                              hidden_features,
                                              bias=False)
        self.linear_out = RowParallelLinear(hidden_features,
                                            output_dim,
                                            bias=False)
133
134
        self.act = get_act_fn("gelu_new")

135
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
136
137
138
139
140
141
142
143
144
145
146
147
        hidden_states, _ = self.linear_in(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.linear_out(hidden_states)
        return hidden_states


class AriaProjector(nn.Module):
    """
    A projection module with one cross attention layer and one FFN layer, which
    projects ViT's outputs into MoE's inputs.

    Args:
148
149
        config: [AriaConfig](https://huggingface.co/docs/transformers/main/model_doc/aria#transformers.AriaConfig)
            containing projector configuration parameters.
150
151
152
153
154

    Outputs:
        A tensor with the shape of (batch_size, query_number, output_dim)
    """

155
    def __init__(self, config: AriaConfig) -> None:
156
        super().__init__()
157
158
159
160
161
162
163

        self.patch_to_query_dict = config.projector_patch_to_query_dict
        self.in_features = config.vision_config.hidden_size
        self.num_heads = config.vision_config.num_attention_heads
        self.kv_dim = config.vision_config.hidden_size
        self.hidden_features = config.text_config.hidden_size
        self.output_dim = config.text_config.hidden_size
164
165

        self.query = nn.Parameter(
166
167
            torch.empty(config.max_value_projector_patch_to_query_dict,
                        self.in_features))
168

169
        self.cross_attn = AriaCrossAttention(config)
170

171
172
173
174
        self.layer_norm = nn.LayerNorm(self.in_features)
        self.feed_forward = AriaProjectorMLP(self.in_features,
                                             self.hidden_features,
                                             self.output_dim)
175

176
177
178
179
180
    def forward(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
181
        batch_size, num_patches = x.shape[0], x.shape[1]
182

183
184
185
186
        if num_patches not in self.patch_to_query_dict:
            raise KeyError(f"Number of patches {num_patches} not found in "
                           "patch_to_query_dict amongst possible values "
                           f"{self.patch_to_query_dict.keys()}.")
187

188
189
190
        query_num = self.patch_to_query_dict[num_patches]

        queries = self.query[:query_num].unsqueeze(0).repeat(batch_size, 1, 1)
191
192
193
194
195
196
197

        if attn_mask is not None:
            attn_mask = attn_mask.repeat_interleave(self.num_heads, 0)
            attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1)

        attention_out = self.cross_attn(x, queries, attn_mask=attn_mask)

198
        out = self.feed_forward(self.layer_norm(attention_out))
199
200
201
202
203
204
205

        return out


class AriaFusedMoE(FusedMoE):

    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
206
                      shard_id: str) -> None:
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
        # Override the weight_loader to handle the expert weights in the Aria
        # model, which are already packed with experts, and merge the gate and
        # up weights for each expert.
        # Note: Loading expert weights with quantization is not supported
        tp_rank = get_tensor_model_parallel_rank()
        if shard_id == 'w13':
            # the shape of loaded_weight is
            # (num_experts, hidden_size, 2 * moe_intermediate_size)
            if self.tp_size > 1:
                up, gate = loaded_weight.chunk(2, dim=-1)
                up_current_rank = up.chunk(self.tp_size, dim=-1)[tp_rank]
                gate_current_rank = gate.chunk(self.tp_size, dim=-1)[tp_rank]
                up_and_gate = torch.cat([up_current_rank, gate_current_rank],
                                        dim=-1).transpose(1, 2)
                param.data.copy_(up_and_gate)
            else:
                param.data.copy_(loaded_weight.transpose(1, 2))
        elif shard_id == 'w2':
            # the shape of loaded_weight is
            # (num_experts, moe_intermediate_size, hidden_size)
            if self.tp_size > 1:
                down_current_rank = loaded_weight.chunk(self.tp_size,
                                                        dim=1)[tp_rank]
                param.data.copy_(down_current_rank.transpose(1, 2))
            else:
                param.data.copy_(loaded_weight.transpose(1, 2))


235
class AriaTextMoELayer(nn.Module):
236
237
238
239
240
241
242
243
244
245
    """
    Mixture of Experts (MoE) Layer for the AriaMoE model.

    This layer implements the MoE mechanism, which routes input tokens to
    different experts based on a routing algorithm, processes them through the
    experts, and then combines the outputs.
    """

    def __init__(
        self,
246
        config: AriaTextConfig,
247
        quant_config: Optional[QuantizationConfig],
248
        prefix: str = "",
249
250
251
252
253
254
255
256
257
258
259
260
    ) -> None:
        super().__init__()
        self.config = config

        self.router_weight = nn.Parameter(
            torch.empty(
                (self.config.moe_num_experts, self.config.hidden_size)))

        self.experts = AriaFusedMoE(
            num_experts=config.moe_num_experts,
            top_k=config.moe_topk,
            hidden_size=config.hidden_size,
261
            intermediate_size=config.intermediate_size,
262
263
            quant_config=quant_config,
            reduce_results=True,
264
            prefix=f"{prefix}.experts",
265
266
267
        )
        self.shared_experts = LlamaMLP(
            config.hidden_size,
268
            config.intermediate_size * config.moe_num_shared_experts,
269
270
            "silu",
            quant_config=quant_config,
271
            bias=config.mlp_bias,
272
273
274
275
276
277
278
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the MoE Layer.

        Args:
279
280
            hidden_states: Input tensor of shape
                (batch_size, sequence_length, hidden_size).
281
282
283
284
285
286
287
288

        Returns:
            torch.Tensor: Output tensor after passing through the MoE layer.
        """

        router_output = torch.nn.functional.linear(hidden_states,
                                                   self.router_weight)

289
290
        hidden_states_copy = hidden_states.clone()
        # NOTE: hidden_states will be modified inplace by `FusedMoE`
291
        sparse_expert_output = self.experts(hidden_states, router_output)
292
        shared_expert_output = self.shared_experts(hidden_states_copy)
293
294
295
296

        return sparse_expert_output + shared_expert_output


297
class AriaTextDecoderLayer(LlamaDecoderLayer):
298
299
300
301
302
303
    """
    Custom Decoder Layer for the AriaMoE model which modifies the standard
    `LlamaDecoderLayer` by replacing the traditional MLP with a Mixture of
    Experts (MoE) Layer.
    """

304
305
306
307
308
309
    def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__(vllm_config, prefix)

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

310
311
312
        self.mlp = AriaTextMoELayer(config,
                                    quant_config=quant_config,
                                    prefix=f"{prefix}.mlp")
313
314


315
class AriaTextModel(LlamaModel, SupportsQuant):
316
317
318
319
    """
    Custom LlamaModel for the AriaMoE model which modifies the standard
    LlamaModel by replacing the `LlamaDecoderLayer` with `MoEDecoderLayer`.
    """
320
321
322
323
324
325
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"],
        "experts.w13_weight": ["experts.fc1.weight"],
        "experts.w2_weight": ["experts.fc2.weight"],
    }
326
327

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
328
329
        super().__init__(vllm_config=vllm_config,
                         prefix=prefix,
330
                         layer_type=AriaTextDecoderLayer)
331
332
333

    # Adapted from LlamaModel.load_weights with the modification of adding
    # the expert weights mapping to `stacked_params_mapping`
334
335
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
336
337
338
339
340
341
342
343
344
345
346
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
            ("experts.w13_weight", "experts.fc1.weight", 'w13'),
            ("experts.w2_weight", "experts.fc2.weight", 'w2'),
        ]
        params_dict = dict(self.named_parameters())
347
        loaded_params: set[str] = set()
348
349
350
351
352
353
354
355
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
356
357
            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
358
                # Loading kv cache quantization scales
359
360
361
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
362
363
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


402
class AriaProcessingInfo(BaseProcessingInfo):
403

404
    def get_hf_config(self):
405
        return self.ctx.get_hf_config(AriaConfig)
406

407
    def get_vision_config(self):
408
        return self.get_hf_config().vision_config
409

410
411
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(AriaProcessor, **kwargs)
412

413
414
415
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}

416
417
418
419
420
421
    def get_num_image_tokens(self) -> int:
        hf_config = self.get_hf_config()
        return max(hf_config.projector_patch_to_query_dict.values())


class AriaDummyInputsBuilder(BaseDummyInputsBuilder[AriaProcessingInfo]):
422

423
424
425
426
427
428
429
430
431
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token: str = processor.tokenizer.image_token  # type: ignore

        return image_token * num_images

    def get_dummy_mm_data(
432
433
434
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
435
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
436
    ) -> MultiModalDataDict:
437
        vision_config = self.info.get_vision_config()
438
439
440
441

        max_image_size = vision_config.image_size
        num_images = mm_counts.get("image", 0)

442
443
        image_overrides = mm_options.get("image") if mm_options else None

444
        return {
445
446
447
            "image":
            self._get_dummy_images(width=max_image_size,
                                   height=max_image_size,
448
449
                                   num_images=num_images,
                                   overrides=image_overrides)
450
451
452
        }


453
class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]):
454

455
456
457
458
459
460
461
462
463
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            pixel_mask=MultiModalFieldConfig.batched("image"),
        )
464

465
    def _get_prompt_updates(
466
467
468
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
469
        out_mm_kwargs: MultiModalKwargsItems,
470
    ) -> Sequence[PromptUpdate]:
471
        hf_config = self.info.get_hf_config()
472
473
        image_token_id = hf_config.image_token_index

474
        num_image_tokens = self.info.get_num_image_tokens()
475
476
477
478
479

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
480
                replacement=[image_token_id] * num_image_tokens,
481
482
            )
        ]
483
484


485
486
487
@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor,
                                        info=AriaProcessingInfo,
                                        dummy_inputs=AriaDummyInputsBuilder)
488
489
490
491
492
493
494
class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
    """
    Aria model for conditional generation tasks.

    This model combines a vision tower, a multi-modal projector, and a language
    model to perform tasks that involve both image and text inputs.
    """
495
496
    merge_by_field_config = True

497
498
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
499
500
501
502
503
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            # mapping for original checkpoint
504
505
506
507
508
509
510
            "language_model.model": "language_model",
            "language_model.lm_head": "lm_head",
        },
        orig_to_new_suffix={
            "router.weight": "router_weight",
        },
    )
511

512
513
514
515
516
517
518
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<|fim_prefix|><|img|><|fim_suffix|>"

        raise ValueError("Only image modality is supported")

519
520
521
522
523
524
525
526
527
528
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

        self.config = config
529
        self.vision_tower = AriaVisionTransformer(
530
            config.vision_config,
531
            quant_config=quant_config,
532
533
534
            prefix=f"{prefix}.vision_tower",
        )
        self.multi_modal_projector = AriaProjector(config)
535
        self.vocab_size = config.text_config.vocab_size
536
        self.language_model = AriaTextModel(
537
538
539
540
541
542
543
544
545
546
547
            vllm_config=vllm_config.with_hf_config(config.text_config),
            prefix=maybe_prefix(prefix, "language_model.model"),
        )
        self.pad_token_id = (self.config.pad_token_id
                             if self.config.pad_token_id is not None else -1)
        self.unpadded_vocab_size = config.text_config.vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.text_config.hidden_size,
            org_num_embeddings=self.language_model.org_vocab_size,
            quant_config=quant_config,
548
            prefix=maybe_prefix(prefix, "lm_head"),
549
550
551
552
553
554
555
556
557
558
559
560
561
562
        )
        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                self.vocab_size, logit_scale)

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[AriaImagePixelInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
        pixel_mask = kwargs.pop("pixel_mask", None)

        if pixel_values is None:
            return None

        return AriaImagePixelInputs(
563
564
565
            type="pixel_values",
            pixel_values=pixel_values,
            pixel_mask=pixel_mask,
566
567
        )

568
    def _create_patch_attention_mask(
569
570
571
        self,
        pixel_mask: Optional[torch.Tensor],
    ) -> Optional[torch.Tensor]:
572
573
574
575
576
577
578
579
580
581
582
583
584
585
        if pixel_mask is None:
            return None

        patches_subgrid = pixel_mask.unfold(
            dimension=1,
            size=self.vision_tower.config.patch_size,
            step=self.vision_tower.config.patch_size,
        ).unfold(
            dimension=2,
            size=self.vision_tower.config.patch_size,
            step=self.vision_tower.config.patch_size,
        )
        return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()

586
587
    def _process_image_input(
        self, image_input: AriaImagePixelInputs
588
    ) -> tuple[torch.Tensor, torch.Tensor]:
589
590
591
592
593
        assert self.vision_tower is not None

        pixel_values = image_input['pixel_values']
        pixel_mask = image_input['pixel_mask']

594
595
596
597
598
599
600
601
602
603
604
605
        patch_attention_mask = self._create_patch_attention_mask(pixel_mask)

        image_outputs = self.vision_tower(
            pixel_values=pixel_values,
            patch_attention_mask=patch_attention_mask,
        )
        image_attn_mask = None
        if patch_attention_mask is not None:
            flattened_mask = patch_attention_mask.flatten(1)
            image_attn_mask = torch.logical_not(flattened_mask)

        return self.multi_modal_projector(image_outputs, image_attn_mask)
606

607
608
609
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

610
611
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
612
613
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
614
            return []
615
616
617
618
619
620
621
622
623
624
625
626
627
        multimodal_embeddings = self._process_image_input(image_input)
        return multimodal_embeddings

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if inputs_embeds is None:
            multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
628
629
630
631
632
            inputs_embeds = self.get_input_embeddings(
                input_ids,
                multimodal_embeddings,
                is_multimodal=input_ids == self.config.image_token_index,
            )
633
634
635
636
637
638
639
640
641
642
643
            input_ids = None

        hidden_states = self.language_model(
            input_ids,
            positions,
            intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )

        return hidden_states

644
645
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
646
647
        return logits

648
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
649
        loader = AutoWeightsLoader(self)
650
        loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)