Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
c51818c3
Unverified
Commit
c51818c3
authored
Mar 15, 2025
by
Atream
Committed by
GitHub
Mar 15, 2025
Browse files
Merge pull request #902 from kvcache-ai/rollback-triton-prefill
rollback-triton-prefill
parents
bda9cf15
3934b9df
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
17 deletions
+8
-17
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+8
-17
No files found.
ktransformers/operators/attention.py
View file @
c51818c3
...
...
@@ -325,27 +325,18 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
key_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
k_pe
.
view
(
bsz
,
kv_seq_len
,
1
,
-
1
)
value_states
=
value_states
.
view
(
bsz
,
kv_seq_len
,
self
.
num_heads
,
self
.
v_head_dim
)
value_states_padded
=
torch
.
nn
.
functional
.
pad
(
value_states
,
[
0
,
query_states
.
shape
[
-
1
]
-
value_states
.
shape
[
-
1
]],
value
=
0
)
# for bsz = 1
attn_output
=
torch
.
zeros
(
bsz
*
q_len
,
self
.
num_heads
,
self
.
v_head_dim
,
device
=
hidden_states
.
device
)
b_start_loc
=
torch
.
zeros
(
bsz
,
dtype
=
torch
.
int64
,
device
=
hidden_states
.
device
)
b_seq_len
=
torch
.
full
((
bsz
,),
q_len
,
dtype
=
torch
.
int64
,
device
=
hidden_states
.
device
)
max_input_len
=
q_len
context_attention_fwd
(
q
=
query_states
.
squeeze
(
0
).
view
(
-
1
,
self
.
num_heads
,
self
.
q_head_dim
),
k
=
key_states
.
squeeze
(
0
).
view
(
-
1
,
self
.
num_heads
,
self
.
q_head_dim
),
v
=
value_states
.
squeeze
(
0
).
view
(
-
1
,
self
.
num_heads
,
self
.
v_head_dim
),
o
=
attn_output
,
b_start_loc
=
b_start_loc
,
b_seq_len
=
b_seq_len
,
max_input_len
=
max_input_len
,
is_causal
=
True
attn_output
=
flash_attn_func
(
query_states
,
key_states
,
value_states_padded
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
True
,
)
if
self
.
q_head_dim
!=
self
.
v_head_dim
:
attn_output
=
attn_output
[:,
:,
:
self
.
v_head_dim
]
attn_output
=
attn_output
[:,
:,
:,
:
self
.
v_head_dim
]
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment