aria.py 23.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Iterable, Mapping, Sequence
4
from typing import Annotated, Literal, Optional, Union
5
6
7

import torch
import torch.nn as nn
8
9
10
from transformers import AriaConfig, AriaTextConfig, BatchFeature
from transformers.models.aria.modeling_aria import AriaCrossAttention
from transformers.models.aria.processing_aria import AriaProcessor
11

12
from vllm.config import VllmConfig
13
from vllm.config.multimodal import BaseDummyOptions
14
15
16
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import FusedMoE
17
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
18
from vllm.model_executor.layers.logits_processor import LogitsProcessor
19
from vllm.model_executor.layers.quantization import QuantizationConfig
20
21
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import (
22
23
24
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
25
from vllm.multimodal import MULTIMODAL_REGISTRY
26
27
28
29
30
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
31
from vllm.multimodal.parse import MultiModalDataItems
32
33
34
35
36
37
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
38
from vllm.multimodal.profiling import BaseDummyInputsBuilder
39
from vllm.sequence import IntermediateTensors
40
from vllm.utils.tensor_schema import TensorSchema, TensorShape
41

42
from .idefics2_vision_model import Idefics2VisionConfig
43
from .idefics2_vision_model import (
44
45
    Idefics2VisionTransformer as Idefics3VisionTransformer,
)
46
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant
47
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
48
49
50
51
52
53
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    is_pp_missing_parameter,
    maybe_prefix,
)
54
55


56
class AriaImagePixelInputs(TensorSchema):
57
    """
58
59
60
61
62
63
    Dimensions:
        - b: Batch size
        - n: Number of images
        - c: Number of channels
        - h: Height of each image
        - w: Width of each image
64
65
    """

66
67
    type: Literal["pixel_values"]

68
69
70
71
72
73
74
75
76
77
    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("bn", 3, "h", "w"),
    ]

    pixel_mask: Annotated[
        Optional[torch.Tensor],
        TensorShape("bn", "h", "w"),
    ]

78

79
80
class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant):
    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
81
82
83
84
85
86
87

    def __init__(
        self,
        config: Idefics2VisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
88
        super().__init__(config, quant_config=quant_config, prefix=prefix)
89
90
91
92
93
        # Unlike Idefics3VisionTransformer which uses LayerNorm after the
        # final layer, Aria omits this normalization, so we replace it with an
        # Identity layer
        self.post_layernorm = nn.Identity()

94
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
95
96
97
98
99
100
101
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
102
        loaded_params: set[str] = set()
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        for name, loaded_weight in weights:
            # NOTE: post_layernorm is not used in Aria
            if "post_layernorm" in name:
                continue

            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
118
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
119
120
121
122
123
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


124
class AriaProjectorMLP(nn.Module):
125
126
    def __init__(
        self,
127
128
129
        in_features: int,
        hidden_features: int,
        output_dim: int,
130
131
132
    ) -> None:
        super().__init__()

133
134
        self.linear_in = ColumnParallelLinear(in_features, hidden_features, bias=False)
        self.linear_out = RowParallelLinear(hidden_features, output_dim, bias=False)
135
136
        self.act = get_act_fn("gelu_new")

137
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
138
139
140
141
142
143
144
145
146
147
148
149
        hidden_states, _ = self.linear_in(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.linear_out(hidden_states)
        return hidden_states


class AriaProjector(nn.Module):
    """
    A projection module with one cross attention layer and one FFN layer, which
    projects ViT's outputs into MoE's inputs.

    Args:
150
151
        config: [AriaConfig](https://huggingface.co/docs/transformers/main/model_doc/aria#transformers.AriaConfig)
            containing projector configuration parameters.
152
153
154
155
156

    Outputs:
        A tensor with the shape of (batch_size, query_number, output_dim)
    """

157
    def __init__(self, config: AriaConfig) -> None:
158
        super().__init__()
159
160
161
162
163
164
165

        self.patch_to_query_dict = config.projector_patch_to_query_dict
        self.in_features = config.vision_config.hidden_size
        self.num_heads = config.vision_config.num_attention_heads
        self.kv_dim = config.vision_config.hidden_size
        self.hidden_features = config.text_config.hidden_size
        self.output_dim = config.text_config.hidden_size
166
167

        self.query = nn.Parameter(
168
169
170
171
            torch.empty(
                config.max_value_projector_patch_to_query_dict, self.in_features
            )
        )
172

173
        self.cross_attn = AriaCrossAttention(config)
174

175
        self.layer_norm = nn.LayerNorm(self.in_features)
176
177
178
        self.feed_forward = AriaProjectorMLP(
            self.in_features, self.hidden_features, self.output_dim
        )
179

180
181
182
183
184
    def forward(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
185
        batch_size, num_patches = x.shape[0], x.shape[1]
186

187
        if num_patches not in self.patch_to_query_dict:
188
189
190
191
192
            raise KeyError(
                f"Number of patches {num_patches} not found in "
                "patch_to_query_dict amongst possible values "
                f"{self.patch_to_query_dict.keys()}."
            )
193

194
195
196
        query_num = self.patch_to_query_dict[num_patches]

        queries = self.query[:query_num].unsqueeze(0).repeat(batch_size, 1, 1)
197
198
199
200
201
202
203

        if attn_mask is not None:
            attn_mask = attn_mask.repeat_interleave(self.num_heads, 0)
            attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1)

        attention_out = self.cross_attn(x, queries, attn_mask=attn_mask)

204
        out = self.feed_forward(self.layer_norm(attention_out))
205
206
207
208
209

        return out


class AriaFusedMoE(FusedMoE):
210
211
212
    def weight_loader(
        self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_id: str
    ) -> None:
213
214
215
216
217
        # Override the weight_loader to handle the expert weights in the Aria
        # model, which are already packed with experts, and merge the gate and
        # up weights for each expert.
        # Note: Loading expert weights with quantization is not supported
        tp_rank = get_tensor_model_parallel_rank()
218
        if shard_id == "w13":
219
220
221
222
223
224
            # the shape of loaded_weight is
            # (num_experts, hidden_size, 2 * moe_intermediate_size)
            if self.tp_size > 1:
                up, gate = loaded_weight.chunk(2, dim=-1)
                up_current_rank = up.chunk(self.tp_size, dim=-1)[tp_rank]
                gate_current_rank = gate.chunk(self.tp_size, dim=-1)[tp_rank]
225
226
227
                up_and_gate = torch.cat(
                    [up_current_rank, gate_current_rank], dim=-1
                ).transpose(1, 2)
228
229
230
                param.data.copy_(up_and_gate)
            else:
                param.data.copy_(loaded_weight.transpose(1, 2))
231
        elif shard_id == "w2":
232
233
234
            # the shape of loaded_weight is
            # (num_experts, moe_intermediate_size, hidden_size)
            if self.tp_size > 1:
235
                down_current_rank = loaded_weight.chunk(self.tp_size, dim=1)[tp_rank]
236
237
238
239
240
                param.data.copy_(down_current_rank.transpose(1, 2))
            else:
                param.data.copy_(loaded_weight.transpose(1, 2))


241
class AriaTextMoELayer(nn.Module):
242
243
244
245
246
247
248
249
250
251
    """
    Mixture of Experts (MoE) Layer for the AriaMoE model.

    This layer implements the MoE mechanism, which routes input tokens to
    different experts based on a routing algorithm, processes them through the
    experts, and then combines the outputs.
    """

    def __init__(
        self,
252
        config: AriaTextConfig,
253
        quant_config: Optional[QuantizationConfig],
254
        prefix: str = "",
255
256
257
258
259
    ) -> None:
        super().__init__()
        self.config = config

        self.router_weight = nn.Parameter(
260
261
            torch.empty((self.config.moe_num_experts, self.config.hidden_size))
        )
262
263
264
265
266

        self.experts = AriaFusedMoE(
            num_experts=config.moe_num_experts,
            top_k=config.moe_topk,
            hidden_size=config.hidden_size,
267
            intermediate_size=config.intermediate_size,
268
269
            quant_config=quant_config,
            reduce_results=True,
270
            prefix=f"{prefix}.experts",
271
272
273
        )
        self.shared_experts = LlamaMLP(
            config.hidden_size,
274
            config.intermediate_size * config.moe_num_shared_experts,
275
276
            "silu",
            quant_config=quant_config,
277
            bias=config.mlp_bias,
278
279
280
281
282
283
284
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the MoE Layer.

        Args:
285
286
            hidden_states: Input tensor of shape
                (batch_size, sequence_length, hidden_size).
287
288
289
290
291

        Returns:
            torch.Tensor: Output tensor after passing through the MoE layer.
        """

292
        router_output = torch.nn.functional.linear(hidden_states, self.router_weight)
293

294
295
        hidden_states_copy = hidden_states.clone()
        # NOTE: hidden_states will be modified inplace by `FusedMoE`
296
        sparse_expert_output = self.experts(hidden_states, router_output)
297
        shared_expert_output = self.shared_experts(hidden_states_copy)
298
299
300
301

        return sparse_expert_output + shared_expert_output


302
class AriaTextDecoderLayer(LlamaDecoderLayer):
303
304
305
306
307
308
    """
    Custom Decoder Layer for the AriaMoE model which modifies the standard
    `LlamaDecoderLayer` by replacing the traditional MLP with a Mixture of
    Experts (MoE) Layer.
    """

309
310
311
312
313
314
    def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__(vllm_config, prefix)

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

315
316
317
        self.mlp = AriaTextMoELayer(
            config, quant_config=quant_config, prefix=f"{prefix}.mlp"
        )
318
319


320
class AriaTextModel(LlamaModel, SupportsQuant):
321
322
323
324
    """
    Custom LlamaModel for the AriaMoE model which modifies the standard
    LlamaModel by replacing the `LlamaDecoderLayer` with `MoEDecoderLayer`.
    """
325

326
327
328
329
330
331
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"],
        "experts.w13_weight": ["experts.fc1.weight"],
        "experts.w2_weight": ["experts.fc2.weight"],
    }
332
333

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
334
335
336
        super().__init__(
            vllm_config=vllm_config, prefix=prefix, layer_type=AriaTextDecoderLayer
        )
337
338
339

    # Adapted from LlamaModel.load_weights with the modification of adding
    # the expert weights mapping to `stacked_params_mapping`
340
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
341
342
343
344
345
346
347
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
348
349
            ("experts.w13_weight", "experts.fc1.weight", "w13"),
            ("experts.w2_weight", "experts.fc2.weight", "w2"),
350
351
        ]
        params_dict = dict(self.named_parameters())
352
        loaded_params: set[str] = set()
353
354
355
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
356
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
357
358
359
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
360
361
362
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
363
                # Loading kv cache quantization scales
364
                param = params_dict[scale_name]
365
366
367
368
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
400
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
401
402
403
404
405
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


406
407
class AriaProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
408
        return self.ctx.get_hf_config(AriaConfig)
409

410
    def get_vision_config(self):
411
        return self.get_hf_config().vision_config
412

413
414
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(AriaProcessor, **kwargs)
415

416
417
418
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}

419
420
421
422
423
424
    def get_num_image_tokens(self) -> int:
        hf_config = self.get_hf_config()
        return max(hf_config.projector_patch_to_query_dict.values())


class AriaDummyInputsBuilder(BaseDummyInputsBuilder[AriaProcessingInfo]):
425
426
427
428
429
430
431
432
433
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token: str = processor.tokenizer.image_token  # type: ignore

        return image_token * num_images

    def get_dummy_mm_data(
434
435
436
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
437
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
438
    ) -> MultiModalDataDict:
439
        vision_config = self.info.get_vision_config()
440
441
442
443

        max_image_size = vision_config.image_size
        num_images = mm_counts.get("image", 0)

444
445
        image_overrides = mm_options.get("image") if mm_options else None

446
        return {
447
448
449
450
451
452
            "image": self._get_dummy_images(
                width=max_image_size,
                height=max_image_size,
                num_images=num_images,
                overrides=image_overrides,
            )
453
454
455
        }


456
class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]):
457
458
459
460
461
462
463
464
465
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            pixel_mask=MultiModalFieldConfig.batched("image"),
        )
466

467
    def _get_prompt_updates(
468
469
470
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
471
        out_mm_kwargs: MultiModalKwargsItems,
472
    ) -> Sequence[PromptUpdate]:
473
        hf_config = self.info.get_hf_config()
474
475
        image_token_id = hf_config.image_token_index

476
        num_image_tokens = self.info.get_num_image_tokens()
477
478
479
480
481

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
482
                replacement=[image_token_id] * num_image_tokens,
483
484
            )
        ]
485
486


487
488
489
490
491
@MULTIMODAL_REGISTRY.register_processor(
    AriaMultiModalProcessor,
    info=AriaProcessingInfo,
    dummy_inputs=AriaDummyInputsBuilder,
)
492
493
494
495
496
497
498
class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
    """
    Aria model for conditional generation tasks.

    This model combines a vision tower, a multi-modal projector, and a language
    model to perform tasks that involve both image and text inputs.
    """
499

500
501
    merge_by_field_config = True

502
503
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
504
505
506
507
508
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            # mapping for original checkpoint
509
510
511
512
513
514
515
            "language_model.model": "language_model",
            "language_model.lm_head": "lm_head",
        },
        orig_to_new_suffix={
            "router.weight": "router_weight",
        },
    )
516

517
518
519
520
521
522
523
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<|fim_prefix|><|img|><|fim_suffix|>"

        raise ValueError("Only image modality is supported")

524
525
526
527
528
529
530
531
532
533
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

        self.config = config
534
        self.vision_tower = AriaVisionTransformer(
535
            config.vision_config,
536
            quant_config=quant_config,
537
538
539
            prefix=f"{prefix}.vision_tower",
        )
        self.multi_modal_projector = AriaProjector(config)
540
        self.vocab_size = config.text_config.vocab_size
541
        self.language_model = AriaTextModel(
542
543
544
            vllm_config=vllm_config.with_hf_config(config.text_config),
            prefix=maybe_prefix(prefix, "language_model.model"),
        )
545
546
547
        self.pad_token_id = (
            self.config.pad_token_id if self.config.pad_token_id is not None else -1
        )
548
549
550
551
552
553
        self.unpadded_vocab_size = config.text_config.vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.text_config.hidden_size,
            org_num_embeddings=self.language_model.org_vocab_size,
            quant_config=quant_config,
554
            prefix=maybe_prefix(prefix, "lm_head"),
555
556
        )
        logit_scale = getattr(config, "logit_scale", 1.0)
557
558
559
        self.logits_processor = LogitsProcessor(
            self.unpadded_vocab_size, self.vocab_size, logit_scale
        )
560
561

    def _parse_and_validate_image_input(
562
563
        self, **kwargs: object
    ) -> Optional[AriaImagePixelInputs]:
564
565
566
567
568
569
570
        pixel_values = kwargs.pop("pixel_values", None)
        pixel_mask = kwargs.pop("pixel_mask", None)

        if pixel_values is None:
            return None

        return AriaImagePixelInputs(
571
572
573
            type="pixel_values",
            pixel_values=pixel_values,
            pixel_mask=pixel_mask,
574
575
        )

576
    def _create_patch_attention_mask(
577
578
579
        self,
        pixel_mask: Optional[torch.Tensor],
    ) -> Optional[torch.Tensor]:
580
581
582
583
584
585
586
587
588
589
590
591
592
593
        if pixel_mask is None:
            return None

        patches_subgrid = pixel_mask.unfold(
            dimension=1,
            size=self.vision_tower.config.patch_size,
            step=self.vision_tower.config.patch_size,
        ).unfold(
            dimension=2,
            size=self.vision_tower.config.patch_size,
            step=self.vision_tower.config.patch_size,
        )
        return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()

594
595
    def _process_image_input(
        self, image_input: AriaImagePixelInputs
596
    ) -> tuple[torch.Tensor, torch.Tensor]:
597
598
        assert self.vision_tower is not None

599
600
        pixel_values = image_input["pixel_values"]
        pixel_mask = image_input["pixel_mask"]
601

602
603
604
605
606
607
608
609
610
611
612
613
        patch_attention_mask = self._create_patch_attention_mask(pixel_mask)

        image_outputs = self.vision_tower(
            pixel_values=pixel_values,
            patch_attention_mask=patch_attention_mask,
        )
        image_attn_mask = None
        if patch_attention_mask is not None:
            flattened_mask = patch_attention_mask.flatten(1)
            image_attn_mask = torch.logical_not(flattened_mask)

        return self.multi_modal_projector(image_outputs, image_attn_mask)
614

615
616
617
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

618
    def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
619
620
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
621
            return []
622
623
624
625
626
627
628
629
630
631
632
633
634
        multimodal_embeddings = self._process_image_input(image_input)
        return multimodal_embeddings

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if inputs_embeds is None:
            multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
635
636
637
638
639
            inputs_embeds = self.get_input_embeddings(
                input_ids,
                multimodal_embeddings,
                is_multimodal=input_ids == self.config.image_token_index,
            )
640
641
642
643
644
645
646
647
648
649
650
            input_ids = None

        hidden_states = self.language_model(
            input_ids,
            positions,
            intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )

        return hidden_states

651
652
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
653
654
        return logits

655
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
656
        loader = AutoWeightsLoader(self)
657
        loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)