Unverified Commit a4bf5c6a authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Support Kimi Linear (#12469)


Co-authored-by: default avataryizhang2077 <1109276519@qq.com>
parent 30ad1070
...@@ -6,6 +6,7 @@ from sglang.srt.configs.dots_vlm import DotsVLMConfig ...@@ -6,6 +6,7 @@ from sglang.srt.configs.dots_vlm import DotsVLMConfig
from sglang.srt.configs.exaone import ExaoneConfig from sglang.srt.configs.exaone import ExaoneConfig
from sglang.srt.configs.falcon_h1 import FalconH1Config from sglang.srt.configs.falcon_h1 import FalconH1Config
from sglang.srt.configs.janus_pro import MultiModalityConfig from sglang.srt.configs.janus_pro import MultiModalityConfig
from sglang.srt.configs.kimi_linear import KimiLinearConfig
from sglang.srt.configs.kimi_vl import KimiVLConfig from sglang.srt.configs.kimi_vl import KimiVLConfig
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
from sglang.srt.configs.longcat_flash import LongcatFlashConfig from sglang.srt.configs.longcat_flash import LongcatFlashConfig
...@@ -31,6 +32,7 @@ __all__ = [ ...@@ -31,6 +32,7 @@ __all__ = [
"Step3TextConfig", "Step3TextConfig",
"Step3VisionEncoderConfig", "Step3VisionEncoderConfig",
"Olmo3Config", "Olmo3Config",
"KimiLinearConfig",
"Qwen3NextConfig", "Qwen3NextConfig",
"DotsVLMConfig", "DotsVLMConfig",
"DotsOCRConfig", "DotsOCRConfig",
......
# Adapted from: https://github.com/vllm-project/vllm/blob/0384aa7150c4c9778efca041ffd1beb3ad2bd694/vllm/transformers_utils/configs/kimi_linear.py
from transformers.configuration_utils import PretrainedConfig
from sglang.srt.configs.mamba_utils import KimiLinearCacheParams, KimiLinearStateShape
from sglang.srt.layers.dp_attention import get_attention_tp_size
class KimiLinearConfig(PretrainedConfig):
model_type = "kimi_linear"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
model_type="kimi_linear",
vocab_size=163840,
hidden_size=4096,
head_dim=None,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
rope_theta=10000.0,
rope_scaling=None,
tie_word_embeddings=False,
moe_intermediate_size: int | None = None,
moe_renormalize: bool = True,
moe_router_activation_func: str = "sigmoid",
num_experts: int | None = None,
num_experts_per_token: int | None = None,
num_shared_experts: int = 0,
routed_scaling_factor: float = 1.0,
first_k_dense_replace: int = 0,
moe_layer_freq: int = 1,
use_grouped_topk: bool = True,
num_expert_group: int = 1,
topk_group: int = 1,
q_lora_rank: int | None = None,
kv_lora_rank: int | None = None,
qk_nope_head_dim: int | None = None,
qk_rope_head_dim: int | None = None,
v_head_dim: int | None = None,
mla_use_nope: bool | None = False,
num_nextn_predict_layers: int = 0,
linear_attn_config: dict | None = None,
**kwargs,
):
self.model_type = model_type
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.head_dim = (
head_dim if head_dim is not None else hidden_size // num_attention_heads
)
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.mla_use_nope = mla_use_nope
# moe config
self.n_routed_experts = self.num_experts = num_experts
self.num_experts_per_token = num_experts_per_token
self.moe_renormalize = moe_renormalize
self.num_shared_experts = num_shared_experts
self.routed_scaling_factor = routed_scaling_factor
self.moe_router_activation_func = moe_router_activation_func
assert self.moe_router_activation_func in ("softmax", "sigmoid")
self.moe_intermediate_size = moe_intermediate_size
self.first_k_dense_replace = first_k_dense_replace
self.moe_layer_freq = moe_layer_freq
self.use_grouped_topk = use_grouped_topk
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.num_nextn_predict_layers = num_nextn_predict_layers
if linear_attn_config is not None:
assert linear_attn_config["kda_layers"] is not None
assert linear_attn_config["full_attn_layers"] is not None
self.linear_attn_config = linear_attn_config
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
@property
def is_mla(self):
return (
self.q_lora_rank is not None
or self.kv_lora_rank is not None
or self.qk_nope_head_dim is not None
or self.qk_rope_head_dim is not None
or self.v_head_dim is not None
or self.mla_use_nope is True
)
@property
def is_moe(self):
return self.num_experts is not None
@property
def is_linear_attn(self) -> bool:
return not (
self.linear_attn_config is None
or (
isinstance(self.linear_attn_config, dict)
and self.linear_attn_config["kda_layers"] is not None
and len(self.linear_attn_config["kda_layers"]) == 0
)
)
def is_kda_layer(self, layer_idx: int):
return (
self.linear_attn_config is not None
and (layer_idx + 1) in self.linear_attn_config["kda_layers"]
)
@property
def linear_layer_ids(self):
return [i for i in range(self.num_hidden_layers) if self.is_kda_layer(i)]
@property
def full_attention_layer_ids(self):
return [i for i in range(self.num_hidden_layers) if not self.is_kda_layer(i)]
@property
def mamba2_cache_params(self) -> KimiLinearCacheParams:
shape = KimiLinearStateShape.create(
tp_world_size=get_attention_tp_size(),
num_heads=self.linear_attn_config["num_heads"],
head_dim=self.linear_attn_config["head_dim"],
conv_kernel_size=self.linear_attn_config["short_conv_kernel_size"],
)
return KimiLinearCacheParams(shape=shape, layers=self.linear_layer_ids)
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Optional
import numpy as np import numpy as np
import torch import torch
...@@ -115,3 +116,68 @@ class Mamba2CacheParams: ...@@ -115,3 +116,68 @@ class Mamba2CacheParams:
int(np.prod(self.shape.conv)) * self.dtype.conv.itemsize int(np.prod(self.shape.conv)) * self.dtype.conv.itemsize
+ int(np.prod(self.shape.temporal)) * self.dtype.temporal.itemsize + int(np.prod(self.shape.temporal)) * self.dtype.temporal.itemsize
) * len(self.layers) ) * len(self.layers)
@dataclass(kw_only=True, frozen=True)
class KimiLinearStateShape:
conv: List[tuple[int, int]]
temporal: tuple[int, int, int]
num_heads: int
head_dim: int
num_k_heads: int
head_k_dim: int
conv_kernel: int
num_spec: int
@staticmethod
def create(
*,
tp_world_size: int,
num_heads: int,
head_dim: int,
num_k_heads: Optional[int] = None,
head_k_dim: Optional[int] = None,
conv_kernel_size: int = 4,
num_spec: int = 0,
) -> "KimiLinearStateShape":
if num_k_heads is None:
num_k_heads = num_heads
if head_k_dim is None:
head_k_dim = head_dim
proj_size = num_heads * head_dim
proj_k_size = num_k_heads * head_k_dim
conv_state_shape = (divide(proj_size, tp_world_size), conv_kernel_size - 1)
conv_state_k_shape = (divide(proj_k_size, tp_world_size), conv_kernel_size - 1)
temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, head_dim)
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
conv_state_k_shape = conv_state_k_shape[1], conv_state_k_shape[0]
return KimiLinearStateShape(
conv=[conv_state_shape, conv_state_k_shape, conv_state_k_shape],
temporal=temporal_state_shape,
num_heads=num_heads,
head_dim=head_dim,
num_k_heads=num_k_heads,
head_k_dim=head_k_dim,
conv_kernel=conv_kernel_size,
num_spec=num_spec,
)
@dataclass(kw_only=True, frozen=True)
class KimiLinearCacheParams:
shape: KimiLinearStateShape
dtype: Mamba2StateDType = field(default_factory=mamba2_state_dtype)
layers: list[int]
@property
def mamba_cache_per_req(self) -> int:
return (
int(np.sum([np.prod(conv_shape) for conv_shape in self.shape.conv]))
* self.dtype.conv.itemsize
+ int(np.prod(self.shape.temporal)) * self.dtype.temporal.itemsize
) * len(self.layers)
...@@ -366,6 +366,13 @@ class ModelConfig: ...@@ -366,6 +366,13 @@ class ModelConfig:
self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
self.v_head_dim = self.hf_text_config.v_head_dim self.v_head_dim = self.hf_text_config.v_head_dim
self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim
elif "KimiLinearForCausalLM" in self.hf_config.architectures:
self.head_dim = 72
self.attention_arch = AttentionArch.MLA
self.kv_lora_rank = self.hf_config.kv_lora_rank
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
self.v_head_dim = self.hf_config.v_head_dim
self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
else: else:
if ( if (
"MistralModel" in self.hf_config.architectures "MistralModel" in self.hf_config.architectures
......
...@@ -189,6 +189,7 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac ...@@ -189,6 +189,7 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac
from sglang.srt.layers.attention.hybrid_linear_attn_backend import ( from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
GDNAttnBackend, GDNAttnBackend,
HybridLinearAttnBackend, HybridLinearAttnBackend,
KimiLinearAttnBackend,
Mamba2AttnBackend, Mamba2AttnBackend,
) )
from sglang.srt.utils import is_blackwell, is_npu from sglang.srt.utils import is_blackwell, is_npu
...@@ -207,6 +208,8 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac ...@@ -207,6 +208,8 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac
linear_attn_backend = GDNAttnBackend(runner) linear_attn_backend = GDNAttnBackend(runner)
elif runner.mamba2_config is not None: elif runner.mamba2_config is not None:
linear_attn_backend = Mamba2AttnBackend(runner) linear_attn_backend = Mamba2AttnBackend(runner)
elif runner.kimi_linear_config is not None:
linear_attn_backend = KimiLinearAttnBackend(runner)
else: else:
raise ValueError( raise ValueError(
"Expected hybrid GDN or NemotronH models, but got unknown model." "Expected hybrid GDN or NemotronH models, but got unknown model."
......
...@@ -21,6 +21,7 @@ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16] ...@@ -21,6 +21,7 @@ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
@triton.heuristics( @triton.heuristics(
{ {
"USE_G": lambda args: args["g"] is not None, "USE_G": lambda args: args["g"] is not None,
"USE_GK": lambda args: args["gk"] is not None,
"USE_INITIAL_STATE": lambda args: args["h0"] is not None, "USE_INITIAL_STATE": lambda args: args["h0"] is not None,
"STORE_FINAL_STATE": lambda args: args["ht"] is not None, "STORE_FINAL_STATE": lambda args: args["ht"] is not None,
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None, "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
...@@ -44,6 +45,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( ...@@ -44,6 +45,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
w, w,
v_new, v_new,
g, g,
gk,
h, h,
h0, h0,
ht, ht,
...@@ -57,6 +59,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( ...@@ -57,6 +59,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
BT: tl.constexpr, BT: tl.constexpr,
BV: tl.constexpr, BV: tl.constexpr,
USE_G: tl.constexpr, USE_G: tl.constexpr,
USE_GK: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr, USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr,
SAVE_NEW_VALUE: tl.constexpr, SAVE_NEW_VALUE: tl.constexpr,
...@@ -86,12 +89,12 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( ...@@ -86,12 +89,12 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
b_h4 = tl.zeros([64, BV], dtype=tl.float32) b_h4 = tl.zeros([64, BV], dtype=tl.float32)
# calculate offset # calculate offset
h += (boh * H + i_h) * K * V h += ((boh * H + i_h) * K * V).to(tl.int64)
v += (bos * H + i_h) * V v += ((bos * H + i_h) * V).to(tl.int64)
k += (bos * Hg + i_h // (H // Hg)) * K k += ((bos * Hg + i_h // (H // Hg)) * K).to(tl.int64)
w += (bos * H + i_h) * K w += ((bos * H + i_h) * K).to(tl.int64)
if SAVE_NEW_VALUE: if SAVE_NEW_VALUE:
v_new += (bos * H + i_h) * V v_new += ((bos * H + i_h) * V).to(tl.int64)
stride_v = H * V stride_v = H * V
stride_h = H * K * V stride_h = H * K * V
stride_k = Hg * K stride_k = Hg * K
...@@ -143,58 +146,48 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( ...@@ -143,58 +146,48 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
) )
tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
p_v = tl.make_block_ptr(
v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
p_v_new = (
tl.make_block_ptr(
v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
if SAVE_NEW_VALUE
else None
)
b_v_new = tl.zeros([BT, BV], dtype=tl.float32)
p_w = tl.make_block_ptr( p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0) w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)
) )
b_w = tl.load(p_w, boundary_check=(0, 1)) b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_new += tl.dot(b_w, b_h1.to(b_w.dtype)) b_v = tl.dot(b_w, b_h1.to(b_w.dtype))
if K > 64: if K > 64:
p_w = tl.make_block_ptr( p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0) w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)
) )
b_w = tl.load(p_w, boundary_check=(0, 1)) b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_new += tl.dot(b_w, b_h2.to(b_w.dtype)) b_v += tl.dot(b_w, b_h2.to(b_w.dtype))
if K > 128: if K > 128:
p_w = tl.make_block_ptr( p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0) w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)
) )
b_w = tl.load(p_w, boundary_check=(0, 1)) b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_new += tl.dot(b_w, b_h3.to(b_w.dtype)) b_v += tl.dot(b_w, b_h3.to(b_w.dtype))
if K > 192: if K > 192:
p_w = tl.make_block_ptr( p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0) w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)
) )
b_w = tl.load(p_w, boundary_check=(0, 1)) b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_new += tl.dot(b_w, b_h4.to(b_w.dtype)) b_v += tl.dot(b_w, b_h4.to(b_w.dtype))
b_v_new = -b_v_new + tl.load(p_v, boundary_check=(0, 1)) p_v = tl.make_block_ptr(
v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v
if SAVE_NEW_VALUE: if SAVE_NEW_VALUE:
p_v_new = tl.make_block_ptr( p_v = tl.make_block_ptr(
v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
) )
tl.store( tl.store(p_v, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1))
p_v_new, b_v_new.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)
)
last_idx = min((i_t + 1) * BT, T) - 1
if USE_G: if USE_G:
last_idx = min((i_t + 1) * BT, T) - 1
b_g_last = tl.load(g + bos * H + last_idx * H + i_h) b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
p_g = tl.make_block_ptr( p_g = tl.make_block_ptr(
g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,) g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
) )
b_g = tl.load(p_g, boundary_check=(0,)) b_g = tl.load(p_g, boundary_check=(0,))
b_v_new = b_v_new * safe_exp(b_g_last - b_g)[:, None] b_v = b_v * safe_exp(b_g_last - b_g)[:, None]
b_g_last = exp(b_g_last) b_g_last = exp(b_g_last)
b_h1 = b_h1 * b_g_last b_h1 = b_h1 * b_g_last
if K > 64: if K > 64:
...@@ -203,30 +196,64 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( ...@@ -203,30 +196,64 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
b_h3 = b_h3 * b_g_last b_h3 = b_h3 * b_g_last
if K > 192: if K > 192:
b_h4 = b_h4 * b_g_last b_h4 = b_h4 * b_g_last
b_v_new = b_v_new.to(k.dtype.element_ty)
if USE_GK:
o_k1 = tl.arange(0, 64)
b_gk_last1 = tl.load(
gk + (bos + last_idx) * H * K + i_h * K + o_k1,
mask=(o_k1 < K),
other=0.0,
)
b_h1 *= exp(b_gk_last1)[:, None]
if K > 64:
o_k2 = 64 + o_k1
b_gk_last2 = tl.load(
gk + (bos + last_idx) * H * K + i_h * K + o_k2,
mask=(o_k2 < K),
other=0.0,
)
b_h2 *= exp(b_gk_last2)[:, None]
if K > 128:
o_k3 = 128 + o_k1
b_gk_last3 = tl.load(
gk + (bos + last_idx) * H * K + i_h * K + o_k3,
mask=(o_k3 < K),
other=0.0,
)
b_h3 *= exp(b_gk_last3)[:, None]
if K > 192:
o_k4 = 192 + o_k1
b_gk_last4 = tl.load(
gk + (bos + last_idx) * H * K + i_h * K + o_k4,
mask=(o_k4 < K),
other=0.0,
)
b_h4 *= exp(b_gk_last4)[:, None]
b_v = b_v.to(k.dtype.element_ty)
p_k = tl.make_block_ptr( p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1) k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)
) )
b_k = tl.load(p_k, boundary_check=(0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1))
b_h1 += tl.dot(b_k, b_v_new) b_h1 += tl.dot(b_k, b_v)
if K > 64: if K > 64:
p_k = tl.make_block_ptr( p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1) k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)
) )
b_k = tl.load(p_k, boundary_check=(0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1))
b_h2 += tl.dot(b_k, b_v_new) b_h2 += tl.dot(b_k, b_v)
if K > 128: if K > 128:
p_k = tl.make_block_ptr( p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1) k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)
) )
b_k = tl.load(p_k, boundary_check=(0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1))
b_h3 += tl.dot(b_k, b_v_new) b_h3 += tl.dot(b_k, b_v)
if K > 192: if K > 192:
p_k = tl.make_block_ptr( p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1) k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)
) )
b_k = tl.load(p_k, boundary_check=(0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1))
b_h4 += tl.dot(b_k, b_v_new) b_h4 += tl.dot(b_k, b_v)
# epilogue # epilogue
if STORE_FINAL_STATE: if STORE_FINAL_STATE:
...@@ -254,6 +281,7 @@ def chunk_gated_delta_rule_fwd_h( ...@@ -254,6 +281,7 @@ def chunk_gated_delta_rule_fwd_h(
w: torch.Tensor, w: torch.Tensor,
u: torch.Tensor, u: torch.Tensor,
g: Optional[torch.Tensor] = None, g: Optional[torch.Tensor] = None,
gk: Optional[torch.Tensor] = None,
initial_state: Optional[torch.Tensor] = None, initial_state: Optional[torch.Tensor] = None,
output_final_state: bool = False, output_final_state: bool = False,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64? chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
...@@ -296,6 +324,7 @@ def chunk_gated_delta_rule_fwd_h( ...@@ -296,6 +324,7 @@ def chunk_gated_delta_rule_fwd_h(
w=w, w=w,
v_new=v_new, v_new=v_new,
g=g, g=g,
gk=gk,
h=h, h=h,
h0=initial_state, h0=initial_state,
ht=final_state, ht=final_state,
......
...@@ -44,6 +44,7 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( ...@@ -44,6 +44,7 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar, IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar,
USE_QK_L2NORM_IN_KERNEL: tl.constexpr, USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
IS_VARLEN: tl.constexpr, IS_VARLEN: tl.constexpr,
IS_KDA: tl.constexpr,
): ):
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_n, i_hv = i_nh // HV, i_nh % HV i_n, i_hv = i_nh // HV, i_nh % HV
...@@ -67,7 +68,11 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( ...@@ -67,7 +68,11 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
p_beta = beta + (bos * HV + i_hv) * V + o_v p_beta = beta + (bos * HV + i_hv) * V + o_v
else: else:
p_beta = beta + bos * HV + i_hv p_beta = beta + bos * HV + i_hv
p_g = g + bos * HV + i_hv if not IS_KDA:
p_g = g + bos * HV + i_hv
else:
p_gk = g + (bos * HV + i_hv) * K + o_k
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
mask_k = o_k < K mask_k = o_k < K
...@@ -83,14 +88,18 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( ...@@ -83,14 +88,18 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
b_g = tl.load(p_g).to(tl.float32)
if USE_QK_L2NORM_IN_KERNEL: if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6)) b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6)) b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
b_q = b_q * scale b_q = b_q * scale
# [BK, BV] # [BK, BV]
b_h *= exp(b_g) if not IS_KDA:
b_g = tl.load(p_g).to(tl.float32)
b_h *= exp(b_g)
else:
b_gk = tl.load(p_gk).to(tl.float32)
b_h *= exp(b_gk[:, None])
# [BV] # [BV]
b_v -= tl.sum(b_h * b_k[:, None], 0) b_v -= tl.sum(b_h * b_k[:, None], 0)
if IS_BETA_HEADWISE: if IS_BETA_HEADWISE:
...@@ -108,7 +117,10 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( ...@@ -108,7 +117,10 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
p_k += H * K p_k += H * K
p_o += HV * V p_o += HV * V
p_v += HV * V p_v += HV * V
p_g += HV if not IS_KDA:
p_g += HV
else:
p_gk += HV * K
p_beta += HV * (V if IS_BETA_HEADWISE else 1) p_beta += HV * (V if IS_BETA_HEADWISE else 1)
if STORE_FINAL_STATE: if STORE_FINAL_STATE:
...@@ -165,6 +177,7 @@ def fused_recurrent_gated_delta_rule_fwd( ...@@ -165,6 +177,7 @@ def fused_recurrent_gated_delta_rule_fwd(
BV=BV, BV=BV,
IS_BETA_HEADWISE=beta.ndim == v.ndim, IS_BETA_HEADWISE=beta.ndim == v.ndim,
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
IS_KDA=False,
num_warps=num_warps, num_warps=num_warps,
num_stages=num_stages, num_stages=num_stages,
) )
......
# Adapted from https://github.com/vllm-project/vllm/blob/0384aa7150c4c9778efca041ffd1beb3ad2bd694/vllm/model_executor/layers/fla/ops/kda.py
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
import torch
import torch.nn as nn
import triton
import triton.language as tl
from sglang.srt.layers.attention.fla.chunk_delta_h import chunk_gated_delta_rule_fwd_h
from sglang.srt.layers.attention.fla.cumsum import chunk_local_cumsum
from sglang.srt.layers.attention.fla.fused_recurrent import (
fused_recurrent_gated_delta_rule_fwd_kernel,
)
from sglang.srt.layers.attention.fla.index import prepare_chunk_indices
from sglang.srt.layers.attention.fla.l2norm import l2norm_fwd
from sglang.srt.layers.attention.fla.op import exp, log
from sglang.srt.layers.attention.fla.solve_tril import solve_tril
from sglang.srt.layers.attention.fla.utils import is_amd
BT_LIST_AUTOTUNE = [32, 64, 128]
NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if is_amd else [4, 8, 16, 32]
def cdiv(a: int, b: int) -> int:
"""Ceiling division."""
return -(a // -b)
def next_power_of_2(n: int) -> int:
"""The next power of 2 (inclusive)"""
if n < 1:
return 1
return 1 << (n - 1).bit_length()
def fused_recurrent_kda_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: torch.LongTensor | None = None,
# ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
HV = v.shape[2]
N = B if cu_seqlens is None else len(cu_seqlens) - 1
BK, BV = next_power_of_2(K), min(next_power_of_2(V), 8)
NK, NV = cdiv(K, BK), cdiv(V, BV)
assert NK == 1, "NK > 1 is not supported yet"
num_stages = 3
num_warps = 1
o = torch.empty_like(k)
if inplace_final_state:
final_state = initial_state
else:
final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype)
stride_init_state_token = initial_state.stride(0)
stride_final_state_token = final_state.stride(0)
# if ssm_state_indices is None:
# stride_indices_seq, stride_indices_tok = 1, 1
# elif ssm_state_indices.ndim == 1:
# stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
# else:
# stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()
grid = (NK, NV, N * HV)
fused_recurrent_gated_delta_rule_fwd_kernel[grid](
q=q,
k=k,
v=v,
g=g,
beta=beta,
o=o,
h0=initial_state,
ht=final_state,
cu_seqlens=cu_seqlens,
# ssm_state_indices=ssm_state_indices,
# num_accepted_tokens=num_accepted_tokens,
scale=scale,
# N=N,
T=T,
B=B,
H=H,
HV=HV,
K=K,
V=V,
BK=BK,
BV=BV,
# stride_init_state_token=stride_init_state_token,
# stride_final_state_token=stride_final_state_token,
# stride_indices_seq=stride_indices_seq,
# stride_indices_tok=stride_indices_tok,
IS_BETA_HEADWISE=beta.ndim == v.ndim,
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
# INPLACE_FINAL_STATE=inplace_final_state,
IS_KDA=True,
num_warps=num_warps,
num_stages=num_stages,
)
return o, final_state
def fused_recurrent_kda(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor = None,
scale: float = None,
initial_state: torch.Tensor = None,
inplace_final_state: bool = True,
use_qk_l2norm_in_kernel: bool = True,
cu_seqlens: torch.LongTensor | None = None,
# ssm_state_indices: torch.LongTensor | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
if cu_seqlens is not None and q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing."
)
if scale is None:
scale = k.shape[-1] ** -0.5
o, final_state = fused_recurrent_kda_fwd(
q=q.contiguous(),
k=k.contiguous(),
v=v.contiguous(),
g=g.contiguous(),
beta=beta.contiguous(),
scale=scale,
initial_state=initial_state,
inplace_final_state=inplace_final_state,
cu_seqlens=cu_seqlens,
# ssm_state_indices=ssm_state_indices,
num_accepted_tokens=None,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)
return o, final_state
@triton.heuristics(
{
"STORE_RESIDUAL_OUT": lambda args: args["residual_out"] is not None,
"HAS_RESIDUAL": lambda args: args["residual"] is not None,
"HAS_WEIGHT": lambda args: args["w"] is not None,
"HAS_BIAS": lambda args: args["b"] is not None,
}
)
@triton.jit
def layer_norm_gated_fwd_kernel(
x, # pointer to the input
g, # pointer to the gate
y, # pointer to the output
w, # pointer to the weights
b, # pointer to the biases
residual, # pointer to the residual
residual_out, # pointer to the residual
mean, # pointer to the mean
rstd, # pointer to the 1/std
eps, # epsilon to avoid division by zero
T, # number of rows in x
D: tl.constexpr, # number of columns in x
BT: tl.constexpr,
BD: tl.constexpr,
ACTIVATION: tl.constexpr,
IS_RMS_NORM: tl.constexpr,
STORE_RESIDUAL_OUT: tl.constexpr,
HAS_RESIDUAL: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_BIAS: tl.constexpr,
):
i_t = tl.program_id(0)
o_d = tl.arange(0, BD)
m_d = o_d < D
p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
if HAS_RESIDUAL:
p_res = tl.make_block_ptr(
residual, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)
)
b_x += tl.load(p_res, boundary_check=(0, 1)).to(tl.float32)
if STORE_RESIDUAL_OUT:
p_res_out = tl.make_block_ptr(
residual_out, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)
)
tl.store(p_res_out, b_x.to(p_res_out.dtype.element_ty), boundary_check=(0, 1))
if not IS_RMS_NORM:
b_mean = tl.sum(b_x, axis=1) / D
p_mean = tl.make_block_ptr(mean, (T,), (1,), (i_t * BT,), (BT,), (0,))
tl.store(p_mean, b_mean.to(p_mean.dtype.element_ty), boundary_check=(0,))
b_xbar = tl.where(m_d[None, :], b_x - b_mean[:, None], 0.0)
b_var = tl.sum(b_xbar * b_xbar, axis=1) / D
else:
b_xbar = tl.where(m_d[None, :], b_x, 0.0)
b_var = tl.sum(b_xbar * b_xbar, axis=1) / D
b_rstd = 1 / tl.sqrt(b_var + eps)
p_rstd = tl.make_block_ptr(rstd, (T,), (1,), (i_t * BT,), (BT,), (0,))
tl.store(p_rstd, b_rstd.to(p_rstd.dtype.element_ty), boundary_check=(0,))
if HAS_WEIGHT:
b_w = tl.load(w + o_d, mask=m_d).to(tl.float32)
if HAS_BIAS:
b_b = tl.load(b + o_d, mask=m_d).to(tl.float32)
b_x_hat = (
(b_x - b_mean[:, None]) * b_rstd[:, None]
if not IS_RMS_NORM
else b_x * b_rstd[:, None]
)
b_y = b_x_hat * b_w[None, :] if HAS_WEIGHT else b_x_hat
if HAS_BIAS:
b_y = b_y + b_b[None, :]
# swish/sigmoid output gate
p_g = tl.make_block_ptr(g, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
if ACTIVATION == "swish" or ACTIVATION == "silu":
b_y = b_y * b_g * tl.sigmoid(b_g)
elif ACTIVATION == "sigmoid":
b_y = b_y * tl.sigmoid(b_g)
# Write output
p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics(
{
"STORE_RESIDUAL_OUT": lambda args: args["residual_out"] is not None,
"HAS_RESIDUAL": lambda args: args["residual"] is not None,
"HAS_WEIGHT": lambda args: args["w"] is not None,
"HAS_BIAS": lambda args: args["b"] is not None,
}
)
@triton.jit
def layer_norm_gated_fwd_kernel1(
x, # pointer to the input
g, # pointer to the gate
y, # pointer to the output
w, # pointer to the weights
b, # pointer to the biases
residual, # pointer to the residual
residual_out, # pointer to the residual
mean, # pointer to the mean
rstd, # pointer to the 1/std
eps, # epsilon to avoid division by zero
D: tl.constexpr, # number of columns in x
BD: tl.constexpr,
ACTIVATION: tl.constexpr,
IS_RMS_NORM: tl.constexpr,
STORE_RESIDUAL_OUT: tl.constexpr,
HAS_RESIDUAL: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_BIAS: tl.constexpr,
):
i_t = tl.program_id(0)
x += i_t * D
y += i_t * D
g += i_t * D
if HAS_RESIDUAL:
residual += i_t * D
if STORE_RESIDUAL_OUT:
residual_out += i_t * D
o_d = tl.arange(0, BD)
m_d = o_d < D
b_x = tl.load(x + o_d, mask=m_d, other=0.0).to(tl.float32)
if HAS_RESIDUAL:
b_x += tl.load(residual + o_d, mask=m_d, other=0.0).to(tl.float32)
if STORE_RESIDUAL_OUT:
tl.store(residual_out + o_d, b_x, mask=m_d)
if not IS_RMS_NORM:
b_mean = tl.sum(b_x, axis=0) / D
tl.store(mean + i_t, b_mean)
b_xbar = tl.where(m_d, b_x - b_mean, 0.0)
b_var = tl.sum(b_xbar * b_xbar, axis=0) / D
else:
b_xbar = tl.where(m_d, b_x, 0.0)
b_var = tl.sum(b_xbar * b_xbar, axis=0) / D
b_rstd = 1 / tl.sqrt(b_var + eps)
tl.store(rstd + i_t, b_rstd)
if HAS_WEIGHT:
b_w = tl.load(w + o_d, mask=m_d).to(tl.float32)
if HAS_BIAS:
b_b = tl.load(b + o_d, mask=m_d).to(tl.float32)
b_x_hat = (b_x - b_mean) * b_rstd if not IS_RMS_NORM else b_x * b_rstd
b_y = b_x_hat * b_w if HAS_WEIGHT else b_x_hat
if HAS_BIAS:
b_y = b_y + b_b
# swish/sigmoid output gate
b_g = tl.load(g + o_d, mask=m_d, other=0.0).to(tl.float32)
if ACTIVATION == "swish" or ACTIVATION == "silu":
b_y = b_y * b_g * tl.sigmoid(b_g)
elif ACTIVATION == "sigmoid":
b_y = b_y * tl.sigmoid(b_g)
# Write output
tl.store(y + o_d, b_y, mask=m_d)
def layer_norm_gated_fwd(
x: torch.Tensor,
g: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
activation: str = "swish",
eps: float = 1e-5,
residual: torch.Tensor = None,
out_dtype: torch.dtype = None,
residual_dtype: torch.dtype = None,
is_rms_norm: bool = False,
):
if residual is not None:
residual_dtype = residual.dtype
T, D = x.shape
if residual is not None:
assert residual.shape == (T, D)
if weight is not None:
assert weight.shape == (D,)
if bias is not None:
assert bias.shape == (D,)
# allocate output
y = x if out_dtype is None else torch.empty_like(x, dtype=out_dtype)
if residual is not None or (
residual_dtype is not None and residual_dtype != x.dtype
):
residual_out = torch.empty(T, D, device=x.device, dtype=residual_dtype)
else:
residual_out = None
mean = (
torch.empty((T,), dtype=torch.float, device=x.device)
if not is_rms_norm
else None
)
rstd = torch.empty((T,), dtype=torch.float, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BD = min(MAX_FUSED_SIZE, next_power_of_2(D))
if D > BD:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
if D <= 512:
BT = 32
layer_norm_gated_fwd_kernel[(cdiv(T, BT),)](
x=x,
g=g,
y=y,
w=weight,
b=bias,
residual=residual,
residual_out=residual_out,
mean=mean,
rstd=rstd,
eps=eps,
T=T,
D=D,
BD=BD,
BT=BT,
ACTIVATION=activation,
IS_RMS_NORM=is_rms_norm,
num_warps=4,
)
else:
layer_norm_gated_fwd_kernel1[(T,)](
x=x,
g=g,
y=y,
w=weight,
b=bias,
residual=residual,
residual_out=residual_out,
mean=mean,
rstd=rstd,
eps=eps,
D=D,
BD=BD,
ACTIVATION=activation,
IS_RMS_NORM=is_rms_norm,
num_warps=4,
)
# residual_out is None if residual is None and residual_dtype == input_dtype
return y, mean, rstd, residual_out if residual_out is not None else x
def rms_norm_gated(
x: torch.Tensor,
g: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
activation: str = "swish",
residual: torch.Tensor | None = None,
prenorm: bool = False,
residual_in_fp32: bool = False,
eps: float = 1e-6,
):
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.contiguous().reshape(-1, x.shape[-1])
g = g.contiguous().reshape(-1, g.shape[-1])
if residual is not None:
assert residual.shape == x_shape_og
residual = residual.contiguous().reshape(-1, residual.shape[-1])
residual_dtype = (
residual.dtype
if residual is not None
else (torch.float if residual_in_fp32 else None)
)
y, _, _, residual_out = layer_norm_gated_fwd(
x=x,
g=g,
weight=weight,
bias=bias,
activation=activation,
eps=eps,
residual=residual,
residual_dtype=residual_dtype,
is_rms_norm=True,
)
y = y.reshape(x_shape_og)
return y if not prenorm else (y, residual_out.reshape(x_shape_og))
class FusedRMSNormGated(nn.Module):
def __init__(
self,
hidden_size: int,
elementwise_affine: bool = True,
eps: float = 1e-5,
activation: str = "swish",
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.hidden_size = hidden_size
self.elementwise_affine = elementwise_affine
self.eps = eps
self.activation = activation
if self.activation not in ["swish", "silu", "sigmoid"]:
raise ValueError(f"Unsupported activation: {self.activation}")
if elementwise_affine:
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def forward(
self,
x: torch.Tensor,
g: torch.Tensor,
residual: torch.Tensor | None = None,
prenorm: bool = False,
residual_in_fp32: bool = False,
) -> torch.Tensor:
return rms_norm_gated(
x,
g,
self.weight,
self.bias,
self.activation,
residual=residual,
eps=self.eps,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32,
)
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.autotune(
configs=[
triton.Config({"BK": BK}, num_warps=num_warps, num_stages=num_stages)
for BK in [32, 64]
for num_warps in [1, 2, 4, 8]
for num_stages in [2, 3, 4]
],
key=["BC"],
)
@triton.jit(do_not_specialize=["T"])
def chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_inter(
q,
k,
g,
beta,
A,
Aqk,
scale,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr,
BK: tl.constexpr,
NC: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
i_i, i_j = i_c // NC, i_c % NC
if IS_VARLEN:
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
if i_t * BT + i_i * BC >= T:
return
if i_i <= i_j:
return
q += (bos * H + i_h) * K
k += (bos * H + i_h) * K
g += (bos * H + i_h) * K
A += (bos * H + i_h) * BT
Aqk += (bos * H + i_h) * BT
p_b = tl.make_block_ptr(
beta + bos * H + i_h, (T,), (H,), (i_t * BT + i_i * BC,), (BC,), (0,)
)
b_b = tl.load(p_b, boundary_check=(0,))
b_A = tl.zeros([BC, BC], dtype=tl.float32)
b_Aqk = tl.zeros([BC, BC], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(
q, (T, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)
)
p_k = tl.make_block_ptr(
k, (T, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)
)
p_g = tl.make_block_ptr(
g, (T, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)
)
b_kt = tl.make_block_ptr(
k, (K, T), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)
)
p_gk = tl.make_block_ptr(
g, (K, T), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)
)
o_k = i_k * BK + tl.arange(0, BK)
m_k = o_k < K
# [BK,]
b_gn = tl.load(g + (i_t * BT + i_i * BC) * H * K + o_k, mask=m_k, other=0)
# [BC, BK]
b_g = tl.load(p_g, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1)) * exp(b_g - b_gn[None, :])
# [BK, BC]
b_gk = tl.load(p_gk, boundary_check=(0, 1))
b_kt = tl.load(b_kt, boundary_check=(0, 1))
# [BC, BC]
b_ktg = b_kt * exp(b_gn[:, None] - b_gk)
b_A += tl.dot(b_k, b_ktg)
b_q = tl.load(p_q, boundary_check=(0, 1))
b_qg = b_q * exp(b_g - b_gn[None, :]) * scale
b_Aqk += tl.dot(b_qg, b_ktg)
b_A *= b_b[:, None]
p_A = tl.make_block_ptr(
A, (T, BT), (H * BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)
)
tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
p_Aqk = tl.make_block_ptr(
Aqk, (T, BT), (H * BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)
)
tl.store(p_Aqk, b_Aqk.to(Aqk.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]],
key=["BK", "BT"],
)
@triton.jit(do_not_specialize=["T"])
def chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_intra(
q,
k,
g,
beta,
A,
Aqk,
scale,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr,
BK: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
if i_t * BT + i_i * BC >= T:
return
o_i = tl.arange(0, BC)
o_k = tl.arange(0, BK)
m_k = o_k < K
m_A = (i_t * BT + i_i * BC + o_i) < T
o_A = (bos + i_t * BT + i_i * BC + o_i) * H * BT + i_h * BT + i_i * BC
p_q = tl.make_block_ptr(
q + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT + i_i * BC, 0),
(BC, BK),
(1, 0),
)
p_k = tl.make_block_ptr(
k + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT + i_i * BC, 0),
(BC, BK),
(1, 0),
)
p_g = tl.make_block_ptr(
g + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT + i_i * BC, 0),
(BC, BK),
(1, 0),
)
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_g = tl.load(p_g, boundary_check=(0, 1))
p_b = beta + (bos + i_t * BT + i_i * BC + o_i) * H + i_h
b_k = b_k * tl.load(p_b, mask=m_A, other=0)[:, None]
p_kt = k + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k
p_gk = g + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k
for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
b_kt = tl.load(p_kt, mask=m_k, other=0).to(tl.float32)
b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32)
b_ktg = b_kt[None, :] * exp(b_g - b_gk[None, :])
b_A = tl.sum(b_k * b_ktg, 1)
b_A = tl.where(o_i > j, b_A, 0.0)
b_Aqk = tl.sum(b_q * b_ktg, 1)
b_Aqk = tl.where(o_i >= j, b_Aqk * scale, 0.0)
tl.store(A + o_A + j, b_A, mask=m_A)
tl.store(Aqk + o_A + j, b_Aqk, mask=m_A)
p_kt += H * K
p_gk += H * K
def chunk_kda_scaled_dot_kkt_fwd(
q: torch.Tensor,
k: torch.Tensor,
gk: torch.Tensor | None = None,
beta: torch.Tensor | None = None,
scale: float | None = None,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]:
r"""
Compute beta * K * K^T.
Args:
k (torch.Tensor):
The key tensor of shape `[B, T, H, K]`.
beta (torch.Tensor):
The beta tensor of shape `[B, T, H]`.
gk (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.
cu_seqlens (torch.LongTensor):
The cumulative sequence lengths of the input tensor.
Default: None
chunk_size (int):
The chunk size. Default: 64.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float32`
Returns:
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
"""
B, T, H, K = k.shape
assert K <= 256
BT = chunk_size
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
NT = cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
BC = min(16, BT)
NC = cdiv(BT, BC)
BK = max(next_power_of_2(K), 16)
A = torch.zeros(B, T, H, BT, device=k.device, dtype=output_dtype)
Aqk = torch.zeros(B, T, H, BT, device=k.device, dtype=output_dtype)
grid = (NT, NC * NC, B * H)
chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_inter[grid](
q=q,
k=k,
g=gk,
beta=beta,
A=A,
Aqk=Aqk,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
K=K,
BT=BT,
BC=BC,
NC=NC,
)
grid = (NT, NC, B * H)
chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_intra[grid](
q=q,
k=k,
g=gk,
beta=beta,
A=A,
Aqk=Aqk,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
K=K,
BT=BT,
BC=BC,
BK=BK,
)
return A, Aqk
@triton.heuristics(
{
"STORE_QG": lambda args: args["qg"] is not None,
"STORE_KG": lambda args: args["kg"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [2, 4, 8]
for num_stages in [2, 3, 4]
],
key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"],
)
@triton.jit(do_not_specialize=["T"])
def recompute_w_u_fwd_kernel(
q,
k,
qg,
kg,
v,
beta,
w,
u,
A,
gk,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
STORE_QG: tl.constexpr,
STORE_KG: tl.constexpr,
IS_VARLEN: tl.constexpr,
DOT_PRECISION: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
p_b = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
b_b = tl.load(p_b, boundary_check=(0,))
p_A = tl.make_block_ptr(
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
)
b_A = tl.load(p_A, boundary_check=(0, 1))
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(
v + (bos * H + i_h) * V,
(T, V),
(H * V, 1),
(i_t * BT, i_v * BV),
(BT, BV),
(1, 0),
)
p_u = tl.make_block_ptr(
u + (bos * H + i_h) * V,
(T, V),
(H * V, 1),
(i_t * BT, i_v * BV),
(BT, BV),
(1, 0),
)
b_v = tl.load(p_v, boundary_check=(0, 1))
b_vb = (b_v * b_b[:, None]).to(b_v.dtype)
b_u = tl.dot(b_A, b_vb, input_precision=DOT_PRECISION)
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
for i_k in range(tl.cdiv(K, BK)):
p_w = tl.make_block_ptr(
w + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
p_k = tl.make_block_ptr(
k + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_kb = b_k * b_b[:, None]
p_gk = tl.make_block_ptr(
gk + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
b_gk = tl.load(p_gk, boundary_check=(0, 1))
b_kb *= exp(b_gk)
if STORE_QG:
p_q = tl.make_block_ptr(
q + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
p_qg = tl.make_block_ptr(
qg + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
b_q = tl.load(p_q, boundary_check=(0, 1))
b_qg = b_q * exp(b_gk)
tl.store(p_qg, b_qg.to(p_qg.dtype.element_ty), boundary_check=(0, 1))
if STORE_KG:
last_idx = min(i_t * BT + BT, T) - 1
o_k = i_k * BK + tl.arange(0, BK)
m_k = o_k < K
b_gn = tl.load(
gk + ((bos + last_idx) * H + i_h) * K + o_k, mask=m_k, other=0.0
)
b_kg = b_k * exp(b_gn - b_gk)
p_kg = tl.make_block_ptr(
kg + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
tl.store(p_kg, b_kg.to(p_kg.dtype.element_ty), boundary_check=(0, 1))
b_w = tl.dot(b_A, b_kb.to(b_k.dtype))
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
def recompute_w_u_fwd(
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
A: torch.Tensor,
q: torch.Tensor | None = None,
gk: torch.Tensor | None = None,
cu_seqlens: torch.LongTensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
BT = A.shape[-1]
BK = 64
BV = 64
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
NT = cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
w = torch.empty_like(k)
u = torch.empty_like(v)
kg = torch.empty_like(k) if gk is not None else None
recompute_w_u_fwd_kernel[(NT, B * H)](
q=q,
k=k,
qg=None,
kg=kg,
v=v,
beta=beta,
w=w,
u=u,
A=A,
gk=gk,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
DOT_PRECISION="ieee",
)
return w, u, None, kg
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.autotune(
configs=[
triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages)
for BK in [32, 64]
for BV in [64, 128]
for num_warps in [2, 4, 8]
for num_stages in [2, 3, 4]
],
key=["BT"],
)
@triton.jit(do_not_specialize=["T"])
def chunk_gla_fwd_kernel_o(
q,
v,
g,
h,
o,
A,
cu_seqlens,
chunk_indices,
scale,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_tg = i_t
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
NT = tl.cdiv(T, BT)
else:
NT = tl.cdiv(T, BT)
i_tg = i_b * NT + i_t
bos, eos = i_b * T, i_b * T + T
m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]
b_o = tl.zeros([BT, BV], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(
q + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
p_g = tl.make_block_ptr(
g + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
p_h = tl.make_block_ptr(
h + (i_tg * H + i_h) * K * V,
(K, V),
(V, 1),
(i_k * BK, i_v * BV),
(BK, BV),
(1, 0),
)
# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
# [BT, BK]
b_g = tl.load(p_g, boundary_check=(0, 1))
# [BT, BK]
b_qg = (b_q * exp(b_g)).to(b_q.dtype)
# [BK, BV]
b_h = tl.load(p_h, boundary_check=(0, 1))
# works but dkw, owing to divine benevolence
# [BT, BV]
if i_k >= 0:
b_o += tl.dot(b_qg, b_h.to(b_qg.dtype))
p_v = tl.make_block_ptr(
v + (bos * H + i_h) * V,
(T, V),
(H * V, 1),
(i_t * BT, i_v * BV),
(BT, BV),
(1, 0),
)
p_o = tl.make_block_ptr(
o + (bos * H + i_h) * V,
(T, V),
(H * V, 1),
(i_t * BT, i_v * BV),
(BT, BV),
(1, 0),
)
p_A = tl.make_block_ptr(
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
)
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BT, BT]
b_A = tl.load(p_A, boundary_check=(0, 1))
b_A = tl.where(m_s, b_A, 0.0).to(b_v.dtype)
b_o += tl.dot(b_A, b_v, allow_tf32=False)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
def chunk_gla_fwd_o_gk(
q: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
A: torch.Tensor,
h: torch.Tensor,
o: torch.Tensor,
scale: float,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
):
B, T, H, K, V = *q.shape, v.shape[-1]
BT = chunk_size
chunk_indices = (
prepare_chunk_indices(cu_seqlens, chunk_size)
if cu_seqlens is not None
else None
)
NT = cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
def grid(meta):
return (cdiv(V, meta["BV"]), NT, B * H)
chunk_gla_fwd_kernel_o[grid](
q=q,
v=v,
g=g,
h=h,
o=o,
A=A,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
scale=scale,
T=T,
H=H,
K=K,
V=V,
BT=BT,
)
return o
def chunk_kda_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
):
chunk_size = 64
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens)
# the intra Aqk is kept in fp32
# the computation has very marginal effect on the entire throughput
A, Aqk = chunk_kda_scaled_dot_kkt_fwd(
q=q,
k=k,
gk=g,
beta=beta,
scale=scale,
cu_seqlens=cu_seqlens,
output_dtype=torch.float32,
)
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
w, u, _, kg = recompute_w_u_fwd(
k=k,
v=v,
beta=beta,
A=A,
gk=g,
cu_seqlens=cu_seqlens,
)
del A
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
k=kg,
w=w,
u=u,
gk=g,
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
)
del w, u, kg
o = chunk_gla_fwd_o_gk(
q=q,
v=v_new,
g=g,
A=Aqk,
h=h,
o=v,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
del Aqk, v_new, h
return o, final_state
def chunk_kda(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
use_qk_l2norm_in_kernel: bool = False,
cu_seqlens: torch.LongTensor | None = None,
**kwargs,
):
if scale is None:
scale = k.shape[-1] ** -0.5
if use_qk_l2norm_in_kernel:
q = l2norm_fwd(q.contiguous())
k = l2norm_fwd(k.contiguous())
o, final_state = chunk_kda_fwd(
q=q,
k=k,
v=v.contiguous(),
g=g.contiguous(),
beta=beta.contiguous(),
scale=scale,
initial_state=initial_state.contiguous(),
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
)
return o, final_state
@triton.autotune(
configs=[
triton.Config({"BT": bt}, num_warps=nw, num_stages=ns)
for bt in BT_LIST_AUTOTUNE
for nw in NUM_WARPS_AUTOTUNE
for ns in [2, 3]
],
key=["H", "D"],
)
@triton.jit
def kda_gate_fwd_kernel(
g,
A,
y,
g_bias,
beta: tl.constexpr,
threshold: tl.constexpr,
T,
H,
D: tl.constexpr,
BT: tl.constexpr,
BD: tl.constexpr,
HAS_BIAS: tl.constexpr,
):
i_t, i_h = tl.program_id(0), tl.program_id(1)
n_t = i_t * BT
b_a = tl.load(A + i_h).to(tl.float32)
b_a = -tl.exp(b_a)
stride_row = H * D
stride_col = 1
g_ptr = tl.make_block_ptr(
base=g + i_h * D,
shape=(T, D),
strides=(stride_row, stride_col),
offsets=(n_t, 0),
block_shape=(BT, BD),
order=(1, 0),
)
y_ptr = tl.make_block_ptr(
base=y + i_h * D,
shape=(T, D),
strides=(stride_row, stride_col),
offsets=(n_t, 0),
block_shape=(BT, BD),
order=(1, 0),
)
b_g = tl.load(g_ptr, boundary_check=(0, 1)).to(tl.float32)
if HAS_BIAS:
n_d = tl.arange(0, BD)
bias_mask = n_d < D
b_bias = tl.load(g_bias + i_h * D + n_d, mask=bias_mask, other=0.0).to(
tl.float32
)
b_g = b_g + b_bias[None, :]
# softplus(x, beta) = (1/beta) * log(1 + exp(beta * x))
# When beta * x > threshold, use linear approximation x
# Use threshold to switch to linear when beta*x > threshold
g_scaled = b_g * beta
use_linear = g_scaled > threshold
sp = tl.where(use_linear, b_g, (1.0 / beta) * log(1.0 + tl.exp(g_scaled)))
b_y = b_a * sp
tl.store(y_ptr, b_y.to(y.dtype.element_ty), boundary_check=(0, 1))
def fused_kda_gate(
g: torch.Tensor,
A: torch.Tensor,
head_k_dim: int,
g_bias: torch.Tensor | None = None,
beta: float = 1.0,
threshold: float = 20.0,
) -> torch.Tensor:
"""
Forward pass for KDA gate:
input g: [..., H*D]
param A: [H] or [1, 1, H, 1]
beta: softplus beta parameter
threshold: softplus threshold parameter
return : [..., H, D]
"""
orig_shape = g.shape[:-1]
g = g.view(-1, g.shape[-1])
T = g.shape[0]
HD = g.shape[1]
H = A.numel()
assert H * head_k_dim == HD
y = torch.empty_like(g, dtype=torch.float32)
def grid(meta):
return (cdiv(T, meta["BT"]), H)
kda_gate_fwd_kernel[grid](
g,
A,
y,
g_bias,
beta,
threshold,
T,
H,
head_k_dim,
BD=next_power_of_2(head_k_dim),
HAS_BIAS=g_bias is not None,
)
y = y.view(*orig_shape, H, head_k_dim)
return y
from typing import Optional, Union from typing import Optional, Union
import torch import torch
from einops import rearrange
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule
...@@ -10,6 +11,11 @@ from sglang.srt.layers.attention.fla.fused_recurrent import ( ...@@ -10,6 +11,11 @@ from sglang.srt.layers.attention.fla.fused_recurrent import (
from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import ( from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (
fused_sigmoid_gating_delta_rule_update, fused_sigmoid_gating_delta_rule_update,
) )
from sglang.srt.layers.attention.fla.kda import (
chunk_kda,
fused_kda_gate,
fused_recurrent_kda,
)
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import ( from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
PAD_SLOT_ID, PAD_SLOT_ID,
causal_conv1d_fn, causal_conv1d_fn,
...@@ -227,6 +233,223 @@ class MambaAttnBackendBase(AttentionBackend): ...@@ -227,6 +233,223 @@ class MambaAttnBackendBase(AttentionBackend):
return 1 # Mamba attn does not use seq lens to index kv cache return 1 # Mamba attn does not use seq lens to index kv cache
class KimiLinearAttnBackend(MambaAttnBackendBase):
"""Attention backend using Mamba kernel."""
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
**kwargs,
):
q_proj_states = kwargs["q_proj_states"]
k_proj_states = kwargs["k_proj_states"]
v_proj_states = kwargs["v_proj_states"]
q_conv_weights = kwargs["q_conv_weights"]
k_conv_weights = kwargs["k_conv_weights"]
v_conv_weights = kwargs["v_conv_weights"]
q_conv_bias = kwargs["q_conv_bias"]
k_conv_bias = kwargs["k_conv_bias"]
v_conv_bias = kwargs["v_conv_bias"]
A_log = kwargs["A_log"]
dt_bias = kwargs["dt_bias"]
b_proj = kwargs["b_proj"]
f_a_proj = kwargs["f_a_proj"]
f_b_proj = kwargs["f_b_proj"]
hidden_states = kwargs["hidden_states"]
head_dim = kwargs["head_dim"]
layer_id = kwargs["layer_id"]
layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
q_conv_state, k_conv_state, v_conv_state = layer_cache.conv
ssm_states = layer_cache.temporal
query_start_loc = self.forward_metadata.query_start_loc
cache_indices = self.forward_metadata.mamba_cache_indices
q_conv_state = q_conv_state.transpose(-1, -2)
k_conv_state = k_conv_state.transpose(-1, -2)
v_conv_state = v_conv_state.transpose(-1, -2)
q = causal_conv1d_update(
q_proj_states,
q_conv_state,
q_conv_weights,
q_conv_bias,
activation="silu",
conv_state_indices=cache_indices,
)
k = causal_conv1d_update(
k_proj_states,
k_conv_state,
k_conv_weights,
k_conv_bias,
activation="silu",
conv_state_indices=cache_indices,
)
v = causal_conv1d_update(
v_proj_states,
v_conv_state,
v_conv_weights,
v_conv_bias,
activation="silu",
conv_state_indices=cache_indices,
)
q, k, v = map(
lambda x: rearrange(x, "n (h d) -> 1 n h d", d=head_dim), (q, k, v)
)
beta = b_proj(hidden_states)[0].float().sigmoid()
g = f_b_proj(f_a_proj(hidden_states)[0])[0]
g = fused_kda_gate(g, A_log, head_dim, g_bias=dt_bias)
beta = beta.unsqueeze(0)
g = g.unsqueeze(0)
initial_state = ssm_states[cache_indices].contiguous()
(
core_attn_out,
last_recurrent_state,
) = fused_recurrent_kda(
q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=initial_state,
use_qk_l2norm_in_kernel=True,
cu_seqlens=query_start_loc,
)
ssm_states[cache_indices] = last_recurrent_state
return core_attn_out
def forward_extend(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
**kwargs,
):
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
causal_conv1d_fn,
)
q_proj_states = kwargs["q_proj_states"]
k_proj_states = kwargs["k_proj_states"]
v_proj_states = kwargs["v_proj_states"]
q_conv_weights = kwargs["q_conv_weights"]
k_conv_weights = kwargs["k_conv_weights"]
v_conv_weights = kwargs["v_conv_weights"]
q_conv_bias = kwargs["q_conv_bias"]
k_conv_bias = kwargs["k_conv_bias"]
v_conv_bias = kwargs["v_conv_bias"]
A_log = kwargs["A_log"]
dt_bias = kwargs["dt_bias"]
b_proj = kwargs["b_proj"]
f_a_proj = kwargs["f_a_proj"]
f_b_proj = kwargs["f_b_proj"]
hidden_states = kwargs["hidden_states"]
head_dim = kwargs["head_dim"]
layer_id = kwargs["layer_id"]
query_start_loc = self.forward_metadata.query_start_loc
cache_indices = self.forward_metadata.mamba_cache_indices
mamba_cache_params = self.req_to_token_pool.mamba2_layer_cache(layer_id)
conv_state_q, conv_state_k, conv_state_v = mamba_cache_params.conv
# deal with strides
conv_state_q = conv_state_q.transpose(-1, -2)
conv_state_k = conv_state_k.transpose(-1, -2)
conv_state_v = conv_state_v.transpose(-1, -2)
ssm_states = mamba_cache_params.temporal
has_initial_state = forward_batch.extend_prefix_lens > 0
q_proj_states = q_proj_states.transpose(0, 1)
k_proj_states = k_proj_states.transpose(0, 1)
v_proj_states = v_proj_states.transpose(0, 1)
q = causal_conv1d_fn(
q_proj_states,
q_conv_weights,
q_conv_bias,
activation="silu",
conv_states=conv_state_q,
has_initial_state=has_initial_state,
cache_indices=cache_indices,
query_start_loc=query_start_loc,
seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
).transpose(0, 1)
k = causal_conv1d_fn(
k_proj_states,
k_conv_weights,
k_conv_bias,
activation="silu",
conv_states=conv_state_k,
has_initial_state=has_initial_state,
cache_indices=cache_indices,
query_start_loc=query_start_loc,
seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
).transpose(0, 1)
v = causal_conv1d_fn(
v_proj_states,
v_conv_weights,
v_conv_bias,
activation="silu",
conv_states=conv_state_v,
has_initial_state=has_initial_state,
cache_indices=cache_indices,
query_start_loc=query_start_loc,
seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
).transpose(0, 1)
q, k, v = map(
lambda x: rearrange(x, "n (h d) -> 1 n h d", d=head_dim), (q, k, v)
)
beta = b_proj(hidden_states)[0].float().sigmoid()
g = f_b_proj(f_a_proj(hidden_states)[0])[0]
g = fused_kda_gate(g, A_log, head_dim, g_bias=dt_bias)
beta = beta.unsqueeze(0)
g = g.unsqueeze(0)
initial_state = ssm_states[cache_indices].contiguous()
(
core_attn_out,
last_recurrent_state,
) = chunk_kda(
q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=initial_state,
output_final_state=True,
use_qk_l2norm_in_kernel=True,
cu_seqlens=query_start_loc,
)
ssm_states[cache_indices] = last_recurrent_state
return core_attn_out
class GDNAttnBackend(MambaAttnBackendBase): class GDNAttnBackend(MambaAttnBackendBase):
"""Attention backend using Mamba kernel.""" """Attention backend using Mamba kernel."""
......
...@@ -92,7 +92,10 @@ class TritonAttnBackend(AttentionBackend): ...@@ -92,7 +92,10 @@ class TritonAttnBackend(AttentionBackend):
self.num_kv_head = model_runner.model_config.get_num_kv_heads( self.num_kv_head = model_runner.model_config.get_num_kv_heads(
get_attention_tp_size() get_attention_tp_size()
) )
if model_runner.hybrid_gdn_config is not None: if (
model_runner.hybrid_gdn_config is not None
or model_runner.kimi_linear_config is not None
):
# For hybrid linear models, layer_id = 0 may not be full attention # For hybrid linear models, layer_id = 0 may not be full attention
self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim() self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()
else: else:
......
...@@ -17,7 +17,7 @@ from __future__ import annotations ...@@ -17,7 +17,7 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from sglang.srt.configs.mamba_utils import Mamba2CacheParams from sglang.srt.configs.mamba_utils import KimiLinearCacheParams, Mamba2CacheParams
from sglang.srt.layers.attention.nsa import index_buf_accessor from sglang.srt.layers.attention.nsa import index_buf_accessor
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
...@@ -33,7 +33,7 @@ KVCache actually holds the physical kv cache. ...@@ -33,7 +33,7 @@ KVCache actually holds the physical kv cache.
import abc import abc
import logging import logging
from contextlib import nullcontext from contextlib import contextmanager, nullcontext
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -59,7 +59,9 @@ if _is_npu: ...@@ -59,7 +59,9 @@ if _is_npu:
import torch_npu import torch_npu
def get_tensor_size_bytes(t: torch.Tensor): def get_tensor_size_bytes(t: Union[torch.Tensor, List[torch.Tensor]]):
if isinstance(t, list):
return sum(get_tensor_size_bytes(x) for x in t)
return np.prod(t.shape) * t.dtype.itemsize return np.prod(t.shape) * t.dtype.itemsize
...@@ -116,10 +118,15 @@ class ReqToTokenPool: ...@@ -116,10 +118,15 @@ class ReqToTokenPool:
class MambaPool: class MambaPool:
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class State: class State:
conv: torch.Tensor conv: Union[torch.Tensor, List[torch.Tensor]]
temporal: torch.Tensor temporal: torch.Tensor
def at_layer_idx(self, layer: int): def at_layer_idx(self, layer: int):
if isinstance(self.conv, list):
return type(self)(
conv=[v[layer] for v in self.conv],
temporal=self.temporal[layer],
)
return type(self)(**{k: v[layer] for k, v in vars(self).items()}) return type(self)(**{k: v[layer] for k, v in vars(self).items()})
def mem_usage_bytes(self): def mem_usage_bytes(self):
...@@ -127,14 +134,14 @@ class MambaPool: ...@@ -127,14 +134,14 @@ class MambaPool:
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class SpeculativeState(State): class SpeculativeState(State):
intermediate_ssm: torch.Tensor intermediate_ssm: Union[torch.Tensor, List[torch.Tensor]]
intermediate_conv_window: torch.Tensor intermediate_conv_window: torch.Tensor
def __init__( def __init__(
self, self,
*, *,
size: int, size: int,
cache_params: "Mamba2CacheParams", cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"],
device: str, device: str,
speculative_num_draft_tokens: Optional[int] = None, speculative_num_draft_tokens: Optional[int] = None,
): ):
...@@ -157,18 +164,29 @@ class MambaPool: ...@@ -157,18 +164,29 @@ class MambaPool:
else: else:
self.custom_mem_pool = None self.custom_mem_pool = None
self.is_kda_cache = isinstance(cache_params, KimiLinearCacheParams)
with ( with (
torch.cuda.use_mem_pool(self.custom_mem_pool) torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.enable_custom_mem_pool if self.enable_custom_mem_pool
else nullcontext() else nullcontext()
): ):
# assume conv_state = (dim, state_len) if self.is_kda_cache:
assert conv_state_shape[0] > conv_state_shape[1] conv_state = [
conv_state = torch.zeros( torch.zeros(
size=(num_mamba_layers, size + 1) + conv_state_shape, size=(num_mamba_layers, size + 1) + conv_shape,
dtype=conv_dtype, dtype=conv_dtype,
device=device, device=device,
) )
for conv_shape in conv_state_shape
]
else:
# assume conv_state = (dim, state_len)
assert conv_state_shape[0] > conv_state_shape[1]
conv_state = torch.zeros(
size=(num_mamba_layers, size + 1) + conv_state_shape,
dtype=conv_dtype,
device=device,
)
temporal_state = torch.zeros( temporal_state = torch.zeros(
size=(num_mamba_layers, size + 1) + temporal_state_shape, size=(num_mamba_layers, size + 1) + temporal_state_shape,
dtype=ssm_dtype, dtype=ssm_dtype,
...@@ -191,17 +209,34 @@ class MambaPool: ...@@ -191,17 +209,34 @@ class MambaPool:
) )
# Cache intermediate conv windows (last K-1 inputs) per draft token during target verify # Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1] # Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
intermediate_conv_window_cache = torch.zeros(
size=( if self.is_kda_cache:
num_mamba_layers, intermediate_conv_window_cache = [
size + 1, torch.zeros(
speculative_num_draft_tokens, size=(
conv_state_shape[0], num_mamba_layers,
conv_state_shape[1], size + 1,
), speculative_num_draft_tokens,
dtype=conv_dtype, conv_shape[0],
device="cuda", conv_shape[1],
) ),
dtype=conv_dtype,
device="cuda",
)
for conv_shape in conv_state_shape
]
else:
intermediate_conv_window_cache = torch.zeros(
size=(
num_mamba_layers,
size + 1,
speculative_num_draft_tokens,
conv_state_shape[0],
conv_state_shape[1],
),
dtype=conv_dtype,
device="cuda",
)
self.mamba_cache = self.SpeculativeState( self.mamba_cache = self.SpeculativeState(
conv=conv_state, conv=conv_state,
temporal=temporal_state, temporal=temporal_state,
...@@ -255,15 +290,25 @@ class MambaPool: ...@@ -255,15 +290,25 @@ class MambaPool:
if free_index.numel() == 0: if free_index.numel() == 0:
return return
self.free_slots = torch.cat((self.free_slots, free_index)) self.free_slots = torch.cat((self.free_slots, free_index))
self.mamba_cache.conv[:, free_index] = self.mamba_cache.temporal[ if self.is_kda_cache:
:, free_index for i in range(len(self.mamba_cache.conv)):
] = 0 self.mamba_cache.conv[i][:, free_index] = 0
else:
self.mamba_cache.conv[:, free_index] = 0
self.mamba_cache.temporal[:, free_index] = 0
def clear(self): def clear(self):
self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device) self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device)
def copy_from(self, src_index: torch.Tensor, dst_index: torch.Tensor): def copy_from(self, src_index: torch.Tensor, dst_index: torch.Tensor):
self.mamba_cache.conv[:, dst_index] = self.mamba_cache.conv[:, src_index] if self.is_kda_cache:
for i in range(len(self.mamba_cache.conv)):
self.mamba_cache.conv[i][:, dst_index] = self.mamba_cache.conv[i][
:, src_index
]
else:
self.mamba_cache.conv[:, dst_index] = self.mamba_cache.conv[:, src_index]
self.mamba_cache.temporal[:, dst_index] = self.mamba_cache.temporal[ self.mamba_cache.temporal[:, dst_index] = self.mamba_cache.temporal[
:, src_index :, src_index
] ]
...@@ -304,7 +349,7 @@ class HybridReqToTokenPool(ReqToTokenPool): ...@@ -304,7 +349,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
max_context_len: int, max_context_len: int,
device: str, device: str,
enable_memory_saver: bool, enable_memory_saver: bool,
cache_params: "Mamba2CacheParams", cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"],
speculative_num_draft_tokens: int = None, speculative_num_draft_tokens: int = None,
): ):
super().__init__( super().__init__(
...@@ -323,7 +368,7 @@ class HybridReqToTokenPool(ReqToTokenPool): ...@@ -323,7 +368,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
def _init_mamba_pool( def _init_mamba_pool(
self, self,
size: int, size: int,
cache_params: "Mamba2CacheParams", cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"],
device: str, device: str,
speculative_num_draft_tokens: int = None, speculative_num_draft_tokens: int = None,
): ):
...@@ -812,6 +857,10 @@ class HybridLinearKVPool(KVCache): ...@@ -812,6 +857,10 @@ class HybridLinearKVPool(KVCache):
enable_kvcache_transpose: bool, enable_kvcache_transpose: bool,
device: str, device: str,
mamba_pool: MambaPool, mamba_pool: MambaPool,
# TODO: refactor mla related args
use_mla: bool = False,
kv_lora_rank: int = None,
qk_rope_head_dim: int = None,
): ):
self.size = size self.size = size
self.dtype = dtype self.dtype = dtype
...@@ -825,25 +874,42 @@ class HybridLinearKVPool(KVCache): ...@@ -825,25 +874,42 @@ class HybridLinearKVPool(KVCache):
self.mamba_pool = mamba_pool self.mamba_pool = mamba_pool
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
assert not enable_kvcache_transpose assert not enable_kvcache_transpose
if _is_npu: self.use_mla = use_mla
TokenToKVPoolClass = AscendTokenToKVPool if not use_mla:
if _is_npu:
TokenToKVPoolClass = AscendTokenToKVPool
else:
TokenToKVPoolClass = MHATokenToKVPool
self.full_kv_pool = TokenToKVPoolClass(
size=size,
page_size=self.page_size,
dtype=dtype,
head_num=head_num,
head_dim=head_dim,
layer_num=self.full_layer_nums,
device=device,
enable_memory_saver=False,
)
else: else:
TokenToKVPoolClass = MHATokenToKVPool TokenToKVPoolClass = MLATokenToKVPool
self.full_kv_pool = TokenToKVPoolClass( self.full_kv_pool = TokenToKVPoolClass(
size=size, size=size,
page_size=self.page_size, page_size=self.page_size,
dtype=dtype, dtype=dtype,
head_num=head_num, layer_num=self.full_layer_nums,
head_dim=head_dim, device=device,
layer_num=self.full_layer_nums, kv_lora_rank=kv_lora_rank,
device=device, qk_rope_head_dim=qk_rope_head_dim,
enable_memory_saver=False, enable_memory_saver=False,
) )
self.full_attention_layer_id_mapping = { self.full_attention_layer_id_mapping = {
id: i for i, id in enumerate(full_attention_layer_ids) id: i for i, id in enumerate(full_attention_layer_ids)
} }
k_size, v_size = self.get_kv_size_bytes() if use_mla:
self.mem_usage = (k_size + v_size) / GB self.mem_usage = self.get_kv_size_bytes() / GB
else:
k_size, v_size = self.get_kv_size_bytes()
self.mem_usage = (k_size + v_size) / GB
def get_kv_size_bytes(self): def get_kv_size_bytes(self):
return self.full_kv_pool.get_kv_size_bytes() return self.full_kv_pool.get_kv_size_bytes()
...@@ -879,6 +945,21 @@ class HybridLinearKVPool(KVCache): ...@@ -879,6 +945,21 @@ class HybridLinearKVPool(KVCache):
layer_id = self._transfer_full_attention_id(layer_id) layer_id = self._transfer_full_attention_id(layer_id)
return self.full_kv_pool.get_kv_buffer(layer_id) return self.full_kv_pool.get_kv_buffer(layer_id)
@contextmanager
def _transfer_id_context(self, layer: RadixAttention):
@contextmanager
def _patch_layer_id(layer):
original_layer_id = layer.layer_id
layer.layer_id = self._transfer_full_attention_id(layer.layer_id)
try:
yield
finally:
layer.layer_id = original_layer_id
with _patch_layer_id(layer):
yield
def set_kv_buffer( def set_kv_buffer(
self, self,
layer: RadixAttention, layer: RadixAttention,
...@@ -889,19 +970,49 @@ class HybridLinearKVPool(KVCache): ...@@ -889,19 +970,49 @@ class HybridLinearKVPool(KVCache):
v_scale: float = 1.0, v_scale: float = 1.0,
): ):
layer_id = self._transfer_full_attention_id(layer.layer_id) layer_id = self._transfer_full_attention_id(layer.layer_id)
self.full_kv_pool.set_kv_buffer( if not self.use_mla:
None, self.full_kv_pool.set_kv_buffer(
loc, None,
cache_k, loc,
cache_v, cache_k,
k_scale, cache_v,
v_scale, k_scale,
layer_id_override=layer_id, v_scale,
) layer_id_override=layer_id,
)
else:
with self._transfer_id_context(layer):
self.full_kv_pool.set_kv_buffer(
layer,
loc,
cache_k,
cache_v,
)
def get_v_head_dim(self): def get_v_head_dim(self):
return self.full_kv_pool.get_value_buffer(0).shape[-1] return self.full_kv_pool.get_value_buffer(0).shape[-1]
def set_mla_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k_nope: torch.Tensor,
cache_k_rope: torch.Tensor,
):
assert self.use_mla, "set_mla_kv_buffer called when use_mla is False"
with self._transfer_id_context(layer):
self.full_kv_pool.set_mla_kv_buffer(layer, loc, cache_k_nope, cache_k_rope)
def get_mla_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
dst_dtype: Optional[torch.dtype] = None,
):
assert self.use_mla, "get_mla_kv_buffer called when use_mla is False"
with self._transfer_id_context(layer):
return self.full_kv_pool.get_mla_kv_buffer(layer, loc, dst_dtype)
class SWAKVPool(KVCache): class SWAKVPool(KVCache):
"""KV cache with separate pools for full and SWA attention layers.""" """KV cache with separate pools for full and SWA attention layers."""
......
...@@ -29,7 +29,12 @@ from typing import Callable, List, Optional, Tuple, Union ...@@ -29,7 +29,12 @@ from typing import Callable, List, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from sglang.srt.configs import FalconH1Config, NemotronHConfig, Qwen3NextConfig from sglang.srt.configs import (
FalconH1Config,
KimiLinearConfig,
NemotronHConfig,
Qwen3NextConfig,
)
from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig, LoadFormat from sglang.srt.configs.load_config import LoadConfig, LoadFormat
from sglang.srt.configs.model_config import ( from sglang.srt.configs.model_config import (
...@@ -1358,9 +1363,16 @@ class ModelRunner: ...@@ -1358,9 +1363,16 @@ class ModelRunner:
return config return config
return None return None
@property
def kimi_linear_config(self):
config = self.model_config.hf_config
if isinstance(config, KimiLinearConfig):
return config
return None
@property @property
def mambaish_config(self): def mambaish_config(self):
return self.mamba2_config or self.hybrid_gdn_config return self.mamba2_config or self.hybrid_gdn_config or self.kimi_linear_config
def set_num_token_hybrid(self): def set_num_token_hybrid(self):
if ( if (
...@@ -1691,7 +1703,7 @@ class ModelRunner: ...@@ -1691,7 +1703,7 @@ class ModelRunner:
end_layer=self.end_layer, end_layer=self.end_layer,
index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config), index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config),
) )
elif self.use_mla_backend: elif self.use_mla_backend and not self.mambaish_config:
assert not is_nsa_model assert not is_nsa_model
self.token_to_kv_pool = MLATokenToKVPool( self.token_to_kv_pool = MLATokenToKVPool(
self.max_total_num_tokens, self.max_total_num_tokens,
...@@ -1735,6 +1747,12 @@ class ModelRunner: ...@@ -1735,6 +1747,12 @@ class ModelRunner:
device=self.device, device=self.device,
) )
elif config := self.mambaish_config: elif config := self.mambaish_config:
extra_args = {}
if self.use_mla_backend:
extra_args = {
"kv_lora_rank": self.model_config.kv_lora_rank,
"qk_rope_head_dim": self.model_config.qk_rope_head_dim,
}
self.token_to_kv_pool = HybridLinearKVPool( self.token_to_kv_pool = HybridLinearKVPool(
page_size=self.page_size, page_size=self.page_size,
size=self.max_total_num_tokens, size=self.max_total_num_tokens,
...@@ -1750,6 +1768,8 @@ class ModelRunner: ...@@ -1750,6 +1768,8 @@ class ModelRunner:
enable_kvcache_transpose=False, enable_kvcache_transpose=False,
device=self.device, device=self.device,
mamba_pool=self.req_to_token_pool.mamba_pool, mamba_pool=self.req_to_token_pool.mamba_pool,
use_mla=self.use_mla_backend,
**extra_args,
) )
else: else:
self.token_to_kv_pool = MHATokenToKVPool( self.token_to_kv_pool = MHATokenToKVPool(
......
...@@ -1075,6 +1075,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1075,6 +1075,7 @@ class DeepseekV2AttentionMLA(nn.Module):
layer_id: int = None, layer_id: int = None,
prefix: str = "", prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None, alt_stream: Optional[torch.cuda.Stream] = None,
skip_rope: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.layer_id = layer_id self.layer_id = layer_id
...@@ -1182,23 +1183,26 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1182,23 +1183,26 @@ class DeepseekV2AttentionMLA(nn.Module):
) )
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
self.rotary_emb = get_rope_wrapper( if not skip_rope:
qk_rope_head_dim, self.rotary_emb = get_rope_wrapper(
rotary_dim=qk_rope_head_dim, qk_rope_head_dim,
max_position=max_position_embeddings, rotary_dim=qk_rope_head_dim,
base=rope_theta, max_position=max_position_embeddings,
rope_scaling=rope_scaling, base=rope_theta,
is_neox_style=False, rope_scaling=rope_scaling,
device=get_global_server_args().device, is_neox_style=False,
) device=get_global_server_args().device,
)
if rope_scaling: if rope_scaling:
mscale_all_dim = rope_scaling.get("mscale_all_dim", False) mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale self.scaling = self.scaling * mscale * mscale
else:
self.rotary_emb.forward = self.rotary_emb.forward_native
else: else:
self.rotary_emb.forward = self.rotary_emb.forward_native self.rotary_emb = None
self.attn_mqa = RadixAttention( self.attn_mqa = RadixAttention(
self.num_local_heads, self.num_local_heads,
...@@ -1487,7 +1491,8 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1487,7 +1491,8 @@ class DeepseekV2AttentionMLA(nn.Module):
latent_cache = latent_cache.unsqueeze(1) latent_cache = latent_cache.unsqueeze(1)
kv_a = self.kv_a_layernorm(kv_a) kv_a = self.kv_a_layernorm(kv_a)
k_pe = latent_cache[:, :, self.kv_lora_rank :] k_pe = latent_cache[:, :, self.kv_lora_rank :]
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) if self.rotary_emb is not None:
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
q[..., self.qk_nope_head_dim :] = q_pe q[..., self.qk_nope_head_dim :] = q_pe
self._set_mla_kv_buffer(latent_cache, kv_a, k_pe, forward_batch) self._set_mla_kv_buffer(latent_cache, kv_a, k_pe, forward_batch)
...@@ -1646,8 +1651,10 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1646,8 +1651,10 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out = q_nope_out.transpose(0, 1) q_nope_out = q_nope_out.transpose(0, 1)
if not self._fuse_rope_for_trtllm_mla(forward_batch) and ( if (
not _use_aiter or not _is_gfx95_supported or self.use_nsa self.rotary_emb is not None
and (not self._fuse_rope_for_trtllm_mla(forward_batch))
and (not _use_aiter or not _is_gfx95_supported or self.use_nsa)
): ):
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
......
# Adapted from: https://github.com/vllm-project/vllm/blob/0384aa7150c4c9778efca041ffd1beb3ad2bd694/vllm/model_executor/models/kimi_linear.py
from collections.abc import Iterable
from typing import Optional
import torch
from einops import rearrange
from torch import nn
from sglang.srt.configs.kimi_linear import KimiLinearConfig
from sglang.srt.distributed import (
divide,
get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.attention.fla.kda import FusedRMSNormGated
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.utils import PPMissingLayer
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
sharded_weight_loader,
)
from sglang.srt.models.deepseek_v2 import DeepseekV2AttentionMLA as KimiMLAAttention
from sglang.srt.models.llama import LlamaMLP as KimiMLP
from sglang.srt.models.transformers import maybe_prefix
from sglang.srt.utils import make_layers
from sglang.srt.utils.common import BumpAllocator, add_prefix, set_weight_attrs
class KimiMoE(nn.Module):
def __init__(
self,
config: KimiLinearConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
layer_idx: int = 0,
):
super().__init__()
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size
moe_intermediate_size = config.moe_intermediate_size
num_experts = config.num_experts
moe_renormalize = config.moe_renormalize
self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor
self.num_shared_experts = config.num_shared_experts
self.layer_idx = layer_idx
if config.hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {config.hidden_act}. "
"Only silu is supported for now."
)
# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(
hidden_size,
num_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate",
)
self.gate.e_score_correction_bias = nn.Parameter(torch.empty(num_experts))
self.experts = get_moe_impl_class(quant_config)(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_token,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
layer_id=self.layer_idx,
quant_config=quant_config,
routed_scaling_factor=self.routed_scaling_factor,
prefix=add_prefix("experts", prefix),
)
self.topk = TopK(
top_k=config.num_experts_per_token,
renormalize=moe_renormalize,
use_grouped_topk=True,
num_expert_group=config.num_expert_group,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
quant_config=quant_config,
routed_scaling_factor=self.routed_scaling_factor,
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk,
# Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
# and requires the output format to be standard. We use quant_config to determine the output format.
output_format=TopKOutputFormat.STANDARD if quant_config is None else None,
)
if self.num_shared_experts is not None:
intermediate_size = moe_intermediate_size * self.num_shared_experts
self.shared_experts = KimiMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_size)
if self.num_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(hidden_states, topk_output)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_size)
class KimiDeltaAttention(nn.Module):
def __init__(
self,
layer_idx: int,
hidden_size: int,
config: KimiLinearConfig,
quant_config: Optional[QuantizationConfig] = None,
rms_norm_eps: float = 1e-5,
prefix: str = "",
**kwargs,
) -> None:
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.hidden_size = hidden_size
self.config = config
self.head_dim = config.linear_attn_config["head_dim"]
self.num_heads = config.linear_attn_config["num_heads"]
self.layer_idx = layer_idx
self.prefix = prefix
assert self.num_heads % self.tp_size == 0
self.local_num_heads = divide(self.num_heads, self.tp_size)
projection_size = self.head_dim * self.num_heads
self.conv_size = config.linear_attn_config["short_conv_kernel_size"]
self.q_proj = ColumnParallelLinear(
self.hidden_size,
projection_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
)
self.k_proj = ColumnParallelLinear(
self.hidden_size,
projection_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.k_proj",
)
self.v_proj = ColumnParallelLinear(
self.hidden_size,
projection_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.v_proj",
)
self.f_a_proj = ReplicatedLinear(
self.hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.f_a_proj",
)
self.f_b_proj = ColumnParallelLinear(
self.head_dim,
projection_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.f_b_proj",
)
self.dt_bias = nn.Parameter(
torch.empty(divide(projection_size, self.tp_size), dtype=torch.float32)
)
set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
self.b_proj = ColumnParallelLinear(
self.hidden_size,
self.num_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.b_proj",
)
self.q_conv1d = ColumnParallelLinear(
input_size=self.conv_size,
output_size=projection_size,
bias=False,
params_dtype=torch.float32,
prefix=f"{prefix}.q_conv1d",
)
self.k_conv1d = ColumnParallelLinear(
input_size=self.conv_size,
output_size=projection_size,
bias=False,
params_dtype=torch.float32,
prefix=f"{prefix}.k_conv1d",
)
self.v_conv1d = ColumnParallelLinear(
input_size=self.conv_size,
output_size=projection_size,
bias=False,
params_dtype=torch.float32,
prefix=f"{prefix}.v_conv1d",
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self.q_conv1d.weight.data = self.q_conv1d.weight.data.unsqueeze(1)
self.k_conv1d.weight.data = self.k_conv1d.weight.data.unsqueeze(1)
self.v_conv1d.weight.data = self.v_conv1d.weight.data.unsqueeze(1)
self.A_log = nn.Parameter(
torch.empty(1, 1, self.local_num_heads, 1, dtype=torch.float32)
)
set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(2)})
self.g_a_proj = ReplicatedLinear(
self.hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.g_a_proj",
)
self.g_b_proj = ColumnParallelLinear(
self.head_dim,
projection_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.g_b_proj",
)
self.o_norm = FusedRMSNormGated(
self.head_dim, eps=rms_norm_eps, activation="sigmoid"
)
self.o_proj = RowParallelLinear(
projection_size,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
def forward(
self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
) -> None:
q_proj_states = self.q_proj(hidden_states)[0]
k_proj_states = self.k_proj(hidden_states)[0]
v_proj_states = self.v_proj(hidden_states)[0]
q_conv_weights = self.q_conv1d.weight.view(
self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2)
)
k_conv_weights = self.k_conv1d.weight.view(
self.k_conv1d.weight.size(0), self.k_conv1d.weight.size(2)
)
v_conv_weights = self.v_conv1d.weight.view(
self.v_conv1d.weight.size(0), self.v_conv1d.weight.size(2)
)
kwargs = {
"q_proj_states": q_proj_states,
"k_proj_states": k_proj_states,
"v_proj_states": v_proj_states,
"q_conv_weights": q_conv_weights,
"k_conv_weights": k_conv_weights,
"v_conv_weights": v_conv_weights,
"q_conv_bias": self.q_conv1d.bias,
"k_conv_bias": self.k_conv1d.bias,
"v_conv_bias": self.v_conv1d.bias,
"dt_bias": self.dt_bias,
"b_proj": self.b_proj,
"f_a_proj": self.f_a_proj,
"f_b_proj": self.f_b_proj,
"A_log": self.A_log,
"head_dim": self.head_dim,
"hidden_states": hidden_states,
"layer_id": self.layer_idx,
}
core_attn_out = forward_batch.attn_backend.forward(
q=None,
k=None,
v=None,
layer=None,
forward_batch=forward_batch,
**kwargs,
)
g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0]
g = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim)
core_attn_out = self.o_norm(core_attn_out, g)
core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")
return self.o_proj(core_attn_out)[0]
class KimiDecoderLayer(nn.Module):
def __init__(
self,
config: KimiLinearConfig,
layer_idx: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.is_moe = config.is_moe
if config.is_kda_layer(layer_idx):
self.self_attn = KimiDeltaAttention(
layer_idx=layer_idx,
hidden_size=config.hidden_size,
config=config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
else:
self.self_attn = KimiMLAAttention(
layer_id=layer_idx,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
config=config,
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
v_head_dim=config.v_head_dim,
q_lora_rank=config.q_lora_rank,
kv_lora_rank=config.kv_lora_rank,
skip_rope=True,
)
if (
self.is_moe
and config.num_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0
):
self.block_sparse_moe = KimiMoE(
config=config,
quant_config=quant_config,
layer_idx=layer_idx,
prefix=f"{prefix}.mlp",
)
self.mlp = self.block_sparse_moe
else:
self.mlp = KimiMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator,
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
hidden_states=hidden_states,
positions=positions,
forward_batch=forward_batch,
zero_allocator=zero_allocator,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class KimiLinearModel(nn.Module):
def __init__(
self,
config: KimiLinearConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.pp_group = get_pp_group()
if self.pp_group.is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
prefix=f"{prefix}.embed_tokens",
)
else:
self.embed_tokens = PPMissingLayer()
self.layers, self.start_layer, self.end_layer = make_layers(
config.num_hidden_layers,
lambda idx, prefix: KimiDecoderLayer(
layer_idx=idx,
config=config,
quant_config=quant_config,
prefix=prefix,
),
pp_rank=self.pp_group.rank_in_group,
pp_size=self.pp_group.world_size,
prefix=f"{prefix}.layers",
)
if self.pp_group.is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
world_size = get_tensor_model_parallel_world_size()
assert (
config.num_attention_heads % world_size == 0
), "num_attention_heads must be divisible by world_size"
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
forward_batch: ForwardBatch,
inputs_embeds: torch.Tensor | None = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> torch.Tensor:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_tokens(input_ids)
residual = None
else:
assert pp_proxy_tensors is not None
hidden_states = pp_proxy_tensors["hidden_states"]
residual = pp_proxy_tensors["residual"]
total_num_layers = self.end_layer - self.start_layer
device = hidden_states.device
zero_allocator = BumpAllocator(
buffer_size=total_num_layers * 2,
dtype=torch.float32,
device=device,
)
# TODO: capture aux hidden states
aux_hidden_states = []
for i in range(self.start_layer, self.end_layer):
ctx = get_global_expert_distribution_recorder().with_current_layer(i)
with ctx:
layer = self.layers[i]
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
residual=residual,
zero_allocator=zero_allocator,
)
if not self.pp_group.is_last_rank:
return PPProxyTensors(
{
"hidden_states": hidden_states,
"residual": residual,
}
)
else:
if hidden_states.shape[0] != 0:
if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) == 0:
return hidden_states
return hidden_states, aux_hidden_states
class KimiLinearForCausalLM(nn.Module):
def __init__(
self,
config: KimiLinearConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = KimiLinearModel(
config, quant_config, prefix=maybe_prefix(prefix, "model")
)
self.pp_group = get_pp_group()
if self.pp_group.is_last_rank:
self.lm_head = ParallelLMHead(
self.config.vocab_size,
self.config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
else:
self.lm_head = PPMissingLayer()
logit_scale = getattr(self.config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(config=config, logit_scale=logit_scale)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
inputs_embeds: Optional[torch.Tensor] = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(
input_ids,
positions,
forward_batch,
inputs_embeds,
pp_proxy_tensors,
)
if self.pp_group.is_last_rank:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
else:
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
if self.config.is_moe:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
num_experts=self.config.num_experts,
)
else:
expert_params_mapping = []
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for args in weights:
name, loaded_weight = args[:2]
kwargs = args[2] if len(args) > 2 else {}
if "rotary_emb.inv_freq" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# if is_pp_missing_parameter(name, self):
# continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for idx, (param_name, weight_name, expert_id, shard_id) in enumerate(
expert_params_mapping
):
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# if is_pp_missing_parameter(name, self):
# continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
expert_id=expert_id,
shard_id=shard_id,
)
break
else:
# Skip loading extra bias for GPTQ models.
if (
name.endswith(".bias")
and name not in params_dict
and not self.config.is_linear_attn
): # noqa: E501
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
# if is_pp_missing_parameter(name, self):
# continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight, **kwargs)
loaded_params.add(name)
for layer_id in self.config.full_attention_layer_ids:
self_attn = self.model.layers[layer_id].self_attn
w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
if hasattr(self_attn.kv_b_proj, "weight_scale"):
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
EntryClass = KimiLinearForCausalLM
...@@ -1028,6 +1028,11 @@ class ServerArgs: ...@@ -1028,6 +1028,11 @@ class ServerArgs:
logger.info( logger.info(
f"Using {self.attention_backend} as attention backend for {model_arch}." f"Using {self.attention_backend} as attention backend for {model_arch}."
) )
elif model_arch in ["KimiLinearForCausalLM"]:
logger.warning(
f"Disabling Radix Cache for {model_arch} as it is not yet supported."
)
self.disable_radix_cache = True
if is_deepseek_nsa(hf_config): if is_deepseek_nsa(hf_config):
if ( if (
......
...@@ -43,6 +43,7 @@ from sglang.srt.configs import ( ...@@ -43,6 +43,7 @@ from sglang.srt.configs import (
DotsVLMConfig, DotsVLMConfig,
ExaoneConfig, ExaoneConfig,
FalconH1Config, FalconH1Config,
KimiLinearConfig,
KimiVLConfig, KimiVLConfig,
LongcatFlashConfig, LongcatFlashConfig,
MultiModalityConfig, MultiModalityConfig,
...@@ -68,6 +69,7 @@ _CONFIG_REGISTRY: List[Type[PretrainedConfig]] = [ ...@@ -68,6 +69,7 @@ _CONFIG_REGISTRY: List[Type[PretrainedConfig]] = [
Step3VLConfig, Step3VLConfig,
LongcatFlashConfig, LongcatFlashConfig,
Olmo3Config, Olmo3Config,
KimiLinearConfig,
Qwen3NextConfig, Qwen3NextConfig,
FalconH1Config, FalconH1Config,
DotsVLMConfig, DotsVLMConfig,
......
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestKimiLinear(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "moonshotai/Kimi-Linear-48B-A3B-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--tp-size", "2", "--trust-remote"],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.88)
if __name__ == "__main__":
unittest.main()
...@@ -151,6 +151,7 @@ suites = { ...@@ -151,6 +151,7 @@ suites = {
TestFile("layers/attention/mamba/test_mamba2_mixer.py", 50), TestFile("layers/attention/mamba/test_mamba2_mixer.py", 50),
TestFile("lora/test_lora_tp.py", 116), TestFile("lora/test_lora_tp.py", 116),
TestFile("models/test_glm4_moe_models.py", 100), TestFile("models/test_glm4_moe_models.py", 100),
TestFile("models/test_kimi_linear_models.py", 90),
TestFile("rl/test_update_weights_from_distributed.py", 103), TestFile("rl/test_update_weights_from_distributed.py", 103),
TestFile("test_data_parallelism.py", 73), TestFile("test_data_parallelism.py", 73),
TestFile("test_disaggregation_basic.py", 400), TestFile("test_disaggregation_basic.py", 400),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment