mixtral.py 21.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Pierre Stock's avatar
Pierre Stock committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
26

27
28
import typing
from collections.abc import Callable, Iterable
29
from itertools import islice
Pierre Stock's avatar
Pierre Stock committed
30
31
32

import torch
from torch import nn
33
from transformers import MixtralConfig
Pierre Stock's avatar
Pierre Stock committed
34

35
from vllm.attention.layer import Attention
36
from vllm.compilation.decorators import support_torch_compile
37
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
38
39
40
41
42
from vllm.distributed import (
    get_ep_group,
    get_pp_group,
    get_tensor_model_parallel_world_size,
)
43
from vllm.model_executor.layers.fused_moe import FusedMoE
Pierre Stock's avatar
Pierre Stock committed
44
from vllm.model_executor.layers.layernorm import RMSNorm
45
46
47
48
49
from vllm.model_executor.layers.linear import (
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
50
from vllm.model_executor.layers.logits_processor import LogitsProcessor
51
from vllm.model_executor.layers.quantization import QuantizationConfig
52
from vllm.model_executor.layers.rotary_embedding import get_rope
Pierre Stock's avatar
Pierre Stock committed
53
from vllm.model_executor.layers.vocab_parallel_embedding import (
54
55
56
    ParallelLMHead,
    VocabParallelEmbedding,
)
57
from vllm.model_executor.model_loader.weight_utils import (
58
59
60
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
61
from vllm.sequence import IntermediateTensors
Pierre Stock's avatar
Pierre Stock committed
62

63
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
64
65
66
67
68
69
70
71
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
72

Pierre Stock's avatar
Pierre Stock committed
73

Philipp Moritz's avatar
Philipp Moritz committed
74
75
76
77
78
79
80
81
class MixtralMoE(nn.Module):
    """A tensor-parallel MoE implementation for Mixtral that shards each expert
    across all ranks.

    Each expert's weights are sharded across all ranks and a fused MoE
    kernel is used for the forward pass, and finally we reduce the outputs
    across ranks.
    """
82

83
84
85
86
87
88
    def __init__(
        self,
        num_experts: int,
        top_k: int,
        hidden_size: int,
        intermediate_size: int,
89
90
91
92
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
        tp_size: int | None = None,
        dp_size: int | None = None,
93
94
95
        prefix: str = "",
        enable_eplb: bool = False,
    ):
96
        super().__init__()
Philipp Moritz's avatar
Philipp Moritz committed
97
        self.hidden_size = hidden_size
98

99
        self.ep_group = get_ep_group().device_group
100
        self.ep_rank = get_ep_group().rank_in_group
101
102
103
104
105
106
107
108
109
        self.ep_size = self.ep_group.size()

        # Expert Parallelism Load balancing settings.
        vllm_config = get_current_vllm_config()
        parallel_config = vllm_config.parallel_config
        self.enable_eplb = enable_eplb

        self.n_routed_experts = num_experts
        self.n_logical_experts = num_experts
110
111
        self.n_redundant_experts = parallel_config.eplb_config.num_redundant_experts
        self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
112
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size
113
114
115
116
        self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
        self.physical_expert_end = (
            self.physical_expert_start + self.n_local_physical_experts
        )
117

118
        # Gate always runs at half / full precision for now.
119

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        self.gate = ReplicatedLinear(
            hidden_size,
            num_experts,
            bias=False,
            params_dtype=params_dtype,
            quant_config=None,
            prefix=f"{prefix}.gate",
        )

        self.experts = FusedMoE(
            num_experts=num_experts,
            top_k=top_k,
            hidden_size=hidden_size,
            intermediate_size=intermediate_size,
            params_dtype=params_dtype,
            reduce_results=True,
            renormalize=True,
            quant_config=quant_config,
            tp_size=tp_size,
            dp_size=dp_size,
            prefix=f"{prefix}.experts",
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
        )
144

145
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
146
147
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
Philipp Moritz's avatar
Philipp Moritz committed
148
        hidden_states = hidden_states.view(-1, self.hidden_size)
149
        # router_logits: (num_tokens, n_experts)
150
        router_logits, _ = self.gate(hidden_states)
151
        final_hidden_states = self.experts(hidden_states, router_logits)
152
        return final_hidden_states.view(orig_shape)
Pierre Stock's avatar
Pierre Stock committed
153
154
155


class MixtralAttention(nn.Module):
156
157
    def __init__(
        self,
158
        config: MixtralConfig,
159
160
161
162
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position: int = 4096 * 32,
163
164
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
165
        prefix: str = "",
166
    ) -> None:
Pierre Stock's avatar
Pierre Stock committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
183
        # MixtralConfig has an optional head_dim argument
184
185
186
        self.head_dim = getattr(config, "head_dim", None)
        if self.head_dim is None:
            self.head_dim = self.hidden_size // self.total_num_heads
Pierre Stock's avatar
Pierre Stock committed
187
188
189
190
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5

191
        self.qkv_proj = QKVParallelLinear(
Pierre Stock's avatar
Pierre Stock committed
192
193
194
195
196
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
197
            quant_config=quant_config,
198
            prefix=f"{prefix}.qkv_proj",
Pierre Stock's avatar
Pierre Stock committed
199
        )
200
        self.o_proj = RowParallelLinear(
Pierre Stock's avatar
Pierre Stock committed
201
202
203
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
204
            quant_config=quant_config,
205
            prefix=f"{prefix}.o_proj",
Pierre Stock's avatar
Pierre Stock committed
206
207
208
209
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position,
210
            rope_parameters=config.rope_parameters,
211
            is_neox_style=True,
Pierre Stock's avatar
Pierre Stock committed
212
        )
213
214
215
216
217
218
219
220
221
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
Pierre Stock's avatar
Pierre Stock committed
222
223
224
225
226
227

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
228
        qkv, _ = self.qkv_proj(hidden_states)
Pierre Stock's avatar
Pierre Stock committed
229
230
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
231
        attn_output = self.attn(q, k, v)
232
        output, _ = self.o_proj(attn_output)
Pierre Stock's avatar
Pierre Stock committed
233
234
235
236
237
238
        return output


class MixtralDecoderLayer(nn.Module):
    def __init__(
        self,
239
        config: MixtralConfig,
240
241
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
242
        prefix: str = "",
243
        enable_eplb: bool = False,
Pierre Stock's avatar
Pierre Stock committed
244
245
246
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
247
        self.self_attn = MixtralAttention(
248
            config=config,
Pierre Stock's avatar
Pierre Stock committed
249
250
251
252
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
253
            cache_config=cache_config,
254
            quant_config=quant_config,
255
256
            prefix=f"{prefix}.self_attn",
        )
Philipp Moritz's avatar
Philipp Moritz committed
257
258
259
260
        self.block_sparse_moe = MixtralMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
261
            intermediate_size=config.intermediate_size,
262
            quant_config=quant_config,
263
            prefix=f"{prefix}.block_sparse_moe",
264
265
266
267
268
269
            enable_eplb=enable_eplb,
        )
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
Pierre Stock's avatar
Pierre Stock committed
270
271
272
273

    def forward(
        self,
        positions: torch.Tensor,
274
        hidden_states: torch.Tensor,
275
        residual: torch.Tensor | None,
Pierre Stock's avatar
Pierre Stock committed
276
    ) -> torch.Tensor:
277
278
279
280
281
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
282
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
283
        hidden_states = self.self_attn(
Pierre Stock's avatar
Pierre Stock committed
284
            positions=positions,
285
            hidden_states=hidden_states,
Pierre Stock's avatar
Pierre Stock committed
286
287
        )

288
        # Fully Connected
289
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
290
291
        hidden_states = self.block_sparse_moe(hidden_states)
        return hidden_states, residual
Pierre Stock's avatar
Pierre Stock committed
292

293

294
@support_torch_compile
295
class MixtralModel(nn.Module):
296
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Pierre Stock's avatar
Pierre Stock committed
297
        super().__init__()
298
299
300
301

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
302

303
        parallel_config = vllm_config.parallel_config
304

305
306
        self.config = config
        self.quant_config = quant_config
307
308

        self.vocab_size = config.vocab_size
309
        self.org_vocab_size = config.vocab_size
310
311

        self.embed_tokens = VocabParallelEmbedding(
312
            self.vocab_size,
Pierre Stock's avatar
Pierre Stock committed
313
314
            config.hidden_size,
        )
315

316
        self.enable_eplb = parallel_config.enable_eplb
317
        self.num_redundant_experts = parallel_config.eplb_config.num_redundant_experts
318

319
        self.start_layer, self.end_layer, self.layers = make_layers(
320
321
            config.num_hidden_layers,
            lambda prefix: MixtralDecoderLayer(
322
323
324
325
326
                config,
                cache_config,
                quant_config=quant_config,
                prefix=prefix,
                enable_eplb=self.enable_eplb,
327
            ),
328
329
            prefix=f"{prefix}.layers",
        )
330

331
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
332
333
334
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
Pierre Stock's avatar
Pierre Stock committed
335

336
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
337
338
        return self.embed_tokens(input_ids)

Pierre Stock's avatar
Pierre Stock committed
339
340
341
342
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
343
344
345
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
346
        if get_pp_group().is_first_rank:
347
348
349
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
350
                hidden_states = self.embed_input_ids(input_ids)
351
352
353
354
355
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
356
        for layer in islice(self.layers, self.start_layer, self.end_layer):
357
            hidden_states, residual = layer(positions, hidden_states, residual)
358
        if not get_pp_group().is_last_rank:
359
360
361
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
362
363
364
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

365
366
367
368
369
370
371
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        return FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="w1",
            ckpt_down_proj_name="w2",
            ckpt_up_proj_name="w3",
372
            num_experts=self.config.num_local_experts,
373
374
            num_redundant_experts=self.num_redundant_experts,
        )
375

376
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
Pierre Stock's avatar
Pierre Stock committed
377
378
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
379
380
381
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
Pierre Stock's avatar
Pierre Stock committed
382
        ]
383

Pierre Stock's avatar
Pierre Stock committed
384
        params_dict = dict(self.named_parameters())
385
        loaded_params: set[str] = set()
386
        expert_params_mapping = self.get_expert_mapping()
387
        for name, loaded_weight in weights:
388
389
390
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
391
                # Loading kv cache quantization scales
392
                param = params_dict[scale_name]
393
394
395
396
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
397
398
399
400
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue

401
            for param_name, weight_name, shard_id in stacked_params_mapping:
Pierre Stock's avatar
Pierre Stock committed
402
403
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
404
405
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
406
407
408
                if (
                    name.endswith(".bias") or name.endswith("_bias")
                ) and name not in params_dict:
CHU Tianxiang's avatar
CHU Tianxiang committed
409
                    continue
410
411
412
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
413
414
415
416
417
                if name.endswith("scale"):
                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue
CHU Tianxiang's avatar
CHU Tianxiang committed
418
                param = params_dict[name]
Pierre Stock's avatar
Pierre Stock committed
419
420
421
422
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
423
                is_expert_weight = False
424
425
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
426

Philipp Moritz's avatar
Philipp Moritz committed
427
428
                    if weight_name not in name:
                        continue
429
430
431
432

                    is_expert_weight = True
                    name_mapped = name.replace(weight_name, param_name)

433
                    # Skip layers on other devices.
434
                    if is_pp_missing_parameter(name_mapped, self):
435
                        continue
436

437
438
439
                    if (
                        name_mapped.endswith(".bias") or name_mapped.endswith("_bias")
                    ) and name_mapped not in params_dict:
440
                        continue
441
442

                    param = params_dict[name_mapped]
443
444
445
446
447
448
449
450
451
452
453
                    weight_loader = typing.cast(
                        Callable[..., bool], param.weight_loader
                    )
                    success = weight_loader(
                        param,
                        loaded_weight,
                        name_mapped,
                        shard_id=shard_id,
                        expert_id=expert_id,
                        return_success=True,
                    )
454
455
456
                    if success:
                        name = name_mapped
                        break
Philipp Moritz's avatar
Philipp Moritz committed
457
                else:
458
459
                    if is_expert_weight:
                        continue
Philipp Moritz's avatar
Philipp Moritz committed
460
                    # Skip loading extra bias for GPTQ models.
461
462
463
                    if (
                        name.endswith(".bias") or name.endswith("_bias")
                    ) and name not in params_dict:
Philipp Moritz's avatar
Philipp Moritz committed
464
                        continue
465
466
467
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
468
                    # Remapping the name of FP8 kv-scale.
469
470
471
472
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue

Philipp Moritz's avatar
Philipp Moritz committed
473
                    param = params_dict[name]
474
475
476
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
Philipp Moritz's avatar
Philipp Moritz committed
477
                    weight_loader(param, loaded_weight)
478
479
            loaded_params.add(name)
        return loaded_params
480
481


482
class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
    fall_back_to_pt_during_load = False

    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
503

504
        self.config = config
505

506
507
        self.quant_config = quant_config

508
509
510
        self.model = MixtralModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
511

512
        self.lm_head = ParallelLMHead(
513
            config.vocab_size,
514
515
            config.hidden_size,
            quant_config=quant_config,
516
            prefix=maybe_prefix(prefix, "lm_head"),
517
518
519
        )
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
520
        self.logits_processor = LogitsProcessor(config.vocab_size)
521
        self.make_empty_intermediate_tensors = (
522
523
            self.model.make_empty_intermediate_tensors
        )
524

525
        self.expert_weights = []
526
        self.moe_layers = []
527
528
529
530
531
532
533
        example_moe = None

        for layer in self.model.layers:
            if isinstance(layer, PPMissingLayer):
                continue
            assert isinstance(layer, MixtralDecoderLayer)
            if hasattr(layer, "block_sparse_moe") and isinstance(
534
535
                layer.block_sparse_moe, MixtralMoE
            ):
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
                example_moe = layer.block_sparse_moe
                self.moe_layers.append(layer.block_sparse_moe.experts)

        self.num_moe_layers = len(self.moe_layers)

        if example_moe is None:
            raise RuntimeError("No MixtralMoE layer found  in model.layers.")

        self.num_logical_experts = example_moe.n_logical_experts
        self.num_physical_experts = example_moe.n_physical_experts
        self.num_local_physical_experts = example_moe.n_local_physical_experts
        self.num_routed_experts = example_moe.n_routed_experts
        self.num_redundant_experts = example_moe.n_redundant_experts
        self.num_expert_groups = 1
        self.num_shared_experts = 0

    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
560
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
561
562
        for layer in self.model.layers:
            if hasattr(layer, "block_sparse_moe") and isinstance(
563
564
                layer.block_sparse_moe, MixtralMoE
            ):
565
566
567
568
569
570
                moe = layer.block_sparse_moe
                moe.n_local_physical_experts = num_local_physical_experts
                moe.n_physical_experts = num_physical_experts
                moe.n_redundant_experts = self.num_redundant_experts
                moe.experts.update_expert_map()

571
572
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
573
574
575
576
577

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
578
579
580
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
581
582
583
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
584
585
586
587
588
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
589
    ) -> torch.Tensor | None:
590
        logits = self.logits_processor(self.lm_head, hidden_states)
591
592
        return logits

593
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
594
        loader = AutoWeightsLoader(self)
595
        return loader.load_weights(weights)
596
597
598

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        return self.model.get_expert_mapping()