Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
8974cc9d
Unverified
Commit
8974cc9d
authored
May 14, 2025
by
wang jiahao
Committed by
GitHub
May 14, 2025
Browse files
Merge pull request #1297 from kvcache-ai/fix-workspace-buffer
update torch MLA kernel
parents
4e015ccc
64742bec
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
167 additions
and
0 deletions
+167
-0
ktransformers/operators/balance_serve_attention.py
ktransformers/operators/balance_serve_attention.py
+167
-0
No files found.
ktransformers/operators/balance_serve_attention.py
View file @
8974cc9d
...
...
@@ -288,3 +288,170 @@ class KQwen3MoeAttention(BaseInjectedModule, Qwen3MoeAttention):
attn_output
=
self
.
o_proj
(
attn_output
.
view
(
q_len
,
self
.
num_heads
*
self
.
head_dim
),
bsz_tensors
)
return
attn_output
class
deepseek_torch_attn
(
BaseInjectedModule
,
DeepseekV2Attention
):
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
prefill_device
:
str
=
"cuda"
,
generate_device
:
str
=
"cuda"
,
chunck_size
:
int
=
1000
,
**
kwargs
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
**
kwargs
)
self
.
orig_module
.
__init__
(
orig_module
.
config
,
orig_module
.
layer_idx
)
self
.
chunck_size
=
chunck_size
# TODO, generate chunck_size automatically.
def
get_absorbed
(
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
not
(
hasattr
(
self
,
'q_absorb'
)
and
hasattr
(
self
,
'out_absorb'
)):
kv_b_proj
=
self
.
kv_b_proj
.
weight
.
view
(
self
.
num_heads
,
-
1
,
self
.
kv_lora_rank
)
q_absorb
=
kv_b_proj
[:,
:
self
.
qk_nope_head_dim
,
:].
reshape
(
-
1
,
self
.
kv_lora_rank
)
out_absorb
=
kv_b_proj
[:,
self
.
qk_nope_head_dim
:,
:].
reshape
(
-
1
,
self
.
kv_lora_rank
)
self
.
q_absorb
=
nn
.
Linear
(
self
.
kv_lora_rank
,
self
.
num_heads
*
self
.
qk_nope_head_dim
,
bias
=
False
,
dtype
=
q_absorb
.
dtype
,
device
=
q_absorb
.
device
)
self
.
q_absorb
.
weight
.
data
=
q_absorb
self
.
out_absorb
=
nn
.
Linear
(
self
.
kv_lora_rank
,
self
.
num_heads
*
self
.
v_head_dim
,
bias
=
False
,
dtype
=
out_absorb
.
dtype
,
device
=
out_absorb
.
device
)
self
.
out_absorb
.
weight
.
data
=
out_absorb
#del self.orig_module.kv_b_proj
q_absorb
=
self
.
q_absorb
.
weight
.
view
(
self
.
num_heads
,
self
.
qk_nope_head_dim
,
self
.
kv_lora_rank
)
out_absorb
=
self
.
out_absorb
.
weight
.
view
(
self
.
num_heads
,
self
.
v_head_dim
,
self
.
kv_lora_rank
)
return
q_absorb
,
out_absorb
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KDeepSeekV3Cache
,
position_ids
:
torch
.
Tensor
,
wrapper
:
None
,
num_tokens_tensors
:
torch
.
Tensor
,
page_idx
:
torch
.
Tensor
,
page_offset
:
torch
.
Tensor
,
attention_masks
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
q_indptr
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_indptr
:
Optional
[
torch
.
Tensor
]
=
None
,
bsz_tensors
:
Optional
[
torch
.
Tensor
]
=
None
,
last_page_len
:
Optional
[
torch
.
Tensor
]
=
None
,
):
# range bsz_tensors
final_attention_output
=
torch
.
tensor
([],
device
=
hidden_states
.
device
)
for
i
in
range
(
bsz_tensors
[
0
]):
batch_num_tokens_tensors
=
q_indptr
[
i
+
1
]
-
q_indptr
[
i
]
batch_last_page_len
=
last_page_len
[
i
]
# kv_total_len is kv_len, batch_compressed_kv is compressed_kv, batch_k_pe is k_pe
batch_page_idx
=
page_idx
[
q_indptr
[
i
]:
q_indptr
[
i
+
1
]]
batch_page_offset
=
page_offset
[
q_indptr
[
i
]:
q_indptr
[
i
+
1
]]
# kv_page_nums is the number of pages for the current batch
kv_page_nums
=
kv_indptr
[
i
+
1
]
-
kv_indptr
[
i
]
# kv_total_len is the total length of the kv cache for the current batch (kv_len for algorithm)
kv_total_len
=
kv_page_nums
*
kv_cache
.
page_size
if
batch_last_page_len
is
not
None
:
kv_total_len
=
kv_total_len
-
(
kv_cache
.
page_size
-
batch_last_page_len
)
# print(f"kv_total_len's shape {kv_total_len.shape}")
# kv_index is the index of the kv cache pages for the current batch
kv_index
=
kv_indices
[
kv_indptr
[
i
]:
kv_indptr
[
i
+
1
]]
# we can index [kv_index, page_offset_indices] to get the kv cache for the current batch
# from q_indptr[i] to q_indptr[i+1] is the range of the current batch
batch_hidden_states
=
hidden_states
[
q_indptr
[
i
]:
q_indptr
[
i
+
1
]]
batch_position_ids
=
position_ids
[
q_indptr
[
i
]:
q_indptr
[
i
+
1
]]
q_len
,
_
=
batch_hidden_states
.
size
()
# print("q_len -> ", q_len)
if
self
.
q_lora_rank
is
None
:
q
=
self
.
q_proj
(
batch_hidden_states
,
batch_num_tokens_tensors
)
else
:
q
=
self
.
q_b_proj
(
self
.
q_a_layernorm
(
self
.
q_a_proj
(
batch_hidden_states
,
batch_num_tokens_tensors
),
batch_num_tokens_tensors
),
batch_num_tokens_tensors
)
# for v3, bsz, q_len, num_heads(128), qk_head_dim(192=128(nope)+64(rope))
q
=
q
.
view
(
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
# q_nope is [q_len, num_heads(128), qk_nope_head_dim(128)]
# q_pe is [q_len, num_heads(128), qk_rope_head_dim(64)]
q_nope
,
q_pe
=
torch
.
split
(
q
,
[
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
# compressed_kv is [q_len, kv_lora_rank(512) + rope(64)]
compressed_kv
=
self
.
kv_a_proj_with_mqa
(
batch_hidden_states
,
batch_num_tokens_tensors
)
# compressed_kv is [q_len, kv_lora_rank(512)], k_pe is [q_len, rope(64)]
compressed_kv
,
k_pe
=
torch
.
split
(
compressed_kv
,
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
compressed_kv
=
compressed_kv
.
contiguous
()
compressed_kv
=
self
.
kv_a_layernorm
(
compressed_kv
,
batch_num_tokens_tensors
)
# k_pe is [q_len, 1, qk_rope_head_dim(64)]
k_pe
=
k_pe
.
view
(
q_len
,
1
,
self
.
qk_rope_head_dim
)
# compressed_kv is [q_len, 1, kv_lora_rank(512)]
compressed_kv
=
compressed_kv
.
view
(
q_len
,
1
,
self
.
kv_lora_rank
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
batch_position_ids
.
unsqueeze
(
0
))
# print(f"q_pe shape{q_pe.shape}, k_pe shape {k_pe.shape}")
q_pe
,
k_pe
=
apply_rotary_pos_emb
(
q_pe
.
unsqueeze
(
0
),
k_pe
.
unsqueeze
(
0
),
cos
,
sin
,
unsqueeze_dim
=
2
)
q_pe
=
q_pe
.
squeeze
(
0
)
# q_pe is [num_heads(128), q_len, qk_rope_head_dim(64)]
q_pe
.
transpose_
(
0
,
1
)
if
kv_cache
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"page_idx"
:
batch_page_idx
,
"page_offset"
:
batch_page_offset
}
# Specific to RoPE models
compressed_kv_with_k_pe
=
kv_cache
.
update
(
compressed_kv
.
unsqueeze
(
0
),
k_pe
,
self
.
layer_idx
,
batch_page_idx
,
batch_page_offset
,
cache_kwargs
)
compressed_kv
=
compressed_kv_with_k_pe
[:,
:,
:,
:
self
.
kv_lora_rank
].
view
(
-
1
,
kv_cache
.
page_size
,
self
.
kv_lora_rank
)
k_pe
=
compressed_kv_with_k_pe
[:,
:,
:,
self
.
kv_lora_rank
:].
view
(
-
1
,
kv_cache
.
page_size
,
self
.
qk_rope_head_dim
)
# q_absorb is [num_heads(128), qk_nope_head_dim(128), kv_lora_rank(512)]
# out_absorb is [num_heads(128), kv_lora_rank(512), v_head_dim(128)] v_head_dim is also the nope dim
q_absorb
,
out_absorb
=
self
.
get_absorbed
()
# q_nope is [num_heads(128), q_len, qk_nope_head_dim(128)]
q_nope
=
q_nope
.
transpose
(
0
,
1
)
# q_len is 1, no GPU overhead, same below
# q_nope is [num_heads(128), q_len, kv_lora_rank(512)]
q_nope
=
torch
.
matmul
(
q_nope
,
q_absorb
)
# batched MM
# # q_nope is [q_len, num_heads(128), kv_lora_rank(512)]
# q_nope = q_nope.transpose(0, 1)
# we need to index out the compressed_kv and k_pe for the current batch
batch_compressed_kv
=
None
batch_k_pe
=
None
for
page_index
in
kv_index
:
if
kv_total_len
>
kv_cache
.
page_size
:
tmp_compressed_kv
=
compressed_kv
[
page_index
,
0
:
kv_cache
.
page_size
,
:]
tmp_k_pe
=
k_pe
[
page_index
,
0
:
kv_cache
.
page_size
,
:]
if
batch_compressed_kv
is
None
or
batch_k_pe
is
None
:
batch_compressed_kv
=
tmp_compressed_kv
batch_k_pe
=
tmp_k_pe
else
:
batch_compressed_kv
=
torch
.
cat
((
batch_compressed_kv
,
tmp_compressed_kv
),
dim
=
0
)
batch_k_pe
=
torch
.
cat
((
batch_k_pe
,
tmp_k_pe
),
dim
=
0
)
kv_total_len
-=
kv_cache
.
page_size
else
:
tmp_compressed_kv
=
compressed_kv
[
page_index
,
0
:
kv_total_len
,
:]
tmp_k_pe
=
k_pe
[
page_index
,
0
:
kv_total_len
,
:]
if
batch_compressed_kv
is
None
or
batch_k_pe
is
None
:
batch_compressed_kv
=
tmp_compressed_kv
batch_k_pe
=
tmp_k_pe
else
:
batch_compressed_kv
=
torch
.
cat
((
batch_compressed_kv
,
tmp_compressed_kv
),
dim
=
0
)
batch_k_pe
=
torch
.
cat
((
batch_k_pe
,
tmp_k_pe
),
dim
=
0
)
break
# batch_compressed_kv is [kv_total_len(k_len), kv_lora_rank(512)]
# batch_k_pe is [kv_total_len(k_len), qk_rope_head_dim(64)]
attention_weights
=
(
torch
.
matmul
(
q_pe
,
batch_k_pe
.
mT
)
+
torch
.
matmul
(
q_nope
,
batch_compressed_kv
.
mT
))
*
self
.
softmax_scale
# attention_weights is [num_heads(128), q_len, k_len]
# attention_weights = attention_weights.transpose(0,1).unsqueeze(0).squeeze(-1).expand(q_len,-1,-1).transpose(0,1)
# attention_masks[i] is [q_len, k_len]
attention_weights
=
(
attention_weights
+
attention_masks
[
i
])
# attention_weights shape is [num_heads(128), q_len, k_len]
attention_weights
=
nn
.
functional
.
softmax
(
attention_weights
,
dim
=-
1
,
dtype
=
torch
.
float32
).
to
(
q_pe
.
dtype
)
attn_output
=
torch
.
matmul
(
attention_weights
,
batch_compressed_kv
)
# [num_heads(128),q_len, lora_rank(512)]
# out_absorb shape is [num_heads(128), kv_lora_rank(512), v_head_dim(128)]
out_absorb
=
out_absorb
.
transpose
(
1
,
2
)
# q for q_len, n for num_heads, h for v_head_dim, v for kv_lora_rank
attn_output
=
torch
.
matmul
(
attn_output
,
out_absorb
)
# [num_heads(128), q_len, v_head_dim(128)]
attn_output
=
attn_output
.
transpose
(
0
,
1
)
# [q_len, num_heads(128), v_head_dim(128)]
attn_output
=
attn_output
.
reshape
(
q_len
,
self
.
num_heads
*
self
.
v_head_dim
)
attn_output
=
self
.
o_proj
(
attn_output
,
batch_num_tokens_tensors
)
final_attention_output
=
torch
.
cat
((
final_attention_output
,
attn_output
),
dim
=
0
)
return
final_attention_output
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment