Unverified Commit a4bf5c6a authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Support Kimi Linear (#12469)


Co-authored-by: default avataryizhang2077 <1109276519@qq.com>
parent 30ad1070
...@@ -6,6 +6,7 @@ from sglang.srt.configs.dots_vlm import DotsVLMConfig ...@@ -6,6 +6,7 @@ from sglang.srt.configs.dots_vlm import DotsVLMConfig
from sglang.srt.configs.exaone import ExaoneConfig from sglang.srt.configs.exaone import ExaoneConfig
from sglang.srt.configs.falcon_h1 import FalconH1Config from sglang.srt.configs.falcon_h1 import FalconH1Config
from sglang.srt.configs.janus_pro import MultiModalityConfig from sglang.srt.configs.janus_pro import MultiModalityConfig
from sglang.srt.configs.kimi_linear import KimiLinearConfig
from sglang.srt.configs.kimi_vl import KimiVLConfig from sglang.srt.configs.kimi_vl import KimiVLConfig
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
from sglang.srt.configs.longcat_flash import LongcatFlashConfig from sglang.srt.configs.longcat_flash import LongcatFlashConfig
...@@ -31,6 +32,7 @@ __all__ = [ ...@@ -31,6 +32,7 @@ __all__ = [
"Step3TextConfig", "Step3TextConfig",
"Step3VisionEncoderConfig", "Step3VisionEncoderConfig",
"Olmo3Config", "Olmo3Config",
"KimiLinearConfig",
"Qwen3NextConfig", "Qwen3NextConfig",
"DotsVLMConfig", "DotsVLMConfig",
"DotsOCRConfig", "DotsOCRConfig",
......
# Adapted from: https://github.com/vllm-project/vllm/blob/0384aa7150c4c9778efca041ffd1beb3ad2bd694/vllm/transformers_utils/configs/kimi_linear.py
from transformers.configuration_utils import PretrainedConfig
from sglang.srt.configs.mamba_utils import KimiLinearCacheParams, KimiLinearStateShape
from sglang.srt.layers.dp_attention import get_attention_tp_size
class KimiLinearConfig(PretrainedConfig):
model_type = "kimi_linear"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
model_type="kimi_linear",
vocab_size=163840,
hidden_size=4096,
head_dim=None,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
rope_theta=10000.0,
rope_scaling=None,
tie_word_embeddings=False,
moe_intermediate_size: int | None = None,
moe_renormalize: bool = True,
moe_router_activation_func: str = "sigmoid",
num_experts: int | None = None,
num_experts_per_token: int | None = None,
num_shared_experts: int = 0,
routed_scaling_factor: float = 1.0,
first_k_dense_replace: int = 0,
moe_layer_freq: int = 1,
use_grouped_topk: bool = True,
num_expert_group: int = 1,
topk_group: int = 1,
q_lora_rank: int | None = None,
kv_lora_rank: int | None = None,
qk_nope_head_dim: int | None = None,
qk_rope_head_dim: int | None = None,
v_head_dim: int | None = None,
mla_use_nope: bool | None = False,
num_nextn_predict_layers: int = 0,
linear_attn_config: dict | None = None,
**kwargs,
):
self.model_type = model_type
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.head_dim = (
head_dim if head_dim is not None else hidden_size // num_attention_heads
)
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.mla_use_nope = mla_use_nope
# moe config
self.n_routed_experts = self.num_experts = num_experts
self.num_experts_per_token = num_experts_per_token
self.moe_renormalize = moe_renormalize
self.num_shared_experts = num_shared_experts
self.routed_scaling_factor = routed_scaling_factor
self.moe_router_activation_func = moe_router_activation_func
assert self.moe_router_activation_func in ("softmax", "sigmoid")
self.moe_intermediate_size = moe_intermediate_size
self.first_k_dense_replace = first_k_dense_replace
self.moe_layer_freq = moe_layer_freq
self.use_grouped_topk = use_grouped_topk
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.num_nextn_predict_layers = num_nextn_predict_layers
if linear_attn_config is not None:
assert linear_attn_config["kda_layers"] is not None
assert linear_attn_config["full_attn_layers"] is not None
self.linear_attn_config = linear_attn_config
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
@property
def is_mla(self):
return (
self.q_lora_rank is not None
or self.kv_lora_rank is not None
or self.qk_nope_head_dim is not None
or self.qk_rope_head_dim is not None
or self.v_head_dim is not None
or self.mla_use_nope is True
)
@property
def is_moe(self):
return self.num_experts is not None
@property
def is_linear_attn(self) -> bool:
return not (
self.linear_attn_config is None
or (
isinstance(self.linear_attn_config, dict)
and self.linear_attn_config["kda_layers"] is not None
and len(self.linear_attn_config["kda_layers"]) == 0
)
)
def is_kda_layer(self, layer_idx: int):
return (
self.linear_attn_config is not None
and (layer_idx + 1) in self.linear_attn_config["kda_layers"]
)
@property
def linear_layer_ids(self):
return [i for i in range(self.num_hidden_layers) if self.is_kda_layer(i)]
@property
def full_attention_layer_ids(self):
return [i for i in range(self.num_hidden_layers) if not self.is_kda_layer(i)]
@property
def mamba2_cache_params(self) -> KimiLinearCacheParams:
shape = KimiLinearStateShape.create(
tp_world_size=get_attention_tp_size(),
num_heads=self.linear_attn_config["num_heads"],
head_dim=self.linear_attn_config["head_dim"],
conv_kernel_size=self.linear_attn_config["short_conv_kernel_size"],
)
return KimiLinearCacheParams(shape=shape, layers=self.linear_layer_ids)
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Optional
import numpy as np import numpy as np
import torch import torch
...@@ -115,3 +116,68 @@ class Mamba2CacheParams: ...@@ -115,3 +116,68 @@ class Mamba2CacheParams:
int(np.prod(self.shape.conv)) * self.dtype.conv.itemsize int(np.prod(self.shape.conv)) * self.dtype.conv.itemsize
+ int(np.prod(self.shape.temporal)) * self.dtype.temporal.itemsize + int(np.prod(self.shape.temporal)) * self.dtype.temporal.itemsize
) * len(self.layers) ) * len(self.layers)
@dataclass(kw_only=True, frozen=True)
class KimiLinearStateShape:
conv: List[tuple[int, int]]
temporal: tuple[int, int, int]
num_heads: int
head_dim: int
num_k_heads: int
head_k_dim: int
conv_kernel: int
num_spec: int
@staticmethod
def create(
*,
tp_world_size: int,
num_heads: int,
head_dim: int,
num_k_heads: Optional[int] = None,
head_k_dim: Optional[int] = None,
conv_kernel_size: int = 4,
num_spec: int = 0,
) -> "KimiLinearStateShape":
if num_k_heads is None:
num_k_heads = num_heads
if head_k_dim is None:
head_k_dim = head_dim
proj_size = num_heads * head_dim
proj_k_size = num_k_heads * head_k_dim
conv_state_shape = (divide(proj_size, tp_world_size), conv_kernel_size - 1)
conv_state_k_shape = (divide(proj_k_size, tp_world_size), conv_kernel_size - 1)
temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, head_dim)
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
conv_state_k_shape = conv_state_k_shape[1], conv_state_k_shape[0]
return KimiLinearStateShape(
conv=[conv_state_shape, conv_state_k_shape, conv_state_k_shape],
temporal=temporal_state_shape,
num_heads=num_heads,
head_dim=head_dim,
num_k_heads=num_k_heads,
head_k_dim=head_k_dim,
conv_kernel=conv_kernel_size,
num_spec=num_spec,
)
@dataclass(kw_only=True, frozen=True)
class KimiLinearCacheParams:
shape: KimiLinearStateShape
dtype: Mamba2StateDType = field(default_factory=mamba2_state_dtype)
layers: list[int]
@property
def mamba_cache_per_req(self) -> int:
return (
int(np.sum([np.prod(conv_shape) for conv_shape in self.shape.conv]))
* self.dtype.conv.itemsize
+ int(np.prod(self.shape.temporal)) * self.dtype.temporal.itemsize
) * len(self.layers)
...@@ -366,6 +366,13 @@ class ModelConfig: ...@@ -366,6 +366,13 @@ class ModelConfig:
self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
self.v_head_dim = self.hf_text_config.v_head_dim self.v_head_dim = self.hf_text_config.v_head_dim
self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim
elif "KimiLinearForCausalLM" in self.hf_config.architectures:
self.head_dim = 72
self.attention_arch = AttentionArch.MLA
self.kv_lora_rank = self.hf_config.kv_lora_rank
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
self.v_head_dim = self.hf_config.v_head_dim
self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
else: else:
if ( if (
"MistralModel" in self.hf_config.architectures "MistralModel" in self.hf_config.architectures
......
...@@ -189,6 +189,7 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac ...@@ -189,6 +189,7 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac
from sglang.srt.layers.attention.hybrid_linear_attn_backend import ( from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
GDNAttnBackend, GDNAttnBackend,
HybridLinearAttnBackend, HybridLinearAttnBackend,
KimiLinearAttnBackend,
Mamba2AttnBackend, Mamba2AttnBackend,
) )
from sglang.srt.utils import is_blackwell, is_npu from sglang.srt.utils import is_blackwell, is_npu
...@@ -207,6 +208,8 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac ...@@ -207,6 +208,8 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac
linear_attn_backend = GDNAttnBackend(runner) linear_attn_backend = GDNAttnBackend(runner)
elif runner.mamba2_config is not None: elif runner.mamba2_config is not None:
linear_attn_backend = Mamba2AttnBackend(runner) linear_attn_backend = Mamba2AttnBackend(runner)
elif runner.kimi_linear_config is not None:
linear_attn_backend = KimiLinearAttnBackend(runner)
else: else:
raise ValueError( raise ValueError(
"Expected hybrid GDN or NemotronH models, but got unknown model." "Expected hybrid GDN or NemotronH models, but got unknown model."
......
...@@ -21,6 +21,7 @@ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16] ...@@ -21,6 +21,7 @@ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
@triton.heuristics( @triton.heuristics(
{ {
"USE_G": lambda args: args["g"] is not None, "USE_G": lambda args: args["g"] is not None,
"USE_GK": lambda args: args["gk"] is not None,
"USE_INITIAL_STATE": lambda args: args["h0"] is not None, "USE_INITIAL_STATE": lambda args: args["h0"] is not None,
"STORE_FINAL_STATE": lambda args: args["ht"] is not None, "STORE_FINAL_STATE": lambda args: args["ht"] is not None,
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None, "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
...@@ -44,6 +45,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( ...@@ -44,6 +45,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
w, w,
v_new, v_new,
g, g,
gk,
h, h,
h0, h0,
ht, ht,
...@@ -57,6 +59,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( ...@@ -57,6 +59,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
BT: tl.constexpr, BT: tl.constexpr,
BV: tl.constexpr, BV: tl.constexpr,
USE_G: tl.constexpr, USE_G: tl.constexpr,
USE_GK: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr, USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr,
SAVE_NEW_VALUE: tl.constexpr, SAVE_NEW_VALUE: tl.constexpr,
...@@ -86,12 +89,12 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( ...@@ -86,12 +89,12 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
b_h4 = tl.zeros([64, BV], dtype=tl.float32) b_h4 = tl.zeros([64, BV], dtype=tl.float32)
# calculate offset # calculate offset
h += (boh * H + i_h) * K * V h += ((boh * H + i_h) * K * V).to(tl.int64)
v += (bos * H + i_h) * V v += ((bos * H + i_h) * V).to(tl.int64)
k += (bos * Hg + i_h // (H // Hg)) * K k += ((bos * Hg + i_h // (H // Hg)) * K).to(tl.int64)
w += (bos * H + i_h) * K w += ((bos * H + i_h) * K).to(tl.int64)
if SAVE_NEW_VALUE: if SAVE_NEW_VALUE:
v_new += (bos * H + i_h) * V v_new += ((bos * H + i_h) * V).to(tl.int64)
stride_v = H * V stride_v = H * V
stride_h = H * K * V stride_h = H * K * V
stride_k = Hg * K stride_k = Hg * K
...@@ -143,58 +146,48 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( ...@@ -143,58 +146,48 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
) )
tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
p_v = tl.make_block_ptr(
v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
p_v_new = (
tl.make_block_ptr(
v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
if SAVE_NEW_VALUE
else None
)
b_v_new = tl.zeros([BT, BV], dtype=tl.float32)
p_w = tl.make_block_ptr( p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0) w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)
) )
b_w = tl.load(p_w, boundary_check=(0, 1)) b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_new += tl.dot(b_w, b_h1.to(b_w.dtype)) b_v = tl.dot(b_w, b_h1.to(b_w.dtype))
if K > 64: if K > 64:
p_w = tl.make_block_ptr( p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0) w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)
) )
b_w = tl.load(p_w, boundary_check=(0, 1)) b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_new += tl.dot(b_w, b_h2.to(b_w.dtype)) b_v += tl.dot(b_w, b_h2.to(b_w.dtype))
if K > 128: if K > 128:
p_w = tl.make_block_ptr( p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0) w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)
) )
b_w = tl.load(p_w, boundary_check=(0, 1)) b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_new += tl.dot(b_w, b_h3.to(b_w.dtype)) b_v += tl.dot(b_w, b_h3.to(b_w.dtype))
if K > 192: if K > 192:
p_w = tl.make_block_ptr( p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0) w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)
) )
b_w = tl.load(p_w, boundary_check=(0, 1)) b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_new += tl.dot(b_w, b_h4.to(b_w.dtype)) b_v += tl.dot(b_w, b_h4.to(b_w.dtype))
b_v_new = -b_v_new + tl.load(p_v, boundary_check=(0, 1)) p_v = tl.make_block_ptr(
v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v
if SAVE_NEW_VALUE: if SAVE_NEW_VALUE:
p_v_new = tl.make_block_ptr( p_v = tl.make_block_ptr(
v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
) )
tl.store( tl.store(p_v, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1))
p_v_new, b_v_new.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)
)
last_idx = min((i_t + 1) * BT, T) - 1
if USE_G: if USE_G:
last_idx = min((i_t + 1) * BT, T) - 1
b_g_last = tl.load(g + bos * H + last_idx * H + i_h) b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
p_g = tl.make_block_ptr( p_g = tl.make_block_ptr(
g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,) g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
) )
b_g = tl.load(p_g, boundary_check=(0,)) b_g = tl.load(p_g, boundary_check=(0,))
b_v_new = b_v_new * safe_exp(b_g_last - b_g)[:, None] b_v = b_v * safe_exp(b_g_last - b_g)[:, None]
b_g_last = exp(b_g_last) b_g_last = exp(b_g_last)
b_h1 = b_h1 * b_g_last b_h1 = b_h1 * b_g_last
if K > 64: if K > 64:
...@@ -203,30 +196,64 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( ...@@ -203,30 +196,64 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
b_h3 = b_h3 * b_g_last b_h3 = b_h3 * b_g_last
if K > 192: if K > 192:
b_h4 = b_h4 * b_g_last b_h4 = b_h4 * b_g_last
b_v_new = b_v_new.to(k.dtype.element_ty)
if USE_GK:
o_k1 = tl.arange(0, 64)
b_gk_last1 = tl.load(
gk + (bos + last_idx) * H * K + i_h * K + o_k1,
mask=(o_k1 < K),
other=0.0,
)
b_h1 *= exp(b_gk_last1)[:, None]
if K > 64:
o_k2 = 64 + o_k1
b_gk_last2 = tl.load(
gk + (bos + last_idx) * H * K + i_h * K + o_k2,
mask=(o_k2 < K),
other=0.0,
)
b_h2 *= exp(b_gk_last2)[:, None]
if K > 128:
o_k3 = 128 + o_k1
b_gk_last3 = tl.load(
gk + (bos + last_idx) * H * K + i_h * K + o_k3,
mask=(o_k3 < K),
other=0.0,
)
b_h3 *= exp(b_gk_last3)[:, None]
if K > 192:
o_k4 = 192 + o_k1
b_gk_last4 = tl.load(
gk + (bos + last_idx) * H * K + i_h * K + o_k4,
mask=(o_k4 < K),
other=0.0,
)
b_h4 *= exp(b_gk_last4)[:, None]
b_v = b_v.to(k.dtype.element_ty)
p_k = tl.make_block_ptr( p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1) k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)
) )
b_k = tl.load(p_k, boundary_check=(0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1))
b_h1 += tl.dot(b_k, b_v_new) b_h1 += tl.dot(b_k, b_v)
if K > 64: if K > 64:
p_k = tl.make_block_ptr( p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1) k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)
) )
b_k = tl.load(p_k, boundary_check=(0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1))
b_h2 += tl.dot(b_k, b_v_new) b_h2 += tl.dot(b_k, b_v)
if K > 128: if K > 128:
p_k = tl.make_block_ptr( p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1) k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)
) )
b_k = tl.load(p_k, boundary_check=(0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1))
b_h3 += tl.dot(b_k, b_v_new) b_h3 += tl.dot(b_k, b_v)
if K > 192: if K > 192:
p_k = tl.make_block_ptr( p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1) k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)
) )
b_k = tl.load(p_k, boundary_check=(0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1))
b_h4 += tl.dot(b_k, b_v_new) b_h4 += tl.dot(b_k, b_v)
# epilogue # epilogue
if STORE_FINAL_STATE: if STORE_FINAL_STATE:
...@@ -254,6 +281,7 @@ def chunk_gated_delta_rule_fwd_h( ...@@ -254,6 +281,7 @@ def chunk_gated_delta_rule_fwd_h(
w: torch.Tensor, w: torch.Tensor,
u: torch.Tensor, u: torch.Tensor,
g: Optional[torch.Tensor] = None, g: Optional[torch.Tensor] = None,
gk: Optional[torch.Tensor] = None,
initial_state: Optional[torch.Tensor] = None, initial_state: Optional[torch.Tensor] = None,
output_final_state: bool = False, output_final_state: bool = False,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64? chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
...@@ -296,6 +324,7 @@ def chunk_gated_delta_rule_fwd_h( ...@@ -296,6 +324,7 @@ def chunk_gated_delta_rule_fwd_h(
w=w, w=w,
v_new=v_new, v_new=v_new,
g=g, g=g,
gk=gk,
h=h, h=h,
h0=initial_state, h0=initial_state,
ht=final_state, ht=final_state,
......
...@@ -44,6 +44,7 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( ...@@ -44,6 +44,7 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar, IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar,
USE_QK_L2NORM_IN_KERNEL: tl.constexpr, USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
IS_VARLEN: tl.constexpr, IS_VARLEN: tl.constexpr,
IS_KDA: tl.constexpr,
): ):
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_n, i_hv = i_nh // HV, i_nh % HV i_n, i_hv = i_nh // HV, i_nh % HV
...@@ -67,7 +68,11 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( ...@@ -67,7 +68,11 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
p_beta = beta + (bos * HV + i_hv) * V + o_v p_beta = beta + (bos * HV + i_hv) * V + o_v
else: else:
p_beta = beta + bos * HV + i_hv p_beta = beta + bos * HV + i_hv
p_g = g + bos * HV + i_hv if not IS_KDA:
p_g = g + bos * HV + i_hv
else:
p_gk = g + (bos * HV + i_hv) * K + o_k
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
mask_k = o_k < K mask_k = o_k < K
...@@ -83,14 +88,18 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( ...@@ -83,14 +88,18 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
b_g = tl.load(p_g).to(tl.float32)
if USE_QK_L2NORM_IN_KERNEL: if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6)) b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6)) b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
b_q = b_q * scale b_q = b_q * scale
# [BK, BV] # [BK, BV]
b_h *= exp(b_g) if not IS_KDA:
b_g = tl.load(p_g).to(tl.float32)
b_h *= exp(b_g)
else:
b_gk = tl.load(p_gk).to(tl.float32)
b_h *= exp(b_gk[:, None])
# [BV] # [BV]
b_v -= tl.sum(b_h * b_k[:, None], 0) b_v -= tl.sum(b_h * b_k[:, None], 0)
if IS_BETA_HEADWISE: if IS_BETA_HEADWISE:
...@@ -108,7 +117,10 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( ...@@ -108,7 +117,10 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
p_k += H * K p_k += H * K
p_o += HV * V p_o += HV * V
p_v += HV * V p_v += HV * V
p_g += HV if not IS_KDA:
p_g += HV
else:
p_gk += HV * K
p_beta += HV * (V if IS_BETA_HEADWISE else 1) p_beta += HV * (V if IS_BETA_HEADWISE else 1)
if STORE_FINAL_STATE: if STORE_FINAL_STATE:
...@@ -165,6 +177,7 @@ def fused_recurrent_gated_delta_rule_fwd( ...@@ -165,6 +177,7 @@ def fused_recurrent_gated_delta_rule_fwd(
BV=BV, BV=BV,
IS_BETA_HEADWISE=beta.ndim == v.ndim, IS_BETA_HEADWISE=beta.ndim == v.ndim,
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
IS_KDA=False,
num_warps=num_warps, num_warps=num_warps,
num_stages=num_stages, num_stages=num_stages,
) )
......
This diff is collapsed.
from typing import Optional, Union from typing import Optional, Union
import torch import torch
from einops import rearrange
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule
...@@ -10,6 +11,11 @@ from sglang.srt.layers.attention.fla.fused_recurrent import ( ...@@ -10,6 +11,11 @@ from sglang.srt.layers.attention.fla.fused_recurrent import (
from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import ( from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (
fused_sigmoid_gating_delta_rule_update, fused_sigmoid_gating_delta_rule_update,
) )
from sglang.srt.layers.attention.fla.kda import (
chunk_kda,
fused_kda_gate,
fused_recurrent_kda,
)
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import ( from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
PAD_SLOT_ID, PAD_SLOT_ID,
causal_conv1d_fn, causal_conv1d_fn,
...@@ -227,6 +233,223 @@ class MambaAttnBackendBase(AttentionBackend): ...@@ -227,6 +233,223 @@ class MambaAttnBackendBase(AttentionBackend):
return 1 # Mamba attn does not use seq lens to index kv cache return 1 # Mamba attn does not use seq lens to index kv cache
class KimiLinearAttnBackend(MambaAttnBackendBase):
"""Attention backend using Mamba kernel."""
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
**kwargs,
):
q_proj_states = kwargs["q_proj_states"]
k_proj_states = kwargs["k_proj_states"]
v_proj_states = kwargs["v_proj_states"]
q_conv_weights = kwargs["q_conv_weights"]
k_conv_weights = kwargs["k_conv_weights"]
v_conv_weights = kwargs["v_conv_weights"]
q_conv_bias = kwargs["q_conv_bias"]
k_conv_bias = kwargs["k_conv_bias"]
v_conv_bias = kwargs["v_conv_bias"]
A_log = kwargs["A_log"]
dt_bias = kwargs["dt_bias"]
b_proj = kwargs["b_proj"]
f_a_proj = kwargs["f_a_proj"]
f_b_proj = kwargs["f_b_proj"]
hidden_states = kwargs["hidden_states"]
head_dim = kwargs["head_dim"]
layer_id = kwargs["layer_id"]
layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
q_conv_state, k_conv_state, v_conv_state = layer_cache.conv
ssm_states = layer_cache.temporal
query_start_loc = self.forward_metadata.query_start_loc
cache_indices = self.forward_metadata.mamba_cache_indices
q_conv_state = q_conv_state.transpose(-1, -2)
k_conv_state = k_conv_state.transpose(-1, -2)
v_conv_state = v_conv_state.transpose(-1, -2)
q = causal_conv1d_update(
q_proj_states,
q_conv_state,
q_conv_weights,
q_conv_bias,
activation="silu",
conv_state_indices=cache_indices,
)
k = causal_conv1d_update(
k_proj_states,
k_conv_state,
k_conv_weights,
k_conv_bias,
activation="silu",
conv_state_indices=cache_indices,
)
v = causal_conv1d_update(
v_proj_states,
v_conv_state,
v_conv_weights,
v_conv_bias,
activation="silu",
conv_state_indices=cache_indices,
)
q, k, v = map(
lambda x: rearrange(x, "n (h d) -> 1 n h d", d=head_dim), (q, k, v)
)
beta = b_proj(hidden_states)[0].float().sigmoid()
g = f_b_proj(f_a_proj(hidden_states)[0])[0]
g = fused_kda_gate(g, A_log, head_dim, g_bias=dt_bias)
beta = beta.unsqueeze(0)
g = g.unsqueeze(0)
initial_state = ssm_states[cache_indices].contiguous()
(
core_attn_out,
last_recurrent_state,
) = fused_recurrent_kda(
q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=initial_state,
use_qk_l2norm_in_kernel=True,
cu_seqlens=query_start_loc,
)
ssm_states[cache_indices] = last_recurrent_state
return core_attn_out
def forward_extend(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
**kwargs,
):
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
causal_conv1d_fn,
)
q_proj_states = kwargs["q_proj_states"]
k_proj_states = kwargs["k_proj_states"]
v_proj_states = kwargs["v_proj_states"]
q_conv_weights = kwargs["q_conv_weights"]
k_conv_weights = kwargs["k_conv_weights"]
v_conv_weights = kwargs["v_conv_weights"]
q_conv_bias = kwargs["q_conv_bias"]
k_conv_bias = kwargs["k_conv_bias"]
v_conv_bias = kwargs["v_conv_bias"]
A_log = kwargs["A_log"]
dt_bias = kwargs["dt_bias"]
b_proj = kwargs["b_proj"]
f_a_proj = kwargs["f_a_proj"]
f_b_proj = kwargs["f_b_proj"]
hidden_states = kwargs["hidden_states"]
head_dim = kwargs["head_dim"]
layer_id = kwargs["layer_id"]
query_start_loc = self.forward_metadata.query_start_loc
cache_indices = self.forward_metadata.mamba_cache_indices
mamba_cache_params = self.req_to_token_pool.mamba2_layer_cache(layer_id)
conv_state_q, conv_state_k, conv_state_v = mamba_cache_params.conv
# deal with strides
conv_state_q = conv_state_q.transpose(-1, -2)
conv_state_k = conv_state_k.transpose(-1, -2)
conv_state_v = conv_state_v.transpose(-1, -2)
ssm_states = mamba_cache_params.temporal
has_initial_state = forward_batch.extend_prefix_lens > 0
q_proj_states = q_proj_states.transpose(0, 1)
k_proj_states = k_proj_states.transpose(0, 1)
v_proj_states = v_proj_states.transpose(0, 1)
q = causal_conv1d_fn(
q_proj_states,
q_conv_weights,
q_conv_bias,
activation="silu",
conv_states=conv_state_q,
has_initial_state=has_initial_state,
cache_indices=cache_indices,
query_start_loc=query_start_loc,
seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
).transpose(0, 1)
k = causal_conv1d_fn(
k_proj_states,
k_conv_weights,
k_conv_bias,
activation="silu",
conv_states=conv_state_k,
has_initial_state=has_initial_state,
cache_indices=cache_indices,
query_start_loc=query_start_loc,
seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
).transpose(0, 1)
v = causal_conv1d_fn(
v_proj_states,
v_conv_weights,
v_conv_bias,
activation="silu",
conv_states=conv_state_v,
has_initial_state=has_initial_state,
cache_indices=cache_indices,
query_start_loc=query_start_loc,
seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
).transpose(0, 1)
q, k, v = map(
lambda x: rearrange(x, "n (h d) -> 1 n h d", d=head_dim), (q, k, v)
)
beta = b_proj(hidden_states)[0].float().sigmoid()
g = f_b_proj(f_a_proj(hidden_states)[0])[0]
g = fused_kda_gate(g, A_log, head_dim, g_bias=dt_bias)
beta = beta.unsqueeze(0)
g = g.unsqueeze(0)
initial_state = ssm_states[cache_indices].contiguous()
(
core_attn_out,
last_recurrent_state,
) = chunk_kda(
q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=initial_state,
output_final_state=True,
use_qk_l2norm_in_kernel=True,
cu_seqlens=query_start_loc,
)
ssm_states[cache_indices] = last_recurrent_state
return core_attn_out
class GDNAttnBackend(MambaAttnBackendBase): class GDNAttnBackend(MambaAttnBackendBase):
"""Attention backend using Mamba kernel.""" """Attention backend using Mamba kernel."""
......
...@@ -92,7 +92,10 @@ class TritonAttnBackend(AttentionBackend): ...@@ -92,7 +92,10 @@ class TritonAttnBackend(AttentionBackend):
self.num_kv_head = model_runner.model_config.get_num_kv_heads( self.num_kv_head = model_runner.model_config.get_num_kv_heads(
get_attention_tp_size() get_attention_tp_size()
) )
if model_runner.hybrid_gdn_config is not None: if (
model_runner.hybrid_gdn_config is not None
or model_runner.kimi_linear_config is not None
):
# For hybrid linear models, layer_id = 0 may not be full attention # For hybrid linear models, layer_id = 0 may not be full attention
self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim() self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()
else: else:
......
...@@ -17,7 +17,7 @@ from __future__ import annotations ...@@ -17,7 +17,7 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from sglang.srt.configs.mamba_utils import Mamba2CacheParams from sglang.srt.configs.mamba_utils import KimiLinearCacheParams, Mamba2CacheParams
from sglang.srt.layers.attention.nsa import index_buf_accessor from sglang.srt.layers.attention.nsa import index_buf_accessor
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
...@@ -33,7 +33,7 @@ KVCache actually holds the physical kv cache. ...@@ -33,7 +33,7 @@ KVCache actually holds the physical kv cache.
import abc import abc
import logging import logging
from contextlib import nullcontext from contextlib import contextmanager, nullcontext
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -59,7 +59,9 @@ if _is_npu: ...@@ -59,7 +59,9 @@ if _is_npu:
import torch_npu import torch_npu
def get_tensor_size_bytes(t: torch.Tensor): def get_tensor_size_bytes(t: Union[torch.Tensor, List[torch.Tensor]]):
if isinstance(t, list):
return sum(get_tensor_size_bytes(x) for x in t)
return np.prod(t.shape) * t.dtype.itemsize return np.prod(t.shape) * t.dtype.itemsize
...@@ -116,10 +118,15 @@ class ReqToTokenPool: ...@@ -116,10 +118,15 @@ class ReqToTokenPool:
class MambaPool: class MambaPool:
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class State: class State:
conv: torch.Tensor conv: Union[torch.Tensor, List[torch.Tensor]]
temporal: torch.Tensor temporal: torch.Tensor
def at_layer_idx(self, layer: int): def at_layer_idx(self, layer: int):
if isinstance(self.conv, list):
return type(self)(
conv=[v[layer] for v in self.conv],
temporal=self.temporal[layer],
)
return type(self)(**{k: v[layer] for k, v in vars(self).items()}) return type(self)(**{k: v[layer] for k, v in vars(self).items()})
def mem_usage_bytes(self): def mem_usage_bytes(self):
...@@ -127,14 +134,14 @@ class MambaPool: ...@@ -127,14 +134,14 @@ class MambaPool:
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class SpeculativeState(State): class SpeculativeState(State):
intermediate_ssm: torch.Tensor intermediate_ssm: Union[torch.Tensor, List[torch.Tensor]]
intermediate_conv_window: torch.Tensor intermediate_conv_window: torch.Tensor
def __init__( def __init__(
self, self,
*, *,
size: int, size: int,
cache_params: "Mamba2CacheParams", cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"],
device: str, device: str,
speculative_num_draft_tokens: Optional[int] = None, speculative_num_draft_tokens: Optional[int] = None,
): ):
...@@ -157,18 +164,29 @@ class MambaPool: ...@@ -157,18 +164,29 @@ class MambaPool:
else: else:
self.custom_mem_pool = None self.custom_mem_pool = None
self.is_kda_cache = isinstance(cache_params, KimiLinearCacheParams)
with ( with (
torch.cuda.use_mem_pool(self.custom_mem_pool) torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.enable_custom_mem_pool if self.enable_custom_mem_pool
else nullcontext() else nullcontext()
): ):
# assume conv_state = (dim, state_len) if self.is_kda_cache:
assert conv_state_shape[0] > conv_state_shape[1] conv_state = [
conv_state = torch.zeros( torch.zeros(
size=(num_mamba_layers, size + 1) + conv_state_shape, size=(num_mamba_layers, size + 1) + conv_shape,
dtype=conv_dtype, dtype=conv_dtype,
device=device, device=device,
) )
for conv_shape in conv_state_shape
]
else:
# assume conv_state = (dim, state_len)
assert conv_state_shape[0] > conv_state_shape[1]
conv_state = torch.zeros(
size=(num_mamba_layers, size + 1) + conv_state_shape,
dtype=conv_dtype,
device=device,
)
temporal_state = torch.zeros( temporal_state = torch.zeros(
size=(num_mamba_layers, size + 1) + temporal_state_shape, size=(num_mamba_layers, size + 1) + temporal_state_shape,
dtype=ssm_dtype, dtype=ssm_dtype,
...@@ -191,17 +209,34 @@ class MambaPool: ...@@ -191,17 +209,34 @@ class MambaPool:
) )
# Cache intermediate conv windows (last K-1 inputs) per draft token during target verify # Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1] # Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
intermediate_conv_window_cache = torch.zeros(
size=( if self.is_kda_cache:
num_mamba_layers, intermediate_conv_window_cache = [
size + 1, torch.zeros(
speculative_num_draft_tokens, size=(
conv_state_shape[0], num_mamba_layers,
conv_state_shape[1], size + 1,
), speculative_num_draft_tokens,
dtype=conv_dtype, conv_shape[0],
device="cuda", conv_shape[1],
) ),
dtype=conv_dtype,
device="cuda",
)
for conv_shape in conv_state_shape
]
else:
intermediate_conv_window_cache = torch.zeros(
size=(
num_mamba_layers,
size + 1,
speculative_num_draft_tokens,
conv_state_shape[0],
conv_state_shape[1],
),
dtype=conv_dtype,
device="cuda",
)
self.mamba_cache = self.SpeculativeState( self.mamba_cache = self.SpeculativeState(
conv=conv_state, conv=conv_state,
temporal=temporal_state, temporal=temporal_state,
...@@ -255,15 +290,25 @@ class MambaPool: ...@@ -255,15 +290,25 @@ class MambaPool:
if free_index.numel() == 0: if free_index.numel() == 0:
return return
self.free_slots = torch.cat((self.free_slots, free_index)) self.free_slots = torch.cat((self.free_slots, free_index))
self.mamba_cache.conv[:, free_index] = self.mamba_cache.temporal[ if self.is_kda_cache:
:, free_index for i in range(len(self.mamba_cache.conv)):
] = 0 self.mamba_cache.conv[i][:, free_index] = 0
else:
self.mamba_cache.conv[:, free_index] = 0
self.mamba_cache.temporal[:, free_index] = 0
def clear(self): def clear(self):
self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device) self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device)
def copy_from(self, src_index: torch.Tensor, dst_index: torch.Tensor): def copy_from(self, src_index: torch.Tensor, dst_index: torch.Tensor):
self.mamba_cache.conv[:, dst_index] = self.mamba_cache.conv[:, src_index] if self.is_kda_cache:
for i in range(len(self.mamba_cache.conv)):
self.mamba_cache.conv[i][:, dst_index] = self.mamba_cache.conv[i][
:, src_index
]
else:
self.mamba_cache.conv[:, dst_index] = self.mamba_cache.conv[:, src_index]
self.mamba_cache.temporal[:, dst_index] = self.mamba_cache.temporal[ self.mamba_cache.temporal[:, dst_index] = self.mamba_cache.temporal[
:, src_index :, src_index
] ]
...@@ -304,7 +349,7 @@ class HybridReqToTokenPool(ReqToTokenPool): ...@@ -304,7 +349,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
max_context_len: int, max_context_len: int,
device: str, device: str,
enable_memory_saver: bool, enable_memory_saver: bool,
cache_params: "Mamba2CacheParams", cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"],
speculative_num_draft_tokens: int = None, speculative_num_draft_tokens: int = None,
): ):
super().__init__( super().__init__(
...@@ -323,7 +368,7 @@ class HybridReqToTokenPool(ReqToTokenPool): ...@@ -323,7 +368,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
def _init_mamba_pool( def _init_mamba_pool(
self, self,
size: int, size: int,
cache_params: "Mamba2CacheParams", cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"],
device: str, device: str,
speculative_num_draft_tokens: int = None, speculative_num_draft_tokens: int = None,
): ):
...@@ -812,6 +857,10 @@ class HybridLinearKVPool(KVCache): ...@@ -812,6 +857,10 @@ class HybridLinearKVPool(KVCache):
enable_kvcache_transpose: bool, enable_kvcache_transpose: bool,
device: str, device: str,
mamba_pool: MambaPool, mamba_pool: MambaPool,
# TODO: refactor mla related args
use_mla: bool = False,
kv_lora_rank: int = None,
qk_rope_head_dim: int = None,
): ):
self.size = size self.size = size
self.dtype = dtype self.dtype = dtype
...@@ -825,25 +874,42 @@ class HybridLinearKVPool(KVCache): ...@@ -825,25 +874,42 @@ class HybridLinearKVPool(KVCache):
self.mamba_pool = mamba_pool self.mamba_pool = mamba_pool
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
assert not enable_kvcache_transpose assert not enable_kvcache_transpose
if _is_npu: self.use_mla = use_mla
TokenToKVPoolClass = AscendTokenToKVPool if not use_mla:
if _is_npu:
TokenToKVPoolClass = AscendTokenToKVPool
else:
TokenToKVPoolClass = MHATokenToKVPool
self.full_kv_pool = TokenToKVPoolClass(
size=size,
page_size=self.page_size,
dtype=dtype,
head_num=head_num,
head_dim=head_dim,
layer_num=self.full_layer_nums,
device=device,
enable_memory_saver=False,
)
else: else:
TokenToKVPoolClass = MHATokenToKVPool TokenToKVPoolClass = MLATokenToKVPool
self.full_kv_pool = TokenToKVPoolClass( self.full_kv_pool = TokenToKVPoolClass(
size=size, size=size,
page_size=self.page_size, page_size=self.page_size,
dtype=dtype, dtype=dtype,
head_num=head_num, layer_num=self.full_layer_nums,
head_dim=head_dim, device=device,
layer_num=self.full_layer_nums, kv_lora_rank=kv_lora_rank,
device=device, qk_rope_head_dim=qk_rope_head_dim,
enable_memory_saver=False, enable_memory_saver=False,
) )
self.full_attention_layer_id_mapping = { self.full_attention_layer_id_mapping = {
id: i for i, id in enumerate(full_attention_layer_ids) id: i for i, id in enumerate(full_attention_layer_ids)
} }
k_size, v_size = self.get_kv_size_bytes() if use_mla:
self.mem_usage = (k_size + v_size) / GB self.mem_usage = self.get_kv_size_bytes() / GB
else:
k_size, v_size = self.get_kv_size_bytes()
self.mem_usage = (k_size + v_size) / GB
def get_kv_size_bytes(self): def get_kv_size_bytes(self):
return self.full_kv_pool.get_kv_size_bytes() return self.full_kv_pool.get_kv_size_bytes()
...@@ -879,6 +945,21 @@ class HybridLinearKVPool(KVCache): ...@@ -879,6 +945,21 @@ class HybridLinearKVPool(KVCache):
layer_id = self._transfer_full_attention_id(layer_id) layer_id = self._transfer_full_attention_id(layer_id)
return self.full_kv_pool.get_kv_buffer(layer_id) return self.full_kv_pool.get_kv_buffer(layer_id)
@contextmanager
def _transfer_id_context(self, layer: RadixAttention):
@contextmanager
def _patch_layer_id(layer):
original_layer_id = layer.layer_id
layer.layer_id = self._transfer_full_attention_id(layer.layer_id)
try:
yield
finally:
layer.layer_id = original_layer_id
with _patch_layer_id(layer):
yield
def set_kv_buffer( def set_kv_buffer(
self, self,
layer: RadixAttention, layer: RadixAttention,
...@@ -889,19 +970,49 @@ class HybridLinearKVPool(KVCache): ...@@ -889,19 +970,49 @@ class HybridLinearKVPool(KVCache):
v_scale: float = 1.0, v_scale: float = 1.0,
): ):
layer_id = self._transfer_full_attention_id(layer.layer_id) layer_id = self._transfer_full_attention_id(layer.layer_id)
self.full_kv_pool.set_kv_buffer( if not self.use_mla:
None, self.full_kv_pool.set_kv_buffer(
loc, None,
cache_k, loc,
cache_v, cache_k,
k_scale, cache_v,
v_scale, k_scale,
layer_id_override=layer_id, v_scale,
) layer_id_override=layer_id,
)
else:
with self._transfer_id_context(layer):
self.full_kv_pool.set_kv_buffer(
layer,
loc,
cache_k,
cache_v,
)
def get_v_head_dim(self): def get_v_head_dim(self):
return self.full_kv_pool.get_value_buffer(0).shape[-1] return self.full_kv_pool.get_value_buffer(0).shape[-1]
def set_mla_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k_nope: torch.Tensor,
cache_k_rope: torch.Tensor,
):
assert self.use_mla, "set_mla_kv_buffer called when use_mla is False"
with self._transfer_id_context(layer):
self.full_kv_pool.set_mla_kv_buffer(layer, loc, cache_k_nope, cache_k_rope)
def get_mla_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
dst_dtype: Optional[torch.dtype] = None,
):
assert self.use_mla, "get_mla_kv_buffer called when use_mla is False"
with self._transfer_id_context(layer):
return self.full_kv_pool.get_mla_kv_buffer(layer, loc, dst_dtype)
class SWAKVPool(KVCache): class SWAKVPool(KVCache):
"""KV cache with separate pools for full and SWA attention layers.""" """KV cache with separate pools for full and SWA attention layers."""
......
...@@ -29,7 +29,12 @@ from typing import Callable, List, Optional, Tuple, Union ...@@ -29,7 +29,12 @@ from typing import Callable, List, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from sglang.srt.configs import FalconH1Config, NemotronHConfig, Qwen3NextConfig from sglang.srt.configs import (
FalconH1Config,
KimiLinearConfig,
NemotronHConfig,
Qwen3NextConfig,
)
from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig, LoadFormat from sglang.srt.configs.load_config import LoadConfig, LoadFormat
from sglang.srt.configs.model_config import ( from sglang.srt.configs.model_config import (
...@@ -1358,9 +1363,16 @@ class ModelRunner: ...@@ -1358,9 +1363,16 @@ class ModelRunner:
return config return config
return None return None
@property
def kimi_linear_config(self):
config = self.model_config.hf_config
if isinstance(config, KimiLinearConfig):
return config
return None
@property @property
def mambaish_config(self): def mambaish_config(self):
return self.mamba2_config or self.hybrid_gdn_config return self.mamba2_config or self.hybrid_gdn_config or self.kimi_linear_config
def set_num_token_hybrid(self): def set_num_token_hybrid(self):
if ( if (
...@@ -1691,7 +1703,7 @@ class ModelRunner: ...@@ -1691,7 +1703,7 @@ class ModelRunner:
end_layer=self.end_layer, end_layer=self.end_layer,
index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config), index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config),
) )
elif self.use_mla_backend: elif self.use_mla_backend and not self.mambaish_config:
assert not is_nsa_model assert not is_nsa_model
self.token_to_kv_pool = MLATokenToKVPool( self.token_to_kv_pool = MLATokenToKVPool(
self.max_total_num_tokens, self.max_total_num_tokens,
...@@ -1735,6 +1747,12 @@ class ModelRunner: ...@@ -1735,6 +1747,12 @@ class ModelRunner:
device=self.device, device=self.device,
) )
elif config := self.mambaish_config: elif config := self.mambaish_config:
extra_args = {}
if self.use_mla_backend:
extra_args = {
"kv_lora_rank": self.model_config.kv_lora_rank,
"qk_rope_head_dim": self.model_config.qk_rope_head_dim,
}
self.token_to_kv_pool = HybridLinearKVPool( self.token_to_kv_pool = HybridLinearKVPool(
page_size=self.page_size, page_size=self.page_size,
size=self.max_total_num_tokens, size=self.max_total_num_tokens,
...@@ -1750,6 +1768,8 @@ class ModelRunner: ...@@ -1750,6 +1768,8 @@ class ModelRunner:
enable_kvcache_transpose=False, enable_kvcache_transpose=False,
device=self.device, device=self.device,
mamba_pool=self.req_to_token_pool.mamba_pool, mamba_pool=self.req_to_token_pool.mamba_pool,
use_mla=self.use_mla_backend,
**extra_args,
) )
else: else:
self.token_to_kv_pool = MHATokenToKVPool( self.token_to_kv_pool = MHATokenToKVPool(
......
...@@ -1075,6 +1075,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1075,6 +1075,7 @@ class DeepseekV2AttentionMLA(nn.Module):
layer_id: int = None, layer_id: int = None,
prefix: str = "", prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None, alt_stream: Optional[torch.cuda.Stream] = None,
skip_rope: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.layer_id = layer_id self.layer_id = layer_id
...@@ -1182,23 +1183,26 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1182,23 +1183,26 @@ class DeepseekV2AttentionMLA(nn.Module):
) )
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
self.rotary_emb = get_rope_wrapper( if not skip_rope:
qk_rope_head_dim, self.rotary_emb = get_rope_wrapper(
rotary_dim=qk_rope_head_dim, qk_rope_head_dim,
max_position=max_position_embeddings, rotary_dim=qk_rope_head_dim,
base=rope_theta, max_position=max_position_embeddings,
rope_scaling=rope_scaling, base=rope_theta,
is_neox_style=False, rope_scaling=rope_scaling,
device=get_global_server_args().device, is_neox_style=False,
) device=get_global_server_args().device,
)
if rope_scaling: if rope_scaling:
mscale_all_dim = rope_scaling.get("mscale_all_dim", False) mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale self.scaling = self.scaling * mscale * mscale
else:
self.rotary_emb.forward = self.rotary_emb.forward_native
else: else:
self.rotary_emb.forward = self.rotary_emb.forward_native self.rotary_emb = None
self.attn_mqa = RadixAttention( self.attn_mqa = RadixAttention(
self.num_local_heads, self.num_local_heads,
...@@ -1487,7 +1491,8 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1487,7 +1491,8 @@ class DeepseekV2AttentionMLA(nn.Module):
latent_cache = latent_cache.unsqueeze(1) latent_cache = latent_cache.unsqueeze(1)
kv_a = self.kv_a_layernorm(kv_a) kv_a = self.kv_a_layernorm(kv_a)
k_pe = latent_cache[:, :, self.kv_lora_rank :] k_pe = latent_cache[:, :, self.kv_lora_rank :]
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) if self.rotary_emb is not None:
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
q[..., self.qk_nope_head_dim :] = q_pe q[..., self.qk_nope_head_dim :] = q_pe
self._set_mla_kv_buffer(latent_cache, kv_a, k_pe, forward_batch) self._set_mla_kv_buffer(latent_cache, kv_a, k_pe, forward_batch)
...@@ -1646,8 +1651,10 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1646,8 +1651,10 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out = q_nope_out.transpose(0, 1) q_nope_out = q_nope_out.transpose(0, 1)
if not self._fuse_rope_for_trtllm_mla(forward_batch) and ( if (
not _use_aiter or not _is_gfx95_supported or self.use_nsa self.rotary_emb is not None
and (not self._fuse_rope_for_trtllm_mla(forward_batch))
and (not _use_aiter or not _is_gfx95_supported or self.use_nsa)
): ):
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
......
This diff is collapsed.
...@@ -1028,6 +1028,11 @@ class ServerArgs: ...@@ -1028,6 +1028,11 @@ class ServerArgs:
logger.info( logger.info(
f"Using {self.attention_backend} as attention backend for {model_arch}." f"Using {self.attention_backend} as attention backend for {model_arch}."
) )
elif model_arch in ["KimiLinearForCausalLM"]:
logger.warning(
f"Disabling Radix Cache for {model_arch} as it is not yet supported."
)
self.disable_radix_cache = True
if is_deepseek_nsa(hf_config): if is_deepseek_nsa(hf_config):
if ( if (
......
...@@ -43,6 +43,7 @@ from sglang.srt.configs import ( ...@@ -43,6 +43,7 @@ from sglang.srt.configs import (
DotsVLMConfig, DotsVLMConfig,
ExaoneConfig, ExaoneConfig,
FalconH1Config, FalconH1Config,
KimiLinearConfig,
KimiVLConfig, KimiVLConfig,
LongcatFlashConfig, LongcatFlashConfig,
MultiModalityConfig, MultiModalityConfig,
...@@ -68,6 +69,7 @@ _CONFIG_REGISTRY: List[Type[PretrainedConfig]] = [ ...@@ -68,6 +69,7 @@ _CONFIG_REGISTRY: List[Type[PretrainedConfig]] = [
Step3VLConfig, Step3VLConfig,
LongcatFlashConfig, LongcatFlashConfig,
Olmo3Config, Olmo3Config,
KimiLinearConfig,
Qwen3NextConfig, Qwen3NextConfig,
FalconH1Config, FalconH1Config,
DotsVLMConfig, DotsVLMConfig,
......
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestKimiLinear(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "moonshotai/Kimi-Linear-48B-A3B-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--tp-size", "2", "--trust-remote"],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.88)
if __name__ == "__main__":
unittest.main()
...@@ -151,6 +151,7 @@ suites = { ...@@ -151,6 +151,7 @@ suites = {
TestFile("layers/attention/mamba/test_mamba2_mixer.py", 50), TestFile("layers/attention/mamba/test_mamba2_mixer.py", 50),
TestFile("lora/test_lora_tp.py", 116), TestFile("lora/test_lora_tp.py", 116),
TestFile("models/test_glm4_moe_models.py", 100), TestFile("models/test_glm4_moe_models.py", 100),
TestFile("models/test_kimi_linear_models.py", 90),
TestFile("rl/test_update_weights_from_distributed.py", 103), TestFile("rl/test_update_weights_from_distributed.py", 103),
TestFile("test_data_parallelism.py", 73), TestFile("test_data_parallelism.py", 73),
TestFile("test_disaggregation_basic.py", 400), TestFile("test_disaggregation_basic.py", 400),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment