Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
b7ca76f1
"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "d452a1665f802bfbe75372b6c942272b252c70a2"
Commit
b7ca76f1
authored
Feb 26, 2025
by
Yu Cheng
Committed by
GitHub
Feb 26, 2025
Browse files
[Dev] Update MLA decode kernel (#120)
parent
524991fe
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
84 additions
and
62 deletions
+84
-62
examples/flash_decoding/example_mla_decode.py
examples/flash_decoding/example_mla_decode.py
+84
-62
No files found.
examples/flash_decoding/example_mla_decode.py
View file @
b7ca76f1
...
@@ -3,17 +3,13 @@ import torch.nn.functional as F
...
@@ -3,17 +3,13 @@ import torch.nn.functional as F
import
tilelang
import
tilelang
from
tilelang.autotuner
import
*
from
tilelang.autotuner
import
*
import
tilelang.language
as
T
import
tilelang.language
as
T
from
einops
import
rearrange
,
einsum
num_split
=
4
num_split
=
1
def
flashattn
(
batch
,
heads
,
kv_head_num
,
seqlen_kv
,
dim
,
pe_dim
,
block_N
,
block_H
):
def
flashattn
(
batch
,
heads
,
kv_head_num
,
seqlen_kv
,
dim
,
pe_dim
,
block_N
,
block_H
):
scale
=
(
1.0
/
(
dim
+
pe_dim
))
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
(
dim
+
pe_dim
))
**
0.5
*
1.44269504
# log2(e)
shape_q
=
[
batch
,
heads
,
(
dim
+
pe_dim
)]
shape_k
=
[
batch
,
seqlen_kv
,
kv_head_num
,
(
dim
+
pe_dim
)]
shape_v
=
[
batch
,
seqlen_kv
,
kv_head_num
,
dim
]
shape_o
=
[
batch
,
heads
,
dim
]
part_shape
=
[
batch
,
heads
,
num_split
,
dim
]
dtype
=
"float16"
dtype
=
"float16"
accum_dtype
=
"float"
accum_dtype
=
"float"
kv_group_num
=
heads
//
kv_head_num
kv_group_num
=
heads
//
kv_head_num
...
@@ -22,19 +18,23 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -22,19 +18,23 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@
T
.
macro
@
T
.
macro
def
flash_attn_split
(
def
flash_attn_split
(
Q
:
T
.
Buffer
(
shape_q
,
dtype
),
Q
:
T
.
Buffer
([
batch
,
heads
,
dim
],
dtype
),
K
:
T
.
Buffer
(
shape_k
,
dtype
),
Q_pe
:
T
.
Buffer
([
batch
,
heads
,
pe_dim
],
dtype
),
V
:
T
.
Buffer
(
shape_v
,
dtype
),
KV
:
T
.
Buffer
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Buffer
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Buffer
([
batch
,
heads
,
num_split
],
dtype
),
glse
:
T
.
Buffer
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Buffer
(
part_shape
,
dtype
),
Output_partial
:
T
.
Buffer
(
[
batch
,
heads
,
num_split
,
dim
]
,
dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
batch
,
heads
//
min
(
block_H
,
kv_group_num
),
num_split
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
batch
,
heads
//
min
(
block_H
,
kv_group_num
),
num_split
,
threads
=
256
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
(
dim
+
pe_dim
)],
dtype
)
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
(
dim
+
pe_dim
)],
dtype
)
S_shared
=
T
.
alloc_shared
([
block_H
,
block_N
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
Q_pe_shared
=
T
.
alloc_shared
([
block_H
,
pe_dim
],
dtype
)
KV_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
K_pe_shared
=
T
.
alloc_shared
([
block_N
,
pe_dim
],
dtype
)
O_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
O_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
acc_s
=
T
.
alloc_fragment
([
block_H
,
block_N
],
accum_dtype
)
acc_s
=
T
.
alloc_fragment
([
block_H
,
block_N
],
accum_dtype
)
acc_s_0
=
T
.
alloc_fragment
([
block_H
,
block_N
],
accum_dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_H
,
block_N
],
dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_H
,
block_N
],
dtype
)
acc_o
=
T
.
alloc_fragment
([
block_H
,
dim
],
accum_dtype
)
acc_o
=
T
.
alloc_fragment
([
block_H
,
dim
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
...
@@ -53,20 +53,32 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -53,20 +53,32 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
})
})
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_shared
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
T
.
ceildiv
((
seqlen_kv
//
num_split
),
block_N
)
loop_range
=
T
.
ceildiv
((
seqlen_kv
//
num_split
),
block_N
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
1
):
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
2
):
kv_start
=
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
block_N
kv_end
=
(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
block_N
T
.
copy
(
KV
[
bid
,
kv_start
:
kv_end
,
cur_kv_head
,
:],
KV_shared
)
T
.
copy
(
T
.
copy
(
K
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
K_pe
[
bid
,
kv_start
:
kv_end
,
cur_kv_head
,
:],
k
*
block_N
:(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
block_N
,
K_pe_shared
cur_kv_head
,
:],
K_shared
)
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
clear
(
acc_s_0
)
T
.
gemm
(
Q_shared
,
KV_shared
,
acc_s_0
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
gemm
(
Q_pe_shared
,
K_pe_shared
,
acc_s_0
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
copy
(
acc_s_0
,
S_shared
)
T
.
copy
(
S_shared
,
acc_s
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
for
i
in
T
.
Parallel
(
block_H
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
...
@@ -78,11 +90,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -78,11 +90,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
copy
(
acc_s
,
acc_s_cast
)
T
.
copy
(
acc_s
,
acc_s_cast
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
T
.
gemm
(
acc_s_cast
,
KV_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
V
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
block_N
:(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
acc_o
[
i
,
j
]
/=
logsum
[
i
]
for
i
in
T
.
Parallel
(
block_H
):
for
i
in
T
.
Parallel
(
block_H
):
...
@@ -96,8 +104,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -96,8 +104,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@
T
.
macro
@
T
.
macro
def
combine
(
def
combine
(
glse
:
T
.
Buffer
([
batch
,
heads
,
num_split
],
dtype
),
glse
:
T
.
Buffer
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Buffer
(
part_shape
,
dtype
),
Output_partial
:
T
.
Buffer
(
[
batch
,
heads
,
num_split
,
dim
]
,
dtype
),
Output
:
T
.
Buffer
(
shape_o
,
dtype
),
Output
:
T
.
Buffer
(
[
batch
,
heads
,
dim
]
,
dtype
),
):
):
with
T
.
Kernel
(
heads
,
batch
,
threads
=
128
)
as
(
by
,
bz
):
with
T
.
Kernel
(
heads
,
batch
,
threads
=
128
)
as
(
by
,
bz
):
po_local
=
T
.
alloc_fragment
([
dim
],
dtype
)
po_local
=
T
.
alloc_fragment
([
dim
],
dtype
)
...
@@ -133,50 +141,63 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
...
@@ -133,50 +141,63 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
Q
:
T
.
Buffer
(
shape_q
,
dtype
),
Q
:
T
.
Buffer
([
batch
,
heads
,
dim
],
dtype
),
K
:
T
.
Buffer
(
shape_k
,
dtype
),
Q_pe
:
T
.
Buffer
([
batch
,
heads
,
pe_dim
],
dtype
),
V
:
T
.
Buffer
(
shape_v
,
dtype
),
KV
:
T
.
Buffer
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Buffer
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Buffer
([
batch
,
heads
,
num_split
],
dtype
),
glse
:
T
.
Buffer
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Buffer
(
part_shape
,
dtype
),
#
[batch, heads, num_split, dim]
Output_partial
:
T
.
Buffer
([
batch
,
heads
,
num_split
,
dim
]
,
dtype
),
Output
:
T
.
Buffer
(
shape_o
,
dtype
),
Output
:
T
.
Buffer
(
[
batch
,
heads
,
dim
]
,
dtype
),
):
):
flash_attn_split
(
Q
,
K
,
V
,
glse
,
Output_partial
)
flash_attn_split
(
Q
,
Q_pe
,
KV
,
K_pe
,
glse
,
Output_partial
)
combine
(
glse
,
Output_partial
,
Output
)
combine
(
glse
,
Output_partial
,
Output
)
return
main
return
main
def
ref_program
(
query
,
key
,
value
,
glse
,
Output_partial
):
def
ref_program
(
q
,
q_pe
,
kv
,
k_pe
,
glse
,
Output_partial
):
# """
# """
# Inputs:
# Inputs:
# - query (Tensor): [batch, heads, dim]
# - q (Tensor): [batch, heads, dim]
# - key (Tensor): [batch, seqlen_kv, kv_head_num, dim]
# - q_pe (Tensor): [batch, heads, pe_dim]
# - value (Tensor): [batch, seqlen_kv, kv_head_num, dim]
# - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim]
# - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim]
# - glse (Tensor): [batch, heads, num_split]
# - Output_partial (Tensor): [batch, heads, num_split, dim]
# Outputs:
# Outputs:
# - output (Tensor): [batch, heads, dim]
# - output (Tensor): [batch, heads, dim]
# """
# """
from
einops
import
rearrange
dim
=
q
.
shape
[
-
1
]
batch_size
,
query_heads
,
dim
=
query
.
shape
# [batch_size, query_heads, dim]
pe_dim
=
q_pe
.
shape
[
-
1
]
_
,
seqlen_kv
,
kv_heads
,
_
=
key
.
shape
# [batch_size, seqlen_kv, kv_heads, kv_dim]
num_head_groups
=
q
.
shape
[
1
]
//
kv
.
shape
[
2
]
dim_v
=
value
.
shape
[
-
1
]
scale
=
(
dim
+
pe_dim
)
**
0.5
assert
kv_heads
==
1
,
"kv_heads must be 1"
q
=
rearrange
(
q
,
'b (h g) d -> b g h d'
,
query_expanded
=
rearrange
(
query
,
'b h d -> b h 1 d'
)
# [batch_size, query_heads, 1, dim]
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
key_expanded
=
key
.
expand
(
-
1
,
-
1
,
query_heads
,
-
1
)
# [batch_size, query_heads, seqlen_kv, dim]
value_expanded
=
value
.
expand
(
-
1
,
-
1
,
query_heads
,
q_pe
=
rearrange
(
-
1
)
# [batch_size, query_heads, seqlen_kv, dim]
q_pe
,
'b (h g) d -> b g h d'
,
key_expanded
=
rearrange
(
key_expanded
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
'b n h d -> b h n d'
)
# [batch_size, kv_head_num, seqlen_kv, dim]
value_expanded
=
rearrange
(
value_expanded
,
kv
=
rearrange
(
kv
,
'b n h d -> b h n d'
)
# [batch_size, groups, seqlen_kv, dim]
'b n h d -> b h n d'
)
# [batch_size, query_heads, seqlen_kv, dim]
k_pe
=
rearrange
(
k_pe
,
'b n h d -> b h n d'
)
# [batch_size, num_head_groups, groups, pe_dim]
scores
=
torch
.
matmul
(
query_expanded
,
key_expanded
.
transpose
(
-
1
,
-
2
))
# [batch_size, query_heads, 1, seqlen_kv]
query
=
torch
.
concat
([
q
,
q_pe
],
dim
=-
1
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
key
=
torch
.
concat
([
kv
,
k_pe
],
dim
=-
1
)
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
# [batch_size, query_heads, 1, seqlen_kv]
output
=
torch
.
matmul
(
attention_weights
,
value_expanded
)
# [batch_size, query_heads, 1, dim]
scores
=
einsum
(
return
output
.
view
(
batch_size
,
query_heads
,
dim_v
)
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, groups, seqlen_kv]
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
out
=
einsum
(
attention
,
kv
,
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
return
out
def
flash_split_ref
(
Q
,
K
,
V
):
def
flash_split_ref
(
Q
,
K
,
V
):
...
@@ -251,7 +272,7 @@ def reduce_ref(Q, K, V, glse, Output_partial):
...
@@ -251,7 +272,7 @@ def reduce_ref(Q, K, V, glse, Output_partial):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
BATCH
,
H_Q
,
KV_H
,
KV_CTX
,
D_HEAD
,
DPE
=
64
,
128
,
1
,
8192
,
512
,
64
BATCH
,
H_Q
,
KV_H
,
KV_CTX
,
D_HEAD
,
DPE
=
128
,
128
,
1
,
8192
,
512
,
64
qk_flops
=
2
*
BATCH
*
H_Q
*
KV_CTX
*
(
D_HEAD
+
DPE
)
qk_flops
=
2
*
BATCH
*
H_Q
*
KV_CTX
*
(
D_HEAD
+
DPE
)
pv_flops
=
2
*
BATCH
*
H_Q
*
KV_CTX
*
D_HEAD
pv_flops
=
2
*
BATCH
*
H_Q
*
KV_CTX
*
D_HEAD
total_flops
=
qk_flops
+
pv_flops
total_flops
=
qk_flops
+
pv_flops
...
@@ -260,8 +281,9 @@ if __name__ == "__main__":
...
@@ -260,8 +281,9 @@ if __name__ == "__main__":
program
=
flashattn
(
BATCH
,
H_Q
,
KV_H
,
KV_CTX
,
D_HEAD
,
DPE
,
BLOCK_N
,
BLOCK_H
)
program
=
flashattn
(
BATCH
,
H_Q
,
KV_H
,
KV_CTX
,
D_HEAD
,
DPE
,
BLOCK_N
,
BLOCK_H
)
mod
,
params
=
tilelang
.
lower
(
program
)
mod
,
params
=
tilelang
.
lower
(
program
)
mod
=
tilelang
.
Profiler
(
mod
,
params
,
[
5
],
tilelang
.
TensorSupplyType
.
Normal
)
mod
=
tilelang
.
Profiler
(
mod
,
params
,
[
6
],
tilelang
.
TensorSupplyType
.
Normal
)
mod
.
assert_allclose
(
ref_program
,
rtol
=
0.01
,
atol
=
0.01
)
mod
.
assert_allclose
(
ref_program
,
rtol
=
0.01
,
atol
=
0.01
)
latency
=
mod
.
do_bench
(
mod
.
func
,
warmup
=
500
)
print
(
"All close"
)
latency
=
mod
.
do_bench
(
mod
.
func
,
n_warmup
=
10
,
n_repeat
=
10
,
profiler
=
"torch"
)
print
(
"Tile-lang: {:.2f} ms"
.
format
(
latency
))
print
(
"Tile-lang: {:.2f} ms"
.
format
(
latency
))
print
(
"Tile-lang: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
print
(
"Tile-lang: {:.2f} TFlops"
.
format
(
total_flops
/
latency
*
1e-9
))
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment