Unverified Commit 30c6e1f5 authored by Yi Zhang's avatar Yi Zhang Committed by GitHub
Browse files

Qwen3-Next support (#10233)


Co-authored-by: default avatarcao1zhg <114661107+cao1zhg@users.noreply.github.com>
Co-authored-by: default avatarispobock <ispobaoke@gmail.com>
Co-authored-by: default avatarBinyao Jiang <byjiang1996@gmail.com>
Co-authored-by: default avatarhebiao064 <hebiaobuaa@gmail.com>
Co-authored-by: default avatarLifu Huang <lifu.hlf@gmail.com>
Co-authored-by: default avatarqingquansong <ustcsqq@gmail.com>
Co-authored-by: default avatarYaoyao Ding <dingyaoyao.cs@gmail.com>
Co-authored-by: default avatarKe Bao <ISPObaoke@163.com>
Co-authored-by: default avatarMinglei Zhu <mingleizhu1122@gmail.com>
parent bfe01a5e
......@@ -6,6 +6,7 @@ from sglang.srt.configs.janus_pro import MultiModalityConfig
from sglang.srt.configs.kimi_vl import KimiVLConfig
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
from sglang.srt.configs.longcat_flash import LongcatFlashConfig
from sglang.srt.configs.qwen3_next import Qwen3NextConfig
from sglang.srt.configs.step3_vl import (
Step3TextConfig,
Step3VisionEncoderConfig,
......@@ -24,4 +25,5 @@ __all__ = [
"Step3VLConfig",
"Step3TextConfig",
"Step3VisionEncoderConfig",
"Qwen3NextConfig",
]
......@@ -147,6 +147,9 @@ class ModelConfig:
):
self.hf_config.architectures[0] = "Ernie4_5_MoeForCausalLMMTP"
if is_draft_model and self.hf_config.architectures[0] == "Qwen3NextForCausalLM":
self.hf_config.architectures[0] = "Qwen3NextForCausalLMMTP"
# Check model type
self.is_generation = is_generation_model(
self.hf_config.architectures, is_embedding
......
# coding=utf-8
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Qwen3Hybrid model configuration"""
import enum
import os
import numpy as np
import torch
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
from sglang.srt.distributed.utils import divide
from sglang.srt.layers.dp_attention import get_attention_tp_size
logger = logging.get_logger(__name__)
# NOTE: HybridLayerType
class HybridLayerType(enum.Enum):
full_attention = "attention"
swa_attention = "swa_attention"
linear_attention = "linear_attention"
mamba2 = "mamba"
class Qwen3NextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen3NextModel`]. It is used to instantiate a
Qwen3-Next model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of
Qwen3-Next-80B-A3B-Instruct [Qwen/Qwen3-Next-80B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the model. Defines the number of different tokens that can be represented by the
`inputs_ids`.
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 5632):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 48):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 2):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
hidden_act (`str`, *optional*, defaults to `"silu"`):
The non-linear activation function in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
partial_rotary_factor (`float`, *optional*, defaults to 0.25):
Percentage of the query and keys which will have rotary embedding.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
head_dim (`int`, *optional*, defaults to 256):
Projection weights dimension in multi-head attention.
linear_conv_kernel_dim (`int`, *optional*, defaults to 4):
Kernel size of the convolution used in linear attention layers.
linear_key_head_dim (`int`, *optional*, defaults to 128):
Dimension of each key head in linear attention.
linear_value_head_dim (`int`, *optional*, defaults to 128):
Dimension of each value head in linear attention.
linear_num_key_heads (`int`, *optional*, defaults to 16):
Number of key heads used in linear attention layers.
linear_num_value_heads (`int`, *optional*, defaults to 32):
Number of value heads used in linear attention layers.
decoder_sparse_step (`int`, *optional*, defaults to 1):
The frequency of the MoE layer.
moe_intermediate_size (`int`, *optional*, defaults to 512):
Intermediate size of the routed expert.
shared_expert_intermediate_size (`int`, *optional*, defaults to 512):
Intermediate size of the shared expert.
num_experts_per_tok (`int`, *optional*, defaults to 10):
Number of selected experts.
num_experts (`int`, *optional*, defaults to 512):
Number of routed experts.
norm_topk_prob (`bool`, *optional*, defaults to `True`):
Whether to normalize the topk probabilities.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether or not the router logits should be returned by the model. Enabling this will also
allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
The aux loss factor for the total loss.
mlp_only_layers (`list[int]`, *optional*, defaults to `[]`):
Indicate which layers use Qwen3NextMLP rather than Qwen3NextSparseMoeBlock
The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
layer_types (`list[str]`, *optional*, defaults to None):
Types of each layer (attention or linear).
```python
>>> from transformers import Qwen3NextModel, Qwen3NextConfig
>>> # Initializing a Qwen3Next style configuration
>>> configuration = Qwen3NextConfig()
>>> # Initializing a model from the Qwen3-Next-80B-A3B style configuration
>>> model = Qwen3NextModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "qwen3_next"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=151936,
hidden_size=2048,
intermediate_size=5632,
num_hidden_layers=48,
num_attention_heads=16,
num_key_value_heads=2,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
partial_rotary_factor=0.25,
attention_bias=False,
attention_dropout=0.0,
head_dim=256,
linear_conv_kernel_dim=4,
linear_key_head_dim=128,
linear_value_head_dim=128,
linear_num_key_heads=16,
linear_num_value_heads=32,
decoder_sparse_step=1,
moe_intermediate_size=512,
shared_expert_intermediate_size=512,
num_experts_per_tok=10,
num_experts=512,
norm_topk_prob=True,
output_router_logits=False,
router_aux_loss_coef=0.001,
mlp_only_layers=[],
layer_types=None,
**kwargs,
):
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.partial_rotary_factor = partial_rotary_factor
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.head_dim = head_dim
rope_config_validation(self)
# linear attention (gdn now part)
self.linear_conv_kernel_dim = linear_conv_kernel_dim
self.linear_key_head_dim = linear_key_head_dim
self.linear_value_head_dim = linear_value_head_dim
self.linear_num_key_heads = linear_num_key_heads
self.linear_num_value_heads = linear_num_value_heads
# MoE arguments
self.decoder_sparse_step = decoder_sparse_step
self.moe_intermediate_size = moe_intermediate_size
self.shared_expert_intermediate_size = shared_expert_intermediate_size
self.num_experts_per_tok = num_experts_per_tok
self.num_experts = num_experts
self.norm_topk_prob = norm_topk_prob
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.mlp_only_layers = mlp_only_layers
@property
def layers_block_type(self):
layer_type_list = []
for l in range(self.num_hidden_layers):
if (l + 1) % self.full_attention_interval == 0:
layer_type_list.append(HybridLayerType.full_attention.value)
else:
layer_type_list.append(HybridLayerType.linear_attention.value)
return layer_type_list
@property
def linear_layer_ids(self):
return [
i
for i, type_value in enumerate(self.layers_block_type)
if type_value == HybridLayerType.linear_attention.value
]
@property
def full_attention_layer_ids(self):
return [
i
for i, type_value in enumerate(self.layers_block_type)
if type_value == HybridLayerType.full_attention.value
]
@property
def hybrid_gdn_params(self):
world_size = get_attention_tp_size()
conv_dim = (
self.linear_key_head_dim * self.linear_num_key_heads * 2
+ self.linear_value_head_dim * self.linear_num_value_heads
)
conv_state_shape = (
divide(conv_dim, world_size),
self.linear_conv_kernel_dim - 1,
)
temporal_state_shape = (
divide(self.linear_num_value_heads, world_size),
self.linear_key_head_dim,
self.linear_value_head_dim,
)
conv_dtype = torch.bfloat16
dtype_map = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
mamba_layers = self.linear_layer_ids
return (
conv_state_shape,
temporal_state_shape,
conv_dtype,
ssm_dtype,
mamba_layers,
)
@property
def mamba_cache_per_req(self):
conv_state_shape, temporal_state_shape, conv_dtype, ssm_dtype, mamba_layers = (
self.hybrid_gdn_params
)
mamba_layers_len = len(mamba_layers)
return (
int(np.prod(conv_state_shape)) * conv_dtype.itemsize
+ int(np.prod(temporal_state_shape)) * ssm_dtype.itemsize
) * mamba_layers_len
......@@ -42,6 +42,7 @@ from sglang.srt.configs import (
KimiVLConfig,
LongcatFlashConfig,
MultiModalityConfig,
Qwen3NextConfig,
Step3VLConfig,
)
from sglang.srt.configs.internvl import InternVLChatConfig
......@@ -58,6 +59,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
InternVLChatConfig.model_type: InternVLChatConfig,
Step3VLConfig.model_type: Step3VLConfig,
LongcatFlashConfig.model_type: LongcatFlashConfig,
Qwen3NextConfig.model_type: Qwen3NextConfig,
}
for name, cls in _CONFIG_REGISTRY.items():
......
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Tri Dao.
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
from typing import Optional
import torch
from sgl_kernel import causal_conv1d_fwd
from sgl_kernel import causal_conv1d_update as causal_conv1d_update_kernel
PAD_SLOT_ID = -1
def causal_conv1d_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
query_start_loc: Optional[torch.Tensor] = None,
cache_indices: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None,
conv_states: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
pad_slot_id: int = PAD_SLOT_ID,
):
"""
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
sequences are concatenated from left to right for varlen
weight: (dim, width)
bias: (dim,)
query_start_loc: (batch + 1) int32
The cumulative sequence lengths of the sequences in
the batch, used to index into sequence. prepended by 0.
for example: query_start_loc = torch.Tensor([0,10,16,17]),
x.shape=(dim,17)
cache_indices: (batch) int32
indicates the corresponding state index,
like so: conv_state = conv_states[cache_indices[batch_id]]
has_initial_state: (batch) bool
indicates whether should the kernel take the current state as initial
state for the calculations
conv_states: (...,dim,width - 1) itype
updated inplace if provided
activation: either None or "silu" or "swish"
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
if x.stride(-1) != 1:
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None
causal_conv1d_fwd(
x,
weight,
bias,
conv_states,
query_start_loc,
cache_indices,
has_initial_state,
activation in ["silu", "swish"],
pad_slot_id,
)
return x
def causal_conv1d_update(
x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Optional[str] = None,
cache_seqlens: Optional[torch.Tensor] = None,
conv_state_indices: Optional[torch.Tensor] = None,
pad_slot_id: int = PAD_SLOT_ID,
):
"""
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the conv_state
starting at the index
@cache_seqlens % state_len.
conv_state_indices: (batch,), dtype int32
If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario.
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim) or (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError(
f"activation must be None, silu, or swish, actual: {activation}"
)
activation_val = activation in ["silu", "swish"]
unsqueeze = x.dim() == 2
if unsqueeze:
x = x.unsqueeze(-1)
causal_conv1d_update_kernel(
x,
conv_state,
weight,
bias,
activation_val,
cache_seqlens,
conv_state_indices,
pad_slot_id,
)
if unsqueeze:
x = x.squeeze(-1)
return x
from typing import Callable, List, Tuple
import torch
LoaderFunction = Callable[[torch.Tensor, torch.Tensor], None]
def mamba_v2_sharded_weight_loader(
shard_spec: List[Tuple[int, int, float]],
tp_size: int,
tp_rank: int,
) -> LoaderFunction:
"""Create a weight loader for mamba v2. This ensures that the projections
are correctly sharded so that they can be split into x, B, C. It also
ensures the the all the groups corresponding to a head shard is placed
together with it.
"""
def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
# - track boundary of (sharded) param, and loaded_weight, respectively
boundary, loaded_boundary = 0, 0
# - iterate over the shard specs
for full_dim, extra, duplicate_groups in shard_spec:
# - full dim is the model dim (before TP).
# - extra > 0, means there is expected overall increase
# of dimensions. This is so because of replication.
# - ratio is used map the tp_rank to the actual shard
# rank. This is useful when there is replication of
# groups to accompany head shards.
# - size of the loaded shard
shard_size = full_dim // tp_size
# - compute the rank into the loaded shard.
# - if there is replication, different TP shards will
# take from the same rank.
# NOTE: currently we only support duplication
# in the case where num_groups == 1
rank = 0 if duplicate_groups else tp_rank
# - leftmost boundary index into loaded weight.
loaded_skip = rank * shard_size
loaded_start_idx = loaded_boundary + loaded_skip
# - take these many dims from the loaded weight.
take = min(shard_size, full_dim - extra - loaded_skip)
# - always shard on dim 0
# - the ignore is for a mundane mypy error as it does not
# seem to handle slices well.
# https://github.com/python/mypy/issues/2410
param.data[
boundary : (boundary + take), ... # type: ignore[misc]
] = loaded_weight[
loaded_start_idx : (loaded_start_idx + take) # type: ignore[misc]
] # type: ignore[misc]
# move indexing boundaries
boundary += shard_size
loaded_boundary += full_dim - extra
return loader
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 3
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 3
}
}
......@@ -38,7 +38,7 @@ import threading
from enum import Enum, auto
from http import HTTPStatus
from itertools import chain
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
import numpy as np
import torch
......@@ -59,7 +59,7 @@ from sglang.srt.mem_cache.allocator import (
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
......@@ -962,8 +962,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def is_empty(self):
return len(self.reqs) == 0
def alloc_req_slots(self, num_reqs: int):
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None):
if isinstance(self.req_to_token_pool, HybridReqToTokenPool):
req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs)
else:
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
if req_pool_indices is None:
raise RuntimeError(
"alloc_req_slots runs out of memory. "
......@@ -1138,7 +1141,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Allocate req slots
bs = len(self.reqs)
req_pool_indices = self.alloc_req_slots(bs)
req_pool_indices = self.alloc_req_slots(bs, self.reqs)
# Init tensors
reqs = self.reqs
......
......@@ -1540,7 +1540,12 @@ class Scheduler(
chunked_req_to_exclude.add(self.chunked_req)
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
# chunked request keeps its rid but will get a new req_pool_idx
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
if self.tp_worker.worker.model_runner.is_hybrid_gdn:
self.req_to_token_pool.free(
self.chunked_req.req_pool_idx, free_mamba_cache=False
)
else:
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
if self.last_batch and self.last_batch.forward_mode.is_extend():
if self.last_batch.chunked_req is not None:
# In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
......
......@@ -102,6 +102,204 @@ class ReqToTokenPool:
self.free_slots = list(range(self.size))
class MambaPool:
def __init__(
self,
size: int,
conv_dtype: torch.dtype,
ssm_dtype: torch.dtype,
num_mamba_layers: int,
conv_state_shape: Tuple[int, int],
temporal_state_shape: Tuple[int, int],
device: str,
speculative_num_draft_tokens: Optional[int] = None,
):
conv_state = torch.zeros(
size=(num_mamba_layers, size + 1) + conv_state_shape,
dtype=conv_dtype,
device=device,
)
temporal_state = torch.zeros(
size=(num_mamba_layers, size + 1) + temporal_state_shape,
dtype=ssm_dtype,
device=device,
)
if speculative_num_draft_tokens is not None:
mixed_qkv_cache = torch.empty(
size=(
num_mamba_layers,
size + 1,
speculative_num_draft_tokens,
conv_state_shape[0],
),
dtype=conv_dtype,
device="cuda",
)
# Cache intermediate SSM states per draft token during target verify
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
intermediate_ssm_state_cache = torch.empty(
size=(
num_mamba_layers,
size + 1,
speculative_num_draft_tokens,
temporal_state_shape[0],
temporal_state_shape[1],
temporal_state_shape[2],
),
dtype=ssm_dtype,
device="cuda",
)
self.mamba_cache = (
conv_state,
temporal_state,
mixed_qkv_cache,
intermediate_ssm_state_cache,
)
else:
self.mamba_cache = (conv_state, temporal_state)
self.size = size
self.free_slots = list(range(size))
self.mem_usage = self.get_mamba_size() / GB
logger.info(
f"Mamba Cache is allocated. "
f"conv_state size: {conv_state.numel() * conv_state.itemsize / GB:.2f}GB, "
f"ssm_state size: {temporal_state.numel() * temporal_state.itemsize / GB:.2f}GB "
)
def get_mamba_params_all_layers(self):
return [self.mamba_cache[i] for i in range(len(self.mamba_cache))]
def get_mamba_params(self, layer_id: int):
return [self.mamba_cache[i][layer_id] for i in range(len(self.mamba_cache))]
def get_mamba_size(self):
return (
np.prod(self.mamba_cache[0].shape) * self.mamba_cache[0].dtype.itemsize
+ np.prod(self.mamba_cache[1].shape) * self.mamba_cache[1].dtype.itemsize
)
def available_size(self):
return len(self.free_slots)
def alloc(self, need_size: int) -> Optional[List[int]]:
if need_size > len(self.free_slots):
return None
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
return select_index
def free(self, free_index: Union[int, List[int]]):
if isinstance(free_index, (int,)):
self.free_slots.append(free_index)
else:
self.free_slots.extend(free_index)
self.mamba_cache[0][:, free_index] = self.mamba_cache[1][:, free_index] = 0
def clear(self):
self.free_slots = list(range(self.size))
class HybridReqToTokenPool(ReqToTokenPool):
"""A memory pool that maps a request to its token locations."""
def __init__(
self,
size: int,
max_context_len: int,
device: str,
enable_memory_saver: bool,
conv_dtype: torch.dtype,
ssm_dtype: torch.dtype,
mamba_layers: List[int],
conv_state_shape: Tuple[int, int],
temporal_state_shape: Tuple[int, int],
speculative_num_draft_tokens: int,
):
super().__init__(
size=size,
max_context_len=max_context_len,
device=device,
enable_memory_saver=enable_memory_saver,
)
self.mamba_pool = MambaPool(
size,
conv_dtype,
ssm_dtype,
len(mamba_layers),
conv_state_shape,
temporal_state_shape,
device,
speculative_num_draft_tokens,
)
self.mamba_map = {layer_id: i for i, layer_id in enumerate(mamba_layers)}
self.device = device
self.req_index_to_mamba_index_mapping: torch.Tensor = torch.empty(
size, dtype=torch.int32, device=self.device
)
self.rid_to_mamba_index_mapping: Dict[str, int] = {}
self.mamba_index_to_rid_mapping: Dict[int, str] = {}
# For chunk prefill req, we do not need to allocate mamba cache,
# We could use allocated mamba cache instead.
def alloc(
self, need_size: int, reqs: Optional[List["Req"]] = None
) -> Optional[List[int]]:
select_index = super().alloc(need_size)
if select_index == None:
return None
mamba_index = []
for req in reqs:
rid = req.rid
if rid in self.rid_to_mamba_index_mapping:
mid = self.rid_to_mamba_index_mapping[rid]
elif (mid := self.mamba_pool.alloc(1)) is not None:
mid = mid[0]
self.rid_to_mamba_index_mapping[rid] = mid
self.mamba_index_to_rid_mapping[mid] = rid
mamba_index.append(mid)
assert len(select_index) == len(
mamba_index
), f"Not enough space for mamba cache, try to increase --max-mamba-cache-size."
self.req_index_to_mamba_index_mapping[select_index] = torch.tensor(
mamba_index, dtype=torch.int32, device=self.device
)
return select_index
def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor:
return self.req_index_to_mamba_index_mapping[req_indices]
def get_mamba_params(self, layer_id: int):
assert layer_id in self.mamba_map
return self.mamba_pool.get_mamba_params(self.mamba_map[layer_id])
def get_mamba_params_all_layers(self):
return self.mamba_pool.get_mamba_params_all_layers()
# For chunk prefill, we can not free mamba cache, we need use it in the future
def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True):
super().free(free_index)
if free_mamba_cache:
mamba_index = self.req_index_to_mamba_index_mapping[free_index]
mamba_index_list = mamba_index.tolist()
if isinstance(mamba_index_list, int):
mamba_index_list = [mamba_index_list]
self.mamba_pool.free(mamba_index_list)
for mid in mamba_index_list:
rid = self.mamba_index_to_rid_mapping[mid]
self.mamba_index_to_rid_mapping.pop(mid)
self.rid_to_mamba_index_mapping.pop(rid)
def clear(self):
super().clear()
self.mamba_pool.clear()
class KVCache(abc.ABC):
@abc.abstractmethod
def __init__(
......@@ -441,6 +639,88 @@ class MHATokenToKVPool(KVCache):
)
class HybridLinearKVPool(KVCache):
"""KV cache with separate pools for full and linear attention layers."""
def __init__(
self,
size: int,
dtype: torch.dtype,
head_num: int,
head_dim: int,
full_attention_layer_ids: List[int],
enable_kvcache_transpose: bool,
device: str,
):
self.size = size
self.dtype = dtype
self.device = device
self.full_layer_nums = len(full_attention_layer_ids)
self.page_size = 1
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
assert not enable_kvcache_transpose
self.full_kv_pool = MHATokenToKVPool(
size=size,
page_size=self.page_size,
dtype=dtype,
head_num=head_num,
head_dim=head_dim,
layer_num=self.full_layer_nums,
device=device,
enable_memory_saver=False,
)
self.full_attention_layer_id_mapping = {
id: i for i, id in enumerate(full_attention_layer_ids)
}
k_size, v_size = self.get_kv_size_bytes()
self.mem_usage = (k_size + v_size) / GB
def get_kv_size_bytes(self):
return self.full_kv_pool.get_kv_size_bytes()
def get_contiguous_buf_infos(self):
return self.full_kv_pool.get_contiguous_buf_infos()
def _transfer_full_attention_id(self, layer_id: int):
if layer_id not in self.full_attention_layer_id_mapping:
raise ValueError(
f"{layer_id=} not in full attention layers: {self.full_attention_layer_id_mapping.keys()}"
)
return self.full_attention_layer_id_mapping[layer_id]
def get_key_buffer(self, layer_id: int):
layer_id = self._transfer_full_attention_id(layer_id)
return self.full_kv_pool.get_key_buffer(layer_id)
def get_value_buffer(self, layer_id: int):
layer_id = self._transfer_full_attention_id(layer_id)
return self.full_kv_pool.get_value_buffer(layer_id)
def get_kv_buffer(self, layer_id: int):
layer_id = self._transfer_full_attention_id(layer_id)
return self.full_kv_pool.get_kv_buffer(layer_id)
def set_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
k_scale: float = 1.0,
v_scale: float = 1.0,
):
layer_id = self._transfer_full_attention_id(layer.layer_id)
self.full_kv_pool.set_kv_buffer(
None,
loc,
cache_k,
cache_v,
k_scale,
v_scale,
layer_id_override=layer_id,
)
class SWAKVPool(KVCache):
"""KV cache with separate pools for full and SWA attention layers."""
......
......@@ -85,6 +85,8 @@ from sglang.srt.mem_cache.memory_pool import (
AscendMLAPagedTokenToKVPool,
AscendTokenToKVPool,
DoubleSparseTokenToKVPool,
HybridLinearKVPool,
HybridReqToTokenPool,
MHATokenToKVPool,
MLATokenToKVPool,
ReqToTokenPool,
......@@ -303,6 +305,26 @@ class ModelRunner:
if architectures and not any("Llama4" in arch for arch in architectures):
self.is_hybrid = self.model_config.is_hybrid = True
if self.is_hybrid_gdn:
logger.warning("Hybrid GDN model detected, disable radix cache")
self.server_args.disable_radix_cache = True
self.server_args.attention_backend = "hybrid_linear_attn"
if self.server_args.max_mamba_cache_size is None:
if self.server_args.max_running_requests is not None:
self.server_args.max_mamba_cache_size = (
self.server_args.max_running_requests
)
else:
self.server_args.max_mamba_cache_size = 512
self.server_args.max_mamba_cache_size = (
self.server_args.max_mamba_cache_size
// (
self.server_args.dp_size
if self.server_args.enable_dp_attention
else 1
)
)
# For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft
# models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to
# determine the number of layers.
......@@ -1080,6 +1102,8 @@ class ModelRunner:
"num_nextn_predict_layers",
self.num_effective_layers,
)
elif self.is_hybrid_gdn:
num_layers = len(self.model_config.hf_config.full_attention_layer_ids)
else:
num_layers = self.num_effective_layers
if self.use_mla_backend:
......@@ -1099,9 +1123,22 @@ class ModelRunner:
rest_memory = available_gpu_memory - total_gpu_memory * (
1 - self.mem_fraction_static
)
if self.is_hybrid_gdn:
rest_memory -= (
self.server_args.max_mamba_cache_size
* self.model_config.hf_config.mamba_cache_per_req
/ (1 << 30)
)
max_num_token = int(rest_memory * (1 << 30) // cell_size)
return max_num_token
@property
def is_hybrid_gdn(self):
return self.model_config.hf_config.architectures[0] in [
"Qwen3NextForCausalLM",
"Qwen3NextForCausalLMMTP",
]
def set_num_token_hybrid(self):
if (
"Llama4ForConditionalGeneration"
......@@ -1222,6 +1259,8 @@ class ModelRunner:
),
4096,
)
if self.is_hybrid_gdn:
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
if not self.spec_algorithm.is_none():
if self.is_draft_worker:
......@@ -1300,6 +1339,28 @@ class ModelRunner:
enable_memory_saver=self.server_args.enable_memory_saver,
pre_alloc_size=pre_alloc_size,
)
elif self.is_hybrid_gdn:
config = self.model_config.hf_config
(
conv_state_shape,
temporal_state_shape,
conv_dtype,
ssm_dtype,
mamba_layers,
) = config.hybrid_gdn_params
self.req_to_token_pool = HybridReqToTokenPool(
size=max_num_reqs,
max_context_len=self.model_config.context_len
+ extra_max_context_len,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
conv_state_shape=conv_state_shape,
temporal_state_shape=temporal_state_shape,
conv_dtype=conv_dtype,
ssm_dtype=ssm_dtype,
mamba_layers=mamba_layers,
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
)
else:
self.req_to_token_pool = ReqToTokenPool(
size=max_num_reqs,
......@@ -1382,6 +1443,23 @@ class ModelRunner:
enable_kvcache_transpose=False,
device=self.device,
)
elif self.is_hybrid_gdn:
self.token_to_kv_pool = HybridLinearKVPool(
size=self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(
get_attention_tp_size()
),
head_dim=self.model_config.head_dim,
# if draft worker, we only need 1 attention layer's kv pool
full_attention_layer_ids=(
[0]
if self.is_draft_worker
else self.model_config.hf_config.full_attention_layer_ids
),
enable_kvcache_transpose=False,
device=self.device,
)
else:
self.token_to_kv_pool = MHATokenToKVPool(
self.max_total_num_tokens,
......@@ -1615,6 +1693,24 @@ class ModelRunner:
)
return DualChunkFlashAttentionBackend(self)
elif backend_str == "hybrid_linear_attn":
assert (
self.is_hybrid_gdn
), "hybrid_linear_attn backend can only be used with hybrid GDN models."
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
)
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
HybridLinearAttnBackend,
MambaAttnBackend,
)
full_attn_backend = FlashAttentionBackend(self)
linear_attn_backend = MambaAttnBackend(self)
full_attn_layers = self.model_config.hf_config.full_attention_layer_ids
return HybridLinearAttnBackend(
full_attn_backend, linear_attn_backend, full_attn_layers
)
else:
raise ValueError(f"Invalid attention backend: {backend_str}")
......
......@@ -35,6 +35,7 @@ from tqdm.auto import tqdm
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.layers.dp_attention import get_attention_tp_rank
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config
from sglang.srt.utils import print_warning_once
......@@ -680,7 +681,7 @@ def sharded_weight_loader(shard_axis: int) -> LoaderFunction:
"""Create a weight loader that shards the weights along the given axis"""
def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
tp_rank = get_tensor_model_parallel_rank()
tp_rank = get_attention_tp_rank()
shard_size = param.data.shape[shard_axis]
start_idx = tp_rank * shard_size
......
This diff is collapsed.
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Inference-only Qwen3Next MTP Speculative Decoding."""
import logging
from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.qwen3_moe import Qwen3MoeModel
from sglang.srt.models.qwen3_next import Qwen3NextForCausalLM, Qwen3NextModel
from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__)
class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config
# if not set, model load will be broken in Qwen3NextForCausalLM load_weights()
self.pp_group = get_pp_group()
# self.determine_num_fused_shared_experts("Qwen3NextForCausalLMMTP")
# currently based on the provided ckpt, we:
# (1) do not use_dedicated_mtp_embeddings provided in ckpt since not provided and directly use the target model embeddings
# (2) hardcode bias=False since not provided
self.fc = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
if getattr(
config, "use_gemma_rms_norm", getattr(config, "apply_layernorm_1p", False)
):
logger.warning_once(
"Using Gemma RMSNorm for input normalization and post attn normalization."
)
RMSNorm_cls = GemmaRMSNorm
else:
RMSNorm_cls = RMSNorm
self.pre_fc_norm_embedding = RMSNorm_cls(
config.hidden_size, config.rms_norm_eps
)
self.pre_fc_norm_hidden = RMSNorm_cls(config.hidden_size, config.rms_norm_eps)
config.num_hidden_layers = 1
config.full_attention_interval = 1
self.model = Qwen3NextModel(
config, quant_config, prefix=add_prefix("model", prefix)
)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("model.shared_head.head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: Optional[torch.Tensor] = None,
**kwargs,
):
if input_embeds is None:
input_embeds = self.model.embed_tokens(input_ids)
input_embeds = self.pre_fc_norm_embedding(input_embeds)
hidden_states = self.pre_fc_norm_hidden(forward_batch.spec_info.hidden_states)
hidden_states = self.fc(torch.cat((input_embeds, hidden_states), dim=-1))
hidden_states = self.model(
input_ids,
positions,
forward_batch,
hidden_states,
)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(
self, weights: Iterable[Tuple[str, torch.Tensor]], is_mtp: bool = False
):
super().load_weights(weights, is_mtp=True)
EntryClass = [Qwen3NextForCausalLMMTP]
......@@ -95,6 +95,7 @@ ATTENTION_BACKEND_CHOICES = [
"trtllm_mla",
"trtllm_mha",
"dual_chunk_flash_attn",
"hybrid_linear_attn",
# AMD specific
"aiter",
"wave",
......@@ -390,6 +391,10 @@ class ServerArgs:
enable_pdmux: bool = False
sm_group_num: int = 3
# Mamba cache
max_mamba_cache_size: Optional[int] = None
mamba_ssm_dtype: str = "float32"
# Deprecated arguments
enable_ep_moe: bool = False
enable_deepep_moe: bool = False
......@@ -835,6 +840,8 @@ class ServerArgs:
os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
"1" if self.enable_torch_compile else "0"
)
os.environ["SGLANG_MAMBA_SSM_DTYPE"] = self.mamba_ssm_dtype
# Set env var before grammar backends init
os.environ["SGLANG_DISABLE_OUTLINES_DISK_CACHE"] = (
"1" if self.disable_outlines_disk_cache else "0"
......@@ -1714,7 +1721,20 @@ class ServerArgs:
default=ServerArgs.moe_dense_tp_size,
help="TP size for MoE dense MLP layers. This flag is useful when, with large TP size, there are errors caused by weights in MLP layers having dimension smaller than the min dimension GEMM supports.",
)
# Mamba Cache
parser.add_argument(
"--max-mamba-cache-size",
type=int,
default=ServerArgs.max_mamba_cache_size,
help="It is used for mamba cache memory static allocation.",
)
parser.add_argument(
"--mamba-ssm-dtype",
type=str,
default=ServerArgs.mamba_ssm_dtype,
choices=["float32", "bfloat16"],
help="It is used to tune mamba ssm dtype",
)
# Hierarchical cache
parser.add_argument(
"--enable-hierarchical-cache",
......
import bisect
from typing import TYPE_CHECKING, Callable
import torch
import torch.nn.functional as F
from sglang.srt.layers.attention.fla.fused_recurrent import (
fused_recurrent_gated_delta_rule_update,
)
from sglang.srt.layers.attention.mamba.causal_conv1d import causal_conv1d_fn
from sglang.srt.model_executor.cuda_graph_runner import (
CUDA_GRAPH_CAPTURE_FAILED_MSG,
CudaGraphRunner,
get_batch_sizes_to_capture,
get_global_graph_memory_pool,
model_capture_mode,
set_global_graph_memory_pool,
)
from sglang.srt.models.qwen3_next import Qwen3HybridLinearDecoderLayer
if TYPE_CHECKING:
from sglang.srt.speculative.eagle_worker import EAGLEWorker
class MambaStateUpdateCudaGraphRunner:
def __init__(self, eagle_worker: "EAGLEWorker"):
self.eagle_worker = eagle_worker
model_runner = eagle_worker.target_worker.model_runner
self.model_runner = model_runner
self.attn_backend = model_runner.attn_backend.attn_backend_list[1]
self.req_to_token_pool = self.attn_backend.req_to_token_pool
self.graphs = {}
self.output_buffers = {}
self.graph_input_buffer = None
self.stream = torch.cuda.Stream()
self.model = model_runner.model
self.enable_profile_cuda_graph = (
model_runner.server_args.enable_profile_cuda_graph
)
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
self.max_bs = self.capture_bs[-1]
self.init_cuda_graph_state()
# Capture
try:
with model_capture_mode():
self.capture()
except RuntimeError as e:
raise Exception(
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
)
def init_cuda_graph_state(self):
self.mamba_cache = self.req_to_token_pool.mamba_pool.mamba_cache
self.num_tokens_per_bs = self.max_accepted_tokens = self.mamba_cache[2].shape[2]
num_mamba_layers = self.mamba_cache[0].shape[0]
conv_dtype = torch.bfloat16
conv_shape = self.mamba_cache[0].shape[2]
total_token_number = self.max_accepted_tokens * self.max_bs
self.mixed_qkv_cache = torch.empty(
size=(
num_mamba_layers,
total_token_number,
conv_shape,
),
dtype=conv_dtype,
device="cuda",
)
self.query_start_loc = torch.zeros(
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
)
self.state_indices = torch.zeros(
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
)
self.has_initial_states = torch.ones(
self.max_bs, dtype=torch.bool, device="cuda"
)
def capture(self):
CudaGraphRunner.capture(self)
def capture_one_batch_size(self, bs: int, forward: Callable):
"""
Capture CUDA Graph for a typical workload
"""
graph = torch.cuda.CUDAGraph()
stream = self.stream
total_token_number = bs * self.max_accepted_tokens
mixed_qkvs = self.mixed_qkv_cache[:, :total_token_number]
query_start_loc = self.query_start_loc[: bs + 1]
state_indices = self.state_indices[:bs]
has_initial_states = self.has_initial_states[:bs]
mamba_caches = self.req_to_token_pool.get_mamba_params_all_layers()
conv_states = mamba_caches[0]
mamba_map = self.req_to_token_pool.mamba_map
def run_once():
for i in range(len(self.model.model.layers)):
layer = self.model.model.layers[i]
if not isinstance(layer, Qwen3HybridLinearDecoderLayer):
continue
conv_weights = layer.linear_attn.conv1d.weight.view(
layer.linear_attn.conv1d.weight.size(0),
layer.linear_attn.conv1d.weight.size(2),
)
layer_id = mamba_map[i]
causal_conv1d_fn(
mixed_qkvs[layer_id].transpose(0, 1),
conv_weights,
layer.linear_attn.conv1d.bias,
activation=layer.linear_attn.activation,
conv_states=conv_states[layer_id],
has_initial_state=has_initial_states,
cache_indices=state_indices,
query_start_loc=query_start_loc,
)
return None
for _ in range(2):
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
run_once()
with torch.cuda.graph(
graph, pool=get_global_graph_memory_pool(), stream=stream
):
out = run_once()
set_global_graph_memory_pool(graph.pool())
return graph, out
def can_run(self, accepted_length):
bs = accepted_length.shape[0]
return bs <= self.max_bs
def replay_repare(self, accepted_length):
request_number = accepted_length.shape[0]
# QQ: step = spec num_draft token num
num_draft_tokens = self.req_to_token_pool.mamba_pool.mamba_cache[2].shape[2]
query_start_loc = accepted_length.cumsum(-1, dtype=accepted_length.dtype)
query_start_loc = torch.cat(
[
torch.zeros(
1,
dtype=query_start_loc.dtype,
device=query_start_loc.device,
),
query_start_loc,
]
)
mask = torch.arange(num_draft_tokens, device=accepted_length.device).unsqueeze(
0
) < accepted_length.unsqueeze(1)
state_indices_tensor = self.attn_backend.forward_metadata.mamba_cache_indices[
:request_number
]
mamba_caches = self.req_to_token_pool.get_mamba_params_all_layers()
_, ssm_states, mix_qkv_cache, intermediate_state_cache = mamba_caches
mixed_qkvs = mamba_caches[2][:, state_indices_tensor][:, mask]
self.mixed_qkv_cache[:, : mixed_qkvs.shape[1]].copy_(mixed_qkvs)
self.query_start_loc[: request_number + 1] = query_start_loc
self.query_start_loc[request_number + 1 :] = self.query_start_loc[
request_number
]
self.state_indices[:request_number] = state_indices_tensor
self.state_indices[request_number:] = -1
valid_mask = accepted_length > 0
if intermediate_state_cache is not None:
last_steps = (accepted_length - 1).to(torch.int64)
valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64)
ssm_states[:, valid_state_indices, :] = intermediate_state_cache[
:, valid_state_indices, last_steps
].to(ssm_states.dtype)
def replay(self, accepted_length):
# batch_size and num_seqs can be different in case there are finished examples
# in the batch, which will not be counted as num_seqs
raw_bs = accepted_length.shape[0]
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
self.replay_repare(accepted_length)
# Replay
self.graphs[bs].replay()
......@@ -214,6 +214,7 @@ class EAGLEWorker(TpModelWorker):
"triton": self._create_triton_decode_backend,
"aiter": self._create_aiter_decode_backend,
"fa3": self._create_fa3_decode_backend,
"hybrid_linear_attn": self._create_fa3_decode_backend,
"flashmla": self._create_flashmla_decode_backend,
"trtllm_mha": self._create_trtllm_mha_decode_backend,
"trtllm_mla": self._create_trtllm_mla_decode_backend,
......@@ -231,6 +232,7 @@ class EAGLEWorker(TpModelWorker):
"triton": self._create_triton_prefill_backend,
"aiter": self._create_aiter_prefill_backend,
"fa3": self._create_fa3_prefill_backend,
"hybrid_linear_attn": self._create_fa3_prefill_backend,
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
}
......@@ -405,6 +407,15 @@ class EAGLEWorker(TpModelWorker):
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
)
if self.target_worker.model_runner.is_hybrid_gdn:
from sglang.srt.speculative.eagle_target_verify_cuda_graph_runner import (
MambaStateUpdateCudaGraphRunner,
)
self.cuda_graph_runner_for_target_verify = MambaStateUpdateCudaGraphRunner(
self
)
@property
def draft_model_runner(self):
return self.model_runner
......@@ -826,6 +837,24 @@ class EAGLEWorker(TpModelWorker):
]
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
# QQ: can be optimized
if self.target_worker.model_runner.is_hybrid_gdn:
# res.draft_input.accept_length is on GPU but may be empty for last verify?
accepted_length = (
torch.tensor(
res.accept_length_per_req_cpu,
device=logits_output.hidden_states.device,
dtype=torch.int32,
)
+ 1
)
if self.cuda_graph_runner_for_target_verify.can_run(accepted_length):
self.cuda_graph_runner_for_target_verify.replay(accepted_length)
else:
self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify(
accepted_length, self.target_worker.model_runner.model
)
if batch.return_logprob:
self.add_logprob_values(batch, res, logits_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment