qwen3_next.py 50.1 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Qwen3Next model."""
4

5
from collections.abc import Iterable
6
from itertools import islice
7
8
9
10
11
12

import torch
from einops import rearrange
from torch import nn
from transformers.activations import ACT2FN

13
14
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention
15
from vllm.compilation.decorators import support_torch_compile
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from vllm.config import (
    CacheConfig,
    ModelConfig,
    SpeculativeConfig,
    VllmConfig,
    get_current_vllm_config,
)
from vllm.distributed import (
    divide,
    get_ep_group,
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
31
32
33
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.fla.ops import (
34
35
36
    chunk_gated_delta_rule,
    fused_recurrent_gated_delta_rule,
)
37
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
38
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
39
40
41
42
from vllm.model_executor.layers.layernorm import (
    GemmaRMSNorm as Qwen3NextRMSNorm,
)
from vllm.model_executor.layers.layernorm import RMSNormGated
43
44
45
46
47
48
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
49
50
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.abstract import MambaBase
51
from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weight_loader
52
from vllm.model_executor.layers.mamba.mamba_utils import (
53
54
55
    MambaStateDtypeCalculator,
    MambaStateShapeCalculator,
)
56
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
57
58
59
    causal_conv1d_fn,
    causal_conv1d_update,
)
60
61
62
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
63
64
65
    ParallelLMHead,
    VocabParallelEmbedding,
)
66
from vllm.model_executor.model_loader.weight_utils import (
67
68
69
    default_weight_loader,
    sharded_weight_loader,
)
70
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
71
from vllm.model_executor.models.utils import sequence_parallel_chunk
72
73
74
75
76
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import Qwen3NextConfig
from vllm.triton_utils import tl, triton
77
from vllm.utils.torch_utils import direct_register_custom_op
78
79
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata

80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from .interfaces import (
    HasInnerState,
    IsHybrid,
    MixtureOfExperts,
    SupportsLoRA,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
96
97
98
99
100
101
102

logger = init_logger(__name__)

KVCache = tuple[torch.Tensor, torch.Tensor]


class Qwen3NextSparseMoeBlock(nn.Module):
103
    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
104
        super().__init__()
105
106
107
108
109

        config = vllm_config.model_config.hf_config
        parallel_config = vllm_config.parallel_config
        quant_config = vllm_config.quant_config

110
111
112
        self.tp_size = get_tensor_model_parallel_world_size()

        self.ep_group = get_ep_group().device_group
113
        self.ep_rank = get_ep_group().rank_in_group
114
115
116
        self.ep_size = self.ep_group.size()
        self.n_routed_experts = config.num_experts

117
118
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe

119
120
121
        if self.tp_size > config.num_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
122
123
                f"the number of experts {config.num_experts}."
            )
124
125
126
127

        # Load balancing settings.
        vllm_config = get_current_vllm_config()
        eplb_config = vllm_config.parallel_config.eplb_config
128
        self.enable_eplb = parallel_config.enable_eplb
129
130
131

        self.n_logical_experts = self.n_routed_experts
        self.n_redundant_experts = eplb_config.num_redundant_experts
132
        self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
133
134
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

135
136
137
138
139
140
141
142
143
144
145
146
        self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
        self.physical_expert_end = (
            self.physical_expert_start + self.n_local_physical_experts
        )

        self.gate = ReplicatedLinear(
            config.hidden_size,
            config.num_experts,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate",
        )
147

148
149
150
151
152
153
154
        self.shared_expert_gate = ReplicatedLinear(
            config.hidden_size,
            1,
            bias=False,
            quant_config=None,
            prefix=f"{prefix}.shared_expert_gate",
        )
155

156
157
158
159
160
161
        if config.shared_expert_intermediate_size > 0:
            self.shared_expert = Qwen3NextMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.shared_expert_intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
162
163
                reduce_results=False,
                expert_gate=self.shared_expert_gate,
164
                prefix=f"{prefix}.shared_expert",
165
166
167
            )
        else:
            self.shared_expert = None
168
169
170

        self.experts = SharedFusedMoE(
            shared_experts=self.shared_expert,
171
            gate=self.gate,
172
173
174
175
176
177
178
179
180
181
182
            num_experts=self.n_routed_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            reduce_results=False,
            renormalize=config.norm_topk_prob,
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
            is_sequence_parallel=self.is_sequence_parallel,
183
            routing_method_type=RoutingMethodType.Renormalize,
184
        )
185
186
187
188

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
189
        num_tokens, hidden_dim = hidden_states.shape
190
191
        hidden_states = hidden_states.view(-1, hidden_dim)

192
193
194
        if self.is_sequence_parallel:
            hidden_states = sequence_parallel_chunk(hidden_states)

195
196
197
198
199
200
201
202
203
204
205
        if self.experts.is_internal_router:
            # In this case, the gate/router runs inside the FusedMoE class
            final_hidden_states = self.experts(
                hidden_states=hidden_states, router_logits=hidden_states
            )
        else:
            # router_logits: (num_tokens, n_experts)
            router_logits, _ = self.gate(hidden_states)
            final_hidden_states = self.experts(
                hidden_states=hidden_states, router_logits=router_logits
            )
206

207
208
        if self.shared_expert is not None:
            final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
209
210
211

        if self.is_sequence_parallel:
            final_hidden_states = tensor_model_parallel_all_gather(
212
213
                final_hidden_states, 0
            )
214
215
            final_hidden_states = final_hidden_states[:num_tokens]
        elif self.tp_size > 1:
216
            final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(  # noqa E501
217
218
                final_hidden_states
            )
219
220
221
222
223
224
225

        return final_hidden_states.view(orig_shape)


class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
    @property
    def mamba_type(self) -> str:
226
        return "gdn_attention"
227
228
229

    def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
        return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
230
231
            self.model_config.dtype, self.cache_config.mamba_cache_dtype
        )
232
233
234

    def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
        return MambaStateShapeCalculator.gated_delta_net_state_shape(
235
236
237
238
239
240
241
242
            self.tp_size,
            self.num_k_heads,
            self.num_v_heads,
            self.head_k_dim,
            self.head_v_dim,
            self.conv_kernel_size,
            self.num_spec,
        )
243
244
245
246

    def __init__(
        self,
        config: Qwen3NextConfig,
247
248
249
250
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        speculative_config: SpeculativeConfig | None = None,
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.hidden_size = config.hidden_size
        self.num_v_heads = config.linear_num_value_heads
        self.num_k_heads = config.linear_num_key_heads
        self.head_k_dim = config.linear_key_head_dim
        self.head_v_dim = config.linear_value_head_dim
        self.key_dim = self.head_k_dim * self.num_k_heads
        self.value_dim = self.head_v_dim * self.num_v_heads

        self.conv_kernel_size = config.linear_conv_kernel_dim
        self.layer_idx = extract_layer_index(prefix)
        self.activation = config.hidden_act
        self.act = ACT2FN[config.hidden_act]
        self.layer_norm_epsilon = config.rms_norm_eps
        self.prefix = prefix

        self.config = config
        self.model_config = model_config
        self.cache_config = cache_config
        self.quant_config = quant_config
        self.speculative_config = speculative_config
276
277
278
279
280
        self.num_spec = (
            self.speculative_config.num_speculative_tokens
            if self.speculative_config
            else 0
        )
281
282
283
284
285
286
287
288
289
290
291
292
293
294

        # QKV
        self.conv_dim = self.key_dim * 2 + self.value_dim
        self.conv1d = ColumnParallelLinear(
            input_size=self.conv_kernel_size,
            output_size=self.conv_dim,
            bias=False,
            prefix=f"{prefix}.conv1d",
        )
        self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)

        # projection of the input hidden states
        self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2
        self.projection_size_ba = self.num_v_heads * 2
295
        self.in_proj_qkvz = ColumnParallelLinear(
296
            input_size=self.hidden_size,
297
            output_size=self.projection_size_qkvz,
298
299
            bias=False,
            quant_config=quant_config,
300
301
302
303
304
305
306
307
308
            prefix=f"{prefix}.in_proj_qkvz",
        )
        # ba_proj doesn't support blockwise fp8 quantization.
        self.in_proj_ba = ColumnParallelLinear(
            input_size=self.hidden_size,
            output_size=self.projection_size_ba,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.in_proj_ba",
309
310
311
312
313
314
315
        )

        query_key_settings = (self.key_dim, 0, False)
        value_settings = (self.value_dim, 0, False)

        delattr(self.conv1d.weight, "weight_loader")
        set_weight_attrs(
316
317
318
319
320
321
322
323
324
325
326
327
328
            self.conv1d.weight,
            {
                "weight_loader": mamba_v2_sharded_weight_loader(
                    [
                        query_key_settings,
                        query_key_settings,
                        value_settings,
                    ],
                    self.tp_size,
                    self.tp_rank,
                )
            },
        )
329
330
331
332
333
334

        # selective projection used to make dt, B and C input dependant

        # time step projection (discretization)
        # instantiate once and copy inv_dt in init_weights of PretrainedModel
        self.dt_bias = nn.Parameter(
335
336
            torch.ones(self.num_v_heads // self.tp_size),
        )
337
338
339
        self.A_log = nn.Parameter(
            torch.empty(
                divide(self.num_v_heads, self.tp_size),
340
341
            )
        )
342

343
344
        set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)})
        set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
345
346
347
348
349
350

        self.norm = RMSNormGated(
            self.head_v_dim,
            eps=self.layer_norm_epsilon,
            group_size=None,
            norm_before_gate=True,
351
            device=current_platform.current_device(),
352
            dtype=config.dtype,
353
354
        )

355
356
357
358
359
360
361
362
        self.out_proj = RowParallelLinear(
            self.value_dim,
            self.hidden_size,
            bias=False,
            input_is_parallel=True,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378

        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

    def fix_query_key_value_ordering(
        self,
        mixed_qkvz,
        mixed_ba,
    ):
        """
        Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.
        """
        new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + (
            self.num_k_heads // self.tp_size,
379
380
381
382
383
384
385
            (
                self.head_k_dim
                + self.head_k_dim
                + (self.head_v_dim + self.head_v_dim)
                * self.num_v_heads
                // self.num_k_heads
            ),
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
        )
        new_tensor_shape_ba = mixed_qkvz.size()[:-1] + (
            self.num_k_heads // self.tp_size,
            2 * self.num_v_heads // self.num_k_heads,
        )

        mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz)
        mixed_ba = mixed_ba.view(*new_tensor_shape_ba)

        split_arg_list_qkvz = [
            self.head_k_dim,
            self.head_k_dim,
            (self.num_v_heads // self.num_k_heads * self.head_v_dim),
            (self.num_v_heads // self.num_k_heads * self.head_v_dim),
        ]
        split_arg_list_ba = [
            self.num_v_heads // self.num_k_heads,
403
            self.num_v_heads // self.num_k_heads,
404
405
406
407
408
        ]

        # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)]
        # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn],
        #  [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng]
409
        (query, key, value, z) = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=2)
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
        (b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2)

        # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn]
        value = value.reshape(value.size(0), -1, self.head_v_dim)
        z = z.reshape(z.size(0), -1, self.head_v_dim)
        b = b.reshape(b.size(0), self.num_v_heads // self.tp_size)
        a = a.reshape(a.size(0), self.num_v_heads // self.tp_size)

        return query, key, value, z, b, a

    def rearrange_mixed_qkv(self, mixed_qkv):
        if mixed_qkv is None:
            return None, None, None
        query, key, value = torch.split(
            mixed_qkv,
            [
                self.key_dim // self.tp_size,
                self.key_dim // self.tp_size,
                self.value_dim // self.tp_size,
            ],
            dim=-1,
        )
        query, key = map(
433
434
435
436
            lambda x: rearrange(x, "l (h d) -> 1 l h d", d=self.head_k_dim),
            (query, key),
        )
        value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim)
437
        return query.contiguous(), key.contiguous(), value.contiguous()
438
439
440
441
442
443

    def forward(
        self,
        hidden_states: torch.Tensor,
        output: torch.Tensor,
    ):
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
        """
        Forward pass with three parts:
        1. Input projection
        2. Core attention (custom op)
        3. Output projection
        """
        num_tokens = hidden_states.size(0)

        # ============================================================
        # Part 1: Input Projection
        # ============================================================
        projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
        projected_states_ba, _ = self.in_proj_ba(hidden_states)
        query, key, value, z, b, a = self.fix_query_key_value_ordering(
            projected_states_qkvz, projected_states_ba
        )
        query, key, value = map(
            lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value)
        )
        mixed_qkv = torch.cat((query, key, value), dim=-1)

        # ============================================================
        # Part 2: Core Attention (Custom Op)
        # ============================================================
468
469
        # Note: we should not use torch.empty here like other attention backends,
        # see discussions in https://github.com/vllm-project/vllm/pull/28182
470
471
472
473
474
475
476
477
478
479
480
        core_attn_out = torch.zeros(
            (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim),
            dtype=hidden_states.dtype,
            device=hidden_states.device,
        )

        torch.ops.vllm.gdn_attention_core(
            mixed_qkv,
            b,
            a,
            core_attn_out,
481
482
483
            self.prefix,
        )

484
485
486
487
488
489
490
491
492
493
494
495
496
        # ============================================================
        # Part 3: Output Projection
        # ============================================================
        z_shape_og = z.shape
        # Reshape input data into 2D tensor
        core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
        z = z.reshape(-1, z.shape[-1])
        core_attn_out = self.norm(core_attn_out, z)
        core_attn_out = core_attn_out.reshape(z_shape_og)
        core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
        output[:num_tokens], _ = self.out_proj(core_attn_out)

    def _forward_core(
497
        self,
498
499
500
501
        mixed_qkv: torch.Tensor,
        b: torch.Tensor,
        a: torch.Tensor,
        core_attn_out: torch.Tensor,
502
    ):
503
504
505
        """
        Core attention computation (called by custom op).
        """
506
507
508
509
510
511
512
513
514
515
516
517
518
519
        forward_context = get_forward_context()
        attn_metadata: AttentionMetadata = forward_context.attn_metadata

        if attn_metadata is None:
            # V1 profile run
            return

        assert isinstance(attn_metadata, dict)
        attn_metadata = attn_metadata[self.prefix]
        assert isinstance(attn_metadata, GDNAttentionMetadata)
        has_initial_state = attn_metadata.has_initial_state
        spec_query_start_loc = attn_metadata.spec_query_start_loc
        non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
        spec_sequence_masks = attn_metadata.spec_sequence_masks
520
521
        spec_token_indx = attn_metadata.spec_token_indx
        non_spec_token_indx = attn_metadata.non_spec_token_indx
522
523
524
525
526
        spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor  # noqa: E501
        non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor  # noqa: E501
        self_kv_cache = self.kv_cache[forward_context.virtual_engine]
        conv_state = self_kv_cache[0].transpose(-1, -2)
        ssm_state = self_kv_cache[1]
527
        num_actual_tokens = attn_metadata.num_actual_tokens
528
        num_accepted_tokens = attn_metadata.num_accepted_tokens
529

530
531
532
        mixed_qkv = mixed_qkv[:num_actual_tokens]
        b = b[:num_actual_tokens]
        a = a[:num_actual_tokens]
533

534
        # 1. Convolution sequence transformation
535
536
537
        conv_weights = self.conv1d.weight.view(
            self.conv1d.weight.size(0), self.conv1d.weight.size(2)
        )
538
539

        if spec_sequence_masks is not None:
540
            if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
541
542
543
                mixed_qkv_spec = mixed_qkv
                mixed_qkv_non_spec = None
            else:
544
545
                mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx)
                mixed_qkv_non_spec = mixed_qkv.index_select(0, non_spec_token_indx)
546
547
548
549
        else:
            mixed_qkv_spec = None
            mixed_qkv_non_spec = mixed_qkv

550
        # 1.1: Process the multi-query part
551
552
553
554
555
556
557
        if spec_sequence_masks is not None:
            mixed_qkv_spec = causal_conv1d_update(
                mixed_qkv_spec,
                conv_state,
                conv_weights,
                self.conv1d.bias,
                self.activation,
558
559
560
                conv_state_indices=spec_state_indices_tensor[:, 0][
                    : attn_metadata.num_spec_decodes
                ],
561
                num_accepted_tokens=num_accepted_tokens,
562
563
                query_start_loc=spec_query_start_loc,
                max_query_len=spec_state_indices_tensor.size(-1),
564
565
566
                validate_data=False,
            )

567
        # 1.2: Process the remaining part
568
        if attn_metadata.num_prefills > 0:
569
            mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1)
570
            # - "cache_indices" updates the conv_state cache in positions
571
            #   pointed to by "state_indices_tensor"
572
            mixed_qkv_non_spec = causal_conv1d_fn(
573
                mixed_qkv_non_spec_T,
574
575
576
577
578
579
580
                conv_weights,
                self.conv1d.bias,
                activation=self.activation,
                conv_states=conv_state,
                has_initial_state=has_initial_state,
                cache_indices=non_spec_state_indices_tensor,
                query_start_loc=non_spec_query_start_loc,
581
                metadata=attn_metadata,
582
583
584
585
586
587
588
589
            ).transpose(0, 1)
        elif attn_metadata.num_decodes > 0:
            mixed_qkv_non_spec = causal_conv1d_update(
                mixed_qkv_non_spec,
                conv_state,
                conv_weights,
                self.conv1d.bias,
                self.activation,
590
                conv_state_indices=non_spec_state_indices_tensor[
591
                    : attn_metadata.num_actual_tokens
592
                ],
593
594
595
596
597
                validate_data=True,
            )
        else:
            mixed_qkv_non_spec = None

598
        query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec)
599
        query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(
600
601
            mixed_qkv_non_spec
        )
602

603
        g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)
604
605

        if spec_sequence_masks is not None:
606
            if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
607
608
609
610
611
                g_spec = g
                beta_spec = beta
                g_non_spec = None
                beta_non_spec = None
            else:
612
613
614
615
                g_spec = g.index_select(1, spec_token_indx)
                beta_spec = beta.index_select(1, spec_token_indx)
                g_non_spec = g.index_select(1, non_spec_token_indx)
                beta_non_spec = beta.index_select(1, non_spec_token_indx)
616
617
618
619
620
621
        else:
            g_spec = None
            beta_spec = None
            g_non_spec = g
            beta_non_spec = beta

622
        # 2. Recurrent attention
623

624
        # 2.1: Process the multi-query part
625
        if spec_sequence_masks is not None:
626
627
628
629
630
631
632
633
634
635
636
637
638
            core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule(
                q=query_spec,
                k=key_spec,
                v=value_spec,
                g=g_spec,
                beta=beta_spec,
                initial_state=ssm_state,
                inplace_final_state=True,
                cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1],
                ssm_state_indices=spec_state_indices_tensor,
                num_accepted_tokens=num_accepted_tokens,
                use_qk_l2norm_in_kernel=True,
            )
639
640
641
        else:
            core_attn_out_spec, last_recurrent_state = None, None

642
        # 2.2: Process the remaining part
643
        if attn_metadata.num_prefills > 0:
644
            initial_state = ssm_state[non_spec_state_indices_tensor].contiguous()
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
            initial_state[~has_initial_state, ...] = 0
            (
                core_attn_out_non_spec,
                last_recurrent_state,
            ) = chunk_gated_delta_rule(
                q=query_non_spec,
                k=key_non_spec,
                v=value_non_spec,
                g=g_non_spec,
                beta=beta_non_spec,
                initial_state=initial_state,
                output_final_state=True,
                cu_seqlens=non_spec_query_start_loc,
                head_first=False,
                use_qk_l2norm_in_kernel=True,
            )
            # Init cache
            ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(
663
664
                ssm_state.dtype
            )
665
666
667
668
669
670
671
672
673
674
        elif attn_metadata.num_decodes > 0:
            core_attn_out_non_spec, last_recurrent_state = (
                fused_recurrent_gated_delta_rule(
                    q=query_non_spec,
                    k=key_non_spec,
                    v=value_non_spec,
                    g=g_non_spec,
                    beta=beta_non_spec,
                    initial_state=ssm_state,
                    inplace_final_state=True,
675
676
677
                    cu_seqlens=non_spec_query_start_loc[
                        : attn_metadata.num_decodes + 1
                    ],
678
679
                    ssm_state_indices=non_spec_state_indices_tensor,
                    use_qk_l2norm_in_kernel=True,
680
681
                )
            )
682
683
684
        else:
            core_attn_out_non_spec, last_recurrent_state = None, None

685
        # 3. Merge core attention output
686
        if spec_sequence_masks is not None and core_attn_out_non_spec is not None:
687
            merged_out = torch.empty(
688
689
690
691
                (1, num_actual_tokens, *core_attn_out_spec.shape[2:]),
                dtype=core_attn_out_non_spec.dtype,
                device=core_attn_out_non_spec.device,
            )
692
693
694
            merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec)
            merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec)
            core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)
695
        elif spec_sequence_masks is not None:
696
            core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)
697
        else:
698
            core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)
699
700
701
702
703
704


class Qwen3NextAttention(nn.Module):
    def __init__(
        self,
        config: Qwen3NextConfig,
705
706
707
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = config.head_dim or (self.hidden_size // self.num_heads)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.dual_chunk_attention_config = getattr(
732
733
            config, "dual_chunk_attention_config", None
        )
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
        self.attn_output_gate = getattr(config, "attn_output_gate", True)

        self.qkv_proj = QKVParallelLinear(
            config.hidden_size,
            self.head_dim,
            self.total_num_heads * (1 + self.attn_output_gate),
            self.total_num_kv_heads,
            bias=getattr(config, "qkv_bias", False),
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            config.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        self.rotary_emb = get_rope(
            head_size=self.head_dim,
            max_position=config.max_position_embeddings,
757
            rope_parameters=config.rope_parameters,
758
759
760
761
762
763
764
765
766
767
768
769
770
            dual_chunk_attention_config=self.dual_chunk_attention_config,
        )

        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            **{
                "layer_idx": extract_layer_index(prefix),
771
772
773
774
                "dual_chunk_attention_config": self.dual_chunk_attention_config,
            }
            if self.dual_chunk_attention_config
            else {},
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
        )

        self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
        self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        output: torch.Tensor,
        hidden_states: torch.Tensor,
    ):
        qkv, _ = self.qkv_proj(hidden_states)

        if self.attn_output_gate:
            q_gate, k, v = qkv.split(
790
791
                [self.q_size * 2, self.kv_size, self.kv_size], dim=-1
            )
792
793
794
795
796
797
            orig_shape = q_gate.shape[:-1]
            q_gate = q_gate.view(*orig_shape, self.num_heads, -1)
            q, gate = torch.chunk(q_gate, 2, dim=-1)
            q = q.reshape(*orig_shape, -1)
            gate = gate.reshape(*orig_shape, -1)
        else:
798
            q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
799
800

        q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view(
801
802
            -1, self.num_heads * self.head_dim
        )
803
        k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view(
804
805
            -1, self.num_kv_heads * self.head_dim
        )
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820

        q, k = self.rotary_emb(positions, q, k)

        attn_output = self.attn(q, k, v)

        if self.attn_output_gate:
            gate = torch.sigmoid(gate)
            attn_output = attn_output * gate

        output[:], _ = self.o_proj(attn_output)


class Qwen3NextDecoderLayer(nn.Module):
    def __init__(
        self,
821
        vllm_config: VllmConfig,
822
823
824
825
        layer_type: str,
        prefix: str = "",
    ) -> None:
        super().__init__()
826
827
828
829
830
831

        config = vllm_config.model_config.hf_config
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        speculative_config = vllm_config.speculative_config
832
833
834
835
836
837
838
839
840
841
842

        self.layer_type = layer_type
        self.layer_idx = extract_layer_index(prefix)

        if self.layer_type == "linear_attention":
            self.linear_attn = Qwen3NextGatedDeltaNet(
                config,
                model_config=model_config,
                cache_config=cache_config,
                quant_config=quant_config,
                speculative_config=speculative_config,
843
844
                prefix=f"{prefix}.linear_attn",
            )
845
846
847
848
849
850
        elif self.layer_type == "full_attention":
            self.self_attn = Qwen3NextAttention(
                config,
                model_config=model_config,
                cache_config=cache_config,
                quant_config=quant_config,
851
                prefix=f"{prefix}.self_attn",
852
853
854
855
            )
        else:
            raise ValueError(f"Invalid layer_type {self.layer_type}")

856
857
858
        mlp_only_layers = (
            [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
        )
859
        if (self.layer_idx not in mlp_only_layers) and (
860
861
862
            config.num_experts > 0
            and (self.layer_idx + 1) % config.decoder_sparse_step == 0
        ):
863
            self.mlp = Qwen3NextSparseMoeBlock(
864
                vllm_config=vllm_config,
865
866
867
868
869
870
871
872
                prefix=f"{prefix}.mlp",
            )
        else:
            self.mlp = Qwen3NextMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
873
                prefix=f"{prefix}.mlp",
874
875
            )

876
877
878
        self.input_layernorm = Qwen3NextRMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
879
        self.post_attention_layernorm = Qwen3NextRMSNorm(
880
881
            config.hidden_size, eps=config.rms_norm_eps
        )
882
883
884
885
886
887
888

        self.layer_scale = getattr(config, "layer_scale", False)
        if self.layer_scale:
            self.attn_layer_scale = torch.nn.Parameter(
                torch.zeros(
                    1,
                    1,
889
                    config.hidden_size,
890
                    dtype=config.dtype,
891
892
                ),
            )
893
894
895
896
            self.ffn_layer_scale = torch.nn.Parameter(
                torch.zeros(
                    1,
                    1,
897
                    config.hidden_size,
898
                    dtype=config.dtype,
899
900
                ),
            )
901
902
903
904

    def forward(
        self,
        hidden_states: torch.Tensor,
905
        residual: torch.Tensor | None,
906
907
908
909
910
911
912
        positions: torch.Tensor = None,
        **kwargs: object,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
913
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933

        self_attention_output = torch.empty_like(hidden_states)
        if self.layer_type == "linear_attention":
            self.linear_attn(
                hidden_states=hidden_states,
                output=self_attention_output,
            )
        elif self.layer_type == "full_attention":
            self.self_attn(
                hidden_states=hidden_states,
                output=self_attention_output,
                positions=positions,
            )
        else:
            raise ValueError("Invalid layer_type")
        hidden_states = self_attention_output

        if self.layer_scale:
            if len(hidden_states.shape) == 2:
                hidden_states = hidden_states * (
934
935
                    self.attn_layer_scale.to(hidden_states.dtype)[0] + 1
                )
936
937
            else:
                hidden_states = hidden_states * (
938
939
                    self.attn_layer_scale.to(hidden_states.dtype) + 1
                )
940
941

        # Fully Connected
942
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
943
944
945
946
947
        hidden_states = self.mlp(hidden_states)

        if self.layer_scale:
            if len(hidden_states.shape) == 2:
                hidden_states = hidden_states * (
948
949
                    self.ffn_layer_scale.to(hidden_states.dtype)[0] + 1
                )
950
            else:
951
                assert len(hidden_states.shape) == len(self.ffn_layer_scale.shape), (
952
953
954
                    f"shape must be the same {len(hidden_states.shape)}, "
                    f"{len(self.ffn_layer_scale.shape)}"
                )
955
                hidden_states = hidden_states * (
956
957
                    self.ffn_layer_scale.to(hidden_states.dtype) + 1
                )
958
959
960
961
962
963
964
965
966
967
968

        return hidden_states, residual


@support_torch_compile
class Qwen3NextModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config: Qwen3NextConfig = vllm_config.model_config.hf_config
        parallel_config = vllm_config.parallel_config
969

970
971
972
973
        eplb_config = parallel_config.eplb_config
        self.num_redundant_experts = eplb_config.num_redundant_experts

        self.config = config
974
975

        self.vocab_size = config.vocab_size
976
977
978
979
980
981
982
983

        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
        )

        def get_layer(prefix: str):
            return Qwen3NextDecoderLayer(
984
                vllm_config,
985
986
987
988
989
                layer_type=config.layer_types[extract_layer_index(prefix)],
                prefix=prefix,
            )

        self.start_layer, self.end_layer, self.layers = make_layers(
990
991
992
993
994
            config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
        )
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
995

996
        if get_pp_group().is_last_rank:
997
            self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
998
999
        else:
            self.norm = PPMissingLayer()
1000

1001
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
1002
1003
1004
1005
1006
1007
        return self.embed_tokens(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1008
1009
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1010
1011
1012
1013
1014
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
1015
                hidden_states = self.embed_input_ids(input_ids)
1016
1017
1018
1019
1020
1021
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

1022
        for layer in islice(self.layers, self.start_layer, self.end_layer):
1023
1024
1025
1026
1027
1028
1029
            hidden_states, residual = layer(
                positions=positions,
                hidden_states=hidden_states,
                residual=residual,
            )

        if not get_pp_group().is_last_rank:
1030
1031
1032
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
1033
1034
1035
1036
1037
1038
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
1039
        return SharedFusedMoE.make_expert_params_mapping(
1040
            self,
1041
1042
1043
1044
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.config.num_experts,
1045
1046
            num_redundant_experts=self.num_redundant_experts,
        )
1047

1048
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        expert_params_mapping = self.get_expert_mapping()
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue

            if name.startswith("mtp."):
                continue

            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue

                if "mlp.experts" in name:
                    continue

                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
                # name = apply_attn_prefix(name, params_dict)
                if name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
                    # Skip loading extra bias for GPTQ models.
1099
1100
1101
                    if (
                        name.endswith(".bias") or name.endswith("_bias")
                    ) and name not in params_dict:
1102
                        continue
1103
1104
                    if name not in params_dict:
                        continue
1105
1106
                    param = params_dict[name]
                    weight_loader = param.weight_loader
1107
1108
1109
1110
1111
1112
1113
                    weight_loader(
                        param,
                        loaded_weight,
                        name,
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
1114
1115
1116
1117
1118
1119
1120
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
                    if is_pp_missing_parameter(name, self):
                        continue
1121
1122
1123
1124
1125
                    if name not in params_dict:
                        logger.warning_once(
                            f"Parameter {name} not found in params_dict, skip loading"
                        )
                        continue
1126
                    param = params_dict[name]
1127
1128
1129
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
1130
1131
1132
1133
1134
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
class QwenNextMixtureOfExperts(MixtureOfExperts):
    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
        for layer in self.model.layers:
            if isinstance(layer.mlp, Qwen3NextSparseMoeBlock):
                moe = layer.mlp
                moe.n_local_physical_experts = num_local_physical_experts
                moe.n_physical_experts = num_physical_experts
                moe.n_redundant_experts = self.num_redundant_experts
                moe.experts.update_expert_map()

    def set_moe_parameters(self):
        self.expert_weights = []

        self.moe_layers = []
        example_moe = None
        for layer in self.model.layers:
            if isinstance(layer, Qwen3NextDecoderLayer) and isinstance(
                layer.mlp, Qwen3NextSparseMoeBlock
            ):
                example_moe = layer.mlp
                self.moe_layers.append(layer.mlp.experts)

1165
1166
        if example_moe is None:
            raise RuntimeError("No Qwen3Next layer found in the model.layers.")
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178

        # Set MoE hyperparameters
        self.num_moe_layers = len(self.moe_layers)
        self.num_expert_groups = 1
        self.num_shared_experts = 0
        self.num_logical_experts = example_moe.n_logical_experts
        self.num_physical_experts = example_moe.n_physical_experts
        self.num_local_physical_experts = example_moe.n_local_physical_experts
        self.num_routed_experts = example_moe.n_routed_experts
        self.num_redundant_experts = example_moe.n_redundant_experts


1179
class Qwen3NextForCausalLM(
1180
1181
1182
1183
1184
1185
    nn.Module,
    HasInnerState,
    SupportsLoRA,
    SupportsPP,
    QwenNextMixtureOfExperts,
    IsHybrid,
1186
):
1187
1188
1189
1190
1191
1192
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
1193
        "gate_up_proj": ["gate_proj", "up_proj"],
1194
1195
1196
1197
1198
1199
1200
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        config = vllm_config.model_config.hf_config
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
1201

1202
        scheduler_config = vllm_config.scheduler_config
1203
        assert not cache_config.enable_prefix_caching, (
1204
            "Qwen3Next currently does not support prefix caching"
1205
        )
1206
1207
1208
1209
1210
        self.quant_config = vllm_config.quant_config

        super().__init__()
        self.config = config
        self.scheduler_config = scheduler_config
1211
1212
1213
        self.model = Qwen3NextModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
1214

1215
        self.lm_head = ParallelLMHead(
1216
            config.vocab_size,
1217
            config.hidden_size,
1218
1219
            prefix=maybe_prefix(prefix, "lm_head"),
        )
1220
        self.logits_processor = LogitsProcessor(config.vocab_size)
1221
        self.make_empty_intermediate_tensors = (
1222
1223
            self.model.make_empty_intermediate_tensors
        )
1224
1225

        # Set MoE hyperparameters
1226
        self.set_moe_parameters()
1227

1228
1229
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
1230
1231
1232
1233
1234

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1235
1236
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1237
1238
        **kwargs: object,
    ):
1239
1240
1241
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
1242
1243
1244
1245
1246
1247
1248
1249
1250

        return hidden_states

    @classmethod
    def get_mamba_state_dtype_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[torch.dtype, torch.dtype]:
        return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
1251
1252
            vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype
        )
1253
1254
1255

    @classmethod
    def get_mamba_state_shape_from_config(
1256
        cls, vllm_config: "VllmConfig"
1257
1258
1259
1260
    ) -> tuple[tuple[int, int], tuple[int, int]]:
        parallel_config = vllm_config.parallel_config
        hf_config = vllm_config.model_config.hf_config
        tp_size = parallel_config.tensor_parallel_size
1261
1262
1263
1264
1265
        num_spec = (
            vllm_config.speculative_config.num_speculative_tokens
            if vllm_config.speculative_config
            else 0
        )
1266
        return MambaStateShapeCalculator.gated_delta_net_state_shape(
1267
1268
1269
1270
1271
1272
1273
1274
            tp_size,
            hf_config.linear_num_key_heads,
            hf_config.linear_num_value_heads,
            hf_config.linear_key_head_dim,
            hf_config.linear_value_head_dim,
            hf_config.linear_conv_kernel_dim,
            num_spec,
        )
1275
1276
1277
1278

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1279
    ) -> torch.Tensor | None:
1280
        return self.logits_processor(self.lm_head, hidden_states)
1281

1282
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=["mtp."],
        )
        return loader.load_weights(weights)

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        return self.model.get_expert_mapping()


1293
1294
1295
1296
1297
def gdn_attention_core(
    mixed_qkv: torch.Tensor,
    b: torch.Tensor,
    a: torch.Tensor,
    core_attn_out: torch.Tensor,
1298
1299
    layer_name: str,
) -> None:
1300
1301
1302
1303
1304
    """
    Custom op for the core attention computation.
    Only handles the convolution + recurrent attention part.
    Input/output projections are handled outside this op.
    """
1305
1306
    forward_context: ForwardContext = get_forward_context()
    self = forward_context.no_compile_layers[layer_name]
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
    self._forward_core(
        mixed_qkv=mixed_qkv,
        b=b,
        a=a,
        core_attn_out=core_attn_out,
    )


def gdn_attention_core_fake(
    mixed_qkv: torch.Tensor,
    b: torch.Tensor,
    a: torch.Tensor,
    core_attn_out: torch.Tensor,
1320
1321
    layer_name: str,
) -> None:
1322
    """Fake implementation for torch.compile."""
1323
1324
1325
1326
    return


direct_register_custom_op(
1327
1328
1329
1330
    op_name="gdn_attention_core",
    op_func=gdn_attention_core,
    mutates_args=["core_attn_out"],
    fake_impl=gdn_attention_core_fake,
1331
1332
1333
1334
1335
1336
)


@triton.jit
def fused_gdn_gating_kernel(
    g,
1337
    beta_output,
1338
1339
    A_log,
    a,
1340
    b,
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
    dt_bias,
    seq_len,
    NUM_HEADS: tl.constexpr,
    beta: tl.constexpr,
    threshold: tl.constexpr,
    BLK_HEADS: tl.constexpr,
):
    i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS)
    off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off
    mask = head_off < NUM_HEADS
    blk_A_log = tl.load(A_log + head_off, mask=mask)
    blk_a = tl.load(a + off, mask=mask)
1354
    blk_b = tl.load(b + off, mask=mask)
1355
1356
1357
    blk_bias = tl.load(dt_bias + head_off, mask=mask)
    # If the model is loaded in fp16, without the .float() here, A might be -inf
    x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)
1358
1359
1360
    softplus_x = tl.where(
        beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x
    )
1361
1362
    blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
    tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
1363
    # compute beta_output = sigmoid(b)
1364
1365
1366
1367
    blk_beta_output = tl.sigmoid(blk_b.to(tl.float32))
    tl.store(
        beta_output + off, blk_beta_output.to(beta_output.dtype.element_ty), mask=mask
    )
1368
1369
1370
1371
1372


def fused_gdn_gating(
    A_log: torch.Tensor,
    a: torch.Tensor,
1373
    b: torch.Tensor,
1374
1375
1376
    dt_bias: torch.Tensor,
    beta: float = 1.0,
    threshold: float = 20.0,
1377
1378
1379
1380
1381
1382
1383
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Fused computation of g and beta for Gated Delta Net.
    g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
    beta_output = b.sigmoid()
    TODO maybe use torch.compile to replace this triton kernel
    """
1384
1385
1386
    batch, num_heads = a.shape
    seq_len = 1
    grid = (batch, seq_len, triton.cdiv(num_heads, 8))
1387
    g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
1388
    beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device)
1389
    fused_gdn_gating_kernel[grid](
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
        g,
        beta_output,
        A_log,
        a,
        b,
        dt_bias,
        seq_len,
        num_heads,
        beta,
        threshold,
        8,
        num_warps=1,
1402
    )
1403
    return g, beta_output