Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
667632cc
Unverified
Commit
667632cc
authored
Dec 22, 2025
by
guchaoyang
Committed by
GitHub
Dec 22, 2025
Browse files
Merge branch 'main' into dcu
parents
d6dd2ddf
a874e4e8
Changes
313
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2093 additions
and
1382 deletions
+2093
-1382
examples/flash_decoding/example_gqa_decode.py
examples/flash_decoding/example_gqa_decode.py
+135
-142
examples/flash_decoding/example_gqa_decode_varlen_logits.py
examples/flash_decoding/example_gqa_decode_varlen_logits.py
+130
-181
examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py
.../flash_decoding/example_gqa_decode_varlen_logits_paged.py
+679
-0
examples/flash_decoding/example_mha_inference.py
examples/flash_decoding/example_mha_inference.py
+73
-78
examples/fusedmoe/example_fusedmoe_tilelang.py
examples/fusedmoe/example_fusedmoe_tilelang.py
+166
-197
examples/fusedmoe/example_fusedmoe_torch.py
examples/fusedmoe/example_fusedmoe_torch.py
+37
-54
examples/fusedmoe/test_example_fusedmoe.py
examples/fusedmoe/test_example_fusedmoe.py
+2
-7
examples/gdn/example_chunk_delta_bwd.py
examples/gdn/example_chunk_delta_bwd.py
+140
-102
examples/gdn/example_chunk_delta_h.py
examples/gdn/example_chunk_delta_h.py
+123
-81
examples/gdn/example_chunk_o.py
examples/gdn/example_chunk_o.py
+50
-41
examples/gdn/example_chunk_o_bwd.py
examples/gdn/example_chunk_o_bwd.py
+99
-109
examples/gdn/example_chunk_scaled_dot_kkt.py
examples/gdn/example_chunk_scaled_dot_kkt.py
+28
-30
examples/gdn/example_cumsum.py
examples/gdn/example_cumsum.py
+20
-24
examples/gdn/example_wy_fast.py
examples/gdn/example_wy_fast.py
+35
-46
examples/gdn/example_wy_fast_bwd_split.py
examples/gdn/example_wy_fast_bwd_split.py
+119
-118
examples/gdn/test_example_gdn_compilation.py
examples/gdn/test_example_gdn_compilation.py
+210
-81
examples/gdn/test_utils.py
examples/gdn/test_utils.py
+6
-8
examples/gemm/README.md
examples/gemm/README.md
+8
-8
examples/gemm/example_gemm.py
examples/gemm/example_gemm.py
+4
-5
examples/gemm/example_gemm_autotune.py
examples/gemm/example_gemm_autotune.py
+29
-70
No files found.
Too many changes to show.
To preserve performance only
313 of 313+
files are displayed.
Plain diff
Email patch
examples/flash_decoding/example_gqa_decode.py
View file @
667632cc
...
@@ -15,18 +15,12 @@ torch.random.manual_seed(0)
...
@@ -15,18 +15,12 @@ torch.random.manual_seed(0)
def
get_configs
():
def
get_configs
():
block_N
=
[
64
,
128
]
block_N
=
[
64
,
128
]
block_H
=
[
64
]
block_H
=
[
64
]
num_split
=
[
2
,
4
,
8
]
num_split
=
[
1
,
2
,
4
,
8
]
num_stages
=
[
1
,
2
,
3
]
num_stages
=
[
1
,
2
,
3
]
threads
=
[
128
]
threads
=
[
128
]
_configs
=
list
(
itertools
.
product
(
block_N
,
block_H
,
num_split
,
num_stages
,
threads
))
_configs
=
list
(
itertools
.
product
(
block_N
,
block_H
,
num_split
,
num_stages
,
threads
))
configs
=
[{
configs
=
[{
"block_N"
:
c
[
0
],
"block_H"
:
c
[
1
],
"num_split"
:
c
[
2
],
"num_stages"
:
c
[
3
],
"threads"
:
c
[
4
]}
for
c
in
_configs
]
'block_N'
:
c
[
0
],
'block_H'
:
c
[
1
],
'num_split'
:
c
[
2
],
'num_stages'
:
c
[
3
],
'threads'
:
c
[
4
]
}
for
c
in
_configs
]
return
configs
return
configs
...
@@ -42,29 +36,25 @@ def get_heuristic_config() -> Tuple[Dict, int]:
...
@@ -42,29 +36,25 @@ def get_heuristic_config() -> Tuple[Dict, int]:
if
sm_version
==
89
:
if
sm_version
==
89
:
cfg
=
dict
(
block_N
=
128
,
block_H
=
64
,
num_split
=
1
,
num_stages
=
0
,
threads
=
128
)
cfg
=
dict
(
block_N
=
128
,
block_H
=
64
,
num_split
=
1
,
num_stages
=
0
,
threads
=
128
)
else
:
else
:
cfg
=
dict
(
block_N
=
128
,
block_H
=
64
,
num_split
=
1
,
num_stages
=
2
,
threads
=
128
)
cfg
=
dict
(
block_N
=
128
,
block_H
=
64
,
num_split
=
8
,
num_stages
=
2
,
threads
=
128
)
return
cfg
,
sm_version
return
cfg
,
sm_version
# TODO(lei): fix warp specialized and tma lower pass
# TODO(lei): fix warp specialized and tma lower pass
def
get_pass_configs
():
def
get_pass_configs
():
return
{
return
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
}
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
}
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
tilelang
.
jit
(
out_idx
=
[
6
],
pass_configs
=
get_pass_configs
())
@
tilelang
.
jit
(
out_idx
=
[
6
],
pass_configs
=
get_pass_configs
())
def
flashattn
(
batch
,
heads
,
groups
,
seqlen_kv
,
dim
,
block_N
,
block_H
,
num_split
,
num_stages
,
def
flashattn
(
batch
,
heads
,
groups
,
seqlen_kv
,
dim
,
block_N
,
block_H
,
num_split
,
num_stages
,
threads
):
threads
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape_q
=
[
batch
,
heads
,
dim
]
shape_q
=
[
batch
,
heads
,
dim
]
shape_k
=
[
batch
,
seqlen_kv
,
groups
,
dim
]
shape_k
=
[
batch
,
seqlen_kv
,
groups
,
dim
]
shape_v
=
[
batch
,
seqlen_kv
,
groups
,
dim
]
shape_v
=
[
batch
,
seqlen_kv
,
groups
,
dim
]
shape_o
=
[
batch
,
heads
,
dim
]
shape_o
=
[
batch
,
heads
,
dim
]
dtype
=
"
float16
"
dtype
=
T
.
float16
accum_dtype
=
"
float
"
accum_dtype
=
T
.
float
32
kv_group_num
=
heads
//
groups
kv_group_num
=
heads
//
groups
part_shape
=
[
batch
,
heads
,
num_split
,
dim
]
part_shape
=
[
batch
,
heads
,
num_split
,
dim
]
...
@@ -73,11 +63,11 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
...
@@ -73,11 +63,11 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
@
T
.
macro
@
T
.
macro
def
flash_attn
(
def
flash_attn
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
mask
:
T
.
Tensor
([
batch
,
seqlen_kv
,
groups
],
"uint8"
),
mask
:
T
.
Tensor
([
batch
,
seqlen_kv
,
groups
],
"uint8"
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
):
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
...
@@ -98,23 +88,24 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
...
@@ -98,23 +88,24 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
hid
=
by
hid
=
by
cur_kv_head
=
hid
//
(
kv_group_num
//
valid_block_H
)
cur_kv_head
=
hid
//
(
kv_group_num
//
valid_block_H
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
T
.
ceildiv
((
seqlen_kv
//
num_split
),
block_N
)
loop_range
=
T
.
ceildiv
((
seqlen_kv
//
num_split
),
block_N
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
T
.
copy
(
K
[
bid
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
copy
(
K
[
bid
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
copy
(
mask
[
bid
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
],
mask_local
)
T
.
copy
(
mask
[
bid
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
],
mask_local
)
T
.
clear
(
acc_s
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
mask_local
[
j
]
!=
0
,
acc_s
[
i
,
j
],
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
mask_local
[
j
]
!=
0
,
acc_s
[
i
,
j
],
-
T
.
infinity
(
accum_dtype
))
-
T
.
infinity
(
accum_dtype
))
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_H
):
for
i
in
T
.
Parallel
(
block_H
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
...
@@ -125,23 +116,23 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
...
@@ -125,23 +116,23 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
T
.
copy
(
acc_s
,
acc_s_cast
)
T
.
copy
(
acc_s
,
acc_s_cast
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
V
[
bid
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
copy
(
V
[
bid
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
acc_o
[
i
,
j
]
/=
logsum
[
i
]
for
i
in
T
.
Parallel
(
block_H
):
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
acc_o
[:
valid_block_H
,
:],
O_shared
)
T
.
copy
(
acc_o
[:
valid_block_H
,
:],
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bid
,
hid
*
valid_block_H
:
(
hid
+
1
)
*
valid_block_H
,
:])
T
.
copy
(
O_shared
,
Output
[
bid
,
hid
*
valid_block_H
:
(
hid
+
1
)
*
valid_block_H
,
:])
@
T
.
macro
@
T
.
macro
def
flash_attn_split
(
def
flash_attn_split
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
mask
:
T
.
Tensor
([
batch
,
seqlen_kv
,
groups
],
"uint8"
),
mask
:
T
.
Tensor
([
batch
,
seqlen_kv
,
groups
],
"uint8"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
):
):
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
...
@@ -163,7 +154,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
...
@@ -163,7 +154,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
sid
=
bz
sid
=
bz
cur_kv_head
=
hid
//
(
kv_group_num
//
valid_block_H
)
cur_kv_head
=
hid
//
(
kv_group_num
//
valid_block_H
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
@@ -172,22 +163,31 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
...
@@ -172,22 +163,31 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
T
.
copy
(
T
.
copy
(
K
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
K
[
k
*
valid_block_N
:(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
valid_block_N
,
bid
,
cur_kv_head
,
:],
K_shared
)
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
valid_block_N
:
(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
valid_block_N
,
cur_kv_head
,
:,
],
K_shared
,
)
T
.
copy
(
T
.
copy
(
mask
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
mask
[
k
*
valid_block_N
:(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
valid_block_N
,
bid
,
cur_kv_head
],
mask_local
)
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
valid_block_N
:
(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
valid_block_N
,
cur_kv_head
,
],
mask_local
,
)
T
.
clear
(
acc_s
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
acc_s
[
i
,
j
]
=
T
.
if_then_else
((
mask_local
[
j
]
!=
0
)
&
(
j
<
seqlen_kv
//
num_split
),
acc_s
[
i
,
j
],
-
T
.
infinity
(
accum_dtype
))
j
]
=
T
.
if_then_else
((
mask_local
[
j
]
!=
0
)
&
(
j
<
seqlen_kv
//
num_split
),
acc_s
[
i
,
j
],
-
T
.
infinity
(
accum_dtype
))
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
for
i
in
T
.
Parallel
(
block_H
):
for
i
in
T
.
Parallel
(
block_H
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
...
@@ -199,9 +199,14 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
...
@@ -199,9 +199,14 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
T
.
copy
(
V
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
V
[
k
*
valid_block_N
:(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
valid_block_N
,
bid
,
cur_kv_head
,
:],
V_shared
)
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
valid_block_N
:
(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
valid_block_N
,
cur_kv_head
,
:,
],
V_shared
,
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
acc_o
[
i
,
j
]
/=
logsum
[
i
]
...
@@ -212,72 +217,74 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
...
@@ -212,72 +217,74 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
if
i
<
valid_block_H
:
if
i
<
valid_block_H
:
glse
[
bid
,
hid
*
valid_block_H
+
i
,
sid
]
=
logsum
[
i
]
glse
[
bid
,
hid
*
valid_block_H
+
i
,
sid
]
=
logsum
[
i
]
T
.
copy
(
acc_o
[:
valid_block_H
,
:],
O_shared
)
T
.
copy
(
acc_o
[:
valid_block_H
,
:],
O_shared
)
T
.
copy
(
O_shared
,
Output_partial
[
bid
,
hid
*
valid_block_H
:(
hid
+
1
)
*
valid_block_H
,
T
.
copy
(
O_shared
,
Output_partial
[
bid
,
hid
*
valid_block_H
:
(
hid
+
1
)
*
valid_block_H
,
sid
,
:])
sid
,
:])
@
T
.
macro
@
T
.
macro
def
combine
(
def
combine
(
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
):
with
T
.
Kernel
(
heads
,
batch
,
threads
=
128
)
as
(
by
,
bz
):
with
T
.
Kernel
(
heads
,
batch
,
threads
=
128
)
as
(
by
,
bz
):
po_local
=
T
.
alloc_fragment
([
dim
],
dtype
)
po_local
=
T
.
alloc_fragment
([
dim
],
dtype
)
o_accum_local
=
T
.
alloc_fragment
([
dim
],
accum_dtype
)
o_accum_local
=
T
.
alloc_fragment
([
dim
],
accum_dtype
)
lse_local
=
T
.
alloc_fragment
([
num_split
,
128
],
dtype
)
lse_local
=
T
.
alloc_fragment
([
num_split
,
128
],
dtype
)
lse_local_split
=
T
.
alloc_local
([
1
],
accum_dtype
)
lse_logsum_local
=
T
.
alloc_fragment
([
128
],
accum_dtype
)
lse_logsum_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
lse_max_local
=
T
.
alloc_fragment
([
128
],
accum_dtype
)
lse_max_local
=
T
.
alloc_fragment
([
128
],
accum_dtype
)
scale_local
=
T
.
alloc_
local
([
1
],
accum_dtype
)
scale_local
=
T
.
alloc_
fragment
([
1
28
],
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
{
lse_max_local
:
T
.
Fragment
(
lse_max_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
# lse_local: (local_id, thread_id)
lse_max_local
:
T
.
Fragment
(
lse_max_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
lse_local
:
T
.
Fragment
(
lse_local
.
shape
,
forward_fn
=
lambda
i
,
j
:
(
j
,
i
)),
# lse_local: (local_id, thread_id)
})
lse_local
:
T
.
Fragment
(
lse_local
.
shape
,
forward_fn
=
lambda
i
,
j
:
(
j
,
i
)),
}
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
o_accum_local
)
T
.
clear
(
o_accum_local
)
for
k
,
j
in
T
.
Parallel
(
num_split
,
128
):
for
k
,
j
in
T
.
Parallel
(
num_split
,
128
):
lse_local
[
k
,
j
]
=
glse
[
bz
,
by
,
k
]
lse_local
[
k
,
j
]
=
glse
[
bz
,
by
,
k
]
T
.
reduce_max
(
lse_local
,
lse_max_local
,
dim
=
0
,
clear
=
True
)
T
.
reduce_max
(
lse_local
,
lse_max_local
,
dim
=
0
,
clear
=
True
)
for
k
in
T
.
Pipelined
(
num_split
,
num_stages
=
1
):
for
k
in
T
.
serial
(
num_split
):
lse_local_split
[
0
]
=
glse
[
bz
,
by
,
k
]
for
j
in
T
.
Parallel
(
128
):
lse_logsum_local
[
0
]
+=
T
.
exp2
(
lse_local_split
[
0
]
-
lse_max_local
[
0
])
lse_logsum_local
[
j
]
+=
T
.
exp2
(
lse_local
[
k
,
j
]
-
lse_max_local
[
j
])
lse_logsum_local
[
0
]
=
T
.
log2
(
lse_logsum_local
[
0
])
+
lse_max_local
[
0
]
for
j
in
T
.
Parallel
(
128
):
lse_logsum_local
[
j
]
=
T
.
log2
(
lse_logsum_local
[
j
])
+
lse_max_local
[
j
]
for
k
in
T
.
serial
(
num_split
):
for
k
in
T
.
serial
(
num_split
):
for
i
in
T
.
Parallel
(
dim
):
for
i
in
T
.
Parallel
(
dim
):
po_local
[
i
]
=
Output_partial
[
bz
,
by
,
k
,
i
]
po_local
[
i
]
=
Output_partial
[
bz
,
by
,
k
,
i
]
lse_local_split
[
0
]
=
glse
[
bz
,
by
,
k
]
for
j
in
T
.
Parallel
(
128
):
scale_local
[
0
]
=
T
.
exp2
(
lse_local_split
[
0
]
-
lse_logsum_local
[
0
])
scale_local
[
j
]
=
T
.
exp2
(
lse_local
[
k
,
j
]
-
lse_logsum_local
[
j
])
# Note: Pay attention to dim and the number of threads in Parallel
for
i
in
T
.
Parallel
(
dim
):
for
i
in
T
.
Parallel
(
dim
):
o_accum_local
[
i
]
+=
po_local
[
i
]
*
scale_local
[
0
]
o_accum_local
[
i
]
+=
po_local
[
i
]
*
scale_local
[
i
]
for
i
in
T
.
Parallel
(
dim
):
for
i
in
T
.
Parallel
(
dim
):
Output
[
bz
,
by
,
i
]
=
o_accum_local
[
i
]
Output
[
bz
,
by
,
i
]
=
o_accum_local
[
i
]
@
T
.
prim_func
@
T
.
prim_func
def
flashattn_gqa_decode_split
(
def
flashattn_gqa_decode_split
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
mask
:
T
.
Tensor
([
batch
,
seqlen_kv
,
groups
],
"uint8"
),
mask
:
T
.
Tensor
([
batch
,
seqlen_kv
,
groups
],
"uint8"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
):
flash_attn_split
(
Q
,
K
,
V
,
mask
,
glse
,
Output_partial
)
flash_attn_split
(
Q
,
K
,
V
,
mask
,
glse
,
Output_partial
)
combine
(
glse
,
Output_partial
,
Output
)
combine
(
glse
,
Output_partial
,
Output
)
@
T
.
prim_func
@
T
.
prim_func
def
flashattn_gqa_decode_no_split
(
def
flashattn_gqa_decode_no_split
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
mask
:
T
.
Tensor
([
batch
,
seqlen_kv
,
groups
],
"uint8"
),
mask
:
T
.
Tensor
([
batch
,
seqlen_kv
,
groups
],
"uint8"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
):
flash_attn
(
Q
,
K
,
V
,
mask
,
Output
)
flash_attn
(
Q
,
K
,
V
,
mask
,
Output
)
...
@@ -300,27 +307,21 @@ def ref_program(query, key, value, mask, glse, Output_partial):
...
@@ -300,27 +307,21 @@ def ref_program(query, key, value, mask, glse, Output_partial):
dim
=
query
.
shape
[
-
1
]
dim
=
query
.
shape
[
-
1
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
scale
=
dim
**
0.5
scale
=
dim
**
0.5
key
=
rearrange
(
key
,
'
b n h d -> b h n d
'
)
# [batch_size, groups, seqlen_kv, dim]
key
=
rearrange
(
key
,
"
b n h d -> b h n d
"
)
# [batch_size, groups, seqlen_kv, dim]
value
=
rearrange
(
value
,
'
b n h d -> b h n d
'
)
# [batch_size, groups, seqlen_kv, dim]
value
=
rearrange
(
value
,
"
b n h d -> b h n d
"
)
# [batch_size, groups, seqlen_kv, dim]
query
=
rearrange
(
query
=
rearrange
(
query
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
query
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
scores
=
einsum
(
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, groups, seqlen_kv]
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, groups, seqlen_kv]
if
mask
is
not
None
:
if
mask
is
not
None
:
mask
=
rearrange
(
mask
,
'
b s h -> b h s
'
)
mask
=
rearrange
(
mask
,
"
b s h -> b h s
"
)
mask
=
mask
.
unsqueeze
(
1
)
mask
=
mask
.
unsqueeze
(
1
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
"
-inf
"
))
attention
=
F
.
softmax
(
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
out
=
einsum
(
attention
,
value
,
out
=
einsum
(
attention
,
value
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, groups, dim]
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
return
out
return
out
...
@@ -334,16 +335,12 @@ def flash_split_ref(Q, K, V, mask):
...
@@ -334,16 +335,12 @@ def flash_split_ref(Q, K, V, mask):
seqlen_kv
=
K
.
size
(
1
)
seqlen_kv
=
K
.
size
(
1
)
num_head_groups
=
nheads
//
groups
num_head_groups
=
nheads
//
groups
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
acc_s
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
,
block_N
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
acc_s
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
,
block_N
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
acc_s_cast
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
,
block_N
),
acc_s_cast
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
,
block_N
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
device
=
"cuda"
,
dtype
=
torch
.
float16
)
acc_o
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
,
dim
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
acc_o
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
,
dim
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
scores_max
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
scores_max
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
scores_max_prev
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
),
scores_max_prev
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
device
=
"cuda"
,
dtype
=
torch
.
float
)
scores_scale
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
scores_scale
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
scores_sum
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
scores_sum
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
logsum
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
logsum
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
...
@@ -351,25 +348,25 @@ def flash_split_ref(Q, K, V, mask):
...
@@ -351,25 +348,25 @@ def flash_split_ref(Q, K, V, mask):
glogsum
=
torch
.
empty
((
num_split
,
batch
,
nheads
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
glogsum
=
torch
.
empty
((
num_split
,
batch
,
nheads
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
Q_
=
Q
*
scale
Q_
=
Q
*
scale
Q_
=
rearrange
(
Q_
,
'
b (h g) d -> b g h d
'
,
g
=
num_head_groups
)
Q_
=
rearrange
(
Q_
,
"
b (h g) d -> b g h d
"
,
g
=
num_head_groups
)
for
ks
in
range
(
num_split
):
for
ks
in
range
(
num_split
):
acc_o
.
fill_
(
0
)
acc_o
.
fill_
(
0
)
logsum
.
fill_
(
0
)
logsum
.
fill_
(
0
)
scores_max
.
fill_
(
float
(
'
-inf
'
))
scores_max
.
fill_
(
float
(
"
-inf
"
))
scores_max_prev
.
fill_
(
float
(
'
-inf
'
))
scores_max_prev
.
fill_
(
float
(
"
-inf
"
))
for
i
in
range
(
int
((
seqlen_kv
//
num_split
)
/
block_N
)):
for
i
in
range
(
int
((
seqlen_kv
//
num_split
)
/
block_N
)):
acc_s
.
fill_
(
0
)
acc_s
.
fill_
(
0
)
acc_s
=
torch
.
einsum
(
'bghd,bkhd->bghk'
,
Q_
,
acc_s
=
torch
.
einsum
(
K
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
"bghd,bkhd->bghk"
,
i
*
block_N
:(
seqlen_kv
//
num_split
)
*
ks
+
Q_
,
(
i
+
1
)
*
block_N
,
:,
:])
# [batch, nheads, block_N]
K
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:
(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:,
:],
)
# [batch, nheads, block_N]
if
mask
is
not
None
:
if
mask
is
not
None
:
mask_local
=
mask
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
mask_local
=
mask
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:
(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:]
i
*
block_N
:(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:]
mask_local
=
rearrange
(
mask_local
,
"b s h -> b h s"
)
mask_local
=
rearrange
(
mask_local
,
'b s h -> b h s'
)
mask_local
=
mask_local
.
unsqueeze
(
1
)
mask_local
=
mask_local
.
unsqueeze
(
1
)
acc_s
=
acc_s
.
masked_fill
(
mask_local
==
0
,
float
(
'
-inf
'
))
acc_s
=
acc_s
.
masked_fill
(
mask_local
==
0
,
float
(
"
-inf
"
))
scores_max_prev
=
scores_max
scores_max_prev
=
scores_max
scores_max
=
acc_s
.
max
(
dim
=-
1
,
keepdim
=
False
).
values
# [batch, nheads]
scores_max
=
acc_s
.
max
(
dim
=-
1
,
keepdim
=
False
).
values
# [batch, nheads]
scores_scale
=
torch
.
exp2
(
scores_max_prev
-
scores_max
)
# [batch, nheads]
scores_scale
=
torch
.
exp2
(
scores_max_prev
-
scores_max
)
# [batch, nheads]
...
@@ -377,15 +374,16 @@ def flash_split_ref(Q, K, V, mask):
...
@@ -377,15 +374,16 @@ def flash_split_ref(Q, K, V, mask):
acc_s
=
torch
.
exp2
(
acc_s
-
scores_max
[:,
:,
:,
None
])
acc_s
=
torch
.
exp2
(
acc_s
-
scores_max
[:,
:,
:,
None
])
acc_s_cast
=
acc_s
.
to
(
torch
.
float16
)
# [batch, nheads, block_N]
acc_s_cast
=
acc_s
.
to
(
torch
.
float16
)
# [batch, nheads, block_N]
acc_o
+=
torch
.
einsum
(
acc_o
+=
torch
.
einsum
(
'bghk,bkhd->bghd'
,
acc_s_cast
,
"bghk,bkhd->bghd"
,
V
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:(
seqlen_kv
//
num_split
)
*
ks
+
acc_s_cast
,
(
i
+
1
)
*
block_N
,
:,
:])
V
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:
(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:,
:],
)
scores_sum
=
acc_s
.
sum
(
dim
=-
1
,
keepdim
=
False
)
scores_sum
=
acc_s
.
sum
(
dim
=-
1
,
keepdim
=
False
)
logsum
=
logsum
*
scores_scale
+
scores_sum
logsum
=
logsum
*
scores_scale
+
scores_sum
acc_o_out
=
rearrange
(
acc_o
,
'
b g h d->b (h g) d
'
)
acc_o_out
=
rearrange
(
acc_o
,
"
b g h d->b (h g) d
"
)
logsum_out
=
rearrange
(
logsum
,
'
b g h->b (h g)
'
)
logsum_out
=
rearrange
(
logsum
,
"
b g h->b (h g)
"
)
acc_o_out
/=
logsum_out
[:,
:,
None
]
acc_o_out
/=
logsum_out
[:,
:,
None
]
logsum_out
=
torch
.
log2
(
logsum_out
)
+
rearrange
(
scores_max
,
'
b g h->b (h g)
'
)
logsum_out
=
torch
.
log2
(
logsum_out
)
+
rearrange
(
scores_max
,
"
b g h->b (h g)
"
)
gacc_o
[
ks
,
:,
:,
:]
=
acc_o_out
gacc_o
[
ks
,
:,
:,
:]
=
acc_o_out
glogsum
[
ks
,
:,
:]
=
logsum_out
glogsum
[
ks
,
:,
:]
=
logsum_out
...
@@ -421,7 +419,7 @@ def calc_sim(x, y, name="tensor"):
...
@@ -421,7 +419,7 @@ def calc_sim(x, y, name="tensor"):
x
,
y
=
x
.
data
.
double
(),
y
.
data
.
double
()
x
,
y
=
x
.
data
.
double
(),
y
.
data
.
double
()
denominator
=
(
x
*
x
+
y
*
y
).
sum
()
denominator
=
(
x
*
x
+
y
*
y
).
sum
()
if
denominator
==
0
:
if
denominator
==
0
:
print_red_warning
(
f
'
{
name
}
all zero
'
)
print_red_warning
(
f
"
{
name
}
all zero
"
)
return
1
return
1
sim
=
2
*
(
x
*
y
).
sum
()
/
denominator
sim
=
2
*
(
x
*
y
).
sum
()
/
denominator
return
sim
return
sim
...
@@ -429,28 +427,23 @@ def calc_sim(x, y, name="tensor"):
...
@@ -429,28 +427,23 @@ def calc_sim(x, y, name="tensor"):
def
assert_similar
(
x
,
y
,
eps
=
1e-2
,
name
=
"tensor"
,
assert_
=
False
,
print_
=
True
):
def
assert_similar
(
x
,
y
,
eps
=
1e-2
,
name
=
"tensor"
,
assert_
=
False
,
print_
=
True
):
sim
=
calc_sim
(
x
,
y
,
name
)
sim
=
calc_sim
(
x
,
y
,
name
)
diff
=
1.
-
sim
diff
=
1.
0
-
sim
if
not
(
0
<=
diff
<=
eps
):
if
not
(
0
<=
diff
<=
eps
):
print_red_warning
(
f
'
{
name
}
Error:
{
diff
}
'
)
print_red_warning
(
f
"
{
name
}
Error:
{
diff
}
"
)
if
assert_
:
if
assert_
:
raise
AssertionError
(
f
'
{
name
}
Error:
{
diff
}
'
)
raise
AssertionError
(
f
"
{
name
}
Error:
{
diff
}
"
)
else
:
else
:
if
print_
:
if
print_
:
print
(
f
'
passed:
{
name
}
diff=
{
diff
}
'
)
print
(
f
"
passed:
{
name
}
diff=
{
diff
}
"
)
def
main
(
batch
:
int
=
1
,
def
main
(
batch
:
int
=
1
,
heads
:
int
=
32
,
groups
:
int
=
8
,
kv_seqlen
:
int
=
8192
,
dim
:
int
=
128
,
tune
:
bool
=
False
):
heads
:
int
=
32
,
groups
:
int
=
8
,
kv_seqlen
:
int
=
8192
,
dim
:
int
=
128
,
tune
:
bool
=
False
):
batch
,
heads
,
groups
,
kv_seqlen
,
dim
=
batch
,
heads
,
groups
,
kv_seqlen
,
dim
batch
,
heads
,
groups
,
kv_seqlen
,
dim
=
batch
,
heads
,
groups
,
kv_seqlen
,
dim
qk_flops
=
2
*
batch
*
heads
*
kv_seqlen
*
dim
qk_flops
=
2
*
batch
*
heads
*
kv_seqlen
*
dim
pv_flops
=
2
*
batch
*
heads
*
kv_seqlen
*
dim
pv_flops
=
2
*
batch
*
heads
*
kv_seqlen
*
dim
total_flops
=
qk_flops
+
pv_flops
total_flops
=
qk_flops
+
pv_flops
if
(
not
tune
)
:
if
not
tune
:
config
,
sm_version
=
get_heuristic_config
()
config
,
sm_version
=
get_heuristic_config
()
kernel
=
flashattn
(
batch
,
heads
,
groups
,
kv_seqlen
,
dim
,
**
config
)
kernel
=
flashattn
(
batch
,
heads
,
groups
,
kv_seqlen
,
dim
,
**
config
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Auto
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Auto
)
...
@@ -470,7 +463,7 @@ def main(batch: int = 1,
...
@@ -470,7 +463,7 @@ def main(batch: int = 1,
print
(
o_ref
)
print
(
o_ref
)
assert_similar
(
o
,
o_ref
,
name
=
"o_ref"
)
assert_similar
(
o
,
o_ref
,
name
=
"o_ref"
)
assert_similar
(
o_ref_split
,
o_ref
,
name
=
"o_ref_split"
)
assert_similar
(
o
,
o_ref_split
,
name
=
"o_ref_split"
)
print
(
"All checks pass."
)
print
(
"All checks pass."
)
latency
=
profiler
.
do_bench
(
ref_program
,
warmup
=
500
)
latency
=
profiler
.
do_bench
(
ref_program
,
warmup
=
500
)
...
@@ -492,11 +485,11 @@ def main(batch: int = 1,
...
@@ -492,11 +485,11 @@ def main(batch: int = 1,
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
1
,
help
=
'
batch size
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
1
,
help
=
"
batch size
"
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
32
,
help
=
'
heads
'
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
32
,
help
=
"
heads
"
)
parser
.
add_argument
(
'
--groups
'
,
type
=
int
,
default
=
8
,
help
=
'
groups
'
)
parser
.
add_argument
(
"
--groups
"
,
type
=
int
,
default
=
8
,
help
=
"
groups
"
)
parser
.
add_argument
(
'
--kv_seqlen
'
,
type
=
int
,
default
=
8192
,
help
=
'
kv sequence length
'
)
parser
.
add_argument
(
"
--kv_seqlen
"
,
type
=
int
,
default
=
8192
,
help
=
"
kv sequence length
"
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
128
,
help
=
'
dim
'
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
128
,
help
=
"
dim
"
)
parser
.
add_argument
(
'
--tune
'
,
action
=
'
store_true
'
,
help
=
'
tune configs
'
)
parser
.
add_argument
(
"
--tune
"
,
action
=
"
store_true
"
,
help
=
"
tune configs
"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
groups
,
args
.
kv_seqlen
,
args
.
dim
,
args
.
tune
)
main
(
args
.
batch
,
args
.
heads
,
args
.
groups
,
args
.
kv_seqlen
,
args
.
dim
,
args
.
tune
)
examples/flash_decoding/example_gqa_decode_varlen_logits.py
View file @
667632cc
...
@@ -19,8 +19,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
...
@@ -19,8 +19,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch
,
num_key_value_heads
,
slen
,
head_dim
=
hidden_states
.
shape
batch
,
num_key_value_heads
,
slen
,
head_dim
=
hidden_states
.
shape
if
n_rep
==
1
:
if
n_rep
==
1
:
return
hidden_states
return
hidden_states
hidden_states
=
hidden_states
[:,
:,
None
,
:,
:].
expand
(
batch
,
num_key_value_heads
,
n_rep
,
slen
,
hidden_states
=
hidden_states
[:,
:,
None
,
:,
:].
expand
(
batch
,
num_key_value_heads
,
n_rep
,
slen
,
head_dim
)
head_dim
)
return
hidden_states
.
reshape
(
batch
,
num_key_value_heads
*
n_rep
,
slen
,
head_dim
)
return
hidden_states
.
reshape
(
batch
,
num_key_value_heads
*
n_rep
,
slen
,
head_dim
)
...
@@ -74,14 +73,9 @@ def _fwd_inner(
...
@@ -74,14 +73,9 @@ def _fwd_inner(
return
m_i
,
l_i
,
acc
return
m_i
,
l_i
,
acc
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
4
,
8
]
for
num_stages
in
[
2
,
4
]],
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
key
=
[
"gqa_group_size"
,
"BLOCK_N"
,
"BLOCK_D"
,
"BLOCK_H"
],
for
num_warps
in
[
4
,
8
]
\
for
num_stages
in
[
2
,
4
]
\
],
key
=
[
'gqa_group_size'
,
'BLOCK_N'
,
'BLOCK_D'
,
'BLOCK_H'
],
)
)
@
triton
.
jit
@
triton
.
jit
def
_fwd_kernel_varlen
(
def
_fwd_kernel_varlen
(
...
@@ -107,13 +101,12 @@ def _fwd_kernel_varlen(
...
@@ -107,13 +101,12 @@ def _fwd_kernel_varlen(
stride_od
,
stride_od
,
stride_sb
,
stride_sb
,
stride_sh
,
stride_sh
,
stride_sn
,
#bmask shape [b, q_h, seq/BLOCK_N]
stride_sn
,
#
bmask shape [b, q_h, seq/BLOCK_N]
gqa_group_size
:
tl
.
constexpr
,
gqa_group_size
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
):
):
off_z
=
tl
.
program_id
(
0
)
off_z
=
tl
.
program_id
(
0
)
off_h_for_kv
=
tl
.
program_id
(
1
)
off_h_for_kv
=
tl
.
program_id
(
1
)
off_h_q
=
off_h_for_kv
*
gqa_group_size
off_h_q
=
off_h_for_kv
*
gqa_group_size
...
@@ -134,8 +127,7 @@ def _fwd_kernel_varlen(
...
@@ -134,8 +127,7 @@ def _fwd_kernel_varlen(
S_ptrs
=
S
+
off_z
*
stride_sb
+
off_h_q
*
stride_sh
S_ptrs
=
S
+
off_z
*
stride_sb
+
off_h_q
*
stride_sh
mask_h
=
offs_h
<
gqa_group_size
mask_h
=
offs_h
<
gqa_group_size
q
=
tl
.
load
(
q
=
tl
.
load
(
Q_ptrs
+
offs_d
[
None
,
:]
*
stride_qd
+
offs_h
[:,
None
]
*
stride_qh
,
mask
=
mask_h
[:,
None
])
Q_ptrs
+
offs_d
[
None
,
:]
*
stride_qd
+
offs_h
[:,
None
]
*
stride_qh
,
mask
=
mask_h
[:,
None
])
if
s_aux
is
not
None
:
if
s_aux
is
not
None
:
sink
=
tl
.
load
(
s_aux
+
off_h_q
+
offs_h
,
mask
=
mask_h
).
to
(
tl
.
float32
)
sink
=
tl
.
load
(
s_aux
+
off_h_q
+
offs_h
,
mask
=
mask_h
).
to
(
tl
.
float32
)
...
@@ -189,14 +181,12 @@ def _fwd_kernel_varlen(
...
@@ -189,14 +181,12 @@ def _fwd_kernel_varlen(
acc
=
acc
.
to
(
O
.
dtype
.
element_ty
)
acc
=
acc
.
to
(
O
.
dtype
.
element_ty
)
tl
.
store
(
tl
.
store
(
O_ptrs
+
offs_h
[:,
None
]
*
stride_oh
+
offs_d
[
None
,
:]
*
stride_od
,
acc
,
mask
=
mask_h
[:,
None
])
O_ptrs
+
offs_h
[:,
None
]
*
stride_oh
+
offs_d
[
None
,
:]
*
stride_od
,
acc
,
mask
=
mask_h
[:,
None
])
def
get_configs
():
def
get_configs
():
import
itertools
import
itertools
block_N
=
[
64
,
128
]
block_N
=
[
64
,
128
]
block_H
=
[
64
]
block_H
=
[
64
]
num_split
=
[
1
]
num_split
=
[
1
]
...
@@ -204,38 +194,23 @@ def get_configs():
...
@@ -204,38 +194,23 @@ def get_configs():
threads
=
[
128
]
threads
=
[
128
]
_configs
=
list
(
itertools
.
product
(
block_N
,
block_H
,
num_split
,
num_stages
,
threads
))
_configs
=
list
(
itertools
.
product
(
block_N
,
block_H
,
num_split
,
num_stages
,
threads
))
configs
=
[{
configs
=
[{
"block_N"
:
c
[
0
],
"block_H"
:
c
[
1
],
"num_split"
:
c
[
2
],
"num_stages"
:
c
[
3
],
"threads"
:
c
[
4
]}
for
c
in
_configs
]
'block_N'
:
c
[
0
],
'block_H'
:
c
[
1
],
'num_split'
:
c
[
2
],
'num_stages'
:
c
[
3
],
'threads'
:
c
[
4
]
}
for
c
in
_configs
]
return
configs
return
configs
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
tilelang
.
jit
(
out_idx
=
[
-
2
,
-
1
],
debug_root_path
=
"./examples/flash_decoding"
)
@
tilelang
.
jit
(
out_idx
=
[
-
2
,
-
1
],
debug_root_path
=
"./examples/flash_decoding"
)
def
flashattn
(
batch
,
def
flashattn
(
heads
,
batch
,
heads
,
k_heads
,
max_seqlen_kv
,
total_seqlen_k
,
dim
,
has_sink
,
block_N
=
128
,
block_H
=
64
,
num_split
=
1
,
num_stages
=
1
,
threads
=
128
k_heads
,
):
max_seqlen_kv
,
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
total_seqlen_k
,
dim
,
has_sink
,
block_N
=
128
,
block_H
=
64
,
num_split
=
1
,
num_stages
=
1
,
threads
=
128
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape_q
=
[
batch
,
heads
,
dim
]
shape_q
=
[
batch
,
heads
,
dim
]
shape_k
=
[
total_seqlen_k
,
k_heads
,
dim
]
shape_k
=
[
total_seqlen_k
,
k_heads
,
dim
]
shape_v
=
[
total_seqlen_k
,
k_heads
,
dim
]
shape_v
=
[
total_seqlen_k
,
k_heads
,
dim
]
shape_o
=
[
batch
,
heads
,
dim
]
shape_o
=
[
batch
,
heads
,
dim
]
shape_s
=
[
batch
,
heads
,
math
.
ceil
(
max_seqlen_kv
/
block_N
)]
shape_s
=
[
batch
,
heads
,
math
.
ceil
(
max_seqlen_kv
/
block_N
)]
dtype
=
"
float16
"
dtype
=
T
.
float16
accum_dtype
=
"
float
"
accum_dtype
=
T
.
float
32
kv_group_num
=
heads
//
k_heads
kv_group_num
=
heads
//
k_heads
valid_block_H
=
min
(
block_H
,
kv_group_num
)
valid_block_H
=
min
(
block_H
,
kv_group_num
)
...
@@ -243,13 +218,13 @@ def flashattn(batch,
...
@@ -243,13 +218,13 @@ def flashattn(batch,
@
T
.
macro
@
T
.
macro
def
flash_attn
(
def
flash_attn
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
cu_seqlens_k
:
T
.
Tensor
([
batch
+
1
],
"
int32
"
),
cu_seqlens_k
:
T
.
Tensor
([
batch
+
1
],
T
.
int32
),
s_aux
:
T
.
Tensor
([
heads
],
"
float32
"
),
s_aux
:
T
.
Tensor
([
heads
],
T
.
float32
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
S
:
T
.
Tensor
(
shape_s
,
dtype
),
S
:
T
.
Tensor
(
shape_s
,
dtype
),
):
):
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
...
@@ -266,15 +241,17 @@ def flashattn(batch,
...
@@ -266,15 +241,17 @@ def flashattn(batch,
logsum
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
S_shared
=
T
.
alloc_shared
([
block_H
,
math
.
ceil
(
max_seqlen_kv
/
block_N
)],
dtype
)
S_shared
=
T
.
alloc_shared
([
block_H
,
math
.
ceil
(
max_seqlen_kv
/
block_N
)],
dtype
)
# S_fragment = T.alloc_fragment([block_H, math.ceil(max_seqlen_kv / block_N)], accum_dtype)
# S_fragment = T.alloc_fragment([block_H, math.ceil(max_seqlen_kv / block_N)], accum_dtype)
s_aux_shared
=
T
.
alloc_shared
([
block_H
],
"float32"
)
s_aux_shared
=
T
.
alloc_shared
([
block_H
],
T
.
float32
)
T
.
annotate_layout
({
T
.
annotate_layout
(
# Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
{
# K_shared: tilelang.layout.make_swizzled_layout(K_shared),
# Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
# V_shared: tilelang.layout.make_swizzled_layout(V_shared),
# K_shared: tilelang.layout.make_swizzled_layout(K_shared),
# O_shared: tilelang.layout.make_swizzled_layout(O_shared),
# V_shared: tilelang.layout.make_swizzled_layout(V_shared),
# S_shared: tilelang.layout.make_swizzled_layout(S_shared),
# O_shared: tilelang.layout.make_swizzled_layout(O_shared),
})
# S_shared: tilelang.layout.make_swizzled_layout(S_shared),
}
)
bid
=
bx
bid
=
bx
hid
=
by
hid
=
by
...
@@ -284,7 +261,7 @@ def flashattn(batch,
...
@@ -284,7 +261,7 @@ def flashattn(batch,
cur_end_k
=
cu_seqlens_k
[
bid
+
1
]
cur_end_k
=
cu_seqlens_k
[
bid
+
1
]
cur_seqlen_k
=
cur_end_k
-
cur_start_k
cur_seqlen_k
=
cur_end_k
-
cur_start_k
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
@@ -292,15 +269,13 @@ def flashattn(batch,
...
@@ -292,15 +269,13 @@ def flashattn(batch,
# loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
# loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
loop_range
=
T
.
ceildiv
((
cur_seqlen_k
//
num_split
),
block_N
)
loop_range
=
T
.
ceildiv
((
cur_seqlen_k
//
num_split
),
block_N
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
T
.
copy
(
K
[
cur_start_k
+
k
*
block_N
:
cur_start_k
+
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
T
.
copy
(
K
[
cur_start_k
+
k
*
block_N
:
cur_start_k
+
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
K_shared
)
T
.
clear
(
acc_s
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
# acc_s[i, j] = T.if_then_else(mask_local[j] != 0 and k * block_N + j < cur_seqlen_k, acc_s[i, j],
# acc_s[i, j] = T.if_then_else(mask_local[j] != 0 and k * block_N + j < cur_seqlen_k, acc_s[i, j],
# -T.infinity(accum_dtype))
# -T.infinity(accum_dtype))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
k
*
block_N
+
j
<
cur_seqlen_k
,
acc_s
[
i
,
j
],
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
k
*
block_N
+
j
<
cur_seqlen_k
,
acc_s
[
i
,
j
],
-
T
.
infinity
(
accum_dtype
))
-
T
.
infinity
(
accum_dtype
))
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
...
@@ -320,12 +295,11 @@ def flashattn(batch,
...
@@ -320,12 +295,11 @@ def flashattn(batch,
T
.
copy
(
acc_s
,
acc_s_cast
)
T
.
copy
(
acc_s
,
acc_s_cast
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
V
[
cur_start_k
+
k
*
block_N
:
cur_start_k
+
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
T
.
copy
(
V
[
cur_start_k
+
k
*
block_N
:
cur_start_k
+
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
if
has_sink
:
if
has_sink
:
T
.
copy
(
s_aux
[
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
],
s_aux_shared
)
T
.
copy
(
s_aux
[
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
],
s_aux_shared
)
for
i
in
T
.
Parallel
(
block_H
):
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
+=
s_aux_shared
[
i
]
logsum
[
i
]
+=
s_aux_shared
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
...
@@ -338,20 +312,19 @@ def flashattn(batch,
...
@@ -338,20 +312,19 @@ def flashattn(batch,
for
i
in
T
.
Parallel
(
block_H
):
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
acc_o
[:
valid_block_H
,
:],
O_shared
)
T
.
copy
(
acc_o
[:
valid_block_H
,
:],
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bid
,
hid
*
valid_block_H
:
(
hid
+
1
)
*
valid_block_H
,
:])
T
.
copy
(
O_shared
,
Output
[
bid
,
hid
*
valid_block_H
:
(
hid
+
1
)
*
valid_block_H
,
:])
# T.copy(S_fragment, S_shared)
# T.copy(S_fragment, S_shared)
T
.
copy
(
S_shared
[:
valid_block_H
,
:],
S
[
bid
,
T
.
copy
(
S_shared
[:
valid_block_H
,
:],
S
[
bid
,
hid
*
valid_block_H
:
(
hid
+
1
)
*
valid_block_H
,
:])
hid
*
valid_block_H
:(
hid
+
1
)
*
valid_block_H
,
:])
@
T
.
prim_func
@
T
.
prim_func
def
flashattn_gqa_decode_no_split
(
def
flashattn_gqa_decode_no_split
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
cu_seqlens_k
:
T
.
Tensor
([
batch
+
1
],
"
int32
"
),
cu_seqlens_k
:
T
.
Tensor
([
batch
+
1
],
T
.
int32
),
s_aux
:
T
.
Tensor
([
heads
],
"
float32
"
),
s_aux
:
T
.
Tensor
([
heads
],
T
.
float32
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
S
:
T
.
Tensor
(
shape_s
,
dtype
),
S
:
T
.
Tensor
(
shape_s
,
dtype
),
):
):
flash_attn
(
Q
,
K
,
V
,
cu_seqlens_k
,
s_aux
,
Output
,
S
)
flash_attn
(
Q
,
K
,
V
,
cu_seqlens_k
,
s_aux
,
Output
,
S
)
...
@@ -388,9 +361,7 @@ def flash_attn_with_attn_pool_decode_tilelang(
...
@@ -388,9 +361,7 @@ def flash_attn_with_attn_pool_decode_tilelang(
gqa_group_size
=
q_h
//
k_h
gqa_group_size
=
q_h
//
k_h
O_tl
=
torch
.
zeros_like
(
Q
)
O_tl
=
torch
.
zeros_like
(
Q
)
S_tl
=
torch
.
zeros
((
batch
,
q_h
,
math
.
ceil
(
real_max_k_seqlen
/
block_size
)),
S_tl
=
torch
.
zeros
((
batch
,
q_h
,
math
.
ceil
(
real_max_k_seqlen
/
block_size
)),
dtype
=
Q
.
dtype
,
device
=
Q
.
device
)
dtype
=
Q
.
dtype
,
device
=
Q
.
device
)
O_tl
,
S_tl
=
tl_kernel
(
Q
,
K
,
V
,
cu_seqlens_k
,
s_aux
)
O_tl
,
S_tl
=
tl_kernel
(
Q
,
K
,
V
,
cu_seqlens_k
,
s_aux
)
if
use_per_kv_head_sparse_index
:
if
use_per_kv_head_sparse_index
:
...
@@ -433,9 +404,7 @@ def flash_attn_with_attn_pool_decode(
...
@@ -433,9 +404,7 @@ def flash_attn_with_attn_pool_decode(
BLOCK_H
=
64
BLOCK_H
=
64
O
=
torch
.
zeros_like
(
Q
)
O
=
torch
.
zeros_like
(
Q
)
S
=
torch
.
zeros
((
batch
,
q_h
,
math
.
ceil
(
max_seqlen_k
/
block_size
)),
S
=
torch
.
zeros
((
batch
,
q_h
,
math
.
ceil
(
max_seqlen_k
/
block_size
)),
dtype
=
Q
.
dtype
,
device
=
Q
.
device
)
dtype
=
Q
.
dtype
,
device
=
Q
.
device
)
def
grid
(
META
):
def
grid
(
META
):
return
(
batch
,
k_h
)
return
(
batch
,
k_h
)
...
@@ -480,18 +449,18 @@ def test_equal_seqlen_decode_main(args):
...
@@ -480,18 +449,18 @@ def test_equal_seqlen_decode_main(args):
real_max_k_seqlen
=
args
.
k_seqlen
real_max_k_seqlen
=
args
.
k_seqlen
head_size
=
args
.
head_size
head_size
=
args
.
head_size
block_size
=
args
.
block_size
block_size
=
args
.
block_size
dtype
=
torch
.
bfloat16
if
args
.
dtype
==
"
bfloat16
"
else
torch
.
float16
dtype
=
torch
.
bfloat16
if
args
.
dtype
==
T
.
bfloat16
else
torch
.
float16
# For decode, query is just 1 token per batch
# For decode, query is just 1 token per batch
q
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
q
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
k
=
torch
.
randn
(
batch_size
,
kv_heads
,
k_seqlen
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
k
=
torch
.
randn
(
batch_size
,
kv_heads
,
k_seqlen
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
v
=
torch
.
randn
(
batch_size
,
kv_heads
,
k_seqlen
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
v
=
torch
.
randn
(
batch_size
,
kv_heads
,
k_seqlen
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
softmax_scale
=
1.0
/
math
.
sqrt
(
head_size
)
softmax_scale
=
1.0
/
math
.
sqrt
(
head_size
)
# Generate sink values if needed
# Generate sink values if needed
sink
=
None
sink
=
None
if
args
.
test_sink
:
if
args
.
test_sink
:
sink
=
torch
.
randn
(
q_heads
,
device
=
'
cuda
'
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
sink
=
torch
.
randn
(
q_heads
,
device
=
"
cuda
"
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
print
(
f
"Using sink attention with sink values:
{
sink
}
"
)
print
(
f
"Using sink attention with sink values:
{
sink
}
"
)
# Convert to varlen format for K, V
# Convert to varlen format for K, V
...
@@ -499,8 +468,7 @@ def test_equal_seqlen_decode_main(args):
...
@@ -499,8 +468,7 @@ def test_equal_seqlen_decode_main(args):
v_varlen
=
v
.
transpose
(
1
,
2
).
reshape
(
batch_size
*
k_seqlen
,
kv_heads
,
head_size
)
v_varlen
=
v
.
transpose
(
1
,
2
).
reshape
(
batch_size
*
k_seqlen
,
kv_heads
,
head_size
)
# Generate cumulative sequence lengths
# Generate cumulative sequence lengths
cu_seqlens_k
=
torch
.
arange
(
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
k_seqlen
,
k_seqlen
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
0
,
(
batch_size
+
1
)
*
k_seqlen
,
k_seqlen
,
device
=
'cuda'
,
dtype
=
torch
.
int32
)
max_seqlen_k
=
k_seqlen
max_seqlen_k
=
k_seqlen
print
(
f
"q shape:
{
q
.
shape
}
"
)
print
(
f
"q shape:
{
q
.
shape
}
"
)
...
@@ -510,8 +478,7 @@ def test_equal_seqlen_decode_main(args):
...
@@ -510,8 +478,7 @@ def test_equal_seqlen_decode_main(args):
num_tokens
,
q_h
,
head_size
=
q
.
shape
num_tokens
,
q_h
,
head_size
=
q
.
shape
batch
=
cu_seqlens_k
.
size
(
0
)
-
1
batch
=
cu_seqlens_k
.
size
(
0
)
-
1
k_h
=
k_varlen
.
size
(
1
)
k_h
=
k_varlen
.
size
(
1
)
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
args
.
test_sink
)
args
.
test_sink
)
# Test our decode kernel
# Test our decode kernel
O_triton
,
S_triton
=
flash_attn_with_attn_pool_decode
(
O_triton
,
S_triton
=
flash_attn_with_attn_pool_decode
(
...
@@ -524,7 +491,8 @@ def test_equal_seqlen_decode_main(args):
...
@@ -524,7 +491,8 @@ def test_equal_seqlen_decode_main(args):
args
.
num_split
,
args
.
num_split
,
softmax_scale
,
softmax_scale
,
s_aux
=
sink
,
s_aux
=
sink
,
block_size
=
block_size
)
block_size
=
block_size
,
)
O_tilelang
,
S_tilelang
=
flash_attn_with_attn_pool_decode_tilelang
(
O_tilelang
,
S_tilelang
=
flash_attn_with_attn_pool_decode_tilelang
(
q
,
q
,
k_varlen
,
k_varlen
,
...
@@ -539,9 +507,7 @@ def test_equal_seqlen_decode_main(args):
...
@@ -539,9 +507,7 @@ def test_equal_seqlen_decode_main(args):
tl_kernel
=
tl_kernel
,
tl_kernel
=
tl_kernel
,
)
)
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
S_tilelang
[
i
,
:,
S_tilelang
[
i
,
:,
math
.
ceil
((
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
())
/
block_size
)
:]
=
0
math
.
ceil
((
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
())
/
block_size
):]
=
0
# Compute torch reference
# Compute torch reference
q_expanded
=
q
.
unsqueeze
(
2
)
# [b, q_heads, 1, head_size]
q_expanded
=
q
.
unsqueeze
(
2
)
# [b, q_heads, 1, head_size]
...
@@ -550,14 +516,12 @@ def test_equal_seqlen_decode_main(args):
...
@@ -550,14 +516,12 @@ def test_equal_seqlen_decode_main(args):
if
sink
is
None
:
if
sink
is
None
:
# Standard scaled dot-product attention
# Standard scaled dot-product attention
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [batch, q_heads, 1, seqlen_k]
-
2
,
-
1
))
*
softmax_scale
# [batch, q_heads, 1, seqlen_k]
attn_weights
=
torch
.
softmax
(
logits
,
dim
=-
1
)
attn_weights
=
torch
.
softmax
(
logits
,
dim
=-
1
)
O_torch
=
torch
.
matmul
(
attn_weights
,
v_repeat
).
squeeze
(
2
)
# [batch, q_heads, head_size]
O_torch
=
torch
.
matmul
(
attn_weights
,
v_repeat
).
squeeze
(
2
)
# [batch, q_heads, head_size]
else
:
else
:
# s_aux attention
# s_aux attention
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [batch, q_heads, 1, seqlen_k]
-
2
,
-
1
))
*
softmax_scale
# [batch, q_heads, 1, seqlen_k]
sink_expanded
=
sink
.
view
(
1
,
q_heads
,
1
,
1
)
# [1, q_heads, 1, 1]
sink_expanded
=
sink
.
view
(
1
,
q_heads
,
1
,
1
)
# [1, q_heads, 1, 1]
logits_max
=
torch
.
max
(
logits
,
dim
=-
1
,
keepdim
=
True
).
values
logits_max
=
torch
.
max
(
logits
,
dim
=-
1
,
keepdim
=
True
).
values
...
@@ -566,15 +530,15 @@ def test_equal_seqlen_decode_main(args):
...
@@ -566,15 +530,15 @@ def test_equal_seqlen_decode_main(args):
unnormalized_scores
=
torch
.
exp
(
logits
-
logits_or_sinks_max
)
unnormalized_scores
=
torch
.
exp
(
logits
-
logits_or_sinks_max
)
normalizer
=
unnormalized_scores
.
sum
(
dim
=-
1
,
keepdim
=
True
)
+
sinks
normalizer
=
unnormalized_scores
.
sum
(
dim
=-
1
,
keepdim
=
True
)
+
sinks
attn_weights
=
unnormalized_scores
/
normalizer
attn_weights
=
unnormalized_scores
/
normalizer
O_torch
=
torch
.
matmul
(
attn_weights
.
to
(
v_repeat
.
dtype
),
O_torch
=
torch
.
matmul
(
attn_weights
.
to
(
v_repeat
.
dtype
),
v_repeat
).
squeeze
(
2
)
# [batch, q_heads, head_size]
v_repeat
).
squeeze
(
2
)
# [batch, q_heads, head_size]
# Compute attention score pooling
# Compute attention score pooling
attn_score_pooled
=
torch
.
max_pool2d
(
attn_score_pooled
=
torch
.
max_pool2d
(
attn_weights
.
squeeze
(
2
),
# [b, q_heads, k_seqlen]
attn_weights
.
squeeze
(
2
),
# [b, q_heads, k_seqlen]
kernel_size
=
(
q_heads
,
block_size
),
kernel_size
=
(
q_heads
,
block_size
),
stride
=
(
q_heads
,
block_size
),
stride
=
(
q_heads
,
block_size
),
ceil_mode
=
True
).
to
(
torch
.
float16
)
ceil_mode
=
True
,
).
to
(
torch
.
float16
)
print
(
"S_tilelang"
,
S_tilelang
)
print
(
"S_tilelang"
,
S_tilelang
)
print
(
"attn_score_pooled"
,
attn_score_pooled
)
print
(
"attn_score_pooled"
,
attn_score_pooled
)
...
@@ -588,15 +552,10 @@ def test_equal_seqlen_decode_main(args):
...
@@ -588,15 +552,10 @@ def test_equal_seqlen_decode_main(args):
print
(
f
"Max difference in S:
{
max_diff_s
.
item
()
}
"
)
print
(
f
"Max difference in S:
{
max_diff_s
.
item
()
}
"
)
print
(
f
"Max difference in O_tilelang:
{
max_diff_o_tilelang
.
item
()
}
"
)
print
(
f
"Max difference in O_tilelang:
{
max_diff_o_tilelang
.
item
()
}
"
)
print
(
f
"Max difference in S_tilelang:
{
max_diff_s_tilelang
.
item
()
}
"
)
print
(
f
"Max difference in S_tilelang:
{
max_diff_s_tilelang
.
item
()
}
"
)
assert
torch
.
allclose
(
assert
torch
.
allclose
(
O_triton
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o
.
item
()
}
"
O_triton
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o
.
item
()
}
"
assert
torch
.
allclose
(
S_triton
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s
.
item
()
}
"
assert
torch
.
allclose
(
assert
torch
.
allclose
(
O_tilelang
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o_tilelang
.
item
()
}
"
S_triton
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s
.
item
()
}
"
assert
torch
.
allclose
(
S_tilelang
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s_tilelang
.
item
()
}
"
assert
torch
.
allclose
(
O_tilelang
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o_tilelang
.
item
()
}
"
assert
torch
.
allclose
(
S_tilelang
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s_tilelang
.
item
()
}
"
print
(
"✅ All tests passed!"
)
print
(
"✅ All tests passed!"
)
...
@@ -609,14 +568,14 @@ def test_varlen_decode_main(args):
...
@@ -609,14 +568,14 @@ def test_varlen_decode_main(args):
real_max_k_seqlen
=
args
.
k_seqlen
real_max_k_seqlen
=
args
.
k_seqlen
head_size
=
args
.
head_size
head_size
=
args
.
head_size
block_size
=
args
.
block_size
block_size
=
args
.
block_size
dtype
=
torch
.
bfloat16
if
args
.
dtype
==
"
bfloat16
"
else
torch
.
float16
dtype
=
torch
.
bfloat16
if
args
.
dtype
==
T
.
bfloat16
else
torch
.
float16
print
(
f
"Testing decode kernel with variable sequence lengths (max_k_seqlen=
{
max_k_seqlen
}
)"
)
print
(
f
"Testing decode kernel with variable sequence lengths (max_k_seqlen=
{
max_k_seqlen
}
)"
)
# Generate sink values if needed
# Generate sink values if needed
sink
=
None
sink
=
None
if
args
.
test_sink
:
if
args
.
test_sink
:
sink
=
torch
.
randn
(
q_heads
,
device
=
'
cuda
'
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
sink
=
torch
.
randn
(
q_heads
,
device
=
"
cuda
"
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
print
(
f
"Using sink attention with sink values:
{
sink
}
"
)
print
(
f
"Using sink attention with sink values:
{
sink
}
"
)
# Generate variable length k sequences
# Generate variable length k sequences
...
@@ -624,7 +583,7 @@ def test_varlen_decode_main(args):
...
@@ -624,7 +583,7 @@ def test_varlen_decode_main(args):
print
(
f
"k_seqlens:
{
k_seqlens
}
"
)
print
(
f
"k_seqlens:
{
k_seqlens
}
"
)
# Generate cumulative sequence lengths for k
# Generate cumulative sequence lengths for k
cu_seqlens_k
=
torch
.
zeros
(
batch_size
+
1
,
device
=
'
cuda
'
,
dtype
=
torch
.
int32
)
cu_seqlens_k
=
torch
.
zeros
(
batch_size
+
1
,
device
=
"
cuda
"
,
dtype
=
torch
.
int32
)
total_k_tokens
=
0
total_k_tokens
=
0
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
cu_seqlens_k
[
i
]
=
total_k_tokens
cu_seqlens_k
[
i
]
=
total_k_tokens
...
@@ -634,9 +593,9 @@ def test_varlen_decode_main(args):
...
@@ -634,9 +593,9 @@ def test_varlen_decode_main(args):
print
(
f
"cu_seqlens_k:
{
cu_seqlens_k
}
"
)
print
(
f
"cu_seqlens_k:
{
cu_seqlens_k
}
"
)
# Generate tensors - Q is [batch_size, q_heads, head_size] for decode
# Generate tensors - Q is [batch_size, q_heads, head_size] for decode
q_decode
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
q_decode
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
k_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
k_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
v_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
v_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
softmax_scale
=
1.0
/
math
.
sqrt
(
head_size
)
softmax_scale
=
1.0
/
math
.
sqrt
(
head_size
)
max_seqlen_k
=
int
(
k_seqlens
.
max
())
max_seqlen_k
=
int
(
k_seqlens
.
max
())
...
@@ -649,8 +608,7 @@ def test_varlen_decode_main(args):
...
@@ -649,8 +608,7 @@ def test_varlen_decode_main(args):
num_tokens
,
q_h
,
head_size
=
q_decode
.
shape
num_tokens
,
q_h
,
head_size
=
q_decode
.
shape
batch
=
cu_seqlens_k
.
size
(
0
)
-
1
batch
=
cu_seqlens_k
.
size
(
0
)
-
1
k_h
=
k_varlen
.
size
(
1
)
k_h
=
k_varlen
.
size
(
1
)
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
args
.
test_sink
)
args
.
test_sink
)
# Test our decode kernel
# Test our decode kernel
O_triton
,
S_triton
=
flash_attn_with_attn_pool_decode
(
O_triton
,
S_triton
=
flash_attn_with_attn_pool_decode
(
...
@@ -663,7 +621,8 @@ def test_varlen_decode_main(args):
...
@@ -663,7 +621,8 @@ def test_varlen_decode_main(args):
args
.
num_split
,
args
.
num_split
,
softmax_scale
,
softmax_scale
,
s_aux
=
sink
,
s_aux
=
sink
,
block_size
=
block_size
)
block_size
=
block_size
,
)
O_tilelang
,
S_tilelang
=
flash_attn_with_attn_pool_decode_tilelang
(
O_tilelang
,
S_tilelang
=
flash_attn_with_attn_pool_decode_tilelang
(
q_decode
,
q_decode
,
k_varlen
,
k_varlen
,
...
@@ -678,9 +637,7 @@ def test_varlen_decode_main(args):
...
@@ -678,9 +637,7 @@ def test_varlen_decode_main(args):
tl_kernel
=
tl_kernel
,
tl_kernel
=
tl_kernel
,
)
)
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
S_tilelang
[
i
,
:,
S_tilelang
[
i
,
:,
math
.
ceil
((
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
())
/
block_size
)
:]
=
0
math
.
ceil
((
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
())
/
block_size
):]
=
0
# Create torch reference - pad tensors for comparison
# Create torch reference - pad tensors for comparison
k_padded_list
=
[]
k_padded_list
=
[]
...
@@ -694,8 +651,8 @@ def test_varlen_decode_main(args):
...
@@ -694,8 +651,8 @@ def test_varlen_decode_main(args):
k_end
=
cu_seqlens_k
[
i
+
1
]
k_end
=
cu_seqlens_k
[
i
+
1
]
# Pad to max_seqlen_k
# Pad to max_seqlen_k
k_padded
=
torch
.
zeros
(
max_seqlen_k
,
kv_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
k_padded
=
torch
.
zeros
(
max_seqlen_k
,
kv_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
v_padded
=
torch
.
zeros
(
max_seqlen_k
,
kv_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
v_padded
=
torch
.
zeros
(
max_seqlen_k
,
kv_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
k_padded
[:
actual_k_len
]
=
k_varlen
[
k_start
:
k_end
]
k_padded
[:
actual_k_len
]
=
k_varlen
[
k_start
:
k_end
]
v_padded
[:
actual_k_len
]
=
v_varlen
[
k_start
:
k_end
]
v_padded
[:
actual_k_len
]
=
v_varlen
[
k_start
:
k_end
]
...
@@ -704,10 +661,8 @@ def test_varlen_decode_main(args):
...
@@ -704,10 +661,8 @@ def test_varlen_decode_main(args):
v_padded_list
.
append
(
v_padded
)
v_padded_list
.
append
(
v_padded
)
# Stack to create batched tensors [b, max_seqlen, kv_heads, head_size]
# Stack to create batched tensors [b, max_seqlen, kv_heads, head_size]
k_padded_batched
=
torch
.
stack
(
k_padded_batched
=
torch
.
stack
(
k_padded_list
,
dim
=
0
).
transpose
(
1
,
2
)
# [b, kv_heads, max_seqlen, head_size]
k_padded_list
,
dim
=
0
).
transpose
(
1
,
2
)
# [b, kv_heads, max_seqlen, head_size]
v_padded_batched
=
torch
.
stack
(
v_padded_list
,
dim
=
0
).
transpose
(
1
,
2
)
# [b, kv_heads, max_seqlen, head_size]
v_padded_batched
=
torch
.
stack
(
v_padded_list
,
dim
=
0
).
transpose
(
1
,
2
)
# [b, kv_heads, max_seqlen, head_size]
# Expand q to match kv heads: [b, q_heads, 1, head_size]
# Expand q to match kv heads: [b, q_heads, 1, head_size]
q_expanded
=
q_decode
.
unsqueeze
(
2
)
# [b, q_heads, 1, head_size]
q_expanded
=
q_decode
.
unsqueeze
(
2
)
# [b, q_heads, 1, head_size]
...
@@ -717,20 +672,17 @@ def test_varlen_decode_main(args):
...
@@ -717,20 +672,17 @@ def test_varlen_decode_main(args):
print
(
f
"v_padded_batched shape:
{
v_padded_batched
.
shape
}
"
)
print
(
f
"v_padded_batched shape:
{
v_padded_batched
.
shape
}
"
)
# Compute torch reference
# Compute torch reference
k_repeat
=
repeat_kv
(
k_padded_batched
,
k_repeat
=
repeat_kv
(
k_padded_batched
,
q_heads
//
kv_heads
)
# [b, q_heads, max_seqlen, head_size]
q_heads
//
kv_heads
)
# [b, q_heads, max_seqlen, head_size]
v_repeat
=
repeat_kv
(
v_padded_batched
,
q_heads
//
kv_heads
)
# [b, q_heads, max_seqlen, head_size]
v_repeat
=
repeat_kv
(
v_padded_batched
,
q_heads
//
kv_heads
)
# [b, q_heads, max_seqlen, head_size]
if
sink
is
None
:
if
sink
is
None
:
# Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen]
# Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen]
attn_score
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
attn_score
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [b, q_heads, 1, max_seqlen]
-
2
,
-
1
))
*
softmax_scale
# [b, q_heads, 1, max_seqlen]
# Apply sequence length masking
# Apply sequence length masking
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
actual_k_len
=
k_seqlens
[
i
]
actual_k_len
=
k_seqlens
[
i
]
attn_score
[
i
,
:,
:,
actual_k_len
:]
=
float
(
'
-inf
'
)
attn_score
[
i
,
:,
:,
actual_k_len
:]
=
float
(
"
-inf
"
)
attn_weights
=
attn_score
.
softmax
(
dim
=-
1
)
# [b, q_heads, 1, max_seqlen]
attn_weights
=
attn_score
.
softmax
(
dim
=-
1
)
# [b, q_heads, 1, max_seqlen]
...
@@ -743,13 +695,12 @@ def test_varlen_decode_main(args):
...
@@ -743,13 +695,12 @@ def test_varlen_decode_main(args):
O_torch
=
torch
.
matmul
(
attn_weights
,
v_repeat
)
# [b, q_heads, 1, head_size]
O_torch
=
torch
.
matmul
(
attn_weights
,
v_repeat
)
# [b, q_heads, 1, head_size]
else
:
else
:
# s_aux attention
# s_aux attention
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [b, q_heads, 1, max_seqlen]
-
2
,
-
1
))
*
softmax_scale
# [b, q_heads, 1, max_seqlen]
# Apply sequence length masking
# Apply sequence length masking
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
actual_k_len
=
k_seqlens
[
i
]
actual_k_len
=
k_seqlens
[
i
]
logits
[
i
,
:,
:,
actual_k_len
:]
=
float
(
'
-inf
'
)
logits
[
i
,
:,
:,
actual_k_len
:]
=
float
(
"
-inf
"
)
sink_expanded
=
sink
.
view
(
1
,
q_heads
,
1
,
1
)
# [1, q_heads, 1, 1]
sink_expanded
=
sink
.
view
(
1
,
q_heads
,
1
,
1
)
# [1, q_heads, 1, 1]
logits_max
=
torch
.
max
(
logits
,
dim
=-
1
,
keepdim
=
True
).
values
logits_max
=
torch
.
max
(
logits
,
dim
=-
1
,
keepdim
=
True
).
values
...
@@ -765,8 +716,7 @@ def test_varlen_decode_main(args):
...
@@ -765,8 +716,7 @@ def test_varlen_decode_main(args):
attn_weights
[
i
,
:,
:,
actual_k_len
:]
=
0.0
attn_weights
[
i
,
:,
:,
actual_k_len
:]
=
0.0
# Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size]
# Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size]
O_torch
=
torch
.
matmul
(
attn_weights
.
to
(
v_repeat
.
dtype
),
O_torch
=
torch
.
matmul
(
attn_weights
.
to
(
v_repeat
.
dtype
),
v_repeat
)
# [b, q_heads, 1, head_size]
v_repeat
)
# [b, q_heads, 1, head_size]
O_torch
=
O_torch
.
squeeze
(
2
)
# [b, q_heads, head_size]
O_torch
=
O_torch
.
squeeze
(
2
)
# [b, q_heads, head_size]
...
@@ -775,7 +725,8 @@ def test_varlen_decode_main(args):
...
@@ -775,7 +725,8 @@ def test_varlen_decode_main(args):
attn_weights
.
squeeze
(
2
),
# [b, q_heads, max_seqlen]
attn_weights
.
squeeze
(
2
),
# [b, q_heads, max_seqlen]
kernel_size
=
(
q_heads
,
block_size
),
kernel_size
=
(
q_heads
,
block_size
),
stride
=
(
q_heads
,
block_size
),
stride
=
(
q_heads
,
block_size
),
ceil_mode
=
True
).
to
(
dtype
=
torch
.
float16
)
# [b, 1, ceil(max_seqlen/block_size)]
ceil_mode
=
True
,
).
to
(
dtype
=
torch
.
float16
)
# [b, 1, ceil(max_seqlen/block_size)]
print
(
f
"O_triton shape:
{
O_triton
.
shape
}
"
)
print
(
f
"O_triton shape:
{
O_triton
.
shape
}
"
)
print
(
f
"O_tilelang shape:
{
O_tilelang
.
shape
}
"
)
print
(
f
"O_tilelang shape:
{
O_tilelang
.
shape
}
"
)
...
@@ -791,22 +742,16 @@ def test_varlen_decode_main(args):
...
@@ -791,22 +742,16 @@ def test_varlen_decode_main(args):
print
(
f
"Max difference in O_tilelang:
{
max_diff_o_tl
.
item
()
}
"
)
print
(
f
"Max difference in O_tilelang:
{
max_diff_o_tl
.
item
()
}
"
)
max_diff_s
=
torch
.
max
(
torch
.
abs
(
S_triton
-
attn_score_pooled
))
max_diff_s
=
torch
.
max
(
torch
.
abs
(
S_triton
-
attn_score_pooled
))
max_diff_s_tl
=
torch
.
max
(
max_diff_s_tl
=
torch
.
max
(
torch
.
abs
(
S_tilelang
[:,
:,
:
math
.
ceil
(
max_seqlen_k
/
block_size
)]
-
attn_score_pooled
))
torch
.
abs
(
S_tilelang
[:,
:,
:
math
.
ceil
(
max_seqlen_k
/
block_size
)]
-
attn_score_pooled
))
print
(
f
"Max difference in S:
{
max_diff_s
.
item
()
}
"
)
print
(
f
"Max difference in S:
{
max_diff_s
.
item
()
}
"
)
print
(
f
"Max difference in S_tilelang:
{
max_diff_s_tl
.
item
()
}
"
)
print
(
f
"Max difference in S_tilelang:
{
max_diff_s_tl
.
item
()
}
"
)
assert
torch
.
allclose
(
assert
torch
.
allclose
(
O_triton
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o
.
item
()
}
"
O_triton
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o
.
item
()
}
"
assert
torch
.
allclose
(
S_triton
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s
.
item
()
}
"
assert
torch
.
allclose
(
assert
torch
.
allclose
(
O_tilelang
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o_tl
.
item
()
}
"
S_triton
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s
.
item
()
}
"
assert
torch
.
allclose
(
S_tilelang
[:,
:,
:
math
.
ceil
(
max_seqlen_k
/
block_size
)],
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
(
assert
torch
.
allclose
(
f
"Score mismatch:
{
max_diff_s_tl
.
item
()
}
"
O_tilelang
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o_tl
.
item
()
}
"
)
assert
torch
.
allclose
(
S_tilelang
[:,
:,
:
math
.
ceil
(
max_seqlen_k
/
block_size
)],
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s_tl
.
item
()
}
"
print
(
"✅ All tests passed!"
)
print
(
"✅ All tests passed!"
)
...
@@ -844,7 +789,7 @@ def speed_benchmark_decode_comparison(args):
...
@@ -844,7 +789,7 @@ def speed_benchmark_decode_comparison(args):
max_k_seqlen
=
args
.
k_seqlen
max_k_seqlen
=
args
.
k_seqlen
head_size
=
args
.
head_size
head_size
=
args
.
head_size
block_size
=
args
.
block_size
block_size
=
args
.
block_size
dtype
=
torch
.
bfloat16
if
args
.
dtype
==
"
bfloat16
"
else
torch
.
float16
dtype
=
torch
.
bfloat16
if
args
.
dtype
==
T
.
bfloat16
else
torch
.
float16
print
(
"
\n
=== Decode Speed Benchmark Comparison ==="
)
print
(
"
\n
=== Decode Speed Benchmark Comparison ==="
)
print
(
"Configuration:"
)
print
(
"Configuration:"
)
...
@@ -865,7 +810,7 @@ def speed_benchmark_decode_comparison(args):
...
@@ -865,7 +810,7 @@ def speed_benchmark_decode_comparison(args):
k_seqlens
=
torch
.
full
((
batch_size
,),
max_k_seqlen
,
dtype
=
int
)
k_seqlens
=
torch
.
full
((
batch_size
,),
max_k_seqlen
,
dtype
=
int
)
# Generate cumulative sequence lengths for k
# Generate cumulative sequence lengths for k
cu_seqlens_k
=
torch
.
zeros
(
batch_size
+
1
,
device
=
'
cuda
'
,
dtype
=
torch
.
int32
)
cu_seqlens_k
=
torch
.
zeros
(
batch_size
+
1
,
device
=
"
cuda
"
,
dtype
=
torch
.
int32
)
total_k_tokens
=
0
total_k_tokens
=
0
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
cu_seqlens_k
[
i
]
=
total_k_tokens
cu_seqlens_k
[
i
]
=
total_k_tokens
...
@@ -873,9 +818,9 @@ def speed_benchmark_decode_comparison(args):
...
@@ -873,9 +818,9 @@ def speed_benchmark_decode_comparison(args):
cu_seqlens_k
[
batch_size
]
=
total_k_tokens
cu_seqlens_k
[
batch_size
]
=
total_k_tokens
# Generate tensors
# Generate tensors
q_decode
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
q_decode
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
k_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
k_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
v_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
v_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
softmax_scale
=
1.0
/
math
.
sqrt
(
head_size
)
softmax_scale
=
1.0
/
math
.
sqrt
(
head_size
)
max_seqlen_k
=
int
(
k_seqlens
.
max
())
max_seqlen_k
=
int
(
k_seqlens
.
max
())
...
@@ -883,7 +828,7 @@ def speed_benchmark_decode_comparison(args):
...
@@ -883,7 +828,7 @@ def speed_benchmark_decode_comparison(args):
# Generate sink values if needed
# Generate sink values if needed
sink
=
None
sink
=
None
if
args
.
test_sink
:
if
args
.
test_sink
:
sink
=
torch
.
randn
(
q_heads
,
device
=
'
cuda
'
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
sink
=
torch
.
randn
(
q_heads
,
device
=
"
cuda
"
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
print
(
" Using sink attention with sink values"
)
print
(
" Using sink attention with sink values"
)
print
(
"Setup complete:"
)
print
(
"Setup complete:"
)
...
@@ -896,8 +841,7 @@ def speed_benchmark_decode_comparison(args):
...
@@ -896,8 +841,7 @@ def speed_benchmark_decode_comparison(args):
num_tokens
,
q_h
,
head_size
=
q_decode
.
shape
num_tokens
,
q_h
,
head_size
=
q_decode
.
shape
batch
=
cu_seqlens_k
.
size
(
0
)
-
1
batch
=
cu_seqlens_k
.
size
(
0
)
-
1
k_h
=
k_varlen
.
size
(
1
)
k_h
=
k_varlen
.
size
(
1
)
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
args
.
test_sink
)
args
.
test_sink
)
# Benchmark
# Benchmark
print
(
"⚡ Benchmarking Tilelang kernel (100 iterations)..."
)
print
(
"⚡ Benchmarking Tilelang kernel (100 iterations)..."
)
...
@@ -920,36 +864,41 @@ def speed_benchmark_decode_comparison(args):
...
@@ -920,36 +864,41 @@ def speed_benchmark_decode_comparison(args):
# Benchmark
# Benchmark
print
(
"⚡ Benchmarking Triton kernel (100 iterations)..."
)
print
(
"⚡ Benchmarking Triton kernel (100 iterations)..."
)
triton_time
=
do_bench
(
flash_attn_with_attn_pool_decode
,
q_decode
,
k_varlen
,
v_varlen
,
triton_time
=
do_bench
(
cu_seqlens_k
,
max_seqlen_k
,
args
.
k_seqlen
,
1
,
softmax_scale
,
sink
,
flash_attn_with_attn_pool_decode
,
block_size
)
q_decode
,
k_varlen
,
v_varlen
,
cu_seqlens_k
,
max_seqlen_k
,
args
.
k_seqlen
,
1
,
softmax_scale
,
sink
,
block_size
,
)
print
(
f
"Average decode kernel time Triton:
{
triton_time
:.
3
f
}
ms"
)
print
(
f
"Average decode kernel time Triton:
{
triton_time
:.
3
f
}
ms"
)
print
(
f
"Speedup:
{
(
triton_time
/
tilelang_time
):.
3
f
}
"
)
print
(
f
"Speedup:
{
(
triton_time
/
tilelang_time
):.
3
f
}
"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
'Flash Attention Decode with Attention Pooling'
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Flash Attention Decode with Attention Pooling"
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
1
,
help
=
'Batch size'
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
,
help
=
"Batch size"
)
parser
.
add_argument
(
'--q_heads'
,
type
=
int
,
default
=
32
,
help
=
'Number of query heads'
)
parser
.
add_argument
(
"--q_heads"
,
type
=
int
,
default
=
32
,
help
=
"Number of query heads"
)
parser
.
add_argument
(
'--kv_heads'
,
type
=
int
,
default
=
8
,
help
=
'Number of key-value heads'
)
parser
.
add_argument
(
"--kv_heads"
,
type
=
int
,
default
=
8
,
help
=
"Number of key-value heads"
)
parser
.
add_argument
(
'--k_seqlen'
,
type
=
int
,
default
=
8192
,
help
=
'Key sequence length'
)
parser
.
add_argument
(
"--k_seqlen"
,
type
=
int
,
default
=
8192
,
help
=
"Key sequence length"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--head_size"
,
type
=
int
,
default
=
128
,
choices
=
[
64
,
128
,
256
],
help
=
"Head dimension"
)
'--head_size'
,
type
=
int
,
default
=
128
,
choices
=
[
64
,
128
,
256
],
help
=
'Head dimension'
)
parser
.
add_argument
(
"--block_size"
,
type
=
int
,
default
=
64
,
help
=
"Block size for computation"
)
parser
.
add_argument
(
'--block_size'
,
type
=
int
,
default
=
64
,
help
=
'Block size for computation'
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
T
.
bfloat16
,
choices
=
[
T
.
float16
,
T
.
bfloat16
],
help
=
"Data type"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--test_varlen"
,
action
=
"store_true"
,
help
=
"Test with truly variable sequence lengths"
)
'--dtype'
,
type
=
str
,
default
=
'bfloat16'
,
choices
=
[
'float16'
,
'bfloat16'
],
help
=
'Data type'
)
parser
.
add_argument
(
"--test_sink"
,
action
=
"store_true"
,
help
=
"Test with sink attention mechanism"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--benchmark"
,
action
=
"store_true"
,
help
=
"Run speed benchmark"
)
'--test_varlen'
,
action
=
'store_true'
,
help
=
'Test with truly variable sequence lengths'
)
parser
.
add_argument
(
"--num_split"
,
type
=
int
,
default
=
1
,
choices
=
[
1
,
16
],
help
=
"Number of splits"
)
parser
.
add_argument
(
'--test_sink'
,
action
=
'store_true'
,
help
=
'Test with sink attention mechanism'
)
parser
.
add_argument
(
'--benchmark'
,
action
=
'store_true'
,
help
=
'Run speed benchmark'
)
parser
.
add_argument
(
'--num_split'
,
type
=
int
,
default
=
1
,
choices
=
[
1
,
16
],
help
=
'Number of splits'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
args
.
test_sink
=
True
args
.
test_sink
=
True
args
.
test_varlen
=
False
args
.
test_varlen
=
False
args
.
dtype
=
'
float16
'
args
.
dtype
=
T
.
float16
args
.
num_split
=
1
args
.
num_split
=
1
if
args
.
benchmark
:
if
args
.
benchmark
:
...
...
examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py
0 → 100644
View file @
667632cc
import
torch
import
math
import
argparse
import
tilelang
import
tilelang.language
as
T
from
example_gqa_decode_varlen_logits
import
flash_attn_with_attn_pool_decode
,
repeat_kv
,
do_bench
torch
.
manual_seed
(
0
)
def
get_configs
():
import
itertools
block_N
=
[
64
,
128
]
block_H
=
[
64
]
num_split
=
[
1
]
num_stages
=
[
1
,
2
,
3
]
threads
=
[
128
]
_configs
=
list
(
itertools
.
product
(
block_N
,
block_H
,
num_split
,
num_stages
,
threads
))
configs
=
[{
"block_N"
:
c
[
0
],
"block_H"
:
c
[
1
],
"num_split"
:
c
[
2
],
"num_stages"
:
c
[
3
],
"threads"
:
c
[
4
]}
for
c
in
_configs
]
return
configs
# @autotune(configs=get_configs(), warmup=10, rep=10)
@
tilelang
.
jit
(
out_idx
=
[
-
2
,
-
1
],
debug_root_path
=
"./examples/flash_decoding"
)
def
flashattn
(
batch
,
heads
,
k_heads
,
max_seqlen_kv
,
total_seqlen_k
,
dim
,
has_sink
,
page_block_size
,
block_N
=
128
,
block_H
=
64
,
num_split
=
1
,
num_stages
=
1
,
threads
=
128
,
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape_q
=
[
batch
,
heads
,
dim
]
shape_k
=
[
total_seqlen_k
,
k_heads
,
dim
]
shape_v
=
[
total_seqlen_k
,
k_heads
,
dim
]
shape_o
=
[
batch
,
heads
,
dim
]
shape_s
=
[
batch
,
heads
,
math
.
ceil
(
max_seqlen_kv
/
block_N
)]
dtype
=
T
.
float16
accum_dtype
=
T
.
float32
kv_group_num
=
heads
//
k_heads
assert
page_block_size
>=
block_N
and
page_block_size
%
block_N
==
0
,
(
"page_block_size must be larger than block_N and a multiple of block_N"
)
valid_block_H
=
min
(
block_H
,
kv_group_num
)
# TODO: check if max_seqlen_kv is correct for varlen case
@
T
.
macro
def
flash_attn
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
cu_seqlens_k
:
T
.
Tensor
([
batch
+
1
],
T
.
int32
),
s_aux
:
T
.
Tensor
([
heads
],
T
.
float32
),
BLOCK_TABLE
:
T
.
Tensor
([
batch
,
math
.
ceil
(
max_seqlen_kv
/
block_N
)],
T
.
int32
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
S
:
T
.
Tensor
(
shape_s
,
dtype
),
):
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
O_shared
=
T
.
alloc_shared
([
valid_block_H
,
dim
],
dtype
)
acc_s
=
T
.
alloc_fragment
([
block_H
,
block_N
],
accum_dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_H
,
block_N
],
dtype
)
acc_o
=
T
.
alloc_fragment
([
block_H
,
dim
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
scores_max_prev
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
scores_scale
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
scores_sum
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
S_shared
=
T
.
alloc_shared
([
block_H
,
math
.
ceil
(
max_seqlen_kv
/
block_N
)],
dtype
)
s_aux_shared
=
T
.
alloc_shared
([
block_H
],
T
.
float32
)
bid
=
bx
hid
=
by
cur_kv_head
=
hid
//
(
kv_group_num
//
valid_block_H
)
cur_start_k
=
cu_seqlens_k
[
bid
]
cur_end_k
=
cu_seqlens_k
[
bid
+
1
]
cur_seqlen_k
=
cur_end_k
-
cur_start_k
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
# loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
loop_range
=
T
.
ceildiv
((
cur_seqlen_k
//
num_split
),
block_N
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
k_start
=
BLOCK_TABLE
[
bid
,
(
k
*
block_N
)
//
page_block_size
]
*
page_block_size
+
(
k
*
block_N
)
%
page_block_size
T
.
copy
(
K
[
cur_start_k
+
k_start
:
cur_start_k
+
k_start
+
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
k
*
block_N
+
j
<
cur_seqlen_k
,
acc_s
[
i
,
j
],
-
T
.
infinity
(
accum_dtype
))
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
# scores_max_prev is m_i
# scores_max is row_max->m_ij in triton
T
.
copy
(
scores_max
,
S_shared
[:,
k
])
# scores_scale is alpha in triton
for
i
in
T
.
Parallel
(
block_H
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
# scores_sum is l_ij in triton
# logsum is l_i in triton
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
T
.
copy
(
acc_s
,
acc_s_cast
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
v_start
=
BLOCK_TABLE
[
bid
,
(
k
*
block_N
)
//
page_block_size
]
*
page_block_size
+
(
k
*
block_N
)
%
page_block_size
T
.
copy
(
V
[
cur_start_k
+
v_start
:
cur_start_k
+
v_start
+
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
if
has_sink
:
T
.
copy
(
s_aux
[
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
],
s_aux_shared
)
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
+=
s_aux_shared
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
for
h
,
k
in
T
.
Parallel
(
block_H
,
math
.
ceil
(
max_seqlen_kv
/
block_N
)):
S_shared
[
h
,
k
]
=
T
.
exp2
((
S_shared
[
h
,
k
]
-
scores_max
[
h
])
*
scale
)
/
logsum
[
h
]
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
acc_o
[:
valid_block_H
,
:],
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bid
,
hid
*
valid_block_H
:
(
hid
+
1
)
*
valid_block_H
,
:])
T
.
copy
(
S_shared
[:
valid_block_H
,
:],
S
[
bid
,
hid
*
valid_block_H
:
(
hid
+
1
)
*
valid_block_H
,
:])
@
T
.
prim_func
def
flashattn_gqa_decode_no_split
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
cu_seqlens_k
:
T
.
Tensor
([
batch
+
1
],
T
.
int32
),
s_aux
:
T
.
Tensor
([
heads
],
T
.
float32
),
BLOCK_TABLE
:
T
.
Tensor
([
batch
,
math
.
ceil
(
max_seqlen_kv
/
page_block_size
)],
T
.
int32
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
S
:
T
.
Tensor
(
shape_s
,
dtype
),
):
flash_attn
(
Q
,
K
,
V
,
cu_seqlens_k
,
s_aux
,
BLOCK_TABLE
,
Output
,
S
)
# TODO: split version
return
flashattn_gqa_decode_no_split
def
flash_attn_with_attn_pool_decode_tilelang
(
Q
:
torch
.
Tensor
,
## [tq = b, q_h, q_dim]
K
:
torch
.
Tensor
,
## [tk, k_h, k_dim]
V
:
torch
.
Tensor
,
cu_seqlens_k
:
torch
.
Tensor
,
max_seqlen_k
:
int
,
real_max_k_seqlen
:
int
,
num_split
:
int
,
softmax_scale
:
float
,
s_aux
:
torch
.
Tensor
=
None
,
block_size
:
int
=
64
,
use_per_kv_head_sparse_index
:
bool
=
False
,
tl_kernel
=
None
,
block_table
:
torch
.
Tensor
=
None
,
):
num_tokens
,
q_h
,
head_size
=
Q
.
shape
batch
=
cu_seqlens_k
.
size
(
0
)
-
1
k_h
=
K
.
size
(
1
)
assert
Q
.
dim
()
==
K
.
dim
()
==
3
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
assert
cu_seqlens_k
.
dim
()
==
1
assert
head_size
in
{
64
,
128
,
256
}
assert
Q
.
is_contiguous
()
assert
K
.
is_contiguous
()
assert
V
.
is_contiguous
()
gqa_group_size
=
q_h
//
k_h
O_tl
=
torch
.
zeros_like
(
Q
)
S_tl
=
torch
.
zeros
((
batch
,
q_h
,
math
.
ceil
(
real_max_k_seqlen
/
block_size
)),
dtype
=
Q
.
dtype
,
device
=
Q
.
device
)
O_tl
,
S_tl
=
tl_kernel
(
Q
,
K
,
V
,
cu_seqlens_k
,
s_aux
,
block_table
)
if
use_per_kv_head_sparse_index
:
S_tl
=
torch
.
max_pool2d
(
S_tl
,
kernel_size
=
(
gqa_group_size
,
1
),
stride
=
(
gqa_group_size
,
1
))
else
:
S_tl
=
torch
.
max_pool2d
(
S_tl
,
kernel_size
=
(
q_h
,
1
),
stride
=
(
q_h
,
1
))
return
O_tl
,
S_tl
def
test_equal_seqlen_decode_main
(
args
):
"""Test decode kernel with equal sequence lengths"""
print
(
"Testing decode kernel with equal sequence lengths"
)
batch_size
=
args
.
batch_size
q_heads
=
args
.
q_heads
kv_heads
=
args
.
kv_heads
k_seqlen
=
args
.
k_seqlen
real_max_k_seqlen
=
args
.
k_seqlen
head_size
=
args
.
head_size
block_size
=
args
.
block_size
page_block_size
=
args
.
page_block_size
dtype
=
torch
.
bfloat16
if
args
.
dtype
==
T
.
bfloat16
else
torch
.
float16
# For decode, query is just 1 token per batch
q
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
"cuda"
,
dtype
=
dtype
)
k
=
torch
.
randn
(
batch_size
,
kv_heads
,
k_seqlen
,
head_size
,
device
=
"cuda"
,
dtype
=
dtype
)
v
=
torch
.
randn
(
batch_size
,
kv_heads
,
k_seqlen
,
head_size
,
device
=
"cuda"
,
dtype
=
dtype
)
softmax_scale
=
1.0
/
math
.
sqrt
(
head_size
)
# Generate sink values if needed
sink
=
None
if
args
.
test_sink
:
sink
=
torch
.
randn
(
q_heads
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
print
(
f
"Using sink attention with sink values:
{
sink
}
"
)
# Convert to varlen format for K, V
k_varlen
=
k
.
transpose
(
1
,
2
).
reshape
(
batch_size
*
k_seqlen
,
kv_heads
,
head_size
).
contiguous
()
v_varlen
=
v
.
transpose
(
1
,
2
).
reshape
(
batch_size
*
k_seqlen
,
kv_heads
,
head_size
).
contiguous
()
# Generate cumulative sequence lengths
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
k_seqlen
,
k_seqlen
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
max_seqlen_k
=
k_seqlen
print
(
f
"q shape:
{
q
.
shape
}
"
)
print
(
f
"k_varlen shape:
{
k_varlen
.
shape
}
"
)
print
(
f
"v_varlen shape:
{
v_varlen
.
shape
}
"
)
num_tokens
,
q_h
,
head_size
=
q
.
shape
batch
=
cu_seqlens_k
.
size
(
0
)
-
1
k_h
=
k_varlen
.
size
(
1
)
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
args
.
test_sink
,
page_block_size
)
block_table
=
torch
.
zeros
(
batch
,
math
.
ceil
(
real_max_k_seqlen
/
page_block_size
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
block_cnt
=
0
for
i
in
range
(
batch
):
cur_seqlen
=
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
()
for
j
in
range
(
math
.
ceil
(
cur_seqlen
/
page_block_size
)):
block_table
[
i
,
j
]
=
block_cnt
block_cnt
+=
1
block_cnt
=
0
# Test our decode kernel
O_triton
,
S_triton
=
flash_attn_with_attn_pool_decode
(
q
,
k_varlen
,
v_varlen
,
cu_seqlens_k
,
max_seqlen_k
,
real_max_k_seqlen
,
args
.
num_split
,
softmax_scale
,
s_aux
=
sink
,
block_size
=
block_size
,
)
O_tilelang
,
S_tilelang
=
flash_attn_with_attn_pool_decode_tilelang
(
q
,
k_varlen
,
v_varlen
,
cu_seqlens_k
,
max_seqlen_k
,
real_max_k_seqlen
,
args
.
num_split
,
softmax_scale
,
s_aux
=
sink
,
block_size
=
block_size
,
tl_kernel
=
tl_kernel
,
block_table
=
block_table
,
)
for
i
in
range
(
batch_size
):
S_tilelang
[
i
,
:,
math
.
ceil
((
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
())
/
block_size
)
:]
=
0
# Compute torch reference
q_expanded
=
q
.
unsqueeze
(
2
)
# [b, q_heads, 1, head_size]
k_repeat
=
repeat_kv
(
k
,
q_heads
//
kv_heads
)
# [b, q_heads, k_seqlen, head_size]
v_repeat
=
repeat_kv
(
v
,
q_heads
//
kv_heads
)
# [b, q_heads, k_seqlen, head_size]
if
sink
is
None
:
# Standard scaled dot-product attention
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [batch, q_heads, 1, seqlen_k]
attn_weights
=
torch
.
softmax
(
logits
,
dim
=-
1
)
O_torch
=
torch
.
matmul
(
attn_weights
,
v_repeat
).
squeeze
(
2
)
# [batch, q_heads, head_size]
else
:
# s_aux attention
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [batch, q_heads, 1, seqlen_k]
sink_expanded
=
sink
.
view
(
1
,
q_heads
,
1
,
1
)
# [1, q_heads, 1, 1]
logits_max
=
torch
.
max
(
logits
,
dim
=-
1
,
keepdim
=
True
).
values
logits_or_sinks_max
=
torch
.
maximum
(
logits_max
,
sink_expanded
)
sinks
=
torch
.
exp
(
sink_expanded
-
logits_or_sinks_max
)
unnormalized_scores
=
torch
.
exp
(
logits
-
logits_or_sinks_max
)
normalizer
=
unnormalized_scores
.
sum
(
dim
=-
1
,
keepdim
=
True
)
+
sinks
attn_weights
=
unnormalized_scores
/
normalizer
O_torch
=
torch
.
matmul
(
attn_weights
.
to
(
v_repeat
.
dtype
),
v_repeat
).
squeeze
(
2
)
# [batch, q_heads, head_size]
# Compute attention score pooling
attn_score_pooled
=
torch
.
max_pool2d
(
attn_weights
.
squeeze
(
2
),
# [b, q_heads, k_seqlen]
kernel_size
=
(
q_heads
,
block_size
),
stride
=
(
q_heads
,
block_size
),
ceil_mode
=
True
,
).
to
(
torch
.
float16
)
print
(
"S_tilelang"
,
S_tilelang
)
print
(
"attn_score_pooled"
,
attn_score_pooled
)
max_diff_o
=
torch
.
max
(
torch
.
abs
(
O_triton
-
O_torch
))
max_diff_s
=
torch
.
max
(
torch
.
abs
(
S_triton
-
attn_score_pooled
))
max_diff_o_tilelang
=
torch
.
max
(
torch
.
abs
(
O_tilelang
-
O_torch
))
max_diff_s_tilelang
=
torch
.
max
(
torch
.
abs
(
S_tilelang
-
attn_score_pooled
))
print
(
f
"Max difference in O:
{
max_diff_o
.
item
()
}
"
)
print
(
f
"Max difference in S:
{
max_diff_s
.
item
()
}
"
)
print
(
f
"Max difference in O_tilelang:
{
max_diff_o_tilelang
.
item
()
}
"
)
print
(
f
"Max difference in S_tilelang:
{
max_diff_s_tilelang
.
item
()
}
"
)
assert
torch
.
allclose
(
O_triton
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o
.
item
()
}
"
assert
torch
.
allclose
(
S_triton
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s
.
item
()
}
"
assert
torch
.
allclose
(
O_tilelang
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o_tilelang
.
item
()
}
"
assert
torch
.
allclose
(
S_tilelang
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s_tilelang
.
item
()
}
"
print
(
"✅ All tests passed!"
)
def
test_varlen_decode_main
(
args
):
"""Test decode kernel with variable sequence lengths"""
batch_size
=
args
.
batch_size
q_heads
=
args
.
q_heads
kv_heads
=
args
.
kv_heads
max_k_seqlen
=
args
.
k_seqlen
# Use as max sequence length
real_max_k_seqlen
=
args
.
k_seqlen
head_size
=
args
.
head_size
block_size
=
args
.
block_size
page_block_size
=
args
.
page_block_size
dtype
=
torch
.
bfloat16
if
args
.
dtype
==
T
.
bfloat16
else
torch
.
float16
print
(
f
"Testing decode kernel with variable sequence lengths (max_k_seqlen=
{
max_k_seqlen
}
)"
)
# Generate sink values if needed
sink
=
None
if
args
.
test_sink
:
sink
=
torch
.
randn
(
q_heads
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
print
(
f
"Using sink attention with sink values:
{
sink
}
"
)
# Generate variable length k sequences
k_seqlens
=
torch
.
randint
(
max_k_seqlen
//
4
,
max_k_seqlen
+
1
,
size
=
(
batch_size
,))
print
(
f
"k_seqlens:
{
k_seqlens
}
"
)
# Generate cumulative sequence lengths for k
cu_seqlens_k
=
torch
.
zeros
(
batch_size
+
1
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
total_k_tokens
=
0
for
i
in
range
(
batch_size
):
cu_seqlens_k
[
i
]
=
total_k_tokens
total_k_tokens
+=
k_seqlens
[
i
]
cu_seqlens_k
[
batch_size
]
=
total_k_tokens
print
(
f
"cu_seqlens_k:
{
cu_seqlens_k
}
"
)
# Generate tensors - Q is [batch_size, q_heads, head_size] for decode
q_decode
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
"cuda"
,
dtype
=
dtype
)
k_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
"cuda"
,
dtype
=
dtype
)
v_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
"cuda"
,
dtype
=
dtype
)
softmax_scale
=
1.0
/
math
.
sqrt
(
head_size
)
max_seqlen_k
=
int
(
k_seqlens
.
max
())
print
(
f
"Actual max_seqlen_k:
{
max_seqlen_k
}
"
)
print
(
f
"q_decode shape:
{
q_decode
.
shape
}
"
)
print
(
f
"k_varlen shape:
{
k_varlen
.
shape
}
"
)
print
(
f
"v_varlen shape:
{
v_varlen
.
shape
}
"
)
num_tokens
,
q_h
,
head_size
=
q_decode
.
shape
batch
=
cu_seqlens_k
.
size
(
0
)
-
1
k_h
=
k_varlen
.
size
(
1
)
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
args
.
test_sink
,
page_block_size
)
block_table
=
torch
.
zeros
(
batch
,
math
.
ceil
(
real_max_k_seqlen
/
page_block_size
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
block_cnt
=
0
for
i
in
range
(
batch
):
cur_seqlen
=
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
()
for
j
in
range
(
math
.
ceil
(
cur_seqlen
/
page_block_size
)):
block_table
[
i
,
j
]
=
block_cnt
block_cnt
+=
1
block_cnt
=
0
# Test our decode kernel
O_triton
,
S_triton
=
flash_attn_with_attn_pool_decode
(
q_decode
,
k_varlen
,
v_varlen
,
cu_seqlens_k
,
max_seqlen_k
,
real_max_k_seqlen
,
args
.
num_split
,
softmax_scale
,
s_aux
=
sink
,
block_size
=
block_size
,
)
O_tilelang
,
S_tilelang
=
flash_attn_with_attn_pool_decode_tilelang
(
q_decode
,
k_varlen
,
v_varlen
,
cu_seqlens_k
,
max_seqlen_k
,
real_max_k_seqlen
,
args
.
num_split
,
softmax_scale
,
s_aux
=
sink
,
block_size
=
block_size
,
tl_kernel
=
tl_kernel
,
block_table
=
block_table
,
)
for
i
in
range
(
batch_size
):
S_tilelang
[
i
,
:,
math
.
ceil
((
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
())
/
block_size
)
:]
=
0
# Create torch reference - pad tensors for comparison
k_padded_list
=
[]
v_padded_list
=
[]
for
i
in
range
(
batch_size
):
actual_k_len
=
k_seqlens
[
i
]
# Extract and pad k, v for this batch
k_start
=
cu_seqlens_k
[
i
]
k_end
=
cu_seqlens_k
[
i
+
1
]
# Pad to max_seqlen_k
k_padded
=
torch
.
zeros
(
max_seqlen_k
,
kv_heads
,
head_size
,
device
=
"cuda"
,
dtype
=
dtype
)
v_padded
=
torch
.
zeros
(
max_seqlen_k
,
kv_heads
,
head_size
,
device
=
"cuda"
,
dtype
=
dtype
)
k_padded
[:
actual_k_len
]
=
k_varlen
[
k_start
:
k_end
]
v_padded
[:
actual_k_len
]
=
v_varlen
[
k_start
:
k_end
]
k_padded_list
.
append
(
k_padded
)
v_padded_list
.
append
(
v_padded
)
# Stack to create batched tensors [b, max_seqlen, kv_heads, head_size]
k_padded_batched
=
torch
.
stack
(
k_padded_list
,
dim
=
0
).
transpose
(
1
,
2
)
# [b, kv_heads, max_seqlen, head_size]
v_padded_batched
=
torch
.
stack
(
v_padded_list
,
dim
=
0
).
transpose
(
1
,
2
)
# [b, kv_heads, max_seqlen, head_size]
# Expand q to match kv heads: [b, q_heads, 1, head_size]
q_expanded
=
q_decode
.
unsqueeze
(
2
)
# [b, q_heads, 1, head_size]
print
(
f
"q_expanded shape:
{
q_expanded
.
shape
}
"
)
print
(
f
"k_padded_batched shape:
{
k_padded_batched
.
shape
}
"
)
print
(
f
"v_padded_batched shape:
{
v_padded_batched
.
shape
}
"
)
# Compute torch reference
k_repeat
=
repeat_kv
(
k_padded_batched
,
q_heads
//
kv_heads
)
# [b, q_heads, max_seqlen, head_size]
v_repeat
=
repeat_kv
(
v_padded_batched
,
q_heads
//
kv_heads
)
# [b, q_heads, max_seqlen, head_size]
if
sink
is
None
:
# Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen]
attn_score
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [b, q_heads, 1, max_seqlen]
# Apply sequence length masking
for
i
in
range
(
batch_size
):
actual_k_len
=
k_seqlens
[
i
]
attn_score
[
i
,
:,
:,
actual_k_len
:]
=
float
(
"-inf"
)
attn_weights
=
attn_score
.
softmax
(
dim
=-
1
)
# [b, q_heads, 1, max_seqlen]
# Mask out invalid positions
for
i
in
range
(
batch_size
):
actual_k_len
=
k_seqlens
[
i
]
attn_weights
[
i
,
:,
:,
actual_k_len
:]
=
0.0
# Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size]
O_torch
=
torch
.
matmul
(
attn_weights
,
v_repeat
)
# [b, q_heads, 1, head_size]
else
:
# s_aux attention
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [b, q_heads, 1, max_seqlen]
# Apply sequence length masking
for
i
in
range
(
batch_size
):
actual_k_len
=
k_seqlens
[
i
]
logits
[
i
,
:,
:,
actual_k_len
:]
=
float
(
"-inf"
)
sink_expanded
=
sink
.
view
(
1
,
q_heads
,
1
,
1
)
# [1, q_heads, 1, 1]
logits_max
=
torch
.
max
(
logits
,
dim
=-
1
,
keepdim
=
True
).
values
logits_or_sinks_max
=
torch
.
maximum
(
logits_max
,
sink_expanded
)
sinks
=
torch
.
exp
(
sink_expanded
-
logits_or_sinks_max
)
unnormalized_scores
=
torch
.
exp
(
logits
-
logits_or_sinks_max
)
normalizer
=
unnormalized_scores
.
sum
(
dim
=-
1
,
keepdim
=
True
)
+
sinks
attn_weights
=
unnormalized_scores
/
normalizer
# Mask out invalid positions
for
i
in
range
(
batch_size
):
actual_k_len
=
k_seqlens
[
i
]
attn_weights
[
i
,
:,
:,
actual_k_len
:]
=
0.0
# Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size]
O_torch
=
torch
.
matmul
(
attn_weights
.
to
(
v_repeat
.
dtype
),
v_repeat
)
# [b, q_heads, 1, head_size]
O_torch
=
O_torch
.
squeeze
(
2
)
# [b, q_heads, head_size]
# Compute attention score pooling for S
attn_score_pooled
=
torch
.
max_pool2d
(
attn_weights
.
squeeze
(
2
),
# [b, q_heads, max_seqlen]
kernel_size
=
(
q_heads
,
block_size
),
stride
=
(
q_heads
,
block_size
),
ceil_mode
=
True
,
).
to
(
dtype
=
torch
.
float16
)
# [b, 1, ceil(max_seqlen/block_size)]
print
(
f
"O_triton shape:
{
O_triton
.
shape
}
"
)
print
(
f
"O_tilelang shape:
{
O_tilelang
.
shape
}
"
)
print
(
f
"O_torch shape:
{
O_torch
.
shape
}
"
)
print
(
f
"S_triton shape:
{
S_triton
.
shape
}
"
)
print
(
f
"S_tilelang shape:
{
S_tilelang
.
shape
}
"
)
print
(
f
"attn_score_pooled shape:
{
attn_score_pooled
.
shape
}
"
)
# Compare results
max_diff_o
=
torch
.
max
(
torch
.
abs
(
O_triton
-
O_torch
))
max_diff_o_tl
=
torch
.
max
(
torch
.
abs
(
O_tilelang
-
O_torch
))
print
(
f
"Max difference in O:
{
max_diff_o
.
item
()
}
"
)
print
(
f
"Max difference in O_tilelang:
{
max_diff_o_tl
.
item
()
}
"
)
max_diff_s
=
torch
.
max
(
torch
.
abs
(
S_triton
-
attn_score_pooled
))
max_diff_s_tl
=
torch
.
max
(
torch
.
abs
(
S_tilelang
[:,
:,
:
math
.
ceil
(
max_seqlen_k
/
block_size
)]
-
attn_score_pooled
))
print
(
f
"Max difference in S:
{
max_diff_s
.
item
()
}
"
)
print
(
f
"Max difference in S_tilelang:
{
max_diff_s_tl
.
item
()
}
"
)
assert
torch
.
allclose
(
O_triton
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o
.
item
()
}
"
assert
torch
.
allclose
(
S_triton
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s
.
item
()
}
"
assert
torch
.
allclose
(
O_tilelang
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o_tl
.
item
()
}
"
assert
torch
.
allclose
(
S_tilelang
[:,
:,
:
math
.
ceil
(
max_seqlen_k
/
block_size
)],
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
(
f
"Score mismatch:
{
max_diff_s_tl
.
item
()
}
"
)
print
(
"✅ All tests passed!"
)
def
speed_benchmark_decode_comparison
(
args
):
"""Speed benchmark for decode kernel"""
batch_size
=
args
.
batch_size
q_heads
=
args
.
q_heads
kv_heads
=
args
.
kv_heads
max_k_seqlen
=
args
.
k_seqlen
real_max_k_seqlen
=
args
.
k_seqlen
head_size
=
args
.
head_size
block_size
=
args
.
block_size
page_block_size
=
args
.
page_block_size
dtype
=
torch
.
bfloat16
if
args
.
dtype
==
T
.
bfloat16
else
torch
.
float16
print
(
"
\n
=== Decode Speed Benchmark Comparison ==="
)
print
(
"Configuration:"
)
print
(
f
" Batch size:
{
batch_size
}
"
)
print
(
f
" Q heads:
{
q_heads
}
, KV heads:
{
kv_heads
}
"
)
print
(
f
" Max K sequence length:
{
max_k_seqlen
}
"
)
print
(
f
" Head size:
{
head_size
}
"
)
print
(
f
" Block size:
{
block_size
}
"
)
print
(
f
" Data type:
{
dtype
}
"
)
print
(
f
" Variable lengths:
{
args
.
test_varlen
}
"
)
print
(
f
" s_aux attention:
{
args
.
test_sink
}
"
)
print
()
# Generate input data
if
args
.
test_varlen
:
k_seqlens
=
torch
.
randint
(
max_k_seqlen
//
4
,
max_k_seqlen
+
1
,
size
=
(
batch_size
,))
else
:
k_seqlens
=
torch
.
full
((
batch_size
,),
max_k_seqlen
,
dtype
=
int
)
# Generate cumulative sequence lengths for k
cu_seqlens_k
=
torch
.
zeros
(
batch_size
+
1
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
total_k_tokens
=
0
for
i
in
range
(
batch_size
):
cu_seqlens_k
[
i
]
=
total_k_tokens
total_k_tokens
+=
k_seqlens
[
i
]
cu_seqlens_k
[
batch_size
]
=
total_k_tokens
# Generate tensors
q_decode
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
"cuda"
,
dtype
=
dtype
)
k_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
"cuda"
,
dtype
=
dtype
)
v_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
"cuda"
,
dtype
=
dtype
)
softmax_scale
=
1.0
/
math
.
sqrt
(
head_size
)
max_seqlen_k
=
int
(
k_seqlens
.
max
())
# Generate sink values if needed
sink
=
None
if
args
.
test_sink
:
sink
=
torch
.
randn
(
q_heads
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
print
(
" Using sink attention with sink values"
)
print
(
"Setup complete:"
)
print
(
f
" Total K tokens:
{
total_k_tokens
}
"
)
print
(
f
" Actual max K seq len:
{
max_seqlen_k
}
"
)
if
args
.
test_varlen
:
print
(
f
" K sequence lengths:
{
k_seqlens
.
tolist
()
}
"
)
# Warmup
num_tokens
,
q_h
,
head_size
=
q_decode
.
shape
batch
=
cu_seqlens_k
.
size
(
0
)
-
1
k_h
=
k_varlen
.
size
(
1
)
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
args
.
test_sink
,
page_block_size
)
block_table
=
torch
.
zeros
(
batch
,
math
.
ceil
(
real_max_k_seqlen
/
page_block_size
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
block_cnt
=
0
for
i
in
range
(
batch
):
cur_seqlen
=
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
()
for
j
in
range
(
math
.
ceil
(
cur_seqlen
/
page_block_size
)):
block_table
[
i
,
j
]
=
block_cnt
block_cnt
+=
1
block_cnt
=
0
# Benchmark
print
(
"⚡ Benchmarking Tilelang kernel (100 iterations)..."
)
tilelang_time
=
do_bench
(
flash_attn_with_attn_pool_decode_tilelang
,
q_decode
,
k_varlen
,
v_varlen
,
cu_seqlens_k
,
max_seqlen_k
,
args
.
k_seqlen
,
1
,
softmax_scale
,
sink
,
block_size
,
False
,
tl_kernel
,
block_table
,
)
print
(
f
"Average decode kernel time Tilelang:
{
tilelang_time
:.
3
f
}
ms"
)
# Benchmark
print
(
"⚡ Benchmarking Triton kernel (100 iterations)..."
)
triton_time
=
do_bench
(
flash_attn_with_attn_pool_decode
,
q_decode
,
k_varlen
,
v_varlen
,
cu_seqlens_k
,
max_seqlen_k
,
args
.
k_seqlen
,
1
,
softmax_scale
,
sink
,
block_size
,
)
print
(
f
"Average decode kernel time Triton:
{
triton_time
:.
3
f
}
ms"
)
print
(
f
"Speedup:
{
(
triton_time
/
tilelang_time
):.
3
f
}
"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Flash Attention Decode with Attention Pooling"
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
,
help
=
"Batch size"
)
parser
.
add_argument
(
"--q_heads"
,
type
=
int
,
default
=
32
,
help
=
"Number of query heads"
)
parser
.
add_argument
(
"--kv_heads"
,
type
=
int
,
default
=
8
,
help
=
"Number of key-value heads"
)
parser
.
add_argument
(
"--k_seqlen"
,
type
=
int
,
default
=
8192
,
help
=
"Key sequence length"
)
parser
.
add_argument
(
"--head_size"
,
type
=
int
,
default
=
128
,
choices
=
[
64
,
128
,
256
],
help
=
"Head dimension"
)
parser
.
add_argument
(
"--block_size"
,
type
=
int
,
default
=
128
,
help
=
"Block size for computation"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
T
.
bfloat16
,
choices
=
[
T
.
float16
,
T
.
bfloat16
],
help
=
"Data type"
)
parser
.
add_argument
(
"--test_varlen"
,
action
=
"store_true"
,
help
=
"Test with truly variable sequence lengths"
)
parser
.
add_argument
(
"--test_sink"
,
action
=
"store_true"
,
help
=
"Test with sink attention mechanism"
)
parser
.
add_argument
(
"--benchmark"
,
action
=
"store_true"
,
help
=
"Run speed benchmark"
)
parser
.
add_argument
(
"--num_split"
,
type
=
int
,
default
=
1
,
choices
=
[
1
,
16
],
help
=
"Number of splits"
)
parser
.
add_argument
(
"--page_block_size"
,
type
=
int
,
default
=
128
,
help
=
"Page block size"
)
args
=
parser
.
parse_args
()
args
.
test_sink
=
True
args
.
test_varlen
=
True
args
.
dtype
=
T
.
float16
args
.
num_split
=
1
if
args
.
benchmark
:
speed_benchmark_decode_comparison
(
args
)
elif
args
.
test_varlen
:
test_varlen_decode_main
(
args
)
else
:
test_equal_seqlen_decode_main
(
args
)
examples/flash_decoding/example_mha_inference.py
View file @
667632cc
...
@@ -10,12 +10,12 @@ num_split = 4
...
@@ -10,12 +10,12 @@ num_split = 4
@
tilelang
.
jit
(
out_idx
=
[
5
])
@
tilelang
.
jit
(
out_idx
=
[
5
])
def
flashattn
(
batch
,
heads
,
seqlen_q
,
seqlen_kv
,
dim
,
is_causal
,
block_M
,
block_N
):
def
flashattn
(
batch
,
heads
,
seqlen_q
,
seqlen_kv
,
dim
,
is_causal
,
block_M
,
block_N
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape_q
=
[
batch
,
seqlen_q
,
heads
,
dim
]
shape_q
=
[
batch
,
seqlen_q
,
heads
,
dim
]
shape_kv
=
[
batch
,
seqlen_kv
,
heads
,
dim
]
shape_kv
=
[
batch
,
seqlen_kv
,
heads
,
dim
]
part_shape
=
[
batch
,
seqlen_q
,
heads
,
num_split
,
dim
]
part_shape
=
[
batch
,
seqlen_q
,
heads
,
num_split
,
dim
]
dtype
=
"
float16
"
dtype
=
T
.
float16
accum_dtype
=
"
float
"
accum_dtype
=
T
.
float
32
@
T
.
macro
@
T
.
macro
def
MMA0
(
def
MMA0
(
...
@@ -29,14 +29,11 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
...
@@ -29,14 +29,11 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
bid
:
T
.
int32
,
bid
:
T
.
int32
,
sid
:
T
.
int32
,
sid
:
T
.
int32
,
):
):
T
.
copy
(
T
.
copy
(
K
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
block_N
:
(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
block_N
,
hid
,
:],
K_shared
)
K
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
block_N
:(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
block_N
,
hid
,
:],
K_shared
)
# TODO: Handle causal split case
# TODO: Handle causal split case
if
is_causal
:
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
mid
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
mid
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
else
:
T
.
clear
(
acc_s
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
@@ -52,24 +49,24 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
...
@@ -52,24 +49,24 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
bid
:
T
.
int32
,
bid
:
T
.
int32
,
sid
:
T
.
int32
,
sid
:
T
.
int32
,
):
):
T
.
copy
(
T
.
copy
(
V
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
block_N
:
(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
block_N
,
hid
,
:],
V_shared
)
V
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
block_N
:(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
block_N
,
hid
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
@
T
.
macro
def
Softmax
(
def
Softmax
(
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
):
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_M
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# in the first ceil_div(kBlockM, kBlockN) steps.
...
@@ -89,23 +86,21 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
...
@@ -89,23 +86,21 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
@
T
.
macro
@
T
.
macro
def
Rescale
(
def
Rescale
(
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
@
T
.
macro
@
T
.
macro
def
flash_attn_split
(
def
flash_attn_split
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_kv
,
dtype
),
K
:
T
.
Tensor
(
shape_kv
,
dtype
),
V
:
T
.
Tensor
(
shape_kv
,
dtype
),
V
:
T
.
Tensor
(
shape_kv
,
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
seqlen_q
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
seqlen_q
],
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
seqlen_q
,
block_M
),
heads
*
batch
,
num_split
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
T
.
ceildiv
(
seqlen_q
,
block_M
),
heads
*
batch
,
num_split
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
...
@@ -126,39 +121,36 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
...
@@ -126,39 +121,36 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
# NOTE(wt): tma barrier has some problems with padded dimensions (seq_q here) currently
# NOTE(wt): tma barrier has some problems with padded dimensions (seq_q here) currently
# disable relevant tma copy and use SIMT as fallback for now
# disable relevant tma copy and use SIMT as fallback for now
T
.
copy
(
Q
[
bid
,
mid
*
block_M
:
(
mid
+
1
)
*
block_M
,
hid
,
:],
Q_shared
,
disable_tma
=
True
)
T
.
copy
(
Q
[
bid
,
mid
*
block_M
:
(
mid
+
1
)
*
block_M
,
hid
,
:],
Q_shared
,
disable_tma
=
True
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
# TODO: Handle causal split case
# TODO: Handle causal split case
loop_range
=
(
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
seqlen_kv
,
block_N
),
T
.
ceildiv
(
T
.
min
(
T
.
ceildiv
(
seqlen_kv
,
block_N
),
T
.
ceildiv
((
mid
+
1
)
*
block_M
,
block_N
))
(
mid
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
if
is_causal
(
seqlen_kv
//
num_split
),
block_N
))
else
T
.
ceildiv
((
seqlen_kv
//
num_split
),
block_N
)
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
2
):
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
2
):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
mid
,
hid
,
bid
,
sid
)
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
mid
,
hid
,
bid
,
sid
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
logsum
)
Rescale
(
acc_o
,
scores_scale
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
hid
,
bid
,
sid
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
hid
,
bid
,
sid
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
acc_o
[
i
,
j
]
/=
logsum
[
i
]
for
i
in
T
.
Parallel
(
block_M
):
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
glse
[
bid
,
hid
,
sid
,
mid
*
block_M
:
(
mid
+
1
)
*
block_M
])
T
.
copy
(
logsum
,
glse
[
bid
,
hid
,
sid
,
mid
*
block_M
:
(
mid
+
1
)
*
block_M
])
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
T
.
copy
(
O_shared
,
Output_partial
[
bid
,
mid
*
block_M
:
(
mid
+
1
)
*
block_M
,
hid
,
sid
,
:],
disable_tma
=
True
)
O_shared
,
Output_partial
[
bid
,
mid
*
block_M
:(
mid
+
1
)
*
block_M
,
hid
,
sid
,
:],
disable_tma
=
True
)
@
T
.
macro
@
T
.
macro
def
combine
(
def
combine
(
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
seqlen_q
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
seqlen_q
],
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
Output
:
T
.
Tensor
(
shape_q
,
dtype
),
Output
:
T
.
Tensor
(
shape_q
,
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
seqlen_q
,
block_M
),
heads
,
batch
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
seqlen_q
,
block_M
),
heads
,
batch
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
po_local
=
T
.
alloc_fragment
([
block_M
,
dim
],
dtype
)
po_local
=
T
.
alloc_fragment
([
block_M
,
dim
],
dtype
)
...
@@ -171,20 +163,25 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
...
@@ -171,20 +163,25 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
lse_max_local
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
lse_max_local
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scale_local
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scale_local
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
o_accum_local
:
T
.
Fragment
(
o_accum_local
.
shape
,
forward_thread_fn
=
lambda
i
,
j
:
i
),
{
o_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
o_shared
),
o_accum_local
:
T
.
Fragment
(
o_accum_local
.
shape
,
forward_thread_fn
=
lambda
i
,
j
:
i
),
po_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
po_shared
),
o_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
o_shared
),
})
po_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
po_shared
),
}
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
o_accum_local
)
T
.
clear
(
o_accum_local
)
T
.
copy
(
glse
[
T
.
copy
(
bz
,
glse
[
by
,
bz
,
:,
by
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
:,
],
lse_local
)
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
],
lse_local
,
)
T
.
reduce_max
(
lse_local
,
lse_max_local
,
dim
=
0
,
clear
=
False
)
T
.
reduce_max
(
lse_local
,
lse_max_local
,
dim
=
0
,
clear
=
False
)
for
k
in
T
.
Pipelined
(
num_split
):
for
k
in
T
.
Pipelined
(
num_split
):
T
.
copy
(
lse_local
[
k
,
:],
lse_local_split
)
T
.
copy
(
lse_local
[
k
,
:],
lse_local_split
)
...
@@ -193,10 +190,7 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
...
@@ -193,10 +190,7 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
for
i
in
T
.
Parallel
(
block_M
):
for
i
in
T
.
Parallel
(
block_M
):
lse_logsum_local
[
i
]
=
T
.
log2
(
lse_logsum_local
[
i
])
+
lse_max_local
[
i
]
lse_logsum_local
[
i
]
=
T
.
log2
(
lse_logsum_local
[
i
])
+
lse_max_local
[
i
]
for
k
in
T
.
Pipelined
(
num_split
,
num_stages
=
2
):
for
k
in
T
.
Pipelined
(
num_split
,
num_stages
=
2
):
T
.
copy
(
T
.
copy
(
Output_partial
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
k
,
:],
po_shared
,
disable_tma
=
True
)
Output_partial
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
,
k
,
:],
po_shared
,
disable_tma
=
True
)
T
.
copy
(
po_shared
,
po_local
)
T
.
copy
(
po_shared
,
po_local
)
for
i
in
T
.
Parallel
(
block_M
):
for
i
in
T
.
Parallel
(
block_M
):
lse_local_split
[
i
]
=
lse_local
[
k
,
i
]
lse_local_split
[
i
]
=
lse_local
[
k
,
i
]
...
@@ -205,16 +199,16 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
...
@@ -205,16 +199,16 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
o_accum_local
[
i
,
j
]
+=
po_local
[
i
,
j
]
*
scale_local
[
i
]
o_accum_local
[
i
,
j
]
+=
po_local
[
i
,
j
]
*
scale_local
[
i
]
T
.
copy
(
o_accum_local
,
o_shared
)
T
.
copy
(
o_accum_local
,
o_shared
)
T
.
copy
(
o_shared
,
Output
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:],
disable_tma
=
True
)
T
.
copy
(
o_shared
,
Output
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:],
disable_tma
=
True
)
@
T
.
prim_func
@
T
.
prim_func
def
flashattn_mha_inference
(
def
flashattn_mha_inference
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_kv
,
dtype
),
K
:
T
.
Tensor
(
shape_kv
,
dtype
),
V
:
T
.
Tensor
(
shape_kv
,
dtype
),
V
:
T
.
Tensor
(
shape_kv
,
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
seqlen_q
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
seqlen_q
],
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
# [batch, seqlen_q, heads, num_split, dim]
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
# [batch, seqlen_q, heads, num_split, dim]
Output
:
T
.
Tensor
(
shape_q
,
dtype
),
Output
:
T
.
Tensor
(
shape_q
,
dtype
),
):
):
flash_attn_split
(
Q
,
K
,
V
,
glse
,
Output_partial
)
flash_attn_split
(
Q
,
K
,
V
,
glse
,
Output_partial
)
combine
(
glse
,
Output_partial
,
Output
)
combine
(
glse
,
Output_partial
,
Output
)
...
@@ -225,10 +219,10 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
...
@@ -225,10 +219,10 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
def
ref_program
(
Q
,
K
,
V
,
glse
,
Output_partial
,
causal
):
def
ref_program
(
Q
,
K
,
V
,
glse
,
Output_partial
,
causal
):
assert
causal
is
False
assert
causal
is
False
dim
=
Q
.
size
(
-
1
)
dim
=
Q
.
size
(
-
1
)
scores
=
torch
.
einsum
(
'
bqhd,bkhd->bhqk
'
,
Q
,
K
)
scores
=
torch
.
einsum
(
"
bqhd,bkhd->bhqk
"
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'
bhqk,bkhd->bqhd
'
,
attention_weights
,
V
)
output
=
torch
.
einsum
(
"
bhqk,bkhd->bqhd
"
,
attention_weights
,
V
)
return
output
return
output
...
@@ -256,7 +250,7 @@ def flash_split_ref(Q, K, V, causal):
...
@@ -256,7 +250,7 @@ def flash_split_ref(Q, K, V, causal):
block_N
=
128
block_N
=
128
seqlen_kv
=
K
.
size
(
1
)
seqlen_kv
=
K
.
size
(
1
)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
acc_s
=
torch
.
empty
((
batch
,
nheads
,
block_M
,
block_N
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
acc_s
=
torch
.
empty
((
batch
,
nheads
,
block_M
,
block_N
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
acc_s_cast
=
torch
.
empty
((
batch
,
nheads
,
block_M
,
block_N
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
acc_s_cast
=
torch
.
empty
((
batch
,
nheads
,
block_M
,
block_N
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
acc_o
=
torch
.
empty
((
batch
,
block_M
,
nheads
,
dim
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
acc_o
=
torch
.
empty
((
batch
,
block_M
,
nheads
,
dim
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
...
@@ -273,14 +267,15 @@ def flash_split_ref(Q, K, V, causal):
...
@@ -273,14 +267,15 @@ def flash_split_ref(Q, K, V, causal):
for
ks
in
range
(
num_split
):
for
ks
in
range
(
num_split
):
acc_o
.
fill_
(
0
)
acc_o
.
fill_
(
0
)
logsum
.
fill_
(
0
)
logsum
.
fill_
(
0
)
scores_max
.
fill_
(
float
(
'
-inf
'
))
scores_max
.
fill_
(
float
(
"
-inf
"
))
scores_max_prev
.
fill_
(
float
(
'
-inf
'
))
scores_max_prev
.
fill_
(
float
(
"
-inf
"
))
for
i
in
range
(
int
((
seqlen_kv
//
num_split
)
/
block_N
)):
for
i
in
range
(
int
((
seqlen_kv
//
num_split
)
/
block_N
)):
acc_s
.
fill_
(
0
)
acc_s
.
fill_
(
0
)
acc_s
=
torch
.
einsum
(
'bqhd,bkhd->bhqk'
,
Q_
,
acc_s
=
torch
.
einsum
(
K
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
"bqhd,bkhd->bhqk"
,
i
*
block_N
:(
seqlen_kv
//
num_split
)
*
ks
+
Q_
,
(
i
+
1
)
*
block_N
,
:,
:])
# [batch, seqlen, nheads, block_N]
K
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:
(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:,
:],
)
# [batch, seqlen, nheads, block_N]
scores_max_prev
=
scores_max
scores_max_prev
=
scores_max
scores_max
=
acc_s
.
max
(
dim
=-
1
,
keepdim
=
False
).
values
# [blockM]
scores_max
=
acc_s
.
max
(
dim
=-
1
,
keepdim
=
False
).
values
# [blockM]
scores_scale
=
torch
.
exp2
(
scores_max_prev
-
scores_max
)
scores_scale
=
torch
.
exp2
(
scores_max_prev
-
scores_max
)
...
@@ -288,9 +283,10 @@ def flash_split_ref(Q, K, V, causal):
...
@@ -288,9 +283,10 @@ def flash_split_ref(Q, K, V, causal):
acc_s
=
torch
.
exp2
(
acc_s
-
scores_max
[:,
:,
:,
None
])
acc_s
=
torch
.
exp2
(
acc_s
-
scores_max
[:,
:,
:,
None
])
acc_s_cast
=
acc_s
.
to
(
torch
.
float16
)
acc_s_cast
=
acc_s
.
to
(
torch
.
float16
)
acc_o
+=
torch
.
einsum
(
acc_o
+=
torch
.
einsum
(
'bhqk,bkhd->bqhd'
,
acc_s_cast
,
"bhqk,bkhd->bqhd"
,
V
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:(
seqlen_kv
//
num_split
)
*
ks
+
acc_s_cast
,
(
i
+
1
)
*
block_N
,
:,
:])
V
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:
(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:,
:],
)
scores_sum
=
acc_s
.
sum
(
dim
=-
1
,
keepdim
=
False
)
scores_sum
=
acc_s
.
sum
(
dim
=-
1
,
keepdim
=
False
)
logsum
=
logsum
*
scores_scale
+
scores_sum
logsum
=
logsum
*
scores_scale
+
scores_sum
acc_o
/=
logsum
[:,
:,
:,
None
].
transpose
(
1
,
2
)
acc_o
/=
logsum
[:,
:,
:,
None
].
transpose
(
1
,
2
)
...
@@ -298,8 +294,7 @@ def flash_split_ref(Q, K, V, causal):
...
@@ -298,8 +294,7 @@ def flash_split_ref(Q, K, V, causal):
gacc_o
[
ks
,
:,
:,
:,
:]
=
acc_o
gacc_o
[
ks
,
:,
:,
:,
:]
=
acc_o
glogsum
[
ks
,
:,
:,
:]
=
logsum
glogsum
[
ks
,
:,
:,
:]
=
logsum
return
glogsum
.
to
(
torch
.
float16
).
permute
(
1
,
2
,
0
,
return
glogsum
.
to
(
torch
.
float16
).
permute
(
1
,
2
,
0
,
3
),
gacc_o
.
to
(
torch
.
float16
).
permute
(
1
,
2
,
3
,
0
,
4
)
3
),
gacc_o
.
to
(
torch
.
float16
).
permute
(
1
,
2
,
3
,
0
,
4
)
def
main
(
BATCH
=
1
,
H
=
32
,
Q_CTX
=
128
,
KV_CTX
=
8192
,
D_HEAD
=
128
,
causal
=
False
):
def
main
(
BATCH
=
1
,
H
=
32
,
Q_CTX
=
128
,
KV_CTX
=
8192
,
D_HEAD
=
128
,
causal
=
False
):
...
...
examples/fusedmoe/example_fusedmoe_tilelang.py
View file @
667632cc
...
@@ -9,17 +9,18 @@ from example_fusedmoe_torch import *
...
@@ -9,17 +9,18 @@ from example_fusedmoe_torch import *
@
tilelang
.
jit
(
pass_configs
=
{
"tl.disable_tma_lower"
:
True
,
"tl.disable_warp_specialized"
:
True
})
@
tilelang
.
jit
(
pass_configs
=
{
"tl.disable_tma_lower"
:
True
,
"tl.disable_warp_specialized"
:
True
})
def
moe_forward_tilelang_shared
(
d_hidden
,
def
moe_forward_tilelang_shared
(
d_expert
,
d_hidden
,
n_shared_experts
,
d_expert
,
dtype
,
n_shared_experts
,
num_tokens
,
dtype
,
block_token
=
128
,
num_tokens
,
block_dhidden
=
128
,
block_token
=
128
,
block_dexpert
=
128
,
block_dhidden
=
128
,
threads
=
256
,
block_dexpert
=
128
,
num_stages
=
1
):
threads
=
256
,
num_stages
=
1
,
):
scale
=
1.44269504
# log2(e)
scale
=
1.44269504
# log2(e)
# Parameters
# Parameters
...
@@ -32,21 +33,19 @@ def moe_forward_tilelang_shared(d_hidden,
...
@@ -32,21 +33,19 @@ def moe_forward_tilelang_shared(d_hidden,
shared_W_up_shape
=
(
dexpert
,
dhidden
)
shared_W_up_shape
=
(
dexpert
,
dhidden
)
shared_W_down_shape
=
(
dhidden
,
dexpert
)
shared_W_down_shape
=
(
dhidden
,
dexpert
)
accum_type
=
"
float32
"
accum_type
=
T
.
float32
@
T
.
prim_func
@
T
.
prim_func
def
kernel_shared
(
def
kernel_shared
(
input
:
T
.
Tensor
(
input_shape
,
dtype
),
# type: ignore
input
:
T
.
Tensor
(
input_shape
,
dtype
),
# type: ignore
shared_W_gate
:
T
.
Tensor
(
shared_W_gate_shape
,
dtype
),
# type: ignore
shared_W_gate
:
T
.
Tensor
(
shared_W_gate_shape
,
dtype
),
# type: ignore
shared_W_up
:
T
.
Tensor
(
shared_W_up_shape
,
dtype
),
# type: ignore
shared_W_up
:
T
.
Tensor
(
shared_W_up_shape
,
dtype
),
# type: ignore
shared_W_down
:
T
.
Tensor
(
shared_W_down_shape
,
dtype
),
# type: ignore
shared_W_down
:
T
.
Tensor
(
shared_W_down_shape
,
dtype
),
# type: ignore
up_logits
:
T
.
Tensor
((
num_tokens
,
dexpert
),
dtype
),
# type: ignore
up_logits
:
T
.
Tensor
((
num_tokens
,
dexpert
),
dtype
),
# type: ignore
output
:
T
.
Tensor
(
input_shape
,
dtype
),
# type: ignore
output
:
T
.
Tensor
(
input_shape
,
dtype
),
# type: ignore
):
):
# Step 1: Compute gate and up logits
# Step 1: Compute gate and up logits
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
num_tokens
,
block_token
),
T
.
ceildiv
(
dexpert
,
block_dexpert
),
threads
=
threads
)
as
(
bx
,
by
):
T
.
ceildiv
(
num_tokens
,
block_token
),
T
.
ceildiv
(
dexpert
,
block_dexpert
),
threads
=
threads
)
as
(
bx
,
by
):
# Split the block to shared experts and routed experts
# Split the block to shared experts and routed experts
input_shared
=
T
.
alloc_fragment
((
block_token
,
block_dhidden
),
dtype
=
dtype
)
input_shared
=
T
.
alloc_fragment
((
block_token
,
block_dhidden
),
dtype
=
dtype
)
W_gate_shared
=
T
.
alloc_shared
((
block_dexpert
,
block_dhidden
),
dtype
=
dtype
)
W_gate_shared
=
T
.
alloc_shared
((
block_dexpert
,
block_dhidden
),
dtype
=
dtype
)
...
@@ -70,16 +69,13 @@ def moe_forward_tilelang_shared(d_hidden,
...
@@ -70,16 +69,13 @@ def moe_forward_tilelang_shared(d_hidden,
# Fuse with SiLU and element-wise product
# Fuse with SiLU and element-wise product
for
i
,
j
in
T
.
Parallel
(
block_token
,
block_dexpert
):
for
i
,
j
in
T
.
Parallel
(
block_token
,
block_dexpert
):
gate_logits_local
[
i
,
j
]
=
gate_logits_local
[
i
,
j
]
*
(
gate_logits_local
[
i
,
j
]
=
gate_logits_local
[
i
,
j
]
*
(
1.0
/
(
1.0
+
T
.
exp2
(
-
gate_logits_local
[
i
,
j
]
*
scale
)))
1.0
/
(
1.0
+
T
.
exp2
(
-
gate_logits_local
[
i
,
j
]
*
scale
)))
up_logits_local
[
i
,
j
]
=
up_logits_local
[
i
,
j
]
*
gate_logits_local
[
i
,
j
]
up_logits_local
[
i
,
j
]
=
up_logits_local
[
i
,
j
]
*
gate_logits_local
[
i
,
j
]
T
.
copy
(
up_logits_local
,
up_logits
[
bx
*
block_token
,
by
*
block_dexpert
])
T
.
copy
(
up_logits_local
,
up_logits
[
bx
*
block_token
,
by
*
block_dexpert
])
# Step 2: Compute down logits
# Step 2: Compute down logits
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
num_tokens
,
block_token
),
T
.
ceildiv
(
dhidden
,
block_dhidden
),
threads
=
threads
)
as
(
bx
,
by
):
T
.
ceildiv
(
num_tokens
,
block_token
),
T
.
ceildiv
(
dhidden
,
block_dhidden
),
threads
=
threads
)
as
(
bx
,
by
):
up_logits_shared
=
T
.
alloc_fragment
((
block_token
,
block_dexpert
),
dtype
=
dtype
)
up_logits_shared
=
T
.
alloc_fragment
((
block_token
,
block_dexpert
),
dtype
=
dtype
)
W_down_shared
=
T
.
alloc_shared
((
block_dhidden
,
block_dexpert
),
dtype
=
dtype
)
W_down_shared
=
T
.
alloc_shared
((
block_dhidden
,
block_dexpert
),
dtype
=
dtype
)
output_local
=
T
.
alloc_fragment
((
block_token
,
block_dhidden
),
dtype
=
accum_type
)
output_local
=
T
.
alloc_fragment
((
block_token
,
block_dhidden
),
dtype
=
accum_type
)
...
@@ -98,20 +94,21 @@ def moe_forward_tilelang_shared(d_hidden,
...
@@ -98,20 +94,21 @@ def moe_forward_tilelang_shared(d_hidden,
@
tilelang
.
jit
(
pass_configs
=
{
"tl.disable_tma_lower"
:
True
,
"tl.disable_warp_specialized"
:
True
})
@
tilelang
.
jit
(
pass_configs
=
{
"tl.disable_tma_lower"
:
True
,
"tl.disable_warp_specialized"
:
True
})
def
moe_forward_tilelang_routed
(
d_hidden
,
def
moe_forward_tilelang_routed
(
d_expert
,
d_hidden
,
n_routed_experts
,
d_expert
,
dtype
,
n_routed_experts
,
group_sum
,
dtype
,
group_count
,
group_sum
,
block_token
=
128
,
group_count
,
block_dhidden
=
128
,
block_token
=
128
,
block_dexpert
=
128
,
block_dhidden
=
128
,
threads
=
256
,
block_dexpert
=
128
,
num_stages
=
1
,
threads
=
256
,
k_pack
=
1
,
num_stages
=
1
,
coalesced_width
=
None
):
k_pack
=
1
,
coalesced_width
=
None
,
):
scale
=
1.44269504
# log2(e)
scale
=
1.44269504
# log2(e)
# Parameters
# Parameters
...
@@ -124,7 +121,7 @@ def moe_forward_tilelang_routed(d_hidden,
...
@@ -124,7 +121,7 @@ def moe_forward_tilelang_routed(d_hidden,
# group_count = len(group_sizes_list)
# group_count = len(group_sizes_list)
# M = sum([(group_size + block_token - 1) // block_token for group_size in group_sizes_list])
# M = sum([(group_size + block_token - 1) // block_token for group_size in group_sizes_list])
M
=
math
.
ceil
(
group_sum
/
block_token
)
+
group_count
M
=
math
.
ceil
(
group_sum
/
block_token
)
+
group_count
accum_dtype
=
"
float32
"
accum_dtype
=
T
.
float32
# Tensors: Note that input shape is reshape to (bs * seq_len * n_experts_per_token, dhidden) for grouped gemm
# Tensors: Note that input shape is reshape to (bs * seq_len * n_experts_per_token, dhidden) for grouped gemm
input_shape
=
(
group_sum
,
dhidden
)
input_shape
=
(
group_sum
,
dhidden
)
...
@@ -132,22 +129,22 @@ def moe_forward_tilelang_routed(d_hidden,
...
@@ -132,22 +129,22 @@ def moe_forward_tilelang_routed(d_hidden,
routed_expert_gate_shape
=
(
n_routed_experts
,
dexpert
,
dhidden
)
routed_expert_gate_shape
=
(
n_routed_experts
,
dexpert
,
dhidden
)
routed_expert_up_shape
=
(
n_routed_experts
,
dexpert
,
dhidden
)
routed_expert_up_shape
=
(
n_routed_experts
,
dexpert
,
dhidden
)
routed_expert_down_shape
=
(
n_routed_experts
,
dhidden
,
dexpert
)
routed_expert_down_shape
=
(
n_routed_experts
,
dhidden
,
dexpert
)
routed_expert_weights_shape
=
(
group_sum
)
routed_expert_weights_shape
=
group_sum
group_sizes_shape
=
(
n_routed_experts
)
group_sizes_shape
=
n_routed_experts
@
T
.
prim_func
@
T
.
prim_func
def
kernel
(
def
kernel
(
input
:
T
.
Tensor
(
input_shape
,
dtype
),
# type: ignore
input
:
T
.
Tensor
(
input_shape
,
dtype
),
# type: ignore
routed_expert_gate
:
T
.
Tensor
(
routed_expert_gate_shape
,
dtype
),
# type: ignore
routed_expert_gate
:
T
.
Tensor
(
routed_expert_gate_shape
,
dtype
),
# type: ignore
routed_expert_up
:
T
.
Tensor
(
routed_expert_up_shape
,
dtype
),
# type: ignore
routed_expert_up
:
T
.
Tensor
(
routed_expert_up_shape
,
dtype
),
# type: ignore
routed_expert_down
:
T
.
Tensor
(
routed_expert_down_shape
,
dtype
),
# type: ignore
routed_expert_down
:
T
.
Tensor
(
routed_expert_down_shape
,
dtype
),
# type: ignore
routed_expert_weights
:
T
.
Tensor
(
routed_expert_weights_shape
,
dtype
),
# type: ignore
routed_expert_weights
:
T
.
Tensor
(
routed_expert_weights_shape
,
dtype
),
# type: ignore
group_sizes
:
T
.
Tensor
(
group_sizes_shape
,
"
int32
"
),
# type: ignore
group_sizes
:
T
.
Tensor
(
group_sizes_shape
,
T
.
int32
),
# type: ignore
group_offsets
:
T
.
Tensor
(
group_sizes_shape
,
"
int32
"
),
# type: ignore
group_offsets
:
T
.
Tensor
(
group_sizes_shape
,
T
.
int32
),
# type: ignore
group_padded_offsets
:
T
.
Tensor
(
group_sizes_shape
,
"
int32
"
),
# type: ignore
group_padded_offsets
:
T
.
Tensor
(
group_sizes_shape
,
T
.
int32
),
# type: ignore
group_idx_for_bx
:
T
.
Tensor
((
M
,),
"
int32
"
),
# type: ignore
group_idx_for_bx
:
T
.
Tensor
((
M
,),
T
.
int32
),
# type: ignore
up_logits
:
T
.
Tensor
(
intermediate_shape
,
dtype
),
# type: ignore
up_logits
:
T
.
Tensor
(
intermediate_shape
,
dtype
),
# type: ignore
output
:
T
.
Tensor
(
input_shape
,
dtype
),
# type: ignore
output
:
T
.
Tensor
(
input_shape
,
dtype
),
# type: ignore
):
):
# Step 1: Compute gate and up logits
# Step 1: Compute gate and up logits
with
T
.
Kernel
(
M
,
T
.
ceildiv
(
dexpert
,
block_dexpert
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
M
,
T
.
ceildiv
(
dexpert
,
block_dexpert
),
threads
=
threads
)
as
(
bx
,
by
):
...
@@ -158,8 +155,8 @@ def moe_forward_tilelang_routed(d_hidden,
...
@@ -158,8 +155,8 @@ def moe_forward_tilelang_routed(d_hidden,
gate_logits_local
=
T
.
alloc_fragment
((
block_token
,
block_dexpert
),
dtype
=
accum_dtype
)
gate_logits_local
=
T
.
alloc_fragment
((
block_token
,
block_dexpert
),
dtype
=
accum_dtype
)
up_logits_local
=
T
.
alloc_fragment
((
block_token
,
block_dexpert
),
dtype
=
accum_dtype
)
up_logits_local
=
T
.
alloc_fragment
((
block_token
,
block_dexpert
),
dtype
=
accum_dtype
)
cur_group_idx
=
T
.
alloc_local
([
1
],
"
int32
"
)
cur_group_idx
=
T
.
alloc_local
([
1
],
T
.
int32
)
cur_group_size
=
T
.
alloc_local
([
1
],
"
int32
"
)
cur_group_size
=
T
.
alloc_local
([
1
],
T
.
int32
)
T
.
use_swizzle
(
10
,
enable
=
True
)
T
.
use_swizzle
(
10
,
enable
=
True
)
...
@@ -168,48 +165,37 @@ def moe_forward_tilelang_routed(d_hidden,
...
@@ -168,48 +165,37 @@ def moe_forward_tilelang_routed(d_hidden,
cur_group_idx
[
0
]
=
group_idx_for_bx
[
bx
]
cur_group_idx
[
0
]
=
group_idx_for_bx
[
bx
]
cur_group_size
[
0
]
=
group_sizes
[
cur_group_idx
[
0
]]
cur_group_size
[
0
]
=
group_sizes
[
cur_group_idx
[
0
]]
m_start
=
m_start_padded
-
group_padded_offsets
[
cur_group_idx
[
0
]]
+
group_offsets
[
m_start
=
m_start_padded
-
group_padded_offsets
[
cur_group_idx
[
0
]]
+
group_offsets
[
cur_group_idx
[
0
]]
cur_group_idx
[
0
]]
actual_rows
=
T
.
max
(
0
,
T
.
min
(
block_token
,
cur_group_size
[
0
]
-
(
m_start_padded
-
group_padded_offsets
[
cur_group_idx
[
0
]])))
actual_rows
=
T
.
max
(
0
,
T
.
min
(
block_token
,
cur_group_size
[
0
]
-
(
m_start_padded
-
group_padded_offsets
[
cur_group_idx
[
0
]])))
T
.
clear
(
gate_logits_local
)
T
.
clear
(
gate_logits_local
)
T
.
clear
(
up_logits_local
)
T
.
clear
(
up_logits_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
dhidden
,
block_dhidden
),
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
dhidden
,
block_dhidden
),
num_stages
=
num_stages
):
T
.
copy
(
T
.
copy
(
input
[
m_start
:
m_start
+
block_token
,
k
*
block_dhidden
:
(
k
+
1
)
*
block_dhidden
],
input
[
m_start
:
m_start
+
block_token
,
k
*
block_dhidden
:
(
k
+
1
)
*
block_dhidden
],
input_shared
,
input_shared
,
coalesced_width
=
coalesced_width
)
coalesced_width
=
coalesced_width
,
)
T
.
copy
(
T
.
copy
(
routed_expert_gate
[
cur_group_idx
[
0
],
routed_expert_gate
[
by
*
block_dexpert
:(
by
+
1
)
*
block_dexpert
,
cur_group_idx
[
0
],
by
*
block_dexpert
:
(
by
+
1
)
*
block_dexpert
,
k
*
block_dhidden
:
(
k
+
1
)
*
block_dhidden
k
*
block_dhidden
:(
k
+
1
)
*
block_dhidden
],
],
routed_expert_gate_shared
,
coalesced_width
=
coalesced_width
)
T
.
gemm
(
input_shared
,
routed_expert_gate_shared
,
routed_expert_gate_shared
,
gate_logits_local
,
coalesced_width
=
coalesced_width
,
k_pack
=
k_pack
,
)
transpose_B
=
True
)
T
.
gemm
(
input_shared
,
routed_expert_gate_shared
,
gate_logits_local
,
k_pack
=
k_pack
,
transpose_B
=
True
)
T
.
copy
(
T
.
copy
(
routed_expert_up
[
cur_group_idx
[
0
],
by
*
block_dexpert
:(
by
+
1
)
*
block_dexpert
,
routed_expert_up
[
k
*
block_dhidden
:(
k
+
1
)
*
block_dhidden
],
cur_group_idx
[
0
],
by
*
block_dexpert
:
(
by
+
1
)
*
block_dexpert
,
k
*
block_dhidden
:
(
k
+
1
)
*
block_dhidden
],
routed_expert_up_shared
,
routed_expert_up_shared
,
coalesced_width
=
coalesced_width
)
coalesced_width
=
coalesced_width
,
T
.
gemm
(
)
input_shared
,
T
.
gemm
(
input_shared
,
routed_expert_up_shared
,
up_logits_local
,
k_pack
=
k_pack
,
transpose_B
=
True
)
routed_expert_up_shared
,
up_logits_local
,
k_pack
=
k_pack
,
transpose_B
=
True
)
for
i
,
j
in
T
.
Parallel
(
block_token
,
block_dexpert
):
for
i
,
j
in
T
.
Parallel
(
block_token
,
block_dexpert
):
gate_logits_local
[
i
,
j
]
=
gate_logits_local
[
i
,
j
]
*
(
gate_logits_local
[
i
,
j
]
=
gate_logits_local
[
i
,
j
]
*
(
1.0
/
(
1.0
+
T
.
exp2
(
-
gate_logits_local
[
i
,
j
]
*
scale
)))
1.0
/
(
1.0
+
T
.
exp2
(
-
gate_logits_local
[
i
,
j
]
*
scale
)))
up_logits_local
[
i
,
j
]
=
up_logits_local
[
i
,
j
]
*
gate_logits_local
[
i
,
j
]
up_logits_local
[
i
,
j
]
=
up_logits_local
[
i
,
j
]
*
gate_logits_local
[
i
,
j
]
for
i
,
j
in
T
.
Parallel
(
block_token
,
block_dexpert
):
for
i
,
j
in
T
.
Parallel
(
block_token
,
block_dexpert
):
...
@@ -222,8 +208,8 @@ def moe_forward_tilelang_routed(d_hidden,
...
@@ -222,8 +208,8 @@ def moe_forward_tilelang_routed(d_hidden,
routed_expert_down_shared
=
T
.
alloc_shared
((
block_dhidden
,
block_dexpert
),
dtype
=
dtype
)
routed_expert_down_shared
=
T
.
alloc_shared
((
block_dhidden
,
block_dexpert
),
dtype
=
dtype
)
output_local
=
T
.
alloc_fragment
((
block_token
,
block_dhidden
),
dtype
=
accum_dtype
)
output_local
=
T
.
alloc_fragment
((
block_token
,
block_dhidden
),
dtype
=
accum_dtype
)
cur_group_idx
=
T
.
alloc_local
([
1
],
"
int32
"
)
cur_group_idx
=
T
.
alloc_local
([
1
],
T
.
int32
)
cur_group_size
=
T
.
alloc_local
([
1
],
"
int32
"
)
cur_group_size
=
T
.
alloc_local
([
1
],
T
.
int32
)
T
.
use_swizzle
(
10
,
enable
=
True
)
T
.
use_swizzle
(
10
,
enable
=
True
)
...
@@ -232,50 +218,35 @@ def moe_forward_tilelang_routed(d_hidden,
...
@@ -232,50 +218,35 @@ def moe_forward_tilelang_routed(d_hidden,
cur_group_idx
[
0
]
=
group_idx_for_bx
[
bx
]
cur_group_idx
[
0
]
=
group_idx_for_bx
[
bx
]
cur_group_size
[
0
]
=
group_sizes
[
cur_group_idx
[
0
]]
cur_group_size
[
0
]
=
group_sizes
[
cur_group_idx
[
0
]]
m_start
=
m_start_padded
-
group_padded_offsets
[
cur_group_idx
[
0
]]
+
group_offsets
[
m_start
=
m_start_padded
-
group_padded_offsets
[
cur_group_idx
[
0
]]
+
group_offsets
[
cur_group_idx
[
0
]]
cur_group_idx
[
0
]]
actual_rows
=
T
.
max
(
0
,
T
.
min
(
block_token
,
cur_group_size
[
0
]
-
(
m_start_padded
-
group_padded_offsets
[
cur_group_idx
[
0
]])))
actual_rows
=
T
.
max
(
0
,
T
.
min
(
block_token
,
cur_group_size
[
0
]
-
(
m_start_padded
-
group_padded_offsets
[
cur_group_idx
[
0
]])))
T
.
clear
(
output_local
)
T
.
clear
(
output_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
dexpert
,
block_dexpert
),
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
dexpert
,
block_dexpert
),
num_stages
=
num_stages
):
T
.
copy
(
T
.
copy
(
up_logits
[
m_start
:
m_start
+
block_token
,
up_logits
[
m_start
:
m_start
+
block_token
,
k
*
block_dexpert
:
(
k
+
1
)
*
block_dexpert
],
k
*
block_dexpert
:(
k
+
1
)
*
block_dexpert
],
up_logits_shared
,
up_logits_shared
,
coalesced_width
=
coalesced_width
)
coalesced_width
=
coalesced_width
,
)
T
.
copy
(
T
.
copy
(
routed_expert_down
[
cur_group_idx
[
0
],
routed_expert_down
[
by
*
block_dhidden
:(
by
+
1
)
*
block_dhidden
,
cur_group_idx
[
0
],
by
*
block_dhidden
:
(
by
+
1
)
*
block_dhidden
,
k
*
block_dexpert
:
(
k
+
1
)
*
block_dexpert
k
*
block_dexpert
:(
k
+
1
)
*
block_dexpert
],
],
routed_expert_down_shared
,
coalesced_width
=
coalesced_width
)
T
.
gemm
(
up_logits_shared
,
routed_expert_down_shared
,
routed_expert_down_shared
,
output_local
,
coalesced_width
=
coalesced_width
,
k_pack
=
k_pack
,
)
transpose_B
=
True
)
T
.
gemm
(
up_logits_shared
,
routed_expert_down_shared
,
output_local
,
k_pack
=
k_pack
,
transpose_B
=
True
)
for
i
,
j
in
T
.
Parallel
(
block_token
,
block_dhidden
):
for
i
,
j
in
T
.
Parallel
(
block_token
,
block_dhidden
):
if
i
<
actual_rows
:
if
i
<
actual_rows
:
output
[
m_start
+
i
,
by
*
block_dhidden
+
output
[
m_start
+
i
,
by
*
block_dhidden
+
j
]
=
output_local
[
i
,
j
]
*
routed_expert_weights
[
m_start
+
i
]
j
]
=
output_local
[
i
,
j
]
*
routed_expert_weights
[
m_start
+
i
]
return
kernel
return
kernel
class
Expert
(
nn
.
Module
):
class
Expert
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Dict
,
gate
:
torch
.
Tensor
,
up
:
torch
.
Tensor
,
down
:
torch
.
Tensor
,
d_expert
:
Optional
[
int
]
=
None
):
def
__init__
(
self
,
config
:
Dict
,
gate
:
torch
.
Tensor
,
up
:
torch
.
Tensor
,
down
:
torch
.
Tensor
,
d_expert
:
Optional
[
int
]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
act_fn
=
nn
.
SiLU
()
self
.
act_fn
=
nn
.
SiLU
()
...
@@ -294,14 +265,13 @@ class Expert(nn.Module):
...
@@ -294,14 +265,13 @@ class Expert(nn.Module):
class
MoEGate
(
nn
.
Module
):
class
MoEGate
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Dict
,
weights
:
Dict
):
def
__init__
(
self
,
config
:
Dict
,
weights
:
Dict
):
super
().
__init__
()
super
().
__init__
()
self
.
top_k
:
int
=
config
[
"n_experts_per_token"
]
self
.
top_k
:
int
=
config
[
"n_experts_per_token"
]
self
.
num_experts
:
int
=
config
[
"n_routed_experts"
]
self
.
num_experts
:
int
=
config
[
"n_routed_experts"
]
self
.
d_hidden
:
int
=
config
[
"d_hidden"
]
self
.
d_hidden
:
int
=
config
[
"d_hidden"
]
self
.
W_g_weight
=
weights
[
'
router.weight
'
].
t
()
self
.
W_g_weight
=
weights
[
"
router.weight
"
].
t
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
logits
=
x
@
self
.
W_g_weight
logits
=
x
@
self
.
W_g_weight
...
@@ -312,76 +282,69 @@ class MoEGate(nn.Module):
...
@@ -312,76 +282,69 @@ class MoEGate(nn.Module):
class
MoE
(
nn
.
Module
):
class
MoE
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
Dict
,
shared_kernel
:
tilelang
.
JITKernel
,
routed_kernel
:
tilelang
.
JITKernel
,
weights
:
Dict
,
padding_M
:
int
=
128
config
:
Dict
,
):
shared_kernel
:
tilelang
.
JITKernel
,
routed_kernel
:
tilelang
.
JITKernel
,
weights
:
Dict
,
padding_M
:
int
=
128
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
shared_kernel
=
shared_kernel
self
.
shared_kernel
=
shared_kernel
self
.
routed_kernel
=
routed_kernel
self
.
routed_kernel
=
routed_kernel
self
.
padding_M
=
padding_M
self
.
padding_M
=
padding_M
self
.
experts
=
nn
.
ModuleList
([
self
.
experts
=
nn
.
ModuleList
(
Expert
(
[
config
,
Expert
(
gate
=
weights
[
f
'experts.
{
i
}
.0.weight'
],
config
,
up
=
weights
[
f
'experts.
{
i
}
.1.weight'
],
gate
=
weights
[
f
"experts.
{
i
}
.0.weight"
],
down
=
weights
[
f
'experts.
{
i
}
.2.weight'
])
for
i
in
range
(
config
[
"n_routed_experts"
])
up
=
weights
[
f
"experts.
{
i
}
.1.weight"
],
])
down
=
weights
[
f
"experts.
{
i
}
.2.weight"
],
)
for
i
in
range
(
config
[
"n_routed_experts"
])
]
)
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
gating_network
=
MoEGate
(
config
,
weights
).
to
(
self
.
device
)
self
.
gating_network
=
MoEGate
(
config
,
weights
).
to
(
self
.
device
)
shared_expert_dim
=
config
[
"d_expert"
]
*
config
[
"n_shared_experts"
]
shared_expert_dim
=
config
[
"d_expert"
]
*
config
[
"n_shared_experts"
]
self
.
shared_expert
=
Expert
(
self
.
shared_expert
=
Expert
(
config
=
config
,
config
=
config
,
gate
=
weights
[
'shared_experts.0.weight'
],
gate
=
weights
[
"shared_experts.0.weight"
],
up
=
weights
[
'shared_experts.1.weight'
],
up
=
weights
[
"shared_experts.1.weight"
],
down
=
weights
[
'shared_experts.2.weight'
],
down
=
weights
[
"shared_experts.2.weight"
],
d_expert
=
shared_expert_dim
).
to
(
self
.
device
)
d_expert
=
shared_expert_dim
,
).
to
(
self
.
device
)
self
.
expert_cache
=
torch
.
zeros
(
self
.
expert_cache
=
torch
.
zeros
(
(
config
[
"batch_size"
]
*
config
[
"seq_len"
],
config
[
"d_hidden"
]),
(
config
[
"batch_size"
]
*
config
[
"seq_len"
],
config
[
"d_hidden"
]),
dtype
=
torch
.
float16
,
device
=
self
.
device
dtype
=
torch
.
float16
,
)
device
=
self
.
device
)
self
.
stacked_expert_w_gate
=
torch
.
stack
([
expert
.
W_gate_weight
for
expert
in
self
.
experts
],
dim
=
0
)
self
.
stacked_expert_w_gate
=
torch
.
stack
([
expert
.
W_gate_weight
for
expert
in
self
.
experts
],
self
.
stacked_expert_w_up
=
torch
.
stack
([
expert
.
W_up_weight
for
expert
in
self
.
experts
],
dim
=
0
)
dim
=
0
)
self
.
stacked_expert_w_down
=
torch
.
stack
([
expert
.
W_down_weight
for
expert
in
self
.
experts
],
dim
=
0
)
self
.
stacked_expert_w_up
=
torch
.
stack
([
expert
.
W_up_weight
for
expert
in
self
.
experts
],
dim
=
0
)
self
.
stacked_expert_w_down
=
torch
.
stack
([
expert
.
W_down_weight
for
expert
in
self
.
experts
],
dim
=
0
)
self
.
stacked_expert_tokens
=
torch
.
empty
(
self
.
stacked_expert_tokens
=
torch
.
empty
(
(
config
[
"batch_size"
]
*
config
[
"seq_len"
]
*
config
[
"n_experts_per_token"
],
(
config
[
"batch_size"
]
*
config
[
"seq_len"
]
*
config
[
"n_experts_per_token"
],
self
.
config
[
"d_hidden"
]),
self
.
config
[
"d_hidden"
]),
dtype
=
torch
.
float16
,
dtype
=
torch
.
float16
,
device
=
self
.
device
)
device
=
self
.
device
,
)
self
.
stacked_expert_weights
=
torch
.
empty
(
self
.
stacked_expert_weights
=
torch
.
empty
(
(
config
[
"batch_size"
]
*
config
[
"seq_len"
]
*
config
[
"n_experts_per_token"
]),
(
config
[
"batch_size"
]
*
config
[
"seq_len"
]
*
config
[
"n_experts_per_token"
]),
dtype
=
torch
.
float16
,
device
=
self
.
device
dtype
=
torch
.
float16
,
)
device
=
self
.
device
)
self
.
stacked_expert_tokens_idxs
=
torch
.
empty
(
self
.
stacked_expert_tokens_idxs
=
torch
.
empty
(
(
config
[
"batch_size"
]
*
config
[
"seq_len"
]
*
config
[
"n_experts_per_token"
]),
(
config
[
"batch_size"
]
*
config
[
"seq_len"
]
*
config
[
"n_experts_per_token"
]),
dtype
=
torch
.
int64
,
device
=
self
.
device
dtype
=
torch
.
int64
,
)
device
=
self
.
device
)
self
.
up_logits_shared
=
torch
.
empty
(
self
.
up_logits_shared
=
torch
.
empty
(
(
config
[
"batch_size"
]
*
config
[
"seq_len"
],
self
.
config
[
"d_expert"
]),
(
config
[
"batch_size"
]
*
config
[
"seq_len"
],
self
.
config
[
"d_expert"
]),
dtype
=
torch
.
float16
,
device
=
self
.
device
dtype
=
torch
.
float16
,
)
device
=
self
.
device
)
self
.
expert_output_shared
=
torch
.
empty
(
self
.
expert_output_shared
=
torch
.
empty
(
(
config
[
"batch_size"
]
*
config
[
"seq_len"
],
self
.
config
[
"d_hidden"
]),
(
config
[
"batch_size"
]
*
config
[
"seq_len"
],
self
.
config
[
"d_hidden"
]),
dtype
=
torch
.
float16
,
device
=
self
.
device
dtype
=
torch
.
float16
,
)
device
=
self
.
device
)
self
.
up_logits_routed
=
torch
.
empty
(
self
.
up_logits_routed
=
torch
.
empty
(
(
config
[
"batch_size"
]
*
config
[
"seq_len"
]
*
config
[
"n_experts_per_token"
],
(
config
[
"batch_size"
]
*
config
[
"seq_len"
]
*
config
[
"n_experts_per_token"
],
self
.
config
[
"d_expert"
]),
self
.
config
[
"d_expert"
]),
dtype
=
torch
.
float16
,
dtype
=
torch
.
float16
,
device
=
self
.
device
)
device
=
self
.
device
,
)
self
.
expert_output_routed
=
torch
.
empty
(
self
.
expert_output_routed
=
torch
.
empty
(
(
config
[
"batch_size"
]
*
config
[
"seq_len"
]
*
config
[
"n_experts_per_token"
],
(
config
[
"batch_size"
]
*
config
[
"seq_len"
]
*
config
[
"n_experts_per_token"
],
self
.
config
[
"d_hidden"
]),
self
.
config
[
"d_hidden"
]),
dtype
=
torch
.
float16
,
dtype
=
torch
.
float16
,
device
=
self
.
device
)
device
=
self
.
device
,
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -413,22 +376,20 @@ class MoE(nn.Module):
...
@@ -413,22 +376,20 @@ class MoE(nn.Module):
self
.
stacked_expert_tokens
[
start_idx
:
end_idx
]
=
expert_tokens
self
.
stacked_expert_tokens
[
start_idx
:
end_idx
]
=
expert_tokens
self
.
stacked_expert_tokens_idxs
[
start_idx
:
end_idx
]
=
exp_token_idxs
self
.
stacked_expert_tokens_idxs
[
start_idx
:
end_idx
]
=
exp_token_idxs
self
.
stacked_expert_weights
[
start_idx
:
end_idx
]
=
flat_expert_weights
[
self
.
stacked_expert_weights
[
start_idx
:
end_idx
]
=
flat_expert_weights
[
idxs
[
start_idx
:
end_idx
]]
idxs
[
start_idx
:
end_idx
]]
group_sizes
=
torch
.
tensor
(
counts
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
group_sizes
=
torch
.
tensor
(
counts
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
group_offset
=
torch
.
tensor
(
group_offset
=
torch
.
tensor
(
tokens_per_expert
-
counts
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
tokens_per_expert
-
counts
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
group_padded_offsets
=
[
0
for
_
in
range
(
len
(
group_sizes
))]
group_padded_offsets
=
[
0
for
_
in
range
(
len
(
group_sizes
))]
for
i
in
range
(
1
,
len
(
group_sizes
)):
for
i
in
range
(
1
,
len
(
group_sizes
)):
group_padded_offsets
[
i
]
=
group_padded_offsets
[
i
-
1
]
+
math
.
ceil
(
group_padded_offsets
[
i
]
=
group_padded_offsets
[
i
-
1
]
+
math
.
ceil
((
counts
[
i
-
1
]
+
1
)
/
self
.
padding_M
)
*
self
.
padding_M
(
counts
[
i
-
1
]
+
1
)
/
self
.
padding_M
)
*
self
.
padding_M
block_token
=
128
block_token
=
128
M
=
math
.
ceil
(
M
=
(
self
.
config
[
"batch_size"
]
*
self
.
config
[
"seq_len"
]
*
math
.
ceil
(
self
.
config
[
"batch_size"
]
*
self
.
config
[
"seq_len"
]
*
self
.
config
[
"n_experts_per_token"
]
/
block_token
)
self
.
config
[
"n_experts_per_token"
]
/
block_token
)
+
self
.
config
[
"n_routed_experts"
]
+
self
.
config
[
"n_routed_experts"
]
)
group_idx_for_bx
=
[
0
for
_
in
range
(
M
)]
group_idx_for_bx
=
[
0
for
_
in
range
(
M
)]
for
bx
in
range
(
M
):
for
bx
in
range
(
M
):
...
@@ -437,8 +398,7 @@ class MoE(nn.Module):
...
@@ -437,8 +398,7 @@ class MoE(nn.Module):
if
m_start_padded
>=
group_padded_offsets
[
i
]:
if
m_start_padded
>=
group_padded_offsets
[
i
]:
group_idx_for_bx
[
bx
]
=
i
group_idx_for_bx
[
bx
]
=
i
group_padded_offsets
=
torch
.
tensor
(
group_padded_offsets
=
torch
.
tensor
(
group_padded_offsets
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
group_padded_offsets
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
group_idx_for_bx
=
torch
.
tensor
(
group_idx_for_bx
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
group_idx_for_bx
=
torch
.
tensor
(
group_idx_for_bx
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# Multi-stream execution
# Multi-stream execution
...
@@ -448,11 +408,19 @@ class MoE(nn.Module):
...
@@ -448,11 +408,19 @@ class MoE(nn.Module):
with
torch
.
cuda
.
stream
(
routed_stream
):
with
torch
.
cuda
.
stream
(
routed_stream
):
# Tilelang version: Grouped GEMM
# Tilelang version: Grouped GEMM
self
.
routed_kernel
(
self
.
stacked_expert_tokens
,
self
.
stacked_expert_w_gate
,
self
.
routed_kernel
(
self
.
stacked_expert_w_up
,
self
.
stacked_expert_w_down
,
self
.
stacked_expert_tokens
,
self
.
stacked_expert_weights
,
group_sizes
,
group_offset
,
self
.
stacked_expert_w_gate
,
group_padded_offsets
,
group_idx_for_bx
,
self
.
up_logits_routed
,
self
.
stacked_expert_w_up
,
self
.
expert_output_routed
)
self
.
stacked_expert_w_down
,
self
.
stacked_expert_weights
,
group_sizes
,
group_offset
,
group_padded_offsets
,
group_idx_for_bx
,
self
.
up_logits_routed
,
self
.
expert_output_routed
,
)
# Scatter reduce
# Scatter reduce
self
.
expert_cache
=
torch
.
scatter_reduce
(
self
.
expert_cache
=
torch
.
scatter_reduce
(
...
@@ -460,14 +428,19 @@ class MoE(nn.Module):
...
@@ -460,14 +428,19 @@ class MoE(nn.Module):
0
,
0
,
self
.
stacked_expert_tokens_idxs
.
view
(
-
1
,
1
).
repeat
(
1
,
x_flat
.
shape
[
-
1
]),
self
.
stacked_expert_tokens_idxs
.
view
(
-
1
,
1
).
repeat
(
1
,
x_flat
.
shape
[
-
1
]),
self
.
expert_output_routed
,
self
.
expert_output_routed
,
reduce
=
'sum'
)
reduce
=
"sum"
,
)
routed_output
=
self
.
expert_cache
.
view
(
*
orig_shape
)
routed_output
=
self
.
expert_cache
.
view
(
*
orig_shape
)
with
torch
.
cuda
.
stream
(
shared_stream
):
with
torch
.
cuda
.
stream
(
shared_stream
):
self
.
shared_kernel
(
self
.
shared_kernel
(
x_flat
,
self
.
shared_expert
.
W_gate_weight
,
x_flat
,
self
.
shared_expert
.
W_up_weight
,
self
.
shared_expert
.
W_down_weight
,
self
.
shared_expert
.
W_gate_weight
,
self
.
up_logits_shared
,
self
.
expert_output_shared
)
self
.
shared_expert
.
W_up_weight
,
self
.
shared_expert
.
W_down_weight
,
self
.
up_logits_shared
,
self
.
expert_output_shared
,
)
shared_output
=
self
.
expert_output_shared
.
view
(
*
orig_shape
)
shared_output
=
self
.
expert_output_shared
.
view
(
*
orig_shape
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -491,14 +464,15 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
...
@@ -491,14 +464,15 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
"""
"""
input_tensor
,
weights
,
config
=
data
input_tensor
,
weights
,
config
=
data
dtype_str
=
"
float16
"
dtype_str
=
T
.
float16
shared_kernel
=
moe_forward_tilelang_shared
(
shared_kernel
=
moe_forward_tilelang_shared
(
config
[
"d_hidden"
],
config
[
"d_hidden"
],
config
[
"d_expert"
],
config
[
"d_expert"
],
config
[
"n_shared_experts"
],
config
[
"n_shared_experts"
],
dtype
=
dtype_str
,
dtype
=
dtype_str
,
num_tokens
=
config
[
"batch_size"
]
*
config
[
"seq_len"
])
num_tokens
=
config
[
"batch_size"
]
*
config
[
"seq_len"
],
)
routed_kernel
=
moe_forward_tilelang_routed
(
routed_kernel
=
moe_forward_tilelang_routed
(
config
[
"d_hidden"
],
config
[
"d_hidden"
],
config
[
"d_expert"
],
config
[
"d_expert"
],
...
@@ -512,7 +486,8 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
...
@@ -512,7 +486,8 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
threads
=
256
,
threads
=
256
,
num_stages
=
1
,
num_stages
=
1
,
k_pack
=
1
,
k_pack
=
1
,
coalesced_width
=
2
)
coalesced_width
=
2
,
)
moe
=
MoE
(
config
,
shared_kernel
,
routed_kernel
,
weights
,
padding_M
=
128
)
moe
=
MoE
(
config
,
shared_kernel
,
routed_kernel
,
weights
,
padding_M
=
128
)
...
@@ -521,13 +496,7 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
...
@@ -521,13 +496,7 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
return
output
return
output
def
main
(
d_hidden
=
7168
,
def
main
(
d_hidden
=
7168
,
d_expert
=
2048
,
n_routed_experts
=
8
,
n_shared_experts
=
1
,
n_experts_per_token
=
4
,
batch_size
=
1
,
seq_len
=
8192
):
d_expert
=
2048
,
n_routed_experts
=
8
,
n_shared_experts
=
1
,
n_experts_per_token
=
4
,
batch_size
=
1
,
seq_len
=
8192
):
config
=
{
config
=
{
"dhidden"
:
d_hidden
,
"dhidden"
:
d_hidden
,
"dexpert"
:
d_expert
,
"dexpert"
:
d_expert
,
...
@@ -536,7 +505,7 @@ def main(d_hidden=7168,
...
@@ -536,7 +505,7 @@ def main(d_hidden=7168,
"nexpertspertoken"
:
n_experts_per_token
,
"nexpertspertoken"
:
n_experts_per_token
,
"bs"
:
batch_size
,
"bs"
:
batch_size
,
"seqlen"
:
seq_len
,
"seqlen"
:
seq_len
,
"seed"
:
81394
"seed"
:
81394
,
}
}
data
=
generate_input
(
**
config
)
data
=
generate_input
(
**
config
)
...
...
examples/fusedmoe/example_fusedmoe_torch.py
View file @
667632cc
...
@@ -6,7 +6,6 @@ from typing import Dict, Tuple, Optional
...
@@ -6,7 +6,6 @@ from typing import Dict, Tuple, Optional
# Reference code in PyTorch
# Reference code in PyTorch
class
ExpertTorch
(
nn
.
Module
):
class
ExpertTorch
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Dict
,
d_expert
:
Optional
[
int
]
=
None
):
def
__init__
(
self
,
config
:
Dict
,
d_expert
:
Optional
[
int
]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -25,7 +24,6 @@ class ExpertTorch(nn.Module):
...
@@ -25,7 +24,6 @@ class ExpertTorch(nn.Module):
class
MoEGateTorch
(
nn
.
Module
):
class
MoEGateTorch
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Dict
):
def
__init__
(
self
,
config
:
Dict
):
super
().
__init__
()
super
().
__init__
()
self
.
top_k
:
int
=
config
[
"n_experts_per_token"
]
self
.
top_k
:
int
=
config
[
"n_experts_per_token"
]
...
@@ -43,12 +41,10 @@ class MoEGateTorch(nn.Module):
...
@@ -43,12 +41,10 @@ class MoEGateTorch(nn.Module):
class
MoETorch
(
nn
.
Module
):
class
MoETorch
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Dict
):
def
__init__
(
self
,
config
:
Dict
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
experts
=
nn
.
ModuleList
(
self
.
experts
=
nn
.
ModuleList
([
ExpertTorch
(
config
)
for
_
in
range
(
config
[
"n_routed_experts"
])])
[
ExpertTorch
(
config
)
for
_
in
range
(
config
[
"n_routed_experts"
])])
self
.
gating_network
=
MoEGateTorch
(
config
)
self
.
gating_network
=
MoEGateTorch
(
config
)
shared_expert_dim
=
config
[
"d_expert"
]
*
config
[
"n_shared_experts"
]
shared_expert_dim
=
config
[
"d_expert"
]
*
config
[
"n_shared_experts"
]
self
.
shared_expert
=
ExpertTorch
(
config
=
config
,
d_expert
=
shared_expert_dim
)
self
.
shared_expert
=
ExpertTorch
(
config
=
config
,
d_expert
=
shared_expert_dim
)
...
@@ -67,8 +63,7 @@ class MoETorch(nn.Module):
...
@@ -67,8 +63,7 @@ class MoETorch(nn.Module):
return
routed_output
+
shared_output
return
routed_output
+
shared_output
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
moe_infer
(
self
,
x
:
torch
.
Tensor
,
flat_expert_indices
:
torch
.
Tensor
,
def
moe_infer
(
self
,
x
:
torch
.
Tensor
,
flat_expert_indices
:
torch
.
Tensor
,
flat_expert_weights
:
torch
.
Tensor
)
->
torch
.
Tensor
:
flat_expert_weights
:
torch
.
Tensor
)
->
torch
.
Tensor
:
expert_cache
=
torch
.
zeros_like
(
x
)
expert_cache
=
torch
.
zeros_like
(
x
)
# test_expert_cache = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"]))
# test_expert_cache = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"]))
# test_expert_tokens = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"]))
# test_expert_tokens = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"]))
...
@@ -91,8 +86,7 @@ class MoETorch(nn.Module):
...
@@ -91,8 +86,7 @@ class MoETorch(nn.Module):
expert_out
=
expert
(
expert_tokens
)
expert_out
=
expert
(
expert_tokens
)
expert_out
.
mul_
(
flat_expert_weights
[
idxs
[
start_idx
:
end_idx
]])
expert_out
.
mul_
(
flat_expert_weights
[
idxs
[
start_idx
:
end_idx
]])
expert_cache
.
scatter_reduce_
(
expert_cache
.
scatter_reduce_
(
0
,
exp_token_idxs
.
view
(
-
1
,
1
).
repeat
(
1
,
x
.
shape
[
-
1
]),
expert_out
,
reduce
=
"sum"
)
0
,
exp_token_idxs
.
view
(
-
1
,
1
).
repeat
(
1
,
x
.
shape
[
-
1
]),
expert_out
,
reduce
=
'sum'
)
return
expert_cache
return
expert_cache
...
@@ -116,21 +110,21 @@ def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
...
@@ -116,21 +110,21 @@ def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
moe
=
MoETorch
(
config
)
moe
=
MoETorch
(
config
)
# Fill in the given weights of the model
# Fill in the given weights of the model
moe
.
gating_network
.
W_g
.
weight
=
nn
.
Parameter
(
weights
[
'
router.weight
'
])
moe
.
gating_network
.
W_g
.
weight
=
nn
.
Parameter
(
weights
[
"
router.weight
"
])
for
i
in
range
(
num_experts
):
for
i
in
range
(
num_experts
):
gate_proj_weight
=
weights
[
f
'
experts.
{
i
}
.0.weight
'
]
gate_proj_weight
=
weights
[
f
"
experts.
{
i
}
.0.weight
"
]
up_proj_weight
=
weights
[
f
'
experts.
{
i
}
.1.weight
'
]
up_proj_weight
=
weights
[
f
"
experts.
{
i
}
.1.weight
"
]
down_proj_weight
=
weights
[
f
'
experts.
{
i
}
.2.weight
'
]
down_proj_weight
=
weights
[
f
"
experts.
{
i
}
.2.weight
"
]
# Transpose weights to match expected shape for nn.Linear
# Transpose weights to match expected shape for nn.Linear
moe
.
experts
[
i
].
W_gate
.
weight
=
nn
.
Parameter
(
gate_proj_weight
.
t
())
moe
.
experts
[
i
].
W_gate
.
weight
=
nn
.
Parameter
(
gate_proj_weight
.
t
())
moe
.
experts
[
i
].
W_up
.
weight
=
nn
.
Parameter
(
up_proj_weight
.
t
())
moe
.
experts
[
i
].
W_up
.
weight
=
nn
.
Parameter
(
up_proj_weight
.
t
())
moe
.
experts
[
i
].
W_down
.
weight
=
nn
.
Parameter
(
down_proj_weight
.
t
())
moe
.
experts
[
i
].
W_down
.
weight
=
nn
.
Parameter
(
down_proj_weight
.
t
())
moe
.
shared_expert
.
W_gate
.
weight
=
nn
.
Parameter
(
weights
[
'
shared_experts.0.weight
'
].
t
())
moe
.
shared_expert
.
W_gate
.
weight
=
nn
.
Parameter
(
weights
[
"
shared_experts.0.weight
"
].
t
())
moe
.
shared_expert
.
W_up
.
weight
=
nn
.
Parameter
(
weights
[
'
shared_experts.1.weight
'
].
t
())
moe
.
shared_expert
.
W_up
.
weight
=
nn
.
Parameter
(
weights
[
"
shared_experts.1.weight
"
].
t
())
moe
.
shared_expert
.
W_down
.
weight
=
nn
.
Parameter
(
weights
[
'
shared_experts.2.weight
'
].
t
())
moe
.
shared_expert
.
W_down
.
weight
=
nn
.
Parameter
(
weights
[
"
shared_experts.2.weight
"
].
t
())
output
=
moe
(
input_tensor
)
output
=
moe
(
input_tensor
)
...
@@ -140,10 +134,9 @@ def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
...
@@ -140,10 +134,9 @@ def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
# Input generation for the reference code
# Input generation for the reference code
def
generate_input
(
dhidden
:
int
,
dexpert
:
int
,
nroutedexperts
:
int
,
nsharedexperts
:
int
,
def
generate_input
(
nexpertspertoken
:
int
,
bs
:
int
,
seqlen
:
int
,
dhidden
:
int
,
dexpert
:
int
,
nroutedexperts
:
int
,
nsharedexperts
:
int
,
nexpertspertoken
:
int
,
bs
:
int
,
seqlen
:
int
,
seed
:
int
seed
:
int
)
->
Tuple
[
torch
.
Tensor
,
Dict
,
Dict
]:
)
->
Tuple
[
torch
.
Tensor
,
Dict
,
Dict
]:
# Really dumb but for now _ isn't parsing correctly.
# Really dumb but for now _ isn't parsing correctly.
d_hidden
=
dhidden
d_hidden
=
dhidden
d_expert
=
dexpert
d_expert
=
dexpert
...
@@ -163,50 +156,40 @@ def generate_input(dhidden: int, dexpert: int, nroutedexperts: int, nsharedexper
...
@@ -163,50 +156,40 @@ def generate_input(dhidden: int, dexpert: int, nroutedexperts: int, nsharedexper
"seq_len"
:
seq_len
,
"seq_len"
:
seq_len
,
}
}
gen
=
torch
.
Generator
(
device
=
'
cuda
'
)
gen
=
torch
.
Generator
(
device
=
"
cuda
"
)
gen
.
manual_seed
(
seed
)
gen
.
manual_seed
(
seed
)
num_experts
=
n_routed_experts
num_experts
=
n_routed_experts
expert_dim
=
d_expert
expert_dim
=
d_expert
weights
=
{}
weights
=
{}
input_tensor
=
torch
.
randn
((
batch_size
,
seq_len
,
d_hidden
),
input_tensor
=
torch
.
randn
((
batch_size
,
seq_len
,
d_hidden
),
device
=
"cuda"
,
dtype
=
torch
.
float16
,
generator
=
gen
).
contiguous
()
device
=
'cuda'
,
dtype
=
torch
.
float16
,
generator
=
gen
).
contiguous
()
# Initialize router weights
# Initialize router weights
weights
[
'router.weight'
]
=
torch
.
randn
(
weights
[
"router.weight"
]
=
torch
.
randn
((
num_experts
,
d_hidden
),
device
=
"cuda"
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
d_hidden
)
(
num_experts
,
d_hidden
),
device
=
"cuda"
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
d_hidden
)
for
i
in
range
(
num_experts
):
for
i
in
range
(
num_experts
):
weights
[
f
'experts.
{
i
}
.0.weight'
]
=
torch
.
randn
(
weights
[
f
"experts.
{
i
}
.0.weight"
]
=
torch
.
randn
(
(
d_hidden
,
expert_dim
),
device
=
'cuda'
,
dtype
=
torch
.
float16
,
(
d_hidden
,
expert_dim
),
device
=
"cuda"
,
dtype
=
torch
.
float16
,
generator
=
gen
generator
=
gen
)
/
math
.
sqrt
(
expert_dim
)
)
/
math
.
sqrt
(
expert_dim
)
weights
[
f
'experts.
{
i
}
.1.weight'
]
=
torch
.
randn
(
weights
[
f
"experts.
{
i
}
.1.weight"
]
=
torch
.
randn
(
(
d_hidden
,
expert_dim
),
device
=
'cuda'
,
dtype
=
torch
.
float16
,
(
d_hidden
,
expert_dim
),
device
=
"cuda"
,
dtype
=
torch
.
float16
,
generator
=
gen
generator
=
gen
)
/
math
.
sqrt
(
expert_dim
)
)
/
math
.
sqrt
(
expert_dim
)
weights
[
f
'experts.
{
i
}
.2.weight'
]
=
torch
.
randn
(
weights
[
f
"experts.
{
i
}
.2.weight"
]
=
torch
.
randn
(
(
expert_dim
,
d_hidden
),
device
=
'cuda'
,
dtype
=
torch
.
float16
,
(
expert_dim
,
d_hidden
),
device
=
"cuda"
,
dtype
=
torch
.
float16
,
generator
=
gen
generator
=
gen
)
/
math
.
sqrt
(
d_hidden
)
)
/
math
.
sqrt
(
d_hidden
)
weights
[
'shared_experts.0.weight'
]
=
torch
.
randn
(
weights
[
"shared_experts.0.weight"
]
=
torch
.
randn
(
(
d_hidden
,
expert_dim
*
n_shared_experts
),
(
d_hidden
,
expert_dim
*
n_shared_experts
),
device
=
"cuda"
,
dtype
=
torch
.
float16
,
generator
=
gen
device
=
'cuda'
,
)
/
math
.
sqrt
(
expert_dim
*
n_shared_experts
)
dtype
=
torch
.
float16
,
weights
[
"shared_experts.1.weight"
]
=
torch
.
randn
(
generator
=
gen
)
/
math
.
sqrt
(
expert_dim
*
n_shared_experts
)
(
d_hidden
,
expert_dim
*
n_shared_experts
),
device
=
"cuda"
,
dtype
=
torch
.
float16
,
generator
=
gen
weights
[
'shared_experts.1.weight'
]
=
torch
.
randn
(
)
/
math
.
sqrt
(
expert_dim
*
n_shared_experts
)
(
d_hidden
,
expert_dim
*
n_shared_experts
),
weights
[
"shared_experts.2.weight"
]
=
torch
.
randn
(
device
=
'cuda'
,
(
expert_dim
*
n_shared_experts
,
d_hidden
),
device
=
"cuda"
,
dtype
=
torch
.
float16
,
generator
=
gen
dtype
=
torch
.
float16
,
)
/
math
.
sqrt
(
d_hidden
)
generator
=
gen
)
/
math
.
sqrt
(
expert_dim
*
n_shared_experts
)
weights
[
'shared_experts.2.weight'
]
=
torch
.
randn
((
expert_dim
*
n_shared_experts
,
d_hidden
),
device
=
'cuda'
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
d_hidden
)
return
(
input_tensor
,
weights
,
config
)
return
(
input_tensor
,
weights
,
config
)
...
...
examples/fusedmoe/test_example_fusedmoe.py
View file @
667632cc
...
@@ -4,13 +4,8 @@ import example_fusedmoe_tilelang
...
@@ -4,13 +4,8 @@ import example_fusedmoe_tilelang
def
test_example_fusedmoe_tilelang
():
def
test_example_fusedmoe_tilelang
():
example_fusedmoe_tilelang
.
main
(
example_fusedmoe_tilelang
.
main
(
d_hidden
=
1024
,
d_hidden
=
1024
,
d_expert
=
256
,
n_routed_experts
=
8
,
n_shared_experts
=
1
,
n_experts_per_token
=
4
,
batch_size
=
1
,
seq_len
=
1024
d_expert
=
256
,
)
n_routed_experts
=
8
,
n_shared_experts
=
1
,
n_experts_per_token
=
4
,
batch_size
=
1
,
seq_len
=
1024
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
examples/gdn/example_chunk_delta_bwd.py
View file @
667632cc
...
@@ -12,6 +12,7 @@ print(tilelang.__file__, flush=True)
...
@@ -12,6 +12,7 @@ print(tilelang.__file__, flush=True)
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try
:
try
:
import
fla
import
fla
print
(
fla
.
__file__
,
flush
=
True
)
print
(
fla
.
__file__
,
flush
=
True
)
from
fla.ops.common.chunk_delta_h
import
chunk_gated_delta_rule_bwd_dhu
from
fla.ops.common.chunk_delta_h
import
chunk_gated_delta_rule_bwd_dhu
except
ImportError
:
except
ImportError
:
...
@@ -24,7 +25,7 @@ import torch.nn.functional as F
...
@@ -24,7 +25,7 @@ import torch.nn.functional as F
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
# torch.set_printoptions(profile="full")
# torch.set_printoptions(profile="full")
from
utils
import
*
from
test_
utils
import
assert_similar
def
prepare_input
(
def
prepare_input
(
...
@@ -49,6 +50,7 @@ def prepare_input(
...
@@ -49,6 +50,7 @@ def prepare_input(
G
=
F
.
logsigmoid
(
G
)
G
=
F
.
logsigmoid
(
G
)
try
:
try
:
from
fla.ops.utils.cumsum
import
chunk_local_cumsum
from
fla.ops.utils.cumsum
import
chunk_local_cumsum
G
=
chunk_local_cumsum
(
G
,
chunk_size
)
G
=
chunk_local_cumsum
(
G
,
chunk_size
)
except
ImportError
:
except
ImportError
:
print
(
"fla not found, skip cumsum"
)
print
(
"fla not found, skip cumsum"
)
...
@@ -125,8 +127,11 @@ def torch_chunk_gated_delta_rule_bwd_dhu(
...
@@ -125,8 +127,11 @@ def torch_chunk_gated_delta_rule_bwd_dhu(
DV
=
dv
.
shape
[
-
1
]
DV
=
dv
.
shape
[
-
1
]
block_S
=
64
block_S
=
64
BS
=
S
//
block_S
BS
=
S
//
block_S
dh
,
dh0
,
dv2
=
torch
.
empty
((
B
,
BS
,
H
,
DK
,
DV
),
dtype
=
output_dtype
),
torch
.
empty
(
dh
,
dh0
,
dv2
=
(
(
B
,
H
,
DK
,
DV
),
dtype
=
state_dtype
),
torch
.
empty
((
B
,
S
,
H
,
DV
),
dtype
=
output_dtype
)
torch
.
empty
((
B
,
BS
,
H
,
DK
,
DV
),
dtype
=
output_dtype
),
torch
.
empty
((
B
,
H
,
DK
,
DV
),
dtype
=
state_dtype
),
torch
.
empty
((
B
,
S
,
H
,
DV
),
dtype
=
output_dtype
),
)
dh_tmp
=
torch
.
empty
((
B
,
H
,
DK
,
DV
),
dtype
=
accum_dtype
)
dh_tmp
=
torch
.
empty
((
B
,
H
,
DK
,
DV
),
dtype
=
accum_dtype
)
dv_tmp
=
torch
.
empty
((
B
,
S
,
H
,
DV
),
dtype
=
accum_dtype
)
dv_tmp
=
torch
.
empty
((
B
,
S
,
H
,
DV
),
dtype
=
accum_dtype
)
Q_tmp
=
torch
.
empty
((
B
,
S
,
H
,
DK
),
dtype
=
accum_dtype
)
Q_tmp
=
torch
.
empty
((
B
,
S
,
H
,
DK
),
dtype
=
accum_dtype
)
...
@@ -138,34 +143,30 @@ def torch_chunk_gated_delta_rule_bwd_dhu(
...
@@ -138,34 +143,30 @@ def torch_chunk_gated_delta_rule_bwd_dhu(
for
i_s
in
range
(
BS
-
1
,
-
1
,
-
1
):
for
i_s
in
range
(
BS
-
1
,
-
1
,
-
1
):
dh
[:,
i_s
,
:,
:,
:]
=
dh_tmp
dh
[:,
i_s
,
:,
:,
:]
=
dh_tmp
dv_tmp
=
torch
.
matmul
(
K
[:,
i_s
*
block_S
:(
i_s
+
1
)
*
block_S
,
:,
:].
permute
(
0
,
2
,
1
,
3
),
dv_tmp
=
torch
.
matmul
(
K
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:].
permute
(
0
,
2
,
1
,
3
),
dh_tmp
.
to
(
K
.
dtype
)).
permute
(
0
,
2
,
1
,
3
)
dh_tmp
.
to
(
K
.
dtype
)).
permute
(
0
,
2
,
1
,
3
)
if
use_g
:
if
use_g
:
for
i_bh
in
range
(
B
*
H
):
for
i_bh
in
range
(
B
*
H
):
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
for
i_s2
in
range
(
block_S
):
for
i_s2
in
range
(
block_S
):
if
G
[
i_b
,
i_s
*
block_S
+
block_S
-
1
,
i_h
]
-
G
[
i_b
,
i_s
*
block_S
+
i_s2
,
if
G
[
i_b
,
i_s
*
block_S
+
block_S
-
1
,
i_h
]
-
G
[
i_b
,
i_s
*
block_S
+
i_s2
,
i_h
]
<=
0
:
i_h
]
<=
0
:
dv_tmp
[
i_b
,
i_s2
,
i_h
,
:]
*=
torch
.
exp
(
G
[
i_b
,
i_s
*
block_S
+
block_S
-
1
,
i_h
]
-
G
[
i_b
,
i_s
*
block_S
+
i_s2
,
i_h
])
dv_tmp
[
i_b
,
i_s2
,
i_h
,
:]
*=
torch
.
exp
(
G
[
i_b
,
i_s
*
block_S
+
block_S
-
1
,
i_h
]
-
G
[
i_b
,
i_s
*
block_S
+
i_s2
,
i_h
])
else
:
else
:
dv_tmp
[
i_b
,
i_s2
,
i_h
,
:]
=
0
dv_tmp
[
i_b
,
i_s2
,
i_h
,
:]
=
0
dv_tmp
+=
dv
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:]
dv_tmp
+=
dv
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:]
dv2
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:]
=
dv_tmp
dv2
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:]
=
dv_tmp
if
use_g
:
if
use_g
:
G_last
=
G
[:,
i_s
*
block_S
+
block_S
-
1
,
:]
G_last
=
G
[:,
i_s
*
block_S
+
block_S
-
1
,
:]
for
i_bh
in
range
(
B
*
H
):
for
i_bh
in
range
(
B
*
H
):
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
dh_tmp
[
i_b
,
i_h
,
:,
:]
*=
torch
.
exp
(
G_last
[
i_b
,
i_h
])
dh_tmp
[
i_b
,
i_h
,
:,
:]
*=
torch
.
exp
(
G_last
[
i_b
,
i_h
])
Q_tmp
=
Q
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:]
Q_tmp
=
Q
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:]
for
i_s2
in
range
(
block_S
):
for
i_s2
in
range
(
block_S
):
for
i_k
in
range
(
DK
):
for
i_k
in
range
(
DK
):
Q_tmp
[:,
i_s2
,
:,
i_k
]
*=
torch
.
exp
(
G
[:,
i_s
*
block_S
+
i_s2
,
:])
Q_tmp
[:,
i_s2
,
:,
i_k
]
*=
torch
.
exp
(
G
[:,
i_s
*
block_S
+
i_s2
,
:])
Q_tmp
*=
scale
Q_tmp
*=
scale
W_tmp
=
W
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:]
W_tmp
=
W
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:]
dO_tmp
=
dO
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:]
dO_tmp
=
dO
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:]
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
dh_tmp
+=
torch
.
matmul
(
Q_tmp
.
permute
(
0
,
2
,
3
,
1
),
dO_tmp
.
permute
(
0
,
2
,
1
,
3
))
dh_tmp
+=
torch
.
matmul
(
Q_tmp
.
permute
(
0
,
2
,
3
,
1
),
dO_tmp
.
permute
(
0
,
2
,
1
,
3
))
...
@@ -223,19 +224,19 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
...
@@ -223,19 +224,19 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
@
T
.
prim_func
@
T
.
prim_func
def
kernel
(
def
kernel
(
# Input
# Input
Q
:
T
.
Tensor
(
Q_shape
,
dtype
=
input_dtype
),
Q
:
T
.
Tensor
(
Q_shape
,
dtype
=
input_dtype
),
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
W
:
T
.
Tensor
(
W_shape
,
dtype
=
input_dtype
),
W
:
T
.
Tensor
(
W_shape
,
dtype
=
input_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
gate_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
gate_dtype
),
h0
:
T
.
Tensor
(
h0_shape
,
dtype
=
input_dtype
),
h0
:
T
.
Tensor
(
h0_shape
,
dtype
=
input_dtype
),
dht
:
T
.
Tensor
(
dht_shape
,
dtype
=
input_dtype
),
dht
:
T
.
Tensor
(
dht_shape
,
dtype
=
input_dtype
),
dO
:
T
.
Tensor
(
dO_shape
,
dtype
=
input_dtype
),
dO
:
T
.
Tensor
(
dO_shape
,
dtype
=
input_dtype
),
dv
:
T
.
Tensor
(
dv_shape
,
dtype
=
input_dtype
),
dv
:
T
.
Tensor
(
dv_shape
,
dtype
=
input_dtype
),
# Output
# Output
dh
:
T
.
Tensor
(
dh_shape
,
dtype
=
output_dtype
),
dh
:
T
.
Tensor
(
dh_shape
,
dtype
=
output_dtype
),
dh0
:
T
.
Tensor
(
dh0_shape
,
dtype
=
state_dtype
),
dh0
:
T
.
Tensor
(
dh0_shape
,
dtype
=
state_dtype
),
dv2
:
T
.
Tensor
(
dv2_shape
,
dtype
=
output_dtype
),
dv2
:
T
.
Tensor
(
dv2_shape
,
dtype
=
output_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
DV
,
block_DV
),
B
*
H
,
threads
=
threads
)
as
(
bv
,
bbh
):
with
T
.
Kernel
(
T
.
ceildiv
(
DV
,
block_DV
),
B
*
H
,
threads
=
threads
)
as
(
bv
,
bbh
):
bb
,
bh
=
bbh
//
H
,
bbh
%
H
bb
,
bh
=
bbh
//
H
,
bbh
%
H
...
@@ -249,13 +250,13 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
...
@@ -249,13 +250,13 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
dv_fragment
=
T
.
alloc_fragment
((
block_S
,
block_DV
),
dtype
=
accum_dtype
)
dv_fragment
=
T
.
alloc_fragment
((
block_S
,
block_DV
),
dtype
=
accum_dtype
)
dv_fragment_2
=
T
.
alloc_fragment
((
block_S
,
block_DV
),
dtype
=
accum_dtype
)
dv_fragment_2
=
T
.
alloc_fragment
((
block_S
,
block_DV
),
dtype
=
accum_dtype
)
dO_shared
=
T
.
alloc_shared
((
block_S
,
block_DV
),
dtype
=
input_dtype
)
dO_shared
=
T
.
alloc_shared
((
block_S
,
block_DV
),
dtype
=
input_dtype
)
dO_shared_t
=
T
.
alloc_shared
((
block_DV
,
block_S
),
dtype
=
"
float32
"
)
dO_shared_t
=
T
.
alloc_shared
((
block_DV
,
block_S
),
dtype
=
T
.
float32
)
dO_fragment
=
T
.
alloc_fragment
((
block_S
,
block_DV
),
dtype
=
"
float32
"
)
dO_fragment
=
T
.
alloc_fragment
((
block_S
,
block_DV
),
dtype
=
T
.
float32
)
dO_fragment_t
=
T
.
alloc_fragment
((
block_DV
,
block_S
),
dtype
=
"
float32
"
)
dO_fragment_t
=
T
.
alloc_fragment
((
block_DV
,
block_S
),
dtype
=
T
.
float32
)
K_shared
=
T
.
alloc_shared
((
block_S
,
DK
),
dtype
=
input_dtype
)
K_shared
=
T
.
alloc_shared
((
block_S
,
DK
),
dtype
=
input_dtype
)
Q_shared
=
T
.
alloc_shared
((
block_S
,
DK
),
dtype
=
input_dtype
)
Q_shared
=
T
.
alloc_shared
((
block_S
,
DK
),
dtype
=
input_dtype
)
Q_shared_fp32
=
T
.
alloc_shared
((
block_S
,
DK
),
dtype
=
"
float32
"
)
Q_shared_fp32
=
T
.
alloc_shared
((
block_S
,
DK
),
dtype
=
T
.
float32
)
W_shared
=
T
.
alloc_shared
((
block_S
,
DK
),
dtype
=
input_dtype
)
W_shared
=
T
.
alloc_shared
((
block_S
,
DK
),
dtype
=
input_dtype
)
G_last_local
=
T
.
alloc_local
((
1
),
dtype
=
gate_dtype
)
G_last_local
=
T
.
alloc_local
((
1
),
dtype
=
gate_dtype
)
...
@@ -269,20 +270,22 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
...
@@ -269,20 +270,22 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
T
.
use_swizzle
(
10
)
T
.
use_swizzle
(
10
)
T
.
annotate_layout
({
T
.
annotate_layout
(
b_dh_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
b_dh_shared
),
{
b_dh_shared_fp32
:
tilelang
.
layout
.
make_swizzled_layout
(
b_dh_shared_fp32
),
b_dh_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
b_dh_shared
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
b_dh_shared_fp32
:
tilelang
.
layout
.
make_swizzled_layout
(
b_dh_shared_fp32
),
dO_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dO_shared
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
dO_shared_t
:
tilelang
.
layout
.
make_swizzled_layout
(
dO_shared_t
),
dO_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dO_shared
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
dO_shared_t
:
tilelang
.
layout
.
make_swizzled_layout
(
dO_shared_t
),
Q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
Q_shared_fp32
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared_fp32
),
Q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared
),
W_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
W_shared
),
Q_shared_fp32
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared_fp32
),
})
W_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
W_shared
),
}
)
if
use_final_state_gradient
:
if
use_final_state_gradient
:
T
.
copy
(
dht
[
bb
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
],
b_dh_shared
)
T
.
copy
(
dht
[
bb
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
],
b_dh_shared
)
T
.
copy
(
b_dh_shared
,
b_dh_fragment
)
T
.
copy
(
b_dh_shared
,
b_dh_fragment
)
else
:
else
:
T
.
clear
(
b_dh_fragment
)
T
.
clear
(
b_dh_fragment
)
...
@@ -293,17 +296,14 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
...
@@ -293,17 +296,14 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
# Store the updated dh
# Store the updated dh
T
.
copy
(
b_dh_fragment
,
b_dh_shared
)
T
.
copy
(
b_dh_fragment
,
b_dh_shared
)
T
.
copy
(
b_dh_shared
,
dh
[
bb
,
i_s_inv
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
T
.
copy
(
b_dh_shared
,
dh
[
bb
,
i_s_inv
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
# Update dv
# Update dv
T
.
copy
(
K
[
bb
,
i_s_inv
*
block_S
:
(
i_s_inv
+
1
)
*
block_S
,
bh
,
0
:
DK
],
K_shared
)
T
.
copy
(
K
[
bb
,
i_s_inv
*
block_S
:
(
i_s_inv
+
1
)
*
block_S
,
bh
,
0
:
DK
],
K_shared
)
T
.
gemm
(
K_shared
,
b_dh_shared
,
dv_fragment
,
clear_accum
=
True
)
T
.
gemm
(
K_shared
,
b_dh_shared
,
dv_fragment
,
clear_accum
=
True
)
if
use_g
:
if
use_g
:
T
.
copy
(
T
.
copy
(
G
[
bb
,
i_s_inv
*
block_S
:
(
i_s_inv
+
1
)
*
block_S
,
bh
],
G_shared
,
disable_tma
=
True
)
G
[
bb
,
i_s_inv
*
block_S
:(
i_s_inv
+
1
)
*
block_S
,
bh
],
G_shared
,
disable_tma
=
True
)
T
.
copy
(
G_shared
,
G_fragment
)
T
.
copy
(
G_shared
,
G_fragment
)
G_last_local
[
0
]
=
G_shared
[
block_S
-
1
]
G_last_local
[
0
]
=
G_shared
[
block_S
-
1
]
G_last_local_exp
[
0
]
=
T
.
exp
(
G_last_local
[
0
])
G_last_local_exp
[
0
]
=
T
.
exp
(
G_last_local
[
0
])
...
@@ -313,27 +313,22 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
...
@@ -313,27 +313,22 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
# with T.If(G_last_local[0] - G_shared[i_s2] <= 0):
# with T.If(G_last_local[0] - G_shared[i_s2] <= 0):
with
T
.
If
(
G_last_local
[
0
]
-
G_fragment
[
i_s2
]
<=
0
):
with
T
.
If
(
G_last_local
[
0
]
-
G_fragment
[
i_s2
]
<=
0
):
with
T
.
Then
():
with
T
.
Then
():
dv_fragment
[
i_s2
,
dv_fragment
[
i_s2
,
i_v
]
=
dv_fragment
[
i_s2
,
i_v
]
*
G_fragment_post
[
i_s2
]
i_v
]
=
dv_fragment
[
i_s2
,
i_v
]
*
G_fragment_post
[
i_s2
]
with
T
.
Else
():
with
T
.
Else
():
dv_fragment
[
i_s2
,
i_v
]
=
0
dv_fragment
[
i_s2
,
i_v
]
=
0
T
.
copy
(
T
.
copy
(
dv
[
bb
,
i_s_inv
*
block_S
:
(
i_s_inv
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
],
dv_shared
)
dv
[
bb
,
i_s_inv
*
block_S
:(
i_s_inv
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:(
bv
+
1
)
*
block_DV
],
dv_shared
)
T
.
copy
(
dv_shared
,
dv_fragment_2
)
T
.
copy
(
dv_shared
,
dv_fragment_2
)
for
i_s2
,
i_v
in
T
.
Parallel
(
block_S
,
block_DV
):
for
i_s2
,
i_v
in
T
.
Parallel
(
block_S
,
block_DV
):
dv_fragment
[
i_s2
,
i_v
]
=
dv_fragment
[
i_s2
,
i_v
]
+
dv_fragment_2
[
i_s2
,
i_v
]
dv_fragment
[
i_s2
,
i_v
]
=
dv_fragment
[
i_s2
,
i_v
]
+
dv_fragment_2
[
i_s2
,
i_v
]
# Store the updated dv
# Store the updated dv
T
.
copy
(
dv_fragment
,
dv_shared
)
T
.
copy
(
dv_fragment
,
dv_shared
)
T
.
copy
(
T
.
copy
(
dv_shared
,
dv2
[
bb
,
i_s_inv
*
block_S
:
(
i_s_inv
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
dv_shared
,
dv2
[
bb
,
i_s_inv
*
block_S
:(
i_s_inv
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:(
bv
+
1
)
*
block_DV
])
# Update dh
# Update dh
T
.
copy
(
Q
[
bb
,
i_s_inv
*
block_S
:
(
i_s_inv
+
1
)
*
block_S
,
bh
,
0
:
DK
],
Q_shared
)
T
.
copy
(
Q
[
bb
,
i_s_inv
*
block_S
:
(
i_s_inv
+
1
)
*
block_S
,
bh
,
0
:
DK
],
Q_shared
)
T
.
copy
(
W
[
bb
,
i_s_inv
*
block_S
:
(
i_s_inv
+
1
)
*
block_S
,
bh
,
0
:
DK
],
W_shared
)
T
.
copy
(
W
[
bb
,
i_s_inv
*
block_S
:
(
i_s_inv
+
1
)
*
block_S
,
bh
,
0
:
DK
],
W_shared
)
T
.
clear
(
Q_fragment
)
T
.
clear
(
Q_fragment
)
if
use_g
:
if
use_g
:
...
@@ -353,9 +348,7 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
...
@@ -353,9 +348,7 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
for
i_s2
,
i_k
in
T
.
Parallel
(
block_S
,
DK
):
for
i_s2
,
i_k
in
T
.
Parallel
(
block_S
,
DK
):
Q_fragment_t
[
i_k
,
i_s2
]
=
Q_fragment
[
i_s2
,
i_k
]
Q_fragment_t
[
i_k
,
i_s2
]
=
Q_fragment
[
i_s2
,
i_k
]
T
.
copy
(
T
.
copy
(
dO
[
bb
,
i_s_inv
*
block_S
:
(
i_s_inv
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
],
dO_shared
)
dO
[
bb
,
i_s_inv
*
block_S
:(
i_s_inv
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:(
bv
+
1
)
*
block_DV
],
dO_shared
)
T
.
copy
(
dO_shared
,
dO_fragment
)
T
.
copy
(
dO_shared
,
dO_fragment
)
for
i_s2
,
i_v
in
T
.
Parallel
(
block_S
,
block_DV
):
for
i_s2
,
i_v
in
T
.
Parallel
(
block_S
,
block_DV
):
dO_fragment_t
[
i_v
,
i_s2
]
=
dO_fragment
[
i_s2
,
i_v
]
dO_fragment_t
[
i_v
,
i_s2
]
=
dO_fragment
[
i_s2
,
i_v
]
...
@@ -369,7 +362,7 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
...
@@ -369,7 +362,7 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
b_dh_fragment
[
i_k
,
i_v
]
+=
b_dh_fragment_1
[
i_k
,
i_v
]
-
b_dh_fragment_2
[
i_k
,
i_v
]
b_dh_fragment
[
i_k
,
i_v
]
+=
b_dh_fragment_1
[
i_k
,
i_v
]
-
b_dh_fragment_2
[
i_k
,
i_v
]
if
use_initial_state
:
if
use_initial_state
:
T
.
copy
(
b_dh_fragment
,
dh0
[
bb
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
T
.
copy
(
b_dh_fragment
,
dh0
[
bb
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
return
kernel
return
kernel
...
@@ -444,44 +437,61 @@ def run_test(
...
@@ -444,44 +437,61 @@ def run_test(
num_stages
=
0
,
num_stages
=
0
,
use_torch
=
False
,
use_torch
=
False
,
):
):
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
=
prepare_input
(
getattr
(
torch
,
input_dtype
),
B
,
getattr
(
torch
,
output_dtype
),
S
,
getattr
(
torch
,
accum_dtype
),
H
,
getattr
(
torch
,
gate_dtype
),
DK
,
getattr
(
torch
,
state_dtype
))
DV
,
dh_ref
,
dh0_ref
,
dv2_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
state_dtype
))
getattr
(
torch
,
accum_dtype
),
dh_tilelang
,
dh0_tilelang
,
dv2_tilelang
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
state_dtype
),
getattr
(
torch
,
gate_dtype
),
)
getattr
(
torch
,
state_dtype
))
dh_ref
,
dh0_ref
,
dv2_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
)
)
dh_tilelang
,
dh0_tilelang
,
dv2_tilelang
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
)
)
# fla ref
# fla ref
print
(
"fla running..."
,
flush
=
True
)
print
(
"fla running..."
,
flush
=
True
)
if
use_g
:
if
use_g
:
dh_ref
,
dh0_ref
,
dv2_ref
=
chunk_gated_delta_rule_bwd_dhu
(
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
,
dh_ref
,
dh0_ref
,
dv2_ref
=
chunk_gated_delta_rule_bwd_dhu
(
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
,
scale
)
scale
)
else
:
else
:
G
=
G
.
fill_
(
0
)
G
=
G
.
fill_
(
0
)
dh_ref
,
dh0_ref
,
dv2_ref
=
chunk_gated_delta_rule_bwd_dhu
(
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
,
dh_ref
,
dh0_ref
,
dv2_ref
=
chunk_gated_delta_rule_bwd_dhu
(
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
,
scale
)
scale
)
# tilelang
# tilelang
print
(
"tilelang running..."
,
flush
=
True
)
print
(
"tilelang running..."
,
flush
=
True
)
kernel
=
tilelang_chunk_gated_delta_rule_bwd_dhu
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
kernel
=
tilelang_chunk_gated_delta_rule_bwd_dhu
(
accum_dtype
,
gate_dtype
,
state_dtype
,
B
,
chunk_size
,
scale
,
use_g
,
use_initial_state
,
S
,
use_final_state_gradient
,
block_DV
,
threads
,
H
,
num_stages
)
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
scale
,
use_g
,
use_initial_state
,
use_final_state_gradient
,
block_DV
,
threads
,
num_stages
,
)
# kernel = tilelang.compile(program)
# kernel = tilelang.compile(program)
print
(
kernel
.
get_kernel_source
())
print
(
kernel
.
get_kernel_source
())
dh_tilelang
,
dh0_tilelang
,
dv2_tilelang
=
kernel
(
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
)
dh_tilelang
,
dh0_tilelang
,
dv2_tilelang
=
kernel
(
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
)
fla_time
=
do_bench
(
fla_time
=
do_bench
(
chunk_gated_delta_rule_bwd_dhu
,
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
,
scale
,
chunk_size
=
chunk_size
)
chunk_gated_delta_rule_bwd_dhu
,
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
,
scale
,
chunk_size
=
chunk_size
)
tilelang_time
=
do_bench
(
kernel
,
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
)
tilelang_time
=
do_bench
(
kernel
,
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
)
print
(
f
"fla time:
{
fla_time
}
ms"
)
print
(
f
"fla time:
{
fla_time
}
ms"
)
...
@@ -496,19 +506,47 @@ def run_test(
...
@@ -496,19 +506,47 @@ def run_test(
print
(
"torch running..."
,
flush
=
True
)
print
(
"torch running..."
,
flush
=
True
)
if
use_g
:
if
use_g
:
dh_ref_torch
,
dh0_ref_torch
,
dv2_ref_torch
=
torch_chunk_gated_delta_rule_bwd_dhu
(
dh_ref_torch
,
dh0_ref_torch
,
dv2_ref_torch
=
torch_chunk_gated_delta_rule_bwd_dhu
(
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
,
scale
,
use_g
,
use_initial_state
,
Q
,
use_final_state_gradient
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
K
,
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
W
,
gate_dtype
),
getattr
(
torch
,
state_dtype
))
G
,
h0
,
dht
,
dO
,
dv
,
scale
,
use_g
,
use_initial_state
,
use_final_state_gradient
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
),
)
dh_ref_torch
=
dh_ref_torch
.
cuda
()
dh_ref_torch
=
dh_ref_torch
.
cuda
()
dh0_ref_torch
=
dh0_ref_torch
.
cuda
()
dh0_ref_torch
=
dh0_ref_torch
.
cuda
()
dv2_ref_torch
=
dv2_ref_torch
.
cuda
()
dv2_ref_torch
=
dv2_ref_torch
.
cuda
()
else
:
else
:
dh_ref_torch
,
dh0_ref_torch
,
dv2_ref_torch
=
torch_chunk_gated_delta_rule_bwd_dhu
(
dh_ref_torch
,
dh0_ref_torch
,
dv2_ref_torch
=
torch_chunk_gated_delta_rule_bwd_dhu
(
Q
,
K
,
W
,
None
,
h0
,
dht
,
dO
,
dv
,
scale
,
use_g
,
use_initial_state
,
Q
,
use_final_state_gradient
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
K
,
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
W
,
gate_dtype
),
getattr
(
torch
,
state_dtype
))
None
,
h0
,
dht
,
dO
,
dv
,
scale
,
use_g
,
use_initial_state
,
use_final_state_gradient
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
),
)
dh_ref_torch
=
dh_ref_torch
.
cuda
()
dh_ref_torch
=
dh_ref_torch
.
cuda
()
dh0_ref_torch
=
dh0_ref_torch
.
cuda
()
dh0_ref_torch
=
dh0_ref_torch
.
cuda
()
dv2_ref_torch
=
dv2_ref_torch
.
cuda
()
dv2_ref_torch
=
dv2_ref_torch
.
cuda
()
...
@@ -554,11 +592,11 @@ def main():
...
@@ -554,11 +592,11 @@ def main():
H
=
8
,
H
=
8
,
DK
=
DK
,
DK
=
DK
,
DV
=
128
,
DV
=
128
,
input_dtype
=
"
bfloat16
"
,
input_dtype
=
T
.
bfloat16
,
output_dtype
=
"
bfloat16
"
,
output_dtype
=
T
.
bfloat16
,
accum_dtype
=
"
float32
"
,
accum_dtype
=
T
.
float32
,
gate_dtype
=
"
float32
"
,
gate_dtype
=
T
.
float32
,
state_dtype
=
"
float32
"
,
state_dtype
=
T
.
float32
,
chunk_size
=
64
,
chunk_size
=
64
,
scale
=
DK
**-
0.5
,
scale
=
DK
**-
0.5
,
use_g
=
True
,
use_g
=
True
,
...
...
examples/gdn/example_chunk_delta_h.py
View file @
667632cc
...
@@ -3,12 +3,14 @@
...
@@ -3,12 +3,14 @@
import
sys
# noqa: F401
import
sys
# noqa: F401
import
tilelang
import
tilelang
import
tilelang.language
as
T
import
tilelang.language
as
T
from
tilelang.autotuner
import
autotune
# Add your fla repository path to sys.path
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try
:
try
:
import
fla
import
fla
print
(
fla
.
__file__
)
print
(
fla
.
__file__
)
from
fla.ops.common.chunk_delta_h
import
chunk_gated_delta_rule_fwd_h
from
fla.ops.common.chunk_delta_h
import
chunk_gated_delta_rule_fwd_h
except
ImportError
:
except
ImportError
:
...
@@ -19,7 +21,7 @@ import torch
...
@@ -19,7 +21,7 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
tilelang.engine.callback
import
register_cuda_postproc_callback
# noqa: F401
from
tilelang.engine.callback
import
register_cuda_postproc_callback
# noqa: F401
from
utils
import
*
from
test_
utils
import
assert_similar
# (zhengju) We can slightly modify the generated cuda code from tilelang lowering
# (zhengju) We can slightly modify the generated cuda code from tilelang lowering
# in the debug folder to make the performance better. To enable this callback,
# in the debug folder to make the performance better. To enable this callback,
...
@@ -55,6 +57,7 @@ def prepare_input(
...
@@ -55,6 +57,7 @@ def prepare_input(
G
=
F
.
logsigmoid
(
G
)
G
=
F
.
logsigmoid
(
G
)
try
:
try
:
from
fla.ops.utils.cumsum
import
chunk_local_cumsum
from
fla.ops.utils.cumsum
import
chunk_local_cumsum
G
=
chunk_local_cumsum
(
G
,
chunk_size
)
G
=
chunk_local_cumsum
(
G
,
chunk_size
)
except
ImportError
:
except
ImportError
:
print
(
"fla not found, skip cumsum"
)
print
(
"fla not found, skip cumsum"
)
...
@@ -80,7 +83,21 @@ def prepare_output(
...
@@ -80,7 +83,21 @@ def prepare_output(
return
h
,
final_state
,
V_new
return
h
,
final_state
,
V_new
@
tilelang
.
jit
(
out_idx
=
[
-
3
,
-
2
,
-
1
])
def
get_configs
():
import
itertools
block_DK
=
[
32
,
64
,
128
]
block_DV
=
[
32
,
64
,
128
]
threads
=
[
128
,
256
]
num_stages
=
[
1
,
2
,
3
]
_configs
=
list
(
itertools
.
product
(
block_DK
,
block_DV
,
threads
,
num_stages
))
configs
=
[{
"block_DK"
:
c
[
0
],
"block_DV"
:
c
[
1
],
"threads"
:
c
[
2
],
"num_stages"
:
c
[
3
]}
for
c
in
_configs
]
return
configs
@
autotune
(
configs
=
get_configs
(),
warmup
=
3
,
rep
=
5
)
@
tilelang
.
jit
(
out_idx
=
[
-
3
,
-
2
,
-
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
})
def
tilelang_chunk_gated_delta_rule_fwd_h
(
def
tilelang_chunk_gated_delta_rule_fwd_h
(
# task config
# task config
B
,
B
,
...
@@ -94,15 +111,15 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
...
@@ -94,15 +111,15 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
gate_dtype
,
gate_dtype
,
state_dtype
,
state_dtype
,
chunk_size
,
chunk_size
,
use_g
=
True
,
use_g
,
use_initial_state
=
True
,
use_initial_state
,
store_final_state
=
True
,
store_final_state
,
save_new_value
=
True
,
save_new_value
,
# kernel config
# kernel config
block_DK
=
64
,
block_DK
=
64
,
block_DV
=
64
,
block_DV
=
32
,
threads
=
256
,
threads
=
128
,
num_stages
=
0
,
num_stages
=
1
,
):
):
block_S
=
chunk_size
block_S
=
chunk_size
BS
=
S
//
block_S
BS
=
S
//
block_S
...
@@ -118,14 +135,14 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
...
@@ -118,14 +135,14 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
@
T
.
prim_func
@
T
.
prim_func
def
kernel
(
def
kernel
(
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
W
:
T
.
Tensor
(
W_shape
,
dtype
=
input_dtype
),
W
:
T
.
Tensor
(
W_shape
,
dtype
=
input_dtype
),
U
:
T
.
Tensor
(
U_shape
,
dtype
=
input_dtype
),
U
:
T
.
Tensor
(
U_shape
,
dtype
=
input_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
gate_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
gate_dtype
),
initial_state
:
T
.
Tensor
(
initial_state_shape
,
dtype
=
input_dtype
),
initial_state
:
T
.
Tensor
(
initial_state_shape
,
dtype
=
input_dtype
),
h
:
T
.
Tensor
(
h_shape
,
dtype
=
output_dtype
),
h
:
T
.
Tensor
(
h_shape
,
dtype
=
output_dtype
),
final_state
:
T
.
Tensor
(
final_state_shape
,
dtype
=
state_dtype
),
final_state
:
T
.
Tensor
(
final_state_shape
,
dtype
=
state_dtype
),
V_new
:
T
.
Tensor
(
V_shape
,
dtype
=
output_dtype
),
V_new
:
T
.
Tensor
(
V_shape
,
dtype
=
output_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
DV
,
block_DV
),
B
*
H
,
threads
=
threads
)
as
(
bv
,
bbh
):
with
T
.
Kernel
(
T
.
ceildiv
(
DV
,
block_DV
),
B
*
H
,
threads
=
threads
)
as
(
bv
,
bbh
):
bb
,
bh
=
bbh
//
H
,
bbh
%
H
bb
,
bh
=
bbh
//
H
,
bbh
%
H
...
@@ -143,35 +160,35 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
...
@@ -143,35 +160,35 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
G_shared
=
T
.
alloc_shared
((
block_S
,
block_DV
),
dtype
=
gate_dtype
)
G_shared
=
T
.
alloc_shared
((
block_S
,
block_DV
),
dtype
=
gate_dtype
)
G_fragment
=
T
.
alloc_fragment
((
block_S
,
block_DV
),
dtype
=
gate_dtype
)
G_fragment
=
T
.
alloc_fragment
((
block_S
,
block_DV
),
dtype
=
gate_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
b_h_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
b_h_shared
),
{
U_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
U_shared
),
b_h_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
b_h_shared
),
W_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
W_shared
),
U_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
U_shared
),
V_new_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
V_new_shared
),
W_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
W_shared
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
V_new_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
V_new_shared
),
G_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
G_shared
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
})
G_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
G_shared
),
}
)
T
.
use_swizzle
(
10
)
T
.
use_swizzle
(
10
)
if
use_initial_state
:
if
use_initial_state
:
T
.
copy
(
initial_state
[
bb
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
],
b_h_shared
)
T
.
copy
(
initial_state
[
bb
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
],
b_h_shared
)
T
.
copy
(
b_h_shared
,
b_h_fragment
)
T
.
copy
(
b_h_shared
,
b_h_fragment
)
else
:
else
:
T
.
clear
(
b_h_fragment
)
T
.
clear
(
b_h_fragment
)
for
i_s
in
T
.
Pipelined
(
T
.
ceildiv
(
S
,
block_S
),
num_stages
=
num_stages
):
for
i_s
in
T
.
Pipelined
(
T
.
ceildiv
(
S
,
block_S
),
num_stages
=
num_stages
):
# Store previous result to the hidden tensor, like the epilogue
# Store previous result to the hidden tensor, like the epilogue
T
.
copy
(
b_h_shared
,
h
[
bb
,
i_s
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
T
.
copy
(
b_h_shared
,
h
[
bb
,
i_s
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
# Recurrence
# Recurrence
T
.
copy
(
W
[
bb
,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
bh
,
0
:
DK
],
W_shared
)
T
.
copy
(
W
[
bb
,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
bh
,
0
:
DK
],
W_shared
)
T
.
gemm
(
W_shared
,
b_h_shared
,
V_new_fragment
,
clear_accum
=
True
)
T
.
gemm
(
W_shared
,
b_h_shared
,
V_new_fragment
,
clear_accum
=
True
)
# U - W * S
# U - W * S
T
.
copy
(
T
.
copy
(
U
[
bb
,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
],
U_shared
)
U
[
bb
,
i_s
*
block_S
:(
i_s
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:(
bv
+
1
)
*
block_DV
],
U_shared
)
T
.
copy
(
U_shared
,
U_fragment
)
T
.
copy
(
U_shared
,
U_fragment
)
for
i_s2
,
i_v
in
T
.
Parallel
(
block_S
,
block_DV
):
for
i_s2
,
i_v
in
T
.
Parallel
(
block_S
,
block_DV
):
V_new_fragment
[
i_s2
,
i_v
]
=
-
V_new_fragment
[
i_s2
,
i_v
]
+
U_fragment
[
i_s2
,
i_v
]
V_new_fragment
[
i_s2
,
i_v
]
=
-
V_new_fragment
[
i_s2
,
i_v
]
+
U_fragment
[
i_s2
,
i_v
]
...
@@ -179,11 +196,9 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
...
@@ -179,11 +196,9 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
# Save V_new
# Save V_new
if
save_new_value
:
if
save_new_value
:
T
.
copy
(
V_new_fragment
,
dst
=
V_new_shared
)
T
.
copy
(
V_new_fragment
,
dst
=
V_new_shared
)
T
.
copy
(
T
.
copy
(
V_new_shared
,
V_new
[
bb
,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
V_new_shared
,
V_new
[
bb
,
i_s
*
block_S
:(
i_s
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:(
bv
+
1
)
*
block_DV
])
T
.
copy
(
K
[
bb
,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
bh
,
0
:
DK
],
K_shared
)
T
.
copy
(
K
[
bb
,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
bh
,
0
:
DK
],
K_shared
)
# use_g
# use_g
if
use_g
:
if
use_g
:
G_last_local
[
0
]
=
G
[
bb
,
(
i_s
+
1
)
*
block_S
-
1
,
bh
]
G_last_local
[
0
]
=
G
[
bb
,
(
i_s
+
1
)
*
block_S
-
1
,
bh
]
...
@@ -193,11 +208,12 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
...
@@ -193,11 +208,12 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
for
i_s2
,
i_v
in
T
.
Parallel
(
block_S
,
block_DV
):
for
i_s2
,
i_v
in
T
.
Parallel
(
block_S
,
block_DV
):
with
T
.
If
(
G_last_local
[
0
]
-
G_fragment
[
i_s2
,
i_v
]
<=
0
):
with
T
.
If
(
G_last_local
[
0
]
-
G_fragment
[
i_s2
,
i_v
]
<=
0
):
with
T
.
Then
():
with
T
.
Then
():
V_new_fragment
[
i_s2
,
i_v
]
=
V_new_fragment
[
i_s2
,
i_v
]
*
T
.
exp
(
V_new_fragment
[
i_s2
,
i_v
]
=
V_new_fragment
[
i_s2
,
i_v
]
*
T
.
exp2
(
G_last_local
[
0
]
-
G_fragment
[
i_s2
,
i_v
])
(
G_last_local
[
0
]
-
G_fragment
[
i_s2
,
i_v
])
*
1.442695
)
with
T
.
Else
():
with
T
.
Else
():
V_new_fragment
[
i_s2
,
i_v
]
=
0
V_new_fragment
[
i_s2
,
i_v
]
=
0
G_last_local
[
0
]
=
T
.
exp
(
G_last_local
[
0
])
G_last_local
[
0
]
=
T
.
exp
2
(
G_last_local
[
0
]
*
1.442695
)
for
i_k
,
i_v
in
T
.
Parallel
(
DK
,
block_DV
):
for
i_k
,
i_v
in
T
.
Parallel
(
DK
,
block_DV
):
b_h_fragment
[
i_k
,
i_v
]
*=
G_last_local
[
0
]
b_h_fragment
[
i_k
,
i_v
]
*=
G_last_local
[
0
]
...
@@ -209,7 +225,7 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
...
@@ -209,7 +225,7 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
# Save final state
# Save final state
if
store_final_state
:
if
store_final_state
:
T
.
copy
(
b_h_fragment
,
final_state
[
bb
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
T
.
copy
(
b_h_fragment
,
final_state
[
bb
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
return
kernel
return
kernel
...
@@ -260,47 +276,77 @@ def run_test(
...
@@ -260,47 +276,77 @@ def run_test(
threads
=
128
,
threads
=
128
,
num_stages
=
0
,
num_stages
=
0
,
):
):
K
,
W
,
U
,
G
,
initial_state
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
K
,
W
,
U
,
G
,
initial_state
=
prepare_input
(
getattr
(
torch
,
input_dtype
),
B
,
getattr
(
torch
,
output_dtype
),
S
,
getattr
(
torch
,
accum_dtype
),
H
,
getattr
(
torch
,
gate_dtype
))
DK
,
h_ref
,
final_state_ref
,
V_new_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
DV
,
getattr
(
torch
,
output_dtype
),
chunk_size
,
getattr
(
torch
,
state_dtype
))
getattr
(
torch
,
input_dtype
),
h_tilelang
,
final_state_tilelang
,
V_new_tilelang
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
state_dtype
))
getattr
(
torch
,
gate_dtype
),
)
h_ref
,
final_state_ref
,
V_new_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
state_dtype
)
)
h_tilelang
,
final_state_tilelang
,
V_new_tilelang
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
state_dtype
)
)
# fla ref
# fla ref
h_ref
,
V_new_ref
,
final_state_ref
=
chunk_gated_delta_rule_fwd_h
(
K
,
W
,
U
,
G
,
initial_state
,
h_ref
,
V_new_ref
,
final_state_ref
=
chunk_gated_delta_rule_fwd_h
(
store_final_state
,
chunk_size
,
k
=
K
,
save_new_value
)
w
=
W
,
u
=
U
,
g
=
G
,
initial_state
=
initial_state
,
output_final_state
=
store_final_state
,
chunk_size
=
chunk_size
,
save_new_value
=
save_new_value
,
)
# tilelang
# tilelang
kernel
=
tilelang_chunk_gated_delta_rule_fwd_h
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
kernel
=
tilelang_chunk_gated_delta_rule_fwd_h
(
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
B
,
use_g
,
use_initial_state
,
store_final_state
,
S
,
save_new_value
,
block_DK
,
block_DV
,
threads
,
H
,
num_stages
)
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
use_g
,
use_initial_state
,
store_final_state
,
save_new_value
,
)
h_tilelang
,
final_state_tilelang
,
V_new_tilelang
=
kernel
(
K
,
W
,
U
,
G
,
initial_state
)
h_tilelang
,
final_state_tilelang
,
V_new_tilelang
=
kernel
(
K
,
W
,
U
,
G
,
initial_state
)
# (zhengju) If you want to print the generated cuda code, you can uncomment the following line
# (zhengju) If you want to print the generated cuda code, you can uncomment the following line
# print("CUDA Code:\n", kernel.get_kernel_source())
# print("CUDA Code:\n", kernel.get_kernel_source())
fla_time
=
do_bench
(
chunk_gated_delta_rule_fwd_h
,
K
,
W
,
U
,
G
,
initial_state
,
store_final_state
,
fla_time
=
do_bench
(
chunk_size
,
save_new_value
)
chunk_gated_delta_rule_fwd_h
,
k
=
K
,
w
=
W
,
u
=
U
,
g
=
G
,
initial_state
=
initial_state
,
output_final_state
=
store_final_state
,
chunk_size
=
chunk_size
,
save_new_value
=
save_new_value
,
)
tilelang_time
=
do_bench
(
kernel
,
K
,
W
,
U
,
G
,
initial_state
)
tilelang_time
=
do_bench
(
kernel
,
K
,
W
,
U
,
G
,
initial_state
)
# check correctness
# check correctness
try
:
try
:
h_ref_fp32
=
h_ref
.
to
(
torch
.
float32
)
h_ref_fp32
=
h_ref
.
to
(
torch
.
float32
)
h_tilelang_fp32
=
h_tilelang
.
to
(
torch
.
float32
)
h_tilelang_fp32
=
h_tilelang
.
to
(
torch
.
float32
)
assert_similar
(
assert_similar
(
h_ref_fp32
,
h_tilelang_fp32
,
eps
=
1e-5
,
name
=
"tilelang chunk gated delta rule fwd h"
,
raise_assert
=
False
)
h_ref_fp32
,
h_tilelang_fp32
,
eps
=
1e-5
,
name
=
"tilelang chunk gated delta rule fwd h"
,
raise_assert
=
False
)
print
(
"tilelang chunk gated delta rule fwd h passed √"
)
print
(
"tilelang chunk gated delta rule fwd h passed √"
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
"tilelang chunk gated delta rule fwd h failed ✗"
)
print
(
"tilelang chunk gated delta rule fwd h failed ✗"
)
...
@@ -314,7 +360,8 @@ def run_test(
...
@@ -314,7 +360,8 @@ def run_test(
final_state_tilelang_fp32
,
final_state_tilelang_fp32
,
eps
=
1e-5
,
eps
=
1e-5
,
name
=
"tilelang chunk gated delta rule fwd final_state"
,
name
=
"tilelang chunk gated delta rule fwd final_state"
,
raise_assert
=
False
)
raise_assert
=
False
,
)
print
(
"tilelang chunk gated delta rule fwd final_state passed √"
)
print
(
"tilelang chunk gated delta rule fwd final_state passed √"
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
"tilelang chunk gated delta rule fwd final_state failed ✗"
)
print
(
"tilelang chunk gated delta rule fwd final_state failed ✗"
)
...
@@ -323,12 +370,7 @@ def run_test(
...
@@ -323,12 +370,7 @@ def run_test(
try
:
try
:
V_new_ref_fp32
=
V_new_ref
.
to
(
torch
.
float32
)
V_new_ref_fp32
=
V_new_ref
.
to
(
torch
.
float32
)
V_new_tilelang_fp32
=
V_new_tilelang
.
to
(
torch
.
float32
)
V_new_tilelang_fp32
=
V_new_tilelang
.
to
(
torch
.
float32
)
assert_similar
(
assert_similar
(
V_new_ref_fp32
,
V_new_tilelang_fp32
,
eps
=
1e-5
,
name
=
"tilelang chunk gated delta rule fwd V_new"
,
raise_assert
=
False
)
V_new_ref_fp32
,
V_new_tilelang_fp32
,
eps
=
1e-5
,
name
=
"tilelang chunk gated delta rule fwd V_new"
,
raise_assert
=
False
)
print
(
"tilelang chunk gated delta rule fwd V_new passed √"
)
print
(
"tilelang chunk gated delta rule fwd V_new passed √"
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
"tilelang chunk gated delta rule fwd V_new failed ✗"
)
print
(
"tilelang chunk gated delta rule fwd V_new failed ✗"
)
...
@@ -345,20 +387,20 @@ def main():
...
@@ -345,20 +387,20 @@ def main():
H
=
32
,
H
=
32
,
DK
=
128
,
DK
=
128
,
DV
=
128
,
DV
=
128
,
input_dtype
=
"
bfloat16
"
,
input_dtype
=
T
.
bfloat16
,
output_dtype
=
"
bfloat16
"
,
output_dtype
=
T
.
bfloat16
,
accum_dtype
=
"
float32
"
,
accum_dtype
=
T
.
float32
,
gate_dtype
=
"
float32
"
,
gate_dtype
=
T
.
float32
,
state_dtype
=
"
float32
"
,
state_dtype
=
T
.
float32
,
chunk_size
=
64
,
chunk_size
=
64
,
use_g
=
True
,
use_g
=
True
,
use_initial_state
=
Tru
e
,
use_initial_state
=
Fals
e
,
store_final_state
=
True
,
store_final_state
=
True
,
save_new_value
=
True
,
save_new_value
=
True
,
block_DK
=
64
,
block_DK
=
32
,
block_DV
=
32
,
block_DV
=
32
,
threads
=
128
,
threads
=
128
,
num_stages
=
1
,
num_stages
=
2
,
)
)
...
...
examples/gdn/example_chunk_o.py
View file @
667632cc
...
@@ -9,6 +9,7 @@ import sys # noqa: F401
...
@@ -9,6 +9,7 @@ import sys # noqa: F401
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try
:
try
:
import
fla
import
fla
print
(
fla
.
__file__
)
print
(
fla
.
__file__
)
from
fla.ops.common.chunk_o
import
chunk_fwd_o
from
fla.ops.common.chunk_o
import
chunk_fwd_o
except
ImportError
:
except
ImportError
:
...
@@ -87,16 +88,14 @@ def tilelang_chunk_fwd_o(
...
@@ -87,16 +88,14 @@ def tilelang_chunk_fwd_o(
@
T
.
prim_func
@
T
.
prim_func
def
kernel
(
def
kernel
(
Q
:
T
.
Tensor
(
Q_shape
,
dtype
=
input_dtype
),
Q
:
T
.
Tensor
(
Q_shape
,
dtype
=
input_dtype
),
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
V
:
T
.
Tensor
(
V_shape
,
dtype
=
input_dtype
),
V
:
T
.
Tensor
(
V_shape
,
dtype
=
input_dtype
),
HIDDEN
:
T
.
Tensor
(
H_shape
,
dtype
=
input_dtype
),
HIDDEN
:
T
.
Tensor
(
H_shape
,
dtype
=
input_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
gate_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
gate_dtype
),
O
:
T
.
Tensor
(
O_shape
,
dtype
=
output_dtype
),
O
:
T
.
Tensor
(
O_shape
,
dtype
=
output_dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
DV
,
block_DV
),
T
.
ceildiv
(
S
,
block_S
),
B
*
H
,
threads
=
threads
)
as
(
bv
,
bs
,
bbh
):
T
.
ceildiv
(
DV
,
block_DV
),
T
.
ceildiv
(
S
,
block_S
),
B
*
H
,
threads
=
threads
)
as
(
bv
,
bs
,
bbh
):
bb
,
bh
=
bbh
//
H
,
bbh
%
H
bb
,
bh
=
bbh
//
H
,
bbh
%
H
Q_shared
=
T
.
alloc_shared
((
block_S
,
block_DK
),
dtype
=
input_dtype
)
Q_shared
=
T
.
alloc_shared
((
block_S
,
block_DK
),
dtype
=
input_dtype
)
K_shared
=
T
.
alloc_shared
((
block_S
,
block_DK
),
dtype
=
input_dtype
)
K_shared
=
T
.
alloc_shared
((
block_S
,
block_DK
),
dtype
=
input_dtype
)
...
@@ -109,28 +108,24 @@ def tilelang_chunk_fwd_o(
...
@@ -109,28 +108,24 @@ def tilelang_chunk_fwd_o(
G_shared
=
T
.
alloc_shared
((
block_S
,),
dtype
=
gate_dtype
,
scope
=
"shared"
)
G_shared
=
T
.
alloc_shared
((
block_S
,),
dtype
=
gate_dtype
,
scope
=
"shared"
)
G_diff_local
=
T
.
alloc_fragment
((
block_S
,
block_S
),
dtype
=
gate_dtype
)
G_diff_local
=
T
.
alloc_fragment
((
block_S
,
block_S
),
dtype
=
gate_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
Q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared
),
{
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
Q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared
),
V_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
V_shared
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
H_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
H_shared
),
V_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
V_shared
),
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
H_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
H_shared
),
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
})
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
}
)
T
.
clear
(
A_fragment
)
T
.
clear
(
A_fragment
)
T
.
clear
(
O_fragment
)
T
.
clear
(
O_fragment
)
T
.
disable_warp_group_reg_alloc
()
T
.
disable_warp_group_reg_alloc
()
for
i_k
in
T
.
Pipelined
(
T
.
ceildiv
(
DK
,
block_DK
),
num_stages
=
num_stages
):
for
i_k
in
T
.
Pipelined
(
T
.
ceildiv
(
DK
,
block_DK
),
num_stages
=
num_stages
):
T
.
copy
(
T
.
copy
(
Q
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
],
Q_shared
)
Q
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
],
T
.
copy
(
K
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
],
K_shared
)
Q_shared
)
T
.
copy
(
HIDDEN
[
bb
,
bs
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
],
H_shared
)
T
.
copy
(
K
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
],
K_shared
)
T
.
copy
(
HIDDEN
[
bb
,
bs
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
,
bv
*
block_DV
:(
bv
+
1
)
*
block_DV
],
H_shared
)
T
.
gemm
(
Q_shared
,
H_shared
,
O_fragment
)
T
.
gemm
(
Q_shared
,
H_shared
,
O_fragment
)
T
.
gemm
(
Q_shared
,
K_shared
,
A_fragment
,
transpose_B
=
True
)
T
.
gemm
(
Q_shared
,
K_shared
,
A_fragment
,
transpose_B
=
True
)
...
@@ -145,8 +140,7 @@ def tilelang_chunk_fwd_o(
...
@@ -145,8 +140,7 @@ def tilelang_chunk_fwd_o(
for
i_s1
,
i_s2
in
T
.
Parallel
(
block_S
,
block_S
):
for
i_s1
,
i_s2
in
T
.
Parallel
(
block_S
,
block_S
):
with
T
.
If
(
G_diff_local
[
i_s1
,
i_s2
]
<=
0
):
with
T
.
If
(
G_diff_local
[
i_s1
,
i_s2
]
<=
0
):
with
T
.
Then
():
with
T
.
Then
():
A_fragment
[
i_s1
,
i_s2
]
=
A_fragment
[
i_s1
,
i_s2
]
*
T
.
exp
(
A_fragment
[
i_s1
,
i_s2
]
=
A_fragment
[
i_s1
,
i_s2
]
*
T
.
exp
(
G_diff_local
[
i_s1
,
i_s2
])
G_diff_local
[
i_s1
,
i_s2
])
with
T
.
Else
():
with
T
.
Else
():
A_fragment
[
i_s1
,
i_s2
]
=
0
A_fragment
[
i_s1
,
i_s2
]
=
0
...
@@ -155,8 +149,7 @@ def tilelang_chunk_fwd_o(
...
@@ -155,8 +149,7 @@ def tilelang_chunk_fwd_o(
with
T
.
Then
():
with
T
.
Then
():
A_fragment
[
i_s1
,
i_s2
]
=
0
A_fragment
[
i_s1
,
i_s2
]
=
0
T
.
copy
(
V
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:(
bv
+
1
)
*
block_DV
],
T
.
copy
(
V
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
],
V_shared
)
V_shared
)
T
.
copy
(
A_fragment
,
A_shared
)
T
.
copy
(
A_fragment
,
A_shared
)
T
.
gemm
(
A_shared
,
V_shared
,
O_fragment
)
T
.
gemm
(
A_shared
,
V_shared
,
O_fragment
)
...
@@ -164,8 +157,7 @@ def tilelang_chunk_fwd_o(
...
@@ -164,8 +157,7 @@ def tilelang_chunk_fwd_o(
O_fragment
[
i_s
,
i_v
]
=
O_fragment
[
i_s
,
i_v
]
*
scale
O_fragment
[
i_s
,
i_v
]
=
O_fragment
[
i_s
,
i_v
]
*
scale
T
.
copy
(
O_fragment
,
O_shared
)
T
.
copy
(
O_fragment
,
O_shared
)
T
.
copy
(
O_shared
,
O
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
T
.
copy
(
O_shared
,
O
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
bv
*
block_DV
:(
bv
+
1
)
*
block_DV
])
return
kernel
return
kernel
...
@@ -191,8 +183,9 @@ def run_test(
...
@@ -191,8 +183,9 @@ def run_test(
output_dtype_torch
=
getattr
(
torch
,
output_dtype
)
output_dtype_torch
=
getattr
(
torch
,
output_dtype
)
accum_dtype_torch
=
getattr
(
torch
,
accum_dtype
)
accum_dtype_torch
=
getattr
(
torch
,
accum_dtype
)
gate_dtype_torch
=
getattr
(
torch
,
gate_dtype
)
gate_dtype_torch
=
getattr
(
torch
,
gate_dtype
)
Q
,
K
,
V
,
HIDDEN
,
G
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
input_dtype_torch
,
Q
,
K
,
V
,
HIDDEN
,
G
=
prepare_input
(
output_dtype_torch
,
accum_dtype_torch
,
gate_dtype_torch
)
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
input_dtype_torch
,
output_dtype_torch
,
accum_dtype_torch
,
gate_dtype_torch
)
scale
=
1.0
/
DK
**
0.5
scale
=
1.0
/
DK
**
0.5
O_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
output_dtype_torch
)
O_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
output_dtype_torch
)
...
@@ -200,9 +193,25 @@ def run_test(
...
@@ -200,9 +193,25 @@ def run_test(
block_S
=
chunk_size
block_S
=
chunk_size
O_tilelang
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
output_dtype_torch
)
O_tilelang
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
output_dtype_torch
)
kernel
=
tilelang_chunk_fwd_o
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
kernel
=
tilelang_chunk_fwd_o
(
gate_dtype
,
chunk_size
,
scale
,
use_g
,
block_S
,
block_DK
,
block_DV
,
B
,
threads
,
num_stages
)
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
chunk_size
,
scale
,
use_g
,
block_S
,
block_DK
,
block_DV
,
threads
,
num_stages
,
)
O_tilelang
=
kernel
(
Q
,
K
,
V
,
HIDDEN
,
G
)
O_tilelang
=
kernel
(
Q
,
K
,
V
,
HIDDEN
,
G
)
try
:
try
:
...
@@ -221,10 +230,10 @@ def main():
...
@@ -221,10 +230,10 @@ def main():
DK
=
128
,
DK
=
128
,
DV
=
128
,
DV
=
128
,
chunk_size
=
64
,
chunk_size
=
64
,
input_dtype
=
"
bfloat16
"
,
input_dtype
=
T
.
bfloat16
,
output_dtype
=
"
bfloat16
"
,
output_dtype
=
T
.
bfloat16
,
accum_dtype
=
"
float32
"
,
accum_dtype
=
T
.
float32
,
gate_dtype
=
"
float32
"
,
gate_dtype
=
T
.
float32
,
use_g
=
True
,
use_g
=
True
,
block_DK
=
128
,
block_DK
=
128
,
block_DV
=
128
,
block_DV
=
128
,
...
...
examples/gdn/example_chunk_o_bwd.py
View file @
667632cc
...
@@ -12,6 +12,7 @@ from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F4
...
@@ -12,6 +12,7 @@ from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F4
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try
:
try
:
import
fla
import
fla
print
(
fla
.
__file__
)
print
(
fla
.
__file__
)
from
fla.ops.common.chunk_o
import
chunk_bwd_dqkwg
from
fla.ops.common.chunk_o
import
chunk_bwd_dqkwg
except
ImportError
:
except
ImportError
:
...
@@ -19,7 +20,7 @@ except ImportError:
...
@@ -19,7 +20,7 @@ except ImportError:
fla
=
None
fla
=
None
import
torch
import
torch
from
utils
import
*
from
test_
utils
import
assert_similar
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
# torch.set_printoptions(profile="full")
# torch.set_printoptions(profile="full")
...
@@ -108,10 +109,8 @@ def prepare_output(
...
@@ -108,10 +109,8 @@ def prepare_output(
@
tilelang
.
jit
(
@
tilelang
.
jit
(
out_idx
=
[
-
4
,
-
3
,
-
2
,
-
1
],
out_idx
=
[
-
4
,
-
3
,
-
2
,
-
1
],
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
},
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
)
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
})
def
tilelang_chunk_o_bwd_dqkwg
(
def
tilelang_chunk_o_bwd_dqkwg
(
# task config
# task config
B
,
B
,
...
@@ -155,25 +154,23 @@ def tilelang_chunk_o_bwd_dqkwg(
...
@@ -155,25 +154,23 @@ def tilelang_chunk_o_bwd_dqkwg(
@
T
.
prim_func
@
T
.
prim_func
def
kernel
(
def
kernel
(
# input
# input
Q
:
T
.
Tensor
(
Q_shape
,
dtype
=
input_dtype
),
Q
:
T
.
Tensor
(
Q_shape
,
dtype
=
input_dtype
),
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
V
:
T
.
Tensor
(
V_shape
,
dtype
=
input_dtype
),
V
:
T
.
Tensor
(
V_shape
,
dtype
=
input_dtype
),
h
:
T
.
Tensor
(
h_shape
,
dtype
=
input_dtype
),
h
:
T
.
Tensor
(
h_shape
,
dtype
=
input_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
gate_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
gate_dtype
),
dO
:
T
.
Tensor
(
dO_shape
,
dtype
=
input_dtype
),
dO
:
T
.
Tensor
(
dO_shape
,
dtype
=
input_dtype
),
dh
:
T
.
Tensor
(
dh_shape
,
dtype
=
input_dtype
),
dh
:
T
.
Tensor
(
dh_shape
,
dtype
=
input_dtype
),
dv
:
T
.
Tensor
(
dv_shape
,
dtype
=
input_dtype
),
dv
:
T
.
Tensor
(
dv_shape
,
dtype
=
input_dtype
),
W
:
T
.
Tensor
(
W_shape
,
dtype
=
input_dtype
),
W
:
T
.
Tensor
(
W_shape
,
dtype
=
input_dtype
),
# output
# output
dq
:
T
.
Tensor
(
dq_shape
,
dtype
=
output_dtype
),
dq
:
T
.
Tensor
(
dq_shape
,
dtype
=
output_dtype
),
dk
:
T
.
Tensor
(
dk_shape
,
dtype
=
output_dtype
),
dk
:
T
.
Tensor
(
dk_shape
,
dtype
=
output_dtype
),
dw
:
T
.
Tensor
(
dw_shape
,
dtype
=
output_dtype
),
dw
:
T
.
Tensor
(
dw_shape
,
dtype
=
output_dtype
),
dg
:
T
.
Tensor
(
dg_shape
,
dtype
=
gate_dtype
),
dg
:
T
.
Tensor
(
dg_shape
,
dtype
=
gate_dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
DK
,
block_DK
),
T
.
ceildiv
(
S
,
block_S
),
B
*
H
,
threads
=
threads
)
as
(
bk
,
bs
,
bbh
):
T
.
ceildiv
(
DK
,
block_DK
),
T
.
ceildiv
(
S
,
block_S
),
B
*
H
,
threads
=
threads
)
as
(
bk
,
bs
,
bbh
):
bb
,
bh
=
bbh
//
H
,
bbh
%
H
bb
,
bh
=
bbh
//
H
,
bbh
%
H
V_shared
=
T
.
alloc_shared
((
block_S
,
block_DV
),
dtype
=
input_dtype
)
V_shared
=
T
.
alloc_shared
((
block_S
,
block_DV
),
dtype
=
input_dtype
)
...
@@ -212,15 +209,17 @@ def tilelang_chunk_o_bwd_dqkwg(
...
@@ -212,15 +209,17 @@ def tilelang_chunk_o_bwd_dqkwg(
T
.
use_swizzle
(
10
)
T
.
use_swizzle
(
10
)
T
.
annotate_layout
({
T
.
annotate_layout
(
V_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
V_shared
),
{
dO_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dO_shared
),
V_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
V_shared
),
h_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
h_shared
),
dO_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dO_shared
),
dh_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dh_shared
),
h_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
h_shared
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
dh_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dh_shared
),
q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
q_shared
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
k_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
k_shared
),
q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
q_shared
),
})
k_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
k_shared
),
}
)
T
.
clear
(
dg_last_local
)
T
.
clear
(
dg_last_local
)
T
.
clear
(
G_last_local
)
T
.
clear
(
G_last_local
)
...
@@ -235,18 +234,10 @@ def tilelang_chunk_o_bwd_dqkwg(
...
@@ -235,18 +234,10 @@ def tilelang_chunk_o_bwd_dqkwg(
T
.
clear
(
dw_fragment
)
T
.
clear
(
dw_fragment
)
for
i_v
in
T
.
Pipelined
(
T
.
ceildiv
(
DV
,
block_DV
),
num_stages
=
num_stages
):
for
i_v
in
T
.
Pipelined
(
T
.
ceildiv
(
DV
,
block_DV
),
num_stages
=
num_stages
):
T
.
copy
(
T
.
copy
(
V
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:
(
i_v
+
1
)
*
block_DV
],
V_shared
)
V
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:(
i_v
+
1
)
*
block_DV
],
T
.
copy
(
dO
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:
(
i_v
+
1
)
*
block_DV
],
dO_shared
)
V_shared
)
T
.
copy
(
h
[
bb
,
bs
,
bh
,
bk
*
block_DK
:
(
bk
+
1
)
*
block_DK
,
i_v
*
block_DV
:
(
i_v
+
1
)
*
block_DV
],
h_shared
)
T
.
copy
(
T
.
copy
(
dh
[
bb
,
bs
,
bh
,
bk
*
block_DK
:
(
bk
+
1
)
*
block_DK
,
i_v
*
block_DV
:
(
i_v
+
1
)
*
block_DV
],
dh_shared
)
dO
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:(
i_v
+
1
)
*
block_DV
],
dO_shared
)
T
.
copy
(
h
[
bb
,
bs
,
bh
,
bk
*
block_DK
:(
bk
+
1
)
*
block_DK
,
i_v
*
block_DV
:(
i_v
+
1
)
*
block_DV
],
h_shared
)
T
.
copy
(
dh
[
bb
,
bs
,
bh
,
bk
*
block_DK
:(
bk
+
1
)
*
block_DK
,
i_v
*
block_DV
:(
i_v
+
1
)
*
block_DV
],
dh_shared
)
if
use_g
:
if
use_g
:
T
.
clear
(
dg_last_fragment_scalar
)
T
.
clear
(
dg_last_fragment_scalar
)
...
@@ -254,9 +245,7 @@ def tilelang_chunk_o_bwd_dqkwg(
...
@@ -254,9 +245,7 @@ def tilelang_chunk_o_bwd_dqkwg(
# for i_kv in T.Parallel(block_DK * block_DV):
# for i_kv in T.Parallel(block_DK * block_DV):
# dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV]
# dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV]
for
i_kv
in
T
.
Parallel
(
block_DK
*
block_DV
):
for
i_kv
in
T
.
Parallel
(
block_DK
*
block_DV
):
dg_last_fragment
[
i_kv
]
=
h_shared
[
i_kv
//
block_DV
,
i_kv
%
dg_last_fragment
[
i_kv
]
=
h_shared
[
i_kv
//
block_DV
,
i_kv
%
block_DV
]
*
dh_shared
[
i_kv
//
block_DV
,
i_kv
%
block_DV
]
block_DV
]
*
dh_shared
[
i_kv
//
block_DV
,
i_kv
%
block_DV
]
T
.
reduce_sum
(
dg_last_fragment
,
dg_last_fragment_scalar
,
dim
=-
1
,
clear
=
False
)
T
.
reduce_sum
(
dg_last_fragment
,
dg_last_fragment_scalar
,
dim
=-
1
,
clear
=
False
)
dg_last_local
[
0
]
+=
dg_last_fragment_scalar
[
0
]
dg_last_local
[
0
]
+=
dg_last_fragment_scalar
[
0
]
...
@@ -265,22 +254,16 @@ def tilelang_chunk_o_bwd_dqkwg(
...
@@ -265,22 +254,16 @@ def tilelang_chunk_o_bwd_dqkwg(
T
.
gemm
(
V_shared
,
dh_shared
,
dk_fragment
,
transpose_B
=
True
)
T
.
gemm
(
V_shared
,
dh_shared
,
dk_fragment
,
transpose_B
=
True
)
if
use_dw
:
if
use_dw
:
T
.
copy
(
T
.
copy
(
dv
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:
(
i_v
+
1
)
*
block_DV
],
dv_shared
)
dv
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:(
i_v
+
1
)
*
block_DV
],
dv_shared
)
T
.
gemm
(
dv_shared
,
h_shared
,
dw_fragment
,
transpose_B
=
True
)
T
.
gemm
(
dv_shared
,
h_shared
,
dw_fragment
,
transpose_B
=
True
)
if
use_dw
:
if
use_dw
:
for
i_s
,
i_k
in
T
.
Parallel
(
block_S
,
block_DK
):
for
i_s
,
i_k
in
T
.
Parallel
(
block_S
,
block_DK
):
dw_fragment
[
i_s
,
i_k
]
=
-
dw_fragment
[
i_s
,
i_k
]
dw_fragment
[
i_s
,
i_k
]
=
-
dw_fragment
[
i_s
,
i_k
]
T
.
copy
(
T
.
copy
(
dw_fragment
,
dw
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:
(
bk
+
1
)
*
block_DK
])
dw_fragment
,
dw
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:(
bk
+
1
)
*
block_DK
])
T
.
copy
(
Q
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:
(
bk
+
1
)
*
block_DK
],
q_shared
)
T
.
copy
(
K
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:
(
bk
+
1
)
*
block_DK
],
k_shared
)
T
.
copy
(
Q
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:(
bk
+
1
)
*
block_DK
],
q_shared
)
T
.
copy
(
K
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:(
bk
+
1
)
*
block_DK
],
k_shared
)
T
.
copy
(
q_shared
,
q_fragment
)
T
.
copy
(
q_shared
,
q_fragment
)
T
.
copy
(
k_shared
,
k_fragment
)
T
.
copy
(
k_shared
,
k_fragment
)
...
@@ -294,8 +277,7 @@ def tilelang_chunk_o_bwd_dqkwg(
...
@@ -294,8 +277,7 @@ def tilelang_chunk_o_bwd_dqkwg(
dg_last_local
[
0
]
=
dg_last_local
[
0
]
*
T
.
exp
(
G
[
bb
,
bs
*
block_S
+
block_S
-
1
,
bh
])
dg_last_local
[
0
]
=
dg_last_local
[
0
]
*
T
.
exp
(
G
[
bb
,
bs
*
block_S
+
block_S
-
1
,
bh
])
for
i_s
,
i_k
in
T
.
Parallel
(
block_S
,
block_DK
):
for
i_s
,
i_k
in
T
.
Parallel
(
block_S
,
block_DK
):
dq_fragment
[
i_s
,
i_k
]
=
dq_fragment
[
i_s
,
i_k
]
*
T
.
exp
(
G
[
bb
,
bs
*
block_S
+
i_s
,
dq_fragment
[
i_s
,
i_k
]
=
dq_fragment
[
i_s
,
i_k
]
*
T
.
exp
(
G
[
bb
,
bs
*
block_S
+
i_s
,
bh
])
*
scale
bh
])
*
scale
T
.
clear
(
dg_fragment_reduce_tmp
)
T
.
clear
(
dg_fragment_reduce_tmp
)
for
i_s
,
i_k
in
T
.
Parallel
(
block_S
,
block_DK
):
for
i_s
,
i_k
in
T
.
Parallel
(
block_S
,
block_DK
):
dg_fragment_reduce_tmp
[
i_s
,
i_k
]
=
dq_fragment
[
i_s
,
i_k
]
*
q_shared
[
i_s
,
i_k
]
dg_fragment_reduce_tmp
[
i_s
,
i_k
]
=
dq_fragment
[
i_s
,
i_k
]
*
q_shared
[
i_s
,
i_k
]
...
@@ -305,8 +287,7 @@ def tilelang_chunk_o_bwd_dqkwg(
...
@@ -305,8 +287,7 @@ def tilelang_chunk_o_bwd_dqkwg(
for
i_s
,
i_k
in
T
.
Parallel
(
block_S
,
block_DK
):
for
i_s
,
i_k
in
T
.
Parallel
(
block_S
,
block_DK
):
with
T
.
If
(
G_last_local
[
0
]
-
G
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
<=
0
):
with
T
.
If
(
G_last_local
[
0
]
-
G
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
<=
0
):
with
T
.
Then
():
with
T
.
Then
():
dk_fragment
[
i_s
,
i_k
]
=
dk_fragment
[
i_s
,
i_k
]
*
T
.
exp
(
dk_fragment
[
i_s
,
i_k
]
=
dk_fragment
[
i_s
,
i_k
]
*
T
.
exp
(
G_last_local
[
0
]
-
G
[
bb
,
bs
*
block_S
+
i_s
,
bh
])
G_last_local
[
0
]
-
G
[
bb
,
bs
*
block_S
+
i_s
,
bh
])
with
T
.
Else
():
with
T
.
Else
():
dk_fragment
[
i_s
,
i_k
]
=
0
dk_fragment
[
i_s
,
i_k
]
=
0
T
.
clear
(
dg_fragment_reduce_tmp
)
T
.
clear
(
dg_fragment_reduce_tmp
)
...
@@ -325,12 +306,11 @@ def tilelang_chunk_o_bwd_dqkwg(
...
@@ -325,12 +306,11 @@ def tilelang_chunk_o_bwd_dqkwg(
dg_last_local
[
1
]
=
dg_last_fragment_scalar_2
[
0
]
dg_last_local
[
1
]
=
dg_last_fragment_scalar_2
[
0
]
for
i_s1
,
i_s2
in
T
.
Parallel
(
block_S
,
block_S
):
for
i_s1
,
i_s2
in
T
.
Parallel
(
block_S
,
block_S
):
with
T
.
If
(
i_s1
>=
i_s2
and
with
T
.
If
(
i_s1
>=
i_s2
and
G
[
bb
,
bs
*
block_S
+
i_s1
,
bh
]
-
G
[
bb
,
bs
*
block_S
+
i_s2
,
bh
]
<=
0
):
G
[
bb
,
bs
*
block_S
+
i_s1
,
bh
]
-
G
[
bb
,
bs
*
block_S
+
i_s2
,
bh
]
<=
0
):
with
T
.
Then
():
with
T
.
Then
():
ds_fragment
[
i_s1
,
i_s2
]
=
ds_fragment
[
ds_fragment
[
i_s1
,
i_s2
]
=
(
i_s1
,
i_s2
]
*
T
.
exp
(
G
[
bb
,
bs
*
block_S
+
i_s1
,
bh
]
-
ds_fragment
[
i_s1
,
i_s2
]
*
T
.
exp
(
G
[
bb
,
bs
*
block_S
+
i_s1
,
bh
]
-
G
[
bb
,
bs
*
block_S
+
i_s2
,
bh
])
*
scale
G
[
bb
,
bs
*
block_S
+
i_s2
,
bh
])
*
scale
)
with
T
.
Else
():
with
T
.
Else
():
ds_fragment
[
i_s1
,
i_s2
]
=
0
ds_fragment
[
i_s1
,
i_s2
]
=
0
...
@@ -338,8 +318,7 @@ def tilelang_chunk_o_bwd_dqkwg(
...
@@ -338,8 +318,7 @@ def tilelang_chunk_o_bwd_dqkwg(
T
.
clear
(
ds_fragment_positive_transpose
)
T
.
clear
(
ds_fragment_positive_transpose
)
T
.
gemm
(
q_shared
,
k_shared
,
ds_fragment_positive
,
transpose_B
=
True
)
T
.
gemm
(
q_shared
,
k_shared
,
ds_fragment_positive
,
transpose_B
=
True
)
for
i_s1
,
i_s2
in
T
.
Parallel
(
block_S
,
block_S
):
for
i_s1
,
i_s2
in
T
.
Parallel
(
block_S
,
block_S
):
ds_fragment_positive
[
ds_fragment_positive
[
i_s1
,
i_s2
]
=
ds_fragment
[
i_s1
,
i_s2
]
*
ds_fragment_positive
[
i_s1
,
i_s2
]
i_s1
,
i_s2
]
=
ds_fragment
[
i_s1
,
i_s2
]
*
ds_fragment_positive
[
i_s1
,
i_s2
]
# FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass
# FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass
T
.
reduce_sum
(
ds_fragment_positive
,
dg_fragment
,
dim
=
1
,
clear
=
False
)
T
.
reduce_sum
(
ds_fragment_positive
,
dg_fragment
,
dim
=
1
,
clear
=
False
)
...
@@ -363,15 +342,10 @@ def tilelang_chunk_o_bwd_dqkwg(
...
@@ -363,15 +342,10 @@ def tilelang_chunk_o_bwd_dqkwg(
for
i_s
in
T
.
Parallel
(
block_S
):
for
i_s
in
T
.
Parallel
(
block_S
):
with
T
.
If
(
i_s
>=
block_S
-
1
):
# noqa: SIM117
with
T
.
If
(
i_s
>=
block_S
-
1
):
# noqa: SIM117
with
T
.
Then
():
with
T
.
Then
():
dg_fragment_final
[
dg_fragment_final
[
i_s
]
=
dg_fragment_final
[
i_s
]
+
dg_last_local
[
0
]
+
dg_last_local
[
1
]
i_s
]
=
dg_fragment_final
[
i_s
]
+
dg_last_local
[
0
]
+
dg_last_local
[
1
]
T
.
copy
(
dq_fragment
,
dq
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:
(
bk
+
1
)
*
block_DK
])
T
.
copy
(
T
.
copy
(
dk_fragment
,
dk
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:
(
bk
+
1
)
*
block_DK
])
dq_fragment
,
dq
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:(
bk
+
1
)
*
block_DK
])
T
.
copy
(
dk_fragment
,
dk
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:(
bk
+
1
)
*
block_DK
])
for
i_s
in
T
.
Parallel
(
block_S
):
for
i_s
in
T
.
Parallel
(
block_S
):
dg
[
bk
,
bb
,
bs
*
block_S
+
i_s
,
bh
]
=
dg_fragment_final
[
i_s
]
dg
[
bk
,
bb
,
bs
*
block_S
+
i_s
,
bh
]
=
dg_fragment_final
[
i_s
]
...
@@ -387,12 +361,8 @@ def tilelang_chunk_o_bwd_dqkwg(
...
@@ -387,12 +361,8 @@ def tilelang_chunk_o_bwd_dqkwg(
for
i_s
,
i_k
in
T
.
Parallel
(
block_S
,
block_DK
):
for
i_s
,
i_k
in
T
.
Parallel
(
block_S
,
block_DK
):
dq_fragment
[
i_s
,
i_k
]
=
dq_fragment
[
i_s
,
i_k
]
*
scale
dq_fragment
[
i_s
,
i_k
]
=
dq_fragment
[
i_s
,
i_k
]
*
scale
dk_fragment
[
i_s
,
i_k
]
=
dk_fragment
[
i_s
,
i_k
]
+
dk_fragment_2
[
i_s
,
i_k
]
*
scale
dk_fragment
[
i_s
,
i_k
]
=
dk_fragment
[
i_s
,
i_k
]
+
dk_fragment_2
[
i_s
,
i_k
]
*
scale
T
.
copy
(
T
.
copy
(
dq_fragment
,
dq
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:
(
bk
+
1
)
*
block_DK
])
dq_fragment
,
dq
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
T
.
copy
(
dk_fragment
,
dk
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:
(
bk
+
1
)
*
block_DK
])
bk
*
block_DK
:(
bk
+
1
)
*
block_DK
])
T
.
copy
(
dk_fragment
,
dk
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:(
bk
+
1
)
*
block_DK
])
return
kernel
return
kernel
...
@@ -442,33 +412,53 @@ def run_test(
...
@@ -442,33 +412,53 @@ def run_test(
threads
=
256
,
threads
=
256
,
num_stages
=
0
,
num_stages
=
0
,
):
):
Q
,
K
,
V
,
h
,
G
,
dO
,
dh
,
dv
,
W
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
Q
,
K
,
V
,
h
,
G
,
dO
,
dh
,
dv
,
W
=
prepare_input
(
getattr
(
torch
,
input_dtype
),
B
,
getattr
(
torch
,
output_dtype
),
S
,
getattr
(
torch
,
accum_dtype
),
H
,
getattr
(
torch
,
gate_dtype
),
DK
,
getattr
(
torch
,
state_dtype
))
DV
,
dq_ref
,
dk_ref
,
dw_ref
,
dg_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
state_dtype
),
block_DK
)
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
),
)
dq_ref
,
dk_ref
,
dw_ref
,
dg_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
),
block_DK
)
dq_tilelang
,
dk_tilelang
,
dw_tilelang
,
dg_tilelang
=
prepare_output
(
dq_tilelang
,
dk_tilelang
,
dw_tilelang
,
dg_tilelang
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
),
block_DK
getattr
(
torch
,
state_dtype
),
block_DK
)
)
# ref
# ref
if
use_g
:
if
use_g
:
dq_ref
,
dk_ref
,
dw_ref
,
dg_ref
=
chunk_bwd_dqkwg
(
dq_ref
,
dk_ref
,
dw_ref
,
dg_ref
=
chunk_bwd_dqkwg
(
Q
,
K
,
V
,
G
,
dO
,
h
,
dh
,
dv
,
W
,
chunk_size
=
chunk_size
,
scale
=
scale
)
Q
,
K
,
V
,
G
,
dO
,
h
,
dh
,
dv
,
W
,
chunk_size
=
chunk_size
,
scale
=
scale
)
else
:
else
:
dq_ref
,
dk_ref
,
dw_ref
,
dg_ref
=
chunk_bwd_dqkwg
(
dq_ref
,
dk_ref
,
dw_ref
,
dg_ref
=
chunk_bwd_dqkwg
(
Q
,
K
,
V
,
None
,
dO
,
h
,
dh
,
dv
,
W
,
chunk_size
=
chunk_size
,
scale
=
scale
)
Q
,
K
,
V
,
None
,
dO
,
h
,
dh
,
dv
,
W
,
chunk_size
=
chunk_size
,
scale
=
scale
)
# tilelang
# tilelang
kernel
=
tilelang_chunk_o_bwd_dqkwg
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
kernel
=
tilelang_chunk_o_bwd_dqkwg
(
gate_dtype
,
state_dtype
,
chunk_size
,
scale
,
use_g
,
use_dw
,
B
,
block_DK
,
block_DV
,
threads
,
num_stages
)
S
,
print
(
kernel
.
get_kernel_source
())
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
scale
,
use_g
,
use_dw
,
block_DK
,
block_DV
,
threads
,
num_stages
,
)
dq_tilelang
,
dk_tilelang
,
dw_tilelang
,
dg_tilelang
=
kernel
(
Q
,
K
,
V
,
h
,
G
,
dO
,
dh
,
dv
,
W
)
dq_tilelang
,
dk_tilelang
,
dw_tilelang
,
dg_tilelang
=
kernel
(
Q
,
K
,
V
,
h
,
G
,
dO
,
dh
,
dv
,
W
)
if
use_g
:
if
use_g
:
...
@@ -515,11 +505,11 @@ def main():
...
@@ -515,11 +505,11 @@ def main():
H
=
8
,
H
=
8
,
DK
=
DK
,
DK
=
DK
,
DV
=
DV
,
DV
=
DV
,
input_dtype
=
"
bfloat16
"
,
input_dtype
=
T
.
bfloat16
,
output_dtype
=
"
bfloat16
"
,
output_dtype
=
T
.
bfloat16
,
accum_dtype
=
"
float32
"
,
accum_dtype
=
T
.
float32
,
gate_dtype
=
"
float32
"
,
gate_dtype
=
T
.
float32
,
state_dtype
=
"
float32
"
,
state_dtype
=
T
.
float32
,
chunk_size
=
64
,
chunk_size
=
64
,
scale
=
DK
**-
0.5
,
scale
=
DK
**-
0.5
,
# scale=1,
# scale=1,
...
...
examples/gdn/example_chunk_scaled_dot_kkt.py
View file @
667632cc
...
@@ -9,6 +9,7 @@ import sys # noqa: F401
...
@@ -9,6 +9,7 @@ import sys # noqa: F401
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try
:
try
:
import
fla
import
fla
print
(
fla
.
__file__
)
print
(
fla
.
__file__
)
from
fla.ops.common.chunk_scaled_dot_kkt
import
chunk_scaled_dot_kkt_fwd
from
fla.ops.common.chunk_scaled_dot_kkt
import
chunk_scaled_dot_kkt_fwd
except
ImportError
:
except
ImportError
:
...
@@ -56,9 +57,9 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
...
@@ -56,9 +57,9 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
H
,
H
,
DK
,
DK
,
chunk_size
=
64
,
chunk_size
=
64
,
input_dtype
=
"
bfloat16
"
,
input_dtype
=
T
.
bfloat16
,
output_dtype
=
"
bfloat16
"
,
output_dtype
=
T
.
bfloat16
,
accum_dtype
=
"
float32
"
,
accum_dtype
=
T
.
float32
,
use_g
=
True
,
use_g
=
True
,
# kernel config
# kernel config
block_S
=
64
,
block_S
=
64
,
...
@@ -75,10 +76,10 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
...
@@ -75,10 +76,10 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
@
T
.
prim_func
@
T
.
prim_func
def
kernel
(
def
kernel
(
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
Beta
:
T
.
Tensor
(
Beta_shape
,
dtype
=
input_dtype
),
Beta
:
T
.
Tensor
(
Beta_shape
,
dtype
=
input_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
accum_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
accum_dtype
),
A
:
T
.
Tensor
(
output_shape
,
dtype
=
output_dtype
),
A
:
T
.
Tensor
(
output_shape
,
dtype
=
output_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
S
,
block_S
),
B
*
H
,
threads
=
threads
)
as
(
bs
,
bbh
):
with
T
.
Kernel
(
T
.
ceildiv
(
S
,
block_S
),
B
*
H
,
threads
=
threads
)
as
(
bs
,
bbh
):
bb
,
bh
=
bbh
//
H
,
bbh
%
H
bb
,
bh
=
bbh
//
H
,
bbh
%
H
...
@@ -93,10 +94,12 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
...
@@ -93,10 +94,12 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
G_shared
=
T
.
alloc_shared
((
block_S
,),
dtype
=
accum_dtype
,
scope
=
"shared"
)
G_shared
=
T
.
alloc_shared
((
block_S
,),
dtype
=
accum_dtype
,
scope
=
"shared"
)
G_diff_local
=
T
.
alloc_fragment
((
block_S
,
block_S
),
dtype
=
accum_dtype
)
G_diff_local
=
T
.
alloc_fragment
((
block_S
,
block_S
),
dtype
=
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
{
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
})
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
}
)
T
.
fill
(
A_fragment
,
0
)
T
.
fill
(
A_fragment
,
0
)
T
.
disable_warp_group_reg_alloc
()
T
.
disable_warp_group_reg_alloc
()
...
@@ -104,9 +107,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
...
@@ -104,9 +107,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
Beta_shared
[
i_s
]
=
Beta
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
Beta_shared
[
i_s
]
=
Beta
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
for
i_k
in
T
.
Pipelined
(
T
.
ceildiv
(
DK
,
block_DK
),
num_stages
=
num_stages
):
for
i_k
in
T
.
Pipelined
(
T
.
ceildiv
(
DK
,
block_DK
),
num_stages
=
num_stages
):
T
.
copy
(
T
.
copy
(
K
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
],
K_shared
)
K
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
],
K_shared
)
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
Beta_K_fragment
[
i_s
,
i_k2
]
=
K_shared
[
i_s
,
i_k2
]
*
Beta_shared
[
i_s
]
Beta_K_fragment
[
i_s
,
i_k2
]
=
K_shared
[
i_s
,
i_k2
]
*
Beta_shared
[
i_s
]
T
.
gemm
(
Beta_K_fragment
,
K_shared
,
A_fragment
,
transpose_B
=
True
)
T
.
gemm
(
Beta_K_fragment
,
K_shared
,
A_fragment
,
transpose_B
=
True
)
...
@@ -119,8 +120,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
...
@@ -119,8 +120,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
for
i_s1
,
i_s2
in
T
.
Parallel
(
block_S
,
block_S
):
for
i_s1
,
i_s2
in
T
.
Parallel
(
block_S
,
block_S
):
with
T
.
If
(
G_diff_local
[
i_s1
,
i_s2
]
<=
0
and
i_s1
>
i_s2
):
with
T
.
If
(
G_diff_local
[
i_s1
,
i_s2
]
<=
0
and
i_s1
>
i_s2
):
with
T
.
Then
():
with
T
.
Then
():
A_fragment
[
i_s1
,
i_s2
]
=
A_fragment
[
i_s1
,
i_s2
]
*
T
.
exp
(
A_fragment
[
i_s1
,
i_s2
]
=
A_fragment
[
i_s1
,
i_s2
]
*
T
.
exp
(
G_diff_local
[
i_s1
,
i_s2
])
G_diff_local
[
i_s1
,
i_s2
])
with
T
.
Else
():
with
T
.
Else
():
A_fragment
[
i_s1
,
i_s2
]
=
0
A_fragment
[
i_s1
,
i_s2
]
=
0
else
:
else
:
...
@@ -130,7 +130,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
...
@@ -130,7 +130,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
A_fragment
[
i_s1
,
i_s2
]
=
0
A_fragment
[
i_s1
,
i_s2
]
=
0
T
.
copy
(
A_fragment
,
A_shared
)
T
.
copy
(
A_fragment
,
A_shared
)
T
.
copy
(
A_shared
,
A
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:])
T
.
copy
(
A_shared
,
A
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:])
return
kernel
return
kernel
...
@@ -149,24 +149,21 @@ def run_test(
...
@@ -149,24 +149,21 @@ def run_test(
threads
,
threads
,
num_stages
,
num_stages
,
):
):
K
,
Beta
,
G
=
prepare_input
(
B
,
S
,
H
,
DK
,
getattr
(
torch
,
input_dtype
),
K
,
Beta
,
G
=
prepare_input
(
B
,
S
,
H
,
DK
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
))
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
))
A_ref
=
prepare_output
(
B
,
S
,
H
,
chunk_size
,
getattr
(
torch
,
output_dtype
))
A_ref
=
prepare_output
(
B
,
S
,
H
,
chunk_size
,
getattr
(
torch
,
output_dtype
))
A_tilelang
=
prepare_output
(
B
,
S
,
H
,
chunk_size
,
getattr
(
torch
,
output_dtype
))
A_tilelang
=
prepare_output
(
B
,
S
,
H
,
chunk_size
,
getattr
(
torch
,
output_dtype
))
# reference
# reference
if
use_g
:
if
use_g
:
A_ref
=
chunk_scaled_dot_kkt_fwd
(
A_ref
=
chunk_scaled_dot_kkt_fwd
(
K
,
Beta
,
G
,
chunk_size
=
chunk_size
,
output_dtype
=
getattr
(
torch
,
output_dtype
))
K
,
Beta
,
G
,
chunk_size
=
chunk_size
,
output_dtype
=
getattr
(
torch
,
output_dtype
))
else
:
else
:
A_ref
=
chunk_scaled_dot_kkt_fwd
(
A_ref
=
chunk_scaled_dot_kkt_fwd
(
K
,
Beta
,
None
,
chunk_size
=
chunk_size
,
output_dtype
=
getattr
(
torch
,
output_dtype
))
K
,
Beta
,
None
,
chunk_size
=
chunk_size
,
output_dtype
=
getattr
(
torch
,
output_dtype
))
# tilelang
# tilelang
block_S
=
chunk_size
block_S
=
chunk_size
kernel
=
tilelang_chunk_scaled_dot_kkt_fwd
(
B
,
S
,
H
,
DK
,
chunk_size
,
input_dtype
,
output_dtype
,
kernel
=
tilelang_chunk_scaled_dot_kkt_fwd
(
accum_dtype
,
use_g
,
block_S
,
block_DK
,
threads
,
B
,
S
,
H
,
DK
,
chunk_size
,
input_dtype
,
output_dtype
,
accum_dtype
,
use_g
,
block_S
,
block_DK
,
threads
,
num_stages
num_stages
)
)
A_tilelang
=
kernel
(
K
,
Beta
,
G
)
A_tilelang
=
kernel
(
K
,
Beta
,
G
)
try
:
try
:
...
@@ -186,13 +183,14 @@ def main():
...
@@ -186,13 +183,14 @@ def main():
H
=
32
,
H
=
32
,
DK
=
128
,
DK
=
128
,
chunk_size
=
64
,
chunk_size
=
64
,
input_dtype
=
"
bfloat16
"
,
input_dtype
=
T
.
bfloat16
,
output_dtype
=
"
bfloat16
"
,
output_dtype
=
T
.
bfloat16
,
accum_dtype
=
"
float32
"
,
accum_dtype
=
T
.
float32
,
use_g
=
True
,
use_g
=
True
,
block_DK
=
64
,
block_DK
=
64
,
threads
=
128
,
threads
=
128
,
num_stages
=
2
)
num_stages
=
2
,
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
examples/gdn/example_cumsum.py
View file @
667632cc
...
@@ -10,6 +10,7 @@ import sys # noqa: F401
...
@@ -10,6 +10,7 @@ import sys # noqa: F401
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try
:
try
:
import
fla
import
fla
print
(
fla
.
__file__
)
print
(
fla
.
__file__
)
from
fla.ops.utils.cumsum
import
chunk_local_cumsum_scalar
from
fla.ops.utils.cumsum
import
chunk_local_cumsum_scalar
except
ImportError
:
except
ImportError
:
...
@@ -20,11 +21,8 @@ import torch
...
@@ -20,11 +21,8 @@ import torch
@
tilelang
.
jit
(
@
tilelang
.
jit
(
out_idx
=
[
-
1
],
out_idx
=
[
-
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
}
pass_configs
=
{
)
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
})
def
tilelang_chunk_local_cumsum_scalar
(
def
tilelang_chunk_local_cumsum_scalar
(
# task config
# task config
B
,
B
,
...
@@ -34,43 +32,43 @@ def tilelang_chunk_local_cumsum_scalar(
...
@@ -34,43 +32,43 @@ def tilelang_chunk_local_cumsum_scalar(
is_varlen
=
False
,
is_varlen
=
False
,
head_first
=
False
,
head_first
=
False
,
reverse
=
False
,
reverse
=
False
,
input_dtype
=
"
float16
"
,
input_dtype
=
T
.
float16
,
output_dtype
=
"
float32
"
,
output_dtype
=
T
.
float32
,
# kernel config
# kernel config
block_S
=
64
,
block_S
=
64
,
threads
=
256
,
threads
=
256
,
use_fragment
=
False
,
use_fragment
=
False
,
):
):
G_shape
=
(
B
,
H
,
S
)
if
head_first
else
(
B
,
S
,
H
)
G_shape
=
(
B
,
H
,
S
)
if
head_first
else
(
B
,
S
,
H
)
assert
chunk_size
==
2
**
(
chunk_size
.
bit_length
()
-
1
),
"chunk_size must be a power of 2"
assert
chunk_size
==
2
**
(
chunk_size
.
bit_length
()
-
1
),
"chunk_size must be a power of 2"
assert
chunk_size
==
block_S
,
"chunk_size must be equal to block_S"
assert
chunk_size
==
block_S
,
"chunk_size must be equal to block_S"
@
T
.
prim_func
@
T
.
prim_func
def
kernel
(
def
kernel
(
G
:
T
.
Tensor
(
G_shape
,
dtype
=
input_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
input_dtype
),
G_new
:
T
.
Tensor
(
G_shape
,
dtype
=
output_dtype
),
G_new
:
T
.
Tensor
(
G_shape
,
dtype
=
output_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
S
,
block_S
),
B
*
H
,
threads
=
threads
)
as
(
bs
,
bbh
):
with
T
.
Kernel
(
T
.
ceildiv
(
S
,
block_S
),
B
*
H
,
threads
=
threads
)
as
(
bs
,
bbh
):
bb
,
bh
=
bbh
//
H
,
bbh
%
H
bb
,
bh
=
bbh
//
H
,
bbh
%
H
G_shared
=
T
.
alloc_shared
((
1
,
block_S
),
dtype
=
output_dtype
,
scope
=
"shared"
)
G_shared
=
T
.
alloc_shared
((
1
,
block_S
),
dtype
=
output_dtype
,
scope
=
"shared"
)
if
head_first
:
if
head_first
:
T
.
copy
(
G
[
bb
,
bh
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
],
G_shared
)
T
.
copy
(
G
[
bb
,
bh
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
],
G_shared
)
else
:
else
:
T
.
copy
(
G
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
],
G_shared
)
T
.
copy
(
G
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
],
G_shared
)
if
use_fragment
:
if
use_fragment
:
G_fragment
=
T
.
alloc_fragment
((
1
,
block_S
),
dtype
=
output_dtype
,
scope
=
"shared"
)
G_fragment
=
T
.
alloc_fragment
((
1
,
block_S
),
dtype
=
output_dtype
,
scope
=
"shared"
)
T
.
copy
(
G_shared
,
G_fragment
)
T
.
copy
(
G_shared
,
G_fragment
)
T
.
cumsum
(
G_fragment
,
dim
=
1
,
reverse
=
reverse
)
T
.
cumsum
(
G_fragment
,
dim
=
1
,
reverse
=
reverse
)
if
head_first
:
if
head_first
:
T
.
copy
(
G_fragment
,
G_new
[
bb
,
bh
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
])
T
.
copy
(
G_fragment
,
G_new
[
bb
,
bh
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
])
else
:
else
:
T
.
copy
(
G_fragment
,
G_new
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
])
T
.
copy
(
G_fragment
,
G_new
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
])
else
:
else
:
T
.
cumsum
(
G_shared
,
dim
=
1
,
reverse
=
reverse
)
T
.
cumsum
(
G_shared
,
dim
=
1
,
reverse
=
reverse
)
if
head_first
:
if
head_first
:
T
.
copy
(
G_shared
,
G_new
[
bb
,
bh
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
])
T
.
copy
(
G_shared
,
G_new
[
bb
,
bh
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
])
else
:
else
:
T
.
copy
(
G_shared
,
G_new
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
])
T
.
copy
(
G_shared
,
G_new
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
])
return
kernel
return
kernel
...
@@ -113,11 +111,8 @@ def run_test(
...
@@ -113,11 +111,8 @@ def run_test(
# reference cumsum
# reference cumsum
G_new_ref
=
chunk_local_cumsum_scalar
(
G_new_ref
=
chunk_local_cumsum_scalar
(
g
=
G
,
g
=
G
,
chunk_size
=
chunk_size
,
reverse
=
reverse
,
head_first
=
head_first
,
output_dtype
=
getattr
(
torch
,
output_dtype
)
chunk_size
=
chunk_size
,
)
reverse
=
reverse
,
head_first
=
head_first
,
output_dtype
=
getattr
(
torch
,
output_dtype
))
# tilelang cumsum
# tilelang cumsum
block_S
=
chunk_size
block_S
=
chunk_size
...
@@ -159,10 +154,11 @@ def main():
...
@@ -159,10 +154,11 @@ def main():
chunk_size
=
64
,
chunk_size
=
64
,
reverse
=
True
,
reverse
=
True
,
head_first
=
False
,
head_first
=
False
,
input_dtype
=
"
float32
"
,
input_dtype
=
T
.
float32
,
output_dtype
=
"
float32
"
,
output_dtype
=
T
.
float32
,
threads
=
256
,
threads
=
256
,
use_fragment
=
False
)
use_fragment
=
False
,
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
examples/gdn/example_wy_fast.py
View file @
667632cc
...
@@ -9,6 +9,7 @@ import sys # noqa: F401
...
@@ -9,6 +9,7 @@ import sys # noqa: F401
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try
:
try
:
import
fla
import
fla
print
(
fla
.
__file__
)
print
(
fla
.
__file__
)
from
fla.ops.gated_delta_rule.wy_fast
import
recompute_w_u_fwd
from
fla.ops.gated_delta_rule.wy_fast
import
recompute_w_u_fwd
except
ImportError
:
except
ImportError
:
...
@@ -73,13 +74,13 @@ def tilelang_recompute_w_u_fwd(
...
@@ -73,13 +74,13 @@ def tilelang_recompute_w_u_fwd(
@
T
.
prim_func
@
T
.
prim_func
def
kernel
(
def
kernel
(
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
V
:
T
.
Tensor
(
V_shape
,
dtype
=
input_dtype
),
V
:
T
.
Tensor
(
V_shape
,
dtype
=
input_dtype
),
Beta
:
T
.
Tensor
(
Beta_shape
,
dtype
=
input_dtype
),
Beta
:
T
.
Tensor
(
Beta_shape
,
dtype
=
input_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
gate_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
gate_dtype
),
A
:
T
.
Tensor
(
A_shape
,
dtype
=
output_dtype
),
A
:
T
.
Tensor
(
A_shape
,
dtype
=
output_dtype
),
W
:
T
.
Tensor
(
K_shape
,
dtype
=
output_dtype
),
W
:
T
.
Tensor
(
K_shape
,
dtype
=
output_dtype
),
U
:
T
.
Tensor
(
V_shape
,
dtype
=
output_dtype
),
U
:
T
.
Tensor
(
V_shape
,
dtype
=
output_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
S
,
block_S
),
B
*
H
,
threads
=
threads
)
as
(
bs
,
bbh
):
with
T
.
Kernel
(
T
.
ceildiv
(
S
,
block_S
),
B
*
H
,
threads
=
threads
)
as
(
bs
,
bbh
):
bb
,
bh
=
bbh
//
H
,
bbh
%
H
bb
,
bh
=
bbh
//
H
,
bbh
%
H
...
@@ -95,49 +96,42 @@ def tilelang_recompute_w_u_fwd(
...
@@ -95,49 +96,42 @@ def tilelang_recompute_w_u_fwd(
W_Beta_shared
=
T
.
alloc_shared
((
block_S
,
block_DK
),
dtype
=
input_dtype
)
W_Beta_shared
=
T
.
alloc_shared
((
block_S
,
block_DK
),
dtype
=
input_dtype
)
U_Beta_shared
=
T
.
alloc_shared
((
block_S
,
block_DV
),
dtype
=
input_dtype
)
U_Beta_shared
=
T
.
alloc_shared
((
block_S
,
block_DV
),
dtype
=
input_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
{
V_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
V_shared
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
V_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
V_shared
),
W_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
W_shared
),
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
U_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
U_shared
),
W_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
W_shared
),
W_Beta_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
W_Beta_shared
),
U_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
U_shared
),
U_Beta_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
U_Beta_shared
),
W_Beta_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
W_Beta_shared
),
})
U_Beta_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
U_Beta_shared
),
}
)
T
.
disable_warp_group_reg_alloc
()
T
.
disable_warp_group_reg_alloc
()
for
i_s
in
T
.
Parallel
(
block_S
):
for
i_s
in
T
.
Parallel
(
block_S
):
Beta_shared
[
i_s
]
=
Beta
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
Beta_shared
[
i_s
]
=
Beta
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
G_shared
[
i_s
]
=
T
.
exp
(
G
[
bb
,
bs
*
block_S
+
i_s
,
bh
])
G_shared
[
i_s
]
=
T
.
exp
(
G
[
bb
,
bs
*
block_S
+
i_s
,
bh
])
T
.
copy
(
A
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:],
A_shared
)
T
.
copy
(
A
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:],
A_shared
)
for
i_v
in
T
.
Pipelined
(
T
.
ceildiv
(
DV
,
block_DV
),
num_stages
=
num_stages
):
for
i_v
in
T
.
Pipelined
(
T
.
ceildiv
(
DV
,
block_DV
),
num_stages
=
num_stages
):
T
.
copy
(
T
.
copy
(
V
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:
(
i_v
+
1
)
*
block_DV
],
V_shared
)
V
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:(
i_v
+
1
)
*
block_DV
],
V_shared
)
for
i_s
,
i_v2
in
T
.
Parallel
(
block_S
,
block_DV
):
for
i_s
,
i_v2
in
T
.
Parallel
(
block_S
,
block_DV
):
U_Beta_shared
[
i_s
,
i_v2
]
=
V_shared
[
i_s
,
i_v2
]
*
Beta_shared
[
i_s
]
U_Beta_shared
[
i_s
,
i_v2
]
=
V_shared
[
i_s
,
i_v2
]
*
Beta_shared
[
i_s
]
T
.
gemm
(
A_shared
,
U_Beta_shared
,
U_fragment
,
clear_accum
=
True
)
T
.
gemm
(
A_shared
,
U_Beta_shared
,
U_fragment
,
clear_accum
=
True
)
# First copy to smem, then copy to gmem to reduce U2RU instructions
# First copy to smem, then copy to gmem to reduce U2RU instructions
T
.
copy
(
U_fragment
,
U_shared
)
T
.
copy
(
U_fragment
,
U_shared
)
T
.
copy
(
T
.
copy
(
U_shared
,
U
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:
(
i_v
+
1
)
*
block_DV
])
U_shared
,
U
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:(
i_v
+
1
)
*
block_DV
])
for
i_k
in
T
.
Pipelined
(
T
.
ceildiv
(
DK
,
block_DK
),
num_stages
=
num_stages
):
for
i_k
in
T
.
Pipelined
(
T
.
ceildiv
(
DK
,
block_DK
),
num_stages
=
num_stages
):
T
.
copy
(
T
.
copy
(
K
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
],
K_shared
)
K
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
],
K_shared
)
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
W_Beta_shared
[
i_s
,
W_Beta_shared
[
i_s
,
i_k2
]
=
K_shared
[
i_s
,
i_k2
]
*
Beta_shared
[
i_s
]
*
G_shared
[
i_s
]
i_k2
]
=
K_shared
[
i_s
,
i_k2
]
*
Beta_shared
[
i_s
]
*
G_shared
[
i_s
]
T
.
gemm
(
A_shared
,
W_Beta_shared
,
W_fragment
,
clear_accum
=
True
)
T
.
gemm
(
A_shared
,
W_Beta_shared
,
W_fragment
,
clear_accum
=
True
)
# First copy to smem, then copy to gmem to reduce U2RU instructions
# First copy to smem, then copy to gmem to reduce U2RU instructions
T
.
copy
(
W_fragment
,
W_shared
)
T
.
copy
(
W_fragment
,
W_shared
)
T
.
copy
(
T
.
copy
(
W_shared
,
W
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
])
W_shared
,
W
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
])
return
kernel
return
kernel
...
@@ -159,15 +153,8 @@ def run_test(
...
@@ -159,15 +153,8 @@ def run_test(
num_stages
,
num_stages
,
):
):
K
,
V
,
Beta
,
G
,
A
=
prepare_input
(
K
,
V
,
Beta
,
G
,
A
=
prepare_input
(
B
,
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
gate_dtype
=
getattr
(
torch
,
gate_dtype
)
S
,
)
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
gate_dtype
=
getattr
(
torch
,
gate_dtype
))
W_ref
,
U_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
getattr
(
torch
,
output_dtype
))
W_ref
,
U_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
getattr
(
torch
,
output_dtype
))
W_tilelang
,
U_tilelang
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
getattr
(
torch
,
output_dtype
))
W_tilelang
,
U_tilelang
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
getattr
(
torch
,
output_dtype
))
...
@@ -191,7 +178,8 @@ def run_test(
...
@@ -191,7 +178,8 @@ def run_test(
block_DK
=
block_DK
,
block_DK
=
block_DK
,
block_DV
=
block_DV
,
block_DV
=
block_DV
,
threads
=
threads
,
threads
=
threads
,
num_stages
=
num_stages
)
num_stages
=
num_stages
,
)
print
(
kernel
.
get_kernel_source
())
print
(
kernel
.
get_kernel_source
())
W_tilelang
,
U_tilelang
=
kernel
(
K
,
V
,
Beta
,
G
,
A
)
W_tilelang
,
U_tilelang
=
kernel
(
K
,
V
,
Beta
,
G
,
A
)
...
@@ -217,14 +205,15 @@ def main():
...
@@ -217,14 +205,15 @@ def main():
DK
=
128
,
DK
=
128
,
DV
=
128
,
DV
=
128
,
chunk_size
=
64
,
chunk_size
=
64
,
input_dtype
=
"
bfloat16
"
,
input_dtype
=
T
.
bfloat16
,
output_dtype
=
"
bfloat16
"
,
output_dtype
=
T
.
bfloat16
,
gate_dtype
=
"
float32
"
,
gate_dtype
=
T
.
float32
,
accum_dtype
=
"
float32
"
,
accum_dtype
=
T
.
float32
,
block_DK
=
64
,
block_DK
=
64
,
block_DV
=
32
,
block_DV
=
32
,
threads
=
128
,
threads
=
128
,
num_stages
=
3
)
num_stages
=
3
,
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
examples/gdn/example_wy_fast_bwd_split.py
View file @
667632cc
...
@@ -10,6 +10,7 @@ import tilelang.language as T
...
@@ -10,6 +10,7 @@ import tilelang.language as T
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try
:
try
:
import
fla
import
fla
print
(
fla
.
__file__
)
print
(
fla
.
__file__
)
from
fla.ops.gated_delta_rule.wy_fast
import
bwd_prepare_wy_repr
from
fla.ops.gated_delta_rule.wy_fast
import
bwd_prepare_wy_repr
except
ImportError
:
except
ImportError
:
...
@@ -93,10 +94,8 @@ def prepare_output(
...
@@ -93,10 +94,8 @@ def prepare_output(
@
tilelang
.
jit
(
@
tilelang
.
jit
(
out_idx
=
[
-
5
,
-
4
,
-
3
,
-
2
,
-
1
],
out_idx
=
[
-
5
,
-
4
,
-
3
,
-
2
,
-
1
],
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
},
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
)
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
})
def
tilelang_wy_fast_bwd
(
def
tilelang_wy_fast_bwd
(
# task config
# task config
B
,
B
,
...
@@ -135,20 +134,20 @@ def tilelang_wy_fast_bwd(
...
@@ -135,20 +134,20 @@ def tilelang_wy_fast_bwd(
@
T
.
prim_func
@
T
.
prim_func
def
kernel
(
def
kernel
(
# input
# input
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
V
:
T
.
Tensor
(
V_shape
,
dtype
=
input_dtype
),
V
:
T
.
Tensor
(
V_shape
,
dtype
=
input_dtype
),
Beta
:
T
.
Tensor
(
Beta_shape
,
dtype
=
input_dtype
),
Beta
:
T
.
Tensor
(
Beta_shape
,
dtype
=
input_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
gate_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
gate_dtype
),
A
:
T
.
Tensor
(
A_shape
,
dtype
=
input_dtype
),
A
:
T
.
Tensor
(
A_shape
,
dtype
=
input_dtype
),
dw
:
T
.
Tensor
(
dw_shape
,
dtype
=
input_dtype
),
dw
:
T
.
Tensor
(
dw_shape
,
dtype
=
input_dtype
),
du
:
T
.
Tensor
(
du_shape
,
dtype
=
input_dtype
),
du
:
T
.
Tensor
(
du_shape
,
dtype
=
input_dtype
),
# output
# output
dA
:
T
.
Tensor
(
dA_shape
,
dtype
=
input_dtype
),
dA
:
T
.
Tensor
(
dA_shape
,
dtype
=
input_dtype
),
dk
:
T
.
Tensor
(
dk_shape
,
dtype
=
output_dtype
),
dk
:
T
.
Tensor
(
dk_shape
,
dtype
=
output_dtype
),
dv
:
T
.
Tensor
(
dv_shape
,
dtype
=
output_dtype
),
dv
:
T
.
Tensor
(
dv_shape
,
dtype
=
output_dtype
),
dbeta
:
T
.
Tensor
(
dbeta_shape
,
dtype
=
output_dtype
),
dbeta
:
T
.
Tensor
(
dbeta_shape
,
dtype
=
output_dtype
),
dg
:
T
.
Tensor
(
dg_shape
,
dtype
=
gate_dtype
),
dg
:
T
.
Tensor
(
dg_shape
,
dtype
=
gate_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
S
,
block_S
),
B
*
H
,
threads
=
threads
)
as
(
bs
,
bbh
):
with
T
.
Kernel
(
T
.
ceildiv
(
S
,
block_S
),
B
*
H
,
threads
=
threads
)
as
(
bs
,
bbh
):
bb
,
bh
=
bbh
//
H
,
bbh
%
H
bb
,
bh
=
bbh
//
H
,
bbh
%
H
...
@@ -187,7 +186,7 @@ def tilelang_wy_fast_bwd(
...
@@ -187,7 +186,7 @@ def tilelang_wy_fast_bwd(
T
.
clear
(
dbeta_fragment_v
)
T
.
clear
(
dbeta_fragment_v
)
T
.
clear
(
dg_fragment
)
T
.
clear
(
dg_fragment
)
T
.
copy
(
A
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:],
A_shared
)
T
.
copy
(
A
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:],
A_shared
)
for
i_s
in
T
.
Parallel
(
block_S
):
for
i_s
in
T
.
Parallel
(
block_S
):
Beta_shared
[
i_s
]
=
Beta
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
Beta_shared
[
i_s
]
=
Beta
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
G_shared
[
i_s
]
=
G
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
G_shared
[
i_s
]
=
G
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
...
@@ -195,51 +194,37 @@ def tilelang_wy_fast_bwd(
...
@@ -195,51 +194,37 @@ def tilelang_wy_fast_bwd(
# Update dk
# Update dk
for
i_k
in
T
.
Pipelined
(
T
.
ceildiv
(
DK
,
block_DK
),
num_stages
=
num_stages
):
for
i_k
in
T
.
Pipelined
(
T
.
ceildiv
(
DK
,
block_DK
),
num_stages
=
num_stages
):
T
.
copy
(
T
.
copy
(
K
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
],
K_shared
)
K
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
],
K_shared
)
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
K_shared_beta_g
[
i_s
,
K_shared_beta_g
[
i_s
,
i_k2
]
=
K_shared
[
i_s
,
i_k2
]
*
Beta_shared
[
i_s
]
*
G_shared_exp
[
i_s
]
i_k2
]
=
K_shared
[
i_s
,
T
.
copy
(
dw
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
],
dw_shared
)
i_k2
]
*
Beta_shared
[
i_s
]
*
G_shared_exp
[
i_s
]
T
.
copy
(
dw
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
],
dw_shared
)
T
.
gemm
(
dw_shared
,
K_shared_beta_g
,
dA_fragment
,
transpose_B
=
True
)
T
.
gemm
(
dw_shared
,
K_shared_beta_g
,
dA_fragment
,
transpose_B
=
True
)
T
.
gemm
(
A_shared
,
dw_shared
,
dk_fragment_beta_g
,
clear_accum
=
True
,
transpose_A
=
True
)
T
.
gemm
(
A_shared
,
dw_shared
,
dk_fragment_beta_g
,
clear_accum
=
True
,
transpose_A
=
True
)
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
dk_fragment
[
dk_fragment
[
i_s
,
i_k2
]
=
dk_fragment_beta_g
[
i_s
,
i_k2
]
*
Beta_shared
[
i_s
]
*
G_shared_exp
[
i_s
]
i_s
,
i_k2
]
=
dk_fragment_beta_g
[
i_s
,
i_k2
]
*
Beta_shared
[
i_s
]
*
G_shared_exp
[
i_s
]
# for i_s, i_k2 in T.Parallel(block_S, block_DK):
# for i_s, i_k2 in T.Parallel(block_S, block_DK):
# dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s]
# dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s]
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
dbeta_fragment_reduce_tmpk
[
i_s
,
i_k2
]
=
dk_fragment_beta_g
[
dbeta_fragment_reduce_tmpk
[
i_s
,
i_k2
]
=
dk_fragment_beta_g
[
i_s
,
i_k2
]
*
K_shared
[
i_s
,
i_k2
]
*
G_shared_exp
[
i_s
]
i_s
,
i_k2
]
*
K_shared
[
i_s
,
i_k2
]
*
G_shared_exp
[
i_s
]
T
.
reduce_sum
(
dbeta_fragment_reduce_tmpk
,
dbeta_fragment_k
,
dim
=
1
,
clear
=
False
)
T
.
reduce_sum
(
dbeta_fragment_reduce_tmpk
,
dbeta_fragment_k
,
dim
=
1
,
clear
=
False
)
# for i_s, i_k2 in T.Parallel(block_S, block_DK):
# for i_s, i_k2 in T.Parallel(block_S, block_DK):
# dg_fragment[i_s] = dg_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s]
# dg_fragment[i_s] = dg_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s]
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
dg_fragment_reduce_tmp
[
i_s
,
i_k2
]
=
dk_fragment_beta_g
[
i_s
,
i_k2
]
*
K_shared
[
dg_fragment_reduce_tmp
[
i_s
,
i_k2
]
=
(
i_s
,
i_k2
]
*
G_shared_exp
[
i_s
]
*
Beta_shared
[
i_s
]
dk_fragment_beta_g
[
i_s
,
i_k2
]
*
K_shared
[
i_s
,
i_k2
]
*
G_shared_exp
[
i_s
]
*
Beta_shared
[
i_s
]
)
T
.
reduce_sum
(
dg_fragment_reduce_tmp
,
dg_fragment
,
dim
=
1
,
clear
=
False
)
T
.
reduce_sum
(
dg_fragment_reduce_tmp
,
dg_fragment
,
dim
=
1
,
clear
=
False
)
# correct dk
# correct dk
T
.
copy
(
T
.
copy
(
dk_fragment
,
dk
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
])
dk_fragment
,
dk
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
])
# Update dv
# Update dv
for
i_v
in
T
.
Pipelined
(
T
.
ceildiv
(
DV
,
block_DV
),
num_stages
=
num_stages
):
for
i_v
in
T
.
Pipelined
(
T
.
ceildiv
(
DV
,
block_DV
),
num_stages
=
num_stages
):
T
.
copy
(
T
.
copy
(
V
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:
(
i_v
+
1
)
*
block_DV
],
V_shared
)
V
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:(
i_v
+
1
)
*
block_DV
],
V_shared
)
for
i_s
,
i_v2
in
T
.
Parallel
(
block_S
,
block_DV
):
for
i_s
,
i_v2
in
T
.
Parallel
(
block_S
,
block_DV
):
V_shared_beta
[
i_s
,
i_v2
]
=
V_shared
[
i_s
,
i_v2
]
*
Beta_shared
[
i_s
]
V_shared_beta
[
i_s
,
i_v2
]
=
V_shared
[
i_s
,
i_v2
]
*
Beta_shared
[
i_s
]
T
.
copy
(
T
.
copy
(
du
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:
(
i_v
+
1
)
*
block_DV
],
du_shared
)
du
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:(
i_v
+
1
)
*
block_DV
],
du_shared
)
T
.
gemm
(
du_shared
,
V_shared_beta
,
dA_fragment
,
transpose_B
=
True
)
T
.
gemm
(
du_shared
,
V_shared_beta
,
dA_fragment
,
transpose_B
=
True
)
T
.
gemm
(
A_shared
,
du_shared
,
dv_fragment_beta
,
clear_accum
=
True
,
transpose_A
=
True
)
T
.
gemm
(
A_shared
,
du_shared
,
dv_fragment_beta
,
clear_accum
=
True
,
transpose_A
=
True
)
for
i_s
,
i_v2
in
T
.
Parallel
(
block_S
,
block_DV
):
for
i_s
,
i_v2
in
T
.
Parallel
(
block_S
,
block_DV
):
...
@@ -247,30 +232,22 @@ def tilelang_wy_fast_bwd(
...
@@ -247,30 +232,22 @@ def tilelang_wy_fast_bwd(
# for i_s, i_v2 in T.Parallel(block_S, block_DV):
# for i_s, i_v2 in T.Parallel(block_S, block_DV):
# dbeta_fragment[i_s] = dbeta_fragment[i_s] + dv_fragment_beta[i_s, i_v2] * V_shared[i_s, i_v2]
# dbeta_fragment[i_s] = dbeta_fragment[i_s] + dv_fragment_beta[i_s, i_v2] * V_shared[i_s, i_v2]
for
i_s
,
i_v2
in
T
.
Parallel
(
block_S
,
block_DV
):
for
i_s
,
i_v2
in
T
.
Parallel
(
block_S
,
block_DV
):
dbeta_fragment_reduce_tmpv
[
i_s
,
dbeta_fragment_reduce_tmpv
[
i_s
,
i_v2
]
=
dv_fragment_beta
[
i_s
,
i_v2
]
*
V_shared
[
i_s
,
i_v2
]
i_v2
]
=
dv_fragment_beta
[
i_s
,
i_v2
]
*
V_shared
[
i_s
,
i_v2
]
T
.
reduce_sum
(
dbeta_fragment_reduce_tmpv
,
dbeta_fragment_v
,
dim
=
1
,
clear
=
False
)
T
.
reduce_sum
(
dbeta_fragment_reduce_tmpv
,
dbeta_fragment_v
,
dim
=
1
,
clear
=
False
)
T
.
copy
(
T
.
copy
(
dv_fragment
,
dv
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:
(
i_v
+
1
)
*
block_DV
])
dv_fragment
,
dv
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:(
i_v
+
1
)
*
block_DV
])
# Temporary store dbeta, dg and dA
# Temporary store dbeta, dg and dA
for
i_s
in
T
.
Parallel
(
block_S
):
for
i_s
in
T
.
Parallel
(
block_S
):
dbeta
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
=
dbeta_fragment_k
[
i_s
]
+
dbeta_fragment_v
[
i_s
]
dbeta
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
=
dbeta_fragment_k
[
i_s
]
+
dbeta_fragment_v
[
i_s
]
dg
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
=
dg_fragment
[
i_s
]
dg
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
=
dg_fragment
[
i_s
]
# correct dA
# correct dA
T
.
copy
(
dA_fragment
,
dA
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:])
T
.
copy
(
dA_fragment
,
dA
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:])
return
kernel
return
kernel
@
tilelang
.
jit
(
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
})
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
})
def
tilelang_wy_fast_bwd_split
(
def
tilelang_wy_fast_bwd_split
(
# task config
# task config
B
,
B
,
...
@@ -308,20 +285,20 @@ def tilelang_wy_fast_bwd_split(
...
@@ -308,20 +285,20 @@ def tilelang_wy_fast_bwd_split(
@
T
.
prim_func
@
T
.
prim_func
def
kernel
(
def
kernel
(
# input
# input
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
K
:
T
.
Tensor
(
K_shape
,
dtype
=
input_dtype
),
V
:
T
.
Tensor
(
V_shape
,
dtype
=
input_dtype
),
V
:
T
.
Tensor
(
V_shape
,
dtype
=
input_dtype
),
Beta
:
T
.
Tensor
(
Beta_shape
,
dtype
=
input_dtype
),
Beta
:
T
.
Tensor
(
Beta_shape
,
dtype
=
input_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
gate_dtype
),
G
:
T
.
Tensor
(
G_shape
,
dtype
=
gate_dtype
),
A
:
T
.
Tensor
(
A_shape
,
dtype
=
input_dtype
),
A
:
T
.
Tensor
(
A_shape
,
dtype
=
input_dtype
),
dw
:
T
.
Tensor
(
dw_shape
,
dtype
=
input_dtype
),
dw
:
T
.
Tensor
(
dw_shape
,
dtype
=
input_dtype
),
du
:
T
.
Tensor
(
du_shape
,
dtype
=
input_dtype
),
du
:
T
.
Tensor
(
du_shape
,
dtype
=
input_dtype
),
dA
:
T
.
Tensor
(
dA_shape
,
dtype
=
input_dtype
),
dA
:
T
.
Tensor
(
dA_shape
,
dtype
=
input_dtype
),
dk
:
T
.
Tensor
(
dk_shape
,
dtype
=
output_dtype
),
dk
:
T
.
Tensor
(
dk_shape
,
dtype
=
output_dtype
),
dv
:
T
.
Tensor
(
dv_shape
,
dtype
=
output_dtype
),
dv
:
T
.
Tensor
(
dv_shape
,
dtype
=
output_dtype
),
dbeta_k
:
T
.
Tensor
(
dbeta_shape
,
dtype
=
output_dtype
),
dbeta_k
:
T
.
Tensor
(
dbeta_shape
,
dtype
=
output_dtype
),
dg_A_positive
:
T
.
Tensor
(
dA_shape
,
dtype
=
gate_dtype
),
dg_A_positive
:
T
.
Tensor
(
dA_shape
,
dtype
=
gate_dtype
),
dg_A_negative
:
T
.
Tensor
(
dA_shape
,
dtype
=
gate_dtype
),
dg_A_negative
:
T
.
Tensor
(
dA_shape
,
dtype
=
gate_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
S
,
block_S
),
B
*
H
,
threads
=
threads
)
as
(
bs
,
bbh
):
with
T
.
Kernel
(
T
.
ceildiv
(
S
,
block_S
),
B
*
H
,
threads
=
threads
)
as
(
bs
,
bbh
):
bb
,
bh
=
bbh
//
H
,
bbh
%
H
bb
,
bh
=
bbh
//
H
,
bbh
%
H
...
@@ -350,7 +327,7 @@ def tilelang_wy_fast_bwd_split(
...
@@ -350,7 +327,7 @@ def tilelang_wy_fast_bwd_split(
T
.
clear
(
dA_A_fragment_1
)
T
.
clear
(
dA_A_fragment_1
)
T
.
clear
(
dA_A_fragment_2
)
T
.
clear
(
dA_A_fragment_2
)
T
.
copy
(
A
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:],
A_shared
)
T
.
copy
(
A
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:],
A_shared
)
for
i_s
in
T
.
Parallel
(
block_S
):
for
i_s
in
T
.
Parallel
(
block_S
):
Beta_shared
[
i_s
]
=
Beta
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
Beta_shared
[
i_s
]
=
Beta
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
G_shared
[
i_s
]
=
G
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
G_shared
[
i_s
]
=
G
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
...
@@ -361,7 +338,7 @@ def tilelang_wy_fast_bwd_split(
...
@@ -361,7 +338,7 @@ def tilelang_wy_fast_bwd_split(
# for i_s in T.Parallel(block_S):
# for i_s in T.Parallel(block_S):
# dbeta_fragment[i_s] = dbeta[bb, bs * block_S + i_s, bh]
# dbeta_fragment[i_s] = dbeta[bb, bs * block_S + i_s, bh]
# dg_fragment[i_s] = dg[bb, bs * block_S + i_s, bh]
# dg_fragment[i_s] = dg[bb, bs * block_S + i_s, bh]
T
.
copy
(
dA
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:],
dA_shared
)
T
.
copy
(
dA
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:],
dA_shared
)
# T.copy(dA_shared, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :])
# T.copy(dA_shared, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :])
# Update dA
# Update dA
...
@@ -385,8 +362,7 @@ def tilelang_wy_fast_bwd_split(
...
@@ -385,8 +362,7 @@ def tilelang_wy_fast_bwd_split(
for
i_s1
,
i_s2
in
T
.
Parallel
(
block_S
,
block_S
):
for
i_s1
,
i_s2
in
T
.
Parallel
(
block_S
,
block_S
):
with
T
.
If
(
G
[
bb
,
bs
*
block_S
+
i_s1
,
bh
]
-
G
[
bb
,
bs
*
block_S
+
i_s2
,
bh
]
<=
0
):
with
T
.
If
(
G
[
bb
,
bs
*
block_S
+
i_s1
,
bh
]
-
G
[
bb
,
bs
*
block_S
+
i_s2
,
bh
]
<=
0
):
with
T
.
Then
():
with
T
.
Then
():
dA_fragment
[
i_s1
,
i_s2
]
*=
T
.
exp
(
G
[
bb
,
bs
*
block_S
+
i_s1
,
bh
]
-
dA_fragment
[
i_s1
,
i_s2
]
*=
T
.
exp
(
G
[
bb
,
bs
*
block_S
+
i_s1
,
bh
]
-
G
[
bb
,
bs
*
block_S
+
i_s2
,
bh
])
G
[
bb
,
bs
*
block_S
+
i_s2
,
bh
])
with
T
.
Else
():
with
T
.
Else
():
dA_fragment
[
i_s1
,
i_s2
]
=
0
dA_fragment
[
i_s1
,
i_s2
]
=
0
T
.
copy
(
dA_fragment
,
dA_shared
)
T
.
copy
(
dA_fragment
,
dA_shared
)
...
@@ -397,12 +373,8 @@ def tilelang_wy_fast_bwd_split(
...
@@ -397,12 +373,8 @@ def tilelang_wy_fast_bwd_split(
# Update dk using previous dk
# Update dk using previous dk
T
.
clear
(
A_fragment
)
T
.
clear
(
A_fragment
)
for
i_k
in
T
.
Pipelined
(
T
.
ceildiv
(
DK
,
block_DK
),
num_stages
=
num_stages
):
for
i_k
in
T
.
Pipelined
(
T
.
ceildiv
(
DK
,
block_DK
),
num_stages
=
num_stages
):
T
.
copy
(
T
.
copy
(
K
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
],
K_shared
)
K
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
],
T
.
copy
(
dk
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
],
dk_shared
)
K_shared
)
T
.
copy
(
dk
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
],
dk_shared
)
T
.
copy
(
dk_shared
,
dk_fragment
)
T
.
copy
(
dk_shared
,
dk_fragment
)
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
K_shared_beta
[
i_s
,
i_k2
]
=
K_shared
[
i_s
,
i_k2
]
*
Beta_shared
[
i_s
]
K_shared_beta
[
i_s
,
i_k2
]
=
K_shared
[
i_s
,
i_k2
]
*
Beta_shared
[
i_s
]
...
@@ -411,18 +383,14 @@ def tilelang_wy_fast_bwd_split(
...
@@ -411,18 +383,14 @@ def tilelang_wy_fast_bwd_split(
# for i_s, i_k2 in T.Parallel(block_S, block_DK):
# for i_s, i_k2 in T.Parallel(block_S, block_DK):
# dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta[i_s, i_k2] * K_shared[i_s, i_k2]
# dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta[i_s, i_k2] * K_shared[i_s, i_k2]
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
dbeta_fragment_reduce_tmpk
[
i_s
,
dbeta_fragment_reduce_tmpk
[
i_s
,
i_k2
]
=
dk_fragment_beta
[
i_s
,
i_k2
]
*
K_shared
[
i_s
,
i_k2
]
i_k2
]
=
dk_fragment_beta
[
i_s
,
i_k2
]
*
K_shared
[
i_s
,
i_k2
]
T
.
reduce_sum
(
dbeta_fragment_reduce_tmpk
,
dbeta_fragment_k
,
dim
=
1
,
clear
=
False
)
T
.
reduce_sum
(
dbeta_fragment_reduce_tmpk
,
dbeta_fragment_k
,
dim
=
1
,
clear
=
False
)
T
.
gemm
(
dA_shared
,
K_shared_beta
,
dk_fragment
,
transpose_A
=
True
)
T
.
gemm
(
dA_shared
,
K_shared_beta
,
dk_fragment
,
transpose_A
=
True
)
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
dk_shared_beta
[
i_s
,
i_k2
]
=
dk_fragment_beta
[
i_s
,
i_k2
]
*
Beta_shared
[
i_s
]
dk_shared_beta
[
i_s
,
i_k2
]
=
dk_fragment_beta
[
i_s
,
i_k2
]
*
Beta_shared
[
i_s
]
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
dk_fragment
[
i_s
,
i_k2
]
=
dk_fragment
[
i_s
,
i_k2
]
+
dk_shared_beta
[
i_s
,
i_k2
]
dk_fragment
[
i_s
,
i_k2
]
=
dk_fragment
[
i_s
,
i_k2
]
+
dk_shared_beta
[
i_s
,
i_k2
]
T
.
copy
(
T
.
copy
(
dk_fragment
,
dk
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
])
dk_fragment
,
dk
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
])
# Update dg and dbeta
# Update dg and dbeta
T
.
copy
(
A_fragment
,
A_shared
)
T
.
copy
(
A_fragment
,
A_shared
)
...
@@ -460,19 +428,25 @@ def run_test(
...
@@ -460,19 +428,25 @@ def run_test(
threads
=
128
,
threads
=
128
,
num_stages
=
0
,
num_stages
=
0
,
):
):
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
=
prepare_input
(
getattr
(
torch
,
input_dtype
),
B
,
getattr
(
torch
,
output_dtype
),
S
,
getattr
(
torch
,
H
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
DK
,
getattr
(
torch
,
state_dtype
))
DV
,
dk_ref
,
dv_ref
,
dbeta_ref
,
dg_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
state_dtype
))
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
),
)
dk_ref
,
dv_ref
,
dbeta_ref
,
dg_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
)
)
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang
,
dg_tilelang
=
prepare_output
(
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang
,
dg_tilelang
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
)
getattr
(
torch
,
state_dtype
)
)
)
BS
=
chunk_size
BS
=
chunk_size
dA_tilelang
=
torch
.
empty
(
B
,
S
,
H
,
BS
,
dtype
=
getattr
(
torch
,
input_dtype
)).
cuda
()
dA_tilelang
=
torch
.
empty
(
B
,
S
,
H
,
BS
,
dtype
=
getattr
(
torch
,
input_dtype
)).
cuda
()
dbeta_tilelang_k
=
torch
.
empty
(
B
,
S
,
H
,
dtype
=
getattr
(
torch
,
output_dtype
)).
cuda
()
dbeta_tilelang_k
=
torch
.
empty
(
B
,
S
,
H
,
dtype
=
getattr
(
torch
,
output_dtype
)).
cuda
()
...
@@ -480,28 +454,55 @@ def run_test(
...
@@ -480,28 +454,55 @@ def run_test(
dg_tilelang_A_negative
=
torch
.
empty
(
B
,
S
,
H
,
BS
,
dtype
=
getattr
(
torch
,
gate_dtype
)).
cuda
()
dg_tilelang_A_negative
=
torch
.
empty
(
B
,
S
,
H
,
BS
,
dtype
=
getattr
(
torch
,
gate_dtype
)).
cuda
()
# ref
# ref
dk_ref
,
dv_ref
,
dbeta_ref
,
dg_ref
=
bwd_prepare_wy_repr
(
dk_ref
,
dv_ref
,
dbeta_ref
,
dg_ref
=
bwd_prepare_wy_repr
(
K
,
V
,
G
,
Beta
,
A
,
dw
,
du
,
cu_seqlens
=
None
)
K
,
V
,
G
,
Beta
,
A
,
dw
,
du
,
cu_seqlens
=
None
)
# tilelang
# tilelang
kernel
=
tilelang_wy_fast_bwd
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
kernel
=
tilelang_wy_fast_bwd
(
gate_dtype
,
state_dtype
,
chunk_size
,
block_DK
,
block_DV
,
threads
,
B
,
num_stages
)
S
,
dA_tilelang
,
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang
,
dg_tilelang
=
kernel
(
H
,
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
)
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
block_DK
,
block_DV
,
threads
,
num_stages
,
)
dA_tilelang
,
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang
,
dg_tilelang
=
kernel
(
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
kernel_split
=
tilelang_wy_fast_bwd_split
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
kernel_split
=
tilelang_wy_fast_bwd_split
(
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
B
,
block_DK
,
block_DV
,
threads
,
num_stages
)
S
,
kernel_split
(
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
,
dA_tilelang
,
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang_k
,
H
,
dg_tilelang_A_positive
,
dg_tilelang_A_negative
)
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
block_DK
,
block_DV
,
threads
,
num_stages
,
)
kernel_split
(
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
,
dA_tilelang
,
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang_k
,
dg_tilelang_A_positive
,
dg_tilelang_A_negative
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
dbeta_tilelang
=
dbeta_tilelang_k
+
dbeta_tilelang
dbeta_tilelang
=
dbeta_tilelang_k
+
dbeta_tilelang
dg_tilelang
=
dg_tilelang
+
dg_tilelang_A_positive
.
sum
(
dim
=-
1
)
-
dg_tilelang_A_negative
.
sum
(
dg_tilelang
=
dg_tilelang
+
dg_tilelang_A_positive
.
sum
(
dim
=-
1
)
-
dg_tilelang_A_negative
.
sum
(
dim
=-
1
)
dim
=-
1
)
from
test_utils
import
assert_similar
from
utils
import
assert_similar
assert_similar
(
dk_ref
,
dk_tilelang
,
eps
=
1e-5
,
name
=
"dk"
,
raise_assert
=
False
)
assert_similar
(
dk_ref
,
dk_tilelang
,
eps
=
1e-5
,
name
=
"dk"
,
raise_assert
=
False
)
assert_similar
(
dv_ref
,
dv_tilelang
,
eps
=
1e-5
,
name
=
"dv"
,
raise_assert
=
False
)
assert_similar
(
dv_ref
,
dv_tilelang
,
eps
=
1e-5
,
name
=
"dv"
,
raise_assert
=
False
)
assert_similar
(
dbeta_ref
,
dbeta_tilelang
,
eps
=
1e-5
,
name
=
"dbeta"
,
raise_assert
=
False
)
assert_similar
(
dbeta_ref
,
dbeta_tilelang
,
eps
=
1e-5
,
name
=
"dbeta"
,
raise_assert
=
False
)
...
@@ -517,11 +518,11 @@ def main():
...
@@ -517,11 +518,11 @@ def main():
H
=
8
,
H
=
8
,
DK
=
DK
,
DK
=
DK
,
DV
=
DV
,
DV
=
DV
,
input_dtype
=
"
bfloat16
"
,
input_dtype
=
T
.
bfloat16
,
output_dtype
=
"
bfloat16
"
,
output_dtype
=
T
.
bfloat16
,
accum_dtype
=
"
float32
"
,
accum_dtype
=
T
.
float32
,
gate_dtype
=
"
float32
"
,
gate_dtype
=
T
.
float32
,
state_dtype
=
"
float32
"
,
state_dtype
=
T
.
float32
,
chunk_size
=
64
,
chunk_size
=
64
,
block_DK
=
32
,
block_DK
=
32
,
block_DV
=
32
,
block_DV
=
32
,
...
...
examples/gdn/test_example_gdn_compilation.py
View file @
667632cc
import
tilelang.testing
import
torch
import
torch
import
tilelang.testing
from
tilelang
import
language
as
T
B
=
1
B
=
1
S
=
1024
# small but for test only.
S
=
1024
# small but for test only.
H
=
32
H
=
32
DK
=
128
DK
=
128
DV
=
128
DV
=
128
input_dtype
=
"
bfloat16
"
input_dtype
=
T
.
bfloat16
output_dtype
=
"
bfloat16
"
output_dtype
=
T
.
bfloat16
accum_dtype
=
"
float32
"
accum_dtype
=
T
.
float32
gate_dtype
=
"
float32
"
gate_dtype
=
T
.
float32
state_dtype
=
"
float32
"
state_dtype
=
T
.
float32
chunk_size
=
64
chunk_size
=
64
use_g
=
True
use_g
=
True
use_initial_state
=
True
use_initial_state
=
True
...
@@ -25,16 +26,10 @@ num_stages = 1
...
@@ -25,16 +26,10 @@ num_stages = 1
def
test_example_wy_fast_compilation
():
def
test_example_wy_fast_compilation
():
from
example_wy_fast
import
tilelang_recompute_w_u_fwd
,
prepare_input
from
example_wy_fast
import
tilelang_recompute_w_u_fwd
,
prepare_input
K
,
V
,
Beta
,
G
,
A
=
prepare_input
(
K
,
V
,
Beta
,
G
,
A
=
prepare_input
(
B
,
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
gate_dtype
=
getattr
(
torch
,
gate_dtype
)
S
,
)
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
gate_dtype
=
getattr
(
torch
,
gate_dtype
))
# tilelang
# tilelang
block_S
=
chunk_size
block_S
=
chunk_size
kernel
=
tilelang_recompute_w_u_fwd
(
kernel
=
tilelang_recompute_w_u_fwd
(
...
@@ -52,22 +47,31 @@ def test_example_wy_fast_compilation():
...
@@ -52,22 +47,31 @@ def test_example_wy_fast_compilation():
block_DK
=
block_DK
,
block_DK
=
block_DK
,
block_DV
=
block_DV
,
block_DV
=
block_DV
,
threads
=
threads
,
threads
=
threads
,
num_stages
=
num_stages
)
num_stages
=
num_stages
,
)
print
(
kernel
.
get_kernel_source
())
print
(
kernel
.
get_kernel_source
())
W_tilelang
,
U_tilelang
=
kernel
(
K
,
V
,
Beta
,
G
,
A
)
W_tilelang
,
U_tilelang
=
kernel
(
K
,
V
,
Beta
,
G
,
A
)
def
test_example_wy_fast_bwd_split_compilation
():
def
test_example_wy_fast_bwd_split_compilation
():
from
example_wy_fast_bwd_split
import
tilelang_wy_fast_bwd
,
tilelang_wy_fast_bwd_split
,
prepare_input
,
prepare_output
from
example_wy_fast_bwd_split
import
tilelang_wy_fast_bwd
,
tilelang_wy_fast_bwd_split
,
prepare_input
,
prepare_output
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
=
prepare_input
(
getattr
(
torch
,
output_dtype
),
B
,
getattr
(
torch
,
S
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
H
,
getattr
(
torch
,
state_dtype
))
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
),
)
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang
,
dg_tilelang
=
prepare_output
(
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang
,
dg_tilelang
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
)
getattr
(
torch
,
state_dtype
)
)
)
BS
=
chunk_size
BS
=
chunk_size
dA_tilelang
=
torch
.
empty
(
B
,
S
,
H
,
BS
,
dtype
=
getattr
(
torch
,
input_dtype
)).
cuda
()
dA_tilelang
=
torch
.
empty
(
B
,
S
,
H
,
BS
,
dtype
=
getattr
(
torch
,
input_dtype
)).
cuda
()
dbeta_tilelang_k
=
torch
.
empty
(
B
,
S
,
H
,
dtype
=
getattr
(
torch
,
output_dtype
)).
cuda
()
dbeta_tilelang_k
=
torch
.
empty
(
B
,
S
,
H
,
dtype
=
getattr
(
torch
,
output_dtype
)).
cuda
()
...
@@ -75,67 +79,146 @@ def test_example_wy_fast_bwd_split_compilation():
...
@@ -75,67 +79,146 @@ def test_example_wy_fast_bwd_split_compilation():
dg_tilelang_A_negative
=
torch
.
empty
(
B
,
S
,
H
,
BS
,
dtype
=
getattr
(
torch
,
gate_dtype
)).
cuda
()
dg_tilelang_A_negative
=
torch
.
empty
(
B
,
S
,
H
,
BS
,
dtype
=
getattr
(
torch
,
gate_dtype
)).
cuda
()
# tilelang
# tilelang
kernel
=
tilelang_wy_fast_bwd
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
kernel
=
tilelang_wy_fast_bwd
(
gate_dtype
,
state_dtype
,
chunk_size
,
block_DK
,
block_DV
,
threads
,
B
,
num_stages
)
S
,
dA_tilelang
,
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang
,
dg_tilelang
=
kernel
(
H
,
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
)
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
block_DK
,
block_DV
,
threads
,
num_stages
,
)
dA_tilelang
,
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang
,
dg_tilelang
=
kernel
(
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
kernel_split
=
tilelang_wy_fast_bwd_split
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
kernel_split
=
tilelang_wy_fast_bwd_split
(
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
B
,
block_DK
,
block_DV
,
threads
,
num_stages
)
S
,
kernel_split
(
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
,
dA_tilelang
,
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang_k
,
H
,
dg_tilelang_A_positive
,
dg_tilelang_A_negative
)
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
block_DK
,
block_DV
,
threads
,
num_stages
,
)
kernel_split
(
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
,
dA_tilelang
,
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang_k
,
dg_tilelang_A_positive
,
dg_tilelang_A_negative
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
dbeta_tilelang
=
dbeta_tilelang_k
+
dbeta_tilelang
dbeta_tilelang
=
dbeta_tilelang_k
+
dbeta_tilelang
dg_tilelang
=
dg_tilelang
+
dg_tilelang_A_positive
.
sum
(
dim
=-
1
)
-
dg_tilelang_A_negative
.
sum
(
dg_tilelang
=
dg_tilelang
+
dg_tilelang_A_positive
.
sum
(
dim
=-
1
)
-
dg_tilelang_A_negative
.
sum
(
dim
=-
1
)
dim
=-
1
)
def
test_example_chunk_o_compilation
():
def
test_example_chunk_o_compilation
():
from
example_chunk_o
import
tilelang_chunk_fwd_o
,
prepare_input
from
example_chunk_o
import
tilelang_chunk_fwd_o
,
prepare_input
Q
,
K
,
V
,
HIDDEN
,
G
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
Q
,
K
,
V
,
HIDDEN
,
G
=
prepare_input
(
getattr
(
torch
,
gate_dtype
))
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
)
scale
=
1.0
/
DK
**
0.5
scale
=
1.0
/
DK
**
0.5
block_S
=
chunk_size
block_S
=
chunk_size
kernel
=
tilelang_chunk_fwd_o
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
kernel
=
tilelang_chunk_fwd_o
(
gate_dtype
,
chunk_size
,
scale
,
use_g
,
block_S
,
block_DK
,
block_DV
,
B
,
threads
,
num_stages
)
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
chunk_size
,
scale
,
use_g
,
block_S
,
block_DK
,
block_DV
,
threads
,
num_stages
,
)
O_tilelang
=
kernel
(
Q
,
K
,
V
,
HIDDEN
,
G
)
# noqa: F841
O_tilelang
=
kernel
(
Q
,
K
,
V
,
HIDDEN
,
G
)
# noqa: F841
def
test_example_chunk_o_bwd_compilation
():
def
test_example_chunk_o_bwd_compilation
():
from
example_chunk_o_bwd
import
tilelang_chunk_o_bwd_dqkwg
,
prepare_input
from
example_chunk_o_bwd
import
tilelang_chunk_o_bwd_dqkwg
,
prepare_input
Q
,
K
,
V
,
h
,
G
,
dO
,
dh
,
dv
,
W
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
Q
,
K
,
V
,
h
,
G
,
dO
,
dh
,
dv
,
W
=
prepare_input
(
getattr
(
torch
,
output_dtype
),
B
,
getattr
(
torch
,
accum_dtype
),
S
,
getattr
(
torch
,
gate_dtype
),
H
,
getattr
(
torch
,
state_dtype
))
DK
,
kernel
=
tilelang_chunk_o_bwd_dqkwg
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
DV
,
gate_dtype
,
state_dtype
,
chunk_size
,
1.0
,
use_g
,
True
,
chunk_size
,
block_DK
,
block_DV
,
threads
,
num_stages
)
getattr
(
torch
,
input_dtype
),
dq_tilelang
,
dk_tilelang
,
dw_tilelang
,
dg_tilelang
=
kernel
(
Q
,
K
,
V
,
h
,
G
,
dO
,
dh
,
dv
,
getattr
(
torch
,
output_dtype
),
W
)
# noqa: F841
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
),
)
kernel
=
tilelang_chunk_o_bwd_dqkwg
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
1.0
,
use_g
,
True
,
block_DK
,
block_DV
,
threads
,
num_stages
,
)
dq_tilelang
,
dk_tilelang
,
dw_tilelang
,
dg_tilelang
=
kernel
(
Q
,
K
,
V
,
h
,
G
,
dO
,
dh
,
dv
,
W
)
# noqa: F841
if
use_g
:
if
use_g
:
dg_tilelang
=
dg_tilelang
.
sum
(
dim
=
0
)
dg_tilelang
=
dg_tilelang
.
sum
(
dim
=
0
)
def
test_example_chunk_scaled_dot_kkt_compilation
():
def
test_example_chunk_scaled_dot_kkt_compilation
():
from
example_chunk_scaled_dot_kkt
import
tilelang_chunk_scaled_dot_kkt_fwd
,
prepare_input
from
example_chunk_scaled_dot_kkt
import
tilelang_chunk_scaled_dot_kkt_fwd
,
prepare_input
K
,
Beta
,
G
=
prepare_input
(
B
,
S
,
H
,
DK
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
))
K
,
Beta
,
G
=
prepare_input
(
B
,
S
,
H
,
DK
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
))
block_S
=
chunk_size
block_S
=
chunk_size
kernel
=
tilelang_chunk_scaled_dot_kkt_fwd
(
B
,
S
,
H
,
DK
,
chunk_size
,
input_dtype
,
output_dtype
,
kernel
=
tilelang_chunk_scaled_dot_kkt_fwd
(
accum_dtype
,
use_g
,
block_S
,
block_DK
,
threads
,
B
,
S
,
H
,
DK
,
chunk_size
,
input_dtype
,
output_dtype
,
accum_dtype
,
use_g
,
block_S
,
block_DK
,
threads
,
num_stages
num_stages
)
)
A_tilelang
=
kernel
(
K
,
Beta
,
G
)
# noqa: F841
A_tilelang
=
kernel
(
K
,
Beta
,
G
)
# noqa: F841
def
test_example_cumsum_compilation
():
def
test_example_cumsum_compilation
():
from
example_cumsum
import
tilelang_chunk_local_cumsum_scalar
,
prepare_cumsum_input
,
prepare_cumsum_output
from
example_cumsum
import
tilelang_chunk_local_cumsum_scalar
,
prepare_cumsum_input
,
prepare_cumsum_output
G
=
prepare_cumsum_input
(
B
,
S
,
H
,
getattr
(
torch
,
gate_dtype
))
G
=
prepare_cumsum_input
(
B
,
S
,
H
,
getattr
(
torch
,
gate_dtype
))
G_new_tilelang
=
prepare_cumsum_output
(
B
,
S
,
H
,
getattr
(
torch
,
gate_dtype
))
G_new_tilelang
=
prepare_cumsum_output
(
B
,
S
,
H
,
getattr
(
torch
,
gate_dtype
))
block_S
=
chunk_size
block_S
=
chunk_size
...
@@ -157,33 +240,79 @@ def test_example_cumsum_compilation():
...
@@ -157,33 +240,79 @@ def test_example_cumsum_compilation():
def
test_example_chunk_delta_h_compilation
():
def
test_example_chunk_delta_h_compilation
():
from
example_chunk_delta_h
import
tilelang_chunk_gated_delta_rule_fwd_h
,
prepare_input
from
example_chunk_delta_h
import
tilelang_chunk_gated_delta_rule_fwd_h
,
prepare_input
K
,
W
,
U
,
G
,
initial_state
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
K
,
W
,
U
,
G
,
initial_state
=
prepare_input
(
getattr
(
torch
,
output_dtype
),
B
,
getattr
(
torch
,
accum_dtype
),
S
,
getattr
(
torch
,
gate_dtype
))
H
,
kernel
=
tilelang_chunk_gated_delta_rule_fwd_h
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
DK
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
DV
,
use_g
,
use_initial_state
,
store_final_state
,
chunk_size
,
save_new_value
,
block_DK
,
block_DV
,
threads
,
getattr
(
torch
,
input_dtype
),
num_stages
)
getattr
(
torch
,
output_dtype
),
h_tilelang
,
final_state_tilelang
,
V_new_tilelang
=
kernel
(
K
,
W
,
U
,
G
,
getattr
(
torch
,
accum_dtype
),
initial_state
)
# noqa: F841
getattr
(
torch
,
gate_dtype
),
)
kernel
=
tilelang_chunk_gated_delta_rule_fwd_h
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
use_g
,
use_initial_state
,
store_final_state
,
save_new_value
,
block_DK
,
block_DV
,
threads
,
num_stages
,
)
h_tilelang
,
final_state_tilelang
,
V_new_tilelang
=
kernel
(
K
,
W
,
U
,
G
,
initial_state
)
# noqa: F841
def
test_example_chunk_delta_bwd_compilation
():
def
test_example_chunk_delta_bwd_compilation
():
from
example_chunk_delta_bwd
import
tilelang_chunk_gated_delta_rule_bwd_dhu
,
prepare_input
from
example_chunk_delta_bwd
import
tilelang_chunk_gated_delta_rule_bwd_dhu
,
prepare_input
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
=
prepare_input
(
getattr
(
torch
,
output_dtype
),
B
,
getattr
(
torch
,
accum_dtype
),
S
,
getattr
(
torch
,
gate_dtype
),
H
,
getattr
(
torch
,
state_dtype
))
DK
,
kernel
=
tilelang_chunk_gated_delta_rule_bwd_dhu
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
DV
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
chunk_size
,
1.0
,
use_g
,
use_initial_state
,
getattr
(
torch
,
input_dtype
),
use_final_state_gradient
,
block_DV
,
threads
,
getattr
(
torch
,
output_dtype
),
num_stages
)
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
),
)
kernel
=
tilelang_chunk_gated_delta_rule_bwd_dhu
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
1.0
,
use_g
,
use_initial_state
,
use_final_state_gradient
,
block_DV
,
threads
,
num_stages
,
)
dh_tilelang
,
dh0_tilelang
,
dv2_tilelang
=
kernel
(
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
)
# noqa: F841
dh_tilelang
,
dh0_tilelang
,
dv2_tilelang
=
kernel
(
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
)
# noqa: F841
...
...
examples/gdn/utils.py
→
examples/gdn/
test_
utils.py
View file @
667632cc
...
@@ -9,7 +9,7 @@ def calc_sim(x, y, name="tensor"):
...
@@ -9,7 +9,7 @@ def calc_sim(x, y, name="tensor"):
x
,
y
=
x
.
data
.
double
(),
y
.
data
.
double
()
x
,
y
=
x
.
data
.
double
(),
y
.
data
.
double
()
denominator
=
(
x
*
x
+
y
*
y
).
sum
()
denominator
=
(
x
*
x
+
y
*
y
).
sum
()
if
denominator
==
0
:
if
denominator
==
0
:
print_red_warning
(
f
'
{
name
}
all zero
'
)
print_red_warning
(
f
"
{
name
}
all zero
"
)
return
1
return
1
sim
=
2
*
(
x
*
y
).
sum
()
/
denominator
sim
=
2
*
(
x
*
y
).
sum
()
/
denominator
return
sim
return
sim
...
@@ -19,21 +19,19 @@ def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True):
...
@@ -19,21 +19,19 @@ def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True):
x_mask
=
torch
.
isfinite
(
x
)
x_mask
=
torch
.
isfinite
(
x
)
y_mask
=
torch
.
isfinite
(
y
)
y_mask
=
torch
.
isfinite
(
y
)
if
not
torch
.
all
(
x_mask
==
y_mask
):
if
not
torch
.
all
(
x_mask
==
y_mask
):
print_red_warning
(
f
'
{
name
}
Error: isfinite mask mismatch
'
)
print_red_warning
(
f
"
{
name
}
Error: isfinite mask mismatch
"
)
if
raise_assert
:
if
raise_assert
:
raise
AssertionError
raise
AssertionError
if
not
torch
.
isclose
(
if
not
torch
.
isclose
(
x
.
masked_fill
(
x_mask
,
0
),
y
.
masked_fill
(
y_mask
,
0
),
rtol
=
0
,
atol
=
0
,
equal_nan
=
True
).
all
():
x
.
masked_fill
(
x_mask
,
0
),
y
.
masked_fill
(
y_mask
,
0
),
rtol
=
0
,
atol
=
0
,
print_red_warning
(
f
"
{
name
}
Error: nonfinite value mismatch"
)
equal_nan
=
True
).
all
():
print_red_warning
(
f
'
{
name
}
Error: nonfinite value mismatch'
)
if
raise_assert
:
if
raise_assert
:
raise
AssertionError
raise
AssertionError
x
=
x
.
masked_fill
(
~
x_mask
,
0
)
x
=
x
.
masked_fill
(
~
x_mask
,
0
)
y
=
y
.
masked_fill
(
~
y_mask
,
0
)
y
=
y
.
masked_fill
(
~
y_mask
,
0
)
sim
=
calc_sim
(
x
,
y
,
name
)
sim
=
calc_sim
(
x
,
y
,
name
)
diff
=
1.
-
sim
diff
=
1.
0
-
sim
if
not
(
0
<=
diff
<=
eps
):
if
not
(
0
<=
diff
<=
eps
):
print_red_warning
(
f
'
{
name
}
Error:
{
diff
}
'
)
print_red_warning
(
f
"
{
name
}
Error:
{
diff
}
"
)
if
raise_assert
:
if
raise_assert
:
raise
AssertionError
raise
AssertionError
else
:
else
:
...
...
examples/gemm/README.md
View file @
667632cc
...
@@ -53,7 +53,7 @@ import tilelang
...
@@ -53,7 +53,7 @@ import tilelang
from
tilelang
import
Profiler
from
tilelang
import
Profiler
import
tilelang.language
as
T
import
tilelang.language
as
T
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"
float16
"
,
accum_dtype
=
"
float
"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
T
.
float16
,
accum_dtype
=
T
.
float
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
@@ -176,7 +176,7 @@ import tilelang.language as T
...
@@ -176,7 +176,7 @@ import tilelang.language as T
# that helps align data for MMA (Matrix Multiply-Accumulate) operations.
# that helps align data for MMA (Matrix Multiply-Accumulate) operations.
from
tilelang.intrinsics
import
make_mma_swizzle_layout
as
make_swizzle_layout
from
tilelang.intrinsics
import
make_mma_swizzle_layout
as
make_swizzle_layout
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"
float16
"
,
accum_dtype
=
"
float
"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
T
.
float16
,
accum_dtype
=
T
.
float
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
@@ -265,18 +265,18 @@ def tl_matmul(
...
@@ -265,18 +265,18 @@ def tl_matmul(
accum_dtype
,
accum_dtype
,
):
):
assert
in_dtype
in
[
assert
in_dtype
in
[
"
float16
"
,
T
.
float16
,
"
int8
"
,
T
.
int8
,
],
"Currently only float16 and int8 are supported"
],
"Currently only float16 and int8 are supported"
assert
out_dtype
in
[
assert
out_dtype
in
[
"
float16
"
,
T
.
float16
,
"
float32
"
,
T
.
float32
,
"
int32
"
,
T
.
int32
,
],
"Currently only float16, float32 and int32 are supported"
],
"Currently only float16, float32 and int32 are supported"
micro_size_x
=
micro_size_y
=
micro_size_k
=
16
micro_size_x
=
micro_size_y
=
micro_size_k
=
16
if
out_dtype
==
"
int32
"
:
if
out_dtype
==
T
.
int32
:
micro_size_k
=
32
micro_size_k
=
32
# This is a debug config
# This is a debug config
...
...
examples/gemm/example_gemm.py
View file @
667632cc
...
@@ -3,13 +3,12 @@ import tilelang.language as T
...
@@ -3,13 +3,12 @@ import tilelang.language as T
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
T
.
float16
,
accum_dtype
=
T
.
float32
):
@
T
.
prim_func
@
T
.
prim_func
def
gemm
(
def
gemm
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
...
examples/gemm/example_gemm_autotune.py
View file @
667632cc
...
@@ -51,9 +51,9 @@ def get_configs(M, N, K, with_roller=False, topk=20):
...
@@ -51,9 +51,9 @@ def get_configs(M, N, K, with_roller=False, topk=20):
M
=
M
,
M
=
M
,
N
=
N
,
N
=
N
,
K
=
K
,
K
=
K
,
in_dtype
=
"
float16
"
,
in_dtype
=
T
.
float16
,
out_dtype
=
"
float16
"
,
out_dtype
=
T
.
float16
,
accum_dtype
=
"
float
"
,
accum_dtype
=
T
.
float
32
,
).
with_arch
(
arch
)
).
with_arch
(
arch
)
func
=
carve_template
.
equivalent_function
()
func
=
carve_template
.
equivalent_function
()
...
@@ -90,7 +90,8 @@ def get_configs(M, N, K, with_roller=False, topk=20):
...
@@ -90,7 +90,8 @@ def get_configs(M, N, K, with_roller=False, topk=20):
num_stages
,
num_stages
,
thread_num
,
thread_num
,
enable_rasterization
,
enable_rasterization
,
))
)
)
configs
=
[
configs
=
[
{
{
...
@@ -100,13 +101,13 @@ def get_configs(M, N, K, with_roller=False, topk=20):
...
@@ -100,13 +101,13 @@ def get_configs(M, N, K, with_roller=False, topk=20):
"num_stages"
:
c
[
3
],
"num_stages"
:
c
[
3
],
"thread_num"
:
c
[
4
],
"thread_num"
:
c
[
4
],
"enable_rasteration"
:
c
[
5
],
# keep param name for backward-compat
"enable_rasteration"
:
c
[
5
],
# keep param name for backward-compat
}
for
c
in
_configs
}
for
c
in
_configs
]
]
return
configs
return
configs
def
get_best_config
(
M
,
N
,
K
,
with_roller
=
False
):
def
get_best_config
(
M
,
N
,
K
,
with_roller
=
False
):
def
kernel
(
def
kernel
(
block_M
=
None
,
block_M
=
None
,
block_N
=
None
,
block_N
=
None
,
...
@@ -115,17 +116,16 @@ def get_best_config(M, N, K, with_roller=False):
...
@@ -115,17 +116,16 @@ def get_best_config(M, N, K, with_roller=False):
thread_num
=
None
,
thread_num
=
None
,
enable_rasteration
=
None
,
enable_rasteration
=
None
,
):
):
dtype
=
"
bfloat16
"
dtype
=
T
.
bfloat16
accum_dtype
=
"
float
"
accum_dtype
=
T
.
float
32
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
@@ -146,15 +146,18 @@ def get_best_config(M, N, K, with_roller=False):
...
@@ -146,15 +146,18 @@ def get_best_config(M, N, K, with_roller=False):
return
main
return
main
autotuner
=
AutoTuner
.
from_kernel
(
autotuner
=
(
kernel
=
kernel
,
configs
=
get_configs
(
M
,
N
,
K
,
with_roller
)).
set_compile_args
(
AutoTuner
.
from_kernel
(
kernel
=
kernel
,
configs
=
get_configs
(
M
,
N
,
K
,
with_roller
))
.
set_compile_args
(
out_idx
=
[
-
1
],
out_idx
=
[
-
1
],
target
=
"auto"
,
target
=
"auto"
,
).
set_profile_args
(
)
.
set_profile_args
(
supply_type
=
tl
.
TensorSupplyType
.
Integer
,
supply_type
=
tl
.
TensorSupplyType
.
Integer
,
ref_prog
=
ref_program
,
ref_prog
=
ref_program
,
skip_check
=
False
,
skip_check
=
False
,
)
)
)
return
autotuner
.
run
(
warmup
=
3
,
rep
=
20
)
return
autotuner
.
run
(
warmup
=
3
,
rep
=
20
)
...
@@ -167,52 +170,20 @@ def get_heuristic_config() -> dict:
...
@@ -167,52 +170,20 @@ def get_heuristic_config() -> dict:
sm_version
=
sm_major
*
10
+
sm_minor
sm_version
=
sm_major
*
10
+
sm_minor
print
(
f
"CUDA device capability:
{
sm_version
}
"
)
print
(
f
"CUDA device capability:
{
sm_version
}
"
)
if
sm_version
in
{
80
}:
if
sm_version
in
{
80
}:
return
{
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
2
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
2
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
elif
sm_version
in
{
90
}:
elif
sm_version
in
{
90
}:
return
{
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
64
,
"num_stages"
:
3
,
"thread_num"
:
256
,
"enable_rasteration"
:
True
}
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
64
,
"num_stages"
:
3
,
"thread_num"
:
256
,
"enable_rasteration"
:
True
}
else
:
else
:
return
{
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
0
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
0
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
@
tl
.
jit
(
out_idx
=
[
-
1
])
@
tl
.
jit
(
out_idx
=
[
-
1
])
def
matmul
(
M
,
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
T
.
float16
,
accum_dtype
=
T
.
float32
):
N
,
K
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
gemm_autotune
(
def
gemm_autotune
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
@@ -236,11 +207,7 @@ def matmul(M,
...
@@ -236,11 +207,7 @@ def matmul(M,
return
gemm_autotune
return
gemm_autotune
def
main
(
M
:
int
=
4096
,
def
main
(
M
:
int
=
4096
,
N
:
int
=
4096
,
K
:
int
=
4096
,
use_autotune
:
bool
=
False
,
with_roller
:
bool
=
False
):
N
:
int
=
4096
,
K
:
int
=
4096
,
use_autotune
:
bool
=
False
,
with_roller
:
bool
=
False
):
use_autotune
=
True
use_autotune
=
True
if
use_autotune
:
if
use_autotune
:
result
=
get_best_config
(
M
,
N
,
K
,
with_roller
)
result
=
get_best_config
(
M
,
N
,
K
,
with_roller
)
...
@@ -266,15 +233,7 @@ if __name__ == "__main__":
...
@@ -266,15 +233,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--m"
,
type
=
int
,
default
=
4096
,
help
=
"Matrix dimension M"
)
parser
.
add_argument
(
"--m"
,
type
=
int
,
default
=
4096
,
help
=
"Matrix dimension M"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
4096
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
4096
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
4096
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
4096
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--use_autotune"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune for matmul configs"
)
"--use_autotune"
,
parser
.
add_argument
(
"--with_roller"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to enable BitBLAS roller for search space"
)
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune for matmul configs"
)
parser
.
add_argument
(
"--with_roller"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to enable BitBLAS roller for search space"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
.
m
,
args
.
n
,
args
.
k
,
args
.
use_autotune
,
args
.
with_roller
)
main
(
args
.
m
,
args
.
n
,
args
.
k
,
args
.
use_autotune
,
args
.
with_roller
)
Prev
1
…
5
6
7
8
9
10
11
12
13
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment