qwen3_next.py 28.7 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Qwen3Next model."""
4

5
from collections.abc import Iterable
6
from itertools import islice
7
8
9
10
11

import torch
from torch import nn

from vllm.compilation.decorators import support_torch_compile
12
13
14
15
16
17
18
19
20
21
22
23
from vllm.config import (
    CacheConfig,
    ModelConfig,
    VllmConfig,
    get_current_vllm_config,
)
from vllm.distributed import (
    get_ep_group,
    get_pp_group,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
24
from vllm.logger import init_logger
25
from vllm.model_executor.layers.attention import Attention
26
27
28
29
from vllm.model_executor.layers.fused_moe import (
    FusedMoE,
    fused_moe_make_expert_params_mapping,
)
30
31
32
from vllm.model_executor.layers.layernorm import (
    GemmaRMSNorm as Qwen3NextRMSNorm,
)
33
34
35
36
37
from vllm.model_executor.layers.linear import (
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
38
from vllm.model_executor.layers.logits_processor import LogitsProcessor
39
from vllm.model_executor.layers.mamba.gdn_linear_attn import GatedDeltaNetAttention
40
from vllm.model_executor.layers.mamba.mamba_utils import (
41
42
    MambaStateCopyFunc,
    MambaStateCopyFuncCalculator,
43
44
45
    MambaStateDtypeCalculator,
    MambaStateShapeCalculator,
)
46
47
48
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
49
50
51
    ParallelLMHead,
    VocabParallelEmbedding,
)
52
from vllm.model_executor.model_loader.weight_utils import (
53
    default_weight_loader,
54
    maybe_remap_kv_scale_name,
55
)
56
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
57
from vllm.model_executor.models.utils import sequence_parallel_chunk
58
from vllm.sequence import IntermediateTensors
59
from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig
60

61
from .interfaces import (
62
    EagleModelMixin,
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    HasInnerState,
    IsHybrid,
    MixtureOfExperts,
    SupportsLoRA,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
78
79
80
81
82
83
84

logger = init_logger(__name__)

KVCache = tuple[torch.Tensor, torch.Tensor]


class Qwen3NextSparseMoeBlock(nn.Module):
85
    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
86
        super().__init__()
87

88
        config = vllm_config.model_config.hf_text_config
89
90
91
        parallel_config = vllm_config.parallel_config
        quant_config = vllm_config.quant_config

92
93
94
        self.tp_size = get_tensor_model_parallel_world_size()

        self.ep_group = get_ep_group().device_group
95
        self.ep_rank = get_ep_group().rank_in_group
96
97
98
        self.ep_size = self.ep_group.size()
        self.n_routed_experts = config.num_experts

99
100
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe

101
102
103
        if self.tp_size > config.num_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
104
105
                f"the number of experts {config.num_experts}."
            )
106
107
108
109

        # Load balancing settings.
        vllm_config = get_current_vllm_config()
        eplb_config = vllm_config.parallel_config.eplb_config
110
        self.enable_eplb = parallel_config.enable_eplb
111
112
113

        self.n_logical_experts = self.n_routed_experts
        self.n_redundant_experts = eplb_config.num_redundant_experts
114
        self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
115
116
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

117
118
119
120
121
122
123
124
125
        self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
        self.physical_expert_end = (
            self.physical_expert_start + self.n_local_physical_experts
        )

        self.gate = ReplicatedLinear(
            config.hidden_size,
            config.num_experts,
            bias=False,
126
            quant_config=None,
127
128
            prefix=f"{prefix}.gate",
        )
129

130
131
132
133
134
135
136
        self.shared_expert_gate = ReplicatedLinear(
            config.hidden_size,
            1,
            bias=False,
            quant_config=None,
            prefix=f"{prefix}.shared_expert_gate",
        )
137

138
139
140
141
142
143
        if config.shared_expert_intermediate_size > 0:
            self.shared_expert = Qwen3NextMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.shared_expert_intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
144
145
                reduce_results=False,
                expert_gate=self.shared_expert_gate,
146
                is_sequence_parallel=self.is_sequence_parallel,
147
                prefix=f"{prefix}.shared_expert",
148
149
150
            )
        else:
            self.shared_expert = None
151

152
        self.experts = FusedMoE(
153
            shared_experts=self.shared_expert,
154
            gate=self.gate,
155
156
157
158
            num_experts=self.n_routed_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
159
            renormalize=getattr(config, "norm_topk_prob", True),
160
161
162
163
164
165
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
            is_sequence_parallel=self.is_sequence_parallel,
        )
166
167
168
169

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
170
        num_tokens, hidden_dim = hidden_states.shape
171
172
        hidden_states = hidden_states.view(-1, hidden_dim)

173
174
175
        if self.is_sequence_parallel:
            hidden_states = sequence_parallel_chunk(hidden_states)

176
177
178
179
180
181
182
183
184
185
186
        if self.experts.is_internal_router:
            # In this case, the gate/router runs inside the FusedMoE class
            final_hidden_states = self.experts(
                hidden_states=hidden_states, router_logits=hidden_states
            )
        else:
            # router_logits: (num_tokens, n_experts)
            router_logits, _ = self.gate(hidden_states)
            final_hidden_states = self.experts(
                hidden_states=hidden_states, router_logits=router_logits
            )
187

188
189
        if self.is_sequence_parallel:
            final_hidden_states = tensor_model_parallel_all_gather(
190
191
                final_hidden_states, 0
            )
192
            final_hidden_states = final_hidden_states[:num_tokens]
193
194
195
196
197
198
199
200

        return final_hidden_states.view(orig_shape)


class Qwen3NextAttention(nn.Module):
    def __init__(
        self,
        config: Qwen3NextConfig,
201
202
203
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = config.head_dim or (self.hidden_size // self.num_heads)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.dual_chunk_attention_config = getattr(
228
229
            config, "dual_chunk_attention_config", None
        )
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
        self.attn_output_gate = getattr(config, "attn_output_gate", True)

        self.qkv_proj = QKVParallelLinear(
            config.hidden_size,
            self.head_dim,
            self.total_num_heads * (1 + self.attn_output_gate),
            self.total_num_kv_heads,
            bias=getattr(config, "qkv_bias", False),
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            config.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        self.rotary_emb = get_rope(
            head_size=self.head_dim,
            max_position=config.max_position_embeddings,
253
            rope_parameters=config.rope_parameters,
254
255
256
257
258
259
260
261
262
263
264
265
266
            dual_chunk_attention_config=self.dual_chunk_attention_config,
        )

        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            **{
                "layer_idx": extract_layer_index(prefix),
267
268
269
270
                "dual_chunk_attention_config": self.dual_chunk_attention_config,
            }
            if self.dual_chunk_attention_config
            else {},
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
        )

        self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
        self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        output: torch.Tensor,
        hidden_states: torch.Tensor,
    ):
        qkv, _ = self.qkv_proj(hidden_states)

        if self.attn_output_gate:
            q_gate, k, v = qkv.split(
286
287
                [self.q_size * 2, self.kv_size, self.kv_size], dim=-1
            )
288
289
290
291
292
293
            orig_shape = q_gate.shape[:-1]
            q_gate = q_gate.view(*orig_shape, self.num_heads, -1)
            q, gate = torch.chunk(q_gate, 2, dim=-1)
            q = q.reshape(*orig_shape, -1)
            gate = gate.reshape(*orig_shape, -1)
        else:
294
            q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
295
296

        q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view(
297
298
            -1, self.num_heads * self.head_dim
        )
299
        k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view(
300
301
            -1, self.num_kv_heads * self.head_dim
        )
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316

        q, k = self.rotary_emb(positions, q, k)

        attn_output = self.attn(q, k, v)

        if self.attn_output_gate:
            gate = torch.sigmoid(gate)
            attn_output = attn_output * gate

        output[:], _ = self.o_proj(attn_output)


class Qwen3NextDecoderLayer(nn.Module):
    def __init__(
        self,
317
        vllm_config: VllmConfig,
318
319
320
321
        layer_type: str,
        prefix: str = "",
    ) -> None:
        super().__init__()
322
323
324
325
326

        config = vllm_config.model_config.hf_config
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
327
328
329
330
331

        self.layer_type = layer_type
        self.layer_idx = extract_layer_index(prefix)

        if self.layer_type == "linear_attention":
332
            self.linear_attn = GatedDeltaNetAttention(
333
                config,
334
                vllm_config=vllm_config,
335
                prefix=f"{prefix}.linear_attn",
336
                gqa_interleaved_layout=True,
337
            )
338
339
340
341
342
343
        elif self.layer_type == "full_attention":
            self.self_attn = Qwen3NextAttention(
                config,
                model_config=model_config,
                cache_config=cache_config,
                quant_config=quant_config,
344
                prefix=f"{prefix}.self_attn",
345
346
347
348
            )
        else:
            raise ValueError(f"Invalid layer_type {self.layer_type}")

349
350
351
        mlp_only_layers = (
            [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
        )
352
        if (self.layer_idx not in mlp_only_layers) and (
353
354
355
            config.num_experts > 0
            and (self.layer_idx + 1) % config.decoder_sparse_step == 0
        ):
356
            self.mlp = Qwen3NextSparseMoeBlock(
357
                vllm_config=vllm_config,
358
359
360
361
362
363
364
365
                prefix=f"{prefix}.mlp",
            )
        else:
            self.mlp = Qwen3NextMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
366
                prefix=f"{prefix}.mlp",
367
368
            )

369
370
371
        self.input_layernorm = Qwen3NextRMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
372
        self.post_attention_layernorm = Qwen3NextRMSNorm(
373
374
            config.hidden_size, eps=config.rms_norm_eps
        )
375
376
377
378
379
380
381

        self.layer_scale = getattr(config, "layer_scale", False)
        if self.layer_scale:
            self.attn_layer_scale = torch.nn.Parameter(
                torch.zeros(
                    1,
                    1,
382
                    config.hidden_size,
383
384
                ),
            )
385
386
387
388
            self.ffn_layer_scale = torch.nn.Parameter(
                torch.zeros(
                    1,
                    1,
389
                    config.hidden_size,
390
391
                ),
            )
392
393
394
395

    def forward(
        self,
        hidden_states: torch.Tensor,
396
        residual: torch.Tensor | None,
397
398
399
400
401
402
403
        positions: torch.Tensor = None,
        **kwargs: object,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
404
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424

        self_attention_output = torch.empty_like(hidden_states)
        if self.layer_type == "linear_attention":
            self.linear_attn(
                hidden_states=hidden_states,
                output=self_attention_output,
            )
        elif self.layer_type == "full_attention":
            self.self_attn(
                hidden_states=hidden_states,
                output=self_attention_output,
                positions=positions,
            )
        else:
            raise ValueError("Invalid layer_type")
        hidden_states = self_attention_output

        if self.layer_scale:
            if len(hidden_states.shape) == 2:
                hidden_states = hidden_states * (
425
426
                    self.attn_layer_scale.to(hidden_states.dtype)[0] + 1
                )
427
428
            else:
                hidden_states = hidden_states * (
429
430
                    self.attn_layer_scale.to(hidden_states.dtype) + 1
                )
431
432

        # Fully Connected
433
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
434
435
436
437
438
        hidden_states = self.mlp(hidden_states)

        if self.layer_scale:
            if len(hidden_states.shape) == 2:
                hidden_states = hidden_states * (
439
440
                    self.ffn_layer_scale.to(hidden_states.dtype)[0] + 1
                )
441
            else:
442
                assert len(hidden_states.shape) == len(self.ffn_layer_scale.shape), (
443
444
445
                    f"shape must be the same {len(hidden_states.shape)}, "
                    f"{len(self.ffn_layer_scale.shape)}"
                )
446
                hidden_states = hidden_states * (
447
448
                    self.ffn_layer_scale.to(hidden_states.dtype) + 1
                )
449
450
451
452
453

        return hidden_states, residual


@support_torch_compile
454
class Qwen3NextModel(nn.Module, EagleModelMixin):
455
456
457
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

458
        config: Qwen3NextConfig = vllm_config.model_config.hf_text_config
459
        parallel_config = vllm_config.parallel_config
460

461
462
463
464
        eplb_config = parallel_config.eplb_config
        self.num_redundant_experts = eplb_config.num_redundant_experts

        self.config = config
465
466

        self.vocab_size = config.vocab_size
467
468
469
470
471
472
473
474

        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
        )

        def get_layer(prefix: str):
            return Qwen3NextDecoderLayer(
475
                vllm_config,
476
477
478
479
480
                layer_type=config.layer_types[extract_layer_index(prefix)],
                prefix=prefix,
            )

        self.start_layer, self.end_layer, self.layers = make_layers(
481
482
483
484
485
            config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
        )
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
486

487
        if get_pp_group().is_last_rank:
488
            self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
489
490
        else:
            self.norm = PPMissingLayer()
491

492
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
493
494
495
496
        return self.embed_tokens(input_ids)

    def forward(
        self,
497
        input_ids: torch.Tensor | None,
498
        positions: torch.Tensor,
499
500
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
501
    ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
502
503
504
505
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
506
                hidden_states = self.embed_input_ids(input_ids)
507
508
509
510
511
512
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

513
        aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
514
515
516
517
        for layer_idx, layer in enumerate(
            islice(self.layers, self.start_layer, self.end_layer),
            start=self.start_layer,
        ):
518
519
520
521
522
            hidden_states, residual = layer(
                positions=positions,
                hidden_states=hidden_states,
                residual=residual,
            )
523
524
525
            self._maybe_add_hidden_state(
                aux_hidden_states, layer_idx + 1, hidden_states, residual
            )
526
527

        if not get_pp_group().is_last_rank:
528
529
530
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
531
        hidden_states, _ = self.norm(hidden_states, residual)
532
533
        if aux_hidden_states:
            return hidden_states, aux_hidden_states
534
535
536
537
538
        return hidden_states

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
539
        return fused_moe_make_expert_params_mapping(
540
            self,
541
542
543
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
544
            num_experts=getattr(self.config, "num_experts", 0),
545
546
            num_redundant_experts=self.num_redundant_experts,
        )
547

548
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        expert_params_mapping = self.get_expert_mapping()
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue

            if name.startswith("mtp."):
                continue

568
569
570
571
572
573
            # Remapping the name of FP8 kv-scale.
            if name.endswith("scale"):
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue

                if "mlp.experts" in name:
                    continue

                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
                # name = apply_attn_prefix(name, params_dict)
                if name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
                    # Skip loading extra bias for GPTQ models.
605
606
607
                    if (
                        name.endswith(".bias") or name.endswith("_bias")
                    ) and name not in params_dict:
608
                        continue
609
610
                    if name not in params_dict:
                        continue
611
612
                    param = params_dict[name]
                    weight_loader = param.weight_loader
613
614
615
616
617
618
619
                    weight_loader(
                        param,
                        loaded_weight,
                        name,
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
620
621
622
623
624
625
626
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
                    if is_pp_missing_parameter(name, self):
                        continue
627
628
629
630
631
                    if name not in params_dict:
                        logger.warning_once(
                            f"Parameter {name} not found in params_dict, skip loading"
                        )
                        continue
632
                    param = params_dict[name]
633
634
635
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
636
637
638
639
640
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
class QwenNextMixtureOfExperts(MixtureOfExperts):
    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
        for layer in self.model.layers:
            if isinstance(layer.mlp, Qwen3NextSparseMoeBlock):
                moe = layer.mlp
                moe.n_local_physical_experts = num_local_physical_experts
                moe.n_physical_experts = num_physical_experts
                moe.n_redundant_experts = self.num_redundant_experts
                moe.experts.update_expert_map()

    def set_moe_parameters(self):
        self.expert_weights = []

        self.moe_layers = []
        example_moe = None
        for layer in self.model.layers:
            if isinstance(layer, Qwen3NextDecoderLayer) and isinstance(
                layer.mlp, Qwen3NextSparseMoeBlock
            ):
                example_moe = layer.mlp
                self.moe_layers.append(layer.mlp.experts)

671
672
        if example_moe is None:
            raise RuntimeError("No Qwen3Next layer found in the model.layers.")
673
674
675
676
677
678
679
680
681
682
683
684

        # Set MoE hyperparameters
        self.num_moe_layers = len(self.moe_layers)
        self.num_expert_groups = 1
        self.num_shared_experts = 0
        self.num_logical_experts = example_moe.n_logical_experts
        self.num_physical_experts = example_moe.n_physical_experts
        self.num_local_physical_experts = example_moe.n_local_physical_experts
        self.num_routed_experts = example_moe.n_routed_experts
        self.num_redundant_experts = example_moe.n_redundant_experts


685
class Qwen3NextForCausalLM(
686
687
688
689
690
691
    nn.Module,
    HasInnerState,
    SupportsLoRA,
    SupportsPP,
    QwenNextMixtureOfExperts,
    IsHybrid,
692
):
693
694
695
696
697
698
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
699
        "gate_up_proj": ["gate_proj", "up_proj"],
700
701
        "in_proj_qkvz": ["in_proj_qkvz"],
        "in_proj_ba": ["in_proj_ba"],
702
703
704
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
705
        config = vllm_config.model_config.hf_text_config
706
707
708
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
709

710
        scheduler_config = vllm_config.scheduler_config
711
712
713
714
715
        if cache_config.mamba_cache_mode == "all":
            raise NotImplementedError(
                "Qwen3Next currently does not support 'all' prefix caching, "
                "please use '--mamba-cache-mode=align' instead"
            )
716
717
718
719
720
        self.quant_config = vllm_config.quant_config

        super().__init__()
        self.config = config
        self.scheduler_config = scheduler_config
721
722
723
        self.model = Qwen3NextModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
724

725
        self.lm_head = ParallelLMHead(
726
            config.vocab_size,
727
            config.hidden_size,
728
729
            prefix=maybe_prefix(prefix, "lm_head"),
        )
730
        self.logits_processor = LogitsProcessor(config.vocab_size)
731
        self.make_empty_intermediate_tensors = (
732
733
            self.model.make_empty_intermediate_tensors
        )
734
735

        # Set MoE hyperparameters
736
        self.set_moe_parameters()
737

738
739
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
740
741
742

    def forward(
        self,
743
        input_ids: torch.Tensor | None,
744
        positions: torch.Tensor,
745
746
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
747
748
        **kwargs: object,
    ):
749
750
751
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
752
753
754
755
756
757
758
759
760

        return hidden_states

    @classmethod
    def get_mamba_state_dtype_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[torch.dtype, torch.dtype]:
        return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
761
762
763
            vllm_config.model_config.dtype,
            vllm_config.cache_config.mamba_cache_dtype,
            vllm_config.cache_config.mamba_ssm_cache_dtype,
764
        )
765
766
767

    @classmethod
    def get_mamba_state_shape_from_config(
768
        cls, vllm_config: "VllmConfig"
769
770
    ) -> tuple[tuple[int, int], tuple[int, int]]:
        parallel_config = vllm_config.parallel_config
771
        hf_config = vllm_config.model_config.hf_text_config
772
        tp_size = parallel_config.tensor_parallel_size
773
774
775
776
777
        num_spec = (
            vllm_config.speculative_config.num_speculative_tokens
            if vllm_config.speculative_config
            else 0
        )
778
        return MambaStateShapeCalculator.gated_delta_net_state_shape(
779
780
781
782
783
784
785
786
            tp_size,
            hf_config.linear_num_key_heads,
            hf_config.linear_num_value_heads,
            hf_config.linear_key_head_dim,
            hf_config.linear_value_head_dim,
            hf_config.linear_conv_kernel_dim,
            num_spec,
        )
787

788
789
790
791
    @classmethod
    def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
        return MambaStateCopyFuncCalculator.gated_delta_net_state_copy_func()

792
793
794
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
795
    ) -> torch.Tensor | None:
796
        return self.logits_processor(self.lm_head, hidden_states)
797

798
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
799
800
801
802
803
804
805
806
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=["mtp."],
        )
        return loader.load_weights(weights)

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        return self.model.get_expert_mapping()