qwen3_next.py 28.9 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Qwen3Next model."""
4

5
from collections.abc import Iterable
6
from itertools import islice
7
8
9
10
11

import torch
from torch import nn

from vllm.compilation.decorators import support_torch_compile
12
13
14
15
16
17
18
19
20
21
22
23
from vllm.config import (
    CacheConfig,
    ModelConfig,
    VllmConfig,
    get_current_vllm_config,
)
from vllm.distributed import (
    get_ep_group,
    get_pp_group,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
24
from vllm.logger import init_logger
25
from vllm.model_executor.layers.attention import Attention
26
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
27
28
29
from vllm.model_executor.layers.layernorm import (
    GemmaRMSNorm as Qwen3NextRMSNorm,
)
30
31
32
33
34
from vllm.model_executor.layers.linear import (
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
35
from vllm.model_executor.layers.logits_processor import LogitsProcessor
36
from vllm.model_executor.layers.mamba.gdn_linear_attn import GatedDeltaNetAttention
37
from vllm.model_executor.layers.mamba.mamba_utils import (
38
39
    MambaStateCopyFunc,
    MambaStateCopyFuncCalculator,
40
41
42
    MambaStateDtypeCalculator,
    MambaStateShapeCalculator,
)
43
44
45
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
46
47
48
    ParallelLMHead,
    VocabParallelEmbedding,
)
49
from vllm.model_executor.model_loader.weight_utils import (
50
    default_weight_loader,
51
    maybe_remap_kv_scale_name,
52
)
53
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
54
from vllm.model_executor.models.utils import sequence_parallel_chunk
55
from vllm.sequence import IntermediateTensors
56
from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig
57

58
from .interfaces import (
59
    EagleModelMixin,
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    HasInnerState,
    IsHybrid,
    MixtureOfExperts,
    SupportsLoRA,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
75
76
77
78
79
80
81

logger = init_logger(__name__)

KVCache = tuple[torch.Tensor, torch.Tensor]


class Qwen3NextSparseMoeBlock(nn.Module):
82
    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
83
        super().__init__()
84

85
        config = vllm_config.model_config.hf_text_config
86
87
88
        parallel_config = vllm_config.parallel_config
        quant_config = vllm_config.quant_config

89
90
91
        self.tp_size = get_tensor_model_parallel_world_size()

        self.ep_group = get_ep_group().device_group
92
        self.ep_rank = get_ep_group().rank_in_group
93
94
95
        self.ep_size = self.ep_group.size()
        self.n_routed_experts = config.num_experts

96
97
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe

98
99
100
        if self.tp_size > config.num_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
101
102
                f"the number of experts {config.num_experts}."
            )
103
104
105
106

        # Load balancing settings.
        vllm_config = get_current_vllm_config()
        eplb_config = vllm_config.parallel_config.eplb_config
107
        self.enable_eplb = parallel_config.enable_eplb
108
109
110

        self.n_logical_experts = self.n_routed_experts
        self.n_redundant_experts = eplb_config.num_redundant_experts
111
        self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
112
113
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

114
115
116
117
118
119
120
121
122
        self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
        self.physical_expert_end = (
            self.physical_expert_start + self.n_local_physical_experts
        )

        self.gate = ReplicatedLinear(
            config.hidden_size,
            config.num_experts,
            bias=False,
123
            quant_config=None,
124
125
            prefix=f"{prefix}.gate",
        )
126

127
128
129
130
131
132
133
        self.shared_expert_gate = ReplicatedLinear(
            config.hidden_size,
            1,
            bias=False,
            quant_config=None,
            prefix=f"{prefix}.shared_expert_gate",
        )
134

135
136
137
138
139
140
        if config.shared_expert_intermediate_size > 0:
            self.shared_expert = Qwen3NextMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.shared_expert_intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
141
142
                reduce_results=False,
                expert_gate=self.shared_expert_gate,
143
                prefix=f"{prefix}.shared_expert",
144
145
146
            )
        else:
            self.shared_expert = None
147
148
149

        self.experts = SharedFusedMoE(
            shared_experts=self.shared_expert,
150
            gate=self.gate,
151
152
153
154
155
            num_experts=self.n_routed_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            reduce_results=False,
156
            renormalize=getattr(config, "norm_topk_prob", True),
157
158
159
160
161
162
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
            is_sequence_parallel=self.is_sequence_parallel,
        )
163
164
165
166

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
167
        num_tokens, hidden_dim = hidden_states.shape
168
169
        hidden_states = hidden_states.view(-1, hidden_dim)

170
171
172
        if self.is_sequence_parallel:
            hidden_states = sequence_parallel_chunk(hidden_states)

173
174
175
176
177
178
179
180
181
182
183
        if self.experts.is_internal_router:
            # In this case, the gate/router runs inside the FusedMoE class
            final_hidden_states = self.experts(
                hidden_states=hidden_states, router_logits=hidden_states
            )
        else:
            # router_logits: (num_tokens, n_experts)
            router_logits, _ = self.gate(hidden_states)
            final_hidden_states = self.experts(
                hidden_states=hidden_states, router_logits=router_logits
            )
184

185
186
        if self.shared_expert is not None:
            final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
187
188
189

        if self.is_sequence_parallel:
            final_hidden_states = tensor_model_parallel_all_gather(
190
191
                final_hidden_states, 0
            )
192
193
            final_hidden_states = final_hidden_states[:num_tokens]
        elif self.tp_size > 1:
194
            final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(  # noqa E501
195
196
                final_hidden_states
            )
197
198
199
200
201
202
203
204

        return final_hidden_states.view(orig_shape)


class Qwen3NextAttention(nn.Module):
    def __init__(
        self,
        config: Qwen3NextConfig,
205
206
207
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = config.head_dim or (self.hidden_size // self.num_heads)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.dual_chunk_attention_config = getattr(
232
233
            config, "dual_chunk_attention_config", None
        )
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
        self.attn_output_gate = getattr(config, "attn_output_gate", True)

        self.qkv_proj = QKVParallelLinear(
            config.hidden_size,
            self.head_dim,
            self.total_num_heads * (1 + self.attn_output_gate),
            self.total_num_kv_heads,
            bias=getattr(config, "qkv_bias", False),
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            config.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        self.rotary_emb = get_rope(
            head_size=self.head_dim,
            max_position=config.max_position_embeddings,
257
            rope_parameters=config.rope_parameters,
258
259
260
261
262
263
264
265
266
267
268
269
270
            dual_chunk_attention_config=self.dual_chunk_attention_config,
        )

        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            **{
                "layer_idx": extract_layer_index(prefix),
271
272
273
274
                "dual_chunk_attention_config": self.dual_chunk_attention_config,
            }
            if self.dual_chunk_attention_config
            else {},
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
        )

        self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
        self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        output: torch.Tensor,
        hidden_states: torch.Tensor,
    ):
        qkv, _ = self.qkv_proj(hidden_states)

        if self.attn_output_gate:
            q_gate, k, v = qkv.split(
290
291
                [self.q_size * 2, self.kv_size, self.kv_size], dim=-1
            )
292
293
294
295
296
297
            orig_shape = q_gate.shape[:-1]
            q_gate = q_gate.view(*orig_shape, self.num_heads, -1)
            q, gate = torch.chunk(q_gate, 2, dim=-1)
            q = q.reshape(*orig_shape, -1)
            gate = gate.reshape(*orig_shape, -1)
        else:
298
            q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
299
300

        q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view(
301
302
            -1, self.num_heads * self.head_dim
        )
303
        k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view(
304
305
            -1, self.num_kv_heads * self.head_dim
        )
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320

        q, k = self.rotary_emb(positions, q, k)

        attn_output = self.attn(q, k, v)

        if self.attn_output_gate:
            gate = torch.sigmoid(gate)
            attn_output = attn_output * gate

        output[:], _ = self.o_proj(attn_output)


class Qwen3NextDecoderLayer(nn.Module):
    def __init__(
        self,
321
        vllm_config: VllmConfig,
322
323
324
325
        layer_type: str,
        prefix: str = "",
    ) -> None:
        super().__init__()
326
327
328
329
330

        config = vllm_config.model_config.hf_config
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
331
332
333
334
335

        self.layer_type = layer_type
        self.layer_idx = extract_layer_index(prefix)

        if self.layer_type == "linear_attention":
336
            self.linear_attn = GatedDeltaNetAttention(
337
                config,
338
                vllm_config=vllm_config,
339
                prefix=f"{prefix}.linear_attn",
340
                gqa_interleaved_layout=True,
341
            )
342
343
344
345
346
347
        elif self.layer_type == "full_attention":
            self.self_attn = Qwen3NextAttention(
                config,
                model_config=model_config,
                cache_config=cache_config,
                quant_config=quant_config,
348
                prefix=f"{prefix}.self_attn",
349
350
351
352
            )
        else:
            raise ValueError(f"Invalid layer_type {self.layer_type}")

353
354
355
        mlp_only_layers = (
            [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
        )
356
        if (self.layer_idx not in mlp_only_layers) and (
357
358
359
            config.num_experts > 0
            and (self.layer_idx + 1) % config.decoder_sparse_step == 0
        ):
360
            self.mlp = Qwen3NextSparseMoeBlock(
361
                vllm_config=vllm_config,
362
363
364
365
366
367
368
369
                prefix=f"{prefix}.mlp",
            )
        else:
            self.mlp = Qwen3NextMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
370
                prefix=f"{prefix}.mlp",
371
372
            )

373
374
375
        self.input_layernorm = Qwen3NextRMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
376
        self.post_attention_layernorm = Qwen3NextRMSNorm(
377
378
            config.hidden_size, eps=config.rms_norm_eps
        )
379
380
381
382
383
384
385

        self.layer_scale = getattr(config, "layer_scale", False)
        if self.layer_scale:
            self.attn_layer_scale = torch.nn.Parameter(
                torch.zeros(
                    1,
                    1,
386
                    config.hidden_size,
387
388
                ),
            )
389
390
391
392
            self.ffn_layer_scale = torch.nn.Parameter(
                torch.zeros(
                    1,
                    1,
393
                    config.hidden_size,
394
395
                ),
            )
396
397
398
399

    def forward(
        self,
        hidden_states: torch.Tensor,
400
        residual: torch.Tensor | None,
401
402
403
404
405
406
407
        positions: torch.Tensor = None,
        **kwargs: object,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
408
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428

        self_attention_output = torch.empty_like(hidden_states)
        if self.layer_type == "linear_attention":
            self.linear_attn(
                hidden_states=hidden_states,
                output=self_attention_output,
            )
        elif self.layer_type == "full_attention":
            self.self_attn(
                hidden_states=hidden_states,
                output=self_attention_output,
                positions=positions,
            )
        else:
            raise ValueError("Invalid layer_type")
        hidden_states = self_attention_output

        if self.layer_scale:
            if len(hidden_states.shape) == 2:
                hidden_states = hidden_states * (
429
430
                    self.attn_layer_scale.to(hidden_states.dtype)[0] + 1
                )
431
432
            else:
                hidden_states = hidden_states * (
433
434
                    self.attn_layer_scale.to(hidden_states.dtype) + 1
                )
435
436

        # Fully Connected
437
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
438
439
440
441
442
        hidden_states = self.mlp(hidden_states)

        if self.layer_scale:
            if len(hidden_states.shape) == 2:
                hidden_states = hidden_states * (
443
444
                    self.ffn_layer_scale.to(hidden_states.dtype)[0] + 1
                )
445
            else:
446
                assert len(hidden_states.shape) == len(self.ffn_layer_scale.shape), (
447
448
449
                    f"shape must be the same {len(hidden_states.shape)}, "
                    f"{len(self.ffn_layer_scale.shape)}"
                )
450
                hidden_states = hidden_states * (
451
452
                    self.ffn_layer_scale.to(hidden_states.dtype) + 1
                )
453
454
455
456
457

        return hidden_states, residual


@support_torch_compile
458
class Qwen3NextModel(nn.Module, EagleModelMixin):
459
460
461
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

462
        config: Qwen3NextConfig = vllm_config.model_config.hf_text_config
463
        parallel_config = vllm_config.parallel_config
464

465
466
467
468
        eplb_config = parallel_config.eplb_config
        self.num_redundant_experts = eplb_config.num_redundant_experts

        self.config = config
469
470

        self.vocab_size = config.vocab_size
471
472
473
474
475
476
477
478

        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
        )

        def get_layer(prefix: str):
            return Qwen3NextDecoderLayer(
479
                vllm_config,
480
481
482
483
484
                layer_type=config.layer_types[extract_layer_index(prefix)],
                prefix=prefix,
            )

        self.start_layer, self.end_layer, self.layers = make_layers(
485
486
487
488
489
            config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
        )
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
490

491
        if get_pp_group().is_last_rank:
492
            self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
493
494
        else:
            self.norm = PPMissingLayer()
495

496
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
497
498
499
500
        return self.embed_tokens(input_ids)

    def forward(
        self,
501
        input_ids: torch.Tensor | None,
502
        positions: torch.Tensor,
503
504
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
505
    ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
506
507
508
509
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
510
                hidden_states = self.embed_input_ids(input_ids)
511
512
513
514
515
516
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

517
        aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
518
519
520
521
        for layer_idx, layer in enumerate(
            islice(self.layers, self.start_layer, self.end_layer),
            start=self.start_layer,
        ):
522
523
524
525
526
            hidden_states, residual = layer(
                positions=positions,
                hidden_states=hidden_states,
                residual=residual,
            )
527
528
529
            self._maybe_add_hidden_state(
                aux_hidden_states, layer_idx + 1, hidden_states, residual
            )
530
531

        if not get_pp_group().is_last_rank:
532
533
534
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
535
        hidden_states, _ = self.norm(hidden_states, residual)
536
537
        if aux_hidden_states:
            return hidden_states, aux_hidden_states
538
539
540
541
542
        return hidden_states

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
543
        return SharedFusedMoE.make_expert_params_mapping(
544
            self,
545
546
547
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
548
            num_experts=getattr(self.config, "num_experts", 0),
549
550
            num_redundant_experts=self.num_redundant_experts,
        )
551

552
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        expert_params_mapping = self.get_expert_mapping()
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue

            if name.startswith("mtp."):
                continue

572
573
574
575
576
577
            # Remapping the name of FP8 kv-scale.
            if name.endswith("scale"):
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue

                if "mlp.experts" in name:
                    continue

                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
                # name = apply_attn_prefix(name, params_dict)
                if name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
                    # Skip loading extra bias for GPTQ models.
609
610
611
                    if (
                        name.endswith(".bias") or name.endswith("_bias")
                    ) and name not in params_dict:
612
                        continue
613
614
                    if name not in params_dict:
                        continue
615
616
                    param = params_dict[name]
                    weight_loader = param.weight_loader
617
618
619
620
621
622
623
                    weight_loader(
                        param,
                        loaded_weight,
                        name,
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
624
625
626
627
628
629
630
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
                    if is_pp_missing_parameter(name, self):
                        continue
631
632
633
634
635
                    if name not in params_dict:
                        logger.warning_once(
                            f"Parameter {name} not found in params_dict, skip loading"
                        )
                        continue
636
                    param = params_dict[name]
637
638
639
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
640
641
642
643
644
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
class QwenNextMixtureOfExperts(MixtureOfExperts):
    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
        for layer in self.model.layers:
            if isinstance(layer.mlp, Qwen3NextSparseMoeBlock):
                moe = layer.mlp
                moe.n_local_physical_experts = num_local_physical_experts
                moe.n_physical_experts = num_physical_experts
                moe.n_redundant_experts = self.num_redundant_experts
                moe.experts.update_expert_map()

    def set_moe_parameters(self):
        self.expert_weights = []

        self.moe_layers = []
        example_moe = None
        for layer in self.model.layers:
            if isinstance(layer, Qwen3NextDecoderLayer) and isinstance(
                layer.mlp, Qwen3NextSparseMoeBlock
            ):
                example_moe = layer.mlp
                self.moe_layers.append(layer.mlp.experts)

675
676
        if example_moe is None:
            raise RuntimeError("No Qwen3Next layer found in the model.layers.")
677
678
679
680
681
682
683
684
685
686
687
688

        # Set MoE hyperparameters
        self.num_moe_layers = len(self.moe_layers)
        self.num_expert_groups = 1
        self.num_shared_experts = 0
        self.num_logical_experts = example_moe.n_logical_experts
        self.num_physical_experts = example_moe.n_physical_experts
        self.num_local_physical_experts = example_moe.n_local_physical_experts
        self.num_routed_experts = example_moe.n_routed_experts
        self.num_redundant_experts = example_moe.n_redundant_experts


689
class Qwen3NextForCausalLM(
690
691
692
693
694
695
    nn.Module,
    HasInnerState,
    SupportsLoRA,
    SupportsPP,
    QwenNextMixtureOfExperts,
    IsHybrid,
696
):
697
698
699
700
701
702
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
703
        "gate_up_proj": ["gate_proj", "up_proj"],
704
705
        "in_proj_qkvz": ["in_proj_qkvz"],
        "in_proj_ba": ["in_proj_ba"],
706
707
708
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
709
        config = vllm_config.model_config.hf_text_config
710
711
712
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
713

714
        scheduler_config = vllm_config.scheduler_config
715
716
717
718
719
        if cache_config.mamba_cache_mode == "all":
            raise NotImplementedError(
                "Qwen3Next currently does not support 'all' prefix caching, "
                "please use '--mamba-cache-mode=align' instead"
            )
720
721
722
723
724
        self.quant_config = vllm_config.quant_config

        super().__init__()
        self.config = config
        self.scheduler_config = scheduler_config
725
726
727
        self.model = Qwen3NextModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
728

729
        self.lm_head = ParallelLMHead(
730
            config.vocab_size,
731
            config.hidden_size,
732
733
            prefix=maybe_prefix(prefix, "lm_head"),
        )
734
        self.logits_processor = LogitsProcessor(config.vocab_size)
735
        self.make_empty_intermediate_tensors = (
736
737
            self.model.make_empty_intermediate_tensors
        )
738
739

        # Set MoE hyperparameters
740
        self.set_moe_parameters()
741

742
743
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
744
745
746

    def forward(
        self,
747
        input_ids: torch.Tensor | None,
748
        positions: torch.Tensor,
749
750
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
751
752
        **kwargs: object,
    ):
753
754
755
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
756
757
758
759
760
761
762
763
764

        return hidden_states

    @classmethod
    def get_mamba_state_dtype_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[torch.dtype, torch.dtype]:
        return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
765
766
767
            vllm_config.model_config.dtype,
            vllm_config.cache_config.mamba_cache_dtype,
            vllm_config.cache_config.mamba_ssm_cache_dtype,
768
        )
769
770
771

    @classmethod
    def get_mamba_state_shape_from_config(
772
        cls, vllm_config: "VllmConfig"
773
774
    ) -> tuple[tuple[int, int], tuple[int, int]]:
        parallel_config = vllm_config.parallel_config
775
        hf_config = vllm_config.model_config.hf_text_config
776
        tp_size = parallel_config.tensor_parallel_size
777
778
779
780
781
        num_spec = (
            vllm_config.speculative_config.num_speculative_tokens
            if vllm_config.speculative_config
            else 0
        )
782
        return MambaStateShapeCalculator.gated_delta_net_state_shape(
783
784
785
786
787
788
789
790
            tp_size,
            hf_config.linear_num_key_heads,
            hf_config.linear_num_value_heads,
            hf_config.linear_key_head_dim,
            hf_config.linear_value_head_dim,
            hf_config.linear_conv_kernel_dim,
            num_spec,
        )
791

792
793
794
795
    @classmethod
    def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
        return MambaStateCopyFuncCalculator.gated_delta_net_state_copy_func()

796
797
798
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
799
    ) -> torch.Tensor | None:
800
        return self.logits_processor(self.lm_head, hidden_states)
801

802
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
803
804
805
806
807
808
809
810
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=["mtp."],
        )
        return loader.load_weights(weights)

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        return self.model.get_expert_mapping()