qwen3_moe.py 35.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
25

26
27
import typing
from collections.abc import Callable, Iterable
28
from itertools import islice
29
from typing import Any
30

zhuwenwen's avatar
zhuwenwen committed
31
32
import os
import re
33
34
35
import torch
from torch import nn

36
from vllm.attention.layer import Attention
37
from vllm.compilation.decorators import support_torch_compile
38
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
39
40
41
42
43
44
from vllm.distributed import (
    get_ep_group,
    get_pp_group,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
45
46
47
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
48
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
49
from vllm.model_executor.layers.layernorm import RMSNorm
50
51
52
53
54
55
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
56
57
58
59
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
60
61
62
    ParallelLMHead,
    VocabParallelEmbedding,
)
63
from vllm.model_executor.model_loader.weight_utils import (
64
65
66
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
67
from vllm.model_executor.models.utils import sequence_parallel_chunk
68
69
from vllm.sequence import IntermediateTensors

70
from .interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP
71
72
73
74
75
76
77
78
79
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
80

81
from vllm import envs
zhuwenwen's avatar
zhuwenwen committed
82
83
84
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
from vllm.utils import W8a8GetCacheJSON
85
from vllm.utils.torch_utils import direct_register_custom_op
86
87
88
89
90
91
92
93
94
95

logger = init_logger(__name__)


class Qwen3MoeMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
96
        quant_config: QuantizationConfig | None = None,
97
98
99
100
101
        reduce_results: bool = True,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
102
103
            hidden_size,
            [intermediate_size] * 2,
104
105
            bias=False,
            quant_config=quant_config,
106
107
108
109
110
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
111
112
            bias=False,
            quant_config=quant_config,
113
114
115
            reduce_results=reduce_results,
            prefix=f"{prefix}.down_proj",
        )
116
        if hidden_act != "silu":
117
118
119
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
120
121
122
123
124
125
126
127
128
129
130
131
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class Qwen3MoeSparseMoeBlock(nn.Module):
    def __init__(
        self,
132
        vllm_config: VllmConfig,
133
134
135
        prefix: str = "",
    ):
        super().__init__()
136

137
        config = vllm_config.model_config.hf_text_config
138
139
140
        parallel_config = vllm_config.parallel_config
        quant_config = vllm_config.quant_config

141
142
        self.tp_size = get_tensor_model_parallel_world_size()

143
        self.ep_group = get_ep_group().device_group
144
        self.ep_rank = get_ep_group().rank_in_group
145
146
147
        self.ep_size = self.ep_group.size()
        self.n_routed_experts = config.num_experts

148
149
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe

150
151
152
        if self.tp_size > config.num_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
153
154
                f"the number of experts {config.num_experts}."
            )
155

156
157
        # Load balancing settings.
        vllm_config = get_current_vllm_config()
158
        eplb_config = vllm_config.parallel_config.eplb_config
159
        self.enable_eplb = parallel_config.enable_eplb
160
161

        self.n_logical_experts = self.n_routed_experts
162
        self.n_redundant_experts = eplb_config.num_redundant_experts
163
        self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
164
165
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
        self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
        self.physical_expert_end = (
            self.physical_expert_start + self.n_local_physical_experts
        )

        self.experts = FusedMoE(
            num_experts=self.n_routed_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            reduce_results=True,
            renormalize=config.norm_topk_prob,
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
            is_sequence_parallel=self.is_sequence_parallel,
183
            routing_method_type=RoutingMethodType.Renormalize,
184
185
186
187
188
189
190
191
192
        )

        self.gate = ReplicatedLinear(
            config.hidden_size,
            config.num_experts,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate",
        )
193
194

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
195
196
197
        assert hidden_states.dim() <= 2, (
            "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs"
        )
198
        is_input_1d = hidden_states.dim() == 1
199
        num_tokens, hidden_dim = hidden_states.shape
200
201
        hidden_states = hidden_states.view(-1, hidden_dim)

202
203
204
        if self.is_sequence_parallel:
            hidden_states = sequence_parallel_chunk(hidden_states)

205
206
        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
207
208
209
        final_hidden_states = self.experts(
            hidden_states=hidden_states, router_logits=router_logits
        )
210

211
212
        if self.is_sequence_parallel:
            final_hidden_states = tensor_model_parallel_all_gather(
213
214
                final_hidden_states, 0
            )
215
            final_hidden_states = final_hidden_states[:num_tokens]
216

217
218
219
220
221
        if self.is_sequence_parallel:
            final_hidden_states = tensor_model_parallel_all_gather(
                final_hidden_states, 0)
            final_hidden_states = final_hidden_states[:num_tokens]

222
        # return to 1d if input is 1d
223
        return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states
224
225
226
227
228
229
230
231


class Qwen3MoeAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
232
        rope_parameters: dict[str, Any],
233
        max_position_embeddings: int = 8192,
234
        head_dim: int | None = None,
235
236
        rms_norm_eps: float = 1e-06,
        qkv_bias: bool = False,
237
238
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
239
        prefix: str = "",
240
        dual_chunk_attention_config: dict[str, Any] | None = None,
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = head_dim or (hidden_size // self.total_num_heads)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings
263
        self.dual_chunk_attention_config = dual_chunk_attention_config
264

265
266
267
268
269
270
271
272
273
        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=qkv_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
274

275
276
277
278
279
280
281
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
282
283
284
285

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position_embeddings,
286
            rope_parameters=rope_parameters,
287
288
289
290
291
292
293
294
295
296
297
298
299
            dual_chunk_attention_config=dual_chunk_attention_config,
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            **{
                "layer_idx": extract_layer_index(prefix),
                "dual_chunk_attention_config": dual_chunk_attention_config,
300
301
302
            }
            if dual_chunk_attention_config
            else {},
303
304
305
306
307
        )

        self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
        self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)

308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
    def rms_rotary_embedding_fuse(
        positions: torch.Tensor,
        query: torch.Tensor,
        head_size: int,
        cos_sin_cache: torch.Tensor,
        is_neox_style: bool,
        q_weight: torch.Tensor,
        k_weight: torch.Tensor,
        epsilon: float,
        key: torch.Tensor | None = None,
        q_bias: torch.Tensor | None = None,
        k_bias: torch.Tensor | None = None,
    ) -> None:
        from lightop import rms_rotary_embedding_fuse as fused_kernel
        fused_kernel(
            positions,
            query,
            key,
            head_size,
            cos_sin_cache,
            is_neox_style,
            q_weight,
            k_weight,
            q_bias,
            k_bias,
            epsilon,
        )

    def rms_rotary_embedding_fuse_fake(
        # q_out:torch.Tensor,
        # k_out:torch.Tensor,
        positions: torch.Tensor,
        query: torch.Tensor,
        head_size: int,
        cos_sin_cache: torch.Tensor,
        is_neox_style: bool,
        q_weight: torch.Tensor,
        k_weight: torch.Tensor,
        epsilon: float,
        key: torch.Tensor | None = None,
        q_bias: torch.Tensor | None = None,
        k_bias: torch.Tensor | None = None,
    ) -> None:
        # Fake impl intentionally left as no-op for graph tracing modes.
        pass


    direct_register_custom_op(
        op_name="rms_rotary_embedding_fuse",
        op_func=rms_rotary_embedding_fuse,
        mutates_args=["query", "key"],
        fake_impl=rms_rotary_embedding_fuse_fake,
    )

362
363
364
365
366
367
368
369
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        # Add qk-norm
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
        if envs.VLLM_USE_FUSED_RMS_ROPE :
            # Fused RMSNorm + RoPE path through custom op.
            cos_sin_cache = self.rotary_emb.cos_sin_cache
            if (cos_sin_cache.device != q.device
                    or cos_sin_cache.dtype != q.dtype):
                cos_sin_cache = cos_sin_cache.to(q.device,
                                                 dtype=q.dtype,
                                                 non_blocking=True)
                # Persist the converted cache so we don't re-copy/re-allocate
                # on every forward when the original buffer starts on CPU.
                self.rotary_emb.cos_sin_cache = cos_sin_cache
            # # q, k 使用 continuous
            q = q.contiguous()
            k = k.contiguous()
            torch.ops.vllm.rms_rotary_embedding_fuse(
                positions,
                q,
                k,
                self.head_dim,
                cos_sin_cache,
                self.rotary_emb.is_neox_style,
                self.q_norm.weight,
                self.k_norm.weight,
                None,
                None,
                self.q_norm.variance_epsilon,
            )
zhuwenwen's avatar
zhuwenwen committed
397
        else:
398
399
400
401
402
403
404
405
406
407
408
409
410
411
            q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
            if envs.VLLM_USE_APEX_RN:
                q_by_head = self.q_norm.forward_apex(q_by_head)
            else:
                q_by_head = self.q_norm.forward_cuda(q_by_head)
            q = q_by_head.view(q.shape)

            k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
            if envs.VLLM_USE_APEX_RN:
                k_by_head = self.k_norm.forward_apex(k_by_head)
            else:
                k_by_head = self.k_norm.forward_cuda(k_by_head)
            k = k_by_head.view(k.shape)
            q, k = self.rotary_emb(positions, q, k)
412
413
414
415
416
417
418
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output


class Qwen3MoeDecoderLayer(nn.Module):

419
    def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
420
        super().__init__()
421

422
        config = vllm_config.model_config.hf_text_config
423
424
425
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

426
        self.hidden_size = config.hidden_size
427
428
429
430
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
        dual_chunk_attention_config = getattr(
            config, "dual_chunk_attention_config", None
        )
431
432
433
434
        self.self_attn = Qwen3MoeAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
435
            rope_parameters=config.rope_parameters,
436
437
            max_position_embeddings=max_position_embeddings,
            rms_norm_eps=config.rms_norm_eps,
438
439
            qkv_bias=getattr(config, "attention_bias", False),
            head_dim=getattr(config, "head_dim", None),
440
441
442
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
443
            dual_chunk_attention_config=dual_chunk_attention_config,
444
445
446
447
        )

        # `mlp_only_layers` in the config.
        layer_idx = extract_layer_index(prefix)
448
449
450
        mlp_only_layers = (
            [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
        )
451
        if (layer_idx not in mlp_only_layers) and (
452
453
454
455
456
            config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
        ):
            self.mlp = Qwen3MoeSparseMoeBlock(
                vllm_config=vllm_config, prefix=f"{prefix}.mlp"
            )
457
        else:
458
459
460
461
462
463
464
465
466
467
468
            self.mlp = Qwen3MoeMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                prefix=f"{prefix}.mlp",
            )
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
469
470
471
472
473

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
474
        residual: torch.Tensor | None,
475
    ) -> tuple[torch.Tensor, torch.Tensor]:
476
477
478
479
480
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
481
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
482
483
484
485
486
487
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
488
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
489
490
491
492
493
494
495
496
497
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


@support_torch_compile
class Qwen3MoeModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

498
        config = vllm_config.model_config.hf_text_config
499
        quant_config = vllm_config.quant_config
500
        parallel_config = vllm_config.parallel_config
501
502
        eplb_config = parallel_config.eplb_config
        self.num_redundant_experts = eplb_config.num_redundant_experts
503
504
505

        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
506
        self.config = config
507
        self.quant_config = quant_config
508
509
510
511
512
513
514
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config        
            # if self.config.quantization_config["bits"] == 4:
            os.environ['LLAMA_NN'] = '0'
            os.environ['LM_NN'] = '0'  
515
516
517
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
518
            quant_config=quant_config,
519
520
            prefix=f"{prefix}.embed_tokens",
        )
521
522
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
523
            lambda prefix: Qwen3MoeDecoderLayer(vllm_config=vllm_config, prefix=prefix),
524
525
526
            prefix=f"{prefix}.layers",
        )
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
527
528
529
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
530
531
        # Track layers for auxiliary hidden state outputs (EAGLE3)
        self.aux_hidden_state_layers: tuple[int, ...] = ()
zhuwenwen's avatar
zhuwenwen committed
532
533
534
535
536
537
538
539
        
        self.tritonsingleton= W8a8GetCacheJSON()
            
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
        self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
540

541
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
542
543
544
545
546
547
        return self.embed_tokens(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
548
549
550
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
551
552
553
554
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
555
                hidden_states = self.embed_input_ids(input_ids)
556
557
558
559
560
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
561
562
563
564
565
566
567
568
569
570
571
572

        aux_hidden_states = []
        for layer_idx, layer in enumerate(
            islice(self.layers, self.start_layer, self.end_layer),
            start=self.start_layer,
        ):
            # Collect auxiliary hidden states if specified
            if layer_idx in self.aux_hidden_state_layers:
                aux_hidden_state = (
                    hidden_states + residual if residual is not None else hidden_states
                )
                aux_hidden_states.append(aux_hidden_state)
573
            hidden_states, residual = layer(positions, hidden_states, residual)
574

575
        if not get_pp_group().is_last_rank:
576
577
578
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
579
        hidden_states, _ = self.norm(hidden_states, residual)
580
581
582
583

        # Return auxiliary hidden states if collected
        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
584
585
        return hidden_states

586
587
588
589
590
591
592
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        return FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
593
            num_experts=self.config.num_experts,
594
595
            num_redundant_experts=self.num_redundant_experts,
        )
596

597
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
598
599
600
601
602
603
604
605
606
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

607
        # Skip loading extra parameters for GPTQ/modelopt models.
608
609
610
611
612
613
614
615
616
617
618
619
        ignore_suffixes = (
            ".bias",
            "_bias",
            ".k_scale",
            "_k_scale",
            ".v_scale",
            "_v_scale",
            ".weight_scale",
            "_weight_scale",
            ".input_scale",
            "_input_scale",
        )
620

621
        params_dict = dict(self.named_parameters())
622
        loaded_params: set[str] = set()
623
        expert_params_mapping = self.get_expert_mapping()
624
        for name, loaded_weight in weights:
zhuwenwen's avatar
zhuwenwen committed
625
626
627
            if self.use_llama_nn:
                current_count = loaded_weight.current_count 
                total_count = loaded_weight.total_count
628
629
630
631
632
633
634
635
636
637
638
639
640
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                assert loaded_weight.numel() == 1, (
                    f"KV scale numel {loaded_weight.numel()} != 1"
                )
                loaded_weight = loaded_weight.squeeze()
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
641
            for param_name, weight_name, shard_id in stacked_params_mapping:
642
643
644
645
646
647
648
649
650
651
652
653
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if "mlp.experts" in name:
                    continue
                name = name.replace(weight_name, param_name)
654
655
656

                # Skip loading extra parameters for GPTQ/modelopt models.
                if name.endswith(ignore_suffixes) and name not in params_dict:
657
                    continue
658

659
660
661
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
662
663
664
665
666
                if name.endswith("scale"):
                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue
667
668
669
670
                if name not in params_dict:
                    continue

                param = params_dict[name]
671
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
672
673
674
675
                if weight_loader == default_weight_loader:
                    weight_loader(param, loaded_weight)
                else:
                    weight_loader(param, loaded_weight, shard_id)
676
677
                break
            else:
678
                is_expert_weight = False
679
680
681
682
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
683
684
685
686
687
688
689
690
691
692

                    # Anyway, this is an expert weight and should not be
                    # attempted to load as other weights later
                    is_expert_weight = True

                    # Do not modify `name` since the loop may continue here
                    # Instead, create a new variable
                    name_mapped = name.replace(weight_name, param_name)

                    if is_pp_missing_parameter(name_mapped, self):
693
                        continue
694

695
                    # Skip loading extra parameters for GPTQ/modelopt models.
696
697
698
699
                    if (
                        name_mapped.endswith(ignore_suffixes)
                        and name_mapped not in params_dict
                    ):
700
                        continue
701
702
703
704
705

                    param = params_dict[name_mapped]
                    # We should ask the weight loader to return success or not
                    # here since otherwise we may skip experts with other
                    # available replicas.
706
707
708
709
710
711
712
713
714
715
716
                    weight_loader = typing.cast(
                        Callable[..., bool], param.weight_loader
                    )
                    success = weight_loader(
                        param,
                        loaded_weight,
                        name_mapped,
                        shard_id=shard_id,
                        expert_id=expert_id,
                        return_success=True,
                    )
717
718
719
                    if success:
                        name = name_mapped
                        break
720
                else:
721
722
723
724
725
726
                    if is_expert_weight:
                        # We've checked that this is an expert weight
                        # However it's not mapped locally to this rank
                        # So we simply skip it
                        continue

727
                    # Skip loading extra parameters for GPTQ/modelopt models.
728
                    if name.endswith(ignore_suffixes) and name not in params_dict:
729
730
731
732
733
734
735
                        continue
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
                    # Remapping the name of FP8 kv-scale.
                    if name.endswith("kv_scale"):
                        remapped_kv_scale_name = name.replace(
736
737
                            ".kv_scale", ".attn.kv_scale"
                        )
738
739
                        if remapped_kv_scale_name not in params_dict:
                            logger.warning_once(
740
741
742
743
                                "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.",  # noqa: E501
                                name,
                                remapped_kv_scale_name,
                            )
744
745
746
747
                            continue
                        else:
                            name = remapped_kv_scale_name
                    param = params_dict[name]
748
749
750
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
751
752
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
zhuwenwen's avatar
zhuwenwen committed
753
        
zhuwenwen's avatar
zhuwenwen committed
754
        if self.use_llama_nn and self.quant_method is None and current_count==total_count:
zhuwenwen's avatar
zhuwenwen committed
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
            lay_key_words = [
                "gate_up_proj.weight",
                "down_proj.weight",
                "mlp.gate.weight",
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "lm_head.weight",
            ]
            combined_words = "|".join(lay_key_words)
            
            # lay_qkv_words = ["self_attn.qkv_proj.weight"]   
            # qkv_words = "|".join(lay_qkv_words)  
            
            # lay_qkv_bias_words = ["self_attn.qkv_proj.bias"]   
            # qkv_bias_words = "|".join(lay_qkv_bias_words) 
            
zhuwenwen's avatar
zhuwenwen committed
771
772
            # for layername in loaded_params:
            for layername in params_dict.keys():
zhuwenwen's avatar
zhuwenwen committed
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
                weight = params_dict[layername]
                os.environ['LM_NN'] = '0' 
                # if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
                #     weight.data = pad_weight(weight.data, 32)
                    
                matches = re.findall(combined_words, layername)
                if matches:   
                    # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                    #     weight.data = pad_weight(weight.data, 32)  
                    
                    # if self.use_fa_pad and (re.findall(qkv_words, layername)):
                    #     if not gemm_bank_conf(weight.data.shape[0]):
                    #         weight.data = pad_weight(weight.data, 32)
                        
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1],-1)
794
        return loaded_params
795
796


797
798
799
class Qwen3MoeForCausalLM(
    nn.Module, SupportsPP, SupportsLoRA, SupportsEagle3, MixtureOfExperts
):
800
801
802
803
804
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
805
        ]
806
    }
807
808
809
810
811

    fall_back_to_pt_during_load = False

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
812
        config = vllm_config.model_config.hf_text_config
813
814
815
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
816
817
        # Only perform the following mapping when Qwen3MoeMLP exists
        if getattr(config, "mlp_only_layers", []):
818
            self.packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"]
819
820
821
822
823
824
825
826
827
        self.model = Qwen3MoeModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
828
829
830
831
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
832
833
            self.model.make_empty_intermediate_tensors
        )
834

835
836
837
        # Set MoE hyperparameters
        self.expert_weights = []

838
        self.moe_layers = []
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
        example_layer = None
        for layer in self.model.layers:
            if isinstance(layer, PPMissingLayer):
                continue

            assert isinstance(layer, Qwen3MoeDecoderLayer)
            if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
                example_layer = layer.mlp
                self.moe_layers.append(layer.mlp.experts)

        if example_layer is None:
            raise RuntimeError("No Qwen3MoE layer found in the model.layers.")

        self.num_moe_layers = len(self.moe_layers)
        self.num_expert_groups = 1
        self.num_shared_experts = 0
        self.num_logical_experts = example_layer.n_logical_experts
        self.num_physical_experts = example_layer.n_physical_experts
        self.num_local_physical_experts = example_layer.n_local_physical_experts
        self.num_routed_experts = example_layer.n_routed_experts
        self.num_redundant_experts = example_layer.n_redundant_experts

    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
869
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
870
871
872
873
874
875
876
877
        for layer in self.model.layers:
            if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
                moe = layer.mlp
                moe.n_local_physical_experts = num_local_physical_experts
                moe.n_physical_experts = num_physical_experts
                moe.n_redundant_experts = self.num_redundant_experts
                moe.experts.update_expert_map()

878
879
880
881
882
883
884
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
        self.model.aux_hidden_state_layers = layers

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

885
886
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
887
888
889
890
891

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
892
893
894
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
895
896
897
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
898
899
900
901
902
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
903
    ) -> torch.Tensor | None:
904
        logits = self.logits_processor(self.lm_head, hidden_states)
905
906
        return logits

907
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
908
        loader = AutoWeightsLoader(self)
909
        return loader.load_weights(weights)
910
911

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
912
        return self.model.get_expert_mapping()