Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fa183e92
Unverified
Commit
fa183e92
authored
Nov 13, 2025
by
Jiangyun Zhu
Committed by
GitHub
Nov 13, 2025
Browse files
[Bugfix] fix kimi-linear crash (#28445)
Signed-off-by:
zjy0516
<
riverclouds.zhu@qq.com
>
parent
4ab34f6e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
9 deletions
+12
-9
vllm/model_executor/layers/kda.py
vllm/model_executor/layers/kda.py
+12
-9
No files found.
vllm/model_executor/layers/kda.py
View file @
fa183e92
...
...
@@ -44,7 +44,6 @@ def kda_attention(
k_proj_states
:
torch
.
Tensor
,
v_proj_states
:
torch
.
Tensor
,
g1
:
torch
.
Tensor
,
g2
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
core_attn_out
:
torch
.
Tensor
,
layer_name
:
str
,
...
...
@@ -56,7 +55,6 @@ def kda_attention(
k_proj_states
=
k_proj_states
,
v_proj_states
=
v_proj_states
,
g1
=
g1
,
g2
=
g2
,
beta
=
beta
,
core_attn_out
=
core_attn_out
,
)
...
...
@@ -67,7 +65,6 @@ def kda_attention_fake(
k_proj_states
:
torch
.
Tensor
,
v_proj_states
:
torch
.
Tensor
,
g1
:
torch
.
Tensor
,
g2
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
core_attn_out
:
torch
.
Tensor
,
layer_name
:
str
,
...
...
@@ -284,7 +281,6 @@ class KimiDeltaAttention(nn.Module, MambaBase):
k
,
v
,
g1
,
g2
,
beta
,
core_attn_out
,
self
.
prefix
,
...
...
@@ -299,7 +295,6 @@ class KimiDeltaAttention(nn.Module, MambaBase):
k_proj_states
:
torch
.
Tensor
,
v_proj_states
:
torch
.
Tensor
,
g1
:
torch
.
Tensor
,
g2
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
core_attn_out
:
torch
.
Tensor
,
)
->
None
:
...
...
@@ -316,8 +311,15 @@ class KimiDeltaAttention(nn.Module, MambaBase):
has_initial_state
=
attn_metadata
.
has_initial_state
non_spec_query_start_loc
=
attn_metadata
.
non_spec_query_start_loc
non_spec_state_indices_tensor
=
attn_metadata
.
non_spec_state_indices_tensor
# noqa: E501
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
constant_caches
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
q_proj_states
=
q_proj_states
[:
num_actual_tokens
]
k_proj_states
=
k_proj_states
[:
num_actual_tokens
]
v_proj_states
=
v_proj_states
[:
num_actual_tokens
]
g1
=
g1
[:
num_actual_tokens
]
beta
=
beta
[:
num_actual_tokens
]
(
conv_state_q
,
conv_state_k
,
conv_state_v
,
recurrent_state
)
=
constant_caches
# deal with strides
conv_state_q
=
conv_state_q
.
transpose
(
-
1
,
-
2
)
...
...
@@ -372,7 +374,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
).
transpose
(
0
,
1
)
else
:
decode_conv_indices
=
non_spec_state_indices_tensor
[
:
attn_metadata
.
num_
decode
s
:
attn_metadata
.
num_
actual_token
s
]
q
=
causal_conv1d_update
(
q_proj_states
,
...
...
@@ -438,8 +440,9 @@ class KimiDeltaAttention(nn.Module, MambaBase):
beta
=
beta
,
initial_state
=
recurrent_state
,
use_qk_l2norm_in_kernel
=
True
,
cu_seqlens
=
non_spec_query_start_loc
,
cu_seqlens
=
non_spec_query_start_loc
[:
attn_metadata
.
num_decodes
+
1
]
,
ssm_state_indices
=
non_spec_state_indices_tensor
,
)
assert
core_attn_out_non_spec
.
shape
==
core_attn_out
.
shape
core_attn_out
[:]
=
core_attn_out_non_spec
core_attn_out
[
0
,
:
num_actual_tokens
]
=
core_attn_out_non_spec
[
0
,
:
num_actual_tokens
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment