"src/vscode:/vscode.git/clone" did not exist on "688a9228a820c419d9548ea2b44a6e4fe0a2cc1e"
Unverified Commit d6837aea authored by Netanel Haber's avatar Netanel Haber Committed by GitHub
Browse files

model: Support Hybrid Mamba2 NemotronHForCausalLM (nvidia/NVIDIA-Nemotron-Nano-9B-v2) (#10909)


Signed-off-by: default avatarNetanel Haber <nhaber@nvidia.com>
parent c882b5ae
...@@ -53,6 +53,7 @@ in the GitHub search bar. ...@@ -53,6 +53,7 @@ in the GitHub search bar.
| **Ling** (16.8B–290B) | `inclusionAI/Ling-lite`, `inclusionAI/Ling-plus` | InclusionAI’s open MoE models. Ling-Lite has 16.8B total / 2.75B active parameters, and Ling-Plus has 290B total / 28.8B active parameters. They are designed for high performance on NLP and complex reasoning tasks. | | **Ling** (16.8B–290B) | `inclusionAI/Ling-lite`, `inclusionAI/Ling-plus` | InclusionAI’s open MoE models. Ling-Lite has 16.8B total / 2.75B active parameters, and Ling-Plus has 290B total / 28.8B active parameters. They are designed for high performance on NLP and complex reasoning tasks. |
| **Granite 3.0, 3.1** (IBM) | `ibm-granite/granite-3.1-8b-instruct` | IBM's open dense foundation models optimized for reasoning, code, and business AI use cases. Integrated with Red Hat and watsonx systems. | | **Granite 3.0, 3.1** (IBM) | `ibm-granite/granite-3.1-8b-instruct` | IBM's open dense foundation models optimized for reasoning, code, and business AI use cases. Integrated with Red Hat and watsonx systems. |
| **Granite 3.0 MoE** (IBM) | `ibm-granite/granite-3.0-3b-a800m-instruct` | IBM’s Mixture-of-Experts models offering strong performance with cost-efficiency. MoE expert routing designed for enterprise deployment at scale. | | **Granite 3.0 MoE** (IBM) | `ibm-granite/granite-3.0-3b-a800m-instruct` | IBM’s Mixture-of-Experts models offering strong performance with cost-efficiency. MoE expert routing designed for enterprise deployment at scale. |
| **Llama Nemotron Super** (v1, v1.5, NVIDIA) | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, `nvidia/Llama-3_3-Nemotron-Super-49B-v1_5` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family builds on the strongest open models in the ecosystem by enhancing them with greater accuracy, efficiency, and transparency using NVIDIA open synthetic datasets, advanced techniques, and tools. This enables the creation of practical, right-sized, and high-performing AI agents. | | **Llama Nemotron Super** (v1, v1.5, NVIDIA) | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, `nvidia/Llama-3_3-Nemotron-Super-49B-v1_5` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family of multimodal models provides state-of-the-art reasoning models specifically designed for enterprise-ready AI agents. |
| **Llama Nemotron Ultra** (v1, NVIDIA) | `nvidia/Llama-3_1-Nemotron-Ultra-253B-v1` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family builds on the strongest open models in the ecosystem by enhancing them with greater accuracy, efficiency, and transparency using NVIDIA open synthetic datasets, advanced techniques, and tools. This enables the creation of practical, right-sized, and high-performing AI agents. | | **Llama Nemotron Ultra** (v1, NVIDIA) | `nvidia/Llama-3_1-Nemotron-Ultra-253B-v1` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family of multimodal models provides state-of-the-art reasoning models specifically designed for enterprise-ready AI agents. |
| **NVIDIA Nemotron Nano 2.0** | `nvidia/NVIDIA-Nemotron-Nano-9B-v2` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family of multimodal models provides state-of-the-art reasoning models specifically designed for enterprise-ready AI agents. `Nemotron-Nano-9B-v2` is a hybrid Mamba-Transformer language model designed to increase throughput for reasoning workloads while achieving state-of-the-art accuracy compared to similarly-sized models. |
| **StarCoder2** (3B-15B) | `bigcode/starcoder2-7b` | StarCoder2 is a family of open large language models (LLMs) specialized for code generation and understanding. It is the successor to StarCoder, jointly developed by the BigCode project (a collaboration between Hugging Face, ServiceNow Research, and other contributors). | | **StarCoder2** (3B-15B) | `bigcode/starcoder2-7b` | StarCoder2 is a family of open large language models (LLMs) specialized for code generation and understanding. It is the successor to StarCoder, jointly developed by the BigCode project (a collaboration between Hugging Face, ServiceNow Research, and other contributors). |
...@@ -9,6 +9,7 @@ from sglang.srt.configs.janus_pro import MultiModalityConfig ...@@ -9,6 +9,7 @@ from sglang.srt.configs.janus_pro import MultiModalityConfig
from sglang.srt.configs.kimi_vl import KimiVLConfig from sglang.srt.configs.kimi_vl import KimiVLConfig
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
from sglang.srt.configs.longcat_flash import LongcatFlashConfig from sglang.srt.configs.longcat_flash import LongcatFlashConfig
from sglang.srt.configs.nemotron_h import NemotronHConfig
from sglang.srt.configs.qwen3_next import Qwen3NextConfig from sglang.srt.configs.qwen3_next import Qwen3NextConfig
from sglang.srt.configs.step3_vl import ( from sglang.srt.configs.step3_vl import (
Step3TextConfig, Step3TextConfig,
...@@ -32,4 +33,5 @@ __all__ = [ ...@@ -32,4 +33,5 @@ __all__ = [
"DotsVLMConfig", "DotsVLMConfig",
"DotsOCRConfig", "DotsOCRConfig",
"FalconH1Config", "FalconH1Config",
"NemotronHConfig",
] ]
...@@ -15,16 +15,12 @@ ...@@ -15,16 +15,12 @@
"""Falcon-H1 model configuration""" """Falcon-H1 model configuration"""
import enum import enum
import os
import numpy as np
import torch
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging from transformers.utils import logging
from sglang.srt.distributed.utils import divide from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
from sglang.srt.layers.attention.mamba.mamba_utils import MambaStateShapeCalculator
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
get_attention_tp_size, get_attention_tp_size,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
...@@ -214,7 +210,7 @@ class FalconH1Config(PretrainedConfig): ...@@ -214,7 +210,7 @@ class FalconH1Config(PretrainedConfig):
self.rope_scaling = None self.rope_scaling = None
self.rope_scaling = rope_scaling self.rope_scaling = rope_scaling
self.projectors_bias = projectors_bias self.projectors_bias = projectors_bias
mamba_intermediate = ( self.mamba_intermediate = mamba_intermediate = (
mamba_expand * hidden_size if mamba_d_ssm is None else mamba_d_ssm mamba_expand * hidden_size if mamba_d_ssm is None else mamba_d_ssm
) )
...@@ -294,18 +290,6 @@ class FalconH1Config(PretrainedConfig): ...@@ -294,18 +290,6 @@ class FalconH1Config(PretrainedConfig):
def layers_block_type(self): def layers_block_type(self):
return ["falcon_h1" for i in range(self.num_hidden_layers)] return ["falcon_h1" for i in range(self.num_hidden_layers)]
@property
def mamba_cache_per_req(self):
conv_state_shape, temporal_state_shape, conv_dtype, ssm_dtype, mamba_layers = (
self.hybrid_gdn_params
)
mamba_layers_len = len(mamba_layers)
return (
int(np.prod(conv_state_shape)) * conv_dtype.itemsize
+ int(np.prod(temporal_state_shape)) * ssm_dtype.itemsize
) * mamba_layers_len
@property @property
def full_attention_layer_ids(self): def full_attention_layer_ids(self):
# For Falcon-H1, we do have attention on all layers # For Falcon-H1, we do have attention on all layers
...@@ -317,44 +301,14 @@ class FalconH1Config(PretrainedConfig): ...@@ -317,44 +301,14 @@ class FalconH1Config(PretrainedConfig):
return range(self.num_hidden_layers) return range(self.num_hidden_layers)
@property @property
def hybrid_gdn_params(self): def mamba2_cache_params(self):
world_size = get_tensor_model_parallel_world_size() shape = Mamba2StateShape.create(
tp_world_size=get_tensor_model_parallel_world_size(),
n_groups = self.mamba_n_groups intermediate_size=self.mamba_intermediate,
if self.mamba_n_groups % world_size != 0: n_groups=self.mamba_n_groups,
# - for TP we shard conv_dim by sharding on n_groups, num_heads=self.mamba_n_heads,
# - but if n_groups cannot divide tp_size, we need to head_dim=self.mamba_d_head,
# extend some extra groups state_size=self.mamba_d_state,
extra_groups = MambaStateShapeCalculator.extra_groups_for_head_shards( conv_kernel=self.mamba_d_conv,
self.mamba_n_groups, world_size
)
n_groups += extra_groups
conv_dim = self.mamba_d_ssm + 2 * n_groups * self.mamba_d_state
conv_state_shape = (
divide(conv_dim, world_size),
self.mamba_d_conv - 1,
)
# we TP-ize on the heads dimension
temporal_state_shape = (
self.mamba_d_state,
self.mamba_d_head,
divide(self.mamba_n_heads, world_size),
)
conv_dtype = torch.bfloat16
dtype_map = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
mamba_layers = self.linear_layer_ids
return (
conv_state_shape,
temporal_state_shape,
conv_dtype,
ssm_dtype,
mamba_layers,
) )
return Mamba2CacheParams(shape=shape, layers=self.linear_layer_ids)
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common config utils for mamba2 - NemotronH, FalconH1, Qwen3Next, etc."""
import os
from dataclasses import dataclass, field
import numpy as np
import torch
from sglang.srt.distributed.utils import divide
def extra_groups_for_head_shards(ngroups: int, tp_size: int):
"""Compute the increase in group numbers to account for
replication in order to accompany the head shards."""
# in the case ngoups % tp_size == 0, this will be zero
if ngroups % tp_size == 0:
return 0
# for n_groups == 1, this is exactly tp_size - n_groups
return tp_size - ngroups
@dataclass(kw_only=True, frozen=True)
class Mamba2StateShape:
conv: tuple[int, int]
temporal: tuple[int, int, int]
intermediate_size: int
conv_dim: int
ssm_state_size: int
num_heads: int
head_dim: int
state_size: int
conv_kernel: int
@staticmethod
def create(
*,
tp_world_size: int,
intermediate_size: int,
n_groups: int,
num_heads: int,
head_dim: int,
state_size: int,
conv_kernel: int,
) -> "Mamba2StateShape":
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
if n_groups % tp_world_size != 0:
extra_groups = extra_groups_for_head_shards(n_groups, tp_world_size)
n_groups += extra_groups
# heads and n_groups are TP-ed
conv_dim = intermediate_size + 2 * n_groups * state_size
# contiguous along 'dim' axis
conv_state_shape = divide(conv_dim, tp_world_size), conv_kernel - 1
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size)
return Mamba2StateShape(
conv=conv_state_shape,
temporal=temporal_state_shape,
intermediate_size=intermediate_size,
conv_dim=conv_dim,
ssm_state_size=state_size,
num_heads=num_heads,
head_dim=head_dim,
state_size=state_size,
conv_kernel=conv_kernel,
)
@dataclass(kw_only=True, frozen=True)
class Mamba2StateDType:
conv: torch.dtype
temporal: torch.dtype
CONV_DTYPE = torch.bfloat16
def mamba2_state_dtype() -> Mamba2StateDType:
dtype_map = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
return Mamba2StateDType(conv=CONV_DTYPE, temporal=ssm_dtype)
@dataclass(kw_only=True, frozen=True)
class Mamba2CacheParams:
shape: Mamba2StateShape
dtype: Mamba2StateDType = field(default_factory=mamba2_state_dtype)
layers: list[int]
@property
def mamba_cache_per_req(self) -> int:
return (
int(np.prod(self.shape.conv)) * self.dtype.conv.itemsize
+ int(np.prod(self.shape.temporal)) * self.dtype.temporal.itemsize
) * len(self.layers)
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/nemotron_h.py
"""NemotronH model configuration"""
import regex as re
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
from sglang.srt.layers.dp_attention import get_attention_tp_size
logger = logging.get_logger(__name__)
MAMBA = "M"
ATTENTION = "*"
MLP = "-"
class NemotronHConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a
[`NemotronHModel`]. It is used to instantiate a NemotronH model according
to the specified arguments, defining the model architecture. Instantiating
a configuration with the defaults will yield a similar configuration to
that of the NemotronH-v0.1 model.
Args:
vocab_size (`int`, *optional*, defaults to 131072):
Vocabulary size of the NemotronH model. Defines the number of
different tokens that can be represented by the `inputs_ids`
passed when calling [`NemotronHModel`]
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be
tied. Note that this is only relevant if the model has an output
word embedding layer.
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 21504):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 52):
Number of hidden layers in the Transformer encoder.
hybrid_override_pattern (`str`, *optional*, defaults to
`"M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-"`):
The pattern of the hybrid model. The pattern is a string of
characters where each character represents
M: Mamba2, *: Attention, -: MLP
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the
Transformer encoder.
attention_head_dim (`int`, *optional*, defaults to 128):
Dimension of each attention head.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to
implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use
Multi Head Attention (MHA), if `num_key_value_heads=1` the model
will use Multi Query Attention (MQA) otherwise GQA is used.
mlp_hidden_act (`str`, *optional*, defaults to "relu2"):
The non-linear activation function in the MLP layers.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use bias in attention layers.
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to use bias in MLP layers.
use_bias (`bool`, *optional*, defaults to `False`):
Whether to use bias in the model.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for
initializing all weight matrices.
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
The epsilon used by the layer normalization layers.
residual_in_fp32 (`bool`, *optional*, defaults to `False`):
Whether or not residuals should be in `float32`. If set to `False`
residuals will keep the same `dtype` as the rest of the model.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values
attentions (not used by all models). Only relevant if
`config.is_decoder=True`.
num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
Number of prompt logits to calculate during generation. If `None`,
all logits will be calculated. If an integer value, only last
`num_logits_to_keep` logits will be calculated.
pad_token_id (`int`, *optional*, defaults to 0):
The id of the padding token.
bos_token_id (`int`, *optional*, defaults to 1):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 2):
The id of the "end-of-sequence" token.
sliding_window (`int`, *optional*, defaults to None):
Sliding window attention window size.
max_position_embeddings (`int`, *optional*, defaults to 4096):
The maximum sequence length that this model might ever be used
with.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
hidden_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the hidden states.
use_mamba_kernels (`bool`, *optional*, defaults to `True`):
Flag indicating whether or not to use the fast mamba kernels.
These are available only if `mamba-ssm` and `causal-conv1d`
are installed, and the mamba modules are running on a CUDA device.
ssm_state_size (`int`, *optional*, defaults to 128):
The dimension of the mamba state space latents.
mamba_num_heads (`int`, *optional*, defaults to 128):
Number of heads in Mamba layers.
mamba_n_groups (`int`, *optional*, defaults to 8):
Number of groups in Mamba layers.
mamba_head_dim (`int`, *optional*, defaults to 64):
Dimension of each Mamba head.
mamba_d_conv (`int`, *optional*, defaults to 4):
The size of the mamba convolution kernel.
mamba_expand (`int`, *optional*, defaults to 2):
Expanding factor used to determine the mamba intermediate size.
mamba_hidden_act (`str`, *optional*, defaults to "silu"):
The non-linear activation function in the Mamba layers.
mamba_dt_min (`float`, *optional*, defaults to 0.001):
Minimum value for the time step in Mamba.
mamba_dt_max (`float`, *optional*, defaults to 0.1):
Maximum value for the time step in Mamba.
mamba_dt_limit (`tuple`, *optional*, defaults to (0.0, float("inf"))):
Limits for the time step in Mamba.
mamba_dt_init_floor (`float`, *optional*, defaults to 1e-4):
Floor value for time step initialization in Mamba.
mamba_conv_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias in the convolution layer of the mamba mixer
block.
mamba_proj_bias (`bool`, *optional*, defaults to `False`):
Whether to use bias in the input and output projections of the
mamba mixer block.
mamba_chunk_size (`int`, *optional*, defaults to 256):
Size of chunks for Mamba processing.
rescale_prenorm_residual (`bool`, *optional*, defaults to `True`):
Whether to rescale the pre-normalization residual connections.
"""
model_type = "nemotron_h"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=131072,
tie_word_embeddings=False,
hidden_size=4096,
intermediate_size=21504,
num_hidden_layers=52,
hybrid_override_pattern="M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-",
num_attention_heads=32,
head_dim=128,
num_key_value_heads=8, # nemo: num_query_groups
mlp_hidden_act="relu2",
attention_bias=False,
mlp_bias=False,
use_bias=False,
initializer_range=0.02, # nemo: init_method_std
layer_norm_epsilon=1e-5, # nemo: layernorm_epsilon
residual_in_fp32=False, # Megatron Core default value
use_cache=True,
num_logits_to_keep=1,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
sliding_window=None,
max_position_embeddings=4096,
attention_dropout=0.0,
hidden_dropout=0.0, # * ADDED
use_mamba_kernels=True,
ssm_state_size=128, # mamba_state_size
mamba_num_heads=128,
mamba_n_groups=8, # nemo: mamba_ssm_ngroups = num_heads
mamba_head_dim=64,
mamba_d_conv=4,
mamba_expand=2,
mamba_hidden_act="silu",
mamba_dt_min=0.001,
mamba_dt_max=0.1,
mamba_dt_limit=(0.0, float("inf")),
mamba_dt_init_floor=1e-4,
mamba_conv_bias=True,
mamba_proj_bias=False,
mamba_chunk_size=256,
rescale_prenorm_residual=True,
**kwargs,
):
self.vocab_size = vocab_size
self.tie_word_embeddings = tie_word_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.hybrid_override_pattern = hybrid_override_pattern
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.sliding_window = sliding_window
self.max_position_embeddings = max_position_embeddings
self.attention_dropout = attention_dropout
self.hidden_dropout = hidden_dropout
# Validate hybrid_override_pattern
# M: Mamba2, *: Attention, -: MLP
assert len(self.hybrid_override_pattern) == self.num_hidden_layers, (
"hybrid_override_pattern must have same length as " "num_hidden_layers"
)
assert re.match(r"^[*-M]+$", self.hybrid_override_pattern), (
"hybrid_override_pattern must only contain characters " "'M', '*', or '-'"
)
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.mlp_hidden_act = mlp_hidden_act
self.attention_bias = attention_bias
self.mlp_bias = mlp_bias
self.use_bias = use_bias
self.initializer_range = initializer_range
self.layer_norm_epsilon = layer_norm_epsilon
self.residual_in_fp32 = residual_in_fp32
self.use_cache = use_cache
self.num_logits_to_keep = num_logits_to_keep
self.use_mamba_kernels = use_mamba_kernels
self.mamba_n_groups = mamba_n_groups
self.mamba_head_dim = mamba_head_dim
self.ssm_state_size = ssm_state_size
self.mamba_num_heads = mamba_num_heads
self.conv_kernel = mamba_d_conv
self.expand = mamba_expand
self.mamba_hidden_act = mamba_hidden_act
self.time_step_min = mamba_dt_min
self.time_step_max = mamba_dt_max
self.time_step_limit = mamba_dt_limit
self.time_step_floor = mamba_dt_init_floor
self.use_conv_bias = mamba_conv_bias
self.mamba_proj_bias = mamba_proj_bias
self.mamba_chunk_size = mamba_chunk_size
self.rescale_prenorm_residual = rescale_prenorm_residual
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
@property
def mamba_layer_ids(self):
return [
i
for i in range(self.num_hidden_layers)
if self.hybrid_override_pattern[i] == MAMBA
]
@property
def full_attention_layer_ids(self):
return [
i
for i in range(self.num_hidden_layers)
if self.hybrid_override_pattern[i] == ATTENTION
]
@property
def mamba2_cache_params(self) -> Mamba2CacheParams:
shape = Mamba2StateShape.create(
tp_world_size=get_attention_tp_size(),
intermediate_size=self.mamba_num_heads * self.mamba_head_dim,
n_groups=self.n_groups,
num_heads=self.mamba_num_heads,
head_dim=self.mamba_head_dim,
state_size=self.ssm_state_size,
conv_kernel=self.conv_kernel,
)
return Mamba2CacheParams(shape=shape, layers=self.mamba_layer_ids)
...@@ -15,14 +15,12 @@ ...@@ -15,14 +15,12 @@
"""Qwen3Hybrid model configuration""" """Qwen3Hybrid model configuration"""
import enum import enum
import os
import numpy as np
import torch
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging from transformers.utils import logging
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
from sglang.srt.distributed.utils import divide from sglang.srt.distributed.utils import divide
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
...@@ -282,45 +280,15 @@ class Qwen3NextConfig(PretrainedConfig): ...@@ -282,45 +280,15 @@ class Qwen3NextConfig(PretrainedConfig):
] ]
@property @property
def hybrid_gdn_params(self): def mamba2_cache_params(self) -> Mamba2CacheParams:
world_size = get_attention_tp_size() shape = Mamba2StateShape.create(
conv_dim = ( tp_world_size=get_attention_tp_size(),
self.linear_key_head_dim * self.linear_num_key_heads * 2 intermediate_size=self.linear_value_head_dim * self.linear_num_value_heads,
+ self.linear_value_head_dim * self.linear_num_value_heads n_groups=self.linear_num_key_heads,
num_heads=self.linear_num_value_heads,
head_dim=self.linear_value_head_dim,
state_size=self.linear_key_head_dim,
conv_kernel=self.linear_conv_kernel_dim,
) )
conv_state_shape = (
divide(conv_dim, world_size),
self.linear_conv_kernel_dim - 1,
)
temporal_state_shape = (
divide(self.linear_num_value_heads, world_size),
self.linear_key_head_dim,
self.linear_value_head_dim,
)
conv_dtype = torch.bfloat16
dtype_map = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
mamba_layers = self.linear_layer_ids
return (
conv_state_shape,
temporal_state_shape,
conv_dtype,
ssm_dtype,
mamba_layers,
)
@property
def mamba_cache_per_req(self):
conv_state_shape, temporal_state_shape, conv_dtype, ssm_dtype, mamba_layers = (
self.hybrid_gdn_params
)
mamba_layers_len = len(mamba_layers)
return ( return Mamba2CacheParams(shape=shape, layers=self.linear_layer_ids)
int(np.prod(conv_state_shape)) * conv_dtype.itemsize
+ int(np.prod(temporal_state_shape)) * ssm_dtype.itemsize
) * mamba_layers_len
import logging import logging
from typing import TYPE_CHECKING
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING:
# evade circular imports
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.model_executor.model_runner import ModelRunner
ATTENTION_BACKENDS = {} ATTENTION_BACKENDS = {}
...@@ -166,36 +173,41 @@ def create_dual_chunk_flash_attn_backend(runner): ...@@ -166,36 +173,41 @@ def create_dual_chunk_flash_attn_backend(runner):
return DualChunkFlashAttentionBackend(runner) return DualChunkFlashAttentionBackend(runner)
def attn_backend_wrapper(runner, full_attn_backend): def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBackend"):
""" """
Wrapper for special models like hybrid GDN, so we don't Wrapper for special models like hybrid GDN, so we don't
need to change the code of the original attention backend. need to change the code of the original attention backend.
""" """
assert not ( assert not (
runner.is_hybrid_gdn and runner.use_mla_backend runner.hybrid_gdn_config is not None and runner.use_mla_backend
), "hybrid_gdn can only be used with non-MLA models." ), "hybrid_gdn can only be used with non-MLA models."
# wrap for hybrid GDN models if cfg := runner.mambaish_config:
if runner.is_hybrid_gdn:
from sglang.srt.utils import is_blackwell, is_npu
if is_blackwell():
assert (
runner.server_args.attention_backend == "triton"
or runner.server_args.attention_backend == "trtllm_mha"
), "triton or trtllm_mha backend are the only supported backends on Blackwell GPUs for hybrid GDN models, use --attention-backend triton or --attention-backend trtllm_mha to specify the backend."
if is_npu():
assert (
runner.server_args.attention_backend == "ascend"
), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
logger.info(f"Using hybrid linear attention backend for hybrid GDN models.")
from sglang.srt.layers.attention.hybrid_linear_attn_backend import ( from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
GDNAttnBackend,
HybridLinearAttnBackend, HybridLinearAttnBackend,
MambaAttnBackend, Mamba2AttnBackend,
) )
from sglang.srt.utils import is_blackwell, is_npu
linear_attn_backend = MambaAttnBackend(runner) if runner.hybrid_gdn_config is not None:
full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids if is_blackwell():
assert (
runner.server_args.attention_backend == "triton"
), "triton backend is the only supported backend on Blackwell GPUs for hybrid GDN models, use --attention-backend triton to specify the backend."
if is_npu():
assert (
runner.server_args.attention_backend == "ascend"
), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
logger.info(f"Using hybrid linear attention backend for hybrid GDN models.")
linear_attn_backend = GDNAttnBackend(runner)
elif runner.mamba2_config is not None:
linear_attn_backend = Mamba2AttnBackend(runner)
else:
raise ValueError(
"Expected hybrid GDN or NemotronH models, but got unknown model."
)
full_attn_layers = cfg.full_attention_layer_ids
return HybridLinearAttnBackend( return HybridLinearAttnBackend(
full_attn_backend, linear_attn_backend, full_attn_layers full_attn_backend, linear_attn_backend, full_attn_layers
) )
......
...@@ -181,6 +181,45 @@ def _layer_norm_fwd( ...@@ -181,6 +181,45 @@ def _layer_norm_fwd(
return out, mean, rstd return out, mean, rstd
def rms_norm_gated(
*,
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
if x.stride(-1) != 1:
x = x.contiguous()
if z is not None:
assert z.shape == x_shape_og
z = z.reshape(-1, z.shape[-1])
if z.stride(-1) != 1:
z = z.contiguous()
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y, mean, rstd = _layer_norm_fwd(
x,
weight,
bias,
eps,
z=z,
group_size=group_size,
norm_before_gate=norm_before_gate,
is_rms_norm=is_rms_norm,
)
return y.reshape(x_shape_og)
class LayerNormFn(torch.autograd.Function): class LayerNormFn(torch.autograd.Function):
@staticmethod @staticmethod
...@@ -195,32 +234,16 @@ class LayerNormFn(torch.autograd.Function): ...@@ -195,32 +234,16 @@ class LayerNormFn(torch.autograd.Function):
norm_before_gate=True, norm_before_gate=True,
is_rms_norm=False, is_rms_norm=False,
): ):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" return rms_norm_gated(
x=x,
x_shape_og = x.shape weight=weight,
# reshape input data into 2D tensor bias=bias,
x = x.reshape(-1, x.shape[-1]) eps=eps,
if x.stride(-1) != 1:
x = x.contiguous()
if z is not None:
assert z.shape == x_shape_og
z = z.reshape(-1, z.shape[-1])
if z.stride(-1) != 1:
z = z.contiguous()
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y, mean, rstd = _layer_norm_fwd(
x,
weight,
bias,
eps,
z=z, z=z,
group_size=group_size, group_size=group_size,
norm_before_gate=norm_before_gate, norm_before_gate=norm_before_gate,
is_rms_norm=is_rms_norm, is_rms_norm=is_rms_norm,
) )
return y.reshape(x_shape_og)
def layernorm_fn( def layernorm_fn(
...@@ -238,14 +261,6 @@ def layernorm_fn( ...@@ -238,14 +261,6 @@ def layernorm_fn(
) )
def rmsnorm_fn(
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
):
return LayerNormFn.apply(
x, weight, bias, z, eps, group_size, norm_before_gate, True
)
class LayerNorm(torch.nn.Module): class LayerNorm(torch.nn.Module):
def __init__( def __init__(
...@@ -284,6 +299,7 @@ class LayerNorm(torch.nn.Module): ...@@ -284,6 +299,7 @@ class LayerNorm(torch.nn.Module):
group_size=self.group_size, group_size=self.group_size,
eps=self.eps, eps=self.eps,
norm_before_gate=self.norm_before_gate, norm_before_gate=self.norm_before_gate,
is_rms_norm=False,
) )
...@@ -315,7 +331,7 @@ class RMSNorm(torch.nn.Module): ...@@ -315,7 +331,7 @@ class RMSNorm(torch.nn.Module):
def forward(self, x, z=None): def forward(self, x, z=None):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
return rmsnorm_fn( return layernorm_fn(
x, x,
self.weight, self.weight,
self.bias, self.bias,
...@@ -323,4 +339,5 @@ class RMSNorm(torch.nn.Module): ...@@ -323,4 +339,5 @@ class RMSNorm(torch.nn.Module):
eps=self.eps, eps=self.eps,
group_size=self.group_size, group_size=self.group_size,
norm_before_gate=self.norm_before_gate, norm_before_gate=self.norm_before_gate,
is_rms_norm=True,
) )
...@@ -14,14 +14,21 @@ from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import ( ...@@ -14,14 +14,21 @@ from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (
fused_sigmoid_gating_delta_rule_update, fused_sigmoid_gating_delta_rule_update,
) )
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import ( from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
PAD_SLOT_ID,
causal_conv1d_fn, causal_conv1d_fn,
causal_conv1d_update, causal_conv1d_update,
) )
from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
from sglang.srt.layers.attention.mamba.mamba2_metadata import (
ForwardMetadata,
Mamba2Metadata,
)
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, MambaPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.models.qwen3_next import fused_gdn_gating from sglang.srt.models.qwen3_next import fused_gdn_gating
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInput from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import is_cuda, is_npu from sglang.srt.utils import is_cuda, is_npu
...@@ -47,18 +54,10 @@ elif is_npu(): ...@@ -47,18 +54,10 @@ elif is_npu():
causal_conv1d_update = causal_conv1d_update_npu causal_conv1d_update = causal_conv1d_update_npu
@dataclass class MambaAttnBackendBase(AttentionBackend):
class ForwardMetadata:
query_start_loc: Optional[torch.Tensor]
mamba_cache_indices: torch.Tensor
class MambaAttnBackend(AttentionBackend):
"""Attention backend using Mamba kernel."""
def __init__(self, model_runner: ModelRunner): def __init__(self, model_runner: ModelRunner):
super().__init__() super().__init__()
self.pad_slot_id = -1 # Default pad slot id self.pad_slot_id = PAD_SLOT_ID
self.device = model_runner.device self.device = model_runner.device
self.req_to_token_pool: HybridReqToTokenPool = model_runner.req_to_token_pool self.req_to_token_pool: HybridReqToTokenPool = model_runner.req_to_token_pool
self.forward_metadata: ForwardMetadata = None self.forward_metadata: ForwardMetadata = None
...@@ -67,7 +66,7 @@ class MambaAttnBackend(AttentionBackend): ...@@ -67,7 +66,7 @@ class MambaAttnBackend(AttentionBackend):
self.cached_cuda_graph_decode_query_start_loc: torch.Tensor = None self.cached_cuda_graph_decode_query_start_loc: torch.Tensor = None
self.cached_cuda_graph_verify_query_start_loc: torch.Tensor = None self.cached_cuda_graph_verify_query_start_loc: torch.Tensor = None
def init_forward_metadata(self, forward_batch: ForwardBatch): def _forward_metadata(self, forward_batch: ForwardBatch):
bs = forward_batch.batch_size bs = forward_batch.batch_size
if forward_batch.forward_mode.is_decode_or_idle(): if forward_batch.forward_mode.is_decode_or_idle():
...@@ -97,11 +96,43 @@ class MambaAttnBackend(AttentionBackend): ...@@ -97,11 +96,43 @@ class MambaAttnBackend(AttentionBackend):
mamba_cache_indices = self.req_to_token_pool.get_mamba_indices( mamba_cache_indices = self.req_to_token_pool.get_mamba_indices(
forward_batch.req_pool_indices forward_batch.req_pool_indices
) )
self.forward_metadata = ForwardMetadata( return ForwardMetadata(
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
mamba_cache_indices=mamba_cache_indices, mamba_cache_indices=mamba_cache_indices,
) )
def init_forward_metadata(self, forward_batch: ForwardBatch):
self.forward_metadata = self._forward_metadata(forward_batch)
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
self.forward_metadata = self._capture_metadata(
bs, req_pool_indices, forward_mode
)
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor],
):
self.forward_metadata = self._replay_metadata(
bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu
)
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
assert ( assert (
max_num_tokens % max_bs == 0 max_num_tokens % max_bs == 0
...@@ -127,15 +158,8 @@ class MambaAttnBackend(AttentionBackend): ...@@ -127,15 +158,8 @@ class MambaAttnBackend(AttentionBackend):
device=self.device, device=self.device,
) )
def init_forward_metadata_capture_cuda_graph( def _capture_metadata(
self, self, bs: int, req_pool_indices: torch.Tensor, forward_mode: ForwardMode
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInput],
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
self.query_start_loc_list[bs - 1].copy_( self.query_start_loc_list[bs - 1].copy_(
...@@ -149,18 +173,15 @@ class MambaAttnBackend(AttentionBackend): ...@@ -149,18 +173,15 @@ class MambaAttnBackend(AttentionBackend):
raise ValueError(f"Invalid forward mode: {forward_mode=}") raise ValueError(f"Invalid forward mode: {forward_mode=}")
mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices) mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices)
self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices) self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices)
self.forward_metadata = ForwardMetadata( return ForwardMetadata(
query_start_loc=self.query_start_loc_list[bs - 1], query_start_loc=self.query_start_loc_list[bs - 1],
mamba_cache_indices=self.state_indices_list[bs - 1], mamba_cache_indices=self.state_indices_list[bs - 1],
) )
def init_forward_metadata_replay_cuda_graph( def _replay_metadata(
self, self,
bs: int, bs: int,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[SpecInput], spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
...@@ -200,7 +221,7 @@ class MambaAttnBackend(AttentionBackend): ...@@ -200,7 +221,7 @@ class MambaAttnBackend(AttentionBackend):
else: else:
raise ValueError(f"Invalid forward mode: {forward_mode=}") raise ValueError(f"Invalid forward mode: {forward_mode=}")
self.forward_metadata = ForwardMetadata( return ForwardMetadata(
query_start_loc=self.query_start_loc_list[bs - 1], query_start_loc=self.query_start_loc_list[bs - 1],
mamba_cache_indices=self.state_indices_list[bs - 1], mamba_cache_indices=self.state_indices_list[bs - 1],
) )
...@@ -208,6 +229,10 @@ class MambaAttnBackend(AttentionBackend): ...@@ -208,6 +229,10 @@ class MambaAttnBackend(AttentionBackend):
def get_cuda_graph_seq_len_fill_value(self): def get_cuda_graph_seq_len_fill_value(self):
return 1 # Mamba attn does not use seq lens to index kv cache return 1 # Mamba attn does not use seq lens to index kv cache
class GDNAttnBackend(MambaAttnBackendBase):
"""Attention backend using Mamba kernel."""
def forward_decode( def forward_decode(
self, self,
q: torch.Tensor, q: torch.Tensor,
...@@ -233,9 +258,9 @@ class MambaAttnBackend(AttentionBackend): ...@@ -233,9 +258,9 @@ class MambaAttnBackend(AttentionBackend):
dt_bias = kwargs["dt_bias"] dt_bias = kwargs["dt_bias"]
layer_id = kwargs["layer_id"] layer_id = kwargs["layer_id"]
conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params( layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
layer_id conv_states = layer_cache.conv
) ssm_states = layer_cache.temporal
query_start_loc = self.forward_metadata.query_start_loc query_start_loc = self.forward_metadata.query_start_loc
cache_indices = self.forward_metadata.mamba_cache_indices cache_indices = self.forward_metadata.mamba_cache_indices
...@@ -313,13 +338,13 @@ class MambaAttnBackend(AttentionBackend): ...@@ -313,13 +338,13 @@ class MambaAttnBackend(AttentionBackend):
query_start_loc = self.forward_metadata.query_start_loc query_start_loc = self.forward_metadata.query_start_loc
cache_indices = self.forward_metadata.mamba_cache_indices cache_indices = self.forward_metadata.mamba_cache_indices
mamba_cache_params = self.req_to_token_pool.mamba2_layer_cache(layer_id)
conv_states = mamba_cache_params.conv
ssm_states = mamba_cache_params.temporal
if is_target_verify: if is_target_verify:
( assert isinstance(mamba_cache_params, MambaPool.SpeculativeState)
conv_states, intermediate_state_cache = mamba_cache_params.intermediate_ssm
ssm_states, intermediate_conv_window_cache = mamba_cache_params.intermediate_conv_window
intermediate_state_cache,
intermediate_conv_window_cache,
) = self.req_to_token_pool.get_mamba_params(layer_id)
has_initial_states = torch.ones( has_initial_states = torch.ones(
seq_len // forward_batch.spec_info.draft_token_num, seq_len // forward_batch.spec_info.draft_token_num,
dtype=torch.bool, dtype=torch.bool,
...@@ -327,9 +352,6 @@ class MambaAttnBackend(AttentionBackend): ...@@ -327,9 +352,6 @@ class MambaAttnBackend(AttentionBackend):
) )
conv_states_to_use = conv_states.clone() conv_states_to_use = conv_states.clone()
else: else:
conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params(
layer_id
)
has_initial_states = forward_batch.extend_prefix_lens > 0 has_initial_states = forward_batch.extend_prefix_lens > 0
conv_states_to_use = conv_states conv_states_to_use = conv_states
...@@ -424,16 +446,100 @@ class MambaAttnBackend(AttentionBackend): ...@@ -424,16 +446,100 @@ class MambaAttnBackend(AttentionBackend):
return core_attn_out return core_attn_out
class Mamba2AttnBackend(MambaAttnBackendBase):
"""Attention backend wrapper for Mamba2Mixer kernels."""
def __init__(self, model_runner: ModelRunner):
super().__init__(model_runner)
config = model_runner.mamba2_config
assert config is not None
self.mamba_chunk_size = config.mamba_chunk_size
def init_forward_metadata(self, forward_batch: ForwardBatch):
metadata = self._forward_metadata(forward_batch)
self.forward_metadata = Mamba2Metadata.prepare_mixed(
metadata.query_start_loc,
metadata.mamba_cache_indices,
self.mamba_chunk_size,
forward_batch,
)
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
metadata = self._capture_metadata(bs, req_pool_indices, forward_mode)
self.forward_metadata = Mamba2Metadata.prepare_decode(
metadata.query_start_loc, metadata.mamba_cache_indices, seq_lens
)
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor],
):
metadata = self._replay_metadata(
bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu
)
self.forward_metadata = Mamba2Metadata.prepare_decode(
metadata.query_start_loc, metadata.mamba_cache_indices, seq_lens
)
def forward(
self,
mixer: MambaMixer2,
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_id: int,
mup_vector: Optional[torch.Tensor] = None,
use_triton_causal_conv: bool = False,
):
assert isinstance(self.forward_metadata, Mamba2Metadata)
layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
return mixer.forward(
hidden_states=hidden_states,
output=output,
layer_cache=layer_cache,
metadata=self.forward_metadata,
mup_vector=mup_vector,
use_triton_causal_conv=use_triton_causal_conv,
)
def forward_decode(self, *args, **kwargs):
raise NotImplementedError(
"Mamba2AttnBackend's forward is called directly instead of through HybridLinearAttnBackend, as it supports mixed prefill and decode"
)
def forward_extend(self, *args, **kwargs):
raise NotImplementedError(
"Mamba2AttnBackend's forward is called directly instead of through HybridLinearAttnBackend, as it supports mixed prefill and decode"
)
class HybridLinearAttnBackend(AttentionBackend): class HybridLinearAttnBackend(AttentionBackend):
"""Support different backends for prefill and decode.""" """Manages a full and linear attention backend"""
def __init__( def __init__(
self, self,
full_attn_backend: AttentionBackend, full_attn_backend: AttentionBackend,
linear_attn_backend: AttentionBackend, linear_attn_backend: MambaAttnBackendBase,
full_attn_layers: list[int], full_attn_layers: list[int],
): ):
self.full_attn_layers = full_attn_layers self.full_attn_layers = full_attn_layers
self.full_attn_backend = full_attn_backend
self.linear_attn_backend = linear_attn_backend
self.attn_backend_list = [full_attn_backend, linear_attn_backend] self.attn_backend_list = [full_attn_backend, linear_attn_backend]
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
...@@ -489,7 +595,7 @@ class HybridLinearAttnBackend(AttentionBackend): ...@@ -489,7 +595,7 @@ class HybridLinearAttnBackend(AttentionBackend):
) )
def get_cuda_graph_seq_len_fill_value(self): def get_cuda_graph_seq_len_fill_value(self):
return self.attn_backend_list[0].get_cuda_graph_seq_len_fill_value() return self.full_attn_backend.get_cuda_graph_seq_len_fill_value()
def forward_decode( def forward_decode(
self, self,
...@@ -503,10 +609,10 @@ class HybridLinearAttnBackend(AttentionBackend): ...@@ -503,10 +609,10 @@ class HybridLinearAttnBackend(AttentionBackend):
): ):
layer_id = layer.layer_id if layer else kwargs["layer_id"] layer_id = layer.layer_id if layer else kwargs["layer_id"]
if layer_id in self.full_attn_layers: if layer_id in self.full_attn_layers:
return self.attn_backend_list[0].forward_decode( return self.full_attn_backend.forward_decode(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs q, k, v, layer, forward_batch, save_kv_cache, **kwargs
) )
return self.attn_backend_list[1].forward_decode( return self.linear_attn_backend.forward_decode(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs q, k, v, layer, forward_batch, save_kv_cache, **kwargs
) )
...@@ -522,10 +628,10 @@ class HybridLinearAttnBackend(AttentionBackend): ...@@ -522,10 +628,10 @@ class HybridLinearAttnBackend(AttentionBackend):
): ):
layer_id = layer.layer_id if layer else kwargs["layer_id"] layer_id = layer.layer_id if layer else kwargs["layer_id"]
if layer_id in self.full_attn_layers: if layer_id in self.full_attn_layers:
return self.attn_backend_list[0].forward_extend( return self.full_attn_backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs q, k, v, layer, forward_batch, save_kv_cache, **kwargs
) )
return self.attn_backend_list[1].forward_extend( return self.linear_attn_backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs q, k, v, layer, forward_batch, save_kv_cache, **kwargs
) )
...@@ -568,20 +674,20 @@ class HybridLinearAttnBackend(AttentionBackend): ...@@ -568,20 +674,20 @@ class HybridLinearAttnBackend(AttentionBackend):
def update_mamba_state_after_mtp_verify(self, accepted_length, model): def update_mamba_state_after_mtp_verify(self, accepted_length, model):
request_number = accepted_length.shape[0] request_number = accepted_length.shape[0]
state_indices_tensor = self.attn_backend_list[ state_indices_tensor = (
1 self.linear_attn_backend.forward_metadata.mamba_cache_indices[
].forward_metadata.mamba_cache_indices[:request_number] :request_number
]
)
mamba_caches = self.attn_backend_list[ mamba_caches = (
1 self.linear_attn_backend.req_to_token_pool.get_speculative_mamba2_params_all_layers()
].req_to_token_pool.get_mamba_params_all_layers() )
( conv_states = mamba_caches.conv
conv_states, ssm_states = mamba_caches.temporal
ssm_states, intermediate_state_cache = mamba_caches.intermediate_ssm
intermediate_state_cache, intermediate_conv_window_cache = mamba_caches.intermediate_conv_window
intermediate_conv_window_cache,
) = mamba_caches
# SSM state updates (chunked to reduce peak memory) # SSM state updates (chunked to reduce peak memory)
valid_mask = accepted_length > 0 valid_mask = accepted_length > 0
......
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
from sgl_kernel import causal_conv1d_fwd from sgl_kernel import causal_conv1d_fwd
from sgl_kernel import causal_conv1d_update as causal_conv1d_update_kernel from sgl_kernel import causal_conv1d_update as causal_conv1d_update_kernel
PAD_SLOT_ID = -1 from .causal_conv1d_triton import PAD_SLOT_ID
def causal_conv1d_fn( def causal_conv1d_fn(
......
...@@ -6,11 +6,11 @@ from typing import List, Optional, Union ...@@ -6,11 +6,11 @@ from typing import List, Optional, Union
import numpy as np import numpy as np
import torch import torch
PAD_SLOT_ID = -1
import triton import triton
import triton.language as tl import triton.language as tl
PAD_SLOT_ID = -1
@triton.jit() @triton.jit()
def _causal_conv1d_fwd_kernel( # continuous batching def _causal_conv1d_fwd_kernel( # continuous batching
...@@ -672,7 +672,9 @@ def _causal_conv1d_update_kernel( ...@@ -672,7 +672,9 @@ def _causal_conv1d_update_kernel(
+ (conv_state_batch_coord * stride_conv_state_seq) + (conv_state_batch_coord * stride_conv_state_seq)
+ conv_state_token_offset * stride_conv_state_tok + conv_state_token_offset * stride_conv_state_tok
+ (idx_feats * stride_conv_state_dim)[None, :] + (idx_feats * stride_conv_state_dim)[None, :]
+ ((idx_tokens + 1) * stride_conv_state_tok)[:, None] + ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[
:, None
]
) # [BLOCK_M, BLOCK_N] ) # [BLOCK_M, BLOCK_N]
mask = ( mask = (
(conv_state_batch_coord < num_cache_lines) (conv_state_batch_coord < num_cache_lines)
...@@ -897,7 +899,10 @@ def causal_conv1d_update( ...@@ -897,7 +899,10 @@ def causal_conv1d_update(
stride_state_indices = ( stride_state_indices = (
conv_state_indices.stride(0) if conv_state_indices is not None else 0 conv_state_indices.stride(0) if conv_state_indices is not None else 0
) )
state_len = width - 1 + (seqlen - 1) # effective state_len needed if num_accepted_tokens is not None:
state_len = width - 1 + (seqlen - 1) # effective state_len needed
else:
state_len = width - 1
np2_statelen = triton.next_power_of_2(state_len) np2_statelen = triton.next_power_of_2(state_len)
def grid(META): def grid(META):
......
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from https://github.com/vllm-project/vllm/blob/2c58742dff8613a3bd7496f2008ce927e18d38d1/vllm/model_executor/layers/mamba/mamba2_metadata.py
import math
from dataclasses import dataclass
import torch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@dataclass(kw_only=True)
class ForwardMetadata:
query_start_loc: torch.Tensor
mamba_cache_indices: torch.Tensor
@dataclass(kw_only=True)
class Mamba2Metadata(ForwardMetadata):
"""stable metadata across all mamba2 layers in the forward pass"""
num_prefills: int
num_prefill_tokens: int
num_decodes: int
@dataclass(kw_only=True, frozen=True)
class MixedMetadata:
has_initial_states: torch.Tensor
prep_initial_states: bool
chunk_size: int
seq_idx: torch.Tensor
chunk_indices: torch.Tensor
chunk_offsets: torch.Tensor
extend_seq_lens_cpu: list[int]
mixed_metadata: MixedMetadata | None = None
"""`mixed_metadata` is used for extend/mixed requests"""
@staticmethod
def _query_start_loc_to_chunk_indices_offsets(
query_start_loc: torch.Tensor, chunk_size: int, total_seqlens: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
query_start_loc (torch.Tensor): 1D tensor of cumulative sequence
lengths, shape (num_seqs + 1,).
The first element should be 0. Each entry represents the starting
index of a sequence in the flattened token array.
chunk_size (int): The size of each physical mamba chunk
(number of tokens per chunk).
total_seqlens (int): The total number of tokens in the batch.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- chunk_indices (torch.Tensor): 1D tensor of indices
indicating the physical chunk for each logical chunk.
- chunk_offsets (torch.Tensor): 1D tensor of offsets
indicating the starting index of each logical chunk within
its physical chunk.
This function computes the chunk indices and offsets for the given
query_start_loc and chunk_size. Both are tensors of integers with length N,
where N is the number of logical (pseudo) chunks.
A logical chunk is a sequence of tokens that are all part of the same
sequence and are all in the same physical mamba chunk.
In other words, a logical chunk changes every time we cross a sequence
boundary or a physical mamba chunk boundary.
Logical chunks are needed to handle batched requests with initial states
(see _state_passing_fwd and _chunk_scan_fwd).
The chunk_indices tensor contains the index of the physical chunk for each
logical chunk.
The chunk_offsets tensor contains the offset (AKA starting index) of the
logical chunk in the physical chunk.
Example:
query_start_loc = [0, 5, 10]
chunk_size = 8
total_seqlens = 10
-> chunk_indices = [0, 0, 1]
-> chunk_offsets = [0, 5, 0]
In this example, we have 2 sequences, each with 5 tokens. The physical
chunk size is 8 tokens.
We have three logical chunks:
- the first logical chunk starts at token 0 in the first physical chunk
and contains all 5 tokens from the first sequence
- the second logical chunk starts at token 5 in the first physical chunk
and contains first 3 tokens from the second sequence
- the third logical chunk starts at token 0 in the second physical chunk
and contains the remaining 2 tokens from the second sequence
"""
cu_seqlens = query_start_loc[1:] # remove prepended 0
# outputs will have length expansion of chunks that do not divide
# chunk_size
N = (
math.ceil(total_seqlens / chunk_size)
+ (cu_seqlens[:-1] % chunk_size > 0).sum()
)
chunk_indices = torch.arange(N, dtype=torch.int, device=query_start_loc.device)
chunk_offsets = torch.zeros(
(N,), dtype=torch.int, device=query_start_loc.device
)
p = 0 # num of insertions
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
# if does not divide chunk_size, then there is one chunk insertion
p += s % chunk_size > 0
# get the dimensions
# - the + 1 for _e is to shift the boundary by one chunk
# - this shifting is not needed if chunk_size divides e
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size > 0)
# adjust indices and offsets
chunk_indices[_s:_e] -= p
chunk_offsets[_s] = s % chunk_size
return chunk_indices, chunk_offsets
@staticmethod
def prepare_decode(
query_start_loc: torch.Tensor,
mamba_cache_indices: torch.Tensor,
seq_lens: torch.Tensor,
) -> "Mamba2Metadata":
"""This path is run during CUDA graph capture, i.e. decode only, so `num_prefills` is 0"""
return Mamba2Metadata(
query_start_loc=query_start_loc,
mamba_cache_indices=mamba_cache_indices,
num_decodes=len(seq_lens),
num_prefills=0,
num_prefill_tokens=0,
)
@classmethod
def prepare_mixed(
cls,
query_start_loc: torch.Tensor,
mamba_cache_indices: torch.Tensor,
chunk_size: int,
forward_batch: ForwardBatch,
) -> "Mamba2Metadata":
"""This path cannot run with CUDA graph, as it contains extend requests."""
if forward_batch.extend_num_tokens is None:
return cls.prepare_decode(
query_start_loc, mamba_cache_indices, forward_batch.seq_lens
)
num_prefills = len(forward_batch.extend_seq_lens)
num_prefill_tokens = forward_batch.extend_num_tokens
num_decodes = len(forward_batch.seq_lens) - num_prefills
context_lens_tensor = forward_batch.extend_prefix_lens
assert context_lens_tensor is not None
# precompute flag to avoid device syncs later
has_initial_states = context_lens_tensor > 0
prep_initial_states = torch.any(has_initial_states[:num_prefills]).item()
query_start_loc = query_start_loc[: num_prefills + 1]
seq_idx = torch.repeat_interleave(
torch.arange(
num_prefills, dtype=torch.int32, device=query_start_loc.device
),
query_start_loc.diff(),
output_size=num_prefill_tokens,
)
seq_idx.unsqueeze_(0)
# We compute metadata for chunked prefill once at the top level model
# forward and reuse them in mamba layers. If not needed, they will be
# ignored inside mamba kernels.
chunk_offsets, chunk_indices = None, None
if prep_initial_states:
chunk_indices, chunk_offsets = (
cls._query_start_loc_to_chunk_indices_offsets(
query_start_loc, chunk_size, num_prefill_tokens
)
)
return Mamba2Metadata(
query_start_loc=query_start_loc,
mamba_cache_indices=mamba_cache_indices,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
mixed_metadata=cls.MixedMetadata(
has_initial_states=has_initial_states,
prep_initial_states=prep_initial_states,
chunk_size=chunk_size,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
),
)
# Adapted from: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/mamba_utils.py
from sglang.srt.distributed.utils import divide
class MambaStateShapeCalculator:
@classmethod
def linear_attention_state_shape(
cls,
num_heads: int,
tp_size: int,
head_dim: int,
) -> tuple[tuple[int, int, int], ...]:
state_shape = (num_heads // tp_size, head_dim, head_dim)
return (state_shape,)
@classmethod
def mamba1_state_shape(
cls,
tp_world_size: int,
intermediate_size: int,
state_size: int,
conv_kernel: int,
) -> tuple[tuple[int, int], tuple[int, int]]:
conv_state_shape = (divide(intermediate_size, tp_world_size), conv_kernel - 1)
temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size)
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
return conv_state_shape, temporal_state_shape
@classmethod
def mamba2_state_shape(
cls,
tp_world_size: int,
intermediate_size: int,
n_groups: int,
num_heads: int,
head_dim: int,
state_size: int,
conv_kernel: int,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
n_groups = n_groups + cls.extra_groups_for_head_shards(n_groups, tp_world_size)
# heads and n_groups are TP-ed
conv_dim = intermediate_size + 2 * n_groups * state_size
# contiguous along 'dim' axis
conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size))
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size)
return conv_state_shape, temporal_state_shape
@classmethod
def short_conv_state_shape(
cls,
tp_world_size: int,
intermediate_size: int,
conv_kernel: int,
) -> tuple[tuple[int, int]]:
conv_dim = divide(intermediate_size, tp_world_size)
conv_state_shape = (conv_kernel - 1, conv_dim)
return (conv_state_shape,)
@classmethod
def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):
"""Compute the increase in group numbers to account for
replication in order to accompany the head shards."""
# in the case ngoups % tp_size == 0, this will be zero
if ngroups % tp_size == 0:
return 0
# for n_groups == 1, this is exactly tp_size - n_groups
return tp_size - ngroups
from typing import Union
import torch
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed.communication_op import (
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.attention.fla.layernorm_gated import rms_norm_gated
from sglang.srt.model_loader.weight_utils import sharded_weight_loader
from sglang.srt.utils.common import set_weight_attrs
class Mixer2RMSNormGated(CustomOp):
def __init__(
self,
full_hidden_size: int,
full_n_groups: int,
use_rms_norm: bool = True,
eps: float = 1e-6,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.full_hidden_size = full_hidden_size
self.group_size = full_hidden_size // full_n_groups
self.per_rank_hidden_size = full_hidden_size // self.tp_size
self.n_groups = full_hidden_size // self.group_size
self.variance_epsilon = eps
self.use_rms_norm = use_rms_norm
if self.use_rms_norm:
# Register norm weight only if we're actually applying RMSNorm
self.weight = torch.nn.Parameter(torch.ones(self.per_rank_hidden_size))
set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)})
else:
# Avoid checkpoint mismatch by skipping unused parameter
self.register_parameter("weight", None)
assert (
self.full_hidden_size % self.tp_size == 0
), "Tensor parallel world size must divide hidden size."
def forward_native(
self,
x: torch.Tensor,
gate: torch.Tensor,
):
# Three tensor-parallel cases:
# 1. n_groups is 1
# In this case we parallelize along the reduction dim.
# Each rank computes a local sum of squares followed by AllReduce
# 2. tp_size divides n_groups
# Each rank only reduces within its local group(s).
# No collective ops necessary.
# 3. The general case can be pretty complicated so we AllGather
# the input and then redundantly compute the RMSNorm.
input_dtype = x.dtype
x = x * torch.nn.functional.silu(gate.to(torch.float32))
if not self.use_rms_norm:
return x.to(input_dtype)
if self.n_groups == 1:
if self.tp_size > 1:
# Compute local sum and then reduce to obtain global sum
local_sums = x.pow(2).sum(dim=-1, keepdim=True)
global_sums = tensor_model_parallel_all_reduce(local_sums)
# Calculate the variance
count = self.tp_size * x.shape[-1]
variance = global_sums / count
else:
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
else:
redundant_tp: bool = self.n_groups % self.tp_size != 0
if redundant_tp:
# To handle the general case, redundantly apply the variance
x = tensor_model_parallel_all_gather(x, -1)
*prefix_dims, hidden_dim = x.shape
group_count = hidden_dim // self.group_size
x_grouped = x.view(*prefix_dims, group_count, self.group_size)
variance = x_grouped.pow(2).mean(-1, keepdim=True)
x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)
x = x_grouped.view(*prefix_dims, hidden_dim)
if redundant_tp:
start = self.per_rank_hidden_size * self.tp_rank
end = start + self.per_rank_hidden_size
x = x[..., start:end]
return self.weight * x.to(input_dtype)
def forward_cuda(
self,
x: torch.Tensor,
gate: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
input_dtype = x.dtype
if not self.use_rms_norm:
# Keep gate in float32 for numerical stability during silu
return x * torch.nn.functional.silu(gate.to(torch.float32)).to(input_dtype)
if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1:
return self.forward_native(x, gate)
return rms_norm_gated(
x=x,
weight=self.weight.data,
bias=None,
z=gate,
eps=self.variance_epsilon,
norm_before_gate=False,
is_rms_norm=True,
)
...@@ -15,56 +15,6 @@ import triton ...@@ -15,56 +15,6 @@ import triton
import triton.language as tl import triton.language as tl
# @triton.autotune(
# configs=[
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
# num_stages=3,
# num_warps=8,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=2,
# ),
# ],
# key=["chunk_size", "K", "IS_CAUSAL"],
# )
@triton.jit @triton.jit
def _bmm_chunk_fwd_kernel( def _bmm_chunk_fwd_kernel(
# Pointers to matrices # Pointers to matrices
......
...@@ -16,66 +16,6 @@ from packaging import version ...@@ -16,66 +16,6 @@ from packaging import version
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0") TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
# @triton.autotune(
# configs=[
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
# num_stages=3,
# num_warps=8,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=2,
# ),
# ],
# key=["chunk_size", "hdim", "dstate", "IS_CAUSAL"],
# )
@triton.jit @triton.jit
def _chunk_scan_fwd_kernel( def _chunk_scan_fwd_kernel(
# Pointers to matrices # Pointers to matrices
......
...@@ -17,17 +17,6 @@ import triton.language as tl ...@@ -17,17 +17,6 @@ import triton.language as tl
from .mamba_ssm import softplus from .mamba_ssm import softplus
# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_SIZE_H": 2}),
# triton.Config({"BLOCK_SIZE_H": 4}),
# triton.Config({"BLOCK_SIZE_H": 8}),
# triton.Config({"BLOCK_SIZE_H": 16}),
# triton.Config({"BLOCK_SIZE_H": 32}),
# triton.Config({"BLOCK_SIZE_H": 64}),
# ],
# key=["chunk_size", "nheads"],
# )
@triton.jit @triton.jit
def _chunk_cumsum_fwd_kernel( def _chunk_cumsum_fwd_kernel(
# Pointers to matrices # Pointers to matrices
...@@ -120,56 +109,6 @@ def _chunk_cumsum_fwd_kernel( ...@@ -120,56 +109,6 @@ def _chunk_cumsum_fwd_kernel(
) )
# @triton.autotune(
# configs=[
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
# num_stages=3,
# num_warps=8,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=2,
# ),
# ],
# key=["hdim", "dstate", "chunk_size"],
# )
@triton.jit @triton.jit
def _chunk_state_fwd_kernel( def _chunk_state_fwd_kernel(
# Pointers to matrices # Pointers to matrices
...@@ -320,56 +259,6 @@ def _chunk_state_fwd_kernel( ...@@ -320,56 +259,6 @@ def _chunk_state_fwd_kernel(
tl.store(states_ptrs, states, mask=c_mask) tl.store(states_ptrs, states, mask=c_mask)
# @triton.autotune(
# configs=[
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
# num_stages=3,
# num_warps=8,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=2,
# ),
# ],
# key=["hdim", "dstate", "chunk_size"],
# )
@triton.jit @triton.jit
def _chunk_state_varlen_kernel( def _chunk_state_varlen_kernel(
# Pointers to matrices # Pointers to matrices
......
...@@ -13,17 +13,6 @@ import triton ...@@ -13,17 +13,6 @@ import triton
import triton.language as tl import triton.language as tl
# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_SIZE": 64}),
# triton.Config({"BLOCK_SIZE": 128}),
# triton.Config({"BLOCK_SIZE": 256}),
# triton.Config({"BLOCK_SIZE": 512}),
# triton.Config({"BLOCK_SIZE": 1024}),
# triton.Config({"BLOCK_SIZE": 2048}),
# ],
# key=["dim"],
# )
@triton.jit @triton.jit
def _state_passing_fwd_kernel( def _state_passing_fwd_kernel(
# Pointers to matrices # Pointers to matrices
......
...@@ -85,7 +85,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -85,7 +85,7 @@ class TritonAttnBackend(AttentionBackend):
self.num_kv_head = model_runner.model_config.get_num_kv_heads( self.num_kv_head = model_runner.model_config.get_num_kv_heads(
get_attention_tp_size() get_attention_tp_size()
) )
if model_runner.is_hybrid_gdn: if model_runner.hybrid_gdn_config is not None:
# For hybrid linear models, layer_id = 0 may not be full attention # For hybrid linear models, layer_id = 0 may not be full attention
self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim() self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment