Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
2ffb43f9
Commit
2ffb43f9
authored
Feb 18, 2025
by
ceerrep
Browse files
fix: adapt prefix cache in `forward_linux_flashinfer`
parent
c70b6f4d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
10 deletions
+22
-10
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+22
-10
No files found.
ktransformers/operators/attention.py
View file @
2ffb43f9
...
...
@@ -374,6 +374,16 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
compressed_kv
=
self
.
kv_a_layernorm
(
compressed_kv
)
k_pe
=
k_pe
.
view
(
bsz
,
q_len
,
1
,
self
.
qk_rope_head_dim
)
compressed_kv
=
compressed_kv
.
view
(
bsz
,
q_len
,
1
,
self
.
kv_lora_rank
)
kv_seq_len
=
q_len
if
past_key_value
is
not
None
:
if
self
.
layer_idx
is
None
:
raise
ValueError
(
f
"The cache structure has changed since version v4.36. If you are using
{
self
.
__class__
.
__name__
}
"
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
position_ids
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
(
q_pe
,
k_pe
,
cos
,
sin
,
unsqueeze_dim
=
2
)
...
...
@@ -453,26 +463,28 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
k_pe
.
squeeze
(
0
)
compressed_kv
.
squeeze
(
0
)
past_key_value
.
update
(
compressed_kv
,
k_pe
,
self
.
layer_idx
,
cache_kwargs
)
k_pe
.
unsqueeze
(
0
)
compressed_kv
.
unsqueeze
(
0
)
k_pe
=
k_pe
[:,
:
q_len
]
compressed_kv
=
compressed_kv
[:,
:
q_len
]
compressed_kv_with_k_pe
,
_
=
past_key_value
.
update
(
compressed_kv
,
k_pe
,
self
.
layer_idx
,
cache_kwargs
)
compressed_kv
,
k_pe
=
torch
.
split
(
compressed_kv_with_k_pe
,
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
k_pe
=
k_pe
.
view
(
bsz
,
-
1
,
self
.
qk_rope_head_dim
)
k_pe
=
k_pe
[:,
:
kv_seq_len
]
compressed_kv
=
compressed_kv
.
view
(
bsz
,
-
1
,
self
.
kv_lora_rank
)
compressed_kv
=
compressed_kv
[:,
:
kv_seq_len
]
kv
=
(
self
.
kv_b_proj
(
compressed_kv
)
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
.
view
(
bsz
,
kv_se
q_len
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
)
k_nope
,
value_states
=
torch
.
split
(
kv
,
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
query_states
=
k_pe
.
new_empty
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
query_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
q_nope
query_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
q_pe
key_states
=
k_pe
.
new_empty
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
key_states
=
k_pe
.
new_empty
(
bsz
,
kv_se
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
key_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
k_nope
key_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
k_pe
key_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
k_pe
.
view
(
bsz
,
kv_seq_len
,
1
,
-
1
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
v_head_dim
)
value_states
=
value_states
.
view
(
bsz
,
kv_se
q_len
,
self
.
num_heads
,
self
.
v_head_dim
)
value_states_padded
=
torch
.
nn
.
functional
.
pad
(
value_states
,
[
0
,
query_states
.
shape
[
-
1
]
-
value_states
.
shape
[
-
1
]],
value
=
0
)
attn_output
=
flash_attn_func
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment