Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a4bf5c6a
Unverified
Commit
a4bf5c6a
authored
Nov 01, 2025
by
Ke Bao
Committed by
GitHub
Oct 31, 2025
Browse files
Support Kimi Linear (#12469)
Co-authored-by:
yizhang2077
<
1109276519@qq.com
>
parent
30ad1070
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
2847 additions
and
112 deletions
+2847
-112
python/sglang/srt/configs/__init__.py
python/sglang/srt/configs/__init__.py
+2
-0
python/sglang/srt/configs/kimi_linear.py
python/sglang/srt/configs/kimi_linear.py
+160
-0
python/sglang/srt/configs/mamba_utils.py
python/sglang/srt/configs/mamba_utils.py
+66
-0
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+7
-0
python/sglang/srt/layers/attention/attention_registry.py
python/sglang/srt/layers/attention/attention_registry.py
+3
-0
python/sglang/srt/layers/attention/fla/chunk_delta_h.py
python/sglang/srt/layers/attention/fla/chunk_delta_h.py
+61
-32
python/sglang/srt/layers/attention/fla/fused_recurrent.py
python/sglang/srt/layers/attention/fla/fused_recurrent.py
+17
-4
python/sglang/srt/layers/attention/fla/kda.py
python/sglang/srt/layers/attention/fla/kda.py
+1359
-0
python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py
...sglang/srt/layers/attention/hybrid_linear_attn_backend.py
+223
-0
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+4
-1
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+165
-54
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+23
-3
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+25
-18
python/sglang/srt/models/kimi_linear.py
python/sglang/srt/models/kimi_linear.py
+678
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+5
-0
python/sglang/srt/utils/hf_transformers_utils.py
python/sglang/srt/utils/hf_transformers_utils.py
+2
-0
test/srt/models/test_kimi_linear_models.py
test/srt/models/test_kimi_linear_models.py
+46
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
No files found.
python/sglang/srt/configs/__init__.py
View file @
a4bf5c6a
...
...
@@ -6,6 +6,7 @@ from sglang.srt.configs.dots_vlm import DotsVLMConfig
from
sglang.srt.configs.exaone
import
ExaoneConfig
from
sglang.srt.configs.falcon_h1
import
FalconH1Config
from
sglang.srt.configs.janus_pro
import
MultiModalityConfig
from
sglang.srt.configs.kimi_linear
import
KimiLinearConfig
from
sglang.srt.configs.kimi_vl
import
KimiVLConfig
from
sglang.srt.configs.kimi_vl_moonvit
import
MoonViTConfig
from
sglang.srt.configs.longcat_flash
import
LongcatFlashConfig
...
...
@@ -31,6 +32,7 @@ __all__ = [
"Step3TextConfig"
,
"Step3VisionEncoderConfig"
,
"Olmo3Config"
,
"KimiLinearConfig"
,
"Qwen3NextConfig"
,
"DotsVLMConfig"
,
"DotsOCRConfig"
,
...
...
python/sglang/srt/configs/kimi_linear.py
0 → 100644
View file @
a4bf5c6a
# Adapted from: https://github.com/vllm-project/vllm/blob/0384aa7150c4c9778efca041ffd1beb3ad2bd694/vllm/transformers_utils/configs/kimi_linear.py
from
transformers.configuration_utils
import
PretrainedConfig
from
sglang.srt.configs.mamba_utils
import
KimiLinearCacheParams
,
KimiLinearStateShape
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
class
KimiLinearConfig
(
PretrainedConfig
):
model_type
=
"kimi_linear"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
def
__init__
(
self
,
model_type
=
"kimi_linear"
,
vocab_size
=
163840
,
hidden_size
=
4096
,
head_dim
=
None
,
intermediate_size
=
11008
,
num_hidden_layers
=
32
,
num_attention_heads
=
32
,
num_key_value_heads
=
None
,
hidden_act
=
"silu"
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-6
,
use_cache
=
True
,
pad_token_id
=
0
,
bos_token_id
=
1
,
eos_token_id
=
2
,
rope_theta
=
10000.0
,
rope_scaling
=
None
,
tie_word_embeddings
=
False
,
moe_intermediate_size
:
int
|
None
=
None
,
moe_renormalize
:
bool
=
True
,
moe_router_activation_func
:
str
=
"sigmoid"
,
num_experts
:
int
|
None
=
None
,
num_experts_per_token
:
int
|
None
=
None
,
num_shared_experts
:
int
=
0
,
routed_scaling_factor
:
float
=
1.0
,
first_k_dense_replace
:
int
=
0
,
moe_layer_freq
:
int
=
1
,
use_grouped_topk
:
bool
=
True
,
num_expert_group
:
int
=
1
,
topk_group
:
int
=
1
,
q_lora_rank
:
int
|
None
=
None
,
kv_lora_rank
:
int
|
None
=
None
,
qk_nope_head_dim
:
int
|
None
=
None
,
qk_rope_head_dim
:
int
|
None
=
None
,
v_head_dim
:
int
|
None
=
None
,
mla_use_nope
:
bool
|
None
=
False
,
num_nextn_predict_layers
:
int
=
0
,
linear_attn_config
:
dict
|
None
=
None
,
**
kwargs
,
):
self
.
model_type
=
model_type
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
head_dim
=
(
head_dim
if
head_dim
is
not
None
else
hidden_size
//
num_attention_heads
)
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
# for backward compatibility
if
num_key_value_heads
is
None
:
num_key_value_heads
=
num_attention_heads
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
q_lora_rank
=
q_lora_rank
self
.
kv_lora_rank
=
kv_lora_rank
self
.
qk_nope_head_dim
=
qk_nope_head_dim
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
v_head_dim
=
v_head_dim
self
.
mla_use_nope
=
mla_use_nope
# moe config
self
.
n_routed_experts
=
self
.
num_experts
=
num_experts
self
.
num_experts_per_token
=
num_experts_per_token
self
.
moe_renormalize
=
moe_renormalize
self
.
num_shared_experts
=
num_shared_experts
self
.
routed_scaling_factor
=
routed_scaling_factor
self
.
moe_router_activation_func
=
moe_router_activation_func
assert
self
.
moe_router_activation_func
in
(
"softmax"
,
"sigmoid"
)
self
.
moe_intermediate_size
=
moe_intermediate_size
self
.
first_k_dense_replace
=
first_k_dense_replace
self
.
moe_layer_freq
=
moe_layer_freq
self
.
use_grouped_topk
=
use_grouped_topk
self
.
num_expert_group
=
num_expert_group
self
.
topk_group
=
topk_group
self
.
num_nextn_predict_layers
=
num_nextn_predict_layers
if
linear_attn_config
is
not
None
:
assert
linear_attn_config
[
"kda_layers"
]
is
not
None
assert
linear_attn_config
[
"full_attn_layers"
]
is
not
None
self
.
linear_attn_config
=
linear_attn_config
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
,
)
@
property
def
is_mla
(
self
):
return
(
self
.
q_lora_rank
is
not
None
or
self
.
kv_lora_rank
is
not
None
or
self
.
qk_nope_head_dim
is
not
None
or
self
.
qk_rope_head_dim
is
not
None
or
self
.
v_head_dim
is
not
None
or
self
.
mla_use_nope
is
True
)
@
property
def
is_moe
(
self
):
return
self
.
num_experts
is
not
None
@
property
def
is_linear_attn
(
self
)
->
bool
:
return
not
(
self
.
linear_attn_config
is
None
or
(
isinstance
(
self
.
linear_attn_config
,
dict
)
and
self
.
linear_attn_config
[
"kda_layers"
]
is
not
None
and
len
(
self
.
linear_attn_config
[
"kda_layers"
])
==
0
)
)
def
is_kda_layer
(
self
,
layer_idx
:
int
):
return
(
self
.
linear_attn_config
is
not
None
and
(
layer_idx
+
1
)
in
self
.
linear_attn_config
[
"kda_layers"
]
)
@
property
def
linear_layer_ids
(
self
):
return
[
i
for
i
in
range
(
self
.
num_hidden_layers
)
if
self
.
is_kda_layer
(
i
)]
@
property
def
full_attention_layer_ids
(
self
):
return
[
i
for
i
in
range
(
self
.
num_hidden_layers
)
if
not
self
.
is_kda_layer
(
i
)]
@
property
def
mamba2_cache_params
(
self
)
->
KimiLinearCacheParams
:
shape
=
KimiLinearStateShape
.
create
(
tp_world_size
=
get_attention_tp_size
(),
num_heads
=
self
.
linear_attn_config
[
"num_heads"
],
head_dim
=
self
.
linear_attn_config
[
"head_dim"
],
conv_kernel_size
=
self
.
linear_attn_config
[
"short_conv_kernel_size"
],
)
return
KimiLinearCacheParams
(
shape
=
shape
,
layers
=
self
.
linear_layer_ids
)
python/sglang/srt/configs/mamba_utils.py
View file @
a4bf5c6a
...
...
@@ -14,6 +14,7 @@
import
os
from
dataclasses
import
dataclass
,
field
from
typing
import
List
,
Optional
import
numpy
as
np
import
torch
...
...
@@ -115,3 +116,68 @@ class Mamba2CacheParams:
int
(
np
.
prod
(
self
.
shape
.
conv
))
*
self
.
dtype
.
conv
.
itemsize
+
int
(
np
.
prod
(
self
.
shape
.
temporal
))
*
self
.
dtype
.
temporal
.
itemsize
)
*
len
(
self
.
layers
)
@
dataclass
(
kw_only
=
True
,
frozen
=
True
)
class
KimiLinearStateShape
:
conv
:
List
[
tuple
[
int
,
int
]]
temporal
:
tuple
[
int
,
int
,
int
]
num_heads
:
int
head_dim
:
int
num_k_heads
:
int
head_k_dim
:
int
conv_kernel
:
int
num_spec
:
int
@
staticmethod
def
create
(
*
,
tp_world_size
:
int
,
num_heads
:
int
,
head_dim
:
int
,
num_k_heads
:
Optional
[
int
]
=
None
,
head_k_dim
:
Optional
[
int
]
=
None
,
conv_kernel_size
:
int
=
4
,
num_spec
:
int
=
0
,
)
->
"KimiLinearStateShape"
:
if
num_k_heads
is
None
:
num_k_heads
=
num_heads
if
head_k_dim
is
None
:
head_k_dim
=
head_dim
proj_size
=
num_heads
*
head_dim
proj_k_size
=
num_k_heads
*
head_k_dim
conv_state_shape
=
(
divide
(
proj_size
,
tp_world_size
),
conv_kernel_size
-
1
)
conv_state_k_shape
=
(
divide
(
proj_k_size
,
tp_world_size
),
conv_kernel_size
-
1
)
temporal_state_shape
=
(
divide
(
num_heads
,
tp_world_size
),
head_dim
,
head_dim
)
conv_state_shape
=
conv_state_shape
[
1
],
conv_state_shape
[
0
]
conv_state_k_shape
=
conv_state_k_shape
[
1
],
conv_state_k_shape
[
0
]
return
KimiLinearStateShape
(
conv
=
[
conv_state_shape
,
conv_state_k_shape
,
conv_state_k_shape
],
temporal
=
temporal_state_shape
,
num_heads
=
num_heads
,
head_dim
=
head_dim
,
num_k_heads
=
num_k_heads
,
head_k_dim
=
head_k_dim
,
conv_kernel
=
conv_kernel_size
,
num_spec
=
num_spec
,
)
@
dataclass
(
kw_only
=
True
,
frozen
=
True
)
class
KimiLinearCacheParams
:
shape
:
KimiLinearStateShape
dtype
:
Mamba2StateDType
=
field
(
default_factory
=
mamba2_state_dtype
)
layers
:
list
[
int
]
@
property
def
mamba_cache_per_req
(
self
)
->
int
:
return
(
int
(
np
.
sum
([
np
.
prod
(
conv_shape
)
for
conv_shape
in
self
.
shape
.
conv
]))
*
self
.
dtype
.
conv
.
itemsize
+
int
(
np
.
prod
(
self
.
shape
.
temporal
))
*
self
.
dtype
.
temporal
.
itemsize
)
*
len
(
self
.
layers
)
python/sglang/srt/configs/model_config.py
View file @
a4bf5c6a
...
...
@@ -366,6 +366,13 @@ class ModelConfig:
self
.
qk_rope_head_dim
=
self
.
hf_text_config
.
qk_rope_head_dim
self
.
v_head_dim
=
self
.
hf_text_config
.
v_head_dim
self
.
qk_nope_head_dim
=
self
.
hf_text_config
.
qk_nope_head_dim
elif
"KimiLinearForCausalLM"
in
self
.
hf_config
.
architectures
:
self
.
head_dim
=
72
self
.
attention_arch
=
AttentionArch
.
MLA
self
.
kv_lora_rank
=
self
.
hf_config
.
kv_lora_rank
self
.
qk_rope_head_dim
=
self
.
hf_config
.
qk_rope_head_dim
self
.
v_head_dim
=
self
.
hf_config
.
v_head_dim
self
.
qk_nope_head_dim
=
self
.
hf_config
.
qk_nope_head_dim
else
:
if
(
"MistralModel"
in
self
.
hf_config
.
architectures
...
...
python/sglang/srt/layers/attention/attention_registry.py
View file @
a4bf5c6a
...
...
@@ -189,6 +189,7 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac
from
sglang.srt.layers.attention.hybrid_linear_attn_backend
import
(
GDNAttnBackend
,
HybridLinearAttnBackend
,
KimiLinearAttnBackend
,
Mamba2AttnBackend
,
)
from
sglang.srt.utils
import
is_blackwell
,
is_npu
...
...
@@ -207,6 +208,8 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac
linear_attn_backend
=
GDNAttnBackend
(
runner
)
elif
runner
.
mamba2_config
is
not
None
:
linear_attn_backend
=
Mamba2AttnBackend
(
runner
)
elif
runner
.
kimi_linear_config
is
not
None
:
linear_attn_backend
=
KimiLinearAttnBackend
(
runner
)
else
:
raise
ValueError
(
"Expected hybrid GDN or NemotronH models, but got unknown model."
...
...
python/sglang/srt/layers/attention/fla/chunk_delta_h.py
View file @
a4bf5c6a
...
...
@@ -21,6 +21,7 @@ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
@
triton
.
heuristics
(
{
"USE_G"
:
lambda
args
:
args
[
"g"
]
is
not
None
,
"USE_GK"
:
lambda
args
:
args
[
"gk"
]
is
not
None
,
"USE_INITIAL_STATE"
:
lambda
args
:
args
[
"h0"
]
is
not
None
,
"STORE_FINAL_STATE"
:
lambda
args
:
args
[
"ht"
]
is
not
None
,
"SAVE_NEW_VALUE"
:
lambda
args
:
args
[
"v_new"
]
is
not
None
,
...
...
@@ -44,6 +45,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
w
,
v_new
,
g
,
gk
,
h
,
h0
,
ht
,
...
...
@@ -57,6 +59,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
BT
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
USE_G
:
tl
.
constexpr
,
USE_GK
:
tl
.
constexpr
,
USE_INITIAL_STATE
:
tl
.
constexpr
,
STORE_FINAL_STATE
:
tl
.
constexpr
,
SAVE_NEW_VALUE
:
tl
.
constexpr
,
...
...
@@ -86,12 +89,12 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
b_h4
=
tl
.
zeros
([
64
,
BV
],
dtype
=
tl
.
float32
)
# calculate offset
h
+=
(
boh
*
H
+
i_h
)
*
K
*
V
v
+=
(
bos
*
H
+
i_h
)
*
V
k
+=
(
bos
*
Hg
+
i_h
//
(
H
//
Hg
))
*
K
w
+=
(
bos
*
H
+
i_h
)
*
K
h
+=
(
(
boh
*
H
+
i_h
)
*
K
*
V
).
to
(
tl
.
int64
)
v
+=
(
(
bos
*
H
+
i_h
)
*
V
).
to
(
tl
.
int64
)
k
+=
(
(
bos
*
Hg
+
i_h
//
(
H
//
Hg
))
*
K
).
to
(
tl
.
int64
)
w
+=
(
(
bos
*
H
+
i_h
)
*
K
).
to
(
tl
.
int64
)
if
SAVE_NEW_VALUE
:
v_new
+=
(
bos
*
H
+
i_h
)
*
V
v_new
+=
(
(
bos
*
H
+
i_h
)
*
V
).
to
(
tl
.
int64
)
stride_v
=
H
*
V
stride_h
=
H
*
K
*
V
stride_k
=
Hg
*
K
...
...
@@ -143,58 +146,48 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
)
tl
.
store
(
p_h4
,
b_h4
.
to
(
p_h4
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
p_v
=
tl
.
make_block_ptr
(
v
,
(
T
,
V
),
(
stride_v
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
)
)
p_v_new
=
(
tl
.
make_block_ptr
(
v_new
,
(
T
,
V
),
(
stride_v
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
)
)
if
SAVE_NEW_VALUE
else
None
)
b_v_new
=
tl
.
zeros
([
BT
,
BV
],
dtype
=
tl
.
float32
)
p_w
=
tl
.
make_block_ptr
(
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
64
),
(
1
,
0
)
)
b_w
=
tl
.
load
(
p_w
,
boundary_check
=
(
0
,
1
))
b_v
_new
+
=
tl
.
dot
(
b_w
,
b_h1
.
to
(
b_w
.
dtype
))
b_v
=
tl
.
dot
(
b_w
,
b_h1
.
to
(
b_w
.
dtype
))
if
K
>
64
:
p_w
=
tl
.
make_block_ptr
(
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
64
),
(
BT
,
64
),
(
1
,
0
)
)
b_w
=
tl
.
load
(
p_w
,
boundary_check
=
(
0
,
1
))
b_v
_new
+=
tl
.
dot
(
b_w
,
b_h2
.
to
(
b_w
.
dtype
))
b_v
+=
tl
.
dot
(
b_w
,
b_h2
.
to
(
b_w
.
dtype
))
if
K
>
128
:
p_w
=
tl
.
make_block_ptr
(
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
128
),
(
BT
,
64
),
(
1
,
0
)
)
b_w
=
tl
.
load
(
p_w
,
boundary_check
=
(
0
,
1
))
b_v
_new
+=
tl
.
dot
(
b_w
,
b_h3
.
to
(
b_w
.
dtype
))
b_v
+=
tl
.
dot
(
b_w
,
b_h3
.
to
(
b_w
.
dtype
))
if
K
>
192
:
p_w
=
tl
.
make_block_ptr
(
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
192
),
(
BT
,
64
),
(
1
,
0
)
)
b_w
=
tl
.
load
(
p_w
,
boundary_check
=
(
0
,
1
))
b_v_new
+=
tl
.
dot
(
b_w
,
b_h4
.
to
(
b_w
.
dtype
))
b_v_new
=
-
b_v_new
+
tl
.
load
(
p_v
,
boundary_check
=
(
0
,
1
))
b_v
+=
tl
.
dot
(
b_w
,
b_h4
.
to
(
b_w
.
dtype
))
p_v
=
tl
.
make_block_ptr
(
v
,
(
T
,
V
),
(
stride_v
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
)
)
b_v
=
tl
.
load
(
p_v
,
boundary_check
=
(
0
,
1
))
-
b_v
if
SAVE_NEW_VALUE
:
p_v
_new
=
tl
.
make_block_ptr
(
p_v
=
tl
.
make_block_ptr
(
v_new
,
(
T
,
V
),
(
stride_v
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
)
)
tl
.
store
(
p_v_new
,
b_v_new
.
to
(
p_v_new
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
)
)
tl
.
store
(
p_v
,
b_v
.
to
(
p_v
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
last_idx
=
min
((
i_t
+
1
)
*
BT
,
T
)
-
1
if
USE_G
:
last_idx
=
min
((
i_t
+
1
)
*
BT
,
T
)
-
1
b_g_last
=
tl
.
load
(
g
+
bos
*
H
+
last_idx
*
H
+
i_h
)
p_g
=
tl
.
make_block_ptr
(
g
+
bos
*
H
+
i_h
,
(
T
,),
(
H
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,)
)
b_g
=
tl
.
load
(
p_g
,
boundary_check
=
(
0
,))
b_v
_new
=
b_v
_new
*
safe_exp
(
b_g_last
-
b_g
)[:,
None
]
b_v
=
b_v
*
safe_exp
(
b_g_last
-
b_g
)[:,
None
]
b_g_last
=
exp
(
b_g_last
)
b_h1
=
b_h1
*
b_g_last
if
K
>
64
:
...
...
@@ -203,30 +196,64 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
b_h3
=
b_h3
*
b_g_last
if
K
>
192
:
b_h4
=
b_h4
*
b_g_last
b_v_new
=
b_v_new
.
to
(
k
.
dtype
.
element_ty
)
if
USE_GK
:
o_k1
=
tl
.
arange
(
0
,
64
)
b_gk_last1
=
tl
.
load
(
gk
+
(
bos
+
last_idx
)
*
H
*
K
+
i_h
*
K
+
o_k1
,
mask
=
(
o_k1
<
K
),
other
=
0.0
,
)
b_h1
*=
exp
(
b_gk_last1
)[:,
None
]
if
K
>
64
:
o_k2
=
64
+
o_k1
b_gk_last2
=
tl
.
load
(
gk
+
(
bos
+
last_idx
)
*
H
*
K
+
i_h
*
K
+
o_k2
,
mask
=
(
o_k2
<
K
),
other
=
0.0
,
)
b_h2
*=
exp
(
b_gk_last2
)[:,
None
]
if
K
>
128
:
o_k3
=
128
+
o_k1
b_gk_last3
=
tl
.
load
(
gk
+
(
bos
+
last_idx
)
*
H
*
K
+
i_h
*
K
+
o_k3
,
mask
=
(
o_k3
<
K
),
other
=
0.0
,
)
b_h3
*=
exp
(
b_gk_last3
)[:,
None
]
if
K
>
192
:
o_k4
=
192
+
o_k1
b_gk_last4
=
tl
.
load
(
gk
+
(
bos
+
last_idx
)
*
H
*
K
+
i_h
*
K
+
o_k4
,
mask
=
(
o_k4
<
K
),
other
=
0.0
,
)
b_h4
*=
exp
(
b_gk_last4
)[:,
None
]
b_v
=
b_v
.
to
(
k
.
dtype
.
element_ty
)
p_k
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
0
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
)
)
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_h1
+=
tl
.
dot
(
b_k
,
b_v
_new
)
b_h1
+=
tl
.
dot
(
b_k
,
b_v
)
if
K
>
64
:
p_k
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
64
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
)
)
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_h2
+=
tl
.
dot
(
b_k
,
b_v
_new
)
b_h2
+=
tl
.
dot
(
b_k
,
b_v
)
if
K
>
128
:
p_k
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
128
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
)
)
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_h3
+=
tl
.
dot
(
b_k
,
b_v
_new
)
b_h3
+=
tl
.
dot
(
b_k
,
b_v
)
if
K
>
192
:
p_k
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
192
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
)
)
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_h4
+=
tl
.
dot
(
b_k
,
b_v
_new
)
b_h4
+=
tl
.
dot
(
b_k
,
b_v
)
# epilogue
if
STORE_FINAL_STATE
:
...
...
@@ -254,6 +281,7 @@ def chunk_gated_delta_rule_fwd_h(
w
:
torch
.
Tensor
,
u
:
torch
.
Tensor
,
g
:
Optional
[
torch
.
Tensor
]
=
None
,
gk
:
Optional
[
torch
.
Tensor
]
=
None
,
initial_state
:
Optional
[
torch
.
Tensor
]
=
None
,
output_final_state
:
bool
=
False
,
chunk_size
:
int
=
64
,
# SY: remove this argument and force chunk size 64?
...
...
@@ -296,6 +324,7 @@ def chunk_gated_delta_rule_fwd_h(
w
=
w
,
v_new
=
v_new
,
g
=
g
,
gk
=
gk
,
h
=
h
,
h0
=
initial_state
,
ht
=
final_state
,
...
...
python/sglang/srt/layers/attention/fla/fused_recurrent.py
View file @
a4bf5c6a
...
...
@@ -44,6 +44,7 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
IS_BETA_HEADWISE
:
tl
.
constexpr
,
# whether beta is headwise vector or scalar,
USE_QK_L2NORM_IN_KERNEL
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
IS_KDA
:
tl
.
constexpr
,
):
i_k
,
i_v
,
i_nh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_n
,
i_hv
=
i_nh
//
HV
,
i_nh
%
HV
...
...
@@ -67,7 +68,11 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
p_beta
=
beta
+
(
bos
*
HV
+
i_hv
)
*
V
+
o_v
else
:
p_beta
=
beta
+
bos
*
HV
+
i_hv
p_g
=
g
+
bos
*
HV
+
i_hv
if
not
IS_KDA
:
p_g
=
g
+
bos
*
HV
+
i_hv
else
:
p_gk
=
g
+
(
bos
*
HV
+
i_hv
)
*
K
+
o_k
p_o
=
o
+
((
i_k
*
all
+
bos
)
*
HV
+
i_hv
)
*
V
+
o_v
mask_k
=
o_k
<
K
...
...
@@ -83,14 +88,18 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
b_q
=
tl
.
load
(
p_q
,
mask
=
mask_k
,
other
=
0
).
to
(
tl
.
float32
)
b_k
=
tl
.
load
(
p_k
,
mask
=
mask_k
,
other
=
0
).
to
(
tl
.
float32
)
b_v
=
tl
.
load
(
p_v
,
mask
=
mask_v
,
other
=
0
).
to
(
tl
.
float32
)
b_g
=
tl
.
load
(
p_g
).
to
(
tl
.
float32
)
if
USE_QK_L2NORM_IN_KERNEL
:
b_q
=
b_q
/
(
tl
.
sqrt
(
tl
.
sum
(
b_q
*
b_q
)
+
1e-6
))
b_k
=
b_k
/
(
tl
.
sqrt
(
tl
.
sum
(
b_k
*
b_k
)
+
1e-6
))
b_q
=
b_q
*
scale
# [BK, BV]
b_h
*=
exp
(
b_g
)
if
not
IS_KDA
:
b_g
=
tl
.
load
(
p_g
).
to
(
tl
.
float32
)
b_h
*=
exp
(
b_g
)
else
:
b_gk
=
tl
.
load
(
p_gk
).
to
(
tl
.
float32
)
b_h
*=
exp
(
b_gk
[:,
None
])
# [BV]
b_v
-=
tl
.
sum
(
b_h
*
b_k
[:,
None
],
0
)
if
IS_BETA_HEADWISE
:
...
...
@@ -108,7 +117,10 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
p_k
+=
H
*
K
p_o
+=
HV
*
V
p_v
+=
HV
*
V
p_g
+=
HV
if
not
IS_KDA
:
p_g
+=
HV
else
:
p_gk
+=
HV
*
K
p_beta
+=
HV
*
(
V
if
IS_BETA_HEADWISE
else
1
)
if
STORE_FINAL_STATE
:
...
...
@@ -165,6 +177,7 @@ def fused_recurrent_gated_delta_rule_fwd(
BV
=
BV
,
IS_BETA_HEADWISE
=
beta
.
ndim
==
v
.
ndim
,
USE_QK_L2NORM_IN_KERNEL
=
use_qk_l2norm_in_kernel
,
IS_KDA
=
False
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
...
...
python/sglang/srt/layers/attention/fla/kda.py
0 → 100644
View file @
a4bf5c6a
# Adapted from https://github.com/vllm-project/vllm/blob/0384aa7150c4c9778efca041ffd1beb3ad2bd694/vllm/model_executor/layers/fla/ops/kda.py
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
import
torch
import
torch.nn
as
nn
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.attention.fla.chunk_delta_h
import
chunk_gated_delta_rule_fwd_h
from
sglang.srt.layers.attention.fla.cumsum
import
chunk_local_cumsum
from
sglang.srt.layers.attention.fla.fused_recurrent
import
(
fused_recurrent_gated_delta_rule_fwd_kernel
,
)
from
sglang.srt.layers.attention.fla.index
import
prepare_chunk_indices
from
sglang.srt.layers.attention.fla.l2norm
import
l2norm_fwd
from
sglang.srt.layers.attention.fla.op
import
exp
,
log
from
sglang.srt.layers.attention.fla.solve_tril
import
solve_tril
from
sglang.srt.layers.attention.fla.utils
import
is_amd
BT_LIST_AUTOTUNE
=
[
32
,
64
,
128
]
NUM_WARPS_AUTOTUNE
=
[
2
,
4
,
8
,
16
]
if
is_amd
else
[
4
,
8
,
16
,
32
]
def
cdiv
(
a
:
int
,
b
:
int
)
->
int
:
"""Ceiling division."""
return
-
(
a
//
-
b
)
def
next_power_of_2
(
n
:
int
)
->
int
:
"""The next power of 2 (inclusive)"""
if
n
<
1
:
return
1
return
1
<<
(
n
-
1
).
bit_length
()
def
fused_recurrent_kda_fwd
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
scale
:
float
,
initial_state
:
torch
.
Tensor
,
inplace_final_state
:
bool
=
True
,
cu_seqlens
:
torch
.
LongTensor
|
None
=
None
,
# ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens
:
torch
.
Tensor
|
None
=
None
,
use_qk_l2norm_in_kernel
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
B
,
T
,
H
,
K
,
V
=
*
k
.
shape
,
v
.
shape
[
-
1
]
HV
=
v
.
shape
[
2
]
N
=
B
if
cu_seqlens
is
None
else
len
(
cu_seqlens
)
-
1
BK
,
BV
=
next_power_of_2
(
K
),
min
(
next_power_of_2
(
V
),
8
)
NK
,
NV
=
cdiv
(
K
,
BK
),
cdiv
(
V
,
BV
)
assert
NK
==
1
,
"NK > 1 is not supported yet"
num_stages
=
3
num_warps
=
1
o
=
torch
.
empty_like
(
k
)
if
inplace_final_state
:
final_state
=
initial_state
else
:
final_state
=
q
.
new_empty
(
T
,
HV
,
K
,
V
,
dtype
=
initial_state
.
dtype
)
stride_init_state_token
=
initial_state
.
stride
(
0
)
stride_final_state_token
=
final_state
.
stride
(
0
)
# if ssm_state_indices is None:
# stride_indices_seq, stride_indices_tok = 1, 1
# elif ssm_state_indices.ndim == 1:
# stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
# else:
# stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()
grid
=
(
NK
,
NV
,
N
*
HV
)
fused_recurrent_gated_delta_rule_fwd_kernel
[
grid
](
q
=
q
,
k
=
k
,
v
=
v
,
g
=
g
,
beta
=
beta
,
o
=
o
,
h0
=
initial_state
,
ht
=
final_state
,
cu_seqlens
=
cu_seqlens
,
# ssm_state_indices=ssm_state_indices,
# num_accepted_tokens=num_accepted_tokens,
scale
=
scale
,
# N=N,
T
=
T
,
B
=
B
,
H
=
H
,
HV
=
HV
,
K
=
K
,
V
=
V
,
BK
=
BK
,
BV
=
BV
,
# stride_init_state_token=stride_init_state_token,
# stride_final_state_token=stride_final_state_token,
# stride_indices_seq=stride_indices_seq,
# stride_indices_tok=stride_indices_tok,
IS_BETA_HEADWISE
=
beta
.
ndim
==
v
.
ndim
,
USE_QK_L2NORM_IN_KERNEL
=
use_qk_l2norm_in_kernel
,
# INPLACE_FINAL_STATE=inplace_final_state,
IS_KDA
=
True
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
return
o
,
final_state
def
fused_recurrent_kda
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
=
None
,
scale
:
float
=
None
,
initial_state
:
torch
.
Tensor
=
None
,
inplace_final_state
:
bool
=
True
,
use_qk_l2norm_in_kernel
:
bool
=
True
,
cu_seqlens
:
torch
.
LongTensor
|
None
=
None
,
# ssm_state_indices: torch.LongTensor | None = None,
**
kwargs
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
cu_seqlens
is
not
None
and
q
.
shape
[
0
]
!=
1
:
raise
ValueError
(
f
"The batch size is expected to be 1 rather than
{
q
.
shape
[
0
]
}
when using `cu_seqlens`."
f
"Please flatten variable-length inputs before processing."
)
if
scale
is
None
:
scale
=
k
.
shape
[
-
1
]
**
-
0.5
o
,
final_state
=
fused_recurrent_kda_fwd
(
q
=
q
.
contiguous
(),
k
=
k
.
contiguous
(),
v
=
v
.
contiguous
(),
g
=
g
.
contiguous
(),
beta
=
beta
.
contiguous
(),
scale
=
scale
,
initial_state
=
initial_state
,
inplace_final_state
=
inplace_final_state
,
cu_seqlens
=
cu_seqlens
,
# ssm_state_indices=ssm_state_indices,
num_accepted_tokens
=
None
,
use_qk_l2norm_in_kernel
=
use_qk_l2norm_in_kernel
,
)
return
o
,
final_state
@
triton
.
heuristics
(
{
"STORE_RESIDUAL_OUT"
:
lambda
args
:
args
[
"residual_out"
]
is
not
None
,
"HAS_RESIDUAL"
:
lambda
args
:
args
[
"residual"
]
is
not
None
,
"HAS_WEIGHT"
:
lambda
args
:
args
[
"w"
]
is
not
None
,
"HAS_BIAS"
:
lambda
args
:
args
[
"b"
]
is
not
None
,
}
)
@
triton
.
jit
def
layer_norm_gated_fwd_kernel
(
x
,
# pointer to the input
g
,
# pointer to the gate
y
,
# pointer to the output
w
,
# pointer to the weights
b
,
# pointer to the biases
residual
,
# pointer to the residual
residual_out
,
# pointer to the residual
mean
,
# pointer to the mean
rstd
,
# pointer to the 1/std
eps
,
# epsilon to avoid division by zero
T
,
# number of rows in x
D
:
tl
.
constexpr
,
# number of columns in x
BT
:
tl
.
constexpr
,
BD
:
tl
.
constexpr
,
ACTIVATION
:
tl
.
constexpr
,
IS_RMS_NORM
:
tl
.
constexpr
,
STORE_RESIDUAL_OUT
:
tl
.
constexpr
,
HAS_RESIDUAL
:
tl
.
constexpr
,
HAS_WEIGHT
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
):
i_t
=
tl
.
program_id
(
0
)
o_d
=
tl
.
arange
(
0
,
BD
)
m_d
=
o_d
<
D
p_x
=
tl
.
make_block_ptr
(
x
,
(
T
,
D
),
(
D
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
BD
),
(
1
,
0
))
b_x
=
tl
.
load
(
p_x
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
if
HAS_RESIDUAL
:
p_res
=
tl
.
make_block_ptr
(
residual
,
(
T
,
D
),
(
D
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
BD
),
(
1
,
0
)
)
b_x
+=
tl
.
load
(
p_res
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
if
STORE_RESIDUAL_OUT
:
p_res_out
=
tl
.
make_block_ptr
(
residual_out
,
(
T
,
D
),
(
D
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
BD
),
(
1
,
0
)
)
tl
.
store
(
p_res_out
,
b_x
.
to
(
p_res_out
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
not
IS_RMS_NORM
:
b_mean
=
tl
.
sum
(
b_x
,
axis
=
1
)
/
D
p_mean
=
tl
.
make_block_ptr
(
mean
,
(
T
,),
(
1
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,))
tl
.
store
(
p_mean
,
b_mean
.
to
(
p_mean
.
dtype
.
element_ty
),
boundary_check
=
(
0
,))
b_xbar
=
tl
.
where
(
m_d
[
None
,
:],
b_x
-
b_mean
[:,
None
],
0.0
)
b_var
=
tl
.
sum
(
b_xbar
*
b_xbar
,
axis
=
1
)
/
D
else
:
b_xbar
=
tl
.
where
(
m_d
[
None
,
:],
b_x
,
0.0
)
b_var
=
tl
.
sum
(
b_xbar
*
b_xbar
,
axis
=
1
)
/
D
b_rstd
=
1
/
tl
.
sqrt
(
b_var
+
eps
)
p_rstd
=
tl
.
make_block_ptr
(
rstd
,
(
T
,),
(
1
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,))
tl
.
store
(
p_rstd
,
b_rstd
.
to
(
p_rstd
.
dtype
.
element_ty
),
boundary_check
=
(
0
,))
if
HAS_WEIGHT
:
b_w
=
tl
.
load
(
w
+
o_d
,
mask
=
m_d
).
to
(
tl
.
float32
)
if
HAS_BIAS
:
b_b
=
tl
.
load
(
b
+
o_d
,
mask
=
m_d
).
to
(
tl
.
float32
)
b_x_hat
=
(
(
b_x
-
b_mean
[:,
None
])
*
b_rstd
[:,
None
]
if
not
IS_RMS_NORM
else
b_x
*
b_rstd
[:,
None
]
)
b_y
=
b_x_hat
*
b_w
[
None
,
:]
if
HAS_WEIGHT
else
b_x_hat
if
HAS_BIAS
:
b_y
=
b_y
+
b_b
[
None
,
:]
# swish/sigmoid output gate
p_g
=
tl
.
make_block_ptr
(
g
,
(
T
,
D
),
(
D
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
BD
),
(
1
,
0
))
b_g
=
tl
.
load
(
p_g
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
if
ACTIVATION
==
"swish"
or
ACTIVATION
==
"silu"
:
b_y
=
b_y
*
b_g
*
tl
.
sigmoid
(
b_g
)
elif
ACTIVATION
==
"sigmoid"
:
b_y
=
b_y
*
tl
.
sigmoid
(
b_g
)
# Write output
p_y
=
tl
.
make_block_ptr
(
y
,
(
T
,
D
),
(
D
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
BD
),
(
1
,
0
))
tl
.
store
(
p_y
,
b_y
.
to
(
p_y
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
@
triton
.
heuristics
(
{
"STORE_RESIDUAL_OUT"
:
lambda
args
:
args
[
"residual_out"
]
is
not
None
,
"HAS_RESIDUAL"
:
lambda
args
:
args
[
"residual"
]
is
not
None
,
"HAS_WEIGHT"
:
lambda
args
:
args
[
"w"
]
is
not
None
,
"HAS_BIAS"
:
lambda
args
:
args
[
"b"
]
is
not
None
,
}
)
@
triton
.
jit
def
layer_norm_gated_fwd_kernel1
(
x
,
# pointer to the input
g
,
# pointer to the gate
y
,
# pointer to the output
w
,
# pointer to the weights
b
,
# pointer to the biases
residual
,
# pointer to the residual
residual_out
,
# pointer to the residual
mean
,
# pointer to the mean
rstd
,
# pointer to the 1/std
eps
,
# epsilon to avoid division by zero
D
:
tl
.
constexpr
,
# number of columns in x
BD
:
tl
.
constexpr
,
ACTIVATION
:
tl
.
constexpr
,
IS_RMS_NORM
:
tl
.
constexpr
,
STORE_RESIDUAL_OUT
:
tl
.
constexpr
,
HAS_RESIDUAL
:
tl
.
constexpr
,
HAS_WEIGHT
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
):
i_t
=
tl
.
program_id
(
0
)
x
+=
i_t
*
D
y
+=
i_t
*
D
g
+=
i_t
*
D
if
HAS_RESIDUAL
:
residual
+=
i_t
*
D
if
STORE_RESIDUAL_OUT
:
residual_out
+=
i_t
*
D
o_d
=
tl
.
arange
(
0
,
BD
)
m_d
=
o_d
<
D
b_x
=
tl
.
load
(
x
+
o_d
,
mask
=
m_d
,
other
=
0.0
).
to
(
tl
.
float32
)
if
HAS_RESIDUAL
:
b_x
+=
tl
.
load
(
residual
+
o_d
,
mask
=
m_d
,
other
=
0.0
).
to
(
tl
.
float32
)
if
STORE_RESIDUAL_OUT
:
tl
.
store
(
residual_out
+
o_d
,
b_x
,
mask
=
m_d
)
if
not
IS_RMS_NORM
:
b_mean
=
tl
.
sum
(
b_x
,
axis
=
0
)
/
D
tl
.
store
(
mean
+
i_t
,
b_mean
)
b_xbar
=
tl
.
where
(
m_d
,
b_x
-
b_mean
,
0.0
)
b_var
=
tl
.
sum
(
b_xbar
*
b_xbar
,
axis
=
0
)
/
D
else
:
b_xbar
=
tl
.
where
(
m_d
,
b_x
,
0.0
)
b_var
=
tl
.
sum
(
b_xbar
*
b_xbar
,
axis
=
0
)
/
D
b_rstd
=
1
/
tl
.
sqrt
(
b_var
+
eps
)
tl
.
store
(
rstd
+
i_t
,
b_rstd
)
if
HAS_WEIGHT
:
b_w
=
tl
.
load
(
w
+
o_d
,
mask
=
m_d
).
to
(
tl
.
float32
)
if
HAS_BIAS
:
b_b
=
tl
.
load
(
b
+
o_d
,
mask
=
m_d
).
to
(
tl
.
float32
)
b_x_hat
=
(
b_x
-
b_mean
)
*
b_rstd
if
not
IS_RMS_NORM
else
b_x
*
b_rstd
b_y
=
b_x_hat
*
b_w
if
HAS_WEIGHT
else
b_x_hat
if
HAS_BIAS
:
b_y
=
b_y
+
b_b
# swish/sigmoid output gate
b_g
=
tl
.
load
(
g
+
o_d
,
mask
=
m_d
,
other
=
0.0
).
to
(
tl
.
float32
)
if
ACTIVATION
==
"swish"
or
ACTIVATION
==
"silu"
:
b_y
=
b_y
*
b_g
*
tl
.
sigmoid
(
b_g
)
elif
ACTIVATION
==
"sigmoid"
:
b_y
=
b_y
*
tl
.
sigmoid
(
b_g
)
# Write output
tl
.
store
(
y
+
o_d
,
b_y
,
mask
=
m_d
)
def
layer_norm_gated_fwd
(
x
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
activation
:
str
=
"swish"
,
eps
:
float
=
1e-5
,
residual
:
torch
.
Tensor
=
None
,
out_dtype
:
torch
.
dtype
=
None
,
residual_dtype
:
torch
.
dtype
=
None
,
is_rms_norm
:
bool
=
False
,
):
if
residual
is
not
None
:
residual_dtype
=
residual
.
dtype
T
,
D
=
x
.
shape
if
residual
is
not
None
:
assert
residual
.
shape
==
(
T
,
D
)
if
weight
is
not
None
:
assert
weight
.
shape
==
(
D
,)
if
bias
is
not
None
:
assert
bias
.
shape
==
(
D
,)
# allocate output
y
=
x
if
out_dtype
is
None
else
torch
.
empty_like
(
x
,
dtype
=
out_dtype
)
if
residual
is
not
None
or
(
residual_dtype
is
not
None
and
residual_dtype
!=
x
.
dtype
):
residual_out
=
torch
.
empty
(
T
,
D
,
device
=
x
.
device
,
dtype
=
residual_dtype
)
else
:
residual_out
=
None
mean
=
(
torch
.
empty
((
T
,),
dtype
=
torch
.
float
,
device
=
x
.
device
)
if
not
is_rms_norm
else
None
)
rstd
=
torch
.
empty
((
T
,),
dtype
=
torch
.
float
,
device
=
x
.
device
)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE
=
65536
//
x
.
element_size
()
BD
=
min
(
MAX_FUSED_SIZE
,
next_power_of_2
(
D
))
if
D
>
BD
:
raise
RuntimeError
(
"This layer norm doesn't support feature dim >= 64KB."
)
# heuristics for number of warps
if
D
<=
512
:
BT
=
32
layer_norm_gated_fwd_kernel
[(
cdiv
(
T
,
BT
),)](
x
=
x
,
g
=
g
,
y
=
y
,
w
=
weight
,
b
=
bias
,
residual
=
residual
,
residual_out
=
residual_out
,
mean
=
mean
,
rstd
=
rstd
,
eps
=
eps
,
T
=
T
,
D
=
D
,
BD
=
BD
,
BT
=
BT
,
ACTIVATION
=
activation
,
IS_RMS_NORM
=
is_rms_norm
,
num_warps
=
4
,
)
else
:
layer_norm_gated_fwd_kernel1
[(
T
,)](
x
=
x
,
g
=
g
,
y
=
y
,
w
=
weight
,
b
=
bias
,
residual
=
residual
,
residual_out
=
residual_out
,
mean
=
mean
,
rstd
=
rstd
,
eps
=
eps
,
D
=
D
,
BD
=
BD
,
ACTIVATION
=
activation
,
IS_RMS_NORM
=
is_rms_norm
,
num_warps
=
4
,
)
# residual_out is None if residual is None and residual_dtype == input_dtype
return
y
,
mean
,
rstd
,
residual_out
if
residual_out
is
not
None
else
x
def
rms_norm_gated
(
x
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
activation
:
str
=
"swish"
,
residual
:
torch
.
Tensor
|
None
=
None
,
prenorm
:
bool
=
False
,
residual_in_fp32
:
bool
=
False
,
eps
:
float
=
1e-6
,
):
x_shape_og
=
x
.
shape
# reshape input data into 2D tensor
x
=
x
.
contiguous
().
reshape
(
-
1
,
x
.
shape
[
-
1
])
g
=
g
.
contiguous
().
reshape
(
-
1
,
g
.
shape
[
-
1
])
if
residual
is
not
None
:
assert
residual
.
shape
==
x_shape_og
residual
=
residual
.
contiguous
().
reshape
(
-
1
,
residual
.
shape
[
-
1
])
residual_dtype
=
(
residual
.
dtype
if
residual
is
not
None
else
(
torch
.
float
if
residual_in_fp32
else
None
)
)
y
,
_
,
_
,
residual_out
=
layer_norm_gated_fwd
(
x
=
x
,
g
=
g
,
weight
=
weight
,
bias
=
bias
,
activation
=
activation
,
eps
=
eps
,
residual
=
residual
,
residual_dtype
=
residual_dtype
,
is_rms_norm
=
True
,
)
y
=
y
.
reshape
(
x_shape_og
)
return
y
if
not
prenorm
else
(
y
,
residual_out
.
reshape
(
x_shape_og
))
class
FusedRMSNormGated
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
elementwise_affine
:
bool
=
True
,
eps
:
float
=
1e-5
,
activation
:
str
=
"swish"
,
device
:
torch
.
device
|
None
=
None
,
dtype
:
torch
.
dtype
|
None
=
None
,
)
->
None
:
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
elementwise_affine
=
elementwise_affine
self
.
eps
=
eps
self
.
activation
=
activation
if
self
.
activation
not
in
[
"swish"
,
"silu"
,
"sigmoid"
]:
raise
ValueError
(
f
"Unsupported activation:
{
self
.
activation
}
"
)
if
elementwise_affine
:
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
else
:
self
.
register_parameter
(
"weight"
,
None
)
self
.
register_parameter
(
"bias"
,
None
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
=
None
,
prenorm
:
bool
=
False
,
residual_in_fp32
:
bool
=
False
,
)
->
torch
.
Tensor
:
return
rms_norm_gated
(
x
,
g
,
self
.
weight
,
self
.
bias
,
self
.
activation
,
residual
=
residual
,
eps
=
self
.
eps
,
prenorm
=
prenorm
,
residual_in_fp32
=
residual_in_fp32
,
)
@
triton
.
heuristics
({
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
})
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BK"
:
BK
},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
BK
in
[
32
,
64
]
for
num_warps
in
[
1
,
2
,
4
,
8
]
for
num_stages
in
[
2
,
3
,
4
]
],
key
=
[
"BC"
],
)
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_inter
(
q
,
k
,
g
,
beta
,
A
,
Aqk
,
scale
,
cu_seqlens
,
chunk_indices
,
T
,
H
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BC
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
NC
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
):
i_t
,
i_c
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
i_i
,
i_j
=
i_c
//
NC
,
i_c
%
NC
if
IS_VARLEN
:
i_n
,
i_t
=
(
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
),
)
bos
,
eos
=
(
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
),
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
if
i_t
*
BT
+
i_i
*
BC
>=
T
:
return
if
i_i
<=
i_j
:
return
q
+=
(
bos
*
H
+
i_h
)
*
K
k
+=
(
bos
*
H
+
i_h
)
*
K
g
+=
(
bos
*
H
+
i_h
)
*
K
A
+=
(
bos
*
H
+
i_h
)
*
BT
Aqk
+=
(
bos
*
H
+
i_h
)
*
BT
p_b
=
tl
.
make_block_ptr
(
beta
+
bos
*
H
+
i_h
,
(
T
,),
(
H
,),
(
i_t
*
BT
+
i_i
*
BC
,),
(
BC
,),
(
0
,)
)
b_b
=
tl
.
load
(
p_b
,
boundary_check
=
(
0
,))
b_A
=
tl
.
zeros
([
BC
,
BC
],
dtype
=
tl
.
float32
)
b_Aqk
=
tl
.
zeros
([
BC
,
BC
],
dtype
=
tl
.
float32
)
for
i_k
in
range
(
tl
.
cdiv
(
K
,
BK
)):
p_q
=
tl
.
make_block_ptr
(
q
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
+
i_i
*
BC
,
i_k
*
BK
),
(
BC
,
BK
),
(
1
,
0
)
)
p_k
=
tl
.
make_block_ptr
(
k
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
+
i_i
*
BC
,
i_k
*
BK
),
(
BC
,
BK
),
(
1
,
0
)
)
p_g
=
tl
.
make_block_ptr
(
g
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
+
i_i
*
BC
,
i_k
*
BK
),
(
BC
,
BK
),
(
1
,
0
)
)
b_kt
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
H
*
K
),
(
i_k
*
BK
,
i_t
*
BT
+
i_j
*
BC
),
(
BK
,
BC
),
(
0
,
1
)
)
p_gk
=
tl
.
make_block_ptr
(
g
,
(
K
,
T
),
(
1
,
H
*
K
),
(
i_k
*
BK
,
i_t
*
BT
+
i_j
*
BC
),
(
BK
,
BC
),
(
0
,
1
)
)
o_k
=
i_k
*
BK
+
tl
.
arange
(
0
,
BK
)
m_k
=
o_k
<
K
# [BK,]
b_gn
=
tl
.
load
(
g
+
(
i_t
*
BT
+
i_i
*
BC
)
*
H
*
K
+
o_k
,
mask
=
m_k
,
other
=
0
)
# [BC, BK]
b_g
=
tl
.
load
(
p_g
,
boundary_check
=
(
0
,
1
))
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
*
exp
(
b_g
-
b_gn
[
None
,
:])
# [BK, BC]
b_gk
=
tl
.
load
(
p_gk
,
boundary_check
=
(
0
,
1
))
b_kt
=
tl
.
load
(
b_kt
,
boundary_check
=
(
0
,
1
))
# [BC, BC]
b_ktg
=
b_kt
*
exp
(
b_gn
[:,
None
]
-
b_gk
)
b_A
+=
tl
.
dot
(
b_k
,
b_ktg
)
b_q
=
tl
.
load
(
p_q
,
boundary_check
=
(
0
,
1
))
b_qg
=
b_q
*
exp
(
b_g
-
b_gn
[
None
,
:])
*
scale
b_Aqk
+=
tl
.
dot
(
b_qg
,
b_ktg
)
b_A
*=
b_b
[:,
None
]
p_A
=
tl
.
make_block_ptr
(
A
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
+
i_i
*
BC
,
i_j
*
BC
),
(
BC
,
BC
),
(
1
,
0
)
)
tl
.
store
(
p_A
,
b_A
.
to
(
A
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
p_Aqk
=
tl
.
make_block_ptr
(
Aqk
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
+
i_i
*
BC
,
i_j
*
BC
),
(
BC
,
BC
),
(
1
,
0
)
)
tl
.
store
(
p_Aqk
,
b_Aqk
.
to
(
Aqk
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
@
triton
.
heuristics
({
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
})
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
)
for
num_warps
in
[
1
,
2
,
4
,
8
]],
key
=
[
"BK"
,
"BT"
],
)
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_intra
(
q
,
k
,
g
,
beta
,
A
,
Aqk
,
scale
,
cu_seqlens
,
chunk_indices
,
T
,
H
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BC
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
):
i_t
,
i_i
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
(
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
),
)
bos
,
eos
=
(
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
),
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
if
i_t
*
BT
+
i_i
*
BC
>=
T
:
return
o_i
=
tl
.
arange
(
0
,
BC
)
o_k
=
tl
.
arange
(
0
,
BK
)
m_k
=
o_k
<
K
m_A
=
(
i_t
*
BT
+
i_i
*
BC
+
o_i
)
<
T
o_A
=
(
bos
+
i_t
*
BT
+
i_i
*
BC
+
o_i
)
*
H
*
BT
+
i_h
*
BT
+
i_i
*
BC
p_q
=
tl
.
make_block_ptr
(
q
+
(
bos
*
H
+
i_h
)
*
K
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
+
i_i
*
BC
,
0
),
(
BC
,
BK
),
(
1
,
0
),
)
p_k
=
tl
.
make_block_ptr
(
k
+
(
bos
*
H
+
i_h
)
*
K
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
+
i_i
*
BC
,
0
),
(
BC
,
BK
),
(
1
,
0
),
)
p_g
=
tl
.
make_block_ptr
(
g
+
(
bos
*
H
+
i_h
)
*
K
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
+
i_i
*
BC
,
0
),
(
BC
,
BK
),
(
1
,
0
),
)
b_q
=
tl
.
load
(
p_q
,
boundary_check
=
(
0
,
1
))
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_g
=
tl
.
load
(
p_g
,
boundary_check
=
(
0
,
1
))
p_b
=
beta
+
(
bos
+
i_t
*
BT
+
i_i
*
BC
+
o_i
)
*
H
+
i_h
b_k
=
b_k
*
tl
.
load
(
p_b
,
mask
=
m_A
,
other
=
0
)[:,
None
]
p_kt
=
k
+
(
bos
+
i_t
*
BT
+
i_i
*
BC
)
*
H
*
K
+
i_h
*
K
+
o_k
p_gk
=
g
+
(
bos
+
i_t
*
BT
+
i_i
*
BC
)
*
H
*
K
+
i_h
*
K
+
o_k
for
j
in
range
(
0
,
min
(
BC
,
T
-
i_t
*
BT
-
i_i
*
BC
)):
b_kt
=
tl
.
load
(
p_kt
,
mask
=
m_k
,
other
=
0
).
to
(
tl
.
float32
)
b_gk
=
tl
.
load
(
p_gk
,
mask
=
m_k
,
other
=
0
).
to
(
tl
.
float32
)
b_ktg
=
b_kt
[
None
,
:]
*
exp
(
b_g
-
b_gk
[
None
,
:])
b_A
=
tl
.
sum
(
b_k
*
b_ktg
,
1
)
b_A
=
tl
.
where
(
o_i
>
j
,
b_A
,
0.0
)
b_Aqk
=
tl
.
sum
(
b_q
*
b_ktg
,
1
)
b_Aqk
=
tl
.
where
(
o_i
>=
j
,
b_Aqk
*
scale
,
0.0
)
tl
.
store
(
A
+
o_A
+
j
,
b_A
,
mask
=
m_A
)
tl
.
store
(
Aqk
+
o_A
+
j
,
b_Aqk
,
mask
=
m_A
)
p_kt
+=
H
*
K
p_gk
+=
H
*
K
def
chunk_kda_scaled_dot_kkt_fwd
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
gk
:
torch
.
Tensor
|
None
=
None
,
beta
:
torch
.
Tensor
|
None
=
None
,
scale
:
float
|
None
=
None
,
cu_seqlens
:
torch
.
LongTensor
|
None
=
None
,
chunk_size
:
int
=
64
,
output_dtype
:
torch
.
dtype
=
torch
.
float32
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
r
"""
Compute beta * K * K^T.
Args:
k (torch.Tensor):
The key tensor of shape `[B, T, H, K]`.
beta (torch.Tensor):
The beta tensor of shape `[B, T, H]`.
gk (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.
cu_seqlens (torch.LongTensor):
The cumulative sequence lengths of the input tensor.
Default: None
chunk_size (int):
The chunk size. Default: 64.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float32`
Returns:
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
"""
B
,
T
,
H
,
K
=
k
.
shape
assert
K
<=
256
BT
=
chunk_size
chunk_indices
=
(
prepare_chunk_indices
(
cu_seqlens
,
BT
)
if
cu_seqlens
is
not
None
else
None
)
NT
=
cdiv
(
T
,
BT
)
if
cu_seqlens
is
None
else
len
(
chunk_indices
)
BC
=
min
(
16
,
BT
)
NC
=
cdiv
(
BT
,
BC
)
BK
=
max
(
next_power_of_2
(
K
),
16
)
A
=
torch
.
zeros
(
B
,
T
,
H
,
BT
,
device
=
k
.
device
,
dtype
=
output_dtype
)
Aqk
=
torch
.
zeros
(
B
,
T
,
H
,
BT
,
device
=
k
.
device
,
dtype
=
output_dtype
)
grid
=
(
NT
,
NC
*
NC
,
B
*
H
)
chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_inter
[
grid
](
q
=
q
,
k
=
k
,
g
=
gk
,
beta
=
beta
,
A
=
A
,
Aqk
=
Aqk
,
scale
=
scale
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
T
=
T
,
H
=
H
,
K
=
K
,
BT
=
BT
,
BC
=
BC
,
NC
=
NC
,
)
grid
=
(
NT
,
NC
,
B
*
H
)
chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_intra
[
grid
](
q
=
q
,
k
=
k
,
g
=
gk
,
beta
=
beta
,
A
=
A
,
Aqk
=
Aqk
,
scale
=
scale
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
T
=
T
,
H
=
H
,
K
=
K
,
BT
=
BT
,
BC
=
BC
,
BK
=
BK
,
)
return
A
,
Aqk
@
triton
.
heuristics
(
{
"STORE_QG"
:
lambda
args
:
args
[
"qg"
]
is
not
None
,
"STORE_KG"
:
lambda
args
:
args
[
"kg"
]
is
not
None
,
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
,
}
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
2
,
4
,
8
]
for
num_stages
in
[
2
,
3
,
4
]
],
key
=
[
"H"
,
"K"
,
"V"
,
"BT"
,
"BK"
,
"BV"
,
"IS_VARLEN"
],
)
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
recompute_w_u_fwd_kernel
(
q
,
k
,
qg
,
kg
,
v
,
beta
,
w
,
u
,
A
,
gk
,
cu_seqlens
,
chunk_indices
,
T
,
H
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
STORE_QG
:
tl
.
constexpr
,
STORE_KG
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
DOT_PRECISION
:
tl
.
constexpr
,
):
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
(
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
),
)
bos
,
eos
=
(
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
),
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
p_b
=
tl
.
make_block_ptr
(
beta
+
bos
*
H
+
i_h
,
(
T
,),
(
H
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,))
b_b
=
tl
.
load
(
p_b
,
boundary_check
=
(
0
,))
p_A
=
tl
.
make_block_ptr
(
A
+
(
bos
*
H
+
i_h
)
*
BT
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
BT
),
(
1
,
0
)
)
b_A
=
tl
.
load
(
p_A
,
boundary_check
=
(
0
,
1
))
for
i_v
in
range
(
tl
.
cdiv
(
V
,
BV
)):
p_v
=
tl
.
make_block_ptr
(
v
+
(
bos
*
H
+
i_h
)
*
V
,
(
T
,
V
),
(
H
*
V
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
),
)
p_u
=
tl
.
make_block_ptr
(
u
+
(
bos
*
H
+
i_h
)
*
V
,
(
T
,
V
),
(
H
*
V
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
),
)
b_v
=
tl
.
load
(
p_v
,
boundary_check
=
(
0
,
1
))
b_vb
=
(
b_v
*
b_b
[:,
None
]).
to
(
b_v
.
dtype
)
b_u
=
tl
.
dot
(
b_A
,
b_vb
,
input_precision
=
DOT_PRECISION
)
tl
.
store
(
p_u
,
b_u
.
to
(
p_u
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
for
i_k
in
range
(
tl
.
cdiv
(
K
,
BK
)):
p_w
=
tl
.
make_block_ptr
(
w
+
(
bos
*
H
+
i_h
)
*
K
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
),
)
p_k
=
tl
.
make_block_ptr
(
k
+
(
bos
*
H
+
i_h
)
*
K
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
),
)
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_kb
=
b_k
*
b_b
[:,
None
]
p_gk
=
tl
.
make_block_ptr
(
gk
+
(
bos
*
H
+
i_h
)
*
K
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
),
)
b_gk
=
tl
.
load
(
p_gk
,
boundary_check
=
(
0
,
1
))
b_kb
*=
exp
(
b_gk
)
if
STORE_QG
:
p_q
=
tl
.
make_block_ptr
(
q
+
(
bos
*
H
+
i_h
)
*
K
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
),
)
p_qg
=
tl
.
make_block_ptr
(
qg
+
(
bos
*
H
+
i_h
)
*
K
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
),
)
b_q
=
tl
.
load
(
p_q
,
boundary_check
=
(
0
,
1
))
b_qg
=
b_q
*
exp
(
b_gk
)
tl
.
store
(
p_qg
,
b_qg
.
to
(
p_qg
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
STORE_KG
:
last_idx
=
min
(
i_t
*
BT
+
BT
,
T
)
-
1
o_k
=
i_k
*
BK
+
tl
.
arange
(
0
,
BK
)
m_k
=
o_k
<
K
b_gn
=
tl
.
load
(
gk
+
((
bos
+
last_idx
)
*
H
+
i_h
)
*
K
+
o_k
,
mask
=
m_k
,
other
=
0.0
)
b_kg
=
b_k
*
exp
(
b_gn
-
b_gk
)
p_kg
=
tl
.
make_block_ptr
(
kg
+
(
bos
*
H
+
i_h
)
*
K
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
),
)
tl
.
store
(
p_kg
,
b_kg
.
to
(
p_kg
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
b_w
=
tl
.
dot
(
b_A
,
b_kb
.
to
(
b_k
.
dtype
))
tl
.
store
(
p_w
,
b_w
.
to
(
p_w
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
def
recompute_w_u_fwd
(
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
q
:
torch
.
Tensor
|
None
=
None
,
gk
:
torch
.
Tensor
|
None
=
None
,
cu_seqlens
:
torch
.
LongTensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
B
,
T
,
H
,
K
,
V
=
*
k
.
shape
,
v
.
shape
[
-
1
]
BT
=
A
.
shape
[
-
1
]
BK
=
64
BV
=
64
chunk_indices
=
(
prepare_chunk_indices
(
cu_seqlens
,
BT
)
if
cu_seqlens
is
not
None
else
None
)
NT
=
cdiv
(
T
,
BT
)
if
cu_seqlens
is
None
else
len
(
chunk_indices
)
w
=
torch
.
empty_like
(
k
)
u
=
torch
.
empty_like
(
v
)
kg
=
torch
.
empty_like
(
k
)
if
gk
is
not
None
else
None
recompute_w_u_fwd_kernel
[(
NT
,
B
*
H
)](
q
=
q
,
k
=
k
,
qg
=
None
,
kg
=
kg
,
v
=
v
,
beta
=
beta
,
w
=
w
,
u
=
u
,
A
=
A
,
gk
=
gk
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
T
=
T
,
H
=
H
,
K
=
K
,
V
=
V
,
BT
=
BT
,
BK
=
BK
,
BV
=
BV
,
DOT_PRECISION
=
"ieee"
,
)
return
w
,
u
,
None
,
kg
@
triton
.
heuristics
({
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
})
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BK"
:
BK
,
"BV"
:
BV
},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
BK
in
[
32
,
64
]
for
BV
in
[
64
,
128
]
for
num_warps
in
[
2
,
4
,
8
]
for
num_stages
in
[
2
,
3
,
4
]
],
key
=
[
"BT"
],
)
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
chunk_gla_fwd_kernel_o
(
q
,
v
,
g
,
h
,
o
,
A
,
cu_seqlens
,
chunk_indices
,
scale
,
T
,
H
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
):
i_v
,
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_tg
=
i_t
i_n
,
i_t
=
(
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
),
)
bos
,
eos
=
(
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
),
)
T
=
eos
-
bos
NT
=
tl
.
cdiv
(
T
,
BT
)
else
:
NT
=
tl
.
cdiv
(
T
,
BT
)
i_tg
=
i_b
*
NT
+
i_t
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
m_s
=
tl
.
arange
(
0
,
BT
)[:,
None
]
>=
tl
.
arange
(
0
,
BT
)[
None
,
:]
b_o
=
tl
.
zeros
([
BT
,
BV
],
dtype
=
tl
.
float32
)
for
i_k
in
range
(
tl
.
cdiv
(
K
,
BK
)):
p_q
=
tl
.
make_block_ptr
(
q
+
(
bos
*
H
+
i_h
)
*
K
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
),
)
p_g
=
tl
.
make_block_ptr
(
g
+
(
bos
*
H
+
i_h
)
*
K
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
),
)
p_h
=
tl
.
make_block_ptr
(
h
+
(
i_tg
*
H
+
i_h
)
*
K
*
V
,
(
K
,
V
),
(
V
,
1
),
(
i_k
*
BK
,
i_v
*
BV
),
(
BK
,
BV
),
(
1
,
0
),
)
# [BT, BK]
b_q
=
tl
.
load
(
p_q
,
boundary_check
=
(
0
,
1
))
b_q
=
(
b_q
*
scale
).
to
(
b_q
.
dtype
)
# [BT, BK]
b_g
=
tl
.
load
(
p_g
,
boundary_check
=
(
0
,
1
))
# [BT, BK]
b_qg
=
(
b_q
*
exp
(
b_g
)).
to
(
b_q
.
dtype
)
# [BK, BV]
b_h
=
tl
.
load
(
p_h
,
boundary_check
=
(
0
,
1
))
# works but dkw, owing to divine benevolence
# [BT, BV]
if
i_k
>=
0
:
b_o
+=
tl
.
dot
(
b_qg
,
b_h
.
to
(
b_qg
.
dtype
))
p_v
=
tl
.
make_block_ptr
(
v
+
(
bos
*
H
+
i_h
)
*
V
,
(
T
,
V
),
(
H
*
V
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
),
)
p_o
=
tl
.
make_block_ptr
(
o
+
(
bos
*
H
+
i_h
)
*
V
,
(
T
,
V
),
(
H
*
V
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
),
)
p_A
=
tl
.
make_block_ptr
(
A
+
(
bos
*
H
+
i_h
)
*
BT
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
BT
),
(
1
,
0
)
)
# [BT, BV]
b_v
=
tl
.
load
(
p_v
,
boundary_check
=
(
0
,
1
))
# [BT, BT]
b_A
=
tl
.
load
(
p_A
,
boundary_check
=
(
0
,
1
))
b_A
=
tl
.
where
(
m_s
,
b_A
,
0.0
).
to
(
b_v
.
dtype
)
b_o
+=
tl
.
dot
(
b_A
,
b_v
,
allow_tf32
=
False
)
tl
.
store
(
p_o
,
b_o
.
to
(
p_o
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
def
chunk_gla_fwd_o_gk
(
q
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
h
:
torch
.
Tensor
,
o
:
torch
.
Tensor
,
scale
:
float
,
cu_seqlens
:
torch
.
LongTensor
|
None
=
None
,
chunk_size
:
int
=
64
,
):
B
,
T
,
H
,
K
,
V
=
*
q
.
shape
,
v
.
shape
[
-
1
]
BT
=
chunk_size
chunk_indices
=
(
prepare_chunk_indices
(
cu_seqlens
,
chunk_size
)
if
cu_seqlens
is
not
None
else
None
)
NT
=
cdiv
(
T
,
BT
)
if
cu_seqlens
is
None
else
len
(
chunk_indices
)
def
grid
(
meta
):
return
(
cdiv
(
V
,
meta
[
"BV"
]),
NT
,
B
*
H
)
chunk_gla_fwd_kernel_o
[
grid
](
q
=
q
,
v
=
v
,
g
=
g
,
h
=
h
,
o
=
o
,
A
=
A
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
scale
=
scale
,
T
=
T
,
H
=
H
,
K
=
K
,
V
=
V
,
BT
=
BT
,
)
return
o
def
chunk_kda_fwd
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
scale
:
float
,
initial_state
:
torch
.
Tensor
,
output_final_state
:
bool
,
cu_seqlens
:
torch
.
LongTensor
|
None
=
None
,
):
chunk_size
=
64
g
=
chunk_local_cumsum
(
g
,
chunk_size
=
chunk_size
,
cu_seqlens
=
cu_seqlens
)
# the intra Aqk is kept in fp32
# the computation has very marginal effect on the entire throughput
A
,
Aqk
=
chunk_kda_scaled_dot_kkt_fwd
(
q
=
q
,
k
=
k
,
gk
=
g
,
beta
=
beta
,
scale
=
scale
,
cu_seqlens
=
cu_seqlens
,
output_dtype
=
torch
.
float32
,
)
A
=
solve_tril
(
A
=
A
,
cu_seqlens
=
cu_seqlens
,
output_dtype
=
k
.
dtype
)
w
,
u
,
_
,
kg
=
recompute_w_u_fwd
(
k
=
k
,
v
=
v
,
beta
=
beta
,
A
=
A
,
gk
=
g
,
cu_seqlens
=
cu_seqlens
,
)
del
A
h
,
v_new
,
final_state
=
chunk_gated_delta_rule_fwd_h
(
k
=
kg
,
w
=
w
,
u
=
u
,
gk
=
g
,
initial_state
=
initial_state
,
output_final_state
=
output_final_state
,
cu_seqlens
=
cu_seqlens
,
)
del
w
,
u
,
kg
o
=
chunk_gla_fwd_o_gk
(
q
=
q
,
v
=
v_new
,
g
=
g
,
A
=
Aqk
,
h
=
h
,
o
=
v
,
scale
=
scale
,
cu_seqlens
=
cu_seqlens
,
chunk_size
=
chunk_size
,
)
del
Aqk
,
v_new
,
h
return
o
,
final_state
def
chunk_kda
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
scale
:
float
=
None
,
initial_state
:
torch
.
Tensor
=
None
,
output_final_state
:
bool
=
False
,
use_qk_l2norm_in_kernel
:
bool
=
False
,
cu_seqlens
:
torch
.
LongTensor
|
None
=
None
,
**
kwargs
,
):
if
scale
is
None
:
scale
=
k
.
shape
[
-
1
]
**
-
0.5
if
use_qk_l2norm_in_kernel
:
q
=
l2norm_fwd
(
q
.
contiguous
())
k
=
l2norm_fwd
(
k
.
contiguous
())
o
,
final_state
=
chunk_kda_fwd
(
q
=
q
,
k
=
k
,
v
=
v
.
contiguous
(),
g
=
g
.
contiguous
(),
beta
=
beta
.
contiguous
(),
scale
=
scale
,
initial_state
=
initial_state
.
contiguous
(),
output_final_state
=
output_final_state
,
cu_seqlens
=
cu_seqlens
,
)
return
o
,
final_state
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BT"
:
bt
},
num_warps
=
nw
,
num_stages
=
ns
)
for
bt
in
BT_LIST_AUTOTUNE
for
nw
in
NUM_WARPS_AUTOTUNE
for
ns
in
[
2
,
3
]
],
key
=
[
"H"
,
"D"
],
)
@
triton
.
jit
def
kda_gate_fwd_kernel
(
g
,
A
,
y
,
g_bias
,
beta
:
tl
.
constexpr
,
threshold
:
tl
.
constexpr
,
T
,
H
,
D
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BD
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
):
i_t
,
i_h
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
n_t
=
i_t
*
BT
b_a
=
tl
.
load
(
A
+
i_h
).
to
(
tl
.
float32
)
b_a
=
-
tl
.
exp
(
b_a
)
stride_row
=
H
*
D
stride_col
=
1
g_ptr
=
tl
.
make_block_ptr
(
base
=
g
+
i_h
*
D
,
shape
=
(
T
,
D
),
strides
=
(
stride_row
,
stride_col
),
offsets
=
(
n_t
,
0
),
block_shape
=
(
BT
,
BD
),
order
=
(
1
,
0
),
)
y_ptr
=
tl
.
make_block_ptr
(
base
=
y
+
i_h
*
D
,
shape
=
(
T
,
D
),
strides
=
(
stride_row
,
stride_col
),
offsets
=
(
n_t
,
0
),
block_shape
=
(
BT
,
BD
),
order
=
(
1
,
0
),
)
b_g
=
tl
.
load
(
g_ptr
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
if
HAS_BIAS
:
n_d
=
tl
.
arange
(
0
,
BD
)
bias_mask
=
n_d
<
D
b_bias
=
tl
.
load
(
g_bias
+
i_h
*
D
+
n_d
,
mask
=
bias_mask
,
other
=
0.0
).
to
(
tl
.
float32
)
b_g
=
b_g
+
b_bias
[
None
,
:]
# softplus(x, beta) = (1/beta) * log(1 + exp(beta * x))
# When beta * x > threshold, use linear approximation x
# Use threshold to switch to linear when beta*x > threshold
g_scaled
=
b_g
*
beta
use_linear
=
g_scaled
>
threshold
sp
=
tl
.
where
(
use_linear
,
b_g
,
(
1.0
/
beta
)
*
log
(
1.0
+
tl
.
exp
(
g_scaled
)))
b_y
=
b_a
*
sp
tl
.
store
(
y_ptr
,
b_y
.
to
(
y
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
def
fused_kda_gate
(
g
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
head_k_dim
:
int
,
g_bias
:
torch
.
Tensor
|
None
=
None
,
beta
:
float
=
1.0
,
threshold
:
float
=
20.0
,
)
->
torch
.
Tensor
:
"""
Forward pass for KDA gate:
input g: [..., H*D]
param A: [H] or [1, 1, H, 1]
beta: softplus beta parameter
threshold: softplus threshold parameter
return : [..., H, D]
"""
orig_shape
=
g
.
shape
[:
-
1
]
g
=
g
.
view
(
-
1
,
g
.
shape
[
-
1
])
T
=
g
.
shape
[
0
]
HD
=
g
.
shape
[
1
]
H
=
A
.
numel
()
assert
H
*
head_k_dim
==
HD
y
=
torch
.
empty_like
(
g
,
dtype
=
torch
.
float32
)
def
grid
(
meta
):
return
(
cdiv
(
T
,
meta
[
"BT"
]),
H
)
kda_gate_fwd_kernel
[
grid
](
g
,
A
,
y
,
g_bias
,
beta
,
threshold
,
T
,
H
,
head_k_dim
,
BD
=
next_power_of_2
(
head_k_dim
),
HAS_BIAS
=
g_bias
is
not
None
,
)
y
=
y
.
view
(
*
orig_shape
,
H
,
head_k_dim
)
return
y
python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py
View file @
a4bf5c6a
from
typing
import
Optional
,
Union
import
torch
from
einops
import
rearrange
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.fla.chunk
import
chunk_gated_delta_rule
...
...
@@ -10,6 +11,11 @@ from sglang.srt.layers.attention.fla.fused_recurrent import (
from
sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent
import
(
fused_sigmoid_gating_delta_rule_update
,
)
from
sglang.srt.layers.attention.fla.kda
import
(
chunk_kda
,
fused_kda_gate
,
fused_recurrent_kda
,
)
from
sglang.srt.layers.attention.mamba.causal_conv1d_triton
import
(
PAD_SLOT_ID
,
causal_conv1d_fn
,
...
...
@@ -227,6 +233,223 @@ class MambaAttnBackendBase(AttentionBackend):
return
1
# Mamba attn does not use seq lens to index kv cache
class
KimiLinearAttnBackend
(
MambaAttnBackendBase
):
"""Attention backend using Mamba kernel."""
def
forward_decode
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
**
kwargs
,
):
q_proj_states
=
kwargs
[
"q_proj_states"
]
k_proj_states
=
kwargs
[
"k_proj_states"
]
v_proj_states
=
kwargs
[
"v_proj_states"
]
q_conv_weights
=
kwargs
[
"q_conv_weights"
]
k_conv_weights
=
kwargs
[
"k_conv_weights"
]
v_conv_weights
=
kwargs
[
"v_conv_weights"
]
q_conv_bias
=
kwargs
[
"q_conv_bias"
]
k_conv_bias
=
kwargs
[
"k_conv_bias"
]
v_conv_bias
=
kwargs
[
"v_conv_bias"
]
A_log
=
kwargs
[
"A_log"
]
dt_bias
=
kwargs
[
"dt_bias"
]
b_proj
=
kwargs
[
"b_proj"
]
f_a_proj
=
kwargs
[
"f_a_proj"
]
f_b_proj
=
kwargs
[
"f_b_proj"
]
hidden_states
=
kwargs
[
"hidden_states"
]
head_dim
=
kwargs
[
"head_dim"
]
layer_id
=
kwargs
[
"layer_id"
]
layer_cache
=
self
.
req_to_token_pool
.
mamba2_layer_cache
(
layer_id
)
q_conv_state
,
k_conv_state
,
v_conv_state
=
layer_cache
.
conv
ssm_states
=
layer_cache
.
temporal
query_start_loc
=
self
.
forward_metadata
.
query_start_loc
cache_indices
=
self
.
forward_metadata
.
mamba_cache_indices
q_conv_state
=
q_conv_state
.
transpose
(
-
1
,
-
2
)
k_conv_state
=
k_conv_state
.
transpose
(
-
1
,
-
2
)
v_conv_state
=
v_conv_state
.
transpose
(
-
1
,
-
2
)
q
=
causal_conv1d_update
(
q_proj_states
,
q_conv_state
,
q_conv_weights
,
q_conv_bias
,
activation
=
"silu"
,
conv_state_indices
=
cache_indices
,
)
k
=
causal_conv1d_update
(
k_proj_states
,
k_conv_state
,
k_conv_weights
,
k_conv_bias
,
activation
=
"silu"
,
conv_state_indices
=
cache_indices
,
)
v
=
causal_conv1d_update
(
v_proj_states
,
v_conv_state
,
v_conv_weights
,
v_conv_bias
,
activation
=
"silu"
,
conv_state_indices
=
cache_indices
,
)
q
,
k
,
v
=
map
(
lambda
x
:
rearrange
(
x
,
"n (h d) -> 1 n h d"
,
d
=
head_dim
),
(
q
,
k
,
v
)
)
beta
=
b_proj
(
hidden_states
)[
0
].
float
().
sigmoid
()
g
=
f_b_proj
(
f_a_proj
(
hidden_states
)[
0
])[
0
]
g
=
fused_kda_gate
(
g
,
A_log
,
head_dim
,
g_bias
=
dt_bias
)
beta
=
beta
.
unsqueeze
(
0
)
g
=
g
.
unsqueeze
(
0
)
initial_state
=
ssm_states
[
cache_indices
].
contiguous
()
(
core_attn_out
,
last_recurrent_state
,
)
=
fused_recurrent_kda
(
q
=
q
,
k
=
k
,
v
=
v
,
g
=
g
,
beta
=
beta
,
initial_state
=
initial_state
,
use_qk_l2norm_in_kernel
=
True
,
cu_seqlens
=
query_start_loc
,
)
ssm_states
[
cache_indices
]
=
last_recurrent_state
return
core_attn_out
def
forward_extend
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
**
kwargs
,
):
from
sglang.srt.layers.attention.mamba.causal_conv1d_triton
import
(
causal_conv1d_fn
,
)
q_proj_states
=
kwargs
[
"q_proj_states"
]
k_proj_states
=
kwargs
[
"k_proj_states"
]
v_proj_states
=
kwargs
[
"v_proj_states"
]
q_conv_weights
=
kwargs
[
"q_conv_weights"
]
k_conv_weights
=
kwargs
[
"k_conv_weights"
]
v_conv_weights
=
kwargs
[
"v_conv_weights"
]
q_conv_bias
=
kwargs
[
"q_conv_bias"
]
k_conv_bias
=
kwargs
[
"k_conv_bias"
]
v_conv_bias
=
kwargs
[
"v_conv_bias"
]
A_log
=
kwargs
[
"A_log"
]
dt_bias
=
kwargs
[
"dt_bias"
]
b_proj
=
kwargs
[
"b_proj"
]
f_a_proj
=
kwargs
[
"f_a_proj"
]
f_b_proj
=
kwargs
[
"f_b_proj"
]
hidden_states
=
kwargs
[
"hidden_states"
]
head_dim
=
kwargs
[
"head_dim"
]
layer_id
=
kwargs
[
"layer_id"
]
query_start_loc
=
self
.
forward_metadata
.
query_start_loc
cache_indices
=
self
.
forward_metadata
.
mamba_cache_indices
mamba_cache_params
=
self
.
req_to_token_pool
.
mamba2_layer_cache
(
layer_id
)
conv_state_q
,
conv_state_k
,
conv_state_v
=
mamba_cache_params
.
conv
# deal with strides
conv_state_q
=
conv_state_q
.
transpose
(
-
1
,
-
2
)
conv_state_k
=
conv_state_k
.
transpose
(
-
1
,
-
2
)
conv_state_v
=
conv_state_v
.
transpose
(
-
1
,
-
2
)
ssm_states
=
mamba_cache_params
.
temporal
has_initial_state
=
forward_batch
.
extend_prefix_lens
>
0
q_proj_states
=
q_proj_states
.
transpose
(
0
,
1
)
k_proj_states
=
k_proj_states
.
transpose
(
0
,
1
)
v_proj_states
=
v_proj_states
.
transpose
(
0
,
1
)
q
=
causal_conv1d_fn
(
q_proj_states
,
q_conv_weights
,
q_conv_bias
,
activation
=
"silu"
,
conv_states
=
conv_state_q
,
has_initial_state
=
has_initial_state
,
cache_indices
=
cache_indices
,
query_start_loc
=
query_start_loc
,
seq_lens_cpu
=
forward_batch
.
extend_seq_lens_cpu
,
).
transpose
(
0
,
1
)
k
=
causal_conv1d_fn
(
k_proj_states
,
k_conv_weights
,
k_conv_bias
,
activation
=
"silu"
,
conv_states
=
conv_state_k
,
has_initial_state
=
has_initial_state
,
cache_indices
=
cache_indices
,
query_start_loc
=
query_start_loc
,
seq_lens_cpu
=
forward_batch
.
extend_seq_lens_cpu
,
).
transpose
(
0
,
1
)
v
=
causal_conv1d_fn
(
v_proj_states
,
v_conv_weights
,
v_conv_bias
,
activation
=
"silu"
,
conv_states
=
conv_state_v
,
has_initial_state
=
has_initial_state
,
cache_indices
=
cache_indices
,
query_start_loc
=
query_start_loc
,
seq_lens_cpu
=
forward_batch
.
extend_seq_lens_cpu
,
).
transpose
(
0
,
1
)
q
,
k
,
v
=
map
(
lambda
x
:
rearrange
(
x
,
"n (h d) -> 1 n h d"
,
d
=
head_dim
),
(
q
,
k
,
v
)
)
beta
=
b_proj
(
hidden_states
)[
0
].
float
().
sigmoid
()
g
=
f_b_proj
(
f_a_proj
(
hidden_states
)[
0
])[
0
]
g
=
fused_kda_gate
(
g
,
A_log
,
head_dim
,
g_bias
=
dt_bias
)
beta
=
beta
.
unsqueeze
(
0
)
g
=
g
.
unsqueeze
(
0
)
initial_state
=
ssm_states
[
cache_indices
].
contiguous
()
(
core_attn_out
,
last_recurrent_state
,
)
=
chunk_kda
(
q
=
q
,
k
=
k
,
v
=
v
,
g
=
g
,
beta
=
beta
,
initial_state
=
initial_state
,
output_final_state
=
True
,
use_qk_l2norm_in_kernel
=
True
,
cu_seqlens
=
query_start_loc
,
)
ssm_states
[
cache_indices
]
=
last_recurrent_state
return
core_attn_out
class
GDNAttnBackend
(
MambaAttnBackendBase
):
"""Attention backend using Mamba kernel."""
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
a4bf5c6a
...
...
@@ -92,7 +92,10 @@ class TritonAttnBackend(AttentionBackend):
self
.
num_kv_head
=
model_runner
.
model_config
.
get_num_kv_heads
(
get_attention_tp_size
()
)
if
model_runner
.
hybrid_gdn_config
is
not
None
:
if
(
model_runner
.
hybrid_gdn_config
is
not
None
or
model_runner
.
kimi_linear_config
is
not
None
):
# For hybrid linear models, layer_id = 0 may not be full attention
self
.
v_head_dim
=
model_runner
.
token_to_kv_pool
.
get_v_head_dim
()
else
:
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
a4bf5c6a
...
...
@@ -17,7 +17,7 @@ from __future__ import annotations
from
dataclasses
import
dataclass
from
sglang.srt.configs.mamba_utils
import
Mamba2CacheParams
from
sglang.srt.configs.mamba_utils
import
KimiLinearCacheParams
,
Mamba2CacheParams
from
sglang.srt.layers.attention.nsa
import
index_buf_accessor
from
sglang.srt.layers.attention.nsa.quant_k_cache
import
quantize_k_cache
from
sglang.srt.utils.torch_memory_saver_adapter
import
TorchMemorySaverAdapter
...
...
@@ -33,7 +33,7 @@ KVCache actually holds the physical kv cache.
import
abc
import
logging
from
contextlib
import
nullcontext
from
contextlib
import
contextmanager
,
nullcontext
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
...
...
@@ -59,7 +59,9 @@ if _is_npu:
import
torch_npu
def
get_tensor_size_bytes
(
t
:
torch
.
Tensor
):
def
get_tensor_size_bytes
(
t
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]):
if
isinstance
(
t
,
list
):
return
sum
(
get_tensor_size_bytes
(
x
)
for
x
in
t
)
return
np
.
prod
(
t
.
shape
)
*
t
.
dtype
.
itemsize
...
...
@@ -116,10 +118,15 @@ class ReqToTokenPool:
class
MambaPool
:
@
dataclass
(
frozen
=
True
,
kw_only
=
True
)
class
State
:
conv
:
torch
.
Tensor
conv
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
temporal
:
torch
.
Tensor
def
at_layer_idx
(
self
,
layer
:
int
):
if
isinstance
(
self
.
conv
,
list
):
return
type
(
self
)(
conv
=
[
v
[
layer
]
for
v
in
self
.
conv
],
temporal
=
self
.
temporal
[
layer
],
)
return
type
(
self
)(
**
{
k
:
v
[
layer
]
for
k
,
v
in
vars
(
self
).
items
()})
def
mem_usage_bytes
(
self
):
...
...
@@ -127,14 +134,14 @@ class MambaPool:
@
dataclass
(
frozen
=
True
,
kw_only
=
True
)
class
SpeculativeState
(
State
):
intermediate_ssm
:
torch
.
Tensor
intermediate_ssm
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
intermediate_conv_window
:
torch
.
Tensor
def
__init__
(
self
,
*
,
size
:
int
,
cache_params
:
"Mamba2CacheParams"
,
cache_params
:
Union
[
"Mamba2CacheParams"
,
"KimiLinearCacheParams"
],
device
:
str
,
speculative_num_draft_tokens
:
Optional
[
int
]
=
None
,
):
...
...
@@ -157,18 +164,29 @@ class MambaPool:
else
:
self
.
custom_mem_pool
=
None
self
.
is_kda_cache
=
isinstance
(
cache_params
,
KimiLinearCacheParams
)
with
(
torch
.
cuda
.
use_mem_pool
(
self
.
custom_mem_pool
)
if
self
.
enable_custom_mem_pool
else
nullcontext
()
):
# assume conv_state = (dim, state_len)
assert
conv_state_shape
[
0
]
>
conv_state_shape
[
1
]
conv_state
=
torch
.
zeros
(
size
=
(
num_mamba_layers
,
size
+
1
)
+
conv_state_shape
,
dtype
=
conv_dtype
,
device
=
device
,
)
if
self
.
is_kda_cache
:
conv_state
=
[
torch
.
zeros
(
size
=
(
num_mamba_layers
,
size
+
1
)
+
conv_shape
,
dtype
=
conv_dtype
,
device
=
device
,
)
for
conv_shape
in
conv_state_shape
]
else
:
# assume conv_state = (dim, state_len)
assert
conv_state_shape
[
0
]
>
conv_state_shape
[
1
]
conv_state
=
torch
.
zeros
(
size
=
(
num_mamba_layers
,
size
+
1
)
+
conv_state_shape
,
dtype
=
conv_dtype
,
device
=
device
,
)
temporal_state
=
torch
.
zeros
(
size
=
(
num_mamba_layers
,
size
+
1
)
+
temporal_state_shape
,
dtype
=
ssm_dtype
,
...
...
@@ -191,17 +209,34 @@ class MambaPool:
)
# Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
intermediate_conv_window_cache
=
torch
.
zeros
(
size
=
(
num_mamba_layers
,
size
+
1
,
speculative_num_draft_tokens
,
conv_state_shape
[
0
],
conv_state_shape
[
1
],
),
dtype
=
conv_dtype
,
device
=
"cuda"
,
)
if
self
.
is_kda_cache
:
intermediate_conv_window_cache
=
[
torch
.
zeros
(
size
=
(
num_mamba_layers
,
size
+
1
,
speculative_num_draft_tokens
,
conv_shape
[
0
],
conv_shape
[
1
],
),
dtype
=
conv_dtype
,
device
=
"cuda"
,
)
for
conv_shape
in
conv_state_shape
]
else
:
intermediate_conv_window_cache
=
torch
.
zeros
(
size
=
(
num_mamba_layers
,
size
+
1
,
speculative_num_draft_tokens
,
conv_state_shape
[
0
],
conv_state_shape
[
1
],
),
dtype
=
conv_dtype
,
device
=
"cuda"
,
)
self
.
mamba_cache
=
self
.
SpeculativeState
(
conv
=
conv_state
,
temporal
=
temporal_state
,
...
...
@@ -255,15 +290,25 @@ class MambaPool:
if
free_index
.
numel
()
==
0
:
return
self
.
free_slots
=
torch
.
cat
((
self
.
free_slots
,
free_index
))
self
.
mamba_cache
.
conv
[:,
free_index
]
=
self
.
mamba_cache
.
temporal
[
:,
free_index
]
=
0
if
self
.
is_kda_cache
:
for
i
in
range
(
len
(
self
.
mamba_cache
.
conv
)):
self
.
mamba_cache
.
conv
[
i
][:,
free_index
]
=
0
else
:
self
.
mamba_cache
.
conv
[:,
free_index
]
=
0
self
.
mamba_cache
.
temporal
[:,
free_index
]
=
0
def
clear
(
self
):
self
.
free_slots
=
torch
.
arange
(
self
.
size
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
def
copy_from
(
self
,
src_index
:
torch
.
Tensor
,
dst_index
:
torch
.
Tensor
):
self
.
mamba_cache
.
conv
[:,
dst_index
]
=
self
.
mamba_cache
.
conv
[:,
src_index
]
if
self
.
is_kda_cache
:
for
i
in
range
(
len
(
self
.
mamba_cache
.
conv
)):
self
.
mamba_cache
.
conv
[
i
][:,
dst_index
]
=
self
.
mamba_cache
.
conv
[
i
][
:,
src_index
]
else
:
self
.
mamba_cache
.
conv
[:,
dst_index
]
=
self
.
mamba_cache
.
conv
[:,
src_index
]
self
.
mamba_cache
.
temporal
[:,
dst_index
]
=
self
.
mamba_cache
.
temporal
[
:,
src_index
]
...
...
@@ -304,7 +349,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
max_context_len
:
int
,
device
:
str
,
enable_memory_saver
:
bool
,
cache_params
:
"Mamba2CacheParams"
,
cache_params
:
Union
[
"Mamba2CacheParams"
,
"KimiLinearCacheParams"
],
speculative_num_draft_tokens
:
int
=
None
,
):
super
().
__init__
(
...
...
@@ -323,7 +368,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
def
_init_mamba_pool
(
self
,
size
:
int
,
cache_params
:
"Mamba2CacheParams"
,
cache_params
:
Union
[
"Mamba2CacheParams"
,
"KimiLinearCacheParams"
],
device
:
str
,
speculative_num_draft_tokens
:
int
=
None
,
):
...
...
@@ -812,6 +857,10 @@ class HybridLinearKVPool(KVCache):
enable_kvcache_transpose
:
bool
,
device
:
str
,
mamba_pool
:
MambaPool
,
# TODO: refactor mla related args
use_mla
:
bool
=
False
,
kv_lora_rank
:
int
=
None
,
qk_rope_head_dim
:
int
=
None
,
):
self
.
size
=
size
self
.
dtype
=
dtype
...
...
@@ -825,25 +874,42 @@ class HybridLinearKVPool(KVCache):
self
.
mamba_pool
=
mamba_pool
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
assert
not
enable_kvcache_transpose
if
_is_npu
:
TokenToKVPoolClass
=
AscendTokenToKVPool
self
.
use_mla
=
use_mla
if
not
use_mla
:
if
_is_npu
:
TokenToKVPoolClass
=
AscendTokenToKVPool
else
:
TokenToKVPoolClass
=
MHATokenToKVPool
self
.
full_kv_pool
=
TokenToKVPoolClass
(
size
=
size
,
page_size
=
self
.
page_size
,
dtype
=
dtype
,
head_num
=
head_num
,
head_dim
=
head_dim
,
layer_num
=
self
.
full_layer_nums
,
device
=
device
,
enable_memory_saver
=
False
,
)
else
:
TokenToKVPoolClass
=
M
H
ATokenToKVPool
self
.
full_kv_pool
=
TokenToKVPoolClass
(
size
=
size
,
page_size
=
self
.
page_size
,
dtype
=
dtype
,
head_num
=
head
_num
,
head_dim
=
head_dim
,
layer_num
=
self
.
full_layer_nums
,
device
=
device
,
enable_memory_saver
=
False
,
)
TokenToKVPoolClass
=
M
L
ATokenToKVPool
self
.
full_kv_pool
=
TokenToKVPoolClass
(
size
=
size
,
page_size
=
self
.
page_size
,
dtype
=
dtype
,
layer_num
=
self
.
full_layer
_num
s
,
device
=
device
,
kv_lora_rank
=
kv_lora_rank
,
qk_rope_head_dim
=
qk_rope_head_dim
,
enable_memory_saver
=
False
,
)
self
.
full_attention_layer_id_mapping
=
{
id
:
i
for
i
,
id
in
enumerate
(
full_attention_layer_ids
)
}
k_size
,
v_size
=
self
.
get_kv_size_bytes
()
self
.
mem_usage
=
(
k_size
+
v_size
)
/
GB
if
use_mla
:
self
.
mem_usage
=
self
.
get_kv_size_bytes
()
/
GB
else
:
k_size
,
v_size
=
self
.
get_kv_size_bytes
()
self
.
mem_usage
=
(
k_size
+
v_size
)
/
GB
def
get_kv_size_bytes
(
self
):
return
self
.
full_kv_pool
.
get_kv_size_bytes
()
...
...
@@ -879,6 +945,21 @@ class HybridLinearKVPool(KVCache):
layer_id
=
self
.
_transfer_full_attention_id
(
layer_id
)
return
self
.
full_kv_pool
.
get_kv_buffer
(
layer_id
)
@
contextmanager
def
_transfer_id_context
(
self
,
layer
:
RadixAttention
):
@
contextmanager
def
_patch_layer_id
(
layer
):
original_layer_id
=
layer
.
layer_id
layer
.
layer_id
=
self
.
_transfer_full_attention_id
(
layer
.
layer_id
)
try
:
yield
finally
:
layer
.
layer_id
=
original_layer_id
with
_patch_layer_id
(
layer
):
yield
def
set_kv_buffer
(
self
,
layer
:
RadixAttention
,
...
...
@@ -889,19 +970,49 @@ class HybridLinearKVPool(KVCache):
v_scale
:
float
=
1.0
,
):
layer_id
=
self
.
_transfer_full_attention_id
(
layer
.
layer_id
)
self
.
full_kv_pool
.
set_kv_buffer
(
None
,
loc
,
cache_k
,
cache_v
,
k_scale
,
v_scale
,
layer_id_override
=
layer_id
,
)
if
not
self
.
use_mla
:
self
.
full_kv_pool
.
set_kv_buffer
(
None
,
loc
,
cache_k
,
cache_v
,
k_scale
,
v_scale
,
layer_id_override
=
layer_id
,
)
else
:
with
self
.
_transfer_id_context
(
layer
):
self
.
full_kv_pool
.
set_kv_buffer
(
layer
,
loc
,
cache_k
,
cache_v
,
)
def
get_v_head_dim
(
self
):
return
self
.
full_kv_pool
.
get_value_buffer
(
0
).
shape
[
-
1
]
def
set_mla_kv_buffer
(
self
,
layer
:
RadixAttention
,
loc
:
torch
.
Tensor
,
cache_k_nope
:
torch
.
Tensor
,
cache_k_rope
:
torch
.
Tensor
,
):
assert
self
.
use_mla
,
"set_mla_kv_buffer called when use_mla is False"
with
self
.
_transfer_id_context
(
layer
):
self
.
full_kv_pool
.
set_mla_kv_buffer
(
layer
,
loc
,
cache_k_nope
,
cache_k_rope
)
def
get_mla_kv_buffer
(
self
,
layer
:
RadixAttention
,
loc
:
torch
.
Tensor
,
dst_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
):
assert
self
.
use_mla
,
"get_mla_kv_buffer called when use_mla is False"
with
self
.
_transfer_id_context
(
layer
):
return
self
.
full_kv_pool
.
get_mla_kv_buffer
(
layer
,
loc
,
dst_dtype
)
class
SWAKVPool
(
KVCache
):
"""KV cache with separate pools for full and SWA attention layers."""
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
a4bf5c6a
...
...
@@ -29,7 +29,12 @@ from typing import Callable, List, Optional, Tuple, Union
import
torch
import
torch.distributed
as
dist
from
sglang.srt.configs
import
FalconH1Config
,
NemotronHConfig
,
Qwen3NextConfig
from
sglang.srt.configs
import
(
FalconH1Config
,
KimiLinearConfig
,
NemotronHConfig
,
Qwen3NextConfig
,
)
from
sglang.srt.configs.device_config
import
DeviceConfig
from
sglang.srt.configs.load_config
import
LoadConfig
,
LoadFormat
from
sglang.srt.configs.model_config
import
(
...
...
@@ -1358,9 +1363,16 @@ class ModelRunner:
return
config
return
None
@
property
def
kimi_linear_config
(
self
):
config
=
self
.
model_config
.
hf_config
if
isinstance
(
config
,
KimiLinearConfig
):
return
config
return
None
@
property
def
mambaish_config
(
self
):
return
self
.
mamba2_config
or
self
.
hybrid_gdn_config
return
self
.
mamba2_config
or
self
.
hybrid_gdn_config
or
self
.
kimi_linear_config
def
set_num_token_hybrid
(
self
):
if
(
...
...
@@ -1691,7 +1703,7 @@ class ModelRunner:
end_layer
=
self
.
end_layer
,
index_head_dim
=
get_nsa_index_head_dim
(
self
.
model_config
.
hf_config
),
)
elif
self
.
use_mla_backend
:
elif
self
.
use_mla_backend
and
not
self
.
mambaish_config
:
assert
not
is_nsa_model
self
.
token_to_kv_pool
=
MLATokenToKVPool
(
self
.
max_total_num_tokens
,
...
...
@@ -1735,6 +1747,12 @@ class ModelRunner:
device
=
self
.
device
,
)
elif
config
:
=
self
.
mambaish_config
:
extra_args
=
{}
if
self
.
use_mla_backend
:
extra_args
=
{
"kv_lora_rank"
:
self
.
model_config
.
kv_lora_rank
,
"qk_rope_head_dim"
:
self
.
model_config
.
qk_rope_head_dim
,
}
self
.
token_to_kv_pool
=
HybridLinearKVPool
(
page_size
=
self
.
page_size
,
size
=
self
.
max_total_num_tokens
,
...
...
@@ -1750,6 +1768,8 @@ class ModelRunner:
enable_kvcache_transpose
=
False
,
device
=
self
.
device
,
mamba_pool
=
self
.
req_to_token_pool
.
mamba_pool
,
use_mla
=
self
.
use_mla_backend
,
**
extra_args
,
)
else
:
self
.
token_to_kv_pool
=
MHATokenToKVPool
(
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
a4bf5c6a
...
...
@@ -1075,6 +1075,7 @@ class DeepseekV2AttentionMLA(nn.Module):
layer_id
:
int
=
None
,
prefix
:
str
=
""
,
alt_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
skip_rope
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
self
.
layer_id
=
layer_id
...
...
@@ -1182,23 +1183,26 @@ class DeepseekV2AttentionMLA(nn.Module):
)
self
.
kv_a_layernorm
=
RMSNorm
(
self
.
kv_lora_rank
,
eps
=
config
.
rms_norm_eps
)
self
.
rotary_emb
=
get_rope_wrapper
(
qk_rope_head_dim
,
rotary_dim
=
qk_rope_head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
is_neox_style
=
False
,
device
=
get_global_server_args
().
device
,
)
if
not
skip_rope
:
self
.
rotary_emb
=
get_rope_wrapper
(
qk_rope_head_dim
,
rotary_dim
=
qk_rope_head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
is_neox_style
=
False
,
device
=
get_global_server_args
().
device
,
)
if
rope_scaling
:
mscale_all_dim
=
rope_scaling
.
get
(
"mscale_all_dim"
,
False
)
scaling_factor
=
rope_scaling
[
"factor"
]
mscale
=
yarn_get_mscale
(
scaling_factor
,
float
(
mscale_all_dim
))
self
.
scaling
=
self
.
scaling
*
mscale
*
mscale
if
rope_scaling
:
mscale_all_dim
=
rope_scaling
.
get
(
"mscale_all_dim"
,
False
)
scaling_factor
=
rope_scaling
[
"factor"
]
mscale
=
yarn_get_mscale
(
scaling_factor
,
float
(
mscale_all_dim
))
self
.
scaling
=
self
.
scaling
*
mscale
*
mscale
else
:
self
.
rotary_emb
.
forward
=
self
.
rotary_emb
.
forward_native
else
:
self
.
rotary_emb
.
forward
=
self
.
rotary_emb
.
forward_nativ
e
self
.
rotary_emb
=
Non
e
self
.
attn_mqa
=
RadixAttention
(
self
.
num_local_heads
,
...
...
@@ -1487,7 +1491,8 @@ class DeepseekV2AttentionMLA(nn.Module):
latent_cache
=
latent_cache
.
unsqueeze
(
1
)
kv_a
=
self
.
kv_a_layernorm
(
kv_a
)
k_pe
=
latent_cache
[:,
:,
self
.
kv_lora_rank
:]
q_pe
,
k_pe
=
self
.
rotary_emb
(
positions
,
q_pe
,
k_pe
)
if
self
.
rotary_emb
is
not
None
:
q_pe
,
k_pe
=
self
.
rotary_emb
(
positions
,
q_pe
,
k_pe
)
q
[...,
self
.
qk_nope_head_dim
:]
=
q_pe
self
.
_set_mla_kv_buffer
(
latent_cache
,
kv_a
,
k_pe
,
forward_batch
)
...
...
@@ -1646,8 +1651,10 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out
=
q_nope_out
.
transpose
(
0
,
1
)
if
not
self
.
_fuse_rope_for_trtllm_mla
(
forward_batch
)
and
(
not
_use_aiter
or
not
_is_gfx95_supported
or
self
.
use_nsa
if
(
self
.
rotary_emb
is
not
None
and
(
not
self
.
_fuse_rope_for_trtllm_mla
(
forward_batch
))
and
(
not
_use_aiter
or
not
_is_gfx95_supported
or
self
.
use_nsa
)
):
q_pe
,
k_pe
=
self
.
rotary_emb
(
positions
,
q_pe
,
k_pe
)
...
...
python/sglang/srt/models/kimi_linear.py
0 → 100644
View file @
a4bf5c6a
# Adapted from: https://github.com/vllm-project/vllm/blob/0384aa7150c4c9778efca041ffd1beb3ad2bd694/vllm/model_executor/models/kimi_linear.py
from
collections.abc
import
Iterable
from
typing
import
Optional
import
torch
from
einops
import
rearrange
from
torch
import
nn
from
sglang.srt.configs.kimi_linear
import
KimiLinearConfig
from
sglang.srt.distributed
import
(
divide
,
get_pp_group
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
)
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.layers.attention.fla.kda
import
FusedRMSNormGated
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
ColumnParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.ep_moe.layer
import
get_moe_impl_class
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopK
,
TopKOutputFormat
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.utils
import
PPMissingLayer
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
,
sharded_weight_loader
,
)
from
sglang.srt.models.deepseek_v2
import
DeepseekV2AttentionMLA
as
KimiMLAAttention
from
sglang.srt.models.llama
import
LlamaMLP
as
KimiMLP
from
sglang.srt.models.transformers
import
maybe_prefix
from
sglang.srt.utils
import
make_layers
from
sglang.srt.utils.common
import
BumpAllocator
,
add_prefix
,
set_weight_attrs
class
KimiMoE
(
nn
.
Module
):
def
__init__
(
self
,
config
:
KimiLinearConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
layer_idx
:
int
=
0
,
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
intermediate_size
=
config
.
intermediate_size
moe_intermediate_size
=
config
.
moe_intermediate_size
num_experts
=
config
.
num_experts
moe_renormalize
=
config
.
moe_renormalize
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
routed_scaling_factor
=
config
.
routed_scaling_factor
self
.
num_shared_experts
=
config
.
num_shared_experts
self
.
layer_idx
=
layer_idx
if
config
.
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
config
.
hidden_act
}
. "
"Only silu is supported for now."
)
# Gate always runs at half / full precision for now.
self
.
gate
=
ReplicatedLinear
(
hidden_size
,
num_experts
,
bias
=
False
,
quant_config
=
None
,
prefix
=
f
"
{
prefix
}
.gate"
,
)
self
.
gate
.
e_score_correction_bias
=
nn
.
Parameter
(
torch
.
empty
(
num_experts
))
self
.
experts
=
get_moe_impl_class
(
quant_config
)(
num_experts
=
config
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_token
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
layer_id
=
self
.
layer_idx
,
quant_config
=
quant_config
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
prefix
=
add_prefix
(
"experts"
,
prefix
),
)
self
.
topk
=
TopK
(
top_k
=
config
.
num_experts_per_token
,
renormalize
=
moe_renormalize
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
num_expert_group
,
topk_group
=
config
.
topk_group
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
quant_config
=
quant_config
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
apply_routed_scaling_factor_on_output
=
self
.
experts
.
should_fuse_routed_scaling_factor_in_topk
,
# Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
# and requires the output format to be standard. We use quant_config to determine the output format.
output_format
=
TopKOutputFormat
.
STANDARD
if
quant_config
is
None
else
None
,
)
if
self
.
num_shared_experts
is
not
None
:
intermediate_size
=
moe_intermediate_size
*
self
.
num_shared_experts
self
.
shared_experts
=
KimiMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
reduce_results
=
False
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_size
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_size
)
if
self
.
num_shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_size
)
class
KimiDeltaAttention
(
nn
.
Module
):
def
__init__
(
self
,
layer_idx
:
int
,
hidden_size
:
int
,
config
:
KimiLinearConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
rms_norm_eps
:
float
=
1e-5
,
prefix
:
str
=
""
,
**
kwargs
,
)
->
None
:
super
().
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
hidden_size
=
hidden_size
self
.
config
=
config
self
.
head_dim
=
config
.
linear_attn_config
[
"head_dim"
]
self
.
num_heads
=
config
.
linear_attn_config
[
"num_heads"
]
self
.
layer_idx
=
layer_idx
self
.
prefix
=
prefix
assert
self
.
num_heads
%
self
.
tp_size
==
0
self
.
local_num_heads
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
projection_size
=
self
.
head_dim
*
self
.
num_heads
self
.
conv_size
=
config
.
linear_attn_config
[
"short_conv_kernel_size"
]
self
.
q_proj
=
ColumnParallelLinear
(
self
.
hidden_size
,
projection_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.q_proj"
,
)
self
.
k_proj
=
ColumnParallelLinear
(
self
.
hidden_size
,
projection_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.k_proj"
,
)
self
.
v_proj
=
ColumnParallelLinear
(
self
.
hidden_size
,
projection_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.v_proj"
,
)
self
.
f_a_proj
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.f_a_proj"
,
)
self
.
f_b_proj
=
ColumnParallelLinear
(
self
.
head_dim
,
projection_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.f_b_proj"
,
)
self
.
dt_bias
=
nn
.
Parameter
(
torch
.
empty
(
divide
(
projection_size
,
self
.
tp_size
),
dtype
=
torch
.
float32
)
)
set_weight_attrs
(
self
.
dt_bias
,
{
"weight_loader"
:
sharded_weight_loader
(
0
)})
self
.
b_proj
=
ColumnParallelLinear
(
self
.
hidden_size
,
self
.
num_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.b_proj"
,
)
self
.
q_conv1d
=
ColumnParallelLinear
(
input_size
=
self
.
conv_size
,
output_size
=
projection_size
,
bias
=
False
,
params_dtype
=
torch
.
float32
,
prefix
=
f
"
{
prefix
}
.q_conv1d"
,
)
self
.
k_conv1d
=
ColumnParallelLinear
(
input_size
=
self
.
conv_size
,
output_size
=
projection_size
,
bias
=
False
,
params_dtype
=
torch
.
float32
,
prefix
=
f
"
{
prefix
}
.k_conv1d"
,
)
self
.
v_conv1d
=
ColumnParallelLinear
(
input_size
=
self
.
conv_size
,
output_size
=
projection_size
,
bias
=
False
,
params_dtype
=
torch
.
float32
,
prefix
=
f
"
{
prefix
}
.v_conv1d"
,
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self
.
q_conv1d
.
weight
.
data
=
self
.
q_conv1d
.
weight
.
data
.
unsqueeze
(
1
)
self
.
k_conv1d
.
weight
.
data
=
self
.
k_conv1d
.
weight
.
data
.
unsqueeze
(
1
)
self
.
v_conv1d
.
weight
.
data
=
self
.
v_conv1d
.
weight
.
data
.
unsqueeze
(
1
)
self
.
A_log
=
nn
.
Parameter
(
torch
.
empty
(
1
,
1
,
self
.
local_num_heads
,
1
,
dtype
=
torch
.
float32
)
)
set_weight_attrs
(
self
.
A_log
,
{
"weight_loader"
:
sharded_weight_loader
(
2
)})
self
.
g_a_proj
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.g_a_proj"
,
)
self
.
g_b_proj
=
ColumnParallelLinear
(
self
.
head_dim
,
projection_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.g_b_proj"
,
)
self
.
o_norm
=
FusedRMSNormGated
(
self
.
head_dim
,
eps
=
rms_norm_eps
,
activation
=
"sigmoid"
)
self
.
o_proj
=
RowParallelLinear
(
projection_size
,
self
.
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
)
->
None
:
q_proj_states
=
self
.
q_proj
(
hidden_states
)[
0
]
k_proj_states
=
self
.
k_proj
(
hidden_states
)[
0
]
v_proj_states
=
self
.
v_proj
(
hidden_states
)[
0
]
q_conv_weights
=
self
.
q_conv1d
.
weight
.
view
(
self
.
q_conv1d
.
weight
.
size
(
0
),
self
.
q_conv1d
.
weight
.
size
(
2
)
)
k_conv_weights
=
self
.
k_conv1d
.
weight
.
view
(
self
.
k_conv1d
.
weight
.
size
(
0
),
self
.
k_conv1d
.
weight
.
size
(
2
)
)
v_conv_weights
=
self
.
v_conv1d
.
weight
.
view
(
self
.
v_conv1d
.
weight
.
size
(
0
),
self
.
v_conv1d
.
weight
.
size
(
2
)
)
kwargs
=
{
"q_proj_states"
:
q_proj_states
,
"k_proj_states"
:
k_proj_states
,
"v_proj_states"
:
v_proj_states
,
"q_conv_weights"
:
q_conv_weights
,
"k_conv_weights"
:
k_conv_weights
,
"v_conv_weights"
:
v_conv_weights
,
"q_conv_bias"
:
self
.
q_conv1d
.
bias
,
"k_conv_bias"
:
self
.
k_conv1d
.
bias
,
"v_conv_bias"
:
self
.
v_conv1d
.
bias
,
"dt_bias"
:
self
.
dt_bias
,
"b_proj"
:
self
.
b_proj
,
"f_a_proj"
:
self
.
f_a_proj
,
"f_b_proj"
:
self
.
f_b_proj
,
"A_log"
:
self
.
A_log
,
"head_dim"
:
self
.
head_dim
,
"hidden_states"
:
hidden_states
,
"layer_id"
:
self
.
layer_idx
,
}
core_attn_out
=
forward_batch
.
attn_backend
.
forward
(
q
=
None
,
k
=
None
,
v
=
None
,
layer
=
None
,
forward_batch
=
forward_batch
,
**
kwargs
,
)
g_proj_states
=
self
.
g_b_proj
(
self
.
g_a_proj
(
hidden_states
)[
0
])[
0
]
g
=
rearrange
(
g_proj_states
,
"... (h d) -> ... h d"
,
d
=
self
.
head_dim
)
core_attn_out
=
self
.
o_norm
(
core_attn_out
,
g
)
core_attn_out
=
rearrange
(
core_attn_out
,
"1 n h d -> n (h d)"
)
return
self
.
o_proj
(
core_attn_out
)[
0
]
class
KimiDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
KimiLinearConfig
,
layer_idx
:
int
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
is_moe
=
config
.
is_moe
if
config
.
is_kda_layer
(
layer_idx
):
self
.
self_attn
=
KimiDeltaAttention
(
layer_idx
=
layer_idx
,
hidden_size
=
config
.
hidden_size
,
config
=
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
else
:
self
.
self_attn
=
KimiMLAAttention
(
layer_id
=
layer_idx
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
config
=
config
,
qk_nope_head_dim
=
config
.
qk_nope_head_dim
,
qk_rope_head_dim
=
config
.
qk_rope_head_dim
,
v_head_dim
=
config
.
v_head_dim
,
q_lora_rank
=
config
.
q_lora_rank
,
kv_lora_rank
=
config
.
kv_lora_rank
,
skip_rope
=
True
,
)
if
(
self
.
is_moe
and
config
.
num_experts
is
not
None
and
layer_idx
>=
config
.
first_k_dense_replace
and
layer_idx
%
config
.
moe_layer_freq
==
0
):
self
.
block_sparse_moe
=
KimiMoE
(
config
=
config
,
quant_config
=
quant_config
,
layer_idx
=
layer_idx
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
self
.
mlp
=
self
.
block_sparse_moe
else
:
self
.
mlp
=
KimiMLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
zero_allocator
:
BumpAllocator
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
,
positions
=
positions
,
forward_batch
=
forward_batch
,
zero_allocator
=
zero_allocator
,
)
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
class
KimiLinearModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
KimiLinearConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
pp_group
=
get_pp_group
()
if
self
.
pp_group
.
is_first_rank
:
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
prefix
=
f
"
{
prefix
}
.embed_tokens"
,
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
=
make_layers
(
config
.
num_hidden_layers
,
lambda
idx
,
prefix
:
KimiDecoderLayer
(
layer_idx
=
idx
,
config
=
config
,
quant_config
=
quant_config
,
prefix
=
prefix
,
),
pp_rank
=
self
.
pp_group
.
rank_in_group
,
pp_size
=
self
.
pp_group
.
world_size
,
prefix
=
f
"
{
prefix
}
.layers"
,
)
if
self
.
pp_group
.
is_last_rank
:
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
()
world_size
=
get_tensor_model_parallel_world_size
()
assert
(
config
.
num_attention_heads
%
world_size
==
0
),
"num_attention_heads must be divisible by world_size"
def
forward
(
self
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
torch
.
Tensor
:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
else
:
assert
pp_proxy_tensors
is
not
None
hidden_states
=
pp_proxy_tensors
[
"hidden_states"
]
residual
=
pp_proxy_tensors
[
"residual"
]
total_num_layers
=
self
.
end_layer
-
self
.
start_layer
device
=
hidden_states
.
device
zero_allocator
=
BumpAllocator
(
buffer_size
=
total_num_layers
*
2
,
dtype
=
torch
.
float32
,
device
=
device
,
)
# TODO: capture aux hidden states
aux_hidden_states
=
[]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
ctx
=
get_global_expert_distribution_recorder
().
with_current_layer
(
i
)
with
ctx
:
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
residual
=
residual
,
zero_allocator
=
zero_allocator
,
)
if
not
self
.
pp_group
.
is_last_rank
:
return
PPProxyTensors
(
{
"hidden_states"
:
hidden_states
,
"residual"
:
residual
,
}
)
else
:
if
hidden_states
.
shape
[
0
]
!=
0
:
if
residual
is
None
:
hidden_states
=
self
.
norm
(
hidden_states
)
else
:
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
if
len
(
aux_hidden_states
)
==
0
:
return
hidden_states
return
hidden_states
,
aux_hidden_states
class
KimiLinearForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
KimiLinearConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
KimiLinearModel
(
config
,
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
)
)
self
.
pp_group
=
get_pp_group
()
if
self
.
pp_group
.
is_last_rank
:
self
.
lm_head
=
ParallelLMHead
(
self
.
config
.
vocab_size
,
self
.
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
else
:
self
.
lm_head
=
PPMissingLayer
()
logit_scale
=
getattr
(
self
.
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
config
=
config
,
logit_scale
=
logit_scale
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
inputs_embeds
,
pp_proxy_tensors
,
)
if
self
.
pp_group
.
is_last_rank
:
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
else
:
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
if
self
.
config
.
is_moe
:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"w1"
,
ckpt_down_proj_name
=
"w2"
,
ckpt_up_proj_name
=
"w3"
,
num_experts
=
self
.
config
.
num_experts
,
)
else
:
expert_params_mapping
=
[]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
for
args
in
weights
:
name
,
loaded_weight
=
args
[:
2
]
kwargs
=
args
[
2
]
if
len
(
args
)
>
2
else
{}
if
"rotary_emb.inv_freq"
in
name
:
continue
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if
(
"mlp.experts."
in
name
)
and
name
not
in
params_dict
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# if is_pp_missing_parameter(name, self):
# continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
for
idx
,
(
param_name
,
weight_name
,
expert_id
,
shard_id
)
in
enumerate
(
expert_params_mapping
):
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# if is_pp_missing_parameter(name, self):
# continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
name
,
expert_id
=
expert_id
,
shard_id
=
shard_id
,
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
(
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
and
not
self
.
config
.
is_linear_attn
):
# noqa: E501
continue
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
# if is_pp_missing_parameter(name, self):
# continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
,
**
kwargs
)
loaded_params
.
add
(
name
)
for
layer_id
in
self
.
config
.
full_attention_layer_ids
:
self_attn
=
self
.
model
.
layers
[
layer_id
].
self_attn
w_kc
,
w_vc
=
self_attn
.
kv_b_proj
.
weight
.
unflatten
(
0
,
(
-
1
,
self_attn
.
qk_nope_head_dim
+
self_attn
.
v_head_dim
)
).
split
([
self_attn
.
qk_nope_head_dim
,
self_attn
.
v_head_dim
],
dim
=
1
)
self_attn
.
w_kc
=
w_kc
.
transpose
(
1
,
2
).
contiguous
().
transpose
(
1
,
2
)
self_attn
.
w_vc
=
w_vc
.
contiguous
().
transpose
(
1
,
2
)
if
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale"
):
self_attn
.
w_scale
=
self_attn
.
kv_b_proj
.
weight_scale
EntryClass
=
KimiLinearForCausalLM
python/sglang/srt/server_args.py
View file @
a4bf5c6a
...
...
@@ -1028,6 +1028,11 @@ class ServerArgs:
logger
.
info
(
f
"Using
{
self
.
attention_backend
}
as attention backend for
{
model_arch
}
."
)
elif
model_arch
in
[
"KimiLinearForCausalLM"
]:
logger
.
warning
(
f
"Disabling Radix Cache for
{
model_arch
}
as it is not yet supported."
)
self
.
disable_radix_cache
=
True
if
is_deepseek_nsa
(
hf_config
):
if
(
...
...
python/sglang/srt/utils/hf_transformers_utils.py
View file @
a4bf5c6a
...
...
@@ -43,6 +43,7 @@ from sglang.srt.configs import (
DotsVLMConfig
,
ExaoneConfig
,
FalconH1Config
,
KimiLinearConfig
,
KimiVLConfig
,
LongcatFlashConfig
,
MultiModalityConfig
,
...
...
@@ -68,6 +69,7 @@ _CONFIG_REGISTRY: List[Type[PretrainedConfig]] = [
Step3VLConfig
,
LongcatFlashConfig
,
Olmo3Config
,
KimiLinearConfig
,
Qwen3NextConfig
,
FalconH1Config
,
DotsVLMConfig
,
...
...
test/srt/models/test_kimi_linear_models.py
0 → 100644
View file @
a4bf5c6a
import
unittest
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
popen_launch_server
,
)
class
TestKimiLinear
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
"moonshotai/Kimi-Linear-48B-A3B-Instruct"
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--tp-size"
,
"2"
,
"--trust-remote"
],
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
None
,
num_questions
=
200
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
"http://127.0.0.1"
,
port
=
int
(
self
.
base_url
.
split
(
":"
)[
-
1
]),
)
metrics
=
run_eval
(
args
)
print
(
f
"
{
metrics
=
}
"
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.88
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/run_suite.py
View file @
a4bf5c6a
...
...
@@ -151,6 +151,7 @@ suites = {
TestFile
(
"layers/attention/mamba/test_mamba2_mixer.py"
,
50
),
TestFile
(
"lora/test_lora_tp.py"
,
116
),
TestFile
(
"models/test_glm4_moe_models.py"
,
100
),
TestFile
(
"models/test_kimi_linear_models.py"
,
90
),
TestFile
(
"rl/test_update_weights_from_distributed.py"
,
103
),
TestFile
(
"test_data_parallelism.py"
,
73
),
TestFile
(
"test_disaggregation_basic.py"
,
400
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment