Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
667632cc
Unverified
Commit
667632cc
authored
Dec 22, 2025
by
guchaoyang
Committed by
GitHub
Dec 22, 2025
Browse files
Merge branch 'main' into dcu
parents
d6dd2ddf
a874e4e8
Changes
313
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2093 additions
and
1382 deletions
+2093
-1382
examples/flash_decoding/example_gqa_decode.py
examples/flash_decoding/example_gqa_decode.py
+135
-142
examples/flash_decoding/example_gqa_decode_varlen_logits.py
examples/flash_decoding/example_gqa_decode_varlen_logits.py
+130
-181
examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py
.../flash_decoding/example_gqa_decode_varlen_logits_paged.py
+679
-0
examples/flash_decoding/example_mha_inference.py
examples/flash_decoding/example_mha_inference.py
+73
-78
examples/fusedmoe/example_fusedmoe_tilelang.py
examples/fusedmoe/example_fusedmoe_tilelang.py
+166
-197
examples/fusedmoe/example_fusedmoe_torch.py
examples/fusedmoe/example_fusedmoe_torch.py
+37
-54
examples/fusedmoe/test_example_fusedmoe.py
examples/fusedmoe/test_example_fusedmoe.py
+2
-7
examples/gdn/example_chunk_delta_bwd.py
examples/gdn/example_chunk_delta_bwd.py
+140
-102
examples/gdn/example_chunk_delta_h.py
examples/gdn/example_chunk_delta_h.py
+123
-81
examples/gdn/example_chunk_o.py
examples/gdn/example_chunk_o.py
+50
-41
examples/gdn/example_chunk_o_bwd.py
examples/gdn/example_chunk_o_bwd.py
+99
-109
examples/gdn/example_chunk_scaled_dot_kkt.py
examples/gdn/example_chunk_scaled_dot_kkt.py
+28
-30
examples/gdn/example_cumsum.py
examples/gdn/example_cumsum.py
+20
-24
examples/gdn/example_wy_fast.py
examples/gdn/example_wy_fast.py
+35
-46
examples/gdn/example_wy_fast_bwd_split.py
examples/gdn/example_wy_fast_bwd_split.py
+119
-118
examples/gdn/test_example_gdn_compilation.py
examples/gdn/test_example_gdn_compilation.py
+210
-81
examples/gdn/test_utils.py
examples/gdn/test_utils.py
+6
-8
examples/gemm/README.md
examples/gemm/README.md
+8
-8
examples/gemm/example_gemm.py
examples/gemm/example_gemm.py
+4
-5
examples/gemm/example_gemm_autotune.py
examples/gemm/example_gemm_autotune.py
+29
-70
No files found.
Too many changes to show.
To preserve performance only
313 of 313+
files are displayed.
Plain diff
Email patch
examples/flash_decoding/example_gqa_decode.py
View file @
667632cc
...
...
@@ -15,18 +15,12 @@ torch.random.manual_seed(0)
def
get_configs
():
block_N
=
[
64
,
128
]
block_H
=
[
64
]
num_split
=
[
2
,
4
,
8
]
num_split
=
[
1
,
2
,
4
,
8
]
num_stages
=
[
1
,
2
,
3
]
threads
=
[
128
]
_configs
=
list
(
itertools
.
product
(
block_N
,
block_H
,
num_split
,
num_stages
,
threads
))
configs
=
[{
'block_N'
:
c
[
0
],
'block_H'
:
c
[
1
],
'num_split'
:
c
[
2
],
'num_stages'
:
c
[
3
],
'threads'
:
c
[
4
]
}
for
c
in
_configs
]
configs
=
[{
"block_N"
:
c
[
0
],
"block_H"
:
c
[
1
],
"num_split"
:
c
[
2
],
"num_stages"
:
c
[
3
],
"threads"
:
c
[
4
]}
for
c
in
_configs
]
return
configs
...
...
@@ -42,29 +36,25 @@ def get_heuristic_config() -> Tuple[Dict, int]:
if
sm_version
==
89
:
cfg
=
dict
(
block_N
=
128
,
block_H
=
64
,
num_split
=
1
,
num_stages
=
0
,
threads
=
128
)
else
:
cfg
=
dict
(
block_N
=
128
,
block_H
=
64
,
num_split
=
1
,
num_stages
=
2
,
threads
=
128
)
cfg
=
dict
(
block_N
=
128
,
block_H
=
64
,
num_split
=
8
,
num_stages
=
2
,
threads
=
128
)
return
cfg
,
sm_version
# TODO(lei): fix warp specialized and tma lower pass
def
get_pass_configs
():
return
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
}
return
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
}
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
tilelang
.
jit
(
out_idx
=
[
6
],
pass_configs
=
get_pass_configs
())
def
flashattn
(
batch
,
heads
,
groups
,
seqlen_kv
,
dim
,
block_N
,
block_H
,
num_split
,
num_stages
,
threads
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
def
flashattn
(
batch
,
heads
,
groups
,
seqlen_kv
,
dim
,
block_N
,
block_H
,
num_split
,
num_stages
,
threads
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape_q
=
[
batch
,
heads
,
dim
]
shape_k
=
[
batch
,
seqlen_kv
,
groups
,
dim
]
shape_v
=
[
batch
,
seqlen_kv
,
groups
,
dim
]
shape_o
=
[
batch
,
heads
,
dim
]
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
kv_group_num
=
heads
//
groups
part_shape
=
[
batch
,
heads
,
num_split
,
dim
]
...
...
@@ -73,11 +63,11 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
@
T
.
macro
def
flash_attn
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
mask
:
T
.
Tensor
([
batch
,
seqlen_kv
,
groups
],
"uint8"
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
mask
:
T
.
Tensor
([
batch
,
seqlen_kv
,
groups
],
"uint8"
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
...
...
@@ -98,23 +88,24 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
hid
=
by
cur_kv_head
=
hid
//
(
kv_group_num
//
valid_block_H
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
T
.
ceildiv
((
seqlen_kv
//
num_split
),
block_N
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
T
.
copy
(
K
[
bid
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
copy
(
mask
[
bid
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
],
mask_local
)
T
.
copy
(
K
[
bid
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
copy
(
mask
[
bid
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
],
mask_local
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
mask_local
[
j
]
!=
0
,
acc_s
[
i
,
j
],
-
T
.
infinity
(
accum_dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
mask_local
[
j
]
!=
0
,
acc_s
[
i
,
j
],
-
T
.
infinity
(
accum_dtype
))
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_H
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
...
...
@@ -125,23 +116,23 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
T
.
copy
(
acc_s
,
acc_s_cast
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
V
[
bid
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
copy
(
V
[
bid
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
acc_o
[:
valid_block_H
,
:],
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bid
,
hid
*
valid_block_H
:
(
hid
+
1
)
*
valid_block_H
,
:])
T
.
copy
(
O_shared
,
Output
[
bid
,
hid
*
valid_block_H
:
(
hid
+
1
)
*
valid_block_H
,
:])
@
T
.
macro
def
flash_attn_split
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
mask
:
T
.
Tensor
([
batch
,
seqlen_kv
,
groups
],
"uint8"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
mask
:
T
.
Tensor
([
batch
,
seqlen_kv
,
groups
],
"uint8"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
):
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
...
...
@@ -163,7 +154,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
sid
=
bz
cur_kv_head
=
hid
//
(
kv_group_num
//
valid_block_H
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -172,22 +163,31 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
T
.
copy
(
K
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
valid_block_N
:(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
valid_block_N
,
cur_kv_head
,
:],
K_shared
)
K
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
valid_block_N
:
(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
valid_block_N
,
cur_kv_head
,
:,
],
K_shared
,
)
T
.
copy
(
mask
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
valid_block_N
:(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
valid_block_N
,
cur_kv_head
],
mask_local
)
mask
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
valid_block_N
:
(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
valid_block_N
,
cur_kv_head
,
],
mask_local
,
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
((
mask_local
[
j
]
!=
0
)
&
(
j
<
seqlen_kv
//
num_split
),
acc_s
[
i
,
j
],
-
T
.
infinity
(
accum_dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
((
mask_local
[
j
]
!=
0
)
&
(
j
<
seqlen_kv
//
num_split
),
acc_s
[
i
,
j
],
-
T
.
infinity
(
accum_dtype
))
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_H
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
...
...
@@ -199,9 +199,14 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
V
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
valid_block_N
:(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
valid_block_N
,
cur_kv_head
,
:],
V_shared
)
V
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
valid_block_N
:
(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
valid_block_N
,
cur_kv_head
,
:,
],
V_shared
,
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
...
...
@@ -212,72 +217,74 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
if
i
<
valid_block_H
:
glse
[
bid
,
hid
*
valid_block_H
+
i
,
sid
]
=
logsum
[
i
]
T
.
copy
(
acc_o
[:
valid_block_H
,
:],
O_shared
)
T
.
copy
(
O_shared
,
Output_partial
[
bid
,
hid
*
valid_block_H
:(
hid
+
1
)
*
valid_block_H
,
sid
,
:])
T
.
copy
(
O_shared
,
Output_partial
[
bid
,
hid
*
valid_block_H
:
(
hid
+
1
)
*
valid_block_H
,
sid
,
:])
@
T
.
macro
def
combine
(
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
with
T
.
Kernel
(
heads
,
batch
,
threads
=
128
)
as
(
by
,
bz
):
po_local
=
T
.
alloc_fragment
([
dim
],
dtype
)
o_accum_local
=
T
.
alloc_fragment
([
dim
],
accum_dtype
)
lse_local
=
T
.
alloc_fragment
([
num_split
,
128
],
dtype
)
lse_local_split
=
T
.
alloc_local
([
1
],
accum_dtype
)
lse_logsum_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
lse_logsum_local
=
T
.
alloc_fragment
([
128
],
accum_dtype
)
lse_max_local
=
T
.
alloc_fragment
([
128
],
accum_dtype
)
scale_local
=
T
.
alloc_
local
([
1
],
accum_dtype
)
scale_local
=
T
.
alloc_
fragment
([
1
28
],
accum_dtype
)
T
.
annotate_layout
({
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
lse_max_local
:
T
.
Fragment
(
lse_max_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
# lse_local: (local_id, thread_id)
lse_local
:
T
.
Fragment
(
lse_local
.
shape
,
forward_fn
=
lambda
i
,
j
:
(
j
,
i
)),
})
T
.
annotate_layout
(
{
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
lse_max_local
:
T
.
Fragment
(
lse_max_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
# lse_local: (local_id, thread_id)
lse_local
:
T
.
Fragment
(
lse_local
.
shape
,
forward_fn
=
lambda
i
,
j
:
(
j
,
i
)),
}
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
o_accum_local
)
for
k
,
j
in
T
.
Parallel
(
num_split
,
128
):
lse_local
[
k
,
j
]
=
glse
[
bz
,
by
,
k
]
T
.
reduce_max
(
lse_local
,
lse_max_local
,
dim
=
0
,
clear
=
True
)
for
k
in
T
.
Pipelined
(
num_split
,
num_stages
=
1
):
lse_local_split
[
0
]
=
glse
[
bz
,
by
,
k
]
lse_logsum_local
[
0
]
+=
T
.
exp2
(
lse_local_split
[
0
]
-
lse_max_local
[
0
])
lse_logsum_local
[
0
]
=
T
.
log2
(
lse_logsum_local
[
0
])
+
lse_max_local
[
0
]
for
k
in
T
.
serial
(
num_split
):
for
j
in
T
.
Parallel
(
128
):
lse_logsum_local
[
j
]
+=
T
.
exp2
(
lse_local
[
k
,
j
]
-
lse_max_local
[
j
])
for
j
in
T
.
Parallel
(
128
):
lse_logsum_local
[
j
]
=
T
.
log2
(
lse_logsum_local
[
j
])
+
lse_max_local
[
j
]
for
k
in
T
.
serial
(
num_split
):
for
i
in
T
.
Parallel
(
dim
):
po_local
[
i
]
=
Output_partial
[
bz
,
by
,
k
,
i
]
lse_local_split
[
0
]
=
glse
[
bz
,
by
,
k
]
scale_local
[
0
]
=
T
.
exp2
(
lse_local_split
[
0
]
-
lse_logsum_local
[
0
])
for
j
in
T
.
Parallel
(
128
):
scale_local
[
j
]
=
T
.
exp2
(
lse_local
[
k
,
j
]
-
lse_logsum_local
[
j
])
# Note: Pay attention to dim and the number of threads in Parallel
for
i
in
T
.
Parallel
(
dim
):
o_accum_local
[
i
]
+=
po_local
[
i
]
*
scale_local
[
0
]
o_accum_local
[
i
]
+=
po_local
[
i
]
*
scale_local
[
i
]
for
i
in
T
.
Parallel
(
dim
):
Output
[
bz
,
by
,
i
]
=
o_accum_local
[
i
]
@
T
.
prim_func
def
flashattn_gqa_decode_split
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
mask
:
T
.
Tensor
([
batch
,
seqlen_kv
,
groups
],
"uint8"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
mask
:
T
.
Tensor
([
batch
,
seqlen_kv
,
groups
],
"uint8"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
flash_attn_split
(
Q
,
K
,
V
,
mask
,
glse
,
Output_partial
)
combine
(
glse
,
Output_partial
,
Output
)
@
T
.
prim_func
def
flashattn_gqa_decode_no_split
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
mask
:
T
.
Tensor
([
batch
,
seqlen_kv
,
groups
],
"uint8"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
mask
:
T
.
Tensor
([
batch
,
seqlen_kv
,
groups
],
"uint8"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
flash_attn
(
Q
,
K
,
V
,
mask
,
Output
)
...
...
@@ -300,27 +307,21 @@ def ref_program(query, key, value, mask, glse, Output_partial):
dim
=
query
.
shape
[
-
1
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
scale
=
dim
**
0.5
key
=
rearrange
(
key
,
'
b n h d -> b h n d
'
)
# [batch_size, groups, seqlen_kv, dim]
value
=
rearrange
(
value
,
'
b n h d -> b h n d
'
)
# [batch_size, groups, seqlen_kv, dim]
key
=
rearrange
(
key
,
"
b n h d -> b h n d
"
)
# [batch_size, groups, seqlen_kv, dim]
value
=
rearrange
(
value
,
"
b n h d -> b h n d
"
)
# [batch_size, groups, seqlen_kv, dim]
query
=
rearrange
(
query
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
query
=
rearrange
(
query
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
scores
=
einsum
(
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, groups, seqlen_kv]
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, groups, seqlen_kv]
if
mask
is
not
None
:
mask
=
rearrange
(
mask
,
'
b s h -> b h s
'
)
mask
=
rearrange
(
mask
,
"
b s h -> b h s
"
)
mask
=
mask
.
unsqueeze
(
1
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
"
-inf
"
))
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
out
=
einsum
(
attention
,
value
,
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
out
=
einsum
(
attention
,
value
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
return
out
...
...
@@ -334,16 +335,12 @@ def flash_split_ref(Q, K, V, mask):
seqlen_kv
=
K
.
size
(
1
)
num_head_groups
=
nheads
//
groups
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
acc_s
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
,
block_N
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
acc_s_cast
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
,
block_N
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
acc_s_cast
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
,
block_N
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
acc_o
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
,
dim
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
scores_max
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
scores_max_prev
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
scores_max_prev
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
scores_scale
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
scores_sum
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
logsum
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
...
...
@@ -351,25 +348,25 @@ def flash_split_ref(Q, K, V, mask):
glogsum
=
torch
.
empty
((
num_split
,
batch
,
nheads
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
Q_
=
Q
*
scale
Q_
=
rearrange
(
Q_
,
'
b (h g) d -> b g h d
'
,
g
=
num_head_groups
)
Q_
=
rearrange
(
Q_
,
"
b (h g) d -> b g h d
"
,
g
=
num_head_groups
)
for
ks
in
range
(
num_split
):
acc_o
.
fill_
(
0
)
logsum
.
fill_
(
0
)
scores_max
.
fill_
(
float
(
'
-inf
'
))
scores_max_prev
.
fill_
(
float
(
'
-inf
'
))
scores_max
.
fill_
(
float
(
"
-inf
"
))
scores_max_prev
.
fill_
(
float
(
"
-inf
"
))
for
i
in
range
(
int
((
seqlen_kv
//
num_split
)
/
block_N
)):
acc_s
.
fill_
(
0
)
acc_s
=
torch
.
einsum
(
'bghd,bkhd->bghk'
,
Q_
,
K
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:,
:])
# [batch, nheads, block_N]
acc_s
=
torch
.
einsum
(
"bghd,bkhd->bghk"
,
Q_
,
K
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:
(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:,
:],
)
# [batch, nheads, block_N]
if
mask
is
not
None
:
mask_local
=
mask
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:]
mask_local
=
rearrange
(
mask_local
,
'b s h -> b h s'
)
mask_local
=
mask
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:
(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:]
mask_local
=
rearrange
(
mask_local
,
"b s h -> b h s"
)
mask_local
=
mask_local
.
unsqueeze
(
1
)
acc_s
=
acc_s
.
masked_fill
(
mask_local
==
0
,
float
(
'
-inf
'
))
acc_s
=
acc_s
.
masked_fill
(
mask_local
==
0
,
float
(
"
-inf
"
))
scores_max_prev
=
scores_max
scores_max
=
acc_s
.
max
(
dim
=-
1
,
keepdim
=
False
).
values
# [batch, nheads]
scores_scale
=
torch
.
exp2
(
scores_max_prev
-
scores_max
)
# [batch, nheads]
...
...
@@ -377,15 +374,16 @@ def flash_split_ref(Q, K, V, mask):
acc_s
=
torch
.
exp2
(
acc_s
-
scores_max
[:,
:,
:,
None
])
acc_s_cast
=
acc_s
.
to
(
torch
.
float16
)
# [batch, nheads, block_N]
acc_o
+=
torch
.
einsum
(
'bghk,bkhd->bghd'
,
acc_s_cast
,
V
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:,
:])
"bghk,bkhd->bghd"
,
acc_s_cast
,
V
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:
(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:,
:],
)
scores_sum
=
acc_s
.
sum
(
dim
=-
1
,
keepdim
=
False
)
logsum
=
logsum
*
scores_scale
+
scores_sum
acc_o_out
=
rearrange
(
acc_o
,
'
b g h d->b (h g) d
'
)
logsum_out
=
rearrange
(
logsum
,
'
b g h->b (h g)
'
)
acc_o_out
=
rearrange
(
acc_o
,
"
b g h d->b (h g) d
"
)
logsum_out
=
rearrange
(
logsum
,
"
b g h->b (h g)
"
)
acc_o_out
/=
logsum_out
[:,
:,
None
]
logsum_out
=
torch
.
log2
(
logsum_out
)
+
rearrange
(
scores_max
,
'
b g h->b (h g)
'
)
logsum_out
=
torch
.
log2
(
logsum_out
)
+
rearrange
(
scores_max
,
"
b g h->b (h g)
"
)
gacc_o
[
ks
,
:,
:,
:]
=
acc_o_out
glogsum
[
ks
,
:,
:]
=
logsum_out
...
...
@@ -421,7 +419,7 @@ def calc_sim(x, y, name="tensor"):
x
,
y
=
x
.
data
.
double
(),
y
.
data
.
double
()
denominator
=
(
x
*
x
+
y
*
y
).
sum
()
if
denominator
==
0
:
print_red_warning
(
f
'
{
name
}
all zero
'
)
print_red_warning
(
f
"
{
name
}
all zero
"
)
return
1
sim
=
2
*
(
x
*
y
).
sum
()
/
denominator
return
sim
...
...
@@ -429,28 +427,23 @@ def calc_sim(x, y, name="tensor"):
def
assert_similar
(
x
,
y
,
eps
=
1e-2
,
name
=
"tensor"
,
assert_
=
False
,
print_
=
True
):
sim
=
calc_sim
(
x
,
y
,
name
)
diff
=
1.
-
sim
diff
=
1.
0
-
sim
if
not
(
0
<=
diff
<=
eps
):
print_red_warning
(
f
'
{
name
}
Error:
{
diff
}
'
)
print_red_warning
(
f
"
{
name
}
Error:
{
diff
}
"
)
if
assert_
:
raise
AssertionError
(
f
'
{
name
}
Error:
{
diff
}
'
)
raise
AssertionError
(
f
"
{
name
}
Error:
{
diff
}
"
)
else
:
if
print_
:
print
(
f
'
passed:
{
name
}
diff=
{
diff
}
'
)
print
(
f
"
passed:
{
name
}
diff=
{
diff
}
"
)
def
main
(
batch
:
int
=
1
,
heads
:
int
=
32
,
groups
:
int
=
8
,
kv_seqlen
:
int
=
8192
,
dim
:
int
=
128
,
tune
:
bool
=
False
):
def
main
(
batch
:
int
=
1
,
heads
:
int
=
32
,
groups
:
int
=
8
,
kv_seqlen
:
int
=
8192
,
dim
:
int
=
128
,
tune
:
bool
=
False
):
batch
,
heads
,
groups
,
kv_seqlen
,
dim
=
batch
,
heads
,
groups
,
kv_seqlen
,
dim
qk_flops
=
2
*
batch
*
heads
*
kv_seqlen
*
dim
pv_flops
=
2
*
batch
*
heads
*
kv_seqlen
*
dim
total_flops
=
qk_flops
+
pv_flops
if
(
not
tune
)
:
if
not
tune
:
config
,
sm_version
=
get_heuristic_config
()
kernel
=
flashattn
(
batch
,
heads
,
groups
,
kv_seqlen
,
dim
,
**
config
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Auto
)
...
...
@@ -470,7 +463,7 @@ def main(batch: int = 1,
print
(
o_ref
)
assert_similar
(
o
,
o_ref
,
name
=
"o_ref"
)
assert_similar
(
o_ref_split
,
o_ref
,
name
=
"o_ref_split"
)
assert_similar
(
o
,
o_ref_split
,
name
=
"o_ref_split"
)
print
(
"All checks pass."
)
latency
=
profiler
.
do_bench
(
ref_program
,
warmup
=
500
)
...
...
@@ -492,11 +485,11 @@ def main(batch: int = 1,
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
1
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
32
,
help
=
'
heads
'
)
parser
.
add_argument
(
'
--groups
'
,
type
=
int
,
default
=
8
,
help
=
'
groups
'
)
parser
.
add_argument
(
'
--kv_seqlen
'
,
type
=
int
,
default
=
8192
,
help
=
'
kv sequence length
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
128
,
help
=
'
dim
'
)
parser
.
add_argument
(
'
--tune
'
,
action
=
'
store_true
'
,
help
=
'
tune configs
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
1
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
32
,
help
=
"
heads
"
)
parser
.
add_argument
(
"
--groups
"
,
type
=
int
,
default
=
8
,
help
=
"
groups
"
)
parser
.
add_argument
(
"
--kv_seqlen
"
,
type
=
int
,
default
=
8192
,
help
=
"
kv sequence length
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
128
,
help
=
"
dim
"
)
parser
.
add_argument
(
"
--tune
"
,
action
=
"
store_true
"
,
help
=
"
tune configs
"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
groups
,
args
.
kv_seqlen
,
args
.
dim
,
args
.
tune
)
examples/flash_decoding/example_gqa_decode_varlen_logits.py
View file @
667632cc
...
...
@@ -19,8 +19,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch
,
num_key_value_heads
,
slen
,
head_dim
=
hidden_states
.
shape
if
n_rep
==
1
:
return
hidden_states
hidden_states
=
hidden_states
[:,
:,
None
,
:,
:].
expand
(
batch
,
num_key_value_heads
,
n_rep
,
slen
,
head_dim
)
hidden_states
=
hidden_states
[:,
:,
None
,
:,
:].
expand
(
batch
,
num_key_value_heads
,
n_rep
,
slen
,
head_dim
)
return
hidden_states
.
reshape
(
batch
,
num_key_value_heads
*
n_rep
,
slen
,
head_dim
)
...
...
@@ -74,14 +73,9 @@ def _fwd_inner(
return
m_i
,
l_i
,
acc
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
4
,
8
]
\
for
num_stages
in
[
2
,
4
]
\
],
key
=
[
'gqa_group_size'
,
'BLOCK_N'
,
'BLOCK_D'
,
'BLOCK_H'
],
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
4
,
8
]
for
num_stages
in
[
2
,
4
]],
key
=
[
"gqa_group_size"
,
"BLOCK_N"
,
"BLOCK_D"
,
"BLOCK_H"
],
)
@
triton
.
jit
def
_fwd_kernel_varlen
(
...
...
@@ -107,13 +101,12 @@ def _fwd_kernel_varlen(
stride_od
,
stride_sb
,
stride_sh
,
stride_sn
,
#bmask shape [b, q_h, seq/BLOCK_N]
stride_sn
,
#
bmask shape [b, q_h, seq/BLOCK_N]
gqa_group_size
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
):
off_z
=
tl
.
program_id
(
0
)
off_h_for_kv
=
tl
.
program_id
(
1
)
off_h_q
=
off_h_for_kv
*
gqa_group_size
...
...
@@ -134,8 +127,7 @@ def _fwd_kernel_varlen(
S_ptrs
=
S
+
off_z
*
stride_sb
+
off_h_q
*
stride_sh
mask_h
=
offs_h
<
gqa_group_size
q
=
tl
.
load
(
Q_ptrs
+
offs_d
[
None
,
:]
*
stride_qd
+
offs_h
[:,
None
]
*
stride_qh
,
mask
=
mask_h
[:,
None
])
q
=
tl
.
load
(
Q_ptrs
+
offs_d
[
None
,
:]
*
stride_qd
+
offs_h
[:,
None
]
*
stride_qh
,
mask
=
mask_h
[:,
None
])
if
s_aux
is
not
None
:
sink
=
tl
.
load
(
s_aux
+
off_h_q
+
offs_h
,
mask
=
mask_h
).
to
(
tl
.
float32
)
...
...
@@ -189,14 +181,12 @@ def _fwd_kernel_varlen(
acc
=
acc
.
to
(
O
.
dtype
.
element_ty
)
tl
.
store
(
O_ptrs
+
offs_h
[:,
None
]
*
stride_oh
+
offs_d
[
None
,
:]
*
stride_od
,
acc
,
mask
=
mask_h
[:,
None
])
tl
.
store
(
O_ptrs
+
offs_h
[:,
None
]
*
stride_oh
+
offs_d
[
None
,
:]
*
stride_od
,
acc
,
mask
=
mask_h
[:,
None
])
def
get_configs
():
import
itertools
block_N
=
[
64
,
128
]
block_H
=
[
64
]
num_split
=
[
1
]
...
...
@@ -204,38 +194,23 @@ def get_configs():
threads
=
[
128
]
_configs
=
list
(
itertools
.
product
(
block_N
,
block_H
,
num_split
,
num_stages
,
threads
))
configs
=
[{
'block_N'
:
c
[
0
],
'block_H'
:
c
[
1
],
'num_split'
:
c
[
2
],
'num_stages'
:
c
[
3
],
'threads'
:
c
[
4
]
}
for
c
in
_configs
]
configs
=
[{
"block_N"
:
c
[
0
],
"block_H"
:
c
[
1
],
"num_split"
:
c
[
2
],
"num_stages"
:
c
[
3
],
"threads"
:
c
[
4
]}
for
c
in
_configs
]
return
configs
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
tilelang
.
jit
(
out_idx
=
[
-
2
,
-
1
],
debug_root_path
=
"./examples/flash_decoding"
)
def
flashattn
(
batch
,
heads
,
k_heads
,
max_seqlen_kv
,
total_seqlen_k
,
dim
,
has_sink
,
block_N
=
128
,
block_H
=
64
,
num_split
=
1
,
num_stages
=
1
,
threads
=
128
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
def
flashattn
(
batch
,
heads
,
k_heads
,
max_seqlen_kv
,
total_seqlen_k
,
dim
,
has_sink
,
block_N
=
128
,
block_H
=
64
,
num_split
=
1
,
num_stages
=
1
,
threads
=
128
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape_q
=
[
batch
,
heads
,
dim
]
shape_k
=
[
total_seqlen_k
,
k_heads
,
dim
]
shape_v
=
[
total_seqlen_k
,
k_heads
,
dim
]
shape_o
=
[
batch
,
heads
,
dim
]
shape_s
=
[
batch
,
heads
,
math
.
ceil
(
max_seqlen_kv
/
block_N
)]
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
kv_group_num
=
heads
//
k_heads
valid_block_H
=
min
(
block_H
,
kv_group_num
)
...
...
@@ -243,13 +218,13 @@ def flashattn(batch,
@
T
.
macro
def
flash_attn
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
cu_seqlens_k
:
T
.
Tensor
([
batch
+
1
],
"
int32
"
),
s_aux
:
T
.
Tensor
([
heads
],
"
float32
"
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
S
:
T
.
Tensor
(
shape_s
,
dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
cu_seqlens_k
:
T
.
Tensor
([
batch
+
1
],
T
.
int32
),
s_aux
:
T
.
Tensor
([
heads
],
T
.
float32
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
S
:
T
.
Tensor
(
shape_s
,
dtype
),
):
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
...
...
@@ -266,15 +241,17 @@ def flashattn(batch,
logsum
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
S_shared
=
T
.
alloc_shared
([
block_H
,
math
.
ceil
(
max_seqlen_kv
/
block_N
)],
dtype
)
# S_fragment = T.alloc_fragment([block_H, math.ceil(max_seqlen_kv / block_N)], accum_dtype)
s_aux_shared
=
T
.
alloc_shared
([
block_H
],
"float32"
)
T
.
annotate_layout
({
# Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
# K_shared: tilelang.layout.make_swizzled_layout(K_shared),
# V_shared: tilelang.layout.make_swizzled_layout(V_shared),
# O_shared: tilelang.layout.make_swizzled_layout(O_shared),
# S_shared: tilelang.layout.make_swizzled_layout(S_shared),
})
s_aux_shared
=
T
.
alloc_shared
([
block_H
],
T
.
float32
)
T
.
annotate_layout
(
{
# Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
# K_shared: tilelang.layout.make_swizzled_layout(K_shared),
# V_shared: tilelang.layout.make_swizzled_layout(V_shared),
# O_shared: tilelang.layout.make_swizzled_layout(O_shared),
# S_shared: tilelang.layout.make_swizzled_layout(S_shared),
}
)
bid
=
bx
hid
=
by
...
...
@@ -284,7 +261,7 @@ def flashattn(batch,
cur_end_k
=
cu_seqlens_k
[
bid
+
1
]
cur_seqlen_k
=
cur_end_k
-
cur_start_k
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -292,15 +269,13 @@ def flashattn(batch,
# loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
loop_range
=
T
.
ceildiv
((
cur_seqlen_k
//
num_split
),
block_N
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
T
.
copy
(
K
[
cur_start_k
+
k
*
block_N
:
cur_start_k
+
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
copy
(
K
[
cur_start_k
+
k
*
block_N
:
cur_start_k
+
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
# acc_s[i, j] = T.if_then_else(mask_local[j] != 0 and k * block_N + j < cur_seqlen_k, acc_s[i, j],
# -T.infinity(accum_dtype))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
k
*
block_N
+
j
<
cur_seqlen_k
,
acc_s
[
i
,
j
],
-
T
.
infinity
(
accum_dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
k
*
block_N
+
j
<
cur_seqlen_k
,
acc_s
[
i
,
j
],
-
T
.
infinity
(
accum_dtype
))
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
...
...
@@ -320,12 +295,11 @@ def flashattn(batch,
T
.
copy
(
acc_s
,
acc_s_cast
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
V
[
cur_start_k
+
k
*
block_N
:
cur_start_k
+
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
copy
(
V
[
cur_start_k
+
k
*
block_N
:
cur_start_k
+
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
if
has_sink
:
T
.
copy
(
s_aux
[
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
],
s_aux_shared
)
T
.
copy
(
s_aux
[
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
],
s_aux_shared
)
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
+=
s_aux_shared
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
...
...
@@ -338,20 +312,19 @@ def flashattn(batch,
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
acc_o
[:
valid_block_H
,
:],
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bid
,
hid
*
valid_block_H
:
(
hid
+
1
)
*
valid_block_H
,
:])
T
.
copy
(
O_shared
,
Output
[
bid
,
hid
*
valid_block_H
:
(
hid
+
1
)
*
valid_block_H
,
:])
# T.copy(S_fragment, S_shared)
T
.
copy
(
S_shared
[:
valid_block_H
,
:],
S
[
bid
,
hid
*
valid_block_H
:(
hid
+
1
)
*
valid_block_H
,
:])
T
.
copy
(
S_shared
[:
valid_block_H
,
:],
S
[
bid
,
hid
*
valid_block_H
:
(
hid
+
1
)
*
valid_block_H
,
:])
@
T
.
prim_func
def
flashattn_gqa_decode_no_split
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
cu_seqlens_k
:
T
.
Tensor
([
batch
+
1
],
"
int32
"
),
s_aux
:
T
.
Tensor
([
heads
],
"
float32
"
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
S
:
T
.
Tensor
(
shape_s
,
dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
cu_seqlens_k
:
T
.
Tensor
([
batch
+
1
],
T
.
int32
),
s_aux
:
T
.
Tensor
([
heads
],
T
.
float32
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
S
:
T
.
Tensor
(
shape_s
,
dtype
),
):
flash_attn
(
Q
,
K
,
V
,
cu_seqlens_k
,
s_aux
,
Output
,
S
)
...
...
@@ -388,9 +361,7 @@ def flash_attn_with_attn_pool_decode_tilelang(
gqa_group_size
=
q_h
//
k_h
O_tl
=
torch
.
zeros_like
(
Q
)
S_tl
=
torch
.
zeros
((
batch
,
q_h
,
math
.
ceil
(
real_max_k_seqlen
/
block_size
)),
dtype
=
Q
.
dtype
,
device
=
Q
.
device
)
S_tl
=
torch
.
zeros
((
batch
,
q_h
,
math
.
ceil
(
real_max_k_seqlen
/
block_size
)),
dtype
=
Q
.
dtype
,
device
=
Q
.
device
)
O_tl
,
S_tl
=
tl_kernel
(
Q
,
K
,
V
,
cu_seqlens_k
,
s_aux
)
if
use_per_kv_head_sparse_index
:
...
...
@@ -433,9 +404,7 @@ def flash_attn_with_attn_pool_decode(
BLOCK_H
=
64
O
=
torch
.
zeros_like
(
Q
)
S
=
torch
.
zeros
((
batch
,
q_h
,
math
.
ceil
(
max_seqlen_k
/
block_size
)),
dtype
=
Q
.
dtype
,
device
=
Q
.
device
)
S
=
torch
.
zeros
((
batch
,
q_h
,
math
.
ceil
(
max_seqlen_k
/
block_size
)),
dtype
=
Q
.
dtype
,
device
=
Q
.
device
)
def
grid
(
META
):
return
(
batch
,
k_h
)
...
...
@@ -480,18 +449,18 @@ def test_equal_seqlen_decode_main(args):
real_max_k_seqlen
=
args
.
k_seqlen
head_size
=
args
.
head_size
block_size
=
args
.
block_size
dtype
=
torch
.
bfloat16
if
args
.
dtype
==
"
bfloat16
"
else
torch
.
float16
dtype
=
torch
.
bfloat16
if
args
.
dtype
==
T
.
bfloat16
else
torch
.
float16
# For decode, query is just 1 token per batch
q
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
k
=
torch
.
randn
(
batch_size
,
kv_heads
,
k_seqlen
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
v
=
torch
.
randn
(
batch_size
,
kv_heads
,
k_seqlen
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
q
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
k
=
torch
.
randn
(
batch_size
,
kv_heads
,
k_seqlen
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
v
=
torch
.
randn
(
batch_size
,
kv_heads
,
k_seqlen
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
softmax_scale
=
1.0
/
math
.
sqrt
(
head_size
)
# Generate sink values if needed
sink
=
None
if
args
.
test_sink
:
sink
=
torch
.
randn
(
q_heads
,
device
=
'
cuda
'
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
sink
=
torch
.
randn
(
q_heads
,
device
=
"
cuda
"
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
print
(
f
"Using sink attention with sink values:
{
sink
}
"
)
# Convert to varlen format for K, V
...
...
@@ -499,8 +468,7 @@ def test_equal_seqlen_decode_main(args):
v_varlen
=
v
.
transpose
(
1
,
2
).
reshape
(
batch_size
*
k_seqlen
,
kv_heads
,
head_size
)
# Generate cumulative sequence lengths
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
k_seqlen
,
k_seqlen
,
device
=
'cuda'
,
dtype
=
torch
.
int32
)
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
k_seqlen
,
k_seqlen
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
max_seqlen_k
=
k_seqlen
print
(
f
"q shape:
{
q
.
shape
}
"
)
...
...
@@ -510,8 +478,7 @@ def test_equal_seqlen_decode_main(args):
num_tokens
,
q_h
,
head_size
=
q
.
shape
batch
=
cu_seqlens_k
.
size
(
0
)
-
1
k_h
=
k_varlen
.
size
(
1
)
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
args
.
test_sink
)
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
args
.
test_sink
)
# Test our decode kernel
O_triton
,
S_triton
=
flash_attn_with_attn_pool_decode
(
...
...
@@ -524,7 +491,8 @@ def test_equal_seqlen_decode_main(args):
args
.
num_split
,
softmax_scale
,
s_aux
=
sink
,
block_size
=
block_size
)
block_size
=
block_size
,
)
O_tilelang
,
S_tilelang
=
flash_attn_with_attn_pool_decode_tilelang
(
q
,
k_varlen
,
...
...
@@ -539,9 +507,7 @@ def test_equal_seqlen_decode_main(args):
tl_kernel
=
tl_kernel
,
)
for
i
in
range
(
batch_size
):
S_tilelang
[
i
,
:,
math
.
ceil
((
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
())
/
block_size
):]
=
0
S_tilelang
[
i
,
:,
math
.
ceil
((
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
())
/
block_size
)
:]
=
0
# Compute torch reference
q_expanded
=
q
.
unsqueeze
(
2
)
# [b, q_heads, 1, head_size]
...
...
@@ -550,14 +516,12 @@ def test_equal_seqlen_decode_main(args):
if
sink
is
None
:
# Standard scaled dot-product attention
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [batch, q_heads, 1, seqlen_k]
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [batch, q_heads, 1, seqlen_k]
attn_weights
=
torch
.
softmax
(
logits
,
dim
=-
1
)
O_torch
=
torch
.
matmul
(
attn_weights
,
v_repeat
).
squeeze
(
2
)
# [batch, q_heads, head_size]
else
:
# s_aux attention
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [batch, q_heads, 1, seqlen_k]
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [batch, q_heads, 1, seqlen_k]
sink_expanded
=
sink
.
view
(
1
,
q_heads
,
1
,
1
)
# [1, q_heads, 1, 1]
logits_max
=
torch
.
max
(
logits
,
dim
=-
1
,
keepdim
=
True
).
values
...
...
@@ -566,15 +530,15 @@ def test_equal_seqlen_decode_main(args):
unnormalized_scores
=
torch
.
exp
(
logits
-
logits_or_sinks_max
)
normalizer
=
unnormalized_scores
.
sum
(
dim
=-
1
,
keepdim
=
True
)
+
sinks
attn_weights
=
unnormalized_scores
/
normalizer
O_torch
=
torch
.
matmul
(
attn_weights
.
to
(
v_repeat
.
dtype
),
v_repeat
).
squeeze
(
2
)
# [batch, q_heads, head_size]
O_torch
=
torch
.
matmul
(
attn_weights
.
to
(
v_repeat
.
dtype
),
v_repeat
).
squeeze
(
2
)
# [batch, q_heads, head_size]
# Compute attention score pooling
attn_score_pooled
=
torch
.
max_pool2d
(
attn_weights
.
squeeze
(
2
),
# [b, q_heads, k_seqlen]
kernel_size
=
(
q_heads
,
block_size
),
stride
=
(
q_heads
,
block_size
),
ceil_mode
=
True
).
to
(
torch
.
float16
)
ceil_mode
=
True
,
).
to
(
torch
.
float16
)
print
(
"S_tilelang"
,
S_tilelang
)
print
(
"attn_score_pooled"
,
attn_score_pooled
)
...
...
@@ -588,15 +552,10 @@ def test_equal_seqlen_decode_main(args):
print
(
f
"Max difference in S:
{
max_diff_s
.
item
()
}
"
)
print
(
f
"Max difference in O_tilelang:
{
max_diff_o_tilelang
.
item
()
}
"
)
print
(
f
"Max difference in S_tilelang:
{
max_diff_s_tilelang
.
item
()
}
"
)
assert
torch
.
allclose
(
O_triton
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o
.
item
()
}
"
assert
torch
.
allclose
(
S_triton
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s
.
item
()
}
"
assert
torch
.
allclose
(
O_tilelang
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o_tilelang
.
item
()
}
"
assert
torch
.
allclose
(
S_tilelang
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s_tilelang
.
item
()
}
"
assert
torch
.
allclose
(
O_triton
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o
.
item
()
}
"
assert
torch
.
allclose
(
S_triton
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s
.
item
()
}
"
assert
torch
.
allclose
(
O_tilelang
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o_tilelang
.
item
()
}
"
assert
torch
.
allclose
(
S_tilelang
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s_tilelang
.
item
()
}
"
print
(
"✅ All tests passed!"
)
...
...
@@ -609,14 +568,14 @@ def test_varlen_decode_main(args):
real_max_k_seqlen
=
args
.
k_seqlen
head_size
=
args
.
head_size
block_size
=
args
.
block_size
dtype
=
torch
.
bfloat16
if
args
.
dtype
==
"
bfloat16
"
else
torch
.
float16
dtype
=
torch
.
bfloat16
if
args
.
dtype
==
T
.
bfloat16
else
torch
.
float16
print
(
f
"Testing decode kernel with variable sequence lengths (max_k_seqlen=
{
max_k_seqlen
}
)"
)
# Generate sink values if needed
sink
=
None
if
args
.
test_sink
:
sink
=
torch
.
randn
(
q_heads
,
device
=
'
cuda
'
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
sink
=
torch
.
randn
(
q_heads
,
device
=
"
cuda
"
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
print
(
f
"Using sink attention with sink values:
{
sink
}
"
)
# Generate variable length k sequences
...
...
@@ -624,7 +583,7 @@ def test_varlen_decode_main(args):
print
(
f
"k_seqlens:
{
k_seqlens
}
"
)
# Generate cumulative sequence lengths for k
cu_seqlens_k
=
torch
.
zeros
(
batch_size
+
1
,
device
=
'
cuda
'
,
dtype
=
torch
.
int32
)
cu_seqlens_k
=
torch
.
zeros
(
batch_size
+
1
,
device
=
"
cuda
"
,
dtype
=
torch
.
int32
)
total_k_tokens
=
0
for
i
in
range
(
batch_size
):
cu_seqlens_k
[
i
]
=
total_k_tokens
...
...
@@ -634,9 +593,9 @@ def test_varlen_decode_main(args):
print
(
f
"cu_seqlens_k:
{
cu_seqlens_k
}
"
)
# Generate tensors - Q is [batch_size, q_heads, head_size] for decode
q_decode
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
k_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
v_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
q_decode
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
k_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
v_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
softmax_scale
=
1.0
/
math
.
sqrt
(
head_size
)
max_seqlen_k
=
int
(
k_seqlens
.
max
())
...
...
@@ -649,8 +608,7 @@ def test_varlen_decode_main(args):
num_tokens
,
q_h
,
head_size
=
q_decode
.
shape
batch
=
cu_seqlens_k
.
size
(
0
)
-
1
k_h
=
k_varlen
.
size
(
1
)
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
args
.
test_sink
)
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
args
.
test_sink
)
# Test our decode kernel
O_triton
,
S_triton
=
flash_attn_with_attn_pool_decode
(
...
...
@@ -663,7 +621,8 @@ def test_varlen_decode_main(args):
args
.
num_split
,
softmax_scale
,
s_aux
=
sink
,
block_size
=
block_size
)
block_size
=
block_size
,
)
O_tilelang
,
S_tilelang
=
flash_attn_with_attn_pool_decode_tilelang
(
q_decode
,
k_varlen
,
...
...
@@ -678,9 +637,7 @@ def test_varlen_decode_main(args):
tl_kernel
=
tl_kernel
,
)
for
i
in
range
(
batch_size
):
S_tilelang
[
i
,
:,
math
.
ceil
((
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
())
/
block_size
):]
=
0
S_tilelang
[
i
,
:,
math
.
ceil
((
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
())
/
block_size
)
:]
=
0
# Create torch reference - pad tensors for comparison
k_padded_list
=
[]
...
...
@@ -694,8 +651,8 @@ def test_varlen_decode_main(args):
k_end
=
cu_seqlens_k
[
i
+
1
]
# Pad to max_seqlen_k
k_padded
=
torch
.
zeros
(
max_seqlen_k
,
kv_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
v_padded
=
torch
.
zeros
(
max_seqlen_k
,
kv_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
k_padded
=
torch
.
zeros
(
max_seqlen_k
,
kv_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
v_padded
=
torch
.
zeros
(
max_seqlen_k
,
kv_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
k_padded
[:
actual_k_len
]
=
k_varlen
[
k_start
:
k_end
]
v_padded
[:
actual_k_len
]
=
v_varlen
[
k_start
:
k_end
]
...
...
@@ -704,10 +661,8 @@ def test_varlen_decode_main(args):
v_padded_list
.
append
(
v_padded
)
# Stack to create batched tensors [b, max_seqlen, kv_heads, head_size]
k_padded_batched
=
torch
.
stack
(
k_padded_list
,
dim
=
0
).
transpose
(
1
,
2
)
# [b, kv_heads, max_seqlen, head_size]
v_padded_batched
=
torch
.
stack
(
v_padded_list
,
dim
=
0
).
transpose
(
1
,
2
)
# [b, kv_heads, max_seqlen, head_size]
k_padded_batched
=
torch
.
stack
(
k_padded_list
,
dim
=
0
).
transpose
(
1
,
2
)
# [b, kv_heads, max_seqlen, head_size]
v_padded_batched
=
torch
.
stack
(
v_padded_list
,
dim
=
0
).
transpose
(
1
,
2
)
# [b, kv_heads, max_seqlen, head_size]
# Expand q to match kv heads: [b, q_heads, 1, head_size]
q_expanded
=
q_decode
.
unsqueeze
(
2
)
# [b, q_heads, 1, head_size]
...
...
@@ -717,20 +672,17 @@ def test_varlen_decode_main(args):
print
(
f
"v_padded_batched shape:
{
v_padded_batched
.
shape
}
"
)
# Compute torch reference
k_repeat
=
repeat_kv
(
k_padded_batched
,
q_heads
//
kv_heads
)
# [b, q_heads, max_seqlen, head_size]
v_repeat
=
repeat_kv
(
v_padded_batched
,
q_heads
//
kv_heads
)
# [b, q_heads, max_seqlen, head_size]
k_repeat
=
repeat_kv
(
k_padded_batched
,
q_heads
//
kv_heads
)
# [b, q_heads, max_seqlen, head_size]
v_repeat
=
repeat_kv
(
v_padded_batched
,
q_heads
//
kv_heads
)
# [b, q_heads, max_seqlen, head_size]
if
sink
is
None
:
# Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen]
attn_score
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [b, q_heads, 1, max_seqlen]
attn_score
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [b, q_heads, 1, max_seqlen]
# Apply sequence length masking
for
i
in
range
(
batch_size
):
actual_k_len
=
k_seqlens
[
i
]
attn_score
[
i
,
:,
:,
actual_k_len
:]
=
float
(
'
-inf
'
)
attn_score
[
i
,
:,
:,
actual_k_len
:]
=
float
(
"
-inf
"
)
attn_weights
=
attn_score
.
softmax
(
dim
=-
1
)
# [b, q_heads, 1, max_seqlen]
...
...
@@ -743,13 +695,12 @@ def test_varlen_decode_main(args):
O_torch
=
torch
.
matmul
(
attn_weights
,
v_repeat
)
# [b, q_heads, 1, head_size]
else
:
# s_aux attention
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [b, q_heads, 1, max_seqlen]
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [b, q_heads, 1, max_seqlen]
# Apply sequence length masking
for
i
in
range
(
batch_size
):
actual_k_len
=
k_seqlens
[
i
]
logits
[
i
,
:,
:,
actual_k_len
:]
=
float
(
'
-inf
'
)
logits
[
i
,
:,
:,
actual_k_len
:]
=
float
(
"
-inf
"
)
sink_expanded
=
sink
.
view
(
1
,
q_heads
,
1
,
1
)
# [1, q_heads, 1, 1]
logits_max
=
torch
.
max
(
logits
,
dim
=-
1
,
keepdim
=
True
).
values
...
...
@@ -765,8 +716,7 @@ def test_varlen_decode_main(args):
attn_weights
[
i
,
:,
:,
actual_k_len
:]
=
0.0
# Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size]
O_torch
=
torch
.
matmul
(
attn_weights
.
to
(
v_repeat
.
dtype
),
v_repeat
)
# [b, q_heads, 1, head_size]
O_torch
=
torch
.
matmul
(
attn_weights
.
to
(
v_repeat
.
dtype
),
v_repeat
)
# [b, q_heads, 1, head_size]
O_torch
=
O_torch
.
squeeze
(
2
)
# [b, q_heads, head_size]
...
...
@@ -775,7 +725,8 @@ def test_varlen_decode_main(args):
attn_weights
.
squeeze
(
2
),
# [b, q_heads, max_seqlen]
kernel_size
=
(
q_heads
,
block_size
),
stride
=
(
q_heads
,
block_size
),
ceil_mode
=
True
).
to
(
dtype
=
torch
.
float16
)
# [b, 1, ceil(max_seqlen/block_size)]
ceil_mode
=
True
,
).
to
(
dtype
=
torch
.
float16
)
# [b, 1, ceil(max_seqlen/block_size)]
print
(
f
"O_triton shape:
{
O_triton
.
shape
}
"
)
print
(
f
"O_tilelang shape:
{
O_tilelang
.
shape
}
"
)
...
...
@@ -791,22 +742,16 @@ def test_varlen_decode_main(args):
print
(
f
"Max difference in O_tilelang:
{
max_diff_o_tl
.
item
()
}
"
)
max_diff_s
=
torch
.
max
(
torch
.
abs
(
S_triton
-
attn_score_pooled
))
max_diff_s_tl
=
torch
.
max
(
torch
.
abs
(
S_tilelang
[:,
:,
:
math
.
ceil
(
max_seqlen_k
/
block_size
)]
-
attn_score_pooled
))
max_diff_s_tl
=
torch
.
max
(
torch
.
abs
(
S_tilelang
[:,
:,
:
math
.
ceil
(
max_seqlen_k
/
block_size
)]
-
attn_score_pooled
))
print
(
f
"Max difference in S:
{
max_diff_s
.
item
()
}
"
)
print
(
f
"Max difference in S_tilelang:
{
max_diff_s_tl
.
item
()
}
"
)
assert
torch
.
allclose
(
O_triton
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o
.
item
()
}
"
assert
torch
.
allclose
(
S_triton
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s
.
item
()
}
"
assert
torch
.
allclose
(
O_tilelang
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o_tl
.
item
()
}
"
assert
torch
.
allclose
(
S_tilelang
[:,
:,
:
math
.
ceil
(
max_seqlen_k
/
block_size
)],
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s_tl
.
item
()
}
"
assert
torch
.
allclose
(
O_triton
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o
.
item
()
}
"
assert
torch
.
allclose
(
S_triton
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s
.
item
()
}
"
assert
torch
.
allclose
(
O_tilelang
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o_tl
.
item
()
}
"
assert
torch
.
allclose
(
S_tilelang
[:,
:,
:
math
.
ceil
(
max_seqlen_k
/
block_size
)],
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
(
f
"Score mismatch:
{
max_diff_s_tl
.
item
()
}
"
)
print
(
"✅ All tests passed!"
)
...
...
@@ -844,7 +789,7 @@ def speed_benchmark_decode_comparison(args):
max_k_seqlen
=
args
.
k_seqlen
head_size
=
args
.
head_size
block_size
=
args
.
block_size
dtype
=
torch
.
bfloat16
if
args
.
dtype
==
"
bfloat16
"
else
torch
.
float16
dtype
=
torch
.
bfloat16
if
args
.
dtype
==
T
.
bfloat16
else
torch
.
float16
print
(
"
\n
=== Decode Speed Benchmark Comparison ==="
)
print
(
"Configuration:"
)
...
...
@@ -865,7 +810,7 @@ def speed_benchmark_decode_comparison(args):
k_seqlens
=
torch
.
full
((
batch_size
,),
max_k_seqlen
,
dtype
=
int
)
# Generate cumulative sequence lengths for k
cu_seqlens_k
=
torch
.
zeros
(
batch_size
+
1
,
device
=
'
cuda
'
,
dtype
=
torch
.
int32
)
cu_seqlens_k
=
torch
.
zeros
(
batch_size
+
1
,
device
=
"
cuda
"
,
dtype
=
torch
.
int32
)
total_k_tokens
=
0
for
i
in
range
(
batch_size
):
cu_seqlens_k
[
i
]
=
total_k_tokens
...
...
@@ -873,9 +818,9 @@ def speed_benchmark_decode_comparison(args):
cu_seqlens_k
[
batch_size
]
=
total_k_tokens
# Generate tensors
q_decode
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
k_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
v_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
q_decode
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
k_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
v_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
softmax_scale
=
1.0
/
math
.
sqrt
(
head_size
)
max_seqlen_k
=
int
(
k_seqlens
.
max
())
...
...
@@ -883,7 +828,7 @@ def speed_benchmark_decode_comparison(args):
# Generate sink values if needed
sink
=
None
if
args
.
test_sink
:
sink
=
torch
.
randn
(
q_heads
,
device
=
'
cuda
'
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
sink
=
torch
.
randn
(
q_heads
,
device
=
"
cuda
"
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
print
(
" Using sink attention with sink values"
)
print
(
"Setup complete:"
)
...
...
@@ -896,8 +841,7 @@ def speed_benchmark_decode_comparison(args):
num_tokens
,
q_h
,
head_size
=
q_decode
.
shape
batch
=
cu_seqlens_k
.
size
(
0
)
-
1
k_h
=
k_varlen
.
size
(
1
)
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
args
.
test_sink
)
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
args
.
test_sink
)
# Benchmark
print
(
"⚡ Benchmarking Tilelang kernel (100 iterations)..."
)
...
...
@@ -920,36 +864,41 @@ def speed_benchmark_decode_comparison(args):
# Benchmark
print
(
"⚡ Benchmarking Triton kernel (100 iterations)..."
)
triton_time
=
do_bench
(
flash_attn_with_attn_pool_decode
,
q_decode
,
k_varlen
,
v_varlen
,
cu_seqlens_k
,
max_seqlen_k
,
args
.
k_seqlen
,
1
,
softmax_scale
,
sink
,
block_size
)
triton_time
=
do_bench
(
flash_attn_with_attn_pool_decode
,
q_decode
,
k_varlen
,
v_varlen
,
cu_seqlens_k
,
max_seqlen_k
,
args
.
k_seqlen
,
1
,
softmax_scale
,
sink
,
block_size
,
)
print
(
f
"Average decode kernel time Triton:
{
triton_time
:.
3
f
}
ms"
)
print
(
f
"Speedup:
{
(
triton_time
/
tilelang_time
):.
3
f
}
"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
'Flash Attention Decode with Attention Pooling'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
1
,
help
=
'Batch size'
)
parser
.
add_argument
(
'--q_heads'
,
type
=
int
,
default
=
32
,
help
=
'Number of query heads'
)
parser
.
add_argument
(
'--kv_heads'
,
type
=
int
,
default
=
8
,
help
=
'Number of key-value heads'
)
parser
.
add_argument
(
'--k_seqlen'
,
type
=
int
,
default
=
8192
,
help
=
'Key sequence length'
)
parser
.
add_argument
(
'--head_size'
,
type
=
int
,
default
=
128
,
choices
=
[
64
,
128
,
256
],
help
=
'Head dimension'
)
parser
.
add_argument
(
'--block_size'
,
type
=
int
,
default
=
64
,
help
=
'Block size for computation'
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'bfloat16'
,
choices
=
[
'float16'
,
'bfloat16'
],
help
=
'Data type'
)
parser
.
add_argument
(
'--test_varlen'
,
action
=
'store_true'
,
help
=
'Test with truly variable sequence lengths'
)
parser
.
add_argument
(
'--test_sink'
,
action
=
'store_true'
,
help
=
'Test with sink attention mechanism'
)
parser
.
add_argument
(
'--benchmark'
,
action
=
'store_true'
,
help
=
'Run speed benchmark'
)
parser
.
add_argument
(
'--num_split'
,
type
=
int
,
default
=
1
,
choices
=
[
1
,
16
],
help
=
'Number of splits'
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Flash Attention Decode with Attention Pooling"
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
,
help
=
"Batch size"
)
parser
.
add_argument
(
"--q_heads"
,
type
=
int
,
default
=
32
,
help
=
"Number of query heads"
)
parser
.
add_argument
(
"--kv_heads"
,
type
=
int
,
default
=
8
,
help
=
"Number of key-value heads"
)
parser
.
add_argument
(
"--k_seqlen"
,
type
=
int
,
default
=
8192
,
help
=
"Key sequence length"
)
parser
.
add_argument
(
"--head_size"
,
type
=
int
,
default
=
128
,
choices
=
[
64
,
128
,
256
],
help
=
"Head dimension"
)
parser
.
add_argument
(
"--block_size"
,
type
=
int
,
default
=
64
,
help
=
"Block size for computation"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
T
.
bfloat16
,
choices
=
[
T
.
float16
,
T
.
bfloat16
],
help
=
"Data type"
)
parser
.
add_argument
(
"--test_varlen"
,
action
=
"store_true"
,
help
=
"Test with truly variable sequence lengths"
)
parser
.
add_argument
(
"--test_sink"
,
action
=
"store_true"
,
help
=
"Test with sink attention mechanism"
)
parser
.
add_argument
(
"--benchmark"
,
action
=
"store_true"
,
help
=
"Run speed benchmark"
)
parser
.
add_argument
(
"--num_split"
,
type
=
int
,
default
=
1
,
choices
=
[
1
,
16
],
help
=
"Number of splits"
)
args
=
parser
.
parse_args
()
args
.
test_sink
=
True
args
.
test_varlen
=
False
args
.
dtype
=
'
float16
'
args
.
dtype
=
T
.
float16
args
.
num_split
=
1
if
args
.
benchmark
:
...
...
examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py
0 → 100644
View file @
667632cc
import
torch
import
math
import
argparse
import
tilelang
import
tilelang.language
as
T
from
example_gqa_decode_varlen_logits
import
flash_attn_with_attn_pool_decode
,
repeat_kv
,
do_bench
torch
.
manual_seed
(
0
)
def
get_configs
():
import
itertools
block_N
=
[
64
,
128
]
block_H
=
[
64
]
num_split
=
[
1
]
num_stages
=
[
1
,
2
,
3
]
threads
=
[
128
]
_configs
=
list
(
itertools
.
product
(
block_N
,
block_H
,
num_split
,
num_stages
,
threads
))
configs
=
[{
"block_N"
:
c
[
0
],
"block_H"
:
c
[
1
],
"num_split"
:
c
[
2
],
"num_stages"
:
c
[
3
],
"threads"
:
c
[
4
]}
for
c
in
_configs
]
return
configs
# @autotune(configs=get_configs(), warmup=10, rep=10)
@
tilelang
.
jit
(
out_idx
=
[
-
2
,
-
1
],
debug_root_path
=
"./examples/flash_decoding"
)
def
flashattn
(
batch
,
heads
,
k_heads
,
max_seqlen_kv
,
total_seqlen_k
,
dim
,
has_sink
,
page_block_size
,
block_N
=
128
,
block_H
=
64
,
num_split
=
1
,
num_stages
=
1
,
threads
=
128
,
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape_q
=
[
batch
,
heads
,
dim
]
shape_k
=
[
total_seqlen_k
,
k_heads
,
dim
]
shape_v
=
[
total_seqlen_k
,
k_heads
,
dim
]
shape_o
=
[
batch
,
heads
,
dim
]
shape_s
=
[
batch
,
heads
,
math
.
ceil
(
max_seqlen_kv
/
block_N
)]
dtype
=
T
.
float16
accum_dtype
=
T
.
float32
kv_group_num
=
heads
//
k_heads
assert
page_block_size
>=
block_N
and
page_block_size
%
block_N
==
0
,
(
"page_block_size must be larger than block_N and a multiple of block_N"
)
valid_block_H
=
min
(
block_H
,
kv_group_num
)
# TODO: check if max_seqlen_kv is correct for varlen case
@
T
.
macro
def
flash_attn
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
cu_seqlens_k
:
T
.
Tensor
([
batch
+
1
],
T
.
int32
),
s_aux
:
T
.
Tensor
([
heads
],
T
.
float32
),
BLOCK_TABLE
:
T
.
Tensor
([
batch
,
math
.
ceil
(
max_seqlen_kv
/
block_N
)],
T
.
int32
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
S
:
T
.
Tensor
(
shape_s
,
dtype
),
):
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
O_shared
=
T
.
alloc_shared
([
valid_block_H
,
dim
],
dtype
)
acc_s
=
T
.
alloc_fragment
([
block_H
,
block_N
],
accum_dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_H
,
block_N
],
dtype
)
acc_o
=
T
.
alloc_fragment
([
block_H
,
dim
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
scores_max_prev
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
scores_scale
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
scores_sum
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
S_shared
=
T
.
alloc_shared
([
block_H
,
math
.
ceil
(
max_seqlen_kv
/
block_N
)],
dtype
)
s_aux_shared
=
T
.
alloc_shared
([
block_H
],
T
.
float32
)
bid
=
bx
hid
=
by
cur_kv_head
=
hid
//
(
kv_group_num
//
valid_block_H
)
cur_start_k
=
cu_seqlens_k
[
bid
]
cur_end_k
=
cu_seqlens_k
[
bid
+
1
]
cur_seqlen_k
=
cur_end_k
-
cur_start_k
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
# loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
loop_range
=
T
.
ceildiv
((
cur_seqlen_k
//
num_split
),
block_N
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
k_start
=
BLOCK_TABLE
[
bid
,
(
k
*
block_N
)
//
page_block_size
]
*
page_block_size
+
(
k
*
block_N
)
%
page_block_size
T
.
copy
(
K
[
cur_start_k
+
k_start
:
cur_start_k
+
k_start
+
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
k
*
block_N
+
j
<
cur_seqlen_k
,
acc_s
[
i
,
j
],
-
T
.
infinity
(
accum_dtype
))
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
# scores_max_prev is m_i
# scores_max is row_max->m_ij in triton
T
.
copy
(
scores_max
,
S_shared
[:,
k
])
# scores_scale is alpha in triton
for
i
in
T
.
Parallel
(
block_H
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
# scores_sum is l_ij in triton
# logsum is l_i in triton
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
T
.
copy
(
acc_s
,
acc_s_cast
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
v_start
=
BLOCK_TABLE
[
bid
,
(
k
*
block_N
)
//
page_block_size
]
*
page_block_size
+
(
k
*
block_N
)
%
page_block_size
T
.
copy
(
V
[
cur_start_k
+
v_start
:
cur_start_k
+
v_start
+
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
if
has_sink
:
T
.
copy
(
s_aux
[
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
],
s_aux_shared
)
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
+=
s_aux_shared
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
for
h
,
k
in
T
.
Parallel
(
block_H
,
math
.
ceil
(
max_seqlen_kv
/
block_N
)):
S_shared
[
h
,
k
]
=
T
.
exp2
((
S_shared
[
h
,
k
]
-
scores_max
[
h
])
*
scale
)
/
logsum
[
h
]
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
acc_o
[:
valid_block_H
,
:],
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bid
,
hid
*
valid_block_H
:
(
hid
+
1
)
*
valid_block_H
,
:])
T
.
copy
(
S_shared
[:
valid_block_H
,
:],
S
[
bid
,
hid
*
valid_block_H
:
(
hid
+
1
)
*
valid_block_H
,
:])
@
T
.
prim_func
def
flashattn_gqa_decode_no_split
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
cu_seqlens_k
:
T
.
Tensor
([
batch
+
1
],
T
.
int32
),
s_aux
:
T
.
Tensor
([
heads
],
T
.
float32
),
BLOCK_TABLE
:
T
.
Tensor
([
batch
,
math
.
ceil
(
max_seqlen_kv
/
page_block_size
)],
T
.
int32
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
S
:
T
.
Tensor
(
shape_s
,
dtype
),
):
flash_attn
(
Q
,
K
,
V
,
cu_seqlens_k
,
s_aux
,
BLOCK_TABLE
,
Output
,
S
)
# TODO: split version
return
flashattn_gqa_decode_no_split
def
flash_attn_with_attn_pool_decode_tilelang
(
Q
:
torch
.
Tensor
,
## [tq = b, q_h, q_dim]
K
:
torch
.
Tensor
,
## [tk, k_h, k_dim]
V
:
torch
.
Tensor
,
cu_seqlens_k
:
torch
.
Tensor
,
max_seqlen_k
:
int
,
real_max_k_seqlen
:
int
,
num_split
:
int
,
softmax_scale
:
float
,
s_aux
:
torch
.
Tensor
=
None
,
block_size
:
int
=
64
,
use_per_kv_head_sparse_index
:
bool
=
False
,
tl_kernel
=
None
,
block_table
:
torch
.
Tensor
=
None
,
):
num_tokens
,
q_h
,
head_size
=
Q
.
shape
batch
=
cu_seqlens_k
.
size
(
0
)
-
1
k_h
=
K
.
size
(
1
)
assert
Q
.
dim
()
==
K
.
dim
()
==
3
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
assert
cu_seqlens_k
.
dim
()
==
1
assert
head_size
in
{
64
,
128
,
256
}
assert
Q
.
is_contiguous
()
assert
K
.
is_contiguous
()
assert
V
.
is_contiguous
()
gqa_group_size
=
q_h
//
k_h
O_tl
=
torch
.
zeros_like
(
Q
)
S_tl
=
torch
.
zeros
((
batch
,
q_h
,
math
.
ceil
(
real_max_k_seqlen
/
block_size
)),
dtype
=
Q
.
dtype
,
device
=
Q
.
device
)
O_tl
,
S_tl
=
tl_kernel
(
Q
,
K
,
V
,
cu_seqlens_k
,
s_aux
,
block_table
)
if
use_per_kv_head_sparse_index
:
S_tl
=
torch
.
max_pool2d
(
S_tl
,
kernel_size
=
(
gqa_group_size
,
1
),
stride
=
(
gqa_group_size
,
1
))
else
:
S_tl
=
torch
.
max_pool2d
(
S_tl
,
kernel_size
=
(
q_h
,
1
),
stride
=
(
q_h
,
1
))
return
O_tl
,
S_tl
def
test_equal_seqlen_decode_main
(
args
):
"""Test decode kernel with equal sequence lengths"""
print
(
"Testing decode kernel with equal sequence lengths"
)
batch_size
=
args
.
batch_size
q_heads
=
args
.
q_heads
kv_heads
=
args
.
kv_heads
k_seqlen
=
args
.
k_seqlen
real_max_k_seqlen
=
args
.
k_seqlen
head_size
=
args
.
head_size
block_size
=
args
.
block_size
page_block_size
=
args
.
page_block_size
dtype
=
torch
.
bfloat16
if
args
.
dtype
==
T
.
bfloat16
else
torch
.
float16
# For decode, query is just 1 token per batch
q
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
"cuda"
,
dtype
=
dtype
)
k
=
torch
.
randn
(
batch_size
,
kv_heads
,
k_seqlen
,
head_size
,
device
=
"cuda"
,
dtype
=
dtype
)
v
=
torch
.
randn
(
batch_size
,
kv_heads
,
k_seqlen
,
head_size
,
device
=
"cuda"
,
dtype
=
dtype
)
softmax_scale
=
1.0
/
math
.
sqrt
(
head_size
)
# Generate sink values if needed
sink
=
None
if
args
.
test_sink
:
sink
=
torch
.
randn
(
q_heads
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
print
(
f
"Using sink attention with sink values:
{
sink
}
"
)
# Convert to varlen format for K, V
k_varlen
=
k
.
transpose
(
1
,
2
).
reshape
(
batch_size
*
k_seqlen
,
kv_heads
,
head_size
).
contiguous
()
v_varlen
=
v
.
transpose
(
1
,
2
).
reshape
(
batch_size
*
k_seqlen
,
kv_heads
,
head_size
).
contiguous
()
# Generate cumulative sequence lengths
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
k_seqlen
,
k_seqlen
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
max_seqlen_k
=
k_seqlen
print
(
f
"q shape:
{
q
.
shape
}
"
)
print
(
f
"k_varlen shape:
{
k_varlen
.
shape
}
"
)
print
(
f
"v_varlen shape:
{
v_varlen
.
shape
}
"
)
num_tokens
,
q_h
,
head_size
=
q
.
shape
batch
=
cu_seqlens_k
.
size
(
0
)
-
1
k_h
=
k_varlen
.
size
(
1
)
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
args
.
test_sink
,
page_block_size
)
block_table
=
torch
.
zeros
(
batch
,
math
.
ceil
(
real_max_k_seqlen
/
page_block_size
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
block_cnt
=
0
for
i
in
range
(
batch
):
cur_seqlen
=
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
()
for
j
in
range
(
math
.
ceil
(
cur_seqlen
/
page_block_size
)):
block_table
[
i
,
j
]
=
block_cnt
block_cnt
+=
1
block_cnt
=
0
# Test our decode kernel
O_triton
,
S_triton
=
flash_attn_with_attn_pool_decode
(
q
,
k_varlen
,
v_varlen
,
cu_seqlens_k
,
max_seqlen_k
,
real_max_k_seqlen
,
args
.
num_split
,
softmax_scale
,
s_aux
=
sink
,
block_size
=
block_size
,
)
O_tilelang
,
S_tilelang
=
flash_attn_with_attn_pool_decode_tilelang
(
q
,
k_varlen
,
v_varlen
,
cu_seqlens_k
,
max_seqlen_k
,
real_max_k_seqlen
,
args
.
num_split
,
softmax_scale
,
s_aux
=
sink
,
block_size
=
block_size
,
tl_kernel
=
tl_kernel
,
block_table
=
block_table
,
)
for
i
in
range
(
batch_size
):
S_tilelang
[
i
,
:,
math
.
ceil
((
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
())
/
block_size
)
:]
=
0
# Compute torch reference
q_expanded
=
q
.
unsqueeze
(
2
)
# [b, q_heads, 1, head_size]
k_repeat
=
repeat_kv
(
k
,
q_heads
//
kv_heads
)
# [b, q_heads, k_seqlen, head_size]
v_repeat
=
repeat_kv
(
v
,
q_heads
//
kv_heads
)
# [b, q_heads, k_seqlen, head_size]
if
sink
is
None
:
# Standard scaled dot-product attention
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [batch, q_heads, 1, seqlen_k]
attn_weights
=
torch
.
softmax
(
logits
,
dim
=-
1
)
O_torch
=
torch
.
matmul
(
attn_weights
,
v_repeat
).
squeeze
(
2
)
# [batch, q_heads, head_size]
else
:
# s_aux attention
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [batch, q_heads, 1, seqlen_k]
sink_expanded
=
sink
.
view
(
1
,
q_heads
,
1
,
1
)
# [1, q_heads, 1, 1]
logits_max
=
torch
.
max
(
logits
,
dim
=-
1
,
keepdim
=
True
).
values
logits_or_sinks_max
=
torch
.
maximum
(
logits_max
,
sink_expanded
)
sinks
=
torch
.
exp
(
sink_expanded
-
logits_or_sinks_max
)
unnormalized_scores
=
torch
.
exp
(
logits
-
logits_or_sinks_max
)
normalizer
=
unnormalized_scores
.
sum
(
dim
=-
1
,
keepdim
=
True
)
+
sinks
attn_weights
=
unnormalized_scores
/
normalizer
O_torch
=
torch
.
matmul
(
attn_weights
.
to
(
v_repeat
.
dtype
),
v_repeat
).
squeeze
(
2
)
# [batch, q_heads, head_size]
# Compute attention score pooling
attn_score_pooled
=
torch
.
max_pool2d
(
attn_weights
.
squeeze
(
2
),
# [b, q_heads, k_seqlen]
kernel_size
=
(
q_heads
,
block_size
),
stride
=
(
q_heads
,
block_size
),
ceil_mode
=
True
,
).
to
(
torch
.
float16
)
print
(
"S_tilelang"
,
S_tilelang
)
print
(
"attn_score_pooled"
,
attn_score_pooled
)
max_diff_o
=
torch
.
max
(
torch
.
abs
(
O_triton
-
O_torch
))
max_diff_s
=
torch
.
max
(
torch
.
abs
(
S_triton
-
attn_score_pooled
))
max_diff_o_tilelang
=
torch
.
max
(
torch
.
abs
(
O_tilelang
-
O_torch
))
max_diff_s_tilelang
=
torch
.
max
(
torch
.
abs
(
S_tilelang
-
attn_score_pooled
))
print
(
f
"Max difference in O:
{
max_diff_o
.
item
()
}
"
)
print
(
f
"Max difference in S:
{
max_diff_s
.
item
()
}
"
)
print
(
f
"Max difference in O_tilelang:
{
max_diff_o_tilelang
.
item
()
}
"
)
print
(
f
"Max difference in S_tilelang:
{
max_diff_s_tilelang
.
item
()
}
"
)
assert
torch
.
allclose
(
O_triton
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o
.
item
()
}
"
assert
torch
.
allclose
(
S_triton
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s
.
item
()
}
"
assert
torch
.
allclose
(
O_tilelang
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o_tilelang
.
item
()
}
"
assert
torch
.
allclose
(
S_tilelang
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s_tilelang
.
item
()
}
"
print
(
"✅ All tests passed!"
)
def
test_varlen_decode_main
(
args
):
"""Test decode kernel with variable sequence lengths"""
batch_size
=
args
.
batch_size
q_heads
=
args
.
q_heads
kv_heads
=
args
.
kv_heads
max_k_seqlen
=
args
.
k_seqlen
# Use as max sequence length
real_max_k_seqlen
=
args
.
k_seqlen
head_size
=
args
.
head_size
block_size
=
args
.
block_size
page_block_size
=
args
.
page_block_size
dtype
=
torch
.
bfloat16
if
args
.
dtype
==
T
.
bfloat16
else
torch
.
float16
print
(
f
"Testing decode kernel with variable sequence lengths (max_k_seqlen=
{
max_k_seqlen
}
)"
)
# Generate sink values if needed
sink
=
None
if
args
.
test_sink
:
sink
=
torch
.
randn
(
q_heads
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
print
(
f
"Using sink attention with sink values:
{
sink
}
"
)
# Generate variable length k sequences
k_seqlens
=
torch
.
randint
(
max_k_seqlen
//
4
,
max_k_seqlen
+
1
,
size
=
(
batch_size
,))
print
(
f
"k_seqlens:
{
k_seqlens
}
"
)
# Generate cumulative sequence lengths for k
cu_seqlens_k
=
torch
.
zeros
(
batch_size
+
1
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
total_k_tokens
=
0
for
i
in
range
(
batch_size
):
cu_seqlens_k
[
i
]
=
total_k_tokens
total_k_tokens
+=
k_seqlens
[
i
]
cu_seqlens_k
[
batch_size
]
=
total_k_tokens
print
(
f
"cu_seqlens_k:
{
cu_seqlens_k
}
"
)
# Generate tensors - Q is [batch_size, q_heads, head_size] for decode
q_decode
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
"cuda"
,
dtype
=
dtype
)
k_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
"cuda"
,
dtype
=
dtype
)
v_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
"cuda"
,
dtype
=
dtype
)
softmax_scale
=
1.0
/
math
.
sqrt
(
head_size
)
max_seqlen_k
=
int
(
k_seqlens
.
max
())
print
(
f
"Actual max_seqlen_k:
{
max_seqlen_k
}
"
)
print
(
f
"q_decode shape:
{
q_decode
.
shape
}
"
)
print
(
f
"k_varlen shape:
{
k_varlen
.
shape
}
"
)
print
(
f
"v_varlen shape:
{
v_varlen
.
shape
}
"
)
num_tokens
,
q_h
,
head_size
=
q_decode
.
shape
batch
=
cu_seqlens_k
.
size
(
0
)
-
1
k_h
=
k_varlen
.
size
(
1
)
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
args
.
test_sink
,
page_block_size
)
block_table
=
torch
.
zeros
(
batch
,
math
.
ceil
(
real_max_k_seqlen
/
page_block_size
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
block_cnt
=
0
for
i
in
range
(
batch
):
cur_seqlen
=
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
()
for
j
in
range
(
math
.
ceil
(
cur_seqlen
/
page_block_size
)):
block_table
[
i
,
j
]
=
block_cnt
block_cnt
+=
1
block_cnt
=
0
# Test our decode kernel
O_triton
,
S_triton
=
flash_attn_with_attn_pool_decode
(
q_decode
,
k_varlen
,
v_varlen
,
cu_seqlens_k
,
max_seqlen_k
,
real_max_k_seqlen
,
args
.
num_split
,
softmax_scale
,
s_aux
=
sink
,
block_size
=
block_size
,
)
O_tilelang
,
S_tilelang
=
flash_attn_with_attn_pool_decode_tilelang
(
q_decode
,
k_varlen
,
v_varlen
,
cu_seqlens_k
,
max_seqlen_k
,
real_max_k_seqlen
,
args
.
num_split
,
softmax_scale
,
s_aux
=
sink
,
block_size
=
block_size
,
tl_kernel
=
tl_kernel
,
block_table
=
block_table
,
)
for
i
in
range
(
batch_size
):
S_tilelang
[
i
,
:,
math
.
ceil
((
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
())
/
block_size
)
:]
=
0
# Create torch reference - pad tensors for comparison
k_padded_list
=
[]
v_padded_list
=
[]
for
i
in
range
(
batch_size
):
actual_k_len
=
k_seqlens
[
i
]
# Extract and pad k, v for this batch
k_start
=
cu_seqlens_k
[
i
]
k_end
=
cu_seqlens_k
[
i
+
1
]
# Pad to max_seqlen_k
k_padded
=
torch
.
zeros
(
max_seqlen_k
,
kv_heads
,
head_size
,
device
=
"cuda"
,
dtype
=
dtype
)
v_padded
=
torch
.
zeros
(
max_seqlen_k
,
kv_heads
,
head_size
,
device
=
"cuda"
,
dtype
=
dtype
)
k_padded
[:
actual_k_len
]
=
k_varlen
[
k_start
:
k_end
]
v_padded
[:
actual_k_len
]
=
v_varlen
[
k_start
:
k_end
]
k_padded_list
.
append
(
k_padded
)
v_padded_list
.
append
(
v_padded
)
# Stack to create batched tensors [b, max_seqlen, kv_heads, head_size]
k_padded_batched
=
torch
.
stack
(
k_padded_list
,
dim
=
0
).
transpose
(
1
,
2
)
# [b, kv_heads, max_seqlen, head_size]
v_padded_batched
=
torch
.
stack
(
v_padded_list
,
dim
=
0
).
transpose
(
1
,
2
)
# [b, kv_heads, max_seqlen, head_size]
# Expand q to match kv heads: [b, q_heads, 1, head_size]
q_expanded
=
q_decode
.
unsqueeze
(
2
)
# [b, q_heads, 1, head_size]
print
(
f
"q_expanded shape:
{
q_expanded
.
shape
}
"
)
print
(
f
"k_padded_batched shape:
{
k_padded_batched
.
shape
}
"
)
print
(
f
"v_padded_batched shape:
{
v_padded_batched
.
shape
}
"
)
# Compute torch reference
k_repeat
=
repeat_kv
(
k_padded_batched
,
q_heads
//
kv_heads
)
# [b, q_heads, max_seqlen, head_size]
v_repeat
=
repeat_kv
(
v_padded_batched
,
q_heads
//
kv_heads
)
# [b, q_heads, max_seqlen, head_size]
if
sink
is
None
:
# Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen]
attn_score
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [b, q_heads, 1, max_seqlen]
# Apply sequence length masking
for
i
in
range
(
batch_size
):
actual_k_len
=
k_seqlens
[
i
]
attn_score
[
i
,
:,
:,
actual_k_len
:]
=
float
(
"-inf"
)
attn_weights
=
attn_score
.
softmax
(
dim
=-
1
)
# [b, q_heads, 1, max_seqlen]
# Mask out invalid positions
for
i
in
range
(
batch_size
):
actual_k_len
=
k_seqlens
[
i
]
attn_weights
[
i
,
:,
:,
actual_k_len
:]
=
0.0
# Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size]
O_torch
=
torch
.
matmul
(
attn_weights
,
v_repeat
)
# [b, q_heads, 1, head_size]
else
:
# s_aux attention
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [b, q_heads, 1, max_seqlen]
# Apply sequence length masking
for
i
in
range
(
batch_size
):
actual_k_len
=
k_seqlens
[
i
]
logits
[
i
,
:,
:,
actual_k_len
:]
=
float
(
"-inf"
)
sink_expanded
=
sink
.
view
(
1
,
q_heads
,
1
,
1
)
# [1, q_heads, 1, 1]
logits_max
=
torch
.
max
(
logits
,
dim
=-
1
,
keepdim
=
True
).
values
logits_or_sinks_max
=
torch
.
maximum
(
logits_max
,
sink_expanded
)
sinks
=
torch
.
exp
(
sink_expanded
-
logits_or_sinks_max
)
unnormalized_scores
=
torch
.
exp
(
logits
-
logits_or_sinks_max
)
normalizer
=
unnormalized_scores
.
sum
(
dim
=-
1
,
keepdim
=
True
)
+
sinks
attn_weights
=
unnormalized_scores
/
normalizer
# Mask out invalid positions
for
i
in
range
(
batch_size
):
actual_k_len
=
k_seqlens
[
i
]
attn_weights
[
i
,
:,
:,
actual_k_len
:]
=
0.0
# Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size]
O_torch
=
torch
.
matmul
(
attn_weights
.
to
(
v_repeat
.
dtype
),
v_repeat
)
# [b, q_heads, 1, head_size]
O_torch
=
O_torch
.
squeeze
(
2
)
# [b, q_heads, head_size]
# Compute attention score pooling for S
attn_score_pooled
=
torch
.
max_pool2d
(
attn_weights
.
squeeze
(
2
),
# [b, q_heads, max_seqlen]
kernel_size
=
(
q_heads
,
block_size
),
stride
=
(
q_heads
,
block_size
),
ceil_mode
=
True
,
).
to
(
dtype
=
torch
.
float16
)
# [b, 1, ceil(max_seqlen/block_size)]
print
(
f
"O_triton shape:
{
O_triton
.
shape
}
"
)
print
(
f
"O_tilelang shape:
{
O_tilelang
.
shape
}
"
)
print
(
f
"O_torch shape:
{
O_torch
.
shape
}
"
)
print
(
f
"S_triton shape:
{
S_triton
.
shape
}
"
)
print
(
f
"S_tilelang shape:
{
S_tilelang
.
shape
}
"
)
print
(
f
"attn_score_pooled shape:
{
attn_score_pooled
.
shape
}
"
)
# Compare results
max_diff_o
=
torch
.
max
(
torch
.
abs
(
O_triton
-
O_torch
))
max_diff_o_tl
=
torch
.
max
(
torch
.
abs
(
O_tilelang
-
O_torch
))
print
(
f
"Max difference in O:
{
max_diff_o
.
item
()
}
"
)
print
(
f
"Max difference in O_tilelang:
{
max_diff_o_tl
.
item
()
}
"
)
max_diff_s
=
torch
.
max
(
torch
.
abs
(
S_triton
-
attn_score_pooled
))
max_diff_s_tl
=
torch
.
max
(
torch
.
abs
(
S_tilelang
[:,
:,
:
math
.
ceil
(
max_seqlen_k
/
block_size
)]
-
attn_score_pooled
))
print
(
f
"Max difference in S:
{
max_diff_s
.
item
()
}
"
)
print
(
f
"Max difference in S_tilelang:
{
max_diff_s_tl
.
item
()
}
"
)
assert
torch
.
allclose
(
O_triton
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o
.
item
()
}
"
assert
torch
.
allclose
(
S_triton
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s
.
item
()
}
"
assert
torch
.
allclose
(
O_tilelang
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o_tl
.
item
()
}
"
assert
torch
.
allclose
(
S_tilelang
[:,
:,
:
math
.
ceil
(
max_seqlen_k
/
block_size
)],
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
(
f
"Score mismatch:
{
max_diff_s_tl
.
item
()
}
"
)
print
(
"✅ All tests passed!"
)
def
speed_benchmark_decode_comparison
(
args
):
"""Speed benchmark for decode kernel"""
batch_size
=
args
.
batch_size
q_heads
=
args
.
q_heads
kv_heads
=
args
.
kv_heads
max_k_seqlen
=
args
.
k_seqlen
real_max_k_seqlen
=
args
.
k_seqlen
head_size
=
args
.
head_size
block_size
=
args
.
block_size
page_block_size
=
args
.
page_block_size
dtype
=
torch
.
bfloat16
if
args
.
dtype
==
T
.
bfloat16
else
torch
.
float16
print
(
"
\n
=== Decode Speed Benchmark Comparison ==="
)
print
(
"Configuration:"
)
print
(
f
" Batch size:
{
batch_size
}
"
)
print
(
f
" Q heads:
{
q_heads
}
, KV heads:
{
kv_heads
}
"
)
print
(
f
" Max K sequence length:
{
max_k_seqlen
}
"
)
print
(
f
" Head size:
{
head_size
}
"
)
print
(
f
" Block size:
{
block_size
}
"
)
print
(
f
" Data type:
{
dtype
}
"
)
print
(
f
" Variable lengths:
{
args
.
test_varlen
}
"
)
print
(
f
" s_aux attention:
{
args
.
test_sink
}
"
)
print
()
# Generate input data
if
args
.
test_varlen
:
k_seqlens
=
torch
.
randint
(
max_k_seqlen
//
4
,
max_k_seqlen
+
1
,
size
=
(
batch_size
,))
else
:
k_seqlens
=
torch
.
full
((
batch_size
,),
max_k_seqlen
,
dtype
=
int
)
# Generate cumulative sequence lengths for k
cu_seqlens_k
=
torch
.
zeros
(
batch_size
+
1
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
total_k_tokens
=
0
for
i
in
range
(
batch_size
):
cu_seqlens_k
[
i
]
=
total_k_tokens
total_k_tokens
+=
k_seqlens
[
i
]
cu_seqlens_k
[
batch_size
]
=
total_k_tokens
# Generate tensors
q_decode
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
"cuda"
,
dtype
=
dtype
)
k_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
"cuda"
,
dtype
=
dtype
)
v_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
"cuda"
,
dtype
=
dtype
)
softmax_scale
=
1.0
/
math
.
sqrt
(
head_size
)
max_seqlen_k
=
int
(
k_seqlens
.
max
())
# Generate sink values if needed
sink
=
None
if
args
.
test_sink
:
sink
=
torch
.
randn
(
q_heads
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
print
(
" Using sink attention with sink values"
)
print
(
"Setup complete:"
)
print
(
f
" Total K tokens:
{
total_k_tokens
}
"
)
print
(
f
" Actual max K seq len:
{
max_seqlen_k
}
"
)
if
args
.
test_varlen
:
print
(
f
" K sequence lengths:
{
k_seqlens
.
tolist
()
}
"
)
# Warmup
num_tokens
,
q_h
,
head_size
=
q_decode
.
shape
batch
=
cu_seqlens_k
.
size
(
0
)
-
1
k_h
=
k_varlen
.
size
(
1
)
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
args
.
test_sink
,
page_block_size
)
block_table
=
torch
.
zeros
(
batch
,
math
.
ceil
(
real_max_k_seqlen
/
page_block_size
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
block_cnt
=
0
for
i
in
range
(
batch
):
cur_seqlen
=
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
()
for
j
in
range
(
math
.
ceil
(
cur_seqlen
/
page_block_size
)):
block_table
[
i
,
j
]
=
block_cnt
block_cnt
+=
1
block_cnt
=
0
# Benchmark
print
(
"⚡ Benchmarking Tilelang kernel (100 iterations)..."
)
tilelang_time
=
do_bench
(
flash_attn_with_attn_pool_decode_tilelang
,
q_decode
,
k_varlen
,
v_varlen
,
cu_seqlens_k
,
max_seqlen_k
,
args
.
k_seqlen
,
1
,
softmax_scale
,
sink
,
block_size
,
False
,
tl_kernel
,
block_table
,
)
print
(
f
"Average decode kernel time Tilelang:
{
tilelang_time
:.
3
f
}
ms"
)
# Benchmark
print
(
"⚡ Benchmarking Triton kernel (100 iterations)..."
)
triton_time
=
do_bench
(
flash_attn_with_attn_pool_decode
,
q_decode
,
k_varlen
,
v_varlen
,
cu_seqlens_k
,
max_seqlen_k
,
args
.
k_seqlen
,
1
,
softmax_scale
,
sink
,
block_size
,
)
print
(
f
"Average decode kernel time Triton:
{
triton_time
:.
3
f
}
ms"
)
print
(
f
"Speedup:
{
(
triton_time
/
tilelang_time
):.
3
f
}
"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Flash Attention Decode with Attention Pooling"
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
,
help
=
"Batch size"
)
parser
.
add_argument
(
"--q_heads"
,
type
=
int
,
default
=
32
,
help
=
"Number of query heads"
)
parser
.
add_argument
(
"--kv_heads"
,
type
=
int
,
default
=
8
,
help
=
"Number of key-value heads"
)
parser
.
add_argument
(
"--k_seqlen"
,
type
=
int
,
default
=
8192
,
help
=
"Key sequence length"
)
parser
.
add_argument
(
"--head_size"
,
type
=
int
,
default
=
128
,
choices
=
[
64
,
128
,
256
],
help
=
"Head dimension"
)
parser
.
add_argument
(
"--block_size"
,
type
=
int
,
default
=
128
,
help
=
"Block size for computation"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
T
.
bfloat16
,
choices
=
[
T
.
float16
,
T
.
bfloat16
],
help
=
"Data type"
)
parser
.
add_argument
(
"--test_varlen"
,
action
=
"store_true"
,
help
=
"Test with truly variable sequence lengths"
)
parser
.
add_argument
(
"--test_sink"
,
action
=
"store_true"
,
help
=
"Test with sink attention mechanism"
)
parser
.
add_argument
(
"--benchmark"
,
action
=
"store_true"
,
help
=
"Run speed benchmark"
)
parser
.
add_argument
(
"--num_split"
,
type
=
int
,
default
=
1
,
choices
=
[
1
,
16
],
help
=
"Number of splits"
)
parser
.
add_argument
(
"--page_block_size"
,
type
=
int
,
default
=
128
,
help
=
"Page block size"
)
args
=
parser
.
parse_args
()
args
.
test_sink
=
True
args
.
test_varlen
=
True
args
.
dtype
=
T
.
float16
args
.
num_split
=
1
if
args
.
benchmark
:
speed_benchmark_decode_comparison
(
args
)
elif
args
.
test_varlen
:
test_varlen_decode_main
(
args
)
else
:
test_equal_seqlen_decode_main
(
args
)
examples/flash_decoding/example_mha_inference.py
View file @
667632cc
...
...
@@ -10,12 +10,12 @@ num_split = 4
@
tilelang
.
jit
(
out_idx
=
[
5
])
def
flashattn
(
batch
,
heads
,
seqlen_q
,
seqlen_kv
,
dim
,
is_causal
,
block_M
,
block_N
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape_q
=
[
batch
,
seqlen_q
,
heads
,
dim
]
shape_kv
=
[
batch
,
seqlen_kv
,
heads
,
dim
]
part_shape
=
[
batch
,
seqlen_q
,
heads
,
num_split
,
dim
]
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
@
T
.
macro
def
MMA0
(
...
...
@@ -29,14 +29,11 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
bid
:
T
.
int32
,
sid
:
T
.
int32
,
):
T
.
copy
(
K
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
block_N
:(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
block_N
,
hid
,
:],
K_shared
)
T
.
copy
(
K
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
block_N
:
(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
block_N
,
hid
,
:],
K_shared
)
# TODO: Handle causal split case
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
mid
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
mid
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
...
@@ -52,24 +49,24 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
bid
:
T
.
int32
,
sid
:
T
.
int32
,
):
T
.
copy
(
V
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
block_N
:(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
block_N
,
hid
,
:],
V_shared
)
T
.
copy
(
V
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
block_N
:
(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
block_N
,
hid
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
def
Softmax
(
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
...
...
@@ -89,23 +86,21 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
@
T
.
macro
def
Rescale
(
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
@
T
.
macro
def
flash_attn_split
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_kv
,
dtype
),
V
:
T
.
Tensor
(
shape_kv
,
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
seqlen_q
],
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_kv
,
dtype
),
V
:
T
.
Tensor
(
shape_kv
,
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
seqlen_q
],
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
seqlen_q
,
block_M
),
heads
*
batch
,
num_split
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
seqlen_q
,
block_M
),
heads
*
batch
,
num_split
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
...
...
@@ -126,39 +121,36 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
# NOTE(wt): tma barrier has some problems with padded dimensions (seq_q here) currently
# disable relevant tma copy and use SIMT as fallback for now
T
.
copy
(
Q
[
bid
,
mid
*
block_M
:
(
mid
+
1
)
*
block_M
,
hid
,
:],
Q_shared
,
disable_tma
=
True
)
T
.
copy
(
Q
[
bid
,
mid
*
block_M
:
(
mid
+
1
)
*
block_M
,
hid
,
:],
Q_shared
,
disable_tma
=
True
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
# TODO: Handle causal split case
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
seqlen_kv
,
block_N
),
T
.
ceildiv
(
(
mid
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
(
seqlen_kv
//
num_split
),
block_N
))
T
.
min
(
T
.
ceildiv
(
seqlen_kv
,
block_N
),
T
.
ceildiv
((
mid
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
((
seqlen_kv
//
num_split
),
block_N
)
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
2
):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
mid
,
hid
,
bid
,
sid
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
hid
,
bid
,
sid
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
glse
[
bid
,
hid
,
sid
,
mid
*
block_M
:
(
mid
+
1
)
*
block_M
])
T
.
copy
(
logsum
,
glse
[
bid
,
hid
,
sid
,
mid
*
block_M
:
(
mid
+
1
)
*
block_M
])
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output_partial
[
bid
,
mid
*
block_M
:(
mid
+
1
)
*
block_M
,
hid
,
sid
,
:],
disable_tma
=
True
)
T
.
copy
(
O_shared
,
Output_partial
[
bid
,
mid
*
block_M
:
(
mid
+
1
)
*
block_M
,
hid
,
sid
,
:],
disable_tma
=
True
)
@
T
.
macro
def
combine
(
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
seqlen_q
],
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
Output
:
T
.
Tensor
(
shape_q
,
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
seqlen_q
],
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
Output
:
T
.
Tensor
(
shape_q
,
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
seqlen_q
,
block_M
),
heads
,
batch
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
po_local
=
T
.
alloc_fragment
([
block_M
,
dim
],
dtype
)
...
...
@@ -171,20 +163,25 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
lse_max_local
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scale_local
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
annotate_layout
({
o_accum_local
:
T
.
Fragment
(
o_accum_local
.
shape
,
forward_thread_fn
=
lambda
i
,
j
:
i
),
o_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
o_shared
),
po_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
po_shared
),
})
T
.
annotate_layout
(
{
o_accum_local
:
T
.
Fragment
(
o_accum_local
.
shape
,
forward_thread_fn
=
lambda
i
,
j
:
i
),
o_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
o_shared
),
po_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
po_shared
),
}
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
o_accum_local
)
T
.
copy
(
glse
[
bz
,
by
,
:,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
],
lse_local
)
T
.
copy
(
glse
[
bz
,
by
,
:,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
],
lse_local
,
)
T
.
reduce_max
(
lse_local
,
lse_max_local
,
dim
=
0
,
clear
=
False
)
for
k
in
T
.
Pipelined
(
num_split
):
T
.
copy
(
lse_local
[
k
,
:],
lse_local_split
)
...
...
@@ -193,10 +190,7 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
for
i
in
T
.
Parallel
(
block_M
):
lse_logsum_local
[
i
]
=
T
.
log2
(
lse_logsum_local
[
i
])
+
lse_max_local
[
i
]
for
k
in
T
.
Pipelined
(
num_split
,
num_stages
=
2
):
T
.
copy
(
Output_partial
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
,
k
,
:],
po_shared
,
disable_tma
=
True
)
T
.
copy
(
Output_partial
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
k
,
:],
po_shared
,
disable_tma
=
True
)
T
.
copy
(
po_shared
,
po_local
)
for
i
in
T
.
Parallel
(
block_M
):
lse_local_split
[
i
]
=
lse_local
[
k
,
i
]
...
...
@@ -205,16 +199,16 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
o_accum_local
[
i
,
j
]
+=
po_local
[
i
,
j
]
*
scale_local
[
i
]
T
.
copy
(
o_accum_local
,
o_shared
)
T
.
copy
(
o_shared
,
Output
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:],
disable_tma
=
True
)
T
.
copy
(
o_shared
,
Output
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:],
disable_tma
=
True
)
@
T
.
prim_func
def
flashattn_mha_inference
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_kv
,
dtype
),
V
:
T
.
Tensor
(
shape_kv
,
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
seqlen_q
],
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
# [batch, seqlen_q, heads, num_split, dim]
Output
:
T
.
Tensor
(
shape_q
,
dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_kv
,
dtype
),
V
:
T
.
Tensor
(
shape_kv
,
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
seqlen_q
],
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
# [batch, seqlen_q, heads, num_split, dim]
Output
:
T
.
Tensor
(
shape_q
,
dtype
),
):
flash_attn_split
(
Q
,
K
,
V
,
glse
,
Output_partial
)
combine
(
glse
,
Output_partial
,
Output
)
...
...
@@ -225,10 +219,10 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
def
ref_program
(
Q
,
K
,
V
,
glse
,
Output_partial
,
causal
):
assert
causal
is
False
dim
=
Q
.
size
(
-
1
)
scores
=
torch
.
einsum
(
'
bqhd,bkhd->bhqk
'
,
Q
,
K
)
scores
=
torch
.
einsum
(
"
bqhd,bkhd->bhqk
"
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'
bhqk,bkhd->bqhd
'
,
attention_weights
,
V
)
output
=
torch
.
einsum
(
"
bhqk,bkhd->bqhd
"
,
attention_weights
,
V
)
return
output
...
...
@@ -256,7 +250,7 @@ def flash_split_ref(Q, K, V, causal):
block_N
=
128
seqlen_kv
=
K
.
size
(
1
)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
acc_s
=
torch
.
empty
((
batch
,
nheads
,
block_M
,
block_N
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
acc_s_cast
=
torch
.
empty
((
batch
,
nheads
,
block_M
,
block_N
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
acc_o
=
torch
.
empty
((
batch
,
block_M
,
nheads
,
dim
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
...
...
@@ -273,14 +267,15 @@ def flash_split_ref(Q, K, V, causal):
for
ks
in
range
(
num_split
):
acc_o
.
fill_
(
0
)
logsum
.
fill_
(
0
)
scores_max
.
fill_
(
float
(
'
-inf
'
))
scores_max_prev
.
fill_
(
float
(
'
-inf
'
))
scores_max
.
fill_
(
float
(
"
-inf
"
))
scores_max_prev
.
fill_
(
float
(
"
-inf
"
))
for
i
in
range
(
int
((
seqlen_kv
//
num_split
)
/
block_N
)):
acc_s
.
fill_
(
0
)
acc_s
=
torch
.
einsum
(
'bqhd,bkhd->bhqk'
,
Q_
,
K
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:,
:])
# [batch, seqlen, nheads, block_N]
acc_s
=
torch
.
einsum
(
"bqhd,bkhd->bhqk"
,
Q_
,
K
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:
(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:,
:],
)
# [batch, seqlen, nheads, block_N]
scores_max_prev
=
scores_max
scores_max
=
acc_s
.
max
(
dim
=-
1
,
keepdim
=
False
).
values
# [blockM]
scores_scale
=
torch
.
exp2
(
scores_max_prev
-
scores_max
)
...
...
@@ -288,9 +283,10 @@ def flash_split_ref(Q, K, V, causal):
acc_s
=
torch
.
exp2
(
acc_s
-
scores_max
[:,
:,
:,
None
])
acc_s_cast
=
acc_s
.
to
(
torch
.
float16
)
acc_o
+=
torch
.
einsum
(
'bhqk,bkhd->bqhd'
,
acc_s_cast
,
V
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:,
:])
"bhqk,bkhd->bqhd"
,
acc_s_cast
,
V
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:
(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:,
:],
)
scores_sum
=
acc_s
.
sum
(
dim
=-
1
,
keepdim
=
False
)
logsum
=
logsum
*
scores_scale
+
scores_sum
acc_o
/=
logsum
[:,
:,
:,
None
].
transpose
(
1
,
2
)
...
...
@@ -298,8 +294,7 @@ def flash_split_ref(Q, K, V, causal):
gacc_o
[
ks
,
:,
:,
:,
:]
=
acc_o
glogsum
[
ks
,
:,
:,
:]
=
logsum
return
glogsum
.
to
(
torch
.
float16
).
permute
(
1
,
2
,
0
,
3
),
gacc_o
.
to
(
torch
.
float16
).
permute
(
1
,
2
,
3
,
0
,
4
)
return
glogsum
.
to
(
torch
.
float16
).
permute
(
1
,
2
,
0
,
3
),
gacc_o
.
to
(
torch
.
float16
).
permute
(
1
,
2
,
3
,
0
,
4
)
def
main
(
BATCH
=
1
,
H
=
32
,
Q_CTX
=
128
,
KV_CTX
=
8192
,
D_HEAD
=
128
,
causal
=
False
):
...
...
examples/fusedmoe/example_fusedmoe_tilelang.py
View file @
667632cc
...
...
@@ -9,17 +9,18 @@ from example_fusedmoe_torch import *
@
tilelang
.
jit
(
pass_configs
=
{
"tl.disable_tma_lower"
:
True
,
"tl.disable_warp_specialized"
:
True
})
def
moe_forward_tilelang_shared
(
d_hidden
,
d_expert
,
n_shared_experts
,
dtype
,
num_tokens
,
block_token
=
128
,
block_dhidden
=
128
,
block_dexpert
=
128
,
threads
=
256
,
num_stages
=
1
):
def
moe_forward_tilelang_shared
(
d_hidden
,
d_expert
,
n_shared_experts
,
dtype
,
num_tokens
,
block_token
=
128
,
block_dhidden
=
128
,
block_dexpert
=
128
,
threads
=
256
,
num_stages
=
1
,
):
scale
=
1.44269504
# log2(e)
# Parameters
...
...
@@ -32,21 +33,19 @@ def moe_forward_tilelang_shared(d_hidden,
shared_W_up_shape
=
(
dexpert
,
dhidden
)
shared_W_down_shape
=
(
dhidden
,
dexpert
)
accum_type
=
"
float32
"
accum_type
=
T
.
float32
@
T
.
prim_func
def
kernel_shared
(
input
:
T
.
Tensor
(
input_shape
,
dtype
),
# type: ignore
shared_W_gate
:
T
.
Tensor
(
shared_W_gate_shape
,
dtype
),
# type: ignore
shared_W_up
:
T
.
Tensor
(
shared_W_up_shape
,
dtype
),
# type: ignore
shared_W_down
:
T
.
Tensor
(
shared_W_down_shape
,
dtype
),
# type: ignore
up_logits
:
T
.
Tensor
((
num_tokens
,
dexpert
),
dtype
),
# type: ignore
output
:
T
.
Tensor
(
input_shape
,
dtype
),
# type: ignore
input
:
T
.
Tensor
(
input_shape
,
dtype
),
# type: ignore
shared_W_gate
:
T
.
Tensor
(
shared_W_gate_shape
,
dtype
),
# type: ignore
shared_W_up
:
T
.
Tensor
(
shared_W_up_shape
,
dtype
),
# type: ignore
shared_W_down
:
T
.
Tensor
(
shared_W_down_shape
,
dtype
),
# type: ignore
up_logits
:
T
.
Tensor
((
num_tokens
,
dexpert
),
dtype
),
# type: ignore
output
:
T
.
Tensor
(
input_shape
,
dtype
),
# type: ignore
):
# Step 1: Compute gate and up logits
with
T
.
Kernel
(
T
.
ceildiv
(
num_tokens
,
block_token
),
T
.
ceildiv
(
dexpert
,
block_dexpert
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
num_tokens
,
block_token
),
T
.
ceildiv
(
dexpert
,
block_dexpert
),
threads
=
threads
)
as
(
bx
,
by
):
# Split the block to shared experts and routed experts
input_shared
=
T
.
alloc_fragment
((
block_token
,
block_dhidden
),
dtype
=
dtype
)
W_gate_shared
=
T
.
alloc_shared
((
block_dexpert
,
block_dhidden
),
dtype
=
dtype
)
...
...
@@ -70,16 +69,13 @@ def moe_forward_tilelang_shared(d_hidden,
# Fuse with SiLU and element-wise product
for
i
,
j
in
T
.
Parallel
(
block_token
,
block_dexpert
):
gate_logits_local
[
i
,
j
]
=
gate_logits_local
[
i
,
j
]
*
(
1.0
/
(
1.0
+
T
.
exp2
(
-
gate_logits_local
[
i
,
j
]
*
scale
)))
gate_logits_local
[
i
,
j
]
=
gate_logits_local
[
i
,
j
]
*
(
1.0
/
(
1.0
+
T
.
exp2
(
-
gate_logits_local
[
i
,
j
]
*
scale
)))
up_logits_local
[
i
,
j
]
=
up_logits_local
[
i
,
j
]
*
gate_logits_local
[
i
,
j
]
T
.
copy
(
up_logits_local
,
up_logits
[
bx
*
block_token
,
by
*
block_dexpert
])
# Step 2: Compute down logits
with
T
.
Kernel
(
T
.
ceildiv
(
num_tokens
,
block_token
),
T
.
ceildiv
(
dhidden
,
block_dhidden
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
num_tokens
,
block_token
),
T
.
ceildiv
(
dhidden
,
block_dhidden
),
threads
=
threads
)
as
(
bx
,
by
):
up_logits_shared
=
T
.
alloc_fragment
((
block_token
,
block_dexpert
),
dtype
=
dtype
)
W_down_shared
=
T
.
alloc_shared
((
block_dhidden
,
block_dexpert
),
dtype
=
dtype
)
output_local
=
T
.
alloc_fragment
((
block_token
,
block_dhidden
),
dtype
=
accum_type
)
...
...
@@ -98,20 +94,21 @@ def moe_forward_tilelang_shared(d_hidden,
@
tilelang
.
jit
(
pass_configs
=
{
"tl.disable_tma_lower"
:
True
,
"tl.disable_warp_specialized"
:
True
})
def
moe_forward_tilelang_routed
(
d_hidden
,
d_expert
,
n_routed_experts
,
dtype
,
group_sum
,
group_count
,
block_token
=
128
,
block_dhidden
=
128
,
block_dexpert
=
128
,
threads
=
256
,
num_stages
=
1
,
k_pack
=
1
,
coalesced_width
=
None
):
def
moe_forward_tilelang_routed
(
d_hidden
,
d_expert
,
n_routed_experts
,
dtype
,
group_sum
,
group_count
,
block_token
=
128
,
block_dhidden
=
128
,
block_dexpert
=
128
,
threads
=
256
,
num_stages
=
1
,
k_pack
=
1
,
coalesced_width
=
None
,
):
scale
=
1.44269504
# log2(e)
# Parameters
...
...
@@ -124,7 +121,7 @@ def moe_forward_tilelang_routed(d_hidden,
# group_count = len(group_sizes_list)
# M = sum([(group_size + block_token - 1) // block_token for group_size in group_sizes_list])
M
=
math
.
ceil
(
group_sum
/
block_token
)
+
group_count
accum_dtype
=
"
float32
"
accum_dtype
=
T
.
float32
# Tensors: Note that input shape is reshape to (bs * seq_len * n_experts_per_token, dhidden) for grouped gemm
input_shape
=
(
group_sum
,
dhidden
)
...
...
@@ -132,22 +129,22 @@ def moe_forward_tilelang_routed(d_hidden,
routed_expert_gate_shape
=
(
n_routed_experts
,
dexpert
,
dhidden
)
routed_expert_up_shape
=
(
n_routed_experts
,
dexpert
,
dhidden
)
routed_expert_down_shape
=
(
n_routed_experts
,
dhidden
,
dexpert
)
routed_expert_weights_shape
=
(
group_sum
)
group_sizes_shape
=
(
n_routed_experts
)
routed_expert_weights_shape
=
group_sum
group_sizes_shape
=
n_routed_experts
@
T
.
prim_func
def
kernel
(
input
:
T
.
Tensor
(
input_shape
,
dtype
),
# type: ignore
routed_expert_gate
:
T
.
Tensor
(
routed_expert_gate_shape
,
dtype
),
# type: ignore
routed_expert_up
:
T
.
Tensor
(
routed_expert_up_shape
,
dtype
),
# type: ignore
routed_expert_down
:
T
.
Tensor
(
routed_expert_down_shape
,
dtype
),
# type: ignore
routed_expert_weights
:
T
.
Tensor
(
routed_expert_weights_shape
,
dtype
),
# type: ignore
group_sizes
:
T
.
Tensor
(
group_sizes_shape
,
"
int32
"
),
# type: ignore
group_offsets
:
T
.
Tensor
(
group_sizes_shape
,
"
int32
"
),
# type: ignore
group_padded_offsets
:
T
.
Tensor
(
group_sizes_shape
,
"
int32
"
),
# type: ignore
group_idx_for_bx
:
T
.
Tensor
((
M
,),
"
int32
"
),
# type: ignore
up_logits
:
T
.
Tensor
(
intermediate_shape
,
dtype
),
# type: ignore
output
:
T
.
Tensor
(
input_shape
,
dtype
),
# type: ignore
input
:
T
.
Tensor
(
input_shape
,
dtype
),
# type: ignore
routed_expert_gate
:
T
.
Tensor
(
routed_expert_gate_shape
,
dtype
),
# type: ignore
routed_expert_up
:
T
.
Tensor
(
routed_expert_up_shape
,
dtype
),
# type: ignore
routed_expert_down
:
T
.
Tensor
(
routed_expert_down_shape
,
dtype
),
# type: ignore
routed_expert_weights
:
T
.
Tensor
(
routed_expert_weights_shape
,
dtype
),
# type: ignore
group_sizes
:
T
.
Tensor
(
group_sizes_shape
,
T
.
int32
),
# type: ignore
group_offsets
:
T
.
Tensor
(
group_sizes_shape
,
T
.
int32
),
# type: ignore
group_padded_offsets
:
T
.
Tensor
(
group_sizes_shape
,
T
.
int32
),
# type: ignore
group_idx_for_bx
:
T
.
Tensor
((
M
,),
T
.
int32
),
# type: ignore
up_logits
:
T
.
Tensor
(
intermediate_shape
,
dtype
),
# type: ignore
output
:
T
.
Tensor
(
input_shape
,
dtype
),
# type: ignore
):
# Step 1: Compute gate and up logits
with
T
.
Kernel
(
M
,
T
.
ceildiv
(
dexpert
,
block_dexpert
),
threads
=
threads
)
as
(
bx
,
by
):
...
...
@@ -158,8 +155,8 @@ def moe_forward_tilelang_routed(d_hidden,
gate_logits_local
=
T
.
alloc_fragment
((
block_token
,
block_dexpert
),
dtype
=
accum_dtype
)
up_logits_local
=
T
.
alloc_fragment
((
block_token
,
block_dexpert
),
dtype
=
accum_dtype
)
cur_group_idx
=
T
.
alloc_local
([
1
],
"
int32
"
)
cur_group_size
=
T
.
alloc_local
([
1
],
"
int32
"
)
cur_group_idx
=
T
.
alloc_local
([
1
],
T
.
int32
)
cur_group_size
=
T
.
alloc_local
([
1
],
T
.
int32
)
T
.
use_swizzle
(
10
,
enable
=
True
)
...
...
@@ -168,48 +165,37 @@ def moe_forward_tilelang_routed(d_hidden,
cur_group_idx
[
0
]
=
group_idx_for_bx
[
bx
]
cur_group_size
[
0
]
=
group_sizes
[
cur_group_idx
[
0
]]
m_start
=
m_start_padded
-
group_padded_offsets
[
cur_group_idx
[
0
]]
+
group_offsets
[
cur_group_idx
[
0
]]
actual_rows
=
T
.
max
(
0
,
T
.
min
(
block_token
,
cur_group_size
[
0
]
-
(
m_start_padded
-
group_padded_offsets
[
cur_group_idx
[
0
]])))
m_start
=
m_start_padded
-
group_padded_offsets
[
cur_group_idx
[
0
]]
+
group_offsets
[
cur_group_idx
[
0
]]
actual_rows
=
T
.
max
(
0
,
T
.
min
(
block_token
,
cur_group_size
[
0
]
-
(
m_start_padded
-
group_padded_offsets
[
cur_group_idx
[
0
]])))
T
.
clear
(
gate_logits_local
)
T
.
clear
(
up_logits_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
dhidden
,
block_dhidden
),
num_stages
=
num_stages
):
T
.
copy
(
input
[
m_start
:
m_start
+
block_token
,
k
*
block_dhidden
:
(
k
+
1
)
*
block_dhidden
],
input
[
m_start
:
m_start
+
block_token
,
k
*
block_dhidden
:
(
k
+
1
)
*
block_dhidden
],
input_shared
,
coalesced_width
=
coalesced_width
)
coalesced_width
=
coalesced_width
,
)
T
.
copy
(
routed_expert_gate
[
cur_group_idx
[
0
],
by
*
block_dexpert
:(
by
+
1
)
*
block_dexpert
,
k
*
block_dhidden
:(
k
+
1
)
*
block_dhidden
],
routed_expert_gate_shared
,
coalesced_width
=
coalesced_width
)
T
.
gemm
(
input_shared
,
routed_expert_gate
[
cur_group_idx
[
0
],
by
*
block_dexpert
:
(
by
+
1
)
*
block_dexpert
,
k
*
block_dhidden
:
(
k
+
1
)
*
block_dhidden
],
routed_expert_gate_shared
,
gate_logits_local
,
k_pack
=
k_pack
,
transpose_B
=
True
)
coalesced_width
=
coalesced_width
,
)
T
.
gemm
(
input_shared
,
routed_expert_gate_shared
,
gate_logits_local
,
k_pack
=
k_pack
,
transpose_B
=
True
)
T
.
copy
(
routed_expert_up
[
cur_group_idx
[
0
],
by
*
block_dexpert
:(
by
+
1
)
*
block_dexpert
,
k
*
block_dhidden
:(
k
+
1
)
*
block_dhidden
],
routed_expert_up
[
cur_group_idx
[
0
],
by
*
block_dexpert
:
(
by
+
1
)
*
block_dexpert
,
k
*
block_dhidden
:
(
k
+
1
)
*
block_dhidden
],
routed_expert_up_shared
,
coalesced_width
=
coalesced_width
)
T
.
gemm
(
input_shared
,
routed_expert_up_shared
,
up_logits_local
,
k_pack
=
k_pack
,
transpose_B
=
True
)
coalesced_width
=
coalesced_width
,
)
T
.
gemm
(
input_shared
,
routed_expert_up_shared
,
up_logits_local
,
k_pack
=
k_pack
,
transpose_B
=
True
)
for
i
,
j
in
T
.
Parallel
(
block_token
,
block_dexpert
):
gate_logits_local
[
i
,
j
]
=
gate_logits_local
[
i
,
j
]
*
(
1.0
/
(
1.0
+
T
.
exp2
(
-
gate_logits_local
[
i
,
j
]
*
scale
)))
gate_logits_local
[
i
,
j
]
=
gate_logits_local
[
i
,
j
]
*
(
1.0
/
(
1.0
+
T
.
exp2
(
-
gate_logits_local
[
i
,
j
]
*
scale
)))
up_logits_local
[
i
,
j
]
=
up_logits_local
[
i
,
j
]
*
gate_logits_local
[
i
,
j
]
for
i
,
j
in
T
.
Parallel
(
block_token
,
block_dexpert
):
...
...
@@ -222,8 +208,8 @@ def moe_forward_tilelang_routed(d_hidden,
routed_expert_down_shared
=
T
.
alloc_shared
((
block_dhidden
,
block_dexpert
),
dtype
=
dtype
)
output_local
=
T
.
alloc_fragment
((
block_token
,
block_dhidden
),
dtype
=
accum_dtype
)
cur_group_idx
=
T
.
alloc_local
([
1
],
"
int32
"
)
cur_group_size
=
T
.
alloc_local
([
1
],
"
int32
"
)
cur_group_idx
=
T
.
alloc_local
([
1
],
T
.
int32
)
cur_group_size
=
T
.
alloc_local
([
1
],
T
.
int32
)
T
.
use_swizzle
(
10
,
enable
=
True
)
...
...
@@ -232,50 +218,35 @@ def moe_forward_tilelang_routed(d_hidden,
cur_group_idx
[
0
]
=
group_idx_for_bx
[
bx
]
cur_group_size
[
0
]
=
group_sizes
[
cur_group_idx
[
0
]]
m_start
=
m_start_padded
-
group_padded_offsets
[
cur_group_idx
[
0
]]
+
group_offsets
[
cur_group_idx
[
0
]]
actual_rows
=
T
.
max
(
0
,
T
.
min
(
block_token
,
cur_group_size
[
0
]
-
(
m_start_padded
-
group_padded_offsets
[
cur_group_idx
[
0
]])))
m_start
=
m_start_padded
-
group_padded_offsets
[
cur_group_idx
[
0
]]
+
group_offsets
[
cur_group_idx
[
0
]]
actual_rows
=
T
.
max
(
0
,
T
.
min
(
block_token
,
cur_group_size
[
0
]
-
(
m_start_padded
-
group_padded_offsets
[
cur_group_idx
[
0
]])))
T
.
clear
(
output_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
dexpert
,
block_dexpert
),
num_stages
=
num_stages
):
T
.
copy
(
up_logits
[
m_start
:
m_start
+
block_token
,
k
*
block_dexpert
:(
k
+
1
)
*
block_dexpert
],
up_logits
[
m_start
:
m_start
+
block_token
,
k
*
block_dexpert
:
(
k
+
1
)
*
block_dexpert
],
up_logits_shared
,
coalesced_width
=
coalesced_width
)
coalesced_width
=
coalesced_width
,
)
T
.
copy
(
routed_expert_down
[
cur_group_idx
[
0
],
by
*
block_dhidden
:(
by
+
1
)
*
block_dhidden
,
k
*
block_dexpert
:(
k
+
1
)
*
block_dexpert
],
routed_expert_down_shared
,
coalesced_width
=
coalesced_width
)
T
.
gemm
(
up_logits_shared
,
routed_expert_down
[
cur_group_idx
[
0
],
by
*
block_dhidden
:
(
by
+
1
)
*
block_dhidden
,
k
*
block_dexpert
:
(
k
+
1
)
*
block_dexpert
],
routed_expert_down_shared
,
output_local
,
k_pack
=
k_pack
,
transpose_B
=
True
)
coalesced_width
=
coalesced_width
,
)
T
.
gemm
(
up_logits_shared
,
routed_expert_down_shared
,
output_local
,
k_pack
=
k_pack
,
transpose_B
=
True
)
for
i
,
j
in
T
.
Parallel
(
block_token
,
block_dhidden
):
if
i
<
actual_rows
:
output
[
m_start
+
i
,
by
*
block_dhidden
+
j
]
=
output_local
[
i
,
j
]
*
routed_expert_weights
[
m_start
+
i
]
output
[
m_start
+
i
,
by
*
block_dhidden
+
j
]
=
output_local
[
i
,
j
]
*
routed_expert_weights
[
m_start
+
i
]
return
kernel
class
Expert
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Dict
,
gate
:
torch
.
Tensor
,
up
:
torch
.
Tensor
,
down
:
torch
.
Tensor
,
d_expert
:
Optional
[
int
]
=
None
):
def
__init__
(
self
,
config
:
Dict
,
gate
:
torch
.
Tensor
,
up
:
torch
.
Tensor
,
down
:
torch
.
Tensor
,
d_expert
:
Optional
[
int
]
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
act_fn
=
nn
.
SiLU
()
...
...
@@ -294,14 +265,13 @@ class Expert(nn.Module):
class
MoEGate
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Dict
,
weights
:
Dict
):
super
().
__init__
()
self
.
top_k
:
int
=
config
[
"n_experts_per_token"
]
self
.
num_experts
:
int
=
config
[
"n_routed_experts"
]
self
.
d_hidden
:
int
=
config
[
"d_hidden"
]
self
.
W_g_weight
=
weights
[
'
router.weight
'
].
t
()
self
.
W_g_weight
=
weights
[
"
router.weight
"
].
t
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
logits
=
x
@
self
.
W_g_weight
...
...
@@ -312,76 +282,69 @@ class MoEGate(nn.Module):
class
MoE
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Dict
,
shared_kernel
:
tilelang
.
JITKernel
,
routed_kernel
:
tilelang
.
JITKernel
,
weights
:
Dict
,
padding_M
:
int
=
128
):
def
__init__
(
self
,
config
:
Dict
,
shared_kernel
:
tilelang
.
JITKernel
,
routed_kernel
:
tilelang
.
JITKernel
,
weights
:
Dict
,
padding_M
:
int
=
128
):
super
().
__init__
()
self
.
config
=
config
self
.
shared_kernel
=
shared_kernel
self
.
routed_kernel
=
routed_kernel
self
.
padding_M
=
padding_M
self
.
experts
=
nn
.
ModuleList
([
Expert
(
config
,
gate
=
weights
[
f
'experts.
{
i
}
.0.weight'
],
up
=
weights
[
f
'experts.
{
i
}
.1.weight'
],
down
=
weights
[
f
'experts.
{
i
}
.2.weight'
])
for
i
in
range
(
config
[
"n_routed_experts"
])
])
self
.
experts
=
nn
.
ModuleList
(
[
Expert
(
config
,
gate
=
weights
[
f
"experts.
{
i
}
.0.weight"
],
up
=
weights
[
f
"experts.
{
i
}
.1.weight"
],
down
=
weights
[
f
"experts.
{
i
}
.2.weight"
],
)
for
i
in
range
(
config
[
"n_routed_experts"
])
]
)
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
gating_network
=
MoEGate
(
config
,
weights
).
to
(
self
.
device
)
shared_expert_dim
=
config
[
"d_expert"
]
*
config
[
"n_shared_experts"
]
self
.
shared_expert
=
Expert
(
config
=
config
,
gate
=
weights
[
'shared_experts.0.weight'
],
up
=
weights
[
'shared_experts.1.weight'
],
down
=
weights
[
'shared_experts.2.weight'
],
d_expert
=
shared_expert_dim
).
to
(
self
.
device
)
gate
=
weights
[
"shared_experts.0.weight"
],
up
=
weights
[
"shared_experts.1.weight"
],
down
=
weights
[
"shared_experts.2.weight"
],
d_expert
=
shared_expert_dim
,
).
to
(
self
.
device
)
self
.
expert_cache
=
torch
.
zeros
(
(
config
[
"batch_size"
]
*
config
[
"seq_len"
],
config
[
"d_hidden"
]),
dtype
=
torch
.
float16
,
device
=
self
.
device
)
self
.
stacked_expert_w_gate
=
torch
.
stack
([
expert
.
W_gate_weight
for
expert
in
self
.
experts
],
dim
=
0
)
self
.
stacked_expert_w_up
=
torch
.
stack
([
expert
.
W_up_weight
for
expert
in
self
.
experts
],
dim
=
0
)
self
.
stacked_expert_w_down
=
torch
.
stack
([
expert
.
W_down_weight
for
expert
in
self
.
experts
],
dim
=
0
)
(
config
[
"batch_size"
]
*
config
[
"seq_len"
],
config
[
"d_hidden"
]),
dtype
=
torch
.
float16
,
device
=
self
.
device
)
self
.
stacked_expert_w_gate
=
torch
.
stack
([
expert
.
W_gate_weight
for
expert
in
self
.
experts
],
dim
=
0
)
self
.
stacked_expert_w_up
=
torch
.
stack
([
expert
.
W_up_weight
for
expert
in
self
.
experts
],
dim
=
0
)
self
.
stacked_expert_w_down
=
torch
.
stack
([
expert
.
W_down_weight
for
expert
in
self
.
experts
],
dim
=
0
)
self
.
stacked_expert_tokens
=
torch
.
empty
(
(
config
[
"batch_size"
]
*
config
[
"seq_len"
]
*
config
[
"n_experts_per_token"
],
self
.
config
[
"d_hidden"
]),
(
config
[
"batch_size"
]
*
config
[
"seq_len"
]
*
config
[
"n_experts_per_token"
],
self
.
config
[
"d_hidden"
]),
dtype
=
torch
.
float16
,
device
=
self
.
device
)
device
=
self
.
device
,
)
self
.
stacked_expert_weights
=
torch
.
empty
(
(
config
[
"batch_size"
]
*
config
[
"seq_len"
]
*
config
[
"n_experts_per_token"
]),
dtype
=
torch
.
float16
,
device
=
self
.
device
)
(
config
[
"batch_size"
]
*
config
[
"seq_len"
]
*
config
[
"n_experts_per_token"
]),
dtype
=
torch
.
float16
,
device
=
self
.
device
)
self
.
stacked_expert_tokens_idxs
=
torch
.
empty
(
(
config
[
"batch_size"
]
*
config
[
"seq_len"
]
*
config
[
"n_experts_per_token"
]),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
(
config
[
"batch_size"
]
*
config
[
"seq_len"
]
*
config
[
"n_experts_per_token"
]),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
up_logits_shared
=
torch
.
empty
(
(
config
[
"batch_size"
]
*
config
[
"seq_len"
],
self
.
config
[
"d_expert"
]),
dtype
=
torch
.
float16
,
device
=
self
.
device
)
(
config
[
"batch_size"
]
*
config
[
"seq_len"
],
self
.
config
[
"d_expert"
]),
dtype
=
torch
.
float16
,
device
=
self
.
device
)
self
.
expert_output_shared
=
torch
.
empty
(
(
config
[
"batch_size"
]
*
config
[
"seq_len"
],
self
.
config
[
"d_hidden"
]),
dtype
=
torch
.
float16
,
device
=
self
.
device
)
(
config
[
"batch_size"
]
*
config
[
"seq_len"
],
self
.
config
[
"d_hidden"
]),
dtype
=
torch
.
float16
,
device
=
self
.
device
)
self
.
up_logits_routed
=
torch
.
empty
(
(
config
[
"batch_size"
]
*
config
[
"seq_len"
]
*
config
[
"n_experts_per_token"
],
self
.
config
[
"d_expert"
]),
(
config
[
"batch_size"
]
*
config
[
"seq_len"
]
*
config
[
"n_experts_per_token"
],
self
.
config
[
"d_expert"
]),
dtype
=
torch
.
float16
,
device
=
self
.
device
)
device
=
self
.
device
,
)
self
.
expert_output_routed
=
torch
.
empty
(
(
config
[
"batch_size"
]
*
config
[
"seq_len"
]
*
config
[
"n_experts_per_token"
],
self
.
config
[
"d_hidden"
]),
(
config
[
"batch_size"
]
*
config
[
"seq_len"
]
*
config
[
"n_experts_per_token"
],
self
.
config
[
"d_hidden"
]),
dtype
=
torch
.
float16
,
device
=
self
.
device
)
device
=
self
.
device
,
)
@
torch
.
no_grad
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -413,22 +376,20 @@ class MoE(nn.Module):
self
.
stacked_expert_tokens
[
start_idx
:
end_idx
]
=
expert_tokens
self
.
stacked_expert_tokens_idxs
[
start_idx
:
end_idx
]
=
exp_token_idxs
self
.
stacked_expert_weights
[
start_idx
:
end_idx
]
=
flat_expert_weights
[
idxs
[
start_idx
:
end_idx
]]
self
.
stacked_expert_weights
[
start_idx
:
end_idx
]
=
flat_expert_weights
[
idxs
[
start_idx
:
end_idx
]]
group_sizes
=
torch
.
tensor
(
counts
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
group_offset
=
torch
.
tensor
(
tokens_per_expert
-
counts
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
group_offset
=
torch
.
tensor
(
tokens_per_expert
-
counts
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
group_padded_offsets
=
[
0
for
_
in
range
(
len
(
group_sizes
))]
for
i
in
range
(
1
,
len
(
group_sizes
)):
group_padded_offsets
[
i
]
=
group_padded_offsets
[
i
-
1
]
+
math
.
ceil
(
(
counts
[
i
-
1
]
+
1
)
/
self
.
padding_M
)
*
self
.
padding_M
group_padded_offsets
[
i
]
=
group_padded_offsets
[
i
-
1
]
+
math
.
ceil
((
counts
[
i
-
1
]
+
1
)
/
self
.
padding_M
)
*
self
.
padding_M
block_token
=
128
M
=
math
.
ceil
(
self
.
config
[
"batch_size"
]
*
self
.
config
[
"seq_len"
]
*
self
.
config
[
"n_experts_per_token"
]
/
block_token
)
+
self
.
config
[
"n_routed_experts"
]
M
=
(
math
.
ceil
(
self
.
config
[
"batch_size"
]
*
self
.
config
[
"seq_len"
]
*
self
.
config
[
"n_experts_per_token"
]
/
block_token
)
+
self
.
config
[
"n_routed_experts"
]
)
group_idx_for_bx
=
[
0
for
_
in
range
(
M
)]
for
bx
in
range
(
M
):
...
...
@@ -437,8 +398,7 @@ class MoE(nn.Module):
if
m_start_padded
>=
group_padded_offsets
[
i
]:
group_idx_for_bx
[
bx
]
=
i
group_padded_offsets
=
torch
.
tensor
(
group_padded_offsets
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
group_padded_offsets
=
torch
.
tensor
(
group_padded_offsets
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
group_idx_for_bx
=
torch
.
tensor
(
group_idx_for_bx
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# Multi-stream execution
...
...
@@ -448,11 +408,19 @@ class MoE(nn.Module):
with
torch
.
cuda
.
stream
(
routed_stream
):
# Tilelang version: Grouped GEMM
self
.
routed_kernel
(
self
.
stacked_expert_tokens
,
self
.
stacked_expert_w_gate
,
self
.
stacked_expert_w_up
,
self
.
stacked_expert_w_down
,
self
.
stacked_expert_weights
,
group_sizes
,
group_offset
,
group_padded_offsets
,
group_idx_for_bx
,
self
.
up_logits_routed
,
self
.
expert_output_routed
)
self
.
routed_kernel
(
self
.
stacked_expert_tokens
,
self
.
stacked_expert_w_gate
,
self
.
stacked_expert_w_up
,
self
.
stacked_expert_w_down
,
self
.
stacked_expert_weights
,
group_sizes
,
group_offset
,
group_padded_offsets
,
group_idx_for_bx
,
self
.
up_logits_routed
,
self
.
expert_output_routed
,
)
# Scatter reduce
self
.
expert_cache
=
torch
.
scatter_reduce
(
...
...
@@ -460,14 +428,19 @@ class MoE(nn.Module):
0
,
self
.
stacked_expert_tokens_idxs
.
view
(
-
1
,
1
).
repeat
(
1
,
x_flat
.
shape
[
-
1
]),
self
.
expert_output_routed
,
reduce
=
'sum'
)
reduce
=
"sum"
,
)
routed_output
=
self
.
expert_cache
.
view
(
*
orig_shape
)
with
torch
.
cuda
.
stream
(
shared_stream
):
self
.
shared_kernel
(
x_flat
,
self
.
shared_expert
.
W_gate_weight
,
self
.
shared_expert
.
W_up_weight
,
self
.
shared_expert
.
W_down_weight
,
self
.
up_logits_shared
,
self
.
expert_output_shared
)
self
.
shared_kernel
(
x_flat
,
self
.
shared_expert
.
W_gate_weight
,
self
.
shared_expert
.
W_up_weight
,
self
.
shared_expert
.
W_down_weight
,
self
.
up_logits_shared
,
self
.
expert_output_shared
,
)
shared_output
=
self
.
expert_output_shared
.
view
(
*
orig_shape
)
torch
.
cuda
.
synchronize
()
...
...
@@ -491,14 +464,15 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
"""
input_tensor
,
weights
,
config
=
data
dtype_str
=
"
float16
"
dtype_str
=
T
.
float16
shared_kernel
=
moe_forward_tilelang_shared
(
config
[
"d_hidden"
],
config
[
"d_expert"
],
config
[
"n_shared_experts"
],
dtype
=
dtype_str
,
num_tokens
=
config
[
"batch_size"
]
*
config
[
"seq_len"
])
num_tokens
=
config
[
"batch_size"
]
*
config
[
"seq_len"
],
)
routed_kernel
=
moe_forward_tilelang_routed
(
config
[
"d_hidden"
],
config
[
"d_expert"
],
...
...
@@ -512,7 +486,8 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
threads
=
256
,
num_stages
=
1
,
k_pack
=
1
,
coalesced_width
=
2
)
coalesced_width
=
2
,
)
moe
=
MoE
(
config
,
shared_kernel
,
routed_kernel
,
weights
,
padding_M
=
128
)
...
...
@@ -521,13 +496,7 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
return
output
def
main
(
d_hidden
=
7168
,
d_expert
=
2048
,
n_routed_experts
=
8
,
n_shared_experts
=
1
,
n_experts_per_token
=
4
,
batch_size
=
1
,
seq_len
=
8192
):
def
main
(
d_hidden
=
7168
,
d_expert
=
2048
,
n_routed_experts
=
8
,
n_shared_experts
=
1
,
n_experts_per_token
=
4
,
batch_size
=
1
,
seq_len
=
8192
):
config
=
{
"dhidden"
:
d_hidden
,
"dexpert"
:
d_expert
,
...
...
@@ -536,7 +505,7 @@ def main(d_hidden=7168,
"nexpertspertoken"
:
n_experts_per_token
,
"bs"
:
batch_size
,
"seqlen"
:
seq_len
,
"seed"
:
81394
"seed"
:
81394
,
}
data
=
generate_input
(
**
config
)
...
...
examples/fusedmoe/example_fusedmoe_torch.py
View file @
667632cc
...
...
@@ -6,7 +6,6 @@ from typing import Dict, Tuple, Optional
# Reference code in PyTorch
class
ExpertTorch
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Dict
,
d_expert
:
Optional
[
int
]
=
None
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -25,7 +24,6 @@ class ExpertTorch(nn.Module):
class
MoEGateTorch
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Dict
):
super
().
__init__
()
self
.
top_k
:
int
=
config
[
"n_experts_per_token"
]
...
...
@@ -43,12 +41,10 @@ class MoEGateTorch(nn.Module):
class
MoETorch
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Dict
):
super
().
__init__
()
self
.
config
=
config
self
.
experts
=
nn
.
ModuleList
(
[
ExpertTorch
(
config
)
for
_
in
range
(
config
[
"n_routed_experts"
])])
self
.
experts
=
nn
.
ModuleList
([
ExpertTorch
(
config
)
for
_
in
range
(
config
[
"n_routed_experts"
])])
self
.
gating_network
=
MoEGateTorch
(
config
)
shared_expert_dim
=
config
[
"d_expert"
]
*
config
[
"n_shared_experts"
]
self
.
shared_expert
=
ExpertTorch
(
config
=
config
,
d_expert
=
shared_expert_dim
)
...
...
@@ -67,8 +63,7 @@ class MoETorch(nn.Module):
return
routed_output
+
shared_output
@
torch
.
no_grad
()
def
moe_infer
(
self
,
x
:
torch
.
Tensor
,
flat_expert_indices
:
torch
.
Tensor
,
flat_expert_weights
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
moe_infer
(
self
,
x
:
torch
.
Tensor
,
flat_expert_indices
:
torch
.
Tensor
,
flat_expert_weights
:
torch
.
Tensor
)
->
torch
.
Tensor
:
expert_cache
=
torch
.
zeros_like
(
x
)
# test_expert_cache = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"]))
# test_expert_tokens = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"]))
...
...
@@ -91,8 +86,7 @@ class MoETorch(nn.Module):
expert_out
=
expert
(
expert_tokens
)
expert_out
.
mul_
(
flat_expert_weights
[
idxs
[
start_idx
:
end_idx
]])
expert_cache
.
scatter_reduce_
(
0
,
exp_token_idxs
.
view
(
-
1
,
1
).
repeat
(
1
,
x
.
shape
[
-
1
]),
expert_out
,
reduce
=
'sum'
)
expert_cache
.
scatter_reduce_
(
0
,
exp_token_idxs
.
view
(
-
1
,
1
).
repeat
(
1
,
x
.
shape
[
-
1
]),
expert_out
,
reduce
=
"sum"
)
return
expert_cache
...
...
@@ -116,21 +110,21 @@ def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
moe
=
MoETorch
(
config
)
# Fill in the given weights of the model
moe
.
gating_network
.
W_g
.
weight
=
nn
.
Parameter
(
weights
[
'
router.weight
'
])
moe
.
gating_network
.
W_g
.
weight
=
nn
.
Parameter
(
weights
[
"
router.weight
"
])
for
i
in
range
(
num_experts
):
gate_proj_weight
=
weights
[
f
'
experts.
{
i
}
.0.weight
'
]
up_proj_weight
=
weights
[
f
'
experts.
{
i
}
.1.weight
'
]
down_proj_weight
=
weights
[
f
'
experts.
{
i
}
.2.weight
'
]
gate_proj_weight
=
weights
[
f
"
experts.
{
i
}
.0.weight
"
]
up_proj_weight
=
weights
[
f
"
experts.
{
i
}
.1.weight
"
]
down_proj_weight
=
weights
[
f
"
experts.
{
i
}
.2.weight
"
]
# Transpose weights to match expected shape for nn.Linear
moe
.
experts
[
i
].
W_gate
.
weight
=
nn
.
Parameter
(
gate_proj_weight
.
t
())
moe
.
experts
[
i
].
W_up
.
weight
=
nn
.
Parameter
(
up_proj_weight
.
t
())
moe
.
experts
[
i
].
W_down
.
weight
=
nn
.
Parameter
(
down_proj_weight
.
t
())
moe
.
shared_expert
.
W_gate
.
weight
=
nn
.
Parameter
(
weights
[
'
shared_experts.0.weight
'
].
t
())
moe
.
shared_expert
.
W_up
.
weight
=
nn
.
Parameter
(
weights
[
'
shared_experts.1.weight
'
].
t
())
moe
.
shared_expert
.
W_down
.
weight
=
nn
.
Parameter
(
weights
[
'
shared_experts.2.weight
'
].
t
())
moe
.
shared_expert
.
W_gate
.
weight
=
nn
.
Parameter
(
weights
[
"
shared_experts.0.weight
"
].
t
())
moe
.
shared_expert
.
W_up
.
weight
=
nn
.
Parameter
(
weights
[
"
shared_experts.1.weight
"
].
t
())
moe
.
shared_expert
.
W_down
.
weight
=
nn
.
Parameter
(
weights
[
"
shared_experts.2.weight
"
].
t
())
output
=
moe
(
input_tensor
)
...
...
@@ -140,10 +134,9 @@ def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
# Input generation for the reference code
def
generate_input
(
dhidden
:
int
,
dexpert
:
int
,
nroutedexperts
:
int
,
nsharedexperts
:
int
,
nexpertspertoken
:
int
,
bs
:
int
,
seqlen
:
int
,
seed
:
int
)
->
Tuple
[
torch
.
Tensor
,
Dict
,
Dict
]:
def
generate_input
(
dhidden
:
int
,
dexpert
:
int
,
nroutedexperts
:
int
,
nsharedexperts
:
int
,
nexpertspertoken
:
int
,
bs
:
int
,
seqlen
:
int
,
seed
:
int
)
->
Tuple
[
torch
.
Tensor
,
Dict
,
Dict
]:
# Really dumb but for now _ isn't parsing correctly.
d_hidden
=
dhidden
d_expert
=
dexpert
...
...
@@ -163,50 +156,40 @@ def generate_input(dhidden: int, dexpert: int, nroutedexperts: int, nsharedexper
"seq_len"
:
seq_len
,
}
gen
=
torch
.
Generator
(
device
=
'
cuda
'
)
gen
=
torch
.
Generator
(
device
=
"
cuda
"
)
gen
.
manual_seed
(
seed
)
num_experts
=
n_routed_experts
expert_dim
=
d_expert
weights
=
{}
input_tensor
=
torch
.
randn
((
batch_size
,
seq_len
,
d_hidden
),
device
=
'cuda'
,
dtype
=
torch
.
float16
,
generator
=
gen
).
contiguous
()
input_tensor
=
torch
.
randn
((
batch_size
,
seq_len
,
d_hidden
),
device
=
"cuda"
,
dtype
=
torch
.
float16
,
generator
=
gen
).
contiguous
()
# Initialize router weights
weights
[
'router.weight'
]
=
torch
.
randn
(
(
num_experts
,
d_hidden
),
device
=
"cuda"
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
d_hidden
)
weights
[
"router.weight"
]
=
torch
.
randn
((
num_experts
,
d_hidden
),
device
=
"cuda"
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
d_hidden
)
for
i
in
range
(
num_experts
):
weights
[
f
'experts.
{
i
}
.0.weight'
]
=
torch
.
randn
(
(
d_hidden
,
expert_dim
),
device
=
'cuda'
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
expert_dim
)
weights
[
f
'experts.
{
i
}
.1.weight'
]
=
torch
.
randn
(
(
d_hidden
,
expert_dim
),
device
=
'cuda'
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
expert_dim
)
weights
[
f
'experts.
{
i
}
.2.weight'
]
=
torch
.
randn
(
(
expert_dim
,
d_hidden
),
device
=
'cuda'
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
d_hidden
)
weights
[
'shared_experts.0.weight'
]
=
torch
.
randn
(
(
d_hidden
,
expert_dim
*
n_shared_experts
),
device
=
'cuda'
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
expert_dim
*
n_shared_experts
)
weights
[
'shared_experts.1.weight'
]
=
torch
.
randn
(
(
d_hidden
,
expert_dim
*
n_shared_experts
),
device
=
'cuda'
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
expert_dim
*
n_shared_experts
)
weights
[
'shared_experts.2.weight'
]
=
torch
.
randn
((
expert_dim
*
n_shared_experts
,
d_hidden
),
device
=
'cuda'
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
d_hidden
)
weights
[
f
"experts.
{
i
}
.0.weight"
]
=
torch
.
randn
(
(
d_hidden
,
expert_dim
),
device
=
"cuda"
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
expert_dim
)
weights
[
f
"experts.
{
i
}
.1.weight"
]
=
torch
.
randn
(
(
d_hidden
,
expert_dim
),
device
=
"cuda"
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
expert_dim
)
weights
[
f
"experts.
{
i
}
.2.weight"
]
=
torch
.
randn
(
(
expert_dim
,
d_hidden
),
device
=
"cuda"
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
d_hidden
)
weights
[
"shared_experts.0.weight"
]
=
torch
.
randn
(
(
d_hidden
,
expert_dim
*
n_shared_experts
),
device
=
"cuda"
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
expert_dim
*
n_shared_experts
)
weights
[
"shared_experts.1.weight"
]
=
torch
.
randn
(
(
d_hidden
,
expert_dim
*
n_shared_experts
),
device
=
"cuda"
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
expert_dim
*
n_shared_experts
)
weights
[
"shared_experts.2.weight"
]
=
torch
.
randn
(
(
expert_dim
*
n_shared_experts
,
d_hidden
),
device
=
"cuda"
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
d_hidden
)
return
(
input_tensor
,
weights
,
config
)
...
...
examples/fusedmoe/test_example_fusedmoe.py
View file @
667632cc
...
...
@@ -4,13 +4,8 @@ import example_fusedmoe_tilelang
def
test_example_fusedmoe_tilelang
():
example_fusedmoe_tilelang
.
main
(
d_hidden
=
1024
,
d_expert
=
256
,
n_routed_experts
=
8
,
n_shared_experts
=
1
,
n_experts_per_token
=
4
,
batch_size
=
1
,
seq_len
=
1024
)
d_hidden
=
1024
,
d_expert
=
256
,
n_routed_experts
=
8
,
n_shared_experts
=
1
,
n_experts_per_token
=
4
,
batch_size
=
1
,
seq_len
=
1024
)
if
__name__
==
"__main__"
:
...
...
examples/gdn/example_chunk_delta_bwd.py
View file @
667632cc
...
...
@@ -12,6 +12,7 @@ print(tilelang.__file__, flush=True)
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try
:
import
fla
print
(
fla
.
__file__
,
flush
=
True
)
from
fla.ops.common.chunk_delta_h
import
chunk_gated_delta_rule_bwd_dhu
except
ImportError
:
...
...
@@ -24,7 +25,7 @@ import torch.nn.functional as F
torch
.
random
.
manual_seed
(
0
)
# torch.set_printoptions(profile="full")
from
utils
import
*
from
test_
utils
import
assert_similar
def
prepare_input
(
...
...
@@ -49,6 +50,7 @@ def prepare_input(
G
=
F
.
logsigmoid
(
G
)
try
:
from
fla.ops.utils.cumsum
import
chunk_local_cumsum
G
=
chunk_local_cumsum
(
G
,
chunk_size
)
except
ImportError
:
print
(
"fla not found, skip cumsum"
)
...
...
@@ -125,8 +127,11 @@ def torch_chunk_gated_delta_rule_bwd_dhu(
DV
=
dv
.
shape
[
-
1
]
block_S
=
64
BS
=
S
//
block_S
dh
,
dh0
,
dv2
=
torch
.
empty
((
B
,
BS
,
H
,
DK
,
DV
),
dtype
=
output_dtype
),
torch
.
empty
(
(
B
,
H
,
DK
,
DV
),
dtype
=
state_dtype
),
torch
.
empty
((
B
,
S
,
H
,
DV
),
dtype
=
output_dtype
)
dh
,
dh0
,
dv2
=
(
torch
.
empty
((
B
,
BS
,
H
,
DK
,
DV
),
dtype
=
output_dtype
),
torch
.
empty
((
B
,
H
,
DK
,
DV
),
dtype
=
state_dtype
),
torch
.
empty
((
B
,
S
,
H
,
DV
),
dtype
=
output_dtype
),
)
dh_tmp
=
torch
.
empty
((
B
,
H
,
DK
,
DV
),
dtype
=
accum_dtype
)
dv_tmp
=
torch
.
empty
((
B
,
S
,
H
,
DV
),
dtype
=
accum_dtype
)
Q_tmp
=
torch
.
empty
((
B
,
S
,
H
,
DK
),
dtype
=
accum_dtype
)
...
...
@@ -138,34 +143,30 @@ def torch_chunk_gated_delta_rule_bwd_dhu(
for
i_s
in
range
(
BS
-
1
,
-
1
,
-
1
):
dh
[:,
i_s
,
:,
:,
:]
=
dh_tmp
dv_tmp
=
torch
.
matmul
(
K
[:,
i_s
*
block_S
:(
i_s
+
1
)
*
block_S
,
:,
:].
permute
(
0
,
2
,
1
,
3
),
dh_tmp
.
to
(
K
.
dtype
)).
permute
(
0
,
2
,
1
,
3
)
dv_tmp
=
torch
.
matmul
(
K
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:].
permute
(
0
,
2
,
1
,
3
),
dh_tmp
.
to
(
K
.
dtype
)).
permute
(
0
,
2
,
1
,
3
)
if
use_g
:
for
i_bh
in
range
(
B
*
H
):
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
for
i_s2
in
range
(
block_S
):
if
G
[
i_b
,
i_s
*
block_S
+
block_S
-
1
,
i_h
]
-
G
[
i_b
,
i_s
*
block_S
+
i_s2
,
i_h
]
<=
0
:
dv_tmp
[
i_b
,
i_s2
,
i_h
,
:]
*=
torch
.
exp
(
G
[
i_b
,
i_s
*
block_S
+
block_S
-
1
,
i_h
]
-
G
[
i_b
,
i_s
*
block_S
+
i_s2
,
i_h
])
if
G
[
i_b
,
i_s
*
block_S
+
block_S
-
1
,
i_h
]
-
G
[
i_b
,
i_s
*
block_S
+
i_s2
,
i_h
]
<=
0
:
dv_tmp
[
i_b
,
i_s2
,
i_h
,
:]
*=
torch
.
exp
(
G
[
i_b
,
i_s
*
block_S
+
block_S
-
1
,
i_h
]
-
G
[
i_b
,
i_s
*
block_S
+
i_s2
,
i_h
])
else
:
dv_tmp
[
i_b
,
i_s2
,
i_h
,
:]
=
0
dv_tmp
+=
dv
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:]
dv2
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:]
=
dv_tmp
dv_tmp
+=
dv
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:]
dv2
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:]
=
dv_tmp
if
use_g
:
G_last
=
G
[:,
i_s
*
block_S
+
block_S
-
1
,
:]
for
i_bh
in
range
(
B
*
H
):
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
dh_tmp
[
i_b
,
i_h
,
:,
:]
*=
torch
.
exp
(
G_last
[
i_b
,
i_h
])
Q_tmp
=
Q
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:]
Q_tmp
=
Q
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:]
for
i_s2
in
range
(
block_S
):
for
i_k
in
range
(
DK
):
Q_tmp
[:,
i_s2
,
:,
i_k
]
*=
torch
.
exp
(
G
[:,
i_s
*
block_S
+
i_s2
,
:])
Q_tmp
*=
scale
W_tmp
=
W
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:]
dO_tmp
=
dO
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:]
W_tmp
=
W
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:]
dO_tmp
=
dO
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:]
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
dh_tmp
+=
torch
.
matmul
(
Q_tmp
.
permute
(
0
,
2
,
3
,
1
),
dO_tmp
.
permute
(
0
,
2
,
1
,
3
))
...
...
@@ -223,19 +224,19 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
@
T
.
prim_func
def
kernel
(
# Input
Q
:
T
.
Tensor
(
Q_shape
,
dtype
=
input_dtype
),
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
W
:
T
.
Tensor
(
W_shape
,
dtype
=
input_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
gate_dtype
),
h0
:
T
.
Tensor
(
h0_shape
,
dtype
=
input_dtype
),
dht
:
T
.
Tensor
(
dht_shape
,
dtype
=
input_dtype
),
dO
:
T
.
Tensor
(
dO_shape
,
dtype
=
input_dtype
),
dv
:
T
.
Tensor
(
dv_shape
,
dtype
=
input_dtype
),
# Output
dh
:
T
.
Tensor
(
dh_shape
,
dtype
=
output_dtype
),
dh0
:
T
.
Tensor
(
dh0_shape
,
dtype
=
state_dtype
),
dv2
:
T
.
Tensor
(
dv2_shape
,
dtype
=
output_dtype
),
# Input
Q
:
T
.
Tensor
(
Q_shape
,
dtype
=
input_dtype
),
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
W
:
T
.
Tensor
(
W_shape
,
dtype
=
input_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
gate_dtype
),
h0
:
T
.
Tensor
(
h0_shape
,
dtype
=
input_dtype
),
dht
:
T
.
Tensor
(
dht_shape
,
dtype
=
input_dtype
),
dO
:
T
.
Tensor
(
dO_shape
,
dtype
=
input_dtype
),
dv
:
T
.
Tensor
(
dv_shape
,
dtype
=
input_dtype
),
# Output
dh
:
T
.
Tensor
(
dh_shape
,
dtype
=
output_dtype
),
dh0
:
T
.
Tensor
(
dh0_shape
,
dtype
=
state_dtype
),
dv2
:
T
.
Tensor
(
dv2_shape
,
dtype
=
output_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
DV
,
block_DV
),
B
*
H
,
threads
=
threads
)
as
(
bv
,
bbh
):
bb
,
bh
=
bbh
//
H
,
bbh
%
H
...
...
@@ -249,13 +250,13 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
dv_fragment
=
T
.
alloc_fragment
((
block_S
,
block_DV
),
dtype
=
accum_dtype
)
dv_fragment_2
=
T
.
alloc_fragment
((
block_S
,
block_DV
),
dtype
=
accum_dtype
)
dO_shared
=
T
.
alloc_shared
((
block_S
,
block_DV
),
dtype
=
input_dtype
)
dO_shared_t
=
T
.
alloc_shared
((
block_DV
,
block_S
),
dtype
=
"
float32
"
)
dO_fragment
=
T
.
alloc_fragment
((
block_S
,
block_DV
),
dtype
=
"
float32
"
)
dO_fragment_t
=
T
.
alloc_fragment
((
block_DV
,
block_S
),
dtype
=
"
float32
"
)
dO_shared_t
=
T
.
alloc_shared
((
block_DV
,
block_S
),
dtype
=
T
.
float32
)
dO_fragment
=
T
.
alloc_fragment
((
block_S
,
block_DV
),
dtype
=
T
.
float32
)
dO_fragment_t
=
T
.
alloc_fragment
((
block_DV
,
block_S
),
dtype
=
T
.
float32
)
K_shared
=
T
.
alloc_shared
((
block_S
,
DK
),
dtype
=
input_dtype
)
Q_shared
=
T
.
alloc_shared
((
block_S
,
DK
),
dtype
=
input_dtype
)
Q_shared_fp32
=
T
.
alloc_shared
((
block_S
,
DK
),
dtype
=
"
float32
"
)
Q_shared_fp32
=
T
.
alloc_shared
((
block_S
,
DK
),
dtype
=
T
.
float32
)
W_shared
=
T
.
alloc_shared
((
block_S
,
DK
),
dtype
=
input_dtype
)
G_last_local
=
T
.
alloc_local
((
1
),
dtype
=
gate_dtype
)
...
...
@@ -269,20 +270,22 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
T
.
use_swizzle
(
10
)
T
.
annotate_layout
({
b_dh_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
b_dh_shared
),
b_dh_shared_fp32
:
tilelang
.
layout
.
make_swizzled_layout
(
b_dh_shared_fp32
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
dO_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dO_shared
),
dO_shared_t
:
tilelang
.
layout
.
make_swizzled_layout
(
dO_shared_t
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
Q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared
),
Q_shared_fp32
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared_fp32
),
W_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
W_shared
),
})
T
.
annotate_layout
(
{
b_dh_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
b_dh_shared
),
b_dh_shared_fp32
:
tilelang
.
layout
.
make_swizzled_layout
(
b_dh_shared_fp32
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
dO_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dO_shared
),
dO_shared_t
:
tilelang
.
layout
.
make_swizzled_layout
(
dO_shared_t
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
Q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared
),
Q_shared_fp32
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared_fp32
),
W_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
W_shared
),
}
)
if
use_final_state_gradient
:
T
.
copy
(
dht
[
bb
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
],
b_dh_shared
)
T
.
copy
(
dht
[
bb
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
],
b_dh_shared
)
T
.
copy
(
b_dh_shared
,
b_dh_fragment
)
else
:
T
.
clear
(
b_dh_fragment
)
...
...
@@ -293,17 +296,14 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
# Store the updated dh
T
.
copy
(
b_dh_fragment
,
b_dh_shared
)
T
.
copy
(
b_dh_shared
,
dh
[
bb
,
i_s_inv
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
T
.
copy
(
b_dh_shared
,
dh
[
bb
,
i_s_inv
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
# Update dv
T
.
copy
(
K
[
bb
,
i_s_inv
*
block_S
:
(
i_s_inv
+
1
)
*
block_S
,
bh
,
0
:
DK
],
K_shared
)
T
.
copy
(
K
[
bb
,
i_s_inv
*
block_S
:
(
i_s_inv
+
1
)
*
block_S
,
bh
,
0
:
DK
],
K_shared
)
T
.
gemm
(
K_shared
,
b_dh_shared
,
dv_fragment
,
clear_accum
=
True
)
if
use_g
:
T
.
copy
(
G
[
bb
,
i_s_inv
*
block_S
:(
i_s_inv
+
1
)
*
block_S
,
bh
],
G_shared
,
disable_tma
=
True
)
T
.
copy
(
G
[
bb
,
i_s_inv
*
block_S
:
(
i_s_inv
+
1
)
*
block_S
,
bh
],
G_shared
,
disable_tma
=
True
)
T
.
copy
(
G_shared
,
G_fragment
)
G_last_local
[
0
]
=
G_shared
[
block_S
-
1
]
G_last_local_exp
[
0
]
=
T
.
exp
(
G_last_local
[
0
])
...
...
@@ -313,27 +313,22 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
# with T.If(G_last_local[0] - G_shared[i_s2] <= 0):
with
T
.
If
(
G_last_local
[
0
]
-
G_fragment
[
i_s2
]
<=
0
):
with
T
.
Then
():
dv_fragment
[
i_s2
,
i_v
]
=
dv_fragment
[
i_s2
,
i_v
]
*
G_fragment_post
[
i_s2
]
dv_fragment
[
i_s2
,
i_v
]
=
dv_fragment
[
i_s2
,
i_v
]
*
G_fragment_post
[
i_s2
]
with
T
.
Else
():
dv_fragment
[
i_s2
,
i_v
]
=
0
T
.
copy
(
dv
[
bb
,
i_s_inv
*
block_S
:(
i_s_inv
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:(
bv
+
1
)
*
block_DV
],
dv_shared
)
T
.
copy
(
dv
[
bb
,
i_s_inv
*
block_S
:
(
i_s_inv
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
],
dv_shared
)
T
.
copy
(
dv_shared
,
dv_fragment_2
)
for
i_s2
,
i_v
in
T
.
Parallel
(
block_S
,
block_DV
):
dv_fragment
[
i_s2
,
i_v
]
=
dv_fragment
[
i_s2
,
i_v
]
+
dv_fragment_2
[
i_s2
,
i_v
]
# Store the updated dv
T
.
copy
(
dv_fragment
,
dv_shared
)
T
.
copy
(
dv_shared
,
dv2
[
bb
,
i_s_inv
*
block_S
:(
i_s_inv
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:(
bv
+
1
)
*
block_DV
])
T
.
copy
(
dv_shared
,
dv2
[
bb
,
i_s_inv
*
block_S
:
(
i_s_inv
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
# Update dh
T
.
copy
(
Q
[
bb
,
i_s_inv
*
block_S
:
(
i_s_inv
+
1
)
*
block_S
,
bh
,
0
:
DK
],
Q_shared
)
T
.
copy
(
W
[
bb
,
i_s_inv
*
block_S
:
(
i_s_inv
+
1
)
*
block_S
,
bh
,
0
:
DK
],
W_shared
)
T
.
copy
(
Q
[
bb
,
i_s_inv
*
block_S
:
(
i_s_inv
+
1
)
*
block_S
,
bh
,
0
:
DK
],
Q_shared
)
T
.
copy
(
W
[
bb
,
i_s_inv
*
block_S
:
(
i_s_inv
+
1
)
*
block_S
,
bh
,
0
:
DK
],
W_shared
)
T
.
clear
(
Q_fragment
)
if
use_g
:
...
...
@@ -353,9 +348,7 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
for
i_s2
,
i_k
in
T
.
Parallel
(
block_S
,
DK
):
Q_fragment_t
[
i_k
,
i_s2
]
=
Q_fragment
[
i_s2
,
i_k
]
T
.
copy
(
dO
[
bb
,
i_s_inv
*
block_S
:(
i_s_inv
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:(
bv
+
1
)
*
block_DV
],
dO_shared
)
T
.
copy
(
dO
[
bb
,
i_s_inv
*
block_S
:
(
i_s_inv
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
],
dO_shared
)
T
.
copy
(
dO_shared
,
dO_fragment
)
for
i_s2
,
i_v
in
T
.
Parallel
(
block_S
,
block_DV
):
dO_fragment_t
[
i_v
,
i_s2
]
=
dO_fragment
[
i_s2
,
i_v
]
...
...
@@ -369,7 +362,7 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
b_dh_fragment
[
i_k
,
i_v
]
+=
b_dh_fragment_1
[
i_k
,
i_v
]
-
b_dh_fragment_2
[
i_k
,
i_v
]
if
use_initial_state
:
T
.
copy
(
b_dh_fragment
,
dh0
[
bb
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
T
.
copy
(
b_dh_fragment
,
dh0
[
bb
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
return
kernel
...
...
@@ -444,44 +437,61 @@ def run_test(
num_stages
=
0
,
use_torch
=
False
,
):
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
))
dh_ref
,
dh0_ref
,
dv2_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
))
dh_tilelang
,
dh0_tilelang
,
dv2_tilelang
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
))
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
),
)
dh_ref
,
dh0_ref
,
dv2_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
)
)
dh_tilelang
,
dh0_tilelang
,
dv2_tilelang
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
)
)
# fla ref
print
(
"fla running..."
,
flush
=
True
)
if
use_g
:
dh_ref
,
dh0_ref
,
dv2_ref
=
chunk_gated_delta_rule_bwd_dhu
(
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
,
scale
)
dh_ref
,
dh0_ref
,
dv2_ref
=
chunk_gated_delta_rule_bwd_dhu
(
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
,
scale
)
else
:
G
=
G
.
fill_
(
0
)
dh_ref
,
dh0_ref
,
dv2_ref
=
chunk_gated_delta_rule_bwd_dhu
(
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
,
scale
)
dh_ref
,
dh0_ref
,
dv2_ref
=
chunk_gated_delta_rule_bwd_dhu
(
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
,
scale
)
# tilelang
print
(
"tilelang running..."
,
flush
=
True
)
kernel
=
tilelang_chunk_gated_delta_rule_bwd_dhu
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
scale
,
use_g
,
use_initial_state
,
use_final_state_gradient
,
block_DV
,
threads
,
num_stages
)
kernel
=
tilelang_chunk_gated_delta_rule_bwd_dhu
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
scale
,
use_g
,
use_initial_state
,
use_final_state_gradient
,
block_DV
,
threads
,
num_stages
,
)
# kernel = tilelang.compile(program)
print
(
kernel
.
get_kernel_source
())
dh_tilelang
,
dh0_tilelang
,
dv2_tilelang
=
kernel
(
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
)
fla_time
=
do_bench
(
chunk_gated_delta_rule_bwd_dhu
,
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
,
scale
,
chunk_size
=
chunk_size
)
fla_time
=
do_bench
(
chunk_gated_delta_rule_bwd_dhu
,
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
,
scale
,
chunk_size
=
chunk_size
)
tilelang_time
=
do_bench
(
kernel
,
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
)
print
(
f
"fla time:
{
fla_time
}
ms"
)
...
...
@@ -496,19 +506,47 @@ def run_test(
print
(
"torch running..."
,
flush
=
True
)
if
use_g
:
dh_ref_torch
,
dh0_ref_torch
,
dv2_ref_torch
=
torch_chunk_gated_delta_rule_bwd_dhu
(
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
,
scale
,
use_g
,
use_initial_state
,
use_final_state_gradient
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
))
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
,
scale
,
use_g
,
use_initial_state
,
use_final_state_gradient
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
),
)
dh_ref_torch
=
dh_ref_torch
.
cuda
()
dh0_ref_torch
=
dh0_ref_torch
.
cuda
()
dv2_ref_torch
=
dv2_ref_torch
.
cuda
()
else
:
dh_ref_torch
,
dh0_ref_torch
,
dv2_ref_torch
=
torch_chunk_gated_delta_rule_bwd_dhu
(
Q
,
K
,
W
,
None
,
h0
,
dht
,
dO
,
dv
,
scale
,
use_g
,
use_initial_state
,
use_final_state_gradient
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
))
Q
,
K
,
W
,
None
,
h0
,
dht
,
dO
,
dv
,
scale
,
use_g
,
use_initial_state
,
use_final_state_gradient
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
),
)
dh_ref_torch
=
dh_ref_torch
.
cuda
()
dh0_ref_torch
=
dh0_ref_torch
.
cuda
()
dv2_ref_torch
=
dv2_ref_torch
.
cuda
()
...
...
@@ -554,11 +592,11 @@ def main():
H
=
8
,
DK
=
DK
,
DV
=
128
,
input_dtype
=
"
bfloat16
"
,
output_dtype
=
"
bfloat16
"
,
accum_dtype
=
"
float32
"
,
gate_dtype
=
"
float32
"
,
state_dtype
=
"
float32
"
,
input_dtype
=
T
.
bfloat16
,
output_dtype
=
T
.
bfloat16
,
accum_dtype
=
T
.
float32
,
gate_dtype
=
T
.
float32
,
state_dtype
=
T
.
float32
,
chunk_size
=
64
,
scale
=
DK
**-
0.5
,
use_g
=
True
,
...
...
examples/gdn/example_chunk_delta_h.py
View file @
667632cc
...
...
@@ -3,12 +3,14 @@
import
sys
# noqa: F401
import
tilelang
import
tilelang.language
as
T
from
tilelang.autotuner
import
autotune
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try
:
import
fla
print
(
fla
.
__file__
)
from
fla.ops.common.chunk_delta_h
import
chunk_gated_delta_rule_fwd_h
except
ImportError
:
...
...
@@ -19,7 +21,7 @@ import torch
import
torch.nn.functional
as
F
from
tilelang.engine.callback
import
register_cuda_postproc_callback
# noqa: F401
from
utils
import
*
from
test_
utils
import
assert_similar
# (zhengju) We can slightly modify the generated cuda code from tilelang lowering
# in the debug folder to make the performance better. To enable this callback,
...
...
@@ -55,6 +57,7 @@ def prepare_input(
G
=
F
.
logsigmoid
(
G
)
try
:
from
fla.ops.utils.cumsum
import
chunk_local_cumsum
G
=
chunk_local_cumsum
(
G
,
chunk_size
)
except
ImportError
:
print
(
"fla not found, skip cumsum"
)
...
...
@@ -80,7 +83,21 @@ def prepare_output(
return
h
,
final_state
,
V_new
@
tilelang
.
jit
(
out_idx
=
[
-
3
,
-
2
,
-
1
])
def
get_configs
():
import
itertools
block_DK
=
[
32
,
64
,
128
]
block_DV
=
[
32
,
64
,
128
]
threads
=
[
128
,
256
]
num_stages
=
[
1
,
2
,
3
]
_configs
=
list
(
itertools
.
product
(
block_DK
,
block_DV
,
threads
,
num_stages
))
configs
=
[{
"block_DK"
:
c
[
0
],
"block_DV"
:
c
[
1
],
"threads"
:
c
[
2
],
"num_stages"
:
c
[
3
]}
for
c
in
_configs
]
return
configs
@
autotune
(
configs
=
get_configs
(),
warmup
=
3
,
rep
=
5
)
@
tilelang
.
jit
(
out_idx
=
[
-
3
,
-
2
,
-
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
})
def
tilelang_chunk_gated_delta_rule_fwd_h
(
# task config
B
,
...
...
@@ -94,15 +111,15 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
gate_dtype
,
state_dtype
,
chunk_size
,
use_g
=
True
,
use_initial_state
=
True
,
store_final_state
=
True
,
save_new_value
=
True
,
use_g
,
use_initial_state
,
store_final_state
,
save_new_value
,
# kernel config
block_DK
=
64
,
block_DV
=
64
,
threads
=
256
,
num_stages
=
0
,
block_DV
=
32
,
threads
=
128
,
num_stages
=
1
,
):
block_S
=
chunk_size
BS
=
S
//
block_S
...
...
@@ -118,14 +135,14 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
@
T
.
prim_func
def
kernel
(
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
W
:
T
.
Tensor
(
W_shape
,
dtype
=
input_dtype
),
U
:
T
.
Tensor
(
U_shape
,
dtype
=
input_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
gate_dtype
),
initial_state
:
T
.
Tensor
(
initial_state_shape
,
dtype
=
input_dtype
),
h
:
T
.
Tensor
(
h_shape
,
dtype
=
output_dtype
),
final_state
:
T
.
Tensor
(
final_state_shape
,
dtype
=
state_dtype
),
V_new
:
T
.
Tensor
(
V_shape
,
dtype
=
output_dtype
),
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
W
:
T
.
Tensor
(
W_shape
,
dtype
=
input_dtype
),
U
:
T
.
Tensor
(
U_shape
,
dtype
=
input_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
gate_dtype
),
initial_state
:
T
.
Tensor
(
initial_state_shape
,
dtype
=
input_dtype
),
h
:
T
.
Tensor
(
h_shape
,
dtype
=
output_dtype
),
final_state
:
T
.
Tensor
(
final_state_shape
,
dtype
=
state_dtype
),
V_new
:
T
.
Tensor
(
V_shape
,
dtype
=
output_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
DV
,
block_DV
),
B
*
H
,
threads
=
threads
)
as
(
bv
,
bbh
):
bb
,
bh
=
bbh
//
H
,
bbh
%
H
...
...
@@ -143,35 +160,35 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
G_shared
=
T
.
alloc_shared
((
block_S
,
block_DV
),
dtype
=
gate_dtype
)
G_fragment
=
T
.
alloc_fragment
((
block_S
,
block_DV
),
dtype
=
gate_dtype
)
T
.
annotate_layout
({
b_h_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
b_h_shared
),
U_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
U_shared
),
W_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
W_shared
),
V_new_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
V_new_shared
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
G_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
G_shared
),
})
T
.
annotate_layout
(
{
b_h_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
b_h_shared
),
U_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
U_shared
),
W_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
W_shared
),
V_new_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
V_new_shared
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
G_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
G_shared
),
}
)
T
.
use_swizzle
(
10
)
if
use_initial_state
:
T
.
copy
(
initial_state
[
bb
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
],
b_h_shared
)
T
.
copy
(
initial_state
[
bb
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
],
b_h_shared
)
T
.
copy
(
b_h_shared
,
b_h_fragment
)
else
:
T
.
clear
(
b_h_fragment
)
for
i_s
in
T
.
Pipelined
(
T
.
ceildiv
(
S
,
block_S
),
num_stages
=
num_stages
):
# Store previous result to the hidden tensor, like the epilogue
T
.
copy
(
b_h_shared
,
h
[
bb
,
i_s
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
T
.
copy
(
b_h_shared
,
h
[
bb
,
i_s
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
# Recurrence
T
.
copy
(
W
[
bb
,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
bh
,
0
:
DK
],
W_shared
)
T
.
copy
(
W
[
bb
,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
bh
,
0
:
DK
],
W_shared
)
T
.
gemm
(
W_shared
,
b_h_shared
,
V_new_fragment
,
clear_accum
=
True
)
# U - W * S
T
.
copy
(
U
[
bb
,
i_s
*
block_S
:(
i_s
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:(
bv
+
1
)
*
block_DV
],
U_shared
)
T
.
copy
(
U
[
bb
,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
],
U_shared
)
T
.
copy
(
U_shared
,
U_fragment
)
for
i_s2
,
i_v
in
T
.
Parallel
(
block_S
,
block_DV
):
V_new_fragment
[
i_s2
,
i_v
]
=
-
V_new_fragment
[
i_s2
,
i_v
]
+
U_fragment
[
i_s2
,
i_v
]
...
...
@@ -179,11 +196,9 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
# Save V_new
if
save_new_value
:
T
.
copy
(
V_new_fragment
,
dst
=
V_new_shared
)
T
.
copy
(
V_new_shared
,
V_new
[
bb
,
i_s
*
block_S
:(
i_s
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:(
bv
+
1
)
*
block_DV
])
T
.
copy
(
V_new_shared
,
V_new
[
bb
,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
T
.
copy
(
K
[
bb
,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
bh
,
0
:
DK
],
K_shared
)
T
.
copy
(
K
[
bb
,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
bh
,
0
:
DK
],
K_shared
)
# use_g
if
use_g
:
G_last_local
[
0
]
=
G
[
bb
,
(
i_s
+
1
)
*
block_S
-
1
,
bh
]
...
...
@@ -193,11 +208,12 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
for
i_s2
,
i_v
in
T
.
Parallel
(
block_S
,
block_DV
):
with
T
.
If
(
G_last_local
[
0
]
-
G_fragment
[
i_s2
,
i_v
]
<=
0
):
with
T
.
Then
():
V_new_fragment
[
i_s2
,
i_v
]
=
V_new_fragment
[
i_s2
,
i_v
]
*
T
.
exp
(
G_last_local
[
0
]
-
G_fragment
[
i_s2
,
i_v
])
V_new_fragment
[
i_s2
,
i_v
]
=
V_new_fragment
[
i_s2
,
i_v
]
*
T
.
exp2
(
(
G_last_local
[
0
]
-
G_fragment
[
i_s2
,
i_v
])
*
1.442695
)
with
T
.
Else
():
V_new_fragment
[
i_s2
,
i_v
]
=
0
G_last_local
[
0
]
=
T
.
exp
(
G_last_local
[
0
])
G_last_local
[
0
]
=
T
.
exp
2
(
G_last_local
[
0
]
*
1.442695
)
for
i_k
,
i_v
in
T
.
Parallel
(
DK
,
block_DV
):
b_h_fragment
[
i_k
,
i_v
]
*=
G_last_local
[
0
]
...
...
@@ -209,7 +225,7 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
# Save final state
if
store_final_state
:
T
.
copy
(
b_h_fragment
,
final_state
[
bb
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
T
.
copy
(
b_h_fragment
,
final_state
[
bb
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
return
kernel
...
...
@@ -260,47 +276,77 @@ def run_test(
threads
=
128
,
num_stages
=
0
,
):
K
,
W
,
U
,
G
,
initial_state
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
))
h_ref
,
final_state_ref
,
V_new_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
state_dtype
))
h_tilelang
,
final_state_tilelang
,
V_new_tilelang
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
state_dtype
))
K
,
W
,
U
,
G
,
initial_state
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
)
h_ref
,
final_state_ref
,
V_new_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
state_dtype
)
)
h_tilelang
,
final_state_tilelang
,
V_new_tilelang
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
state_dtype
)
)
# fla ref
h_ref
,
V_new_ref
,
final_state_ref
=
chunk_gated_delta_rule_fwd_h
(
K
,
W
,
U
,
G
,
initial_state
,
store_final_state
,
chunk_size
,
save_new_value
)
h_ref
,
V_new_ref
,
final_state_ref
=
chunk_gated_delta_rule_fwd_h
(
k
=
K
,
w
=
W
,
u
=
U
,
g
=
G
,
initial_state
=
initial_state
,
output_final_state
=
store_final_state
,
chunk_size
=
chunk_size
,
save_new_value
=
save_new_value
,
)
# tilelang
kernel
=
tilelang_chunk_gated_delta_rule_fwd_h
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
use_g
,
use_initial_state
,
store_final_state
,
save_new_value
,
block_DK
,
block_DV
,
threads
,
num_stages
)
kernel
=
tilelang_chunk_gated_delta_rule_fwd_h
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
use_g
,
use_initial_state
,
store_final_state
,
save_new_value
,
)
h_tilelang
,
final_state_tilelang
,
V_new_tilelang
=
kernel
(
K
,
W
,
U
,
G
,
initial_state
)
# (zhengju) If you want to print the generated cuda code, you can uncomment the following line
# print("CUDA Code:\n", kernel.get_kernel_source())
fla_time
=
do_bench
(
chunk_gated_delta_rule_fwd_h
,
K
,
W
,
U
,
G
,
initial_state
,
store_final_state
,
chunk_size
,
save_new_value
)
fla_time
=
do_bench
(
chunk_gated_delta_rule_fwd_h
,
k
=
K
,
w
=
W
,
u
=
U
,
g
=
G
,
initial_state
=
initial_state
,
output_final_state
=
store_final_state
,
chunk_size
=
chunk_size
,
save_new_value
=
save_new_value
,
)
tilelang_time
=
do_bench
(
kernel
,
K
,
W
,
U
,
G
,
initial_state
)
# check correctness
try
:
h_ref_fp32
=
h_ref
.
to
(
torch
.
float32
)
h_tilelang_fp32
=
h_tilelang
.
to
(
torch
.
float32
)
assert_similar
(
h_ref_fp32
,
h_tilelang_fp32
,
eps
=
1e-5
,
name
=
"tilelang chunk gated delta rule fwd h"
,
raise_assert
=
False
)
assert_similar
(
h_ref_fp32
,
h_tilelang_fp32
,
eps
=
1e-5
,
name
=
"tilelang chunk gated delta rule fwd h"
,
raise_assert
=
False
)
print
(
"tilelang chunk gated delta rule fwd h passed √"
)
except
Exception
as
e
:
print
(
"tilelang chunk gated delta rule fwd h failed ✗"
)
...
...
@@ -314,7 +360,8 @@ def run_test(
final_state_tilelang_fp32
,
eps
=
1e-5
,
name
=
"tilelang chunk gated delta rule fwd final_state"
,
raise_assert
=
False
)
raise_assert
=
False
,
)
print
(
"tilelang chunk gated delta rule fwd final_state passed √"
)
except
Exception
as
e
:
print
(
"tilelang chunk gated delta rule fwd final_state failed ✗"
)
...
...
@@ -323,12 +370,7 @@ def run_test(
try
:
V_new_ref_fp32
=
V_new_ref
.
to
(
torch
.
float32
)
V_new_tilelang_fp32
=
V_new_tilelang
.
to
(
torch
.
float32
)
assert_similar
(
V_new_ref_fp32
,
V_new_tilelang_fp32
,
eps
=
1e-5
,
name
=
"tilelang chunk gated delta rule fwd V_new"
,
raise_assert
=
False
)
assert_similar
(
V_new_ref_fp32
,
V_new_tilelang_fp32
,
eps
=
1e-5
,
name
=
"tilelang chunk gated delta rule fwd V_new"
,
raise_assert
=
False
)
print
(
"tilelang chunk gated delta rule fwd V_new passed √"
)
except
Exception
as
e
:
print
(
"tilelang chunk gated delta rule fwd V_new failed ✗"
)
...
...
@@ -345,20 +387,20 @@ def main():
H
=
32
,
DK
=
128
,
DV
=
128
,
input_dtype
=
"
bfloat16
"
,
output_dtype
=
"
bfloat16
"
,
accum_dtype
=
"
float32
"
,
gate_dtype
=
"
float32
"
,
state_dtype
=
"
float32
"
,
input_dtype
=
T
.
bfloat16
,
output_dtype
=
T
.
bfloat16
,
accum_dtype
=
T
.
float32
,
gate_dtype
=
T
.
float32
,
state_dtype
=
T
.
float32
,
chunk_size
=
64
,
use_g
=
True
,
use_initial_state
=
Tru
e
,
use_initial_state
=
Fals
e
,
store_final_state
=
True
,
save_new_value
=
True
,
block_DK
=
64
,
block_DK
=
32
,
block_DV
=
32
,
threads
=
128
,
num_stages
=
1
,
num_stages
=
2
,
)
...
...
examples/gdn/example_chunk_o.py
View file @
667632cc
...
...
@@ -9,6 +9,7 @@ import sys # noqa: F401
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try
:
import
fla
print
(
fla
.
__file__
)
from
fla.ops.common.chunk_o
import
chunk_fwd_o
except
ImportError
:
...
...
@@ -87,16 +88,14 @@ def tilelang_chunk_fwd_o(
@
T
.
prim_func
def
kernel
(
Q
:
T
.
Tensor
(
Q_shape
,
dtype
=
input_dtype
),
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
V
:
T
.
Tensor
(
V_shape
,
dtype
=
input_dtype
),
HIDDEN
:
T
.
Tensor
(
H_shape
,
dtype
=
input_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
gate_dtype
),
O
:
T
.
Tensor
(
O_shape
,
dtype
=
output_dtype
),
Q
:
T
.
Tensor
(
Q_shape
,
dtype
=
input_dtype
),
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
V
:
T
.
Tensor
(
V_shape
,
dtype
=
input_dtype
),
HIDDEN
:
T
.
Tensor
(
H_shape
,
dtype
=
input_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
gate_dtype
),
O
:
T
.
Tensor
(
O_shape
,
dtype
=
output_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
DV
,
block_DV
),
T
.
ceildiv
(
S
,
block_S
),
B
*
H
,
threads
=
threads
)
as
(
bv
,
bs
,
bbh
):
with
T
.
Kernel
(
T
.
ceildiv
(
DV
,
block_DV
),
T
.
ceildiv
(
S
,
block_S
),
B
*
H
,
threads
=
threads
)
as
(
bv
,
bs
,
bbh
):
bb
,
bh
=
bbh
//
H
,
bbh
%
H
Q_shared
=
T
.
alloc_shared
((
block_S
,
block_DK
),
dtype
=
input_dtype
)
K_shared
=
T
.
alloc_shared
((
block_S
,
block_DK
),
dtype
=
input_dtype
)
...
...
@@ -109,28 +108,24 @@ def tilelang_chunk_fwd_o(
G_shared
=
T
.
alloc_shared
((
block_S
,),
dtype
=
gate_dtype
,
scope
=
"shared"
)
G_diff_local
=
T
.
alloc_fragment
((
block_S
,
block_S
),
dtype
=
gate_dtype
)
T
.
annotate_layout
({
Q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
V_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
V_shared
),
H_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
H_shared
),
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
})
T
.
annotate_layout
(
{
Q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
V_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
V_shared
),
H_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
H_shared
),
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
}
)
T
.
clear
(
A_fragment
)
T
.
clear
(
O_fragment
)
T
.
disable_warp_group_reg_alloc
()
for
i_k
in
T
.
Pipelined
(
T
.
ceildiv
(
DK
,
block_DK
),
num_stages
=
num_stages
):
T
.
copy
(
Q
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
],
Q_shared
)
T
.
copy
(
K
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
],
K_shared
)
T
.
copy
(
HIDDEN
[
bb
,
bs
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
,
bv
*
block_DV
:(
bv
+
1
)
*
block_DV
],
H_shared
)
T
.
copy
(
Q
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
],
Q_shared
)
T
.
copy
(
K
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
],
K_shared
)
T
.
copy
(
HIDDEN
[
bb
,
bs
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
],
H_shared
)
T
.
gemm
(
Q_shared
,
H_shared
,
O_fragment
)
T
.
gemm
(
Q_shared
,
K_shared
,
A_fragment
,
transpose_B
=
True
)
...
...
@@ -145,8 +140,7 @@ def tilelang_chunk_fwd_o(
for
i_s1
,
i_s2
in
T
.
Parallel
(
block_S
,
block_S
):
with
T
.
If
(
G_diff_local
[
i_s1
,
i_s2
]
<=
0
):
with
T
.
Then
():
A_fragment
[
i_s1
,
i_s2
]
=
A_fragment
[
i_s1
,
i_s2
]
*
T
.
exp
(
G_diff_local
[
i_s1
,
i_s2
])
A_fragment
[
i_s1
,
i_s2
]
=
A_fragment
[
i_s1
,
i_s2
]
*
T
.
exp
(
G_diff_local
[
i_s1
,
i_s2
])
with
T
.
Else
():
A_fragment
[
i_s1
,
i_s2
]
=
0
...
...
@@ -155,8 +149,7 @@ def tilelang_chunk_fwd_o(
with
T
.
Then
():
A_fragment
[
i_s1
,
i_s2
]
=
0
T
.
copy
(
V
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:(
bv
+
1
)
*
block_DV
],
V_shared
)
T
.
copy
(
V
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
],
V_shared
)
T
.
copy
(
A_fragment
,
A_shared
)
T
.
gemm
(
A_shared
,
V_shared
,
O_fragment
)
...
...
@@ -164,8 +157,7 @@ def tilelang_chunk_fwd_o(
O_fragment
[
i_s
,
i_v
]
=
O_fragment
[
i_s
,
i_v
]
*
scale
T
.
copy
(
O_fragment
,
O_shared
)
T
.
copy
(
O_shared
,
O
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:(
bv
+
1
)
*
block_DV
])
T
.
copy
(
O_shared
,
O
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
return
kernel
...
...
@@ -191,8 +183,9 @@ def run_test(
output_dtype_torch
=
getattr
(
torch
,
output_dtype
)
accum_dtype_torch
=
getattr
(
torch
,
accum_dtype
)
gate_dtype_torch
=
getattr
(
torch
,
gate_dtype
)
Q
,
K
,
V
,
HIDDEN
,
G
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
input_dtype_torch
,
output_dtype_torch
,
accum_dtype_torch
,
gate_dtype_torch
)
Q
,
K
,
V
,
HIDDEN
,
G
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
input_dtype_torch
,
output_dtype_torch
,
accum_dtype_torch
,
gate_dtype_torch
)
scale
=
1.0
/
DK
**
0.5
O_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
output_dtype_torch
)
...
...
@@ -200,9 +193,25 @@ def run_test(
block_S
=
chunk_size
O_tilelang
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
output_dtype_torch
)
kernel
=
tilelang_chunk_fwd_o
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
chunk_size
,
scale
,
use_g
,
block_S
,
block_DK
,
block_DV
,
threads
,
num_stages
)
kernel
=
tilelang_chunk_fwd_o
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
chunk_size
,
scale
,
use_g
,
block_S
,
block_DK
,
block_DV
,
threads
,
num_stages
,
)
O_tilelang
=
kernel
(
Q
,
K
,
V
,
HIDDEN
,
G
)
try
:
...
...
@@ -221,10 +230,10 @@ def main():
DK
=
128
,
DV
=
128
,
chunk_size
=
64
,
input_dtype
=
"
bfloat16
"
,
output_dtype
=
"
bfloat16
"
,
accum_dtype
=
"
float32
"
,
gate_dtype
=
"
float32
"
,
input_dtype
=
T
.
bfloat16
,
output_dtype
=
T
.
bfloat16
,
accum_dtype
=
T
.
float32
,
gate_dtype
=
T
.
float32
,
use_g
=
True
,
block_DK
=
128
,
block_DV
=
128
,
...
...
examples/gdn/example_chunk_o_bwd.py
View file @
667632cc
...
...
@@ -12,6 +12,7 @@ from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F4
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try
:
import
fla
print
(
fla
.
__file__
)
from
fla.ops.common.chunk_o
import
chunk_bwd_dqkwg
except
ImportError
:
...
...
@@ -19,7 +20,7 @@ except ImportError:
fla
=
None
import
torch
from
utils
import
*
from
test_
utils
import
assert_similar
torch
.
random
.
manual_seed
(
0
)
# torch.set_printoptions(profile="full")
...
...
@@ -108,10 +109,8 @@ def prepare_output(
@
tilelang
.
jit
(
out_idx
=
[
-
4
,
-
3
,
-
2
,
-
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
})
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
},
)
def
tilelang_chunk_o_bwd_dqkwg
(
# task config
B
,
...
...
@@ -155,25 +154,23 @@ def tilelang_chunk_o_bwd_dqkwg(
@
T
.
prim_func
def
kernel
(
# input
Q
:
T
.
Tensor
(
Q_shape
,
dtype
=
input_dtype
),
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
V
:
T
.
Tensor
(
V_shape
,
dtype
=
input_dtype
),
h
:
T
.
Tensor
(
h_shape
,
dtype
=
input_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
gate_dtype
),
dO
:
T
.
Tensor
(
dO_shape
,
dtype
=
input_dtype
),
dh
:
T
.
Tensor
(
dh_shape
,
dtype
=
input_dtype
),
dv
:
T
.
Tensor
(
dv_shape
,
dtype
=
input_dtype
),
W
:
T
.
Tensor
(
W_shape
,
dtype
=
input_dtype
),
# output
dq
:
T
.
Tensor
(
dq_shape
,
dtype
=
output_dtype
),
dk
:
T
.
Tensor
(
dk_shape
,
dtype
=
output_dtype
),
dw
:
T
.
Tensor
(
dw_shape
,
dtype
=
output_dtype
),
dg
:
T
.
Tensor
(
dg_shape
,
dtype
=
gate_dtype
),
# input
Q
:
T
.
Tensor
(
Q_shape
,
dtype
=
input_dtype
),
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
V
:
T
.
Tensor
(
V_shape
,
dtype
=
input_dtype
),
h
:
T
.
Tensor
(
h_shape
,
dtype
=
input_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
gate_dtype
),
dO
:
T
.
Tensor
(
dO_shape
,
dtype
=
input_dtype
),
dh
:
T
.
Tensor
(
dh_shape
,
dtype
=
input_dtype
),
dv
:
T
.
Tensor
(
dv_shape
,
dtype
=
input_dtype
),
W
:
T
.
Tensor
(
W_shape
,
dtype
=
input_dtype
),
# output
dq
:
T
.
Tensor
(
dq_shape
,
dtype
=
output_dtype
),
dk
:
T
.
Tensor
(
dk_shape
,
dtype
=
output_dtype
),
dw
:
T
.
Tensor
(
dw_shape
,
dtype
=
output_dtype
),
dg
:
T
.
Tensor
(
dg_shape
,
dtype
=
gate_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
DK
,
block_DK
),
T
.
ceildiv
(
S
,
block_S
),
B
*
H
,
threads
=
threads
)
as
(
bk
,
bs
,
bbh
):
with
T
.
Kernel
(
T
.
ceildiv
(
DK
,
block_DK
),
T
.
ceildiv
(
S
,
block_S
),
B
*
H
,
threads
=
threads
)
as
(
bk
,
bs
,
bbh
):
bb
,
bh
=
bbh
//
H
,
bbh
%
H
V_shared
=
T
.
alloc_shared
((
block_S
,
block_DV
),
dtype
=
input_dtype
)
...
...
@@ -212,15 +209,17 @@ def tilelang_chunk_o_bwd_dqkwg(
T
.
use_swizzle
(
10
)
T
.
annotate_layout
({
V_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
V_shared
),
dO_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dO_shared
),
h_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
h_shared
),
dh_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dh_shared
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
q_shared
),
k_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
k_shared
),
})
T
.
annotate_layout
(
{
V_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
V_shared
),
dO_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dO_shared
),
h_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
h_shared
),
dh_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dh_shared
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
q_shared
),
k_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
k_shared
),
}
)
T
.
clear
(
dg_last_local
)
T
.
clear
(
G_last_local
)
...
...
@@ -235,18 +234,10 @@ def tilelang_chunk_o_bwd_dqkwg(
T
.
clear
(
dw_fragment
)
for
i_v
in
T
.
Pipelined
(
T
.
ceildiv
(
DV
,
block_DV
),
num_stages
=
num_stages
):
T
.
copy
(
V
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:(
i_v
+
1
)
*
block_DV
],
V_shared
)
T
.
copy
(
dO
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:(
i_v
+
1
)
*
block_DV
],
dO_shared
)
T
.
copy
(
h
[
bb
,
bs
,
bh
,
bk
*
block_DK
:(
bk
+
1
)
*
block_DK
,
i_v
*
block_DV
:(
i_v
+
1
)
*
block_DV
],
h_shared
)
T
.
copy
(
dh
[
bb
,
bs
,
bh
,
bk
*
block_DK
:(
bk
+
1
)
*
block_DK
,
i_v
*
block_DV
:(
i_v
+
1
)
*
block_DV
],
dh_shared
)
T
.
copy
(
V
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:
(
i_v
+
1
)
*
block_DV
],
V_shared
)
T
.
copy
(
dO
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:
(
i_v
+
1
)
*
block_DV
],
dO_shared
)
T
.
copy
(
h
[
bb
,
bs
,
bh
,
bk
*
block_DK
:
(
bk
+
1
)
*
block_DK
,
i_v
*
block_DV
:
(
i_v
+
1
)
*
block_DV
],
h_shared
)
T
.
copy
(
dh
[
bb
,
bs
,
bh
,
bk
*
block_DK
:
(
bk
+
1
)
*
block_DK
,
i_v
*
block_DV
:
(
i_v
+
1
)
*
block_DV
],
dh_shared
)
if
use_g
:
T
.
clear
(
dg_last_fragment_scalar
)
...
...
@@ -254,9 +245,7 @@ def tilelang_chunk_o_bwd_dqkwg(
# for i_kv in T.Parallel(block_DK * block_DV):
# dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV]
for
i_kv
in
T
.
Parallel
(
block_DK
*
block_DV
):
dg_last_fragment
[
i_kv
]
=
h_shared
[
i_kv
//
block_DV
,
i_kv
%
block_DV
]
*
dh_shared
[
i_kv
//
block_DV
,
i_kv
%
block_DV
]
dg_last_fragment
[
i_kv
]
=
h_shared
[
i_kv
//
block_DV
,
i_kv
%
block_DV
]
*
dh_shared
[
i_kv
//
block_DV
,
i_kv
%
block_DV
]
T
.
reduce_sum
(
dg_last_fragment
,
dg_last_fragment_scalar
,
dim
=-
1
,
clear
=
False
)
dg_last_local
[
0
]
+=
dg_last_fragment_scalar
[
0
]
...
...
@@ -265,22 +254,16 @@ def tilelang_chunk_o_bwd_dqkwg(
T
.
gemm
(
V_shared
,
dh_shared
,
dk_fragment
,
transpose_B
=
True
)
if
use_dw
:
T
.
copy
(
dv
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:(
i_v
+
1
)
*
block_DV
],
dv_shared
)
T
.
copy
(
dv
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:
(
i_v
+
1
)
*
block_DV
],
dv_shared
)
T
.
gemm
(
dv_shared
,
h_shared
,
dw_fragment
,
transpose_B
=
True
)
if
use_dw
:
for
i_s
,
i_k
in
T
.
Parallel
(
block_S
,
block_DK
):
dw_fragment
[
i_s
,
i_k
]
=
-
dw_fragment
[
i_s
,
i_k
]
T
.
copy
(
dw_fragment
,
dw
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:(
bk
+
1
)
*
block_DK
])
T
.
copy
(
Q
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:(
bk
+
1
)
*
block_DK
],
q_shared
)
T
.
copy
(
K
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:(
bk
+
1
)
*
block_DK
],
k_shared
)
T
.
copy
(
dw_fragment
,
dw
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:
(
bk
+
1
)
*
block_DK
])
T
.
copy
(
Q
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:
(
bk
+
1
)
*
block_DK
],
q_shared
)
T
.
copy
(
K
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:
(
bk
+
1
)
*
block_DK
],
k_shared
)
T
.
copy
(
q_shared
,
q_fragment
)
T
.
copy
(
k_shared
,
k_fragment
)
...
...
@@ -294,8 +277,7 @@ def tilelang_chunk_o_bwd_dqkwg(
dg_last_local
[
0
]
=
dg_last_local
[
0
]
*
T
.
exp
(
G
[
bb
,
bs
*
block_S
+
block_S
-
1
,
bh
])
for
i_s
,
i_k
in
T
.
Parallel
(
block_S
,
block_DK
):
dq_fragment
[
i_s
,
i_k
]
=
dq_fragment
[
i_s
,
i_k
]
*
T
.
exp
(
G
[
bb
,
bs
*
block_S
+
i_s
,
bh
])
*
scale
dq_fragment
[
i_s
,
i_k
]
=
dq_fragment
[
i_s
,
i_k
]
*
T
.
exp
(
G
[
bb
,
bs
*
block_S
+
i_s
,
bh
])
*
scale
T
.
clear
(
dg_fragment_reduce_tmp
)
for
i_s
,
i_k
in
T
.
Parallel
(
block_S
,
block_DK
):
dg_fragment_reduce_tmp
[
i_s
,
i_k
]
=
dq_fragment
[
i_s
,
i_k
]
*
q_shared
[
i_s
,
i_k
]
...
...
@@ -305,8 +287,7 @@ def tilelang_chunk_o_bwd_dqkwg(
for
i_s
,
i_k
in
T
.
Parallel
(
block_S
,
block_DK
):
with
T
.
If
(
G_last_local
[
0
]
-
G
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
<=
0
):
with
T
.
Then
():
dk_fragment
[
i_s
,
i_k
]
=
dk_fragment
[
i_s
,
i_k
]
*
T
.
exp
(
G_last_local
[
0
]
-
G
[
bb
,
bs
*
block_S
+
i_s
,
bh
])
dk_fragment
[
i_s
,
i_k
]
=
dk_fragment
[
i_s
,
i_k
]
*
T
.
exp
(
G_last_local
[
0
]
-
G
[
bb
,
bs
*
block_S
+
i_s
,
bh
])
with
T
.
Else
():
dk_fragment
[
i_s
,
i_k
]
=
0
T
.
clear
(
dg_fragment_reduce_tmp
)
...
...
@@ -325,12 +306,11 @@ def tilelang_chunk_o_bwd_dqkwg(
dg_last_local
[
1
]
=
dg_last_fragment_scalar_2
[
0
]
for
i_s1
,
i_s2
in
T
.
Parallel
(
block_S
,
block_S
):
with
T
.
If
(
i_s1
>=
i_s2
and
G
[
bb
,
bs
*
block_S
+
i_s1
,
bh
]
-
G
[
bb
,
bs
*
block_S
+
i_s2
,
bh
]
<=
0
):
with
T
.
If
(
i_s1
>=
i_s2
and
G
[
bb
,
bs
*
block_S
+
i_s1
,
bh
]
-
G
[
bb
,
bs
*
block_S
+
i_s2
,
bh
]
<=
0
):
with
T
.
Then
():
ds_fragment
[
i_s1
,
i_s2
]
=
ds_fragment
[
i_s1
,
i_s2
]
*
T
.
exp
(
G
[
bb
,
bs
*
block_S
+
i_s1
,
bh
]
-
G
[
bb
,
bs
*
block_S
+
i_s2
,
bh
])
*
scale
ds_fragment
[
i_s1
,
i_s2
]
=
(
ds_fragment
[
i_s1
,
i_s2
]
*
T
.
exp
(
G
[
bb
,
bs
*
block_S
+
i_s1
,
bh
]
-
G
[
bb
,
bs
*
block_S
+
i_s2
,
bh
])
*
scale
)
with
T
.
Else
():
ds_fragment
[
i_s1
,
i_s2
]
=
0
...
...
@@ -338,8 +318,7 @@ def tilelang_chunk_o_bwd_dqkwg(
T
.
clear
(
ds_fragment_positive_transpose
)
T
.
gemm
(
q_shared
,
k_shared
,
ds_fragment_positive
,
transpose_B
=
True
)
for
i_s1
,
i_s2
in
T
.
Parallel
(
block_S
,
block_S
):
ds_fragment_positive
[
i_s1
,
i_s2
]
=
ds_fragment
[
i_s1
,
i_s2
]
*
ds_fragment_positive
[
i_s1
,
i_s2
]
ds_fragment_positive
[
i_s1
,
i_s2
]
=
ds_fragment
[
i_s1
,
i_s2
]
*
ds_fragment_positive
[
i_s1
,
i_s2
]
# FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass
T
.
reduce_sum
(
ds_fragment_positive
,
dg_fragment
,
dim
=
1
,
clear
=
False
)
...
...
@@ -363,15 +342,10 @@ def tilelang_chunk_o_bwd_dqkwg(
for
i_s
in
T
.
Parallel
(
block_S
):
with
T
.
If
(
i_s
>=
block_S
-
1
):
# noqa: SIM117
with
T
.
Then
():
dg_fragment_final
[
i_s
]
=
dg_fragment_final
[
i_s
]
+
dg_last_local
[
0
]
+
dg_last_local
[
1
]
T
.
copy
(
dq_fragment
,
dq
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:(
bk
+
1
)
*
block_DK
])
T
.
copy
(
dk_fragment
,
dk
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:(
bk
+
1
)
*
block_DK
])
dg_fragment_final
[
i_s
]
=
dg_fragment_final
[
i_s
]
+
dg_last_local
[
0
]
+
dg_last_local
[
1
]
T
.
copy
(
dq_fragment
,
dq
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:
(
bk
+
1
)
*
block_DK
])
T
.
copy
(
dk_fragment
,
dk
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:
(
bk
+
1
)
*
block_DK
])
for
i_s
in
T
.
Parallel
(
block_S
):
dg
[
bk
,
bb
,
bs
*
block_S
+
i_s
,
bh
]
=
dg_fragment_final
[
i_s
]
...
...
@@ -387,12 +361,8 @@ def tilelang_chunk_o_bwd_dqkwg(
for
i_s
,
i_k
in
T
.
Parallel
(
block_S
,
block_DK
):
dq_fragment
[
i_s
,
i_k
]
=
dq_fragment
[
i_s
,
i_k
]
*
scale
dk_fragment
[
i_s
,
i_k
]
=
dk_fragment
[
i_s
,
i_k
]
+
dk_fragment_2
[
i_s
,
i_k
]
*
scale
T
.
copy
(
dq_fragment
,
dq
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:(
bk
+
1
)
*
block_DK
])
T
.
copy
(
dk_fragment
,
dk
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:(
bk
+
1
)
*
block_DK
])
T
.
copy
(
dq_fragment
,
dq
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:
(
bk
+
1
)
*
block_DK
])
T
.
copy
(
dk_fragment
,
dk
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:
(
bk
+
1
)
*
block_DK
])
return
kernel
...
...
@@ -442,33 +412,53 @@ def run_test(
threads
=
256
,
num_stages
=
0
,
):
Q
,
K
,
V
,
h
,
G
,
dO
,
dh
,
dv
,
W
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
))
dq_ref
,
dk_ref
,
dw_ref
,
dg_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
),
block_DK
)
Q
,
K
,
V
,
h
,
G
,
dO
,
dh
,
dv
,
W
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
),
)
dq_ref
,
dk_ref
,
dw_ref
,
dg_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
),
block_DK
)
dq_tilelang
,
dk_tilelang
,
dw_tilelang
,
dg_tilelang
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
),
block_DK
)
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
),
block_DK
)
# ref
if
use_g
:
dq_ref
,
dk_ref
,
dw_ref
,
dg_ref
=
chunk_bwd_dqkwg
(
Q
,
K
,
V
,
G
,
dO
,
h
,
dh
,
dv
,
W
,
chunk_size
=
chunk_size
,
scale
=
scale
)
dq_ref
,
dk_ref
,
dw_ref
,
dg_ref
=
chunk_bwd_dqkwg
(
Q
,
K
,
V
,
G
,
dO
,
h
,
dh
,
dv
,
W
,
chunk_size
=
chunk_size
,
scale
=
scale
)
else
:
dq_ref
,
dk_ref
,
dw_ref
,
dg_ref
=
chunk_bwd_dqkwg
(
Q
,
K
,
V
,
None
,
dO
,
h
,
dh
,
dv
,
W
,
chunk_size
=
chunk_size
,
scale
=
scale
)
dq_ref
,
dk_ref
,
dw_ref
,
dg_ref
=
chunk_bwd_dqkwg
(
Q
,
K
,
V
,
None
,
dO
,
h
,
dh
,
dv
,
W
,
chunk_size
=
chunk_size
,
scale
=
scale
)
# tilelang
kernel
=
tilelang_chunk_o_bwd_dqkwg
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
scale
,
use_g
,
use_dw
,
block_DK
,
block_DV
,
threads
,
num_stages
)
print
(
kernel
.
get_kernel_source
())
kernel
=
tilelang_chunk_o_bwd_dqkwg
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
scale
,
use_g
,
use_dw
,
block_DK
,
block_DV
,
threads
,
num_stages
,
)
dq_tilelang
,
dk_tilelang
,
dw_tilelang
,
dg_tilelang
=
kernel
(
Q
,
K
,
V
,
h
,
G
,
dO
,
dh
,
dv
,
W
)
if
use_g
:
...
...
@@ -515,11 +505,11 @@ def main():
H
=
8
,
DK
=
DK
,
DV
=
DV
,
input_dtype
=
"
bfloat16
"
,
output_dtype
=
"
bfloat16
"
,
accum_dtype
=
"
float32
"
,
gate_dtype
=
"
float32
"
,
state_dtype
=
"
float32
"
,
input_dtype
=
T
.
bfloat16
,
output_dtype
=
T
.
bfloat16
,
accum_dtype
=
T
.
float32
,
gate_dtype
=
T
.
float32
,
state_dtype
=
T
.
float32
,
chunk_size
=
64
,
scale
=
DK
**-
0.5
,
# scale=1,
...
...
examples/gdn/example_chunk_scaled_dot_kkt.py
View file @
667632cc
...
...
@@ -9,6 +9,7 @@ import sys # noqa: F401
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try
:
import
fla
print
(
fla
.
__file__
)
from
fla.ops.common.chunk_scaled_dot_kkt
import
chunk_scaled_dot_kkt_fwd
except
ImportError
:
...
...
@@ -56,9 +57,9 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
H
,
DK
,
chunk_size
=
64
,
input_dtype
=
"
bfloat16
"
,
output_dtype
=
"
bfloat16
"
,
accum_dtype
=
"
float32
"
,
input_dtype
=
T
.
bfloat16
,
output_dtype
=
T
.
bfloat16
,
accum_dtype
=
T
.
float32
,
use_g
=
True
,
# kernel config
block_S
=
64
,
...
...
@@ -75,10 +76,10 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
@
T
.
prim_func
def
kernel
(
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
Beta
:
T
.
Tensor
(
Beta_shape
,
dtype
=
input_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
accum_dtype
),
A
:
T
.
Tensor
(
output_shape
,
dtype
=
output_dtype
),
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
Beta
:
T
.
Tensor
(
Beta_shape
,
dtype
=
input_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
accum_dtype
),
A
:
T
.
Tensor
(
output_shape
,
dtype
=
output_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
S
,
block_S
),
B
*
H
,
threads
=
threads
)
as
(
bs
,
bbh
):
bb
,
bh
=
bbh
//
H
,
bbh
%
H
...
...
@@ -93,10 +94,12 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
G_shared
=
T
.
alloc_shared
((
block_S
,),
dtype
=
accum_dtype
,
scope
=
"shared"
)
G_diff_local
=
T
.
alloc_fragment
((
block_S
,
block_S
),
dtype
=
accum_dtype
)
T
.
annotate_layout
({
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
})
T
.
annotate_layout
(
{
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
}
)
T
.
fill
(
A_fragment
,
0
)
T
.
disable_warp_group_reg_alloc
()
...
...
@@ -104,9 +107,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
Beta_shared
[
i_s
]
=
Beta
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
for
i_k
in
T
.
Pipelined
(
T
.
ceildiv
(
DK
,
block_DK
),
num_stages
=
num_stages
):
T
.
copy
(
K
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
],
K_shared
)
T
.
copy
(
K
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
],
K_shared
)
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
Beta_K_fragment
[
i_s
,
i_k2
]
=
K_shared
[
i_s
,
i_k2
]
*
Beta_shared
[
i_s
]
T
.
gemm
(
Beta_K_fragment
,
K_shared
,
A_fragment
,
transpose_B
=
True
)
...
...
@@ -119,8 +120,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
for
i_s1
,
i_s2
in
T
.
Parallel
(
block_S
,
block_S
):
with
T
.
If
(
G_diff_local
[
i_s1
,
i_s2
]
<=
0
and
i_s1
>
i_s2
):
with
T
.
Then
():
A_fragment
[
i_s1
,
i_s2
]
=
A_fragment
[
i_s1
,
i_s2
]
*
T
.
exp
(
G_diff_local
[
i_s1
,
i_s2
])
A_fragment
[
i_s1
,
i_s2
]
=
A_fragment
[
i_s1
,
i_s2
]
*
T
.
exp
(
G_diff_local
[
i_s1
,
i_s2
])
with
T
.
Else
():
A_fragment
[
i_s1
,
i_s2
]
=
0
else
:
...
...
@@ -130,7 +130,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
A_fragment
[
i_s1
,
i_s2
]
=
0
T
.
copy
(
A_fragment
,
A_shared
)
T
.
copy
(
A_shared
,
A
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:])
T
.
copy
(
A_shared
,
A
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:])
return
kernel
...
...
@@ -149,24 +149,21 @@ def run_test(
threads
,
num_stages
,
):
K
,
Beta
,
G
=
prepare_input
(
B
,
S
,
H
,
DK
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
))
K
,
Beta
,
G
=
prepare_input
(
B
,
S
,
H
,
DK
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
))
A_ref
=
prepare_output
(
B
,
S
,
H
,
chunk_size
,
getattr
(
torch
,
output_dtype
))
A_tilelang
=
prepare_output
(
B
,
S
,
H
,
chunk_size
,
getattr
(
torch
,
output_dtype
))
# reference
if
use_g
:
A_ref
=
chunk_scaled_dot_kkt_fwd
(
K
,
Beta
,
G
,
chunk_size
=
chunk_size
,
output_dtype
=
getattr
(
torch
,
output_dtype
))
A_ref
=
chunk_scaled_dot_kkt_fwd
(
K
,
Beta
,
G
,
chunk_size
=
chunk_size
,
output_dtype
=
getattr
(
torch
,
output_dtype
))
else
:
A_ref
=
chunk_scaled_dot_kkt_fwd
(
K
,
Beta
,
None
,
chunk_size
=
chunk_size
,
output_dtype
=
getattr
(
torch
,
output_dtype
))
A_ref
=
chunk_scaled_dot_kkt_fwd
(
K
,
Beta
,
None
,
chunk_size
=
chunk_size
,
output_dtype
=
getattr
(
torch
,
output_dtype
))
# tilelang
block_S
=
chunk_size
kernel
=
tilelang_chunk_scaled_dot_kkt_fwd
(
B
,
S
,
H
,
DK
,
chunk_size
,
input_dtype
,
output_dtype
,
accum_dtype
,
use_g
,
block_S
,
block_DK
,
threads
,
num_stages
)
kernel
=
tilelang_chunk_scaled_dot_kkt_fwd
(
B
,
S
,
H
,
DK
,
chunk_size
,
input_dtype
,
output_dtype
,
accum_dtype
,
use_g
,
block_S
,
block_DK
,
threads
,
num_stages
)
A_tilelang
=
kernel
(
K
,
Beta
,
G
)
try
:
...
...
@@ -186,13 +183,14 @@ def main():
H
=
32
,
DK
=
128
,
chunk_size
=
64
,
input_dtype
=
"
bfloat16
"
,
output_dtype
=
"
bfloat16
"
,
accum_dtype
=
"
float32
"
,
input_dtype
=
T
.
bfloat16
,
output_dtype
=
T
.
bfloat16
,
accum_dtype
=
T
.
float32
,
use_g
=
True
,
block_DK
=
64
,
threads
=
128
,
num_stages
=
2
)
num_stages
=
2
,
)
if
__name__
==
"__main__"
:
...
...
examples/gdn/example_cumsum.py
View file @
667632cc
...
...
@@ -10,6 +10,7 @@ import sys # noqa: F401
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try
:
import
fla
print
(
fla
.
__file__
)
from
fla.ops.utils.cumsum
import
chunk_local_cumsum_scalar
except
ImportError
:
...
...
@@ -20,11 +21,8 @@ import torch
@
tilelang
.
jit
(
out_idx
=
[
-
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
})
out_idx
=
[
-
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
}
)
def
tilelang_chunk_local_cumsum_scalar
(
# task config
B
,
...
...
@@ -34,43 +32,43 @@ def tilelang_chunk_local_cumsum_scalar(
is_varlen
=
False
,
head_first
=
False
,
reverse
=
False
,
input_dtype
=
"
float16
"
,
output_dtype
=
"
float32
"
,
input_dtype
=
T
.
float16
,
output_dtype
=
T
.
float32
,
# kernel config
block_S
=
64
,
threads
=
256
,
use_fragment
=
False
,
):
G_shape
=
(
B
,
H
,
S
)
if
head_first
else
(
B
,
S
,
H
)
assert
chunk_size
==
2
**
(
chunk_size
.
bit_length
()
-
1
),
"chunk_size must be a power of 2"
assert
chunk_size
==
2
**
(
chunk_size
.
bit_length
()
-
1
),
"chunk_size must be a power of 2"
assert
chunk_size
==
block_S
,
"chunk_size must be equal to block_S"
@
T
.
prim_func
def
kernel
(
G
:
T
.
Tensor
(
G_shape
,
dtype
=
input_dtype
),
G_new
:
T
.
Tensor
(
G_shape
,
dtype
=
output_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
input_dtype
),
G_new
:
T
.
Tensor
(
G_shape
,
dtype
=
output_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
S
,
block_S
),
B
*
H
,
threads
=
threads
)
as
(
bs
,
bbh
):
bb
,
bh
=
bbh
//
H
,
bbh
%
H
G_shared
=
T
.
alloc_shared
((
1
,
block_S
),
dtype
=
output_dtype
,
scope
=
"shared"
)
if
head_first
:
T
.
copy
(
G
[
bb
,
bh
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
],
G_shared
)
T
.
copy
(
G
[
bb
,
bh
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
],
G_shared
)
else
:
T
.
copy
(
G
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
],
G_shared
)
T
.
copy
(
G
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
],
G_shared
)
if
use_fragment
:
G_fragment
=
T
.
alloc_fragment
((
1
,
block_S
),
dtype
=
output_dtype
,
scope
=
"shared"
)
T
.
copy
(
G_shared
,
G_fragment
)
T
.
cumsum
(
G_fragment
,
dim
=
1
,
reverse
=
reverse
)
if
head_first
:
T
.
copy
(
G_fragment
,
G_new
[
bb
,
bh
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
])
T
.
copy
(
G_fragment
,
G_new
[
bb
,
bh
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
])
else
:
T
.
copy
(
G_fragment
,
G_new
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
])
T
.
copy
(
G_fragment
,
G_new
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
])
else
:
T
.
cumsum
(
G_shared
,
dim
=
1
,
reverse
=
reverse
)
if
head_first
:
T
.
copy
(
G_shared
,
G_new
[
bb
,
bh
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
])
T
.
copy
(
G_shared
,
G_new
[
bb
,
bh
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
])
else
:
T
.
copy
(
G_shared
,
G_new
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
])
T
.
copy
(
G_shared
,
G_new
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
])
return
kernel
...
...
@@ -113,11 +111,8 @@ def run_test(
# reference cumsum
G_new_ref
=
chunk_local_cumsum_scalar
(
g
=
G
,
chunk_size
=
chunk_size
,
reverse
=
reverse
,
head_first
=
head_first
,
output_dtype
=
getattr
(
torch
,
output_dtype
))
g
=
G
,
chunk_size
=
chunk_size
,
reverse
=
reverse
,
head_first
=
head_first
,
output_dtype
=
getattr
(
torch
,
output_dtype
)
)
# tilelang cumsum
block_S
=
chunk_size
...
...
@@ -159,10 +154,11 @@ def main():
chunk_size
=
64
,
reverse
=
True
,
head_first
=
False
,
input_dtype
=
"
float32
"
,
output_dtype
=
"
float32
"
,
input_dtype
=
T
.
float32
,
output_dtype
=
T
.
float32
,
threads
=
256
,
use_fragment
=
False
)
use_fragment
=
False
,
)
if
__name__
==
"__main__"
:
...
...
examples/gdn/example_wy_fast.py
View file @
667632cc
...
...
@@ -9,6 +9,7 @@ import sys # noqa: F401
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try
:
import
fla
print
(
fla
.
__file__
)
from
fla.ops.gated_delta_rule.wy_fast
import
recompute_w_u_fwd
except
ImportError
:
...
...
@@ -73,13 +74,13 @@ def tilelang_recompute_w_u_fwd(
@
T
.
prim_func
def
kernel
(
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
V
:
T
.
Tensor
(
V_shape
,
dtype
=
input_dtype
),
Beta
:
T
.
Tensor
(
Beta_shape
,
dtype
=
input_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
gate_dtype
),
A
:
T
.
Tensor
(
A_shape
,
dtype
=
output_dtype
),
W
:
T
.
Tensor
(
K_shape
,
dtype
=
output_dtype
),
U
:
T
.
Tensor
(
V_shape
,
dtype
=
output_dtype
),
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
V
:
T
.
Tensor
(
V_shape
,
dtype
=
input_dtype
),
Beta
:
T
.
Tensor
(
Beta_shape
,
dtype
=
input_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
gate_dtype
),
A
:
T
.
Tensor
(
A_shape
,
dtype
=
output_dtype
),
W
:
T
.
Tensor
(
K_shape
,
dtype
=
output_dtype
),
U
:
T
.
Tensor
(
V_shape
,
dtype
=
output_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
S
,
block_S
),
B
*
H
,
threads
=
threads
)
as
(
bs
,
bbh
):
bb
,
bh
=
bbh
//
H
,
bbh
%
H
...
...
@@ -95,49 +96,42 @@ def tilelang_recompute_w_u_fwd(
W_Beta_shared
=
T
.
alloc_shared
((
block_S
,
block_DK
),
dtype
=
input_dtype
)
U_Beta_shared
=
T
.
alloc_shared
((
block_S
,
block_DV
),
dtype
=
input_dtype
)
T
.
annotate_layout
({
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
V_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
V_shared
),
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
W_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
W_shared
),
U_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
U_shared
),
W_Beta_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
W_Beta_shared
),
U_Beta_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
U_Beta_shared
),
})
T
.
annotate_layout
(
{
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
V_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
V_shared
),
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
W_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
W_shared
),
U_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
U_shared
),
W_Beta_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
W_Beta_shared
),
U_Beta_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
U_Beta_shared
),
}
)
T
.
disable_warp_group_reg_alloc
()
for
i_s
in
T
.
Parallel
(
block_S
):
Beta_shared
[
i_s
]
=
Beta
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
G_shared
[
i_s
]
=
T
.
exp
(
G
[
bb
,
bs
*
block_S
+
i_s
,
bh
])
T
.
copy
(
A
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:],
A_shared
)
T
.
copy
(
A
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:],
A_shared
)
for
i_v
in
T
.
Pipelined
(
T
.
ceildiv
(
DV
,
block_DV
),
num_stages
=
num_stages
):
T
.
copy
(
V
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:(
i_v
+
1
)
*
block_DV
],
V_shared
)
T
.
copy
(
V
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:
(
i_v
+
1
)
*
block_DV
],
V_shared
)
for
i_s
,
i_v2
in
T
.
Parallel
(
block_S
,
block_DV
):
U_Beta_shared
[
i_s
,
i_v2
]
=
V_shared
[
i_s
,
i_v2
]
*
Beta_shared
[
i_s
]
T
.
gemm
(
A_shared
,
U_Beta_shared
,
U_fragment
,
clear_accum
=
True
)
# First copy to smem, then copy to gmem to reduce U2RU instructions
T
.
copy
(
U_fragment
,
U_shared
)
T
.
copy
(
U_shared
,
U
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:(
i_v
+
1
)
*
block_DV
])
T
.
copy
(
U_shared
,
U
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:
(
i_v
+
1
)
*
block_DV
])
for
i_k
in
T
.
Pipelined
(
T
.
ceildiv
(
DK
,
block_DK
),
num_stages
=
num_stages
):
T
.
copy
(
K
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
],
K_shared
)
T
.
copy
(
K
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
],
K_shared
)
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
W_Beta_shared
[
i_s
,
i_k2
]
=
K_shared
[
i_s
,
i_k2
]
*
Beta_shared
[
i_s
]
*
G_shared
[
i_s
]
W_Beta_shared
[
i_s
,
i_k2
]
=
K_shared
[
i_s
,
i_k2
]
*
Beta_shared
[
i_s
]
*
G_shared
[
i_s
]
T
.
gemm
(
A_shared
,
W_Beta_shared
,
W_fragment
,
clear_accum
=
True
)
# First copy to smem, then copy to gmem to reduce U2RU instructions
T
.
copy
(
W_fragment
,
W_shared
)
T
.
copy
(
W_shared
,
W
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
])
T
.
copy
(
W_shared
,
W
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
])
return
kernel
...
...
@@ -159,15 +153,8 @@ def run_test(
num_stages
,
):
K
,
V
,
Beta
,
G
,
A
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
gate_dtype
=
getattr
(
torch
,
gate_dtype
))
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
gate_dtype
=
getattr
(
torch
,
gate_dtype
)
)
W_ref
,
U_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
getattr
(
torch
,
output_dtype
))
W_tilelang
,
U_tilelang
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
getattr
(
torch
,
output_dtype
))
...
...
@@ -191,7 +178,8 @@ def run_test(
block_DK
=
block_DK
,
block_DV
=
block_DV
,
threads
=
threads
,
num_stages
=
num_stages
)
num_stages
=
num_stages
,
)
print
(
kernel
.
get_kernel_source
())
W_tilelang
,
U_tilelang
=
kernel
(
K
,
V
,
Beta
,
G
,
A
)
...
...
@@ -217,14 +205,15 @@ def main():
DK
=
128
,
DV
=
128
,
chunk_size
=
64
,
input_dtype
=
"
bfloat16
"
,
output_dtype
=
"
bfloat16
"
,
gate_dtype
=
"
float32
"
,
accum_dtype
=
"
float32
"
,
input_dtype
=
T
.
bfloat16
,
output_dtype
=
T
.
bfloat16
,
gate_dtype
=
T
.
float32
,
accum_dtype
=
T
.
float32
,
block_DK
=
64
,
block_DV
=
32
,
threads
=
128
,
num_stages
=
3
)
num_stages
=
3
,
)
if
__name__
==
"__main__"
:
...
...
examples/gdn/example_wy_fast_bwd_split.py
View file @
667632cc
...
...
@@ -10,6 +10,7 @@ import tilelang.language as T
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try
:
import
fla
print
(
fla
.
__file__
)
from
fla.ops.gated_delta_rule.wy_fast
import
bwd_prepare_wy_repr
except
ImportError
:
...
...
@@ -93,10 +94,8 @@ def prepare_output(
@
tilelang
.
jit
(
out_idx
=
[
-
5
,
-
4
,
-
3
,
-
2
,
-
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
})
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
},
)
def
tilelang_wy_fast_bwd
(
# task config
B
,
...
...
@@ -135,20 +134,20 @@ def tilelang_wy_fast_bwd(
@
T
.
prim_func
def
kernel
(
# input
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
V
:
T
.
Tensor
(
V_shape
,
dtype
=
input_dtype
),
Beta
:
T
.
Tensor
(
Beta_shape
,
dtype
=
input_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
gate_dtype
),
A
:
T
.
Tensor
(
A_shape
,
dtype
=
input_dtype
),
dw
:
T
.
Tensor
(
dw_shape
,
dtype
=
input_dtype
),
du
:
T
.
Tensor
(
du_shape
,
dtype
=
input_dtype
),
# output
dA
:
T
.
Tensor
(
dA_shape
,
dtype
=
input_dtype
),
dk
:
T
.
Tensor
(
dk_shape
,
dtype
=
output_dtype
),
dv
:
T
.
Tensor
(
dv_shape
,
dtype
=
output_dtype
),
dbeta
:
T
.
Tensor
(
dbeta_shape
,
dtype
=
output_dtype
),
dg
:
T
.
Tensor
(
dg_shape
,
dtype
=
gate_dtype
),
# input
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
V
:
T
.
Tensor
(
V_shape
,
dtype
=
input_dtype
),
Beta
:
T
.
Tensor
(
Beta_shape
,
dtype
=
input_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
gate_dtype
),
A
:
T
.
Tensor
(
A_shape
,
dtype
=
input_dtype
),
dw
:
T
.
Tensor
(
dw_shape
,
dtype
=
input_dtype
),
du
:
T
.
Tensor
(
du_shape
,
dtype
=
input_dtype
),
# output
dA
:
T
.
Tensor
(
dA_shape
,
dtype
=
input_dtype
),
dk
:
T
.
Tensor
(
dk_shape
,
dtype
=
output_dtype
),
dv
:
T
.
Tensor
(
dv_shape
,
dtype
=
output_dtype
),
dbeta
:
T
.
Tensor
(
dbeta_shape
,
dtype
=
output_dtype
),
dg
:
T
.
Tensor
(
dg_shape
,
dtype
=
gate_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
S
,
block_S
),
B
*
H
,
threads
=
threads
)
as
(
bs
,
bbh
):
bb
,
bh
=
bbh
//
H
,
bbh
%
H
...
...
@@ -187,7 +186,7 @@ def tilelang_wy_fast_bwd(
T
.
clear
(
dbeta_fragment_v
)
T
.
clear
(
dg_fragment
)
T
.
copy
(
A
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:],
A_shared
)
T
.
copy
(
A
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:],
A_shared
)
for
i_s
in
T
.
Parallel
(
block_S
):
Beta_shared
[
i_s
]
=
Beta
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
G_shared
[
i_s
]
=
G
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
...
...
@@ -195,51 +194,37 @@ def tilelang_wy_fast_bwd(
# Update dk
for
i_k
in
T
.
Pipelined
(
T
.
ceildiv
(
DK
,
block_DK
),
num_stages
=
num_stages
):
T
.
copy
(
K
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
],
K_shared
)
T
.
copy
(
K
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
],
K_shared
)
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
K_shared_beta_g
[
i_s
,
i_k2
]
=
K_shared
[
i_s
,
i_k2
]
*
Beta_shared
[
i_s
]
*
G_shared_exp
[
i_s
]
T
.
copy
(
dw
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
],
dw_shared
)
K_shared_beta_g
[
i_s
,
i_k2
]
=
K_shared
[
i_s
,
i_k2
]
*
Beta_shared
[
i_s
]
*
G_shared_exp
[
i_s
]
T
.
copy
(
dw
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
],
dw_shared
)
T
.
gemm
(
dw_shared
,
K_shared_beta_g
,
dA_fragment
,
transpose_B
=
True
)
T
.
gemm
(
A_shared
,
dw_shared
,
dk_fragment_beta_g
,
clear_accum
=
True
,
transpose_A
=
True
)
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
dk_fragment
[
i_s
,
i_k2
]
=
dk_fragment_beta_g
[
i_s
,
i_k2
]
*
Beta_shared
[
i_s
]
*
G_shared_exp
[
i_s
]
dk_fragment
[
i_s
,
i_k2
]
=
dk_fragment_beta_g
[
i_s
,
i_k2
]
*
Beta_shared
[
i_s
]
*
G_shared_exp
[
i_s
]
# for i_s, i_k2 in T.Parallel(block_S, block_DK):
# dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s]
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
dbeta_fragment_reduce_tmpk
[
i_s
,
i_k2
]
=
dk_fragment_beta_g
[
i_s
,
i_k2
]
*
K_shared
[
i_s
,
i_k2
]
*
G_shared_exp
[
i_s
]
dbeta_fragment_reduce_tmpk
[
i_s
,
i_k2
]
=
dk_fragment_beta_g
[
i_s
,
i_k2
]
*
K_shared
[
i_s
,
i_k2
]
*
G_shared_exp
[
i_s
]
T
.
reduce_sum
(
dbeta_fragment_reduce_tmpk
,
dbeta_fragment_k
,
dim
=
1
,
clear
=
False
)
# for i_s, i_k2 in T.Parallel(block_S, block_DK):
# dg_fragment[i_s] = dg_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s]
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
dg_fragment_reduce_tmp
[
i_s
,
i_k2
]
=
dk_fragment_beta_g
[
i_s
,
i_k2
]
*
K_shared
[
i_s
,
i_k2
]
*
G_shared_exp
[
i_s
]
*
Beta_shared
[
i_s
]
dg_fragment_reduce_tmp
[
i_s
,
i_k2
]
=
(
dk_fragment_beta_g
[
i_s
,
i_k2
]
*
K_shared
[
i_s
,
i_k2
]
*
G_shared_exp
[
i_s
]
*
Beta_shared
[
i_s
]
)
T
.
reduce_sum
(
dg_fragment_reduce_tmp
,
dg_fragment
,
dim
=
1
,
clear
=
False
)
# correct dk
T
.
copy
(
dk_fragment
,
dk
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
])
T
.
copy
(
dk_fragment
,
dk
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
])
# Update dv
for
i_v
in
T
.
Pipelined
(
T
.
ceildiv
(
DV
,
block_DV
),
num_stages
=
num_stages
):
T
.
copy
(
V
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:(
i_v
+
1
)
*
block_DV
],
V_shared
)
T
.
copy
(
V
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:
(
i_v
+
1
)
*
block_DV
],
V_shared
)
for
i_s
,
i_v2
in
T
.
Parallel
(
block_S
,
block_DV
):
V_shared_beta
[
i_s
,
i_v2
]
=
V_shared
[
i_s
,
i_v2
]
*
Beta_shared
[
i_s
]
T
.
copy
(
du
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:(
i_v
+
1
)
*
block_DV
],
du_shared
)
T
.
copy
(
du
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:
(
i_v
+
1
)
*
block_DV
],
du_shared
)
T
.
gemm
(
du_shared
,
V_shared_beta
,
dA_fragment
,
transpose_B
=
True
)
T
.
gemm
(
A_shared
,
du_shared
,
dv_fragment_beta
,
clear_accum
=
True
,
transpose_A
=
True
)
for
i_s
,
i_v2
in
T
.
Parallel
(
block_S
,
block_DV
):
...
...
@@ -247,30 +232,22 @@ def tilelang_wy_fast_bwd(
# for i_s, i_v2 in T.Parallel(block_S, block_DV):
# dbeta_fragment[i_s] = dbeta_fragment[i_s] + dv_fragment_beta[i_s, i_v2] * V_shared[i_s, i_v2]
for
i_s
,
i_v2
in
T
.
Parallel
(
block_S
,
block_DV
):
dbeta_fragment_reduce_tmpv
[
i_s
,
i_v2
]
=
dv_fragment_beta
[
i_s
,
i_v2
]
*
V_shared
[
i_s
,
i_v2
]
dbeta_fragment_reduce_tmpv
[
i_s
,
i_v2
]
=
dv_fragment_beta
[
i_s
,
i_v2
]
*
V_shared
[
i_s
,
i_v2
]
T
.
reduce_sum
(
dbeta_fragment_reduce_tmpv
,
dbeta_fragment_v
,
dim
=
1
,
clear
=
False
)
T
.
copy
(
dv_fragment
,
dv
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:(
i_v
+
1
)
*
block_DV
])
T
.
copy
(
dv_fragment
,
dv
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:
(
i_v
+
1
)
*
block_DV
])
# Temporary store dbeta, dg and dA
for
i_s
in
T
.
Parallel
(
block_S
):
dbeta
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
=
dbeta_fragment_k
[
i_s
]
+
dbeta_fragment_v
[
i_s
]
dg
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
=
dg_fragment
[
i_s
]
# correct dA
T
.
copy
(
dA_fragment
,
dA
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:])
T
.
copy
(
dA_fragment
,
dA
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:])
return
kernel
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
})
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
})
def
tilelang_wy_fast_bwd_split
(
# task config
B
,
...
...
@@ -308,20 +285,20 @@ def tilelang_wy_fast_bwd_split(
@
T
.
prim_func
def
kernel
(
# input
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
V
:
T
.
Tensor
(
V_shape
,
dtype
=
input_dtype
),
Beta
:
T
.
Tensor
(
Beta_shape
,
dtype
=
input_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
gate_dtype
),
A
:
T
.
Tensor
(
A_shape
,
dtype
=
input_dtype
),
dw
:
T
.
Tensor
(
dw_shape
,
dtype
=
input_dtype
),
du
:
T
.
Tensor
(
du_shape
,
dtype
=
input_dtype
),
dA
:
T
.
Tensor
(
dA_shape
,
dtype
=
input_dtype
),
dk
:
T
.
Tensor
(
dk_shape
,
dtype
=
output_dtype
),
dv
:
T
.
Tensor
(
dv_shape
,
dtype
=
output_dtype
),
dbeta_k
:
T
.
Tensor
(
dbeta_shape
,
dtype
=
output_dtype
),
dg_A_positive
:
T
.
Tensor
(
dA_shape
,
dtype
=
gate_dtype
),
dg_A_negative
:
T
.
Tensor
(
dA_shape
,
dtype
=
gate_dtype
),
# input
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
V
:
T
.
Tensor
(
V_shape
,
dtype
=
input_dtype
),
Beta
:
T
.
Tensor
(
Beta_shape
,
dtype
=
input_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
gate_dtype
),
A
:
T
.
Tensor
(
A_shape
,
dtype
=
input_dtype
),
dw
:
T
.
Tensor
(
dw_shape
,
dtype
=
input_dtype
),
du
:
T
.
Tensor
(
du_shape
,
dtype
=
input_dtype
),
dA
:
T
.
Tensor
(
dA_shape
,
dtype
=
input_dtype
),
dk
:
T
.
Tensor
(
dk_shape
,
dtype
=
output_dtype
),
dv
:
T
.
Tensor
(
dv_shape
,
dtype
=
output_dtype
),
dbeta_k
:
T
.
Tensor
(
dbeta_shape
,
dtype
=
output_dtype
),
dg_A_positive
:
T
.
Tensor
(
dA_shape
,
dtype
=
gate_dtype
),
dg_A_negative
:
T
.
Tensor
(
dA_shape
,
dtype
=
gate_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
S
,
block_S
),
B
*
H
,
threads
=
threads
)
as
(
bs
,
bbh
):
bb
,
bh
=
bbh
//
H
,
bbh
%
H
...
...
@@ -350,7 +327,7 @@ def tilelang_wy_fast_bwd_split(
T
.
clear
(
dA_A_fragment_1
)
T
.
clear
(
dA_A_fragment_2
)
T
.
copy
(
A
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:],
A_shared
)
T
.
copy
(
A
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:],
A_shared
)
for
i_s
in
T
.
Parallel
(
block_S
):
Beta_shared
[
i_s
]
=
Beta
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
G_shared
[
i_s
]
=
G
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
...
...
@@ -361,7 +338,7 @@ def tilelang_wy_fast_bwd_split(
# for i_s in T.Parallel(block_S):
# dbeta_fragment[i_s] = dbeta[bb, bs * block_S + i_s, bh]
# dg_fragment[i_s] = dg[bb, bs * block_S + i_s, bh]
T
.
copy
(
dA
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:],
dA_shared
)
T
.
copy
(
dA
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:],
dA_shared
)
# T.copy(dA_shared, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :])
# Update dA
...
...
@@ -385,8 +362,7 @@ def tilelang_wy_fast_bwd_split(
for
i_s1
,
i_s2
in
T
.
Parallel
(
block_S
,
block_S
):
with
T
.
If
(
G
[
bb
,
bs
*
block_S
+
i_s1
,
bh
]
-
G
[
bb
,
bs
*
block_S
+
i_s2
,
bh
]
<=
0
):
with
T
.
Then
():
dA_fragment
[
i_s1
,
i_s2
]
*=
T
.
exp
(
G
[
bb
,
bs
*
block_S
+
i_s1
,
bh
]
-
G
[
bb
,
bs
*
block_S
+
i_s2
,
bh
])
dA_fragment
[
i_s1
,
i_s2
]
*=
T
.
exp
(
G
[
bb
,
bs
*
block_S
+
i_s1
,
bh
]
-
G
[
bb
,
bs
*
block_S
+
i_s2
,
bh
])
with
T
.
Else
():
dA_fragment
[
i_s1
,
i_s2
]
=
0
T
.
copy
(
dA_fragment
,
dA_shared
)
...
...
@@ -397,12 +373,8 @@ def tilelang_wy_fast_bwd_split(
# Update dk using previous dk
T
.
clear
(
A_fragment
)
for
i_k
in
T
.
Pipelined
(
T
.
ceildiv
(
DK
,
block_DK
),
num_stages
=
num_stages
):
T
.
copy
(
K
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
],
K_shared
)
T
.
copy
(
dk
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
],
dk_shared
)
T
.
copy
(
K
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
],
K_shared
)
T
.
copy
(
dk
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
],
dk_shared
)
T
.
copy
(
dk_shared
,
dk_fragment
)
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
K_shared_beta
[
i_s
,
i_k2
]
=
K_shared
[
i_s
,
i_k2
]
*
Beta_shared
[
i_s
]
...
...
@@ -411,18 +383,14 @@ def tilelang_wy_fast_bwd_split(
# for i_s, i_k2 in T.Parallel(block_S, block_DK):
# dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta[i_s, i_k2] * K_shared[i_s, i_k2]
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
dbeta_fragment_reduce_tmpk
[
i_s
,
i_k2
]
=
dk_fragment_beta
[
i_s
,
i_k2
]
*
K_shared
[
i_s
,
i_k2
]
dbeta_fragment_reduce_tmpk
[
i_s
,
i_k2
]
=
dk_fragment_beta
[
i_s
,
i_k2
]
*
K_shared
[
i_s
,
i_k2
]
T
.
reduce_sum
(
dbeta_fragment_reduce_tmpk
,
dbeta_fragment_k
,
dim
=
1
,
clear
=
False
)
T
.
gemm
(
dA_shared
,
K_shared_beta
,
dk_fragment
,
transpose_A
=
True
)
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
dk_shared_beta
[
i_s
,
i_k2
]
=
dk_fragment_beta
[
i_s
,
i_k2
]
*
Beta_shared
[
i_s
]
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
dk_fragment
[
i_s
,
i_k2
]
=
dk_fragment
[
i_s
,
i_k2
]
+
dk_shared_beta
[
i_s
,
i_k2
]
T
.
copy
(
dk_fragment
,
dk
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
])
T
.
copy
(
dk_fragment
,
dk
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
])
# Update dg and dbeta
T
.
copy
(
A_fragment
,
A_shared
)
...
...
@@ -460,19 +428,25 @@ def run_test(
threads
=
128
,
num_stages
=
0
,
):
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
))
dk_ref
,
dv_ref
,
dbeta_ref
,
dg_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
))
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
),
)
dk_ref
,
dv_ref
,
dbeta_ref
,
dg_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
)
)
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang
,
dg_tilelang
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
)
)
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
)
)
BS
=
chunk_size
dA_tilelang
=
torch
.
empty
(
B
,
S
,
H
,
BS
,
dtype
=
getattr
(
torch
,
input_dtype
)).
cuda
()
dbeta_tilelang_k
=
torch
.
empty
(
B
,
S
,
H
,
dtype
=
getattr
(
torch
,
output_dtype
)).
cuda
()
...
...
@@ -480,28 +454,55 @@ def run_test(
dg_tilelang_A_negative
=
torch
.
empty
(
B
,
S
,
H
,
BS
,
dtype
=
getattr
(
torch
,
gate_dtype
)).
cuda
()
# ref
dk_ref
,
dv_ref
,
dbeta_ref
,
dg_ref
=
bwd_prepare_wy_repr
(
K
,
V
,
G
,
Beta
,
A
,
dw
,
du
,
cu_seqlens
=
None
)
dk_ref
,
dv_ref
,
dbeta_ref
,
dg_ref
=
bwd_prepare_wy_repr
(
K
,
V
,
G
,
Beta
,
A
,
dw
,
du
,
cu_seqlens
=
None
)
# tilelang
kernel
=
tilelang_wy_fast_bwd
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
block_DK
,
block_DV
,
threads
,
num_stages
)
dA_tilelang
,
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang
,
dg_tilelang
=
kernel
(
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
)
kernel
=
tilelang_wy_fast_bwd
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
block_DK
,
block_DV
,
threads
,
num_stages
,
)
dA_tilelang
,
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang
,
dg_tilelang
=
kernel
(
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
)
torch
.
cuda
.
synchronize
()
kernel_split
=
tilelang_wy_fast_bwd_split
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
block_DK
,
block_DV
,
threads
,
num_stages
)
kernel_split
(
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
,
dA_tilelang
,
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang_k
,
dg_tilelang_A_positive
,
dg_tilelang_A_negative
)
kernel_split
=
tilelang_wy_fast_bwd_split
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
block_DK
,
block_DV
,
threads
,
num_stages
,
)
kernel_split
(
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
,
dA_tilelang
,
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang_k
,
dg_tilelang_A_positive
,
dg_tilelang_A_negative
)
torch
.
cuda
.
synchronize
()
dbeta_tilelang
=
dbeta_tilelang_k
+
dbeta_tilelang
dg_tilelang
=
dg_tilelang
+
dg_tilelang_A_positive
.
sum
(
dim
=-
1
)
-
dg_tilelang_A_negative
.
sum
(
dim
=-
1
)
dg_tilelang
=
dg_tilelang
+
dg_tilelang_A_positive
.
sum
(
dim
=-
1
)
-
dg_tilelang_A_negative
.
sum
(
dim
=-
1
)
from
test_utils
import
assert_similar
from
utils
import
assert_similar
assert_similar
(
dk_ref
,
dk_tilelang
,
eps
=
1e-5
,
name
=
"dk"
,
raise_assert
=
False
)
assert_similar
(
dv_ref
,
dv_tilelang
,
eps
=
1e-5
,
name
=
"dv"
,
raise_assert
=
False
)
assert_similar
(
dbeta_ref
,
dbeta_tilelang
,
eps
=
1e-5
,
name
=
"dbeta"
,
raise_assert
=
False
)
...
...
@@ -517,11 +518,11 @@ def main():
H
=
8
,
DK
=
DK
,
DV
=
DV
,
input_dtype
=
"
bfloat16
"
,
output_dtype
=
"
bfloat16
"
,
accum_dtype
=
"
float32
"
,
gate_dtype
=
"
float32
"
,
state_dtype
=
"
float32
"
,
input_dtype
=
T
.
bfloat16
,
output_dtype
=
T
.
bfloat16
,
accum_dtype
=
T
.
float32
,
gate_dtype
=
T
.
float32
,
state_dtype
=
T
.
float32
,
chunk_size
=
64
,
block_DK
=
32
,
block_DV
=
32
,
...
...
examples/gdn/test_example_gdn_compilation.py
View file @
667632cc
import
tilelang.testing
import
torch
import
tilelang.testing
from
tilelang
import
language
as
T
B
=
1
S
=
1024
# small but for test only.
H
=
32
DK
=
128
DV
=
128
input_dtype
=
"
bfloat16
"
output_dtype
=
"
bfloat16
"
accum_dtype
=
"
float32
"
gate_dtype
=
"
float32
"
state_dtype
=
"
float32
"
input_dtype
=
T
.
bfloat16
output_dtype
=
T
.
bfloat16
accum_dtype
=
T
.
float32
gate_dtype
=
T
.
float32
state_dtype
=
T
.
float32
chunk_size
=
64
use_g
=
True
use_initial_state
=
True
...
...
@@ -25,16 +26,10 @@ num_stages = 1
def
test_example_wy_fast_compilation
():
from
example_wy_fast
import
tilelang_recompute_w_u_fwd
,
prepare_input
K
,
V
,
Beta
,
G
,
A
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
gate_dtype
=
getattr
(
torch
,
gate_dtype
))
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
gate_dtype
=
getattr
(
torch
,
gate_dtype
)
)
# tilelang
block_S
=
chunk_size
kernel
=
tilelang_recompute_w_u_fwd
(
...
...
@@ -52,22 +47,31 @@ def test_example_wy_fast_compilation():
block_DK
=
block_DK
,
block_DV
=
block_DV
,
threads
=
threads
,
num_stages
=
num_stages
)
num_stages
=
num_stages
,
)
print
(
kernel
.
get_kernel_source
())
W_tilelang
,
U_tilelang
=
kernel
(
K
,
V
,
Beta
,
G
,
A
)
def
test_example_wy_fast_bwd_split_compilation
():
from
example_wy_fast_bwd_split
import
tilelang_wy_fast_bwd
,
tilelang_wy_fast_bwd_split
,
prepare_input
,
prepare_output
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
))
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
),
)
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang
,
dg_tilelang
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
)
)
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
)
)
BS
=
chunk_size
dA_tilelang
=
torch
.
empty
(
B
,
S
,
H
,
BS
,
dtype
=
getattr
(
torch
,
input_dtype
)).
cuda
()
dbeta_tilelang_k
=
torch
.
empty
(
B
,
S
,
H
,
dtype
=
getattr
(
torch
,
output_dtype
)).
cuda
()
...
...
@@ -75,67 +79,146 @@ def test_example_wy_fast_bwd_split_compilation():
dg_tilelang_A_negative
=
torch
.
empty
(
B
,
S
,
H
,
BS
,
dtype
=
getattr
(
torch
,
gate_dtype
)).
cuda
()
# tilelang
kernel
=
tilelang_wy_fast_bwd
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
block_DK
,
block_DV
,
threads
,
num_stages
)
dA_tilelang
,
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang
,
dg_tilelang
=
kernel
(
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
)
kernel
=
tilelang_wy_fast_bwd
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
block_DK
,
block_DV
,
threads
,
num_stages
,
)
dA_tilelang
,
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang
,
dg_tilelang
=
kernel
(
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
)
torch
.
cuda
.
synchronize
()
kernel_split
=
tilelang_wy_fast_bwd_split
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
block_DK
,
block_DV
,
threads
,
num_stages
)
kernel_split
(
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
,
dA_tilelang
,
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang_k
,
dg_tilelang_A_positive
,
dg_tilelang_A_negative
)
kernel_split
=
tilelang_wy_fast_bwd_split
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
block_DK
,
block_DV
,
threads
,
num_stages
,
)
kernel_split
(
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
,
dA_tilelang
,
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang_k
,
dg_tilelang_A_positive
,
dg_tilelang_A_negative
)
torch
.
cuda
.
synchronize
()
dbeta_tilelang
=
dbeta_tilelang_k
+
dbeta_tilelang
dg_tilelang
=
dg_tilelang
+
dg_tilelang_A_positive
.
sum
(
dim
=-
1
)
-
dg_tilelang_A_negative
.
sum
(
dim
=-
1
)
dg_tilelang
=
dg_tilelang
+
dg_tilelang_A_positive
.
sum
(
dim
=-
1
)
-
dg_tilelang_A_negative
.
sum
(
dim
=-
1
)
def
test_example_chunk_o_compilation
():
from
example_chunk_o
import
tilelang_chunk_fwd_o
,
prepare_input
Q
,
K
,
V
,
HIDDEN
,
G
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
))
Q
,
K
,
V
,
HIDDEN
,
G
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
)
scale
=
1.0
/
DK
**
0.5
block_S
=
chunk_size
kernel
=
tilelang_chunk_fwd_o
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
chunk_size
,
scale
,
use_g
,
block_S
,
block_DK
,
block_DV
,
threads
,
num_stages
)
kernel
=
tilelang_chunk_fwd_o
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
chunk_size
,
scale
,
use_g
,
block_S
,
block_DK
,
block_DV
,
threads
,
num_stages
,
)
O_tilelang
=
kernel
(
Q
,
K
,
V
,
HIDDEN
,
G
)
# noqa: F841
def
test_example_chunk_o_bwd_compilation
():
from
example_chunk_o_bwd
import
tilelang_chunk_o_bwd_dqkwg
,
prepare_input
Q
,
K
,
V
,
h
,
G
,
dO
,
dh
,
dv
,
W
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
))
kernel
=
tilelang_chunk_o_bwd_dqkwg
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
1.0
,
use_g
,
True
,
block_DK
,
block_DV
,
threads
,
num_stages
)
dq_tilelang
,
dk_tilelang
,
dw_tilelang
,
dg_tilelang
=
kernel
(
Q
,
K
,
V
,
h
,
G
,
dO
,
dh
,
dv
,
W
)
# noqa: F841
Q
,
K
,
V
,
h
,
G
,
dO
,
dh
,
dv
,
W
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
),
)
kernel
=
tilelang_chunk_o_bwd_dqkwg
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
1.0
,
use_g
,
True
,
block_DK
,
block_DV
,
threads
,
num_stages
,
)
dq_tilelang
,
dk_tilelang
,
dw_tilelang
,
dg_tilelang
=
kernel
(
Q
,
K
,
V
,
h
,
G
,
dO
,
dh
,
dv
,
W
)
# noqa: F841
if
use_g
:
dg_tilelang
=
dg_tilelang
.
sum
(
dim
=
0
)
def
test_example_chunk_scaled_dot_kkt_compilation
():
from
example_chunk_scaled_dot_kkt
import
tilelang_chunk_scaled_dot_kkt_fwd
,
prepare_input
K
,
Beta
,
G
=
prepare_input
(
B
,
S
,
H
,
DK
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
))
K
,
Beta
,
G
=
prepare_input
(
B
,
S
,
H
,
DK
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
))
block_S
=
chunk_size
kernel
=
tilelang_chunk_scaled_dot_kkt_fwd
(
B
,
S
,
H
,
DK
,
chunk_size
,
input_dtype
,
output_dtype
,
accum_dtype
,
use_g
,
block_S
,
block_DK
,
threads
,
num_stages
)
kernel
=
tilelang_chunk_scaled_dot_kkt_fwd
(
B
,
S
,
H
,
DK
,
chunk_size
,
input_dtype
,
output_dtype
,
accum_dtype
,
use_g
,
block_S
,
block_DK
,
threads
,
num_stages
)
A_tilelang
=
kernel
(
K
,
Beta
,
G
)
# noqa: F841
def
test_example_cumsum_compilation
():
from
example_cumsum
import
tilelang_chunk_local_cumsum_scalar
,
prepare_cumsum_input
,
prepare_cumsum_output
G
=
prepare_cumsum_input
(
B
,
S
,
H
,
getattr
(
torch
,
gate_dtype
))
G_new_tilelang
=
prepare_cumsum_output
(
B
,
S
,
H
,
getattr
(
torch
,
gate_dtype
))
block_S
=
chunk_size
...
...
@@ -157,33 +240,79 @@ def test_example_cumsum_compilation():
def
test_example_chunk_delta_h_compilation
():
from
example_chunk_delta_h
import
tilelang_chunk_gated_delta_rule_fwd_h
,
prepare_input
K
,
W
,
U
,
G
,
initial_state
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
))
kernel
=
tilelang_chunk_gated_delta_rule_fwd_h
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
use_g
,
use_initial_state
,
store_final_state
,
save_new_value
,
block_DK
,
block_DV
,
threads
,
num_stages
)
h_tilelang
,
final_state_tilelang
,
V_new_tilelang
=
kernel
(
K
,
W
,
U
,
G
,
initial_state
)
# noqa: F841
K
,
W
,
U
,
G
,
initial_state
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
)
kernel
=
tilelang_chunk_gated_delta_rule_fwd_h
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
use_g
,
use_initial_state
,
store_final_state
,
save_new_value
,
block_DK
,
block_DV
,
threads
,
num_stages
,
)
h_tilelang
,
final_state_tilelang
,
V_new_tilelang
=
kernel
(
K
,
W
,
U
,
G
,
initial_state
)
# noqa: F841
def
test_example_chunk_delta_bwd_compilation
():
from
example_chunk_delta_bwd
import
tilelang_chunk_gated_delta_rule_bwd_dhu
,
prepare_input
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
))
kernel
=
tilelang_chunk_gated_delta_rule_bwd_dhu
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
1.0
,
use_g
,
use_initial_state
,
use_final_state_gradient
,
block_DV
,
threads
,
num_stages
)
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
),
)
kernel
=
tilelang_chunk_gated_delta_rule_bwd_dhu
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
1.0
,
use_g
,
use_initial_state
,
use_final_state_gradient
,
block_DV
,
threads
,
num_stages
,
)
dh_tilelang
,
dh0_tilelang
,
dv2_tilelang
=
kernel
(
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
)
# noqa: F841
...
...
examples/gdn/utils.py
→
examples/gdn/
test_
utils.py
View file @
667632cc
...
...
@@ -9,7 +9,7 @@ def calc_sim(x, y, name="tensor"):
x
,
y
=
x
.
data
.
double
(),
y
.
data
.
double
()
denominator
=
(
x
*
x
+
y
*
y
).
sum
()
if
denominator
==
0
:
print_red_warning
(
f
'
{
name
}
all zero
'
)
print_red_warning
(
f
"
{
name
}
all zero
"
)
return
1
sim
=
2
*
(
x
*
y
).
sum
()
/
denominator
return
sim
...
...
@@ -19,21 +19,19 @@ def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True):
x_mask
=
torch
.
isfinite
(
x
)
y_mask
=
torch
.
isfinite
(
y
)
if
not
torch
.
all
(
x_mask
==
y_mask
):
print_red_warning
(
f
'
{
name
}
Error: isfinite mask mismatch
'
)
print_red_warning
(
f
"
{
name
}
Error: isfinite mask mismatch
"
)
if
raise_assert
:
raise
AssertionError
if
not
torch
.
isclose
(
x
.
masked_fill
(
x_mask
,
0
),
y
.
masked_fill
(
y_mask
,
0
),
rtol
=
0
,
atol
=
0
,
equal_nan
=
True
).
all
():
print_red_warning
(
f
'
{
name
}
Error: nonfinite value mismatch'
)
if
not
torch
.
isclose
(
x
.
masked_fill
(
x_mask
,
0
),
y
.
masked_fill
(
y_mask
,
0
),
rtol
=
0
,
atol
=
0
,
equal_nan
=
True
).
all
():
print_red_warning
(
f
"
{
name
}
Error: nonfinite value mismatch"
)
if
raise_assert
:
raise
AssertionError
x
=
x
.
masked_fill
(
~
x_mask
,
0
)
y
=
y
.
masked_fill
(
~
y_mask
,
0
)
sim
=
calc_sim
(
x
,
y
,
name
)
diff
=
1.
-
sim
diff
=
1.
0
-
sim
if
not
(
0
<=
diff
<=
eps
):
print_red_warning
(
f
'
{
name
}
Error:
{
diff
}
'
)
print_red_warning
(
f
"
{
name
}
Error:
{
diff
}
"
)
if
raise_assert
:
raise
AssertionError
else
:
...
...
examples/gemm/README.md
View file @
667632cc
...
...
@@ -53,7 +53,7 @@ import tilelang
from
tilelang
import
Profiler
import
tilelang.language
as
T
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"
float16
"
,
accum_dtype
=
"
float
"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
T
.
float16
,
accum_dtype
=
T
.
float
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
...
@@ -176,7 +176,7 @@ import tilelang.language as T
# that helps align data for MMA (Matrix Multiply-Accumulate) operations.
from
tilelang.intrinsics
import
make_mma_swizzle_layout
as
make_swizzle_layout
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"
float16
"
,
accum_dtype
=
"
float
"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
T
.
float16
,
accum_dtype
=
T
.
float
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
...
@@ -265,18 +265,18 @@ def tl_matmul(
accum_dtype
,
):
assert
in_dtype
in
[
"
float16
"
,
"
int8
"
,
T
.
float16
,
T
.
int8
,
],
"Currently only float16 and int8 are supported"
assert
out_dtype
in
[
"
float16
"
,
"
float32
"
,
"
int32
"
,
T
.
float16
,
T
.
float32
,
T
.
int32
,
],
"Currently only float16, float32 and int32 are supported"
micro_size_x
=
micro_size_y
=
micro_size_k
=
16
if
out_dtype
==
"
int32
"
:
if
out_dtype
==
T
.
int32
:
micro_size_k
=
32
# This is a debug config
...
...
examples/gemm/example_gemm.py
View file @
667632cc
...
...
@@ -3,13 +3,12 @@ import tilelang.language as T
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
T
.
float16
,
accum_dtype
=
T
.
float32
):
@
T
.
prim_func
def
gemm
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
...
examples/gemm/example_gemm_autotune.py
View file @
667632cc
...
...
@@ -51,9 +51,9 @@ def get_configs(M, N, K, with_roller=False, topk=20):
M
=
M
,
N
=
N
,
K
=
K
,
in_dtype
=
"
float16
"
,
out_dtype
=
"
float16
"
,
accum_dtype
=
"
float
"
,
in_dtype
=
T
.
float16
,
out_dtype
=
T
.
float16
,
accum_dtype
=
T
.
float
32
,
).
with_arch
(
arch
)
func
=
carve_template
.
equivalent_function
()
...
...
@@ -90,7 +90,8 @@ def get_configs(M, N, K, with_roller=False, topk=20):
num_stages
,
thread_num
,
enable_rasterization
,
))
)
)
configs
=
[
{
...
...
@@ -100,13 +101,13 @@ def get_configs(M, N, K, with_roller=False, topk=20):
"num_stages"
:
c
[
3
],
"thread_num"
:
c
[
4
],
"enable_rasteration"
:
c
[
5
],
# keep param name for backward-compat
}
for
c
in
_configs
}
for
c
in
_configs
]
return
configs
def
get_best_config
(
M
,
N
,
K
,
with_roller
=
False
):
def
kernel
(
block_M
=
None
,
block_N
=
None
,
...
...
@@ -115,17 +116,16 @@ def get_best_config(M, N, K, with_roller=False):
thread_num
=
None
,
enable_rasteration
=
None
,
):
dtype
=
"
bfloat16
"
accum_dtype
=
"
float
"
dtype
=
T
.
bfloat16
accum_dtype
=
T
.
float
32
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
...
@@ -146,15 +146,18 @@ def get_best_config(M, N, K, with_roller=False):
return
main
autotuner
=
AutoTuner
.
from_kernel
(
kernel
=
kernel
,
configs
=
get_configs
(
M
,
N
,
K
,
with_roller
)).
set_compile_args
(
autotuner
=
(
AutoTuner
.
from_kernel
(
kernel
=
kernel
,
configs
=
get_configs
(
M
,
N
,
K
,
with_roller
))
.
set_compile_args
(
out_idx
=
[
-
1
],
target
=
"auto"
,
).
set_profile_args
(
)
.
set_profile_args
(
supply_type
=
tl
.
TensorSupplyType
.
Integer
,
ref_prog
=
ref_program
,
skip_check
=
False
,
)
)
return
autotuner
.
run
(
warmup
=
3
,
rep
=
20
)
...
...
@@ -167,52 +170,20 @@ def get_heuristic_config() -> dict:
sm_version
=
sm_major
*
10
+
sm_minor
print
(
f
"CUDA device capability:
{
sm_version
}
"
)
if
sm_version
in
{
80
}:
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
2
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
2
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
elif
sm_version
in
{
90
}:
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
64
,
"num_stages"
:
3
,
"thread_num"
:
256
,
"enable_rasteration"
:
True
}
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
64
,
"num_stages"
:
3
,
"thread_num"
:
256
,
"enable_rasteration"
:
True
}
else
:
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
0
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
0
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
@
tl
.
jit
(
out_idx
=
[
-
1
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
T
.
float16
,
accum_dtype
=
T
.
float32
):
@
T
.
prim_func
def
gemm_autotune
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
...
@@ -236,11 +207,7 @@ def matmul(M,
return
gemm_autotune
def
main
(
M
:
int
=
4096
,
N
:
int
=
4096
,
K
:
int
=
4096
,
use_autotune
:
bool
=
False
,
with_roller
:
bool
=
False
):
def
main
(
M
:
int
=
4096
,
N
:
int
=
4096
,
K
:
int
=
4096
,
use_autotune
:
bool
=
False
,
with_roller
:
bool
=
False
):
use_autotune
=
True
if
use_autotune
:
result
=
get_best_config
(
M
,
N
,
K
,
with_roller
)
...
...
@@ -266,15 +233,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--m"
,
type
=
int
,
default
=
4096
,
help
=
"Matrix dimension M"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
4096
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
4096
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
"--use_autotune"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune for matmul configs"
)
parser
.
add_argument
(
"--with_roller"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to enable BitBLAS roller for search space"
)
parser
.
add_argument
(
"--use_autotune"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune for matmul configs"
)
parser
.
add_argument
(
"--with_roller"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to enable BitBLAS roller for search space"
)
args
=
parser
.
parse_args
()
main
(
args
.
m
,
args
.
n
,
args
.
k
,
args
.
use_autotune
,
args
.
with_roller
)
Prev
1
…
5
6
7
8
9
10
11
12
13
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment