Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3857eb87
Unverified
Commit
3857eb87
authored
Oct 31, 2025
by
Jiangyun Zhu
Committed by
GitHub
Oct 31, 2025
Browse files
[Perf] Decouple torch op from GDA to leverage torch.compile (#27871)
Signed-off-by:
zjy0516
<
riverclouds.zhu@qq.com
>
parent
933cdea4
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
68 additions
and
48 deletions
+68
-48
vllm/model_executor/layers/kda.py
vllm/model_executor/layers/kda.py
+68
-48
No files found.
vllm/model_executor/layers/kda.py
View file @
3857eb87
...
...
@@ -40,18 +40,36 @@ logger = init_logger(__name__)
def
kda_attention
(
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
q_proj_states
:
torch
.
Tensor
,
k_proj_states
:
torch
.
Tensor
,
v_proj_states
:
torch
.
Tensor
,
g1
:
torch
.
Tensor
,
g2
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
core_attn_out
:
torch
.
Tensor
,
layer_name
:
str
,
)
->
None
:
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
.
_forward
(
hidden_states
=
hidden_states
,
output
=
output
)
self
.
_forward
(
q_proj_states
=
q_proj_states
,
k_proj_states
=
k_proj_states
,
v_proj_states
=
v_proj_states
,
g1
=
g1
,
g2
=
g2
,
beta
=
beta
,
core_attn_out
=
core_attn_out
,
)
def
kda_attention_fake
(
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
q_proj_states
:
torch
.
Tensor
,
k_proj_states
:
torch
.
Tensor
,
v_proj_states
:
torch
.
Tensor
,
g1
:
torch
.
Tensor
,
g2
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
core_attn_out
:
torch
.
Tensor
,
layer_name
:
str
,
)
->
None
:
return
...
...
@@ -60,7 +78,7 @@ def kda_attention_fake(
direct_register_custom_op
(
op_name
=
"kda_attention"
,
op_func
=
kda_attention
,
mutates_args
=
[
"
outp
ut"
],
mutates_args
=
[
"
core_attn_o
ut"
],
fake_impl
=
kda_attention_fake
,
)
...
...
@@ -241,37 +259,56 @@ class KimiDeltaAttention(nn.Module, MambaBase):
hidden_states
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
)
->
None
:
return
torch
.
ops
.
vllm
.
kda_attention
(
hidden_states
,
output
,
)
->
torch
.
Tensor
:
num_tokens
=
hidden_states
.
size
(
0
)
q
=
self
.
q_proj
(
hidden_states
)[
0
]
k
=
self
.
k_proj
(
hidden_states
)[
0
]
v
=
self
.
v_proj
(
hidden_states
)[
0
]
beta
=
self
.
b_proj
(
hidden_states
)[
0
].
float
().
sigmoid
()
g1
=
self
.
f_b_proj
(
self
.
f_a_proj
(
hidden_states
)[
0
])[
0
]
g1
=
fused_kda_gate
(
g1
,
self
.
A_log
,
self
.
head_dim
,
g_bias
=
self
.
dt_bias
)
beta
=
beta
.
unsqueeze
(
0
)
g1
=
g1
.
unsqueeze
(
0
)
g_proj_states
=
self
.
g_b_proj
(
self
.
g_a_proj
(
hidden_states
)[
0
])[
0
]
g2
=
rearrange
(
g_proj_states
,
"... (h d) -> ... h d"
,
d
=
self
.
head_dim
)
core_attn_out
=
torch
.
zeros
(
(
1
,
num_tokens
,
self
.
local_num_heads
,
self
.
head_dim
),
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
)
torch
.
ops
.
vllm
.
kda_attention
(
q
,
k
,
v
,
g1
,
g2
,
beta
,
core_attn_out
,
self
.
prefix
,
)
core_attn_out
=
self
.
o_norm
(
core_attn_out
,
g2
)
core_attn_out
=
rearrange
(
core_attn_out
,
"1 n h d -> n (h d)"
)
return
self
.
o_proj
(
core_attn_out
)[
0
]
def
_forward
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
q_proj_states
:
torch
.
Tensor
,
k_proj_states
:
torch
.
Tensor
,
v_proj_states
:
torch
.
Tensor
,
g1
:
torch
.
Tensor
,
g2
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
core_attn_out
:
torch
.
Tensor
,
)
->
None
:
forward_context
=
get_forward_context
()
attn_metadata
:
AttentionMetadata
=
forward_context
.
attn_metadata
if
attn_metadata
is
None
:
# V1 profile run
# Mimic the memory allocation in the real run
q
=
torch
.
empty_like
(
hidden_states
)
k
=
torch
.
empty_like
(
hidden_states
)
v
=
torch
.
empty_like
(
hidden_states
)
g
=
hidden_states
.
new_empty
(
hidden_states
.
size
(
0
),
self
.
local_num_heads
,
self
.
head_dim
,
dtype
=
torch
.
float32
,
)
beta
=
torch
.
empty
(
hidden_states
.
size
(
0
),
self
.
local_num_heads
,
dtype
=
torch
.
float32
)
core_attn_out
=
torch
.
empty_like
(
hidden_states
)
# # V1 profile run
return
assert
isinstance
(
attn_metadata
,
dict
)
...
...
@@ -288,10 +325,6 @@ class KimiDeltaAttention(nn.Module, MambaBase):
conv_state_k
=
conv_state_k
.
transpose
(
-
1
,
-
2
)
conv_state_v
=
conv_state_v
.
transpose
(
-
1
,
-
2
)
q_proj_states
=
self
.
q_proj
(
hidden_states
)[
0
]
k_proj_states
=
self
.
k_proj
(
hidden_states
)[
0
]
v_proj_states
=
self
.
v_proj
(
hidden_states
)[
0
]
q_conv_weights
=
self
.
q_conv1d
.
weight
.
view
(
self
.
q_conv1d
.
weight
.
size
(
0
),
self
.
q_conv1d
.
weight
.
size
(
2
)
)
...
...
@@ -374,14 +407,6 @@ class KimiDeltaAttention(nn.Module, MambaBase):
lambda
x
:
rearrange
(
x
,
"n (h d) -> 1 n h d"
,
d
=
self
.
head_dim
),
(
q
,
k
,
v
)
)
beta
=
self
.
b_proj
(
hidden_states
)[
0
].
float
().
sigmoid
()
g
=
self
.
f_b_proj
(
self
.
f_a_proj
(
hidden_states
)[
0
])[
0
]
g
=
fused_kda_gate
(
g
,
self
.
A_log
,
self
.
head_dim
,
g_bias
=
self
.
dt_bias
)
beta
=
beta
.
unsqueeze
(
0
)
g
=
g
.
unsqueeze
(
0
)
if
attn_metadata
.
num_prefills
>
0
:
zero_idx
=
non_spec_state_indices_tensor
[
~
has_initial_state
]
recurrent_state
[
zero_idx
]
=
0
...
...
@@ -393,7 +418,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
q
=
q
,
k
=
k
,
v
=
v
,
g
=
g
,
g
=
g
1
,
beta
=
beta
,
initial_state
=
initial_state
,
output_final_state
=
True
,
...
...
@@ -410,17 +435,12 @@ class KimiDeltaAttention(nn.Module, MambaBase):
q
=
q
,
k
=
k
,
v
=
v
,
g
=
g
,
g
=
g
1
,
beta
=
beta
,
initial_state
=
recurrent_state
,
use_qk_l2norm_in_kernel
=
True
,
cu_seqlens
=
non_spec_query_start_loc
,
ssm_state_indices
=
non_spec_state_indices_tensor
,
)
g_proj_states
=
self
.
g_b_proj
(
self
.
g_a_proj
(
hidden_states
)[
0
])[
0
]
g
=
rearrange
(
g_proj_states
,
"... (h d) -> ... h d"
,
d
=
self
.
head_dim
)
core_attn_out
=
self
.
o_norm
(
core_attn_out_non_spec
,
g
)
core_attn_out
=
rearrange
(
core_attn_out
,
"1 n h d -> n (h d)"
)
output
[:]
=
self
.
o_proj
(
core_attn_out
)[
0
]
assert
core_attn_out_non_spec
.
shape
==
core_attn_out
.
shape
core_attn_out
[:]
=
core_attn_out_non_spec
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment