aria.py 23.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Iterable, Mapping, Sequence
4
from typing import Annotated, Optional, Union
5
6
7

import torch
import torch.nn as nn
8
9
10
from transformers import AriaConfig, AriaTextConfig, BatchFeature
from transformers.models.aria.modeling_aria import AriaCrossAttention
from transformers.models.aria.processing_aria import AriaProcessor
11
12
13
14
15
16
17
18
19
20
21
22

from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, maybe_remap_kv_scale_name)
from vllm.multimodal import MULTIMODAL_REGISTRY
23
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
24
                                    MultiModalKwargsItems)
25
from vllm.multimodal.parse import MultiModalDataItems
26
from vllm.multimodal.processing import (BaseMultiModalProcessor,
27
28
                                        BaseProcessingInfo, PromptReplacement,
                                        PromptUpdate)
29
from vllm.multimodal.profiling import BaseDummyInputsBuilder
30
from vllm.sequence import IntermediateTensors
31
from vllm.utils.tensor_schema import TensorSchema, TensorShape
32

33
# yapf: disable
34
from .idefics2_vision_model import Idefics2VisionConfig
35
36
37
from .idefics2_vision_model import (
    Idefics2VisionTransformer as Idefics3VisionTransformer)
# yapf: enable
38
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant
39
40
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
41
                    is_pp_missing_parameter, maybe_prefix)
42
43


44
class AriaImagePixelInputs(TensorSchema):
45
    """
46
47
48
49
50
51
    Dimensions:
        - b: Batch size
        - n: Number of images
        - c: Number of channels
        - h: Height of each image
        - w: Width of each image
52
53
    """

54
55
56
57
58
59
60
61
62
63
    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("bn", 3, "h", "w"),
    ]

    pixel_mask: Annotated[
        Optional[torch.Tensor],
        TensorShape("bn", "h", "w"),
    ]

64

65
66
class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant):
    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
67
68
69
70
71
72
73

    def __init__(
        self,
        config: Idefics2VisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
74
        super().__init__(config, quant_config=quant_config, prefix=prefix)
75
76
77
78
79
        # Unlike Idefics3VisionTransformer which uses LayerNorm after the
        # final layer, Aria omits this normalization, so we replace it with an
        # Identity layer
        self.post_layernorm = nn.Identity()

80
81
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
82
83
84
85
86
87
88
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
89
        loaded_params: set[str] = set()
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        for name, loaded_weight in weights:

            # NOTE: post_layernorm is not used in Aria
            if "post_layernorm" in name:
                continue

            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


113
class AriaProjectorMLP(nn.Module):
114
115
116

    def __init__(
        self,
117
118
119
        in_features: int,
        hidden_features: int,
        output_dim: int,
120
121
122
    ) -> None:
        super().__init__()

123
124
125
126
127
128
        self.linear_in = ColumnParallelLinear(in_features,
                                              hidden_features,
                                              bias=False)
        self.linear_out = RowParallelLinear(hidden_features,
                                            output_dim,
                                            bias=False)
129
130
        self.act = get_act_fn("gelu_new")

131
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
132
133
134
135
136
137
138
139
140
141
142
143
        hidden_states, _ = self.linear_in(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.linear_out(hidden_states)
        return hidden_states


class AriaProjector(nn.Module):
    """
    A projection module with one cross attention layer and one FFN layer, which
    projects ViT's outputs into MoE's inputs.

    Args:
144
145
        config: [AriaConfig](https://huggingface.co/docs/transformers/main/model_doc/aria#transformers.AriaConfig)
            containing projector configuration parameters.
146
147
148
149
150

    Outputs:
        A tensor with the shape of (batch_size, query_number, output_dim)
    """

151
    def __init__(self, config: AriaConfig) -> None:
152
        super().__init__()
153
154
155
156
157
158
159

        self.patch_to_query_dict = config.projector_patch_to_query_dict
        self.in_features = config.vision_config.hidden_size
        self.num_heads = config.vision_config.num_attention_heads
        self.kv_dim = config.vision_config.hidden_size
        self.hidden_features = config.text_config.hidden_size
        self.output_dim = config.text_config.hidden_size
160
161

        self.query = nn.Parameter(
162
163
            torch.empty(config.max_value_projector_patch_to_query_dict,
                        self.in_features))
164

165
        self.cross_attn = AriaCrossAttention(config)
166

167
168
169
170
        self.layer_norm = nn.LayerNorm(self.in_features)
        self.feed_forward = AriaProjectorMLP(self.in_features,
                                             self.hidden_features,
                                             self.output_dim)
171

172
173
174
175
176
    def forward(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
177
        batch_size, num_patches = x.shape[0], x.shape[1]
178

179
180
181
182
        if num_patches not in self.patch_to_query_dict:
            raise KeyError(f"Number of patches {num_patches} not found in "
                           "patch_to_query_dict amongst possible values "
                           f"{self.patch_to_query_dict.keys()}.")
183

184
185
186
        query_num = self.patch_to_query_dict[num_patches]

        queries = self.query[:query_num].unsqueeze(0).repeat(batch_size, 1, 1)
187
188
189
190
191
192
193

        if attn_mask is not None:
            attn_mask = attn_mask.repeat_interleave(self.num_heads, 0)
            attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1)

        attention_out = self.cross_attn(x, queries, attn_mask=attn_mask)

194
        out = self.feed_forward(self.layer_norm(attention_out))
195
196
197
198
199
200
201

        return out


class AriaFusedMoE(FusedMoE):

    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
202
                      shard_id: str) -> None:
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
        # Override the weight_loader to handle the expert weights in the Aria
        # model, which are already packed with experts, and merge the gate and
        # up weights for each expert.
        # Note: Loading expert weights with quantization is not supported
        tp_rank = get_tensor_model_parallel_rank()
        if shard_id == 'w13':
            # the shape of loaded_weight is
            # (num_experts, hidden_size, 2 * moe_intermediate_size)
            if self.tp_size > 1:
                up, gate = loaded_weight.chunk(2, dim=-1)
                up_current_rank = up.chunk(self.tp_size, dim=-1)[tp_rank]
                gate_current_rank = gate.chunk(self.tp_size, dim=-1)[tp_rank]
                up_and_gate = torch.cat([up_current_rank, gate_current_rank],
                                        dim=-1).transpose(1, 2)
                param.data.copy_(up_and_gate)
            else:
                param.data.copy_(loaded_weight.transpose(1, 2))
        elif shard_id == 'w2':
            # the shape of loaded_weight is
            # (num_experts, moe_intermediate_size, hidden_size)
            if self.tp_size > 1:
                down_current_rank = loaded_weight.chunk(self.tp_size,
                                                        dim=1)[tp_rank]
                param.data.copy_(down_current_rank.transpose(1, 2))
            else:
                param.data.copy_(loaded_weight.transpose(1, 2))


231
class AriaTextMoELayer(nn.Module):
232
233
234
235
236
237
238
239
240
241
    """
    Mixture of Experts (MoE) Layer for the AriaMoE model.

    This layer implements the MoE mechanism, which routes input tokens to
    different experts based on a routing algorithm, processes them through the
    experts, and then combines the outputs.
    """

    def __init__(
        self,
242
        config: AriaTextConfig,
243
        quant_config: Optional[QuantizationConfig],
244
        prefix: str = "",
245
246
247
248
249
250
251
252
253
254
255
256
    ) -> None:
        super().__init__()
        self.config = config

        self.router_weight = nn.Parameter(
            torch.empty(
                (self.config.moe_num_experts, self.config.hidden_size)))

        self.experts = AriaFusedMoE(
            num_experts=config.moe_num_experts,
            top_k=config.moe_topk,
            hidden_size=config.hidden_size,
257
            intermediate_size=config.intermediate_size,
258
259
            quant_config=quant_config,
            reduce_results=True,
260
            prefix=f"{prefix}.experts",
261
262
263
        )
        self.shared_experts = LlamaMLP(
            config.hidden_size,
264
            config.intermediate_size * config.moe_num_shared_experts,
265
266
            "silu",
            quant_config=quant_config,
267
            bias=config.mlp_bias,
268
269
270
271
272
273
274
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the MoE Layer.

        Args:
275
276
            hidden_states: Input tensor of shape
                (batch_size, sequence_length, hidden_size).
277
278
279
280
281
282
283
284

        Returns:
            torch.Tensor: Output tensor after passing through the MoE layer.
        """

        router_output = torch.nn.functional.linear(hidden_states,
                                                   self.router_weight)

285
286
        hidden_states_copy = hidden_states.clone()
        # NOTE: hidden_states will be modified inplace by `FusedMoE`
287
        sparse_expert_output = self.experts(hidden_states, router_output)
288
        shared_expert_output = self.shared_experts(hidden_states_copy)
289
290
291
292

        return sparse_expert_output + shared_expert_output


293
class AriaTextDecoderLayer(LlamaDecoderLayer):
294
295
296
297
298
299
300
301
    """
    Custom Decoder Layer for the AriaMoE model which modifies the standard
    `LlamaDecoderLayer` by replacing the traditional MLP with a Mixture of
    Experts (MoE) Layer.
    """

    def __init__(
        self,
302
        config: AriaTextConfig,
303
304
305
306
307
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__(config, cache_config, quant_config, prefix)
308
309
310
        self.mlp = AriaTextMoELayer(config,
                                    quant_config=quant_config,
                                    prefix=f"{prefix}.mlp")
311
312


313
class AriaTextModel(LlamaModel, SupportsQuant):
314
315
316
317
    """
    Custom LlamaModel for the AriaMoE model which modifies the standard
    LlamaModel by replacing the `LlamaDecoderLayer` with `MoEDecoderLayer`.
    """
318
319
320
321
322
323
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"],
        "experts.w13_weight": ["experts.fc1.weight"],
        "experts.w2_weight": ["experts.fc2.weight"],
    }
324
325

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
326
327
        super().__init__(vllm_config=vllm_config,
                         prefix=prefix,
328
                         layer_type=AriaTextDecoderLayer)
329
330
331

    # Adapted from LlamaModel.load_weights with the modification of adding
    # the expert weights mapping to `stacked_params_mapping`
332
333
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
334
335
336
337
338
339
340
341
342
343
344
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
            ("experts.w13_weight", "experts.fc1.weight", 'w13'),
            ("experts.w2_weight", "experts.fc2.weight", 'w2'),
        ]
        params_dict = dict(self.named_parameters())
345
        loaded_params: set[str] = set()
346
347
348
349
350
351
352
353
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
354
355
            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
356
                # Loading kv cache quantization scales
357
358
359
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
360
361
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


400
class AriaProcessingInfo(BaseProcessingInfo):
401

402
    def get_hf_config(self):
403
        return self.ctx.get_hf_config(AriaConfig)
404

405
    def get_vision_config(self):
406
        return self.get_hf_config().vision_config
407

408
409
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(AriaProcessor, **kwargs)
410

411
412
413
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}

414
415
416
417
418
419
    def get_num_image_tokens(self) -> int:
        hf_config = self.get_hf_config()
        return max(hf_config.projector_patch_to_query_dict.values())


class AriaDummyInputsBuilder(BaseDummyInputsBuilder[AriaProcessingInfo]):
420

421
422
423
424
425
426
427
428
429
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token: str = processor.tokenizer.image_token  # type: ignore

        return image_token * num_images

    def get_dummy_mm_data(
430
431
432
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
433
    ) -> MultiModalDataDict:
434
        vision_config = self.info.get_vision_config()
435
436
437
438

        max_image_size = vision_config.image_size
        num_images = mm_counts.get("image", 0)

439
        return {
440
441
442
443
444
445
446
            "image":
            self._get_dummy_images(width=max_image_size,
                                   height=max_image_size,
                                   num_images=num_images)
        }


447
class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]):
448

449
450
451
452
453
454
455
456
457
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            pixel_mask=MultiModalFieldConfig.batched("image"),
        )
458

459
    def _get_prompt_updates(
460
461
462
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
463
        out_mm_kwargs: MultiModalKwargsItems,
464
    ) -> Sequence[PromptUpdate]:
465
        hf_config = self.info.get_hf_config()
466
467
        image_token_id = hf_config.image_token_index

468
        num_image_tokens = self.info.get_num_image_tokens()
469
470
471
472
473

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
474
                replacement=[image_token_id] * num_image_tokens,
475
476
            )
        ]
477
478


479
480
481
@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor,
                                        info=AriaProcessingInfo,
                                        dummy_inputs=AriaDummyInputsBuilder)
482
483
484
485
486
487
488
class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
    """
    Aria model for conditional generation tasks.

    This model combines a vision tower, a multi-modal projector, and a language
    model to perform tasks that involve both image and text inputs.
    """
489
490
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
491
492
493
494
495
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            # mapping for original checkpoint
496
497
498
499
500
501
502
            "language_model.model": "language_model",
            "language_model.lm_head": "lm_head",
        },
        orig_to_new_suffix={
            "router.weight": "router_weight",
        },
    )
503

504
505
506
507
508
509
510
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<|fim_prefix|><|img|><|fim_suffix|>"

        raise ValueError("Only image modality is supported")

511
512
513
514
515
516
517
518
519
520
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

        self.config = config
521
        self.vision_tower = AriaVisionTransformer(
522
            config.vision_config,
523
            quant_config=quant_config,
524
525
526
            prefix=f"{prefix}.vision_tower",
        )
        self.multi_modal_projector = AriaProjector(config)
527
        self.vocab_size = config.text_config.vocab_size
528
        self.language_model = AriaTextModel(
529
530
531
532
533
534
535
536
537
538
539
            vllm_config=vllm_config.with_hf_config(config.text_config),
            prefix=maybe_prefix(prefix, "language_model.model"),
        )
        self.pad_token_id = (self.config.pad_token_id
                             if self.config.pad_token_id is not None else -1)
        self.unpadded_vocab_size = config.text_config.vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.text_config.hidden_size,
            org_num_embeddings=self.language_model.org_vocab_size,
            quant_config=quant_config,
540
            prefix=maybe_prefix(prefix, "lm_head"),
541
542
543
544
545
546
547
548
549
550
551
552
553
554
        )
        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                self.vocab_size, logit_scale)

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[AriaImagePixelInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
        pixel_mask = kwargs.pop("pixel_mask", None)

        if pixel_values is None:
            return None

        return AriaImagePixelInputs(
555
556
            pixel_values=flatten_bn(pixel_values, concat=True),
            pixel_mask=flatten_bn(pixel_mask, concat=True),
557
558
        )

559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
    def _create_patch_attention_mask(
            self, pixel_mask: Optional[torch.Tensor]) -> torch.Tensor:
        if pixel_mask is None:
            return None

        patches_subgrid = pixel_mask.unfold(
            dimension=1,
            size=self.vision_tower.config.patch_size,
            step=self.vision_tower.config.patch_size,
        ).unfold(
            dimension=2,
            size=self.vision_tower.config.patch_size,
            step=self.vision_tower.config.patch_size,
        )
        return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()

575
576
    def _process_image_input(
        self, image_input: AriaImagePixelInputs
577
    ) -> tuple[torch.Tensor, torch.Tensor]:
578
579
580
581
582
        assert self.vision_tower is not None

        pixel_values = image_input['pixel_values']
        pixel_mask = image_input['pixel_mask']

583
584
585
586
587
588
589
590
591
592
593
594
        patch_attention_mask = self._create_patch_attention_mask(pixel_mask)

        image_outputs = self.vision_tower(
            pixel_values=pixel_values,
            patch_attention_mask=patch_attention_mask,
        )
        image_attn_mask = None
        if patch_attention_mask is not None:
            flattened_mask = patch_attention_mask.flatten(1)
            image_attn_mask = torch.logical_not(flattened_mask)

        return self.multi_modal_projector(image_outputs, image_attn_mask)
595

596
597
598
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

599
600
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
601
602
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
603
            return []
604
605
606
607
608
609
610
611
612
613
614
615
616
        multimodal_embeddings = self._process_image_input(image_input)
        return multimodal_embeddings

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if inputs_embeds is None:
            multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
617
618
619
620
621
            inputs_embeds = self.get_input_embeddings(
                input_ids,
                multimodal_embeddings,
                is_multimodal=input_ids == self.config.image_token_index,
            )
622
623
624
625
626
627
628
629
630
631
632
            input_ids = None

        hidden_states = self.language_model(
            input_ids,
            positions,
            intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )

        return hidden_states

633
634
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
635
636
        return logits

637
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
638
        loader = AutoWeightsLoader(self)
639
        loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)