qwen3_moe.py 40.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
25
26
import typing
from collections.abc import Callable, Iterable
27
from itertools import islice
28
from typing import Any, Optional, Union
29

zhuwenwen's avatar
zhuwenwen committed
30
31
import os
import re
32
33
34
35
36
import torch
from torch import nn

from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
37
38
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (get_ep_group, get_pp_group,
39
40
                              get_tensor_model_parallel_world_size,
                              tensor_model_parallel_all_gather)
41
42
43
44
45
46
47
48
49
50
51
52
53
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead, VocabParallelEmbedding)
54
55
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, maybe_remap_kv_scale_name)
56
from vllm.model_executor.models.utils import sequence_parallel_chunk
57
58
from vllm.sequence import IntermediateTensors

59
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
60
from vllm.utils import direct_register_custom_op
61
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
62
                    is_pp_missing_parameter,
63
64
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
zhuwenwen's avatar
zhuwenwen committed
65
import vllm.envs as envs
zhuwenwen's avatar
zhuwenwen committed
66
67
68
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
from vllm.utils import W8a8GetCacheJSON
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

logger = init_logger(__name__)


class Qwen3MoeMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        quant_config: Optional[QuantizationConfig] = None,
        reduce_results: bool = True,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj")
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
                                           quant_config=quant_config,
                                           reduce_results=reduce_results,
                                           prefix=f"{prefix}.down_proj")
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class Qwen3MoeSparseMoeBlock(nn.Module):

    def __init__(
        self,
112
        vllm_config: VllmConfig,
113
114
115
        prefix: str = "",
    ):
        super().__init__()
116

117
        config = vllm_config.model_config.hf_text_config
118
119
120
        parallel_config = vllm_config.parallel_config
        quant_config = vllm_config.quant_config

121
122
        self.tp_size = get_tensor_model_parallel_world_size()

123
124
125
126
127
        self.ep_group = get_ep_group().device_group
        self.ep_rank = self.ep_group.rank()
        self.ep_size = self.ep_group.size()
        self.n_routed_experts = config.num_experts

128
129
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe

130
131
132
133
134
        if self.tp_size > config.num_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
                f"the number of experts {config.num_experts}.")

135
136
        # Load balancing settings.
        vllm_config = get_current_vllm_config()
137
        eplb_config = vllm_config.parallel_config.eplb_config
138
        self.enable_eplb = parallel_config.enable_eplb
139
140

        self.n_logical_experts = self.n_routed_experts
141
        self.n_redundant_experts = eplb_config.num_redundant_experts
142
143
144
145
146
147
148
149
150
151
        self.n_physical_experts = (self.n_logical_experts +
                                   self.n_redundant_experts)
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

        self.physical_expert_start = (self.ep_rank *
                                      self.n_local_physical_experts)
        self.physical_expert_end = (self.physical_expert_start +
                                    self.n_local_physical_experts)

        self.experts = FusedMoE(num_experts=self.n_routed_experts,
152
153
154
                                top_k=config.num_experts_per_tok,
                                hidden_size=config.hidden_size,
                                intermediate_size=config.moe_intermediate_size,
155
                                reduce_results=True,
156
157
                                renormalize=config.norm_topk_prob,
                                quant_config=quant_config,
158
159
                                prefix=f"{prefix}.experts",
                                enable_eplb=self.enable_eplb,
160
161
                                num_redundant_experts=self.n_redundant_experts,
                                is_sequence_parallel=self.is_sequence_parallel)
162

163
164
165
166
167
        self.gate = ReplicatedLinear(config.hidden_size,
                                     config.num_experts,
                                     bias=False,
                                     quant_config=quant_config,
                                     prefix=f"{prefix}.gate")
168
169

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
170
171
172
        assert hidden_states.dim(
        ) <= 2, "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs"
        is_input_1d = hidden_states.dim() == 1
173
        num_tokens, hidden_dim = hidden_states.shape
174
175
        hidden_states = hidden_states.view(-1, hidden_dim)

176
177
178
        if self.is_sequence_parallel:
            hidden_states = sequence_parallel_chunk(hidden_states)

179
180
181
182
        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
        final_hidden_states = self.experts(hidden_states=hidden_states,
                                           router_logits=router_logits)
183

184
185
186
187
188
        if self.is_sequence_parallel:
            final_hidden_states = tensor_model_parallel_all_gather(
                final_hidden_states, 0)
            final_hidden_states = final_hidden_states[:num_tokens]

189
190
191
        # return to 1d if input is 1d
        return final_hidden_states.squeeze(0) if is_input_1d else \
            final_hidden_states
192
193
194
195
196
197
198
199
200
201


class Qwen3MoeAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
202
        rope_scaling: Optional[dict[str, Any]] = None,
203
204
205
206
207
208
209
        max_position_embeddings: int = 8192,
        head_dim: Optional[int] = None,
        rms_norm_eps: float = 1e-06,
        qkv_bias: bool = False,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
210
        dual_chunk_attention_config: Optional[dict[str, Any]] = None,
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = head_dim or (hidden_size // self.total_num_heads)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
234
        self.dual_chunk_attention_config = dual_chunk_attention_config
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255

        self.qkv_proj = QKVParallelLinear(hidden_size,
                                          self.head_dim,
                                          self.total_num_heads,
                                          self.total_num_kv_heads,
                                          bias=qkv_bias,
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.qkv_proj")

        self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
                                        hidden_size,
                                        bias=False,
                                        quant_config=quant_config,
                                        prefix=f"{prefix}.o_proj")

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
256
257
258
259
260
261
262
263
264
265
266
267
268
269
            dual_chunk_attention_config=dual_chunk_attention_config,
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            **{
                "layer_idx": extract_layer_index(prefix),
                "dual_chunk_attention_config": dual_chunk_attention_config,
            } if dual_chunk_attention_config else {},
270
271
272
273
        )

        self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
        self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
274
275
276
277
278
279
280
281
282
    def rms_rotary_embedding_fuse(
        positions: torch.Tensor,
        query: torch.Tensor,
        key: Optional[torch.Tensor],
        head_size: int,
        cos_sin_cache: torch.Tensor,
        is_neox_style: bool,
        q_weight: torch.Tensor,
        k_weight: torch.Tensor,
283
284
        q_residual: Optional[torch.Tensor],
        k_residual: Optional[torch.Tensor],
285
286
        epsilon: float,
    ) -> None:
287
288
289
290
291
292
293
294
295
296
297
298
        backend = envs.VLLM_FUSED_RMS_ROPE_BACKEND
        if backend == "lightop":
            from lightop import rms_rotary_embedding_fuse as fused_kernel
            fused_kernel(
                positions,
                query,
                key,
                head_size,
                cos_sin_cache,
                is_neox_style,
                q_weight,
                k_weight,
299
300
                q_residual,
                k_residual,
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
                epsilon,
            )
            return

        if backend not in ("vllm", "auto"):
            raise ValueError(
                "VLLM_FUSED_RMS_ROPE_BACKEND must be one of "
                "('auto', 'vllm', 'lightop'), got: %r" % backend)

        try:
            import vllm._C  # noqa: F401
        except Exception:
            if backend == "vllm":
                raise

        if backend == "auto" and not hasattr(torch.ops._C,
                                             "rms_rotary_embedding_fuse"):
            from lightop import rms_rotary_embedding_fuse as fused_kernel
            fused_kernel(
                positions,
                query,
                key,
                head_size,
                cos_sin_cache,
                is_neox_style,
                q_weight,
                k_weight,
328
329
                q_residual,
                k_residual,
330
331
332
333
334
                epsilon,
            )
            return

        torch.ops._C.rms_rotary_embedding_fuse(
335
336
337
338
339
340
341
342
            positions,
            query,
            key,
            head_size,
            cos_sin_cache,
            is_neox_style,
            q_weight,
            k_weight,
343
344
            q_residual,
            k_residual,
345
346
347
348
349
350
351
352
353
354
355
356
357
358
            epsilon,
        )

    def rms_rotary_embedding_fuse_fake(
        # q_out:torch.Tensor,
        # k_out:torch.Tensor,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: Optional[torch.Tensor],
        head_size: int,
        cos_sin_cache: torch.Tensor,
        is_neox_style: bool,
        q_weight: torch.Tensor,
        k_weight: torch.Tensor,
359
360
        q_residual: Optional[torch.Tensor],
        k_residual: Optional[torch.Tensor],
361
362
363
364
365
        epsilon: float,
    ) -> None:
        # Fake impl intentionally left as no-op for graph tracing modes.
        pass

366
367
368
369
370
371
372
    if not hasattr(torch.ops.vllm, "rms_rotary_embedding_fuse"):
        direct_register_custom_op(
            op_name="rms_rotary_embedding_fuse",
            op_func=rms_rotary_embedding_fuse,
            mutates_args=["query", "key"],
            fake_impl=rms_rotary_embedding_fuse_fake,
        )
373

374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
    def rms_mrope_fuse(
        query: torch.Tensor,
        key: torch.Tensor,
        cos: torch.Tensor,
        sin: torch.Tensor,
        head_size: int,
        rotary_dim: int,
        mrope_section_t: int,
        mrope_section_h: int,
        mrope_section_w: int,
        mrope_interleaved: bool,
        q_weight: torch.Tensor,
        k_weight: torch.Tensor,
        q_residual: Optional[torch.Tensor],
        k_residual: Optional[torch.Tensor],
        epsilon: float,
    ) -> None:
        from lightop import op as lightop_ops
        lightop_ops.fuse_rms_mrope_cuda(
            query,
            key,
            cos,
            sin,
            [mrope_section_t, mrope_section_h, mrope_section_w],
            head_size,
            rotary_dim,
            mrope_interleaved,
            q_weight,
            k_weight,
            q_residual,
            k_residual,
            epsilon,
        )

    def rms_mrope_fuse_fake(
        query: torch.Tensor,
        key: torch.Tensor,
        cos: torch.Tensor,
        sin: torch.Tensor,
        head_size: int,
        rotary_dim: int,
        mrope_section_t: int,
        mrope_section_h: int,
        mrope_section_w: int,
        mrope_interleaved: bool,
        q_weight: torch.Tensor,
        k_weight: torch.Tensor,
        q_residual: Optional[torch.Tensor],
        k_residual: Optional[torch.Tensor],
        epsilon: float,
    ) -> None:
        # Fake impl intentionally left as no-op for graph tracing modes.
        pass

    direct_register_custom_op(
        op_name="rms_mrope_fuse",
        op_func=rms_mrope_fuse,
        mutates_args=["query", "key"],
        fake_impl=rms_mrope_fuse_fake,
    )

435
436
437
438
439
440
441
442
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        # Add qk-norm
443
        if envs.VLLM_USE_FUSED_RMS_ROPE and positions.ndim == 1:
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
            # Fused RMSNorm + RoPE path through custom op.
            cos_sin_cache = self.rotary_emb.cos_sin_cache
            if (cos_sin_cache.device != q.device
                    or cos_sin_cache.dtype != q.dtype):
                cos_sin_cache = cos_sin_cache.to(q.device,
                                                 dtype=q.dtype,
                                                 non_blocking=True)
                # Persist the converted cache so we don't re-copy/re-allocate
                # on every forward when the original buffer starts on CPU.
                self.rotary_emb.cos_sin_cache = cos_sin_cache
            torch.ops.vllm.rms_rotary_embedding_fuse(
                positions,
                q,
                k,
                self.head_dim,
                cos_sin_cache,
                self.rotary_emb.is_neox_style,
                self.q_norm.weight,
                self.k_norm.weight,
                None,
                None,
                self.q_norm.variance_epsilon,
            )
467

468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
        elif envs.VLLM_USE_FUSED_RMS_ROPE and positions.ndim == 2 and getattr(
                self.rotary_emb, "mrope_section", None) is not None:
            # Fused RMSNorm + M-RoPE path through custom op.
            cos_sin_cache = self.rotary_emb.cos_sin_cache
            if (cos_sin_cache.device != q.device
                    or cos_sin_cache.dtype != q.dtype):
                cos_sin_cache = cos_sin_cache.to(q.device,
                                                 dtype=q.dtype,
                                                 non_blocking=True)
                self.rotary_emb.cos_sin_cache = cos_sin_cache

            cos_sin = cos_sin_cache[positions]
            cos, sin = cos_sin.chunk(2, dim=-1)

            q = q.contiguous()
            k = k.contiguous()
            cos = cos.contiguous()
            sin = sin.contiguous()
            mrope_section = self.rotary_emb.mrope_section
            assert mrope_section is not None and len(mrope_section) == 3
            torch.ops.vllm.rms_mrope_fuse(
                q,
                k,
                cos,
                sin,
                self.head_dim,
                self.rotary_emb.rotary_dim,
                mrope_section[0],
                mrope_section[1],
                mrope_section[2],
                self.rotary_emb.mrope_interleaved,
                self.q_norm.weight,
                self.k_norm.weight,
                None,
                None,
                self.q_norm.variance_epsilon,
            )

zhuwenwen's avatar
zhuwenwen committed
506
        else:
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
            # Add qk-norm then RoPE (original path).
            q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
                               self.head_dim)
            if envs.VLLM_USE_APEX_RN:
                q_by_head = self.q_norm.forward_apex(q_by_head)
            else:
                q_by_head = self.q_norm.forward_cuda(q_by_head)
            q = q_by_head.view(q.shape)

            k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
                               self.head_dim)
            if envs.VLLM_USE_APEX_RN:
                k_by_head = self.k_norm.forward_apex(k_by_head)
            else:
                k_by_head = self.k_norm.forward_cuda(k_by_head)
            k = k_by_head.view(k.shape)
            q, k = self.rotary_emb(positions, q, k)
524
525
526
527
528
529
530
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output


class Qwen3MoeDecoderLayer(nn.Module):

531
    def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
532
        super().__init__()
533

534
        config = vllm_config.model_config.hf_text_config
535
536
537
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

538
539
540
541
542
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
543
544
545
        dual_chunk_attention_config = getattr(config,
                                              "dual_chunk_attention_config",
                                              None)
546
547
548
549
550
551
552
553
554
555
556
557
558
        self.self_attn = Qwen3MoeAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            rms_norm_eps=config.rms_norm_eps,
            qkv_bias=getattr(config, 'attention_bias', False),
            head_dim=getattr(config, 'head_dim', None),
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
559
            dual_chunk_attention_config=dual_chunk_attention_config,
560
561
562
563
564
565
566
567
568
        )

        # `mlp_only_layers` in the config.
        layer_idx = extract_layer_index(prefix)
        mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
                           config.mlp_only_layers)
        if (layer_idx not in mlp_only_layers) and (
                config.num_experts > 0 and
            (layer_idx + 1) % config.decoder_sparse_step == 0):
569
570
            self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config,
                                              prefix=f"{prefix}.mlp")
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
        else:
            self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
                                   intermediate_size=config.intermediate_size,
                                   hidden_act=config.hidden_act,
                                   quant_config=quant_config,
                                   prefix=f"{prefix}.mlp")
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
587
    ) -> tuple[torch.Tensor, torch.Tensor]:
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


607
608
609
610
611
612
613
614
615
@support_torch_compile(
    dynamic_arg_dims={
        "input_ids": 0,
        # positions is of shape (3, seq_len) if mrope is enabled,
        # otherwise (seq_len, ).
        "positions": -1,
        "intermediate_tensors": 0,
        "inputs_embeds": 0,
    })
616
617
618
619
620
class Qwen3MoeModel(nn.Module):

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

621
        config = vllm_config.model_config.hf_text_config
622
        quant_config = vllm_config.quant_config
623
        parallel_config = vllm_config.parallel_config
624
625
        eplb_config = parallel_config.eplb_config
        self.num_redundant_experts = eplb_config.num_redundant_experts
626
627
628

        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
629
        self.config = config
630
631
632
633
634
635
636
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config        
            # if self.config.quantization_config["bits"] == 4:
            os.environ['LLAMA_NN'] = '0'
            os.environ['LM_NN'] = '0'  
637
638
639
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
640
            quant_config=quant_config,
641
642
643
            prefix=f"{prefix}.embed_tokens")
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
644
645
            lambda prefix: Qwen3MoeDecoderLayer(vllm_config=vllm_config,
                                                prefix=prefix),
646
647
648
649
650
651
            prefix=f"{prefix}.layers",
        )
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
zhuwenwen's avatar
zhuwenwen committed
652
653
654
655
656
657
658
        
        self.tritonsingleton= W8a8GetCacheJSON()
            
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
659
        self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
681
        for layer in islice(self.layers, self.start_layer, self.end_layer):
682
683
684
685
686
687
688
689
690
            hidden_states, residual = layer(positions, hidden_states, residual)
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

691
692
693
694
695
696
697
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        return FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
698
699
            num_experts=self.config.num_experts,
            num_redundant_experts=self.num_redundant_experts)
700

701
702
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
703
704
705
706
707
708
709
710
711
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

712
713
714
715
716
        # Skip loading extra parameters for GPTQ/modelopt models.
        ignore_suffixes = (".bias", "_bias", ".k_scale", "_k_scale",
                           ".v_scale", "_v_scale", ".weight_scale",
                           "_weight_scale", ".input_scale", "_input_scale")

717
        params_dict = dict(self.named_parameters())
718
        loaded_params: set[str] = set()
719
        expert_params_mapping = self.get_expert_mapping()
720
        for name, loaded_weight in weights:
zhuwenwen's avatar
zhuwenwen committed
721
722
723
            if self.use_llama_nn:
                current_count = loaded_weight.current_count 
                total_count = loaded_weight.total_count
724
725
726
727
728
729
730
731
732
733
734
735
736
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if "mlp.experts" in name:
                    continue
                name = name.replace(weight_name, param_name)
737
738
739

                # Skip loading extra parameters for GPTQ/modelopt models.
                if name.endswith(ignore_suffixes) and name not in params_dict:
740
                    continue
741

742
743
744
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
745
746
747
748
749
                if name.endswith("scale"):
                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue
750
751
752
753
                if name not in params_dict:
                    continue

                param = params_dict[name]
754
755
756
757
758
759
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                if weight_loader == default_weight_loader:
                    weight_loader(param, loaded_weight)
                else:
                    weight_loader(param, loaded_weight, shard_id)
760
761
                break
            else:
762
                is_expert_weight = False
763
764
765
766
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
767
768
769
770
771
772
773
774
775
776

                    # Anyway, this is an expert weight and should not be
                    # attempted to load as other weights later
                    is_expert_weight = True

                    # Do not modify `name` since the loop may continue here
                    # Instead, create a new variable
                    name_mapped = name.replace(weight_name, param_name)

                    if is_pp_missing_parameter(name_mapped, self):
777
                        continue
778

779
                    # Skip loading extra parameters for GPTQ/modelopt models.
780
781
782
                    if name_mapped.endswith(
                            ignore_suffixes
                    ) and name_mapped not in params_dict:
783
                        continue
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799

                    param = params_dict[name_mapped]
                    # We should ask the weight loader to return success or not
                    # here since otherwise we may skip experts with other
                    # available replicas.
                    weight_loader = typing.cast(Callable[..., bool],
                                                param.weight_loader)
                    success = weight_loader(param,
                                            loaded_weight,
                                            name_mapped,
                                            shard_id=shard_id,
                                            expert_id=expert_id,
                                            return_success=True)
                    if success:
                        name = name_mapped
                        break
800
                else:
801
802
803
804
805
806
                    if is_expert_weight:
                        # We've checked that this is an expert weight
                        # However it's not mapped locally to this rank
                        # So we simply skip it
                        continue

807
808
809
                    # Skip loading extra parameters for GPTQ/modelopt models.
                    if name.endswith(
                            ignore_suffixes) and name not in params_dict:
810
811
812
813
814
815
816
817
818
819
                        continue
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
                    # Remapping the name of FP8 kv-scale.
                    if name.endswith("kv_scale"):
                        remapped_kv_scale_name = name.replace(
                            ".kv_scale", ".attn.kv_scale")
                        if remapped_kv_scale_name not in params_dict:
                            logger.warning_once(
820
821
822
823
                                "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.",  # noqa: E501
                                name,
                                remapped_kv_scale_name,
                            )
824
825
826
827
828
829
830
831
                            continue
                        else:
                            name = remapped_kv_scale_name
                    param = params_dict[name]
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
zhuwenwen's avatar
zhuwenwen committed
832
        
zhuwenwen's avatar
zhuwenwen committed
833
        if self.use_llama_nn and self.quant_method is None and current_count==total_count:
zhuwenwen's avatar
zhuwenwen committed
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
            lay_key_words = [
                "gate_up_proj.weight",
                "down_proj.weight",
                "mlp.gate.weight",
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "lm_head.weight",
            ]
            combined_words = "|".join(lay_key_words)
            
            # lay_qkv_words = ["self_attn.qkv_proj.weight"]   
            # qkv_words = "|".join(lay_qkv_words)  
            
            # lay_qkv_bias_words = ["self_attn.qkv_proj.bias"]   
            # qkv_bias_words = "|".join(lay_qkv_bias_words) 
            
zhuwenwen's avatar
zhuwenwen committed
850
851
            # for layername in loaded_params:
            for layername in params_dict.keys():
zhuwenwen's avatar
zhuwenwen committed
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
                weight = params_dict[layername]
                os.environ['LM_NN'] = '0' 
                # if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
                #     weight.data = pad_weight(weight.data, 32)
                    
                matches = re.findall(combined_words, layername)
                if matches:   
                    # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                    #     weight.data = pad_weight(weight.data, 32)  
                    
                    # if self.use_fa_pad and (re.findall(qkv_words, layername)):
                    #     if not gemm_bank_conf(weight.data.shape[0]):
                    #         weight.data = pad_weight(weight.data, 32)
                        
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1],-1)
873
        return loaded_params
874
875


876
877
class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
                          MixtureOfExperts):
878
879
880
881
882
883
884
885
886
887
888
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
889
890
891
892
893

    fall_back_to_pt_during_load = False

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
894
        config = vllm_config.model_config.hf_text_config
895
896
897
898
899
900
901
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
        self.model = Qwen3MoeModel(vllm_config=vllm_config,
                                   prefix=maybe_prefix(prefix, "model"))
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
902
903
                                      quant_config=quant_config,
                                      prefix=maybe_prefix(prefix, "lm_head"))
904
905
906
907
908
909
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)

910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
        # Set MoE hyperparameters
        self.expert_weights = []

        self.moe_layers: list[FusedMoE] = []
        example_layer = None
        for layer in self.model.layers:
            if isinstance(layer, PPMissingLayer):
                continue

            assert isinstance(layer, Qwen3MoeDecoderLayer)
            if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
                example_layer = layer.mlp
                self.moe_layers.append(layer.mlp.experts)

        if example_layer is None:
            raise RuntimeError("No Qwen3MoE layer found in the model.layers.")

        self.num_moe_layers = len(self.moe_layers)
        self.num_expert_groups = 1
        self.num_shared_experts = 0
        self.num_logical_experts = example_layer.n_logical_experts
        self.num_physical_experts = example_layer.n_physical_experts
        self.num_local_physical_experts = example_layer.n_local_physical_experts
        self.num_routed_experts = example_layer.n_routed_experts
        self.num_redundant_experts = example_layer.n_redundant_experts

    def set_eplb_state(
        self,
        expert_load_view: torch.Tensor,
        logical_to_physical_map: torch.Tensor,
        logical_replica_count: torch.Tensor,
    ) -> None:
        for layer_idx, layer in enumerate(self.moe_layers):
            # Register the expert weights.
            self.expert_weights.append(layer.get_expert_weights())
            layer.set_eplb_state(
                moe_layer_idx=layer_idx,
                expert_load_view=expert_load_view,
                logical_to_physical_map=logical_to_physical_map,
                logical_replica_count=logical_replica_count,
            )

    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
        self.num_redundant_experts = (num_physical_experts -
                                      self.num_logical_experts)
        for layer in self.model.layers:
            if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
                moe = layer.mlp
                moe.n_local_physical_experts = num_local_physical_experts
                moe.n_physical_experts = num_physical_experts
                moe.n_redundant_experts = self.num_redundant_experts
                moe.experts.update_expert_map()

970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
                                   inputs_embeds)
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
988
        logits = self.logits_processor(self.lm_head, hidden_states)
989
990
        return logits

991
992
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
993
        loader = AutoWeightsLoader(self)
994
        return loader.load_weights(weights)
995
996

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
997
        return self.model.get_expert_mapping()