Unverified Commit a4bf5c6a authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Support Kimi Linear (#12469)


Co-authored-by: default avataryizhang2077 <1109276519@qq.com>
parent 30ad1070
......@@ -6,6 +6,7 @@ from sglang.srt.configs.dots_vlm import DotsVLMConfig
from sglang.srt.configs.exaone import ExaoneConfig
from sglang.srt.configs.falcon_h1 import FalconH1Config
from sglang.srt.configs.janus_pro import MultiModalityConfig
from sglang.srt.configs.kimi_linear import KimiLinearConfig
from sglang.srt.configs.kimi_vl import KimiVLConfig
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
from sglang.srt.configs.longcat_flash import LongcatFlashConfig
......@@ -31,6 +32,7 @@ __all__ = [
"Step3TextConfig",
"Step3VisionEncoderConfig",
"Olmo3Config",
"KimiLinearConfig",
"Qwen3NextConfig",
"DotsVLMConfig",
"DotsOCRConfig",
......
# Adapted from: https://github.com/vllm-project/vllm/blob/0384aa7150c4c9778efca041ffd1beb3ad2bd694/vllm/transformers_utils/configs/kimi_linear.py
from transformers.configuration_utils import PretrainedConfig
from sglang.srt.configs.mamba_utils import KimiLinearCacheParams, KimiLinearStateShape
from sglang.srt.layers.dp_attention import get_attention_tp_size
class KimiLinearConfig(PretrainedConfig):
model_type = "kimi_linear"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
model_type="kimi_linear",
vocab_size=163840,
hidden_size=4096,
head_dim=None,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
rope_theta=10000.0,
rope_scaling=None,
tie_word_embeddings=False,
moe_intermediate_size: int | None = None,
moe_renormalize: bool = True,
moe_router_activation_func: str = "sigmoid",
num_experts: int | None = None,
num_experts_per_token: int | None = None,
num_shared_experts: int = 0,
routed_scaling_factor: float = 1.0,
first_k_dense_replace: int = 0,
moe_layer_freq: int = 1,
use_grouped_topk: bool = True,
num_expert_group: int = 1,
topk_group: int = 1,
q_lora_rank: int | None = None,
kv_lora_rank: int | None = None,
qk_nope_head_dim: int | None = None,
qk_rope_head_dim: int | None = None,
v_head_dim: int | None = None,
mla_use_nope: bool | None = False,
num_nextn_predict_layers: int = 0,
linear_attn_config: dict | None = None,
**kwargs,
):
self.model_type = model_type
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.head_dim = (
head_dim if head_dim is not None else hidden_size // num_attention_heads
)
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.mla_use_nope = mla_use_nope
# moe config
self.n_routed_experts = self.num_experts = num_experts
self.num_experts_per_token = num_experts_per_token
self.moe_renormalize = moe_renormalize
self.num_shared_experts = num_shared_experts
self.routed_scaling_factor = routed_scaling_factor
self.moe_router_activation_func = moe_router_activation_func
assert self.moe_router_activation_func in ("softmax", "sigmoid")
self.moe_intermediate_size = moe_intermediate_size
self.first_k_dense_replace = first_k_dense_replace
self.moe_layer_freq = moe_layer_freq
self.use_grouped_topk = use_grouped_topk
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.num_nextn_predict_layers = num_nextn_predict_layers
if linear_attn_config is not None:
assert linear_attn_config["kda_layers"] is not None
assert linear_attn_config["full_attn_layers"] is not None
self.linear_attn_config = linear_attn_config
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
@property
def is_mla(self):
return (
self.q_lora_rank is not None
or self.kv_lora_rank is not None
or self.qk_nope_head_dim is not None
or self.qk_rope_head_dim is not None
or self.v_head_dim is not None
or self.mla_use_nope is True
)
@property
def is_moe(self):
return self.num_experts is not None
@property
def is_linear_attn(self) -> bool:
return not (
self.linear_attn_config is None
or (
isinstance(self.linear_attn_config, dict)
and self.linear_attn_config["kda_layers"] is not None
and len(self.linear_attn_config["kda_layers"]) == 0
)
)
def is_kda_layer(self, layer_idx: int):
return (
self.linear_attn_config is not None
and (layer_idx + 1) in self.linear_attn_config["kda_layers"]
)
@property
def linear_layer_ids(self):
return [i for i in range(self.num_hidden_layers) if self.is_kda_layer(i)]
@property
def full_attention_layer_ids(self):
return [i for i in range(self.num_hidden_layers) if not self.is_kda_layer(i)]
@property
def mamba2_cache_params(self) -> KimiLinearCacheParams:
shape = KimiLinearStateShape.create(
tp_world_size=get_attention_tp_size(),
num_heads=self.linear_attn_config["num_heads"],
head_dim=self.linear_attn_config["head_dim"],
conv_kernel_size=self.linear_attn_config["short_conv_kernel_size"],
)
return KimiLinearCacheParams(shape=shape, layers=self.linear_layer_ids)
......@@ -14,6 +14,7 @@
import os
from dataclasses import dataclass, field
from typing import List, Optional
import numpy as np
import torch
......@@ -115,3 +116,68 @@ class Mamba2CacheParams:
int(np.prod(self.shape.conv)) * self.dtype.conv.itemsize
+ int(np.prod(self.shape.temporal)) * self.dtype.temporal.itemsize
) * len(self.layers)
@dataclass(kw_only=True, frozen=True)
class KimiLinearStateShape:
conv: List[tuple[int, int]]
temporal: tuple[int, int, int]
num_heads: int
head_dim: int
num_k_heads: int
head_k_dim: int
conv_kernel: int
num_spec: int
@staticmethod
def create(
*,
tp_world_size: int,
num_heads: int,
head_dim: int,
num_k_heads: Optional[int] = None,
head_k_dim: Optional[int] = None,
conv_kernel_size: int = 4,
num_spec: int = 0,
) -> "KimiLinearStateShape":
if num_k_heads is None:
num_k_heads = num_heads
if head_k_dim is None:
head_k_dim = head_dim
proj_size = num_heads * head_dim
proj_k_size = num_k_heads * head_k_dim
conv_state_shape = (divide(proj_size, tp_world_size), conv_kernel_size - 1)
conv_state_k_shape = (divide(proj_k_size, tp_world_size), conv_kernel_size - 1)
temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, head_dim)
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
conv_state_k_shape = conv_state_k_shape[1], conv_state_k_shape[0]
return KimiLinearStateShape(
conv=[conv_state_shape, conv_state_k_shape, conv_state_k_shape],
temporal=temporal_state_shape,
num_heads=num_heads,
head_dim=head_dim,
num_k_heads=num_k_heads,
head_k_dim=head_k_dim,
conv_kernel=conv_kernel_size,
num_spec=num_spec,
)
@dataclass(kw_only=True, frozen=True)
class KimiLinearCacheParams:
shape: KimiLinearStateShape
dtype: Mamba2StateDType = field(default_factory=mamba2_state_dtype)
layers: list[int]
@property
def mamba_cache_per_req(self) -> int:
return (
int(np.sum([np.prod(conv_shape) for conv_shape in self.shape.conv]))
* self.dtype.conv.itemsize
+ int(np.prod(self.shape.temporal)) * self.dtype.temporal.itemsize
) * len(self.layers)
......@@ -366,6 +366,13 @@ class ModelConfig:
self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
self.v_head_dim = self.hf_text_config.v_head_dim
self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim
elif "KimiLinearForCausalLM" in self.hf_config.architectures:
self.head_dim = 72
self.attention_arch = AttentionArch.MLA
self.kv_lora_rank = self.hf_config.kv_lora_rank
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
self.v_head_dim = self.hf_config.v_head_dim
self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
else:
if (
"MistralModel" in self.hf_config.architectures
......
......@@ -189,6 +189,7 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
GDNAttnBackend,
HybridLinearAttnBackend,
KimiLinearAttnBackend,
Mamba2AttnBackend,
)
from sglang.srt.utils import is_blackwell, is_npu
......@@ -207,6 +208,8 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac
linear_attn_backend = GDNAttnBackend(runner)
elif runner.mamba2_config is not None:
linear_attn_backend = Mamba2AttnBackend(runner)
elif runner.kimi_linear_config is not None:
linear_attn_backend = KimiLinearAttnBackend(runner)
else:
raise ValueError(
"Expected hybrid GDN or NemotronH models, but got unknown model."
......
......@@ -21,6 +21,7 @@ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_GK": lambda args: args["gk"] is not None,
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
......@@ -44,6 +45,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
w,
v_new,
g,
gk,
h,
h0,
ht,
......@@ -57,6 +59,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
BT: tl.constexpr,
BV: tl.constexpr,
USE_G: tl.constexpr,
USE_GK: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr,
SAVE_NEW_VALUE: tl.constexpr,
......@@ -86,12 +89,12 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
b_h4 = tl.zeros([64, BV], dtype=tl.float32)
# calculate offset
h += (boh * H + i_h) * K * V
v += (bos * H + i_h) * V
k += (bos * Hg + i_h // (H // Hg)) * K
w += (bos * H + i_h) * K
h += ((boh * H + i_h) * K * V).to(tl.int64)
v += ((bos * H + i_h) * V).to(tl.int64)
k += ((bos * Hg + i_h // (H // Hg)) * K).to(tl.int64)
w += ((bos * H + i_h) * K).to(tl.int64)
if SAVE_NEW_VALUE:
v_new += (bos * H + i_h) * V
v_new += ((bos * H + i_h) * V).to(tl.int64)
stride_v = H * V
stride_h = H * K * V
stride_k = Hg * K
......@@ -143,58 +146,48 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
)
tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
p_v = tl.make_block_ptr(
v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
p_v_new = (
tl.make_block_ptr(
v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
if SAVE_NEW_VALUE
else None
)
b_v_new = tl.zeros([BT, BV], dtype=tl.float32)
p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)
)
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_new += tl.dot(b_w, b_h1.to(b_w.dtype))
b_v = tl.dot(b_w, b_h1.to(b_w.dtype))
if K > 64:
p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)
)
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_new += tl.dot(b_w, b_h2.to(b_w.dtype))
b_v += tl.dot(b_w, b_h2.to(b_w.dtype))
if K > 128:
p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)
)
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_new += tl.dot(b_w, b_h3.to(b_w.dtype))
b_v += tl.dot(b_w, b_h3.to(b_w.dtype))
if K > 192:
p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)
)
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_new += tl.dot(b_w, b_h4.to(b_w.dtype))
b_v_new = -b_v_new + tl.load(p_v, boundary_check=(0, 1))
b_v += tl.dot(b_w, b_h4.to(b_w.dtype))
p_v = tl.make_block_ptr(
v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v
if SAVE_NEW_VALUE:
p_v_new = tl.make_block_ptr(
p_v = tl.make_block_ptr(
v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
tl.store(
p_v_new, b_v_new.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)
)
tl.store(p_v, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1))
last_idx = min((i_t + 1) * BT, T) - 1
if USE_G:
last_idx = min((i_t + 1) * BT, T) - 1
b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
p_g = tl.make_block_ptr(
g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
)
b_g = tl.load(p_g, boundary_check=(0,))
b_v_new = b_v_new * safe_exp(b_g_last - b_g)[:, None]
b_v = b_v * safe_exp(b_g_last - b_g)[:, None]
b_g_last = exp(b_g_last)
b_h1 = b_h1 * b_g_last
if K > 64:
......@@ -203,30 +196,64 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
b_h3 = b_h3 * b_g_last
if K > 192:
b_h4 = b_h4 * b_g_last
b_v_new = b_v_new.to(k.dtype.element_ty)
if USE_GK:
o_k1 = tl.arange(0, 64)
b_gk_last1 = tl.load(
gk + (bos + last_idx) * H * K + i_h * K + o_k1,
mask=(o_k1 < K),
other=0.0,
)
b_h1 *= exp(b_gk_last1)[:, None]
if K > 64:
o_k2 = 64 + o_k1
b_gk_last2 = tl.load(
gk + (bos + last_idx) * H * K + i_h * K + o_k2,
mask=(o_k2 < K),
other=0.0,
)
b_h2 *= exp(b_gk_last2)[:, None]
if K > 128:
o_k3 = 128 + o_k1
b_gk_last3 = tl.load(
gk + (bos + last_idx) * H * K + i_h * K + o_k3,
mask=(o_k3 < K),
other=0.0,
)
b_h3 *= exp(b_gk_last3)[:, None]
if K > 192:
o_k4 = 192 + o_k1
b_gk_last4 = tl.load(
gk + (bos + last_idx) * H * K + i_h * K + o_k4,
mask=(o_k4 < K),
other=0.0,
)
b_h4 *= exp(b_gk_last4)[:, None]
b_v = b_v.to(k.dtype.element_ty)
p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h1 += tl.dot(b_k, b_v_new)
b_h1 += tl.dot(b_k, b_v)
if K > 64:
p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h2 += tl.dot(b_k, b_v_new)
b_h2 += tl.dot(b_k, b_v)
if K > 128:
p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h3 += tl.dot(b_k, b_v_new)
b_h3 += tl.dot(b_k, b_v)
if K > 192:
p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h4 += tl.dot(b_k, b_v_new)
b_h4 += tl.dot(b_k, b_v)
# epilogue
if STORE_FINAL_STATE:
......@@ -254,6 +281,7 @@ def chunk_gated_delta_rule_fwd_h(
w: torch.Tensor,
u: torch.Tensor,
g: Optional[torch.Tensor] = None,
gk: Optional[torch.Tensor] = None,
initial_state: Optional[torch.Tensor] = None,
output_final_state: bool = False,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
......@@ -296,6 +324,7 @@ def chunk_gated_delta_rule_fwd_h(
w=w,
v_new=v_new,
g=g,
gk=gk,
h=h,
h0=initial_state,
ht=final_state,
......
......@@ -44,6 +44,7 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar,
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
IS_VARLEN: tl.constexpr,
IS_KDA: tl.constexpr,
):
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_n, i_hv = i_nh // HV, i_nh % HV
......@@ -67,7 +68,11 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
p_beta = beta + (bos * HV + i_hv) * V + o_v
else:
p_beta = beta + bos * HV + i_hv
p_g = g + bos * HV + i_hv
if not IS_KDA:
p_g = g + bos * HV + i_hv
else:
p_gk = g + (bos * HV + i_hv) * K + o_k
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
mask_k = o_k < K
......@@ -83,14 +88,18 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
b_g = tl.load(p_g).to(tl.float32)
if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
b_q = b_q * scale
# [BK, BV]
b_h *= exp(b_g)
if not IS_KDA:
b_g = tl.load(p_g).to(tl.float32)
b_h *= exp(b_g)
else:
b_gk = tl.load(p_gk).to(tl.float32)
b_h *= exp(b_gk[:, None])
# [BV]
b_v -= tl.sum(b_h * b_k[:, None], 0)
if IS_BETA_HEADWISE:
......@@ -108,7 +117,10 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
p_k += H * K
p_o += HV * V
p_v += HV * V
p_g += HV
if not IS_KDA:
p_g += HV
else:
p_gk += HV * K
p_beta += HV * (V if IS_BETA_HEADWISE else 1)
if STORE_FINAL_STATE:
......@@ -165,6 +177,7 @@ def fused_recurrent_gated_delta_rule_fwd(
BV=BV,
IS_BETA_HEADWISE=beta.ndim == v.ndim,
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
IS_KDA=False,
num_warps=num_warps,
num_stages=num_stages,
)
......
# Adapted from https://github.com/vllm-project/vllm/blob/0384aa7150c4c9778efca041ffd1beb3ad2bd694/vllm/model_executor/layers/fla/ops/kda.py
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
import torch
import torch.nn as nn
import triton
import triton.language as tl
from sglang.srt.layers.attention.fla.chunk_delta_h import chunk_gated_delta_rule_fwd_h
from sglang.srt.layers.attention.fla.cumsum import chunk_local_cumsum
from sglang.srt.layers.attention.fla.fused_recurrent import (
fused_recurrent_gated_delta_rule_fwd_kernel,
)
from sglang.srt.layers.attention.fla.index import prepare_chunk_indices
from sglang.srt.layers.attention.fla.l2norm import l2norm_fwd
from sglang.srt.layers.attention.fla.op import exp, log
from sglang.srt.layers.attention.fla.solve_tril import solve_tril
from sglang.srt.layers.attention.fla.utils import is_amd
BT_LIST_AUTOTUNE = [32, 64, 128]
NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if is_amd else [4, 8, 16, 32]
def cdiv(a: int, b: int) -> int:
"""Ceiling division."""
return -(a // -b)
def next_power_of_2(n: int) -> int:
"""The next power of 2 (inclusive)"""
if n < 1:
return 1
return 1 << (n - 1).bit_length()
def fused_recurrent_kda_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: torch.LongTensor | None = None,
# ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
HV = v.shape[2]
N = B if cu_seqlens is None else len(cu_seqlens) - 1
BK, BV = next_power_of_2(K), min(next_power_of_2(V), 8)
NK, NV = cdiv(K, BK), cdiv(V, BV)
assert NK == 1, "NK > 1 is not supported yet"
num_stages = 3
num_warps = 1
o = torch.empty_like(k)
if inplace_final_state:
final_state = initial_state
else:
final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype)
stride_init_state_token = initial_state.stride(0)
stride_final_state_token = final_state.stride(0)
# if ssm_state_indices is None:
# stride_indices_seq, stride_indices_tok = 1, 1
# elif ssm_state_indices.ndim == 1:
# stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
# else:
# stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()
grid = (NK, NV, N * HV)
fused_recurrent_gated_delta_rule_fwd_kernel[grid](
q=q,
k=k,
v=v,
g=g,
beta=beta,
o=o,
h0=initial_state,
ht=final_state,
cu_seqlens=cu_seqlens,
# ssm_state_indices=ssm_state_indices,
# num_accepted_tokens=num_accepted_tokens,
scale=scale,
# N=N,
T=T,
B=B,
H=H,
HV=HV,
K=K,
V=V,
BK=BK,
BV=BV,
# stride_init_state_token=stride_init_state_token,
# stride_final_state_token=stride_final_state_token,
# stride_indices_seq=stride_indices_seq,
# stride_indices_tok=stride_indices_tok,
IS_BETA_HEADWISE=beta.ndim == v.ndim,
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
# INPLACE_FINAL_STATE=inplace_final_state,
IS_KDA=True,
num_warps=num_warps,
num_stages=num_stages,
)
return o, final_state
def fused_recurrent_kda(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor = None,
scale: float = None,
initial_state: torch.Tensor = None,
inplace_final_state: bool = True,
use_qk_l2norm_in_kernel: bool = True,
cu_seqlens: torch.LongTensor | None = None,
# ssm_state_indices: torch.LongTensor | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
if cu_seqlens is not None and q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing."
)
if scale is None:
scale = k.shape[-1] ** -0.5
o, final_state = fused_recurrent_kda_fwd(
q=q.contiguous(),
k=k.contiguous(),
v=v.contiguous(),
g=g.contiguous(),
beta=beta.contiguous(),
scale=scale,
initial_state=initial_state,
inplace_final_state=inplace_final_state,
cu_seqlens=cu_seqlens,
# ssm_state_indices=ssm_state_indices,
num_accepted_tokens=None,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)
return o, final_state
@triton.heuristics(
{
"STORE_RESIDUAL_OUT": lambda args: args["residual_out"] is not None,
"HAS_RESIDUAL": lambda args: args["residual"] is not None,
"HAS_WEIGHT": lambda args: args["w"] is not None,
"HAS_BIAS": lambda args: args["b"] is not None,
}
)
@triton.jit
def layer_norm_gated_fwd_kernel(
x, # pointer to the input
g, # pointer to the gate
y, # pointer to the output
w, # pointer to the weights
b, # pointer to the biases
residual, # pointer to the residual
residual_out, # pointer to the residual
mean, # pointer to the mean
rstd, # pointer to the 1/std
eps, # epsilon to avoid division by zero
T, # number of rows in x
D: tl.constexpr, # number of columns in x
BT: tl.constexpr,
BD: tl.constexpr,
ACTIVATION: tl.constexpr,
IS_RMS_NORM: tl.constexpr,
STORE_RESIDUAL_OUT: tl.constexpr,
HAS_RESIDUAL: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_BIAS: tl.constexpr,
):
i_t = tl.program_id(0)
o_d = tl.arange(0, BD)
m_d = o_d < D
p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
if HAS_RESIDUAL:
p_res = tl.make_block_ptr(
residual, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)
)
b_x += tl.load(p_res, boundary_check=(0, 1)).to(tl.float32)
if STORE_RESIDUAL_OUT:
p_res_out = tl.make_block_ptr(
residual_out, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)
)
tl.store(p_res_out, b_x.to(p_res_out.dtype.element_ty), boundary_check=(0, 1))
if not IS_RMS_NORM:
b_mean = tl.sum(b_x, axis=1) / D
p_mean = tl.make_block_ptr(mean, (T,), (1,), (i_t * BT,), (BT,), (0,))
tl.store(p_mean, b_mean.to(p_mean.dtype.element_ty), boundary_check=(0,))
b_xbar = tl.where(m_d[None, :], b_x - b_mean[:, None], 0.0)
b_var = tl.sum(b_xbar * b_xbar, axis=1) / D
else:
b_xbar = tl.where(m_d[None, :], b_x, 0.0)
b_var = tl.sum(b_xbar * b_xbar, axis=1) / D
b_rstd = 1 / tl.sqrt(b_var + eps)
p_rstd = tl.make_block_ptr(rstd, (T,), (1,), (i_t * BT,), (BT,), (0,))
tl.store(p_rstd, b_rstd.to(p_rstd.dtype.element_ty), boundary_check=(0,))
if HAS_WEIGHT:
b_w = tl.load(w + o_d, mask=m_d).to(tl.float32)
if HAS_BIAS:
b_b = tl.load(b + o_d, mask=m_d).to(tl.float32)
b_x_hat = (
(b_x - b_mean[:, None]) * b_rstd[:, None]
if not IS_RMS_NORM
else b_x * b_rstd[:, None]
)
b_y = b_x_hat * b_w[None, :] if HAS_WEIGHT else b_x_hat
if HAS_BIAS:
b_y = b_y + b_b[None, :]
# swish/sigmoid output gate
p_g = tl.make_block_ptr(g, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
if ACTIVATION == "swish" or ACTIVATION == "silu":
b_y = b_y * b_g * tl.sigmoid(b_g)
elif ACTIVATION == "sigmoid":
b_y = b_y * tl.sigmoid(b_g)
# Write output
p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics(
{
"STORE_RESIDUAL_OUT": lambda args: args["residual_out"] is not None,
"HAS_RESIDUAL": lambda args: args["residual"] is not None,
"HAS_WEIGHT": lambda args: args["w"] is not None,
"HAS_BIAS": lambda args: args["b"] is not None,
}
)
@triton.jit
def layer_norm_gated_fwd_kernel1(
x, # pointer to the input
g, # pointer to the gate
y, # pointer to the output
w, # pointer to the weights
b, # pointer to the biases
residual, # pointer to the residual
residual_out, # pointer to the residual
mean, # pointer to the mean
rstd, # pointer to the 1/std
eps, # epsilon to avoid division by zero
D: tl.constexpr, # number of columns in x
BD: tl.constexpr,
ACTIVATION: tl.constexpr,
IS_RMS_NORM: tl.constexpr,
STORE_RESIDUAL_OUT: tl.constexpr,
HAS_RESIDUAL: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_BIAS: tl.constexpr,
):
i_t = tl.program_id(0)
x += i_t * D
y += i_t * D
g += i_t * D
if HAS_RESIDUAL:
residual += i_t * D
if STORE_RESIDUAL_OUT:
residual_out += i_t * D
o_d = tl.arange(0, BD)
m_d = o_d < D
b_x = tl.load(x + o_d, mask=m_d, other=0.0).to(tl.float32)
if HAS_RESIDUAL:
b_x += tl.load(residual + o_d, mask=m_d, other=0.0).to(tl.float32)
if STORE_RESIDUAL_OUT:
tl.store(residual_out + o_d, b_x, mask=m_d)
if not IS_RMS_NORM:
b_mean = tl.sum(b_x, axis=0) / D
tl.store(mean + i_t, b_mean)
b_xbar = tl.where(m_d, b_x - b_mean, 0.0)
b_var = tl.sum(b_xbar * b_xbar, axis=0) / D
else:
b_xbar = tl.where(m_d, b_x, 0.0)
b_var = tl.sum(b_xbar * b_xbar, axis=0) / D
b_rstd = 1 / tl.sqrt(b_var + eps)
tl.store(rstd + i_t, b_rstd)
if HAS_WEIGHT:
b_w = tl.load(w + o_d, mask=m_d).to(tl.float32)
if HAS_BIAS:
b_b = tl.load(b + o_d, mask=m_d).to(tl.float32)
b_x_hat = (b_x - b_mean) * b_rstd if not IS_RMS_NORM else b_x * b_rstd
b_y = b_x_hat * b_w if HAS_WEIGHT else b_x_hat
if HAS_BIAS:
b_y = b_y + b_b
# swish/sigmoid output gate
b_g = tl.load(g + o_d, mask=m_d, other=0.0).to(tl.float32)
if ACTIVATION == "swish" or ACTIVATION == "silu":
b_y = b_y * b_g * tl.sigmoid(b_g)
elif ACTIVATION == "sigmoid":
b_y = b_y * tl.sigmoid(b_g)
# Write output
tl.store(y + o_d, b_y, mask=m_d)
def layer_norm_gated_fwd(
x: torch.Tensor,
g: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
activation: str = "swish",
eps: float = 1e-5,
residual: torch.Tensor = None,
out_dtype: torch.dtype = None,
residual_dtype: torch.dtype = None,
is_rms_norm: bool = False,
):
if residual is not None:
residual_dtype = residual.dtype
T, D = x.shape
if residual is not None:
assert residual.shape == (T, D)
if weight is not None:
assert weight.shape == (D,)
if bias is not None:
assert bias.shape == (D,)
# allocate output
y = x if out_dtype is None else torch.empty_like(x, dtype=out_dtype)
if residual is not None or (
residual_dtype is not None and residual_dtype != x.dtype
):
residual_out = torch.empty(T, D, device=x.device, dtype=residual_dtype)
else:
residual_out = None
mean = (
torch.empty((T,), dtype=torch.float, device=x.device)
if not is_rms_norm
else None
)
rstd = torch.empty((T,), dtype=torch.float, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BD = min(MAX_FUSED_SIZE, next_power_of_2(D))
if D > BD:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
if D <= 512:
BT = 32
layer_norm_gated_fwd_kernel[(cdiv(T, BT),)](
x=x,
g=g,
y=y,
w=weight,
b=bias,
residual=residual,
residual_out=residual_out,
mean=mean,
rstd=rstd,
eps=eps,
T=T,
D=D,
BD=BD,
BT=BT,
ACTIVATION=activation,
IS_RMS_NORM=is_rms_norm,
num_warps=4,
)
else:
layer_norm_gated_fwd_kernel1[(T,)](
x=x,
g=g,
y=y,
w=weight,
b=bias,
residual=residual,
residual_out=residual_out,
mean=mean,
rstd=rstd,
eps=eps,
D=D,
BD=BD,
ACTIVATION=activation,
IS_RMS_NORM=is_rms_norm,
num_warps=4,
)
# residual_out is None if residual is None and residual_dtype == input_dtype
return y, mean, rstd, residual_out if residual_out is not None else x
def rms_norm_gated(
x: torch.Tensor,
g: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
activation: str = "swish",
residual: torch.Tensor | None = None,
prenorm: bool = False,
residual_in_fp32: bool = False,
eps: float = 1e-6,
):
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.contiguous().reshape(-1, x.shape[-1])
g = g.contiguous().reshape(-1, g.shape[-1])
if residual is not None:
assert residual.shape == x_shape_og
residual = residual.contiguous().reshape(-1, residual.shape[-1])
residual_dtype = (
residual.dtype
if residual is not None
else (torch.float if residual_in_fp32 else None)
)
y, _, _, residual_out = layer_norm_gated_fwd(
x=x,
g=g,
weight=weight,
bias=bias,
activation=activation,
eps=eps,
residual=residual,
residual_dtype=residual_dtype,
is_rms_norm=True,
)
y = y.reshape(x_shape_og)
return y if not prenorm else (y, residual_out.reshape(x_shape_og))
class FusedRMSNormGated(nn.Module):
def __init__(
self,
hidden_size: int,
elementwise_affine: bool = True,
eps: float = 1e-5,
activation: str = "swish",
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.hidden_size = hidden_size
self.elementwise_affine = elementwise_affine
self.eps = eps
self.activation = activation
if self.activation not in ["swish", "silu", "sigmoid"]:
raise ValueError(f"Unsupported activation: {self.activation}")
if elementwise_affine:
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def forward(
self,
x: torch.Tensor,
g: torch.Tensor,
residual: torch.Tensor | None = None,
prenorm: bool = False,
residual_in_fp32: bool = False,
) -> torch.Tensor:
return rms_norm_gated(
x,
g,
self.weight,
self.bias,
self.activation,
residual=residual,
eps=self.eps,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32,
)
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.autotune(
configs=[
triton.Config({"BK": BK}, num_warps=num_warps, num_stages=num_stages)
for BK in [32, 64]
for num_warps in [1, 2, 4, 8]
for num_stages in [2, 3, 4]
],
key=["BC"],
)
@triton.jit(do_not_specialize=["T"])
def chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_inter(
q,
k,
g,
beta,
A,
Aqk,
scale,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr,
BK: tl.constexpr,
NC: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
i_i, i_j = i_c // NC, i_c % NC
if IS_VARLEN:
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
if i_t * BT + i_i * BC >= T:
return
if i_i <= i_j:
return
q += (bos * H + i_h) * K
k += (bos * H + i_h) * K
g += (bos * H + i_h) * K
A += (bos * H + i_h) * BT
Aqk += (bos * H + i_h) * BT
p_b = tl.make_block_ptr(
beta + bos * H + i_h, (T,), (H,), (i_t * BT + i_i * BC,), (BC,), (0,)
)
b_b = tl.load(p_b, boundary_check=(0,))
b_A = tl.zeros([BC, BC], dtype=tl.float32)
b_Aqk = tl.zeros([BC, BC], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(
q, (T, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)
)
p_k = tl.make_block_ptr(
k, (T, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)
)
p_g = tl.make_block_ptr(
g, (T, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)
)
b_kt = tl.make_block_ptr(
k, (K, T), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)
)
p_gk = tl.make_block_ptr(
g, (K, T), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)
)
o_k = i_k * BK + tl.arange(0, BK)
m_k = o_k < K
# [BK,]
b_gn = tl.load(g + (i_t * BT + i_i * BC) * H * K + o_k, mask=m_k, other=0)
# [BC, BK]
b_g = tl.load(p_g, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1)) * exp(b_g - b_gn[None, :])
# [BK, BC]
b_gk = tl.load(p_gk, boundary_check=(0, 1))
b_kt = tl.load(b_kt, boundary_check=(0, 1))
# [BC, BC]
b_ktg = b_kt * exp(b_gn[:, None] - b_gk)
b_A += tl.dot(b_k, b_ktg)
b_q = tl.load(p_q, boundary_check=(0, 1))
b_qg = b_q * exp(b_g - b_gn[None, :]) * scale
b_Aqk += tl.dot(b_qg, b_ktg)
b_A *= b_b[:, None]
p_A = tl.make_block_ptr(
A, (T, BT), (H * BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)
)
tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
p_Aqk = tl.make_block_ptr(
Aqk, (T, BT), (H * BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)
)
tl.store(p_Aqk, b_Aqk.to(Aqk.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]],
key=["BK", "BT"],
)
@triton.jit(do_not_specialize=["T"])
def chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_intra(
q,
k,
g,
beta,
A,
Aqk,
scale,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr,
BK: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
if i_t * BT + i_i * BC >= T:
return
o_i = tl.arange(0, BC)
o_k = tl.arange(0, BK)
m_k = o_k < K
m_A = (i_t * BT + i_i * BC + o_i) < T
o_A = (bos + i_t * BT + i_i * BC + o_i) * H * BT + i_h * BT + i_i * BC
p_q = tl.make_block_ptr(
q + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT + i_i * BC, 0),
(BC, BK),
(1, 0),
)
p_k = tl.make_block_ptr(
k + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT + i_i * BC, 0),
(BC, BK),
(1, 0),
)
p_g = tl.make_block_ptr(
g + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT + i_i * BC, 0),
(BC, BK),
(1, 0),
)
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_g = tl.load(p_g, boundary_check=(0, 1))
p_b = beta + (bos + i_t * BT + i_i * BC + o_i) * H + i_h
b_k = b_k * tl.load(p_b, mask=m_A, other=0)[:, None]
p_kt = k + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k
p_gk = g + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k
for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
b_kt = tl.load(p_kt, mask=m_k, other=0).to(tl.float32)
b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32)
b_ktg = b_kt[None, :] * exp(b_g - b_gk[None, :])
b_A = tl.sum(b_k * b_ktg, 1)
b_A = tl.where(o_i > j, b_A, 0.0)
b_Aqk = tl.sum(b_q * b_ktg, 1)
b_Aqk = tl.where(o_i >= j, b_Aqk * scale, 0.0)
tl.store(A + o_A + j, b_A, mask=m_A)
tl.store(Aqk + o_A + j, b_Aqk, mask=m_A)
p_kt += H * K
p_gk += H * K
def chunk_kda_scaled_dot_kkt_fwd(
q: torch.Tensor,
k: torch.Tensor,
gk: torch.Tensor | None = None,
beta: torch.Tensor | None = None,
scale: float | None = None,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]:
r"""
Compute beta * K * K^T.
Args:
k (torch.Tensor):
The key tensor of shape `[B, T, H, K]`.
beta (torch.Tensor):
The beta tensor of shape `[B, T, H]`.
gk (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.
cu_seqlens (torch.LongTensor):
The cumulative sequence lengths of the input tensor.
Default: None
chunk_size (int):
The chunk size. Default: 64.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float32`
Returns:
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
"""
B, T, H, K = k.shape
assert K <= 256
BT = chunk_size
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
NT = cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
BC = min(16, BT)
NC = cdiv(BT, BC)
BK = max(next_power_of_2(K), 16)
A = torch.zeros(B, T, H, BT, device=k.device, dtype=output_dtype)
Aqk = torch.zeros(B, T, H, BT, device=k.device, dtype=output_dtype)
grid = (NT, NC * NC, B * H)
chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_inter[grid](
q=q,
k=k,
g=gk,
beta=beta,
A=A,
Aqk=Aqk,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
K=K,
BT=BT,
BC=BC,
NC=NC,
)
grid = (NT, NC, B * H)
chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_intra[grid](
q=q,
k=k,
g=gk,
beta=beta,
A=A,
Aqk=Aqk,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
K=K,
BT=BT,
BC=BC,
BK=BK,
)
return A, Aqk
@triton.heuristics(
{
"STORE_QG": lambda args: args["qg"] is not None,
"STORE_KG": lambda args: args["kg"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [2, 4, 8]
for num_stages in [2, 3, 4]
],
key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"],
)
@triton.jit(do_not_specialize=["T"])
def recompute_w_u_fwd_kernel(
q,
k,
qg,
kg,
v,
beta,
w,
u,
A,
gk,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
STORE_QG: tl.constexpr,
STORE_KG: tl.constexpr,
IS_VARLEN: tl.constexpr,
DOT_PRECISION: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
p_b = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
b_b = tl.load(p_b, boundary_check=(0,))
p_A = tl.make_block_ptr(
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
)
b_A = tl.load(p_A, boundary_check=(0, 1))
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(
v + (bos * H + i_h) * V,
(T, V),
(H * V, 1),
(i_t * BT, i_v * BV),
(BT, BV),
(1, 0),
)
p_u = tl.make_block_ptr(
u + (bos * H + i_h) * V,
(T, V),
(H * V, 1),
(i_t * BT, i_v * BV),
(BT, BV),
(1, 0),
)
b_v = tl.load(p_v, boundary_check=(0, 1))
b_vb = (b_v * b_b[:, None]).to(b_v.dtype)
b_u = tl.dot(b_A, b_vb, input_precision=DOT_PRECISION)
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
for i_k in range(tl.cdiv(K, BK)):
p_w = tl.make_block_ptr(
w + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
p_k = tl.make_block_ptr(
k + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_kb = b_k * b_b[:, None]
p_gk = tl.make_block_ptr(
gk + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
b_gk = tl.load(p_gk, boundary_check=(0, 1))
b_kb *= exp(b_gk)
if STORE_QG:
p_q = tl.make_block_ptr(
q + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
p_qg = tl.make_block_ptr(
qg + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
b_q = tl.load(p_q, boundary_check=(0, 1))
b_qg = b_q * exp(b_gk)
tl.store(p_qg, b_qg.to(p_qg.dtype.element_ty), boundary_check=(0, 1))
if STORE_KG:
last_idx = min(i_t * BT + BT, T) - 1
o_k = i_k * BK + tl.arange(0, BK)
m_k = o_k < K
b_gn = tl.load(
gk + ((bos + last_idx) * H + i_h) * K + o_k, mask=m_k, other=0.0
)
b_kg = b_k * exp(b_gn - b_gk)
p_kg = tl.make_block_ptr(
kg + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
tl.store(p_kg, b_kg.to(p_kg.dtype.element_ty), boundary_check=(0, 1))
b_w = tl.dot(b_A, b_kb.to(b_k.dtype))
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
def recompute_w_u_fwd(
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
A: torch.Tensor,
q: torch.Tensor | None = None,
gk: torch.Tensor | None = None,
cu_seqlens: torch.LongTensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
BT = A.shape[-1]
BK = 64
BV = 64
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
NT = cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
w = torch.empty_like(k)
u = torch.empty_like(v)
kg = torch.empty_like(k) if gk is not None else None
recompute_w_u_fwd_kernel[(NT, B * H)](
q=q,
k=k,
qg=None,
kg=kg,
v=v,
beta=beta,
w=w,
u=u,
A=A,
gk=gk,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
DOT_PRECISION="ieee",
)
return w, u, None, kg
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.autotune(
configs=[
triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages)
for BK in [32, 64]
for BV in [64, 128]
for num_warps in [2, 4, 8]
for num_stages in [2, 3, 4]
],
key=["BT"],
)
@triton.jit(do_not_specialize=["T"])
def chunk_gla_fwd_kernel_o(
q,
v,
g,
h,
o,
A,
cu_seqlens,
chunk_indices,
scale,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_tg = i_t
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
NT = tl.cdiv(T, BT)
else:
NT = tl.cdiv(T, BT)
i_tg = i_b * NT + i_t
bos, eos = i_b * T, i_b * T + T
m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]
b_o = tl.zeros([BT, BV], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(
q + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
p_g = tl.make_block_ptr(
g + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
p_h = tl.make_block_ptr(
h + (i_tg * H + i_h) * K * V,
(K, V),
(V, 1),
(i_k * BK, i_v * BV),
(BK, BV),
(1, 0),
)
# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
# [BT, BK]
b_g = tl.load(p_g, boundary_check=(0, 1))
# [BT, BK]
b_qg = (b_q * exp(b_g)).to(b_q.dtype)
# [BK, BV]
b_h = tl.load(p_h, boundary_check=(0, 1))
# works but dkw, owing to divine benevolence
# [BT, BV]
if i_k >= 0:
b_o += tl.dot(b_qg, b_h.to(b_qg.dtype))
p_v = tl.make_block_ptr(
v + (bos * H + i_h) * V,
(T, V),
(H * V, 1),
(i_t * BT, i_v * BV),
(BT, BV),
(1, 0),
)
p_o = tl.make_block_ptr(
o + (bos * H + i_h) * V,
(T, V),
(H * V, 1),
(i_t * BT, i_v * BV),
(BT, BV),
(1, 0),
)
p_A = tl.make_block_ptr(
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
)
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BT, BT]
b_A = tl.load(p_A, boundary_check=(0, 1))
b_A = tl.where(m_s, b_A, 0.0).to(b_v.dtype)
b_o += tl.dot(b_A, b_v, allow_tf32=False)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
def chunk_gla_fwd_o_gk(
q: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
A: torch.Tensor,
h: torch.Tensor,
o: torch.Tensor,
scale: float,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
):
B, T, H, K, V = *q.shape, v.shape[-1]
BT = chunk_size
chunk_indices = (
prepare_chunk_indices(cu_seqlens, chunk_size)
if cu_seqlens is not None
else None
)
NT = cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
def grid(meta):
return (cdiv(V, meta["BV"]), NT, B * H)
chunk_gla_fwd_kernel_o[grid](
q=q,
v=v,
g=g,
h=h,
o=o,
A=A,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
scale=scale,
T=T,
H=H,
K=K,
V=V,
BT=BT,
)
return o
def chunk_kda_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
):
chunk_size = 64
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens)
# the intra Aqk is kept in fp32
# the computation has very marginal effect on the entire throughput
A, Aqk = chunk_kda_scaled_dot_kkt_fwd(
q=q,
k=k,
gk=g,
beta=beta,
scale=scale,
cu_seqlens=cu_seqlens,
output_dtype=torch.float32,
)
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
w, u, _, kg = recompute_w_u_fwd(
k=k,
v=v,
beta=beta,
A=A,
gk=g,
cu_seqlens=cu_seqlens,
)
del A
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
k=kg,
w=w,
u=u,
gk=g,
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
)
del w, u, kg
o = chunk_gla_fwd_o_gk(
q=q,
v=v_new,
g=g,
A=Aqk,
h=h,
o=v,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
del Aqk, v_new, h
return o, final_state
def chunk_kda(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
use_qk_l2norm_in_kernel: bool = False,
cu_seqlens: torch.LongTensor | None = None,
**kwargs,
):
if scale is None:
scale = k.shape[-1] ** -0.5
if use_qk_l2norm_in_kernel:
q = l2norm_fwd(q.contiguous())
k = l2norm_fwd(k.contiguous())
o, final_state = chunk_kda_fwd(
q=q,
k=k,
v=v.contiguous(),
g=g.contiguous(),
beta=beta.contiguous(),
scale=scale,
initial_state=initial_state.contiguous(),
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
)
return o, final_state
@triton.autotune(
configs=[
triton.Config({"BT": bt}, num_warps=nw, num_stages=ns)
for bt in BT_LIST_AUTOTUNE
for nw in NUM_WARPS_AUTOTUNE
for ns in [2, 3]
],
key=["H", "D"],
)
@triton.jit
def kda_gate_fwd_kernel(
g,
A,
y,
g_bias,
beta: tl.constexpr,
threshold: tl.constexpr,
T,
H,
D: tl.constexpr,
BT: tl.constexpr,
BD: tl.constexpr,
HAS_BIAS: tl.constexpr,
):
i_t, i_h = tl.program_id(0), tl.program_id(1)
n_t = i_t * BT
b_a = tl.load(A + i_h).to(tl.float32)
b_a = -tl.exp(b_a)
stride_row = H * D
stride_col = 1
g_ptr = tl.make_block_ptr(
base=g + i_h * D,
shape=(T, D),
strides=(stride_row, stride_col),
offsets=(n_t, 0),
block_shape=(BT, BD),
order=(1, 0),
)
y_ptr = tl.make_block_ptr(
base=y + i_h * D,
shape=(T, D),
strides=(stride_row, stride_col),
offsets=(n_t, 0),
block_shape=(BT, BD),
order=(1, 0),
)
b_g = tl.load(g_ptr, boundary_check=(0, 1)).to(tl.float32)
if HAS_BIAS:
n_d = tl.arange(0, BD)
bias_mask = n_d < D
b_bias = tl.load(g_bias + i_h * D + n_d, mask=bias_mask, other=0.0).to(
tl.float32
)
b_g = b_g + b_bias[None, :]
# softplus(x, beta) = (1/beta) * log(1 + exp(beta * x))
# When beta * x > threshold, use linear approximation x
# Use threshold to switch to linear when beta*x > threshold
g_scaled = b_g * beta
use_linear = g_scaled > threshold
sp = tl.where(use_linear, b_g, (1.0 / beta) * log(1.0 + tl.exp(g_scaled)))
b_y = b_a * sp
tl.store(y_ptr, b_y.to(y.dtype.element_ty), boundary_check=(0, 1))
def fused_kda_gate(
g: torch.Tensor,
A: torch.Tensor,
head_k_dim: int,
g_bias: torch.Tensor | None = None,
beta: float = 1.0,
threshold: float = 20.0,
) -> torch.Tensor:
"""
Forward pass for KDA gate:
input g: [..., H*D]
param A: [H] or [1, 1, H, 1]
beta: softplus beta parameter
threshold: softplus threshold parameter
return : [..., H, D]
"""
orig_shape = g.shape[:-1]
g = g.view(-1, g.shape[-1])
T = g.shape[0]
HD = g.shape[1]
H = A.numel()
assert H * head_k_dim == HD
y = torch.empty_like(g, dtype=torch.float32)
def grid(meta):
return (cdiv(T, meta["BT"]), H)
kda_gate_fwd_kernel[grid](
g,
A,
y,
g_bias,
beta,
threshold,
T,
H,
head_k_dim,
BD=next_power_of_2(head_k_dim),
HAS_BIAS=g_bias is not None,
)
y = y.view(*orig_shape, H, head_k_dim)
return y
from typing import Optional, Union
import torch
from einops import rearrange
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule
......@@ -10,6 +11,11 @@ from sglang.srt.layers.attention.fla.fused_recurrent import (
from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (
fused_sigmoid_gating_delta_rule_update,
)
from sglang.srt.layers.attention.fla.kda import (
chunk_kda,
fused_kda_gate,
fused_recurrent_kda,
)
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
PAD_SLOT_ID,
causal_conv1d_fn,
......@@ -227,6 +233,223 @@ class MambaAttnBackendBase(AttentionBackend):
return 1 # Mamba attn does not use seq lens to index kv cache
class KimiLinearAttnBackend(MambaAttnBackendBase):
"""Attention backend using Mamba kernel."""
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
**kwargs,
):
q_proj_states = kwargs["q_proj_states"]
k_proj_states = kwargs["k_proj_states"]
v_proj_states = kwargs["v_proj_states"]
q_conv_weights = kwargs["q_conv_weights"]
k_conv_weights = kwargs["k_conv_weights"]
v_conv_weights = kwargs["v_conv_weights"]
q_conv_bias = kwargs["q_conv_bias"]
k_conv_bias = kwargs["k_conv_bias"]
v_conv_bias = kwargs["v_conv_bias"]
A_log = kwargs["A_log"]
dt_bias = kwargs["dt_bias"]
b_proj = kwargs["b_proj"]
f_a_proj = kwargs["f_a_proj"]
f_b_proj = kwargs["f_b_proj"]
hidden_states = kwargs["hidden_states"]
head_dim = kwargs["head_dim"]
layer_id = kwargs["layer_id"]
layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
q_conv_state, k_conv_state, v_conv_state = layer_cache.conv
ssm_states = layer_cache.temporal
query_start_loc = self.forward_metadata.query_start_loc
cache_indices = self.forward_metadata.mamba_cache_indices
q_conv_state = q_conv_state.transpose(-1, -2)
k_conv_state = k_conv_state.transpose(-1, -2)
v_conv_state = v_conv_state.transpose(-1, -2)
q = causal_conv1d_update(
q_proj_states,
q_conv_state,
q_conv_weights,
q_conv_bias,
activation="silu",
conv_state_indices=cache_indices,
)
k = causal_conv1d_update(
k_proj_states,
k_conv_state,
k_conv_weights,
k_conv_bias,
activation="silu",
conv_state_indices=cache_indices,
)
v = causal_conv1d_update(
v_proj_states,
v_conv_state,
v_conv_weights,
v_conv_bias,
activation="silu",
conv_state_indices=cache_indices,
)
q, k, v = map(
lambda x: rearrange(x, "n (h d) -> 1 n h d", d=head_dim), (q, k, v)
)
beta = b_proj(hidden_states)[0].float().sigmoid()
g = f_b_proj(f_a_proj(hidden_states)[0])[0]
g = fused_kda_gate(g, A_log, head_dim, g_bias=dt_bias)
beta = beta.unsqueeze(0)
g = g.unsqueeze(0)
initial_state = ssm_states[cache_indices].contiguous()
(
core_attn_out,
last_recurrent_state,
) = fused_recurrent_kda(
q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=initial_state,
use_qk_l2norm_in_kernel=True,
cu_seqlens=query_start_loc,
)
ssm_states[cache_indices] = last_recurrent_state
return core_attn_out
def forward_extend(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
**kwargs,
):
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
causal_conv1d_fn,
)
q_proj_states = kwargs["q_proj_states"]
k_proj_states = kwargs["k_proj_states"]
v_proj_states = kwargs["v_proj_states"]
q_conv_weights = kwargs["q_conv_weights"]
k_conv_weights = kwargs["k_conv_weights"]
v_conv_weights = kwargs["v_conv_weights"]
q_conv_bias = kwargs["q_conv_bias"]
k_conv_bias = kwargs["k_conv_bias"]
v_conv_bias = kwargs["v_conv_bias"]
A_log = kwargs["A_log"]
dt_bias = kwargs["dt_bias"]
b_proj = kwargs["b_proj"]
f_a_proj = kwargs["f_a_proj"]
f_b_proj = kwargs["f_b_proj"]
hidden_states = kwargs["hidden_states"]
head_dim = kwargs["head_dim"]
layer_id = kwargs["layer_id"]
query_start_loc = self.forward_metadata.query_start_loc
cache_indices = self.forward_metadata.mamba_cache_indices
mamba_cache_params = self.req_to_token_pool.mamba2_layer_cache(layer_id)
conv_state_q, conv_state_k, conv_state_v = mamba_cache_params.conv
# deal with strides
conv_state_q = conv_state_q.transpose(-1, -2)
conv_state_k = conv_state_k.transpose(-1, -2)
conv_state_v = conv_state_v.transpose(-1, -2)
ssm_states = mamba_cache_params.temporal
has_initial_state = forward_batch.extend_prefix_lens > 0
q_proj_states = q_proj_states.transpose(0, 1)
k_proj_states = k_proj_states.transpose(0, 1)
v_proj_states = v_proj_states.transpose(0, 1)
q = causal_conv1d_fn(
q_proj_states,
q_conv_weights,
q_conv_bias,
activation="silu",
conv_states=conv_state_q,
has_initial_state=has_initial_state,
cache_indices=cache_indices,
query_start_loc=query_start_loc,
seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
).transpose(0, 1)
k = causal_conv1d_fn(
k_proj_states,
k_conv_weights,
k_conv_bias,
activation="silu",
conv_states=conv_state_k,
has_initial_state=has_initial_state,
cache_indices=cache_indices,
query_start_loc=query_start_loc,
seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
).transpose(0, 1)
v = causal_conv1d_fn(
v_proj_states,
v_conv_weights,
v_conv_bias,
activation="silu",
conv_states=conv_state_v,
has_initial_state=has_initial_state,
cache_indices=cache_indices,
query_start_loc=query_start_loc,
seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
).transpose(0, 1)
q, k, v = map(
lambda x: rearrange(x, "n (h d) -> 1 n h d", d=head_dim), (q, k, v)
)
beta = b_proj(hidden_states)[0].float().sigmoid()
g = f_b_proj(f_a_proj(hidden_states)[0])[0]
g = fused_kda_gate(g, A_log, head_dim, g_bias=dt_bias)
beta = beta.unsqueeze(0)
g = g.unsqueeze(0)
initial_state = ssm_states[cache_indices].contiguous()
(
core_attn_out,
last_recurrent_state,
) = chunk_kda(
q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=initial_state,
output_final_state=True,
use_qk_l2norm_in_kernel=True,
cu_seqlens=query_start_loc,
)
ssm_states[cache_indices] = last_recurrent_state
return core_attn_out
class GDNAttnBackend(MambaAttnBackendBase):
"""Attention backend using Mamba kernel."""
......
......@@ -92,7 +92,10 @@ class TritonAttnBackend(AttentionBackend):
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
get_attention_tp_size()
)
if model_runner.hybrid_gdn_config is not None:
if (
model_runner.hybrid_gdn_config is not None
or model_runner.kimi_linear_config is not None
):
# For hybrid linear models, layer_id = 0 may not be full attention
self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()
else:
......
......@@ -17,7 +17,7 @@ from __future__ import annotations
from dataclasses import dataclass
from sglang.srt.configs.mamba_utils import Mamba2CacheParams
from sglang.srt.configs.mamba_utils import KimiLinearCacheParams, Mamba2CacheParams
from sglang.srt.layers.attention.nsa import index_buf_accessor
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
......@@ -33,7 +33,7 @@ KVCache actually holds the physical kv cache.
import abc
import logging
from contextlib import nullcontext
from contextlib import contextmanager, nullcontext
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import numpy as np
......@@ -59,7 +59,9 @@ if _is_npu:
import torch_npu
def get_tensor_size_bytes(t: torch.Tensor):
def get_tensor_size_bytes(t: Union[torch.Tensor, List[torch.Tensor]]):
if isinstance(t, list):
return sum(get_tensor_size_bytes(x) for x in t)
return np.prod(t.shape) * t.dtype.itemsize
......@@ -116,10 +118,15 @@ class ReqToTokenPool:
class MambaPool:
@dataclass(frozen=True, kw_only=True)
class State:
conv: torch.Tensor
conv: Union[torch.Tensor, List[torch.Tensor]]
temporal: torch.Tensor
def at_layer_idx(self, layer: int):
if isinstance(self.conv, list):
return type(self)(
conv=[v[layer] for v in self.conv],
temporal=self.temporal[layer],
)
return type(self)(**{k: v[layer] for k, v in vars(self).items()})
def mem_usage_bytes(self):
......@@ -127,14 +134,14 @@ class MambaPool:
@dataclass(frozen=True, kw_only=True)
class SpeculativeState(State):
intermediate_ssm: torch.Tensor
intermediate_ssm: Union[torch.Tensor, List[torch.Tensor]]
intermediate_conv_window: torch.Tensor
def __init__(
self,
*,
size: int,
cache_params: "Mamba2CacheParams",
cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"],
device: str,
speculative_num_draft_tokens: Optional[int] = None,
):
......@@ -157,18 +164,29 @@ class MambaPool:
else:
self.custom_mem_pool = None
self.is_kda_cache = isinstance(cache_params, KimiLinearCacheParams)
with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.enable_custom_mem_pool
else nullcontext()
):
# assume conv_state = (dim, state_len)
assert conv_state_shape[0] > conv_state_shape[1]
conv_state = torch.zeros(
size=(num_mamba_layers, size + 1) + conv_state_shape,
dtype=conv_dtype,
device=device,
)
if self.is_kda_cache:
conv_state = [
torch.zeros(
size=(num_mamba_layers, size + 1) + conv_shape,
dtype=conv_dtype,
device=device,
)
for conv_shape in conv_state_shape
]
else:
# assume conv_state = (dim, state_len)
assert conv_state_shape[0] > conv_state_shape[1]
conv_state = torch.zeros(
size=(num_mamba_layers, size + 1) + conv_state_shape,
dtype=conv_dtype,
device=device,
)
temporal_state = torch.zeros(
size=(num_mamba_layers, size + 1) + temporal_state_shape,
dtype=ssm_dtype,
......@@ -191,17 +209,34 @@ class MambaPool:
)
# Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
intermediate_conv_window_cache = torch.zeros(
size=(
num_mamba_layers,
size + 1,
speculative_num_draft_tokens,
conv_state_shape[0],
conv_state_shape[1],
),
dtype=conv_dtype,
device="cuda",
)
if self.is_kda_cache:
intermediate_conv_window_cache = [
torch.zeros(
size=(
num_mamba_layers,
size + 1,
speculative_num_draft_tokens,
conv_shape[0],
conv_shape[1],
),
dtype=conv_dtype,
device="cuda",
)
for conv_shape in conv_state_shape
]
else:
intermediate_conv_window_cache = torch.zeros(
size=(
num_mamba_layers,
size + 1,
speculative_num_draft_tokens,
conv_state_shape[0],
conv_state_shape[1],
),
dtype=conv_dtype,
device="cuda",
)
self.mamba_cache = self.SpeculativeState(
conv=conv_state,
temporal=temporal_state,
......@@ -255,15 +290,25 @@ class MambaPool:
if free_index.numel() == 0:
return
self.free_slots = torch.cat((self.free_slots, free_index))
self.mamba_cache.conv[:, free_index] = self.mamba_cache.temporal[
:, free_index
] = 0
if self.is_kda_cache:
for i in range(len(self.mamba_cache.conv)):
self.mamba_cache.conv[i][:, free_index] = 0
else:
self.mamba_cache.conv[:, free_index] = 0
self.mamba_cache.temporal[:, free_index] = 0
def clear(self):
self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device)
def copy_from(self, src_index: torch.Tensor, dst_index: torch.Tensor):
self.mamba_cache.conv[:, dst_index] = self.mamba_cache.conv[:, src_index]
if self.is_kda_cache:
for i in range(len(self.mamba_cache.conv)):
self.mamba_cache.conv[i][:, dst_index] = self.mamba_cache.conv[i][
:, src_index
]
else:
self.mamba_cache.conv[:, dst_index] = self.mamba_cache.conv[:, src_index]
self.mamba_cache.temporal[:, dst_index] = self.mamba_cache.temporal[
:, src_index
]
......@@ -304,7 +349,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
max_context_len: int,
device: str,
enable_memory_saver: bool,
cache_params: "Mamba2CacheParams",
cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"],
speculative_num_draft_tokens: int = None,
):
super().__init__(
......@@ -323,7 +368,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
def _init_mamba_pool(
self,
size: int,
cache_params: "Mamba2CacheParams",
cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"],
device: str,
speculative_num_draft_tokens: int = None,
):
......@@ -812,6 +857,10 @@ class HybridLinearKVPool(KVCache):
enable_kvcache_transpose: bool,
device: str,
mamba_pool: MambaPool,
# TODO: refactor mla related args
use_mla: bool = False,
kv_lora_rank: int = None,
qk_rope_head_dim: int = None,
):
self.size = size
self.dtype = dtype
......@@ -825,25 +874,42 @@ class HybridLinearKVPool(KVCache):
self.mamba_pool = mamba_pool
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
assert not enable_kvcache_transpose
if _is_npu:
TokenToKVPoolClass = AscendTokenToKVPool
self.use_mla = use_mla
if not use_mla:
if _is_npu:
TokenToKVPoolClass = AscendTokenToKVPool
else:
TokenToKVPoolClass = MHATokenToKVPool
self.full_kv_pool = TokenToKVPoolClass(
size=size,
page_size=self.page_size,
dtype=dtype,
head_num=head_num,
head_dim=head_dim,
layer_num=self.full_layer_nums,
device=device,
enable_memory_saver=False,
)
else:
TokenToKVPoolClass = MHATokenToKVPool
self.full_kv_pool = TokenToKVPoolClass(
size=size,
page_size=self.page_size,
dtype=dtype,
head_num=head_num,
head_dim=head_dim,
layer_num=self.full_layer_nums,
device=device,
enable_memory_saver=False,
)
TokenToKVPoolClass = MLATokenToKVPool
self.full_kv_pool = TokenToKVPoolClass(
size=size,
page_size=self.page_size,
dtype=dtype,
layer_num=self.full_layer_nums,
device=device,
kv_lora_rank=kv_lora_rank,
qk_rope_head_dim=qk_rope_head_dim,
enable_memory_saver=False,
)
self.full_attention_layer_id_mapping = {
id: i for i, id in enumerate(full_attention_layer_ids)
}
k_size, v_size = self.get_kv_size_bytes()
self.mem_usage = (k_size + v_size) / GB
if use_mla:
self.mem_usage = self.get_kv_size_bytes() / GB
else:
k_size, v_size = self.get_kv_size_bytes()
self.mem_usage = (k_size + v_size) / GB
def get_kv_size_bytes(self):
return self.full_kv_pool.get_kv_size_bytes()
......@@ -879,6 +945,21 @@ class HybridLinearKVPool(KVCache):
layer_id = self._transfer_full_attention_id(layer_id)
return self.full_kv_pool.get_kv_buffer(layer_id)
@contextmanager
def _transfer_id_context(self, layer: RadixAttention):
@contextmanager
def _patch_layer_id(layer):
original_layer_id = layer.layer_id
layer.layer_id = self._transfer_full_attention_id(layer.layer_id)
try:
yield
finally:
layer.layer_id = original_layer_id
with _patch_layer_id(layer):
yield
def set_kv_buffer(
self,
layer: RadixAttention,
......@@ -889,19 +970,49 @@ class HybridLinearKVPool(KVCache):
v_scale: float = 1.0,
):
layer_id = self._transfer_full_attention_id(layer.layer_id)
self.full_kv_pool.set_kv_buffer(
None,
loc,
cache_k,
cache_v,
k_scale,
v_scale,
layer_id_override=layer_id,
)
if not self.use_mla:
self.full_kv_pool.set_kv_buffer(
None,
loc,
cache_k,
cache_v,
k_scale,
v_scale,
layer_id_override=layer_id,
)
else:
with self._transfer_id_context(layer):
self.full_kv_pool.set_kv_buffer(
layer,
loc,
cache_k,
cache_v,
)
def get_v_head_dim(self):
return self.full_kv_pool.get_value_buffer(0).shape[-1]
def set_mla_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k_nope: torch.Tensor,
cache_k_rope: torch.Tensor,
):
assert self.use_mla, "set_mla_kv_buffer called when use_mla is False"
with self._transfer_id_context(layer):
self.full_kv_pool.set_mla_kv_buffer(layer, loc, cache_k_nope, cache_k_rope)
def get_mla_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
dst_dtype: Optional[torch.dtype] = None,
):
assert self.use_mla, "get_mla_kv_buffer called when use_mla is False"
with self._transfer_id_context(layer):
return self.full_kv_pool.get_mla_kv_buffer(layer, loc, dst_dtype)
class SWAKVPool(KVCache):
"""KV cache with separate pools for full and SWA attention layers."""
......
......@@ -29,7 +29,12 @@ from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from sglang.srt.configs import FalconH1Config, NemotronHConfig, Qwen3NextConfig
from sglang.srt.configs import (
FalconH1Config,
KimiLinearConfig,
NemotronHConfig,
Qwen3NextConfig,
)
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
from sglang.srt.configs.model_config import (
......@@ -1358,9 +1363,16 @@ class ModelRunner:
return config
return None
@property
def kimi_linear_config(self):
config = self.model_config.hf_config
if isinstance(config, KimiLinearConfig):
return config
return None
@property
def mambaish_config(self):
return self.mamba2_config or self.hybrid_gdn_config
return self.mamba2_config or self.hybrid_gdn_config or self.kimi_linear_config
def set_num_token_hybrid(self):
if (
......@@ -1691,7 +1703,7 @@ class ModelRunner:
end_layer=self.end_layer,
index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config),
)
elif self.use_mla_backend:
elif self.use_mla_backend and not self.mambaish_config:
assert not is_nsa_model
self.token_to_kv_pool = MLATokenToKVPool(
self.max_total_num_tokens,
......@@ -1735,6 +1747,12 @@ class ModelRunner:
device=self.device,
)
elif config := self.mambaish_config:
extra_args = {}
if self.use_mla_backend:
extra_args = {
"kv_lora_rank": self.model_config.kv_lora_rank,
"qk_rope_head_dim": self.model_config.qk_rope_head_dim,
}
self.token_to_kv_pool = HybridLinearKVPool(
page_size=self.page_size,
size=self.max_total_num_tokens,
......@@ -1750,6 +1768,8 @@ class ModelRunner:
enable_kvcache_transpose=False,
device=self.device,
mamba_pool=self.req_to_token_pool.mamba_pool,
use_mla=self.use_mla_backend,
**extra_args,
)
else:
self.token_to_kv_pool = MHATokenToKVPool(
......
......@@ -1075,6 +1075,7 @@ class DeepseekV2AttentionMLA(nn.Module):
layer_id: int = None,
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
skip_rope: bool = False,
) -> None:
super().__init__()
self.layer_id = layer_id
......@@ -1182,23 +1183,26 @@ class DeepseekV2AttentionMLA(nn.Module):
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
self.rotary_emb = get_rope_wrapper(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False,
device=get_global_server_args().device,
)
if not skip_rope:
self.rotary_emb = get_rope_wrapper(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False,
device=get_global_server_args().device,
)
if rope_scaling:
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
scaling_factor = rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
if rope_scaling:
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
scaling_factor = rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
else:
self.rotary_emb.forward = self.rotary_emb.forward_native
else:
self.rotary_emb.forward = self.rotary_emb.forward_native
self.rotary_emb = None
self.attn_mqa = RadixAttention(
self.num_local_heads,
......@@ -1487,7 +1491,8 @@ class DeepseekV2AttentionMLA(nn.Module):
latent_cache = latent_cache.unsqueeze(1)
kv_a = self.kv_a_layernorm(kv_a)
k_pe = latent_cache[:, :, self.kv_lora_rank :]
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
if self.rotary_emb is not None:
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
q[..., self.qk_nope_head_dim :] = q_pe
self._set_mla_kv_buffer(latent_cache, kv_a, k_pe, forward_batch)
......@@ -1646,8 +1651,10 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out = q_nope_out.transpose(0, 1)
if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
not _use_aiter or not _is_gfx95_supported or self.use_nsa
if (
self.rotary_emb is not None
and (not self._fuse_rope_for_trtllm_mla(forward_batch))
and (not _use_aiter or not _is_gfx95_supported or self.use_nsa)
):
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
......
# Adapted from: https://github.com/vllm-project/vllm/blob/0384aa7150c4c9778efca041ffd1beb3ad2bd694/vllm/model_executor/models/kimi_linear.py
from collections.abc import Iterable
from typing import Optional
import torch
from einops import rearrange
from torch import nn
from sglang.srt.configs.kimi_linear import KimiLinearConfig
from sglang.srt.distributed import (
divide,
get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.attention.fla.kda import FusedRMSNormGated
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.utils import PPMissingLayer
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
sharded_weight_loader,
)
from sglang.srt.models.deepseek_v2 import DeepseekV2AttentionMLA as KimiMLAAttention
from sglang.srt.models.llama import LlamaMLP as KimiMLP
from sglang.srt.models.transformers import maybe_prefix
from sglang.srt.utils import make_layers
from sglang.srt.utils.common import BumpAllocator, add_prefix, set_weight_attrs
class KimiMoE(nn.Module):
def __init__(
self,
config: KimiLinearConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
layer_idx: int = 0,
):
super().__init__()
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size
moe_intermediate_size = config.moe_intermediate_size
num_experts = config.num_experts
moe_renormalize = config.moe_renormalize
self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor
self.num_shared_experts = config.num_shared_experts
self.layer_idx = layer_idx
if config.hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {config.hidden_act}. "
"Only silu is supported for now."
)
# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(
hidden_size,
num_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate",
)
self.gate.e_score_correction_bias = nn.Parameter(torch.empty(num_experts))
self.experts = get_moe_impl_class(quant_config)(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_token,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
layer_id=self.layer_idx,
quant_config=quant_config,
routed_scaling_factor=self.routed_scaling_factor,
prefix=add_prefix("experts", prefix),
)
self.topk = TopK(
top_k=config.num_experts_per_token,
renormalize=moe_renormalize,
use_grouped_topk=True,
num_expert_group=config.num_expert_group,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
quant_config=quant_config,
routed_scaling_factor=self.routed_scaling_factor,
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk,
# Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
# and requires the output format to be standard. We use quant_config to determine the output format.
output_format=TopKOutputFormat.STANDARD if quant_config is None else None,
)
if self.num_shared_experts is not None:
intermediate_size = moe_intermediate_size * self.num_shared_experts
self.shared_experts = KimiMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_size)
if self.num_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(hidden_states, topk_output)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_size)
class KimiDeltaAttention(nn.Module):
def __init__(
self,
layer_idx: int,
hidden_size: int,
config: KimiLinearConfig,
quant_config: Optional[QuantizationConfig] = None,
rms_norm_eps: float = 1e-5,
prefix: str = "",
**kwargs,
) -> None:
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.hidden_size = hidden_size
self.config = config
self.head_dim = config.linear_attn_config["head_dim"]
self.num_heads = config.linear_attn_config["num_heads"]
self.layer_idx = layer_idx
self.prefix = prefix
assert self.num_heads % self.tp_size == 0
self.local_num_heads = divide(self.num_heads, self.tp_size)
projection_size = self.head_dim * self.num_heads
self.conv_size = config.linear_attn_config["short_conv_kernel_size"]
self.q_proj = ColumnParallelLinear(
self.hidden_size,
projection_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
)
self.k_proj = ColumnParallelLinear(
self.hidden_size,
projection_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.k_proj",
)
self.v_proj = ColumnParallelLinear(
self.hidden_size,
projection_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.v_proj",
)
self.f_a_proj = ReplicatedLinear(
self.hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.f_a_proj",
)
self.f_b_proj = ColumnParallelLinear(
self.head_dim,
projection_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.f_b_proj",
)
self.dt_bias = nn.Parameter(
torch.empty(divide(projection_size, self.tp_size), dtype=torch.float32)
)
set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
self.b_proj = ColumnParallelLinear(
self.hidden_size,
self.num_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.b_proj",
)
self.q_conv1d = ColumnParallelLinear(
input_size=self.conv_size,
output_size=projection_size,
bias=False,
params_dtype=torch.float32,
prefix=f"{prefix}.q_conv1d",
)
self.k_conv1d = ColumnParallelLinear(
input_size=self.conv_size,
output_size=projection_size,
bias=False,
params_dtype=torch.float32,
prefix=f"{prefix}.k_conv1d",
)
self.v_conv1d = ColumnParallelLinear(
input_size=self.conv_size,
output_size=projection_size,
bias=False,
params_dtype=torch.float32,
prefix=f"{prefix}.v_conv1d",
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self.q_conv1d.weight.data = self.q_conv1d.weight.data.unsqueeze(1)
self.k_conv1d.weight.data = self.k_conv1d.weight.data.unsqueeze(1)
self.v_conv1d.weight.data = self.v_conv1d.weight.data.unsqueeze(1)
self.A_log = nn.Parameter(
torch.empty(1, 1, self.local_num_heads, 1, dtype=torch.float32)
)
set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(2)})
self.g_a_proj = ReplicatedLinear(
self.hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.g_a_proj",
)
self.g_b_proj = ColumnParallelLinear(
self.head_dim,
projection_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.g_b_proj",
)
self.o_norm = FusedRMSNormGated(
self.head_dim, eps=rms_norm_eps, activation="sigmoid"
)
self.o_proj = RowParallelLinear(
projection_size,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
def forward(
self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
) -> None:
q_proj_states = self.q_proj(hidden_states)[0]
k_proj_states = self.k_proj(hidden_states)[0]
v_proj_states = self.v_proj(hidden_states)[0]
q_conv_weights = self.q_conv1d.weight.view(
self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2)
)
k_conv_weights = self.k_conv1d.weight.view(
self.k_conv1d.weight.size(0), self.k_conv1d.weight.size(2)
)
v_conv_weights = self.v_conv1d.weight.view(
self.v_conv1d.weight.size(0), self.v_conv1d.weight.size(2)
)
kwargs = {
"q_proj_states": q_proj_states,
"k_proj_states": k_proj_states,
"v_proj_states": v_proj_states,
"q_conv_weights": q_conv_weights,
"k_conv_weights": k_conv_weights,
"v_conv_weights": v_conv_weights,
"q_conv_bias": self.q_conv1d.bias,
"k_conv_bias": self.k_conv1d.bias,
"v_conv_bias": self.v_conv1d.bias,
"dt_bias": self.dt_bias,
"b_proj": self.b_proj,
"f_a_proj": self.f_a_proj,
"f_b_proj": self.f_b_proj,
"A_log": self.A_log,
"head_dim": self.head_dim,
"hidden_states": hidden_states,
"layer_id": self.layer_idx,
}
core_attn_out = forward_batch.attn_backend.forward(
q=None,
k=None,
v=None,
layer=None,
forward_batch=forward_batch,
**kwargs,
)
g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0]
g = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim)
core_attn_out = self.o_norm(core_attn_out, g)
core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")
return self.o_proj(core_attn_out)[0]
class KimiDecoderLayer(nn.Module):
def __init__(
self,
config: KimiLinearConfig,
layer_idx: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.is_moe = config.is_moe
if config.is_kda_layer(layer_idx):
self.self_attn = KimiDeltaAttention(
layer_idx=layer_idx,
hidden_size=config.hidden_size,
config=config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
else:
self.self_attn = KimiMLAAttention(
layer_id=layer_idx,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
config=config,
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
v_head_dim=config.v_head_dim,
q_lora_rank=config.q_lora_rank,
kv_lora_rank=config.kv_lora_rank,
skip_rope=True,
)
if (
self.is_moe
and config.num_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0
):
self.block_sparse_moe = KimiMoE(
config=config,
quant_config=quant_config,
layer_idx=layer_idx,
prefix=f"{prefix}.mlp",
)
self.mlp = self.block_sparse_moe
else:
self.mlp = KimiMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator,
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
hidden_states=hidden_states,
positions=positions,
forward_batch=forward_batch,
zero_allocator=zero_allocator,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class KimiLinearModel(nn.Module):
def __init__(
self,
config: KimiLinearConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.pp_group = get_pp_group()
if self.pp_group.is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
prefix=f"{prefix}.embed_tokens",
)
else:
self.embed_tokens = PPMissingLayer()
self.layers, self.start_layer, self.end_layer = make_layers(
config.num_hidden_layers,
lambda idx, prefix: KimiDecoderLayer(
layer_idx=idx,
config=config,
quant_config=quant_config,
prefix=prefix,
),
pp_rank=self.pp_group.rank_in_group,
pp_size=self.pp_group.world_size,
prefix=f"{prefix}.layers",
)
if self.pp_group.is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
world_size = get_tensor_model_parallel_world_size()
assert (
config.num_attention_heads % world_size == 0
), "num_attention_heads must be divisible by world_size"
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
forward_batch: ForwardBatch,
inputs_embeds: torch.Tensor | None = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> torch.Tensor:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_tokens(input_ids)
residual = None
else:
assert pp_proxy_tensors is not None
hidden_states = pp_proxy_tensors["hidden_states"]
residual = pp_proxy_tensors["residual"]
total_num_layers = self.end_layer - self.start_layer
device = hidden_states.device
zero_allocator = BumpAllocator(
buffer_size=total_num_layers * 2,
dtype=torch.float32,
device=device,
)
# TODO: capture aux hidden states
aux_hidden_states = []
for i in range(self.start_layer, self.end_layer):
ctx = get_global_expert_distribution_recorder().with_current_layer(i)
with ctx:
layer = self.layers[i]
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
residual=residual,
zero_allocator=zero_allocator,
)
if not self.pp_group.is_last_rank:
return PPProxyTensors(
{
"hidden_states": hidden_states,
"residual": residual,
}
)
else:
if hidden_states.shape[0] != 0:
if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) == 0:
return hidden_states
return hidden_states, aux_hidden_states
class KimiLinearForCausalLM(nn.Module):
def __init__(
self,
config: KimiLinearConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = KimiLinearModel(
config, quant_config, prefix=maybe_prefix(prefix, "model")
)
self.pp_group = get_pp_group()
if self.pp_group.is_last_rank:
self.lm_head = ParallelLMHead(
self.config.vocab_size,
self.config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
else:
self.lm_head = PPMissingLayer()
logit_scale = getattr(self.config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(config=config, logit_scale=logit_scale)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
inputs_embeds: Optional[torch.Tensor] = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(
input_ids,
positions,
forward_batch,
inputs_embeds,
pp_proxy_tensors,
)
if self.pp_group.is_last_rank:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
else:
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
if self.config.is_moe:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
num_experts=self.config.num_experts,
)
else:
expert_params_mapping = []
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for args in weights:
name, loaded_weight = args[:2]
kwargs = args[2] if len(args) > 2 else {}
if "rotary_emb.inv_freq" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# if is_pp_missing_parameter(name, self):
# continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for idx, (param_name, weight_name, expert_id, shard_id) in enumerate(
expert_params_mapping
):
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# if is_pp_missing_parameter(name, self):
# continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
expert_id=expert_id,
shard_id=shard_id,
)
break
else:
# Skip loading extra bias for GPTQ models.
if (
name.endswith(".bias")
and name not in params_dict
and not self.config.is_linear_attn
): # noqa: E501
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
# if is_pp_missing_parameter(name, self):
# continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight, **kwargs)
loaded_params.add(name)
for layer_id in self.config.full_attention_layer_ids:
self_attn = self.model.layers[layer_id].self_attn
w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
if hasattr(self_attn.kv_b_proj, "weight_scale"):
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
EntryClass = KimiLinearForCausalLM
......@@ -1028,6 +1028,11 @@ class ServerArgs:
logger.info(
f"Using {self.attention_backend} as attention backend for {model_arch}."
)
elif model_arch in ["KimiLinearForCausalLM"]:
logger.warning(
f"Disabling Radix Cache for {model_arch} as it is not yet supported."
)
self.disable_radix_cache = True
if is_deepseek_nsa(hf_config):
if (
......
......@@ -43,6 +43,7 @@ from sglang.srt.configs import (
DotsVLMConfig,
ExaoneConfig,
FalconH1Config,
KimiLinearConfig,
KimiVLConfig,
LongcatFlashConfig,
MultiModalityConfig,
......@@ -68,6 +69,7 @@ _CONFIG_REGISTRY: List[Type[PretrainedConfig]] = [
Step3VLConfig,
LongcatFlashConfig,
Olmo3Config,
KimiLinearConfig,
Qwen3NextConfig,
FalconH1Config,
DotsVLMConfig,
......
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestKimiLinear(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "moonshotai/Kimi-Linear-48B-A3B-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--tp-size", "2", "--trust-remote"],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.88)
if __name__ == "__main__":
unittest.main()
......@@ -151,6 +151,7 @@ suites = {
TestFile("layers/attention/mamba/test_mamba2_mixer.py", 50),
TestFile("lora/test_lora_tp.py", 116),
TestFile("models/test_glm4_moe_models.py", 100),
TestFile("models/test_kimi_linear_models.py", 90),
TestFile("rl/test_update_weights_from_distributed.py", 103),
TestFile("test_data_parallelism.py", 73),
TestFile("test_disaggregation_basic.py", 400),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment