qwen3_next.py 65.1 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Qwen3Next model."""
4

5
from collections.abc import Iterable
6
from itertools import islice
7
8
9
10
11
12

import torch
from einops import rearrange
from torch import nn
from transformers.activations import ACT2FN

13
from vllm import envs
14
from vllm.compilation.decorators import support_torch_compile
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from vllm.config import (
    CacheConfig,
    ModelConfig,
    VllmConfig,
    get_current_vllm_config,
)
from vllm.distributed import (
    divide,
    get_ep_group,
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
29
30
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
31
from vllm.model_executor.custom_op import CustomOp
32
from vllm.model_executor.layers.attention import Attention
33
from vllm.model_executor.layers.fla.ops import (
34
35
36
    chunk_gated_delta_rule as fla_chunk_gated_delta_rule,
)
from vllm.model_executor.layers.fla.ops import (
37
    fused_recurrent_gated_delta_rule_packed_decode,
38
    fused_sigmoid_gating_delta_rule_update,
39
)
40
from vllm.model_executor.layers.fla.ops.chunk import l2norm_fwd
41
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
42
43
44
45
from vllm.model_executor.layers.layernorm import (
    GemmaRMSNorm as Qwen3NextRMSNorm,
)
from vllm.model_executor.layers.layernorm import RMSNormGated
46
47
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
48
    MergedColumnParallelLinear,
49
50
51
52
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
53
54
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.abstract import MambaBase
55
from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weight_loader
56
from vllm.model_executor.layers.mamba.mamba_utils import (
57
58
    MambaStateCopyFunc,
    MambaStateCopyFuncCalculator,
59
60
61
    MambaStateDtypeCalculator,
    MambaStateShapeCalculator,
)
62
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
63
64
65
    causal_conv1d_fn,
    causal_conv1d_update,
)
66
67
68
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
69
70
71
    ParallelLMHead,
    VocabParallelEmbedding,
)
72
from vllm.model_executor.model_loader.weight_utils import (
73
    default_weight_loader,
74
    maybe_remap_kv_scale_name,
75
76
    sharded_weight_loader,
)
77
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
78
from vllm.model_executor.models.utils import sequence_parallel_chunk
79
80
81
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
82
from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig
83
from vllm.triton_utils import tl, triton
84
85
86
87
88
from vllm.utils.multi_stream_utils import maybe_execute_in_parallel
from vllm.utils.torch_utils import (
    aux_stream,
    direct_register_custom_op,
)
89
from vllm.v1.attention.backend import AttentionMetadata
90
91
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from .interfaces import (
    HasInnerState,
    IsHybrid,
    MixtureOfExperts,
    SupportsLoRA,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
108
109
110
111
112
113

logger = init_logger(__name__)

KVCache = tuple[torch.Tensor, torch.Tensor]


114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def fi_chunk_gated_delta_rule(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    g: torch.Tensor,
    beta: torch.Tensor,
    initial_state: torch.Tensor,
    output_final_state: bool,
    cu_seqlens: torch.LongTensor | None = None,
    use_qk_l2norm_in_kernel: bool = True,
):
    from flashinfer.gdn_prefill import (
        chunk_gated_delta_rule as chunk_gated_delta_rule_fi,
    )

    if use_qk_l2norm_in_kernel:
        q = l2norm_fwd(q)
        k = l2norm_fwd(k)

    # use flashinfer implementation
    q = q.squeeze(0).contiguous()
    k = k.squeeze(0).contiguous()
    v = v.squeeze(0).contiguous()

    g = g.squeeze(0).contiguous()
    beta = beta.squeeze(0).contiguous()
    fi_state = initial_state.to(torch.float32)
    fi_g = g.to(torch.float32)
    fi_beta = beta.to(torch.float32)
143
    result = chunk_gated_delta_rule_fi(
144
145
146
147
148
149
150
151
152
        q=q,
        k=k,
        v=v,
        g=torch.exp(fi_g),
        beta=fi_beta,
        initial_state=fi_state,
        output_final_state=output_final_state,
        cu_seqlens=cu_seqlens,
    )
153
154
    # FlashInfer returns (output, state) when output_final_state=True,
    # or just output when output_final_state=False.
155
    # Unsqueeze back to 4D (1, L, H, D) to match fla output format
156
157
158
159
160
    if output_final_state:
        output, final_state = result
        return output.unsqueeze(0), final_state
    else:
        return result.unsqueeze(0), None
161
162
163
164
165
166


@CustomOp.register("chunk_gated_delta_rule")
class ChunkGatedDeltaRule(CustomOp):
    def __init__(self) -> None:
        super().__init__()
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        backend = (
            str(
                get_current_vllm_config().additional_config.get(
                    "gdn_prefill_backend", "auto"
                )
            )
            .strip()
            .lower()
        )
        supports_flashinfer = (
            current_platform.is_cuda() and current_platform.is_device_capability(90)
        )

        if backend == "flashinfer":
            use_flashinfer = supports_flashinfer
            if not use_flashinfer:
                logger.warning_once(
                    "GDN prefill backend 'flashinfer' is selected but "
                    "cannot use this kernel on the current platform. "
                    "Falling back to Triton/FLA."
                )
        elif backend == "triton":
            use_flashinfer = False
        else:
            use_flashinfer = supports_flashinfer

        if use_flashinfer:
194
            logger.info_once("Using FlashInfer GDN prefill kernel", scope="local")
195
            logger.info_once(
196
197
                "FlashInfer GDN prefill kernel is JIT-compiled; first run may "
                "take a while to compile. Set `--gdn-prefill-backend triton` to "
198
199
                "avoid JIT compile time.",
                scope="local",
200
201
            )
        else:
202
            logger.info_once("Using Triton/FLA GDN prefill kernel", scope="local")
203
204
205
206

        self._forward_method = (
            self.forward_cuda if use_flashinfer else self.forward_native
        )
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256

    def forward_cuda(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        g: torch.Tensor,
        beta: torch.Tensor,
        initial_state: torch.Tensor,
        output_final_state: bool,
        cu_seqlens: torch.LongTensor | None = None,
        use_qk_l2norm_in_kernel: bool = True,
    ):
        return fi_chunk_gated_delta_rule(
            q=q,
            k=k,
            v=v,
            g=g,
            beta=beta,
            initial_state=initial_state,
            output_final_state=output_final_state,
            cu_seqlens=cu_seqlens,
            use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
        )

    def forward_native(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        g: torch.Tensor,
        beta: torch.Tensor,
        initial_state: torch.Tensor,
        output_final_state: bool,
        cu_seqlens: torch.LongTensor | None = None,
        use_qk_l2norm_in_kernel: bool = True,
    ):
        return fla_chunk_gated_delta_rule(
            q=q,
            k=k,
            v=v,
            g=g,
            beta=beta,
            initial_state=initial_state,
            output_final_state=output_final_state,
            cu_seqlens=cu_seqlens,
            use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
        )


257
class Qwen3NextSparseMoeBlock(nn.Module):
258
    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
259
        super().__init__()
260

261
        config = vllm_config.model_config.hf_text_config
262
263
264
        parallel_config = vllm_config.parallel_config
        quant_config = vllm_config.quant_config

265
266
267
        self.tp_size = get_tensor_model_parallel_world_size()

        self.ep_group = get_ep_group().device_group
268
        self.ep_rank = get_ep_group().rank_in_group
269
270
271
        self.ep_size = self.ep_group.size()
        self.n_routed_experts = config.num_experts

272
273
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe

274
275
276
        if self.tp_size > config.num_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
277
278
                f"the number of experts {config.num_experts}."
            )
279
280
281
282

        # Load balancing settings.
        vllm_config = get_current_vllm_config()
        eplb_config = vllm_config.parallel_config.eplb_config
283
        self.enable_eplb = parallel_config.enable_eplb
284
285
286

        self.n_logical_experts = self.n_routed_experts
        self.n_redundant_experts = eplb_config.num_redundant_experts
287
        self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
288
289
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

290
291
292
293
294
295
296
297
298
        self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
        self.physical_expert_end = (
            self.physical_expert_start + self.n_local_physical_experts
        )

        self.gate = ReplicatedLinear(
            config.hidden_size,
            config.num_experts,
            bias=False,
299
            quant_config=None,
300
301
            prefix=f"{prefix}.gate",
        )
302

303
304
305
306
307
308
309
        self.shared_expert_gate = ReplicatedLinear(
            config.hidden_size,
            1,
            bias=False,
            quant_config=None,
            prefix=f"{prefix}.shared_expert_gate",
        )
310

311
312
313
314
315
316
        if config.shared_expert_intermediate_size > 0:
            self.shared_expert = Qwen3NextMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.shared_expert_intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
317
318
                reduce_results=False,
                expert_gate=self.shared_expert_gate,
319
                prefix=f"{prefix}.shared_expert",
320
321
322
            )
        else:
            self.shared_expert = None
323
324
325

        self.experts = SharedFusedMoE(
            shared_experts=self.shared_expert,
326
            gate=self.gate,
327
328
329
330
331
            num_experts=self.n_routed_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            reduce_results=False,
332
            renormalize=getattr(config, "norm_topk_prob", True),
333
334
335
336
337
338
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
            is_sequence_parallel=self.is_sequence_parallel,
        )
339
340
341
342

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
343
        num_tokens, hidden_dim = hidden_states.shape
344
345
        hidden_states = hidden_states.view(-1, hidden_dim)

346
347
348
        if self.is_sequence_parallel:
            hidden_states = sequence_parallel_chunk(hidden_states)

349
350
351
352
353
354
355
356
357
358
359
        if self.experts.is_internal_router:
            # In this case, the gate/router runs inside the FusedMoE class
            final_hidden_states = self.experts(
                hidden_states=hidden_states, router_logits=hidden_states
            )
        else:
            # router_logits: (num_tokens, n_experts)
            router_logits, _ = self.gate(hidden_states)
            final_hidden_states = self.experts(
                hidden_states=hidden_states, router_logits=router_logits
            )
360

361
362
        if self.shared_expert is not None:
            final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
363
364
365

        if self.is_sequence_parallel:
            final_hidden_states = tensor_model_parallel_all_gather(
366
367
                final_hidden_states, 0
            )
368
369
            final_hidden_states = final_hidden_states[:num_tokens]
        elif self.tp_size > 1:
370
            final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(  # noqa E501
371
372
                final_hidden_states
            )
373
374
375
376
377
378
379

        return final_hidden_states.view(orig_shape)


class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
    @property
    def mamba_type(self) -> str:
380
        return "gdn_attention"
381
382
383

    def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
        return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
384
385
386
            self.model_config.dtype,
            self.cache_config.mamba_cache_dtype,
            self.cache_config.mamba_ssm_cache_dtype,
387
        )
388
389
390

    def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
        return MambaStateShapeCalculator.gated_delta_net_state_shape(
391
392
393
394
395
396
397
398
            self.tp_size,
            self.num_k_heads,
            self.num_v_heads,
            self.head_k_dim,
            self.head_v_dim,
            self.conv_kernel_size,
            self.num_spec,
        )
399
400
401
402

    def __init__(
        self,
        config: Qwen3NextConfig,
403
        vllm_config: VllmConfig,
404
        prefix: str = "",
405
        create_in_proj_qkvz: bool = True,
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
    ) -> None:
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.hidden_size = config.hidden_size
        self.num_v_heads = config.linear_num_value_heads
        self.num_k_heads = config.linear_num_key_heads
        self.head_k_dim = config.linear_key_head_dim
        self.head_v_dim = config.linear_value_head_dim
        self.key_dim = self.head_k_dim * self.num_k_heads
        self.value_dim = self.head_v_dim * self.num_v_heads

        self.conv_kernel_size = config.linear_conv_kernel_dim
        self.layer_idx = extract_layer_index(prefix)
        self.activation = config.hidden_act
        self.act = ACT2FN[config.hidden_act]
        self.layer_norm_epsilon = config.rms_norm_eps
        self.prefix = prefix
424
425
426
        self.aux_stream = aux_stream()
        self.events = (
            [torch.cuda.Event(), torch.cuda.Event()]
427
            if current_platform.is_cuda_alike()
428
429
            else [None, None]
        )
430
431

        self.config = config
432
433
434
435
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        self.speculative_config = vllm_config.speculative_config
436
437
438
439
440
        self.num_spec = (
            self.speculative_config.num_speculative_tokens
            if self.speculative_config
            else 0
        )
441
442
443
444
445
446
447
448
449
450
451
452

        # QKV
        self.conv_dim = self.key_dim * 2 + self.value_dim
        self.conv1d = ColumnParallelLinear(
            input_size=self.conv_kernel_size,
            output_size=self.conv_dim,
            bias=False,
            prefix=f"{prefix}.conv1d",
        )
        self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)

        # projection of the input hidden states
453
454
        # Qwen3-Next and Qwen3.5 has a different qkv_proj layout,
        # we need to create qkvz_proj adaptively here.
455
456
457
458
459
460
461
462
463
464
        # When create_in_proj_qkvz is False (e.g. LoRA enabled in Qwen3.5),
        # the subclass creates in_proj_qkv and in_proj_z separately.
        if create_in_proj_qkvz:
            self.in_proj_qkvz = self.create_qkvz_proj(
                hidden_size=self.hidden_size,
                key_dim=self.key_dim,
                value_dim=self.value_dim,
                quant_config=quant_config,
                prefix=f"{prefix}.in_proj_qkvz",
            )
465
        # ba_proj doesn't support blockwise fp8 quantization.
466
467
468
469
470
        # Qwen3-Next and Qwen3.5 have different in_proj_ba checkpoint
        # layouts, so we use a factory method to create the projection.
        self.in_proj_ba = self.create_ba_proj(
            hidden_size=self.hidden_size,
            num_v_heads=self.num_v_heads,
471
472
            quant_config=quant_config,
            prefix=f"{prefix}.in_proj_ba",
473
474
475
476
477
478
479
        )

        query_key_settings = (self.key_dim, 0, False)
        value_settings = (self.value_dim, 0, False)

        delattr(self.conv1d.weight, "weight_loader")
        set_weight_attrs(
480
481
482
483
484
485
486
487
488
489
490
491
492
            self.conv1d.weight,
            {
                "weight_loader": mamba_v2_sharded_weight_loader(
                    [
                        query_key_settings,
                        query_key_settings,
                        value_settings,
                    ],
                    self.tp_size,
                    self.tp_rank,
                )
            },
        )
493

494
        # selective projection used to make dt, B and C input dependent
495
496
497
498

        # time step projection (discretization)
        # instantiate once and copy inv_dt in init_weights of PretrainedModel
        self.dt_bias = nn.Parameter(
499
500
            torch.ones(self.num_v_heads // self.tp_size),
        )
501
502
503
        self.A_log = nn.Parameter(
            torch.empty(
                divide(self.num_v_heads, self.tp_size),
504
                dtype=torch.float32,
505
506
            )
        )
507

508
509
        set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)})
        set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
510
511
512
513
514
515

        self.norm = RMSNormGated(
            self.head_v_dim,
            eps=self.layer_norm_epsilon,
            group_size=None,
            norm_before_gate=True,
516
            device=current_platform.current_device(),
517
518
        )

519
520
521
522
523
524
525
526
        self.out_proj = RowParallelLinear(
            self.value_dim,
            self.hidden_size,
            bias=False,
            input_is_parallel=True,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )
527

528
        self.chunk_gated_delta_rule = ChunkGatedDeltaRule()
529
530
531
        self.enable_packed_recurrent_decode = (
            envs.VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE
        )
532

533
534
535
536
537
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
    def create_qkvz_proj(
        self,
        hidden_size: int,
        key_dim: int,
        value_dim: int,
        quant_config: QuantizationConfig | None,
        prefix: str,
    ) -> MergedColumnParallelLinear:
        return MergedColumnParallelLinear(
            input_size=hidden_size,
            output_sizes=[sum((key_dim, key_dim, value_dim, value_dim))],
            bias=False,
            quant_config=quant_config,
            prefix=prefix,
        )

554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
    def create_ba_proj(
        self,
        hidden_size: int,
        num_v_heads: int,
        quant_config: QuantizationConfig | None,
        prefix: str,
    ) -> MergedColumnParallelLinear:
        # Qwen3-Next stores in_proj_ba as a single fused weight with an
        # interleaved GQA layout: [b_g0, a_g0, b_g1, a_g1, ...] where
        # each group corresponds to a key-head group. We must use a single
        # output shard so that ColumnParallel sharding preserves this
        # interleaved structure across TP ranks.
        # Qwen3.5 overrides this to use [num_v_heads, num_v_heads] since
        # its checkpoint has separate in_proj_b and in_proj_a weights.
        return MergedColumnParallelLinear(
            input_size=hidden_size,
            output_sizes=[num_v_heads * 2],
            bias=False,
            quant_config=quant_config,
            prefix=prefix,
        )

576
577
    def fix_query_key_value_ordering(
        self,
578
579
        mixed_qkvz: torch.Tensor,
        mixed_ba: torch.Tensor,
580
581
582
583
584
585
    ):
        """
        Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.
        """
        new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + (
            self.num_k_heads // self.tp_size,
586
587
588
589
590
591
592
            (
                self.head_k_dim
                + self.head_k_dim
                + (self.head_v_dim + self.head_v_dim)
                * self.num_v_heads
                // self.num_k_heads
            ),
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
        )
        new_tensor_shape_ba = mixed_qkvz.size()[:-1] + (
            self.num_k_heads // self.tp_size,
            2 * self.num_v_heads // self.num_k_heads,
        )

        mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz)
        mixed_ba = mixed_ba.view(*new_tensor_shape_ba)

        split_arg_list_qkvz = [
            self.head_k_dim,
            self.head_k_dim,
            (self.num_v_heads // self.num_k_heads * self.head_v_dim),
            (self.num_v_heads // self.num_k_heads * self.head_v_dim),
        ]
        split_arg_list_ba = [
            self.num_v_heads // self.num_k_heads,
610
            self.num_v_heads // self.num_k_heads,
611
612
613
614
615
        ]

        # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)]
        # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn],
        #  [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng]
616
        (query, key, value, z) = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=2)
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
        (b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2)

        # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn]
        value = value.reshape(value.size(0), -1, self.head_v_dim)
        z = z.reshape(z.size(0), -1, self.head_v_dim)
        b = b.reshape(b.size(0), self.num_v_heads // self.tp_size)
        a = a.reshape(a.size(0), self.num_v_heads // self.tp_size)

        return query, key, value, z, b, a

    def rearrange_mixed_qkv(self, mixed_qkv):
        if mixed_qkv is None:
            return None, None, None
        query, key, value = torch.split(
            mixed_qkv,
            [
                self.key_dim // self.tp_size,
                self.key_dim // self.tp_size,
                self.value_dim // self.tp_size,
            ],
            dim=-1,
        )
        query, key = map(
640
641
642
643
            lambda x: rearrange(x, "l (h d) -> 1 l h d", d=self.head_k_dim),
            (query, key),
        )
        value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim)
644
        return query.contiguous(), key.contiguous(), value.contiguous()
645
646
647
648
649
650

    def forward(
        self,
        hidden_states: torch.Tensor,
        output: torch.Tensor,
    ):
651
652
653
654
655
656
657
658
659
660
661
        """
        Forward pass with three parts:
        1. Input projection
        2. Core attention (custom op)
        3. Output projection
        """
        num_tokens = hidden_states.size(0)

        # ============================================================
        # Part 1: Input Projection
        # ============================================================
662
663
        projected_states_qkvz, projected_states_ba = torch.ops.vllm.gdn_in_proj(
            hidden_states,
664
665
            sum(self.in_proj_qkvz.output_sizes) // self.tp_size,
            sum(self.in_proj_ba.output_sizes) // self.tp_size,
666
667
            self.prefix,
        )
668
669
670
671
672
673
674
675
676
677
678
        query, key, value, z, b, a = self.fix_query_key_value_ordering(
            projected_states_qkvz, projected_states_ba
        )
        query, key, value = map(
            lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value)
        )
        mixed_qkv = torch.cat((query, key, value), dim=-1)

        # ============================================================
        # Part 2: Core Attention (Custom Op)
        # ============================================================
679
680
        # Note: we should not use torch.empty here like other attention backends,
        # see discussions in https://github.com/vllm-project/vllm/pull/28182
681
682
683
684
685
686
687
688
689
690
691
        core_attn_out = torch.zeros(
            (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim),
            dtype=hidden_states.dtype,
            device=hidden_states.device,
        )

        torch.ops.vllm.gdn_attention_core(
            mixed_qkv,
            b,
            a,
            core_attn_out,
692
693
694
            self.prefix,
        )

695
696
697
698
699
700
701
702
703
704
705
706
        # ============================================================
        # Part 3: Output Projection
        # ============================================================
        z_shape_og = z.shape
        # Reshape input data into 2D tensor
        core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
        z = z.reshape(-1, z.shape[-1])
        core_attn_out = self.norm(core_attn_out, z)
        core_attn_out = core_attn_out.reshape(z_shape_og)
        core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
        output[:num_tokens], _ = self.out_proj(core_attn_out)

707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
    def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None:
        """Warm up GDN prefill kernels during V1 profiling.

        During V1 profile runs, ``_forward_core`` returns early because
        ``attn_metadata`` is ``None``, so the autotuned kernels used by
        ``chunk_gated_delta_rule`` (e.g. ``solve_tril``,
        ``chunk_scaled_dot_kkt``) are never invoked.  After profiling,
        vLLM allocates KV cache using most of the remaining GPU memory.
        When the first real inference triggers the autotuner it OOMs
        because there is not enough memory left for benchmarking.

        This method runs minimal forward passes through
        ``chunk_gated_delta_rule`` with small dummy tensors to force
        autotuning while GPU memory is still plentiful.  The autotuner
        results are cached globally, so only the first layer incurs
        actual benchmarking cost.

        Most kernels use a fixed ``BT = chunk_size`` (64), but
        ``chunk_fwd_kernel_o`` recomputes ``BT`` from the sequence
        length: ``min(64, max(16, next_power_of_2(T)))``.  Since ``BT``
        is part of its autotune key, we run warmup passes with T = 16,
        32, and 64 to cover all possible ``BT`` values.

        The decode path uses ``fused_sigmoid_gating_delta_rule_update``
        which has fixed kernel parameters (no autotuning), so only the
        prefill (chunked) path needs warming up.
        """
        if hasattr(self, "_prefill_kernels_warmed_up"):
            return
        self._prefill_kernels_warmed_up = True

        device = mixed_qkv.device
        dtype = mixed_qkv.dtype
        num_k_heads = self.num_k_heads // self.tp_size
        num_v_heads = self.num_v_heads // self.tp_size
        _, state_dtype = self.get_state_dtype()

        # Run warmup for each possible BT value of chunk_fwd_kernel_o:
        #   T=16 → BT=16, T=32 → BT=32, T=64 → BT=64.
        # Other kernels always use BT=chunk_size(64), so their autotune
        # cache is populated on the first pass and reused thereafter.
        for T in (16, 32, 64):
            q = torch.randn(
                1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype
            )
            k = torch.randn(
                1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype
            )
            v = torch.randn(
                1, T, num_v_heads, self.head_v_dim, device=device, dtype=dtype
            )
            g = torch.randn(1, T, num_v_heads, device=device, dtype=dtype)
            beta = torch.randn(1, T, num_v_heads, device=device, dtype=dtype)
            state = torch.zeros(
                1,
                num_v_heads,
                self.head_v_dim,
                self.head_k_dim,
                device=device,
                dtype=state_dtype,
            )
            cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.long)

            try:
                self.chunk_gated_delta_rule(
                    q=q,
                    k=k,
                    v=v,
                    g=g,
                    beta=beta,
                    initial_state=state,
                    output_final_state=False,
                    cu_seqlens=cu_seqlens,
                    use_qk_l2norm_in_kernel=True,
                )
            except Exception:
                logger.warning(
                    "GDN prefill kernel warmup (T=%d) failed for "
                    "layer %s. First inference may OOM due to "
                    "autotuner.",
                    T,
                    self.prefix,
                    exc_info=True,
                )
            else:
                logger.debug(
                    "GDN prefill kernel warmup (T=%d) completed for layer %s",
                    T,
                    self.prefix,
                )
            finally:
                del q, k, v, g, beta, state, cu_seqlens

        torch.accelerator.empty_cache()

802
803
804
805
806
807
808
809
810
811
812
813
    def _forward_in_proj(
        self, hidden_states: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        projected_states_qkvz, projected_states_ba = maybe_execute_in_parallel(
            lambda: self.in_proj_qkvz(hidden_states)[0],
            lambda: self.in_proj_ba(hidden_states)[0],
            self.events[0],
            self.events[1],
            self.aux_stream,
        )
        return projected_states_qkvz, projected_states_ba

814
    def _forward_core(
815
        self,
816
817
818
819
        mixed_qkv: torch.Tensor,
        b: torch.Tensor,
        a: torch.Tensor,
        core_attn_out: torch.Tensor,
820
821
822
823
824
    ):
        forward_context = get_forward_context()
        attn_metadata: AttentionMetadata = forward_context.attn_metadata

        if attn_metadata is None:
825
826
827
            # V1 profile run — warm up prefill kernels so that
            # autotuning completes before KV cache allocation.
            self._warmup_prefill_kernels(mixed_qkv)
828
829
830
831
832
            return

        assert isinstance(attn_metadata, dict)
        attn_metadata = attn_metadata[self.prefix]
        assert isinstance(attn_metadata, GDNAttentionMetadata)
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847

        if (
            self.enable_packed_recurrent_decode
            and attn_metadata.spec_sequence_masks is None
            and attn_metadata.num_prefills == 0
            and attn_metadata.num_decodes > 0
        ):
            return self._forward_core_decode_non_spec(
                mixed_qkv=mixed_qkv,
                b=b,
                a=a,
                core_attn_out=core_attn_out,
                attn_metadata=attn_metadata,
            )

848
849
850
851
        has_initial_state = attn_metadata.has_initial_state
        spec_query_start_loc = attn_metadata.spec_query_start_loc
        non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
        spec_sequence_masks = attn_metadata.spec_sequence_masks
852
853
        spec_token_indx = attn_metadata.spec_token_indx
        non_spec_token_indx = attn_metadata.non_spec_token_indx
854
855
        spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor  # noqa: E501
        non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor  # noqa: E501
856
        self_kv_cache = self.kv_cache[0]
857
858
        conv_state = self_kv_cache[0].transpose(-1, -2)
        ssm_state = self_kv_cache[1]
859
        num_actual_tokens = attn_metadata.num_actual_tokens
860
        num_accepted_tokens = attn_metadata.num_accepted_tokens
861

862
863
864
        mixed_qkv = mixed_qkv[:num_actual_tokens]
        b = b[:num_actual_tokens]
        a = a[:num_actual_tokens]
865

866
        # 1. Convolution sequence transformation
867
868
869
        conv_weights = self.conv1d.weight.view(
            self.conv1d.weight.size(0), self.conv1d.weight.size(2)
        )
870
871

        if spec_sequence_masks is not None:
872
            if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
873
874
875
                mixed_qkv_spec = mixed_qkv
                mixed_qkv_non_spec = None
            else:
876
877
                mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx)
                mixed_qkv_non_spec = mixed_qkv.index_select(0, non_spec_token_indx)
878
879
880
881
        else:
            mixed_qkv_spec = None
            mixed_qkv_non_spec = mixed_qkv

882
        # 1.1: Process the multi-query part
883
884
885
886
887
888
889
        if spec_sequence_masks is not None:
            mixed_qkv_spec = causal_conv1d_update(
                mixed_qkv_spec,
                conv_state,
                conv_weights,
                self.conv1d.bias,
                self.activation,
890
891
892
                conv_state_indices=spec_state_indices_tensor[:, 0][
                    : attn_metadata.num_spec_decodes
                ],
893
                num_accepted_tokens=num_accepted_tokens,
894
895
                query_start_loc=spec_query_start_loc,
                max_query_len=spec_state_indices_tensor.size(-1),
896
897
898
                validate_data=False,
            )

899
        # 1.2: Process the remaining part
900
        if attn_metadata.num_prefills > 0:
901
            mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1)
902
            # - "cache_indices" updates the conv_state cache in positions
903
            #   pointed to by "state_indices_tensor"
904
            mixed_qkv_non_spec = causal_conv1d_fn(
905
                mixed_qkv_non_spec_T,
906
907
908
909
910
911
912
                conv_weights,
                self.conv1d.bias,
                activation=self.activation,
                conv_states=conv_state,
                has_initial_state=has_initial_state,
                cache_indices=non_spec_state_indices_tensor,
                query_start_loc=non_spec_query_start_loc,
913
                metadata=attn_metadata,
914
915
916
917
918
919
920
921
            ).transpose(0, 1)
        elif attn_metadata.num_decodes > 0:
            mixed_qkv_non_spec = causal_conv1d_update(
                mixed_qkv_non_spec,
                conv_state,
                conv_weights,
                self.conv1d.bias,
                self.activation,
922
                conv_state_indices=non_spec_state_indices_tensor[
923
                    : attn_metadata.num_actual_tokens
924
                ],
925
926
927
928
929
                validate_data=True,
            )
        else:
            mixed_qkv_non_spec = None

930
        query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec)
931
        query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(
932
933
            mixed_qkv_non_spec
        )
934

935
936
937
        if attn_metadata.num_prefills > 0:
            g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)
            if spec_sequence_masks is not None:
938
939
                g_non_spec = g.index_select(1, non_spec_token_indx)
                beta_non_spec = beta.index_select(1, non_spec_token_indx)
940
941
942
            else:
                g_non_spec = g
                beta_non_spec = beta
943
        else:
944
945
            g_non_spec = None
            beta_non_spec = None
946

947
        # 2. Recurrent attention
948

949
        # 2.1: Process the multi-query part
950
        if spec_sequence_masks is not None:
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
            core_attn_out_spec, last_recurrent_state = (
                fused_sigmoid_gating_delta_rule_update(
                    A_log=self.A_log,
                    a=a,
                    b=b,
                    dt_bias=self.dt_bias,
                    q=query_spec,
                    k=key_spec,
                    v=value_spec,
                    initial_state=ssm_state,
                    inplace_final_state=True,
                    cu_seqlens=spec_query_start_loc[
                        : attn_metadata.num_spec_decodes + 1
                    ],
                    ssm_state_indices=spec_state_indices_tensor,
                    num_accepted_tokens=num_accepted_tokens,
                    use_qk_l2norm_in_kernel=True,
                )
969
            )
970
971
972
        else:
            core_attn_out_spec, last_recurrent_state = None, None

973
        # 2.2: Process the remaining part
974
        if attn_metadata.num_prefills > 0:
975
            initial_state = ssm_state[non_spec_state_indices_tensor].contiguous()
976
977
978
979
            initial_state[~has_initial_state, ...] = 0
            (
                core_attn_out_non_spec,
                last_recurrent_state,
980
            ) = self.chunk_gated_delta_rule(
981
982
983
984
985
986
987
988
989
990
991
992
                q=query_non_spec,
                k=key_non_spec,
                v=value_non_spec,
                g=g_non_spec,
                beta=beta_non_spec,
                initial_state=initial_state,
                output_final_state=True,
                cu_seqlens=non_spec_query_start_loc,
                use_qk_l2norm_in_kernel=True,
            )
            # Init cache
            ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(
993
994
                ssm_state.dtype
            )
995
996
        elif attn_metadata.num_decodes > 0:
            core_attn_out_non_spec, last_recurrent_state = (
997
998
999
1000
1001
                fused_sigmoid_gating_delta_rule_update(
                    A_log=self.A_log,
                    a=a,
                    b=b,
                    dt_bias=self.dt_bias,
1002
1003
1004
1005
1006
                    q=query_non_spec,
                    k=key_non_spec,
                    v=value_non_spec,
                    initial_state=ssm_state,
                    inplace_final_state=True,
1007
1008
1009
                    cu_seqlens=non_spec_query_start_loc[
                        : attn_metadata.num_decodes + 1
                    ],
1010
1011
                    ssm_state_indices=non_spec_state_indices_tensor,
                    use_qk_l2norm_in_kernel=True,
1012
1013
                )
            )
1014
1015
1016
        else:
            core_attn_out_non_spec, last_recurrent_state = None, None

1017
        # 3. Merge core attention output
1018
        if spec_sequence_masks is not None and core_attn_out_non_spec is not None:
1019
            merged_out = torch.empty(
1020
1021
1022
1023
                (1, num_actual_tokens, *core_attn_out_spec.shape[2:]),
                dtype=core_attn_out_non_spec.dtype,
                device=core_attn_out_non_spec.device,
            )
1024
1025
1026
            merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec)
            merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec)
            core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)
1027
        elif spec_sequence_masks is not None:
1028
            core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)
1029
        else:
1030
            core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)
1031

1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
    def _forward_core_decode_non_spec(
        self,
        mixed_qkv: torch.Tensor,
        b: torch.Tensor,
        a: torch.Tensor,
        core_attn_out: torch.Tensor,
        attn_metadata: GDNAttentionMetadata,
    ):
        """
        Core attention computation with a packed non-spec decode fast path.
        """
        non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor  # noqa: E501
1044
        self_kv_cache = self.kv_cache[0]
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
        conv_state = self_kv_cache[0].transpose(-1, -2)
        ssm_state = self_kv_cache[1]
        num_actual_tokens = attn_metadata.num_actual_tokens

        mixed_qkv = mixed_qkv[:num_actual_tokens]
        b = b[:num_actual_tokens]
        a = a[:num_actual_tokens]

        conv_weights = self.conv1d.weight.view(
            self.conv1d.weight.size(0), self.conv1d.weight.size(2)
        )
        mixed_qkv_non_spec = causal_conv1d_update(
            mixed_qkv,
            conv_state,
            conv_weights,
            self.conv1d.bias,
            self.activation,
            conv_state_indices=non_spec_state_indices_tensor[:num_actual_tokens],
            validate_data=False,
        )
        out_buf = core_attn_out[:num_actual_tokens].unsqueeze(1)
        fused_recurrent_gated_delta_rule_packed_decode(
            mixed_qkv=mixed_qkv_non_spec,
            a=a,
            b=b,
            A_log=self.A_log,
            dt_bias=self.dt_bias,
            scale=self.head_k_dim**-0.5,
            initial_state=ssm_state,
            out=out_buf,
            ssm_state_indices=non_spec_state_indices_tensor[:num_actual_tokens],
            use_qk_l2norm_in_kernel=True,
        )
        return

1080
1081
1082
1083
1084

class Qwen3NextAttention(nn.Module):
    def __init__(
        self,
        config: Qwen3NextConfig,
1085
1086
1087
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = config.head_dim or (self.hidden_size // self.num_heads)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.dual_chunk_attention_config = getattr(
1112
1113
            config, "dual_chunk_attention_config", None
        )
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
        self.attn_output_gate = getattr(config, "attn_output_gate", True)

        self.qkv_proj = QKVParallelLinear(
            config.hidden_size,
            self.head_dim,
            self.total_num_heads * (1 + self.attn_output_gate),
            self.total_num_kv_heads,
            bias=getattr(config, "qkv_bias", False),
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            config.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        self.rotary_emb = get_rope(
            head_size=self.head_dim,
            max_position=config.max_position_embeddings,
1137
            rope_parameters=config.rope_parameters,
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
            dual_chunk_attention_config=self.dual_chunk_attention_config,
        )

        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            **{
                "layer_idx": extract_layer_index(prefix),
1151
1152
1153
1154
                "dual_chunk_attention_config": self.dual_chunk_attention_config,
            }
            if self.dual_chunk_attention_config
            else {},
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
        )

        self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
        self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        output: torch.Tensor,
        hidden_states: torch.Tensor,
    ):
        qkv, _ = self.qkv_proj(hidden_states)

        if self.attn_output_gate:
            q_gate, k, v = qkv.split(
1170
1171
                [self.q_size * 2, self.kv_size, self.kv_size], dim=-1
            )
1172
1173
1174
1175
1176
1177
            orig_shape = q_gate.shape[:-1]
            q_gate = q_gate.view(*orig_shape, self.num_heads, -1)
            q, gate = torch.chunk(q_gate, 2, dim=-1)
            q = q.reshape(*orig_shape, -1)
            gate = gate.reshape(*orig_shape, -1)
        else:
1178
            q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
1179
1180

        q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view(
1181
1182
            -1, self.num_heads * self.head_dim
        )
1183
        k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view(
1184
1185
            -1, self.num_kv_heads * self.head_dim
        )
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200

        q, k = self.rotary_emb(positions, q, k)

        attn_output = self.attn(q, k, v)

        if self.attn_output_gate:
            gate = torch.sigmoid(gate)
            attn_output = attn_output * gate

        output[:], _ = self.o_proj(attn_output)


class Qwen3NextDecoderLayer(nn.Module):
    def __init__(
        self,
1201
        vllm_config: VllmConfig,
1202
1203
1204
1205
        layer_type: str,
        prefix: str = "",
    ) -> None:
        super().__init__()
1206
1207
1208
1209
1210

        config = vllm_config.model_config.hf_config
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
1211
1212
1213
1214
1215
1216
1217

        self.layer_type = layer_type
        self.layer_idx = extract_layer_index(prefix)

        if self.layer_type == "linear_attention":
            self.linear_attn = Qwen3NextGatedDeltaNet(
                config,
1218
                vllm_config=vllm_config,
1219
1220
                prefix=f"{prefix}.linear_attn",
            )
1221
1222
1223
1224
1225
1226
        elif self.layer_type == "full_attention":
            self.self_attn = Qwen3NextAttention(
                config,
                model_config=model_config,
                cache_config=cache_config,
                quant_config=quant_config,
1227
                prefix=f"{prefix}.self_attn",
1228
1229
1230
1231
            )
        else:
            raise ValueError(f"Invalid layer_type {self.layer_type}")

1232
1233
1234
        mlp_only_layers = (
            [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
        )
1235
        if (self.layer_idx not in mlp_only_layers) and (
1236
1237
1238
            config.num_experts > 0
            and (self.layer_idx + 1) % config.decoder_sparse_step == 0
        ):
1239
            self.mlp = Qwen3NextSparseMoeBlock(
1240
                vllm_config=vllm_config,
1241
1242
1243
1244
1245
1246
1247
1248
                prefix=f"{prefix}.mlp",
            )
        else:
            self.mlp = Qwen3NextMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
1249
                prefix=f"{prefix}.mlp",
1250
1251
            )

1252
1253
1254
        self.input_layernorm = Qwen3NextRMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
1255
        self.post_attention_layernorm = Qwen3NextRMSNorm(
1256
1257
            config.hidden_size, eps=config.rms_norm_eps
        )
1258
1259
1260
1261
1262
1263
1264

        self.layer_scale = getattr(config, "layer_scale", False)
        if self.layer_scale:
            self.attn_layer_scale = torch.nn.Parameter(
                torch.zeros(
                    1,
                    1,
1265
                    config.hidden_size,
1266
1267
                ),
            )
1268
1269
1270
1271
            self.ffn_layer_scale = torch.nn.Parameter(
                torch.zeros(
                    1,
                    1,
1272
                    config.hidden_size,
1273
1274
                ),
            )
1275
1276
1277
1278

    def forward(
        self,
        hidden_states: torch.Tensor,
1279
        residual: torch.Tensor | None,
1280
1281
1282
1283
1284
1285
1286
        positions: torch.Tensor = None,
        **kwargs: object,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
1287
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307

        self_attention_output = torch.empty_like(hidden_states)
        if self.layer_type == "linear_attention":
            self.linear_attn(
                hidden_states=hidden_states,
                output=self_attention_output,
            )
        elif self.layer_type == "full_attention":
            self.self_attn(
                hidden_states=hidden_states,
                output=self_attention_output,
                positions=positions,
            )
        else:
            raise ValueError("Invalid layer_type")
        hidden_states = self_attention_output

        if self.layer_scale:
            if len(hidden_states.shape) == 2:
                hidden_states = hidden_states * (
1308
1309
                    self.attn_layer_scale.to(hidden_states.dtype)[0] + 1
                )
1310
1311
            else:
                hidden_states = hidden_states * (
1312
1313
                    self.attn_layer_scale.to(hidden_states.dtype) + 1
                )
1314
1315

        # Fully Connected
1316
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
1317
1318
1319
1320
1321
        hidden_states = self.mlp(hidden_states)

        if self.layer_scale:
            if len(hidden_states.shape) == 2:
                hidden_states = hidden_states * (
1322
1323
                    self.ffn_layer_scale.to(hidden_states.dtype)[0] + 1
                )
1324
            else:
1325
                assert len(hidden_states.shape) == len(self.ffn_layer_scale.shape), (
1326
1327
1328
                    f"shape must be the same {len(hidden_states.shape)}, "
                    f"{len(self.ffn_layer_scale.shape)}"
                )
1329
                hidden_states = hidden_states * (
1330
1331
                    self.ffn_layer_scale.to(hidden_states.dtype) + 1
                )
1332
1333
1334
1335
1336
1337
1338
1339
1340

        return hidden_states, residual


@support_torch_compile
class Qwen3NextModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

1341
        config: Qwen3NextConfig = vllm_config.model_config.hf_text_config
1342
        parallel_config = vllm_config.parallel_config
1343

1344
1345
1346
1347
        eplb_config = parallel_config.eplb_config
        self.num_redundant_experts = eplb_config.num_redundant_experts

        self.config = config
1348
1349

        self.vocab_size = config.vocab_size
1350
1351
1352
1353
1354
1355
1356
1357

        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
        )

        def get_layer(prefix: str):
            return Qwen3NextDecoderLayer(
1358
                vllm_config,
1359
1360
1361
1362
1363
                layer_type=config.layer_types[extract_layer_index(prefix)],
                prefix=prefix,
            )

        self.start_layer, self.end_layer, self.layers = make_layers(
1364
1365
1366
1367
1368
            config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
        )
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
1369

1370
        if get_pp_group().is_last_rank:
1371
            self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1372
1373
        else:
            self.norm = PPMissingLayer()
1374

1375
1376
        self.aux_hidden_state_layers: tuple[int, ...] = ()

1377
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
1378
1379
1380
1381
        return self.embed_tokens(input_ids)

    def forward(
        self,
1382
        input_ids: torch.Tensor | None,
1383
        positions: torch.Tensor,
1384
1385
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1386
    ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
1387
1388
1389
1390
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
1391
                hidden_states = self.embed_input_ids(input_ids)
1392
1393
1394
1395
1396
1397
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

1398
1399
1400
1401
1402
1403
1404
1405
1406
        aux_hidden_states = []
        for layer_idx, layer in enumerate(
            islice(self.layers, self.start_layer, self.end_layer),
            start=self.start_layer,
        ):
            if layer_idx in self.aux_hidden_state_layers:
                aux_hidden_states.append(
                    hidden_states + residual if residual is not None else hidden_states
                )
1407
1408
1409
1410
1411
1412
1413
            hidden_states, residual = layer(
                positions=positions,
                hidden_states=hidden_states,
                residual=residual,
            )

        if not get_pp_group().is_last_rank:
1414
1415
1416
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
1417
        hidden_states, _ = self.norm(hidden_states, residual)
1418
1419
        if aux_hidden_states:
            return hidden_states, aux_hidden_states
1420
1421
1422
1423
1424
        return hidden_states

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
1425
        return SharedFusedMoE.make_expert_params_mapping(
1426
            self,
1427
1428
1429
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
1430
            num_experts=getattr(self.config, "num_experts", 0),
1431
1432
            num_redundant_experts=self.num_redundant_experts,
        )
1433

1434
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        expert_params_mapping = self.get_expert_mapping()
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue

            if name.startswith("mtp."):
                continue

1454
1455
1456
1457
1458
1459
            # Remapping the name of FP8 kv-scale.
            if name.endswith("scale"):
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue

                if "mlp.experts" in name:
                    continue

                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
                # name = apply_attn_prefix(name, params_dict)
                if name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
                    # Skip loading extra bias for GPTQ models.
1491
1492
1493
                    if (
                        name.endswith(".bias") or name.endswith("_bias")
                    ) and name not in params_dict:
1494
                        continue
1495
1496
                    if name not in params_dict:
                        continue
1497
1498
                    param = params_dict[name]
                    weight_loader = param.weight_loader
1499
1500
1501
1502
1503
1504
1505
                    weight_loader(
                        param,
                        loaded_weight,
                        name,
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
1506
1507
1508
1509
1510
1511
1512
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
                    if is_pp_missing_parameter(name, self):
                        continue
1513
1514
1515
1516
1517
                    if name not in params_dict:
                        logger.warning_once(
                            f"Parameter {name} not found in params_dict, skip loading"
                        )
                        continue
1518
                    param = params_dict[name]
1519
1520
1521
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
1522
1523
1524
1525
1526
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
class QwenNextMixtureOfExperts(MixtureOfExperts):
    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
        for layer in self.model.layers:
            if isinstance(layer.mlp, Qwen3NextSparseMoeBlock):
                moe = layer.mlp
                moe.n_local_physical_experts = num_local_physical_experts
                moe.n_physical_experts = num_physical_experts
                moe.n_redundant_experts = self.num_redundant_experts
                moe.experts.update_expert_map()

    def set_moe_parameters(self):
        self.expert_weights = []

        self.moe_layers = []
        example_moe = None
        for layer in self.model.layers:
            if isinstance(layer, Qwen3NextDecoderLayer) and isinstance(
                layer.mlp, Qwen3NextSparseMoeBlock
            ):
                example_moe = layer.mlp
                self.moe_layers.append(layer.mlp.experts)

1557
1558
        if example_moe is None:
            raise RuntimeError("No Qwen3Next layer found in the model.layers.")
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570

        # Set MoE hyperparameters
        self.num_moe_layers = len(self.moe_layers)
        self.num_expert_groups = 1
        self.num_shared_experts = 0
        self.num_logical_experts = example_moe.n_logical_experts
        self.num_physical_experts = example_moe.n_physical_experts
        self.num_local_physical_experts = example_moe.n_local_physical_experts
        self.num_routed_experts = example_moe.n_routed_experts
        self.num_redundant_experts = example_moe.n_redundant_experts


1571
class Qwen3NextForCausalLM(
1572
1573
1574
1575
1576
1577
    nn.Module,
    HasInnerState,
    SupportsLoRA,
    SupportsPP,
    QwenNextMixtureOfExperts,
    IsHybrid,
1578
):
1579
1580
1581
1582
1583
1584
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
1585
        "gate_up_proj": ["gate_proj", "up_proj"],
1586
1587
        "in_proj_qkvz": ["in_proj_qkvz"],
        "in_proj_ba": ["in_proj_ba"],
1588
1589
1590
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1591
        config = vllm_config.model_config.hf_text_config
1592
1593
1594
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
1595

1596
        scheduler_config = vllm_config.scheduler_config
1597
1598
1599
1600
1601
        if cache_config.mamba_cache_mode == "all":
            raise NotImplementedError(
                "Qwen3Next currently does not support 'all' prefix caching, "
                "please use '--mamba-cache-mode=align' instead"
            )
1602
1603
1604
1605
1606
        self.quant_config = vllm_config.quant_config

        super().__init__()
        self.config = config
        self.scheduler_config = scheduler_config
1607
1608
1609
        self.model = Qwen3NextModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
1610

1611
        self.lm_head = ParallelLMHead(
1612
            config.vocab_size,
1613
            config.hidden_size,
1614
1615
            prefix=maybe_prefix(prefix, "lm_head"),
        )
1616
        self.logits_processor = LogitsProcessor(config.vocab_size)
1617
        self.make_empty_intermediate_tensors = (
1618
1619
            self.model.make_empty_intermediate_tensors
        )
1620
1621

        # Set MoE hyperparameters
1622
        self.set_moe_parameters()
1623

1624
1625
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
1626
1627
1628

    def forward(
        self,
1629
        input_ids: torch.Tensor | None,
1630
        positions: torch.Tensor,
1631
1632
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1633
1634
        **kwargs: object,
    ):
1635
1636
1637
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
1638
1639
1640
1641
1642
1643
1644
1645
1646

        return hidden_states

    @classmethod
    def get_mamba_state_dtype_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[torch.dtype, torch.dtype]:
        return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
1647
1648
1649
            vllm_config.model_config.dtype,
            vllm_config.cache_config.mamba_cache_dtype,
            vllm_config.cache_config.mamba_ssm_cache_dtype,
1650
        )
1651
1652
1653

    @classmethod
    def get_mamba_state_shape_from_config(
1654
        cls, vllm_config: "VllmConfig"
1655
1656
    ) -> tuple[tuple[int, int], tuple[int, int]]:
        parallel_config = vllm_config.parallel_config
1657
        hf_config = vllm_config.model_config.hf_text_config
1658
        tp_size = parallel_config.tensor_parallel_size
1659
1660
1661
1662
1663
        num_spec = (
            vllm_config.speculative_config.num_speculative_tokens
            if vllm_config.speculative_config
            else 0
        )
1664
        return MambaStateShapeCalculator.gated_delta_net_state_shape(
1665
1666
1667
1668
1669
1670
1671
1672
            tp_size,
            hf_config.linear_num_key_heads,
            hf_config.linear_num_value_heads,
            hf_config.linear_key_head_dim,
            hf_config.linear_value_head_dim,
            hf_config.linear_conv_kernel_dim,
            num_spec,
        )
1673

1674
1675
1676
1677
    @classmethod
    def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
        return MambaStateCopyFuncCalculator.gated_delta_net_state_copy_func()

1678
1679
1680
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1681
    ) -> torch.Tensor | None:
1682
        return self.logits_processor(self.lm_head, hidden_states)
1683

1684
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=["mtp."],
        )
        return loader.load_weights(weights)

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        return self.model.get_expert_mapping()


1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
def gdn_in_proj(
    hidden_states: torch.Tensor,
    qkvz_output_size: int,
    ba_output_size: int,
    layer_name: str,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Custom op for the input projection.
    """
    forward_context: ForwardContext = get_forward_context()
    self = forward_context.no_compile_layers[layer_name]
    return self._forward_in_proj(hidden_states)


def gdn_in_proj_fake(
    hidden_states: torch.Tensor,
    qkvz_output_size: int,
    ba_output_size: int,
    layer_name: str,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Fake implementation for torch.compile."""
    return hidden_states.new_empty(
        hidden_states.shape[0], qkvz_output_size
    ), hidden_states.new_empty(hidden_states.shape[0], ba_output_size)


1721
1722
1723
1724
1725
def gdn_attention_core(
    mixed_qkv: torch.Tensor,
    b: torch.Tensor,
    a: torch.Tensor,
    core_attn_out: torch.Tensor,
1726
1727
    layer_name: str,
) -> None:
1728
1729
1730
1731
1732
    """
    Custom op for the core attention computation.
    Only handles the convolution + recurrent attention part.
    Input/output projections are handled outside this op.
    """
1733
1734
    forward_context: ForwardContext = get_forward_context()
    self = forward_context.no_compile_layers[layer_name]
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
    self._forward_core(
        mixed_qkv=mixed_qkv,
        b=b,
        a=a,
        core_attn_out=core_attn_out,
    )


def gdn_attention_core_fake(
    mixed_qkv: torch.Tensor,
    b: torch.Tensor,
    a: torch.Tensor,
    core_attn_out: torch.Tensor,
1748
1749
    layer_name: str,
) -> None:
1750
    """Fake implementation for torch.compile."""
1751
1752
1753
    return


1754
1755
1756
1757
1758
1759
direct_register_custom_op(
    op_name="gdn_in_proj",
    op_func=gdn_in_proj,
    fake_impl=gdn_in_proj_fake,
)

1760
direct_register_custom_op(
1761
1762
1763
1764
    op_name="gdn_attention_core",
    op_func=gdn_attention_core,
    mutates_args=["core_attn_out"],
    fake_impl=gdn_attention_core_fake,
1765
1766
1767
1768
1769
1770
)


@triton.jit
def fused_gdn_gating_kernel(
    g,
1771
    beta_output,
1772
1773
    A_log,
    a,
1774
    b,
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
    dt_bias,
    seq_len,
    NUM_HEADS: tl.constexpr,
    beta: tl.constexpr,
    threshold: tl.constexpr,
    BLK_HEADS: tl.constexpr,
):
    i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS)
    off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off
    mask = head_off < NUM_HEADS
    blk_A_log = tl.load(A_log + head_off, mask=mask)
    blk_a = tl.load(a + off, mask=mask)
1788
    blk_b = tl.load(b + off, mask=mask)
1789
1790
1791
    blk_bias = tl.load(dt_bias + head_off, mask=mask)
    # If the model is loaded in fp16, without the .float() here, A might be -inf
    x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)
1792
1793
1794
    softplus_x = tl.where(
        beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x
    )
1795
1796
    blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
    tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
1797
    # compute beta_output = sigmoid(b)
1798
1799
1800
1801
    blk_beta_output = tl.sigmoid(blk_b.to(tl.float32))
    tl.store(
        beta_output + off, blk_beta_output.to(beta_output.dtype.element_ty), mask=mask
    )
1802
1803
1804
1805
1806


def fused_gdn_gating(
    A_log: torch.Tensor,
    a: torch.Tensor,
1807
    b: torch.Tensor,
1808
1809
1810
    dt_bias: torch.Tensor,
    beta: float = 1.0,
    threshold: float = 20.0,
1811
1812
1813
1814
1815
1816
1817
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Fused computation of g and beta for Gated Delta Net.
    g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
    beta_output = b.sigmoid()
    TODO maybe use torch.compile to replace this triton kernel
    """
1818
1819
1820
    batch, num_heads = a.shape
    seq_len = 1
    grid = (batch, seq_len, triton.cdiv(num_heads, 8))
1821
    g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
1822
    beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device)
1823
    fused_gdn_gating_kernel[grid](
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
        g,
        beta_output,
        A_log,
        a,
        b,
        dt_bias,
        seq_len,
        num_heads,
        beta,
        threshold,
        8,
        num_warps=1,
1836
    )
1837
    return g, beta_output