qwen3_moe.py 38.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
25

26
27
import typing
from collections.abc import Callable, Iterable
28
from itertools import islice
29
from typing import Any
30

zhuwenwen's avatar
zhuwenwen committed
31
32
import os
import re
33
import torch
34
import torch.nn.functional as F
35
36
from torch import nn

37
from vllm.attention.layer import Attention
38
from vllm.compilation.decorators import support_torch_compile
39
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
40
41
42
43
44
45
from vllm.distributed import (
    get_ep_group,
    get_pp_group,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
46
47
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
48
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
49
50
51
52
53
54
55
56
57
58
59
60
61
try:
    from vllm.model_executor.layers.fused_moe.router_capture import (
        maybe_record_router_logits,
    )
except ImportError:

    def maybe_record_router_logits(
        *,
        layer_name: str,
        router_logits: torch.Tensor,
        top_k: int,
    ) -> None:
        return None
62
from vllm.model_executor.layers.layernorm import RMSNorm
63
64
65
66
67
68
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
69
70
71
72
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
73
74
75
    ParallelLMHead,
    VocabParallelEmbedding,
)
76
from vllm.model_executor.model_loader.weight_utils import (
77
78
79
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
80
from vllm.model_executor.models.utils import sequence_parallel_chunk
81
82
from vllm.sequence import IntermediateTensors

83
from .interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP
84
85
86
87
88
89
90
91
92
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
93

94
from vllm import envs
zhuwenwen's avatar
zhuwenwen committed
95
96
97
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
from vllm.utils import W8a8GetCacheJSON
98
from vllm.utils.torch_utils import direct_register_custom_op
99
100
101
102
103
104
105
106
107
logger = init_logger(__name__)


class Qwen3MoeMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
108
        quant_config: QuantizationConfig | None = None,
109
        reduce_results: bool = True,
110
        expert_gate: torch.nn.Linear | None = None,
111
112
113
114
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
115
116
            hidden_size,
            [intermediate_size] * 2,
117
118
            bias=False,
            quant_config=quant_config,
119
120
121
122
123
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
124
125
            bias=False,
            quant_config=quant_config,
126
127
128
            reduce_results=reduce_results,
            prefix=f"{prefix}.down_proj",
        )
129
        if hidden_act != "silu":
130
131
132
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
133
        self.act_fn = SiluAndMul()
134
        self.expert_gate = expert_gate
135
136
137

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
138
139
140
141
142
143
144
        out = self.act_fn(gate_up)
        out, _ = self.down_proj(out)

        if self.expert_gate is not None:
            out = F.sigmoid(self.expert_gate(x)[0]) * out

        return out
145
146
147
148
149


class Qwen3MoeSparseMoeBlock(nn.Module):
    def __init__(
        self,
150
        vllm_config: VllmConfig,
151
152
153
        prefix: str = "",
    ):
        super().__init__()
154

155
        config = vllm_config.model_config.hf_text_config
156
157
158
        parallel_config = vllm_config.parallel_config
        quant_config = vllm_config.quant_config

159
160
        self.tp_size = get_tensor_model_parallel_world_size()

161
        self.ep_group = get_ep_group().device_group
162
        self.ep_rank = get_ep_group().rank_in_group
163
164
165
        self.ep_size = self.ep_group.size()
        self.n_routed_experts = config.num_experts

166
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
167
168
        self._router_top_k = int(config.num_experts_per_tok)
        self._router_capture_layer_name = prefix
169

170
171
172
        if self.tp_size > config.num_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
173
174
                f"the number of experts {config.num_experts}."
            )
175

176
177
        # Load balancing settings.
        vllm_config = get_current_vllm_config()
178
        eplb_config = vllm_config.parallel_config.eplb_config
179
        self.enable_eplb = parallel_config.enable_eplb
180
181

        self.n_logical_experts = self.n_routed_experts
182
        self.n_redundant_experts = eplb_config.num_redundant_experts
183
        self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
184
185
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

186
187
188
189
190
        self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
        self.physical_expert_end = (
            self.physical_expert_start + self.n_local_physical_experts
        )

191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        self.gate = ReplicatedLinear(
            config.hidden_size,
            config.num_experts,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate",
        )

        shared_expert_intermediate_size = getattr(
            config, "shared_expert_intermediate_size", 0
        )
        if shared_expert_intermediate_size > 0:
            self.shared_expert_gate = ReplicatedLinear(
                config.hidden_size,
                1,
                bias=False,
                quant_config=None,
                prefix=f"{prefix}.shared_expert_gate",
            )
            self.shared_expert = Qwen3MoeMLP(
                hidden_size=config.hidden_size,
                intermediate_size=shared_expert_intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                reduce_results=False,
                expert_gate=self.shared_expert_gate,
                prefix=f"{prefix}.shared_expert",
            )
        else:
            self.shared_expert_gate = None
            self.shared_expert = None

        self.experts = SharedFusedMoE(
            shared_experts=self.shared_expert,
            gate=self.gate,
226
227
228
229
            num_experts=self.n_routed_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
230
            reduce_results=False,
231
232
233
234
235
236
237
            renormalize=config.norm_topk_prob,
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
            is_sequence_parallel=self.is_sequence_parallel,
        )
238
239

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
240
241
242
        assert hidden_states.dim() <= 2, (
            "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs"
        )
243
        is_input_1d = hidden_states.dim() == 1
244
        num_tokens, hidden_dim = hidden_states.shape
245
246
        hidden_states = hidden_states.view(-1, hidden_dim)

247
248
249
        if self.is_sequence_parallel:
            hidden_states = sequence_parallel_chunk(hidden_states)

250
251
        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
252
253
254
255
256
257
258
259
        if not (hasattr(torch, "compiler") and torch.compiler.is_compiling()):
            capture_enabled = envs.VLLM_MOE_ROUTER_CAPTURE
            if capture_enabled:
                maybe_record_router_logits(
                    layer_name=self._router_capture_layer_name,
                    router_logits=router_logits,
                    top_k=self._router_top_k,
                )
260
        shared_out, fused_out = self.experts(
261
262
            hidden_states=hidden_states, router_logits=router_logits
        )
263
264
265
        final_hidden_states = (
            shared_out + fused_out if shared_out is not None else fused_out
        )
266

267
268
        if self.is_sequence_parallel:
            final_hidden_states = tensor_model_parallel_all_gather(
269
270
                final_hidden_states, 0
            )
271
            final_hidden_states = final_hidden_states[:num_tokens]
272
273
274
275
        elif self.tp_size > 1:
            final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(  # noqa E501
                final_hidden_states
            )
276

277
        # return to 1d if input is 1d
278
        return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states
279
280
281
282
283
284
285
286


class Qwen3MoeAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
287
        rope_parameters: dict[str, Any],
288
        max_position_embeddings: int = 8192,
289
        head_dim: int | None = None,
290
291
        rms_norm_eps: float = 1e-06,
        qkv_bias: bool = False,
292
293
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
294
        prefix: str = "",
295
        dual_chunk_attention_config: dict[str, Any] | None = None,
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = head_dim or (hidden_size // self.total_num_heads)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings
318
        self.dual_chunk_attention_config = dual_chunk_attention_config
319

320
321
322
323
324
325
326
327
328
        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=qkv_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
329

330
331
332
333
334
335
336
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
337
338
339
340

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position_embeddings,
341
            rope_parameters=rope_parameters,
342
343
344
345
346
347
348
349
350
351
352
353
354
            dual_chunk_attention_config=dual_chunk_attention_config,
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            **{
                "layer_idx": extract_layer_index(prefix),
                "dual_chunk_attention_config": dual_chunk_attention_config,
355
356
357
            }
            if dual_chunk_attention_config
            else {},
358
359
360
361
362
363
364
365
366
367
368
        )

        self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
        self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
369
370
371
        use_qkv_split_rms_rope = (envs.VLLM_USE_QKV_SPLIT_RMS_ROPE
                                  and positions.ndim == 1)
        if use_qkv_split_rms_rope:
372
            cos_sin_cache = self.rotary_emb.cos_sin_cache
373
374
375
376
            if (cos_sin_cache.device != qkv.device
                    or cos_sin_cache.dtype != qkv.dtype):
                cos_sin_cache = cos_sin_cache.to(qkv.device,
                                                 dtype=qkv.dtype,
377
378
                                                 non_blocking=True)
                self.rotary_emb.cos_sin_cache = cos_sin_cache
379
            q, k, v = torch.ops.vllm.qkv_split_rms_rotary_embedding_fuse(
380
                positions,
381
382
383
                qkv,
                self.q_size,
                self.kv_size,
384
385
386
387
388
389
390
                self.head_dim,
                cos_sin_cache,
                self.rotary_emb.is_neox_style,
                self.q_norm.weight,
                self.k_norm.weight,
                None,
                None,
391
                self.q_norm.variance_epsilon,
392
            )
zhuwenwen's avatar
zhuwenwen committed
393
        else:
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
            q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
            if (not use_qkv_split_rms_rope and envs.VLLM_USE_FUSED_RMS_ROPE
                    and positions.ndim == 1):
                # Fused RMSNorm + RoPE path through custom op.
                cos_sin_cache = self.rotary_emb.cos_sin_cache
                if (cos_sin_cache.device != q.device
                        or cos_sin_cache.dtype != q.dtype):
                    cos_sin_cache = cos_sin_cache.to(q.device,
                                                     dtype=q.dtype,
                                                     non_blocking=True)
                    self.rotary_emb.cos_sin_cache = cos_sin_cache
                q = q.contiguous()
                k = k.contiguous()
                torch.ops.vllm.rms_rotary_embedding_fuse(
                    positions,
                    q,
                    k,
                    self.head_dim,
                    cos_sin_cache,
                    self.rotary_emb.is_neox_style,
                    self.q_norm.weight,
                    self.k_norm.weight,
                    None,
                    None,
                    self.q_norm.variance_epsilon,
                )
            elif (not use_qkv_split_rms_rope and envs.VLLM_USE_FUSED_RMS_ROPE
                  and positions.ndim == 2 and getattr(
                      self.rotary_emb, "mrope_section", None) is not None):
                # Fused RMSNorm + M-RoPE path through custom op.
                cos_sin_cache = self.rotary_emb.cos_sin_cache
                if (cos_sin_cache.device != q.device
                        or cos_sin_cache.dtype != q.dtype):
                    cos_sin_cache = cos_sin_cache.to(q.device,
                                                     dtype=q.dtype,
                                                     non_blocking=True)
                    self.rotary_emb.cos_sin_cache = cos_sin_cache

                cos_sin = cos_sin_cache[positions]
                cos, sin = cos_sin.chunk(2, dim=-1)

                q = q.contiguous()
                k = k.contiguous()
                cos = cos.contiguous()
                sin = sin.contiguous()
                mrope_section = self.rotary_emb.mrope_section
                assert mrope_section is not None and len(mrope_section) == 3
                torch.ops.vllm.rms_mrope_fuse(
                    q,
                    k,
                    cos,
                    sin,
                    self.head_dim,
                    self.rotary_emb.rotary_dim,
                    mrope_section[0],
                    mrope_section[1],
                    mrope_section[2],
                    self.rotary_emb.mrope_interleaved,
                    self.q_norm.weight,
                    self.k_norm.weight,
                    self.q_norm.variance_epsilon,
                    None,
                    None,
                )
458
            else:
459
460
461
462
463
464
465
466
467
468
469
470
471
472
                q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
                if envs.VLLM_USE_APEX_RN:
                    q_by_head = self.q_norm.forward_apex(q_by_head)
                else:
                    q_by_head = self.q_norm.forward_cuda(q_by_head)
                q = q_by_head.view(q.shape)

                k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
                if envs.VLLM_USE_APEX_RN:
                    k_by_head = self.k_norm.forward_apex(k_by_head)
                else:
                    k_by_head = self.k_norm.forward_cuda(k_by_head)
                k = k_by_head.view(k.shape)
                q, k = self.rotary_emb(positions, q, k)
473

474
475
476
477
478
479
480
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output


class Qwen3MoeDecoderLayer(nn.Module):

481
    def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
482
        super().__init__()
483

484
        config = vllm_config.model_config.hf_text_config
485
486
487
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

488
        self.hidden_size = config.hidden_size
489
490
491
492
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
        dual_chunk_attention_config = getattr(
            config, "dual_chunk_attention_config", None
        )
493
494
495
496
        self.self_attn = Qwen3MoeAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
497
            rope_parameters=config.rope_parameters,
498
499
            max_position_embeddings=max_position_embeddings,
            rms_norm_eps=config.rms_norm_eps,
500
501
            qkv_bias=getattr(config, "attention_bias", False),
            head_dim=getattr(config, "head_dim", None),
502
503
504
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
505
            dual_chunk_attention_config=dual_chunk_attention_config,
506
507
508
509
        )

        # `mlp_only_layers` in the config.
        layer_idx = extract_layer_index(prefix)
510
511
512
        mlp_only_layers = (
            [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
        )
513
        if (layer_idx not in mlp_only_layers) and (
514
515
516
517
518
            config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
        ):
            self.mlp = Qwen3MoeSparseMoeBlock(
                vllm_config=vllm_config, prefix=f"{prefix}.mlp"
            )
519
        else:
520
521
522
523
524
525
526
527
528
529
530
            self.mlp = Qwen3MoeMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                prefix=f"{prefix}.mlp",
            )
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
531
532
533
534
535

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
536
        residual: torch.Tensor | None,
537
    ) -> tuple[torch.Tensor, torch.Tensor]:
538
539
540
541
542
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
543
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
544
545
546
547
548
549
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
550
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
551
552
553
554
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


555
556
557
558
559
560
561
562
563
@support_torch_compile(
    dynamic_arg_dims={
        "input_ids": 0,
        # positions is of shape (3, seq_len) if mrope is enabled,
        # otherwise (seq_len, ).
        "positions": -1,
        "intermediate_tensors": 0,
        "inputs_embeds": 0,
    })
564
565
566
567
class Qwen3MoeModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

568
        config = vllm_config.model_config.hf_text_config
569
        quant_config = vllm_config.quant_config
570
        parallel_config = vllm_config.parallel_config
571
572
        eplb_config = parallel_config.eplb_config
        self.num_redundant_experts = eplb_config.num_redundant_experts
573
574
575

        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
576
        self.config = config
577
        self.quant_config = quant_config
578
579
580
581
582
583
584
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config        
            # if self.config.quantization_config["bits"] == 4:
            os.environ['LLAMA_NN'] = '0'
            os.environ['LM_NN'] = '0'  
585
586
587
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
588
            quant_config=quant_config,
589
590
            prefix=f"{prefix}.embed_tokens",
        )
591
592
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
593
            lambda prefix: Qwen3MoeDecoderLayer(vllm_config=vllm_config, prefix=prefix),
594
595
596
            prefix=f"{prefix}.layers",
        )
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
597
598
599
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
600
601
        # Track layers for auxiliary hidden state outputs (EAGLE3)
        self.aux_hidden_state_layers: tuple[int, ...] = ()
zhuwenwen's avatar
zhuwenwen committed
602
603
604
605
606
607
608
        
        self.tritonsingleton= W8a8GetCacheJSON()
            
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
609
        self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
610

611
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
612
613
614
615
        return self.embed_tokens(input_ids)

    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
616
        input_ids: torch.Tensor,
617
        positions: torch.Tensor,
618
619
620
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
621
622
623
624
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
625
                hidden_states = self.embed_input_ids(input_ids)
626
627
628
629
630
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
631
632
633
634
635
636
637
638
639
640
641
642

        aux_hidden_states = []
        for layer_idx, layer in enumerate(
            islice(self.layers, self.start_layer, self.end_layer),
            start=self.start_layer,
        ):
            # Collect auxiliary hidden states if specified
            if layer_idx in self.aux_hidden_state_layers:
                aux_hidden_state = (
                    hidden_states + residual if residual is not None else hidden_states
                )
                aux_hidden_states.append(aux_hidden_state)
643
            hidden_states, residual = layer(positions, hidden_states, residual)
644

645
        if not get_pp_group().is_last_rank:
646
647
648
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
649
        hidden_states, _ = self.norm(hidden_states, residual)
650
651
652
653

        # Return auxiliary hidden states if collected
        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
654
655
        return hidden_states

656
657
658
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
659
        return SharedFusedMoE.make_expert_params_mapping(
660
            self,
661
662
663
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
664
            num_experts=self.config.num_experts,
665
666
            num_redundant_experts=self.num_redundant_experts,
        )
667

668
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
669
670
671
672
673
674
675
676
677
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

678
        # Skip loading extra parameters for GPTQ/modelopt models.
679
680
681
682
683
684
685
686
687
688
689
690
        ignore_suffixes = (
            ".bias",
            "_bias",
            ".k_scale",
            "_k_scale",
            ".v_scale",
            "_v_scale",
            ".weight_scale",
            "_weight_scale",
            ".input_scale",
            "_input_scale",
        )
691

692
        params_dict = dict(self.named_parameters())
693
        loaded_params: set[str] = set()
694
        expert_params_mapping = self.get_expert_mapping()
695
        for name, loaded_weight in weights:
zhuwenwen's avatar
zhuwenwen committed
696
697
698
            if self.use_llama_nn:
                current_count = loaded_weight.current_count 
                total_count = loaded_weight.total_count
699
700
701
702
703
704
705
706
707
708
709
710
711
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                assert loaded_weight.numel() == 1, (
                    f"KV scale numel {loaded_weight.numel()} != 1"
                )
                loaded_weight = loaded_weight.squeeze()
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
712
            for param_name, weight_name, shard_id in stacked_params_mapping:
713
714
715
716
717
718
719
720
721
722
723
724
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if "mlp.experts" in name:
                    continue
                name = name.replace(weight_name, param_name)
725
726
727

                # Skip loading extra parameters for GPTQ/modelopt models.
                if name.endswith(ignore_suffixes) and name not in params_dict:
728
                    continue
729

730
731
732
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
733
734
735
736
737
                if name.endswith("scale"):
                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue
738
739
740
741
                if name not in params_dict:
                    continue

                param = params_dict[name]
742
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
743
744
745
746
                if weight_loader == default_weight_loader:
                    weight_loader(param, loaded_weight)
                else:
                    weight_loader(param, loaded_weight, shard_id)
747
748
                break
            else:
749
                is_expert_weight = False
750
751
752
753
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
754
755
756
757
758
759
760
761
762
763

                    # Anyway, this is an expert weight and should not be
                    # attempted to load as other weights later
                    is_expert_weight = True

                    # Do not modify `name` since the loop may continue here
                    # Instead, create a new variable
                    name_mapped = name.replace(weight_name, param_name)

                    if is_pp_missing_parameter(name_mapped, self):
764
                        continue
765

766
                    # Skip loading extra parameters for GPTQ/modelopt models.
767
768
769
770
                    if (
                        name_mapped.endswith(ignore_suffixes)
                        and name_mapped not in params_dict
                    ):
771
                        continue
772
773
774
775
776

                    param = params_dict[name_mapped]
                    # We should ask the weight loader to return success or not
                    # here since otherwise we may skip experts with other
                    # available replicas.
777
778
779
780
781
782
783
784
785
786
787
                    weight_loader = typing.cast(
                        Callable[..., bool], param.weight_loader
                    )
                    success = weight_loader(
                        param,
                        loaded_weight,
                        name_mapped,
                        shard_id=shard_id,
                        expert_id=expert_id,
                        return_success=True,
                    )
788
789
790
                    if success:
                        name = name_mapped
                        break
791
                else:
792
793
794
795
796
797
                    if is_expert_weight:
                        # We've checked that this is an expert weight
                        # However it's not mapped locally to this rank
                        # So we simply skip it
                        continue

798
                    # Skip loading extra parameters for GPTQ/modelopt models.
799
                    if name.endswith(ignore_suffixes) and name not in params_dict:
800
801
802
803
804
805
806
                        continue
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
                    # Remapping the name of FP8 kv-scale.
                    if name.endswith("kv_scale"):
                        remapped_kv_scale_name = name.replace(
807
808
                            ".kv_scale", ".attn.kv_scale"
                        )
809
810
                        if remapped_kv_scale_name not in params_dict:
                            logger.warning_once(
811
812
813
814
                                "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.",  # noqa: E501
                                name,
                                remapped_kv_scale_name,
                            )
815
816
817
818
                            continue
                        else:
                            name = remapped_kv_scale_name
                    param = params_dict[name]
819
820
821
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
822
823
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
zhuwenwen's avatar
zhuwenwen committed
824
        
zhuwenwen's avatar
zhuwenwen committed
825
        if self.use_llama_nn and self.quant_method is None and current_count==total_count:
zhuwenwen's avatar
zhuwenwen committed
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
            lay_key_words = [
                "gate_up_proj.weight",
                "down_proj.weight",
                "mlp.gate.weight",
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "lm_head.weight",
            ]
            combined_words = "|".join(lay_key_words)
            
            # lay_qkv_words = ["self_attn.qkv_proj.weight"]   
            # qkv_words = "|".join(lay_qkv_words)  
            
            # lay_qkv_bias_words = ["self_attn.qkv_proj.bias"]   
            # qkv_bias_words = "|".join(lay_qkv_bias_words) 
            
zhuwenwen's avatar
zhuwenwen committed
842
843
            # for layername in loaded_params:
            for layername in params_dict.keys():
zhuwenwen's avatar
zhuwenwen committed
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
                weight = params_dict[layername]
                os.environ['LM_NN'] = '0' 
                # if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
                #     weight.data = pad_weight(weight.data, 32)
                    
                matches = re.findall(combined_words, layername)
                if matches:   
                    # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                    #     weight.data = pad_weight(weight.data, 32)  
                    
                    # if self.use_fa_pad and (re.findall(qkv_words, layername)):
                    #     if not gemm_bank_conf(weight.data.shape[0]):
                    #         weight.data = pad_weight(weight.data, 32)
                        
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1],-1)
865
        return loaded_params
866
867


868
869
870
class Qwen3MoeForCausalLM(
    nn.Module, SupportsPP, SupportsLoRA, SupportsEagle3, MixtureOfExperts
):
871
872
873
874
875
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
876
        ]
877
    }
878
879
880
881
882

    fall_back_to_pt_during_load = False

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
883
        config = vllm_config.model_config.hf_text_config
884
885
886
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
887
888
        # Only perform the following mapping when Qwen3MoeMLP exists
        if getattr(config, "mlp_only_layers", []):
889
            self.packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"]
890
891
892
893
894
895
896
897
898
        self.model = Qwen3MoeModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
899
900
901
902
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
903
904
            self.model.make_empty_intermediate_tensors
        )
905

906
907
908
        # Set MoE hyperparameters
        self.expert_weights = []

909
        self.moe_layers = []
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
        example_layer = None
        for layer in self.model.layers:
            if isinstance(layer, PPMissingLayer):
                continue

            assert isinstance(layer, Qwen3MoeDecoderLayer)
            if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
                example_layer = layer.mlp
                self.moe_layers.append(layer.mlp.experts)

        if example_layer is None:
            raise RuntimeError("No Qwen3MoE layer found in the model.layers.")

        self.num_moe_layers = len(self.moe_layers)
        self.num_expert_groups = 1
        self.num_shared_experts = 0
        self.num_logical_experts = example_layer.n_logical_experts
        self.num_physical_experts = example_layer.n_physical_experts
        self.num_local_physical_experts = example_layer.n_local_physical_experts
        self.num_routed_experts = example_layer.n_routed_experts
        self.num_redundant_experts = example_layer.n_redundant_experts

    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
940
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
941
942
943
944
945
946
947
948
        for layer in self.model.layers:
            if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
                moe = layer.mlp
                moe.n_local_physical_experts = num_local_physical_experts
                moe.n_physical_experts = num_physical_experts
                moe.n_redundant_experts = self.num_redundant_experts
                moe.experts.update_expert_map()

949
950
951
952
953
954
955
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
        self.model.aux_hidden_state_layers = layers

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

956
957
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
958
959
960

    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
961
        input_ids: torch.Tensor,
962
        positions: torch.Tensor,
963
964
965
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
966
967
968
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
969
970
971
972
973
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
974
    ) -> torch.Tensor | None:
975
        logits = self.logits_processor(self.lm_head, hidden_states)
976
977
        return logits

978
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
979
        loader = AutoWeightsLoader(self)
980
        return loader.load_weights(weights)
981
982

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
983
        return self.model.get_expert_mapping()