Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
64742bec
Commit
64742bec
authored
May 14, 2025
by
qiyuxinlin
Browse files
update torch MLA kernel
parent
e8e83308
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
167 additions
and
0 deletions
+167
-0
ktransformers/operators/balance_serve_attention.py
ktransformers/operators/balance_serve_attention.py
+167
-0
No files found.
ktransformers/operators/balance_serve_attention.py
View file @
64742bec
...
@@ -288,3 +288,170 @@ class KQwen3MoeAttention(BaseInjectedModule, Qwen3MoeAttention):
...
@@ -288,3 +288,170 @@ class KQwen3MoeAttention(BaseInjectedModule, Qwen3MoeAttention):
attn_output
=
self
.
o_proj
(
attn_output
.
view
(
q_len
,
self
.
num_heads
*
self
.
head_dim
),
bsz_tensors
)
attn_output
=
self
.
o_proj
(
attn_output
.
view
(
q_len
,
self
.
num_heads
*
self
.
head_dim
),
bsz_tensors
)
return
attn_output
return
attn_output
class
deepseek_torch_attn
(
BaseInjectedModule
,
DeepseekV2Attention
):
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
prefill_device
:
str
=
"cuda"
,
generate_device
:
str
=
"cuda"
,
chunck_size
:
int
=
1000
,
**
kwargs
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
**
kwargs
)
self
.
orig_module
.
__init__
(
orig_module
.
config
,
orig_module
.
layer_idx
)
self
.
chunck_size
=
chunck_size
# TODO, generate chunck_size automatically.
def
get_absorbed
(
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
not
(
hasattr
(
self
,
'q_absorb'
)
and
hasattr
(
self
,
'out_absorb'
)):
kv_b_proj
=
self
.
kv_b_proj
.
weight
.
view
(
self
.
num_heads
,
-
1
,
self
.
kv_lora_rank
)
q_absorb
=
kv_b_proj
[:,
:
self
.
qk_nope_head_dim
,
:].
reshape
(
-
1
,
self
.
kv_lora_rank
)
out_absorb
=
kv_b_proj
[:,
self
.
qk_nope_head_dim
:,
:].
reshape
(
-
1
,
self
.
kv_lora_rank
)
self
.
q_absorb
=
nn
.
Linear
(
self
.
kv_lora_rank
,
self
.
num_heads
*
self
.
qk_nope_head_dim
,
bias
=
False
,
dtype
=
q_absorb
.
dtype
,
device
=
q_absorb
.
device
)
self
.
q_absorb
.
weight
.
data
=
q_absorb
self
.
out_absorb
=
nn
.
Linear
(
self
.
kv_lora_rank
,
self
.
num_heads
*
self
.
v_head_dim
,
bias
=
False
,
dtype
=
out_absorb
.
dtype
,
device
=
out_absorb
.
device
)
self
.
out_absorb
.
weight
.
data
=
out_absorb
#del self.orig_module.kv_b_proj
q_absorb
=
self
.
q_absorb
.
weight
.
view
(
self
.
num_heads
,
self
.
qk_nope_head_dim
,
self
.
kv_lora_rank
)
out_absorb
=
self
.
out_absorb
.
weight
.
view
(
self
.
num_heads
,
self
.
v_head_dim
,
self
.
kv_lora_rank
)
return
q_absorb
,
out_absorb
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KDeepSeekV3Cache
,
position_ids
:
torch
.
Tensor
,
wrapper
:
None
,
num_tokens_tensors
:
torch
.
Tensor
,
page_idx
:
torch
.
Tensor
,
page_offset
:
torch
.
Tensor
,
attention_masks
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
q_indptr
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_indptr
:
Optional
[
torch
.
Tensor
]
=
None
,
bsz_tensors
:
Optional
[
torch
.
Tensor
]
=
None
,
last_page_len
:
Optional
[
torch
.
Tensor
]
=
None
,
):
# range bsz_tensors
final_attention_output
=
torch
.
tensor
([],
device
=
hidden_states
.
device
)
for
i
in
range
(
bsz_tensors
[
0
]):
batch_num_tokens_tensors
=
q_indptr
[
i
+
1
]
-
q_indptr
[
i
]
batch_last_page_len
=
last_page_len
[
i
]
# kv_total_len is kv_len, batch_compressed_kv is compressed_kv, batch_k_pe is k_pe
batch_page_idx
=
page_idx
[
q_indptr
[
i
]:
q_indptr
[
i
+
1
]]
batch_page_offset
=
page_offset
[
q_indptr
[
i
]:
q_indptr
[
i
+
1
]]
# kv_page_nums is the number of pages for the current batch
kv_page_nums
=
kv_indptr
[
i
+
1
]
-
kv_indptr
[
i
]
# kv_total_len is the total length of the kv cache for the current batch (kv_len for algorithm)
kv_total_len
=
kv_page_nums
*
kv_cache
.
page_size
if
batch_last_page_len
is
not
None
:
kv_total_len
=
kv_total_len
-
(
kv_cache
.
page_size
-
batch_last_page_len
)
# print(f"kv_total_len's shape {kv_total_len.shape}")
# kv_index is the index of the kv cache pages for the current batch
kv_index
=
kv_indices
[
kv_indptr
[
i
]:
kv_indptr
[
i
+
1
]]
# we can index [kv_index, page_offset_indices] to get the kv cache for the current batch
# from q_indptr[i] to q_indptr[i+1] is the range of the current batch
batch_hidden_states
=
hidden_states
[
q_indptr
[
i
]:
q_indptr
[
i
+
1
]]
batch_position_ids
=
position_ids
[
q_indptr
[
i
]:
q_indptr
[
i
+
1
]]
q_len
,
_
=
batch_hidden_states
.
size
()
# print("q_len -> ", q_len)
if
self
.
q_lora_rank
is
None
:
q
=
self
.
q_proj
(
batch_hidden_states
,
batch_num_tokens_tensors
)
else
:
q
=
self
.
q_b_proj
(
self
.
q_a_layernorm
(
self
.
q_a_proj
(
batch_hidden_states
,
batch_num_tokens_tensors
),
batch_num_tokens_tensors
),
batch_num_tokens_tensors
)
# for v3, bsz, q_len, num_heads(128), qk_head_dim(192=128(nope)+64(rope))
q
=
q
.
view
(
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
# q_nope is [q_len, num_heads(128), qk_nope_head_dim(128)]
# q_pe is [q_len, num_heads(128), qk_rope_head_dim(64)]
q_nope
,
q_pe
=
torch
.
split
(
q
,
[
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
# compressed_kv is [q_len, kv_lora_rank(512) + rope(64)]
compressed_kv
=
self
.
kv_a_proj_with_mqa
(
batch_hidden_states
,
batch_num_tokens_tensors
)
# compressed_kv is [q_len, kv_lora_rank(512)], k_pe is [q_len, rope(64)]
compressed_kv
,
k_pe
=
torch
.
split
(
compressed_kv
,
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
compressed_kv
=
compressed_kv
.
contiguous
()
compressed_kv
=
self
.
kv_a_layernorm
(
compressed_kv
,
batch_num_tokens_tensors
)
# k_pe is [q_len, 1, qk_rope_head_dim(64)]
k_pe
=
k_pe
.
view
(
q_len
,
1
,
self
.
qk_rope_head_dim
)
# compressed_kv is [q_len, 1, kv_lora_rank(512)]
compressed_kv
=
compressed_kv
.
view
(
q_len
,
1
,
self
.
kv_lora_rank
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
batch_position_ids
.
unsqueeze
(
0
))
# print(f"q_pe shape{q_pe.shape}, k_pe shape {k_pe.shape}")
q_pe
,
k_pe
=
apply_rotary_pos_emb
(
q_pe
.
unsqueeze
(
0
),
k_pe
.
unsqueeze
(
0
),
cos
,
sin
,
unsqueeze_dim
=
2
)
q_pe
=
q_pe
.
squeeze
(
0
)
# q_pe is [num_heads(128), q_len, qk_rope_head_dim(64)]
q_pe
.
transpose_
(
0
,
1
)
if
kv_cache
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"page_idx"
:
batch_page_idx
,
"page_offset"
:
batch_page_offset
}
# Specific to RoPE models
compressed_kv_with_k_pe
=
kv_cache
.
update
(
compressed_kv
.
unsqueeze
(
0
),
k_pe
,
self
.
layer_idx
,
batch_page_idx
,
batch_page_offset
,
cache_kwargs
)
compressed_kv
=
compressed_kv_with_k_pe
[:,
:,
:,
:
self
.
kv_lora_rank
].
view
(
-
1
,
kv_cache
.
page_size
,
self
.
kv_lora_rank
)
k_pe
=
compressed_kv_with_k_pe
[:,
:,
:,
self
.
kv_lora_rank
:].
view
(
-
1
,
kv_cache
.
page_size
,
self
.
qk_rope_head_dim
)
# q_absorb is [num_heads(128), qk_nope_head_dim(128), kv_lora_rank(512)]
# out_absorb is [num_heads(128), kv_lora_rank(512), v_head_dim(128)] v_head_dim is also the nope dim
q_absorb
,
out_absorb
=
self
.
get_absorbed
()
# q_nope is [num_heads(128), q_len, qk_nope_head_dim(128)]
q_nope
=
q_nope
.
transpose
(
0
,
1
)
# q_len is 1, no GPU overhead, same below
# q_nope is [num_heads(128), q_len, kv_lora_rank(512)]
q_nope
=
torch
.
matmul
(
q_nope
,
q_absorb
)
# batched MM
# # q_nope is [q_len, num_heads(128), kv_lora_rank(512)]
# q_nope = q_nope.transpose(0, 1)
# we need to index out the compressed_kv and k_pe for the current batch
batch_compressed_kv
=
None
batch_k_pe
=
None
for
page_index
in
kv_index
:
if
kv_total_len
>
kv_cache
.
page_size
:
tmp_compressed_kv
=
compressed_kv
[
page_index
,
0
:
kv_cache
.
page_size
,
:]
tmp_k_pe
=
k_pe
[
page_index
,
0
:
kv_cache
.
page_size
,
:]
if
batch_compressed_kv
is
None
or
batch_k_pe
is
None
:
batch_compressed_kv
=
tmp_compressed_kv
batch_k_pe
=
tmp_k_pe
else
:
batch_compressed_kv
=
torch
.
cat
((
batch_compressed_kv
,
tmp_compressed_kv
),
dim
=
0
)
batch_k_pe
=
torch
.
cat
((
batch_k_pe
,
tmp_k_pe
),
dim
=
0
)
kv_total_len
-=
kv_cache
.
page_size
else
:
tmp_compressed_kv
=
compressed_kv
[
page_index
,
0
:
kv_total_len
,
:]
tmp_k_pe
=
k_pe
[
page_index
,
0
:
kv_total_len
,
:]
if
batch_compressed_kv
is
None
or
batch_k_pe
is
None
:
batch_compressed_kv
=
tmp_compressed_kv
batch_k_pe
=
tmp_k_pe
else
:
batch_compressed_kv
=
torch
.
cat
((
batch_compressed_kv
,
tmp_compressed_kv
),
dim
=
0
)
batch_k_pe
=
torch
.
cat
((
batch_k_pe
,
tmp_k_pe
),
dim
=
0
)
break
# batch_compressed_kv is [kv_total_len(k_len), kv_lora_rank(512)]
# batch_k_pe is [kv_total_len(k_len), qk_rope_head_dim(64)]
attention_weights
=
(
torch
.
matmul
(
q_pe
,
batch_k_pe
.
mT
)
+
torch
.
matmul
(
q_nope
,
batch_compressed_kv
.
mT
))
*
self
.
softmax_scale
# attention_weights is [num_heads(128), q_len, k_len]
# attention_weights = attention_weights.transpose(0,1).unsqueeze(0).squeeze(-1).expand(q_len,-1,-1).transpose(0,1)
# attention_masks[i] is [q_len, k_len]
attention_weights
=
(
attention_weights
+
attention_masks
[
i
])
# attention_weights shape is [num_heads(128), q_len, k_len]
attention_weights
=
nn
.
functional
.
softmax
(
attention_weights
,
dim
=-
1
,
dtype
=
torch
.
float32
).
to
(
q_pe
.
dtype
)
attn_output
=
torch
.
matmul
(
attention_weights
,
batch_compressed_kv
)
# [num_heads(128),q_len, lora_rank(512)]
# out_absorb shape is [num_heads(128), kv_lora_rank(512), v_head_dim(128)]
out_absorb
=
out_absorb
.
transpose
(
1
,
2
)
# q for q_len, n for num_heads, h for v_head_dim, v for kv_lora_rank
attn_output
=
torch
.
matmul
(
attn_output
,
out_absorb
)
# [num_heads(128), q_len, v_head_dim(128)]
attn_output
=
attn_output
.
transpose
(
0
,
1
)
# [q_len, num_heads(128), v_head_dim(128)]
attn_output
=
attn_output
.
reshape
(
q_len
,
self
.
num_heads
*
self
.
v_head_dim
)
attn_output
=
self
.
o_proj
(
attn_output
,
batch_num_tokens_tensors
)
final_attention_output
=
torch
.
cat
((
final_attention_output
,
attn_output
),
dim
=
0
)
return
final_attention_output
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment