Unverified Commit 083629c2 authored by ilyasch2's avatar ilyasch2 Committed by GitHub
Browse files

[model] Add mamba2 and Falcon-H1 support. (#10988)


Co-authored-by: default avatarYounes Belkada <younes.belkada@tii.ae>
Co-authored-by: default avatarYounes B <49240599+younesbelkada@users.noreply.github.com>
parent b658be6f
......@@ -4,6 +4,7 @@ from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
from sglang.srt.configs.dots_ocr import DotsOCRConfig
from sglang.srt.configs.dots_vlm import DotsVLMConfig
from sglang.srt.configs.exaone import ExaoneConfig
from sglang.srt.configs.falcon_h1 import FalconH1Config
from sglang.srt.configs.janus_pro import MultiModalityConfig
from sglang.srt.configs.kimi_vl import KimiVLConfig
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
......@@ -30,4 +31,5 @@ __all__ = [
"Qwen3NextConfig",
"DotsVLMConfig",
"DotsOCRConfig",
"FalconH1Config",
]
# coding=utf-8
# Copyright 2024 TII and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Falcon-H1 model configuration"""
import enum
import os
import numpy as np
import torch
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
from sglang.srt.distributed.utils import divide
from sglang.srt.layers.attention.mamba.mamba_utils import MambaStateShapeCalculator
from sglang.srt.layers.dp_attention import (
get_attention_tp_size,
get_tensor_model_parallel_world_size,
)
logger = logging.get_logger(__name__)
class FalconH1Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`FalconH1Model`]. It is used to instantiate a
FalconH1Model model according to the specified arguments, defining the model architecture. Instantiating a configuration
with defaults taken from [ibm-fms/FalconH1-9.8b-2.2T-hf](https://huggingface.co/ibm-fms/FalconH1-9.8b-2.2T-hf).
The FalconH1Model is a hybrid [mamba2](https://github.com/state-spaces/mamba) architecture with SwiGLU.
The checkpoints are jointly trained by IBM, Princeton, and UIUC.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 128000):
Vocabulary size of the FalconH1 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`FalconH1Model`]
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
model has a output word embedding layer.
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 14336):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details, check out [this
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an
integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the
logits of the last prompt token are needed for generation. For long sequences, the logits for the entire
sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint
significantly.
pad_token_id (`int`, *optional*, defaults to 0):
The id of the padding token.
bos_token_id (`int`, *optional*, defaults to 1):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 2):
The id of the "end-of-sequence" token.
max_position_embeddings (`int`, *optional*, defaults to 8192):
Max cached sequence length for the model
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
mamba_d_ssm (`int`, *optional*, defaults to 1024):
The dimension of the SSM state space latents.
mamba_n_heads (`int`, *optional*, defaults to 128):
The number of mamba heads used in the v2 implementation.
mamba_d_head (`int`, *optional*, defaults to `"auto"`):
Head embedding dimension size
mamba_n_groups (`int`, *optional*, defaults to 1):
The number of the mamba groups used in the v2 implementation.
mamba_d_state (`int`, *optional*, defaults to 256):
The dimension the mamba state space latents
mamba_d_conv (`int`, *optional*, defaults to 4):
The size of the mamba convolution kernel
mamba_expand (`int`, *optional*, defaults to 2):
Expanding factor (relative to hidden_size) used to determine the mamba intermediate size
mamba_chunk_size (`int`, *optional*, defaults to 256):
The chunks in which to break the sequence when doing prefill/training
mamba_conv_bias (`bool`, *optional*, defaults to `True`):
Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block.
mamba_proj_bias (`bool`, *optional*, defaults to `False`):
Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block
mamba_norm_before_gate (`bool`, *optional*, defaults to `True`):
Whether to use RMSNorm before the gate in the Mamba block
mamba_rms_norm (`bool`, *optional*, defaults to `False`):
Whether to use RMSNorm instead of LayerNorm in the Mamba block
projectors_bias (`bool`, *optional*, defaults to `False`):
Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the attention block
rope_theta (`float`, *optional*, defaults to 100000.0):
The theta value used for the RoPE embeddings.
rope_scaling (`float`, *optional*):
The scaling value used for the RoPE embeddings. If `None`, no scaling is applied.
lm_head_multiplier (`float`, *optional*, defaults to 1.0):
The multiplier for the LM head. This is used to scale the output of the LM head.
embedding_multiplier (`float`, *optional*, defaults to 1.0):
The multiplier for the embedding layer. This is used to scale the output of the embedding layer.
mlp_multipliers (`list[float]`, *optional*):
The multipliers for the MLP layers. This is used to scale the output of the MLP layers. The first value is
the multiplier of gate layer, the second value is the multiplier of the down_proj layer.
key_multiplier (`float`, *optional*):
The multiplier for the key layer. This is used to scale the output of the key layer.
attention_out_multiplier (`float`, *optional*):
The multiplier for the attention output layer. This is used to scale the output of the attention output
attention_in_multiplier (`float`, *optional*):
The multiplier for the attention input layer. This is used to scale the output of the attention input layer.
ssm_multipliers (`list[float]`, *optional*):
The multipliers for the SSM layers. This is used to scale the output of the SSM layers.
ssm_in_multiplier (`float`, *optional*):
The multiplier for the SSM input layer. This is used to scale the output of the SSM input layer.
ssm_out_multiplier (`float`, *optional*):
The multiplier for the SSM output layer. This is used to scale the output of the SSM output layer.
"""
model_type = "falcon_h1"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=128000,
tie_word_embeddings=False,
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
num_logits_to_keep=1,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
max_position_embeddings=8192,
attention_dropout=0.0,
mamba_d_ssm=1024,
mamba_n_heads=128,
mamba_d_head="auto",
mamba_n_groups=1,
mamba_d_state=256,
mamba_d_conv=4,
mamba_expand=2,
mamba_chunk_size=256,
mamba_conv_bias=True,
mamba_proj_bias=False,
mamba_norm_before_gate=True,
mamba_rms_norm=False,
projectors_bias=False,
rope_theta=100000.0,
rope_scaling=None,
lm_head_multiplier=1.0,
embedding_multiplier=1.0,
mlp_multipliers=None,
key_multiplier=None,
attention_out_multiplier=None,
attention_in_multiplier=None,
ssm_multipliers=None,
ssm_in_multiplier=None,
ssm_out_multiplier=None,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.max_position_embeddings = max_position_embeddings
self.attention_dropout = attention_dropout
self.attention_bias = False
self.mlp_bias = False
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.num_logits_to_keep = num_logits_to_keep
self.rope_theta = rope_theta
self.rope_scaling = None
self.rope_scaling = rope_scaling
self.projectors_bias = projectors_bias
mamba_intermediate = (
mamba_expand * hidden_size if mamba_d_ssm is None else mamba_d_ssm
)
if mamba_intermediate % mamba_n_heads != 0:
raise ValueError("mamba_n_heads must divide mamba_expand * hidden_size")
# for the mamba_v2, must satisfy the following
if mamba_d_head == "auto":
mamba_d_head = mamba_intermediate // mamba_n_heads
if mamba_d_head * mamba_n_heads != mamba_intermediate:
raise ValueError(
"The dimensions for the Mamba head state do not match the model intermediate_size"
)
self.mamba_d_ssm = mamba_d_ssm
self.mamba_n_heads = mamba_n_heads
self.mamba_d_head = mamba_d_head
self.mamba_n_groups = mamba_n_groups
self.mamba_d_state = mamba_d_state
self.mamba_d_conv = mamba_d_conv
self.mamba_expand = mamba_expand
self.mamba_chunk_size = mamba_chunk_size
self.mamba_conv_bias = mamba_conv_bias
self.mamba_proj_bias = mamba_proj_bias
self.mamba_norm_before_gate = mamba_norm_before_gate
self.mamba_rms_norm = mamba_rms_norm
self.lm_head_multiplier = lm_head_multiplier
self.embedding_multiplier = embedding_multiplier
if mlp_multipliers is not None:
self.mlp_multipliers = mlp_multipliers
else:
self.mlp_multipliers = [1.0, 1.0]
if attention_out_multiplier is not None:
self.attention_out_multiplier = attention_out_multiplier
else:
self.attention_out_multiplier = 1.0
if attention_in_multiplier is not None:
self.attention_in_multiplier = attention_in_multiplier
else:
self.attention_in_multiplier = 1.0
if key_multiplier is not None:
self.key_multiplier = key_multiplier
else:
self.key_multiplier = 1.0
if ssm_multipliers is not None:
self.ssm_multipliers = ssm_multipliers
else:
self.ssm_multipliers = [1.0, 1.0, 1.0, 1.0, 1.0]
if ssm_in_multiplier is not None:
self.ssm_in_multiplier = ssm_in_multiplier
else:
self.ssm_in_multiplier = 1.0
if ssm_out_multiplier is not None:
self.ssm_out_multiplier = ssm_out_multiplier
else:
self.ssm_out_multiplier = 1.0
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
@property
def layers_block_type(self):
return ["falcon_h1" for i in range(self.num_hidden_layers)]
@property
def mamba_cache_per_req(self):
conv_state_shape, temporal_state_shape, conv_dtype, ssm_dtype, mamba_layers = (
self.hybrid_gdn_params
)
mamba_layers_len = len(mamba_layers)
return (
int(np.prod(conv_state_shape)) * conv_dtype.itemsize
+ int(np.prod(temporal_state_shape)) * ssm_dtype.itemsize
) * mamba_layers_len
@property
def full_attention_layer_ids(self):
# For Falcon-H1, we do have attention on all layers
return range(self.num_hidden_layers)
@property
def linear_layer_ids(self):
# For Falcon-H1, we do have mamba on all layers
return range(self.num_hidden_layers)
@property
def hybrid_gdn_params(self):
world_size = get_tensor_model_parallel_world_size()
n_groups = self.mamba_n_groups
if self.mamba_n_groups % world_size != 0:
# - for TP we shard conv_dim by sharding on n_groups,
# - but if n_groups cannot divide tp_size, we need to
# extend some extra groups
extra_groups = MambaStateShapeCalculator.extra_groups_for_head_shards(
self.mamba_n_groups, world_size
)
n_groups += extra_groups
conv_dim = self.mamba_d_ssm + 2 * n_groups * self.mamba_d_state
conv_state_shape = (
divide(conv_dim, world_size),
self.mamba_d_conv - 1,
)
# we TP-ize on the heads dimension
temporal_state_shape = (
self.mamba_d_state,
self.mamba_d_head,
divide(self.mamba_n_heads, world_size),
)
conv_dtype = torch.bfloat16
dtype_map = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
mamba_layers = self.linear_layer_ids
return (
conv_state_shape,
temporal_state_shape,
conv_dtype,
ssm_dtype,
mamba_layers,
)
......@@ -41,6 +41,7 @@ from sglang.srt.configs import (
DotsOCRConfig,
DotsVLMConfig,
ExaoneConfig,
FalconH1Config,
KimiVLConfig,
LongcatFlashConfig,
MultiModalityConfig,
......@@ -62,6 +63,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
Step3VLConfig.model_type: Step3VLConfig,
LongcatFlashConfig.model_type: LongcatFlashConfig,
Qwen3NextConfig.model_type: Qwen3NextConfig,
FalconH1Config.model_type: FalconH1Config,
DotsVLMConfig.model_type: DotsVLMConfig,
DotsOCRConfig.model_type: DotsOCRConfig,
}
......
......@@ -69,6 +69,7 @@ class MambaAttnBackend(AttentionBackend):
def init_forward_metadata(self, forward_batch: ForwardBatch):
bs = forward_batch.batch_size
if forward_batch.forward_mode.is_decode_or_idle():
query_start_loc = torch.arange(
0, bs + 1, dtype=torch.int32, device=self.device
......
# Adapted from: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/mamba_utils.py
from sglang.srt.distributed.utils import divide
class MambaStateShapeCalculator:
@classmethod
def linear_attention_state_shape(
cls,
num_heads: int,
tp_size: int,
head_dim: int,
) -> tuple[tuple[int, int, int], ...]:
state_shape = (num_heads // tp_size, head_dim, head_dim)
return (state_shape,)
@classmethod
def mamba1_state_shape(
cls,
tp_world_size: int,
intermediate_size: int,
state_size: int,
conv_kernel: int,
) -> tuple[tuple[int, int], tuple[int, int]]:
conv_state_shape = (divide(intermediate_size, tp_world_size), conv_kernel - 1)
temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size)
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
return conv_state_shape, temporal_state_shape
@classmethod
def mamba2_state_shape(
cls,
tp_world_size: int,
intermediate_size: int,
n_groups: int,
num_heads: int,
head_dim: int,
state_size: int,
conv_kernel: int,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
n_groups = n_groups + cls.extra_groups_for_head_shards(n_groups, tp_world_size)
# heads and n_groups are TP-ed
conv_dim = intermediate_size + 2 * n_groups * state_size
# contiguous along 'dim' axis
conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size))
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size)
return conv_state_shape, temporal_state_shape
@classmethod
def short_conv_state_shape(
cls,
tp_world_size: int,
intermediate_size: int,
conv_kernel: int,
) -> tuple[tuple[int, int]]:
conv_dim = divide(intermediate_size, tp_world_size)
conv_state_shape = (conv_kernel - 1, conv_dim)
return (conv_state_shape,)
@classmethod
def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):
"""Compute the increase in group numbers to account for
replication in order to accompany the head shards."""
# in the case ngoups % tp_size == 0, this will be zero
if ngroups % tp_size == 0:
return 0
# for n_groups == 1, this is exactly tp_size - n_groups
return tp_size - ngroups
from .mamba_ssm import selective_state_update
from .ssd_combined import mamba_chunk_scan_combined
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2024, Tri Dao.
# Adapted from https://github.com/state-spaces/mamba/blob/60dadf2e0ee730ac337035d5533de10bc26e4847/mamba_ssm/ops/triton/layernorm_gated.py
import torch
import triton
import triton.language as tl
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
@triton.jit
def _layer_norm_fwd_1pass_kernel(
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
Z, # pointer to the other branch
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride_x_row: tl.int64,
stride_y_row: tl.int64,
stride_z_row: tl.int64,
M: tl.int64, # number of rows in X
N: tl.int64, # number of columns in X
eps, # epsilon to avoid division by zero
BLOCK_N: tl.constexpr,
HAS_BIAS: tl.constexpr,
HAS_Z: tl.constexpr,
NORM_BEFORE_GATE: tl.constexpr,
IS_RMS_NORM: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
group = tl.program_id(1)
X += row * stride_x_row + group * N
Y += row * stride_y_row + group * N
if HAS_Z:
Z += row * stride_z_row + group * N
if not IS_RMS_NORM:
Mean += group * M
Rstd += group * M
W += group * N
if HAS_BIAS:
B += group * N
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_Z and not NORM_BEFORE_GATE:
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
x *= z * tl.sigmoid(z)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N
tl.store(Mean + row, mean)
xbar = tl.where(cols < N, x - mean, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
else:
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
tl.store(Rstd + row, rstd)
# Normalize and apply linear transformation
mask = cols < N
w = tl.load(W + cols, mask=mask).to(tl.float32)
if HAS_BIAS:
b = tl.load(B + cols, mask=mask).to(tl.float32)
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
y = x_hat * w + b if HAS_BIAS else x_hat * w
if HAS_Z and NORM_BEFORE_GATE:
z = tl.load(Z + cols, mask=mask).to(tl.float32)
y *= z * tl.sigmoid(z)
# Write output
tl.store(Y + cols, y, mask=mask)
def _layer_norm_fwd(
x,
weight,
bias,
eps,
z=None,
out=None,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
):
M, N = x.shape
if group_size is None:
group_size = N
assert N % group_size == 0
ngroups = N // group_size
assert x.stride(-1) == 1
if z is not None:
assert z.stride(-1) == 1
assert z.shape == (M, N)
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
# allocate output
if out is not None:
assert out.shape == x.shape
else:
out = torch.empty_like(x)
assert out.stride(-1) == 1
mean = (
torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
if not is_rms_norm
else None
)
rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
if group_size > BLOCK_N:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_N // 256, 1), 8)
grid = (M, ngroups)
with torch.cuda.device(x.device.index):
_layer_norm_fwd_1pass_kernel[grid](
x,
out,
weight,
bias,
z,
mean,
rstd,
x.stride(0),
out.stride(0),
z.stride(0) if z is not None else 0,
M,
group_size,
eps,
BLOCK_N=BLOCK_N,
NORM_BEFORE_GATE=norm_before_gate,
IS_RMS_NORM=is_rms_norm,
num_warps=num_warps,
)
return out, mean, rstd
def rms_norm_gated(
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
):
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
if x.stride(-1) != 1:
x = x.contiguous()
if z is not None:
assert z.shape == x_shape_og
z = z.reshape(-1, z.shape[-1])
if z.stride(-1) != 1:
z = z.contiguous()
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y, _, _ = _layer_norm_fwd(
x,
weight,
bias,
eps,
z=z,
group_size=group_size,
norm_before_gate=norm_before_gate,
is_rms_norm=True,
)
return y.reshape(x_shape_og)
# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/mamba_ssm.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py
import torch
import triton
import triton.language as tl
from packaging import version
from sglang.srt import _custom_ops as ops
PAD_SLOT_ID = -1
TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
if TRITON3:
@triton.jit
def softplus(dt):
dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)
return dt
else:
@triton.jit
def softplus(dt):
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
return dt
@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
@triton.heuristics(
{
"HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"]
is not None
}
)
@triton.heuristics(
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}
)
@triton.jit
def _selective_scan_update_kernel(
# Pointers to matrices
state_ptr,
x_ptr,
dt_ptr,
dt_bias_ptr,
A_ptr,
B_ptr,
C_ptr,
D_ptr,
z_ptr,
out_ptr,
state_batch_indices_ptr,
pad_slot_id,
# Matrix dimensions
batch,
nheads,
dim,
dstate,
nheads_ngroups_ratio,
# Strides
stride_state_batch,
stride_state_head,
stride_state_dim,
stride_state_dstate,
stride_x_batch,
stride_x_head,
stride_x_dim,
stride_dt_batch,
stride_dt_head,
stride_dt_dim,
stride_dt_bias_head,
stride_dt_bias_dim,
stride_A_head,
stride_A_dim,
stride_A_dstate,
stride_B_batch,
stride_B_group,
stride_B_dstate,
stride_C_batch,
stride_C_group,
stride_C_dstate,
stride_D_head,
stride_D_dim,
stride_z_batch,
stride_z_head,
stride_z_dim,
stride_out_batch,
stride_out_head,
stride_out_dim,
# Meta-parameters
DT_SOFTPLUS: tl.constexpr,
TIE_HDIM: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
HAS_DT_BIAS: tl.constexpr,
HAS_D: tl.constexpr,
HAS_Z: tl.constexpr,
HAS_STATE_BATCH_INDICES: tl.constexpr,
BLOCK_SIZE_DSTATE: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_b = tl.program_id(axis=1)
pid_h = tl.program_id(axis=2)
# If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate
# is taken from the state_batch_indices_ptr Otherwise, the state coordinate
# is the same as the batch id.
if HAS_STATE_BATCH_INDICES:
state_batch_indices_ptr += pid_b
state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64)
state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
else:
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
if HAS_DT_BIAS:
dt_bias_ptr += pid_h * stride_dt_bias_head
A_ptr += pid_h * stride_A_head
B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
if HAS_Z:
z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
state_ptrs = state_ptr + (
offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
)
x_ptrs = x_ptr + offs_m * stride_x_dim
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
if HAS_DT_BIAS:
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
if HAS_D:
D_ptr += pid_h * stride_D_head
A_ptrs = A_ptr + (
offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate
)
B_ptrs = B_ptr + offs_n * stride_B_dstate
C_ptrs = C_ptr + offs_n * stride_C_dstate
if HAS_D:
D_ptrs = D_ptr + offs_m * stride_D_dim
if HAS_Z:
z_ptrs = z_ptr + offs_m * stride_z_dim
out_ptrs = out_ptr + offs_m * stride_out_dim
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
if HAS_STATE_BATCH_INDICES:
mask &= state_batch_idx != pad_slot_id
state = tl.load(state_ptrs, mask=mask, other=0.0)
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if not TIE_HDIM:
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if DT_SOFTPLUS:
dt = softplus(dt)
A = tl.load(
A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
).to(tl.float32)
dA = tl.exp(A * dt[:, None])
else:
dt = tl.load(dt_ptr).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptr).to(tl.float32)
if DT_SOFTPLUS:
dt = softplus(dt)
A = tl.load(A_ptr).to(tl.float32)
dA = tl.exp(A * dt) # scalar, not a matrix
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
if HAS_D:
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_Z:
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
state = state * dA + dB * x[:, None]
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
if HAS_STATE_BATCH_INDICES:
mask &= state_batch_idx != pad_slot_id
tl.store(state_ptrs, state, mask=mask)
out = tl.sum(state * C[None, :], axis=1)
if HAS_D:
out += x * D
if HAS_Z:
out *= z * tl.sigmoid(z)
tl.store(out_ptrs, out, mask=offs_m < dim)
def selective_state_update(
state,
x,
dt,
A,
B,
C,
D=None,
z=None,
dt_bias=None,
dt_softplus=False,
state_batch_indices=None,
pad_slot_id=PAD_SLOT_ID,
out=None,
):
"""
Argument:
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
x: (batch, dim) or (batch, nheads, dim)
dt: (batch, dim) or (batch, nheads, dim)
A: (dim, dstate) or (nheads, dim, dstate)
B: (batch, dstate) or (batch, ngroups, dstate)
C: (batch, dstate) or (batch, ngroups, dstate)
D: (dim,) or (nheads, dim)
z: (batch, dim) or (batch, nheads, dim)
dt_bias: (dim,) or (nheads, dim)
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: Preallocated ssm output tensor. Assume same shape as x.
In-place updated.
"""
if state.dim() == 3:
state = state.unsqueeze(1)
if x.dim() == 2:
x = x.unsqueeze(1)
if dt.dim() == 2:
dt = dt.unsqueeze(1)
if A.dim() == 2:
A = A.unsqueeze(0)
if B.dim() == 2:
B = B.unsqueeze(1)
if C.dim() == 2:
C = C.unsqueeze(1)
if D is not None and D.dim() == 1:
D = D.unsqueeze(0)
if z is not None and z.dim() == 2:
z = z.unsqueeze(1)
if dt_bias is not None and dt_bias.dim() == 1:
dt_bias = dt_bias.unsqueeze(0)
if out.dim() == 2:
out = out.unsqueeze(1)
_, nheads, dim, dstate = state.shape
batch = x.shape[0]
assert x.shape == (batch, nheads, dim)
assert dt.shape == x.shape
assert A.shape == (nheads, dim, dstate)
ngroups = B.shape[1]
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
assert B.shape == (batch, ngroups, dstate)
assert C.shape == B.shape
if D is not None:
assert D.shape == (nheads, dim)
if z is not None:
assert z.shape == x.shape
if dt_bias is not None:
assert dt_bias.shape == (nheads, dim)
if state_batch_indices is not None:
assert state_batch_indices.shape == (batch,)
assert out.shape == x.shape
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads)
z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
# We don't want autotune since it will overwrite the state
# We instead tune by hand.
BLOCK_SIZE_M, num_warps = (
(32, 4)
if dstate <= 16
else (
(16, 4)
if dstate <= 32
else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))
)
)
tie_hdim = (
A.stride(-1) == 0
and A.stride(-2) == 0
and dt.stride(-1) == 0
and dt_bias.stride(-1) == 0
)
with torch.cuda.device(x.device.index):
_selective_scan_update_kernel[grid](
state,
x,
dt,
dt_bias,
A,
B,
C,
D,
z,
out,
state_batch_indices,
pad_slot_id,
batch,
nheads,
dim,
dstate,
nheads // ngroups,
state.stride(0),
state.stride(1),
state.stride(2),
state.stride(3),
x.stride(0),
x.stride(1),
x.stride(2),
dt.stride(0),
dt.stride(1),
dt.stride(2),
*(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
A.stride(0),
A.stride(1),
A.stride(2),
B.stride(0),
B.stride(1),
B.stride(2),
C.stride(0),
C.stride(1),
C.stride(2),
*(D.stride(0), D.stride(1)) if D is not None else 0,
z_strides[0],
z_strides[1],
z_strides[2],
out.stride(0),
out.stride(1),
out.stride(2),
dt_softplus,
tie_hdim,
BLOCK_SIZE_M,
num_warps=num_warps,
)
def selective_scan_fn(
u,
ssm_states,
delta,
A,
B,
C,
D=None,
z=None,
delta_bias=None,
delta_softplus=False,
query_start_loc=None,
cache_indices=None,
has_initial_state=None,
pad_slot_id=PAD_SLOT_ID,
) -> torch.Tensor:
"""
u: (dim, total_length) for varlen or (batch, dim, seqlen)
applies changes in place.
ssm_states: (batch, dim, dstate) or (batch, nheads, dim, dstate)
applies changes in place.
delta: (dim, total_length) for varlen or (batch, dim, seqlen)
A: (dim, dstate)
B: (ngroups, dstate, total_length) for varlen or
(batch,ngroups,dstate,seqlen)
C: (ngroups, dstate, total_length) for varlen or
(batch,ngroups,dstate,seqlen)
D: (dim,)
z: (dim, total_length) for varlen or (batch, dim, seqlen)
dt_bias: (dim,) or (dim)
query_start_loc: (batch + 1) int32
The cumulative sequence lengths of the sequences in
the batch, used to index into sequence. prepended with 0.
for example: query_start_loc = torch.Tensor([0,10,16,17]),
x.shape=(dim,17)
cache_indices: (batch) int32
A tensor with each cell is a correspondent
input and output ssm_state index
has_initial_state: (batch) bool
A tensor populated with ones and zeros,
indicate if the ssm_state at the corresponding index should be
used as initial state. Not providing argument assumes
there's no initial state
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padding entries
that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at indices 0 and 3
returns
output: (dim, total_length) for varlen or (batch, dim, seqlen)
supports inplace replacement
"""
if u.stride(-1) != 1:
u = u.contiguous()
if delta.stride(-1) != 1:
delta = delta.contiguous()
if D is not None:
D = D.contiguous()
if B.stride(-1) != 1:
B = B.contiguous()
if C.stride(-1) != 1:
C = C.contiguous()
if z is not None and z.stride(-1) != 1:
z = z.contiguous()
if B.dim() == 3 and query_start_loc is None:
B = B.unsqueeze(1)
if B.dim() == 2 and query_start_loc is not None:
B = B.unsqueeze(0)
if C.dim() == 3 and query_start_loc is None:
C = C.unsqueeze(1)
if C.dim() == 2 and query_start_loc is not None:
C = C.unsqueeze(0)
ops.selective_scan_fwd(
u,
delta,
A,
B,
C,
D,
z,
delta_bias,
delta_softplus,
query_start_loc,
cache_indices,
has_initial_state,
ssm_states,
pad_slot_id,
)
if z is None:
return delta # output written inplace to delta
else:
return z # output written inplace to z
# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/ssd_bmm.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_bmm.py
# ruff: noqa: E501,SIM102
import math
import torch
import triton
import triton.language as tl
# @triton.autotune(
# configs=[
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
# num_stages=3,
# num_warps=8,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=2,
# ),
# ],
# key=["chunk_size", "K", "IS_CAUSAL"],
# )
@triton.jit
def _bmm_chunk_fwd_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
out_ptr,
seq_idx_ptr,
# Matrix dimensions
seqlen,
chunk_size,
K,
ngroups,
stride_a_batch,
stride_a_seqlen,
stride_a_head,
stride_ak,
stride_b_batch,
stride_b_seqlen,
stride_b_head,
stride_bk,
stride_out_batch,
stride_out_chunk,
stride_out_head,
stride_outm,
stride_outn,
stride_seq_idx_batch,
stride_seq_idx_seqlen,
# Meta-parameters
IS_CAUSAL: tl.constexpr,
dot_dtype: tl.constexpr,
HAS_SEQ_IDX: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr = 16,
BLOCK_SIZE_N: tl.constexpr = 16,
BLOCK_SIZE_K: tl.constexpr = 16,
):
pid_b = tl.program_id(axis=1)
pid_ch = tl.program_id(axis=2).to(tl.int64)
pid_c = pid_ch // ngroups
pid_h = pid_ch - pid_c * ngroups
num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
if IS_CAUSAL:
if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
return
a_ptr += (
pid_b * stride_a_batch
+ pid_c * chunk_size * stride_a_seqlen
+ pid_h * stride_a_head
)
b_ptr += (
pid_b * stride_b_batch
+ pid_c * chunk_size * stride_b_seqlen
+ pid_h * stride_b_head
)
if HAS_SEQ_IDX:
seq_idx_ptr += (
pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
)
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(
a_ptrs,
mask=(offs_m[:, None] < chunk_size_limit)
& (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
).to(dot_dtype)
b = tl.load(
b_ptrs,
mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K)
& (offs_n[None, :] < chunk_size_limit),
other=0.0,
).to(dot_dtype)
acc += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
if HAS_SEQ_IDX:
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
seq_idx_m = tl.load(
seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
mask=offs_m < chunk_size_limit,
other=-1,
)
seq_idx_n = tl.load(
seq_idx_ptr + offs_n * stride_seq_idx_seqlen,
mask=offs_n < chunk_size_limit,
other=-2,
)
acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
out = acc.to(out_ptr.dtype.element_ty)
out_ptr += (
pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head
)
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)
tl.store(
out_ptrs,
out,
mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size),
)
def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None):
"""
Argument:
a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
b: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
guaranteed to be correct.
Return:
out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
"""
# Check constraints.
has_groups = a.dim() == 4
if not has_groups:
batch, seqlen, k = a.shape
else:
batch, seqlen, ngroups, k = a.shape
assert b.shape == a.shape
if seq_idx is not None:
assert seq_idx.shape == (batch, seqlen)
if a.stride(-1) != 1 and a.stride(1) != 1:
a = a.contiguous()
if b.stride(-1) != 1 and b.stride(1) != 1:
b = b.contiguous()
nchunks = math.ceil(seqlen / chunk_size)
# Allocates output.
out_dtype = a.dtype if output_dtype is None else output_dtype
out = torch.empty(
(
(batch, nchunks, chunk_size, chunk_size)
if not has_groups
else (batch, nchunks, ngroups, chunk_size, chunk_size)
),
device=a.device,
dtype=out_dtype,
)
dot_dtype = (
tl.bfloat16
if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16
else (
tl.float16
if a.dtype == torch.float16 or b.dtype == torch.float16
else tl.float32
)
)
grid = lambda META: (
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
* triton.cdiv(chunk_size, META["BLOCK_SIZE_N"]),
batch,
nchunks if not has_groups else nchunks * ngroups,
)
with torch.cuda.device(a.device.index):
_bmm_chunk_fwd_kernel[grid](
a,
b,
out,
seq_idx,
seqlen,
chunk_size,
k,
ngroups if has_groups else 1,
a.stride(0),
a.stride(1),
0 if not has_groups else a.stride(2),
a.stride(-1),
b.stride(0),
b.stride(1),
0 if not has_groups else b.stride(2),
b.stride(-1),
out.stride(0),
out.stride(1),
0 if not has_groups else out.stride(2),
out.stride(-2),
out.stride(-1),
*(
(seq_idx.stride(0), seq_idx.stride(1))
if seq_idx is not None
else (0, 0)
),
causal,
dot_dtype,
HAS_SEQ_IDX=seq_idx is not None,
)
return out
This diff is collapsed.
# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/ssd_combined.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_combined.py
# ruff: noqa: E501
import torch
import triton
import triton.language as tl
from einops import rearrange
from packaging import version
from .ssd_bmm import _bmm_chunk_fwd
from .ssd_chunk_scan import _chunk_scan_fwd
from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd, chunk_state_varlen
from .ssd_state_passing import _state_passing_fwd
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
def is_int_pow_2(n):
return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0
def _mamba_chunk_scan_combined_fwd(
x,
dt,
A,
B,
C,
chunk_size,
D=None,
z=None,
dt_bias=None,
initial_states=None,
seq_idx=None,
chunk_indices=None,
chunk_offsets=None,
cu_seqlens=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
state_dtype=None,
out=None,
):
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
batch, seqlen, nheads, headdim = x.shape
_, _, ngroups, dstate = B.shape
assert nheads % ngroups == 0
assert B.shape == (batch, seqlen, ngroups, dstate)
assert dt.shape == (batch, seqlen, nheads)
assert A.shape == (nheads,)
assert C.shape == B.shape
if z is not None:
assert z.shape == x.shape
if D is not None:
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
if seq_idx is not None:
assert seq_idx.shape == (batch, seqlen)
if B.stride(-1) != 1:
B = B.contiguous()
if C.stride(-1) != 1:
C = C.contiguous()
if (
x.stride(-1) != 1 and x.stride(1) != 1
): # Either M or K dimension should be contiguous
x = x.contiguous()
if (
z is not None and z.stride(-1) != 1 and z.stride(1) != 1
): # Either M or K dimension should be contiguous
z = z.contiguous()
if D is not None and D.stride(-1) != 1:
D = D.contiguous()
if initial_states is not None:
if cu_seqlens is None:
assert initial_states.shape == (batch, nheads, headdim, dstate)
else:
assert initial_states.shape == (
len(cu_seqlens) - 1,
nheads,
headdim,
dstate,
)
# This function executes 5 sub-functions for computing mamba
# - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/
# which has a minimal implementation to understand the below operations
# - as explained by the blog, mamba is a special case of causal attention
# - the idea is to chunk the attention matrix and compute each
# submatrix separately using different optimizations.
# - see the blog and paper for a visualization of the submatrices
# which we refer to in the comments below
# 1. Compute chunked cumsum of A * dt
# - here dt may go through a softplus activation
dA_cumsum, dt = _chunk_cumsum_fwd(
dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit
)
# 2. Compute the state for each intra-chunk
# (right term of low-rank factorization of off-diagonal blocks; B terms)
states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
# - for handling chunked prefill, this requires i) initial_states
# ii) seq_idx iii) is_cont_batched and (iv) chunk_offsets to be all specified.
# - When a new seq_idx is detected, we will stop passing the prev_state
# and switch accordingly to the init_state corresponding to the new seq_idx.
# - We will also make sure that the dA_cumsum is taken only from the start of the
# sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries)
# - this will ensure that states will be updated with the rightmost flushed seq_idx
# of the previous chunk. This implies that the first chunk of states is either 0
# or equal to init_states of the first example.
states, final_states = _state_passing_fwd(
rearrange(states, "... p n -> ... (p n)"),
dA_cumsum,
initial_states=(
rearrange(initial_states, "... p n -> ... (p n)")
if initial_states is not None
else None
),
seq_idx=seq_idx,
chunk_size=chunk_size,
out_dtype=state_dtype if state_dtype is not None else C.dtype,
is_cont_batched=cu_seqlens is not None,
chunk_offsets=chunk_offsets,
)
states, final_states = (
rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]
)
# 4. Compute batched matrix multiply for C_j^T B_i terms
CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
# 5. Scan and compute the diagonal blocks, taking into
# account past causal states.
# - if initial states are provided, then states information will be
# augmented with initial_states.
# - to do this properly, we need to account for example changes in
# the continuous batch, therefore we introduce pseudo chunks, which is
# a chunk that is split up each time an example changes.
# - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had
# a seq_idx change, in which case we take states information from
# init_states.
out_x = _chunk_scan_fwd(
CB,
x,
dt,
dA_cumsum,
C,
states,
D=D,
z=z,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
initial_states=initial_states,
out=out,
)
if cu_seqlens is None:
return out_x, dt, dA_cumsum, states, final_states
else:
assert (
batch == 1
), "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1"
varlen_states = chunk_state_varlen(
B.squeeze(0),
x.squeeze(0),
dt.squeeze(0),
dA_cumsum.squeeze(0),
cu_seqlens,
states.squeeze(0),
initial_states=initial_states,
)
return out_x, dt, dA_cumsum, states, final_states, varlen_states
def mamba_chunk_scan_combined(
x,
dt,
A,
B,
C,
chunk_size,
D=None,
z=None,
dt_bias=None,
initial_states=None,
seq_idx=None,
chunk_indices=None,
chunk_offsets=None,
cu_seqlens=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
out=None,
return_final_states=False,
return_varlen_states=False,
state_dtype=None,
):
"""
Argument:
x: (batch, seqlen, nheads, headdim)
dt: (batch, seqlen, nheads)
A: (nheads)
B: (batch, seqlen, ngroups, dstate)
C: (batch, seqlen, ngroups, dstate)
chunk_size: int
D: (nheads, headdim) or (nheads,)
z: (batch, seqlen, nheads, headdim)
dt_bias: (nheads,)
initial_states: (batch, nheads, headdim, dstate)
seq_idx: (batch, seqlen)
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
dt_softplus: Whether to apply softplus to dt
out: Preallocated output tensor
state_dtype: The data type of the ssm state
"""
if not return_varlen_states:
cu_seqlens = None
else:
assert (
cu_seqlens is not None
), "cu_seqlens must be provided if return_varlen_states is True"
out_x, dt_out, dA_cumsum, states, final_states, *rest = (
_mamba_chunk_scan_combined_fwd(
x,
dt,
A,
B,
C,
chunk_size,
D=D,
z=z,
dt_bias=dt_bias,
initial_states=initial_states,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
cu_seqlens=cu_seqlens,
dt_softplus=dt_softplus,
dt_limit=dt_limit,
out=out,
state_dtype=state_dtype,
)
)
if not return_varlen_states:
if not return_final_states:
return
else:
return final_states
else:
varlen_states = rest[0]
return (
(varlen_states)
if not return_final_states
else (final_states, varlen_states)
)
# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_state_passing.py
# ruff: noqa: E501
import torch
import triton
import triton.language as tl
# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_SIZE": 64}),
# triton.Config({"BLOCK_SIZE": 128}),
# triton.Config({"BLOCK_SIZE": 256}),
# triton.Config({"BLOCK_SIZE": 512}),
# triton.Config({"BLOCK_SIZE": 1024}),
# triton.Config({"BLOCK_SIZE": 2048}),
# ],
# key=["dim"],
# )
@triton.jit
def _state_passing_fwd_kernel(
# Pointers to matrices
states_ptr,
out_ptr,
final_states_ptr,
dA_cs_ptr,
initstates_ptr,
seq_idx_ptr,
chunk_offsets_ptr,
chunk_meta_num,
# Matrix dimensions
dim,
nchunks,
seqlen,
chunk_size,
# Strides
stride_states_batch,
stride_states_chunk,
stride_states_head,
stride_states_dim,
stride_out_batch,
stride_out_chunk,
stride_out_head,
stride_out_dim,
stride_final_states_batch,
stride_final_states_head,
stride_final_states_dim,
stride_dA_cs_batch,
stride_dA_cs_chunk,
stride_dA_cs_head,
stride_dA_cs_csize,
stride_initstates_batch,
stride_initstates_head,
stride_initstates_dim,
stride_seq_idx_batch,
stride_seq_idx_seqlen,
# Meta-parameters
HAS_INITSTATES: tl.constexpr,
HAS_SEQ_IDX: tl.constexpr,
IS_CONT_BATCHED: tl.constexpr,
BLOCK_SIZE: tl.constexpr = 16,
):
pid_b = tl.program_id(axis=1)
pid_h = tl.program_id(axis=2)
pid_m = tl.program_id(axis=0)
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
dA_cs_ptr += (
pid_b * stride_dA_cs_batch
+ pid_h * stride_dA_cs_head
+ (chunk_size - 1) * stride_dA_cs_csize
)
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
final_states_ptr += (
pid_b * stride_final_states_batch + pid_h * stride_final_states_head
)
if HAS_INITSTATES:
initstates_ptr += pid_h * stride_initstates_head
if not IS_CONT_BATCHED:
initstates_ptr += pid_b * stride_initstates_batch
if HAS_SEQ_IDX:
seq_idx_ptr += pid_b * stride_seq_idx_batch
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
states_ptrs = states_ptr + offs_m * stride_states_dim
out_ptrs = out_ptr + offs_m * stride_out_dim
final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim
# - states will be the past state of the sequence that continues on the current check
if not HAS_INITSTATES:
states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
else:
initstates_ptr += offs_m * stride_initstates_dim
initstates_ptrs = initstates_ptr
# - for cont batches, for the first chunk mean it will be the first batch's
# init state
states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
tl.store(out_ptrs, states, mask=offs_m < dim)
out_ptrs += stride_out_chunk
prev_seq_idx_chunk_end = 0
logical_chunk_idx = 0
for c in range(nchunks):
new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
scale_mask = True
if HAS_SEQ_IDX:
# - the seq to pass forward is the one that is flushed to the right
# boundary.
# - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk.
seq_idx_chunk_end = tl.load(
seq_idx_ptr
+ (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen
)
if HAS_INITSTATES:
if IS_CONT_BATCHED and prev_seq_idx_chunk_end != seq_idx_chunk_end:
# this means in the current chunk the rightmost flushed seq
# has changed.
# - so we do not propagate the state from previous chunk
# - but rather we load that sequence's init state
initstates_ptrs = (
initstates_ptr + seq_idx_chunk_end * stride_initstates_batch
)
# - update state with seq_idx_new's init state
states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(
tl.float32
)
# - we need to consider the cumsum only of the last sequence in the chunk
# - find its starting position (given by c_off of the logical chunk index)
# - and subtract the cumsum just before that position from the total cumsum
# - first, update the logical chunk index (add the number of sequences in the current physical chunk):
# sequence index at the start of the current chunk
seq_idx_chunk_start = tl.load(
seq_idx_ptr
+ min(c * chunk_size, seqlen) * stride_seq_idx_seqlen
)
logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start
# - load the chunk offset:
c_off = tl.load(
chunk_offsets_ptr + logical_chunk_idx,
mask=logical_chunk_idx < chunk_meta_num,
other=0,
)
# - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything
if c_off > 0:
# - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset
dA_cs_boundary = tl.load(
dA_cs_ptr
- (chunk_size - 1) * stride_dA_cs_csize
+ (c_off - 1) * stride_dA_cs_csize,
mask=(c_off - 1) > -1 and c_off < chunk_size,
other=0.0,
)
dA_cs -= dA_cs_boundary
# - increment logical chunk index for every physical chunk
logical_chunk_idx += 1
else:
scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end
prev_seq_idx_chunk_end = seq_idx_chunk_end
scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0)
states = scale * states + new_states
if c < nchunks - 1:
tl.store(out_ptrs, states, mask=offs_m < dim)
else:
tl.store(final_states_ptrs, states, mask=offs_m < dim)
states_ptrs += stride_states_chunk
dA_cs_ptr += stride_dA_cs_chunk
out_ptrs += stride_out_chunk
def _state_passing_fwd(
states,
dA_cumsum,
initial_states=None,
seq_idx=None,
chunk_size=None,
out_dtype=None,
is_cont_batched=False,
chunk_offsets=None,
):
batch, nchunks, nheads, dim = states.shape
if chunk_size is None:
chunk_size = dA_cumsum.shape[-1]
else:
assert chunk_size == dA_cumsum.shape[-1]
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
if initial_states is not None:
if is_cont_batched:
# - if cu_seqlens is provided, then the initial states
# are used for continuous batching. In which case we
# require seq_idx to be provided
assert (
seq_idx is not None
), "seq_idx must be provided for continuous batching"
# - we also need chunk_offsets to be provided, to account
# for computation of dA_cumsum from the start of the
# sequence
assert (
chunk_offsets is not None
), "chunk_offsets must be provided for continuous batching"
else:
# - this is the regular batching case, where initial
# states are used are for each example of the batch.
assert initial_states.shape == (batch, nheads, dim)
if seq_idx is not None:
seqlen = seq_idx.shape[-1]
assert seq_idx.shape == (batch, seqlen)
out_dtype = states.dtype if out_dtype is None else out_dtype
out = torch.empty(
(batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype
)
final_states = torch.empty(
(batch, nheads, dim), device=states.device, dtype=torch.float32
)
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), batch, nheads)
with torch.cuda.device(states.device.index):
_state_passing_fwd_kernel[grid](
states,
out,
final_states,
dA_cumsum,
initial_states,
seq_idx,
chunk_offsets,
len(chunk_offsets) if chunk_offsets is not None else 0,
dim,
nchunks,
seqlen if seq_idx is not None else 0,
chunk_size,
states.stride(0),
states.stride(1),
states.stride(2),
states.stride(3),
out.stride(0),
out.stride(1),
out.stride(2),
out.stride(3),
final_states.stride(0),
final_states.stride(1),
final_states.stride(2),
dA_cumsum.stride(0),
dA_cumsum.stride(2),
dA_cumsum.stride(1),
dA_cumsum.stride(3),
*(
(
initial_states.stride(0),
initial_states.stride(1),
initial_states.stride(2),
)
if initial_states is not None
else (0, 0, 0)
),
*(
(seq_idx.stride(0), seq_idx.stride(1))
if seq_idx is not None
else (0, 0)
),
HAS_INITSTATES=initial_states is not None,
HAS_SEQ_IDX=seq_idx is not None,
IS_CONT_BATCHED=is_cont_batched,
)
return out, final_states
......@@ -1297,6 +1297,7 @@ class ModelRunner:
return self.model_config.hf_config.architectures[0] in [
"Qwen3NextForCausalLM",
"Qwen3NextForCausalLMMTP",
"FalconH1ForCausalLM",
]
def set_num_token_hybrid(self):
......
This diff is collapsed.
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestFalconH1(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "tiiuae/Falcon-H1-0.5B-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--tensor-parallel-size",
"1",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.74)
class TestFalconH1TP4(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "tiiuae/Falcon-H1-0.5B-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--tensor-parallel-size",
"4",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.74)
class TestFalconH1NoGatedRMS(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "tiiuae/Falcon-H1-1.5B-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--tensor-parallel-size",
"1",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.74)
class TestFalconH1NoGatedTP4(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "tiiuae/Falcon-H1-1.5B-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--tensor-parallel-size",
"4",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.74)
......@@ -140,6 +140,7 @@ suites = {
TestFile("test_local_attn.py", 250),
TestFile("test_pp_single_node.py", 372),
TestFile("models/test_qwen3_next_models.py", 200),
TestFile("models/test_falcon_h1_models.py", 200),
TestFile("test_multi_instance_release_memory_occupation.py", 64),
],
"per-commit-8-gpu": [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment