Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
5ac26608
Commit
5ac26608
authored
Feb 17, 2025
by
ceerrep
Browse files
fix: use flash_attn for faster prefill
parent
bb0ccc7b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
14 deletions
+26
-14
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+26
-14
No files found.
ktransformers/operators/attention.py
View file @
5ac26608
...
@@ -125,8 +125,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -125,8 +125,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
# compressed_kv [pages, page_size, 1, self.kv_lora_rank]
# compressed_kv [pages, page_size, 1, self.kv_lora_rank]
q_absorb
,
out_absorb
=
self
.
get_absorbed
()
q_absorb
,
out_absorb
=
self
.
get_absorbed
()
if
hasattr
(
self
.
orig_module
,
'kv_b_proj'
):
#
if hasattr(self.orig_module, 'kv_b_proj'):
del
self
.
orig_module
.
kv_b_proj
#
del self.orig_module.kv_b_proj
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
# q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
# q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
...
@@ -216,13 +216,23 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -216,13 +216,23 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
compressed_kv
=
self
.
kv_a_layernorm
(
compressed_kv
)
compressed_kv
=
self
.
kv_a_layernorm
(
compressed_kv
)
k_pe
=
k_pe
.
view
(
bsz
,
q_len
,
1
,
self
.
qk_rope_head_dim
)
k_pe
=
k_pe
.
view
(
bsz
,
q_len
,
1
,
self
.
qk_rope_head_dim
)
compressed_kv
=
compressed_kv
.
view
(
bsz
,
q_len
,
1
,
self
.
kv_lora_rank
)
compressed_kv
=
compressed_kv
.
view
(
bsz
,
q_len
,
1
,
self
.
kv_lora_rank
)
kv_seq_len
=
q_len
if
past_key_value
is
not
None
:
if
self
.
layer_idx
is
None
:
raise
ValueError
(
f
"The cache structure has changed since version v4.36. If you are using
{
self
.
__class__
.
__name__
}
"
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
position_ids
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
position_ids
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
(
q_pe
,
k_pe
,
cos
,
sin
,
unsqueeze_dim
=
2
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
(
q_pe
,
k_pe
,
cos
,
sin
,
unsqueeze_dim
=
2
)
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
# decode
# decode
if
q_len
==
1
:
if
self
.
use_triton
and
q_len
==
1
:
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
compressed_kv_with_k_pe
,
page_table
=
past_key_value
.
update
(
compressed_kv
,
k_pe
,
self
.
layer_idx
,
cache_kwargs
)
compressed_kv_with_k_pe
,
page_table
=
past_key_value
.
update
(
compressed_kv
,
k_pe
,
self
.
layer_idx
,
cache_kwargs
)
...
@@ -287,26 +297,28 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -287,26 +297,28 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
k_pe
.
squeeze
(
0
)
k_pe
.
squeeze
(
0
)
compressed_kv
.
squeeze
(
0
)
compressed_kv
.
squeeze
(
0
)
past_key_value
.
update
(
compressed_kv
,
k_pe
,
self
.
layer_idx
,
cache_kwargs
)
compressed_kv_with_k_pe
,
_
=
past_key_value
.
update
(
compressed_kv
,
k_pe
,
self
.
layer_idx
,
cache_kwargs
)
k_pe
.
unsqueeze
(
0
)
compressed_kv
,
k_pe
=
torch
.
split
(
compressed_kv
.
unsqueeze
(
0
)
compressed_kv_with_k_pe
,
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
k_pe
=
k_pe
[:,
:
q_len
]
k_pe
=
k_pe
.
view
(
bsz
,
-
1
,
self
.
qk_rope_head_dim
)
compressed_kv
=
compressed_kv
[:,
:
q_len
]
k_pe
=
k_pe
[:,
:
kv_seq_len
]
compressed_kv
=
compressed_kv
.
view
(
bsz
,
-
1
,
self
.
kv_lora_rank
)
compressed_kv
=
compressed_kv
[:,
:
kv_seq_len
]
kv
=
(
kv
=
(
self
.
kv_b_proj
(
compressed_kv
)
self
.
kv_b_proj
(
compressed_kv
)
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
.
view
(
bsz
,
kv_se
q_len
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
)
)
k_nope
,
value_states
=
torch
.
split
(
kv
,
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
k_nope
,
value_states
=
torch
.
split
(
kv
,
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
query_states
=
k_pe
.
new_empty
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
query_states
=
k_pe
.
new_empty
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
query_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
q_nope
query_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
q_nope
query_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
q_pe
query_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
q_pe
key_states
=
k_pe
.
new_empty
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
key_states
=
k_pe
.
new_empty
(
bsz
,
kv_se
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
key_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
k_nope
key_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
k_nope
key_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
k_pe
key_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
k_pe
.
view
(
bsz
,
kv_seq_len
,
1
,
-
1
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
v_head_dim
)
value_states
=
value_states
.
view
(
bsz
,
kv_se
q_len
,
self
.
num_heads
,
self
.
v_head_dim
)
value_states_padded
=
torch
.
nn
.
functional
.
pad
(
value_states
,
[
0
,
query_states
.
shape
[
-
1
]
-
value_states
.
shape
[
-
1
]],
value
=
0
)
value_states_padded
=
torch
.
nn
.
functional
.
pad
(
value_states
,
[
0
,
query_states
.
shape
[
-
1
]
-
value_states
.
shape
[
-
1
]],
value
=
0
)
attn_output
=
flash_attn_func
(
attn_output
=
flash_attn_func
(
...
@@ -403,7 +415,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -403,7 +415,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
**
kwargs
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
if
not
self
.
use_triton
:
# os.name == 'nt'
if
os
.
name
==
'nt'
or
hidden_states
.
shape
[
1
]
==
1
:
# Use in decode
return
self
.
forward_windows
(
return
self
.
forward_windows
(
hidden_states
,
hidden_states
,
attention_mask
,
attention_mask
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment