qwen2_moe.py 25.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
27

28
from collections.abc import Iterable
29
from itertools import islice
30
from typing import Any
31
32
33
34

import torch
import torch.nn.functional as F
from torch import nn
35
from transformers import Qwen2MoeConfig
36

37
from vllm.attention.layer import Attention
38
from vllm.compilation.decorators import support_torch_compile
39
from vllm.config import CacheConfig, VllmConfig
40
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
41
from vllm.logger import init_logger
42
from vllm.model_executor.layers.activation import SiluAndMul
43
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
44
from vllm.model_executor.layers.layernorm import RMSNorm
45
46
47
48
49
50
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
51
from vllm.model_executor.layers.logits_processor import LogitsProcessor
52
from vllm.model_executor.layers.quantization import QuantizationConfig
53
54
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
55
56
57
    ParallelLMHead,
    VocabParallelEmbedding,
)
58
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
59
from vllm.sequence import IntermediateTensors
60

61
from .interfaces import SupportsLoRA, SupportsPP
62
63
64
65
66
67
68
69
from .utils import (
    AutoWeightsLoader,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
zhuwenwen's avatar
zhuwenwen committed
70
71
72
73
74
import os
import re
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf

75
76
logger = init_logger(__name__)

77
78
79
80
81
82
83

class Qwen2MoeMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
84
        quant_config: QuantizationConfig | None = None,
85
        reduce_results: bool = True,
86
        expert_gate: torch.nn.Linear | None = None,
87
        prefix: str = "",
88
89
90
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
91
92
            hidden_size,
            [intermediate_size] * 2,
93
            bias=False,
94
            quant_config=quant_config,
95
96
97
98
99
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
100
            bias=False,
101
            quant_config=quant_config,
102
103
104
            reduce_results=reduce_results,
            prefix=f"{prefix}.down_proj",
        )
105
        if hidden_act != "silu":
106
107
108
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
109
        self.act_fn = SiluAndMul()
110
        self.expert_gate = expert_gate
111
112
113

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
114
115
        out = self.act_fn(gate_up)
        out, _ = self.down_proj(out)
116

117
        if self.expert_gate is not None:
118
            out = F.sigmoid(self.expert_gate(x)[0]) * out
119
120

        return out
121
122
123
124
125


class Qwen2MoeSparseMoeBlock(nn.Module):
    def __init__(
        self,
126
        config: Qwen2MoeConfig,
127
        quant_config: QuantizationConfig | None = None,
128
        prefix: str = "",
129
130
131
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
132
133

        if self.tp_size > config.num_experts:
134
135
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
136
137
138
139
140
141
142
143
144
145
                f"the number of experts {config.num_experts}."
            )

        self.gate = ReplicatedLinear(
            config.hidden_size,
            config.num_experts,
            bias=False,
            quant_config=None,
            prefix=f"{prefix}.gate",
        )
146

147
148
149
150
151
152
153
        self.shared_expert_gate = ReplicatedLinear(
            config.hidden_size,
            1,
            bias=False,
            quant_config=None,
            prefix=f"{prefix}.shared_expert_gate",
        )
154

155
156
157
158
159
        if config.shared_expert_intermediate_size > 0:
            self.shared_expert = Qwen2MoeMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.shared_expert_intermediate_size,
                hidden_act=config.hidden_act,
160
                quant_config=quant_config,
161
162
                reduce_results=False,
                expert_gate=self.shared_expert_gate,
163
                prefix=f"{prefix}.shared_expert",
164
165
166
            )
        else:
            self.shared_expert = None
167
168
169
170
171
172
173
174
175
176
177
178

        self.experts = SharedFusedMoE(
            shared_experts=self.shared_expert,
            num_experts=config.num_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            reduce_results=False,
            renormalize=config.norm_topk_prob,
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
        )
179
180

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
181
182
183
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
        hidden_dim = hidden_states.shape[-1]
184
185
186
187
        hidden_states = hidden_states.view(-1, hidden_dim)

        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
188
189
190
        final_hidden_states = self.experts(
            hidden_states=hidden_states, router_logits=router_logits
        )
191
192
        if self.shared_expert is not None:
            final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
193
        if self.tp_size > 1:
194
            final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(  # noqa E501
195
196
                final_hidden_states
            )
197

198
        return final_hidden_states.view(orig_shape)
199
200
201
202
203
204
205
206


class Qwen2MoeAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
207
        rope_parameters: dict[str, Any] | None = None,
208
        max_position_embeddings: int = 8192,
209
210
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
211
        prefix: str = "",
212
        dual_chunk_attention_config: dict[str, Any] | None = None,
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings
235
        self.dual_chunk_attention_config = dual_chunk_attention_config
236

237
238
239
240
241
242
243
244
245
        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
246

247
248
249
250
251
252
253
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
254
255
256
257

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position_embeddings,
258
            rope_parameters=rope_parameters,
259
            dual_chunk_attention_config=dual_chunk_attention_config,
260
        )
261
262
263
264
265
266
267
268
269
270
271
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            **{
                "layer_idx": extract_layer_index(prefix),
                "dual_chunk_attention_config": dual_chunk_attention_config,
272
273
274
275
            }
            if dual_chunk_attention_config
            else {},
        )
276
277
278
279
280
281
282
283
284

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
285
        attn_output = self.attn(q, k, v)
286
287
288
289
290
291
292
        output, _ = self.o_proj(attn_output)
        return output


class Qwen2MoeDecoderLayer(nn.Module):
    def __init__(
        self,
293
        config: Qwen2MoeConfig,
294
295
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
296
        prefix: str = "",
297
298
299
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
300
301
302
303
        dual_chunk_attention_config = getattr(
            config, "dual_chunk_attention_config", None
        )
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
304
305
306
307
        self.self_attn = Qwen2MoeAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
308
            rope_parameters=config.rope_parameters,
309
            max_position_embeddings=max_position_embeddings,
310
            cache_config=cache_config,
311
            quant_config=quant_config,
312
            prefix=f"{prefix}.self_attn",
313
            dual_chunk_attention_config=dual_chunk_attention_config,
314
        )
315
316
317

        # Note: Qwen/Qwen2-57B-A14B-Instruct does not have
        # `mlp_only_layers` in the config.
318
        layer_idx = extract_layer_index(prefix)
319
320
321
        mlp_only_layers = (
            [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
        )
322
        if (layer_idx not in mlp_only_layers) and (
323
324
325
326
327
            config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
        ):
            self.mlp = Qwen2MoeSparseMoeBlock(
                config=config, quant_config=quant_config, prefix=f"{prefix}.mlp"
            )
328
        else:
329
330
331
332
333
334
335
336
337
338
339
            self.mlp = Qwen2MoeMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                prefix=f"{prefix}.mlp",
            )
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
340
341
342
343
344

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
345
        residual: torch.Tensor | None,
346
347
348
349
350
351
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
352
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
353
354
355
356
357
358
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
359
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
360
361
362
363
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


364
@support_torch_compile
365
class Qwen2MoeModel(nn.Module):
366
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
367
        super().__init__()
368
369
370
371
372

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

373
        self.vocab_size = config.vocab_size
374
        self.config = config
375
376
377
378

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
379
380
            quant_config=quant_config,
            prefix=f"{prefix}.embed_tokens",
381
        )
382
383
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
384
385
386
387
388
389
            lambda prefix: Qwen2MoeDecoderLayer(
                config=config,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=prefix,
            ),
390
391
            prefix=f"{prefix}.layers",
        )
392
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
393
394
395
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
zhuwenwen's avatar
zhuwenwen committed
396
397
398
399
400
401
402
403
404
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
               
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
405

406
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
407
        return self.embed_tokens(input_ids)
408
409
410

    def forward(
        self,
411
        input_ids: torch.Tensor | None,
412
        positions: torch.Tensor,
413
414
415
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
416
        if get_pp_group().is_first_rank:
417
418
419
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
420
                hidden_states = self.embed_input_ids(input_ids)
421
422
423
424
425
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
426
        for layer in islice(self.layers, self.start_layer, self.end_layer):
427
            hidden_states, residual = layer(positions, hidden_states, residual)
428
        if not get_pp_group().is_last_rank:
429
430
431
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
432
433
434
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

435
436
437
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
438
        return SharedFusedMoE.make_expert_params_mapping(
439
            self,
440
441
442
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
443
444
            num_experts=self.config.num_experts,
        )
445

446
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
447
448
449
450
451
452
453
454
455
456
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        params_dict = dict(self.named_parameters())
457
        loaded_params: set[str] = set()
458
        expert_params_mapping = self.get_expert_mapping()
459
        for name, loaded_weight in weights:
460
            for param_name, weight_name, shard_id in stacked_params_mapping:
461
                # Skip non-stacked layers and experts (experts handled below).
462
463
                if weight_name not in name:
                    continue
464
465
466
467
468
469
470
471
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if "mlp.experts" in name:
                    continue
472
473
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
474
475
476
                if (
                    name.endswith(".bias") or name.endswith("_bias")
                ) and name not in params_dict:
477
                    continue
478
479
480
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
481
482
483
                if name not in params_dict:
                    continue

484
485
486
487
488
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
489
490
491
492
493
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
494

495
496
497
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
498
                    # Skip loading extra bias for GPTQ models.
499
500
501
                    if (
                        name.endswith(".bias") or name.endswith("_bias")
                    ) and name not in params_dict:
502
                        continue
503
504
                    param = params_dict[name]
                    weight_loader = param.weight_loader
505
506
507
508
509
510
511
                    weight_loader(
                        param,
                        loaded_weight,
                        name,
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
512
513
514
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
515
516
517
                    if (
                        name.endswith(".bias") or name.endswith("_bias")
                    ) and name not in params_dict:
518
                        continue
519
520
521
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
522
523
524
                    # Remapping the name of FP8 kv-scale.
                    if name.endswith("kv_scale"):
                        remapped_kv_scale_name = name.replace(
525
526
                            ".kv_scale", ".attn.kv_scale"
                        )
527
                        if remapped_kv_scale_name not in params_dict:
528
                            logger.warning_once(
529
530
531
532
                                "Found kv_scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv_scale is not loaded.",  #  noqa: E501
                                name,
                                remapped_kv_scale_name,
                            )
533
534
535
                            continue
                        else:
                            name = remapped_kv_scale_name
536
537
538
539
540
541
                    # GGUF: make sure that shared_expert_gate is a 2D tensor.
                    if (
                        "mlp.shared_expert_gate" in name
                        and len(loaded_weight.shape) == 1
                    ):
                        loaded_weight = loaded_weight[None, :]
542
                    param = params_dict[name]
543
544
545
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
546
                    weight_loader(param, loaded_weight)
547
            loaded_params.add(name)
zhuwenwen's avatar
zhuwenwen committed
548
549
550
551
552
553
554
555
556
557
558
559

        if self.use_llama_nn and self.quant_method is None:
            lay_key_words = [
                "gate_up_proj.weight",
                "down_proj.weight",
                "mlp.gate.weight",
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "lm_head.weight",
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
560
561
            # lay_qkv_words = ["self_attn.qkv_proj.weight"]   
            # qkv_words = "|".join(lay_qkv_words)  
zhuwenwen's avatar
zhuwenwen committed
562
            
zhuwenwen's avatar
zhuwenwen committed
563
564
            # lay_qkv_bias_words = ["self_attn.qkv_proj.bias"]   
            # qkv_bias_words = "|".join(lay_qkv_bias_words) 
zhuwenwen's avatar
zhuwenwen committed
565
            
zhuwenwen's avatar
zhuwenwen committed
566
567
            for layername in loaded_params:
                weight = params_dict[layername]
zhuwenwen's avatar
zhuwenwen committed
568
                os.environ['LM_NN'] = '0' 
zhuwenwen's avatar
zhuwenwen committed
569
570
571
572
573
                # if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
                #     weight.data = pad_weight(weight.data, 32)
                    
                matches = re.findall(combined_words, layername)
                if matches:   
zhuwenwen's avatar
zhuwenwen committed
574
575
                    # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                    #     weight.data = pad_weight(weight.data, 32)  
zhuwenwen's avatar
zhuwenwen committed
576
577
578
579
580
581
582
583
584
585
586
                    
                    # if self.use_fa_pad and (re.findall(qkv_words, layername)):
                    #     if not gemm_bank_conf(weight.data.shape[0]):
                    #         weight.data = pad_weight(weight.data, 32)
                        
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
zhuwenwen's avatar
zhuwenwen committed
587
588
                    weight.data=weight.data.reshape(ori_shape[1],-1)
            
589
        return loaded_params
590
591


592
class Qwen2MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
593
    fall_back_to_pt_during_load = False
594
595
596
597
598
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
599
        ]
600
    }
601
602
603
604
605
606
607

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
608
609
610
611
612
        # Only perform the following mapping when Qwen2MoeMLP exists
        if (
            getattr(config, "mlp_only_layers", [])
            or config.shared_expert_intermediate_size > 0
        ):
613
            self.packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"]
614

615
616
617
618
619
620
621
622
623
        self.model = Qwen2MoeModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
624
625
626
627
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
628
629
            self.model.make_empty_intermediate_tensors
        )
630

631
632
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
633
634
635

    def forward(
        self,
636
        input_ids: torch.Tensor | None,
637
        positions: torch.Tensor,
638
639
640
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
641
642
643
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
644
645
646
647
648
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
649
    ) -> torch.Tensor | None:
650
        logits = self.logits_processor(self.lm_head, hidden_states)
651
652
        return logits

653
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
654
        loader = AutoWeightsLoader(self)
655
        return loader.load_weights(weights)
656
657
658

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        return self.model.get_expert_mapping()