qwen3_next.py 53.9 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Qwen3Next model."""
4

5
from collections.abc import Iterable
6
from itertools import islice
7
8
9
10
11
12
13

import torch
from einops import rearrange
from torch import nn
from transformers.activations import ACT2FN

from vllm.compilation.decorators import support_torch_compile
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from vllm.config import (
    CacheConfig,
    ModelConfig,
    SpeculativeConfig,
    VllmConfig,
    get_current_vllm_config,
)
from vllm.distributed import (
    divide,
    get_ep_group,
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
29
30
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
31
from vllm.model_executor.custom_op import CustomOp
32
from vllm.model_executor.layers.attention import Attention
33
from vllm.model_executor.layers.fla.ops import (
34
35
36
    chunk_gated_delta_rule as fla_chunk_gated_delta_rule,
)
from vllm.model_executor.layers.fla.ops import (
37
38
    fused_recurrent_gated_delta_rule,
)
39
from vllm.model_executor.layers.fla.ops.chunk import l2norm_fwd
40
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
41
42
43
44
from vllm.model_executor.layers.layernorm import (
    GemmaRMSNorm as Qwen3NextRMSNorm,
)
from vllm.model_executor.layers.layernorm import RMSNormGated
45
46
47
48
49
50
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
51
52
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.abstract import MambaBase
53
from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weight_loader
54
from vllm.model_executor.layers.mamba.mamba_utils import (
55
56
    MambaStateCopyFunc,
    MambaStateCopyFuncCalculator,
57
58
59
    MambaStateDtypeCalculator,
    MambaStateShapeCalculator,
)
60
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
61
62
63
    causal_conv1d_fn,
    causal_conv1d_update,
)
64
65
66
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
67
68
69
    ParallelLMHead,
    VocabParallelEmbedding,
)
70
from vllm.model_executor.model_loader.weight_utils import (
71
    default_weight_loader,
72
    maybe_remap_kv_scale_name,
73
74
    sharded_weight_loader,
)
75
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
76
from vllm.model_executor.models.utils import sequence_parallel_chunk
77
78
79
80
81
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import Qwen3NextConfig
from vllm.triton_utils import tl, triton
82
from vllm.utils.torch_utils import direct_register_custom_op
83
from vllm.v1.attention.backend import AttentionMetadata
84
85
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from .interfaces import (
    HasInnerState,
    IsHybrid,
    MixtureOfExperts,
    SupportsLoRA,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
102
103
104
105
106
107

logger = init_logger(__name__)

KVCache = tuple[torch.Tensor, torch.Tensor]


108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def fi_chunk_gated_delta_rule(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    g: torch.Tensor,
    beta: torch.Tensor,
    initial_state: torch.Tensor,
    output_final_state: bool,
    cu_seqlens: torch.LongTensor | None = None,
    head_first: bool = False,
    use_qk_l2norm_in_kernel: bool = True,
):
    from flashinfer.gdn_prefill import (
        chunk_gated_delta_rule as chunk_gated_delta_rule_fi,
    )

    if use_qk_l2norm_in_kernel:
        q = l2norm_fwd(q)
        k = l2norm_fwd(k)

    # use flashinfer implementation
    q = q.squeeze(0).contiguous()
    k = k.squeeze(0).contiguous()
    v = v.squeeze(0).contiguous()

    g = g.squeeze(0).contiguous()
    beta = beta.squeeze(0).contiguous()
    fi_state = initial_state.to(torch.float32)
    fi_g = g.to(torch.float32)
    fi_beta = beta.to(torch.float32)
    return chunk_gated_delta_rule_fi(
        q=q,
        k=k,
        v=v,
        g=torch.exp(fi_g),
        beta=fi_beta,
        initial_state=fi_state,
        output_final_state=output_final_state,
        cu_seqlens=cu_seqlens,
    )


@CustomOp.register("chunk_gated_delta_rule")
class ChunkGatedDeltaRule(CustomOp):
    def __init__(self) -> None:
        super().__init__()
        if current_platform.is_cuda() and current_platform.is_device_capability(90):
            logger.info_once(
                "Using FlashInfer GDN prefill kernel on CUDA compute capability 90"
            )
            self._forward_method = self.forward_cuda
        else:
            self._forward_method = self.forward_native

    def forward_cuda(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        g: torch.Tensor,
        beta: torch.Tensor,
        initial_state: torch.Tensor,
        output_final_state: bool,
        cu_seqlens: torch.LongTensor | None = None,
        head_first: bool = False,
        use_qk_l2norm_in_kernel: bool = True,
    ):
        return fi_chunk_gated_delta_rule(
            q=q,
            k=k,
            v=v,
            g=g,
            beta=beta,
            initial_state=initial_state,
            output_final_state=output_final_state,
            cu_seqlens=cu_seqlens,
            head_first=head_first,
            use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
        )

    def forward_native(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        g: torch.Tensor,
        beta: torch.Tensor,
        initial_state: torch.Tensor,
        output_final_state: bool,
        cu_seqlens: torch.LongTensor | None = None,
        head_first: bool = False,
        use_qk_l2norm_in_kernel: bool = True,
    ):
        return fla_chunk_gated_delta_rule(
            q=q,
            k=k,
            v=v,
            g=g,
            beta=beta,
            initial_state=initial_state,
            output_final_state=output_final_state,
            cu_seqlens=cu_seqlens,
            head_first=head_first,
            use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
        )


215
class Qwen3NextSparseMoeBlock(nn.Module):
216
    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
217
        super().__init__()
218

219
        config = vllm_config.model_config.hf_text_config
220
221
222
        parallel_config = vllm_config.parallel_config
        quant_config = vllm_config.quant_config

223
224
225
        self.tp_size = get_tensor_model_parallel_world_size()

        self.ep_group = get_ep_group().device_group
226
        self.ep_rank = get_ep_group().rank_in_group
227
228
229
        self.ep_size = self.ep_group.size()
        self.n_routed_experts = config.num_experts

230
231
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe

232
233
234
        if self.tp_size > config.num_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
235
236
                f"the number of experts {config.num_experts}."
            )
237
238
239
240

        # Load balancing settings.
        vllm_config = get_current_vllm_config()
        eplb_config = vllm_config.parallel_config.eplb_config
241
        self.enable_eplb = parallel_config.enable_eplb
242
243
244

        self.n_logical_experts = self.n_routed_experts
        self.n_redundant_experts = eplb_config.num_redundant_experts
245
        self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
246
247
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

248
249
250
251
252
253
254
255
256
257
258
259
        self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
        self.physical_expert_end = (
            self.physical_expert_start + self.n_local_physical_experts
        )

        self.gate = ReplicatedLinear(
            config.hidden_size,
            config.num_experts,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate",
        )
260

261
262
263
264
265
266
267
        self.shared_expert_gate = ReplicatedLinear(
            config.hidden_size,
            1,
            bias=False,
            quant_config=None,
            prefix=f"{prefix}.shared_expert_gate",
        )
268

269
270
271
272
273
274
        if config.shared_expert_intermediate_size > 0:
            self.shared_expert = Qwen3NextMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.shared_expert_intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
275
276
                reduce_results=False,
                expert_gate=self.shared_expert_gate,
277
                prefix=f"{prefix}.shared_expert",
278
279
280
            )
        else:
            self.shared_expert = None
281
282
283

        self.experts = SharedFusedMoE(
            shared_experts=self.shared_expert,
284
            gate=self.gate,
285
286
287
288
289
            num_experts=self.n_routed_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            reduce_results=False,
290
            renormalize=getattr(config, "norm_topk_prob", True),
291
292
293
294
295
296
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
            is_sequence_parallel=self.is_sequence_parallel,
        )
297
298
299
300

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
301
        num_tokens, hidden_dim = hidden_states.shape
302
303
        hidden_states = hidden_states.view(-1, hidden_dim)

304
305
306
        if self.is_sequence_parallel:
            hidden_states = sequence_parallel_chunk(hidden_states)

307
308
309
310
311
312
313
314
315
316
317
        if self.experts.is_internal_router:
            # In this case, the gate/router runs inside the FusedMoE class
            final_hidden_states = self.experts(
                hidden_states=hidden_states, router_logits=hidden_states
            )
        else:
            # router_logits: (num_tokens, n_experts)
            router_logits, _ = self.gate(hidden_states)
            final_hidden_states = self.experts(
                hidden_states=hidden_states, router_logits=router_logits
            )
318

319
320
        if self.shared_expert is not None:
            final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
321
322
323

        if self.is_sequence_parallel:
            final_hidden_states = tensor_model_parallel_all_gather(
324
325
                final_hidden_states, 0
            )
326
327
            final_hidden_states = final_hidden_states[:num_tokens]
        elif self.tp_size > 1:
328
            final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(  # noqa E501
329
330
                final_hidden_states
            )
331
332
333
334
335
336
337

        return final_hidden_states.view(orig_shape)


class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
    @property
    def mamba_type(self) -> str:
338
        return "gdn_attention"
339
340
341

    def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
        return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
342
343
            self.model_config.dtype, self.cache_config.mamba_cache_dtype
        )
344
345
346

    def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
        return MambaStateShapeCalculator.gated_delta_net_state_shape(
347
348
349
350
351
352
353
354
            self.tp_size,
            self.num_k_heads,
            self.num_v_heads,
            self.head_k_dim,
            self.head_v_dim,
            self.conv_kernel_size,
            self.num_spec,
        )
355
356
357
358

    def __init__(
        self,
        config: Qwen3NextConfig,
359
360
361
362
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        speculative_config: SpeculativeConfig | None = None,
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.hidden_size = config.hidden_size
        self.num_v_heads = config.linear_num_value_heads
        self.num_k_heads = config.linear_num_key_heads
        self.head_k_dim = config.linear_key_head_dim
        self.head_v_dim = config.linear_value_head_dim
        self.key_dim = self.head_k_dim * self.num_k_heads
        self.value_dim = self.head_v_dim * self.num_v_heads

        self.conv_kernel_size = config.linear_conv_kernel_dim
        self.layer_idx = extract_layer_index(prefix)
        self.activation = config.hidden_act
        self.act = ACT2FN[config.hidden_act]
        self.layer_norm_epsilon = config.rms_norm_eps
        self.prefix = prefix

        self.config = config
        self.model_config = model_config
        self.cache_config = cache_config
        self.quant_config = quant_config
        self.speculative_config = speculative_config
388
389
390
391
392
        self.num_spec = (
            self.speculative_config.num_speculative_tokens
            if self.speculative_config
            else 0
        )
393
394
395
396
397
398
399
400
401
402
403
404
405
406

        # QKV
        self.conv_dim = self.key_dim * 2 + self.value_dim
        self.conv1d = ColumnParallelLinear(
            input_size=self.conv_kernel_size,
            output_size=self.conv_dim,
            bias=False,
            prefix=f"{prefix}.conv1d",
        )
        self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)

        # projection of the input hidden states
        self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2
        self.projection_size_ba = self.num_v_heads * 2
407
        self.in_proj_qkvz = ColumnParallelLinear(
408
            input_size=self.hidden_size,
409
            output_size=self.projection_size_qkvz,
410
411
            bias=False,
            quant_config=quant_config,
412
413
414
415
416
417
418
419
420
            prefix=f"{prefix}.in_proj_qkvz",
        )
        # ba_proj doesn't support blockwise fp8 quantization.
        self.in_proj_ba = ColumnParallelLinear(
            input_size=self.hidden_size,
            output_size=self.projection_size_ba,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.in_proj_ba",
421
422
423
424
425
426
427
        )

        query_key_settings = (self.key_dim, 0, False)
        value_settings = (self.value_dim, 0, False)

        delattr(self.conv1d.weight, "weight_loader")
        set_weight_attrs(
428
429
430
431
432
433
434
435
436
437
438
439
440
            self.conv1d.weight,
            {
                "weight_loader": mamba_v2_sharded_weight_loader(
                    [
                        query_key_settings,
                        query_key_settings,
                        value_settings,
                    ],
                    self.tp_size,
                    self.tp_rank,
                )
            },
        )
441
442
443
444
445
446

        # selective projection used to make dt, B and C input dependant

        # time step projection (discretization)
        # instantiate once and copy inv_dt in init_weights of PretrainedModel
        self.dt_bias = nn.Parameter(
447
448
            torch.ones(self.num_v_heads // self.tp_size),
        )
449
450
451
        self.A_log = nn.Parameter(
            torch.empty(
                divide(self.num_v_heads, self.tp_size),
452
453
            )
        )
454

455
456
        set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)})
        set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
457
458
459
460
461
462

        self.norm = RMSNormGated(
            self.head_v_dim,
            eps=self.layer_norm_epsilon,
            group_size=None,
            norm_before_gate=True,
463
            device=current_platform.current_device(),
464
            dtype=config.dtype,
465
466
        )

467
468
469
470
471
472
473
474
        self.out_proj = RowParallelLinear(
            self.value_dim,
            self.hidden_size,
            bias=False,
            input_is_parallel=True,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )
475

476
477
        self.chunk_gated_delta_rule = ChunkGatedDeltaRule()

478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

    def fix_query_key_value_ordering(
        self,
        mixed_qkvz,
        mixed_ba,
    ):
        """
        Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.
        """
        new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + (
            self.num_k_heads // self.tp_size,
493
494
495
496
497
498
499
            (
                self.head_k_dim
                + self.head_k_dim
                + (self.head_v_dim + self.head_v_dim)
                * self.num_v_heads
                // self.num_k_heads
            ),
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
        )
        new_tensor_shape_ba = mixed_qkvz.size()[:-1] + (
            self.num_k_heads // self.tp_size,
            2 * self.num_v_heads // self.num_k_heads,
        )

        mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz)
        mixed_ba = mixed_ba.view(*new_tensor_shape_ba)

        split_arg_list_qkvz = [
            self.head_k_dim,
            self.head_k_dim,
            (self.num_v_heads // self.num_k_heads * self.head_v_dim),
            (self.num_v_heads // self.num_k_heads * self.head_v_dim),
        ]
        split_arg_list_ba = [
            self.num_v_heads // self.num_k_heads,
517
            self.num_v_heads // self.num_k_heads,
518
519
520
521
522
        ]

        # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)]
        # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn],
        #  [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng]
523
        (query, key, value, z) = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=2)
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
        (b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2)

        # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn]
        value = value.reshape(value.size(0), -1, self.head_v_dim)
        z = z.reshape(z.size(0), -1, self.head_v_dim)
        b = b.reshape(b.size(0), self.num_v_heads // self.tp_size)
        a = a.reshape(a.size(0), self.num_v_heads // self.tp_size)

        return query, key, value, z, b, a

    def rearrange_mixed_qkv(self, mixed_qkv):
        if mixed_qkv is None:
            return None, None, None
        query, key, value = torch.split(
            mixed_qkv,
            [
                self.key_dim // self.tp_size,
                self.key_dim // self.tp_size,
                self.value_dim // self.tp_size,
            ],
            dim=-1,
        )
        query, key = map(
547
548
549
550
            lambda x: rearrange(x, "l (h d) -> 1 l h d", d=self.head_k_dim),
            (query, key),
        )
        value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim)
551
        return query.contiguous(), key.contiguous(), value.contiguous()
552
553
554
555
556
557

    def forward(
        self,
        hidden_states: torch.Tensor,
        output: torch.Tensor,
    ):
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
        """
        Forward pass with three parts:
        1. Input projection
        2. Core attention (custom op)
        3. Output projection
        """
        num_tokens = hidden_states.size(0)

        # ============================================================
        # Part 1: Input Projection
        # ============================================================
        projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
        projected_states_ba, _ = self.in_proj_ba(hidden_states)
        query, key, value, z, b, a = self.fix_query_key_value_ordering(
            projected_states_qkvz, projected_states_ba
        )
        query, key, value = map(
            lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value)
        )
        mixed_qkv = torch.cat((query, key, value), dim=-1)

        # ============================================================
        # Part 2: Core Attention (Custom Op)
        # ============================================================
582
583
        # Note: we should not use torch.empty here like other attention backends,
        # see discussions in https://github.com/vllm-project/vllm/pull/28182
584
585
586
587
588
589
590
591
592
593
594
        core_attn_out = torch.zeros(
            (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim),
            dtype=hidden_states.dtype,
            device=hidden_states.device,
        )

        torch.ops.vllm.gdn_attention_core(
            mixed_qkv,
            b,
            a,
            core_attn_out,
595
596
597
            self.prefix,
        )

598
599
600
601
602
603
604
605
606
607
608
609
610
        # ============================================================
        # Part 3: Output Projection
        # ============================================================
        z_shape_og = z.shape
        # Reshape input data into 2D tensor
        core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
        z = z.reshape(-1, z.shape[-1])
        core_attn_out = self.norm(core_attn_out, z)
        core_attn_out = core_attn_out.reshape(z_shape_og)
        core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
        output[:num_tokens], _ = self.out_proj(core_attn_out)

    def _forward_core(
611
        self,
612
613
614
615
        mixed_qkv: torch.Tensor,
        b: torch.Tensor,
        a: torch.Tensor,
        core_attn_out: torch.Tensor,
616
    ):
617
618
619
        """
        Core attention computation (called by custom op).
        """
620
621
622
623
624
625
626
627
628
629
630
631
632
633
        forward_context = get_forward_context()
        attn_metadata: AttentionMetadata = forward_context.attn_metadata

        if attn_metadata is None:
            # V1 profile run
            return

        assert isinstance(attn_metadata, dict)
        attn_metadata = attn_metadata[self.prefix]
        assert isinstance(attn_metadata, GDNAttentionMetadata)
        has_initial_state = attn_metadata.has_initial_state
        spec_query_start_loc = attn_metadata.spec_query_start_loc
        non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
        spec_sequence_masks = attn_metadata.spec_sequence_masks
634
635
        spec_token_indx = attn_metadata.spec_token_indx
        non_spec_token_indx = attn_metadata.non_spec_token_indx
636
637
638
639
640
        spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor  # noqa: E501
        non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor  # noqa: E501
        self_kv_cache = self.kv_cache[forward_context.virtual_engine]
        conv_state = self_kv_cache[0].transpose(-1, -2)
        ssm_state = self_kv_cache[1]
641
        num_actual_tokens = attn_metadata.num_actual_tokens
642
        num_accepted_tokens = attn_metadata.num_accepted_tokens
643

644
645
646
        mixed_qkv = mixed_qkv[:num_actual_tokens]
        b = b[:num_actual_tokens]
        a = a[:num_actual_tokens]
647

648
        # 1. Convolution sequence transformation
649
650
651
        conv_weights = self.conv1d.weight.view(
            self.conv1d.weight.size(0), self.conv1d.weight.size(2)
        )
652
653

        if spec_sequence_masks is not None:
654
            if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
655
656
657
                mixed_qkv_spec = mixed_qkv
                mixed_qkv_non_spec = None
            else:
658
659
                mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx)
                mixed_qkv_non_spec = mixed_qkv.index_select(0, non_spec_token_indx)
660
661
662
663
        else:
            mixed_qkv_spec = None
            mixed_qkv_non_spec = mixed_qkv

664
        # 1.1: Process the multi-query part
665
666
667
668
669
670
671
        if spec_sequence_masks is not None:
            mixed_qkv_spec = causal_conv1d_update(
                mixed_qkv_spec,
                conv_state,
                conv_weights,
                self.conv1d.bias,
                self.activation,
672
673
674
                conv_state_indices=spec_state_indices_tensor[:, 0][
                    : attn_metadata.num_spec_decodes
                ],
675
                num_accepted_tokens=num_accepted_tokens,
676
677
                query_start_loc=spec_query_start_loc,
                max_query_len=spec_state_indices_tensor.size(-1),
678
679
680
                validate_data=False,
            )

681
        # 1.2: Process the remaining part
682
        if attn_metadata.num_prefills > 0:
683
            mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1)
684
            # - "cache_indices" updates the conv_state cache in positions
685
            #   pointed to by "state_indices_tensor"
686
            mixed_qkv_non_spec = causal_conv1d_fn(
687
                mixed_qkv_non_spec_T,
688
689
690
691
692
693
694
                conv_weights,
                self.conv1d.bias,
                activation=self.activation,
                conv_states=conv_state,
                has_initial_state=has_initial_state,
                cache_indices=non_spec_state_indices_tensor,
                query_start_loc=non_spec_query_start_loc,
695
                metadata=attn_metadata,
696
697
698
699
700
701
702
703
            ).transpose(0, 1)
        elif attn_metadata.num_decodes > 0:
            mixed_qkv_non_spec = causal_conv1d_update(
                mixed_qkv_non_spec,
                conv_state,
                conv_weights,
                self.conv1d.bias,
                self.activation,
704
                conv_state_indices=non_spec_state_indices_tensor[
705
                    : attn_metadata.num_actual_tokens
706
                ],
707
708
709
710
711
                validate_data=True,
            )
        else:
            mixed_qkv_non_spec = None

712
        query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec)
713
        query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(
714
715
            mixed_qkv_non_spec
        )
716

717
        g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)
718
719

        if spec_sequence_masks is not None:
720
            if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
721
722
723
724
725
                g_spec = g
                beta_spec = beta
                g_non_spec = None
                beta_non_spec = None
            else:
726
727
728
729
                g_spec = g.index_select(1, spec_token_indx)
                beta_spec = beta.index_select(1, spec_token_indx)
                g_non_spec = g.index_select(1, non_spec_token_indx)
                beta_non_spec = beta.index_select(1, non_spec_token_indx)
730
731
732
733
734
735
        else:
            g_spec = None
            beta_spec = None
            g_non_spec = g
            beta_non_spec = beta

736
        # 2. Recurrent attention
737

738
        # 2.1: Process the multi-query part
739
        if spec_sequence_masks is not None:
740
741
742
743
744
745
746
747
748
749
750
751
752
            core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule(
                q=query_spec,
                k=key_spec,
                v=value_spec,
                g=g_spec,
                beta=beta_spec,
                initial_state=ssm_state,
                inplace_final_state=True,
                cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1],
                ssm_state_indices=spec_state_indices_tensor,
                num_accepted_tokens=num_accepted_tokens,
                use_qk_l2norm_in_kernel=True,
            )
753
754
755
        else:
            core_attn_out_spec, last_recurrent_state = None, None

756
        # 2.2: Process the remaining part
757
        if attn_metadata.num_prefills > 0:
758
            initial_state = ssm_state[non_spec_state_indices_tensor].contiguous()
759
760
761
762
            initial_state[~has_initial_state, ...] = 0
            (
                core_attn_out_non_spec,
                last_recurrent_state,
763
            ) = self.chunk_gated_delta_rule(
764
765
766
767
768
769
770
771
772
773
774
775
776
                q=query_non_spec,
                k=key_non_spec,
                v=value_non_spec,
                g=g_non_spec,
                beta=beta_non_spec,
                initial_state=initial_state,
                output_final_state=True,
                cu_seqlens=non_spec_query_start_loc,
                head_first=False,
                use_qk_l2norm_in_kernel=True,
            )
            # Init cache
            ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(
777
778
                ssm_state.dtype
            )
779
780
781
782
783
784
785
786
787
788
        elif attn_metadata.num_decodes > 0:
            core_attn_out_non_spec, last_recurrent_state = (
                fused_recurrent_gated_delta_rule(
                    q=query_non_spec,
                    k=key_non_spec,
                    v=value_non_spec,
                    g=g_non_spec,
                    beta=beta_non_spec,
                    initial_state=ssm_state,
                    inplace_final_state=True,
789
790
791
                    cu_seqlens=non_spec_query_start_loc[
                        : attn_metadata.num_decodes + 1
                    ],
792
793
                    ssm_state_indices=non_spec_state_indices_tensor,
                    use_qk_l2norm_in_kernel=True,
794
795
                )
            )
796
797
798
        else:
            core_attn_out_non_spec, last_recurrent_state = None, None

799
        # 3. Merge core attention output
800
        if spec_sequence_masks is not None and core_attn_out_non_spec is not None:
801
            merged_out = torch.empty(
802
803
804
805
                (1, num_actual_tokens, *core_attn_out_spec.shape[2:]),
                dtype=core_attn_out_non_spec.dtype,
                device=core_attn_out_non_spec.device,
            )
806
807
808
            merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec)
            merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec)
            core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)
809
        elif spec_sequence_masks is not None:
810
            core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)
811
        else:
812
            core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)
813
814
815
816
817
818


class Qwen3NextAttention(nn.Module):
    def __init__(
        self,
        config: Qwen3NextConfig,
819
820
821
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = config.head_dim or (self.hidden_size // self.num_heads)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.dual_chunk_attention_config = getattr(
846
847
            config, "dual_chunk_attention_config", None
        )
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
        self.attn_output_gate = getattr(config, "attn_output_gate", True)

        self.qkv_proj = QKVParallelLinear(
            config.hidden_size,
            self.head_dim,
            self.total_num_heads * (1 + self.attn_output_gate),
            self.total_num_kv_heads,
            bias=getattr(config, "qkv_bias", False),
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            config.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        self.rotary_emb = get_rope(
            head_size=self.head_dim,
            max_position=config.max_position_embeddings,
871
            rope_parameters=config.rope_parameters,
872
873
874
875
876
877
878
879
880
881
882
883
884
            dual_chunk_attention_config=self.dual_chunk_attention_config,
        )

        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            **{
                "layer_idx": extract_layer_index(prefix),
885
886
887
888
                "dual_chunk_attention_config": self.dual_chunk_attention_config,
            }
            if self.dual_chunk_attention_config
            else {},
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
        )

        self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
        self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        output: torch.Tensor,
        hidden_states: torch.Tensor,
    ):
        qkv, _ = self.qkv_proj(hidden_states)

        if self.attn_output_gate:
            q_gate, k, v = qkv.split(
904
905
                [self.q_size * 2, self.kv_size, self.kv_size], dim=-1
            )
906
907
908
909
910
911
            orig_shape = q_gate.shape[:-1]
            q_gate = q_gate.view(*orig_shape, self.num_heads, -1)
            q, gate = torch.chunk(q_gate, 2, dim=-1)
            q = q.reshape(*orig_shape, -1)
            gate = gate.reshape(*orig_shape, -1)
        else:
912
            q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
913
914

        q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view(
915
916
            -1, self.num_heads * self.head_dim
        )
917
        k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view(
918
919
            -1, self.num_kv_heads * self.head_dim
        )
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934

        q, k = self.rotary_emb(positions, q, k)

        attn_output = self.attn(q, k, v)

        if self.attn_output_gate:
            gate = torch.sigmoid(gate)
            attn_output = attn_output * gate

        output[:], _ = self.o_proj(attn_output)


class Qwen3NextDecoderLayer(nn.Module):
    def __init__(
        self,
935
        vllm_config: VllmConfig,
936
937
938
939
        layer_type: str,
        prefix: str = "",
    ) -> None:
        super().__init__()
940
941
942
943
944
945

        config = vllm_config.model_config.hf_config
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        speculative_config = vllm_config.speculative_config
946
947
948
949
950
951
952
953
954
955
956

        self.layer_type = layer_type
        self.layer_idx = extract_layer_index(prefix)

        if self.layer_type == "linear_attention":
            self.linear_attn = Qwen3NextGatedDeltaNet(
                config,
                model_config=model_config,
                cache_config=cache_config,
                quant_config=quant_config,
                speculative_config=speculative_config,
957
958
                prefix=f"{prefix}.linear_attn",
            )
959
960
961
962
963
964
        elif self.layer_type == "full_attention":
            self.self_attn = Qwen3NextAttention(
                config,
                model_config=model_config,
                cache_config=cache_config,
                quant_config=quant_config,
965
                prefix=f"{prefix}.self_attn",
966
967
968
969
            )
        else:
            raise ValueError(f"Invalid layer_type {self.layer_type}")

970
971
972
        mlp_only_layers = (
            [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
        )
973
        if (self.layer_idx not in mlp_only_layers) and (
974
975
976
            config.num_experts > 0
            and (self.layer_idx + 1) % config.decoder_sparse_step == 0
        ):
977
            self.mlp = Qwen3NextSparseMoeBlock(
978
                vllm_config=vllm_config,
979
980
981
982
983
984
985
986
                prefix=f"{prefix}.mlp",
            )
        else:
            self.mlp = Qwen3NextMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
987
                prefix=f"{prefix}.mlp",
988
989
            )

990
991
992
        self.input_layernorm = Qwen3NextRMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
993
        self.post_attention_layernorm = Qwen3NextRMSNorm(
994
995
            config.hidden_size, eps=config.rms_norm_eps
        )
996
997
998
999
1000
1001
1002

        self.layer_scale = getattr(config, "layer_scale", False)
        if self.layer_scale:
            self.attn_layer_scale = torch.nn.Parameter(
                torch.zeros(
                    1,
                    1,
1003
                    config.hidden_size,
1004
                    dtype=config.dtype,
1005
1006
                ),
            )
1007
1008
1009
1010
            self.ffn_layer_scale = torch.nn.Parameter(
                torch.zeros(
                    1,
                    1,
1011
                    config.hidden_size,
1012
                    dtype=config.dtype,
1013
1014
                ),
            )
1015
1016
1017
1018

    def forward(
        self,
        hidden_states: torch.Tensor,
1019
        residual: torch.Tensor | None,
1020
1021
1022
1023
1024
1025
1026
        positions: torch.Tensor = None,
        **kwargs: object,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
1027
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047

        self_attention_output = torch.empty_like(hidden_states)
        if self.layer_type == "linear_attention":
            self.linear_attn(
                hidden_states=hidden_states,
                output=self_attention_output,
            )
        elif self.layer_type == "full_attention":
            self.self_attn(
                hidden_states=hidden_states,
                output=self_attention_output,
                positions=positions,
            )
        else:
            raise ValueError("Invalid layer_type")
        hidden_states = self_attention_output

        if self.layer_scale:
            if len(hidden_states.shape) == 2:
                hidden_states = hidden_states * (
1048
1049
                    self.attn_layer_scale.to(hidden_states.dtype)[0] + 1
                )
1050
1051
            else:
                hidden_states = hidden_states * (
1052
1053
                    self.attn_layer_scale.to(hidden_states.dtype) + 1
                )
1054
1055

        # Fully Connected
1056
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
1057
1058
1059
1060
1061
        hidden_states = self.mlp(hidden_states)

        if self.layer_scale:
            if len(hidden_states.shape) == 2:
                hidden_states = hidden_states * (
1062
1063
                    self.ffn_layer_scale.to(hidden_states.dtype)[0] + 1
                )
1064
            else:
1065
                assert len(hidden_states.shape) == len(self.ffn_layer_scale.shape), (
1066
1067
1068
                    f"shape must be the same {len(hidden_states.shape)}, "
                    f"{len(self.ffn_layer_scale.shape)}"
                )
1069
                hidden_states = hidden_states * (
1070
1071
                    self.ffn_layer_scale.to(hidden_states.dtype) + 1
                )
1072
1073
1074
1075
1076
1077
1078
1079
1080

        return hidden_states, residual


@support_torch_compile
class Qwen3NextModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

1081
        config: Qwen3NextConfig = vllm_config.model_config.hf_text_config
1082
        parallel_config = vllm_config.parallel_config
1083

1084
1085
1086
1087
        eplb_config = parallel_config.eplb_config
        self.num_redundant_experts = eplb_config.num_redundant_experts

        self.config = config
1088
1089

        self.vocab_size = config.vocab_size
1090
1091
1092
1093
1094
1095
1096
1097

        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
        )

        def get_layer(prefix: str):
            return Qwen3NextDecoderLayer(
1098
                vllm_config,
1099
1100
1101
1102
1103
                layer_type=config.layer_types[extract_layer_index(prefix)],
                prefix=prefix,
            )

        self.start_layer, self.end_layer, self.layers = make_layers(
1104
1105
1106
1107
1108
            config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
        )
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
1109

1110
        if get_pp_group().is_last_rank:
1111
            self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1112
1113
        else:
            self.norm = PPMissingLayer()
1114

1115
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
1116
1117
1118
1119
        return self.embed_tokens(input_ids)

    def forward(
        self,
1120
        input_ids: torch.Tensor | None,
1121
        positions: torch.Tensor,
1122
1123
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1124
1125
1126
1127
1128
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
1129
                hidden_states = self.embed_input_ids(input_ids)
1130
1131
1132
1133
1134
1135
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

1136
        for layer in islice(self.layers, self.start_layer, self.end_layer):
1137
1138
1139
1140
1141
1142
1143
            hidden_states, residual = layer(
                positions=positions,
                hidden_states=hidden_states,
                residual=residual,
            )

        if not get_pp_group().is_last_rank:
1144
1145
1146
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
1147
1148
1149
1150
1151
1152
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
1153
        return SharedFusedMoE.make_expert_params_mapping(
1154
            self,
1155
1156
1157
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
1158
            num_experts=getattr(self.config, "num_experts", 0),
1159
1160
            num_redundant_experts=self.num_redundant_experts,
        )
1161

1162
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        expert_params_mapping = self.get_expert_mapping()
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue

            if name.startswith("mtp."):
                continue

1182
1183
1184
1185
1186
1187
            # Remapping the name of FP8 kv-scale.
            if name.endswith("scale"):
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue

                if "mlp.experts" in name:
                    continue

                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
                # name = apply_attn_prefix(name, params_dict)
                if name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
                    # Skip loading extra bias for GPTQ models.
1219
1220
1221
                    if (
                        name.endswith(".bias") or name.endswith("_bias")
                    ) and name not in params_dict:
1222
                        continue
1223
1224
                    if name not in params_dict:
                        continue
1225
1226
                    param = params_dict[name]
                    weight_loader = param.weight_loader
1227
1228
1229
1230
1231
1232
1233
                    weight_loader(
                        param,
                        loaded_weight,
                        name,
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
1234
1235
1236
1237
1238
1239
1240
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
                    if is_pp_missing_parameter(name, self):
                        continue
1241
1242
1243
1244
1245
                    if name not in params_dict:
                        logger.warning_once(
                            f"Parameter {name} not found in params_dict, skip loading"
                        )
                        continue
1246
                    param = params_dict[name]
1247
1248
1249
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
1250
1251
1252
1253
1254
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
class QwenNextMixtureOfExperts(MixtureOfExperts):
    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
        for layer in self.model.layers:
            if isinstance(layer.mlp, Qwen3NextSparseMoeBlock):
                moe = layer.mlp
                moe.n_local_physical_experts = num_local_physical_experts
                moe.n_physical_experts = num_physical_experts
                moe.n_redundant_experts = self.num_redundant_experts
                moe.experts.update_expert_map()

    def set_moe_parameters(self):
        self.expert_weights = []

        self.moe_layers = []
        example_moe = None
        for layer in self.model.layers:
            if isinstance(layer, Qwen3NextDecoderLayer) and isinstance(
                layer.mlp, Qwen3NextSparseMoeBlock
            ):
                example_moe = layer.mlp
                self.moe_layers.append(layer.mlp.experts)

1285
1286
        if example_moe is None:
            raise RuntimeError("No Qwen3Next layer found in the model.layers.")
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298

        # Set MoE hyperparameters
        self.num_moe_layers = len(self.moe_layers)
        self.num_expert_groups = 1
        self.num_shared_experts = 0
        self.num_logical_experts = example_moe.n_logical_experts
        self.num_physical_experts = example_moe.n_physical_experts
        self.num_local_physical_experts = example_moe.n_local_physical_experts
        self.num_routed_experts = example_moe.n_routed_experts
        self.num_redundant_experts = example_moe.n_redundant_experts


1299
class Qwen3NextForCausalLM(
1300
1301
1302
1303
1304
1305
    nn.Module,
    HasInnerState,
    SupportsLoRA,
    SupportsPP,
    QwenNextMixtureOfExperts,
    IsHybrid,
1306
):
1307
1308
1309
1310
1311
1312
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
1313
        "gate_up_proj": ["gate_proj", "up_proj"],
1314
1315
1316
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1317
        config = vllm_config.model_config.hf_text_config
1318
1319
1320
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
1321

1322
        scheduler_config = vllm_config.scheduler_config
1323
1324
1325
1326
1327
        if cache_config.mamba_cache_mode == "all":
            raise NotImplementedError(
                "Qwen3Next currently does not support 'all' prefix caching, "
                "please use '--mamba-cache-mode=align' instead"
            )
1328
1329
1330
1331
1332
        self.quant_config = vllm_config.quant_config

        super().__init__()
        self.config = config
        self.scheduler_config = scheduler_config
1333
1334
1335
        self.model = Qwen3NextModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
1336

1337
        self.lm_head = ParallelLMHead(
1338
            config.vocab_size,
1339
            config.hidden_size,
1340
1341
            prefix=maybe_prefix(prefix, "lm_head"),
        )
1342
        self.logits_processor = LogitsProcessor(config.vocab_size)
1343
        self.make_empty_intermediate_tensors = (
1344
1345
            self.model.make_empty_intermediate_tensors
        )
1346
1347

        # Set MoE hyperparameters
1348
        self.set_moe_parameters()
1349

1350
1351
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
1352
1353
1354

    def forward(
        self,
1355
        input_ids: torch.Tensor | None,
1356
        positions: torch.Tensor,
1357
1358
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1359
1360
        **kwargs: object,
    ):
1361
1362
1363
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
1364
1365
1366
1367
1368
1369
1370
1371
1372

        return hidden_states

    @classmethod
    def get_mamba_state_dtype_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[torch.dtype, torch.dtype]:
        return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
1373
1374
            vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype
        )
1375
1376
1377

    @classmethod
    def get_mamba_state_shape_from_config(
1378
        cls, vllm_config: "VllmConfig"
1379
1380
    ) -> tuple[tuple[int, int], tuple[int, int]]:
        parallel_config = vllm_config.parallel_config
1381
        hf_config = vllm_config.model_config.hf_text_config
1382
        tp_size = parallel_config.tensor_parallel_size
1383
1384
1385
1386
1387
        num_spec = (
            vllm_config.speculative_config.num_speculative_tokens
            if vllm_config.speculative_config
            else 0
        )
1388
        return MambaStateShapeCalculator.gated_delta_net_state_shape(
1389
1390
1391
1392
1393
1394
1395
1396
            tp_size,
            hf_config.linear_num_key_heads,
            hf_config.linear_num_value_heads,
            hf_config.linear_key_head_dim,
            hf_config.linear_value_head_dim,
            hf_config.linear_conv_kernel_dim,
            num_spec,
        )
1397

1398
1399
1400
1401
    @classmethod
    def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
        return MambaStateCopyFuncCalculator.gated_delta_net_state_copy_func()

1402
1403
1404
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1405
    ) -> torch.Tensor | None:
1406
        return self.logits_processor(self.lm_head, hidden_states)
1407

1408
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=["mtp."],
        )
        return loader.load_weights(weights)

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        return self.model.get_expert_mapping()


1419
1420
1421
1422
1423
def gdn_attention_core(
    mixed_qkv: torch.Tensor,
    b: torch.Tensor,
    a: torch.Tensor,
    core_attn_out: torch.Tensor,
1424
1425
    layer_name: str,
) -> None:
1426
1427
1428
1429
1430
    """
    Custom op for the core attention computation.
    Only handles the convolution + recurrent attention part.
    Input/output projections are handled outside this op.
    """
1431
1432
    forward_context: ForwardContext = get_forward_context()
    self = forward_context.no_compile_layers[layer_name]
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
    self._forward_core(
        mixed_qkv=mixed_qkv,
        b=b,
        a=a,
        core_attn_out=core_attn_out,
    )


def gdn_attention_core_fake(
    mixed_qkv: torch.Tensor,
    b: torch.Tensor,
    a: torch.Tensor,
    core_attn_out: torch.Tensor,
1446
1447
    layer_name: str,
) -> None:
1448
    """Fake implementation for torch.compile."""
1449
1450
1451
1452
    return


direct_register_custom_op(
1453
1454
1455
1456
    op_name="gdn_attention_core",
    op_func=gdn_attention_core,
    mutates_args=["core_attn_out"],
    fake_impl=gdn_attention_core_fake,
1457
1458
1459
1460
1461
1462
)


@triton.jit
def fused_gdn_gating_kernel(
    g,
1463
    beta_output,
1464
1465
    A_log,
    a,
1466
    b,
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
    dt_bias,
    seq_len,
    NUM_HEADS: tl.constexpr,
    beta: tl.constexpr,
    threshold: tl.constexpr,
    BLK_HEADS: tl.constexpr,
):
    i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS)
    off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off
    mask = head_off < NUM_HEADS
    blk_A_log = tl.load(A_log + head_off, mask=mask)
    blk_a = tl.load(a + off, mask=mask)
1480
    blk_b = tl.load(b + off, mask=mask)
1481
1482
1483
    blk_bias = tl.load(dt_bias + head_off, mask=mask)
    # If the model is loaded in fp16, without the .float() here, A might be -inf
    x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)
1484
1485
1486
    softplus_x = tl.where(
        beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x
    )
1487
1488
    blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
    tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
1489
    # compute beta_output = sigmoid(b)
1490
1491
1492
1493
    blk_beta_output = tl.sigmoid(blk_b.to(tl.float32))
    tl.store(
        beta_output + off, blk_beta_output.to(beta_output.dtype.element_ty), mask=mask
    )
1494
1495
1496
1497
1498


def fused_gdn_gating(
    A_log: torch.Tensor,
    a: torch.Tensor,
1499
    b: torch.Tensor,
1500
1501
1502
    dt_bias: torch.Tensor,
    beta: float = 1.0,
    threshold: float = 20.0,
1503
1504
1505
1506
1507
1508
1509
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Fused computation of g and beta for Gated Delta Net.
    g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
    beta_output = b.sigmoid()
    TODO maybe use torch.compile to replace this triton kernel
    """
1510
1511
1512
    batch, num_heads = a.shape
    seq_len = 1
    grid = (batch, seq_len, triton.cdiv(num_heads, 8))
1513
    g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
1514
    beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device)
1515
    fused_gdn_gating_kernel[grid](
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
        g,
        beta_output,
        A_log,
        a,
        b,
        dt_bias,
        seq_len,
        num_heads,
        beta,
        threshold,
        8,
        num_warps=1,
1528
    )
1529
    return g, beta_output