Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
467
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1417 additions
and
1439 deletions
+1417
-1439
examples/flash_attention/example_mha_fwd_varlen.py
examples/flash_attention/example_mha_fwd_varlen.py
+37
-45
examples/flash_attention/test_example_flash_attention.py
examples/flash_attention/test_example_flash_attention.py
+2
-4
examples/flash_attention/varlen_utils.py
examples/flash_attention/varlen_utils.py
+9
-23
examples/flash_decoding/example_gqa_decode.py
examples/flash_decoding/example_gqa_decode.py
+115
-127
examples/flash_decoding/example_gqa_decode_varlen_logits.py
examples/flash_decoding/example_gqa_decode_varlen_logits.py
+123
-174
examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py
.../flash_decoding/example_gqa_decode_varlen_logits_paged.py
+130
-162
examples/flash_decoding/example_mha_inference.py
examples/flash_decoding/example_mha_inference.py
+69
-76
examples/fusedmoe/example_fusedmoe_tilelang.py
examples/fusedmoe/example_fusedmoe_tilelang.py
+159
-190
examples/fusedmoe/example_fusedmoe_torch.py
examples/fusedmoe/example_fusedmoe_torch.py
+37
-54
examples/fusedmoe/test_example_fusedmoe.py
examples/fusedmoe/test_example_fusedmoe.py
+2
-7
examples/gdn/example_chunk_delta_bwd.py
examples/gdn/example_chunk_delta_bwd.py
+130
-92
examples/gdn/example_chunk_delta_h.py
examples/gdn/example_chunk_delta_h.py
+74
-64
examples/gdn/example_chunk_o.py
examples/gdn/example_chunk_o.py
+46
-37
examples/gdn/example_chunk_o_bwd.py
examples/gdn/example_chunk_o_bwd.py
+93
-102
examples/gdn/example_chunk_scaled_dot_kkt.py
examples/gdn/example_chunk_scaled_dot_kkt.py
+22
-24
examples/gdn/example_cumsum.py
examples/gdn/example_cumsum.py
+16
-20
examples/gdn/example_wy_fast.py
examples/gdn/example_wy_fast.py
+31
-42
examples/gdn/example_wy_fast_bwd_split.py
examples/gdn/example_wy_fast_bwd_split.py
+113
-112
examples/gdn/test_example_gdn_compilation.py
examples/gdn/test_example_gdn_compilation.py
+203
-76
examples/gdn/test_utils.py
examples/gdn/test_utils.py
+6
-8
No files found.
examples/flash_attention/example_mha_fwd_varlen.py
View file @
29051439
...
...
@@ -47,7 +47,7 @@ def attention_ref(
if
upcast
:
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
dim
=
q
.
shape
[
-
1
]
scale
=
(
1.0
/
dim
)
**
0.5
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
# log2(e)
k
=
repeat
(
k
,
"b s h d -> b s (h g) d"
,
g
=
q
.
shape
[
2
]
//
k
.
shape
[
2
])
v
=
repeat
(
v
,
"b s h d -> b s (h g) d"
,
g
=
q
.
shape
[
2
]
//
v
.
shape
[
2
])
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
,
k
)
...
...
@@ -68,20 +68,13 @@ def attention_ref(
@
tilelang
.
jit
(
out_idx
=
[
6
],
pass_configs
=
{
out_idx
=
[
6
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn
(
batch_size
,
UQ
,
UKV
,
heads
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
0
,
threads
=
32
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
},
)
def
flashattn
(
batch_size
,
UQ
,
UKV
,
heads
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
0
,
threads
=
32
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
q_shape
=
[
UQ
,
heads
,
dim
]
k_shape
=
[
UKV
,
heads
,
dim
]
v_shape
=
[
UKV
,
heads
,
dim
]
...
...
@@ -100,9 +93,7 @@ def flashattn(batch_size,
max_seqlen_q
:
T
.
int32
,
Output_unpad
:
T
.
Tensor
(
o_shape
,
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
max_seqlen_q
,
block_M
),
heads
,
batch_size
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
max_seqlen_q
,
block_M
),
heads
,
batch_size
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
,
"shared"
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
,
"shared"
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
,
"shared"
)
...
...
@@ -151,15 +142,17 @@ def flashattn(batch_size,
K_shared
[
i
,
d
]
=
0
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
((
bx
*
block_M
+
i
>=
k
*
block_N
+
j
)
and
(
bx
*
block_M
+
i
>=
q_current_seqlen
or
k
*
block_N
+
j
>=
k_current_seqlen
),
-
T
.
infinity
(
acc_s
.
dtype
),
0
)
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
)
and
(
bx
*
block_M
+
i
>=
q_current_seqlen
or
k
*
block_N
+
j
>=
k_current_seqlen
),
-
T
.
infinity
(
acc_s
.
dtype
),
0
,
)
else
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
(
bx
*
block_M
+
i
>=
q_current_seqlen
or
k
*
block_N
+
j
>=
k_current_seqlen
),
-
T
.
infinity
(
acc_s
.
dtype
),
0
)
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
(
bx
*
block_M
+
i
>=
q_current_seqlen
or
k
*
block_N
+
j
>=
k_current_seqlen
),
-
T
.
infinity
(
acc_s
.
dtype
),
0
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
...
@@ -244,8 +237,7 @@ def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128):
output_pad_fn
,
dq_pad_fn
,
dk_pad_fn
,
)
=
generate_qkv
(
q
,
k
,
v
,
query_padding_mask
,
key_padding_mask
,
kvpacked
=
False
)
)
=
generate_qkv
(
q
,
k
,
v
,
query_padding_mask
,
key_padding_mask
,
kvpacked
=
False
)
UQ
=
q_unpad
.
shape
[
0
]
# unpadded query length
UK
=
k_unpad
.
shape
[
0
]
# unpadded key length
...
...
@@ -287,10 +279,10 @@ def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128):
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
8
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
64
,
help
=
'
heads
'
)
parser
.
add_argument
(
'
--seq_len
'
,
type
=
int
,
default
=
2048
,
help
=
'
sequence length
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
128
,
help
=
'
dim
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
8
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
64
,
help
=
"
heads
"
)
parser
.
add_argument
(
"
--seq_len
"
,
type
=
int
,
default
=
2048
,
help
=
"
sequence length
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
128
,
help
=
"
dim
"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_len
,
args
.
dim
)
examples/flash_attention/test_example_flash_attention.py
View file @
29051439
...
...
@@ -62,14 +62,12 @@ def test_example_mha_bwd_wgmma_pipelined():
@
tilelang
.
testing
.
requires_cuda
@
tilelang
.
testing
.
requires_cuda_compute_version_ge
(
9
,
0
)
def
test_example_gqa_fwd_bshd_wgmma_pipelined
():
example_gqa_fwd_bshd_wgmma_pipelined
.
main
(
batch
=
1
,
heads
=
16
,
seq_len
=
1024
,
dim
=
128
,
is_causal
=
False
,
groups
=
16
,
tune
=
False
)
example_gqa_fwd_bshd_wgmma_pipelined
.
main
(
batch
=
1
,
heads
=
16
,
seq_len
=
1024
,
dim
=
128
,
is_causal
=
False
,
groups
=
16
,
tune
=
False
)
@
tilelang
.
testing
.
requires_cuda
def
test_example_gqa_fwd_bshd
():
example_gqa_fwd_bshd
.
main
(
batch
=
1
,
heads
=
16
,
seq_len
=
1024
,
dim
=
128
,
is_causal
=
False
,
groups
=
16
,
tune
=
False
)
example_gqa_fwd_bshd
.
main
(
batch
=
1
,
heads
=
16
,
seq_len
=
1024
,
dim
=
128
,
is_causal
=
False
,
groups
=
16
,
tune
=
False
)
@
tilelang
.
testing
.
requires_cuda
...
...
examples/flash_attention/varlen_utils.py
View file @
29051439
...
...
@@ -9,22 +9,14 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
if
mode
==
"full"
:
lengths
=
torch
.
full
((
batch_size
,
1
),
max_seqlen
,
device
=
device
,
dtype
=
torch
.
int32
)
elif
mode
==
"random"
:
lengths
=
torch
.
randint
(
max
(
1
,
max_seqlen
-
20
),
max_seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
)
lengths
=
torch
.
randint
(
max
(
1
,
max_seqlen
-
20
),
max_seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
)
elif
mode
==
"third"
:
lengths
=
torch
.
randint
(
max_seqlen
//
3
,
max_seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
)
padding_mask
=
(
repeat
(
torch
.
arange
(
max_seqlen
,
device
=
device
),
"s -> b s"
,
b
=
batch_size
)
<
lengths
)
padding_mask
=
repeat
(
torch
.
arange
(
max_seqlen
,
device
=
device
),
"s -> b s"
,
b
=
batch_size
)
<
lengths
return
padding_mask
def
generate_qkv
(
q
,
k
,
v
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
kvpacked
=
False
,
qkvpacked
=
False
):
def
generate_qkv
(
q
,
k
,
v
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
kvpacked
=
False
,
qkvpacked
=
False
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
...
...
@@ -39,15 +31,12 @@ def generate_qkv(q,
if
query_padding_mask
is
not
None
:
q_unpad
,
indices_q
,
cu_seqlens_q
,
max_seqlen_q
=
unpad_input
(
q
,
query_padding_mask
)
output_pad_fn
=
lambda
output_unpad
:
pad_input
(
output_unpad
,
indices_q
,
batch_size
,
seqlen_q
)
output_pad_fn
=
lambda
output_unpad
:
pad_input
(
output_unpad
,
indices_q
,
batch_size
,
seqlen_q
)
else
:
q_unpad
=
rearrange
(
q
,
"b s h d -> (b s) h d"
)
cu_seqlens_q
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_q
,
step
=
seqlen_q
,
dtype
=
torch
.
int32
,
device
=
q_unpad
.
device
)
cu_seqlens_q
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_q
,
step
=
seqlen_q
,
dtype
=
torch
.
int32
,
device
=
q_unpad
.
device
)
max_seqlen_q
=
seqlen_q
output_pad_fn
=
lambda
output_unpad
:
rearrange
(
output_unpad
,
"(b s) h d -> b s h d"
,
b
=
batch_size
)
output_pad_fn
=
lambda
output_unpad
:
rearrange
(
output_unpad
,
"(b s) h d -> b s h d"
,
b
=
batch_size
)
if
key_padding_mask
is
not
None
:
k_unpad
,
indices_k
,
cu_seqlens_k
,
max_seqlen_k
=
unpad_input
(
k
,
key_padding_mask
)
...
...
@@ -55,8 +44,7 @@ def generate_qkv(q,
else
:
k_unpad
=
rearrange
(
k
,
"b s h d -> (b s) h d"
)
v_unpad
=
rearrange
(
v
,
"b s h d -> (b s) h d"
)
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_k
,
step
=
seqlen_k
,
dtype
=
torch
.
int32
,
device
=
k_unpad
.
device
)
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_k
,
step
=
seqlen_k
,
dtype
=
torch
.
int32
,
device
=
k_unpad
.
device
)
max_seqlen_k
=
seqlen_k
if
qkvpacked
:
...
...
@@ -67,8 +55,7 @@ def generate_qkv(q,
if
query_padding_mask
is
not
None
:
dqkv_pad_fn
=
lambda
dqkv_unpad
:
pad_input
(
dqkv_unpad
,
indices_q
,
batch_size
,
seqlen_q
)
else
:
dqkv_pad_fn
=
lambda
dqkv_unpad
:
rearrange
(
dqkv_unpad
,
"(b s) t h d -> b s t h d"
,
b
=
batch_size
)
dqkv_pad_fn
=
lambda
dqkv_unpad
:
rearrange
(
dqkv_unpad
,
"(b s) t h d -> b s t h d"
,
b
=
batch_size
)
return
(
qkv_unpad
.
detach
().
requires_grad_
(),
cu_seqlens_q
,
...
...
@@ -84,8 +71,7 @@ def generate_qkv(q,
if
key_padding_mask
is
not
None
:
dkv_pad_fn
=
lambda
dkv_unpad
:
pad_input
(
dkv_unpad
,
indices_k
,
batch_size
,
seqlen_k
)
else
:
dkv_pad_fn
=
lambda
dkv_unpad
:
rearrange
(
dkv_unpad
,
"(b s) t h d -> b s t h d"
,
b
=
batch_size
)
dkv_pad_fn
=
lambda
dkv_unpad
:
rearrange
(
dkv_unpad
,
"(b s) t h d -> b s t h d"
,
b
=
batch_size
)
return
(
q_unpad
.
detach
().
requires_grad_
(),
kv_unpad
.
detach
().
requires_grad_
(),
...
...
examples/flash_decoding/example_gqa_decode.py
View file @
29051439
...
...
@@ -20,13 +20,7 @@ def get_configs():
threads
=
[
128
]
_configs
=
list
(
itertools
.
product
(
block_N
,
block_H
,
num_split
,
num_stages
,
threads
))
configs
=
[{
'block_N'
:
c
[
0
],
'block_H'
:
c
[
1
],
'num_split'
:
c
[
2
],
'num_stages'
:
c
[
3
],
'threads'
:
c
[
4
]
}
for
c
in
_configs
]
configs
=
[{
"block_N"
:
c
[
0
],
"block_H"
:
c
[
1
],
"num_split"
:
c
[
2
],
"num_stages"
:
c
[
3
],
"threads"
:
c
[
4
]}
for
c
in
_configs
]
return
configs
...
...
@@ -48,17 +42,13 @@ def get_heuristic_config() -> Tuple[Dict, int]:
# TODO(lei): fix warp specialized and tma lower pass
def
get_pass_configs
():
return
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
}
return
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
}
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
tilelang
.
jit
(
out_idx
=
[
6
],
pass_configs
=
get_pass_configs
())
def
flashattn
(
batch
,
heads
,
groups
,
seqlen_kv
,
dim
,
block_N
,
block_H
,
num_split
,
num_stages
,
threads
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
def
flashattn
(
batch
,
heads
,
groups
,
seqlen_kv
,
dim
,
block_N
,
block_H
,
num_split
,
num_stages
,
threads
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape_q
=
[
batch
,
heads
,
dim
]
shape_k
=
[
batch
,
seqlen_kv
,
groups
,
dim
]
shape_v
=
[
batch
,
seqlen_kv
,
groups
,
dim
]
...
...
@@ -98,20 +88,19 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
hid
=
by
cur_kv_head
=
hid
//
(
kv_group_num
//
valid_block_H
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
T
.
ceildiv
((
seqlen_kv
//
num_split
),
block_N
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
T
.
copy
(
K
[
bid
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
copy
(
mask
[
bid
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
],
mask_local
)
T
.
copy
(
K
[
bid
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
copy
(
mask
[
bid
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
],
mask_local
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
mask_local
[
j
]
!=
0
,
acc_s
[
i
,
j
],
-
T
.
infinity
(
accum_dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
mask_local
[
j
]
!=
0
,
acc_s
[
i
,
j
],
-
T
.
infinity
(
accum_dtype
))
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
...
...
@@ -127,14 +116,14 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
T
.
copy
(
acc_s
,
acc_s_cast
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
V
[
bid
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
copy
(
V
[
bid
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
acc_o
[:
valid_block_H
,
:],
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bid
,
hid
*
valid_block_H
:
(
hid
+
1
)
*
valid_block_H
,
:])
T
.
copy
(
O_shared
,
Output
[
bid
,
hid
*
valid_block_H
:
(
hid
+
1
)
*
valid_block_H
,
:])
@
T
.
macro
def
flash_attn_split
(
...
...
@@ -165,7 +154,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
sid
=
bz
cur_kv_head
=
hid
//
(
kv_group_num
//
valid_block_H
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -174,19 +163,26 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
T
.
copy
(
K
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
valid_block_N
:(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
valid_block_N
,
cur_kv_head
,
:],
K_shared
)
K
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
valid_block_N
:
(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
valid_block_N
,
cur_kv_head
,
:,
],
K_shared
,
)
T
.
copy
(
mask
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
valid_block_N
:(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
valid_block_N
,
cur_kv_head
],
mask_local
)
mask
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
valid_block_N
:
(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
valid_block_N
,
cur_kv_head
,
],
mask_local
,
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
((
mask_local
[
j
]
!=
0
)
&
(
j
<
seqlen_kv
//
num_split
),
acc_s
[
i
,
j
],
-
T
.
infinity
(
accum_dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
((
mask_local
[
j
]
!=
0
)
&
(
j
<
seqlen_kv
//
num_split
),
acc_s
[
i
,
j
],
-
T
.
infinity
(
accum_dtype
))
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
...
...
@@ -203,9 +199,14 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
V
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
valid_block_N
:(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
valid_block_N
,
cur_kv_head
,
:],
V_shared
)
V
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
valid_block_N
:
(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
valid_block_N
,
cur_kv_head
,
:,
],
V_shared
,
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
...
...
@@ -216,8 +217,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
if
i
<
valid_block_H
:
glse
[
bid
,
hid
*
valid_block_H
+
i
,
sid
]
=
logsum
[
i
]
T
.
copy
(
acc_o
[:
valid_block_H
,
:],
O_shared
)
T
.
copy
(
O_shared
,
Output_partial
[
bid
,
hid
*
valid_block_H
:(
hid
+
1
)
*
valid_block_H
,
sid
,
:])
T
.
copy
(
O_shared
,
Output_partial
[
bid
,
hid
*
valid_block_H
:
(
hid
+
1
)
*
valid_block_H
,
sid
,
:])
@
T
.
macro
def
combine
(
...
...
@@ -233,12 +233,14 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
lse_max_local
=
T
.
alloc_fragment
([
128
],
accum_dtype
)
scale_local
=
T
.
alloc_fragment
([
128
],
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
lse_max_local
:
T
.
Fragment
(
lse_max_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
# lse_local: (local_id, thread_id)
lse_local
:
T
.
Fragment
(
lse_local
.
shape
,
forward_fn
=
lambda
i
,
j
:
(
j
,
i
)),
})
}
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
o_accum_local
)
...
...
@@ -305,27 +307,21 @@ def ref_program(query, key, value, mask, glse, Output_partial):
dim
=
query
.
shape
[
-
1
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
scale
=
dim
**
0.5
key
=
rearrange
(
key
,
'
b n h d -> b h n d
'
)
# [batch_size, groups, seqlen_kv, dim]
value
=
rearrange
(
value
,
'
b n h d -> b h n d
'
)
# [batch_size, groups, seqlen_kv, dim]
key
=
rearrange
(
key
,
"
b n h d -> b h n d
"
)
# [batch_size, groups, seqlen_kv, dim]
value
=
rearrange
(
value
,
"
b n h d -> b h n d
"
)
# [batch_size, groups, seqlen_kv, dim]
query
=
rearrange
(
query
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
query
=
rearrange
(
query
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
scores
=
einsum
(
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, groups, seqlen_kv]
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, groups, seqlen_kv]
if
mask
is
not
None
:
mask
=
rearrange
(
mask
,
'
b s h -> b h s
'
)
mask
=
rearrange
(
mask
,
"
b s h -> b h s
"
)
mask
=
mask
.
unsqueeze
(
1
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
"
-inf
"
))
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
out
=
einsum
(
attention
,
value
,
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
out
=
einsum
(
attention
,
value
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
return
out
...
...
@@ -339,16 +335,12 @@ def flash_split_ref(Q, K, V, mask):
seqlen_kv
=
K
.
size
(
1
)
num_head_groups
=
nheads
//
groups
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
acc_s
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
,
block_N
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
acc_s_cast
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
,
block_N
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
acc_s_cast
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
,
block_N
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
acc_o
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
,
dim
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
scores_max
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
scores_max_prev
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
scores_max_prev
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
scores_scale
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
scores_sum
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
logsum
=
torch
.
empty
((
batch
,
num_head_groups
,
groups
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
...
...
@@ -356,25 +348,25 @@ def flash_split_ref(Q, K, V, mask):
glogsum
=
torch
.
empty
((
num_split
,
batch
,
nheads
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
Q_
=
Q
*
scale
Q_
=
rearrange
(
Q_
,
'
b (h g) d -> b g h d
'
,
g
=
num_head_groups
)
Q_
=
rearrange
(
Q_
,
"
b (h g) d -> b g h d
"
,
g
=
num_head_groups
)
for
ks
in
range
(
num_split
):
acc_o
.
fill_
(
0
)
logsum
.
fill_
(
0
)
scores_max
.
fill_
(
float
(
'
-inf
'
))
scores_max_prev
.
fill_
(
float
(
'
-inf
'
))
scores_max
.
fill_
(
float
(
"
-inf
"
))
scores_max_prev
.
fill_
(
float
(
"
-inf
"
))
for
i
in
range
(
int
((
seqlen_kv
//
num_split
)
/
block_N
)):
acc_s
.
fill_
(
0
)
acc_s
=
torch
.
einsum
(
'bghd,bkhd->bghk'
,
Q_
,
K
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:,
:])
# [batch, nheads, block_N]
acc_s
=
torch
.
einsum
(
"bghd,bkhd->bghk"
,
Q_
,
K
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:
(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:,
:],
)
# [batch, nheads, block_N]
if
mask
is
not
None
:
mask_local
=
mask
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:]
mask_local
=
rearrange
(
mask_local
,
'b s h -> b h s'
)
mask_local
=
mask
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:
(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:]
mask_local
=
rearrange
(
mask_local
,
"b s h -> b h s"
)
mask_local
=
mask_local
.
unsqueeze
(
1
)
acc_s
=
acc_s
.
masked_fill
(
mask_local
==
0
,
float
(
'
-inf
'
))
acc_s
=
acc_s
.
masked_fill
(
mask_local
==
0
,
float
(
"
-inf
"
))
scores_max_prev
=
scores_max
scores_max
=
acc_s
.
max
(
dim
=-
1
,
keepdim
=
False
).
values
# [batch, nheads]
scores_scale
=
torch
.
exp2
(
scores_max_prev
-
scores_max
)
# [batch, nheads]
...
...
@@ -382,15 +374,16 @@ def flash_split_ref(Q, K, V, mask):
acc_s
=
torch
.
exp2
(
acc_s
-
scores_max
[:,
:,
:,
None
])
acc_s_cast
=
acc_s
.
to
(
torch
.
float16
)
# [batch, nheads, block_N]
acc_o
+=
torch
.
einsum
(
'bghk,bkhd->bghd'
,
acc_s_cast
,
V
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:,
:])
"bghk,bkhd->bghd"
,
acc_s_cast
,
V
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:
(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:,
:],
)
scores_sum
=
acc_s
.
sum
(
dim
=-
1
,
keepdim
=
False
)
logsum
=
logsum
*
scores_scale
+
scores_sum
acc_o_out
=
rearrange
(
acc_o
,
'
b g h d->b (h g) d
'
)
logsum_out
=
rearrange
(
logsum
,
'
b g h->b (h g)
'
)
acc_o_out
=
rearrange
(
acc_o
,
"
b g h d->b (h g) d
"
)
logsum_out
=
rearrange
(
logsum
,
"
b g h->b (h g)
"
)
acc_o_out
/=
logsum_out
[:,
:,
None
]
logsum_out
=
torch
.
log2
(
logsum_out
)
+
rearrange
(
scores_max
,
'
b g h->b (h g)
'
)
logsum_out
=
torch
.
log2
(
logsum_out
)
+
rearrange
(
scores_max
,
"
b g h->b (h g)
"
)
gacc_o
[
ks
,
:,
:,
:]
=
acc_o_out
glogsum
[
ks
,
:,
:]
=
logsum_out
...
...
@@ -426,7 +419,7 @@ def calc_sim(x, y, name="tensor"):
x
,
y
=
x
.
data
.
double
(),
y
.
data
.
double
()
denominator
=
(
x
*
x
+
y
*
y
).
sum
()
if
denominator
==
0
:
print_red_warning
(
f
'
{
name
}
all zero
'
)
print_red_warning
(
f
"
{
name
}
all zero
"
)
return
1
sim
=
2
*
(
x
*
y
).
sum
()
/
denominator
return
sim
...
...
@@ -434,28 +427,23 @@ def calc_sim(x, y, name="tensor"):
def
assert_similar
(
x
,
y
,
eps
=
1e-2
,
name
=
"tensor"
,
assert_
=
False
,
print_
=
True
):
sim
=
calc_sim
(
x
,
y
,
name
)
diff
=
1.
-
sim
diff
=
1.
0
-
sim
if
not
(
0
<=
diff
<=
eps
):
print_red_warning
(
f
'
{
name
}
Error:
{
diff
}
'
)
print_red_warning
(
f
"
{
name
}
Error:
{
diff
}
"
)
if
assert_
:
raise
AssertionError
(
f
'
{
name
}
Error:
{
diff
}
'
)
raise
AssertionError
(
f
"
{
name
}
Error:
{
diff
}
"
)
else
:
if
print_
:
print
(
f
'
passed:
{
name
}
diff=
{
diff
}
'
)
print
(
f
"
passed:
{
name
}
diff=
{
diff
}
"
)
def
main
(
batch
:
int
=
1
,
heads
:
int
=
32
,
groups
:
int
=
8
,
kv_seqlen
:
int
=
8192
,
dim
:
int
=
128
,
tune
:
bool
=
False
):
def
main
(
batch
:
int
=
1
,
heads
:
int
=
32
,
groups
:
int
=
8
,
kv_seqlen
:
int
=
8192
,
dim
:
int
=
128
,
tune
:
bool
=
False
):
batch
,
heads
,
groups
,
kv_seqlen
,
dim
=
batch
,
heads
,
groups
,
kv_seqlen
,
dim
qk_flops
=
2
*
batch
*
heads
*
kv_seqlen
*
dim
pv_flops
=
2
*
batch
*
heads
*
kv_seqlen
*
dim
total_flops
=
qk_flops
+
pv_flops
if
(
not
tune
)
:
if
not
tune
:
config
,
sm_version
=
get_heuristic_config
()
kernel
=
flashattn
(
batch
,
heads
,
groups
,
kv_seqlen
,
dim
,
**
config
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Auto
)
...
...
@@ -497,11 +485,11 @@ def main(batch: int = 1,
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
1
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
32
,
help
=
'
heads
'
)
parser
.
add_argument
(
'
--groups
'
,
type
=
int
,
default
=
8
,
help
=
'
groups
'
)
parser
.
add_argument
(
'
--kv_seqlen
'
,
type
=
int
,
default
=
8192
,
help
=
'
kv sequence length
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
128
,
help
=
'
dim
'
)
parser
.
add_argument
(
'
--tune
'
,
action
=
'
store_true
'
,
help
=
'
tune configs
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
1
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
32
,
help
=
"
heads
"
)
parser
.
add_argument
(
"
--groups
"
,
type
=
int
,
default
=
8
,
help
=
"
groups
"
)
parser
.
add_argument
(
"
--kv_seqlen
"
,
type
=
int
,
default
=
8192
,
help
=
"
kv sequence length
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
128
,
help
=
"
dim
"
)
parser
.
add_argument
(
"
--tune
"
,
action
=
"
store_true
"
,
help
=
"
tune configs
"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
groups
,
args
.
kv_seqlen
,
args
.
dim
,
args
.
tune
)
examples/flash_decoding/example_gqa_decode_varlen_logits.py
View file @
29051439
...
...
@@ -19,8 +19,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch
,
num_key_value_heads
,
slen
,
head_dim
=
hidden_states
.
shape
if
n_rep
==
1
:
return
hidden_states
hidden_states
=
hidden_states
[:,
:,
None
,
:,
:].
expand
(
batch
,
num_key_value_heads
,
n_rep
,
slen
,
head_dim
)
hidden_states
=
hidden_states
[:,
:,
None
,
:,
:].
expand
(
batch
,
num_key_value_heads
,
n_rep
,
slen
,
head_dim
)
return
hidden_states
.
reshape
(
batch
,
num_key_value_heads
*
n_rep
,
slen
,
head_dim
)
...
...
@@ -74,14 +73,9 @@ def _fwd_inner(
return
m_i
,
l_i
,
acc
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
4
,
8
]
\
for
num_stages
in
[
2
,
4
]
\
],
key
=
[
'gqa_group_size'
,
'BLOCK_N'
,
'BLOCK_D'
,
'BLOCK_H'
],
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
4
,
8
]
for
num_stages
in
[
2
,
4
]],
key
=
[
"gqa_group_size"
,
"BLOCK_N"
,
"BLOCK_D"
,
"BLOCK_H"
],
)
@
triton
.
jit
def
_fwd_kernel_varlen
(
...
...
@@ -107,13 +101,12 @@ def _fwd_kernel_varlen(
stride_od
,
stride_sb
,
stride_sh
,
stride_sn
,
#bmask shape [b, q_h, seq/BLOCK_N]
stride_sn
,
#
bmask shape [b, q_h, seq/BLOCK_N]
gqa_group_size
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
):
off_z
=
tl
.
program_id
(
0
)
off_h_for_kv
=
tl
.
program_id
(
1
)
off_h_q
=
off_h_for_kv
*
gqa_group_size
...
...
@@ -134,8 +127,7 @@ def _fwd_kernel_varlen(
S_ptrs
=
S
+
off_z
*
stride_sb
+
off_h_q
*
stride_sh
mask_h
=
offs_h
<
gqa_group_size
q
=
tl
.
load
(
Q_ptrs
+
offs_d
[
None
,
:]
*
stride_qd
+
offs_h
[:,
None
]
*
stride_qh
,
mask
=
mask_h
[:,
None
])
q
=
tl
.
load
(
Q_ptrs
+
offs_d
[
None
,
:]
*
stride_qd
+
offs_h
[:,
None
]
*
stride_qh
,
mask
=
mask_h
[:,
None
])
if
s_aux
is
not
None
:
sink
=
tl
.
load
(
s_aux
+
off_h_q
+
offs_h
,
mask
=
mask_h
).
to
(
tl
.
float32
)
...
...
@@ -189,14 +181,12 @@ def _fwd_kernel_varlen(
acc
=
acc
.
to
(
O
.
dtype
.
element_ty
)
tl
.
store
(
O_ptrs
+
offs_h
[:,
None
]
*
stride_oh
+
offs_d
[
None
,
:]
*
stride_od
,
acc
,
mask
=
mask_h
[:,
None
])
tl
.
store
(
O_ptrs
+
offs_h
[:,
None
]
*
stride_oh
+
offs_d
[
None
,
:]
*
stride_od
,
acc
,
mask
=
mask_h
[:,
None
])
def
get_configs
():
import
itertools
block_N
=
[
64
,
128
]
block_H
=
[
64
]
num_split
=
[
1
]
...
...
@@ -204,31 +194,16 @@ def get_configs():
threads
=
[
128
]
_configs
=
list
(
itertools
.
product
(
block_N
,
block_H
,
num_split
,
num_stages
,
threads
))
configs
=
[{
'block_N'
:
c
[
0
],
'block_H'
:
c
[
1
],
'num_split'
:
c
[
2
],
'num_stages'
:
c
[
3
],
'threads'
:
c
[
4
]
}
for
c
in
_configs
]
configs
=
[{
"block_N"
:
c
[
0
],
"block_H"
:
c
[
1
],
"num_split"
:
c
[
2
],
"num_stages"
:
c
[
3
],
"threads"
:
c
[
4
]}
for
c
in
_configs
]
return
configs
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
tilelang
.
jit
(
out_idx
=
[
-
2
,
-
1
],
debug_root_path
=
"./examples/flash_decoding"
)
def
flashattn
(
batch
,
heads
,
k_heads
,
max_seqlen_kv
,
total_seqlen_k
,
dim
,
has_sink
,
block_N
=
128
,
block_H
=
64
,
num_split
=
1
,
num_stages
=
1
,
threads
=
128
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
def
flashattn
(
batch
,
heads
,
k_heads
,
max_seqlen_kv
,
total_seqlen_k
,
dim
,
has_sink
,
block_N
=
128
,
block_H
=
64
,
num_split
=
1
,
num_stages
=
1
,
threads
=
128
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape_q
=
[
batch
,
heads
,
dim
]
shape_k
=
[
total_seqlen_k
,
k_heads
,
dim
]
shape_v
=
[
total_seqlen_k
,
k_heads
,
dim
]
...
...
@@ -268,13 +243,15 @@ def flashattn(batch,
# S_fragment = T.alloc_fragment([block_H, math.ceil(max_seqlen_kv / block_N)], accum_dtype)
s_aux_shared
=
T
.
alloc_shared
([
block_H
],
"float32"
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
# Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
# K_shared: tilelang.layout.make_swizzled_layout(K_shared),
# V_shared: tilelang.layout.make_swizzled_layout(V_shared),
# O_shared: tilelang.layout.make_swizzled_layout(O_shared),
# S_shared: tilelang.layout.make_swizzled_layout(S_shared),
})
}
)
bid
=
bx
hid
=
by
...
...
@@ -284,7 +261,7 @@ def flashattn(batch,
cur_end_k
=
cu_seqlens_k
[
bid
+
1
]
cur_seqlen_k
=
cur_end_k
-
cur_start_k
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -292,15 +269,13 @@ def flashattn(batch,
# loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
loop_range
=
T
.
ceildiv
((
cur_seqlen_k
//
num_split
),
block_N
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
T
.
copy
(
K
[
cur_start_k
+
k
*
block_N
:
cur_start_k
+
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
copy
(
K
[
cur_start_k
+
k
*
block_N
:
cur_start_k
+
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
# acc_s[i, j] = T.if_then_else(mask_local[j] != 0 and k * block_N + j < cur_seqlen_k, acc_s[i, j],
# -T.infinity(accum_dtype))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
k
*
block_N
+
j
<
cur_seqlen_k
,
acc_s
[
i
,
j
],
-
T
.
infinity
(
accum_dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
k
*
block_N
+
j
<
cur_seqlen_k
,
acc_s
[
i
,
j
],
-
T
.
infinity
(
accum_dtype
))
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
...
...
@@ -320,12 +295,11 @@ def flashattn(batch,
T
.
copy
(
acc_s
,
acc_s_cast
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
V
[
cur_start_k
+
k
*
block_N
:
cur_start_k
+
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
copy
(
V
[
cur_start_k
+
k
*
block_N
:
cur_start_k
+
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
if
has_sink
:
T
.
copy
(
s_aux
[
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
],
s_aux_shared
)
T
.
copy
(
s_aux
[
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
],
s_aux_shared
)
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
+=
s_aux_shared
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
...
...
@@ -338,10 +312,9 @@ def flashattn(batch,
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
acc_o
[:
valid_block_H
,
:],
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bid
,
hid
*
valid_block_H
:
(
hid
+
1
)
*
valid_block_H
,
:])
T
.
copy
(
O_shared
,
Output
[
bid
,
hid
*
valid_block_H
:
(
hid
+
1
)
*
valid_block_H
,
:])
# T.copy(S_fragment, S_shared)
T
.
copy
(
S_shared
[:
valid_block_H
,
:],
S
[
bid
,
hid
*
valid_block_H
:(
hid
+
1
)
*
valid_block_H
,
:])
T
.
copy
(
S_shared
[:
valid_block_H
,
:],
S
[
bid
,
hid
*
valid_block_H
:
(
hid
+
1
)
*
valid_block_H
,
:])
@
T
.
prim_func
def
flashattn_gqa_decode_no_split
(
...
...
@@ -388,9 +361,7 @@ def flash_attn_with_attn_pool_decode_tilelang(
gqa_group_size
=
q_h
//
k_h
O_tl
=
torch
.
zeros_like
(
Q
)
S_tl
=
torch
.
zeros
((
batch
,
q_h
,
math
.
ceil
(
real_max_k_seqlen
/
block_size
)),
dtype
=
Q
.
dtype
,
device
=
Q
.
device
)
S_tl
=
torch
.
zeros
((
batch
,
q_h
,
math
.
ceil
(
real_max_k_seqlen
/
block_size
)),
dtype
=
Q
.
dtype
,
device
=
Q
.
device
)
O_tl
,
S_tl
=
tl_kernel
(
Q
,
K
,
V
,
cu_seqlens_k
,
s_aux
)
if
use_per_kv_head_sparse_index
:
...
...
@@ -433,9 +404,7 @@ def flash_attn_with_attn_pool_decode(
BLOCK_H
=
64
O
=
torch
.
zeros_like
(
Q
)
S
=
torch
.
zeros
((
batch
,
q_h
,
math
.
ceil
(
max_seqlen_k
/
block_size
)),
dtype
=
Q
.
dtype
,
device
=
Q
.
device
)
S
=
torch
.
zeros
((
batch
,
q_h
,
math
.
ceil
(
max_seqlen_k
/
block_size
)),
dtype
=
Q
.
dtype
,
device
=
Q
.
device
)
def
grid
(
META
):
return
(
batch
,
k_h
)
...
...
@@ -483,15 +452,15 @@ def test_equal_seqlen_decode_main(args):
dtype
=
torch
.
bfloat16
if
args
.
dtype
==
"bfloat16"
else
torch
.
float16
# For decode, query is just 1 token per batch
q
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
k
=
torch
.
randn
(
batch_size
,
kv_heads
,
k_seqlen
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
v
=
torch
.
randn
(
batch_size
,
kv_heads
,
k_seqlen
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
q
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
k
=
torch
.
randn
(
batch_size
,
kv_heads
,
k_seqlen
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
v
=
torch
.
randn
(
batch_size
,
kv_heads
,
k_seqlen
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
softmax_scale
=
1.0
/
math
.
sqrt
(
head_size
)
# Generate sink values if needed
sink
=
None
if
args
.
test_sink
:
sink
=
torch
.
randn
(
q_heads
,
device
=
'
cuda
'
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
sink
=
torch
.
randn
(
q_heads
,
device
=
"
cuda
"
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
print
(
f
"Using sink attention with sink values:
{
sink
}
"
)
# Convert to varlen format for K, V
...
...
@@ -499,8 +468,7 @@ def test_equal_seqlen_decode_main(args):
v_varlen
=
v
.
transpose
(
1
,
2
).
reshape
(
batch_size
*
k_seqlen
,
kv_heads
,
head_size
)
# Generate cumulative sequence lengths
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
k_seqlen
,
k_seqlen
,
device
=
'cuda'
,
dtype
=
torch
.
int32
)
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
k_seqlen
,
k_seqlen
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
max_seqlen_k
=
k_seqlen
print
(
f
"q shape:
{
q
.
shape
}
"
)
...
...
@@ -510,8 +478,7 @@ def test_equal_seqlen_decode_main(args):
num_tokens
,
q_h
,
head_size
=
q
.
shape
batch
=
cu_seqlens_k
.
size
(
0
)
-
1
k_h
=
k_varlen
.
size
(
1
)
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
args
.
test_sink
)
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
args
.
test_sink
)
# Test our decode kernel
O_triton
,
S_triton
=
flash_attn_with_attn_pool_decode
(
...
...
@@ -524,7 +491,8 @@ def test_equal_seqlen_decode_main(args):
args
.
num_split
,
softmax_scale
,
s_aux
=
sink
,
block_size
=
block_size
)
block_size
=
block_size
,
)
O_tilelang
,
S_tilelang
=
flash_attn_with_attn_pool_decode_tilelang
(
q
,
k_varlen
,
...
...
@@ -539,9 +507,7 @@ def test_equal_seqlen_decode_main(args):
tl_kernel
=
tl_kernel
,
)
for
i
in
range
(
batch_size
):
S_tilelang
[
i
,
:,
math
.
ceil
((
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
())
/
block_size
):]
=
0
S_tilelang
[
i
,
:,
math
.
ceil
((
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
())
/
block_size
)
:]
=
0
# Compute torch reference
q_expanded
=
q
.
unsqueeze
(
2
)
# [b, q_heads, 1, head_size]
...
...
@@ -550,14 +516,12 @@ def test_equal_seqlen_decode_main(args):
if
sink
is
None
:
# Standard scaled dot-product attention
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [batch, q_heads, 1, seqlen_k]
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [batch, q_heads, 1, seqlen_k]
attn_weights
=
torch
.
softmax
(
logits
,
dim
=-
1
)
O_torch
=
torch
.
matmul
(
attn_weights
,
v_repeat
).
squeeze
(
2
)
# [batch, q_heads, head_size]
else
:
# s_aux attention
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [batch, q_heads, 1, seqlen_k]
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [batch, q_heads, 1, seqlen_k]
sink_expanded
=
sink
.
view
(
1
,
q_heads
,
1
,
1
)
# [1, q_heads, 1, 1]
logits_max
=
torch
.
max
(
logits
,
dim
=-
1
,
keepdim
=
True
).
values
...
...
@@ -566,15 +530,15 @@ def test_equal_seqlen_decode_main(args):
unnormalized_scores
=
torch
.
exp
(
logits
-
logits_or_sinks_max
)
normalizer
=
unnormalized_scores
.
sum
(
dim
=-
1
,
keepdim
=
True
)
+
sinks
attn_weights
=
unnormalized_scores
/
normalizer
O_torch
=
torch
.
matmul
(
attn_weights
.
to
(
v_repeat
.
dtype
),
v_repeat
).
squeeze
(
2
)
# [batch, q_heads, head_size]
O_torch
=
torch
.
matmul
(
attn_weights
.
to
(
v_repeat
.
dtype
),
v_repeat
).
squeeze
(
2
)
# [batch, q_heads, head_size]
# Compute attention score pooling
attn_score_pooled
=
torch
.
max_pool2d
(
attn_weights
.
squeeze
(
2
),
# [b, q_heads, k_seqlen]
kernel_size
=
(
q_heads
,
block_size
),
stride
=
(
q_heads
,
block_size
),
ceil_mode
=
True
).
to
(
torch
.
float16
)
ceil_mode
=
True
,
).
to
(
torch
.
float16
)
print
(
"S_tilelang"
,
S_tilelang
)
print
(
"attn_score_pooled"
,
attn_score_pooled
)
...
...
@@ -588,15 +552,10 @@ def test_equal_seqlen_decode_main(args):
print
(
f
"Max difference in S:
{
max_diff_s
.
item
()
}
"
)
print
(
f
"Max difference in O_tilelang:
{
max_diff_o_tilelang
.
item
()
}
"
)
print
(
f
"Max difference in S_tilelang:
{
max_diff_s_tilelang
.
item
()
}
"
)
assert
torch
.
allclose
(
O_triton
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o
.
item
()
}
"
assert
torch
.
allclose
(
S_triton
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s
.
item
()
}
"
assert
torch
.
allclose
(
O_tilelang
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o_tilelang
.
item
()
}
"
assert
torch
.
allclose
(
S_tilelang
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s_tilelang
.
item
()
}
"
assert
torch
.
allclose
(
O_triton
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o
.
item
()
}
"
assert
torch
.
allclose
(
S_triton
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s
.
item
()
}
"
assert
torch
.
allclose
(
O_tilelang
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o_tilelang
.
item
()
}
"
assert
torch
.
allclose
(
S_tilelang
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s_tilelang
.
item
()
}
"
print
(
"✅ All tests passed!"
)
...
...
@@ -616,7 +575,7 @@ def test_varlen_decode_main(args):
# Generate sink values if needed
sink
=
None
if
args
.
test_sink
:
sink
=
torch
.
randn
(
q_heads
,
device
=
'
cuda
'
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
sink
=
torch
.
randn
(
q_heads
,
device
=
"
cuda
"
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
print
(
f
"Using sink attention with sink values:
{
sink
}
"
)
# Generate variable length k sequences
...
...
@@ -624,7 +583,7 @@ def test_varlen_decode_main(args):
print
(
f
"k_seqlens:
{
k_seqlens
}
"
)
# Generate cumulative sequence lengths for k
cu_seqlens_k
=
torch
.
zeros
(
batch_size
+
1
,
device
=
'
cuda
'
,
dtype
=
torch
.
int32
)
cu_seqlens_k
=
torch
.
zeros
(
batch_size
+
1
,
device
=
"
cuda
"
,
dtype
=
torch
.
int32
)
total_k_tokens
=
0
for
i
in
range
(
batch_size
):
cu_seqlens_k
[
i
]
=
total_k_tokens
...
...
@@ -634,9 +593,9 @@ def test_varlen_decode_main(args):
print
(
f
"cu_seqlens_k:
{
cu_seqlens_k
}
"
)
# Generate tensors - Q is [batch_size, q_heads, head_size] for decode
q_decode
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
k_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
v_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
q_decode
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
k_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
v_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
softmax_scale
=
1.0
/
math
.
sqrt
(
head_size
)
max_seqlen_k
=
int
(
k_seqlens
.
max
())
...
...
@@ -649,8 +608,7 @@ def test_varlen_decode_main(args):
num_tokens
,
q_h
,
head_size
=
q_decode
.
shape
batch
=
cu_seqlens_k
.
size
(
0
)
-
1
k_h
=
k_varlen
.
size
(
1
)
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
args
.
test_sink
)
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
args
.
test_sink
)
# Test our decode kernel
O_triton
,
S_triton
=
flash_attn_with_attn_pool_decode
(
...
...
@@ -663,7 +621,8 @@ def test_varlen_decode_main(args):
args
.
num_split
,
softmax_scale
,
s_aux
=
sink
,
block_size
=
block_size
)
block_size
=
block_size
,
)
O_tilelang
,
S_tilelang
=
flash_attn_with_attn_pool_decode_tilelang
(
q_decode
,
k_varlen
,
...
...
@@ -678,9 +637,7 @@ def test_varlen_decode_main(args):
tl_kernel
=
tl_kernel
,
)
for
i
in
range
(
batch_size
):
S_tilelang
[
i
,
:,
math
.
ceil
((
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
())
/
block_size
):]
=
0
S_tilelang
[
i
,
:,
math
.
ceil
((
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
())
/
block_size
)
:]
=
0
# Create torch reference - pad tensors for comparison
k_padded_list
=
[]
...
...
@@ -694,8 +651,8 @@ def test_varlen_decode_main(args):
k_end
=
cu_seqlens_k
[
i
+
1
]
# Pad to max_seqlen_k
k_padded
=
torch
.
zeros
(
max_seqlen_k
,
kv_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
v_padded
=
torch
.
zeros
(
max_seqlen_k
,
kv_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
k_padded
=
torch
.
zeros
(
max_seqlen_k
,
kv_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
v_padded
=
torch
.
zeros
(
max_seqlen_k
,
kv_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
k_padded
[:
actual_k_len
]
=
k_varlen
[
k_start
:
k_end
]
v_padded
[:
actual_k_len
]
=
v_varlen
[
k_start
:
k_end
]
...
...
@@ -704,10 +661,8 @@ def test_varlen_decode_main(args):
v_padded_list
.
append
(
v_padded
)
# Stack to create batched tensors [b, max_seqlen, kv_heads, head_size]
k_padded_batched
=
torch
.
stack
(
k_padded_list
,
dim
=
0
).
transpose
(
1
,
2
)
# [b, kv_heads, max_seqlen, head_size]
v_padded_batched
=
torch
.
stack
(
v_padded_list
,
dim
=
0
).
transpose
(
1
,
2
)
# [b, kv_heads, max_seqlen, head_size]
k_padded_batched
=
torch
.
stack
(
k_padded_list
,
dim
=
0
).
transpose
(
1
,
2
)
# [b, kv_heads, max_seqlen, head_size]
v_padded_batched
=
torch
.
stack
(
v_padded_list
,
dim
=
0
).
transpose
(
1
,
2
)
# [b, kv_heads, max_seqlen, head_size]
# Expand q to match kv heads: [b, q_heads, 1, head_size]
q_expanded
=
q_decode
.
unsqueeze
(
2
)
# [b, q_heads, 1, head_size]
...
...
@@ -717,20 +672,17 @@ def test_varlen_decode_main(args):
print
(
f
"v_padded_batched shape:
{
v_padded_batched
.
shape
}
"
)
# Compute torch reference
k_repeat
=
repeat_kv
(
k_padded_batched
,
q_heads
//
kv_heads
)
# [b, q_heads, max_seqlen, head_size]
v_repeat
=
repeat_kv
(
v_padded_batched
,
q_heads
//
kv_heads
)
# [b, q_heads, max_seqlen, head_size]
k_repeat
=
repeat_kv
(
k_padded_batched
,
q_heads
//
kv_heads
)
# [b, q_heads, max_seqlen, head_size]
v_repeat
=
repeat_kv
(
v_padded_batched
,
q_heads
//
kv_heads
)
# [b, q_heads, max_seqlen, head_size]
if
sink
is
None
:
# Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen]
attn_score
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [b, q_heads, 1, max_seqlen]
attn_score
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [b, q_heads, 1, max_seqlen]
# Apply sequence length masking
for
i
in
range
(
batch_size
):
actual_k_len
=
k_seqlens
[
i
]
attn_score
[
i
,
:,
:,
actual_k_len
:]
=
float
(
'
-inf
'
)
attn_score
[
i
,
:,
:,
actual_k_len
:]
=
float
(
"
-inf
"
)
attn_weights
=
attn_score
.
softmax
(
dim
=-
1
)
# [b, q_heads, 1, max_seqlen]
...
...
@@ -743,13 +695,12 @@ def test_varlen_decode_main(args):
O_torch
=
torch
.
matmul
(
attn_weights
,
v_repeat
)
# [b, q_heads, 1, head_size]
else
:
# s_aux attention
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [b, q_heads, 1, max_seqlen]
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [b, q_heads, 1, max_seqlen]
# Apply sequence length masking
for
i
in
range
(
batch_size
):
actual_k_len
=
k_seqlens
[
i
]
logits
[
i
,
:,
:,
actual_k_len
:]
=
float
(
'
-inf
'
)
logits
[
i
,
:,
:,
actual_k_len
:]
=
float
(
"
-inf
"
)
sink_expanded
=
sink
.
view
(
1
,
q_heads
,
1
,
1
)
# [1, q_heads, 1, 1]
logits_max
=
torch
.
max
(
logits
,
dim
=-
1
,
keepdim
=
True
).
values
...
...
@@ -765,8 +716,7 @@ def test_varlen_decode_main(args):
attn_weights
[
i
,
:,
:,
actual_k_len
:]
=
0.0
# Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size]
O_torch
=
torch
.
matmul
(
attn_weights
.
to
(
v_repeat
.
dtype
),
v_repeat
)
# [b, q_heads, 1, head_size]
O_torch
=
torch
.
matmul
(
attn_weights
.
to
(
v_repeat
.
dtype
),
v_repeat
)
# [b, q_heads, 1, head_size]
O_torch
=
O_torch
.
squeeze
(
2
)
# [b, q_heads, head_size]
...
...
@@ -775,7 +725,8 @@ def test_varlen_decode_main(args):
attn_weights
.
squeeze
(
2
),
# [b, q_heads, max_seqlen]
kernel_size
=
(
q_heads
,
block_size
),
stride
=
(
q_heads
,
block_size
),
ceil_mode
=
True
).
to
(
dtype
=
torch
.
float16
)
# [b, 1, ceil(max_seqlen/block_size)]
ceil_mode
=
True
,
).
to
(
dtype
=
torch
.
float16
)
# [b, 1, ceil(max_seqlen/block_size)]
print
(
f
"O_triton shape:
{
O_triton
.
shape
}
"
)
print
(
f
"O_tilelang shape:
{
O_tilelang
.
shape
}
"
)
...
...
@@ -791,22 +742,16 @@ def test_varlen_decode_main(args):
print
(
f
"Max difference in O_tilelang:
{
max_diff_o_tl
.
item
()
}
"
)
max_diff_s
=
torch
.
max
(
torch
.
abs
(
S_triton
-
attn_score_pooled
))
max_diff_s_tl
=
torch
.
max
(
torch
.
abs
(
S_tilelang
[:,
:,
:
math
.
ceil
(
max_seqlen_k
/
block_size
)]
-
attn_score_pooled
))
max_diff_s_tl
=
torch
.
max
(
torch
.
abs
(
S_tilelang
[:,
:,
:
math
.
ceil
(
max_seqlen_k
/
block_size
)]
-
attn_score_pooled
))
print
(
f
"Max difference in S:
{
max_diff_s
.
item
()
}
"
)
print
(
f
"Max difference in S_tilelang:
{
max_diff_s_tl
.
item
()
}
"
)
assert
torch
.
allclose
(
O_triton
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o
.
item
()
}
"
assert
torch
.
allclose
(
S_triton
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s
.
item
()
}
"
assert
torch
.
allclose
(
O_tilelang
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o_tl
.
item
()
}
"
assert
torch
.
allclose
(
S_tilelang
[:,
:,
:
math
.
ceil
(
max_seqlen_k
/
block_size
)],
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s_tl
.
item
()
}
"
assert
torch
.
allclose
(
O_triton
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o
.
item
()
}
"
assert
torch
.
allclose
(
S_triton
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s
.
item
()
}
"
assert
torch
.
allclose
(
O_tilelang
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o_tl
.
item
()
}
"
assert
torch
.
allclose
(
S_tilelang
[:,
:,
:
math
.
ceil
(
max_seqlen_k
/
block_size
)],
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
(
f
"Score mismatch:
{
max_diff_s_tl
.
item
()
}
"
)
print
(
"✅ All tests passed!"
)
...
...
@@ -865,7 +810,7 @@ def speed_benchmark_decode_comparison(args):
k_seqlens
=
torch
.
full
((
batch_size
,),
max_k_seqlen
,
dtype
=
int
)
# Generate cumulative sequence lengths for k
cu_seqlens_k
=
torch
.
zeros
(
batch_size
+
1
,
device
=
'
cuda
'
,
dtype
=
torch
.
int32
)
cu_seqlens_k
=
torch
.
zeros
(
batch_size
+
1
,
device
=
"
cuda
"
,
dtype
=
torch
.
int32
)
total_k_tokens
=
0
for
i
in
range
(
batch_size
):
cu_seqlens_k
[
i
]
=
total_k_tokens
...
...
@@ -873,9 +818,9 @@ def speed_benchmark_decode_comparison(args):
cu_seqlens_k
[
batch_size
]
=
total_k_tokens
# Generate tensors
q_decode
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
k_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
v_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
q_decode
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
k_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
v_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
softmax_scale
=
1.0
/
math
.
sqrt
(
head_size
)
max_seqlen_k
=
int
(
k_seqlens
.
max
())
...
...
@@ -883,7 +828,7 @@ def speed_benchmark_decode_comparison(args):
# Generate sink values if needed
sink
=
None
if
args
.
test_sink
:
sink
=
torch
.
randn
(
q_heads
,
device
=
'
cuda
'
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
sink
=
torch
.
randn
(
q_heads
,
device
=
"
cuda
"
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
print
(
" Using sink attention with sink values"
)
print
(
"Setup complete:"
)
...
...
@@ -896,8 +841,7 @@ def speed_benchmark_decode_comparison(args):
num_tokens
,
q_h
,
head_size
=
q_decode
.
shape
batch
=
cu_seqlens_k
.
size
(
0
)
-
1
k_h
=
k_varlen
.
size
(
1
)
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
args
.
test_sink
)
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
args
.
test_sink
)
# Benchmark
print
(
"⚡ Benchmarking Tilelang kernel (100 iterations)..."
)
...
...
@@ -920,36 +864,41 @@ def speed_benchmark_decode_comparison(args):
# Benchmark
print
(
"⚡ Benchmarking Triton kernel (100 iterations)..."
)
triton_time
=
do_bench
(
flash_attn_with_attn_pool_decode
,
q_decode
,
k_varlen
,
v_varlen
,
cu_seqlens_k
,
max_seqlen_k
,
args
.
k_seqlen
,
1
,
softmax_scale
,
sink
,
block_size
)
triton_time
=
do_bench
(
flash_attn_with_attn_pool_decode
,
q_decode
,
k_varlen
,
v_varlen
,
cu_seqlens_k
,
max_seqlen_k
,
args
.
k_seqlen
,
1
,
softmax_scale
,
sink
,
block_size
,
)
print
(
f
"Average decode kernel time Triton:
{
triton_time
:.
3
f
}
ms"
)
print
(
f
"Speedup:
{
(
triton_time
/
tilelang_time
):.
3
f
}
"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
'Flash Attention Decode with Attention Pooling'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
1
,
help
=
'Batch size'
)
parser
.
add_argument
(
'--q_heads'
,
type
=
int
,
default
=
32
,
help
=
'Number of query heads'
)
parser
.
add_argument
(
'--kv_heads'
,
type
=
int
,
default
=
8
,
help
=
'Number of key-value heads'
)
parser
.
add_argument
(
'--k_seqlen'
,
type
=
int
,
default
=
8192
,
help
=
'Key sequence length'
)
parser
.
add_argument
(
'--head_size'
,
type
=
int
,
default
=
128
,
choices
=
[
64
,
128
,
256
],
help
=
'Head dimension'
)
parser
.
add_argument
(
'--block_size'
,
type
=
int
,
default
=
64
,
help
=
'Block size for computation'
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'bfloat16'
,
choices
=
[
'float16'
,
'bfloat16'
],
help
=
'Data type'
)
parser
.
add_argument
(
'--test_varlen'
,
action
=
'store_true'
,
help
=
'Test with truly variable sequence lengths'
)
parser
.
add_argument
(
'--test_sink'
,
action
=
'store_true'
,
help
=
'Test with sink attention mechanism'
)
parser
.
add_argument
(
'--benchmark'
,
action
=
'store_true'
,
help
=
'Run speed benchmark'
)
parser
.
add_argument
(
'--num_split'
,
type
=
int
,
default
=
1
,
choices
=
[
1
,
16
],
help
=
'Number of splits'
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Flash Attention Decode with Attention Pooling"
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
,
help
=
"Batch size"
)
parser
.
add_argument
(
"--q_heads"
,
type
=
int
,
default
=
32
,
help
=
"Number of query heads"
)
parser
.
add_argument
(
"--kv_heads"
,
type
=
int
,
default
=
8
,
help
=
"Number of key-value heads"
)
parser
.
add_argument
(
"--k_seqlen"
,
type
=
int
,
default
=
8192
,
help
=
"Key sequence length"
)
parser
.
add_argument
(
"--head_size"
,
type
=
int
,
default
=
128
,
choices
=
[
64
,
128
,
256
],
help
=
"Head dimension"
)
parser
.
add_argument
(
"--block_size"
,
type
=
int
,
default
=
64
,
help
=
"Block size for computation"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"bfloat16"
,
choices
=
[
"float16"
,
"bfloat16"
],
help
=
"Data type"
)
parser
.
add_argument
(
"--test_varlen"
,
action
=
"store_true"
,
help
=
"Test with truly variable sequence lengths"
)
parser
.
add_argument
(
"--test_sink"
,
action
=
"store_true"
,
help
=
"Test with sink attention mechanism"
)
parser
.
add_argument
(
"--benchmark"
,
action
=
"store_true"
,
help
=
"Run speed benchmark"
)
parser
.
add_argument
(
"--num_split"
,
type
=
int
,
default
=
1
,
choices
=
[
1
,
16
],
help
=
"Number of splits"
)
args
=
parser
.
parse_args
()
args
.
test_sink
=
True
args
.
test_varlen
=
False
args
.
dtype
=
'
float16
'
args
.
dtype
=
"
float16
"
args
.
num_split
=
1
if
args
.
benchmark
:
...
...
examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py
View file @
29051439
...
...
@@ -10,6 +10,7 @@ torch.manual_seed(0)
def
get_configs
():
import
itertools
block_N
=
[
64
,
128
]
block_H
=
[
64
]
num_split
=
[
1
]
...
...
@@ -17,19 +18,14 @@ def get_configs():
threads
=
[
128
]
_configs
=
list
(
itertools
.
product
(
block_N
,
block_H
,
num_split
,
num_stages
,
threads
))
configs
=
[{
'block_N'
:
c
[
0
],
'block_H'
:
c
[
1
],
'num_split'
:
c
[
2
],
'num_stages'
:
c
[
3
],
'threads'
:
c
[
4
]
}
for
c
in
_configs
]
configs
=
[{
"block_N"
:
c
[
0
],
"block_H"
:
c
[
1
],
"num_split"
:
c
[
2
],
"num_stages"
:
c
[
3
],
"threads"
:
c
[
4
]}
for
c
in
_configs
]
return
configs
# @autotune(configs=get_configs(), warmup=10, rep=10)
@
tilelang
.
jit
(
out_idx
=
[
-
2
,
-
1
],
debug_root_path
=
"./examples/flash_decoding"
)
def
flashattn
(
batch
,
def
flashattn
(
batch
,
heads
,
k_heads
,
max_seqlen_kv
,
...
...
@@ -41,8 +37,9 @@ def flashattn(batch,
block_H
=
64
,
num_split
=
1
,
num_stages
=
1
,
threads
=
128
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
threads
=
128
,
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape_q
=
[
batch
,
heads
,
dim
]
shape_k
=
[
total_seqlen_k
,
k_heads
,
dim
]
shape_v
=
[
total_seqlen_k
,
k_heads
,
dim
]
...
...
@@ -51,7 +48,9 @@ def flashattn(batch,
dtype
=
"float16"
accum_dtype
=
"float"
kv_group_num
=
heads
//
k_heads
assert
page_block_size
>=
block_N
and
page_block_size
%
block_N
==
0
,
"page_block_size must be larger than block_N and a multiple of block_N"
assert
page_block_size
>=
block_N
and
page_block_size
%
block_N
==
0
,
(
"page_block_size must be larger than block_N and a multiple of block_N"
)
valid_block_H
=
min
(
block_H
,
kv_group_num
)
# TODO: check if max_seqlen_kv is correct for varlen case
...
...
@@ -91,7 +90,7 @@ def flashattn(batch,
cur_end_k
=
cu_seqlens_k
[
bid
+
1
]
cur_seqlen_k
=
cur_end_k
-
cur_start_k
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -99,15 +98,12 @@ def flashattn(batch,
# loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
loop_range
=
T
.
ceildiv
((
cur_seqlen_k
//
num_split
),
block_N
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
k_start
=
BLOCK_TABLE
[
bid
,
(
k
*
block_N
)
//
page_block_size
]
*
page_block_size
+
(
k
*
block_N
)
%
page_block_size
T
.
copy
(
K
[
cur_start_k
+
k_start
:
cur_start_k
+
k_start
+
block_N
,
cur_kv_head
,
:],
K_shared
)
k_start
=
BLOCK_TABLE
[
bid
,
(
k
*
block_N
)
//
page_block_size
]
*
page_block_size
+
(
k
*
block_N
)
%
page_block_size
T
.
copy
(
K
[
cur_start_k
+
k_start
:
cur_start_k
+
k_start
+
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
k
*
block_N
+
j
<
cur_seqlen_k
,
acc_s
[
i
,
j
],
-
T
.
infinity
(
accum_dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
k
*
block_N
+
j
<
cur_seqlen_k
,
acc_s
[
i
,
j
],
-
T
.
infinity
(
accum_dtype
))
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
...
...
@@ -127,14 +123,12 @@ def flashattn(batch,
T
.
copy
(
acc_s
,
acc_s_cast
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
v_start
=
BLOCK_TABLE
[
bid
,
(
k
*
block_N
)
//
page_block_size
]
*
page_block_size
+
(
k
*
block_N
)
%
page_block_size
T
.
copy
(
V
[
cur_start_k
+
v_start
:
cur_start_k
+
v_start
+
block_N
,
cur_kv_head
,
:],
V_shared
)
v_start
=
BLOCK_TABLE
[
bid
,
(
k
*
block_N
)
//
page_block_size
]
*
page_block_size
+
(
k
*
block_N
)
%
page_block_size
T
.
copy
(
V
[
cur_start_k
+
v_start
:
cur_start_k
+
v_start
+
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
if
has_sink
:
T
.
copy
(
s_aux
[
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
],
s_aux_shared
)
T
.
copy
(
s_aux
[
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
],
s_aux_shared
)
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
+=
s_aux_shared
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
...
...
@@ -144,9 +138,8 @@ def flashattn(batch,
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
acc_o
[:
valid_block_H
,
:],
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bid
,
hid
*
valid_block_H
:(
hid
+
1
)
*
valid_block_H
,
:])
T
.
copy
(
S_shared
[:
valid_block_H
,
:],
S
[
bid
,
hid
*
valid_block_H
:(
hid
+
1
)
*
valid_block_H
,
:])
T
.
copy
(
O_shared
,
Output
[
bid
,
hid
*
valid_block_H
:
(
hid
+
1
)
*
valid_block_H
,
:])
T
.
copy
(
S_shared
[:
valid_block_H
,
:],
S
[
bid
,
hid
*
valid_block_H
:
(
hid
+
1
)
*
valid_block_H
,
:])
@
T
.
prim_func
def
flashattn_gqa_decode_no_split
(
...
...
@@ -195,9 +188,7 @@ def flash_attn_with_attn_pool_decode_tilelang(
gqa_group_size
=
q_h
//
k_h
O_tl
=
torch
.
zeros_like
(
Q
)
S_tl
=
torch
.
zeros
((
batch
,
q_h
,
math
.
ceil
(
real_max_k_seqlen
/
block_size
)),
dtype
=
Q
.
dtype
,
device
=
Q
.
device
)
S_tl
=
torch
.
zeros
((
batch
,
q_h
,
math
.
ceil
(
real_max_k_seqlen
/
block_size
)),
dtype
=
Q
.
dtype
,
device
=
Q
.
device
)
O_tl
,
S_tl
=
tl_kernel
(
Q
,
K
,
V
,
cu_seqlens_k
,
s_aux
,
block_table
)
if
use_per_kv_head_sparse_index
:
...
...
@@ -223,15 +214,15 @@ def test_equal_seqlen_decode_main(args):
dtype
=
torch
.
bfloat16
if
args
.
dtype
==
"bfloat16"
else
torch
.
float16
# For decode, query is just 1 token per batch
q
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
k
=
torch
.
randn
(
batch_size
,
kv_heads
,
k_seqlen
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
v
=
torch
.
randn
(
batch_size
,
kv_heads
,
k_seqlen
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
q
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
k
=
torch
.
randn
(
batch_size
,
kv_heads
,
k_seqlen
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
v
=
torch
.
randn
(
batch_size
,
kv_heads
,
k_seqlen
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
softmax_scale
=
1.0
/
math
.
sqrt
(
head_size
)
# Generate sink values if needed
sink
=
None
if
args
.
test_sink
:
sink
=
torch
.
randn
(
q_heads
,
device
=
'
cuda
'
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
sink
=
torch
.
randn
(
q_heads
,
device
=
"
cuda
"
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
print
(
f
"Using sink attention with sink values:
{
sink
}
"
)
# Convert to varlen format for K, V
...
...
@@ -239,8 +230,7 @@ def test_equal_seqlen_decode_main(args):
v_varlen
=
v
.
transpose
(
1
,
2
).
reshape
(
batch_size
*
k_seqlen
,
kv_heads
,
head_size
).
contiguous
()
# Generate cumulative sequence lengths
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
k_seqlen
,
k_seqlen
,
device
=
'cuda'
,
dtype
=
torch
.
int32
)
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
k_seqlen
,
k_seqlen
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
max_seqlen_k
=
k_seqlen
print
(
f
"q shape:
{
q
.
shape
}
"
)
...
...
@@ -250,11 +240,9 @@ def test_equal_seqlen_decode_main(args):
num_tokens
,
q_h
,
head_size
=
q
.
shape
batch
=
cu_seqlens_k
.
size
(
0
)
-
1
k_h
=
k_varlen
.
size
(
1
)
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
args
.
test_sink
,
page_block_size
)
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
args
.
test_sink
,
page_block_size
)
block_table
=
torch
.
zeros
(
batch
,
math
.
ceil
(
real_max_k_seqlen
/
page_block_size
),
device
=
'cuda'
,
dtype
=
torch
.
int32
)
block_table
=
torch
.
zeros
(
batch
,
math
.
ceil
(
real_max_k_seqlen
/
page_block_size
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
block_cnt
=
0
for
i
in
range
(
batch
):
cur_seqlen
=
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
()
...
...
@@ -274,7 +262,8 @@ def test_equal_seqlen_decode_main(args):
args
.
num_split
,
softmax_scale
,
s_aux
=
sink
,
block_size
=
block_size
)
block_size
=
block_size
,
)
O_tilelang
,
S_tilelang
=
flash_attn_with_attn_pool_decode_tilelang
(
q
,
k_varlen
,
...
...
@@ -290,9 +279,7 @@ def test_equal_seqlen_decode_main(args):
block_table
=
block_table
,
)
for
i
in
range
(
batch_size
):
S_tilelang
[
i
,
:,
math
.
ceil
((
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
())
/
block_size
):]
=
0
S_tilelang
[
i
,
:,
math
.
ceil
((
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
())
/
block_size
)
:]
=
0
# Compute torch reference
q_expanded
=
q
.
unsqueeze
(
2
)
# [b, q_heads, 1, head_size]
...
...
@@ -301,14 +288,12 @@ def test_equal_seqlen_decode_main(args):
if
sink
is
None
:
# Standard scaled dot-product attention
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [batch, q_heads, 1, seqlen_k]
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [batch, q_heads, 1, seqlen_k]
attn_weights
=
torch
.
softmax
(
logits
,
dim
=-
1
)
O_torch
=
torch
.
matmul
(
attn_weights
,
v_repeat
).
squeeze
(
2
)
# [batch, q_heads, head_size]
else
:
# s_aux attention
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [batch, q_heads, 1, seqlen_k]
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [batch, q_heads, 1, seqlen_k]
sink_expanded
=
sink
.
view
(
1
,
q_heads
,
1
,
1
)
# [1, q_heads, 1, 1]
logits_max
=
torch
.
max
(
logits
,
dim
=-
1
,
keepdim
=
True
).
values
...
...
@@ -317,15 +302,15 @@ def test_equal_seqlen_decode_main(args):
unnormalized_scores
=
torch
.
exp
(
logits
-
logits_or_sinks_max
)
normalizer
=
unnormalized_scores
.
sum
(
dim
=-
1
,
keepdim
=
True
)
+
sinks
attn_weights
=
unnormalized_scores
/
normalizer
O_torch
=
torch
.
matmul
(
attn_weights
.
to
(
v_repeat
.
dtype
),
v_repeat
).
squeeze
(
2
)
# [batch, q_heads, head_size]
O_torch
=
torch
.
matmul
(
attn_weights
.
to
(
v_repeat
.
dtype
),
v_repeat
).
squeeze
(
2
)
# [batch, q_heads, head_size]
# Compute attention score pooling
attn_score_pooled
=
torch
.
max_pool2d
(
attn_weights
.
squeeze
(
2
),
# [b, q_heads, k_seqlen]
kernel_size
=
(
q_heads
,
block_size
),
stride
=
(
q_heads
,
block_size
),
ceil_mode
=
True
).
to
(
torch
.
float16
)
ceil_mode
=
True
,
).
to
(
torch
.
float16
)
print
(
"S_tilelang"
,
S_tilelang
)
print
(
"attn_score_pooled"
,
attn_score_pooled
)
...
...
@@ -339,15 +324,10 @@ def test_equal_seqlen_decode_main(args):
print
(
f
"Max difference in S:
{
max_diff_s
.
item
()
}
"
)
print
(
f
"Max difference in O_tilelang:
{
max_diff_o_tilelang
.
item
()
}
"
)
print
(
f
"Max difference in S_tilelang:
{
max_diff_s_tilelang
.
item
()
}
"
)
assert
torch
.
allclose
(
O_triton
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o
.
item
()
}
"
assert
torch
.
allclose
(
S_triton
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s
.
item
()
}
"
assert
torch
.
allclose
(
O_tilelang
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o_tilelang
.
item
()
}
"
assert
torch
.
allclose
(
S_tilelang
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s_tilelang
.
item
()
}
"
assert
torch
.
allclose
(
O_triton
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o
.
item
()
}
"
assert
torch
.
allclose
(
S_triton
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s
.
item
()
}
"
assert
torch
.
allclose
(
O_tilelang
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o_tilelang
.
item
()
}
"
assert
torch
.
allclose
(
S_tilelang
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s_tilelang
.
item
()
}
"
print
(
"✅ All tests passed!"
)
...
...
@@ -368,7 +348,7 @@ def test_varlen_decode_main(args):
# Generate sink values if needed
sink
=
None
if
args
.
test_sink
:
sink
=
torch
.
randn
(
q_heads
,
device
=
'
cuda
'
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
sink
=
torch
.
randn
(
q_heads
,
device
=
"
cuda
"
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
print
(
f
"Using sink attention with sink values:
{
sink
}
"
)
# Generate variable length k sequences
...
...
@@ -376,7 +356,7 @@ def test_varlen_decode_main(args):
print
(
f
"k_seqlens:
{
k_seqlens
}
"
)
# Generate cumulative sequence lengths for k
cu_seqlens_k
=
torch
.
zeros
(
batch_size
+
1
,
device
=
'
cuda
'
,
dtype
=
torch
.
int32
)
cu_seqlens_k
=
torch
.
zeros
(
batch_size
+
1
,
device
=
"
cuda
"
,
dtype
=
torch
.
int32
)
total_k_tokens
=
0
for
i
in
range
(
batch_size
):
cu_seqlens_k
[
i
]
=
total_k_tokens
...
...
@@ -386,9 +366,9 @@ def test_varlen_decode_main(args):
print
(
f
"cu_seqlens_k:
{
cu_seqlens_k
}
"
)
# Generate tensors - Q is [batch_size, q_heads, head_size] for decode
q_decode
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
k_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
v_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
q_decode
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
k_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
v_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
softmax_scale
=
1.0
/
math
.
sqrt
(
head_size
)
max_seqlen_k
=
int
(
k_seqlens
.
max
())
...
...
@@ -401,11 +381,9 @@ def test_varlen_decode_main(args):
num_tokens
,
q_h
,
head_size
=
q_decode
.
shape
batch
=
cu_seqlens_k
.
size
(
0
)
-
1
k_h
=
k_varlen
.
size
(
1
)
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
args
.
test_sink
,
page_block_size
)
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
args
.
test_sink
,
page_block_size
)
block_table
=
torch
.
zeros
(
batch
,
math
.
ceil
(
real_max_k_seqlen
/
page_block_size
),
device
=
'cuda'
,
dtype
=
torch
.
int32
)
block_table
=
torch
.
zeros
(
batch
,
math
.
ceil
(
real_max_k_seqlen
/
page_block_size
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
block_cnt
=
0
for
i
in
range
(
batch
):
cur_seqlen
=
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
()
...
...
@@ -425,7 +403,8 @@ def test_varlen_decode_main(args):
args
.
num_split
,
softmax_scale
,
s_aux
=
sink
,
block_size
=
block_size
)
block_size
=
block_size
,
)
O_tilelang
,
S_tilelang
=
flash_attn_with_attn_pool_decode_tilelang
(
q_decode
,
k_varlen
,
...
...
@@ -441,9 +420,7 @@ def test_varlen_decode_main(args):
block_table
=
block_table
,
)
for
i
in
range
(
batch_size
):
S_tilelang
[
i
,
:,
math
.
ceil
((
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
())
/
block_size
):]
=
0
S_tilelang
[
i
,
:,
math
.
ceil
((
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
())
/
block_size
)
:]
=
0
# Create torch reference - pad tensors for comparison
k_padded_list
=
[]
...
...
@@ -457,8 +434,8 @@ def test_varlen_decode_main(args):
k_end
=
cu_seqlens_k
[
i
+
1
]
# Pad to max_seqlen_k
k_padded
=
torch
.
zeros
(
max_seqlen_k
,
kv_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
v_padded
=
torch
.
zeros
(
max_seqlen_k
,
kv_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
k_padded
=
torch
.
zeros
(
max_seqlen_k
,
kv_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
v_padded
=
torch
.
zeros
(
max_seqlen_k
,
kv_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
k_padded
[:
actual_k_len
]
=
k_varlen
[
k_start
:
k_end
]
v_padded
[:
actual_k_len
]
=
v_varlen
[
k_start
:
k_end
]
...
...
@@ -467,10 +444,8 @@ def test_varlen_decode_main(args):
v_padded_list
.
append
(
v_padded
)
# Stack to create batched tensors [b, max_seqlen, kv_heads, head_size]
k_padded_batched
=
torch
.
stack
(
k_padded_list
,
dim
=
0
).
transpose
(
1
,
2
)
# [b, kv_heads, max_seqlen, head_size]
v_padded_batched
=
torch
.
stack
(
v_padded_list
,
dim
=
0
).
transpose
(
1
,
2
)
# [b, kv_heads, max_seqlen, head_size]
k_padded_batched
=
torch
.
stack
(
k_padded_list
,
dim
=
0
).
transpose
(
1
,
2
)
# [b, kv_heads, max_seqlen, head_size]
v_padded_batched
=
torch
.
stack
(
v_padded_list
,
dim
=
0
).
transpose
(
1
,
2
)
# [b, kv_heads, max_seqlen, head_size]
# Expand q to match kv heads: [b, q_heads, 1, head_size]
q_expanded
=
q_decode
.
unsqueeze
(
2
)
# [b, q_heads, 1, head_size]
...
...
@@ -480,20 +455,17 @@ def test_varlen_decode_main(args):
print
(
f
"v_padded_batched shape:
{
v_padded_batched
.
shape
}
"
)
# Compute torch reference
k_repeat
=
repeat_kv
(
k_padded_batched
,
q_heads
//
kv_heads
)
# [b, q_heads, max_seqlen, head_size]
v_repeat
=
repeat_kv
(
v_padded_batched
,
q_heads
//
kv_heads
)
# [b, q_heads, max_seqlen, head_size]
k_repeat
=
repeat_kv
(
k_padded_batched
,
q_heads
//
kv_heads
)
# [b, q_heads, max_seqlen, head_size]
v_repeat
=
repeat_kv
(
v_padded_batched
,
q_heads
//
kv_heads
)
# [b, q_heads, max_seqlen, head_size]
if
sink
is
None
:
# Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen]
attn_score
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [b, q_heads, 1, max_seqlen]
attn_score
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [b, q_heads, 1, max_seqlen]
# Apply sequence length masking
for
i
in
range
(
batch_size
):
actual_k_len
=
k_seqlens
[
i
]
attn_score
[
i
,
:,
:,
actual_k_len
:]
=
float
(
'
-inf
'
)
attn_score
[
i
,
:,
:,
actual_k_len
:]
=
float
(
"
-inf
"
)
attn_weights
=
attn_score
.
softmax
(
dim
=-
1
)
# [b, q_heads, 1, max_seqlen]
...
...
@@ -506,13 +478,12 @@ def test_varlen_decode_main(args):
O_torch
=
torch
.
matmul
(
attn_weights
,
v_repeat
)
# [b, q_heads, 1, head_size]
else
:
# s_aux attention
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [b, q_heads, 1, max_seqlen]
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [b, q_heads, 1, max_seqlen]
# Apply sequence length masking
for
i
in
range
(
batch_size
):
actual_k_len
=
k_seqlens
[
i
]
logits
[
i
,
:,
:,
actual_k_len
:]
=
float
(
'
-inf
'
)
logits
[
i
,
:,
:,
actual_k_len
:]
=
float
(
"
-inf
"
)
sink_expanded
=
sink
.
view
(
1
,
q_heads
,
1
,
1
)
# [1, q_heads, 1, 1]
logits_max
=
torch
.
max
(
logits
,
dim
=-
1
,
keepdim
=
True
).
values
...
...
@@ -528,8 +499,7 @@ def test_varlen_decode_main(args):
attn_weights
[
i
,
:,
:,
actual_k_len
:]
=
0.0
# Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size]
O_torch
=
torch
.
matmul
(
attn_weights
.
to
(
v_repeat
.
dtype
),
v_repeat
)
# [b, q_heads, 1, head_size]
O_torch
=
torch
.
matmul
(
attn_weights
.
to
(
v_repeat
.
dtype
),
v_repeat
)
# [b, q_heads, 1, head_size]
O_torch
=
O_torch
.
squeeze
(
2
)
# [b, q_heads, head_size]
...
...
@@ -538,7 +508,8 @@ def test_varlen_decode_main(args):
attn_weights
.
squeeze
(
2
),
# [b, q_heads, max_seqlen]
kernel_size
=
(
q_heads
,
block_size
),
stride
=
(
q_heads
,
block_size
),
ceil_mode
=
True
).
to
(
dtype
=
torch
.
float16
)
# [b, 1, ceil(max_seqlen/block_size)]
ceil_mode
=
True
,
).
to
(
dtype
=
torch
.
float16
)
# [b, 1, ceil(max_seqlen/block_size)]
print
(
f
"O_triton shape:
{
O_triton
.
shape
}
"
)
print
(
f
"O_tilelang shape:
{
O_tilelang
.
shape
}
"
)
...
...
@@ -554,22 +525,16 @@ def test_varlen_decode_main(args):
print
(
f
"Max difference in O_tilelang:
{
max_diff_o_tl
.
item
()
}
"
)
max_diff_s
=
torch
.
max
(
torch
.
abs
(
S_triton
-
attn_score_pooled
))
max_diff_s_tl
=
torch
.
max
(
torch
.
abs
(
S_tilelang
[:,
:,
:
math
.
ceil
(
max_seqlen_k
/
block_size
)]
-
attn_score_pooled
))
max_diff_s_tl
=
torch
.
max
(
torch
.
abs
(
S_tilelang
[:,
:,
:
math
.
ceil
(
max_seqlen_k
/
block_size
)]
-
attn_score_pooled
))
print
(
f
"Max difference in S:
{
max_diff_s
.
item
()
}
"
)
print
(
f
"Max difference in S_tilelang:
{
max_diff_s_tl
.
item
()
}
"
)
assert
torch
.
allclose
(
O_triton
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o
.
item
()
}
"
assert
torch
.
allclose
(
S_triton
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s
.
item
()
}
"
assert
torch
.
allclose
(
O_tilelang
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o_tl
.
item
()
}
"
assert
torch
.
allclose
(
S_tilelang
[:,
:,
:
math
.
ceil
(
max_seqlen_k
/
block_size
)],
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s_tl
.
item
()
}
"
assert
torch
.
allclose
(
O_triton
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o
.
item
()
}
"
assert
torch
.
allclose
(
S_triton
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s
.
item
()
}
"
assert
torch
.
allclose
(
O_tilelang
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o_tl
.
item
()
}
"
assert
torch
.
allclose
(
S_tilelang
[:,
:,
:
math
.
ceil
(
max_seqlen_k
/
block_size
)],
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
(
f
"Score mismatch:
{
max_diff_s_tl
.
item
()
}
"
)
print
(
"✅ All tests passed!"
)
...
...
@@ -605,7 +570,7 @@ def speed_benchmark_decode_comparison(args):
k_seqlens
=
torch
.
full
((
batch_size
,),
max_k_seqlen
,
dtype
=
int
)
# Generate cumulative sequence lengths for k
cu_seqlens_k
=
torch
.
zeros
(
batch_size
+
1
,
device
=
'
cuda
'
,
dtype
=
torch
.
int32
)
cu_seqlens_k
=
torch
.
zeros
(
batch_size
+
1
,
device
=
"
cuda
"
,
dtype
=
torch
.
int32
)
total_k_tokens
=
0
for
i
in
range
(
batch_size
):
cu_seqlens_k
[
i
]
=
total_k_tokens
...
...
@@ -613,9 +578,9 @@ def speed_benchmark_decode_comparison(args):
cu_seqlens_k
[
batch_size
]
=
total_k_tokens
# Generate tensors
q_decode
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
k_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
v_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
'
cuda
'
,
dtype
=
dtype
)
q_decode
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
k_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
v_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
"
cuda
"
,
dtype
=
dtype
)
softmax_scale
=
1.0
/
math
.
sqrt
(
head_size
)
max_seqlen_k
=
int
(
k_seqlens
.
max
())
...
...
@@ -623,7 +588,7 @@ def speed_benchmark_decode_comparison(args):
# Generate sink values if needed
sink
=
None
if
args
.
test_sink
:
sink
=
torch
.
randn
(
q_heads
,
device
=
'
cuda
'
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
sink
=
torch
.
randn
(
q_heads
,
device
=
"
cuda
"
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
print
(
" Using sink attention with sink values"
)
print
(
"Setup complete:"
)
...
...
@@ -636,11 +601,9 @@ def speed_benchmark_decode_comparison(args):
num_tokens
,
q_h
,
head_size
=
q_decode
.
shape
batch
=
cu_seqlens_k
.
size
(
0
)
-
1
k_h
=
k_varlen
.
size
(
1
)
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
args
.
test_sink
,
page_block_size
)
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
args
.
test_sink
,
page_block_size
)
block_table
=
torch
.
zeros
(
batch
,
math
.
ceil
(
real_max_k_seqlen
/
page_block_size
),
device
=
'cuda'
,
dtype
=
torch
.
int32
)
block_table
=
torch
.
zeros
(
batch
,
math
.
ceil
(
real_max_k_seqlen
/
page_block_size
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
block_cnt
=
0
for
i
in
range
(
batch
):
cur_seqlen
=
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
()
...
...
@@ -671,36 +634,41 @@ def speed_benchmark_decode_comparison(args):
# Benchmark
print
(
"⚡ Benchmarking Triton kernel (100 iterations)..."
)
triton_time
=
do_bench
(
flash_attn_with_attn_pool_decode
,
q_decode
,
k_varlen
,
v_varlen
,
cu_seqlens_k
,
max_seqlen_k
,
args
.
k_seqlen
,
1
,
softmax_scale
,
sink
,
block_size
)
triton_time
=
do_bench
(
flash_attn_with_attn_pool_decode
,
q_decode
,
k_varlen
,
v_varlen
,
cu_seqlens_k
,
max_seqlen_k
,
args
.
k_seqlen
,
1
,
softmax_scale
,
sink
,
block_size
,
)
print
(
f
"Average decode kernel time Triton:
{
triton_time
:.
3
f
}
ms"
)
print
(
f
"Speedup:
{
(
triton_time
/
tilelang_time
):.
3
f
}
"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
'Flash Attention Decode with Attention Pooling'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
1
,
help
=
'Batch size'
)
parser
.
add_argument
(
'--q_heads'
,
type
=
int
,
default
=
32
,
help
=
'Number of query heads'
)
parser
.
add_argument
(
'--kv_heads'
,
type
=
int
,
default
=
8
,
help
=
'Number of key-value heads'
)
parser
.
add_argument
(
'--k_seqlen'
,
type
=
int
,
default
=
8192
,
help
=
'Key sequence length'
)
parser
.
add_argument
(
'--head_size'
,
type
=
int
,
default
=
128
,
choices
=
[
64
,
128
,
256
],
help
=
'Head dimension'
)
parser
.
add_argument
(
'--block_size'
,
type
=
int
,
default
=
128
,
help
=
'Block size for computation'
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'bfloat16'
,
choices
=
[
'float16'
,
'bfloat16'
],
help
=
'Data type'
)
parser
.
add_argument
(
'--test_varlen'
,
action
=
'store_true'
,
help
=
'Test with truly variable sequence lengths'
)
parser
.
add_argument
(
'--test_sink'
,
action
=
'store_true'
,
help
=
'Test with sink attention mechanism'
)
parser
.
add_argument
(
'--benchmark'
,
action
=
'store_true'
,
help
=
'Run speed benchmark'
)
parser
.
add_argument
(
'--num_split'
,
type
=
int
,
default
=
1
,
choices
=
[
1
,
16
],
help
=
'Number of splits'
)
parser
.
add_argument
(
'--page_block_size'
,
type
=
int
,
default
=
128
,
help
=
'Page block size'
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Flash Attention Decode with Attention Pooling"
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
,
help
=
"Batch size"
)
parser
.
add_argument
(
"--q_heads"
,
type
=
int
,
default
=
32
,
help
=
"Number of query heads"
)
parser
.
add_argument
(
"--kv_heads"
,
type
=
int
,
default
=
8
,
help
=
"Number of key-value heads"
)
parser
.
add_argument
(
"--k_seqlen"
,
type
=
int
,
default
=
8192
,
help
=
"Key sequence length"
)
parser
.
add_argument
(
"--head_size"
,
type
=
int
,
default
=
128
,
choices
=
[
64
,
128
,
256
],
help
=
"Head dimension"
)
parser
.
add_argument
(
"--block_size"
,
type
=
int
,
default
=
128
,
help
=
"Block size for computation"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"bfloat16"
,
choices
=
[
"float16"
,
"bfloat16"
],
help
=
"Data type"
)
parser
.
add_argument
(
"--test_varlen"
,
action
=
"store_true"
,
help
=
"Test with truly variable sequence lengths"
)
parser
.
add_argument
(
"--test_sink"
,
action
=
"store_true"
,
help
=
"Test with sink attention mechanism"
)
parser
.
add_argument
(
"--benchmark"
,
action
=
"store_true"
,
help
=
"Run speed benchmark"
)
parser
.
add_argument
(
"--num_split"
,
type
=
int
,
default
=
1
,
choices
=
[
1
,
16
],
help
=
"Number of splits"
)
parser
.
add_argument
(
"--page_block_size"
,
type
=
int
,
default
=
128
,
help
=
"Page block size"
)
args
=
parser
.
parse_args
()
args
.
test_sink
=
True
args
.
test_varlen
=
True
args
.
dtype
=
'
float16
'
args
.
dtype
=
"
float16
"
args
.
num_split
=
1
if
args
.
benchmark
:
...
...
examples/flash_decoding/example_mha_inference.py
View file @
29051439
...
...
@@ -10,7 +10,7 @@ num_split = 4
@
tilelang
.
jit
(
out_idx
=
[
5
])
def
flashattn
(
batch
,
heads
,
seqlen_q
,
seqlen_kv
,
dim
,
is_causal
,
block_M
,
block_N
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape_q
=
[
batch
,
seqlen_q
,
heads
,
dim
]
shape_kv
=
[
batch
,
seqlen_kv
,
heads
,
dim
]
part_shape
=
[
batch
,
seqlen_q
,
heads
,
num_split
,
dim
]
...
...
@@ -29,14 +29,11 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
bid
:
T
.
int32
,
sid
:
T
.
int32
,
):
T
.
copy
(
K
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
block_N
:(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
block_N
,
hid
,
:],
K_shared
)
T
.
copy
(
K
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
block_N
:
(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
block_N
,
hid
,
:],
K_shared
)
# TODO: Handle causal split case
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
mid
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
mid
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
...
@@ -52,9 +49,7 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
bid
:
T
.
int32
,
sid
:
T
.
int32
,
):
T
.
copy
(
V
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
block_N
:(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
block_N
,
hid
,
:],
V_shared
)
T
.
copy
(
V
[
bid
,
(
seqlen_kv
//
num_split
)
*
sid
+
k
*
block_N
:
(
seqlen_kv
//
num_split
)
*
sid
+
(
k
+
1
)
*
block_N
,
hid
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
...
...
@@ -105,9 +100,7 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
seqlen_q
],
dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
seqlen_q
,
block_M
),
heads
*
batch
,
num_split
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
seqlen_q
,
block_M
),
heads
*
batch
,
num_split
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
...
...
@@ -128,33 +121,30 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
# NOTE(wt): tma barrier has some problems with padded dimensions (seq_q here) currently
# disable relevant tma copy and use SIMT as fallback for now
T
.
copy
(
Q
[
bid
,
mid
*
block_M
:
(
mid
+
1
)
*
block_M
,
hid
,
:],
Q_shared
,
disable_tma
=
True
)
T
.
copy
(
Q
[
bid
,
mid
*
block_M
:
(
mid
+
1
)
*
block_M
,
hid
,
:],
Q_shared
,
disable_tma
=
True
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
# TODO: Handle causal split case
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
seqlen_kv
,
block_N
),
T
.
ceildiv
(
(
mid
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
(
seqlen_kv
//
num_split
),
block_N
))
T
.
min
(
T
.
ceildiv
(
seqlen_kv
,
block_N
),
T
.
ceildiv
((
mid
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
((
seqlen_kv
//
num_split
),
block_N
)
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
2
):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
mid
,
hid
,
bid
,
sid
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
hid
,
bid
,
sid
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
logsum
,
glse
[
bid
,
hid
,
sid
,
mid
*
block_M
:
(
mid
+
1
)
*
block_M
])
T
.
copy
(
logsum
,
glse
[
bid
,
hid
,
sid
,
mid
*
block_M
:
(
mid
+
1
)
*
block_M
])
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output_partial
[
bid
,
mid
*
block_M
:(
mid
+
1
)
*
block_M
,
hid
,
sid
,
:],
disable_tma
=
True
)
T
.
copy
(
O_shared
,
Output_partial
[
bid
,
mid
*
block_M
:
(
mid
+
1
)
*
block_M
,
hid
,
sid
,
:],
disable_tma
=
True
)
@
T
.
macro
def
combine
(
...
...
@@ -173,20 +163,25 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
lse_max_local
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scale_local
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
o_accum_local
:
T
.
Fragment
(
o_accum_local
.
shape
,
forward_thread_fn
=
lambda
i
,
j
:
i
),
o_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
o_shared
),
po_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
po_shared
),
})
}
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
o_accum_local
)
T
.
copy
(
glse
[
T
.
copy
(
glse
[
bz
,
by
,
:,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
],
lse_local
)
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
],
lse_local
,
)
T
.
reduce_max
(
lse_local
,
lse_max_local
,
dim
=
0
,
clear
=
False
)
for
k
in
T
.
Pipelined
(
num_split
):
T
.
copy
(
lse_local
[
k
,
:],
lse_local_split
)
...
...
@@ -195,10 +190,7 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
for
i
in
T
.
Parallel
(
block_M
):
lse_logsum_local
[
i
]
=
T
.
log2
(
lse_logsum_local
[
i
])
+
lse_max_local
[
i
]
for
k
in
T
.
Pipelined
(
num_split
,
num_stages
=
2
):
T
.
copy
(
Output_partial
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
,
k
,
:],
po_shared
,
disable_tma
=
True
)
T
.
copy
(
Output_partial
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
k
,
:],
po_shared
,
disable_tma
=
True
)
T
.
copy
(
po_shared
,
po_local
)
for
i
in
T
.
Parallel
(
block_M
):
lse_local_split
[
i
]
=
lse_local
[
k
,
i
]
...
...
@@ -207,7 +199,7 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
o_accum_local
[
i
,
j
]
+=
po_local
[
i
,
j
]
*
scale_local
[
i
]
T
.
copy
(
o_accum_local
,
o_shared
)
T
.
copy
(
o_shared
,
Output
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:],
disable_tma
=
True
)
T
.
copy
(
o_shared
,
Output
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
,
:],
disable_tma
=
True
)
@
T
.
prim_func
def
flashattn_mha_inference
(
...
...
@@ -227,10 +219,10 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
def
ref_program
(
Q
,
K
,
V
,
glse
,
Output_partial
,
causal
):
assert
causal
is
False
dim
=
Q
.
size
(
-
1
)
scores
=
torch
.
einsum
(
'
bqhd,bkhd->bhqk
'
,
Q
,
K
)
scores
=
torch
.
einsum
(
"
bqhd,bkhd->bhqk
"
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'
bhqk,bkhd->bqhd
'
,
attention_weights
,
V
)
output
=
torch
.
einsum
(
"
bhqk,bkhd->bqhd
"
,
attention_weights
,
V
)
return
output
...
...
@@ -258,7 +250,7 @@ def flash_split_ref(Q, K, V, causal):
block_N
=
128
seqlen_kv
=
K
.
size
(
1
)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
acc_s
=
torch
.
empty
((
batch
,
nheads
,
block_M
,
block_N
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
acc_s_cast
=
torch
.
empty
((
batch
,
nheads
,
block_M
,
block_N
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
acc_o
=
torch
.
empty
((
batch
,
block_M
,
nheads
,
dim
),
device
=
"cuda"
,
dtype
=
torch
.
float
)
...
...
@@ -275,14 +267,15 @@ def flash_split_ref(Q, K, V, causal):
for
ks
in
range
(
num_split
):
acc_o
.
fill_
(
0
)
logsum
.
fill_
(
0
)
scores_max
.
fill_
(
float
(
'
-inf
'
))
scores_max_prev
.
fill_
(
float
(
'
-inf
'
))
scores_max
.
fill_
(
float
(
"
-inf
"
))
scores_max_prev
.
fill_
(
float
(
"
-inf
"
))
for
i
in
range
(
int
((
seqlen_kv
//
num_split
)
/
block_N
)):
acc_s
.
fill_
(
0
)
acc_s
=
torch
.
einsum
(
'bqhd,bkhd->bhqk'
,
Q_
,
K
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:,
:])
# [batch, seqlen, nheads, block_N]
acc_s
=
torch
.
einsum
(
"bqhd,bkhd->bhqk"
,
Q_
,
K
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:
(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:,
:],
)
# [batch, seqlen, nheads, block_N]
scores_max_prev
=
scores_max
scores_max
=
acc_s
.
max
(
dim
=-
1
,
keepdim
=
False
).
values
# [blockM]
scores_scale
=
torch
.
exp2
(
scores_max_prev
-
scores_max
)
...
...
@@ -290,9 +283,10 @@ def flash_split_ref(Q, K, V, causal):
acc_s
=
torch
.
exp2
(
acc_s
-
scores_max
[:,
:,
:,
None
])
acc_s_cast
=
acc_s
.
to
(
torch
.
float16
)
acc_o
+=
torch
.
einsum
(
'bhqk,bkhd->bqhd'
,
acc_s_cast
,
V
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:,
:])
"bhqk,bkhd->bqhd"
,
acc_s_cast
,
V
[:,
(
seqlen_kv
//
num_split
)
*
ks
+
i
*
block_N
:
(
seqlen_kv
//
num_split
)
*
ks
+
(
i
+
1
)
*
block_N
,
:,
:],
)
scores_sum
=
acc_s
.
sum
(
dim
=-
1
,
keepdim
=
False
)
logsum
=
logsum
*
scores_scale
+
scores_sum
acc_o
/=
logsum
[:,
:,
:,
None
].
transpose
(
1
,
2
)
...
...
@@ -300,8 +294,7 @@ def flash_split_ref(Q, K, V, causal):
gacc_o
[
ks
,
:,
:,
:,
:]
=
acc_o
glogsum
[
ks
,
:,
:,
:]
=
logsum
return
glogsum
.
to
(
torch
.
float16
).
permute
(
1
,
2
,
0
,
3
),
gacc_o
.
to
(
torch
.
float16
).
permute
(
1
,
2
,
3
,
0
,
4
)
return
glogsum
.
to
(
torch
.
float16
).
permute
(
1
,
2
,
0
,
3
),
gacc_o
.
to
(
torch
.
float16
).
permute
(
1
,
2
,
3
,
0
,
4
)
def
main
(
BATCH
=
1
,
H
=
32
,
Q_CTX
=
128
,
KV_CTX
=
8192
,
D_HEAD
=
128
,
causal
=
False
):
...
...
examples/fusedmoe/example_fusedmoe_tilelang.py
View file @
29051439
...
...
@@ -9,7 +9,8 @@ from example_fusedmoe_torch import *
@
tilelang
.
jit
(
pass_configs
=
{
"tl.disable_tma_lower"
:
True
,
"tl.disable_warp_specialized"
:
True
})
def
moe_forward_tilelang_shared
(
d_hidden
,
def
moe_forward_tilelang_shared
(
d_hidden
,
d_expert
,
n_shared_experts
,
dtype
,
...
...
@@ -18,8 +19,8 @@ def moe_forward_tilelang_shared(d_hidden,
block_dhidden
=
128
,
block_dexpert
=
128
,
threads
=
256
,
num_stages
=
1
):
num_stages
=
1
,
):
scale
=
1.44269504
# log2(e)
# Parameters
...
...
@@ -44,9 +45,7 @@ def moe_forward_tilelang_shared(d_hidden,
output
:
T
.
Tensor
(
input_shape
,
dtype
),
# type: ignore
):
# Step 1: Compute gate and up logits
with
T
.
Kernel
(
T
.
ceildiv
(
num_tokens
,
block_token
),
T
.
ceildiv
(
dexpert
,
block_dexpert
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
num_tokens
,
block_token
),
T
.
ceildiv
(
dexpert
,
block_dexpert
),
threads
=
threads
)
as
(
bx
,
by
):
# Split the block to shared experts and routed experts
input_shared
=
T
.
alloc_fragment
((
block_token
,
block_dhidden
),
dtype
=
dtype
)
W_gate_shared
=
T
.
alloc_shared
((
block_dexpert
,
block_dhidden
),
dtype
=
dtype
)
...
...
@@ -70,16 +69,13 @@ def moe_forward_tilelang_shared(d_hidden,
# Fuse with SiLU and element-wise product
for
i
,
j
in
T
.
Parallel
(
block_token
,
block_dexpert
):
gate_logits_local
[
i
,
j
]
=
gate_logits_local
[
i
,
j
]
*
(
1.0
/
(
1.0
+
T
.
exp2
(
-
gate_logits_local
[
i
,
j
]
*
scale
)))
gate_logits_local
[
i
,
j
]
=
gate_logits_local
[
i
,
j
]
*
(
1.0
/
(
1.0
+
T
.
exp2
(
-
gate_logits_local
[
i
,
j
]
*
scale
)))
up_logits_local
[
i
,
j
]
=
up_logits_local
[
i
,
j
]
*
gate_logits_local
[
i
,
j
]
T
.
copy
(
up_logits_local
,
up_logits
[
bx
*
block_token
,
by
*
block_dexpert
])
# Step 2: Compute down logits
with
T
.
Kernel
(
T
.
ceildiv
(
num_tokens
,
block_token
),
T
.
ceildiv
(
dhidden
,
block_dhidden
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
num_tokens
,
block_token
),
T
.
ceildiv
(
dhidden
,
block_dhidden
),
threads
=
threads
)
as
(
bx
,
by
):
up_logits_shared
=
T
.
alloc_fragment
((
block_token
,
block_dexpert
),
dtype
=
dtype
)
W_down_shared
=
T
.
alloc_shared
((
block_dhidden
,
block_dexpert
),
dtype
=
dtype
)
output_local
=
T
.
alloc_fragment
((
block_token
,
block_dhidden
),
dtype
=
accum_type
)
...
...
@@ -98,7 +94,8 @@ def moe_forward_tilelang_shared(d_hidden,
@
tilelang
.
jit
(
pass_configs
=
{
"tl.disable_tma_lower"
:
True
,
"tl.disable_warp_specialized"
:
True
})
def
moe_forward_tilelang_routed
(
d_hidden
,
def
moe_forward_tilelang_routed
(
d_hidden
,
d_expert
,
n_routed_experts
,
dtype
,
...
...
@@ -110,8 +107,8 @@ def moe_forward_tilelang_routed(d_hidden,
threads
=
256
,
num_stages
=
1
,
k_pack
=
1
,
coalesced_width
=
None
):
coalesced_width
=
None
,
):
scale
=
1.44269504
# log2(e)
# Parameters
...
...
@@ -132,8 +129,8 @@ def moe_forward_tilelang_routed(d_hidden,
routed_expert_gate_shape
=
(
n_routed_experts
,
dexpert
,
dhidden
)
routed_expert_up_shape
=
(
n_routed_experts
,
dexpert
,
dhidden
)
routed_expert_down_shape
=
(
n_routed_experts
,
dhidden
,
dexpert
)
routed_expert_weights_shape
=
(
group_sum
)
group_sizes_shape
=
(
n_routed_experts
)
routed_expert_weights_shape
=
group_sum
group_sizes_shape
=
n_routed_experts
@
T
.
prim_func
def
kernel
(
...
...
@@ -168,48 +165,37 @@ def moe_forward_tilelang_routed(d_hidden,
cur_group_idx
[
0
]
=
group_idx_for_bx
[
bx
]
cur_group_size
[
0
]
=
group_sizes
[
cur_group_idx
[
0
]]
m_start
=
m_start_padded
-
group_padded_offsets
[
cur_group_idx
[
0
]]
+
group_offsets
[
cur_group_idx
[
0
]]
actual_rows
=
T
.
max
(
0
,
T
.
min
(
block_token
,
cur_group_size
[
0
]
-
(
m_start_padded
-
group_padded_offsets
[
cur_group_idx
[
0
]])))
m_start
=
m_start_padded
-
group_padded_offsets
[
cur_group_idx
[
0
]]
+
group_offsets
[
cur_group_idx
[
0
]]
actual_rows
=
T
.
max
(
0
,
T
.
min
(
block_token
,
cur_group_size
[
0
]
-
(
m_start_padded
-
group_padded_offsets
[
cur_group_idx
[
0
]])))
T
.
clear
(
gate_logits_local
)
T
.
clear
(
up_logits_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
dhidden
,
block_dhidden
),
num_stages
=
num_stages
):
T
.
copy
(
input
[
m_start
:
m_start
+
block_token
,
k
*
block_dhidden
:
(
k
+
1
)
*
block_dhidden
],
input
[
m_start
:
m_start
+
block_token
,
k
*
block_dhidden
:
(
k
+
1
)
*
block_dhidden
],
input_shared
,
coalesced_width
=
coalesced_width
)
coalesced_width
=
coalesced_width
,
)
T
.
copy
(
routed_expert_gate
[
cur_group_idx
[
0
],
by
*
block_dexpert
:(
by
+
1
)
*
block_dexpert
,
k
*
block_dhidden
:(
k
+
1
)
*
block_dhidden
],
routed_expert_gate_shared
,
coalesced_width
=
coalesced_width
)
T
.
gemm
(
input_shared
,
routed_expert_gate
[
cur_group_idx
[
0
],
by
*
block_dexpert
:
(
by
+
1
)
*
block_dexpert
,
k
*
block_dhidden
:
(
k
+
1
)
*
block_dhidden
],
routed_expert_gate_shared
,
gate_logits_local
,
k_pack
=
k_pack
,
transpose_B
=
True
)
coalesced_width
=
coalesced_width
,
)
T
.
gemm
(
input_shared
,
routed_expert_gate_shared
,
gate_logits_local
,
k_pack
=
k_pack
,
transpose_B
=
True
)
T
.
copy
(
routed_expert_up
[
cur_group_idx
[
0
],
by
*
block_dexpert
:(
by
+
1
)
*
block_dexpert
,
k
*
block_dhidden
:(
k
+
1
)
*
block_dhidden
],
routed_expert_up
[
cur_group_idx
[
0
],
by
*
block_dexpert
:
(
by
+
1
)
*
block_dexpert
,
k
*
block_dhidden
:
(
k
+
1
)
*
block_dhidden
],
routed_expert_up_shared
,
coalesced_width
=
coalesced_width
)
T
.
gemm
(
input_shared
,
routed_expert_up_shared
,
up_logits_local
,
k_pack
=
k_pack
,
transpose_B
=
True
)
coalesced_width
=
coalesced_width
,
)
T
.
gemm
(
input_shared
,
routed_expert_up_shared
,
up_logits_local
,
k_pack
=
k_pack
,
transpose_B
=
True
)
for
i
,
j
in
T
.
Parallel
(
block_token
,
block_dexpert
):
gate_logits_local
[
i
,
j
]
=
gate_logits_local
[
i
,
j
]
*
(
1.0
/
(
1.0
+
T
.
exp2
(
-
gate_logits_local
[
i
,
j
]
*
scale
)))
gate_logits_local
[
i
,
j
]
=
gate_logits_local
[
i
,
j
]
*
(
1.0
/
(
1.0
+
T
.
exp2
(
-
gate_logits_local
[
i
,
j
]
*
scale
)))
up_logits_local
[
i
,
j
]
=
up_logits_local
[
i
,
j
]
*
gate_logits_local
[
i
,
j
]
for
i
,
j
in
T
.
Parallel
(
block_token
,
block_dexpert
):
...
...
@@ -232,50 +218,35 @@ def moe_forward_tilelang_routed(d_hidden,
cur_group_idx
[
0
]
=
group_idx_for_bx
[
bx
]
cur_group_size
[
0
]
=
group_sizes
[
cur_group_idx
[
0
]]
m_start
=
m_start_padded
-
group_padded_offsets
[
cur_group_idx
[
0
]]
+
group_offsets
[
cur_group_idx
[
0
]]
actual_rows
=
T
.
max
(
0
,
T
.
min
(
block_token
,
cur_group_size
[
0
]
-
(
m_start_padded
-
group_padded_offsets
[
cur_group_idx
[
0
]])))
m_start
=
m_start_padded
-
group_padded_offsets
[
cur_group_idx
[
0
]]
+
group_offsets
[
cur_group_idx
[
0
]]
actual_rows
=
T
.
max
(
0
,
T
.
min
(
block_token
,
cur_group_size
[
0
]
-
(
m_start_padded
-
group_padded_offsets
[
cur_group_idx
[
0
]])))
T
.
clear
(
output_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
dexpert
,
block_dexpert
),
num_stages
=
num_stages
):
T
.
copy
(
up_logits
[
m_start
:
m_start
+
block_token
,
k
*
block_dexpert
:(
k
+
1
)
*
block_dexpert
],
up_logits
[
m_start
:
m_start
+
block_token
,
k
*
block_dexpert
:
(
k
+
1
)
*
block_dexpert
],
up_logits_shared
,
coalesced_width
=
coalesced_width
)
coalesced_width
=
coalesced_width
,
)
T
.
copy
(
routed_expert_down
[
cur_group_idx
[
0
],
by
*
block_dhidden
:(
by
+
1
)
*
block_dhidden
,
k
*
block_dexpert
:(
k
+
1
)
*
block_dexpert
],
routed_expert_down_shared
,
coalesced_width
=
coalesced_width
)
T
.
gemm
(
up_logits_shared
,
routed_expert_down
[
cur_group_idx
[
0
],
by
*
block_dhidden
:
(
by
+
1
)
*
block_dhidden
,
k
*
block_dexpert
:
(
k
+
1
)
*
block_dexpert
],
routed_expert_down_shared
,
output_local
,
k_pack
=
k_pack
,
transpose_B
=
True
)
coalesced_width
=
coalesced_width
,
)
T
.
gemm
(
up_logits_shared
,
routed_expert_down_shared
,
output_local
,
k_pack
=
k_pack
,
transpose_B
=
True
)
for
i
,
j
in
T
.
Parallel
(
block_token
,
block_dhidden
):
if
i
<
actual_rows
:
output
[
m_start
+
i
,
by
*
block_dhidden
+
j
]
=
output_local
[
i
,
j
]
*
routed_expert_weights
[
m_start
+
i
]
output
[
m_start
+
i
,
by
*
block_dhidden
+
j
]
=
output_local
[
i
,
j
]
*
routed_expert_weights
[
m_start
+
i
]
return
kernel
class
Expert
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Dict
,
gate
:
torch
.
Tensor
,
up
:
torch
.
Tensor
,
down
:
torch
.
Tensor
,
d_expert
:
Optional
[
int
]
=
None
):
def
__init__
(
self
,
config
:
Dict
,
gate
:
torch
.
Tensor
,
up
:
torch
.
Tensor
,
down
:
torch
.
Tensor
,
d_expert
:
Optional
[
int
]
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
act_fn
=
nn
.
SiLU
()
...
...
@@ -294,14 +265,13 @@ class Expert(nn.Module):
class
MoEGate
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Dict
,
weights
:
Dict
):
super
().
__init__
()
self
.
top_k
:
int
=
config
[
"n_experts_per_token"
]
self
.
num_experts
:
int
=
config
[
"n_routed_experts"
]
self
.
d_hidden
:
int
=
config
[
"d_hidden"
]
self
.
W_g_weight
=
weights
[
'
router.weight
'
].
t
()
self
.
W_g_weight
=
weights
[
"
router.weight
"
].
t
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
logits
=
x
@
self
.
W_g_weight
...
...
@@ -312,76 +282,69 @@ class MoEGate(nn.Module):
class
MoE
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Dict
,
shared_kernel
:
tilelang
.
JITKernel
,
routed_kernel
:
tilelang
.
JITKernel
,
weights
:
Dict
,
padding_M
:
int
=
128
):
def
__init__
(
self
,
config
:
Dict
,
shared_kernel
:
tilelang
.
JITKernel
,
routed_kernel
:
tilelang
.
JITKernel
,
weights
:
Dict
,
padding_M
:
int
=
128
):
super
().
__init__
()
self
.
config
=
config
self
.
shared_kernel
=
shared_kernel
self
.
routed_kernel
=
routed_kernel
self
.
padding_M
=
padding_M
self
.
experts
=
nn
.
ModuleList
([
self
.
experts
=
nn
.
ModuleList
(
[
Expert
(
config
,
gate
=
weights
[
f
'experts.
{
i
}
.0.weight'
],
up
=
weights
[
f
'experts.
{
i
}
.1.weight'
],
down
=
weights
[
f
'experts.
{
i
}
.2.weight'
])
for
i
in
range
(
config
[
"n_routed_experts"
])
])
gate
=
weights
[
f
"experts.
{
i
}
.0.weight"
],
up
=
weights
[
f
"experts.
{
i
}
.1.weight"
],
down
=
weights
[
f
"experts.
{
i
}
.2.weight"
],
)
for
i
in
range
(
config
[
"n_routed_experts"
])
]
)
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
gating_network
=
MoEGate
(
config
,
weights
).
to
(
self
.
device
)
shared_expert_dim
=
config
[
"d_expert"
]
*
config
[
"n_shared_experts"
]
self
.
shared_expert
=
Expert
(
config
=
config
,
gate
=
weights
[
'shared_experts.0.weight'
],
up
=
weights
[
'shared_experts.1.weight'
],
down
=
weights
[
'shared_experts.2.weight'
],
d_expert
=
shared_expert_dim
).
to
(
self
.
device
)
gate
=
weights
[
"shared_experts.0.weight"
],
up
=
weights
[
"shared_experts.1.weight"
],
down
=
weights
[
"shared_experts.2.weight"
],
d_expert
=
shared_expert_dim
,
).
to
(
self
.
device
)
self
.
expert_cache
=
torch
.
zeros
(
(
config
[
"batch_size"
]
*
config
[
"seq_len"
],
config
[
"d_hidden"
]),
dtype
=
torch
.
float16
,
device
=
self
.
device
)
self
.
stacked_expert_w_gate
=
torch
.
stack
([
expert
.
W_gate_weight
for
expert
in
self
.
experts
],
dim
=
0
)
self
.
stacked_expert_w_up
=
torch
.
stack
([
expert
.
W_up_weight
for
expert
in
self
.
experts
],
dim
=
0
)
self
.
stacked_expert_w_down
=
torch
.
stack
([
expert
.
W_down_weight
for
expert
in
self
.
experts
],
dim
=
0
)
(
config
[
"batch_size"
]
*
config
[
"seq_len"
],
config
[
"d_hidden"
]),
dtype
=
torch
.
float16
,
device
=
self
.
device
)
self
.
stacked_expert_w_gate
=
torch
.
stack
([
expert
.
W_gate_weight
for
expert
in
self
.
experts
],
dim
=
0
)
self
.
stacked_expert_w_up
=
torch
.
stack
([
expert
.
W_up_weight
for
expert
in
self
.
experts
],
dim
=
0
)
self
.
stacked_expert_w_down
=
torch
.
stack
([
expert
.
W_down_weight
for
expert
in
self
.
experts
],
dim
=
0
)
self
.
stacked_expert_tokens
=
torch
.
empty
(
(
config
[
"batch_size"
]
*
config
[
"seq_len"
]
*
config
[
"n_experts_per_token"
],
self
.
config
[
"d_hidden"
]),
(
config
[
"batch_size"
]
*
config
[
"seq_len"
]
*
config
[
"n_experts_per_token"
],
self
.
config
[
"d_hidden"
]),
dtype
=
torch
.
float16
,
device
=
self
.
device
)
device
=
self
.
device
,
)
self
.
stacked_expert_weights
=
torch
.
empty
(
(
config
[
"batch_size"
]
*
config
[
"seq_len"
]
*
config
[
"n_experts_per_token"
]),
dtype
=
torch
.
float16
,
device
=
self
.
device
)
(
config
[
"batch_size"
]
*
config
[
"seq_len"
]
*
config
[
"n_experts_per_token"
]),
dtype
=
torch
.
float16
,
device
=
self
.
device
)
self
.
stacked_expert_tokens_idxs
=
torch
.
empty
(
(
config
[
"batch_size"
]
*
config
[
"seq_len"
]
*
config
[
"n_experts_per_token"
]),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
(
config
[
"batch_size"
]
*
config
[
"seq_len"
]
*
config
[
"n_experts_per_token"
]),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
up_logits_shared
=
torch
.
empty
(
(
config
[
"batch_size"
]
*
config
[
"seq_len"
],
self
.
config
[
"d_expert"
]),
dtype
=
torch
.
float16
,
device
=
self
.
device
)
(
config
[
"batch_size"
]
*
config
[
"seq_len"
],
self
.
config
[
"d_expert"
]),
dtype
=
torch
.
float16
,
device
=
self
.
device
)
self
.
expert_output_shared
=
torch
.
empty
(
(
config
[
"batch_size"
]
*
config
[
"seq_len"
],
self
.
config
[
"d_hidden"
]),
dtype
=
torch
.
float16
,
device
=
self
.
device
)
(
config
[
"batch_size"
]
*
config
[
"seq_len"
],
self
.
config
[
"d_hidden"
]),
dtype
=
torch
.
float16
,
device
=
self
.
device
)
self
.
up_logits_routed
=
torch
.
empty
(
(
config
[
"batch_size"
]
*
config
[
"seq_len"
]
*
config
[
"n_experts_per_token"
],
self
.
config
[
"d_expert"
]),
(
config
[
"batch_size"
]
*
config
[
"seq_len"
]
*
config
[
"n_experts_per_token"
],
self
.
config
[
"d_expert"
]),
dtype
=
torch
.
float16
,
device
=
self
.
device
)
device
=
self
.
device
,
)
self
.
expert_output_routed
=
torch
.
empty
(
(
config
[
"batch_size"
]
*
config
[
"seq_len"
]
*
config
[
"n_experts_per_token"
],
self
.
config
[
"d_hidden"
]),
(
config
[
"batch_size"
]
*
config
[
"seq_len"
]
*
config
[
"n_experts_per_token"
],
self
.
config
[
"d_hidden"
]),
dtype
=
torch
.
float16
,
device
=
self
.
device
)
device
=
self
.
device
,
)
@
torch
.
no_grad
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -413,22 +376,20 @@ class MoE(nn.Module):
self
.
stacked_expert_tokens
[
start_idx
:
end_idx
]
=
expert_tokens
self
.
stacked_expert_tokens_idxs
[
start_idx
:
end_idx
]
=
exp_token_idxs
self
.
stacked_expert_weights
[
start_idx
:
end_idx
]
=
flat_expert_weights
[
idxs
[
start_idx
:
end_idx
]]
self
.
stacked_expert_weights
[
start_idx
:
end_idx
]
=
flat_expert_weights
[
idxs
[
start_idx
:
end_idx
]]
group_sizes
=
torch
.
tensor
(
counts
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
group_offset
=
torch
.
tensor
(
tokens_per_expert
-
counts
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
group_offset
=
torch
.
tensor
(
tokens_per_expert
-
counts
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
group_padded_offsets
=
[
0
for
_
in
range
(
len
(
group_sizes
))]
for
i
in
range
(
1
,
len
(
group_sizes
)):
group_padded_offsets
[
i
]
=
group_padded_offsets
[
i
-
1
]
+
math
.
ceil
(
(
counts
[
i
-
1
]
+
1
)
/
self
.
padding_M
)
*
self
.
padding_M
group_padded_offsets
[
i
]
=
group_padded_offsets
[
i
-
1
]
+
math
.
ceil
((
counts
[
i
-
1
]
+
1
)
/
self
.
padding_M
)
*
self
.
padding_M
block_token
=
128
M
=
math
.
ceil
(
self
.
config
[
"batch_size"
]
*
self
.
config
[
"seq_len"
]
*
self
.
config
[
"n_experts_per_token"
]
/
block_token
)
+
self
.
config
[
"n_routed_experts"
]
M
=
(
math
.
ceil
(
self
.
config
[
"batch_size"
]
*
self
.
config
[
"seq_len"
]
*
self
.
config
[
"n_experts_per_token"
]
/
block_token
)
+
self
.
config
[
"n_routed_experts"
]
)
group_idx_for_bx
=
[
0
for
_
in
range
(
M
)]
for
bx
in
range
(
M
):
...
...
@@ -437,8 +398,7 @@ class MoE(nn.Module):
if
m_start_padded
>=
group_padded_offsets
[
i
]:
group_idx_for_bx
[
bx
]
=
i
group_padded_offsets
=
torch
.
tensor
(
group_padded_offsets
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
group_padded_offsets
=
torch
.
tensor
(
group_padded_offsets
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
group_idx_for_bx
=
torch
.
tensor
(
group_idx_for_bx
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# Multi-stream execution
...
...
@@ -448,11 +408,19 @@ class MoE(nn.Module):
with
torch
.
cuda
.
stream
(
routed_stream
):
# Tilelang version: Grouped GEMM
self
.
routed_kernel
(
self
.
stacked_expert_tokens
,
self
.
stacked_expert_w_gate
,
self
.
stacked_expert_w_up
,
self
.
stacked_expert_w_down
,
self
.
stacked_expert_weights
,
group_sizes
,
group_offset
,
group_padded_offsets
,
group_idx_for_bx
,
self
.
up_logits_routed
,
self
.
expert_output_routed
)
self
.
routed_kernel
(
self
.
stacked_expert_tokens
,
self
.
stacked_expert_w_gate
,
self
.
stacked_expert_w_up
,
self
.
stacked_expert_w_down
,
self
.
stacked_expert_weights
,
group_sizes
,
group_offset
,
group_padded_offsets
,
group_idx_for_bx
,
self
.
up_logits_routed
,
self
.
expert_output_routed
,
)
# Scatter reduce
self
.
expert_cache
=
torch
.
scatter_reduce
(
...
...
@@ -460,14 +428,19 @@ class MoE(nn.Module):
0
,
self
.
stacked_expert_tokens_idxs
.
view
(
-
1
,
1
).
repeat
(
1
,
x_flat
.
shape
[
-
1
]),
self
.
expert_output_routed
,
reduce
=
'sum'
)
reduce
=
"sum"
,
)
routed_output
=
self
.
expert_cache
.
view
(
*
orig_shape
)
with
torch
.
cuda
.
stream
(
shared_stream
):
self
.
shared_kernel
(
x_flat
,
self
.
shared_expert
.
W_gate_weight
,
self
.
shared_expert
.
W_up_weight
,
self
.
shared_expert
.
W_down_weight
,
self
.
up_logits_shared
,
self
.
expert_output_shared
)
self
.
shared_kernel
(
x_flat
,
self
.
shared_expert
.
W_gate_weight
,
self
.
shared_expert
.
W_up_weight
,
self
.
shared_expert
.
W_down_weight
,
self
.
up_logits_shared
,
self
.
expert_output_shared
,
)
shared_output
=
self
.
expert_output_shared
.
view
(
*
orig_shape
)
torch
.
cuda
.
synchronize
()
...
...
@@ -498,7 +471,8 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
config
[
"d_expert"
],
config
[
"n_shared_experts"
],
dtype
=
dtype_str
,
num_tokens
=
config
[
"batch_size"
]
*
config
[
"seq_len"
])
num_tokens
=
config
[
"batch_size"
]
*
config
[
"seq_len"
],
)
routed_kernel
=
moe_forward_tilelang_routed
(
config
[
"d_hidden"
],
config
[
"d_expert"
],
...
...
@@ -512,7 +486,8 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
threads
=
256
,
num_stages
=
1
,
k_pack
=
1
,
coalesced_width
=
2
)
coalesced_width
=
2
,
)
moe
=
MoE
(
config
,
shared_kernel
,
routed_kernel
,
weights
,
padding_M
=
128
)
...
...
@@ -521,13 +496,7 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
return
output
def
main
(
d_hidden
=
7168
,
d_expert
=
2048
,
n_routed_experts
=
8
,
n_shared_experts
=
1
,
n_experts_per_token
=
4
,
batch_size
=
1
,
seq_len
=
8192
):
def
main
(
d_hidden
=
7168
,
d_expert
=
2048
,
n_routed_experts
=
8
,
n_shared_experts
=
1
,
n_experts_per_token
=
4
,
batch_size
=
1
,
seq_len
=
8192
):
config
=
{
"dhidden"
:
d_hidden
,
"dexpert"
:
d_expert
,
...
...
@@ -536,7 +505,7 @@ def main(d_hidden=7168,
"nexpertspertoken"
:
n_experts_per_token
,
"bs"
:
batch_size
,
"seqlen"
:
seq_len
,
"seed"
:
81394
"seed"
:
81394
,
}
data
=
generate_input
(
**
config
)
...
...
examples/fusedmoe/example_fusedmoe_torch.py
View file @
29051439
...
...
@@ -6,7 +6,6 @@ from typing import Dict, Tuple, Optional
# Reference code in PyTorch
class
ExpertTorch
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Dict
,
d_expert
:
Optional
[
int
]
=
None
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -25,7 +24,6 @@ class ExpertTorch(nn.Module):
class
MoEGateTorch
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Dict
):
super
().
__init__
()
self
.
top_k
:
int
=
config
[
"n_experts_per_token"
]
...
...
@@ -43,12 +41,10 @@ class MoEGateTorch(nn.Module):
class
MoETorch
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Dict
):
super
().
__init__
()
self
.
config
=
config
self
.
experts
=
nn
.
ModuleList
(
[
ExpertTorch
(
config
)
for
_
in
range
(
config
[
"n_routed_experts"
])])
self
.
experts
=
nn
.
ModuleList
([
ExpertTorch
(
config
)
for
_
in
range
(
config
[
"n_routed_experts"
])])
self
.
gating_network
=
MoEGateTorch
(
config
)
shared_expert_dim
=
config
[
"d_expert"
]
*
config
[
"n_shared_experts"
]
self
.
shared_expert
=
ExpertTorch
(
config
=
config
,
d_expert
=
shared_expert_dim
)
...
...
@@ -67,8 +63,7 @@ class MoETorch(nn.Module):
return
routed_output
+
shared_output
@
torch
.
no_grad
()
def
moe_infer
(
self
,
x
:
torch
.
Tensor
,
flat_expert_indices
:
torch
.
Tensor
,
flat_expert_weights
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
moe_infer
(
self
,
x
:
torch
.
Tensor
,
flat_expert_indices
:
torch
.
Tensor
,
flat_expert_weights
:
torch
.
Tensor
)
->
torch
.
Tensor
:
expert_cache
=
torch
.
zeros_like
(
x
)
# test_expert_cache = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"]))
# test_expert_tokens = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"]))
...
...
@@ -91,8 +86,7 @@ class MoETorch(nn.Module):
expert_out
=
expert
(
expert_tokens
)
expert_out
.
mul_
(
flat_expert_weights
[
idxs
[
start_idx
:
end_idx
]])
expert_cache
.
scatter_reduce_
(
0
,
exp_token_idxs
.
view
(
-
1
,
1
).
repeat
(
1
,
x
.
shape
[
-
1
]),
expert_out
,
reduce
=
'sum'
)
expert_cache
.
scatter_reduce_
(
0
,
exp_token_idxs
.
view
(
-
1
,
1
).
repeat
(
1
,
x
.
shape
[
-
1
]),
expert_out
,
reduce
=
"sum"
)
return
expert_cache
...
...
@@ -116,21 +110,21 @@ def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
moe
=
MoETorch
(
config
)
# Fill in the given weights of the model
moe
.
gating_network
.
W_g
.
weight
=
nn
.
Parameter
(
weights
[
'
router.weight
'
])
moe
.
gating_network
.
W_g
.
weight
=
nn
.
Parameter
(
weights
[
"
router.weight
"
])
for
i
in
range
(
num_experts
):
gate_proj_weight
=
weights
[
f
'
experts.
{
i
}
.0.weight
'
]
up_proj_weight
=
weights
[
f
'
experts.
{
i
}
.1.weight
'
]
down_proj_weight
=
weights
[
f
'
experts.
{
i
}
.2.weight
'
]
gate_proj_weight
=
weights
[
f
"
experts.
{
i
}
.0.weight
"
]
up_proj_weight
=
weights
[
f
"
experts.
{
i
}
.1.weight
"
]
down_proj_weight
=
weights
[
f
"
experts.
{
i
}
.2.weight
"
]
# Transpose weights to match expected shape for nn.Linear
moe
.
experts
[
i
].
W_gate
.
weight
=
nn
.
Parameter
(
gate_proj_weight
.
t
())
moe
.
experts
[
i
].
W_up
.
weight
=
nn
.
Parameter
(
up_proj_weight
.
t
())
moe
.
experts
[
i
].
W_down
.
weight
=
nn
.
Parameter
(
down_proj_weight
.
t
())
moe
.
shared_expert
.
W_gate
.
weight
=
nn
.
Parameter
(
weights
[
'
shared_experts.0.weight
'
].
t
())
moe
.
shared_expert
.
W_up
.
weight
=
nn
.
Parameter
(
weights
[
'
shared_experts.1.weight
'
].
t
())
moe
.
shared_expert
.
W_down
.
weight
=
nn
.
Parameter
(
weights
[
'
shared_experts.2.weight
'
].
t
())
moe
.
shared_expert
.
W_gate
.
weight
=
nn
.
Parameter
(
weights
[
"
shared_experts.0.weight
"
].
t
())
moe
.
shared_expert
.
W_up
.
weight
=
nn
.
Parameter
(
weights
[
"
shared_experts.1.weight
"
].
t
())
moe
.
shared_expert
.
W_down
.
weight
=
nn
.
Parameter
(
weights
[
"
shared_experts.2.weight
"
].
t
())
output
=
moe
(
input_tensor
)
...
...
@@ -140,10 +134,9 @@ def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
# Input generation for the reference code
def
generate_input
(
dhidden
:
int
,
dexpert
:
int
,
nroutedexperts
:
int
,
nsharedexperts
:
int
,
nexpertspertoken
:
int
,
bs
:
int
,
seqlen
:
int
,
seed
:
int
)
->
Tuple
[
torch
.
Tensor
,
Dict
,
Dict
]:
def
generate_input
(
dhidden
:
int
,
dexpert
:
int
,
nroutedexperts
:
int
,
nsharedexperts
:
int
,
nexpertspertoken
:
int
,
bs
:
int
,
seqlen
:
int
,
seed
:
int
)
->
Tuple
[
torch
.
Tensor
,
Dict
,
Dict
]:
# Really dumb but for now _ isn't parsing correctly.
d_hidden
=
dhidden
d_expert
=
dexpert
...
...
@@ -163,50 +156,40 @@ def generate_input(dhidden: int, dexpert: int, nroutedexperts: int, nsharedexper
"seq_len"
:
seq_len
,
}
gen
=
torch
.
Generator
(
device
=
'
cuda
'
)
gen
=
torch
.
Generator
(
device
=
"
cuda
"
)
gen
.
manual_seed
(
seed
)
num_experts
=
n_routed_experts
expert_dim
=
d_expert
weights
=
{}
input_tensor
=
torch
.
randn
((
batch_size
,
seq_len
,
d_hidden
),
device
=
'cuda'
,
dtype
=
torch
.
float16
,
generator
=
gen
).
contiguous
()
input_tensor
=
torch
.
randn
((
batch_size
,
seq_len
,
d_hidden
),
device
=
"cuda"
,
dtype
=
torch
.
float16
,
generator
=
gen
).
contiguous
()
# Initialize router weights
weights
[
'router.weight'
]
=
torch
.
randn
(
(
num_experts
,
d_hidden
),
device
=
"cuda"
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
d_hidden
)
weights
[
"router.weight"
]
=
torch
.
randn
((
num_experts
,
d_hidden
),
device
=
"cuda"
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
d_hidden
)
for
i
in
range
(
num_experts
):
weights
[
f
'experts.
{
i
}
.0.weight'
]
=
torch
.
randn
(
(
d_hidden
,
expert_dim
),
device
=
'cuda'
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
expert_dim
)
weights
[
f
'experts.
{
i
}
.1.weight'
]
=
torch
.
randn
(
(
d_hidden
,
expert_dim
),
device
=
'cuda'
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
expert_dim
)
weights
[
f
'experts.
{
i
}
.2.weight'
]
=
torch
.
randn
(
(
expert_dim
,
d_hidden
),
device
=
'cuda'
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
d_hidden
)
weights
[
'shared_experts.0.weight'
]
=
torch
.
randn
(
(
d_hidden
,
expert_dim
*
n_shared_experts
),
device
=
'cuda'
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
expert_dim
*
n_shared_experts
)
weights
[
'shared_experts.1.weight'
]
=
torch
.
randn
(
(
d_hidden
,
expert_dim
*
n_shared_experts
),
device
=
'cuda'
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
expert_dim
*
n_shared_experts
)
weights
[
'shared_experts.2.weight'
]
=
torch
.
randn
((
expert_dim
*
n_shared_experts
,
d_hidden
),
device
=
'cuda'
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
d_hidden
)
weights
[
f
"experts.
{
i
}
.0.weight"
]
=
torch
.
randn
(
(
d_hidden
,
expert_dim
),
device
=
"cuda"
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
expert_dim
)
weights
[
f
"experts.
{
i
}
.1.weight"
]
=
torch
.
randn
(
(
d_hidden
,
expert_dim
),
device
=
"cuda"
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
expert_dim
)
weights
[
f
"experts.
{
i
}
.2.weight"
]
=
torch
.
randn
(
(
expert_dim
,
d_hidden
),
device
=
"cuda"
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
d_hidden
)
weights
[
"shared_experts.0.weight"
]
=
torch
.
randn
(
(
d_hidden
,
expert_dim
*
n_shared_experts
),
device
=
"cuda"
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
expert_dim
*
n_shared_experts
)
weights
[
"shared_experts.1.weight"
]
=
torch
.
randn
(
(
d_hidden
,
expert_dim
*
n_shared_experts
),
device
=
"cuda"
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
expert_dim
*
n_shared_experts
)
weights
[
"shared_experts.2.weight"
]
=
torch
.
randn
(
(
expert_dim
*
n_shared_experts
,
d_hidden
),
device
=
"cuda"
,
dtype
=
torch
.
float16
,
generator
=
gen
)
/
math
.
sqrt
(
d_hidden
)
return
(
input_tensor
,
weights
,
config
)
...
...
examples/fusedmoe/test_example_fusedmoe.py
View file @
29051439
...
...
@@ -4,13 +4,8 @@ import example_fusedmoe_tilelang
def
test_example_fusedmoe_tilelang
():
example_fusedmoe_tilelang
.
main
(
d_hidden
=
1024
,
d_expert
=
256
,
n_routed_experts
=
8
,
n_shared_experts
=
1
,
n_experts_per_token
=
4
,
batch_size
=
1
,
seq_len
=
1024
)
d_hidden
=
1024
,
d_expert
=
256
,
n_routed_experts
=
8
,
n_shared_experts
=
1
,
n_experts_per_token
=
4
,
batch_size
=
1
,
seq_len
=
1024
)
if
__name__
==
"__main__"
:
...
...
examples/gdn/example_chunk_delta_bwd.py
View file @
29051439
...
...
@@ -12,6 +12,7 @@ print(tilelang.__file__, flush=True)
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try
:
import
fla
print
(
fla
.
__file__
,
flush
=
True
)
from
fla.ops.common.chunk_delta_h
import
chunk_gated_delta_rule_bwd_dhu
except
ImportError
:
...
...
@@ -49,6 +50,7 @@ def prepare_input(
G
=
F
.
logsigmoid
(
G
)
try
:
from
fla.ops.utils.cumsum
import
chunk_local_cumsum
G
=
chunk_local_cumsum
(
G
,
chunk_size
)
except
ImportError
:
print
(
"fla not found, skip cumsum"
)
...
...
@@ -125,8 +127,11 @@ def torch_chunk_gated_delta_rule_bwd_dhu(
DV
=
dv
.
shape
[
-
1
]
block_S
=
64
BS
=
S
//
block_S
dh
,
dh0
,
dv2
=
torch
.
empty
((
B
,
BS
,
H
,
DK
,
DV
),
dtype
=
output_dtype
),
torch
.
empty
(
(
B
,
H
,
DK
,
DV
),
dtype
=
state_dtype
),
torch
.
empty
((
B
,
S
,
H
,
DV
),
dtype
=
output_dtype
)
dh
,
dh0
,
dv2
=
(
torch
.
empty
((
B
,
BS
,
H
,
DK
,
DV
),
dtype
=
output_dtype
),
torch
.
empty
((
B
,
H
,
DK
,
DV
),
dtype
=
state_dtype
),
torch
.
empty
((
B
,
S
,
H
,
DV
),
dtype
=
output_dtype
),
)
dh_tmp
=
torch
.
empty
((
B
,
H
,
DK
,
DV
),
dtype
=
accum_dtype
)
dv_tmp
=
torch
.
empty
((
B
,
S
,
H
,
DV
),
dtype
=
accum_dtype
)
Q_tmp
=
torch
.
empty
((
B
,
S
,
H
,
DK
),
dtype
=
accum_dtype
)
...
...
@@ -138,34 +143,30 @@ def torch_chunk_gated_delta_rule_bwd_dhu(
for
i_s
in
range
(
BS
-
1
,
-
1
,
-
1
):
dh
[:,
i_s
,
:,
:,
:]
=
dh_tmp
dv_tmp
=
torch
.
matmul
(
K
[:,
i_s
*
block_S
:(
i_s
+
1
)
*
block_S
,
:,
:].
permute
(
0
,
2
,
1
,
3
),
dh_tmp
.
to
(
K
.
dtype
)).
permute
(
0
,
2
,
1
,
3
)
dv_tmp
=
torch
.
matmul
(
K
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:].
permute
(
0
,
2
,
1
,
3
),
dh_tmp
.
to
(
K
.
dtype
)).
permute
(
0
,
2
,
1
,
3
)
if
use_g
:
for
i_bh
in
range
(
B
*
H
):
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
for
i_s2
in
range
(
block_S
):
if
G
[
i_b
,
i_s
*
block_S
+
block_S
-
1
,
i_h
]
-
G
[
i_b
,
i_s
*
block_S
+
i_s2
,
i_h
]
<=
0
:
dv_tmp
[
i_b
,
i_s2
,
i_h
,
:]
*=
torch
.
exp
(
G
[
i_b
,
i_s
*
block_S
+
block_S
-
1
,
i_h
]
-
G
[
i_b
,
i_s
*
block_S
+
i_s2
,
i_h
])
if
G
[
i_b
,
i_s
*
block_S
+
block_S
-
1
,
i_h
]
-
G
[
i_b
,
i_s
*
block_S
+
i_s2
,
i_h
]
<=
0
:
dv_tmp
[
i_b
,
i_s2
,
i_h
,
:]
*=
torch
.
exp
(
G
[
i_b
,
i_s
*
block_S
+
block_S
-
1
,
i_h
]
-
G
[
i_b
,
i_s
*
block_S
+
i_s2
,
i_h
])
else
:
dv_tmp
[
i_b
,
i_s2
,
i_h
,
:]
=
0
dv_tmp
+=
dv
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:]
dv2
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:]
=
dv_tmp
dv_tmp
+=
dv
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:]
dv2
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:]
=
dv_tmp
if
use_g
:
G_last
=
G
[:,
i_s
*
block_S
+
block_S
-
1
,
:]
for
i_bh
in
range
(
B
*
H
):
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
dh_tmp
[
i_b
,
i_h
,
:,
:]
*=
torch
.
exp
(
G_last
[
i_b
,
i_h
])
Q_tmp
=
Q
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:]
Q_tmp
=
Q
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:]
for
i_s2
in
range
(
block_S
):
for
i_k
in
range
(
DK
):
Q_tmp
[:,
i_s2
,
:,
i_k
]
*=
torch
.
exp
(
G
[:,
i_s
*
block_S
+
i_s2
,
:])
Q_tmp
*=
scale
W_tmp
=
W
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:]
dO_tmp
=
dO
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:]
W_tmp
=
W
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:]
dO_tmp
=
dO
[:,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
:,
:]
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
dh_tmp
+=
torch
.
matmul
(
Q_tmp
.
permute
(
0
,
2
,
3
,
1
),
dO_tmp
.
permute
(
0
,
2
,
1
,
3
))
...
...
@@ -269,7 +270,8 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
T
.
use_swizzle
(
10
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
b_dh_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
b_dh_shared
),
b_dh_shared_fp32
:
tilelang
.
layout
.
make_swizzled_layout
(
b_dh_shared_fp32
),
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
...
...
@@ -279,10 +281,11 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
Q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared
),
Q_shared_fp32
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared_fp32
),
W_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
W_shared
),
})
}
)
if
use_final_state_gradient
:
T
.
copy
(
dht
[
bb
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
],
b_dh_shared
)
T
.
copy
(
dht
[
bb
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
],
b_dh_shared
)
T
.
copy
(
b_dh_shared
,
b_dh_fragment
)
else
:
T
.
clear
(
b_dh_fragment
)
...
...
@@ -293,17 +296,14 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
# Store the updated dh
T
.
copy
(
b_dh_fragment
,
b_dh_shared
)
T
.
copy
(
b_dh_shared
,
dh
[
bb
,
i_s_inv
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
T
.
copy
(
b_dh_shared
,
dh
[
bb
,
i_s_inv
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
# Update dv
T
.
copy
(
K
[
bb
,
i_s_inv
*
block_S
:
(
i_s_inv
+
1
)
*
block_S
,
bh
,
0
:
DK
],
K_shared
)
T
.
copy
(
K
[
bb
,
i_s_inv
*
block_S
:
(
i_s_inv
+
1
)
*
block_S
,
bh
,
0
:
DK
],
K_shared
)
T
.
gemm
(
K_shared
,
b_dh_shared
,
dv_fragment
,
clear_accum
=
True
)
if
use_g
:
T
.
copy
(
G
[
bb
,
i_s_inv
*
block_S
:(
i_s_inv
+
1
)
*
block_S
,
bh
],
G_shared
,
disable_tma
=
True
)
T
.
copy
(
G
[
bb
,
i_s_inv
*
block_S
:
(
i_s_inv
+
1
)
*
block_S
,
bh
],
G_shared
,
disable_tma
=
True
)
T
.
copy
(
G_shared
,
G_fragment
)
G_last_local
[
0
]
=
G_shared
[
block_S
-
1
]
G_last_local_exp
[
0
]
=
T
.
exp
(
G_last_local
[
0
])
...
...
@@ -313,27 +313,22 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
# with T.If(G_last_local[0] - G_shared[i_s2] <= 0):
with
T
.
If
(
G_last_local
[
0
]
-
G_fragment
[
i_s2
]
<=
0
):
with
T
.
Then
():
dv_fragment
[
i_s2
,
i_v
]
=
dv_fragment
[
i_s2
,
i_v
]
*
G_fragment_post
[
i_s2
]
dv_fragment
[
i_s2
,
i_v
]
=
dv_fragment
[
i_s2
,
i_v
]
*
G_fragment_post
[
i_s2
]
with
T
.
Else
():
dv_fragment
[
i_s2
,
i_v
]
=
0
T
.
copy
(
dv
[
bb
,
i_s_inv
*
block_S
:(
i_s_inv
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:(
bv
+
1
)
*
block_DV
],
dv_shared
)
T
.
copy
(
dv
[
bb
,
i_s_inv
*
block_S
:
(
i_s_inv
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
],
dv_shared
)
T
.
copy
(
dv_shared
,
dv_fragment_2
)
for
i_s2
,
i_v
in
T
.
Parallel
(
block_S
,
block_DV
):
dv_fragment
[
i_s2
,
i_v
]
=
dv_fragment
[
i_s2
,
i_v
]
+
dv_fragment_2
[
i_s2
,
i_v
]
# Store the updated dv
T
.
copy
(
dv_fragment
,
dv_shared
)
T
.
copy
(
dv_shared
,
dv2
[
bb
,
i_s_inv
*
block_S
:(
i_s_inv
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:(
bv
+
1
)
*
block_DV
])
T
.
copy
(
dv_shared
,
dv2
[
bb
,
i_s_inv
*
block_S
:
(
i_s_inv
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
# Update dh
T
.
copy
(
Q
[
bb
,
i_s_inv
*
block_S
:
(
i_s_inv
+
1
)
*
block_S
,
bh
,
0
:
DK
],
Q_shared
)
T
.
copy
(
W
[
bb
,
i_s_inv
*
block_S
:
(
i_s_inv
+
1
)
*
block_S
,
bh
,
0
:
DK
],
W_shared
)
T
.
copy
(
Q
[
bb
,
i_s_inv
*
block_S
:
(
i_s_inv
+
1
)
*
block_S
,
bh
,
0
:
DK
],
Q_shared
)
T
.
copy
(
W
[
bb
,
i_s_inv
*
block_S
:
(
i_s_inv
+
1
)
*
block_S
,
bh
,
0
:
DK
],
W_shared
)
T
.
clear
(
Q_fragment
)
if
use_g
:
...
...
@@ -353,9 +348,7 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
for
i_s2
,
i_k
in
T
.
Parallel
(
block_S
,
DK
):
Q_fragment_t
[
i_k
,
i_s2
]
=
Q_fragment
[
i_s2
,
i_k
]
T
.
copy
(
dO
[
bb
,
i_s_inv
*
block_S
:(
i_s_inv
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:(
bv
+
1
)
*
block_DV
],
dO_shared
)
T
.
copy
(
dO
[
bb
,
i_s_inv
*
block_S
:
(
i_s_inv
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
],
dO_shared
)
T
.
copy
(
dO_shared
,
dO_fragment
)
for
i_s2
,
i_v
in
T
.
Parallel
(
block_S
,
block_DV
):
dO_fragment_t
[
i_v
,
i_s2
]
=
dO_fragment
[
i_s2
,
i_v
]
...
...
@@ -369,7 +362,7 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
b_dh_fragment
[
i_k
,
i_v
]
+=
b_dh_fragment_1
[
i_k
,
i_v
]
-
b_dh_fragment_2
[
i_k
,
i_v
]
if
use_initial_state
:
T
.
copy
(
b_dh_fragment
,
dh0
[
bb
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
T
.
copy
(
b_dh_fragment
,
dh0
[
bb
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
return
kernel
...
...
@@ -444,44 +437,61 @@ def run_test(
num_stages
=
0
,
use_torch
=
False
,
):
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
))
dh_ref
,
dh0_ref
,
dv2_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
))
dh_tilelang
,
dh0_tilelang
,
dv2_tilelang
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
))
getattr
(
torch
,
state_dtype
),
)
dh_ref
,
dh0_ref
,
dv2_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
)
)
dh_tilelang
,
dh0_tilelang
,
dv2_tilelang
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
)
)
# fla ref
print
(
"fla running..."
,
flush
=
True
)
if
use_g
:
dh_ref
,
dh0_ref
,
dv2_ref
=
chunk_gated_delta_rule_bwd_dhu
(
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
,
scale
)
dh_ref
,
dh0_ref
,
dv2_ref
=
chunk_gated_delta_rule_bwd_dhu
(
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
,
scale
)
else
:
G
=
G
.
fill_
(
0
)
dh_ref
,
dh0_ref
,
dv2_ref
=
chunk_gated_delta_rule_bwd_dhu
(
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
,
scale
)
dh_ref
,
dh0_ref
,
dv2_ref
=
chunk_gated_delta_rule_bwd_dhu
(
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
,
scale
)
# tilelang
print
(
"tilelang running..."
,
flush
=
True
)
kernel
=
tilelang_chunk_gated_delta_rule_bwd_dhu
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
scale
,
use_g
,
use_initial_state
,
use_final_state_gradient
,
block_DV
,
threads
,
num_stages
)
kernel
=
tilelang_chunk_gated_delta_rule_bwd_dhu
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
scale
,
use_g
,
use_initial_state
,
use_final_state_gradient
,
block_DV
,
threads
,
num_stages
,
)
# kernel = tilelang.compile(program)
print
(
kernel
.
get_kernel_source
())
dh_tilelang
,
dh0_tilelang
,
dv2_tilelang
=
kernel
(
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
)
fla_time
=
do_bench
(
chunk_gated_delta_rule_bwd_dhu
,
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
,
scale
,
chunk_size
=
chunk_size
)
fla_time
=
do_bench
(
chunk_gated_delta_rule_bwd_dhu
,
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
,
scale
,
chunk_size
=
chunk_size
)
tilelang_time
=
do_bench
(
kernel
,
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
)
print
(
f
"fla time:
{
fla_time
}
ms"
)
...
...
@@ -496,19 +506,47 @@ def run_test(
print
(
"torch running..."
,
flush
=
True
)
if
use_g
:
dh_ref_torch
,
dh0_ref_torch
,
dv2_ref_torch
=
torch_chunk_gated_delta_rule_bwd_dhu
(
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
,
scale
,
use_g
,
use_initial_state
,
use_final_state_gradient
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
))
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
,
scale
,
use_g
,
use_initial_state
,
use_final_state_gradient
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
),
)
dh_ref_torch
=
dh_ref_torch
.
cuda
()
dh0_ref_torch
=
dh0_ref_torch
.
cuda
()
dv2_ref_torch
=
dv2_ref_torch
.
cuda
()
else
:
dh_ref_torch
,
dh0_ref_torch
,
dv2_ref_torch
=
torch_chunk_gated_delta_rule_bwd_dhu
(
Q
,
K
,
W
,
None
,
h0
,
dht
,
dO
,
dv
,
scale
,
use_g
,
use_initial_state
,
use_final_state_gradient
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
))
Q
,
K
,
W
,
None
,
h0
,
dht
,
dO
,
dv
,
scale
,
use_g
,
use_initial_state
,
use_final_state_gradient
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
),
)
dh_ref_torch
=
dh_ref_torch
.
cuda
()
dh0_ref_torch
=
dh0_ref_torch
.
cuda
()
dv2_ref_torch
=
dv2_ref_torch
.
cuda
()
...
...
examples/gdn/example_chunk_delta_h.py
View file @
29051439
...
...
@@ -10,6 +10,7 @@ from tilelang.autotuner import autotune
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try
:
import
fla
print
(
fla
.
__file__
)
from
fla.ops.common.chunk_delta_h
import
chunk_gated_delta_rule_fwd_h
except
ImportError
:
...
...
@@ -56,6 +57,7 @@ def prepare_input(
G
=
F
.
logsigmoid
(
G
)
try
:
from
fla.ops.utils.cumsum
import
chunk_local_cumsum
G
=
chunk_local_cumsum
(
G
,
chunk_size
)
except
ImportError
:
print
(
"fla not found, skip cumsum"
)
...
...
@@ -83,18 +85,14 @@ def prepare_output(
def
get_configs
():
import
itertools
block_DK
=
[
32
,
64
,
128
]
block_DV
=
[
32
,
64
,
128
]
threads
=
[
128
,
256
]
num_stages
=
[
1
,
2
,
3
]
_configs
=
list
(
itertools
.
product
(
block_DK
,
block_DV
,
threads
,
num_stages
))
configs
=
[{
'block_DK'
:
c
[
0
],
'block_DV'
:
c
[
1
],
'threads'
:
c
[
2
],
'num_stages'
:
c
[
3
]
}
for
c
in
_configs
]
configs
=
[{
"block_DK"
:
c
[
0
],
"block_DV"
:
c
[
1
],
"threads"
:
c
[
2
],
"num_stages"
:
c
[
3
]}
for
c
in
_configs
]
return
configs
...
...
@@ -162,35 +160,35 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
G_shared
=
T
.
alloc_shared
((
block_S
,
block_DV
),
dtype
=
gate_dtype
)
G_fragment
=
T
.
alloc_fragment
((
block_S
,
block_DV
),
dtype
=
gate_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
b_h_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
b_h_shared
),
U_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
U_shared
),
W_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
W_shared
),
V_new_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
V_new_shared
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
G_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
G_shared
),
})
}
)
T
.
use_swizzle
(
10
)
if
use_initial_state
:
T
.
copy
(
initial_state
[
bb
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
],
b_h_shared
)
T
.
copy
(
initial_state
[
bb
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
],
b_h_shared
)
T
.
copy
(
b_h_shared
,
b_h_fragment
)
else
:
T
.
clear
(
b_h_fragment
)
for
i_s
in
T
.
Pipelined
(
T
.
ceildiv
(
S
,
block_S
),
num_stages
=
num_stages
):
# Store previous result to the hidden tensor, like the epilogue
T
.
copy
(
b_h_shared
,
h
[
bb
,
i_s
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
T
.
copy
(
b_h_shared
,
h
[
bb
,
i_s
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
# Recurrence
T
.
copy
(
W
[
bb
,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
bh
,
0
:
DK
],
W_shared
)
T
.
copy
(
W
[
bb
,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
bh
,
0
:
DK
],
W_shared
)
T
.
gemm
(
W_shared
,
b_h_shared
,
V_new_fragment
,
clear_accum
=
True
)
# U - W * S
T
.
copy
(
U
[
bb
,
i_s
*
block_S
:(
i_s
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:(
bv
+
1
)
*
block_DV
],
U_shared
)
T
.
copy
(
U
[
bb
,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
],
U_shared
)
T
.
copy
(
U_shared
,
U_fragment
)
for
i_s2
,
i_v
in
T
.
Parallel
(
block_S
,
block_DV
):
V_new_fragment
[
i_s2
,
i_v
]
=
-
V_new_fragment
[
i_s2
,
i_v
]
+
U_fragment
[
i_s2
,
i_v
]
...
...
@@ -198,11 +196,9 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
# Save V_new
if
save_new_value
:
T
.
copy
(
V_new_fragment
,
dst
=
V_new_shared
)
T
.
copy
(
V_new_shared
,
V_new
[
bb
,
i_s
*
block_S
:(
i_s
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:(
bv
+
1
)
*
block_DV
])
T
.
copy
(
V_new_shared
,
V_new
[
bb
,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
T
.
copy
(
K
[
bb
,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
bh
,
0
:
DK
],
K_shared
)
T
.
copy
(
K
[
bb
,
i_s
*
block_S
:
(
i_s
+
1
)
*
block_S
,
bh
,
0
:
DK
],
K_shared
)
# use_g
if
use_g
:
G_last_local
[
0
]
=
G
[
bb
,
(
i_s
+
1
)
*
block_S
-
1
,
bh
]
...
...
@@ -213,7 +209,8 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
with
T
.
If
(
G_last_local
[
0
]
-
G_fragment
[
i_s2
,
i_v
]
<=
0
):
with
T
.
Then
():
V_new_fragment
[
i_s2
,
i_v
]
=
V_new_fragment
[
i_s2
,
i_v
]
*
T
.
exp2
(
(
G_last_local
[
0
]
-
G_fragment
[
i_s2
,
i_v
])
*
1.442695
)
(
G_last_local
[
0
]
-
G_fragment
[
i_s2
,
i_v
])
*
1.442695
)
with
T
.
Else
():
V_new_fragment
[
i_s2
,
i_v
]
=
0
G_last_local
[
0
]
=
T
.
exp2
(
G_last_local
[
0
]
*
1.442695
)
...
...
@@ -228,7 +225,7 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
# Save final state
if
store_final_state
:
T
.
copy
(
b_h_fragment
,
final_state
[
bb
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
T
.
copy
(
b_h_fragment
,
final_state
[
bb
,
bh
,
0
:
DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
return
kernel
...
...
@@ -279,17 +276,24 @@ def run_test(
threads
=
128
,
num_stages
=
0
,
):
K
,
W
,
U
,
G
,
initial_state
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
K
,
W
,
U
,
G
,
initial_state
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
))
h_ref
,
final_state_ref
,
V_new_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
state_dtype
))
h_tilelang
,
final_state_tilelang
,
V_new_tilelang
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
state_dtype
))
getattr
(
torch
,
gate_dtype
),
)
h_ref
,
final_state_ref
,
V_new_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
state_dtype
)
)
h_tilelang
,
final_state_tilelang
,
V_new_tilelang
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
state_dtype
)
)
# fla ref
h_ref
,
V_new_ref
,
final_state_ref
=
chunk_gated_delta_rule_fwd_h
(
...
...
@@ -300,13 +304,27 @@ def run_test(
initial_state
=
initial_state
,
output_final_state
=
store_final_state
,
chunk_size
=
chunk_size
,
save_new_value
=
save_new_value
)
save_new_value
=
save_new_value
,
)
# tilelang
kernel
=
tilelang_chunk_gated_delta_rule_fwd_h
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
use_g
,
use_initial_state
,
store_final_state
,
save_new_value
)
kernel
=
tilelang_chunk_gated_delta_rule_fwd_h
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
use_g
,
use_initial_state
,
store_final_state
,
save_new_value
,
)
h_tilelang
,
final_state_tilelang
,
V_new_tilelang
=
kernel
(
K
,
W
,
U
,
G
,
initial_state
)
# (zhengju) If you want to print the generated cuda code, you can uncomment the following line
# print("CUDA Code:\n", kernel.get_kernel_source())
...
...
@@ -320,19 +338,15 @@ def run_test(
initial_state
=
initial_state
,
output_final_state
=
store_final_state
,
chunk_size
=
chunk_size
,
save_new_value
=
save_new_value
)
save_new_value
=
save_new_value
,
)
tilelang_time
=
do_bench
(
kernel
,
K
,
W
,
U
,
G
,
initial_state
)
# check correctness
try
:
h_ref_fp32
=
h_ref
.
to
(
torch
.
float32
)
h_tilelang_fp32
=
h_tilelang
.
to
(
torch
.
float32
)
assert_similar
(
h_ref_fp32
,
h_tilelang_fp32
,
eps
=
1e-5
,
name
=
"tilelang chunk gated delta rule fwd h"
,
raise_assert
=
False
)
assert_similar
(
h_ref_fp32
,
h_tilelang_fp32
,
eps
=
1e-5
,
name
=
"tilelang chunk gated delta rule fwd h"
,
raise_assert
=
False
)
print
(
"tilelang chunk gated delta rule fwd h passed √"
)
except
Exception
as
e
:
print
(
"tilelang chunk gated delta rule fwd h failed ✗"
)
...
...
@@ -346,7 +360,8 @@ def run_test(
final_state_tilelang_fp32
,
eps
=
1e-5
,
name
=
"tilelang chunk gated delta rule fwd final_state"
,
raise_assert
=
False
)
raise_assert
=
False
,
)
print
(
"tilelang chunk gated delta rule fwd final_state passed √"
)
except
Exception
as
e
:
print
(
"tilelang chunk gated delta rule fwd final_state failed ✗"
)
...
...
@@ -355,12 +370,7 @@ def run_test(
try
:
V_new_ref_fp32
=
V_new_ref
.
to
(
torch
.
float32
)
V_new_tilelang_fp32
=
V_new_tilelang
.
to
(
torch
.
float32
)
assert_similar
(
V_new_ref_fp32
,
V_new_tilelang_fp32
,
eps
=
1e-5
,
name
=
"tilelang chunk gated delta rule fwd V_new"
,
raise_assert
=
False
)
assert_similar
(
V_new_ref_fp32
,
V_new_tilelang_fp32
,
eps
=
1e-5
,
name
=
"tilelang chunk gated delta rule fwd V_new"
,
raise_assert
=
False
)
print
(
"tilelang chunk gated delta rule fwd V_new passed √"
)
except
Exception
as
e
:
print
(
"tilelang chunk gated delta rule fwd V_new failed ✗"
)
...
...
examples/gdn/example_chunk_o.py
View file @
29051439
...
...
@@ -9,6 +9,7 @@ import sys # noqa: F401
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try
:
import
fla
print
(
fla
.
__file__
)
from
fla.ops.common.chunk_o
import
chunk_fwd_o
except
ImportError
:
...
...
@@ -94,9 +95,7 @@ def tilelang_chunk_fwd_o(
G
:
T
.
Tensor
(
G_shape
,
dtype
=
gate_dtype
),
O
:
T
.
Tensor
(
O_shape
,
dtype
=
output_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
DV
,
block_DV
),
T
.
ceildiv
(
S
,
block_S
),
B
*
H
,
threads
=
threads
)
as
(
bv
,
bs
,
bbh
):
with
T
.
Kernel
(
T
.
ceildiv
(
DV
,
block_DV
),
T
.
ceildiv
(
S
,
block_S
),
B
*
H
,
threads
=
threads
)
as
(
bv
,
bs
,
bbh
):
bb
,
bh
=
bbh
//
H
,
bbh
%
H
Q_shared
=
T
.
alloc_shared
((
block_S
,
block_DK
),
dtype
=
input_dtype
)
K_shared
=
T
.
alloc_shared
((
block_S
,
block_DK
),
dtype
=
input_dtype
)
...
...
@@ -109,28 +108,24 @@ def tilelang_chunk_fwd_o(
G_shared
=
T
.
alloc_shared
((
block_S
,),
dtype
=
gate_dtype
,
scope
=
"shared"
)
G_diff_local
=
T
.
alloc_fragment
((
block_S
,
block_S
),
dtype
=
gate_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
Q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
Q_shared
),
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
V_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
V_shared
),
H_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
H_shared
),
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
O_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared
),
})
}
)
T
.
clear
(
A_fragment
)
T
.
clear
(
O_fragment
)
T
.
disable_warp_group_reg_alloc
()
for
i_k
in
T
.
Pipelined
(
T
.
ceildiv
(
DK
,
block_DK
),
num_stages
=
num_stages
):
T
.
copy
(
Q
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
],
Q_shared
)
T
.
copy
(
K
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
],
K_shared
)
T
.
copy
(
HIDDEN
[
bb
,
bs
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
,
bv
*
block_DV
:(
bv
+
1
)
*
block_DV
],
H_shared
)
T
.
copy
(
Q
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
],
Q_shared
)
T
.
copy
(
K
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
],
K_shared
)
T
.
copy
(
HIDDEN
[
bb
,
bs
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
],
H_shared
)
T
.
gemm
(
Q_shared
,
H_shared
,
O_fragment
)
T
.
gemm
(
Q_shared
,
K_shared
,
A_fragment
,
transpose_B
=
True
)
...
...
@@ -145,8 +140,7 @@ def tilelang_chunk_fwd_o(
for
i_s1
,
i_s2
in
T
.
Parallel
(
block_S
,
block_S
):
with
T
.
If
(
G_diff_local
[
i_s1
,
i_s2
]
<=
0
):
with
T
.
Then
():
A_fragment
[
i_s1
,
i_s2
]
=
A_fragment
[
i_s1
,
i_s2
]
*
T
.
exp
(
G_diff_local
[
i_s1
,
i_s2
])
A_fragment
[
i_s1
,
i_s2
]
=
A_fragment
[
i_s1
,
i_s2
]
*
T
.
exp
(
G_diff_local
[
i_s1
,
i_s2
])
with
T
.
Else
():
A_fragment
[
i_s1
,
i_s2
]
=
0
...
...
@@ -155,8 +149,7 @@ def tilelang_chunk_fwd_o(
with
T
.
Then
():
A_fragment
[
i_s1
,
i_s2
]
=
0
T
.
copy
(
V
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:(
bv
+
1
)
*
block_DV
],
V_shared
)
T
.
copy
(
V
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
],
V_shared
)
T
.
copy
(
A_fragment
,
A_shared
)
T
.
gemm
(
A_shared
,
V_shared
,
O_fragment
)
...
...
@@ -164,8 +157,7 @@ def tilelang_chunk_fwd_o(
O_fragment
[
i_s
,
i_v
]
=
O_fragment
[
i_s
,
i_v
]
*
scale
T
.
copy
(
O_fragment
,
O_shared
)
T
.
copy
(
O_shared
,
O
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:(
bv
+
1
)
*
block_DV
])
T
.
copy
(
O_shared
,
O
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
bv
*
block_DV
:
(
bv
+
1
)
*
block_DV
])
return
kernel
...
...
@@ -191,8 +183,9 @@ def run_test(
output_dtype_torch
=
getattr
(
torch
,
output_dtype
)
accum_dtype_torch
=
getattr
(
torch
,
accum_dtype
)
gate_dtype_torch
=
getattr
(
torch
,
gate_dtype
)
Q
,
K
,
V
,
HIDDEN
,
G
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
input_dtype_torch
,
output_dtype_torch
,
accum_dtype_torch
,
gate_dtype_torch
)
Q
,
K
,
V
,
HIDDEN
,
G
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
input_dtype_torch
,
output_dtype_torch
,
accum_dtype_torch
,
gate_dtype_torch
)
scale
=
1.0
/
DK
**
0.5
O_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
output_dtype_torch
)
...
...
@@ -200,9 +193,25 @@ def run_test(
block_S
=
chunk_size
O_tilelang
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
output_dtype_torch
)
kernel
=
tilelang_chunk_fwd_o
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
chunk_size
,
scale
,
use_g
,
block_S
,
block_DK
,
block_DV
,
threads
,
num_stages
)
kernel
=
tilelang_chunk_fwd_o
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
chunk_size
,
scale
,
use_g
,
block_S
,
block_DK
,
block_DV
,
threads
,
num_stages
,
)
O_tilelang
=
kernel
(
Q
,
K
,
V
,
HIDDEN
,
G
)
try
:
...
...
examples/gdn/example_chunk_o_bwd.py
View file @
29051439
...
...
@@ -12,6 +12,7 @@ from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F4
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try
:
import
fla
print
(
fla
.
__file__
)
from
fla.ops.common.chunk_o
import
chunk_bwd_dqkwg
except
ImportError
:
...
...
@@ -108,10 +109,8 @@ def prepare_output(
@
tilelang
.
jit
(
out_idx
=
[
-
4
,
-
3
,
-
2
,
-
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
})
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
},
)
def
tilelang_chunk_o_bwd_dqkwg
(
# task config
B
,
...
...
@@ -171,9 +170,7 @@ def tilelang_chunk_o_bwd_dqkwg(
dw
:
T
.
Tensor
(
dw_shape
,
dtype
=
output_dtype
),
dg
:
T
.
Tensor
(
dg_shape
,
dtype
=
gate_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
DK
,
block_DK
),
T
.
ceildiv
(
S
,
block_S
),
B
*
H
,
threads
=
threads
)
as
(
bk
,
bs
,
bbh
):
with
T
.
Kernel
(
T
.
ceildiv
(
DK
,
block_DK
),
T
.
ceildiv
(
S
,
block_S
),
B
*
H
,
threads
=
threads
)
as
(
bk
,
bs
,
bbh
):
bb
,
bh
=
bbh
//
H
,
bbh
%
H
V_shared
=
T
.
alloc_shared
((
block_S
,
block_DV
),
dtype
=
input_dtype
)
...
...
@@ -212,7 +209,8 @@ def tilelang_chunk_o_bwd_dqkwg(
T
.
use_swizzle
(
10
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
V_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
V_shared
),
dO_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dO_shared
),
h_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
h_shared
),
...
...
@@ -220,7 +218,8 @@ def tilelang_chunk_o_bwd_dqkwg(
dv_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
dv_shared
),
q_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
q_shared
),
k_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
k_shared
),
})
}
)
T
.
clear
(
dg_last_local
)
T
.
clear
(
G_last_local
)
...
...
@@ -235,18 +234,10 @@ def tilelang_chunk_o_bwd_dqkwg(
T
.
clear
(
dw_fragment
)
for
i_v
in
T
.
Pipelined
(
T
.
ceildiv
(
DV
,
block_DV
),
num_stages
=
num_stages
):
T
.
copy
(
V
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:(
i_v
+
1
)
*
block_DV
],
V_shared
)
T
.
copy
(
dO
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:(
i_v
+
1
)
*
block_DV
],
dO_shared
)
T
.
copy
(
h
[
bb
,
bs
,
bh
,
bk
*
block_DK
:(
bk
+
1
)
*
block_DK
,
i_v
*
block_DV
:(
i_v
+
1
)
*
block_DV
],
h_shared
)
T
.
copy
(
dh
[
bb
,
bs
,
bh
,
bk
*
block_DK
:(
bk
+
1
)
*
block_DK
,
i_v
*
block_DV
:(
i_v
+
1
)
*
block_DV
],
dh_shared
)
T
.
copy
(
V
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:
(
i_v
+
1
)
*
block_DV
],
V_shared
)
T
.
copy
(
dO
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:
(
i_v
+
1
)
*
block_DV
],
dO_shared
)
T
.
copy
(
h
[
bb
,
bs
,
bh
,
bk
*
block_DK
:
(
bk
+
1
)
*
block_DK
,
i_v
*
block_DV
:
(
i_v
+
1
)
*
block_DV
],
h_shared
)
T
.
copy
(
dh
[
bb
,
bs
,
bh
,
bk
*
block_DK
:
(
bk
+
1
)
*
block_DK
,
i_v
*
block_DV
:
(
i_v
+
1
)
*
block_DV
],
dh_shared
)
if
use_g
:
T
.
clear
(
dg_last_fragment_scalar
)
...
...
@@ -254,9 +245,7 @@ def tilelang_chunk_o_bwd_dqkwg(
# for i_kv in T.Parallel(block_DK * block_DV):
# dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV]
for
i_kv
in
T
.
Parallel
(
block_DK
*
block_DV
):
dg_last_fragment
[
i_kv
]
=
h_shared
[
i_kv
//
block_DV
,
i_kv
%
block_DV
]
*
dh_shared
[
i_kv
//
block_DV
,
i_kv
%
block_DV
]
dg_last_fragment
[
i_kv
]
=
h_shared
[
i_kv
//
block_DV
,
i_kv
%
block_DV
]
*
dh_shared
[
i_kv
//
block_DV
,
i_kv
%
block_DV
]
T
.
reduce_sum
(
dg_last_fragment
,
dg_last_fragment_scalar
,
dim
=-
1
,
clear
=
False
)
dg_last_local
[
0
]
+=
dg_last_fragment_scalar
[
0
]
...
...
@@ -265,22 +254,16 @@ def tilelang_chunk_o_bwd_dqkwg(
T
.
gemm
(
V_shared
,
dh_shared
,
dk_fragment
,
transpose_B
=
True
)
if
use_dw
:
T
.
copy
(
dv
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:(
i_v
+
1
)
*
block_DV
],
dv_shared
)
T
.
copy
(
dv
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:
(
i_v
+
1
)
*
block_DV
],
dv_shared
)
T
.
gemm
(
dv_shared
,
h_shared
,
dw_fragment
,
transpose_B
=
True
)
if
use_dw
:
for
i_s
,
i_k
in
T
.
Parallel
(
block_S
,
block_DK
):
dw_fragment
[
i_s
,
i_k
]
=
-
dw_fragment
[
i_s
,
i_k
]
T
.
copy
(
dw_fragment
,
dw
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:(
bk
+
1
)
*
block_DK
])
T
.
copy
(
Q
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:(
bk
+
1
)
*
block_DK
],
q_shared
)
T
.
copy
(
K
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:(
bk
+
1
)
*
block_DK
],
k_shared
)
T
.
copy
(
dw_fragment
,
dw
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:
(
bk
+
1
)
*
block_DK
])
T
.
copy
(
Q
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:
(
bk
+
1
)
*
block_DK
],
q_shared
)
T
.
copy
(
K
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:
(
bk
+
1
)
*
block_DK
],
k_shared
)
T
.
copy
(
q_shared
,
q_fragment
)
T
.
copy
(
k_shared
,
k_fragment
)
...
...
@@ -294,8 +277,7 @@ def tilelang_chunk_o_bwd_dqkwg(
dg_last_local
[
0
]
=
dg_last_local
[
0
]
*
T
.
exp
(
G
[
bb
,
bs
*
block_S
+
block_S
-
1
,
bh
])
for
i_s
,
i_k
in
T
.
Parallel
(
block_S
,
block_DK
):
dq_fragment
[
i_s
,
i_k
]
=
dq_fragment
[
i_s
,
i_k
]
*
T
.
exp
(
G
[
bb
,
bs
*
block_S
+
i_s
,
bh
])
*
scale
dq_fragment
[
i_s
,
i_k
]
=
dq_fragment
[
i_s
,
i_k
]
*
T
.
exp
(
G
[
bb
,
bs
*
block_S
+
i_s
,
bh
])
*
scale
T
.
clear
(
dg_fragment_reduce_tmp
)
for
i_s
,
i_k
in
T
.
Parallel
(
block_S
,
block_DK
):
dg_fragment_reduce_tmp
[
i_s
,
i_k
]
=
dq_fragment
[
i_s
,
i_k
]
*
q_shared
[
i_s
,
i_k
]
...
...
@@ -305,8 +287,7 @@ def tilelang_chunk_o_bwd_dqkwg(
for
i_s
,
i_k
in
T
.
Parallel
(
block_S
,
block_DK
):
with
T
.
If
(
G_last_local
[
0
]
-
G
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
<=
0
):
with
T
.
Then
():
dk_fragment
[
i_s
,
i_k
]
=
dk_fragment
[
i_s
,
i_k
]
*
T
.
exp
(
G_last_local
[
0
]
-
G
[
bb
,
bs
*
block_S
+
i_s
,
bh
])
dk_fragment
[
i_s
,
i_k
]
=
dk_fragment
[
i_s
,
i_k
]
*
T
.
exp
(
G_last_local
[
0
]
-
G
[
bb
,
bs
*
block_S
+
i_s
,
bh
])
with
T
.
Else
():
dk_fragment
[
i_s
,
i_k
]
=
0
T
.
clear
(
dg_fragment_reduce_tmp
)
...
...
@@ -325,12 +306,11 @@ def tilelang_chunk_o_bwd_dqkwg(
dg_last_local
[
1
]
=
dg_last_fragment_scalar_2
[
0
]
for
i_s1
,
i_s2
in
T
.
Parallel
(
block_S
,
block_S
):
with
T
.
If
(
i_s1
>=
i_s2
and
G
[
bb
,
bs
*
block_S
+
i_s1
,
bh
]
-
G
[
bb
,
bs
*
block_S
+
i_s2
,
bh
]
<=
0
):
with
T
.
If
(
i_s1
>=
i_s2
and
G
[
bb
,
bs
*
block_S
+
i_s1
,
bh
]
-
G
[
bb
,
bs
*
block_S
+
i_s2
,
bh
]
<=
0
):
with
T
.
Then
():
ds_fragment
[
i_s1
,
i_s2
]
=
ds_fragment
[
i_s1
,
i_s2
]
*
T
.
exp
(
G
[
bb
,
bs
*
block_S
+
i_s1
,
bh
]
-
G
[
bb
,
bs
*
block_S
+
i_s2
,
bh
])
*
scale
ds_fragment
[
i_s1
,
i_s2
]
=
(
ds_fragment
[
i_s1
,
i_s2
]
*
T
.
exp
(
G
[
bb
,
bs
*
block_S
+
i_s1
,
bh
]
-
G
[
bb
,
bs
*
block_S
+
i_s2
,
bh
])
*
scale
)
with
T
.
Else
():
ds_fragment
[
i_s1
,
i_s2
]
=
0
...
...
@@ -338,8 +318,7 @@ def tilelang_chunk_o_bwd_dqkwg(
T
.
clear
(
ds_fragment_positive_transpose
)
T
.
gemm
(
q_shared
,
k_shared
,
ds_fragment_positive
,
transpose_B
=
True
)
for
i_s1
,
i_s2
in
T
.
Parallel
(
block_S
,
block_S
):
ds_fragment_positive
[
i_s1
,
i_s2
]
=
ds_fragment
[
i_s1
,
i_s2
]
*
ds_fragment_positive
[
i_s1
,
i_s2
]
ds_fragment_positive
[
i_s1
,
i_s2
]
=
ds_fragment
[
i_s1
,
i_s2
]
*
ds_fragment_positive
[
i_s1
,
i_s2
]
# FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass
T
.
reduce_sum
(
ds_fragment_positive
,
dg_fragment
,
dim
=
1
,
clear
=
False
)
...
...
@@ -363,15 +342,10 @@ def tilelang_chunk_o_bwd_dqkwg(
for
i_s
in
T
.
Parallel
(
block_S
):
with
T
.
If
(
i_s
>=
block_S
-
1
):
# noqa: SIM117
with
T
.
Then
():
dg_fragment_final
[
i_s
]
=
dg_fragment_final
[
i_s
]
+
dg_last_local
[
0
]
+
dg_last_local
[
1
]
T
.
copy
(
dq_fragment
,
dq
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:(
bk
+
1
)
*
block_DK
])
T
.
copy
(
dk_fragment
,
dk
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:(
bk
+
1
)
*
block_DK
])
dg_fragment_final
[
i_s
]
=
dg_fragment_final
[
i_s
]
+
dg_last_local
[
0
]
+
dg_last_local
[
1
]
T
.
copy
(
dq_fragment
,
dq
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:
(
bk
+
1
)
*
block_DK
])
T
.
copy
(
dk_fragment
,
dk
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:
(
bk
+
1
)
*
block_DK
])
for
i_s
in
T
.
Parallel
(
block_S
):
dg
[
bk
,
bb
,
bs
*
block_S
+
i_s
,
bh
]
=
dg_fragment_final
[
i_s
]
...
...
@@ -387,12 +361,8 @@ def tilelang_chunk_o_bwd_dqkwg(
for
i_s
,
i_k
in
T
.
Parallel
(
block_S
,
block_DK
):
dq_fragment
[
i_s
,
i_k
]
=
dq_fragment
[
i_s
,
i_k
]
*
scale
dk_fragment
[
i_s
,
i_k
]
=
dk_fragment
[
i_s
,
i_k
]
+
dk_fragment_2
[
i_s
,
i_k
]
*
scale
T
.
copy
(
dq_fragment
,
dq
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:(
bk
+
1
)
*
block_DK
])
T
.
copy
(
dk_fragment
,
dk
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:(
bk
+
1
)
*
block_DK
])
T
.
copy
(
dq_fragment
,
dq
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:
(
bk
+
1
)
*
block_DK
])
T
.
copy
(
dk_fragment
,
dk
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
bk
*
block_DK
:
(
bk
+
1
)
*
block_DK
])
return
kernel
...
...
@@ -442,32 +412,53 @@ def run_test(
threads
=
256
,
num_stages
=
0
,
):
Q
,
K
,
V
,
h
,
G
,
dO
,
dh
,
dv
,
W
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
Q
,
K
,
V
,
h
,
G
,
dO
,
dh
,
dv
,
W
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
)
)
dq_ref
,
dk_ref
,
dw_ref
,
dg_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
g
ate_dtype
),
getattr
(
torch
,
state_dtype
),
block_DK
)
getattr
(
torch
,
state_dtype
)
,
)
dq_ref
,
dk_ref
,
dw_ref
,
dg_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
st
ate_dtype
),
block_DK
)
dq_tilelang
,
dk_tilelang
,
dw_tilelang
,
dg_tilelang
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
),
block_DK
)
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
),
block_DK
)
# ref
if
use_g
:
dq_ref
,
dk_ref
,
dw_ref
,
dg_ref
=
chunk_bwd_dqkwg
(
Q
,
K
,
V
,
G
,
dO
,
h
,
dh
,
dv
,
W
,
chunk_size
=
chunk_size
,
scale
=
scale
)
dq_ref
,
dk_ref
,
dw_ref
,
dg_ref
=
chunk_bwd_dqkwg
(
Q
,
K
,
V
,
G
,
dO
,
h
,
dh
,
dv
,
W
,
chunk_size
=
chunk_size
,
scale
=
scale
)
else
:
dq_ref
,
dk_ref
,
dw_ref
,
dg_ref
=
chunk_bwd_dqkwg
(
Q
,
K
,
V
,
None
,
dO
,
h
,
dh
,
dv
,
W
,
chunk_size
=
chunk_size
,
scale
=
scale
)
dq_ref
,
dk_ref
,
dw_ref
,
dg_ref
=
chunk_bwd_dqkwg
(
Q
,
K
,
V
,
None
,
dO
,
h
,
dh
,
dv
,
W
,
chunk_size
=
chunk_size
,
scale
=
scale
)
# tilelang
kernel
=
tilelang_chunk_o_bwd_dqkwg
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
scale
,
use_g
,
use_dw
,
block_DK
,
block_DV
,
threads
,
num_stages
)
kernel
=
tilelang_chunk_o_bwd_dqkwg
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
scale
,
use_g
,
use_dw
,
block_DK
,
block_DV
,
threads
,
num_stages
,
)
dq_tilelang
,
dk_tilelang
,
dw_tilelang
,
dg_tilelang
=
kernel
(
Q
,
K
,
V
,
h
,
G
,
dO
,
dh
,
dv
,
W
)
if
use_g
:
...
...
examples/gdn/example_chunk_scaled_dot_kkt.py
View file @
29051439
...
...
@@ -9,6 +9,7 @@ import sys # noqa: F401
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try
:
import
fla
print
(
fla
.
__file__
)
from
fla.ops.common.chunk_scaled_dot_kkt
import
chunk_scaled_dot_kkt_fwd
except
ImportError
:
...
...
@@ -93,10 +94,12 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
G_shared
=
T
.
alloc_shared
((
block_S
,),
dtype
=
accum_dtype
,
scope
=
"shared"
)
G_diff_local
=
T
.
alloc_fragment
((
block_S
,
block_S
),
dtype
=
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
})
}
)
T
.
fill
(
A_fragment
,
0
)
T
.
disable_warp_group_reg_alloc
()
...
...
@@ -104,9 +107,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
Beta_shared
[
i_s
]
=
Beta
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
for
i_k
in
T
.
Pipelined
(
T
.
ceildiv
(
DK
,
block_DK
),
num_stages
=
num_stages
):
T
.
copy
(
K
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
],
K_shared
)
T
.
copy
(
K
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
],
K_shared
)
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
Beta_K_fragment
[
i_s
,
i_k2
]
=
K_shared
[
i_s
,
i_k2
]
*
Beta_shared
[
i_s
]
T
.
gemm
(
Beta_K_fragment
,
K_shared
,
A_fragment
,
transpose_B
=
True
)
...
...
@@ -119,8 +120,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
for
i_s1
,
i_s2
in
T
.
Parallel
(
block_S
,
block_S
):
with
T
.
If
(
G_diff_local
[
i_s1
,
i_s2
]
<=
0
and
i_s1
>
i_s2
):
with
T
.
Then
():
A_fragment
[
i_s1
,
i_s2
]
=
A_fragment
[
i_s1
,
i_s2
]
*
T
.
exp
(
G_diff_local
[
i_s1
,
i_s2
])
A_fragment
[
i_s1
,
i_s2
]
=
A_fragment
[
i_s1
,
i_s2
]
*
T
.
exp
(
G_diff_local
[
i_s1
,
i_s2
])
with
T
.
Else
():
A_fragment
[
i_s1
,
i_s2
]
=
0
else
:
...
...
@@ -130,7 +130,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
A_fragment
[
i_s1
,
i_s2
]
=
0
T
.
copy
(
A_fragment
,
A_shared
)
T
.
copy
(
A_shared
,
A
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:])
T
.
copy
(
A_shared
,
A
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:])
return
kernel
...
...
@@ -149,24 +149,21 @@ def run_test(
threads
,
num_stages
,
):
K
,
Beta
,
G
=
prepare_input
(
B
,
S
,
H
,
DK
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
))
K
,
Beta
,
G
=
prepare_input
(
B
,
S
,
H
,
DK
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
))
A_ref
=
prepare_output
(
B
,
S
,
H
,
chunk_size
,
getattr
(
torch
,
output_dtype
))
A_tilelang
=
prepare_output
(
B
,
S
,
H
,
chunk_size
,
getattr
(
torch
,
output_dtype
))
# reference
if
use_g
:
A_ref
=
chunk_scaled_dot_kkt_fwd
(
K
,
Beta
,
G
,
chunk_size
=
chunk_size
,
output_dtype
=
getattr
(
torch
,
output_dtype
))
A_ref
=
chunk_scaled_dot_kkt_fwd
(
K
,
Beta
,
G
,
chunk_size
=
chunk_size
,
output_dtype
=
getattr
(
torch
,
output_dtype
))
else
:
A_ref
=
chunk_scaled_dot_kkt_fwd
(
K
,
Beta
,
None
,
chunk_size
=
chunk_size
,
output_dtype
=
getattr
(
torch
,
output_dtype
))
A_ref
=
chunk_scaled_dot_kkt_fwd
(
K
,
Beta
,
None
,
chunk_size
=
chunk_size
,
output_dtype
=
getattr
(
torch
,
output_dtype
))
# tilelang
block_S
=
chunk_size
kernel
=
tilelang_chunk_scaled_dot_kkt_fwd
(
B
,
S
,
H
,
DK
,
chunk_size
,
input_dtype
,
output_dtype
,
accum_dtype
,
use_g
,
block_S
,
block_DK
,
threads
,
num_stages
)
kernel
=
tilelang_chunk_scaled_dot_kkt_fwd
(
B
,
S
,
H
,
DK
,
chunk_size
,
input_dtype
,
output_dtype
,
accum_dtype
,
use_g
,
block_S
,
block_DK
,
threads
,
num_stages
)
A_tilelang
=
kernel
(
K
,
Beta
,
G
)
try
:
...
...
@@ -192,7 +189,8 @@ def main():
use_g
=
True
,
block_DK
=
64
,
threads
=
128
,
num_stages
=
2
)
num_stages
=
2
,
)
if
__name__
==
"__main__"
:
...
...
examples/gdn/example_cumsum.py
View file @
29051439
...
...
@@ -10,6 +10,7 @@ import sys # noqa: F401
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try
:
import
fla
print
(
fla
.
__file__
)
from
fla.ops.utils.cumsum
import
chunk_local_cumsum_scalar
except
ImportError
:
...
...
@@ -20,11 +21,8 @@ import torch
@
tilelang
.
jit
(
out_idx
=
[
-
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
})
out_idx
=
[
-
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
}
)
def
tilelang_chunk_local_cumsum_scalar
(
# task config
B
,
...
...
@@ -42,7 +40,7 @@ def tilelang_chunk_local_cumsum_scalar(
use_fragment
=
False
,
):
G_shape
=
(
B
,
H
,
S
)
if
head_first
else
(
B
,
S
,
H
)
assert
chunk_size
==
2
**
(
chunk_size
.
bit_length
()
-
1
),
"chunk_size must be a power of 2"
assert
chunk_size
==
2
**
(
chunk_size
.
bit_length
()
-
1
),
"chunk_size must be a power of 2"
assert
chunk_size
==
block_S
,
"chunk_size must be equal to block_S"
@
T
.
prim_func
...
...
@@ -54,23 +52,23 @@ def tilelang_chunk_local_cumsum_scalar(
bb
,
bh
=
bbh
//
H
,
bbh
%
H
G_shared
=
T
.
alloc_shared
((
1
,
block_S
),
dtype
=
output_dtype
,
scope
=
"shared"
)
if
head_first
:
T
.
copy
(
G
[
bb
,
bh
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
],
G_shared
)
T
.
copy
(
G
[
bb
,
bh
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
],
G_shared
)
else
:
T
.
copy
(
G
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
],
G_shared
)
T
.
copy
(
G
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
],
G_shared
)
if
use_fragment
:
G_fragment
=
T
.
alloc_fragment
((
1
,
block_S
),
dtype
=
output_dtype
,
scope
=
"shared"
)
T
.
copy
(
G_shared
,
G_fragment
)
T
.
cumsum
(
G_fragment
,
dim
=
1
,
reverse
=
reverse
)
if
head_first
:
T
.
copy
(
G_fragment
,
G_new
[
bb
,
bh
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
])
T
.
copy
(
G_fragment
,
G_new
[
bb
,
bh
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
])
else
:
T
.
copy
(
G_fragment
,
G_new
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
])
T
.
copy
(
G_fragment
,
G_new
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
])
else
:
T
.
cumsum
(
G_shared
,
dim
=
1
,
reverse
=
reverse
)
if
head_first
:
T
.
copy
(
G_shared
,
G_new
[
bb
,
bh
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
])
T
.
copy
(
G_shared
,
G_new
[
bb
,
bh
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
])
else
:
T
.
copy
(
G_shared
,
G_new
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
])
T
.
copy
(
G_shared
,
G_new
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
])
return
kernel
...
...
@@ -113,11 +111,8 @@ def run_test(
# reference cumsum
G_new_ref
=
chunk_local_cumsum_scalar
(
g
=
G
,
chunk_size
=
chunk_size
,
reverse
=
reverse
,
head_first
=
head_first
,
output_dtype
=
getattr
(
torch
,
output_dtype
))
g
=
G
,
chunk_size
=
chunk_size
,
reverse
=
reverse
,
head_first
=
head_first
,
output_dtype
=
getattr
(
torch
,
output_dtype
)
)
# tilelang cumsum
block_S
=
chunk_size
...
...
@@ -162,7 +157,8 @@ def main():
input_dtype
=
"float32"
,
output_dtype
=
"float32"
,
threads
=
256
,
use_fragment
=
False
)
use_fragment
=
False
,
)
if
__name__
==
"__main__"
:
...
...
examples/gdn/example_wy_fast.py
View file @
29051439
...
...
@@ -9,6 +9,7 @@ import sys # noqa: F401
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try
:
import
fla
print
(
fla
.
__file__
)
from
fla.ops.gated_delta_rule.wy_fast
import
recompute_w_u_fwd
except
ImportError
:
...
...
@@ -95,7 +96,8 @@ def tilelang_recompute_w_u_fwd(
W_Beta_shared
=
T
.
alloc_shared
((
block_S
,
block_DK
),
dtype
=
input_dtype
)
U_Beta_shared
=
T
.
alloc_shared
((
block_S
,
block_DV
),
dtype
=
input_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
K_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
K_shared
),
V_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
V_shared
),
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
...
...
@@ -103,41 +105,33 @@ def tilelang_recompute_w_u_fwd(
U_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
U_shared
),
W_Beta_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
W_Beta_shared
),
U_Beta_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
U_Beta_shared
),
})
}
)
T
.
disable_warp_group_reg_alloc
()
for
i_s
in
T
.
Parallel
(
block_S
):
Beta_shared
[
i_s
]
=
Beta
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
G_shared
[
i_s
]
=
T
.
exp
(
G
[
bb
,
bs
*
block_S
+
i_s
,
bh
])
T
.
copy
(
A
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:],
A_shared
)
T
.
copy
(
A
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:],
A_shared
)
for
i_v
in
T
.
Pipelined
(
T
.
ceildiv
(
DV
,
block_DV
),
num_stages
=
num_stages
):
T
.
copy
(
V
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:(
i_v
+
1
)
*
block_DV
],
V_shared
)
T
.
copy
(
V
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:
(
i_v
+
1
)
*
block_DV
],
V_shared
)
for
i_s
,
i_v2
in
T
.
Parallel
(
block_S
,
block_DV
):
U_Beta_shared
[
i_s
,
i_v2
]
=
V_shared
[
i_s
,
i_v2
]
*
Beta_shared
[
i_s
]
T
.
gemm
(
A_shared
,
U_Beta_shared
,
U_fragment
,
clear_accum
=
True
)
# First copy to smem, then copy to gmem to reduce U2RU instructions
T
.
copy
(
U_fragment
,
U_shared
)
T
.
copy
(
U_shared
,
U
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:(
i_v
+
1
)
*
block_DV
])
T
.
copy
(
U_shared
,
U
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:
(
i_v
+
1
)
*
block_DV
])
for
i_k
in
T
.
Pipelined
(
T
.
ceildiv
(
DK
,
block_DK
),
num_stages
=
num_stages
):
T
.
copy
(
K
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
],
K_shared
)
T
.
copy
(
K
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
],
K_shared
)
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
W_Beta_shared
[
i_s
,
i_k2
]
=
K_shared
[
i_s
,
i_k2
]
*
Beta_shared
[
i_s
]
*
G_shared
[
i_s
]
W_Beta_shared
[
i_s
,
i_k2
]
=
K_shared
[
i_s
,
i_k2
]
*
Beta_shared
[
i_s
]
*
G_shared
[
i_s
]
T
.
gemm
(
A_shared
,
W_Beta_shared
,
W_fragment
,
clear_accum
=
True
)
# First copy to smem, then copy to gmem to reduce U2RU instructions
T
.
copy
(
W_fragment
,
W_shared
)
T
.
copy
(
W_shared
,
W
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
])
T
.
copy
(
W_shared
,
W
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
])
return
kernel
...
...
@@ -159,15 +153,8 @@ def run_test(
num_stages
,
):
K
,
V
,
Beta
,
G
,
A
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
gate_dtype
=
getattr
(
torch
,
gate_dtype
))
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
gate_dtype
=
getattr
(
torch
,
gate_dtype
)
)
W_ref
,
U_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
getattr
(
torch
,
output_dtype
))
W_tilelang
,
U_tilelang
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
getattr
(
torch
,
output_dtype
))
...
...
@@ -191,7 +178,8 @@ def run_test(
block_DK
=
block_DK
,
block_DV
=
block_DV
,
threads
=
threads
,
num_stages
=
num_stages
)
num_stages
=
num_stages
,
)
print
(
kernel
.
get_kernel_source
())
W_tilelang
,
U_tilelang
=
kernel
(
K
,
V
,
Beta
,
G
,
A
)
...
...
@@ -224,7 +212,8 @@ def main():
block_DK
=
64
,
block_DV
=
32
,
threads
=
128
,
num_stages
=
3
)
num_stages
=
3
,
)
if
__name__
==
"__main__"
:
...
...
examples/gdn/example_wy_fast_bwd_split.py
View file @
29051439
...
...
@@ -10,6 +10,7 @@ import tilelang.language as T
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try
:
import
fla
print
(
fla
.
__file__
)
from
fla.ops.gated_delta_rule.wy_fast
import
bwd_prepare_wy_repr
except
ImportError
:
...
...
@@ -93,10 +94,8 @@ def prepare_output(
@
tilelang
.
jit
(
out_idx
=
[
-
5
,
-
4
,
-
3
,
-
2
,
-
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
})
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
},
)
def
tilelang_wy_fast_bwd
(
# task config
B
,
...
...
@@ -187,7 +186,7 @@ def tilelang_wy_fast_bwd(
T
.
clear
(
dbeta_fragment_v
)
T
.
clear
(
dg_fragment
)
T
.
copy
(
A
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:],
A_shared
)
T
.
copy
(
A
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:],
A_shared
)
for
i_s
in
T
.
Parallel
(
block_S
):
Beta_shared
[
i_s
]
=
Beta
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
G_shared
[
i_s
]
=
G
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
...
...
@@ -195,51 +194,37 @@ def tilelang_wy_fast_bwd(
# Update dk
for
i_k
in
T
.
Pipelined
(
T
.
ceildiv
(
DK
,
block_DK
),
num_stages
=
num_stages
):
T
.
copy
(
K
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
],
K_shared
)
T
.
copy
(
K
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
],
K_shared
)
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
K_shared_beta_g
[
i_s
,
i_k2
]
=
K_shared
[
i_s
,
i_k2
]
*
Beta_shared
[
i_s
]
*
G_shared_exp
[
i_s
]
T
.
copy
(
dw
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
],
dw_shared
)
K_shared_beta_g
[
i_s
,
i_k2
]
=
K_shared
[
i_s
,
i_k2
]
*
Beta_shared
[
i_s
]
*
G_shared_exp
[
i_s
]
T
.
copy
(
dw
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
],
dw_shared
)
T
.
gemm
(
dw_shared
,
K_shared_beta_g
,
dA_fragment
,
transpose_B
=
True
)
T
.
gemm
(
A_shared
,
dw_shared
,
dk_fragment_beta_g
,
clear_accum
=
True
,
transpose_A
=
True
)
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
dk_fragment
[
i_s
,
i_k2
]
=
dk_fragment_beta_g
[
i_s
,
i_k2
]
*
Beta_shared
[
i_s
]
*
G_shared_exp
[
i_s
]
dk_fragment
[
i_s
,
i_k2
]
=
dk_fragment_beta_g
[
i_s
,
i_k2
]
*
Beta_shared
[
i_s
]
*
G_shared_exp
[
i_s
]
# for i_s, i_k2 in T.Parallel(block_S, block_DK):
# dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s]
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
dbeta_fragment_reduce_tmpk
[
i_s
,
i_k2
]
=
dk_fragment_beta_g
[
i_s
,
i_k2
]
*
K_shared
[
i_s
,
i_k2
]
*
G_shared_exp
[
i_s
]
dbeta_fragment_reduce_tmpk
[
i_s
,
i_k2
]
=
dk_fragment_beta_g
[
i_s
,
i_k2
]
*
K_shared
[
i_s
,
i_k2
]
*
G_shared_exp
[
i_s
]
T
.
reduce_sum
(
dbeta_fragment_reduce_tmpk
,
dbeta_fragment_k
,
dim
=
1
,
clear
=
False
)
# for i_s, i_k2 in T.Parallel(block_S, block_DK):
# dg_fragment[i_s] = dg_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s]
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
dg_fragment_reduce_tmp
[
i_s
,
i_k2
]
=
dk_fragment_beta_g
[
i_s
,
i_k2
]
*
K_shared
[
i_s
,
i_k2
]
*
G_shared_exp
[
i_s
]
*
Beta_shared
[
i_s
]
dg_fragment_reduce_tmp
[
i_s
,
i_k2
]
=
(
dk_fragment_beta_g
[
i_s
,
i_k2
]
*
K_shared
[
i_s
,
i_k2
]
*
G_shared_exp
[
i_s
]
*
Beta_shared
[
i_s
]
)
T
.
reduce_sum
(
dg_fragment_reduce_tmp
,
dg_fragment
,
dim
=
1
,
clear
=
False
)
# correct dk
T
.
copy
(
dk_fragment
,
dk
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
])
T
.
copy
(
dk_fragment
,
dk
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
])
# Update dv
for
i_v
in
T
.
Pipelined
(
T
.
ceildiv
(
DV
,
block_DV
),
num_stages
=
num_stages
):
T
.
copy
(
V
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:(
i_v
+
1
)
*
block_DV
],
V_shared
)
T
.
copy
(
V
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:
(
i_v
+
1
)
*
block_DV
],
V_shared
)
for
i_s
,
i_v2
in
T
.
Parallel
(
block_S
,
block_DV
):
V_shared_beta
[
i_s
,
i_v2
]
=
V_shared
[
i_s
,
i_v2
]
*
Beta_shared
[
i_s
]
T
.
copy
(
du
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:(
i_v
+
1
)
*
block_DV
],
du_shared
)
T
.
copy
(
du
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:
(
i_v
+
1
)
*
block_DV
],
du_shared
)
T
.
gemm
(
du_shared
,
V_shared_beta
,
dA_fragment
,
transpose_B
=
True
)
T
.
gemm
(
A_shared
,
du_shared
,
dv_fragment_beta
,
clear_accum
=
True
,
transpose_A
=
True
)
for
i_s
,
i_v2
in
T
.
Parallel
(
block_S
,
block_DV
):
...
...
@@ -247,30 +232,22 @@ def tilelang_wy_fast_bwd(
# for i_s, i_v2 in T.Parallel(block_S, block_DV):
# dbeta_fragment[i_s] = dbeta_fragment[i_s] + dv_fragment_beta[i_s, i_v2] * V_shared[i_s, i_v2]
for
i_s
,
i_v2
in
T
.
Parallel
(
block_S
,
block_DV
):
dbeta_fragment_reduce_tmpv
[
i_s
,
i_v2
]
=
dv_fragment_beta
[
i_s
,
i_v2
]
*
V_shared
[
i_s
,
i_v2
]
dbeta_fragment_reduce_tmpv
[
i_s
,
i_v2
]
=
dv_fragment_beta
[
i_s
,
i_v2
]
*
V_shared
[
i_s
,
i_v2
]
T
.
reduce_sum
(
dbeta_fragment_reduce_tmpv
,
dbeta_fragment_v
,
dim
=
1
,
clear
=
False
)
T
.
copy
(
dv_fragment
,
dv
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:(
i_v
+
1
)
*
block_DV
])
T
.
copy
(
dv_fragment
,
dv
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_v
*
block_DV
:
(
i_v
+
1
)
*
block_DV
])
# Temporary store dbeta, dg and dA
for
i_s
in
T
.
Parallel
(
block_S
):
dbeta
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
=
dbeta_fragment_k
[
i_s
]
+
dbeta_fragment_v
[
i_s
]
dg
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
=
dg_fragment
[
i_s
]
# correct dA
T
.
copy
(
dA_fragment
,
dA
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:])
T
.
copy
(
dA_fragment
,
dA
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:])
return
kernel
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
})
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
})
def
tilelang_wy_fast_bwd_split
(
# task config
B
,
...
...
@@ -350,7 +327,7 @@ def tilelang_wy_fast_bwd_split(
T
.
clear
(
dA_A_fragment_1
)
T
.
clear
(
dA_A_fragment_2
)
T
.
copy
(
A
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:],
A_shared
)
T
.
copy
(
A
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:],
A_shared
)
for
i_s
in
T
.
Parallel
(
block_S
):
Beta_shared
[
i_s
]
=
Beta
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
G_shared
[
i_s
]
=
G
[
bb
,
bs
*
block_S
+
i_s
,
bh
]
...
...
@@ -361,7 +338,7 @@ def tilelang_wy_fast_bwd_split(
# for i_s in T.Parallel(block_S):
# dbeta_fragment[i_s] = dbeta[bb, bs * block_S + i_s, bh]
# dg_fragment[i_s] = dg[bb, bs * block_S + i_s, bh]
T
.
copy
(
dA
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:],
dA_shared
)
T
.
copy
(
dA
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
:],
dA_shared
)
# T.copy(dA_shared, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :])
# Update dA
...
...
@@ -385,8 +362,7 @@ def tilelang_wy_fast_bwd_split(
for
i_s1
,
i_s2
in
T
.
Parallel
(
block_S
,
block_S
):
with
T
.
If
(
G
[
bb
,
bs
*
block_S
+
i_s1
,
bh
]
-
G
[
bb
,
bs
*
block_S
+
i_s2
,
bh
]
<=
0
):
with
T
.
Then
():
dA_fragment
[
i_s1
,
i_s2
]
*=
T
.
exp
(
G
[
bb
,
bs
*
block_S
+
i_s1
,
bh
]
-
G
[
bb
,
bs
*
block_S
+
i_s2
,
bh
])
dA_fragment
[
i_s1
,
i_s2
]
*=
T
.
exp
(
G
[
bb
,
bs
*
block_S
+
i_s1
,
bh
]
-
G
[
bb
,
bs
*
block_S
+
i_s2
,
bh
])
with
T
.
Else
():
dA_fragment
[
i_s1
,
i_s2
]
=
0
T
.
copy
(
dA_fragment
,
dA_shared
)
...
...
@@ -397,12 +373,8 @@ def tilelang_wy_fast_bwd_split(
# Update dk using previous dk
T
.
clear
(
A_fragment
)
for
i_k
in
T
.
Pipelined
(
T
.
ceildiv
(
DK
,
block_DK
),
num_stages
=
num_stages
):
T
.
copy
(
K
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
],
K_shared
)
T
.
copy
(
dk
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
],
dk_shared
)
T
.
copy
(
K
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
],
K_shared
)
T
.
copy
(
dk
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
],
dk_shared
)
T
.
copy
(
dk_shared
,
dk_fragment
)
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
K_shared_beta
[
i_s
,
i_k2
]
=
K_shared
[
i_s
,
i_k2
]
*
Beta_shared
[
i_s
]
...
...
@@ -411,18 +383,14 @@ def tilelang_wy_fast_bwd_split(
# for i_s, i_k2 in T.Parallel(block_S, block_DK):
# dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta[i_s, i_k2] * K_shared[i_s, i_k2]
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
dbeta_fragment_reduce_tmpk
[
i_s
,
i_k2
]
=
dk_fragment_beta
[
i_s
,
i_k2
]
*
K_shared
[
i_s
,
i_k2
]
dbeta_fragment_reduce_tmpk
[
i_s
,
i_k2
]
=
dk_fragment_beta
[
i_s
,
i_k2
]
*
K_shared
[
i_s
,
i_k2
]
T
.
reduce_sum
(
dbeta_fragment_reduce_tmpk
,
dbeta_fragment_k
,
dim
=
1
,
clear
=
False
)
T
.
gemm
(
dA_shared
,
K_shared_beta
,
dk_fragment
,
transpose_A
=
True
)
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
dk_shared_beta
[
i_s
,
i_k2
]
=
dk_fragment_beta
[
i_s
,
i_k2
]
*
Beta_shared
[
i_s
]
for
i_s
,
i_k2
in
T
.
Parallel
(
block_S
,
block_DK
):
dk_fragment
[
i_s
,
i_k2
]
=
dk_fragment
[
i_s
,
i_k2
]
+
dk_shared_beta
[
i_s
,
i_k2
]
T
.
copy
(
dk_fragment
,
dk
[
bb
,
bs
*
block_S
:(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:(
i_k
+
1
)
*
block_DK
])
T
.
copy
(
dk_fragment
,
dk
[
bb
,
bs
*
block_S
:
(
bs
+
1
)
*
block_S
,
bh
,
i_k
*
block_DK
:
(
i_k
+
1
)
*
block_DK
])
# Update dg and dbeta
T
.
copy
(
A_fragment
,
A_shared
)
...
...
@@ -460,19 +428,25 @@ def run_test(
threads
=
128
,
num_stages
=
0
,
):
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
))
dk_ref
,
dv_ref
,
dbeta_ref
,
dg_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
))
getattr
(
torch
,
state_dtype
),
)
dk_ref
,
dv_ref
,
dbeta_ref
,
dg_ref
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
)
)
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang
,
dg_tilelang
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
)
)
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
)
)
BS
=
chunk_size
dA_tilelang
=
torch
.
empty
(
B
,
S
,
H
,
BS
,
dtype
=
getattr
(
torch
,
input_dtype
)).
cuda
()
dbeta_tilelang_k
=
torch
.
empty
(
B
,
S
,
H
,
dtype
=
getattr
(
torch
,
output_dtype
)).
cuda
()
...
...
@@ -480,28 +454,55 @@ def run_test(
dg_tilelang_A_negative
=
torch
.
empty
(
B
,
S
,
H
,
BS
,
dtype
=
getattr
(
torch
,
gate_dtype
)).
cuda
()
# ref
dk_ref
,
dv_ref
,
dbeta_ref
,
dg_ref
=
bwd_prepare_wy_repr
(
K
,
V
,
G
,
Beta
,
A
,
dw
,
du
,
cu_seqlens
=
None
)
dk_ref
,
dv_ref
,
dbeta_ref
,
dg_ref
=
bwd_prepare_wy_repr
(
K
,
V
,
G
,
Beta
,
A
,
dw
,
du
,
cu_seqlens
=
None
)
# tilelang
kernel
=
tilelang_wy_fast_bwd
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
block_DK
,
block_DV
,
threads
,
num_stages
)
dA_tilelang
,
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang
,
dg_tilelang
=
kernel
(
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
)
kernel
=
tilelang_wy_fast_bwd
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
block_DK
,
block_DV
,
threads
,
num_stages
,
)
dA_tilelang
,
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang
,
dg_tilelang
=
kernel
(
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
)
torch
.
cuda
.
synchronize
()
kernel_split
=
tilelang_wy_fast_bwd_split
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
block_DK
,
block_DV
,
threads
,
num_stages
)
kernel_split
(
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
,
dA_tilelang
,
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang_k
,
dg_tilelang_A_positive
,
dg_tilelang_A_negative
)
kernel_split
=
tilelang_wy_fast_bwd_split
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
block_DK
,
block_DV
,
threads
,
num_stages
,
)
kernel_split
(
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
,
dA_tilelang
,
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang_k
,
dg_tilelang_A_positive
,
dg_tilelang_A_negative
)
torch
.
cuda
.
synchronize
()
dbeta_tilelang
=
dbeta_tilelang_k
+
dbeta_tilelang
dg_tilelang
=
dg_tilelang
+
dg_tilelang_A_positive
.
sum
(
dim
=-
1
)
-
dg_tilelang_A_negative
.
sum
(
dim
=-
1
)
dg_tilelang
=
dg_tilelang
+
dg_tilelang_A_positive
.
sum
(
dim
=-
1
)
-
dg_tilelang_A_negative
.
sum
(
dim
=-
1
)
from
test_utils
import
assert_similar
assert_similar
(
dk_ref
,
dk_tilelang
,
eps
=
1e-5
,
name
=
"dk"
,
raise_assert
=
False
)
assert_similar
(
dv_ref
,
dv_tilelang
,
eps
=
1e-5
,
name
=
"dv"
,
raise_assert
=
False
)
assert_similar
(
dbeta_ref
,
dbeta_tilelang
,
eps
=
1e-5
,
name
=
"dbeta"
,
raise_assert
=
False
)
...
...
examples/gdn/test_example_gdn_compilation.py
View file @
29051439
...
...
@@ -25,16 +25,10 @@ num_stages = 1
def
test_example_wy_fast_compilation
():
from
example_wy_fast
import
tilelang_recompute_w_u_fwd
,
prepare_input
K
,
V
,
Beta
,
G
,
A
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
gate_dtype
=
getattr
(
torch
,
gate_dtype
))
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
gate_dtype
=
getattr
(
torch
,
gate_dtype
)
)
# tilelang
block_S
=
chunk_size
kernel
=
tilelang_recompute_w_u_fwd
(
...
...
@@ -52,22 +46,31 @@ def test_example_wy_fast_compilation():
block_DK
=
block_DK
,
block_DV
=
block_DV
,
threads
=
threads
,
num_stages
=
num_stages
)
num_stages
=
num_stages
,
)
print
(
kernel
.
get_kernel_source
())
W_tilelang
,
U_tilelang
=
kernel
(
K
,
V
,
Beta
,
G
,
A
)
def
test_example_wy_fast_bwd_split_compilation
():
from
example_wy_fast_bwd_split
import
tilelang_wy_fast_bwd
,
tilelang_wy_fast_bwd_split
,
prepare_input
,
prepare_output
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
))
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
),
)
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang
,
dg_tilelang
=
prepare_output
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
)
)
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
)
)
BS
=
chunk_size
dA_tilelang
=
torch
.
empty
(
B
,
S
,
H
,
BS
,
dtype
=
getattr
(
torch
,
input_dtype
)).
cuda
()
dbeta_tilelang_k
=
torch
.
empty
(
B
,
S
,
H
,
dtype
=
getattr
(
torch
,
output_dtype
)).
cuda
()
...
...
@@ -75,68 +78,146 @@ def test_example_wy_fast_bwd_split_compilation():
dg_tilelang_A_negative
=
torch
.
empty
(
B
,
S
,
H
,
BS
,
dtype
=
getattr
(
torch
,
gate_dtype
)).
cuda
()
# tilelang
kernel
=
tilelang_wy_fast_bwd
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
block_DK
,
block_DV
,
threads
,
num_stages
)
dA_tilelang
,
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang
,
dg_tilelang
=
kernel
(
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
)
kernel
=
tilelang_wy_fast_bwd
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
block_DK
,
block_DV
,
threads
,
num_stages
,
)
dA_tilelang
,
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang
,
dg_tilelang
=
kernel
(
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
)
torch
.
cuda
.
synchronize
()
kernel_split
=
tilelang_wy_fast_bwd_split
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
block_DK
,
block_DV
,
threads
,
num_stages
)
kernel_split
(
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
,
dA_tilelang
,
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang_k
,
dg_tilelang_A_positive
,
dg_tilelang_A_negative
)
kernel_split
=
tilelang_wy_fast_bwd_split
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
block_DK
,
block_DV
,
threads
,
num_stages
,
)
kernel_split
(
K
,
V
,
Beta
,
G
,
A
,
dw
,
du
,
dA_tilelang
,
dk_tilelang
,
dv_tilelang
,
dbeta_tilelang_k
,
dg_tilelang_A_positive
,
dg_tilelang_A_negative
)
torch
.
cuda
.
synchronize
()
dbeta_tilelang
=
dbeta_tilelang_k
+
dbeta_tilelang
dg_tilelang
=
dg_tilelang
+
dg_tilelang_A_positive
.
sum
(
dim
=-
1
)
-
dg_tilelang_A_negative
.
sum
(
dim
=-
1
)
dg_tilelang
=
dg_tilelang
+
dg_tilelang_A_positive
.
sum
(
dim
=-
1
)
-
dg_tilelang_A_negative
.
sum
(
dim
=-
1
)
def
test_example_chunk_o_compilation
():
from
example_chunk_o
import
tilelang_chunk_fwd_o
,
prepare_input
Q
,
K
,
V
,
HIDDEN
,
G
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
))
Q
,
K
,
V
,
HIDDEN
,
G
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
)
scale
=
1.0
/
DK
**
0.5
block_S
=
chunk_size
kernel
=
tilelang_chunk_fwd_o
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
chunk_size
,
scale
,
use_g
,
block_S
,
block_DK
,
block_DV
,
threads
,
num_stages
)
kernel
=
tilelang_chunk_fwd_o
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
chunk_size
,
scale
,
use_g
,
block_S
,
block_DK
,
block_DV
,
threads
,
num_stages
,
)
O_tilelang
=
kernel
(
Q
,
K
,
V
,
HIDDEN
,
G
)
# noqa: F841
def
test_example_chunk_o_bwd_compilation
():
from
example_chunk_o_bwd
import
tilelang_chunk_o_bwd_dqkwg
,
prepare_input
Q
,
K
,
V
,
h
,
G
,
dO
,
dh
,
dv
,
W
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
Q
,
K
,
V
,
h
,
G
,
dO
,
dh
,
dv
,
W
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
))
kernel
=
tilelang_chunk_o_bwd_dqkwg
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
1.0
,
use_g
,
True
,
block_DK
,
block_DV
,
threads
,
num_stages
)
getattr
(
torch
,
state_dtype
),
)
kernel
=
tilelang_chunk_o_bwd_dqkwg
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
1.0
,
use_g
,
True
,
block_DK
,
block_DV
,
threads
,
num_stages
,
)
dq_tilelang
,
dk_tilelang
,
dw_tilelang
,
dg_tilelang
=
kernel
(
Q
,
K
,
V
,
h
,
G
,
dO
,
dh
,
dv
,
W
)
# noqa: F841
dq_tilelang
,
dk_tilelang
,
dw_tilelang
,
dg_tilelang
=
kernel
(
Q
,
K
,
V
,
h
,
G
,
dO
,
dh
,
dv
,
W
)
# noqa: F841
if
use_g
:
dg_tilelang
=
dg_tilelang
.
sum
(
dim
=
0
)
def
test_example_chunk_scaled_dot_kkt_compilation
():
from
example_chunk_scaled_dot_kkt
import
tilelang_chunk_scaled_dot_kkt_fwd
,
prepare_input
K
,
Beta
,
G
=
prepare_input
(
B
,
S
,
H
,
DK
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
))
K
,
Beta
,
G
=
prepare_input
(
B
,
S
,
H
,
DK
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
))
block_S
=
chunk_size
kernel
=
tilelang_chunk_scaled_dot_kkt_fwd
(
B
,
S
,
H
,
DK
,
chunk_size
,
input_dtype
,
output_dtype
,
accum_dtype
,
use_g
,
block_S
,
block_DK
,
threads
,
num_stages
)
kernel
=
tilelang_chunk_scaled_dot_kkt_fwd
(
B
,
S
,
H
,
DK
,
chunk_size
,
input_dtype
,
output_dtype
,
accum_dtype
,
use_g
,
block_S
,
block_DK
,
threads
,
num_stages
)
A_tilelang
=
kernel
(
K
,
Beta
,
G
)
# noqa: F841
def
test_example_cumsum_compilation
():
from
example_cumsum
import
tilelang_chunk_local_cumsum_scalar
,
prepare_cumsum_input
,
prepare_cumsum_output
G
=
prepare_cumsum_input
(
B
,
S
,
H
,
getattr
(
torch
,
gate_dtype
))
G_new_tilelang
=
prepare_cumsum_output
(
B
,
S
,
H
,
getattr
(
torch
,
gate_dtype
))
block_S
=
chunk_size
...
...
@@ -158,33 +239,79 @@ def test_example_cumsum_compilation():
def
test_example_chunk_delta_h_compilation
():
from
example_chunk_delta_h
import
tilelang_chunk_gated_delta_rule_fwd_h
,
prepare_input
K
,
W
,
U
,
G
,
initial_state
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
K
,
W
,
U
,
G
,
initial_state
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
))
kernel
=
tilelang_chunk_gated_delta_rule_fwd_h
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
use_g
,
use_initial_state
,
store_final_state
,
save_new_value
,
block_DK
,
block_DV
,
threads
,
num_stages
)
h_tilelang
,
final_state_tilelang
,
V_new_tilelang
=
kernel
(
K
,
W
,
U
,
G
,
initial_state
)
# noqa: F841
getattr
(
torch
,
gate_dtype
),
)
kernel
=
tilelang_chunk_gated_delta_rule_fwd_h
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
use_g
,
use_initial_state
,
store_final_state
,
save_new_value
,
block_DK
,
block_DV
,
threads
,
num_stages
,
)
h_tilelang
,
final_state_tilelang
,
V_new_tilelang
=
kernel
(
K
,
W
,
U
,
G
,
initial_state
)
# noqa: F841
def
test_example_chunk_delta_bwd_compilation
():
from
example_chunk_delta_bwd
import
tilelang_chunk_gated_delta_rule_bwd_dhu
,
prepare_input
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
=
prepare_input
(
B
,
S
,
H
,
DK
,
DV
,
chunk_size
,
getattr
(
torch
,
input_dtype
),
getattr
(
torch
,
output_dtype
),
getattr
(
torch
,
accum_dtype
),
getattr
(
torch
,
gate_dtype
),
getattr
(
torch
,
state_dtype
))
kernel
=
tilelang_chunk_gated_delta_rule_bwd_dhu
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
1.0
,
use_g
,
use_initial_state
,
use_final_state_gradient
,
block_DV
,
threads
,
num_stages
)
getattr
(
torch
,
state_dtype
),
)
kernel
=
tilelang_chunk_gated_delta_rule_bwd_dhu
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
1.0
,
use_g
,
use_initial_state
,
use_final_state_gradient
,
block_DV
,
threads
,
num_stages
,
)
dh_tilelang
,
dh0_tilelang
,
dv2_tilelang
=
kernel
(
Q
,
K
,
W
,
G
,
h0
,
dht
,
dO
,
dv
)
# noqa: F841
...
...
examples/gdn/test_utils.py
View file @
29051439
...
...
@@ -9,7 +9,7 @@ def calc_sim(x, y, name="tensor"):
x
,
y
=
x
.
data
.
double
(),
y
.
data
.
double
()
denominator
=
(
x
*
x
+
y
*
y
).
sum
()
if
denominator
==
0
:
print_red_warning
(
f
'
{
name
}
all zero
'
)
print_red_warning
(
f
"
{
name
}
all zero
"
)
return
1
sim
=
2
*
(
x
*
y
).
sum
()
/
denominator
return
sim
...
...
@@ -19,21 +19,19 @@ def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True):
x_mask
=
torch
.
isfinite
(
x
)
y_mask
=
torch
.
isfinite
(
y
)
if
not
torch
.
all
(
x_mask
==
y_mask
):
print_red_warning
(
f
'
{
name
}
Error: isfinite mask mismatch
'
)
print_red_warning
(
f
"
{
name
}
Error: isfinite mask mismatch
"
)
if
raise_assert
:
raise
AssertionError
if
not
torch
.
isclose
(
x
.
masked_fill
(
x_mask
,
0
),
y
.
masked_fill
(
y_mask
,
0
),
rtol
=
0
,
atol
=
0
,
equal_nan
=
True
).
all
():
print_red_warning
(
f
'
{
name
}
Error: nonfinite value mismatch'
)
if
not
torch
.
isclose
(
x
.
masked_fill
(
x_mask
,
0
),
y
.
masked_fill
(
y_mask
,
0
),
rtol
=
0
,
atol
=
0
,
equal_nan
=
True
).
all
():
print_red_warning
(
f
"
{
name
}
Error: nonfinite value mismatch"
)
if
raise_assert
:
raise
AssertionError
x
=
x
.
masked_fill
(
~
x_mask
,
0
)
y
=
y
.
masked_fill
(
~
y_mask
,
0
)
sim
=
calc_sim
(
x
,
y
,
name
)
diff
=
1.
-
sim
diff
=
1.
0
-
sim
if
not
(
0
<=
diff
<=
eps
):
print_red_warning
(
f
'
{
name
}
Error:
{
diff
}
'
)
print_red_warning
(
f
"
{
name
}
Error:
{
diff
}
"
)
if
raise_assert
:
raise
AssertionError
else
:
...
...
Prev
1
…
3
4
5
6
7
8
9
10
11
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment