mixtral.py 22 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Pierre Stock's avatar
Pierre Stock committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
26

27
28
import typing
from collections.abc import Callable, Iterable
29
from itertools import islice
Pierre Stock's avatar
Pierre Stock committed
30
31
32

import torch
from torch import nn
33
from transformers import MixtralConfig
Pierre Stock's avatar
Pierre Stock committed
34

35
from vllm.compilation.decorators import support_torch_compile
36
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
37
38
39
40
41
from vllm.distributed import (
    get_ep_group,
    get_pp_group,
    get_tensor_model_parallel_world_size,
)
42
from vllm.model_executor.layers.attention import Attention
43
44
45
46
from vllm.model_executor.layers.fused_moe import (
    FusedMoE,
    fused_moe_make_expert_params_mapping,
)
Pierre Stock's avatar
Pierre Stock committed
47
from vllm.model_executor.layers.layernorm import RMSNorm
48
49
50
51
52
from vllm.model_executor.layers.linear import (
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
53
from vllm.model_executor.layers.logits_processor import LogitsProcessor
54
from vllm.model_executor.layers.quantization import QuantizationConfig
55
from vllm.model_executor.layers.rotary_embedding import get_rope
Pierre Stock's avatar
Pierre Stock committed
56
from vllm.model_executor.layers.vocab_parallel_embedding import (
57
58
59
    ParallelLMHead,
    VocabParallelEmbedding,
)
60
from vllm.model_executor.model_loader.weight_utils import (
61
62
63
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
64
from vllm.sequence import IntermediateTensors
Pierre Stock's avatar
Pierre Stock committed
65

66
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
67
68
69
70
71
72
73
74
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
75

Pierre Stock's avatar
Pierre Stock committed
76

Philipp Moritz's avatar
Philipp Moritz committed
77
78
79
80
81
82
83
84
class MixtralMoE(nn.Module):
    """A tensor-parallel MoE implementation for Mixtral that shards each expert
    across all ranks.

    Each expert's weights are sharded across all ranks and a fused MoE
    kernel is used for the forward pass, and finally we reduce the outputs
    across ranks.
    """
85

86
87
88
89
90
91
    def __init__(
        self,
        num_experts: int,
        top_k: int,
        hidden_size: int,
        intermediate_size: int,
92
93
94
95
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
        tp_size: int | None = None,
        dp_size: int | None = None,
96
97
98
        prefix: str = "",
        enable_eplb: bool = False,
    ):
99
        super().__init__()
Philipp Moritz's avatar
Philipp Moritz committed
100
        self.hidden_size = hidden_size
101

102
        self.ep_group = get_ep_group().device_group
103
        self.ep_rank = get_ep_group().rank_in_group
104
105
106
107
108
109
110
111
112
        self.ep_size = self.ep_group.size()

        # Expert Parallelism Load balancing settings.
        vllm_config = get_current_vllm_config()
        parallel_config = vllm_config.parallel_config
        self.enable_eplb = enable_eplb

        self.n_routed_experts = num_experts
        self.n_logical_experts = num_experts
113
114
        self.n_redundant_experts = parallel_config.eplb_config.num_redundant_experts
        self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
115
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size
116
117
118
119
        self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
        self.physical_expert_end = (
            self.physical_expert_start + self.n_local_physical_experts
        )
120

121
        # Gate always runs at half / full precision for now.
122

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        self.gate = ReplicatedLinear(
            hidden_size,
            num_experts,
            bias=False,
            params_dtype=params_dtype,
            quant_config=None,
            prefix=f"{prefix}.gate",
        )

        self.experts = FusedMoE(
            num_experts=num_experts,
            top_k=top_k,
            hidden_size=hidden_size,
            intermediate_size=intermediate_size,
            params_dtype=params_dtype,
            renormalize=True,
            quant_config=quant_config,
            tp_size=tp_size,
            dp_size=dp_size,
            prefix=f"{prefix}.experts",
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
        )
146

147
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
148
149
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
Philipp Moritz's avatar
Philipp Moritz committed
150
        hidden_states = hidden_states.view(-1, self.hidden_size)
151
        # router_logits: (num_tokens, n_experts)
152
        router_logits, _ = self.gate(hidden_states)
153
        final_hidden_states = self.experts(hidden_states, router_logits)
154
        return final_hidden_states.view(orig_shape)
Pierre Stock's avatar
Pierre Stock committed
155
156
157


class MixtralAttention(nn.Module):
158
159
    def __init__(
        self,
160
        config: MixtralConfig,
161
162
163
164
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position: int = 4096 * 32,
165
166
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
167
        prefix: str = "",
168
    ) -> None:
Pierre Stock's avatar
Pierre Stock committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
185
        # MixtralConfig has an optional head_dim argument
186
187
188
        self.head_dim = getattr(config, "head_dim", None)
        if self.head_dim is None:
            self.head_dim = self.hidden_size // self.total_num_heads
Pierre Stock's avatar
Pierre Stock committed
189
190
191
192
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5

193
        self.qkv_proj = QKVParallelLinear(
Pierre Stock's avatar
Pierre Stock committed
194
195
196
197
198
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
199
            quant_config=quant_config,
200
            prefix=f"{prefix}.qkv_proj",
Pierre Stock's avatar
Pierre Stock committed
201
        )
202
        self.o_proj = RowParallelLinear(
Pierre Stock's avatar
Pierre Stock committed
203
204
205
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
206
            quant_config=quant_config,
207
            prefix=f"{prefix}.o_proj",
Pierre Stock's avatar
Pierre Stock committed
208
209
210
211
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position,
212
            rope_parameters=config.rope_parameters,
213
            is_neox_style=True,
Pierre Stock's avatar
Pierre Stock committed
214
        )
215
216
217
218
219
220
221
222
223
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
Pierre Stock's avatar
Pierre Stock committed
224
225
226
227
228
229

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
230
        qkv, _ = self.qkv_proj(hidden_states)
Pierre Stock's avatar
Pierre Stock committed
231
232
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
233
        attn_output = self.attn(q, k, v)
234
        output, _ = self.o_proj(attn_output)
Pierre Stock's avatar
Pierre Stock committed
235
236
237
238
239
240
        return output


class MixtralDecoderLayer(nn.Module):
    def __init__(
        self,
241
        config: MixtralConfig,
242
243
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
244
        prefix: str = "",
245
        enable_eplb: bool = False,
Pierre Stock's avatar
Pierre Stock committed
246
247
248
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
249
        self.self_attn = MixtralAttention(
250
            config=config,
Pierre Stock's avatar
Pierre Stock committed
251
252
253
254
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
255
            cache_config=cache_config,
256
            quant_config=quant_config,
257
258
            prefix=f"{prefix}.self_attn",
        )
Philipp Moritz's avatar
Philipp Moritz committed
259
260
261
262
        self.block_sparse_moe = MixtralMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
263
            intermediate_size=config.intermediate_size,
264
            quant_config=quant_config,
265
            prefix=f"{prefix}.block_sparse_moe",
266
267
268
269
270
271
            enable_eplb=enable_eplb,
        )
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
Pierre Stock's avatar
Pierre Stock committed
272
273
274
275

    def forward(
        self,
        positions: torch.Tensor,
276
        hidden_states: torch.Tensor,
277
        residual: torch.Tensor | None,
Pierre Stock's avatar
Pierre Stock committed
278
    ) -> torch.Tensor:
279
280
281
282
283
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
284
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
285
        hidden_states = self.self_attn(
Pierre Stock's avatar
Pierre Stock committed
286
            positions=positions,
287
            hidden_states=hidden_states,
Pierre Stock's avatar
Pierre Stock committed
288
289
        )

290
        # Fully Connected
291
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
292
293
        hidden_states = self.block_sparse_moe(hidden_states)
        return hidden_states, residual
Pierre Stock's avatar
Pierre Stock committed
294

295

296
@support_torch_compile
297
class MixtralModel(nn.Module):
298
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Pierre Stock's avatar
Pierre Stock committed
299
        super().__init__()
300
301
302
303

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
304

305
        parallel_config = vllm_config.parallel_config
306

307
308
        self.config = config
        self.quant_config = quant_config
309
310

        self.vocab_size = config.vocab_size
311
        self.org_vocab_size = config.vocab_size
312
313

        self.embed_tokens = VocabParallelEmbedding(
314
            self.vocab_size,
Pierre Stock's avatar
Pierre Stock committed
315
316
            config.hidden_size,
        )
317

318
        self.enable_eplb = parallel_config.enable_eplb
319
        self.num_redundant_experts = parallel_config.eplb_config.num_redundant_experts
320

321
        self.start_layer, self.end_layer, self.layers = make_layers(
322
323
            config.num_hidden_layers,
            lambda prefix: MixtralDecoderLayer(
324
325
326
327
328
                config,
                cache_config,
                quant_config=quant_config,
                prefix=prefix,
                enable_eplb=self.enable_eplb,
329
            ),
330
331
            prefix=f"{prefix}.layers",
        )
332

333
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
334
335
336
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
Pierre Stock's avatar
Pierre Stock committed
337

338
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
339
340
        return self.embed_tokens(input_ids)

Pierre Stock's avatar
Pierre Stock committed
341
342
    def forward(
        self,
343
        input_ids: torch.Tensor | None,
Pierre Stock's avatar
Pierre Stock committed
344
        positions: torch.Tensor,
345
346
347
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
348
        if get_pp_group().is_first_rank:
349
350
351
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
352
                hidden_states = self.embed_input_ids(input_ids)
353
354
355
356
357
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
358
        for layer in islice(self.layers, self.start_layer, self.end_layer):
359
            hidden_states, residual = layer(positions, hidden_states, residual)
360
        if not get_pp_group().is_last_rank:
361
362
363
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
364
365
366
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

367
368
369
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
370
        return fused_moe_make_expert_params_mapping(
371
            self,
372
373
374
            ckpt_gate_proj_name="w1",
            ckpt_down_proj_name="w2",
            ckpt_up_proj_name="w3",
375
            num_experts=self.config.num_local_experts,
376
377
            num_redundant_experts=self.num_redundant_experts,
        )
378

379
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
Pierre Stock's avatar
Pierre Stock committed
380
381
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
382
383
384
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
Pierre Stock's avatar
Pierre Stock committed
385
        ]
386

Pierre Stock's avatar
Pierre Stock committed
387
        params_dict = dict(self.named_parameters())
388
        loaded_params: set[str] = set()
389
        expert_params_mapping = self.get_expert_mapping()
390
        for name, loaded_weight in weights:
391
392
393
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
394
                # Loading kv cache quantization scales
395
                param = params_dict[scale_name]
396
397
398
399
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
400
401
402
403
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue

404
            for param_name, weight_name, shard_id in stacked_params_mapping:
Pierre Stock's avatar
Pierre Stock committed
405
406
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
407
408
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
409
410
411
                if (
                    name.endswith(".bias") or name.endswith("_bias")
                ) and name not in params_dict:
CHU Tianxiang's avatar
CHU Tianxiang committed
412
                    continue
413
414
415
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
416
417
418
419
420
                if name.endswith("scale"):
                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue
CHU Tianxiang's avatar
CHU Tianxiang committed
421
                param = params_dict[name]
Pierre Stock's avatar
Pierre Stock committed
422
423
424
425
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
426
                is_expert_weight = False
427
428
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
429

Philipp Moritz's avatar
Philipp Moritz committed
430
431
                    if weight_name not in name:
                        continue
432
433
434
435

                    is_expert_weight = True
                    name_mapped = name.replace(weight_name, param_name)

436
                    # Skip layers on other devices.
437
                    if is_pp_missing_parameter(name_mapped, self):
438
                        continue
439

440
441
442
                    if (
                        name_mapped.endswith(".bias") or name_mapped.endswith("_bias")
                    ) and name_mapped not in params_dict:
443
                        continue
444
445

                    param = params_dict[name_mapped]
446
447
448
449
450
451
452
453
454
455
456
                    weight_loader = typing.cast(
                        Callable[..., bool], param.weight_loader
                    )
                    success = weight_loader(
                        param,
                        loaded_weight,
                        name_mapped,
                        shard_id=shard_id,
                        expert_id=expert_id,
                        return_success=True,
                    )
457
458
459
                    if success:
                        name = name_mapped
                        break
Philipp Moritz's avatar
Philipp Moritz committed
460
                else:
461
462
                    if is_expert_weight:
                        continue
Philipp Moritz's avatar
Philipp Moritz committed
463
                    # Skip loading extra bias for GPTQ models.
464
465
466
                    if (
                        name.endswith(".bias") or name.endswith("_bias")
                    ) and name not in params_dict:
Philipp Moritz's avatar
Philipp Moritz committed
467
                        continue
468
469
470
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
471
                    # Remapping the name of FP8 kv-scale.
472
473
474
475
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue

Philipp Moritz's avatar
Philipp Moritz committed
476
                    param = params_dict[name]
477
478
479
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
Philipp Moritz's avatar
Philipp Moritz committed
480
                    weight_loader(param, loaded_weight)
481
482
            loaded_params.add(name)
        return loaded_params
483
484


485
class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
    fall_back_to_pt_during_load = False

    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
506

507
        self.config = config
508

509
510
        self.quant_config = quant_config

511
512
513
        self.model = MixtralModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
514

515
        self.lm_head = ParallelLMHead(
516
            config.vocab_size,
517
518
            config.hidden_size,
            quant_config=quant_config,
519
            prefix=maybe_prefix(prefix, "lm_head"),
520
521
522
        )
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
523
        self.logits_processor = LogitsProcessor(config.vocab_size)
524
        self.make_empty_intermediate_tensors = (
525
526
            self.model.make_empty_intermediate_tensors
        )
527

528
        self.expert_weights = []
529
        self.moe_layers = []
530
531
532
533
534
535
536
        example_moe = None

        for layer in self.model.layers:
            if isinstance(layer, PPMissingLayer):
                continue
            assert isinstance(layer, MixtralDecoderLayer)
            if hasattr(layer, "block_sparse_moe") and isinstance(
537
538
                layer.block_sparse_moe, MixtralMoE
            ):
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
                example_moe = layer.block_sparse_moe
                self.moe_layers.append(layer.block_sparse_moe.experts)

        self.num_moe_layers = len(self.moe_layers)

        if example_moe is None:
            raise RuntimeError("No MixtralMoE layer found  in model.layers.")

        self.num_logical_experts = example_moe.n_logical_experts
        self.num_physical_experts = example_moe.n_physical_experts
        self.num_local_physical_experts = example_moe.n_local_physical_experts
        self.num_routed_experts = example_moe.n_routed_experts
        self.num_redundant_experts = example_moe.n_redundant_experts
        self.num_expert_groups = 1
        self.num_shared_experts = 0

    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
563
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
564
565
        for layer in self.model.layers:
            if hasattr(layer, "block_sparse_moe") and isinstance(
566
567
                layer.block_sparse_moe, MixtralMoE
            ):
568
569
570
571
572
573
                moe = layer.block_sparse_moe
                moe.n_local_physical_experts = num_local_physical_experts
                moe.n_physical_experts = num_physical_experts
                moe.n_redundant_experts = self.num_redundant_experts
                moe.experts.update_expert_map()

574
575
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
576
577
578

    def forward(
        self,
579
        input_ids: torch.Tensor | None,
580
        positions: torch.Tensor,
581
582
583
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
584
585
586
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
587
588
589
590
591
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
592
    ) -> torch.Tensor | None:
593
        logits = self.logits_processor(self.lm_head, hidden_states)
594
595
        return logits

596
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
597
        loader = AutoWeightsLoader(self)
598
        return loader.load_weights(weights)
599
600
601

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        return self.model.get_expert_mapping()