qwen3_next.py 28.7 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Qwen3Next model."""
4

5
from collections.abc import Iterable
6
from itertools import islice
7
8
9
10
11

import torch
from torch import nn

from vllm.compilation.decorators import support_torch_compile
12
13
14
15
16
17
18
19
20
21
22
23
from vllm.config import (
    CacheConfig,
    ModelConfig,
    VllmConfig,
    get_current_vllm_config,
)
from vllm.distributed import (
    get_ep_group,
    get_pp_group,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
24
from vllm.logger import init_logger
25
from vllm.model_executor.layers.attention import Attention
26
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
27
28
29
from vllm.model_executor.layers.layernorm import (
    GemmaRMSNorm as Qwen3NextRMSNorm,
)
30
31
32
33
34
from vllm.model_executor.layers.linear import (
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
35
from vllm.model_executor.layers.logits_processor import LogitsProcessor
36
from vllm.model_executor.layers.mamba.gdn_linear_attn import GatedDeltaNetAttention
37
from vllm.model_executor.layers.mamba.mamba_utils import (
38
39
    MambaStateCopyFunc,
    MambaStateCopyFuncCalculator,
40
41
42
    MambaStateDtypeCalculator,
    MambaStateShapeCalculator,
)
43
44
45
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
46
47
48
    ParallelLMHead,
    VocabParallelEmbedding,
)
49
from vllm.model_executor.model_loader.weight_utils import (
50
    default_weight_loader,
51
    maybe_remap_kv_scale_name,
52
)
53
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
54
from vllm.model_executor.models.utils import sequence_parallel_chunk
55
from vllm.sequence import IntermediateTensors
56
from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig
57

58
from .interfaces import (
59
    EagleModelMixin,
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    HasInnerState,
    IsHybrid,
    MixtureOfExperts,
    SupportsLoRA,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
75
76
77
78
79
80
81

logger = init_logger(__name__)

KVCache = tuple[torch.Tensor, torch.Tensor]


class Qwen3NextSparseMoeBlock(nn.Module):
82
    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
83
        super().__init__()
84

85
        config = vllm_config.model_config.hf_text_config
86
87
88
        parallel_config = vllm_config.parallel_config
        quant_config = vllm_config.quant_config

89
90
91
        self.tp_size = get_tensor_model_parallel_world_size()

        self.ep_group = get_ep_group().device_group
92
        self.ep_rank = get_ep_group().rank_in_group
93
94
95
        self.ep_size = self.ep_group.size()
        self.n_routed_experts = config.num_experts

96
97
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe

98
99
100
        if self.tp_size > config.num_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
101
102
                f"the number of experts {config.num_experts}."
            )
103
104
105
106

        # Load balancing settings.
        vllm_config = get_current_vllm_config()
        eplb_config = vllm_config.parallel_config.eplb_config
107
        self.enable_eplb = parallel_config.enable_eplb
108
109
110

        self.n_logical_experts = self.n_routed_experts
        self.n_redundant_experts = eplb_config.num_redundant_experts
111
        self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
112
113
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

114
115
116
117
118
119
120
121
122
        self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
        self.physical_expert_end = (
            self.physical_expert_start + self.n_local_physical_experts
        )

        self.gate = ReplicatedLinear(
            config.hidden_size,
            config.num_experts,
            bias=False,
123
            quant_config=None,
124
125
            prefix=f"{prefix}.gate",
        )
126

127
128
129
130
131
132
133
        self.shared_expert_gate = ReplicatedLinear(
            config.hidden_size,
            1,
            bias=False,
            quant_config=None,
            prefix=f"{prefix}.shared_expert_gate",
        )
134

135
136
137
138
139
140
        if config.shared_expert_intermediate_size > 0:
            self.shared_expert = Qwen3NextMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.shared_expert_intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
141
142
                reduce_results=False,
                expert_gate=self.shared_expert_gate,
143
                is_sequence_parallel=self.is_sequence_parallel,
144
                prefix=f"{prefix}.shared_expert",
145
146
147
            )
        else:
            self.shared_expert = None
148
149
150

        self.experts = SharedFusedMoE(
            shared_experts=self.shared_expert,
151
            gate=self.gate,
152
153
154
155
            num_experts=self.n_routed_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
156
            renormalize=getattr(config, "norm_topk_prob", True),
157
158
159
160
161
162
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
            is_sequence_parallel=self.is_sequence_parallel,
        )
163
164
165
166

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
167
        num_tokens, hidden_dim = hidden_states.shape
168
169
        hidden_states = hidden_states.view(-1, hidden_dim)

170
171
172
        if self.is_sequence_parallel:
            hidden_states = sequence_parallel_chunk(hidden_states)

173
174
175
176
177
178
179
180
181
182
183
        if self.experts.is_internal_router:
            # In this case, the gate/router runs inside the FusedMoE class
            final_hidden_states = self.experts(
                hidden_states=hidden_states, router_logits=hidden_states
            )
        else:
            # router_logits: (num_tokens, n_experts)
            router_logits, _ = self.gate(hidden_states)
            final_hidden_states = self.experts(
                hidden_states=hidden_states, router_logits=router_logits
            )
184

185
186
        if self.is_sequence_parallel:
            final_hidden_states = tensor_model_parallel_all_gather(
187
188
                final_hidden_states, 0
            )
189
            final_hidden_states = final_hidden_states[:num_tokens]
190
191
192
193
194
195
196
197

        return final_hidden_states.view(orig_shape)


class Qwen3NextAttention(nn.Module):
    def __init__(
        self,
        config: Qwen3NextConfig,
198
199
200
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = config.head_dim or (self.hidden_size // self.num_heads)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.dual_chunk_attention_config = getattr(
225
226
            config, "dual_chunk_attention_config", None
        )
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
        self.attn_output_gate = getattr(config, "attn_output_gate", True)

        self.qkv_proj = QKVParallelLinear(
            config.hidden_size,
            self.head_dim,
            self.total_num_heads * (1 + self.attn_output_gate),
            self.total_num_kv_heads,
            bias=getattr(config, "qkv_bias", False),
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            config.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        self.rotary_emb = get_rope(
            head_size=self.head_dim,
            max_position=config.max_position_embeddings,
250
            rope_parameters=config.rope_parameters,
251
252
253
254
255
256
257
258
259
260
261
262
263
            dual_chunk_attention_config=self.dual_chunk_attention_config,
        )

        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            **{
                "layer_idx": extract_layer_index(prefix),
264
265
266
267
                "dual_chunk_attention_config": self.dual_chunk_attention_config,
            }
            if self.dual_chunk_attention_config
            else {},
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
        )

        self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
        self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        output: torch.Tensor,
        hidden_states: torch.Tensor,
    ):
        qkv, _ = self.qkv_proj(hidden_states)

        if self.attn_output_gate:
            q_gate, k, v = qkv.split(
283
284
                [self.q_size * 2, self.kv_size, self.kv_size], dim=-1
            )
285
286
287
288
289
290
            orig_shape = q_gate.shape[:-1]
            q_gate = q_gate.view(*orig_shape, self.num_heads, -1)
            q, gate = torch.chunk(q_gate, 2, dim=-1)
            q = q.reshape(*orig_shape, -1)
            gate = gate.reshape(*orig_shape, -1)
        else:
291
            q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
292
293

        q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view(
294
295
            -1, self.num_heads * self.head_dim
        )
296
        k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view(
297
298
            -1, self.num_kv_heads * self.head_dim
        )
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313

        q, k = self.rotary_emb(positions, q, k)

        attn_output = self.attn(q, k, v)

        if self.attn_output_gate:
            gate = torch.sigmoid(gate)
            attn_output = attn_output * gate

        output[:], _ = self.o_proj(attn_output)


class Qwen3NextDecoderLayer(nn.Module):
    def __init__(
        self,
314
        vllm_config: VllmConfig,
315
316
317
318
        layer_type: str,
        prefix: str = "",
    ) -> None:
        super().__init__()
319
320
321
322
323

        config = vllm_config.model_config.hf_config
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
324
325
326
327
328

        self.layer_type = layer_type
        self.layer_idx = extract_layer_index(prefix)

        if self.layer_type == "linear_attention":
329
            self.linear_attn = GatedDeltaNetAttention(
330
                config,
331
                vllm_config=vllm_config,
332
                prefix=f"{prefix}.linear_attn",
333
                gqa_interleaved_layout=True,
334
            )
335
336
337
338
339
340
        elif self.layer_type == "full_attention":
            self.self_attn = Qwen3NextAttention(
                config,
                model_config=model_config,
                cache_config=cache_config,
                quant_config=quant_config,
341
                prefix=f"{prefix}.self_attn",
342
343
344
345
            )
        else:
            raise ValueError(f"Invalid layer_type {self.layer_type}")

346
347
348
        mlp_only_layers = (
            [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
        )
349
        if (self.layer_idx not in mlp_only_layers) and (
350
351
352
            config.num_experts > 0
            and (self.layer_idx + 1) % config.decoder_sparse_step == 0
        ):
353
            self.mlp = Qwen3NextSparseMoeBlock(
354
                vllm_config=vllm_config,
355
356
357
358
359
360
361
362
                prefix=f"{prefix}.mlp",
            )
        else:
            self.mlp = Qwen3NextMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
363
                prefix=f"{prefix}.mlp",
364
365
            )

366
367
368
        self.input_layernorm = Qwen3NextRMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
369
        self.post_attention_layernorm = Qwen3NextRMSNorm(
370
371
            config.hidden_size, eps=config.rms_norm_eps
        )
372
373
374
375
376
377
378

        self.layer_scale = getattr(config, "layer_scale", False)
        if self.layer_scale:
            self.attn_layer_scale = torch.nn.Parameter(
                torch.zeros(
                    1,
                    1,
379
                    config.hidden_size,
380
381
                ),
            )
382
383
384
385
            self.ffn_layer_scale = torch.nn.Parameter(
                torch.zeros(
                    1,
                    1,
386
                    config.hidden_size,
387
388
                ),
            )
389
390
391
392

    def forward(
        self,
        hidden_states: torch.Tensor,
393
        residual: torch.Tensor | None,
394
395
396
397
398
399
400
        positions: torch.Tensor = None,
        **kwargs: object,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
401
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421

        self_attention_output = torch.empty_like(hidden_states)
        if self.layer_type == "linear_attention":
            self.linear_attn(
                hidden_states=hidden_states,
                output=self_attention_output,
            )
        elif self.layer_type == "full_attention":
            self.self_attn(
                hidden_states=hidden_states,
                output=self_attention_output,
                positions=positions,
            )
        else:
            raise ValueError("Invalid layer_type")
        hidden_states = self_attention_output

        if self.layer_scale:
            if len(hidden_states.shape) == 2:
                hidden_states = hidden_states * (
422
423
                    self.attn_layer_scale.to(hidden_states.dtype)[0] + 1
                )
424
425
            else:
                hidden_states = hidden_states * (
426
427
                    self.attn_layer_scale.to(hidden_states.dtype) + 1
                )
428
429

        # Fully Connected
430
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
431
432
433
434
435
        hidden_states = self.mlp(hidden_states)

        if self.layer_scale:
            if len(hidden_states.shape) == 2:
                hidden_states = hidden_states * (
436
437
                    self.ffn_layer_scale.to(hidden_states.dtype)[0] + 1
                )
438
            else:
439
                assert len(hidden_states.shape) == len(self.ffn_layer_scale.shape), (
440
441
442
                    f"shape must be the same {len(hidden_states.shape)}, "
                    f"{len(self.ffn_layer_scale.shape)}"
                )
443
                hidden_states = hidden_states * (
444
445
                    self.ffn_layer_scale.to(hidden_states.dtype) + 1
                )
446
447
448
449
450

        return hidden_states, residual


@support_torch_compile
451
class Qwen3NextModel(nn.Module, EagleModelMixin):
452
453
454
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

455
        config: Qwen3NextConfig = vllm_config.model_config.hf_text_config
456
        parallel_config = vllm_config.parallel_config
457

458
459
460
461
        eplb_config = parallel_config.eplb_config
        self.num_redundant_experts = eplb_config.num_redundant_experts

        self.config = config
462
463

        self.vocab_size = config.vocab_size
464
465
466
467
468
469
470
471

        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
        )

        def get_layer(prefix: str):
            return Qwen3NextDecoderLayer(
472
                vllm_config,
473
474
475
476
477
                layer_type=config.layer_types[extract_layer_index(prefix)],
                prefix=prefix,
            )

        self.start_layer, self.end_layer, self.layers = make_layers(
478
479
480
481
482
            config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
        )
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
483

484
        if get_pp_group().is_last_rank:
485
            self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
486
487
        else:
            self.norm = PPMissingLayer()
488

489
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
490
491
492
493
        return self.embed_tokens(input_ids)

    def forward(
        self,
494
        input_ids: torch.Tensor | None,
495
        positions: torch.Tensor,
496
497
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
498
    ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
499
500
501
502
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
503
                hidden_states = self.embed_input_ids(input_ids)
504
505
506
507
508
509
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

510
        aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
511
512
513
514
        for layer_idx, layer in enumerate(
            islice(self.layers, self.start_layer, self.end_layer),
            start=self.start_layer,
        ):
515
516
517
518
519
            hidden_states, residual = layer(
                positions=positions,
                hidden_states=hidden_states,
                residual=residual,
            )
520
521
522
            self._maybe_add_hidden_state(
                aux_hidden_states, layer_idx + 1, hidden_states, residual
            )
523
524

        if not get_pp_group().is_last_rank:
525
526
527
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
528
        hidden_states, _ = self.norm(hidden_states, residual)
529
530
        if aux_hidden_states:
            return hidden_states, aux_hidden_states
531
532
533
534
535
        return hidden_states

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
536
        return SharedFusedMoE.make_expert_params_mapping(
537
            self,
538
539
540
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
541
            num_experts=getattr(self.config, "num_experts", 0),
542
543
            num_redundant_experts=self.num_redundant_experts,
        )
544

545
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        expert_params_mapping = self.get_expert_mapping()
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue

            if name.startswith("mtp."):
                continue

565
566
567
568
569
570
            # Remapping the name of FP8 kv-scale.
            if name.endswith("scale"):
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue

                if "mlp.experts" in name:
                    continue

                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
                # name = apply_attn_prefix(name, params_dict)
                if name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
                    # Skip loading extra bias for GPTQ models.
602
603
604
                    if (
                        name.endswith(".bias") or name.endswith("_bias")
                    ) and name not in params_dict:
605
                        continue
606
607
                    if name not in params_dict:
                        continue
608
609
                    param = params_dict[name]
                    weight_loader = param.weight_loader
610
611
612
613
614
615
616
                    weight_loader(
                        param,
                        loaded_weight,
                        name,
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
617
618
619
620
621
622
623
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
                    if is_pp_missing_parameter(name, self):
                        continue
624
625
626
627
628
                    if name not in params_dict:
                        logger.warning_once(
                            f"Parameter {name} not found in params_dict, skip loading"
                        )
                        continue
629
                    param = params_dict[name]
630
631
632
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
633
634
635
636
637
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
class QwenNextMixtureOfExperts(MixtureOfExperts):
    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
        for layer in self.model.layers:
            if isinstance(layer.mlp, Qwen3NextSparseMoeBlock):
                moe = layer.mlp
                moe.n_local_physical_experts = num_local_physical_experts
                moe.n_physical_experts = num_physical_experts
                moe.n_redundant_experts = self.num_redundant_experts
                moe.experts.update_expert_map()

    def set_moe_parameters(self):
        self.expert_weights = []

        self.moe_layers = []
        example_moe = None
        for layer in self.model.layers:
            if isinstance(layer, Qwen3NextDecoderLayer) and isinstance(
                layer.mlp, Qwen3NextSparseMoeBlock
            ):
                example_moe = layer.mlp
                self.moe_layers.append(layer.mlp.experts)

668
669
        if example_moe is None:
            raise RuntimeError("No Qwen3Next layer found in the model.layers.")
670
671
672
673
674
675
676
677
678
679
680
681

        # Set MoE hyperparameters
        self.num_moe_layers = len(self.moe_layers)
        self.num_expert_groups = 1
        self.num_shared_experts = 0
        self.num_logical_experts = example_moe.n_logical_experts
        self.num_physical_experts = example_moe.n_physical_experts
        self.num_local_physical_experts = example_moe.n_local_physical_experts
        self.num_routed_experts = example_moe.n_routed_experts
        self.num_redundant_experts = example_moe.n_redundant_experts


682
class Qwen3NextForCausalLM(
683
684
685
686
687
688
    nn.Module,
    HasInnerState,
    SupportsLoRA,
    SupportsPP,
    QwenNextMixtureOfExperts,
    IsHybrid,
689
):
690
691
692
693
694
695
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
696
        "gate_up_proj": ["gate_proj", "up_proj"],
697
698
        "in_proj_qkvz": ["in_proj_qkvz"],
        "in_proj_ba": ["in_proj_ba"],
699
700
701
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
702
        config = vllm_config.model_config.hf_text_config
703
704
705
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
706

707
        scheduler_config = vllm_config.scheduler_config
708
709
710
711
712
        if cache_config.mamba_cache_mode == "all":
            raise NotImplementedError(
                "Qwen3Next currently does not support 'all' prefix caching, "
                "please use '--mamba-cache-mode=align' instead"
            )
713
714
715
716
717
        self.quant_config = vllm_config.quant_config

        super().__init__()
        self.config = config
        self.scheduler_config = scheduler_config
718
719
720
        self.model = Qwen3NextModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
721

722
        self.lm_head = ParallelLMHead(
723
            config.vocab_size,
724
            config.hidden_size,
725
726
            prefix=maybe_prefix(prefix, "lm_head"),
        )
727
        self.logits_processor = LogitsProcessor(config.vocab_size)
728
        self.make_empty_intermediate_tensors = (
729
730
            self.model.make_empty_intermediate_tensors
        )
731
732

        # Set MoE hyperparameters
733
        self.set_moe_parameters()
734

735
736
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
737
738
739

    def forward(
        self,
740
        input_ids: torch.Tensor | None,
741
        positions: torch.Tensor,
742
743
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
744
745
        **kwargs: object,
    ):
746
747
748
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
749
750
751
752
753
754
755
756
757

        return hidden_states

    @classmethod
    def get_mamba_state_dtype_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[torch.dtype, torch.dtype]:
        return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
758
759
760
            vllm_config.model_config.dtype,
            vllm_config.cache_config.mamba_cache_dtype,
            vllm_config.cache_config.mamba_ssm_cache_dtype,
761
        )
762
763
764

    @classmethod
    def get_mamba_state_shape_from_config(
765
        cls, vllm_config: "VllmConfig"
766
767
    ) -> tuple[tuple[int, int], tuple[int, int]]:
        parallel_config = vllm_config.parallel_config
768
        hf_config = vllm_config.model_config.hf_text_config
769
        tp_size = parallel_config.tensor_parallel_size
770
771
772
773
774
        num_spec = (
            vllm_config.speculative_config.num_speculative_tokens
            if vllm_config.speculative_config
            else 0
        )
775
        return MambaStateShapeCalculator.gated_delta_net_state_shape(
776
777
778
779
780
781
782
783
            tp_size,
            hf_config.linear_num_key_heads,
            hf_config.linear_num_value_heads,
            hf_config.linear_key_head_dim,
            hf_config.linear_value_head_dim,
            hf_config.linear_conv_kernel_dim,
            num_spec,
        )
784

785
786
787
788
    @classmethod
    def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
        return MambaStateCopyFuncCalculator.gated_delta_net_state_copy_func()

789
790
791
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
792
    ) -> torch.Tensor | None:
793
        return self.logits_processor(self.lm_head, hidden_states)
794

795
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
796
797
798
799
800
801
802
803
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=["mtp."],
        )
        return loader.load_weights(weights)

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        return self.model.get_expert_mapping()