"requirements-rocm.txt" did not exist on "bb1ba58f064731b179d586ae32fdaaaea439098d"
qwen3_next.py 62.3 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Qwen3Next model."""
4

5
from collections.abc import Iterable
6
from itertools import islice
7
8
9
10
11
12

import torch
from einops import rearrange
from torch import nn
from transformers.activations import ACT2FN

13
from vllm import envs
14
from vllm.compilation.decorators import support_torch_compile
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from vllm.config import (
    CacheConfig,
    ModelConfig,
    SpeculativeConfig,
    VllmConfig,
    get_current_vllm_config,
)
from vllm.distributed import (
    divide,
    get_ep_group,
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
30
31
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
32
from vllm.model_executor.custom_op import CustomOp
33
from vllm.model_executor.layers.attention import Attention
34
from vllm.model_executor.layers.fla.ops import (
35
36
37
    chunk_gated_delta_rule as fla_chunk_gated_delta_rule,
)
from vllm.model_executor.layers.fla.ops import (
38
    fused_recurrent_gated_delta_rule_packed_decode,
39
    fused_sigmoid_gating_delta_rule_update,
40
)
41
from vllm.model_executor.layers.fla.ops.chunk import l2norm_fwd
42
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
43
44
45
46
from vllm.model_executor.layers.layernorm import (
    GemmaRMSNorm as Qwen3NextRMSNorm,
)
from vllm.model_executor.layers.layernorm import RMSNormGated
47
48
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
49
    MergedColumnParallelLinear,
50
51
52
53
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
54
55
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.abstract import MambaBase
56
from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weight_loader
57
from vllm.model_executor.layers.mamba.mamba_utils import (
58
59
    MambaStateCopyFunc,
    MambaStateCopyFuncCalculator,
60
61
62
    MambaStateDtypeCalculator,
    MambaStateShapeCalculator,
)
63
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
64
65
66
    causal_conv1d_fn,
    causal_conv1d_update,
)
67
68
69
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
70
71
72
    ParallelLMHead,
    VocabParallelEmbedding,
)
73
from vllm.model_executor.model_loader.weight_utils import (
74
    default_weight_loader,
75
    maybe_remap_kv_scale_name,
76
77
    sharded_weight_loader,
)
78
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
79
from vllm.model_executor.models.utils import sequence_parallel_chunk
80
81
82
83
84
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import Qwen3NextConfig
from vllm.triton_utils import tl, triton
85
from vllm.utils.torch_utils import direct_register_custom_op
86
from vllm.v1.attention.backend import AttentionMetadata
87
88
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata

89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from .interfaces import (
    HasInnerState,
    IsHybrid,
    MixtureOfExperts,
    SupportsLoRA,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
105
106
107
108
109
110

logger = init_logger(__name__)

KVCache = tuple[torch.Tensor, torch.Tensor]


111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def fi_chunk_gated_delta_rule(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    g: torch.Tensor,
    beta: torch.Tensor,
    initial_state: torch.Tensor,
    output_final_state: bool,
    cu_seqlens: torch.LongTensor | None = None,
    use_qk_l2norm_in_kernel: bool = True,
):
    from flashinfer.gdn_prefill import (
        chunk_gated_delta_rule as chunk_gated_delta_rule_fi,
    )

    if use_qk_l2norm_in_kernel:
        q = l2norm_fwd(q)
        k = l2norm_fwd(k)

    # use flashinfer implementation
    q = q.squeeze(0).contiguous()
    k = k.squeeze(0).contiguous()
    v = v.squeeze(0).contiguous()

    g = g.squeeze(0).contiguous()
    beta = beta.squeeze(0).contiguous()
    fi_state = initial_state.to(torch.float32)
    fi_g = g.to(torch.float32)
    fi_beta = beta.to(torch.float32)
140
    output, final_state = chunk_gated_delta_rule_fi(
141
142
143
144
145
146
147
148
149
        q=q,
        k=k,
        v=v,
        g=torch.exp(fi_g),
        beta=fi_beta,
        initial_state=fi_state,
        output_final_state=output_final_state,
        cu_seqlens=cu_seqlens,
    )
150
151
    # Unsqueeze back to 4D (1, L, H, D) to match fla output format
    return output.unsqueeze(0), final_state
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214


@CustomOp.register("chunk_gated_delta_rule")
class ChunkGatedDeltaRule(CustomOp):
    def __init__(self) -> None:
        super().__init__()
        if current_platform.is_cuda() and current_platform.is_device_capability(90):
            logger.info_once(
                "Using FlashInfer GDN prefill kernel on CUDA compute capability 90"
            )
            self._forward_method = self.forward_cuda
        else:
            self._forward_method = self.forward_native

    def forward_cuda(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        g: torch.Tensor,
        beta: torch.Tensor,
        initial_state: torch.Tensor,
        output_final_state: bool,
        cu_seqlens: torch.LongTensor | None = None,
        use_qk_l2norm_in_kernel: bool = True,
    ):
        return fi_chunk_gated_delta_rule(
            q=q,
            k=k,
            v=v,
            g=g,
            beta=beta,
            initial_state=initial_state,
            output_final_state=output_final_state,
            cu_seqlens=cu_seqlens,
            use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
        )

    def forward_native(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        g: torch.Tensor,
        beta: torch.Tensor,
        initial_state: torch.Tensor,
        output_final_state: bool,
        cu_seqlens: torch.LongTensor | None = None,
        use_qk_l2norm_in_kernel: bool = True,
    ):
        return fla_chunk_gated_delta_rule(
            q=q,
            k=k,
            v=v,
            g=g,
            beta=beta,
            initial_state=initial_state,
            output_final_state=output_final_state,
            cu_seqlens=cu_seqlens,
            use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
        )


215
class Qwen3NextSparseMoeBlock(nn.Module):
216
    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
217
        super().__init__()
218

219
        config = vllm_config.model_config.hf_text_config
220
221
222
        parallel_config = vllm_config.parallel_config
        quant_config = vllm_config.quant_config

223
224
225
        self.tp_size = get_tensor_model_parallel_world_size()

        self.ep_group = get_ep_group().device_group
226
        self.ep_rank = get_ep_group().rank_in_group
227
228
229
        self.ep_size = self.ep_group.size()
        self.n_routed_experts = config.num_experts

230
231
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe

232
233
234
        if self.tp_size > config.num_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
235
236
                f"the number of experts {config.num_experts}."
            )
237
238
239
240

        # Load balancing settings.
        vllm_config = get_current_vllm_config()
        eplb_config = vllm_config.parallel_config.eplb_config
241
        self.enable_eplb = parallel_config.enable_eplb
242
243
244

        self.n_logical_experts = self.n_routed_experts
        self.n_redundant_experts = eplb_config.num_redundant_experts
245
        self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
246
247
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

248
249
250
251
252
253
254
255
256
        self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
        self.physical_expert_end = (
            self.physical_expert_start + self.n_local_physical_experts
        )

        self.gate = ReplicatedLinear(
            config.hidden_size,
            config.num_experts,
            bias=False,
257
            quant_config=None,
258
259
            prefix=f"{prefix}.gate",
        )
260

261
262
263
264
265
266
267
        self.shared_expert_gate = ReplicatedLinear(
            config.hidden_size,
            1,
            bias=False,
            quant_config=None,
            prefix=f"{prefix}.shared_expert_gate",
        )
268

269
270
271
272
273
274
        if config.shared_expert_intermediate_size > 0:
            self.shared_expert = Qwen3NextMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.shared_expert_intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
275
276
                reduce_results=False,
                expert_gate=self.shared_expert_gate,
277
                prefix=f"{prefix}.shared_expert",
278
279
280
            )
        else:
            self.shared_expert = None
281
282
283

        self.experts = SharedFusedMoE(
            shared_experts=self.shared_expert,
284
            gate=self.gate,
285
286
287
288
289
            num_experts=self.n_routed_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            reduce_results=False,
290
            renormalize=getattr(config, "norm_topk_prob", True),
291
292
293
294
295
296
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
            is_sequence_parallel=self.is_sequence_parallel,
        )
297
298
299
300

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
301
        num_tokens, hidden_dim = hidden_states.shape
302
303
        hidden_states = hidden_states.view(-1, hidden_dim)

304
305
306
        if self.is_sequence_parallel:
            hidden_states = sequence_parallel_chunk(hidden_states)

307
308
309
310
311
312
313
314
315
316
317
        if self.experts.is_internal_router:
            # In this case, the gate/router runs inside the FusedMoE class
            final_hidden_states = self.experts(
                hidden_states=hidden_states, router_logits=hidden_states
            )
        else:
            # router_logits: (num_tokens, n_experts)
            router_logits, _ = self.gate(hidden_states)
            final_hidden_states = self.experts(
                hidden_states=hidden_states, router_logits=router_logits
            )
318

319
320
        if self.shared_expert is not None:
            final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
321
322
323

        if self.is_sequence_parallel:
            final_hidden_states = tensor_model_parallel_all_gather(
324
325
                final_hidden_states, 0
            )
326
327
            final_hidden_states = final_hidden_states[:num_tokens]
        elif self.tp_size > 1:
328
            final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(  # noqa E501
329
330
                final_hidden_states
            )
331
332
333
334
335
336
337

        return final_hidden_states.view(orig_shape)


class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
    @property
    def mamba_type(self) -> str:
338
        return "gdn_attention"
339
340
341

    def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
        return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
342
343
344
            self.model_config.dtype,
            self.cache_config.mamba_cache_dtype,
            self.cache_config.mamba_ssm_cache_dtype,
345
        )
346
347
348

    def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
        return MambaStateShapeCalculator.gated_delta_net_state_shape(
349
350
351
352
353
354
355
356
            self.tp_size,
            self.num_k_heads,
            self.num_v_heads,
            self.head_k_dim,
            self.head_v_dim,
            self.conv_kernel_size,
            self.num_spec,
        )
357
358
359
360

    def __init__(
        self,
        config: Qwen3NextConfig,
361
362
363
364
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        speculative_config: SpeculativeConfig | None = None,
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.hidden_size = config.hidden_size
        self.num_v_heads = config.linear_num_value_heads
        self.num_k_heads = config.linear_num_key_heads
        self.head_k_dim = config.linear_key_head_dim
        self.head_v_dim = config.linear_value_head_dim
        self.key_dim = self.head_k_dim * self.num_k_heads
        self.value_dim = self.head_v_dim * self.num_v_heads

        self.conv_kernel_size = config.linear_conv_kernel_dim
        self.layer_idx = extract_layer_index(prefix)
        self.activation = config.hidden_act
        self.act = ACT2FN[config.hidden_act]
        self.layer_norm_epsilon = config.rms_norm_eps
        self.prefix = prefix

        self.config = config
        self.model_config = model_config
        self.cache_config = cache_config
        self.quant_config = quant_config
        self.speculative_config = speculative_config
390
391
392
393
394
        self.num_spec = (
            self.speculative_config.num_speculative_tokens
            if self.speculative_config
            else 0
        )
395
396
397
398
399
400
401
402
403
404
405
406

        # QKV
        self.conv_dim = self.key_dim * 2 + self.value_dim
        self.conv1d = ColumnParallelLinear(
            input_size=self.conv_kernel_size,
            output_size=self.conv_dim,
            bias=False,
            prefix=f"{prefix}.conv1d",
        )
        self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)

        # projection of the input hidden states
407
408
409
410
411
412
        # Qwen3-Next and Qwen3.5 has a different qkv_proj layout,
        # we need to create qkvz_proj adaptively here.
        self.in_proj_qkvz = self.create_qkvz_proj(
            hidden_size=self.hidden_size,
            key_dim=self.key_dim,
            value_dim=self.value_dim,
413
            quant_config=quant_config,
414
415
416
            prefix=f"{prefix}.in_proj_qkvz",
        )
        # ba_proj doesn't support blockwise fp8 quantization.
417
418
419
420
421
        # Qwen3-Next and Qwen3.5 have different in_proj_ba checkpoint
        # layouts, so we use a factory method to create the projection.
        self.in_proj_ba = self.create_ba_proj(
            hidden_size=self.hidden_size,
            num_v_heads=self.num_v_heads,
422
423
            quant_config=quant_config,
            prefix=f"{prefix}.in_proj_ba",
424
425
426
427
428
429
430
        )

        query_key_settings = (self.key_dim, 0, False)
        value_settings = (self.value_dim, 0, False)

        delattr(self.conv1d.weight, "weight_loader")
        set_weight_attrs(
431
432
433
434
435
436
437
438
439
440
441
442
443
            self.conv1d.weight,
            {
                "weight_loader": mamba_v2_sharded_weight_loader(
                    [
                        query_key_settings,
                        query_key_settings,
                        value_settings,
                    ],
                    self.tp_size,
                    self.tp_rank,
                )
            },
        )
444

445
        # selective projection used to make dt, B and C input dependent
446
447
448
449

        # time step projection (discretization)
        # instantiate once and copy inv_dt in init_weights of PretrainedModel
        self.dt_bias = nn.Parameter(
450
451
            torch.ones(self.num_v_heads // self.tp_size),
        )
452
453
454
        self.A_log = nn.Parameter(
            torch.empty(
                divide(self.num_v_heads, self.tp_size),
455
456
            )
        )
457

458
459
        set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)})
        set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
460
461
462
463
464
465

        self.norm = RMSNormGated(
            self.head_v_dim,
            eps=self.layer_norm_epsilon,
            group_size=None,
            norm_before_gate=True,
466
            device=current_platform.current_device(),
467
468
        )

469
470
471
472
473
474
475
476
        self.out_proj = RowParallelLinear(
            self.value_dim,
            self.hidden_size,
            bias=False,
            input_is_parallel=True,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )
477

478
        self.chunk_gated_delta_rule = ChunkGatedDeltaRule()
479
480
481
        self.enable_packed_recurrent_decode = (
            envs.VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE
        )
482

483
484
485
486
487
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
    def create_qkvz_proj(
        self,
        hidden_size: int,
        key_dim: int,
        value_dim: int,
        quant_config: QuantizationConfig | None,
        prefix: str,
    ) -> MergedColumnParallelLinear:
        return MergedColumnParallelLinear(
            input_size=hidden_size,
            output_sizes=[sum((key_dim, key_dim, value_dim, value_dim))],
            bias=False,
            quant_config=quant_config,
            prefix=prefix,
        )

504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
    def create_ba_proj(
        self,
        hidden_size: int,
        num_v_heads: int,
        quant_config: QuantizationConfig | None,
        prefix: str,
    ) -> MergedColumnParallelLinear:
        # Qwen3-Next stores in_proj_ba as a single fused weight with an
        # interleaved GQA layout: [b_g0, a_g0, b_g1, a_g1, ...] where
        # each group corresponds to a key-head group. We must use a single
        # output shard so that ColumnParallel sharding preserves this
        # interleaved structure across TP ranks.
        # Qwen3.5 overrides this to use [num_v_heads, num_v_heads] since
        # its checkpoint has separate in_proj_b and in_proj_a weights.
        return MergedColumnParallelLinear(
            input_size=hidden_size,
            output_sizes=[num_v_heads * 2],
            bias=False,
            quant_config=quant_config,
            prefix=prefix,
        )

526
527
    def fix_query_key_value_ordering(
        self,
528
529
        mixed_qkvz: torch.Tensor,
        mixed_ba: torch.Tensor,
530
531
532
533
534
535
    ):
        """
        Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.
        """
        new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + (
            self.num_k_heads // self.tp_size,
536
537
538
539
540
541
542
            (
                self.head_k_dim
                + self.head_k_dim
                + (self.head_v_dim + self.head_v_dim)
                * self.num_v_heads
                // self.num_k_heads
            ),
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
        )
        new_tensor_shape_ba = mixed_qkvz.size()[:-1] + (
            self.num_k_heads // self.tp_size,
            2 * self.num_v_heads // self.num_k_heads,
        )

        mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz)
        mixed_ba = mixed_ba.view(*new_tensor_shape_ba)

        split_arg_list_qkvz = [
            self.head_k_dim,
            self.head_k_dim,
            (self.num_v_heads // self.num_k_heads * self.head_v_dim),
            (self.num_v_heads // self.num_k_heads * self.head_v_dim),
        ]
        split_arg_list_ba = [
            self.num_v_heads // self.num_k_heads,
560
            self.num_v_heads // self.num_k_heads,
561
562
563
564
565
        ]

        # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)]
        # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn],
        #  [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng]
566
        (query, key, value, z) = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=2)
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
        (b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2)

        # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn]
        value = value.reshape(value.size(0), -1, self.head_v_dim)
        z = z.reshape(z.size(0), -1, self.head_v_dim)
        b = b.reshape(b.size(0), self.num_v_heads // self.tp_size)
        a = a.reshape(a.size(0), self.num_v_heads // self.tp_size)

        return query, key, value, z, b, a

    def rearrange_mixed_qkv(self, mixed_qkv):
        if mixed_qkv is None:
            return None, None, None
        query, key, value = torch.split(
            mixed_qkv,
            [
                self.key_dim // self.tp_size,
                self.key_dim // self.tp_size,
                self.value_dim // self.tp_size,
            ],
            dim=-1,
        )
        query, key = map(
590
591
592
593
            lambda x: rearrange(x, "l (h d) -> 1 l h d", d=self.head_k_dim),
            (query, key),
        )
        value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim)
594
        return query.contiguous(), key.contiguous(), value.contiguous()
595
596
597
598
599
600

    def forward(
        self,
        hidden_states: torch.Tensor,
        output: torch.Tensor,
    ):
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
        """
        Forward pass with three parts:
        1. Input projection
        2. Core attention (custom op)
        3. Output projection
        """
        num_tokens = hidden_states.size(0)

        # ============================================================
        # Part 1: Input Projection
        # ============================================================
        projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
        projected_states_ba, _ = self.in_proj_ba(hidden_states)
        query, key, value, z, b, a = self.fix_query_key_value_ordering(
            projected_states_qkvz, projected_states_ba
        )
        query, key, value = map(
            lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value)
        )
        mixed_qkv = torch.cat((query, key, value), dim=-1)

        # ============================================================
        # Part 2: Core Attention (Custom Op)
        # ============================================================
625
626
        # Note: we should not use torch.empty here like other attention backends,
        # see discussions in https://github.com/vllm-project/vllm/pull/28182
627
628
629
630
631
632
633
634
635
636
637
        core_attn_out = torch.zeros(
            (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim),
            dtype=hidden_states.dtype,
            device=hidden_states.device,
        )

        torch.ops.vllm.gdn_attention_core(
            mixed_qkv,
            b,
            a,
            core_attn_out,
638
639
640
            self.prefix,
        )

641
642
643
644
645
646
647
648
649
650
651
652
        # ============================================================
        # Part 3: Output Projection
        # ============================================================
        z_shape_og = z.shape
        # Reshape input data into 2D tensor
        core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
        z = z.reshape(-1, z.shape[-1])
        core_attn_out = self.norm(core_attn_out, z)
        core_attn_out = core_attn_out.reshape(z_shape_og)
        core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
        output[:num_tokens], _ = self.out_proj(core_attn_out)

653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
    def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None:
        """Warm up GDN prefill kernels during V1 profiling.

        During V1 profile runs, ``_forward_core`` returns early because
        ``attn_metadata`` is ``None``, so the autotuned kernels used by
        ``chunk_gated_delta_rule`` (e.g. ``solve_tril``,
        ``chunk_scaled_dot_kkt``) are never invoked.  After profiling,
        vLLM allocates KV cache using most of the remaining GPU memory.
        When the first real inference triggers the autotuner it OOMs
        because there is not enough memory left for benchmarking.

        This method runs minimal forward passes through
        ``chunk_gated_delta_rule`` with small dummy tensors to force
        autotuning while GPU memory is still plentiful.  The autotuner
        results are cached globally, so only the first layer incurs
        actual benchmarking cost.

        Most kernels use a fixed ``BT = chunk_size`` (64), but
        ``chunk_fwd_kernel_o`` recomputes ``BT`` from the sequence
        length: ``min(64, max(16, next_power_of_2(T)))``.  Since ``BT``
        is part of its autotune key, we run warmup passes with T = 16,
        32, and 64 to cover all possible ``BT`` values.

        The decode path uses ``fused_sigmoid_gating_delta_rule_update``
        which has fixed kernel parameters (no autotuning), so only the
        prefill (chunked) path needs warming up.
        """
        if hasattr(self, "_prefill_kernels_warmed_up"):
            return
        self._prefill_kernels_warmed_up = True

        device = mixed_qkv.device
        dtype = mixed_qkv.dtype
        num_k_heads = self.num_k_heads // self.tp_size
        num_v_heads = self.num_v_heads // self.tp_size
        _, state_dtype = self.get_state_dtype()

        # Run warmup for each possible BT value of chunk_fwd_kernel_o:
        #   T=16 → BT=16, T=32 → BT=32, T=64 → BT=64.
        # Other kernels always use BT=chunk_size(64), so their autotune
        # cache is populated on the first pass and reused thereafter.
        for T in (16, 32, 64):
            q = torch.randn(
                1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype
            )
            k = torch.randn(
                1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype
            )
            v = torch.randn(
                1, T, num_v_heads, self.head_v_dim, device=device, dtype=dtype
            )
            g = torch.randn(1, T, num_v_heads, device=device, dtype=dtype)
            beta = torch.randn(1, T, num_v_heads, device=device, dtype=dtype)
            state = torch.zeros(
                1,
                num_v_heads,
                self.head_v_dim,
                self.head_k_dim,
                device=device,
                dtype=state_dtype,
            )
            cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.long)

            try:
                self.chunk_gated_delta_rule(
                    q=q,
                    k=k,
                    v=v,
                    g=g,
                    beta=beta,
                    initial_state=state,
                    output_final_state=False,
                    cu_seqlens=cu_seqlens,
                    use_qk_l2norm_in_kernel=True,
                )
            except Exception:
                logger.warning(
                    "GDN prefill kernel warmup (T=%d) failed for "
                    "layer %s. First inference may OOM due to "
                    "autotuner.",
                    T,
                    self.prefix,
                    exc_info=True,
                )
            else:
                logger.debug(
                    "GDN prefill kernel warmup (T=%d) completed for layer %s",
                    T,
                    self.prefix,
                )
            finally:
                del q, k, v, g, beta, state, cu_seqlens

        torch.accelerator.empty_cache()

748
    def _forward_core(
749
        self,
750
751
752
753
        mixed_qkv: torch.Tensor,
        b: torch.Tensor,
        a: torch.Tensor,
        core_attn_out: torch.Tensor,
754
755
756
757
758
    ):
        forward_context = get_forward_context()
        attn_metadata: AttentionMetadata = forward_context.attn_metadata

        if attn_metadata is None:
759
760
761
            # V1 profile run — warm up prefill kernels so that
            # autotuning completes before KV cache allocation.
            self._warmup_prefill_kernels(mixed_qkv)
762
763
764
765
766
            return

        assert isinstance(attn_metadata, dict)
        attn_metadata = attn_metadata[self.prefix]
        assert isinstance(attn_metadata, GDNAttentionMetadata)
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782

        if (
            self.enable_packed_recurrent_decode
            and attn_metadata.spec_sequence_masks is None
            and attn_metadata.num_prefills == 0
            and attn_metadata.num_decodes > 0
        ):
            return self._forward_core_decode_non_spec(
                mixed_qkv=mixed_qkv,
                b=b,
                a=a,
                core_attn_out=core_attn_out,
                attn_metadata=attn_metadata,
                virtual_engine=forward_context.virtual_engine,
            )

783
784
785
786
        has_initial_state = attn_metadata.has_initial_state
        spec_query_start_loc = attn_metadata.spec_query_start_loc
        non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
        spec_sequence_masks = attn_metadata.spec_sequence_masks
787
788
        spec_token_indx = attn_metadata.spec_token_indx
        non_spec_token_indx = attn_metadata.non_spec_token_indx
789
790
791
792
793
        spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor  # noqa: E501
        non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor  # noqa: E501
        self_kv_cache = self.kv_cache[forward_context.virtual_engine]
        conv_state = self_kv_cache[0].transpose(-1, -2)
        ssm_state = self_kv_cache[1]
794
        num_actual_tokens = attn_metadata.num_actual_tokens
795
        num_accepted_tokens = attn_metadata.num_accepted_tokens
796

797
798
799
        mixed_qkv = mixed_qkv[:num_actual_tokens]
        b = b[:num_actual_tokens]
        a = a[:num_actual_tokens]
800

801
        # 1. Convolution sequence transformation
802
803
804
        conv_weights = self.conv1d.weight.view(
            self.conv1d.weight.size(0), self.conv1d.weight.size(2)
        )
805
806

        if spec_sequence_masks is not None:
807
            if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
808
809
810
                mixed_qkv_spec = mixed_qkv
                mixed_qkv_non_spec = None
            else:
811
812
                mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx)
                mixed_qkv_non_spec = mixed_qkv.index_select(0, non_spec_token_indx)
813
814
815
816
        else:
            mixed_qkv_spec = None
            mixed_qkv_non_spec = mixed_qkv

817
        # 1.1: Process the multi-query part
818
819
820
821
822
823
824
        if spec_sequence_masks is not None:
            mixed_qkv_spec = causal_conv1d_update(
                mixed_qkv_spec,
                conv_state,
                conv_weights,
                self.conv1d.bias,
                self.activation,
825
826
827
                conv_state_indices=spec_state_indices_tensor[:, 0][
                    : attn_metadata.num_spec_decodes
                ],
828
                num_accepted_tokens=num_accepted_tokens,
829
830
                query_start_loc=spec_query_start_loc,
                max_query_len=spec_state_indices_tensor.size(-1),
831
832
833
                validate_data=False,
            )

834
        # 1.2: Process the remaining part
835
        if attn_metadata.num_prefills > 0:
836
            mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1)
837
            # - "cache_indices" updates the conv_state cache in positions
838
            #   pointed to by "state_indices_tensor"
839
            mixed_qkv_non_spec = causal_conv1d_fn(
840
                mixed_qkv_non_spec_T,
841
842
843
844
845
846
847
                conv_weights,
                self.conv1d.bias,
                activation=self.activation,
                conv_states=conv_state,
                has_initial_state=has_initial_state,
                cache_indices=non_spec_state_indices_tensor,
                query_start_loc=non_spec_query_start_loc,
848
                metadata=attn_metadata,
849
850
851
852
853
854
855
856
            ).transpose(0, 1)
        elif attn_metadata.num_decodes > 0:
            mixed_qkv_non_spec = causal_conv1d_update(
                mixed_qkv_non_spec,
                conv_state,
                conv_weights,
                self.conv1d.bias,
                self.activation,
857
                conv_state_indices=non_spec_state_indices_tensor[
858
                    : attn_metadata.num_actual_tokens
859
                ],
860
861
862
863
864
                validate_data=True,
            )
        else:
            mixed_qkv_non_spec = None

865
        query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec)
866
        query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(
867
868
            mixed_qkv_non_spec
        )
869

870
871
872
        if attn_metadata.num_prefills > 0:
            g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)
            if spec_sequence_masks is not None:
873
874
                g_non_spec = g.index_select(1, non_spec_token_indx)
                beta_non_spec = beta.index_select(1, non_spec_token_indx)
875
876
877
            else:
                g_non_spec = g
                beta_non_spec = beta
878
        else:
879
880
            g_non_spec = None
            beta_non_spec = None
881

882
        # 2. Recurrent attention
883

884
        # 2.1: Process the multi-query part
885
        if spec_sequence_masks is not None:
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
            core_attn_out_spec, last_recurrent_state = (
                fused_sigmoid_gating_delta_rule_update(
                    A_log=self.A_log,
                    a=a,
                    b=b,
                    dt_bias=self.dt_bias,
                    q=query_spec,
                    k=key_spec,
                    v=value_spec,
                    initial_state=ssm_state,
                    inplace_final_state=True,
                    cu_seqlens=spec_query_start_loc[
                        : attn_metadata.num_spec_decodes + 1
                    ],
                    ssm_state_indices=spec_state_indices_tensor,
                    num_accepted_tokens=num_accepted_tokens,
                    use_qk_l2norm_in_kernel=True,
                )
904
            )
905
906
907
        else:
            core_attn_out_spec, last_recurrent_state = None, None

908
        # 2.2: Process the remaining part
909
        if attn_metadata.num_prefills > 0:
910
            initial_state = ssm_state[non_spec_state_indices_tensor].contiguous()
911
912
913
914
            initial_state[~has_initial_state, ...] = 0
            (
                core_attn_out_non_spec,
                last_recurrent_state,
915
            ) = self.chunk_gated_delta_rule(
916
917
918
919
920
921
922
923
924
925
926
927
                q=query_non_spec,
                k=key_non_spec,
                v=value_non_spec,
                g=g_non_spec,
                beta=beta_non_spec,
                initial_state=initial_state,
                output_final_state=True,
                cu_seqlens=non_spec_query_start_loc,
                use_qk_l2norm_in_kernel=True,
            )
            # Init cache
            ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(
928
929
                ssm_state.dtype
            )
930
931
        elif attn_metadata.num_decodes > 0:
            core_attn_out_non_spec, last_recurrent_state = (
932
933
934
935
936
                fused_sigmoid_gating_delta_rule_update(
                    A_log=self.A_log,
                    a=a,
                    b=b,
                    dt_bias=self.dt_bias,
937
938
939
940
941
                    q=query_non_spec,
                    k=key_non_spec,
                    v=value_non_spec,
                    initial_state=ssm_state,
                    inplace_final_state=True,
942
943
944
                    cu_seqlens=non_spec_query_start_loc[
                        : attn_metadata.num_decodes + 1
                    ],
945
946
                    ssm_state_indices=non_spec_state_indices_tensor,
                    use_qk_l2norm_in_kernel=True,
947
948
                )
            )
949
950
951
        else:
            core_attn_out_non_spec, last_recurrent_state = None, None

952
        # 3. Merge core attention output
953
        if spec_sequence_masks is not None and core_attn_out_non_spec is not None:
954
            merged_out = torch.empty(
955
956
957
958
                (1, num_actual_tokens, *core_attn_out_spec.shape[2:]),
                dtype=core_attn_out_non_spec.dtype,
                device=core_attn_out_non_spec.device,
            )
959
960
961
            merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec)
            merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec)
            core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)
962
        elif spec_sequence_masks is not None:
963
            core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)
964
        else:
965
            core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)
966

967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
    def _forward_core_decode_non_spec(
        self,
        mixed_qkv: torch.Tensor,
        b: torch.Tensor,
        a: torch.Tensor,
        core_attn_out: torch.Tensor,
        attn_metadata: GDNAttentionMetadata,
        virtual_engine: int,
    ):
        """
        Core attention computation with a packed non-spec decode fast path.
        """
        non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor  # noqa: E501
        self_kv_cache = self.kv_cache[virtual_engine]
        conv_state = self_kv_cache[0].transpose(-1, -2)
        ssm_state = self_kv_cache[1]
        num_actual_tokens = attn_metadata.num_actual_tokens

        mixed_qkv = mixed_qkv[:num_actual_tokens]
        b = b[:num_actual_tokens]
        a = a[:num_actual_tokens]

        conv_weights = self.conv1d.weight.view(
            self.conv1d.weight.size(0), self.conv1d.weight.size(2)
        )
        mixed_qkv_non_spec = causal_conv1d_update(
            mixed_qkv,
            conv_state,
            conv_weights,
            self.conv1d.bias,
            self.activation,
            conv_state_indices=non_spec_state_indices_tensor[:num_actual_tokens],
            validate_data=False,
        )
        out_buf = core_attn_out[:num_actual_tokens].unsqueeze(1)
        fused_recurrent_gated_delta_rule_packed_decode(
            mixed_qkv=mixed_qkv_non_spec,
            a=a,
            b=b,
            A_log=self.A_log,
            dt_bias=self.dt_bias,
            scale=self.head_k_dim**-0.5,
            initial_state=ssm_state,
            out=out_buf,
            ssm_state_indices=non_spec_state_indices_tensor[:num_actual_tokens],
            use_qk_l2norm_in_kernel=True,
        )
        return

1016
1017
1018
1019
1020

class Qwen3NextAttention(nn.Module):
    def __init__(
        self,
        config: Qwen3NextConfig,
1021
1022
1023
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = config.head_dim or (self.hidden_size // self.num_heads)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.dual_chunk_attention_config = getattr(
1048
1049
            config, "dual_chunk_attention_config", None
        )
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
        self.attn_output_gate = getattr(config, "attn_output_gate", True)

        self.qkv_proj = QKVParallelLinear(
            config.hidden_size,
            self.head_dim,
            self.total_num_heads * (1 + self.attn_output_gate),
            self.total_num_kv_heads,
            bias=getattr(config, "qkv_bias", False),
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            config.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        self.rotary_emb = get_rope(
            head_size=self.head_dim,
            max_position=config.max_position_embeddings,
1073
            rope_parameters=config.rope_parameters,
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
            dual_chunk_attention_config=self.dual_chunk_attention_config,
        )

        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            **{
                "layer_idx": extract_layer_index(prefix),
1087
1088
1089
1090
                "dual_chunk_attention_config": self.dual_chunk_attention_config,
            }
            if self.dual_chunk_attention_config
            else {},
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
        )

        self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
        self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        output: torch.Tensor,
        hidden_states: torch.Tensor,
    ):
        qkv, _ = self.qkv_proj(hidden_states)

        if self.attn_output_gate:
            q_gate, k, v = qkv.split(
1106
1107
                [self.q_size * 2, self.kv_size, self.kv_size], dim=-1
            )
1108
1109
1110
1111
1112
1113
            orig_shape = q_gate.shape[:-1]
            q_gate = q_gate.view(*orig_shape, self.num_heads, -1)
            q, gate = torch.chunk(q_gate, 2, dim=-1)
            q = q.reshape(*orig_shape, -1)
            gate = gate.reshape(*orig_shape, -1)
        else:
1114
            q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
1115
1116

        q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view(
1117
1118
            -1, self.num_heads * self.head_dim
        )
1119
        k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view(
1120
1121
            -1, self.num_kv_heads * self.head_dim
        )
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136

        q, k = self.rotary_emb(positions, q, k)

        attn_output = self.attn(q, k, v)

        if self.attn_output_gate:
            gate = torch.sigmoid(gate)
            attn_output = attn_output * gate

        output[:], _ = self.o_proj(attn_output)


class Qwen3NextDecoderLayer(nn.Module):
    def __init__(
        self,
1137
        vllm_config: VllmConfig,
1138
1139
1140
1141
        layer_type: str,
        prefix: str = "",
    ) -> None:
        super().__init__()
1142
1143
1144
1145
1146
1147

        config = vllm_config.model_config.hf_config
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        speculative_config = vllm_config.speculative_config
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158

        self.layer_type = layer_type
        self.layer_idx = extract_layer_index(prefix)

        if self.layer_type == "linear_attention":
            self.linear_attn = Qwen3NextGatedDeltaNet(
                config,
                model_config=model_config,
                cache_config=cache_config,
                quant_config=quant_config,
                speculative_config=speculative_config,
1159
1160
                prefix=f"{prefix}.linear_attn",
            )
1161
1162
1163
1164
1165
1166
        elif self.layer_type == "full_attention":
            self.self_attn = Qwen3NextAttention(
                config,
                model_config=model_config,
                cache_config=cache_config,
                quant_config=quant_config,
1167
                prefix=f"{prefix}.self_attn",
1168
1169
1170
1171
            )
        else:
            raise ValueError(f"Invalid layer_type {self.layer_type}")

1172
1173
1174
        mlp_only_layers = (
            [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
        )
1175
        if (self.layer_idx not in mlp_only_layers) and (
1176
1177
1178
            config.num_experts > 0
            and (self.layer_idx + 1) % config.decoder_sparse_step == 0
        ):
1179
            self.mlp = Qwen3NextSparseMoeBlock(
1180
                vllm_config=vllm_config,
1181
1182
1183
1184
1185
1186
1187
1188
                prefix=f"{prefix}.mlp",
            )
        else:
            self.mlp = Qwen3NextMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
1189
                prefix=f"{prefix}.mlp",
1190
1191
            )

1192
1193
1194
        self.input_layernorm = Qwen3NextRMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
1195
        self.post_attention_layernorm = Qwen3NextRMSNorm(
1196
1197
            config.hidden_size, eps=config.rms_norm_eps
        )
1198
1199
1200
1201
1202
1203
1204

        self.layer_scale = getattr(config, "layer_scale", False)
        if self.layer_scale:
            self.attn_layer_scale = torch.nn.Parameter(
                torch.zeros(
                    1,
                    1,
1205
                    config.hidden_size,
1206
1207
                ),
            )
1208
1209
1210
1211
            self.ffn_layer_scale = torch.nn.Parameter(
                torch.zeros(
                    1,
                    1,
1212
                    config.hidden_size,
1213
1214
                ),
            )
1215
1216
1217
1218

    def forward(
        self,
        hidden_states: torch.Tensor,
1219
        residual: torch.Tensor | None,
1220
1221
1222
1223
1224
1225
1226
        positions: torch.Tensor = None,
        **kwargs: object,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
1227
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247

        self_attention_output = torch.empty_like(hidden_states)
        if self.layer_type == "linear_attention":
            self.linear_attn(
                hidden_states=hidden_states,
                output=self_attention_output,
            )
        elif self.layer_type == "full_attention":
            self.self_attn(
                hidden_states=hidden_states,
                output=self_attention_output,
                positions=positions,
            )
        else:
            raise ValueError("Invalid layer_type")
        hidden_states = self_attention_output

        if self.layer_scale:
            if len(hidden_states.shape) == 2:
                hidden_states = hidden_states * (
1248
1249
                    self.attn_layer_scale.to(hidden_states.dtype)[0] + 1
                )
1250
1251
            else:
                hidden_states = hidden_states * (
1252
1253
                    self.attn_layer_scale.to(hidden_states.dtype) + 1
                )
1254
1255

        # Fully Connected
1256
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
1257
1258
1259
1260
1261
        hidden_states = self.mlp(hidden_states)

        if self.layer_scale:
            if len(hidden_states.shape) == 2:
                hidden_states = hidden_states * (
1262
1263
                    self.ffn_layer_scale.to(hidden_states.dtype)[0] + 1
                )
1264
            else:
1265
                assert len(hidden_states.shape) == len(self.ffn_layer_scale.shape), (
1266
1267
1268
                    f"shape must be the same {len(hidden_states.shape)}, "
                    f"{len(self.ffn_layer_scale.shape)}"
                )
1269
                hidden_states = hidden_states * (
1270
1271
                    self.ffn_layer_scale.to(hidden_states.dtype) + 1
                )
1272
1273
1274
1275
1276
1277
1278
1279
1280

        return hidden_states, residual


@support_torch_compile
class Qwen3NextModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

1281
        config: Qwen3NextConfig = vllm_config.model_config.hf_text_config
1282
        parallel_config = vllm_config.parallel_config
1283

1284
1285
1286
1287
        eplb_config = parallel_config.eplb_config
        self.num_redundant_experts = eplb_config.num_redundant_experts

        self.config = config
1288
1289

        self.vocab_size = config.vocab_size
1290
1291
1292
1293
1294
1295
1296
1297

        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
        )

        def get_layer(prefix: str):
            return Qwen3NextDecoderLayer(
1298
                vllm_config,
1299
1300
1301
1302
1303
                layer_type=config.layer_types[extract_layer_index(prefix)],
                prefix=prefix,
            )

        self.start_layer, self.end_layer, self.layers = make_layers(
1304
1305
1306
1307
1308
            config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
        )
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
1309

1310
        if get_pp_group().is_last_rank:
1311
            self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1312
1313
        else:
            self.norm = PPMissingLayer()
1314

1315
1316
        self.aux_hidden_state_layers: tuple[int, ...] = ()

1317
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
1318
1319
1320
1321
        return self.embed_tokens(input_ids)

    def forward(
        self,
1322
        input_ids: torch.Tensor | None,
1323
        positions: torch.Tensor,
1324
1325
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1326
    ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
1327
1328
1329
1330
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
1331
                hidden_states = self.embed_input_ids(input_ids)
1332
1333
1334
1335
1336
1337
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

1338
1339
1340
1341
1342
1343
1344
1345
1346
        aux_hidden_states = []
        for layer_idx, layer in enumerate(
            islice(self.layers, self.start_layer, self.end_layer),
            start=self.start_layer,
        ):
            if layer_idx in self.aux_hidden_state_layers:
                aux_hidden_states.append(
                    hidden_states + residual if residual is not None else hidden_states
                )
1347
1348
1349
1350
1351
1352
1353
            hidden_states, residual = layer(
                positions=positions,
                hidden_states=hidden_states,
                residual=residual,
            )

        if not get_pp_group().is_last_rank:
1354
1355
1356
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
1357
        hidden_states, _ = self.norm(hidden_states, residual)
1358
1359
        if aux_hidden_states:
            return hidden_states, aux_hidden_states
1360
1361
1362
1363
1364
        return hidden_states

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
1365
        return SharedFusedMoE.make_expert_params_mapping(
1366
            self,
1367
1368
1369
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
1370
            num_experts=getattr(self.config, "num_experts", 0),
1371
1372
            num_redundant_experts=self.num_redundant_experts,
        )
1373

1374
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        expert_params_mapping = self.get_expert_mapping()
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue

            if name.startswith("mtp."):
                continue

1394
1395
1396
1397
1398
1399
            # Remapping the name of FP8 kv-scale.
            if name.endswith("scale"):
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue

                if "mlp.experts" in name:
                    continue

                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
                # name = apply_attn_prefix(name, params_dict)
                if name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
                    # Skip loading extra bias for GPTQ models.
1431
1432
1433
                    if (
                        name.endswith(".bias") or name.endswith("_bias")
                    ) and name not in params_dict:
1434
                        continue
1435
1436
                    if name not in params_dict:
                        continue
1437
1438
                    param = params_dict[name]
                    weight_loader = param.weight_loader
1439
1440
1441
1442
1443
1444
1445
                    weight_loader(
                        param,
                        loaded_weight,
                        name,
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
1446
1447
1448
1449
1450
1451
1452
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
                    if is_pp_missing_parameter(name, self):
                        continue
1453
1454
1455
1456
1457
                    if name not in params_dict:
                        logger.warning_once(
                            f"Parameter {name} not found in params_dict, skip loading"
                        )
                        continue
1458
                    param = params_dict[name]
1459
1460
1461
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
1462
1463
1464
1465
1466
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
class QwenNextMixtureOfExperts(MixtureOfExperts):
    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
        for layer in self.model.layers:
            if isinstance(layer.mlp, Qwen3NextSparseMoeBlock):
                moe = layer.mlp
                moe.n_local_physical_experts = num_local_physical_experts
                moe.n_physical_experts = num_physical_experts
                moe.n_redundant_experts = self.num_redundant_experts
                moe.experts.update_expert_map()

    def set_moe_parameters(self):
        self.expert_weights = []

        self.moe_layers = []
        example_moe = None
        for layer in self.model.layers:
            if isinstance(layer, Qwen3NextDecoderLayer) and isinstance(
                layer.mlp, Qwen3NextSparseMoeBlock
            ):
                example_moe = layer.mlp
                self.moe_layers.append(layer.mlp.experts)

1497
1498
        if example_moe is None:
            raise RuntimeError("No Qwen3Next layer found in the model.layers.")
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510

        # Set MoE hyperparameters
        self.num_moe_layers = len(self.moe_layers)
        self.num_expert_groups = 1
        self.num_shared_experts = 0
        self.num_logical_experts = example_moe.n_logical_experts
        self.num_physical_experts = example_moe.n_physical_experts
        self.num_local_physical_experts = example_moe.n_local_physical_experts
        self.num_routed_experts = example_moe.n_routed_experts
        self.num_redundant_experts = example_moe.n_redundant_experts


1511
class Qwen3NextForCausalLM(
1512
1513
1514
1515
1516
1517
    nn.Module,
    HasInnerState,
    SupportsLoRA,
    SupportsPP,
    QwenNextMixtureOfExperts,
    IsHybrid,
1518
):
1519
1520
1521
1522
1523
1524
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
1525
        "gate_up_proj": ["gate_proj", "up_proj"],
1526
1527
        "in_proj_qkvz": ["in_proj_qkvz"],
        "in_proj_ba": ["in_proj_ba"],
1528
1529
1530
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1531
        config = vllm_config.model_config.hf_text_config
1532
1533
1534
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
1535

1536
        scheduler_config = vllm_config.scheduler_config
1537
1538
1539
1540
1541
        if cache_config.mamba_cache_mode == "all":
            raise NotImplementedError(
                "Qwen3Next currently does not support 'all' prefix caching, "
                "please use '--mamba-cache-mode=align' instead"
            )
1542
1543
1544
1545
1546
        self.quant_config = vllm_config.quant_config

        super().__init__()
        self.config = config
        self.scheduler_config = scheduler_config
1547
1548
1549
        self.model = Qwen3NextModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
1550

1551
        self.lm_head = ParallelLMHead(
1552
            config.vocab_size,
1553
            config.hidden_size,
1554
1555
            prefix=maybe_prefix(prefix, "lm_head"),
        )
1556
        self.logits_processor = LogitsProcessor(config.vocab_size)
1557
        self.make_empty_intermediate_tensors = (
1558
1559
            self.model.make_empty_intermediate_tensors
        )
1560
1561

        # Set MoE hyperparameters
1562
        self.set_moe_parameters()
1563

1564
1565
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
1566
1567
1568

    def forward(
        self,
1569
        input_ids: torch.Tensor | None,
1570
        positions: torch.Tensor,
1571
1572
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1573
1574
        **kwargs: object,
    ):
1575
1576
1577
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
1578
1579
1580
1581
1582
1583
1584
1585
1586

        return hidden_states

    @classmethod
    def get_mamba_state_dtype_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[torch.dtype, torch.dtype]:
        return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
1587
1588
1589
            vllm_config.model_config.dtype,
            vllm_config.cache_config.mamba_cache_dtype,
            vllm_config.cache_config.mamba_ssm_cache_dtype,
1590
        )
1591
1592
1593

    @classmethod
    def get_mamba_state_shape_from_config(
1594
        cls, vllm_config: "VllmConfig"
1595
1596
    ) -> tuple[tuple[int, int], tuple[int, int]]:
        parallel_config = vllm_config.parallel_config
1597
        hf_config = vllm_config.model_config.hf_text_config
1598
        tp_size = parallel_config.tensor_parallel_size
1599
1600
1601
1602
1603
        num_spec = (
            vllm_config.speculative_config.num_speculative_tokens
            if vllm_config.speculative_config
            else 0
        )
1604
        return MambaStateShapeCalculator.gated_delta_net_state_shape(
1605
1606
1607
1608
1609
1610
1611
1612
            tp_size,
            hf_config.linear_num_key_heads,
            hf_config.linear_num_value_heads,
            hf_config.linear_key_head_dim,
            hf_config.linear_value_head_dim,
            hf_config.linear_conv_kernel_dim,
            num_spec,
        )
1613

1614
1615
1616
1617
    @classmethod
    def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
        return MambaStateCopyFuncCalculator.gated_delta_net_state_copy_func()

1618
1619
1620
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1621
    ) -> torch.Tensor | None:
1622
        return self.logits_processor(self.lm_head, hidden_states)
1623

1624
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=["mtp."],
        )
        return loader.load_weights(weights)

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        return self.model.get_expert_mapping()


1635
1636
1637
1638
1639
def gdn_attention_core(
    mixed_qkv: torch.Tensor,
    b: torch.Tensor,
    a: torch.Tensor,
    core_attn_out: torch.Tensor,
1640
1641
    layer_name: str,
) -> None:
1642
1643
1644
1645
1646
    """
    Custom op for the core attention computation.
    Only handles the convolution + recurrent attention part.
    Input/output projections are handled outside this op.
    """
1647
1648
    forward_context: ForwardContext = get_forward_context()
    self = forward_context.no_compile_layers[layer_name]
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
    self._forward_core(
        mixed_qkv=mixed_qkv,
        b=b,
        a=a,
        core_attn_out=core_attn_out,
    )


def gdn_attention_core_fake(
    mixed_qkv: torch.Tensor,
    b: torch.Tensor,
    a: torch.Tensor,
    core_attn_out: torch.Tensor,
1662
1663
    layer_name: str,
) -> None:
1664
    """Fake implementation for torch.compile."""
1665
1666
1667
1668
    return


direct_register_custom_op(
1669
1670
1671
1672
    op_name="gdn_attention_core",
    op_func=gdn_attention_core,
    mutates_args=["core_attn_out"],
    fake_impl=gdn_attention_core_fake,
1673
1674
1675
1676
1677
1678
)


@triton.jit
def fused_gdn_gating_kernel(
    g,
1679
    beta_output,
1680
1681
    A_log,
    a,
1682
    b,
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
    dt_bias,
    seq_len,
    NUM_HEADS: tl.constexpr,
    beta: tl.constexpr,
    threshold: tl.constexpr,
    BLK_HEADS: tl.constexpr,
):
    i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS)
    off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off
    mask = head_off < NUM_HEADS
    blk_A_log = tl.load(A_log + head_off, mask=mask)
    blk_a = tl.load(a + off, mask=mask)
1696
    blk_b = tl.load(b + off, mask=mask)
1697
1698
1699
    blk_bias = tl.load(dt_bias + head_off, mask=mask)
    # If the model is loaded in fp16, without the .float() here, A might be -inf
    x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)
1700
1701
1702
    softplus_x = tl.where(
        beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x
    )
1703
1704
    blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
    tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
1705
    # compute beta_output = sigmoid(b)
1706
1707
1708
1709
    blk_beta_output = tl.sigmoid(blk_b.to(tl.float32))
    tl.store(
        beta_output + off, blk_beta_output.to(beta_output.dtype.element_ty), mask=mask
    )
1710
1711
1712
1713
1714


def fused_gdn_gating(
    A_log: torch.Tensor,
    a: torch.Tensor,
1715
    b: torch.Tensor,
1716
1717
1718
    dt_bias: torch.Tensor,
    beta: float = 1.0,
    threshold: float = 20.0,
1719
1720
1721
1722
1723
1724
1725
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Fused computation of g and beta for Gated Delta Net.
    g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
    beta_output = b.sigmoid()
    TODO maybe use torch.compile to replace this triton kernel
    """
1726
1727
1728
    batch, num_heads = a.shape
    seq_len = 1
    grid = (batch, seq_len, triton.cdiv(num_heads, 8))
1729
    g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
1730
    beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device)
1731
    fused_gdn_gating_kernel[grid](
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
        g,
        beta_output,
        A_log,
        a,
        b,
        dt_bias,
        seq_len,
        num_heads,
        beta,
        threshold,
        8,
        num_warps=1,
1744
    )
1745
    return g, beta_output