Unverified Commit a4bf5c6a authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Support Kimi Linear (#12469)


Co-authored-by: default avataryizhang2077 <1109276519@qq.com>
parent 30ad1070
......@@ -6,6 +6,7 @@ from sglang.srt.configs.dots_vlm import DotsVLMConfig
from sglang.srt.configs.exaone import ExaoneConfig
from sglang.srt.configs.falcon_h1 import FalconH1Config
from sglang.srt.configs.janus_pro import MultiModalityConfig
from sglang.srt.configs.kimi_linear import KimiLinearConfig
from sglang.srt.configs.kimi_vl import KimiVLConfig
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
from sglang.srt.configs.longcat_flash import LongcatFlashConfig
......@@ -31,6 +32,7 @@ __all__ = [
"Step3TextConfig",
"Step3VisionEncoderConfig",
"Olmo3Config",
"KimiLinearConfig",
"Qwen3NextConfig",
"DotsVLMConfig",
"DotsOCRConfig",
......
# Adapted from: https://github.com/vllm-project/vllm/blob/0384aa7150c4c9778efca041ffd1beb3ad2bd694/vllm/transformers_utils/configs/kimi_linear.py
from transformers.configuration_utils import PretrainedConfig
from sglang.srt.configs.mamba_utils import KimiLinearCacheParams, KimiLinearStateShape
from sglang.srt.layers.dp_attention import get_attention_tp_size
class KimiLinearConfig(PretrainedConfig):
model_type = "kimi_linear"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
model_type="kimi_linear",
vocab_size=163840,
hidden_size=4096,
head_dim=None,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
rope_theta=10000.0,
rope_scaling=None,
tie_word_embeddings=False,
moe_intermediate_size: int | None = None,
moe_renormalize: bool = True,
moe_router_activation_func: str = "sigmoid",
num_experts: int | None = None,
num_experts_per_token: int | None = None,
num_shared_experts: int = 0,
routed_scaling_factor: float = 1.0,
first_k_dense_replace: int = 0,
moe_layer_freq: int = 1,
use_grouped_topk: bool = True,
num_expert_group: int = 1,
topk_group: int = 1,
q_lora_rank: int | None = None,
kv_lora_rank: int | None = None,
qk_nope_head_dim: int | None = None,
qk_rope_head_dim: int | None = None,
v_head_dim: int | None = None,
mla_use_nope: bool | None = False,
num_nextn_predict_layers: int = 0,
linear_attn_config: dict | None = None,
**kwargs,
):
self.model_type = model_type
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.head_dim = (
head_dim if head_dim is not None else hidden_size // num_attention_heads
)
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.mla_use_nope = mla_use_nope
# moe config
self.n_routed_experts = self.num_experts = num_experts
self.num_experts_per_token = num_experts_per_token
self.moe_renormalize = moe_renormalize
self.num_shared_experts = num_shared_experts
self.routed_scaling_factor = routed_scaling_factor
self.moe_router_activation_func = moe_router_activation_func
assert self.moe_router_activation_func in ("softmax", "sigmoid")
self.moe_intermediate_size = moe_intermediate_size
self.first_k_dense_replace = first_k_dense_replace
self.moe_layer_freq = moe_layer_freq
self.use_grouped_topk = use_grouped_topk
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.num_nextn_predict_layers = num_nextn_predict_layers
if linear_attn_config is not None:
assert linear_attn_config["kda_layers"] is not None
assert linear_attn_config["full_attn_layers"] is not None
self.linear_attn_config = linear_attn_config
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
@property
def is_mla(self):
return (
self.q_lora_rank is not None
or self.kv_lora_rank is not None
or self.qk_nope_head_dim is not None
or self.qk_rope_head_dim is not None
or self.v_head_dim is not None
or self.mla_use_nope is True
)
@property
def is_moe(self):
return self.num_experts is not None
@property
def is_linear_attn(self) -> bool:
return not (
self.linear_attn_config is None
or (
isinstance(self.linear_attn_config, dict)
and self.linear_attn_config["kda_layers"] is not None
and len(self.linear_attn_config["kda_layers"]) == 0
)
)
def is_kda_layer(self, layer_idx: int):
return (
self.linear_attn_config is not None
and (layer_idx + 1) in self.linear_attn_config["kda_layers"]
)
@property
def linear_layer_ids(self):
return [i for i in range(self.num_hidden_layers) if self.is_kda_layer(i)]
@property
def full_attention_layer_ids(self):
return [i for i in range(self.num_hidden_layers) if not self.is_kda_layer(i)]
@property
def mamba2_cache_params(self) -> KimiLinearCacheParams:
shape = KimiLinearStateShape.create(
tp_world_size=get_attention_tp_size(),
num_heads=self.linear_attn_config["num_heads"],
head_dim=self.linear_attn_config["head_dim"],
conv_kernel_size=self.linear_attn_config["short_conv_kernel_size"],
)
return KimiLinearCacheParams(shape=shape, layers=self.linear_layer_ids)
......@@ -14,6 +14,7 @@
import os
from dataclasses import dataclass, field
from typing import List, Optional
import numpy as np
import torch
......@@ -115,3 +116,68 @@ class Mamba2CacheParams:
int(np.prod(self.shape.conv)) * self.dtype.conv.itemsize
+ int(np.prod(self.shape.temporal)) * self.dtype.temporal.itemsize
) * len(self.layers)
@dataclass(kw_only=True, frozen=True)
class KimiLinearStateShape:
conv: List[tuple[int, int]]
temporal: tuple[int, int, int]
num_heads: int
head_dim: int
num_k_heads: int
head_k_dim: int
conv_kernel: int
num_spec: int
@staticmethod
def create(
*,
tp_world_size: int,
num_heads: int,
head_dim: int,
num_k_heads: Optional[int] = None,
head_k_dim: Optional[int] = None,
conv_kernel_size: int = 4,
num_spec: int = 0,
) -> "KimiLinearStateShape":
if num_k_heads is None:
num_k_heads = num_heads
if head_k_dim is None:
head_k_dim = head_dim
proj_size = num_heads * head_dim
proj_k_size = num_k_heads * head_k_dim
conv_state_shape = (divide(proj_size, tp_world_size), conv_kernel_size - 1)
conv_state_k_shape = (divide(proj_k_size, tp_world_size), conv_kernel_size - 1)
temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, head_dim)
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
conv_state_k_shape = conv_state_k_shape[1], conv_state_k_shape[0]
return KimiLinearStateShape(
conv=[conv_state_shape, conv_state_k_shape, conv_state_k_shape],
temporal=temporal_state_shape,
num_heads=num_heads,
head_dim=head_dim,
num_k_heads=num_k_heads,
head_k_dim=head_k_dim,
conv_kernel=conv_kernel_size,
num_spec=num_spec,
)
@dataclass(kw_only=True, frozen=True)
class KimiLinearCacheParams:
shape: KimiLinearStateShape
dtype: Mamba2StateDType = field(default_factory=mamba2_state_dtype)
layers: list[int]
@property
def mamba_cache_per_req(self) -> int:
return (
int(np.sum([np.prod(conv_shape) for conv_shape in self.shape.conv]))
* self.dtype.conv.itemsize
+ int(np.prod(self.shape.temporal)) * self.dtype.temporal.itemsize
) * len(self.layers)
......@@ -366,6 +366,13 @@ class ModelConfig:
self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
self.v_head_dim = self.hf_text_config.v_head_dim
self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim
elif "KimiLinearForCausalLM" in self.hf_config.architectures:
self.head_dim = 72
self.attention_arch = AttentionArch.MLA
self.kv_lora_rank = self.hf_config.kv_lora_rank
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
self.v_head_dim = self.hf_config.v_head_dim
self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
else:
if (
"MistralModel" in self.hf_config.architectures
......
......@@ -189,6 +189,7 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
GDNAttnBackend,
HybridLinearAttnBackend,
KimiLinearAttnBackend,
Mamba2AttnBackend,
)
from sglang.srt.utils import is_blackwell, is_npu
......@@ -207,6 +208,8 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac
linear_attn_backend = GDNAttnBackend(runner)
elif runner.mamba2_config is not None:
linear_attn_backend = Mamba2AttnBackend(runner)
elif runner.kimi_linear_config is not None:
linear_attn_backend = KimiLinearAttnBackend(runner)
else:
raise ValueError(
"Expected hybrid GDN or NemotronH models, but got unknown model."
......
......@@ -21,6 +21,7 @@ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_GK": lambda args: args["gk"] is not None,
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
......@@ -44,6 +45,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
w,
v_new,
g,
gk,
h,
h0,
ht,
......@@ -57,6 +59,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
BT: tl.constexpr,
BV: tl.constexpr,
USE_G: tl.constexpr,
USE_GK: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr,
SAVE_NEW_VALUE: tl.constexpr,
......@@ -86,12 +89,12 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
b_h4 = tl.zeros([64, BV], dtype=tl.float32)
# calculate offset
h += (boh * H + i_h) * K * V
v += (bos * H + i_h) * V
k += (bos * Hg + i_h // (H // Hg)) * K
w += (bos * H + i_h) * K
h += ((boh * H + i_h) * K * V).to(tl.int64)
v += ((bos * H + i_h) * V).to(tl.int64)
k += ((bos * Hg + i_h // (H // Hg)) * K).to(tl.int64)
w += ((bos * H + i_h) * K).to(tl.int64)
if SAVE_NEW_VALUE:
v_new += (bos * H + i_h) * V
v_new += ((bos * H + i_h) * V).to(tl.int64)
stride_v = H * V
stride_h = H * K * V
stride_k = Hg * K
......@@ -143,58 +146,48 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
)
tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
p_v = tl.make_block_ptr(
v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
p_v_new = (
tl.make_block_ptr(
v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
if SAVE_NEW_VALUE
else None
)
b_v_new = tl.zeros([BT, BV], dtype=tl.float32)
p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)
)
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_new += tl.dot(b_w, b_h1.to(b_w.dtype))
b_v = tl.dot(b_w, b_h1.to(b_w.dtype))
if K > 64:
p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)
)
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_new += tl.dot(b_w, b_h2.to(b_w.dtype))
b_v += tl.dot(b_w, b_h2.to(b_w.dtype))
if K > 128:
p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)
)
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_new += tl.dot(b_w, b_h3.to(b_w.dtype))
b_v += tl.dot(b_w, b_h3.to(b_w.dtype))
if K > 192:
p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)
)
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_new += tl.dot(b_w, b_h4.to(b_w.dtype))
b_v_new = -b_v_new + tl.load(p_v, boundary_check=(0, 1))
b_v += tl.dot(b_w, b_h4.to(b_w.dtype))
p_v = tl.make_block_ptr(
v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v
if SAVE_NEW_VALUE:
p_v_new = tl.make_block_ptr(
p_v = tl.make_block_ptr(
v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
tl.store(
p_v_new, b_v_new.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)
)
tl.store(p_v, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1))
last_idx = min((i_t + 1) * BT, T) - 1
if USE_G:
last_idx = min((i_t + 1) * BT, T) - 1
b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
p_g = tl.make_block_ptr(
g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
)
b_g = tl.load(p_g, boundary_check=(0,))
b_v_new = b_v_new * safe_exp(b_g_last - b_g)[:, None]
b_v = b_v * safe_exp(b_g_last - b_g)[:, None]
b_g_last = exp(b_g_last)
b_h1 = b_h1 * b_g_last
if K > 64:
......@@ -203,30 +196,64 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
b_h3 = b_h3 * b_g_last
if K > 192:
b_h4 = b_h4 * b_g_last
b_v_new = b_v_new.to(k.dtype.element_ty)
if USE_GK:
o_k1 = tl.arange(0, 64)
b_gk_last1 = tl.load(
gk + (bos + last_idx) * H * K + i_h * K + o_k1,
mask=(o_k1 < K),
other=0.0,
)
b_h1 *= exp(b_gk_last1)[:, None]
if K > 64:
o_k2 = 64 + o_k1
b_gk_last2 = tl.load(
gk + (bos + last_idx) * H * K + i_h * K + o_k2,
mask=(o_k2 < K),
other=0.0,
)
b_h2 *= exp(b_gk_last2)[:, None]
if K > 128:
o_k3 = 128 + o_k1
b_gk_last3 = tl.load(
gk + (bos + last_idx) * H * K + i_h * K + o_k3,
mask=(o_k3 < K),
other=0.0,
)
b_h3 *= exp(b_gk_last3)[:, None]
if K > 192:
o_k4 = 192 + o_k1
b_gk_last4 = tl.load(
gk + (bos + last_idx) * H * K + i_h * K + o_k4,
mask=(o_k4 < K),
other=0.0,
)
b_h4 *= exp(b_gk_last4)[:, None]
b_v = b_v.to(k.dtype.element_ty)
p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h1 += tl.dot(b_k, b_v_new)
b_h1 += tl.dot(b_k, b_v)
if K > 64:
p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h2 += tl.dot(b_k, b_v_new)
b_h2 += tl.dot(b_k, b_v)
if K > 128:
p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h3 += tl.dot(b_k, b_v_new)
b_h3 += tl.dot(b_k, b_v)
if K > 192:
p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h4 += tl.dot(b_k, b_v_new)
b_h4 += tl.dot(b_k, b_v)
# epilogue
if STORE_FINAL_STATE:
......@@ -254,6 +281,7 @@ def chunk_gated_delta_rule_fwd_h(
w: torch.Tensor,
u: torch.Tensor,
g: Optional[torch.Tensor] = None,
gk: Optional[torch.Tensor] = None,
initial_state: Optional[torch.Tensor] = None,
output_final_state: bool = False,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
......@@ -296,6 +324,7 @@ def chunk_gated_delta_rule_fwd_h(
w=w,
v_new=v_new,
g=g,
gk=gk,
h=h,
h0=initial_state,
ht=final_state,
......
......@@ -44,6 +44,7 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar,
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
IS_VARLEN: tl.constexpr,
IS_KDA: tl.constexpr,
):
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_n, i_hv = i_nh // HV, i_nh % HV
......@@ -67,7 +68,11 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
p_beta = beta + (bos * HV + i_hv) * V + o_v
else:
p_beta = beta + bos * HV + i_hv
p_g = g + bos * HV + i_hv
if not IS_KDA:
p_g = g + bos * HV + i_hv
else:
p_gk = g + (bos * HV + i_hv) * K + o_k
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
mask_k = o_k < K
......@@ -83,14 +88,18 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
b_g = tl.load(p_g).to(tl.float32)
if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
b_q = b_q * scale
# [BK, BV]
b_h *= exp(b_g)
if not IS_KDA:
b_g = tl.load(p_g).to(tl.float32)
b_h *= exp(b_g)
else:
b_gk = tl.load(p_gk).to(tl.float32)
b_h *= exp(b_gk[:, None])
# [BV]
b_v -= tl.sum(b_h * b_k[:, None], 0)
if IS_BETA_HEADWISE:
......@@ -108,7 +117,10 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
p_k += H * K
p_o += HV * V
p_v += HV * V
p_g += HV
if not IS_KDA:
p_g += HV
else:
p_gk += HV * K
p_beta += HV * (V if IS_BETA_HEADWISE else 1)
if STORE_FINAL_STATE:
......@@ -165,6 +177,7 @@ def fused_recurrent_gated_delta_rule_fwd(
BV=BV,
IS_BETA_HEADWISE=beta.ndim == v.ndim,
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
IS_KDA=False,
num_warps=num_warps,
num_stages=num_stages,
)
......
This diff is collapsed.
from typing import Optional, Union
import torch
from einops import rearrange
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule
......@@ -10,6 +11,11 @@ from sglang.srt.layers.attention.fla.fused_recurrent import (
from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (
fused_sigmoid_gating_delta_rule_update,
)
from sglang.srt.layers.attention.fla.kda import (
chunk_kda,
fused_kda_gate,
fused_recurrent_kda,
)
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
PAD_SLOT_ID,
causal_conv1d_fn,
......@@ -227,6 +233,223 @@ class MambaAttnBackendBase(AttentionBackend):
return 1 # Mamba attn does not use seq lens to index kv cache
class KimiLinearAttnBackend(MambaAttnBackendBase):
"""Attention backend using Mamba kernel."""
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
**kwargs,
):
q_proj_states = kwargs["q_proj_states"]
k_proj_states = kwargs["k_proj_states"]
v_proj_states = kwargs["v_proj_states"]
q_conv_weights = kwargs["q_conv_weights"]
k_conv_weights = kwargs["k_conv_weights"]
v_conv_weights = kwargs["v_conv_weights"]
q_conv_bias = kwargs["q_conv_bias"]
k_conv_bias = kwargs["k_conv_bias"]
v_conv_bias = kwargs["v_conv_bias"]
A_log = kwargs["A_log"]
dt_bias = kwargs["dt_bias"]
b_proj = kwargs["b_proj"]
f_a_proj = kwargs["f_a_proj"]
f_b_proj = kwargs["f_b_proj"]
hidden_states = kwargs["hidden_states"]
head_dim = kwargs["head_dim"]
layer_id = kwargs["layer_id"]
layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
q_conv_state, k_conv_state, v_conv_state = layer_cache.conv
ssm_states = layer_cache.temporal
query_start_loc = self.forward_metadata.query_start_loc
cache_indices = self.forward_metadata.mamba_cache_indices
q_conv_state = q_conv_state.transpose(-1, -2)
k_conv_state = k_conv_state.transpose(-1, -2)
v_conv_state = v_conv_state.transpose(-1, -2)
q = causal_conv1d_update(
q_proj_states,
q_conv_state,
q_conv_weights,
q_conv_bias,
activation="silu",
conv_state_indices=cache_indices,
)
k = causal_conv1d_update(
k_proj_states,
k_conv_state,
k_conv_weights,
k_conv_bias,
activation="silu",
conv_state_indices=cache_indices,
)
v = causal_conv1d_update(
v_proj_states,
v_conv_state,
v_conv_weights,
v_conv_bias,
activation="silu",
conv_state_indices=cache_indices,
)
q, k, v = map(
lambda x: rearrange(x, "n (h d) -> 1 n h d", d=head_dim), (q, k, v)
)
beta = b_proj(hidden_states)[0].float().sigmoid()
g = f_b_proj(f_a_proj(hidden_states)[0])[0]
g = fused_kda_gate(g, A_log, head_dim, g_bias=dt_bias)
beta = beta.unsqueeze(0)
g = g.unsqueeze(0)
initial_state = ssm_states[cache_indices].contiguous()
(
core_attn_out,
last_recurrent_state,
) = fused_recurrent_kda(
q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=initial_state,
use_qk_l2norm_in_kernel=True,
cu_seqlens=query_start_loc,
)
ssm_states[cache_indices] = last_recurrent_state
return core_attn_out
def forward_extend(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
**kwargs,
):
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
causal_conv1d_fn,
)
q_proj_states = kwargs["q_proj_states"]
k_proj_states = kwargs["k_proj_states"]
v_proj_states = kwargs["v_proj_states"]
q_conv_weights = kwargs["q_conv_weights"]
k_conv_weights = kwargs["k_conv_weights"]
v_conv_weights = kwargs["v_conv_weights"]
q_conv_bias = kwargs["q_conv_bias"]
k_conv_bias = kwargs["k_conv_bias"]
v_conv_bias = kwargs["v_conv_bias"]
A_log = kwargs["A_log"]
dt_bias = kwargs["dt_bias"]
b_proj = kwargs["b_proj"]
f_a_proj = kwargs["f_a_proj"]
f_b_proj = kwargs["f_b_proj"]
hidden_states = kwargs["hidden_states"]
head_dim = kwargs["head_dim"]
layer_id = kwargs["layer_id"]
query_start_loc = self.forward_metadata.query_start_loc
cache_indices = self.forward_metadata.mamba_cache_indices
mamba_cache_params = self.req_to_token_pool.mamba2_layer_cache(layer_id)
conv_state_q, conv_state_k, conv_state_v = mamba_cache_params.conv
# deal with strides
conv_state_q = conv_state_q.transpose(-1, -2)
conv_state_k = conv_state_k.transpose(-1, -2)
conv_state_v = conv_state_v.transpose(-1, -2)
ssm_states = mamba_cache_params.temporal
has_initial_state = forward_batch.extend_prefix_lens > 0
q_proj_states = q_proj_states.transpose(0, 1)
k_proj_states = k_proj_states.transpose(0, 1)
v_proj_states = v_proj_states.transpose(0, 1)
q = causal_conv1d_fn(
q_proj_states,
q_conv_weights,
q_conv_bias,
activation="silu",
conv_states=conv_state_q,
has_initial_state=has_initial_state,
cache_indices=cache_indices,
query_start_loc=query_start_loc,
seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
).transpose(0, 1)
k = causal_conv1d_fn(
k_proj_states,
k_conv_weights,
k_conv_bias,
activation="silu",
conv_states=conv_state_k,
has_initial_state=has_initial_state,
cache_indices=cache_indices,
query_start_loc=query_start_loc,
seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
).transpose(0, 1)
v = causal_conv1d_fn(
v_proj_states,
v_conv_weights,
v_conv_bias,
activation="silu",
conv_states=conv_state_v,
has_initial_state=has_initial_state,
cache_indices=cache_indices,
query_start_loc=query_start_loc,
seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
).transpose(0, 1)
q, k, v = map(
lambda x: rearrange(x, "n (h d) -> 1 n h d", d=head_dim), (q, k, v)
)
beta = b_proj(hidden_states)[0].float().sigmoid()
g = f_b_proj(f_a_proj(hidden_states)[0])[0]
g = fused_kda_gate(g, A_log, head_dim, g_bias=dt_bias)
beta = beta.unsqueeze(0)
g = g.unsqueeze(0)
initial_state = ssm_states[cache_indices].contiguous()
(
core_attn_out,
last_recurrent_state,
) = chunk_kda(
q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=initial_state,
output_final_state=True,
use_qk_l2norm_in_kernel=True,
cu_seqlens=query_start_loc,
)
ssm_states[cache_indices] = last_recurrent_state
return core_attn_out
class GDNAttnBackend(MambaAttnBackendBase):
"""Attention backend using Mamba kernel."""
......
......@@ -92,7 +92,10 @@ class TritonAttnBackend(AttentionBackend):
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
get_attention_tp_size()
)
if model_runner.hybrid_gdn_config is not None:
if (
model_runner.hybrid_gdn_config is not None
or model_runner.kimi_linear_config is not None
):
# For hybrid linear models, layer_id = 0 may not be full attention
self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()
else:
......
......@@ -17,7 +17,7 @@ from __future__ import annotations
from dataclasses import dataclass
from sglang.srt.configs.mamba_utils import Mamba2CacheParams
from sglang.srt.configs.mamba_utils import KimiLinearCacheParams, Mamba2CacheParams
from sglang.srt.layers.attention.nsa import index_buf_accessor
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
......@@ -33,7 +33,7 @@ KVCache actually holds the physical kv cache.
import abc
import logging
from contextlib import nullcontext
from contextlib import contextmanager, nullcontext
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import numpy as np
......@@ -59,7 +59,9 @@ if _is_npu:
import torch_npu
def get_tensor_size_bytes(t: torch.Tensor):
def get_tensor_size_bytes(t: Union[torch.Tensor, List[torch.Tensor]]):
if isinstance(t, list):
return sum(get_tensor_size_bytes(x) for x in t)
return np.prod(t.shape) * t.dtype.itemsize
......@@ -116,10 +118,15 @@ class ReqToTokenPool:
class MambaPool:
@dataclass(frozen=True, kw_only=True)
class State:
conv: torch.Tensor
conv: Union[torch.Tensor, List[torch.Tensor]]
temporal: torch.Tensor
def at_layer_idx(self, layer: int):
if isinstance(self.conv, list):
return type(self)(
conv=[v[layer] for v in self.conv],
temporal=self.temporal[layer],
)
return type(self)(**{k: v[layer] for k, v in vars(self).items()})
def mem_usage_bytes(self):
......@@ -127,14 +134,14 @@ class MambaPool:
@dataclass(frozen=True, kw_only=True)
class SpeculativeState(State):
intermediate_ssm: torch.Tensor
intermediate_ssm: Union[torch.Tensor, List[torch.Tensor]]
intermediate_conv_window: torch.Tensor
def __init__(
self,
*,
size: int,
cache_params: "Mamba2CacheParams",
cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"],
device: str,
speculative_num_draft_tokens: Optional[int] = None,
):
......@@ -157,18 +164,29 @@ class MambaPool:
else:
self.custom_mem_pool = None
self.is_kda_cache = isinstance(cache_params, KimiLinearCacheParams)
with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.enable_custom_mem_pool
else nullcontext()
):
# assume conv_state = (dim, state_len)
assert conv_state_shape[0] > conv_state_shape[1]
conv_state = torch.zeros(
size=(num_mamba_layers, size + 1) + conv_state_shape,
dtype=conv_dtype,
device=device,
)
if self.is_kda_cache:
conv_state = [
torch.zeros(
size=(num_mamba_layers, size + 1) + conv_shape,
dtype=conv_dtype,
device=device,
)
for conv_shape in conv_state_shape
]
else:
# assume conv_state = (dim, state_len)
assert conv_state_shape[0] > conv_state_shape[1]
conv_state = torch.zeros(
size=(num_mamba_layers, size + 1) + conv_state_shape,
dtype=conv_dtype,
device=device,
)
temporal_state = torch.zeros(
size=(num_mamba_layers, size + 1) + temporal_state_shape,
dtype=ssm_dtype,
......@@ -191,17 +209,34 @@ class MambaPool:
)
# Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
intermediate_conv_window_cache = torch.zeros(
size=(
num_mamba_layers,
size + 1,
speculative_num_draft_tokens,
conv_state_shape[0],
conv_state_shape[1],
),
dtype=conv_dtype,
device="cuda",
)
if self.is_kda_cache:
intermediate_conv_window_cache = [
torch.zeros(
size=(
num_mamba_layers,
size + 1,
speculative_num_draft_tokens,
conv_shape[0],
conv_shape[1],
),
dtype=conv_dtype,
device="cuda",
)
for conv_shape in conv_state_shape
]
else:
intermediate_conv_window_cache = torch.zeros(
size=(
num_mamba_layers,
size + 1,
speculative_num_draft_tokens,
conv_state_shape[0],
conv_state_shape[1],
),
dtype=conv_dtype,
device="cuda",
)
self.mamba_cache = self.SpeculativeState(
conv=conv_state,
temporal=temporal_state,
......@@ -255,15 +290,25 @@ class MambaPool:
if free_index.numel() == 0:
return
self.free_slots = torch.cat((self.free_slots, free_index))
self.mamba_cache.conv[:, free_index] = self.mamba_cache.temporal[
:, free_index
] = 0
if self.is_kda_cache:
for i in range(len(self.mamba_cache.conv)):
self.mamba_cache.conv[i][:, free_index] = 0
else:
self.mamba_cache.conv[:, free_index] = 0
self.mamba_cache.temporal[:, free_index] = 0
def clear(self):
self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device)
def copy_from(self, src_index: torch.Tensor, dst_index: torch.Tensor):
self.mamba_cache.conv[:, dst_index] = self.mamba_cache.conv[:, src_index]
if self.is_kda_cache:
for i in range(len(self.mamba_cache.conv)):
self.mamba_cache.conv[i][:, dst_index] = self.mamba_cache.conv[i][
:, src_index
]
else:
self.mamba_cache.conv[:, dst_index] = self.mamba_cache.conv[:, src_index]
self.mamba_cache.temporal[:, dst_index] = self.mamba_cache.temporal[
:, src_index
]
......@@ -304,7 +349,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
max_context_len: int,
device: str,
enable_memory_saver: bool,
cache_params: "Mamba2CacheParams",
cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"],
speculative_num_draft_tokens: int = None,
):
super().__init__(
......@@ -323,7 +368,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
def _init_mamba_pool(
self,
size: int,
cache_params: "Mamba2CacheParams",
cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"],
device: str,
speculative_num_draft_tokens: int = None,
):
......@@ -812,6 +857,10 @@ class HybridLinearKVPool(KVCache):
enable_kvcache_transpose: bool,
device: str,
mamba_pool: MambaPool,
# TODO: refactor mla related args
use_mla: bool = False,
kv_lora_rank: int = None,
qk_rope_head_dim: int = None,
):
self.size = size
self.dtype = dtype
......@@ -825,25 +874,42 @@ class HybridLinearKVPool(KVCache):
self.mamba_pool = mamba_pool
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
assert not enable_kvcache_transpose
if _is_npu:
TokenToKVPoolClass = AscendTokenToKVPool
self.use_mla = use_mla
if not use_mla:
if _is_npu:
TokenToKVPoolClass = AscendTokenToKVPool
else:
TokenToKVPoolClass = MHATokenToKVPool
self.full_kv_pool = TokenToKVPoolClass(
size=size,
page_size=self.page_size,
dtype=dtype,
head_num=head_num,
head_dim=head_dim,
layer_num=self.full_layer_nums,
device=device,
enable_memory_saver=False,
)
else:
TokenToKVPoolClass = MHATokenToKVPool
self.full_kv_pool = TokenToKVPoolClass(
size=size,
page_size=self.page_size,
dtype=dtype,
head_num=head_num,
head_dim=head_dim,
layer_num=self.full_layer_nums,
device=device,
enable_memory_saver=False,
)
TokenToKVPoolClass = MLATokenToKVPool
self.full_kv_pool = TokenToKVPoolClass(
size=size,
page_size=self.page_size,
dtype=dtype,
layer_num=self.full_layer_nums,
device=device,
kv_lora_rank=kv_lora_rank,
qk_rope_head_dim=qk_rope_head_dim,
enable_memory_saver=False,
)
self.full_attention_layer_id_mapping = {
id: i for i, id in enumerate(full_attention_layer_ids)
}
k_size, v_size = self.get_kv_size_bytes()
self.mem_usage = (k_size + v_size) / GB
if use_mla:
self.mem_usage = self.get_kv_size_bytes() / GB
else:
k_size, v_size = self.get_kv_size_bytes()
self.mem_usage = (k_size + v_size) / GB
def get_kv_size_bytes(self):
return self.full_kv_pool.get_kv_size_bytes()
......@@ -879,6 +945,21 @@ class HybridLinearKVPool(KVCache):
layer_id = self._transfer_full_attention_id(layer_id)
return self.full_kv_pool.get_kv_buffer(layer_id)
@contextmanager
def _transfer_id_context(self, layer: RadixAttention):
@contextmanager
def _patch_layer_id(layer):
original_layer_id = layer.layer_id
layer.layer_id = self._transfer_full_attention_id(layer.layer_id)
try:
yield
finally:
layer.layer_id = original_layer_id
with _patch_layer_id(layer):
yield
def set_kv_buffer(
self,
layer: RadixAttention,
......@@ -889,19 +970,49 @@ class HybridLinearKVPool(KVCache):
v_scale: float = 1.0,
):
layer_id = self._transfer_full_attention_id(layer.layer_id)
self.full_kv_pool.set_kv_buffer(
None,
loc,
cache_k,
cache_v,
k_scale,
v_scale,
layer_id_override=layer_id,
)
if not self.use_mla:
self.full_kv_pool.set_kv_buffer(
None,
loc,
cache_k,
cache_v,
k_scale,
v_scale,
layer_id_override=layer_id,
)
else:
with self._transfer_id_context(layer):
self.full_kv_pool.set_kv_buffer(
layer,
loc,
cache_k,
cache_v,
)
def get_v_head_dim(self):
return self.full_kv_pool.get_value_buffer(0).shape[-1]
def set_mla_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k_nope: torch.Tensor,
cache_k_rope: torch.Tensor,
):
assert self.use_mla, "set_mla_kv_buffer called when use_mla is False"
with self._transfer_id_context(layer):
self.full_kv_pool.set_mla_kv_buffer(layer, loc, cache_k_nope, cache_k_rope)
def get_mla_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
dst_dtype: Optional[torch.dtype] = None,
):
assert self.use_mla, "get_mla_kv_buffer called when use_mla is False"
with self._transfer_id_context(layer):
return self.full_kv_pool.get_mla_kv_buffer(layer, loc, dst_dtype)
class SWAKVPool(KVCache):
"""KV cache with separate pools for full and SWA attention layers."""
......
......@@ -29,7 +29,12 @@ from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from sglang.srt.configs import FalconH1Config, NemotronHConfig, Qwen3NextConfig
from sglang.srt.configs import (
FalconH1Config,
KimiLinearConfig,
NemotronHConfig,
Qwen3NextConfig,
)
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
from sglang.srt.configs.model_config import (
......@@ -1358,9 +1363,16 @@ class ModelRunner:
return config
return None
@property
def kimi_linear_config(self):
config = self.model_config.hf_config
if isinstance(config, KimiLinearConfig):
return config
return None
@property
def mambaish_config(self):
return self.mamba2_config or self.hybrid_gdn_config
return self.mamba2_config or self.hybrid_gdn_config or self.kimi_linear_config
def set_num_token_hybrid(self):
if (
......@@ -1691,7 +1703,7 @@ class ModelRunner:
end_layer=self.end_layer,
index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config),
)
elif self.use_mla_backend:
elif self.use_mla_backend and not self.mambaish_config:
assert not is_nsa_model
self.token_to_kv_pool = MLATokenToKVPool(
self.max_total_num_tokens,
......@@ -1735,6 +1747,12 @@ class ModelRunner:
device=self.device,
)
elif config := self.mambaish_config:
extra_args = {}
if self.use_mla_backend:
extra_args = {
"kv_lora_rank": self.model_config.kv_lora_rank,
"qk_rope_head_dim": self.model_config.qk_rope_head_dim,
}
self.token_to_kv_pool = HybridLinearKVPool(
page_size=self.page_size,
size=self.max_total_num_tokens,
......@@ -1750,6 +1768,8 @@ class ModelRunner:
enable_kvcache_transpose=False,
device=self.device,
mamba_pool=self.req_to_token_pool.mamba_pool,
use_mla=self.use_mla_backend,
**extra_args,
)
else:
self.token_to_kv_pool = MHATokenToKVPool(
......
......@@ -1075,6 +1075,7 @@ class DeepseekV2AttentionMLA(nn.Module):
layer_id: int = None,
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
skip_rope: bool = False,
) -> None:
super().__init__()
self.layer_id = layer_id
......@@ -1182,23 +1183,26 @@ class DeepseekV2AttentionMLA(nn.Module):
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
self.rotary_emb = get_rope_wrapper(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False,
device=get_global_server_args().device,
)
if not skip_rope:
self.rotary_emb = get_rope_wrapper(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False,
device=get_global_server_args().device,
)
if rope_scaling:
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
scaling_factor = rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
if rope_scaling:
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
scaling_factor = rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
else:
self.rotary_emb.forward = self.rotary_emb.forward_native
else:
self.rotary_emb.forward = self.rotary_emb.forward_native
self.rotary_emb = None
self.attn_mqa = RadixAttention(
self.num_local_heads,
......@@ -1487,7 +1491,8 @@ class DeepseekV2AttentionMLA(nn.Module):
latent_cache = latent_cache.unsqueeze(1)
kv_a = self.kv_a_layernorm(kv_a)
k_pe = latent_cache[:, :, self.kv_lora_rank :]
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
if self.rotary_emb is not None:
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
q[..., self.qk_nope_head_dim :] = q_pe
self._set_mla_kv_buffer(latent_cache, kv_a, k_pe, forward_batch)
......@@ -1646,8 +1651,10 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out = q_nope_out.transpose(0, 1)
if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
not _use_aiter or not _is_gfx95_supported or self.use_nsa
if (
self.rotary_emb is not None
and (not self._fuse_rope_for_trtllm_mla(forward_batch))
and (not _use_aiter or not _is_gfx95_supported or self.use_nsa)
):
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
......
This diff is collapsed.
......@@ -1028,6 +1028,11 @@ class ServerArgs:
logger.info(
f"Using {self.attention_backend} as attention backend for {model_arch}."
)
elif model_arch in ["KimiLinearForCausalLM"]:
logger.warning(
f"Disabling Radix Cache for {model_arch} as it is not yet supported."
)
self.disable_radix_cache = True
if is_deepseek_nsa(hf_config):
if (
......
......@@ -43,6 +43,7 @@ from sglang.srt.configs import (
DotsVLMConfig,
ExaoneConfig,
FalconH1Config,
KimiLinearConfig,
KimiVLConfig,
LongcatFlashConfig,
MultiModalityConfig,
......@@ -68,6 +69,7 @@ _CONFIG_REGISTRY: List[Type[PretrainedConfig]] = [
Step3VLConfig,
LongcatFlashConfig,
Olmo3Config,
KimiLinearConfig,
Qwen3NextConfig,
FalconH1Config,
DotsVLMConfig,
......
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestKimiLinear(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "moonshotai/Kimi-Linear-48B-A3B-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--tp-size", "2", "--trust-remote"],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.88)
if __name__ == "__main__":
unittest.main()
......@@ -151,6 +151,7 @@ suites = {
TestFile("layers/attention/mamba/test_mamba2_mixer.py", 50),
TestFile("lora/test_lora_tp.py", 116),
TestFile("models/test_glm4_moe_models.py", 100),
TestFile("models/test_kimi_linear_models.py", 90),
TestFile("rl/test_update_weights_from_distributed.py", 103),
TestFile("test_data_parallelism.py", 73),
TestFile("test_disaggregation_basic.py", 400),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment