qwen3_next.py 60 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Qwen3Next model."""
4

5
from collections.abc import Iterable
6
from itertools import islice
7
8
9
10
11
12
13

import torch
from einops import rearrange
from torch import nn
from transformers.activations import ACT2FN

from vllm.compilation.decorators import support_torch_compile
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from vllm.config import (
    CacheConfig,
    ModelConfig,
    SpeculativeConfig,
    VllmConfig,
    get_current_vllm_config,
)
from vllm.distributed import (
    divide,
    get_ep_group,
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
29
30
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
31
from vllm.model_executor.custom_op import CustomOp
32
from vllm.model_executor.layers.attention import Attention
33
from vllm.model_executor.layers.fla.ops import (
34
35
36
    chunk_gated_delta_rule as fla_chunk_gated_delta_rule,
)
from vllm.model_executor.layers.fla.ops import (
37
    fused_sigmoid_gating_delta_rule_update,
38
)
39
from vllm.model_executor.layers.fla.ops.chunk import l2norm_fwd
40
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
41
42
43
44
from vllm.model_executor.layers.layernorm import (
    GemmaRMSNorm as Qwen3NextRMSNorm,
)
from vllm.model_executor.layers.layernorm import RMSNormGated
45
46
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
47
    MergedColumnParallelLinear,
48
49
50
51
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
52
53
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.abstract import MambaBase
54
from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weight_loader
55
from vllm.model_executor.layers.mamba.mamba_utils import (
56
57
    MambaStateCopyFunc,
    MambaStateCopyFuncCalculator,
58
59
60
    MambaStateDtypeCalculator,
    MambaStateShapeCalculator,
)
61
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
62
63
64
    causal_conv1d_fn,
    causal_conv1d_update,
)
65
66
67
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
68
69
70
    ParallelLMHead,
    VocabParallelEmbedding,
)
71
from vllm.model_executor.model_loader.weight_utils import (
72
    default_weight_loader,
73
    maybe_remap_kv_scale_name,
74
75
    sharded_weight_loader,
)
76
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
77
from vllm.model_executor.models.utils import sequence_parallel_chunk
78
79
80
81
82
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import Qwen3NextConfig
from vllm.triton_utils import tl, triton
83
from vllm.utils.torch_utils import direct_register_custom_op
84
from vllm.v1.attention.backend import AttentionMetadata
85
86
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata

87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from .interfaces import (
    HasInnerState,
    IsHybrid,
    MixtureOfExperts,
    SupportsLoRA,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
103
104
105
106
107
108

logger = init_logger(__name__)

KVCache = tuple[torch.Tensor, torch.Tensor]


109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def fi_chunk_gated_delta_rule(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    g: torch.Tensor,
    beta: torch.Tensor,
    initial_state: torch.Tensor,
    output_final_state: bool,
    cu_seqlens: torch.LongTensor | None = None,
    use_qk_l2norm_in_kernel: bool = True,
):
    from flashinfer.gdn_prefill import (
        chunk_gated_delta_rule as chunk_gated_delta_rule_fi,
    )

    if use_qk_l2norm_in_kernel:
        q = l2norm_fwd(q)
        k = l2norm_fwd(k)

    # use flashinfer implementation
    q = q.squeeze(0).contiguous()
    k = k.squeeze(0).contiguous()
    v = v.squeeze(0).contiguous()

    g = g.squeeze(0).contiguous()
    beta = beta.squeeze(0).contiguous()
    fi_state = initial_state.to(torch.float32)
    fi_g = g.to(torch.float32)
    fi_beta = beta.to(torch.float32)
138
    output, final_state = chunk_gated_delta_rule_fi(
139
140
141
142
143
144
145
146
147
        q=q,
        k=k,
        v=v,
        g=torch.exp(fi_g),
        beta=fi_beta,
        initial_state=fi_state,
        output_final_state=output_final_state,
        cu_seqlens=cu_seqlens,
    )
148
149
    # Unsqueeze back to 4D (1, L, H, D) to match fla output format
    return output.unsqueeze(0), final_state
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212


@CustomOp.register("chunk_gated_delta_rule")
class ChunkGatedDeltaRule(CustomOp):
    def __init__(self) -> None:
        super().__init__()
        if current_platform.is_cuda() and current_platform.is_device_capability(90):
            logger.info_once(
                "Using FlashInfer GDN prefill kernel on CUDA compute capability 90"
            )
            self._forward_method = self.forward_cuda
        else:
            self._forward_method = self.forward_native

    def forward_cuda(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        g: torch.Tensor,
        beta: torch.Tensor,
        initial_state: torch.Tensor,
        output_final_state: bool,
        cu_seqlens: torch.LongTensor | None = None,
        use_qk_l2norm_in_kernel: bool = True,
    ):
        return fi_chunk_gated_delta_rule(
            q=q,
            k=k,
            v=v,
            g=g,
            beta=beta,
            initial_state=initial_state,
            output_final_state=output_final_state,
            cu_seqlens=cu_seqlens,
            use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
        )

    def forward_native(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        g: torch.Tensor,
        beta: torch.Tensor,
        initial_state: torch.Tensor,
        output_final_state: bool,
        cu_seqlens: torch.LongTensor | None = None,
        use_qk_l2norm_in_kernel: bool = True,
    ):
        return fla_chunk_gated_delta_rule(
            q=q,
            k=k,
            v=v,
            g=g,
            beta=beta,
            initial_state=initial_state,
            output_final_state=output_final_state,
            cu_seqlens=cu_seqlens,
            use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
        )


213
class Qwen3NextSparseMoeBlock(nn.Module):
214
    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
215
        super().__init__()
216

217
        config = vllm_config.model_config.hf_text_config
218
219
220
        parallel_config = vllm_config.parallel_config
        quant_config = vllm_config.quant_config

221
222
223
        self.tp_size = get_tensor_model_parallel_world_size()

        self.ep_group = get_ep_group().device_group
224
        self.ep_rank = get_ep_group().rank_in_group
225
226
227
        self.ep_size = self.ep_group.size()
        self.n_routed_experts = config.num_experts

228
229
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe

230
231
232
        if self.tp_size > config.num_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
233
234
                f"the number of experts {config.num_experts}."
            )
235
236
237
238

        # Load balancing settings.
        vllm_config = get_current_vllm_config()
        eplb_config = vllm_config.parallel_config.eplb_config
239
        self.enable_eplb = parallel_config.enable_eplb
240
241
242

        self.n_logical_experts = self.n_routed_experts
        self.n_redundant_experts = eplb_config.num_redundant_experts
243
        self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
244
245
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

246
247
248
249
250
251
252
253
254
        self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
        self.physical_expert_end = (
            self.physical_expert_start + self.n_local_physical_experts
        )

        self.gate = ReplicatedLinear(
            config.hidden_size,
            config.num_experts,
            bias=False,
255
            quant_config=None,
256
257
            prefix=f"{prefix}.gate",
        )
258

259
260
261
262
263
264
265
        self.shared_expert_gate = ReplicatedLinear(
            config.hidden_size,
            1,
            bias=False,
            quant_config=None,
            prefix=f"{prefix}.shared_expert_gate",
        )
266

267
268
269
270
271
272
        if config.shared_expert_intermediate_size > 0:
            self.shared_expert = Qwen3NextMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.shared_expert_intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
273
274
                reduce_results=False,
                expert_gate=self.shared_expert_gate,
275
                prefix=f"{prefix}.shared_expert",
276
277
278
            )
        else:
            self.shared_expert = None
279
280
281

        self.experts = SharedFusedMoE(
            shared_experts=self.shared_expert,
282
            gate=self.gate,
283
284
285
286
287
            num_experts=self.n_routed_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            reduce_results=False,
288
            renormalize=getattr(config, "norm_topk_prob", True),
289
290
291
292
293
294
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
            is_sequence_parallel=self.is_sequence_parallel,
        )
295
296
297
298

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
299
        num_tokens, hidden_dim = hidden_states.shape
300
301
        hidden_states = hidden_states.view(-1, hidden_dim)

302
303
304
        if self.is_sequence_parallel:
            hidden_states = sequence_parallel_chunk(hidden_states)

305
306
307
308
309
310
311
312
313
314
315
        if self.experts.is_internal_router:
            # In this case, the gate/router runs inside the FusedMoE class
            final_hidden_states = self.experts(
                hidden_states=hidden_states, router_logits=hidden_states
            )
        else:
            # router_logits: (num_tokens, n_experts)
            router_logits, _ = self.gate(hidden_states)
            final_hidden_states = self.experts(
                hidden_states=hidden_states, router_logits=router_logits
            )
316

317
318
        if self.shared_expert is not None:
            final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
319
320
321

        if self.is_sequence_parallel:
            final_hidden_states = tensor_model_parallel_all_gather(
322
323
                final_hidden_states, 0
            )
324
325
            final_hidden_states = final_hidden_states[:num_tokens]
        elif self.tp_size > 1:
326
            final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(  # noqa E501
327
328
                final_hidden_states
            )
329
330
331
332
333
334
335

        return final_hidden_states.view(orig_shape)


class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
    @property
    def mamba_type(self) -> str:
336
        return "gdn_attention"
337
338
339

    def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
        return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
340
341
342
            self.model_config.dtype,
            self.cache_config.mamba_cache_dtype,
            self.cache_config.mamba_ssm_cache_dtype,
343
        )
344
345
346

    def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
        return MambaStateShapeCalculator.gated_delta_net_state_shape(
347
348
349
350
351
352
353
354
            self.tp_size,
            self.num_k_heads,
            self.num_v_heads,
            self.head_k_dim,
            self.head_v_dim,
            self.conv_kernel_size,
            self.num_spec,
        )
355
356
357
358

    def __init__(
        self,
        config: Qwen3NextConfig,
359
360
361
362
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        speculative_config: SpeculativeConfig | None = None,
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.hidden_size = config.hidden_size
        self.num_v_heads = config.linear_num_value_heads
        self.num_k_heads = config.linear_num_key_heads
        self.head_k_dim = config.linear_key_head_dim
        self.head_v_dim = config.linear_value_head_dim
        self.key_dim = self.head_k_dim * self.num_k_heads
        self.value_dim = self.head_v_dim * self.num_v_heads

        self.conv_kernel_size = config.linear_conv_kernel_dim
        self.layer_idx = extract_layer_index(prefix)
        self.activation = config.hidden_act
        self.act = ACT2FN[config.hidden_act]
        self.layer_norm_epsilon = config.rms_norm_eps
        self.prefix = prefix

        self.config = config
        self.model_config = model_config
        self.cache_config = cache_config
        self.quant_config = quant_config
        self.speculative_config = speculative_config
388
389
390
391
392
        self.num_spec = (
            self.speculative_config.num_speculative_tokens
            if self.speculative_config
            else 0
        )
393
394
395
396
397
398
399
400
401
402
403
404

        # QKV
        self.conv_dim = self.key_dim * 2 + self.value_dim
        self.conv1d = ColumnParallelLinear(
            input_size=self.conv_kernel_size,
            output_size=self.conv_dim,
            bias=False,
            prefix=f"{prefix}.conv1d",
        )
        self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)

        # projection of the input hidden states
405
406
407
408
409
410
        # Qwen3-Next and Qwen3.5 has a different qkv_proj layout,
        # we need to create qkvz_proj adaptively here.
        self.in_proj_qkvz = self.create_qkvz_proj(
            hidden_size=self.hidden_size,
            key_dim=self.key_dim,
            value_dim=self.value_dim,
411
            quant_config=quant_config,
412
413
414
            prefix=f"{prefix}.in_proj_qkvz",
        )
        # ba_proj doesn't support blockwise fp8 quantization.
415
416
417
418
419
        # Qwen3-Next and Qwen3.5 have different in_proj_ba checkpoint
        # layouts, so we use a factory method to create the projection.
        self.in_proj_ba = self.create_ba_proj(
            hidden_size=self.hidden_size,
            num_v_heads=self.num_v_heads,
420
421
            quant_config=quant_config,
            prefix=f"{prefix}.in_proj_ba",
422
423
424
425
426
427
428
        )

        query_key_settings = (self.key_dim, 0, False)
        value_settings = (self.value_dim, 0, False)

        delattr(self.conv1d.weight, "weight_loader")
        set_weight_attrs(
429
430
431
432
433
434
435
436
437
438
439
440
441
            self.conv1d.weight,
            {
                "weight_loader": mamba_v2_sharded_weight_loader(
                    [
                        query_key_settings,
                        query_key_settings,
                        value_settings,
                    ],
                    self.tp_size,
                    self.tp_rank,
                )
            },
        )
442

443
        # selective projection used to make dt, B and C input dependent
444
445
446
447

        # time step projection (discretization)
        # instantiate once and copy inv_dt in init_weights of PretrainedModel
        self.dt_bias = nn.Parameter(
448
449
            torch.ones(self.num_v_heads // self.tp_size),
        )
450
451
452
        self.A_log = nn.Parameter(
            torch.empty(
                divide(self.num_v_heads, self.tp_size),
453
454
            )
        )
455

456
457
        set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)})
        set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
458
459
460
461
462
463

        self.norm = RMSNormGated(
            self.head_v_dim,
            eps=self.layer_norm_epsilon,
            group_size=None,
            norm_before_gate=True,
464
            device=current_platform.current_device(),
465
466
        )

467
468
469
470
471
472
473
474
        self.out_proj = RowParallelLinear(
            self.value_dim,
            self.hidden_size,
            bias=False,
            input_is_parallel=True,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )
475

476
477
        self.chunk_gated_delta_rule = ChunkGatedDeltaRule()

478
479
480
481
482
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
    def create_qkvz_proj(
        self,
        hidden_size: int,
        key_dim: int,
        value_dim: int,
        quant_config: QuantizationConfig | None,
        prefix: str,
    ) -> MergedColumnParallelLinear:
        return MergedColumnParallelLinear(
            input_size=hidden_size,
            output_sizes=[sum((key_dim, key_dim, value_dim, value_dim))],
            bias=False,
            quant_config=quant_config,
            prefix=prefix,
        )

499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
    def create_ba_proj(
        self,
        hidden_size: int,
        num_v_heads: int,
        quant_config: QuantizationConfig | None,
        prefix: str,
    ) -> MergedColumnParallelLinear:
        # Qwen3-Next stores in_proj_ba as a single fused weight with an
        # interleaved GQA layout: [b_g0, a_g0, b_g1, a_g1, ...] where
        # each group corresponds to a key-head group. We must use a single
        # output shard so that ColumnParallel sharding preserves this
        # interleaved structure across TP ranks.
        # Qwen3.5 overrides this to use [num_v_heads, num_v_heads] since
        # its checkpoint has separate in_proj_b and in_proj_a weights.
        return MergedColumnParallelLinear(
            input_size=hidden_size,
            output_sizes=[num_v_heads * 2],
            bias=False,
            quant_config=quant_config,
            prefix=prefix,
        )

521
522
    def fix_query_key_value_ordering(
        self,
523
524
        mixed_qkvz: torch.Tensor,
        mixed_ba: torch.Tensor,
525
526
527
528
529
530
    ):
        """
        Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.
        """
        new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + (
            self.num_k_heads // self.tp_size,
531
532
533
534
535
536
537
            (
                self.head_k_dim
                + self.head_k_dim
                + (self.head_v_dim + self.head_v_dim)
                * self.num_v_heads
                // self.num_k_heads
            ),
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
        )
        new_tensor_shape_ba = mixed_qkvz.size()[:-1] + (
            self.num_k_heads // self.tp_size,
            2 * self.num_v_heads // self.num_k_heads,
        )

        mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz)
        mixed_ba = mixed_ba.view(*new_tensor_shape_ba)

        split_arg_list_qkvz = [
            self.head_k_dim,
            self.head_k_dim,
            (self.num_v_heads // self.num_k_heads * self.head_v_dim),
            (self.num_v_heads // self.num_k_heads * self.head_v_dim),
        ]
        split_arg_list_ba = [
            self.num_v_heads // self.num_k_heads,
555
            self.num_v_heads // self.num_k_heads,
556
557
558
559
560
        ]

        # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)]
        # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn],
        #  [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng]
561
        (query, key, value, z) = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=2)
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
        (b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2)

        # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn]
        value = value.reshape(value.size(0), -1, self.head_v_dim)
        z = z.reshape(z.size(0), -1, self.head_v_dim)
        b = b.reshape(b.size(0), self.num_v_heads // self.tp_size)
        a = a.reshape(a.size(0), self.num_v_heads // self.tp_size)

        return query, key, value, z, b, a

    def rearrange_mixed_qkv(self, mixed_qkv):
        if mixed_qkv is None:
            return None, None, None
        query, key, value = torch.split(
            mixed_qkv,
            [
                self.key_dim // self.tp_size,
                self.key_dim // self.tp_size,
                self.value_dim // self.tp_size,
            ],
            dim=-1,
        )
        query, key = map(
585
586
587
588
            lambda x: rearrange(x, "l (h d) -> 1 l h d", d=self.head_k_dim),
            (query, key),
        )
        value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim)
589
        return query.contiguous(), key.contiguous(), value.contiguous()
590
591
592
593
594
595

    def forward(
        self,
        hidden_states: torch.Tensor,
        output: torch.Tensor,
    ):
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
        """
        Forward pass with three parts:
        1. Input projection
        2. Core attention (custom op)
        3. Output projection
        """
        num_tokens = hidden_states.size(0)

        # ============================================================
        # Part 1: Input Projection
        # ============================================================
        projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
        projected_states_ba, _ = self.in_proj_ba(hidden_states)
        query, key, value, z, b, a = self.fix_query_key_value_ordering(
            projected_states_qkvz, projected_states_ba
        )
        query, key, value = map(
            lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value)
        )
        mixed_qkv = torch.cat((query, key, value), dim=-1)

        # ============================================================
        # Part 2: Core Attention (Custom Op)
        # ============================================================
620
621
        # Note: we should not use torch.empty here like other attention backends,
        # see discussions in https://github.com/vllm-project/vllm/pull/28182
622
623
624
625
626
627
628
629
630
631
632
        core_attn_out = torch.zeros(
            (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim),
            dtype=hidden_states.dtype,
            device=hidden_states.device,
        )

        torch.ops.vllm.gdn_attention_core(
            mixed_qkv,
            b,
            a,
            core_attn_out,
633
634
635
            self.prefix,
        )

636
637
638
639
640
641
642
643
644
645
646
647
        # ============================================================
        # Part 3: Output Projection
        # ============================================================
        z_shape_og = z.shape
        # Reshape input data into 2D tensor
        core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
        z = z.reshape(-1, z.shape[-1])
        core_attn_out = self.norm(core_attn_out, z)
        core_attn_out = core_attn_out.reshape(z_shape_og)
        core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
        output[:num_tokens], _ = self.out_proj(core_attn_out)

648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
    def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None:
        """Warm up GDN prefill kernels during V1 profiling.

        During V1 profile runs, ``_forward_core`` returns early because
        ``attn_metadata`` is ``None``, so the autotuned kernels used by
        ``chunk_gated_delta_rule`` (e.g. ``solve_tril``,
        ``chunk_scaled_dot_kkt``) are never invoked.  After profiling,
        vLLM allocates KV cache using most of the remaining GPU memory.
        When the first real inference triggers the autotuner it OOMs
        because there is not enough memory left for benchmarking.

        This method runs minimal forward passes through
        ``chunk_gated_delta_rule`` with small dummy tensors to force
        autotuning while GPU memory is still plentiful.  The autotuner
        results are cached globally, so only the first layer incurs
        actual benchmarking cost.

        Most kernels use a fixed ``BT = chunk_size`` (64), but
        ``chunk_fwd_kernel_o`` recomputes ``BT`` from the sequence
        length: ``min(64, max(16, next_power_of_2(T)))``.  Since ``BT``
        is part of its autotune key, we run warmup passes with T = 16,
        32, and 64 to cover all possible ``BT`` values.

        The decode path uses ``fused_sigmoid_gating_delta_rule_update``
        which has fixed kernel parameters (no autotuning), so only the
        prefill (chunked) path needs warming up.
        """
        if hasattr(self, "_prefill_kernels_warmed_up"):
            return
        self._prefill_kernels_warmed_up = True

        device = mixed_qkv.device
        dtype = mixed_qkv.dtype
        num_k_heads = self.num_k_heads // self.tp_size
        num_v_heads = self.num_v_heads // self.tp_size
        _, state_dtype = self.get_state_dtype()

        # Run warmup for each possible BT value of chunk_fwd_kernel_o:
        #   T=16 → BT=16, T=32 → BT=32, T=64 → BT=64.
        # Other kernels always use BT=chunk_size(64), so their autotune
        # cache is populated on the first pass and reused thereafter.
        for T in (16, 32, 64):
            q = torch.randn(
                1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype
            )
            k = torch.randn(
                1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype
            )
            v = torch.randn(
                1, T, num_v_heads, self.head_v_dim, device=device, dtype=dtype
            )
            g = torch.randn(1, T, num_v_heads, device=device, dtype=dtype)
            beta = torch.randn(1, T, num_v_heads, device=device, dtype=dtype)
            state = torch.zeros(
                1,
                num_v_heads,
                self.head_v_dim,
                self.head_k_dim,
                device=device,
                dtype=state_dtype,
            )
            cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.long)

            try:
                self.chunk_gated_delta_rule(
                    q=q,
                    k=k,
                    v=v,
                    g=g,
                    beta=beta,
                    initial_state=state,
                    output_final_state=False,
                    cu_seqlens=cu_seqlens,
                    use_qk_l2norm_in_kernel=True,
                )
            except Exception:
                logger.warning(
                    "GDN prefill kernel warmup (T=%d) failed for "
                    "layer %s. First inference may OOM due to "
                    "autotuner.",
                    T,
                    self.prefix,
                    exc_info=True,
                )
            else:
                logger.debug(
                    "GDN prefill kernel warmup (T=%d) completed for layer %s",
                    T,
                    self.prefix,
                )
            finally:
                del q, k, v, g, beta, state, cu_seqlens

        torch.accelerator.empty_cache()

743
    def _forward_core(
744
        self,
745
746
747
748
        mixed_qkv: torch.Tensor,
        b: torch.Tensor,
        a: torch.Tensor,
        core_attn_out: torch.Tensor,
749
    ):
750
751
752
        """
        Core attention computation (called by custom op).
        """
753
754
755
756
        forward_context = get_forward_context()
        attn_metadata: AttentionMetadata = forward_context.attn_metadata

        if attn_metadata is None:
757
758
759
            # V1 profile run — warm up prefill kernels so that
            # autotuning completes before KV cache allocation.
            self._warmup_prefill_kernels(mixed_qkv)
760
761
762
763
764
765
766
767
768
            return

        assert isinstance(attn_metadata, dict)
        attn_metadata = attn_metadata[self.prefix]
        assert isinstance(attn_metadata, GDNAttentionMetadata)
        has_initial_state = attn_metadata.has_initial_state
        spec_query_start_loc = attn_metadata.spec_query_start_loc
        non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
        spec_sequence_masks = attn_metadata.spec_sequence_masks
769
770
        spec_token_indx = attn_metadata.spec_token_indx
        non_spec_token_indx = attn_metadata.non_spec_token_indx
771
772
773
774
775
        spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor  # noqa: E501
        non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor  # noqa: E501
        self_kv_cache = self.kv_cache[forward_context.virtual_engine]
        conv_state = self_kv_cache[0].transpose(-1, -2)
        ssm_state = self_kv_cache[1]
776
        num_actual_tokens = attn_metadata.num_actual_tokens
777
        num_accepted_tokens = attn_metadata.num_accepted_tokens
778

779
780
781
        mixed_qkv = mixed_qkv[:num_actual_tokens]
        b = b[:num_actual_tokens]
        a = a[:num_actual_tokens]
782

783
        # 1. Convolution sequence transformation
784
785
786
        conv_weights = self.conv1d.weight.view(
            self.conv1d.weight.size(0), self.conv1d.weight.size(2)
        )
787
788

        if spec_sequence_masks is not None:
789
            if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
790
791
792
                mixed_qkv_spec = mixed_qkv
                mixed_qkv_non_spec = None
            else:
793
794
                mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx)
                mixed_qkv_non_spec = mixed_qkv.index_select(0, non_spec_token_indx)
795
796
797
798
        else:
            mixed_qkv_spec = None
            mixed_qkv_non_spec = mixed_qkv

799
        # 1.1: Process the multi-query part
800
801
802
803
804
805
806
        if spec_sequence_masks is not None:
            mixed_qkv_spec = causal_conv1d_update(
                mixed_qkv_spec,
                conv_state,
                conv_weights,
                self.conv1d.bias,
                self.activation,
807
808
809
                conv_state_indices=spec_state_indices_tensor[:, 0][
                    : attn_metadata.num_spec_decodes
                ],
810
                num_accepted_tokens=num_accepted_tokens,
811
812
                query_start_loc=spec_query_start_loc,
                max_query_len=spec_state_indices_tensor.size(-1),
813
814
815
                validate_data=False,
            )

816
        # 1.2: Process the remaining part
817
        if attn_metadata.num_prefills > 0:
818
            mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1)
819
            # - "cache_indices" updates the conv_state cache in positions
820
            #   pointed to by "state_indices_tensor"
821
            mixed_qkv_non_spec = causal_conv1d_fn(
822
                mixed_qkv_non_spec_T,
823
824
825
826
827
828
829
                conv_weights,
                self.conv1d.bias,
                activation=self.activation,
                conv_states=conv_state,
                has_initial_state=has_initial_state,
                cache_indices=non_spec_state_indices_tensor,
                query_start_loc=non_spec_query_start_loc,
830
                metadata=attn_metadata,
831
832
833
834
835
836
837
838
            ).transpose(0, 1)
        elif attn_metadata.num_decodes > 0:
            mixed_qkv_non_spec = causal_conv1d_update(
                mixed_qkv_non_spec,
                conv_state,
                conv_weights,
                self.conv1d.bias,
                self.activation,
839
                conv_state_indices=non_spec_state_indices_tensor[
840
                    : attn_metadata.num_actual_tokens
841
                ],
842
843
844
845
846
                validate_data=True,
            )
        else:
            mixed_qkv_non_spec = None

847
        query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec)
848
        query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(
849
850
            mixed_qkv_non_spec
        )
851

852
853
854
        if attn_metadata.num_prefills > 0:
            g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)
            if spec_sequence_masks is not None:
855
856
                g_non_spec = g.index_select(1, non_spec_token_indx)
                beta_non_spec = beta.index_select(1, non_spec_token_indx)
857
858
859
            else:
                g_non_spec = g
                beta_non_spec = beta
860
        else:
861
862
            g_non_spec = None
            beta_non_spec = None
863

864
        # 2. Recurrent attention
865

866
        # 2.1: Process the multi-query part
867
        if spec_sequence_masks is not None:
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
            core_attn_out_spec, last_recurrent_state = (
                fused_sigmoid_gating_delta_rule_update(
                    A_log=self.A_log,
                    a=a,
                    b=b,
                    dt_bias=self.dt_bias,
                    q=query_spec,
                    k=key_spec,
                    v=value_spec,
                    initial_state=ssm_state,
                    inplace_final_state=True,
                    cu_seqlens=spec_query_start_loc[
                        : attn_metadata.num_spec_decodes + 1
                    ],
                    ssm_state_indices=spec_state_indices_tensor,
                    num_accepted_tokens=num_accepted_tokens,
                    use_qk_l2norm_in_kernel=True,
                )
886
            )
887
888
889
        else:
            core_attn_out_spec, last_recurrent_state = None, None

890
        # 2.2: Process the remaining part
891
        if attn_metadata.num_prefills > 0:
892
            initial_state = ssm_state[non_spec_state_indices_tensor].contiguous()
893
894
895
896
            initial_state[~has_initial_state, ...] = 0
            (
                core_attn_out_non_spec,
                last_recurrent_state,
897
            ) = self.chunk_gated_delta_rule(
898
899
900
901
902
903
904
905
906
907
908
909
                q=query_non_spec,
                k=key_non_spec,
                v=value_non_spec,
                g=g_non_spec,
                beta=beta_non_spec,
                initial_state=initial_state,
                output_final_state=True,
                cu_seqlens=non_spec_query_start_loc,
                use_qk_l2norm_in_kernel=True,
            )
            # Init cache
            ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(
910
911
                ssm_state.dtype
            )
912
913
        elif attn_metadata.num_decodes > 0:
            core_attn_out_non_spec, last_recurrent_state = (
914
915
916
917
918
                fused_sigmoid_gating_delta_rule_update(
                    A_log=self.A_log,
                    a=a,
                    b=b,
                    dt_bias=self.dt_bias,
919
920
921
922
923
                    q=query_non_spec,
                    k=key_non_spec,
                    v=value_non_spec,
                    initial_state=ssm_state,
                    inplace_final_state=True,
924
925
926
                    cu_seqlens=non_spec_query_start_loc[
                        : attn_metadata.num_decodes + 1
                    ],
927
928
                    ssm_state_indices=non_spec_state_indices_tensor,
                    use_qk_l2norm_in_kernel=True,
929
930
                )
            )
931
932
933
        else:
            core_attn_out_non_spec, last_recurrent_state = None, None

934
        # 3. Merge core attention output
935
        if spec_sequence_masks is not None and core_attn_out_non_spec is not None:
936
            merged_out = torch.empty(
937
938
939
940
                (1, num_actual_tokens, *core_attn_out_spec.shape[2:]),
                dtype=core_attn_out_non_spec.dtype,
                device=core_attn_out_non_spec.device,
            )
941
942
943
            merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec)
            merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec)
            core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)
944
        elif spec_sequence_masks is not None:
945
            core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)
946
        else:
947
            core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)
948
949
950
951
952
953


class Qwen3NextAttention(nn.Module):
    def __init__(
        self,
        config: Qwen3NextConfig,
954
955
956
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = config.head_dim or (self.hidden_size // self.num_heads)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.dual_chunk_attention_config = getattr(
981
982
            config, "dual_chunk_attention_config", None
        )
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
        self.attn_output_gate = getattr(config, "attn_output_gate", True)

        self.qkv_proj = QKVParallelLinear(
            config.hidden_size,
            self.head_dim,
            self.total_num_heads * (1 + self.attn_output_gate),
            self.total_num_kv_heads,
            bias=getattr(config, "qkv_bias", False),
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            config.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        self.rotary_emb = get_rope(
            head_size=self.head_dim,
            max_position=config.max_position_embeddings,
1006
            rope_parameters=config.rope_parameters,
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
            dual_chunk_attention_config=self.dual_chunk_attention_config,
        )

        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            **{
                "layer_idx": extract_layer_index(prefix),
1020
1021
1022
1023
                "dual_chunk_attention_config": self.dual_chunk_attention_config,
            }
            if self.dual_chunk_attention_config
            else {},
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
        )

        self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
        self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        output: torch.Tensor,
        hidden_states: torch.Tensor,
    ):
        qkv, _ = self.qkv_proj(hidden_states)

        if self.attn_output_gate:
            q_gate, k, v = qkv.split(
1039
1040
                [self.q_size * 2, self.kv_size, self.kv_size], dim=-1
            )
1041
1042
1043
1044
1045
1046
            orig_shape = q_gate.shape[:-1]
            q_gate = q_gate.view(*orig_shape, self.num_heads, -1)
            q, gate = torch.chunk(q_gate, 2, dim=-1)
            q = q.reshape(*orig_shape, -1)
            gate = gate.reshape(*orig_shape, -1)
        else:
1047
            q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
1048
1049

        q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view(
1050
1051
            -1, self.num_heads * self.head_dim
        )
1052
        k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view(
1053
1054
            -1, self.num_kv_heads * self.head_dim
        )
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069

        q, k = self.rotary_emb(positions, q, k)

        attn_output = self.attn(q, k, v)

        if self.attn_output_gate:
            gate = torch.sigmoid(gate)
            attn_output = attn_output * gate

        output[:], _ = self.o_proj(attn_output)


class Qwen3NextDecoderLayer(nn.Module):
    def __init__(
        self,
1070
        vllm_config: VllmConfig,
1071
1072
1073
1074
        layer_type: str,
        prefix: str = "",
    ) -> None:
        super().__init__()
1075
1076
1077
1078
1079
1080

        config = vllm_config.model_config.hf_config
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        speculative_config = vllm_config.speculative_config
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091

        self.layer_type = layer_type
        self.layer_idx = extract_layer_index(prefix)

        if self.layer_type == "linear_attention":
            self.linear_attn = Qwen3NextGatedDeltaNet(
                config,
                model_config=model_config,
                cache_config=cache_config,
                quant_config=quant_config,
                speculative_config=speculative_config,
1092
1093
                prefix=f"{prefix}.linear_attn",
            )
1094
1095
1096
1097
1098
1099
        elif self.layer_type == "full_attention":
            self.self_attn = Qwen3NextAttention(
                config,
                model_config=model_config,
                cache_config=cache_config,
                quant_config=quant_config,
1100
                prefix=f"{prefix}.self_attn",
1101
1102
1103
1104
            )
        else:
            raise ValueError(f"Invalid layer_type {self.layer_type}")

1105
1106
1107
        mlp_only_layers = (
            [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
        )
1108
        if (self.layer_idx not in mlp_only_layers) and (
1109
1110
1111
            config.num_experts > 0
            and (self.layer_idx + 1) % config.decoder_sparse_step == 0
        ):
1112
            self.mlp = Qwen3NextSparseMoeBlock(
1113
                vllm_config=vllm_config,
1114
1115
1116
1117
1118
1119
1120
1121
                prefix=f"{prefix}.mlp",
            )
        else:
            self.mlp = Qwen3NextMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
1122
                prefix=f"{prefix}.mlp",
1123
1124
            )

1125
1126
1127
        self.input_layernorm = Qwen3NextRMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
1128
        self.post_attention_layernorm = Qwen3NextRMSNorm(
1129
1130
            config.hidden_size, eps=config.rms_norm_eps
        )
1131
1132
1133
1134
1135
1136
1137

        self.layer_scale = getattr(config, "layer_scale", False)
        if self.layer_scale:
            self.attn_layer_scale = torch.nn.Parameter(
                torch.zeros(
                    1,
                    1,
1138
                    config.hidden_size,
1139
1140
                ),
            )
1141
1142
1143
1144
            self.ffn_layer_scale = torch.nn.Parameter(
                torch.zeros(
                    1,
                    1,
1145
                    config.hidden_size,
1146
1147
                ),
            )
1148
1149
1150
1151

    def forward(
        self,
        hidden_states: torch.Tensor,
1152
        residual: torch.Tensor | None,
1153
1154
1155
1156
1157
1158
1159
        positions: torch.Tensor = None,
        **kwargs: object,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
1160
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180

        self_attention_output = torch.empty_like(hidden_states)
        if self.layer_type == "linear_attention":
            self.linear_attn(
                hidden_states=hidden_states,
                output=self_attention_output,
            )
        elif self.layer_type == "full_attention":
            self.self_attn(
                hidden_states=hidden_states,
                output=self_attention_output,
                positions=positions,
            )
        else:
            raise ValueError("Invalid layer_type")
        hidden_states = self_attention_output

        if self.layer_scale:
            if len(hidden_states.shape) == 2:
                hidden_states = hidden_states * (
1181
1182
                    self.attn_layer_scale.to(hidden_states.dtype)[0] + 1
                )
1183
1184
            else:
                hidden_states = hidden_states * (
1185
1186
                    self.attn_layer_scale.to(hidden_states.dtype) + 1
                )
1187
1188

        # Fully Connected
1189
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
1190
1191
1192
1193
1194
        hidden_states = self.mlp(hidden_states)

        if self.layer_scale:
            if len(hidden_states.shape) == 2:
                hidden_states = hidden_states * (
1195
1196
                    self.ffn_layer_scale.to(hidden_states.dtype)[0] + 1
                )
1197
            else:
1198
                assert len(hidden_states.shape) == len(self.ffn_layer_scale.shape), (
1199
1200
1201
                    f"shape must be the same {len(hidden_states.shape)}, "
                    f"{len(self.ffn_layer_scale.shape)}"
                )
1202
                hidden_states = hidden_states * (
1203
1204
                    self.ffn_layer_scale.to(hidden_states.dtype) + 1
                )
1205
1206
1207
1208
1209
1210
1211
1212
1213

        return hidden_states, residual


@support_torch_compile
class Qwen3NextModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

1214
        config: Qwen3NextConfig = vllm_config.model_config.hf_text_config
1215
        parallel_config = vllm_config.parallel_config
1216

1217
1218
1219
1220
        eplb_config = parallel_config.eplb_config
        self.num_redundant_experts = eplb_config.num_redundant_experts

        self.config = config
1221
1222

        self.vocab_size = config.vocab_size
1223
1224
1225
1226
1227
1228
1229
1230

        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
        )

        def get_layer(prefix: str):
            return Qwen3NextDecoderLayer(
1231
                vllm_config,
1232
1233
1234
1235
1236
                layer_type=config.layer_types[extract_layer_index(prefix)],
                prefix=prefix,
            )

        self.start_layer, self.end_layer, self.layers = make_layers(
1237
1238
1239
1240
1241
            config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
        )
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
1242

1243
        if get_pp_group().is_last_rank:
1244
            self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1245
1246
        else:
            self.norm = PPMissingLayer()
1247

1248
1249
        self.aux_hidden_state_layers: tuple[int, ...] = ()

1250
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
1251
1252
1253
1254
        return self.embed_tokens(input_ids)

    def forward(
        self,
1255
        input_ids: torch.Tensor | None,
1256
        positions: torch.Tensor,
1257
1258
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1259
    ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
1260
1261
1262
1263
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
1264
                hidden_states = self.embed_input_ids(input_ids)
1265
1266
1267
1268
1269
1270
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

1271
1272
1273
1274
1275
1276
1277
1278
1279
        aux_hidden_states = []
        for layer_idx, layer in enumerate(
            islice(self.layers, self.start_layer, self.end_layer),
            start=self.start_layer,
        ):
            if layer_idx in self.aux_hidden_state_layers:
                aux_hidden_states.append(
                    hidden_states + residual if residual is not None else hidden_states
                )
1280
1281
1282
1283
1284
1285
1286
            hidden_states, residual = layer(
                positions=positions,
                hidden_states=hidden_states,
                residual=residual,
            )

        if not get_pp_group().is_last_rank:
1287
1288
1289
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
1290
        hidden_states, _ = self.norm(hidden_states, residual)
1291
1292
        if aux_hidden_states:
            return hidden_states, aux_hidden_states
1293
1294
1295
1296
1297
        return hidden_states

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
1298
        return SharedFusedMoE.make_expert_params_mapping(
1299
            self,
1300
1301
1302
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
1303
            num_experts=getattr(self.config, "num_experts", 0),
1304
1305
            num_redundant_experts=self.num_redundant_experts,
        )
1306

1307
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        expert_params_mapping = self.get_expert_mapping()
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue

            if name.startswith("mtp."):
                continue

1327
1328
1329
1330
1331
1332
            # Remapping the name of FP8 kv-scale.
            if name.endswith("scale"):
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue

                if "mlp.experts" in name:
                    continue

                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
                # name = apply_attn_prefix(name, params_dict)
                if name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
                    # Skip loading extra bias for GPTQ models.
1364
1365
1366
                    if (
                        name.endswith(".bias") or name.endswith("_bias")
                    ) and name not in params_dict:
1367
                        continue
1368
1369
                    if name not in params_dict:
                        continue
1370
1371
                    param = params_dict[name]
                    weight_loader = param.weight_loader
1372
1373
1374
1375
1376
1377
1378
                    weight_loader(
                        param,
                        loaded_weight,
                        name,
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
1379
1380
1381
1382
1383
1384
1385
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
                    if is_pp_missing_parameter(name, self):
                        continue
1386
1387
1388
1389
1390
                    if name not in params_dict:
                        logger.warning_once(
                            f"Parameter {name} not found in params_dict, skip loading"
                        )
                        continue
1391
                    param = params_dict[name]
1392
1393
1394
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
1395
1396
1397
1398
1399
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
class QwenNextMixtureOfExperts(MixtureOfExperts):
    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
        for layer in self.model.layers:
            if isinstance(layer.mlp, Qwen3NextSparseMoeBlock):
                moe = layer.mlp
                moe.n_local_physical_experts = num_local_physical_experts
                moe.n_physical_experts = num_physical_experts
                moe.n_redundant_experts = self.num_redundant_experts
                moe.experts.update_expert_map()

    def set_moe_parameters(self):
        self.expert_weights = []

        self.moe_layers = []
        example_moe = None
        for layer in self.model.layers:
            if isinstance(layer, Qwen3NextDecoderLayer) and isinstance(
                layer.mlp, Qwen3NextSparseMoeBlock
            ):
                example_moe = layer.mlp
                self.moe_layers.append(layer.mlp.experts)

1430
1431
        if example_moe is None:
            raise RuntimeError("No Qwen3Next layer found in the model.layers.")
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443

        # Set MoE hyperparameters
        self.num_moe_layers = len(self.moe_layers)
        self.num_expert_groups = 1
        self.num_shared_experts = 0
        self.num_logical_experts = example_moe.n_logical_experts
        self.num_physical_experts = example_moe.n_physical_experts
        self.num_local_physical_experts = example_moe.n_local_physical_experts
        self.num_routed_experts = example_moe.n_routed_experts
        self.num_redundant_experts = example_moe.n_redundant_experts


1444
class Qwen3NextForCausalLM(
1445
1446
1447
1448
1449
1450
    nn.Module,
    HasInnerState,
    SupportsLoRA,
    SupportsPP,
    QwenNextMixtureOfExperts,
    IsHybrid,
1451
):
1452
1453
1454
1455
1456
1457
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
1458
        "gate_up_proj": ["gate_proj", "up_proj"],
1459
1460
        "in_proj_qkvz": ["in_proj_qkvz"],
        "in_proj_ba": ["in_proj_ba"],
1461
1462
1463
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1464
        config = vllm_config.model_config.hf_text_config
1465
1466
1467
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
1468

1469
        scheduler_config = vllm_config.scheduler_config
1470
1471
1472
1473
1474
        if cache_config.mamba_cache_mode == "all":
            raise NotImplementedError(
                "Qwen3Next currently does not support 'all' prefix caching, "
                "please use '--mamba-cache-mode=align' instead"
            )
1475
1476
1477
1478
1479
        self.quant_config = vllm_config.quant_config

        super().__init__()
        self.config = config
        self.scheduler_config = scheduler_config
1480
1481
1482
        self.model = Qwen3NextModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
1483

1484
        self.lm_head = ParallelLMHead(
1485
            config.vocab_size,
1486
            config.hidden_size,
1487
1488
            prefix=maybe_prefix(prefix, "lm_head"),
        )
1489
        self.logits_processor = LogitsProcessor(config.vocab_size)
1490
        self.make_empty_intermediate_tensors = (
1491
1492
            self.model.make_empty_intermediate_tensors
        )
1493
1494

        # Set MoE hyperparameters
1495
        self.set_moe_parameters()
1496

1497
1498
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
1499
1500
1501

    def forward(
        self,
1502
        input_ids: torch.Tensor | None,
1503
        positions: torch.Tensor,
1504
1505
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1506
1507
        **kwargs: object,
    ):
1508
1509
1510
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
1511
1512
1513
1514
1515
1516
1517
1518
1519

        return hidden_states

    @classmethod
    def get_mamba_state_dtype_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[torch.dtype, torch.dtype]:
        return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
1520
1521
1522
            vllm_config.model_config.dtype,
            vllm_config.cache_config.mamba_cache_dtype,
            vllm_config.cache_config.mamba_ssm_cache_dtype,
1523
        )
1524
1525
1526

    @classmethod
    def get_mamba_state_shape_from_config(
1527
        cls, vllm_config: "VllmConfig"
1528
1529
    ) -> tuple[tuple[int, int], tuple[int, int]]:
        parallel_config = vllm_config.parallel_config
1530
        hf_config = vllm_config.model_config.hf_text_config
1531
        tp_size = parallel_config.tensor_parallel_size
1532
1533
1534
1535
1536
        num_spec = (
            vllm_config.speculative_config.num_speculative_tokens
            if vllm_config.speculative_config
            else 0
        )
1537
        return MambaStateShapeCalculator.gated_delta_net_state_shape(
1538
1539
1540
1541
1542
1543
1544
1545
            tp_size,
            hf_config.linear_num_key_heads,
            hf_config.linear_num_value_heads,
            hf_config.linear_key_head_dim,
            hf_config.linear_value_head_dim,
            hf_config.linear_conv_kernel_dim,
            num_spec,
        )
1546

1547
1548
1549
1550
    @classmethod
    def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
        return MambaStateCopyFuncCalculator.gated_delta_net_state_copy_func()

1551
1552
1553
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1554
    ) -> torch.Tensor | None:
1555
        return self.logits_processor(self.lm_head, hidden_states)
1556

1557
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=["mtp."],
        )
        return loader.load_weights(weights)

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        return self.model.get_expert_mapping()


1568
1569
1570
1571
1572
def gdn_attention_core(
    mixed_qkv: torch.Tensor,
    b: torch.Tensor,
    a: torch.Tensor,
    core_attn_out: torch.Tensor,
1573
1574
    layer_name: str,
) -> None:
1575
1576
1577
1578
1579
    """
    Custom op for the core attention computation.
    Only handles the convolution + recurrent attention part.
    Input/output projections are handled outside this op.
    """
1580
1581
    forward_context: ForwardContext = get_forward_context()
    self = forward_context.no_compile_layers[layer_name]
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
    self._forward_core(
        mixed_qkv=mixed_qkv,
        b=b,
        a=a,
        core_attn_out=core_attn_out,
    )


def gdn_attention_core_fake(
    mixed_qkv: torch.Tensor,
    b: torch.Tensor,
    a: torch.Tensor,
    core_attn_out: torch.Tensor,
1595
1596
    layer_name: str,
) -> None:
1597
    """Fake implementation for torch.compile."""
1598
1599
1600
1601
    return


direct_register_custom_op(
1602
1603
1604
1605
    op_name="gdn_attention_core",
    op_func=gdn_attention_core,
    mutates_args=["core_attn_out"],
    fake_impl=gdn_attention_core_fake,
1606
1607
1608
1609
1610
1611
)


@triton.jit
def fused_gdn_gating_kernel(
    g,
1612
    beta_output,
1613
1614
    A_log,
    a,
1615
    b,
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
    dt_bias,
    seq_len,
    NUM_HEADS: tl.constexpr,
    beta: tl.constexpr,
    threshold: tl.constexpr,
    BLK_HEADS: tl.constexpr,
):
    i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS)
    off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off
    mask = head_off < NUM_HEADS
    blk_A_log = tl.load(A_log + head_off, mask=mask)
    blk_a = tl.load(a + off, mask=mask)
1629
    blk_b = tl.load(b + off, mask=mask)
1630
1631
1632
    blk_bias = tl.load(dt_bias + head_off, mask=mask)
    # If the model is loaded in fp16, without the .float() here, A might be -inf
    x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)
1633
1634
1635
    softplus_x = tl.where(
        beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x
    )
1636
1637
    blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
    tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
1638
    # compute beta_output = sigmoid(b)
1639
1640
1641
1642
    blk_beta_output = tl.sigmoid(blk_b.to(tl.float32))
    tl.store(
        beta_output + off, blk_beta_output.to(beta_output.dtype.element_ty), mask=mask
    )
1643
1644
1645
1646
1647


def fused_gdn_gating(
    A_log: torch.Tensor,
    a: torch.Tensor,
1648
    b: torch.Tensor,
1649
1650
1651
    dt_bias: torch.Tensor,
    beta: float = 1.0,
    threshold: float = 20.0,
1652
1653
1654
1655
1656
1657
1658
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Fused computation of g and beta for Gated Delta Net.
    g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
    beta_output = b.sigmoid()
    TODO maybe use torch.compile to replace this triton kernel
    """
1659
1660
1661
    batch, num_heads = a.shape
    seq_len = 1
    grid = (batch, seq_len, triton.cdiv(num_heads, 8))
1662
    g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
1663
    beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device)
1664
    fused_gdn_gating_kernel[grid](
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
        g,
        beta_output,
        A_log,
        a,
        b,
        dt_bias,
        seq_len,
        num_heads,
        beta,
        threshold,
        8,
        num_warps=1,
1677
    )
1678
    return g, beta_output