Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
bbbf4207
Unverified
Commit
bbbf4207
authored
Nov 14, 2025
by
guchaoyang
Committed by
GitHub
Nov 14, 2025
Browse files
Merge branch 'main' into dcu
parents
8f4628e0
5eb30a4f
Changes
286
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2923 additions
and
69 deletions
+2923
-69
examples/flash_decoding/example_gqa_decode.py
examples/flash_decoding/example_gqa_decode.py
+5
-4
examples/flash_decoding/example_gqa_decode_varlen_logits.py
examples/flash_decoding/example_gqa_decode_varlen_logits.py
+960
-0
examples/flash_decoding/example_mha_inference.py
examples/flash_decoding/example_mha_inference.py
+1
-3
examples/flash_decoding/test_example_flash_decoding.py
examples/flash_decoding/test_example_flash_decoding.py
+1
-1
examples/gdn/example_chunk_o_bwd.py
examples/gdn/example_chunk_o_bwd.py
+3
-4
examples/gemm/README.md
examples/gemm/README.md
+37
-34
examples/linear_attention/example_linear_attn_fwd.py
examples/linear_attention/example_linear_attn_fwd.py
+1
-1
examples/linear_attention/example_retention_fwd.py
examples/linear_attention/example_retention_fwd.py
+0
-7
examples/norm/rms_norm.py
examples/norm/rms_norm.py
+2
-2
examples/norm/test_rms_norm.py
examples/norm/test_rms_norm.py
+2
-2
examples/plot_layout/fragment_mfma_load_a.py
examples/plot_layout/fragment_mfma_load_a.py
+133
-0
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py
...ialize/example_warp_specialize_gemm_barrierpipe_stage2.py
+5
-0
format.sh
format.sh
+18
-8
maint/gemm_v2/correctness_evaluation.py
maint/gemm_v2/correctness_evaluation.py
+735
-0
maint/gemm_v2/correctness_evaluation_sm70.py
maint/gemm_v2/correctness_evaluation_sm70.py
+350
-0
maint/gemm_v2/correctness_evaluation_tcgen05.py
maint/gemm_v2/correctness_evaluation_tcgen05.py
+226
-0
maint/gemm_v2/latency.py
maint/gemm_v2/latency.py
+99
-0
maint/gemm_v2/latency_gemm.py
maint/gemm_v2/latency_gemm.py
+99
-0
maint/gemm_v2/latency_mha_fwd_bhsd.py
maint/gemm_v2/latency_mha_fwd_bhsd.py
+246
-0
maint/scripts/docker_build_all.sh
maint/scripts/docker_build_all.sh
+0
-3
No files found.
examples/flash_decoding/example_gqa_decode.py
View file @
bbbf4207
...
...
@@ -40,9 +40,9 @@ def get_heuristic_config() -> Tuple[Dict, int]:
sm_version
=
sm_major
*
10
+
sm_minor
print
(
f
"CUDA device capability:
{
sm_version
}
"
)
if
sm_version
==
89
:
cfg
=
dict
(
block_N
=
128
,
block_H
=
64
,
num_split
=
1
6
,
num_stages
=
0
,
threads
=
128
)
cfg
=
dict
(
block_N
=
128
,
block_H
=
64
,
num_split
=
1
,
num_stages
=
0
,
threads
=
128
)
else
:
cfg
=
dict
(
block_N
=
128
,
block_H
=
64
,
num_split
=
1
6
,
num_stages
=
2
,
threads
=
128
)
cfg
=
dict
(
block_N
=
128
,
block_H
=
64
,
num_split
=
1
,
num_stages
=
2
,
threads
=
128
)
return
cfg
,
sm_version
...
...
@@ -459,8 +459,9 @@ def main(batch: int = 1,
k
=
torch
.
randn
(
batch
,
kv_seqlen
,
groups
,
dim
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
batch
,
kv_seqlen
,
groups
,
dim
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
mask
=
torch
.
randint
(
0
,
2
,
(
batch
,
kv_seqlen
,
groups
),
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
glse
=
torch
.
empty
(
batch
,
heads
,
16
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
Output_partial
=
torch
.
empty
(
batch
,
heads
,
16
,
dim
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
split
=
config
[
"num_split"
]
glse
=
torch
.
empty
(
batch
,
heads
,
split
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
Output_partial
=
torch
.
empty
(
batch
,
heads
,
split
,
dim
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
o
=
kernel
(
q
,
k
,
v
,
mask
,
glse
,
Output_partial
)
o_ref
=
ref_program
(
q
,
k
,
v
,
mask
,
glse
,
Output_partial
)
o_ref_split
=
ref_split_program
(
q
,
k
,
v
,
mask
,
glse
,
Output_partial
)
...
...
examples/flash_decoding/example_gqa_decode_varlen_logits.py
0 → 100644
View file @
bbbf4207
import
torch
import
triton
import
triton.language
as
tl
import
math
import
argparse
import
tilelang
import
tilelang.language
as
T
from
tilelang.autotuner
import
autotune
torch
.
manual_seed
(
0
)
tilelang
.
disable_cache
()
def
repeat_kv
(
hidden_states
:
torch
.
Tensor
,
n_rep
:
int
)
->
torch
.
Tensor
:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch
,
num_key_value_heads
,
slen
,
head_dim
=
hidden_states
.
shape
if
n_rep
==
1
:
return
hidden_states
hidden_states
=
hidden_states
[:,
:,
None
,
:,
:].
expand
(
batch
,
num_key_value_heads
,
n_rep
,
slen
,
head_dim
)
return
hidden_states
.
reshape
(
batch
,
num_key_value_heads
*
n_rep
,
slen
,
head_dim
)
@
triton
.
jit
def
_fwd_inner
(
q
,
k_ptrs
,
v_ptrs
,
s_ptrs
,
m_i
,
l_i
,
acc
,
offs_h
,
mask_h
,
offs_n
,
seqlen
,
softmax_scale
,
lo
,
hi
,
stride_kt
,
stride_vt
,
stride_sh
,
stride_sn
,
BLOCK_N
:
tl
.
constexpr
,
):
"""Inner loop computation for attention"""
for
blk_idx
in
tl
.
range
(
lo
,
hi
):
start_n
=
blk_idx
*
BLOCK_N
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kt
,
mask
=
offs_n
[
None
,
:]
+
start_n
<
seqlen
)
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vt
,
mask
=
offs_n
[:,
None
]
+
start_n
<
seqlen
)
qk
=
tl
.
dot
(
q
,
k
)
qk
*=
softmax_scale
qk
+=
tl
.
where
(
offs_n
[
None
,
:]
+
start_n
<
seqlen
,
0
,
-
1.0e9
)
row_max
=
tl
.
max
(
qk
,
1
)
tl
.
store
(
s_ptrs
+
offs_h
*
stride_sh
+
blk_idx
*
stride_sn
,
row_max
,
mask
=
mask_h
)
m_ij
=
tl
.
maximum
(
m_i
,
row_max
)
qk
-=
m_ij
[:,
None
]
p
=
tl
.
math
.
exp
(
qk
)
l_ij
=
tl
.
sum
(
p
,
1
)
alpha
=
tl
.
math
.
exp
(
m_i
-
m_ij
)
l_i
=
l_i
*
alpha
+
l_ij
m_i
=
m_ij
acc
*=
alpha
[:,
None
]
p
=
p
.
to
(
v
.
type
.
element_ty
)
acc
+=
tl
.
dot
(
p
,
v
)
return
m_i
,
l_i
,
acc
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
4
,
8
]
\
for
num_stages
in
[
2
,
4
]
\
],
key
=
[
'gqa_group_size'
,
'BLOCK_N'
,
'BLOCK_D'
,
'BLOCK_H'
],
)
@
triton
.
jit
def
_fwd_kernel_varlen
(
Q
,
# [token_q = b, h_q, dim]
K
,
# [token_k, h_kv, dim]
V
,
O
,
S
,
s_aux
,
softmax_scale
,
cu_seqlens_k
,
stride_qt
,
stride_qh
,
stride_qd
,
stride_kt
,
stride_kh
,
stride_kd
,
stride_vt
,
stride_vh
,
stride_vd
,
stride_ot
,
stride_oh
,
stride_od
,
stride_sb
,
stride_sh
,
stride_sn
,
#bmask shape [b, q_h, seq/BLOCK_N]
gqa_group_size
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
):
off_z
=
tl
.
program_id
(
0
)
off_h_for_kv
=
tl
.
program_id
(
1
)
off_h_q
=
off_h_for_kv
*
gqa_group_size
cu_k_start
=
tl
.
load
(
cu_seqlens_k
+
off_z
)
cu_k_end
=
tl
.
load
(
cu_seqlens_k
+
off_z
+
1
)
seqlen_k
=
cu_k_end
-
cu_k_start
offs_h
=
tl
.
arange
(
0
,
BLOCK_H
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_D
)
Q_ptrs
=
Q
+
off_z
*
stride_qt
+
off_h_q
*
stride_qh
K_ptrs
=
K
+
(
cu_k_start
)
*
stride_kt
+
off_h_for_kv
*
stride_kh
V_ptrs
=
V
+
(
cu_k_start
)
*
stride_vt
+
off_h_for_kv
*
stride_vh
O_ptrs
=
O
+
off_z
*
stride_ot
+
off_h_q
*
stride_oh
S_ptrs
=
S
+
off_z
*
stride_sb
+
off_h_q
*
stride_sh
mask_h
=
offs_h
<
gqa_group_size
q
=
tl
.
load
(
Q_ptrs
+
offs_d
[
None
,
:]
*
stride_qd
+
offs_h
[:,
None
]
*
stride_qh
,
mask
=
mask_h
[:,
None
])
if
s_aux
is
not
None
:
sink
=
tl
.
load
(
s_aux
+
off_h_q
+
offs_h
,
mask
=
mask_h
).
to
(
tl
.
float32
)
l_i
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
m_i
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
+
sink
else
:
l_i
=
tl
.
full
([
BLOCK_H
],
1.0
,
dtype
=
tl
.
float32
)
m_i
=
tl
.
full
([
BLOCK_H
],
float
(
"-inf"
),
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_H
,
BLOCK_D
],
dtype
=
tl
.
float32
)
k_ptrs
=
K_ptrs
+
offs_n
[
None
,
:]
*
stride_kt
+
offs_d
[:,
None
]
*
stride_kd
v_ptrs
=
V_ptrs
+
offs_n
[:,
None
]
*
stride_vt
+
offs_d
[
None
,
:]
*
stride_vd
lo
,
hi
=
0
,
tl
.
cdiv
(
seqlen_k
,
BLOCK_N
)
m_i
,
l_i
,
acc
=
_fwd_inner
(
q
,
k_ptrs
,
v_ptrs
,
S_ptrs
,
m_i
,
l_i
,
acc
,
offs_h
,
mask_h
,
offs_n
,
seqlen_k
,
softmax_scale
,
lo
,
hi
,
stride_kt
,
stride_vt
,
stride_sh
,
stride_sn
,
BLOCK_N
,
)
if
s_aux
is
not
None
:
sink
=
tl
.
math
.
exp
(
sink
-
m_i
)
l_i
=
l_i
+
sink
acc
=
acc
/
l_i
[:,
None
]
else
:
l_recip
=
1
/
l_i
[:,
None
]
acc
=
acc
*
l_recip
for
blk_idx
in
tl
.
range
(
lo
,
hi
):
s
=
tl
.
load
(
S_ptrs
+
offs_h
*
stride_sh
+
blk_idx
*
stride_sn
,
mask
=
mask_h
)
s
=
tl
.
exp
(
s
-
m_i
)
/
l_i
tl
.
store
(
S_ptrs
+
offs_h
*
stride_sh
+
blk_idx
*
stride_sn
,
s
,
mask
=
mask_h
)
acc
=
acc
.
to
(
O
.
dtype
.
element_ty
)
tl
.
store
(
O_ptrs
+
offs_h
[:,
None
]
*
stride_oh
+
offs_d
[
None
,
:]
*
stride_od
,
acc
,
mask
=
mask_h
[:,
None
])
def
get_configs
():
import
itertools
block_N
=
[
64
,
128
]
block_H
=
[
64
]
num_split
=
[
1
]
num_stages
=
[
1
,
2
,
3
]
threads
=
[
128
]
_configs
=
list
(
itertools
.
product
(
block_N
,
block_H
,
num_split
,
num_stages
,
threads
))
configs
=
[{
'block_N'
:
c
[
0
],
'block_H'
:
c
[
1
],
'num_split'
:
c
[
2
],
'num_stages'
:
c
[
3
],
'threads'
:
c
[
4
]
}
for
c
in
_configs
]
return
configs
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
tilelang
.
jit
(
out_idx
=
[
-
2
,
-
1
],
debug_root_path
=
"./examples/flash_decoding"
)
def
flashattn
(
batch
,
heads
,
k_heads
,
max_seqlen_kv
,
total_seqlen_k
,
dim
,
has_sink
,
block_N
=
128
,
block_H
=
64
,
num_split
=
1
,
num_stages
=
1
,
threads
=
128
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape_q
=
[
batch
,
heads
,
dim
]
shape_k
=
[
total_seqlen_k
,
k_heads
,
dim
]
shape_v
=
[
total_seqlen_k
,
k_heads
,
dim
]
shape_o
=
[
batch
,
heads
,
dim
]
shape_s
=
[
batch
,
heads
,
math
.
ceil
(
max_seqlen_kv
/
block_N
)]
dtype
=
"float16"
accum_dtype
=
"float"
kv_group_num
=
heads
//
k_heads
valid_block_H
=
min
(
block_H
,
kv_group_num
)
# TODO: check if max_seqlen_kv is correct for varlen case
@
T
.
macro
def
flash_attn
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
cu_seqlens_k
:
T
.
Tensor
([
batch
+
1
],
"int32"
),
s_aux
:
T
.
Tensor
([
heads
],
"float32"
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
S
:
T
.
Tensor
(
shape_s
,
dtype
),
):
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
O_shared
=
T
.
alloc_shared
([
valid_block_H
,
dim
],
dtype
)
acc_s
=
T
.
alloc_fragment
([
block_H
,
block_N
],
accum_dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_H
,
block_N
],
dtype
)
acc_o
=
T
.
alloc_fragment
([
block_H
,
dim
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
scores_max_prev
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
scores_scale
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
scores_sum
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
S_shared
=
T
.
alloc_shared
([
block_H
,
math
.
ceil
(
max_seqlen_kv
/
block_N
)],
dtype
)
# S_fragment = T.alloc_fragment([block_H, math.ceil(max_seqlen_kv / block_N)], accum_dtype)
s_aux_shared
=
T
.
alloc_shared
([
block_H
],
"float32"
)
T
.
annotate_layout
({
# Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
# K_shared: tilelang.layout.make_swizzled_layout(K_shared),
# V_shared: tilelang.layout.make_swizzled_layout(V_shared),
# O_shared: tilelang.layout.make_swizzled_layout(O_shared),
# S_shared: tilelang.layout.make_swizzled_layout(S_shared),
})
bid
=
bx
hid
=
by
cur_kv_head
=
hid
//
(
kv_group_num
//
valid_block_H
)
cur_start_k
=
cu_seqlens_k
[
bid
]
cur_end_k
=
cu_seqlens_k
[
bid
+
1
]
cur_seqlen_k
=
cur_end_k
-
cur_start_k
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
# loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
loop_range
=
T
.
ceildiv
((
cur_seqlen_k
//
num_split
),
block_N
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
T
.
copy
(
K
[
cur_start_k
+
k
*
block_N
:
cur_start_k
+
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
# acc_s[i, j] = T.if_then_else(mask_local[j] != 0 and k * block_N + j < cur_seqlen_k, acc_s[i, j],
# -T.infinity(accum_dtype))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
k
*
block_N
+
j
<
cur_seqlen_k
,
acc_s
[
i
,
j
],
-
T
.
infinity
(
accum_dtype
))
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
# scores_max_prev is m_i
# scores_max is row_max->m_ij in triton
T
.
copy
(
scores_max
,
S_shared
[:,
k
])
# scores_scale is alpha in triton
for
i
in
T
.
Parallel
(
block_H
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
# scores_sum is l_ij in triton
# logsum is l_i in triton
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
T
.
copy
(
acc_s
,
acc_s_cast
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
V
[
cur_start_k
+
k
*
block_N
:
cur_start_k
+
(
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
if
has_sink
:
T
.
copy
(
s_aux
[
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
],
s_aux_shared
)
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
+=
s_aux_shared
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
for
h
,
k
in
T
.
Parallel
(
block_H
,
math
.
ceil
(
max_seqlen_kv
/
block_N
)):
S_shared
[
h
,
k
]
=
T
.
exp2
((
S_shared
[
h
,
k
]
-
scores_max
[
h
])
*
scale
)
/
logsum
[
h
]
# T.copy(S_shared, S_fragment)
# for h, k in T.Parallel(block_H, math.ceil(max_seqlen_kv / block_N)):
# S_fragment[h, k] = T.exp2((S_fragment[h, k] - scores_max[h]) * scale) / logsum[h]
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
T
.
copy
(
acc_o
[:
valid_block_H
,
:],
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bid
,
hid
*
valid_block_H
:(
hid
+
1
)
*
valid_block_H
,
:])
# T.copy(S_fragment, S_shared)
T
.
copy
(
S_shared
[:
valid_block_H
,
:],
S
[
bid
,
hid
*
valid_block_H
:(
hid
+
1
)
*
valid_block_H
,
:])
@
T
.
prim_func
def
flashattn_gqa_decode_no_split
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
cu_seqlens_k
:
T
.
Tensor
([
batch
+
1
],
"int32"
),
s_aux
:
T
.
Tensor
([
heads
],
"float32"
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
S
:
T
.
Tensor
(
shape_s
,
dtype
),
):
flash_attn
(
Q
,
K
,
V
,
cu_seqlens_k
,
s_aux
,
Output
,
S
)
# TODO: split version
return
flashattn_gqa_decode_no_split
def
flash_attn_with_attn_pool_decode_tilelang
(
Q
:
torch
.
Tensor
,
## [tq = b, q_h, q_dim]
K
:
torch
.
Tensor
,
## [tk, k_h, k_dim]
V
:
torch
.
Tensor
,
cu_seqlens_k
:
torch
.
Tensor
,
max_seqlen_k
:
int
,
real_max_k_seqlen
:
int
,
num_split
:
int
,
softmax_scale
:
float
,
s_aux
:
torch
.
Tensor
=
None
,
block_size
:
int
=
64
,
use_per_kv_head_sparse_index
:
bool
=
False
,
tl_kernel
=
None
,
):
num_tokens
,
q_h
,
head_size
=
Q
.
shape
batch
=
cu_seqlens_k
.
size
(
0
)
-
1
k_h
=
K
.
size
(
1
)
assert
Q
.
dim
()
==
K
.
dim
()
==
3
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
assert
cu_seqlens_k
.
dim
()
==
1
assert
head_size
in
{
64
,
128
,
256
}
assert
Q
.
is_contiguous
()
assert
K
.
is_contiguous
()
assert
V
.
is_contiguous
()
gqa_group_size
=
q_h
//
k_h
O_tl
=
torch
.
zeros_like
(
Q
)
S_tl
=
torch
.
zeros
((
batch
,
q_h
,
math
.
ceil
(
real_max_k_seqlen
/
block_size
)),
dtype
=
Q
.
dtype
,
device
=
Q
.
device
)
O_tl
,
S_tl
=
tl_kernel
(
Q
,
K
,
V
,
cu_seqlens_k
,
s_aux
)
if
use_per_kv_head_sparse_index
:
S_tl
=
torch
.
max_pool2d
(
S_tl
,
kernel_size
=
(
gqa_group_size
,
1
),
stride
=
(
gqa_group_size
,
1
))
else
:
S_tl
=
torch
.
max_pool2d
(
S_tl
,
kernel_size
=
(
q_h
,
1
),
stride
=
(
q_h
,
1
))
return
O_tl
,
S_tl
def
flash_attn_with_attn_pool_decode
(
Q
:
torch
.
Tensor
,
## [tq = b, q_h, q_dim]
K
:
torch
.
Tensor
,
## [tk, k_h, k_dim]
V
:
torch
.
Tensor
,
cu_seqlens_k
:
torch
.
Tensor
,
max_seqlen_k
:
int
,
real_max_k_seqlen
:
int
,
num_split
:
int
,
softmax_scale
:
float
,
s_aux
:
torch
.
Tensor
=
None
,
block_size
:
int
=
64
,
use_per_kv_head_sparse_index
:
bool
=
False
,
):
num_tokens
,
q_h
,
head_size
=
Q
.
shape
batch
=
cu_seqlens_k
.
size
(
0
)
-
1
k_h
=
K
.
size
(
1
)
assert
Q
.
dim
()
==
K
.
dim
()
==
3
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
assert
cu_seqlens_k
.
dim
()
==
1
assert
head_size
in
{
64
,
128
,
256
}
assert
Q
.
is_contiguous
()
assert
K
.
is_contiguous
()
assert
V
.
is_contiguous
()
gqa_group_size
=
q_h
//
k_h
BLOCK_D
=
head_size
BLOCK_N
=
block_size
BLOCK_H
=
64
O
=
torch
.
zeros_like
(
Q
)
S
=
torch
.
zeros
((
batch
,
q_h
,
math
.
ceil
(
max_seqlen_k
/
block_size
)),
dtype
=
Q
.
dtype
,
device
=
Q
.
device
)
def
grid
(
META
):
return
(
batch
,
k_h
)
with
torch
.
cuda
.
device
(
Q
.
device
.
index
):
_fwd_kernel_varlen
[
grid
](
Q
,
K
,
V
,
O
,
S
,
s_aux
,
softmax_scale
,
cu_seqlens_k
,
*
Q
.
stride
(),
*
K
.
stride
(),
*
V
.
stride
(),
*
O
.
stride
(),
*
S
.
stride
(),
gqa_group_size
,
BLOCK_H
=
BLOCK_H
,
BLOCK_N
=
BLOCK_N
,
BLOCK_D
=
BLOCK_D
,
)
if
use_per_kv_head_sparse_index
:
S
=
torch
.
max_pool2d
(
S
,
kernel_size
=
(
gqa_group_size
,
1
),
stride
=
(
gqa_group_size
,
1
))
else
:
S
=
torch
.
max_pool2d
(
S
,
kernel_size
=
(
q_h
,
1
),
stride
=
(
q_h
,
1
))
return
O
,
S
def
test_equal_seqlen_decode_main
(
args
):
"""Test decode kernel with equal sequence lengths"""
print
(
"Testing decode kernel with equal sequence lengths"
)
batch_size
=
args
.
batch_size
q_heads
=
args
.
q_heads
kv_heads
=
args
.
kv_heads
k_seqlen
=
args
.
k_seqlen
real_max_k_seqlen
=
args
.
k_seqlen
head_size
=
args
.
head_size
block_size
=
args
.
block_size
dtype
=
torch
.
bfloat16
if
args
.
dtype
==
"bfloat16"
else
torch
.
float16
# For decode, query is just 1 token per batch
q
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
'cuda'
,
dtype
=
dtype
)
k
=
torch
.
randn
(
batch_size
,
kv_heads
,
k_seqlen
,
head_size
,
device
=
'cuda'
,
dtype
=
dtype
)
v
=
torch
.
randn
(
batch_size
,
kv_heads
,
k_seqlen
,
head_size
,
device
=
'cuda'
,
dtype
=
dtype
)
softmax_scale
=
1.0
/
math
.
sqrt
(
head_size
)
# Generate sink values if needed
sink
=
None
if
args
.
test_sink
:
sink
=
torch
.
randn
(
q_heads
,
device
=
'cuda'
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
print
(
f
"Using sink attention with sink values:
{
sink
}
"
)
# Convert to varlen format for K, V
k_varlen
=
k
.
transpose
(
1
,
2
).
reshape
(
batch_size
*
k_seqlen
,
kv_heads
,
head_size
)
v_varlen
=
v
.
transpose
(
1
,
2
).
reshape
(
batch_size
*
k_seqlen
,
kv_heads
,
head_size
)
# Generate cumulative sequence lengths
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
k_seqlen
,
k_seqlen
,
device
=
'cuda'
,
dtype
=
torch
.
int32
)
max_seqlen_k
=
k_seqlen
print
(
f
"q shape:
{
q
.
shape
}
"
)
print
(
f
"k_varlen shape:
{
k_varlen
.
shape
}
"
)
print
(
f
"v_varlen shape:
{
v_varlen
.
shape
}
"
)
num_tokens
,
q_h
,
head_size
=
q
.
shape
batch
=
cu_seqlens_k
.
size
(
0
)
-
1
k_h
=
k_varlen
.
size
(
1
)
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
args
.
test_sink
)
# Test our decode kernel
O_triton
,
S_triton
=
flash_attn_with_attn_pool_decode
(
q
,
k_varlen
,
v_varlen
,
cu_seqlens_k
,
max_seqlen_k
,
real_max_k_seqlen
,
args
.
num_split
,
softmax_scale
,
s_aux
=
sink
,
block_size
=
block_size
)
O_tilelang
,
S_tilelang
=
flash_attn_with_attn_pool_decode_tilelang
(
q
,
k_varlen
,
v_varlen
,
cu_seqlens_k
,
max_seqlen_k
,
real_max_k_seqlen
,
args
.
num_split
,
softmax_scale
,
s_aux
=
sink
,
block_size
=
block_size
,
tl_kernel
=
tl_kernel
,
)
for
i
in
range
(
batch_size
):
S_tilelang
[
i
,
:,
math
.
ceil
((
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
())
/
block_size
):]
=
0
# Compute torch reference
q_expanded
=
q
.
unsqueeze
(
2
)
# [b, q_heads, 1, head_size]
k_repeat
=
repeat_kv
(
k
,
q_heads
//
kv_heads
)
# [b, q_heads, k_seqlen, head_size]
v_repeat
=
repeat_kv
(
v
,
q_heads
//
kv_heads
)
# [b, q_heads, k_seqlen, head_size]
if
sink
is
None
:
# Standard scaled dot-product attention
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [batch, q_heads, 1, seqlen_k]
attn_weights
=
torch
.
softmax
(
logits
,
dim
=-
1
)
O_torch
=
torch
.
matmul
(
attn_weights
,
v_repeat
).
squeeze
(
2
)
# [batch, q_heads, head_size]
else
:
# s_aux attention
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [batch, q_heads, 1, seqlen_k]
sink_expanded
=
sink
.
view
(
1
,
q_heads
,
1
,
1
)
# [1, q_heads, 1, 1]
logits_max
=
torch
.
max
(
logits
,
dim
=-
1
,
keepdim
=
True
).
values
logits_or_sinks_max
=
torch
.
maximum
(
logits_max
,
sink_expanded
)
sinks
=
torch
.
exp
(
sink_expanded
-
logits_or_sinks_max
)
unnormalized_scores
=
torch
.
exp
(
logits
-
logits_or_sinks_max
)
normalizer
=
unnormalized_scores
.
sum
(
dim
=-
1
,
keepdim
=
True
)
+
sinks
attn_weights
=
unnormalized_scores
/
normalizer
O_torch
=
torch
.
matmul
(
attn_weights
.
to
(
v_repeat
.
dtype
),
v_repeat
).
squeeze
(
2
)
# [batch, q_heads, head_size]
# Compute attention score pooling
attn_score_pooled
=
torch
.
max_pool2d
(
attn_weights
.
squeeze
(
2
),
# [b, q_heads, k_seqlen]
kernel_size
=
(
q_heads
,
block_size
),
stride
=
(
q_heads
,
block_size
),
ceil_mode
=
True
).
to
(
torch
.
float16
)
print
(
"S_tilelang"
,
S_tilelang
)
print
(
"attn_score_pooled"
,
attn_score_pooled
)
max_diff_o
=
torch
.
max
(
torch
.
abs
(
O_triton
-
O_torch
))
max_diff_s
=
torch
.
max
(
torch
.
abs
(
S_triton
-
attn_score_pooled
))
max_diff_o_tilelang
=
torch
.
max
(
torch
.
abs
(
O_tilelang
-
O_torch
))
max_diff_s_tilelang
=
torch
.
max
(
torch
.
abs
(
S_tilelang
-
attn_score_pooled
))
print
(
f
"Max difference in O:
{
max_diff_o
.
item
()
}
"
)
print
(
f
"Max difference in S:
{
max_diff_s
.
item
()
}
"
)
print
(
f
"Max difference in O_tilelang:
{
max_diff_o_tilelang
.
item
()
}
"
)
print
(
f
"Max difference in S_tilelang:
{
max_diff_s_tilelang
.
item
()
}
"
)
assert
torch
.
allclose
(
O_triton
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o
.
item
()
}
"
assert
torch
.
allclose
(
S_triton
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s
.
item
()
}
"
assert
torch
.
allclose
(
O_tilelang
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o_tilelang
.
item
()
}
"
assert
torch
.
allclose
(
S_tilelang
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s_tilelang
.
item
()
}
"
print
(
"✅ All tests passed!"
)
def
test_varlen_decode_main
(
args
):
"""Test decode kernel with variable sequence lengths"""
batch_size
=
args
.
batch_size
q_heads
=
args
.
q_heads
kv_heads
=
args
.
kv_heads
max_k_seqlen
=
args
.
k_seqlen
# Use as max sequence length
real_max_k_seqlen
=
args
.
k_seqlen
head_size
=
args
.
head_size
block_size
=
args
.
block_size
dtype
=
torch
.
bfloat16
if
args
.
dtype
==
"bfloat16"
else
torch
.
float16
print
(
f
"Testing decode kernel with variable sequence lengths (max_k_seqlen=
{
max_k_seqlen
}
)"
)
# Generate sink values if needed
sink
=
None
if
args
.
test_sink
:
sink
=
torch
.
randn
(
q_heads
,
device
=
'cuda'
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
print
(
f
"Using sink attention with sink values:
{
sink
}
"
)
# Generate variable length k sequences
k_seqlens
=
torch
.
randint
(
max_k_seqlen
//
4
,
max_k_seqlen
+
1
,
size
=
(
batch_size
,))
print
(
f
"k_seqlens:
{
k_seqlens
}
"
)
# Generate cumulative sequence lengths for k
cu_seqlens_k
=
torch
.
zeros
(
batch_size
+
1
,
device
=
'cuda'
,
dtype
=
torch
.
int32
)
total_k_tokens
=
0
for
i
in
range
(
batch_size
):
cu_seqlens_k
[
i
]
=
total_k_tokens
total_k_tokens
+=
k_seqlens
[
i
]
cu_seqlens_k
[
batch_size
]
=
total_k_tokens
print
(
f
"cu_seqlens_k:
{
cu_seqlens_k
}
"
)
# Generate tensors - Q is [batch_size, q_heads, head_size] for decode
q_decode
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
'cuda'
,
dtype
=
dtype
)
k_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
'cuda'
,
dtype
=
dtype
)
v_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
'cuda'
,
dtype
=
dtype
)
softmax_scale
=
1.0
/
math
.
sqrt
(
head_size
)
max_seqlen_k
=
int
(
k_seqlens
.
max
())
print
(
f
"Actual max_seqlen_k:
{
max_seqlen_k
}
"
)
print
(
f
"q_decode shape:
{
q_decode
.
shape
}
"
)
print
(
f
"k_varlen shape:
{
k_varlen
.
shape
}
"
)
print
(
f
"v_varlen shape:
{
v_varlen
.
shape
}
"
)
num_tokens
,
q_h
,
head_size
=
q_decode
.
shape
batch
=
cu_seqlens_k
.
size
(
0
)
-
1
k_h
=
k_varlen
.
size
(
1
)
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
args
.
test_sink
)
# Test our decode kernel
O_triton
,
S_triton
=
flash_attn_with_attn_pool_decode
(
q_decode
,
k_varlen
,
v_varlen
,
cu_seqlens_k
,
max_seqlen_k
,
real_max_k_seqlen
,
args
.
num_split
,
softmax_scale
,
s_aux
=
sink
,
block_size
=
block_size
)
O_tilelang
,
S_tilelang
=
flash_attn_with_attn_pool_decode_tilelang
(
q_decode
,
k_varlen
,
v_varlen
,
cu_seqlens_k
,
max_seqlen_k
,
real_max_k_seqlen
,
args
.
num_split
,
softmax_scale
,
s_aux
=
sink
,
block_size
=
block_size
,
tl_kernel
=
tl_kernel
,
)
for
i
in
range
(
batch_size
):
S_tilelang
[
i
,
:,
math
.
ceil
((
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
())
/
block_size
):]
=
0
# Create torch reference - pad tensors for comparison
k_padded_list
=
[]
v_padded_list
=
[]
for
i
in
range
(
batch_size
):
actual_k_len
=
k_seqlens
[
i
]
# Extract and pad k, v for this batch
k_start
=
cu_seqlens_k
[
i
]
k_end
=
cu_seqlens_k
[
i
+
1
]
# Pad to max_seqlen_k
k_padded
=
torch
.
zeros
(
max_seqlen_k
,
kv_heads
,
head_size
,
device
=
'cuda'
,
dtype
=
dtype
)
v_padded
=
torch
.
zeros
(
max_seqlen_k
,
kv_heads
,
head_size
,
device
=
'cuda'
,
dtype
=
dtype
)
k_padded
[:
actual_k_len
]
=
k_varlen
[
k_start
:
k_end
]
v_padded
[:
actual_k_len
]
=
v_varlen
[
k_start
:
k_end
]
k_padded_list
.
append
(
k_padded
)
v_padded_list
.
append
(
v_padded
)
# Stack to create batched tensors [b, max_seqlen, kv_heads, head_size]
k_padded_batched
=
torch
.
stack
(
k_padded_list
,
dim
=
0
).
transpose
(
1
,
2
)
# [b, kv_heads, max_seqlen, head_size]
v_padded_batched
=
torch
.
stack
(
v_padded_list
,
dim
=
0
).
transpose
(
1
,
2
)
# [b, kv_heads, max_seqlen, head_size]
# Expand q to match kv heads: [b, q_heads, 1, head_size]
q_expanded
=
q_decode
.
unsqueeze
(
2
)
# [b, q_heads, 1, head_size]
print
(
f
"q_expanded shape:
{
q_expanded
.
shape
}
"
)
print
(
f
"k_padded_batched shape:
{
k_padded_batched
.
shape
}
"
)
print
(
f
"v_padded_batched shape:
{
v_padded_batched
.
shape
}
"
)
# Compute torch reference
k_repeat
=
repeat_kv
(
k_padded_batched
,
q_heads
//
kv_heads
)
# [b, q_heads, max_seqlen, head_size]
v_repeat
=
repeat_kv
(
v_padded_batched
,
q_heads
//
kv_heads
)
# [b, q_heads, max_seqlen, head_size]
if
sink
is
None
:
# Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen]
attn_score
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [b, q_heads, 1, max_seqlen]
# Apply sequence length masking
for
i
in
range
(
batch_size
):
actual_k_len
=
k_seqlens
[
i
]
attn_score
[
i
,
:,
:,
actual_k_len
:]
=
float
(
'-inf'
)
attn_weights
=
attn_score
.
softmax
(
dim
=-
1
)
# [b, q_heads, 1, max_seqlen]
# Mask out invalid positions
for
i
in
range
(
batch_size
):
actual_k_len
=
k_seqlens
[
i
]
attn_weights
[
i
,
:,
:,
actual_k_len
:]
=
0.0
# Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size]
O_torch
=
torch
.
matmul
(
attn_weights
,
v_repeat
)
# [b, q_heads, 1, head_size]
else
:
# s_aux attention
logits
=
torch
.
matmul
(
q_expanded
,
k_repeat
.
transpose
(
-
2
,
-
1
))
*
softmax_scale
# [b, q_heads, 1, max_seqlen]
# Apply sequence length masking
for
i
in
range
(
batch_size
):
actual_k_len
=
k_seqlens
[
i
]
logits
[
i
,
:,
:,
actual_k_len
:]
=
float
(
'-inf'
)
sink_expanded
=
sink
.
view
(
1
,
q_heads
,
1
,
1
)
# [1, q_heads, 1, 1]
logits_max
=
torch
.
max
(
logits
,
dim
=-
1
,
keepdim
=
True
).
values
logits_or_sinks_max
=
torch
.
maximum
(
logits_max
,
sink_expanded
)
sinks
=
torch
.
exp
(
sink_expanded
-
logits_or_sinks_max
)
unnormalized_scores
=
torch
.
exp
(
logits
-
logits_or_sinks_max
)
normalizer
=
unnormalized_scores
.
sum
(
dim
=-
1
,
keepdim
=
True
)
+
sinks
attn_weights
=
unnormalized_scores
/
normalizer
# Mask out invalid positions
for
i
in
range
(
batch_size
):
actual_k_len
=
k_seqlens
[
i
]
attn_weights
[
i
,
:,
:,
actual_k_len
:]
=
0.0
# Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size]
O_torch
=
torch
.
matmul
(
attn_weights
.
to
(
v_repeat
.
dtype
),
v_repeat
)
# [b, q_heads, 1, head_size]
O_torch
=
O_torch
.
squeeze
(
2
)
# [b, q_heads, head_size]
# Compute attention score pooling for S
attn_score_pooled
=
torch
.
max_pool2d
(
attn_weights
.
squeeze
(
2
),
# [b, q_heads, max_seqlen]
kernel_size
=
(
q_heads
,
block_size
),
stride
=
(
q_heads
,
block_size
),
ceil_mode
=
True
).
to
(
dtype
=
torch
.
float16
)
# [b, 1, ceil(max_seqlen/block_size)]
print
(
f
"O_triton shape:
{
O_triton
.
shape
}
"
)
print
(
f
"O_tilelang shape:
{
O_tilelang
.
shape
}
"
)
print
(
f
"O_torch shape:
{
O_torch
.
shape
}
"
)
print
(
f
"S_triton shape:
{
S_triton
.
shape
}
"
)
print
(
f
"S_tilelang shape:
{
S_tilelang
.
shape
}
"
)
print
(
f
"attn_score_pooled shape:
{
attn_score_pooled
.
shape
}
"
)
# Compare results
max_diff_o
=
torch
.
max
(
torch
.
abs
(
O_triton
-
O_torch
))
max_diff_o_tl
=
torch
.
max
(
torch
.
abs
(
O_tilelang
-
O_torch
))
print
(
f
"Max difference in O:
{
max_diff_o
.
item
()
}
"
)
print
(
f
"Max difference in O_tilelang:
{
max_diff_o_tl
.
item
()
}
"
)
max_diff_s
=
torch
.
max
(
torch
.
abs
(
S_triton
-
attn_score_pooled
))
max_diff_s_tl
=
torch
.
max
(
torch
.
abs
(
S_tilelang
[:,
:,
:
math
.
ceil
(
max_seqlen_k
/
block_size
)]
-
attn_score_pooled
))
print
(
f
"Max difference in S:
{
max_diff_s
.
item
()
}
"
)
print
(
f
"Max difference in S_tilelang:
{
max_diff_s_tl
.
item
()
}
"
)
assert
torch
.
allclose
(
O_triton
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o
.
item
()
}
"
assert
torch
.
allclose
(
S_triton
,
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s
.
item
()
}
"
assert
torch
.
allclose
(
O_tilelang
,
O_torch
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Output mismatch:
{
max_diff_o_tl
.
item
()
}
"
assert
torch
.
allclose
(
S_tilelang
[:,
:,
:
math
.
ceil
(
max_seqlen_k
/
block_size
)],
attn_score_pooled
,
atol
=
1e-2
,
rtol
=
1e-2
),
f
"Score mismatch:
{
max_diff_s_tl
.
item
()
}
"
print
(
"✅ All tests passed!"
)
def
do_bench
(
fn
,
*
args
,
warmup
=
10
,
rep
=
10
,
**
kwargs
):
"""
Do benchmark for a function.
"""
start_event
=
[
torch
.
cuda
.
Event
(
enable_timing
=
True
)
for
i
in
range
(
rep
)]
end_event
=
[
torch
.
cuda
.
Event
(
enable_timing
=
True
)
for
i
in
range
(
rep
)]
for
_
in
range
(
warmup
):
fn
(
*
args
,
**
kwargs
)
torch
.
cuda
.
synchronize
()
for
i
in
range
(
rep
):
start_event
[
i
].
record
()
fn
(
*
args
,
**
kwargs
)
end_event
[
i
].
record
()
torch
.
cuda
.
synchronize
()
# Record clocks
times
=
torch
.
tensor
(
[
s
.
elapsed_time
(
e
)
for
s
,
e
in
zip
(
start_event
,
end_event
)],
dtype
=
torch
.
float
,
)
return
times
.
mean
().
item
()
def
speed_benchmark_decode_comparison
(
args
):
"""Speed benchmark for decode kernel"""
batch_size
=
args
.
batch_size
q_heads
=
args
.
q_heads
kv_heads
=
args
.
kv_heads
max_k_seqlen
=
args
.
k_seqlen
head_size
=
args
.
head_size
block_size
=
args
.
block_size
dtype
=
torch
.
bfloat16
if
args
.
dtype
==
"bfloat16"
else
torch
.
float16
print
(
"
\n
=== Decode Speed Benchmark Comparison ==="
)
print
(
"Configuration:"
)
print
(
f
" Batch size:
{
batch_size
}
"
)
print
(
f
" Q heads:
{
q_heads
}
, KV heads:
{
kv_heads
}
"
)
print
(
f
" Max K sequence length:
{
max_k_seqlen
}
"
)
print
(
f
" Head size:
{
head_size
}
"
)
print
(
f
" Block size:
{
block_size
}
"
)
print
(
f
" Data type:
{
dtype
}
"
)
print
(
f
" Variable lengths:
{
args
.
test_varlen
}
"
)
print
(
f
" s_aux attention:
{
args
.
test_sink
}
"
)
print
()
# Generate input data
if
args
.
test_varlen
:
k_seqlens
=
torch
.
randint
(
max_k_seqlen
//
4
,
max_k_seqlen
+
1
,
size
=
(
batch_size
,))
else
:
k_seqlens
=
torch
.
full
((
batch_size
,),
max_k_seqlen
,
dtype
=
int
)
# Generate cumulative sequence lengths for k
cu_seqlens_k
=
torch
.
zeros
(
batch_size
+
1
,
device
=
'cuda'
,
dtype
=
torch
.
int32
)
total_k_tokens
=
0
for
i
in
range
(
batch_size
):
cu_seqlens_k
[
i
]
=
total_k_tokens
total_k_tokens
+=
k_seqlens
[
i
]
cu_seqlens_k
[
batch_size
]
=
total_k_tokens
# Generate tensors
q_decode
=
torch
.
randn
(
batch_size
,
q_heads
,
head_size
,
device
=
'cuda'
,
dtype
=
dtype
)
k_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
'cuda'
,
dtype
=
dtype
)
v_varlen
=
torch
.
randn
(
total_k_tokens
,
kv_heads
,
head_size
,
device
=
'cuda'
,
dtype
=
dtype
)
softmax_scale
=
1.0
/
math
.
sqrt
(
head_size
)
max_seqlen_k
=
int
(
k_seqlens
.
max
())
# Generate sink values if needed
sink
=
None
if
args
.
test_sink
:
sink
=
torch
.
randn
(
q_heads
,
device
=
'cuda'
,
dtype
=
torch
.
float32
)
*
0.1
# Small sink values
print
(
" Using sink attention with sink values"
)
print
(
"Setup complete:"
)
print
(
f
" Total K tokens:
{
total_k_tokens
}
"
)
print
(
f
" Actual max K seq len:
{
max_seqlen_k
}
"
)
if
args
.
test_varlen
:
print
(
f
" K sequence lengths:
{
k_seqlens
.
tolist
()
}
"
)
# Warmup
num_tokens
,
q_h
,
head_size
=
q_decode
.
shape
batch
=
cu_seqlens_k
.
size
(
0
)
-
1
k_h
=
k_varlen
.
size
(
1
)
tl_kernel
=
flashattn
(
batch
,
q_h
,
k_h
,
args
.
k_seqlen
,
cu_seqlens_k
[
-
1
].
item
(),
head_size
,
args
.
test_sink
)
# Benchmark
print
(
"⚡ Benchmarking Tilelang kernel (100 iterations)..."
)
tilelang_time
=
do_bench
(
flash_attn_with_attn_pool_decode_tilelang
,
q_decode
,
k_varlen
,
v_varlen
,
cu_seqlens_k
,
max_seqlen_k
,
args
.
k_seqlen
,
1
,
softmax_scale
,
sink
,
block_size
,
False
,
tl_kernel
,
)
print
(
f
"Average decode kernel time Tilelang:
{
tilelang_time
:.
3
f
}
ms"
)
# Benchmark
print
(
"⚡ Benchmarking Triton kernel (100 iterations)..."
)
triton_time
=
do_bench
(
flash_attn_with_attn_pool_decode
,
q_decode
,
k_varlen
,
v_varlen
,
cu_seqlens_k
,
max_seqlen_k
,
args
.
k_seqlen
,
1
,
softmax_scale
,
sink
,
block_size
)
print
(
f
"Average decode kernel time Triton:
{
triton_time
:.
3
f
}
ms"
)
print
(
f
"Speedup:
{
(
triton_time
/
tilelang_time
):.
3
f
}
"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
'Flash Attention Decode with Attention Pooling'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
1
,
help
=
'Batch size'
)
parser
.
add_argument
(
'--q_heads'
,
type
=
int
,
default
=
32
,
help
=
'Number of query heads'
)
parser
.
add_argument
(
'--kv_heads'
,
type
=
int
,
default
=
8
,
help
=
'Number of key-value heads'
)
parser
.
add_argument
(
'--k_seqlen'
,
type
=
int
,
default
=
8192
,
help
=
'Key sequence length'
)
parser
.
add_argument
(
'--head_size'
,
type
=
int
,
default
=
128
,
choices
=
[
64
,
128
,
256
],
help
=
'Head dimension'
)
parser
.
add_argument
(
'--block_size'
,
type
=
int
,
default
=
64
,
help
=
'Block size for computation'
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'bfloat16'
,
choices
=
[
'float16'
,
'bfloat16'
],
help
=
'Data type'
)
parser
.
add_argument
(
'--test_varlen'
,
action
=
'store_true'
,
help
=
'Test with truly variable sequence lengths'
)
parser
.
add_argument
(
'--test_sink'
,
action
=
'store_true'
,
help
=
'Test with sink attention mechanism'
)
parser
.
add_argument
(
'--benchmark'
,
action
=
'store_true'
,
help
=
'Run speed benchmark'
)
parser
.
add_argument
(
'--num_split'
,
type
=
int
,
default
=
1
,
choices
=
[
1
,
16
],
help
=
'Number of splits'
)
args
=
parser
.
parse_args
()
args
.
test_sink
=
True
args
.
test_varlen
=
False
args
.
dtype
=
'float16'
args
.
num_split
=
1
if
args
.
benchmark
:
speed_benchmark_decode_comparison
(
args
)
elif
args
.
test_varlen
:
test_varlen_decode_main
(
args
)
else
:
test_equal_seqlen_decode_main
(
args
)
examples/flash_decoding/example_mha_inference.py
View file @
bbbf4207
...
...
@@ -302,9 +302,7 @@ def flash_split_ref(Q, K, V, causal):
3
),
gacc_o
.
to
(
torch
.
float16
).
permute
(
1
,
2
,
3
,
0
,
4
)
def
main
():
BATCH
,
H
,
Q_CTX
,
KV_CTX
,
D_HEAD
=
1
,
32
,
128
,
8192
,
128
causal
=
False
def
main
(
BATCH
=
1
,
H
=
32
,
Q_CTX
=
128
,
KV_CTX
=
8192
,
D_HEAD
=
128
,
causal
=
False
):
flops_per_matmul
=
2.0
*
BATCH
*
H
*
Q_CTX
*
KV_CTX
*
D_HEAD
total_flops
=
2
*
flops_per_matmul
if
causal
:
...
...
examples/flash_decoding/test_example_flash_decoding.py
View file @
bbbf4207
...
...
@@ -12,7 +12,7 @@ def test_example_example_gqa_decode():
def
test_example_example_mha_inference
():
example_mha_inference
.
main
()
example_mha_inference
.
main
(
BATCH
=
1
,
H
=
32
,
Q_CTX
=
128
,
KV_CTX
=
2048
,
D_HEAD
=
128
,
causal
=
False
)
if
__name__
==
"__main__"
:
...
...
examples/gdn/example_chunk_o_bwd.py
View file @
bbbf4207
...
...
@@ -7,8 +7,6 @@ import tilelang
import
tilelang.language
as
T
from
tilelang.engine.callback
import
register_cuda_postproc_callback
# noqa: F401
print
(
tilelang
.
__file__
)
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
...
...
@@ -256,8 +254,9 @@ def tilelang_chunk_o_bwd_dqkwg(
# for i_kv in T.Parallel(block_DK * block_DV):
# dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV]
for
i_kv
in
T
.
Parallel
(
block_DK
*
block_DV
):
i_k
,
i_v
=
i_kv
//
block_DV
,
i_kv
%
block_DV
dg_last_fragment
[
i_kv
]
=
h_shared
[
i_k
,
i_v
]
*
dh_shared
[
i_k
,
i_v
]
dg_last_fragment
[
i_kv
]
=
h_shared
[
i_kv
//
block_DV
,
i_kv
%
block_DV
]
*
dh_shared
[
i_kv
//
block_DV
,
i_kv
%
block_DV
]
T
.
reduce_sum
(
dg_last_fragment
,
dg_last_fragment_scalar
,
dim
=-
1
,
clear
=
False
)
dg_last_local
[
0
]
+=
dg_last_fragment_scalar
[
0
]
...
...
examples/gemm/README.md
View file @
bbbf4207
...
...
@@ -4,20 +4,23 @@ TileLang is a domain-specific language designed to simplify the process of writi
## Table of Contents
1.
[
Getting Started
](
#getting-started
)
2.
[
Simple GEMM Example
](
#simple-gemm-example
)
-
[
Table of Contents
](
#table-of-contents
)
-
[
Getting Started
](
#getting-started
)
-
[
Prerequisites
](
#prerequisites
)
-
[
Installation
](
#installation
)
-
[
Simple GEMM Example
](
#simple-gemm-example
)
-
[
Code Walkthrough
](
#code-walkthrough
)
-
[
Compiling and Profiling
](
#compiling-and-profiling
)
3.
[
Advanced GEMM Features
](
#advanced-gemm-features
)
-
[
Advanced GEMM Features
](
#advanced-gemm-features
)
-
[
Custom Memory Layout / Swizzling
](
#custom-memory-layout--swizzling
)
-
[
Parallel Copy and Auto-Pipelining
](
#parallel-copy-and-auto-pipelining
)
-
[
Rasterization for L2 Cache Locality
](
#rasterization-for-l2-cache-locality
)
4.
[
Enhanced GEMM Example with Annotations
](
#enhanced-gemm-example-with-annotations
)
5.
[
Verifying Correctness
](
#verifying-correctness
)
6.
[
Fine-grained MMA Computations
](
#fine-grained-mma-computations
)
-
[
Enhanced GEMM Example with Annotations
](
#enhanced-gemm-example-with-annotations
)
-
[
Verifying Correctness
](
#verifying-correctness
)
-
[
Fine-grained MMA Computations
](
#fine-grained-mma-computations
)
-
[
Example Workflow
](
#example-workflow
)
-
[
Summary
](
#summary
)
7.
[
References
](
#references
)
-
[
References
](
#references
)
---
...
...
examples/linear_attention/example_linear_attn_fwd.py
View file @
bbbf4207
...
...
@@ -80,7 +80,6 @@ def tl_fused_chunk_fwd_kernel(
T
.
atomic_add
(
O
[
i_b
,
i
*
chunk_size
:(
i
+
1
)
*
chunk_size
,
i_h
,
i_v
*
BV
:(
i_v
+
1
)
*
BV
],
o_shared
)
#TODO: consider using vectorized atomic add or tma reduce for sm90
# Output final state
T
.
copy
(
h
,
final_state
[
i_b
,
i_h
,
i_k
*
BK
:(
i_k
+
1
)
*
BK
,
i_v
*
BV
:(
i_v
+
1
)
*
BV
])
...
...
@@ -91,6 +90,7 @@ def tl_fused_chunk_fwd_kernel(
def
tl_fused_chunk_fwd
(
q
,
k
,
v
):
B
,
S
,
H
,
D
=
q
.
shape
kernel
=
tl_fused_chunk_fwd_kernel
(
B
,
S
,
H
,
D
,
D
)
print
(
kernel
.
get_kernel_source
())
o
=
torch
.
zeros
((
B
,
S
,
H
,
D
),
device
=
'cuda'
,
dtype
=
torch
.
float32
)
h
=
kernel
(
q
,
k
,
v
,
o
)
return
o
,
h
...
...
examples/linear_attention/example_retention_fwd.py
View file @
bbbf4207
...
...
@@ -51,13 +51,6 @@ def chunk_retention_fwd_kernel(
o
=
T
.
alloc_fragment
([
chunk_size
,
BV
],
accum_dtype
)
T
.
clear
(
h
)
T
.
annotate_layout
({
q
:
tl
.
layout
.
make_swizzled_layout
(
q
),
k
:
tl
.
layout
.
make_swizzled_layout
(
k
),
v
:
tl
.
layout
.
make_swizzled_layout
(
v
),
h_shared
:
tl
.
layout
.
make_swizzled_layout
(
h_shared
),
s_shared
:
tl
.
layout
.
make_swizzled_layout
(
s_shared
),
})
T
.
use_swizzle
(
10
)
for
i
in
T
.
Pipelined
(
0
,
NT
):
...
...
examples/norm/rms_norm.py
View file @
bbbf4207
...
...
@@ -21,7 +21,7 @@ def rms_norm_splitk(M, N, blk_m, blk_k):
A_local
[
i
,
j
]
+=
A_shared
[
i
,
j
]
*
A_shared
[
i
,
j
]
T
.
reduce_sum
(
A_local
,
A_powsum
,
dim
=
1
)
for
i
in
T
.
Parallel
(
blk_m
):
A_powsum
[
i
]
=
T
.
rsqrt
(
A_powsum
[
i
]
/
N
)
+
1e-12
A_powsum
[
i
]
=
T
.
rsqrt
(
A_powsum
[
i
]
/
N
+
1e-12
)
for
k
in
range
(
num_k_step
):
# reverse, better cache hit rate
...
...
@@ -51,7 +51,7 @@ def rms_norm(M, N, blk_m):
A_pow_local
[
i
,
j
]
=
A_local
[
i
,
j
]
*
A_local
[
i
,
j
]
T
.
reduce_sum
(
A_pow_local
,
A_powsum
,
dim
=
1
)
for
i
in
T
.
Parallel
(
blk_m
):
A_powsum
[
i
]
=
T
.
rsqrt
(
A_powsum
[
i
]
/
N
)
+
1e-12
A_powsum
[
i
]
=
T
.
rsqrt
(
A_powsum
[
i
]
/
N
+
1e-12
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
N
):
A_local
[
i
,
j
]
*=
A_powsum
[
i
]
T
.
copy
(
A_local
,
B
[
bx
*
blk_m
:(
bx
+
1
)
*
blk_m
,
:])
...
...
examples/norm/test_rms_norm.py
View file @
bbbf4207
...
...
@@ -22,7 +22,7 @@ def rms_norm_splitk(M, N, blk_m, blk_k):
A_local
[
i
,
j
]
+=
A_shared
[
i
,
j
]
*
A_shared
[
i
,
j
]
T
.
reduce_sum
(
A_local
,
A_powsum
,
dim
=
1
)
for
i
in
T
.
Parallel
(
blk_m
):
A_powsum
[
i
]
=
T
.
rsqrt
(
A_powsum
[
i
]
/
N
)
+
1e-12
A_powsum
[
i
]
=
T
.
rsqrt
(
A_powsum
[
i
]
/
N
+
1e-12
)
for
k
in
range
(
num_k_step
):
# reverse, better cache hit rate
...
...
@@ -51,7 +51,7 @@ def rms_norm(M, N, blk_m):
A_pow_local
[
i
,
j
]
=
A_local
[
i
,
j
]
*
A_local
[
i
,
j
]
T
.
reduce_sum
(
A_pow_local
,
A_powsum
,
dim
=
1
)
for
i
in
T
.
Parallel
(
blk_m
):
A_powsum
[
i
]
=
T
.
rsqrt
(
A_powsum
[
i
]
/
N
)
+
1e-12
A_powsum
[
i
]
=
T
.
rsqrt
(
A_powsum
[
i
]
/
N
+
1e-12
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
N
):
A_local
[
i
,
j
]
*=
A_powsum
[
i
]
T
.
copy
(
A_local
,
B
[
bx
*
blk_m
:(
bx
+
1
)
*
blk_m
,
:])
...
...
examples/plot_layout/fragment_mfma_load_a.py
0 → 100644
View file @
bbbf4207
import
tilelang.language
as
T
from
typing
import
Literal
,
Callable
from
tvm.tir
import
IndexMap
from
tilelang.intrinsics.utils
import
get_mma_micro_size
from
tilelang.intrinsics.mfma_layout
import
(
shared_16x4_to_local_64x1_layout_A
,
shared_16x16_to_local_64x4_layout_A
,
shared_16x32_to_local_64x8_layout_A
,
shared_16x64_to_local_64x16_layout_A
,
)
def
make_mfma_load_base_layout
(
dtype
:
str
=
"float16"
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
,
k_dim
:
int
=
16
,
transposed
:
bool
=
False
)
->
T
.
Fragment
:
"""
Create a layout function for storing MFMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mfma_store_layout` to
map fragment indices to threads and local indices.
Parameters
----------
dtype : str
The data type of the matrix.
matrix : Literal["A", "B"]
The mfma operand to be loaded.
k_dim : int
The k dimension of the mfma.
transposed : bool
Whether the matrix is transposed, by default False.
Returns
-------
T.Fragment
Describes how threads and indices in fragment are laid out.
"""
assert
matrix
in
[
"A"
,
"B"
],
"matrix should be either A or B"
# s represents spatial axis
# r represents reduction axis
# sr represents the two dims are spatial + reduction
# rs represents the two dims are reduction + spatial
transform_func_sr_a
:
Callable
=
None
transform_func_sr_b
:
Callable
=
None
if
k_dim
==
4
:
transform_func_sr_a
=
shared_16x4_to_local_64x1_layout_A
transform_func_sr_b
=
shared_16x4_to_local_64x1_layout_A
elif
k_dim
==
16
:
transform_func_sr_a
=
shared_16x16_to_local_64x4_layout_A
transform_func_sr_b
=
shared_16x16_to_local_64x4_layout_A
elif
k_dim
==
32
:
transform_func_sr_a
=
shared_16x32_to_local_64x8_layout_A
transform_func_sr_b
=
shared_16x32_to_local_64x8_layout_A
elif
k_dim
==
64
:
transform_func_sr_a
=
shared_16x64_to_local_64x16_layout_A
transform_func_sr_b
=
shared_16x64_to_local_64x16_layout_A
else
:
raise
ValueError
(
"k_dim must be 4 or 16 or 32 or 64 currently"
)
is_sr_conditions
=
[
False
]
is_sr_conditions
.
append
(
matrix
==
"A"
and
not
transposed
)
is_sr_conditions
.
append
(
matrix
==
"B"
and
transposed
)
is_sr_axis_order
=
any
(
is_sr_conditions
)
micro_size_x
,
micro_size_y
,
micro_size_k
=
get_mma_micro_size
(
dtype
)
# the layout of mma.sync is row.col.
# so the b matrix expected a transposed basic layout
transform_func
:
Callable
=
None
if
matrix
==
"A"
:
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
micro_size_s
,
micro_size_r
=
micro_size_x
,
micro_size_k
elif
matrix
==
"B"
:
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
j
,
i
)
micro_size_s
,
micro_size_r
=
micro_size_k
,
micro_size_y
else
:
raise
ValueError
(
f
"Unsupported matrix
{
matrix
}
"
)
inverse_mma_load_layout
=
IndexMap
.
from_func
(
transform_func
,
index_dtype
=
"int32"
)
def
forward_thread
(
i
:
int
,
j
:
int
)
->
int
:
"""
Given the row index `i` and column index `j` in the fragment,
"""
lane_id
,
_
=
inverse_mma_load_layout
.
map_indices
([
i
,
j
])
return
lane_id
def
forward_index
(
i
:
int
,
j
:
int
)
->
int
:
"""
Given the row index `i` and column index `j` in the fragment,
"""
_
,
local_id
=
inverse_mma_load_layout
.
map_indices
([
i
,
j
])
return
local_id
base_fragment
=
T
.
Fragment
(
[
micro_size_s
,
micro_size_r
]
if
is_sr_axis_order
else
[
micro_size_r
,
micro_size_s
],
forward_thread_fn
=
forward_thread
,
forward_index_fn
=
forward_index
,
)
return
base_fragment
block_rows
=
2
block_cols
=
2
warp_rows
=
2
warp_cols
=
2
chunk
=
2
from
tilelang.tools
import
plot_layout
# ldmatrix layout 16x16
base_layout
=
make_mfma_load_base_layout
(
dtype
=
"float16"
,
matrix
=
"A"
,
transposed
=
False
)
print
(
base_layout
)
plot_layout
(
base_layout
,
name
=
"base_layout"
)
# warp layout 32x32
warp_layout
=
base_layout
.
repeat
([
warp_rows
,
warp_cols
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
print
(
warp_layout
)
plot_layout
(
warp_layout
,
name
=
"warp_layout"
)
# block layout 64x32
block_layout
=
warp_layout
.
repeat
([
block_rows
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
block_cols
)
print
(
block_layout
)
plot_layout
(
block_layout
,
name
=
"block_layout"
)
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py
View file @
bbbf4207
import
tilelang
import
tilelang.language
as
T
tilelang
.
disable_cache
()
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
...
...
@@ -52,11 +54,14 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
def
main
(
M
=
16384
,
N
=
16384
,
K
=
16384
):
tilelang
.
disable_cache
()
block_M
=
128
block_N
=
128
block_K
=
64
jit_kernel
=
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
)
print
(
jit_kernel
.
get_kernel_source
())
import
torch
a
=
torch
.
randn
(
M
,
K
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
...
...
format.sh
View file @
bbbf4207
...
...
@@ -29,10 +29,7 @@ ALL_FILES=''
ONLY_CHANGED
=
''
FILES
=()
if
((
$#
==
0
))
;
then
if
[[
-n
"
$(
git status
--porcelain
--ignore-submodules
--untracked-files
=
no
)
"
]]
;
then
echo
"Detected uncommitted changes. Please commit or stash them before running
$0
."
>
&2
exit
1
fi
# Default: allow dirty workspace; run on changed files (committed + worktree)
ONLY_CHANGED
=
'true'
else
while
((
$#
>
0
))
;
do
...
...
@@ -78,14 +75,17 @@ if [[ -n "${ALL_FILES}" ]]; then
echo
"Checking all files..."
>
&2
elif
[[
-n
"
${
ONLY_CHANGED
}
"
]]
;
then
MERGE_BASE
=
"
$(
get_merge_base
)
"
echo
"Checking changed files
compared to
merge base (
${
MERGE_BASE
}
)..."
>
&2
echo
"Checking changed files
vs
merge base (
${
MERGE_BASE
}
)
and working tree
..."
>
&2
elif
[[
"
${#
FILES
[@]
}
"
-gt
0
]]
;
then
echo
"Checking specified files:
${
FILES
[*]
}
..."
>
&2
fi
# Some systems set pip's default to --user, which breaks isolated virtualenvs.
export
PIP_USER
=
0
# If pre-commit is not installed, install it.
if
!
python3
-m
pre_commit
--version
&>/dev/null
;
then
python3
-m
pip
install
pre-commit
python3
-m
pip
install
pre-commit
--user
fi
echo
'tile-lang pre-commit: Check Start'
...
...
@@ -93,7 +93,17 @@ echo 'tile-lang pre-commit: Check Start'
if
[[
-n
"
${
ALL_FILES
}
"
]]
;
then
python3
-m
pre_commit run
--all-files
elif
[[
-n
"
${
ONLY_CHANGED
}
"
]]
;
then
python3
-m
pre_commit run
--from-ref
"
${
MERGE_BASE
}
"
--to-ref
HEAD
# Collect changed files (committed since merge-base + current worktree)
CHANGED_FILES
=
"
$(
git diff
--name-only
--diff-filter
=
ACM
"
${
MERGE_BASE
}
"
2>/dev/null
||
true
)
"
if
[[
-n
"
${
CHANGED_FILES
}
"
]]
;
then
echo
"Running pre-commit on changed files:"
echo
"
${
CHANGED_FILES
}
"
# Convert newline-separated files to space-separated and run pre-commit once
CHANGED_FILES_SPACE
=
"
$(
echo
"
${
CHANGED_FILES
}
"
|
tr
'\n'
' '
)
"
python3
-m
pre_commit run
--files
${
CHANGED_FILES_SPACE
}
else
echo
"No files changed relative to merge base and worktree. Skipping pre-commit."
fi
elif
[[
"
${#
FILES
[@]
}
"
-gt
0
]]
;
then
python3
-m
pre_commit run
--files
"
${
FILES
[@]
}
"
fi
...
...
@@ -105,7 +115,7 @@ echo 'tile-lang clang-tidy: Check Start'
if
[[
-x
"
$(
command
-v
run-clang-tidy
)
"
]]
;
then
# Check if clang-tidy is available
if
[[
!
-x
"
$(
command
-v
clang-tidy
)
"
]]
;
then
python3
-m
pip
install
--upgrade
--requirements
"
${
ROOT
}
/requirements-lint.txt"
python3
-m
pip
install
--upgrade
--requirements
"
${
ROOT
}
/requirements-lint.txt"
--user
fi
# Get clang-tidy version
CLANG_TIDY_VERSION
=
"
$(
clang-tidy
--version
|
head
-n1
|
awk
'{print $4}'
)
"
...
...
maint/gemm_v2/correctness_evaluation.py
0 → 100644
View file @
bbbf4207
# pytest correctness_evaluation.py -n 32
import
pytest
from
tilelang
import
tvm
as
tvm
import
tilelang.testing
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
accum_dtype
,
num_stages
,
threads
,
):
A_shape
=
(
K
,
M
)
if
trans_A
else
(
M
,
K
)
B_shape
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
)
A_shared_shape
=
(
block_K
,
block_M
)
if
trans_A
else
(
block_M
,
block_K
)
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
if
trans_A
:
T
.
copy
(
A
[
k
*
block_K
,
by
*
block_M
],
A_shared
)
else
:
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
if
trans_B
:
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
else
:
T
.
copy
(
B
[
k
*
block_K
,
bx
*
block_N
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
,
trans_A
,
trans_B
)
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
return
main
def
_compile_and_check
(
program
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
):
kernel
=
tilelang
.
compile
(
program
,
out_idx
=
[
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
# tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False,
})
print
(
kernel
.
get_kernel_source
())
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
def
ref_program
(
A
,
B
):
import
torch
if
trans_A
:
A
=
A
.
T
if
trans_B
:
B
=
B
.
T
if
in_dtype
==
"float32"
:
A
=
(
A
.
view
(
torch
.
int32
)
-
0x1000
).
view
(
torch
.
float32
)
B
=
(
B
.
view
(
torch
.
int32
)
-
0x1000
).
view
(
torch
.
float32
)
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
C
.
to
(
torch
.
__getattribute__
(
out_dtype
))
return
C
profiler
.
assert_allclose
(
ref_program
,
atol
=
1e-2
,
rtol
=
1e-2
)
print
(
"assert_allclose"
)
def
run_gemm
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
2
,
num_threads
=
128
,
):
if
block_N
>=
256
or
block_M
>=
256
or
block_K
>=
256
:
num_stages
=
0
program
=
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
num_stages
,
num_threads
,
)
_compile_and_check
(
program
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
)
def
matmul_rs
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
accum_dtype
,
num_stages
,
threads
,
):
A_shape
=
(
K
,
M
)
if
trans_A
else
(
M
,
K
)
B_shape
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
)
A_shared_shape
=
(
block_K
,
block_M
)
if
trans_A
else
(
block_M
,
block_K
)
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
A_frag_shape
=
A_shared_shape
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
A_frag
=
T
.
alloc_fragment
(
A_frag_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
if
trans_A
:
T
.
copy
(
A
[
k
*
block_K
,
by
*
block_M
],
A_shared
)
else
:
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
if
trans_B
:
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
else
:
T
.
copy
(
B
[
k
*
block_K
,
bx
*
block_N
],
B_shared
)
T
.
copy
(
A_shared
,
A_frag
)
T
.
gemm_v2
(
A_frag
,
B_shared
,
C_local
,
trans_A
,
trans_B
)
# T.gemm(A_frag, B_shared, C_local, trans_A, trans_B)
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
return
main
def
run_gemm_rs
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
2
,
num_threads
=
128
,
):
if
block_N
>=
256
or
block_M
>=
256
or
block_K
>=
256
:
num_stages
=
0
program
=
matmul_rs
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
num_stages
,
num_threads
,
)
_compile_and_check
(
program
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
)
def
matmul_sr
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
accum_dtype
,
num_stages
,
threads
,
):
A_shape
=
(
K
,
M
)
if
trans_A
else
(
M
,
K
)
B_shape
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
)
A_shared_shape
=
(
block_K
,
block_M
)
if
trans_A
else
(
block_M
,
block_K
)
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
B_frag_shape
=
B_shared_shape
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
B_frag
=
T
.
alloc_fragment
(
B_frag_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
if
trans_A
:
T
.
copy
(
A
[
k
*
block_K
,
by
*
block_M
],
A_shared
)
else
:
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
if
trans_B
:
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
else
:
T
.
copy
(
B
[
k
*
block_K
,
bx
*
block_N
],
B_shared
)
T
.
copy
(
B_shared
,
B_frag
)
T
.
gemm_v2
(
A_shared
,
B_frag
,
C_local
,
trans_A
,
trans_B
)
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
return
main
def
run_gemm_sr
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
2
,
num_threads
=
128
,
):
if
block_N
>=
256
or
block_M
>=
256
or
block_K
>=
256
:
num_stages
=
0
program
=
matmul_sr
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
num_stages
,
num_threads
,
)
_compile_and_check
(
program
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
)
def
matmul_rr
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
accum_dtype
,
num_stages
,
threads
,
):
A_shape
=
(
K
,
M
)
if
trans_A
else
(
M
,
K
)
B_shape
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
)
A_shared_shape
=
(
block_K
,
block_M
)
if
trans_A
else
(
block_M
,
block_K
)
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
A_frag_shape
=
A_shared_shape
B_frag_shape
=
B_shared_shape
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
A_frag
=
T
.
alloc_fragment
(
A_frag_shape
,
in_dtype
)
B_frag
=
T
.
alloc_fragment
(
B_frag_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
if
trans_A
:
T
.
copy
(
A
[
k
*
block_K
,
by
*
block_M
],
A_shared
)
else
:
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
if
trans_B
:
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
else
:
T
.
copy
(
B
[
k
*
block_K
,
bx
*
block_N
],
B_shared
)
T
.
copy
(
A_shared
,
A_frag
)
T
.
copy
(
B_shared
,
B_frag
)
T
.
gemm_v2
(
A_frag
,
B_frag
,
C_local
,
trans_A
,
trans_B
)
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
return
main
def
run_gemm_rr
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
2
,
num_threads
=
128
,
):
if
block_N
>=
256
or
block_M
>=
256
or
block_K
>=
256
:
num_stages
=
0
program
=
matmul_rr
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
num_stages
,
num_threads
,
)
_compile_and_check
(
program
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
)
M_VALUES
=
[
64
,
128
,
256
]
N_VALUES
=
[
16
,
32
,
64
,
128
,
256
,
512
]
K_VALUES
=
[
16
,
32
,
64
,
128
]
K_VALUES_8Bit
=
[
32
,
64
,
128
]
FALSE_TRUE_CASES
=
([
pytest
.
param
(
k
,
"float16"
,
"float16"
,
"float16"
,
id
=
f
"K
{
k
}
-float16-float16-float16"
,
)
for
k
in
K_VALUES
]
+
[
pytest
.
param
(
k
,
"int8"
,
"int32"
,
"int32"
,
id
=
"K32-int8-int32-int32"
,
)
for
k
in
K_VALUES_8Bit
]
+
[
pytest
.
param
(
k
,
"float8_e5m2"
,
"float32"
,
"float32"
,
id
=
"K32-float8_e5m2-float32-float32"
,
)
for
k
in
K_VALUES_8Bit
]
+
[
pytest
.
param
(
k
,
"float8_e4m3"
,
"float32"
,
"float32"
,
id
=
"K32-float8_e4m3-float32-float32"
,
)
for
k
in
K_VALUES_8Bit
])
def
_ensure_torch_dtypes
(
*
dtype_names
):
import
torch
for
name
in
set
(
dtype_names
):
if
not
hasattr
(
torch
,
name
):
pytest
.
skip
(
f
"Torch does not expose dtype
{
name
}
"
)
def
run_gemm_rs_false_true
(
m
,
n
,
k
,
in_dtype
,
out_dtype
,
accum_dtype
):
run_gemm_rs
(
m
,
n
,
k
*
3
,
False
,
True
,
in_dtype
,
out_dtype
,
accum_dtype
,
m
,
n
,
k
)
def
run_gemm_rs_false_false
(
m
,
n
,
k
):
run_gemm_rs
(
m
,
n
,
k
*
3
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
m
,
n
,
k
)
def
run_gemm_rs_true_false
(
m
,
n
,
k
):
run_gemm_rs
(
m
,
n
,
k
*
3
,
True
,
False
,
"float16"
,
"float16"
,
"float16"
,
m
,
n
,
k
)
def
run_gemm_rs_true_true
(
m
,
n
,
k
):
run_gemm_rs
(
m
,
n
,
k
*
3
,
True
,
True
,
"float16"
,
"float16"
,
"float16"
,
m
,
n
,
k
)
def
run_gemm_sr_false_true
(
m
,
n
,
k
,
in_dtype
,
out_dtype
,
accum_dtype
):
run_gemm_sr
(
m
,
n
,
k
*
3
,
False
,
True
,
in_dtype
,
out_dtype
,
accum_dtype
,
m
,
n
,
k
)
def
run_gemm_sr_false_false
(
m
,
n
,
k
):
run_gemm_sr
(
m
,
n
,
k
*
3
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
m
,
n
,
k
)
def
run_gemm_sr_true_false
(
m
,
n
,
k
):
run_gemm_sr
(
m
,
n
,
k
*
3
,
True
,
False
,
"float16"
,
"float16"
,
"float16"
,
m
,
n
,
k
)
def
run_gemm_sr_true_true
(
m
,
n
,
k
):
run_gemm_sr
(
m
,
n
,
k
*
3
,
True
,
True
,
"float16"
,
"float16"
,
"float16"
,
m
,
n
,
k
)
def
run_gemm_rr_false_true
(
m
,
n
,
k
,
in_dtype
,
out_dtype
,
accum_dtype
):
run_gemm_rr
(
m
,
n
,
k
*
3
,
False
,
True
,
in_dtype
,
out_dtype
,
accum_dtype
,
m
,
n
,
k
)
def
run_gemm_rr_false_false
(
m
,
n
,
k
):
run_gemm_rr
(
m
,
n
,
k
*
3
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
m
,
n
,
k
)
def
run_gemm_rr_true_false
(
m
,
n
,
k
):
run_gemm_rr
(
m
,
n
,
k
*
3
,
True
,
False
,
"float16"
,
"float16"
,
"float16"
,
m
,
n
,
k
)
def
run_gemm_rr_true_true
(
m
,
n
,
k
):
run_gemm_rr
(
m
,
n
,
k
*
3
,
True
,
True
,
"float16"
,
"float16"
,
"float16"
,
m
,
n
,
k
)
TRANS_CASES
=
[
pytest
.
param
(
False
,
False
,
id
=
"nn"
),
pytest
.
param
(
False
,
True
,
id
=
"nt"
),
pytest
.
param
(
True
,
False
,
id
=
"tn"
),
pytest
.
param
(
True
,
True
,
id
=
"tt"
),
]
@
pytest
.
fixture
(
scope
=
"module"
,
autouse
=
True
)
def
_setup_tilelang_environment
():
tilelang
.
disable_cache
()
tilelang
.
testing
.
set_random_seed
(
42
)
@
pytest
.
mark
.
parametrize
(
"m"
,
M_VALUES
,
ids
=
lambda
v
:
f
"M
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k,in_dtype,out_dtype,accum_dtype"
,
FALSE_TRUE_CASES
)
def
test_gemm_false_true
(
m
,
n
,
k
,
in_dtype
,
out_dtype
,
accum_dtype
):
import
torch
required_torch_attrs
=
{
in_dtype
,
out_dtype
,
accum_dtype
,
}
for
attr
in
required_torch_attrs
:
if
not
hasattr
(
torch
,
attr
):
pytest
.
skip
(
f
"Torch does not expose dtype
{
attr
}
"
)
run_gemm
(
m
,
n
,
k
*
3
,
False
,
True
,
in_dtype
,
out_dtype
,
accum_dtype
,
m
,
n
,
k
,
)
@
pytest
.
mark
.
parametrize
(
"m"
,
M_VALUES
,
ids
=
lambda
v
:
f
"M
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k"
,
K_VALUES
,
ids
=
lambda
v
:
f
"K
{
v
}
"
)
def
test_gemm_false_false
(
m
,
n
,
k
):
run_gemm
(
m
,
n
,
k
*
3
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
m
,
n
,
k
,
)
@
pytest
.
mark
.
parametrize
(
"m"
,
M_VALUES
,
ids
=
lambda
v
:
f
"M
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k"
,
K_VALUES
,
ids
=
lambda
v
:
f
"K
{
v
}
"
)
def
test_gemm_true_false
(
m
,
n
,
k
):
run_gemm
(
m
,
n
,
k
*
3
,
True
,
False
,
"float16"
,
"float16"
,
"float16"
,
m
,
n
,
k
,
)
@
pytest
.
mark
.
parametrize
(
"m"
,
M_VALUES
,
ids
=
lambda
v
:
f
"M
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k"
,
K_VALUES
,
ids
=
lambda
v
:
f
"K
{
v
}
"
)
def
test_gemm_true_true
(
m
,
n
,
k
):
run_gemm
(
m
,
n
,
k
*
3
,
True
,
True
,
"float16"
,
"float16"
,
"float16"
,
m
,
n
,
k
,
)
@
pytest
.
mark
.
parametrize
(
"m"
,
M_VALUES
,
ids
=
lambda
v
:
f
"M
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k,in_dtype,out_dtype,accum_dtype"
,
FALSE_TRUE_CASES
)
def
test_gemm_rs_false_true
(
m
,
n
,
k
,
in_dtype
,
out_dtype
,
accum_dtype
):
_ensure_torch_dtypes
(
in_dtype
,
out_dtype
,
accum_dtype
)
run_gemm_rs_false_true
(
m
,
n
,
k
,
in_dtype
,
out_dtype
,
accum_dtype
)
@
pytest
.
mark
.
parametrize
(
"m"
,
M_VALUES
,
ids
=
lambda
v
:
f
"M
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k"
,
K_VALUES
,
ids
=
lambda
v
:
f
"K
{
v
}
"
)
def
test_gemm_rs_false_false
(
m
,
n
,
k
):
_ensure_torch_dtypes
(
"float16"
)
run_gemm_rs_false_false
(
m
,
n
,
k
)
@
pytest
.
mark
.
parametrize
(
"m"
,
M_VALUES
,
ids
=
lambda
v
:
f
"M
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k"
,
K_VALUES
,
ids
=
lambda
v
:
f
"K
{
v
}
"
)
def
test_gemm_rs_true_false
(
m
,
n
,
k
):
_ensure_torch_dtypes
(
"float16"
)
run_gemm_rs_true_false
(
m
,
n
,
k
)
@
pytest
.
mark
.
parametrize
(
"m"
,
M_VALUES
,
ids
=
lambda
v
:
f
"M
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k"
,
K_VALUES
,
ids
=
lambda
v
:
f
"K
{
v
}
"
)
def
test_gemm_rs_true_true
(
m
,
n
,
k
):
_ensure_torch_dtypes
(
"float16"
)
run_gemm_rs_true_true
(
m
,
n
,
k
)
@
pytest
.
mark
.
parametrize
(
"m"
,
M_VALUES
,
ids
=
lambda
v
:
f
"M
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k,in_dtype,out_dtype,accum_dtype"
,
FALSE_TRUE_CASES
)
def
test_gemm_sr_false_true
(
m
,
n
,
k
,
in_dtype
,
out_dtype
,
accum_dtype
):
_ensure_torch_dtypes
(
in_dtype
,
out_dtype
,
accum_dtype
)
run_gemm_sr_false_true
(
m
,
n
,
k
,
in_dtype
,
out_dtype
,
accum_dtype
)
@
pytest
.
mark
.
parametrize
(
"m"
,
M_VALUES
,
ids
=
lambda
v
:
f
"M
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k"
,
K_VALUES
,
ids
=
lambda
v
:
f
"K
{
v
}
"
)
def
test_gemm_sr_false_false
(
m
,
n
,
k
):
_ensure_torch_dtypes
(
"float16"
)
run_gemm_sr_false_false
(
m
,
n
,
k
)
@
pytest
.
mark
.
parametrize
(
"m"
,
M_VALUES
,
ids
=
lambda
v
:
f
"M
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k"
,
K_VALUES
,
ids
=
lambda
v
:
f
"K
{
v
}
"
)
def
test_gemm_sr_true_false
(
m
,
n
,
k
):
_ensure_torch_dtypes
(
"float16"
)
run_gemm_sr_true_false
(
m
,
n
,
k
)
@
pytest
.
mark
.
parametrize
(
"m"
,
M_VALUES
,
ids
=
lambda
v
:
f
"M
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k"
,
K_VALUES
,
ids
=
lambda
v
:
f
"K
{
v
}
"
)
def
test_gemm_sr_true_true
(
m
,
n
,
k
):
_ensure_torch_dtypes
(
"float16"
)
run_gemm_sr_true_true
(
m
,
n
,
k
)
@
pytest
.
mark
.
parametrize
(
"m"
,
M_VALUES
,
ids
=
lambda
v
:
f
"M
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k,in_dtype,out_dtype,accum_dtype"
,
FALSE_TRUE_CASES
)
def
test_gemm_rr_false_true
(
m
,
n
,
k
,
in_dtype
,
out_dtype
,
accum_dtype
):
_ensure_torch_dtypes
(
in_dtype
,
out_dtype
,
accum_dtype
)
run_gemm_rr_false_true
(
m
,
n
,
k
,
in_dtype
,
out_dtype
,
accum_dtype
)
@
pytest
.
mark
.
parametrize
(
"m"
,
M_VALUES
,
ids
=
lambda
v
:
f
"M
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k"
,
K_VALUES
,
ids
=
lambda
v
:
f
"K
{
v
}
"
)
def
test_gemm_rr_false_false
(
m
,
n
,
k
):
_ensure_torch_dtypes
(
"float16"
)
run_gemm_rr_false_false
(
m
,
n
,
k
)
@
pytest
.
mark
.
parametrize
(
"m"
,
M_VALUES
,
ids
=
lambda
v
:
f
"M
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k"
,
K_VALUES
,
ids
=
lambda
v
:
f
"K
{
v
}
"
)
def
test_gemm_rr_true_false
(
m
,
n
,
k
):
_ensure_torch_dtypes
(
"float16"
)
run_gemm_rr_true_false
(
m
,
n
,
k
)
@
pytest
.
mark
.
parametrize
(
"m"
,
M_VALUES
,
ids
=
lambda
v
:
f
"M
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k"
,
K_VALUES
,
ids
=
lambda
v
:
f
"K
{
v
}
"
)
def
test_gemm_rr_true_true
(
m
,
n
,
k
):
_ensure_torch_dtypes
(
"float16"
)
run_gemm_rr_true_true
(
m
,
n
,
k
)
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
# # Test Pass
# for m in [64, 128, 256]:
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True, "float16", "float16", "float16", m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# # Test Pass
# for m in [64, 128, 256]:
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# print(f"======================= Test {m} {n} {k} False False =============================")
# run_gemm(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# # Test Pass
# for m in [64, 128, 256]:
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# print(f"======================= Test {m} {n} {k} True False =============================")
# run_gemm(m, n, k * 3, True, False, "float16", "float16", "float16", m, n, k, 2, 128)
# print(f"Test {m}, {n} {k} Pass")
# print(f"Test {n} Pass")
# # Test Pass
# for m in [64, 128, 256]:
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# print(f"======================= Test {m} {n} {k} True True =============================")
# run_gemm(m, n, k * 3, True, True, "float16", "float16", "float16", m, n, k, 2, 128)
# print(f"Test {m}, {n} {k} Pass")
# print(f"Test {n} Pass")
# Test Pass
# for m in [64, 128, 256]:
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm_rs(m, n, k * 3, False, True, "float16", "float16", "float16", m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# run_gemm_rs(64, n, k, False, False, "float16", "float16", "float16", 64, n, k, 0, 256)
# print(f"Test {64} {n} {k} Pass")
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# run_gemm(64, n, k, False, False, "float16", "float16", "float16", 64, n, k, 0, 256)
# print(f"Test {64} {n} {k} Pass")
maint/gemm_v2/correctness_evaluation_sm70.py
0 → 100644
View file @
bbbf4207
# pytest maint/gemm_v2/correctness_evaluation_sm70.py -n 32
import
pytest
from
tilelang
import
tvm
as
tvm
import
tilelang.testing
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
accum_dtype
,
num_stages
,
threads
,
):
A_shape
=
(
K
,
M
)
if
trans_A
else
(
M
,
K
)
B_shape
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
)
A_shared_shape
=
(
block_K
,
block_M
)
if
trans_A
else
(
block_M
,
block_K
)
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
if
trans_A
:
T
.
copy
(
A
[
k
*
block_K
,
by
*
block_M
],
A_shared
)
else
:
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
if
trans_B
:
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
else
:
T
.
copy
(
B
[
k
*
block_K
,
bx
*
block_N
],
B_shared
)
# T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T
.
gemm_v2
(
A_shared
,
B_shared
,
C_local
,
trans_A
,
trans_B
)
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
return
main
def
_compile_and_check
(
program
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
):
kernel
=
tilelang
.
compile
(
program
,
out_idx
=
[
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
# tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False,
})
print
(
kernel
.
get_kernel_source
())
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
def
ref_program
(
A
,
B
):
import
torch
if
trans_A
:
A
=
A
.
T
if
trans_B
:
B
=
B
.
T
if
in_dtype
==
"float32"
:
A
=
(
A
.
view
(
torch
.
int32
)
-
0x1000
).
view
(
torch
.
float32
)
B
=
(
B
.
view
(
torch
.
int32
)
-
0x1000
).
view
(
torch
.
float32
)
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
C
.
to
(
torch
.
__getattribute__
(
out_dtype
))
return
C
profiler
.
assert_allclose
(
ref_program
,
atol
=
1e-2
,
rtol
=
1e-2
)
print
(
"assert_allclose"
)
def
run_gemm
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
128
,
):
program
=
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
num_stages
,
num_threads
,
)
_compile_and_check
(
program
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
)
def
matmul_rs
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
accum_dtype
,
num_stages
,
threads
,
):
A_shape
=
(
K
,
M
)
if
trans_A
else
(
M
,
K
)
B_shape
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
)
A_shared_shape
=
(
block_K
,
block_M
)
if
trans_A
else
(
block_M
,
block_K
)
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
A_frag_shape
=
A_shared_shape
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
A_frag
=
T
.
alloc_fragment
(
A_frag_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
if
trans_A
:
T
.
copy
(
A
[
k
*
block_K
,
by
*
block_M
],
A_shared
)
else
:
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
if
trans_B
:
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
else
:
T
.
copy
(
B
[
k
*
block_K
,
bx
*
block_N
],
B_shared
)
T
.
copy
(
A_shared
,
A_frag
)
T
.
gemm_v2
(
A_frag
,
B_shared
,
C_local
,
trans_A
,
trans_B
)
# T.gemm(A_frag, B_shared, C_local, trans_A, trans_B)
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
return
main
def
run_gemm_rs
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
128
,
):
program
=
matmul_rs
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
num_stages
,
num_threads
,
)
_compile_and_check
(
program
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
)
M_VALUES
=
[
64
,
128
]
N_VALUES
=
[
32
,
64
,
128
]
K_VALUES
=
[
16
,
32
,
64
]
FALSE_TRUE_CASES
=
([
pytest
.
param
(
k
,
"float16"
,
"float16"
,
"float16"
,
id
=
f
"K
{
k
}
-float16-float16-float16"
,
)
for
k
in
K_VALUES
]
+
[
pytest
.
param
(
k
,
"float16"
,
"float16"
,
"float32"
,
id
=
f
"K
{
k
}
-float16-float16-float32"
,
)
for
k
in
K_VALUES
])
def
_ensure_torch_dtypes
(
*
dtype_names
):
import
torch
for
name
in
set
(
dtype_names
):
if
not
hasattr
(
torch
,
name
):
pytest
.
skip
(
f
"Torch does not expose dtype
{
name
}
"
)
def
run_gemm_rs_false_true
(
m
,
n
,
k
,
in_dtype
,
out_dtype
,
accum_dtype
):
run_gemm_rs
(
m
,
n
,
k
*
3
,
False
,
True
,
in_dtype
,
out_dtype
,
accum_dtype
,
m
,
n
,
k
,
2
,
128
)
def
run_gemm_rs_false_false
(
m
,
n
,
k
):
run_gemm_rs
(
m
,
n
,
k
*
3
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
m
,
n
,
k
,
2
,
128
)
TRANS_CASES
=
[
pytest
.
param
(
False
,
False
,
id
=
"nn"
),
pytest
.
param
(
False
,
True
,
id
=
"nt"
),
pytest
.
param
(
True
,
False
,
id
=
"tn"
),
pytest
.
param
(
True
,
True
,
id
=
"tt"
),
]
@
pytest
.
fixture
(
scope
=
"module"
,
autouse
=
True
)
def
_setup_tilelang_environment
():
tilelang
.
disable_cache
()
tilelang
.
testing
.
set_random_seed
(
42
)
@
pytest
.
mark
.
parametrize
(
"m"
,
M_VALUES
,
ids
=
lambda
v
:
f
"M
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k,in_dtype,out_dtype,accum_dtype"
,
FALSE_TRUE_CASES
)
def
test_gemm_false_true
(
m
,
n
,
k
,
in_dtype
,
out_dtype
,
accum_dtype
):
import
torch
required_torch_attrs
=
{
in_dtype
,
out_dtype
,
accum_dtype
,
}
for
attr
in
required_torch_attrs
:
if
not
hasattr
(
torch
,
attr
):
pytest
.
skip
(
f
"Torch does not expose dtype
{
attr
}
"
)
run_gemm
(
m
,
n
,
k
*
3
,
False
,
True
,
in_dtype
,
out_dtype
,
accum_dtype
,
m
,
n
,
k
,
2
,
128
,
)
@
pytest
.
mark
.
parametrize
(
"m"
,
M_VALUES
,
ids
=
lambda
v
:
f
"M
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k"
,
K_VALUES
,
ids
=
lambda
v
:
f
"K
{
v
}
"
)
def
test_gemm_false_false
(
m
,
n
,
k
):
run_gemm
(
m
,
n
,
k
*
3
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
m
,
n
,
k
,
2
,
128
,
)
@
pytest
.
mark
.
parametrize
(
"m"
,
M_VALUES
,
ids
=
lambda
v
:
f
"M
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k,in_dtype,out_dtype,accum_dtype"
,
FALSE_TRUE_CASES
)
def
test_gemm_rs_false_true
(
m
,
n
,
k
,
in_dtype
,
out_dtype
,
accum_dtype
):
_ensure_torch_dtypes
(
in_dtype
,
out_dtype
,
accum_dtype
)
run_gemm_rs_false_true
(
m
,
n
,
k
,
in_dtype
,
out_dtype
,
accum_dtype
)
@
pytest
.
mark
.
parametrize
(
"m"
,
M_VALUES
,
ids
=
lambda
v
:
f
"M
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k"
,
K_VALUES
,
ids
=
lambda
v
:
f
"K
{
v
}
"
)
def
test_gemm_rs_false_false
(
m
,
n
,
k
):
_ensure_torch_dtypes
(
"float16"
)
run_gemm_rs_false_false
(
m
,
n
,
k
)
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
# # Test Pass
# for m in [64, 128]:
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64]:
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True, "float16", "float16", "float16", m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# # Test Pass
# for m in [64, 128]:
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64]:
# print(f"======================= Test {m} {n} {k} False False =============================")
# run_gemm(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
maint/gemm_v2/correctness_evaluation_tcgen05.py
0 → 100644
View file @
bbbf4207
# pytest correctness_evaluation.py -n 32
import
pytest
from
tilelang
import
tvm
as
tvm
import
tilelang.testing
import
tilelang.language
as
T
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
accum_dtype
,
num_stages
,
threads
,
):
A_shape
=
(
K
,
M
)
if
trans_A
else
(
M
,
K
)
B_shape
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
)
A_shared_shape
=
(
block_K
,
block_M
)
if
trans_A
else
(
block_M
,
block_K
)
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
C_tmem
=
T
.
alloc_tmem
([
block_M
,
block_N
],
accum_dtype
)
mbar
=
T
.
alloc_barrier
(
1
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
out_dtype
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_tmem
,
trans_A
,
trans_B
,
mbar
=
mbar
,
wg_wait
=-
1
,
clear_accum
=
k
==
0
)
T
.
mbarrier_wait_parity
(
mbar
,
k
%
2
)
T
.
copy
(
C_tmem
,
C_local
)
T
.
copy
(
C_local
,
C_shared
)
T
.
copy
(
C_shared
,
C
[
by
*
block_M
,
bx
*
block_N
])
return
main
def
_compile_and_check
(
program
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
):
kernel
=
tilelang
.
compile
(
program
,
out_idx
=
[
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
print
(
kernel
.
get_kernel_source
())
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
def
ref_program
(
A
,
B
):
import
torch
if
trans_A
:
A
=
A
.
T
if
trans_B
:
B
=
B
.
T
if
in_dtype
==
"float32"
:
A
=
(
A
.
view
(
torch
.
int32
)
-
0x1000
).
view
(
torch
.
float32
)
B
=
(
B
.
view
(
torch
.
int32
)
-
0x1000
).
view
(
torch
.
float32
)
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
C
.
to
(
torch
.
__getattribute__
(
out_dtype
))
return
C
profiler
.
assert_allclose
(
ref_program
,
atol
=
1e-2
,
rtol
=
1e-2
)
print
(
"assert_allclose"
)
def
run_gemm
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
2
,
num_threads
=
128
,
):
if
block_N
>=
256
or
block_M
>=
256
or
block_K
>=
256
:
num_stages
=
0
program
=
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
num_stages
,
num_threads
,
)
_compile_and_check
(
program
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
)
M_VALUES
=
[
32
,
64
,
128
,
256
]
N_VALUES
=
[
64
,
128
,
256
,
512
]
K_VALUES
=
[
16
,
32
,
64
,
128
]
K_VALUES_8Bit
=
[
32
,
64
,
128
]
FALSE_TRUE_CASES
=
([
pytest
.
param
(
k
,
"float16"
,
"float32"
,
"float32"
,
id
=
f
"K
{
k
}
-float16-float-float"
,
)
for
k
in
K_VALUES
]
+
[
pytest
.
param
(
k
,
"float8_e5m2"
,
"float32"
,
"float32"
,
id
=
"K32-float8_e5m2-float32-float32"
,
)
for
k
in
K_VALUES_8Bit
])
TRANS_CASES
=
[
pytest
.
param
(
False
,
True
,
id
=
"nt"
),
]
@
pytest
.
mark
.
parametrize
(
"m"
,
M_VALUES
,
ids
=
lambda
v
:
f
"M
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k,in_dtype,out_dtype,accum_dtype"
,
FALSE_TRUE_CASES
)
def
test_gemm_false_true
(
m
,
n
,
k
,
in_dtype
,
out_dtype
,
accum_dtype
):
import
torch
required_torch_attrs
=
{
in_dtype
,
out_dtype
,
accum_dtype
,
}
for
attr
in
required_torch_attrs
:
if
not
hasattr
(
torch
,
attr
):
pytest
.
skip
(
f
"Torch does not expose dtype
{
attr
}
"
)
run_gemm
(
m
,
n
,
k
*
3
,
False
,
True
,
in_dtype
,
out_dtype
,
accum_dtype
,
m
,
n
,
k
,
)
if
__name__
==
"__main__"
:
# tilelang.testing.main()
# # Test Pass
# for m in [32, 64, 128, 256]:
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# if m in [32, 64] and (n not in [64, 128, 256]):
# continue
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True, "float16", "float", "float", m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# # Test Pass
# for m in [32, 64, 128, 256]:
# for n in [16, 32, 64, 128]:
# for k in [32, 64, 128]:
# if m in [32, 64] and (n not in [64, 128, 256]):
# continue
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True, "float8_e5m2", "float", "float", m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
tilelang
.
disable_cache
()
run_gemm
(
32
,
512
,
16
,
False
,
True
,
"float16"
,
"float32"
,
"float32"
,
32
,
512
,
16
,
0
,
128
)
run_gemm
(
32
,
512
,
32
,
False
,
True
,
"float16"
,
"float32"
,
"float32"
,
32
,
512
,
32
,
0
,
128
)
run_gemm
(
32
,
512
,
64
,
False
,
True
,
"float16"
,
"float32"
,
"float32"
,
32
,
512
,
64
,
0
,
128
)
run_gemm
(
64
,
512
,
16
,
False
,
True
,
"float16"
,
"float32"
,
"float32"
,
64
,
512
,
16
,
0
,
128
)
run_gemm
(
64
,
512
,
16
,
False
,
True
,
"float16"
,
"float32"
,
"float32"
,
32
,
512
,
16
,
0
,
128
)
run_gemm
(
128
,
512
,
16
,
False
,
True
,
"float16"
,
"float32"
,
"float32"
,
128
,
512
,
16
,
0
,
128
)
# run_gemm(64, 512, 32, False, True, "float16", "float32", "float32", 64, 512, 32, 0, 128)
# run_gemm(64, 512, 64, False, True, "float16", "float32", "float32", 64, 512, 64, 0, 128)
# run_gemm(128, 512, 16, False, True, "float16", "float32", "float32", 128, 512, 16, 0, 128)
maint/gemm_v2/latency.py
0 → 100644
View file @
bbbf4207
import
tilelang
import
tilelang.language
as
T
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--use_v2"
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
use_v2
=
args
.
use_v2
# @tilelang.jit(target="cuda")
# target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time
@
tilelang
.
jit
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
matmul_relu_kernel
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
# Enable rasterization for better L2 cache locality (Optional)
# T.use_swizzle(panel_size=10, enable=True)
# Clear local accumulation
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
3
):
# Copy tile of A
# This is a sugar syntax for parallelized copy
T
.
copy
(
A
[
by
*
block_M
,
ko
*
block_K
],
A_shared
)
# Copy tile of B
T
.
copy
(
B
[
ko
*
block_K
,
bx
*
block_N
],
B_shared
)
# Perform a tile-level GEMM on the shared buffers
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
if
use_v2
:
T
.
gemm_v2
(
A_shared
,
B_shared
,
C_local
)
else
:
T
.
gemm_v1
(
A_shared
,
B_shared
,
C_local
)
# relu
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
C_local
[
i
,
j
]
=
T
.
max
(
C_local
[
i
,
j
],
0
)
# Copy result back to global memory
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
return
matmul_relu_kernel
M
=
16384
# M = T.dynamic("m") if you want to use dynamic shape
N
=
16384
K
=
16384
block_M
=
128
block_N
=
128
block_K
=
32
# 1. Define the kernel (matmul) and compile/lower it into an executable module
matmul_relu_kernel
=
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
)
# 3. Test the kernel in Python with PyTorch data
import
torch
# Create random input tensors on the GPU
a
=
torch
.
randn
(
M
,
K
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
b
=
torch
.
randn
(
K
,
N
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
c
=
torch
.
empty
(
M
,
N
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
# Run the kernel through the Profiler
matmul_relu_kernel
(
a
,
b
,
c
)
print
(
c
)
# Reference multiplication using PyTorch
ref_c
=
torch
.
relu
(
a
@
b
)
# Validate correctness
torch
.
testing
.
assert_close
(
c
,
ref_c
,
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
"Kernel output matches PyTorch reference."
)
# 4. Retrieve and inspect the generated CUDA source (optional)
# cuda_source = jit_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source)
# 5.Profile latency with kernel
profiler
=
matmul_relu_kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
latency
=
profiler
.
do_bench
()
print
(
f
"Latency:
{
latency
}
ms"
)
maint/gemm_v2/latency_gemm.py
0 → 100644
View file @
bbbf4207
import
tilelang
import
tilelang.language
as
T
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--use_v2"
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
use_v2
=
args
.
use_v2
# @tilelang.jit(target="cuda")
# target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time
@
tilelang
.
jit
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
matmul_relu_kernel
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
# Enable rasterization for better L2 cache locality (Optional)
# T.use_swizzle(panel_size=10, enable=True)
# Clear local accumulation
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
3
):
# Copy tile of A
# This is a sugar syntax for parallelized copy
T
.
copy
(
A
[
by
*
block_M
,
ko
*
block_K
],
A_shared
)
# Copy tile of B
T
.
copy
(
B
[
ko
*
block_K
,
bx
*
block_N
],
B_shared
)
# Perform a tile-level GEMM on the shared buffers
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
if
use_v2
:
T
.
gemm_v2
(
A_shared
,
B_shared
,
C_local
)
else
:
T
.
gemm_v1
(
A_shared
,
B_shared
,
C_local
)
# relu
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
C_local
[
i
,
j
]
=
T
.
max
(
C_local
[
i
,
j
],
0
)
# Copy result back to global memory
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
return
matmul_relu_kernel
M
=
16384
# M = T.dynamic("m") if you want to use dynamic shape
N
=
16384
K
=
16384
block_M
=
128
block_N
=
128
block_K
=
64
# 1. Define the kernel (matmul) and compile/lower it into an executable module
matmul_relu_kernel
=
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
)
# 3. Test the kernel in Python with PyTorch data
import
torch
# Create random input tensors on the GPU
a
=
torch
.
randn
(
M
,
K
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
b
=
torch
.
randn
(
K
,
N
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
c
=
torch
.
empty
(
M
,
N
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
# Run the kernel through the Profiler
matmul_relu_kernel
(
a
,
b
,
c
)
print
(
c
)
# Reference multiplication using PyTorch
ref_c
=
torch
.
relu
(
a
@
b
)
# Validate correctness
torch
.
testing
.
assert_close
(
c
,
ref_c
,
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
"Kernel output matches PyTorch reference."
)
# 4. Retrieve and inspect the generated CUDA source (optional)
# cuda_source = jit_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source)
# 5.Profile latency with kernel
profiler
=
matmul_relu_kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
latency
=
profiler
.
do_bench
()
print
(
f
"Latency:
{
latency
}
ms"
)
maint/gemm_v2/latency_mha_fwd_bhsd.py
0 → 100644
View file @
bbbf4207
import
torch
import
torch.nn.functional
as
F
import
tilelang
from
tilelang.autotuner
import
*
import
tilelang.language
as
T
import
itertools
import
argparse
from
functools
import
partial
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
128
,
help
=
'batch size'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
16
,
help
=
'heads'
)
parser
.
add_argument
(
'--seq_q'
,
type
=
int
,
default
=
1024
,
help
=
'query sequence length'
)
parser
.
add_argument
(
'--seq_kv'
,
type
=
int
,
default
=
1024
,
help
=
'key/value sequence length'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
256
,
help
=
'dim'
)
parser
.
add_argument
(
'--is_causal'
,
action
=
'store_true'
,
help
=
'causal'
)
parser
.
add_argument
(
'--tune'
,
action
=
'store_true'
,
help
=
'tune configs'
)
parser
.
add_argument
(
"--use_v2"
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
use_v2
=
args
.
use_v2
def
get_configs
():
iter_params
=
dict
(
block_M
=
[
128
],
block_N
=
[
128
],
num_stages
=
[
2
],
threads
=
[
256
])
return
[
dict
(
zip
(
iter_params
,
values
))
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
tilelang
.
jit
(
out_idx
=
[
3
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
0
,
threads
=
128
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
q_shape
=
[
batch
,
heads
,
seq_q
,
dim
]
kv_shape
=
[
batch
,
heads
,
seq_kv
,
dim
]
dtype
=
"float16"
accum_dtype
=
"float"
past_len
=
seq_kv
-
seq_q
assert
past_len
>=
0
,
"seq_kv must be greater than or equal to seq_q"
@
T
.
macro
def
MMA0
(
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
Q_shared
:
T
.
SharedBuffer
([
block_M
,
dim
],
dtype
),
K_shared
:
T
.
SharedBuffer
([
block_N
,
dim
],
dtype
),
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
k
:
T
.
int32
,
bx
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
q_idx
=
bx
*
block_M
+
i
+
past_len
k_idx
=
k
*
block_N
+
j
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
q_idx
>=
k_idx
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
if
use_v2
:
T
.
gemm_v2
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
else
:
T
.
gemm_v1
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
def
MMA1
(
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
V_shared
:
T
.
SharedBuffer
([
block_N
,
dim
],
dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
k
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:(
k
+
1
)
*
block_N
,
:],
V_shared
)
# T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if
use_v2
:
T
.
gemm_v2
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
else
:
T
.
gemm_v1
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
def
Softmax
(
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
T
.
copy
(
acc_s
,
acc_s_cast
)
@
T
.
macro
def
Rescale
(
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
@
T
.
prim_func
def
main
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_q
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
O_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
acc_s
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
acc_o
=
T
.
alloc_fragment
([
block_M
,
dim
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_max_prev
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_scale
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
+
past_len
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_kv
,
block_N
))
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
:])
return
main
def
ref_program
(
Q
,
K
,
V
,
is_causal
):
dim
=
Q
.
size
(
-
1
)
scores
=
torch
.
einsum
(
'bhqd,bhkd->bhqk'
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_q
=
Q
.
size
(
2
)
seq_kv
=
K
.
size
(
2
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_q
,
seq_kv
,
device
=
scores
.
device
),
seq_kv
-
seq_q
)
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'-inf'
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'bhqk,bhkd->bhqd'
,
attention_weights
,
V
)
return
output
def
main
(
batch
:
int
=
1
,
heads
:
int
=
1
,
seq_q
:
int
=
256
,
seq_kv
:
int
=
256
,
dim
:
int
=
64
,
is_causal
:
bool
=
False
,
tune
:
bool
=
False
,
):
flops_per_matmul
=
2.0
*
batch
*
heads
*
seq_q
*
seq_kv
*
dim
total_flops
=
2
*
flops_per_matmul
if
is_causal
:
total_flops
*=
0.5
if
(
not
tune
):
kernel
=
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
0
,
threads
=
128
)
print
(
kernel
.
get_kernel_source
())
ref_program_processed
=
partial
(
ref_program
,
is_causal
=
is_causal
)
profiler
=
kernel
.
get_profiler
()
profiler
.
assert_allclose
(
ref_program_processed
,
rtol
=
0.01
,
atol
=
0.01
)
print
(
"All checks pass."
)
latency
=
profiler
.
do_bench
(
ref_program_processed
,
warmup
=
500
)
print
(
f
"Ref:
{
latency
:.
2
f
}
ms"
)
print
(
f
"Ref:
{
total_flops
/
latency
*
1e-9
:.
2
f
}
TFlops"
)
latency
=
profiler
.
do_bench
(
warmup
=
500
)
print
(
f
"Tile-lang:
{
latency
:.
2
f
}
ms"
)
print
(
f
"Tile-lang:
{
total_flops
/
latency
*
1e-9
:.
2
f
}
TFlops"
)
else
:
kernel
=
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
)
best_latency
=
kernel
.
latency
best_config
=
kernel
.
config
ref_latency
=
kernel
.
ref_latency
print
(
f
"Best latency:
{
best_latency
}
"
)
print
(
f
"Best TFlops:
{
total_flops
/
best_latency
*
1e-9
}
"
)
print
(
f
"Best config:
{
best_config
}
"
)
print
(
f
"Ref latency:
{
ref_latency
}
"
)
if
__name__
==
"__main__"
:
tilelang
.
disable_cache
()
main
(
args
.
batch
,
args
.
heads
,
args
.
seq_q
,
args
.
seq_kv
,
args
.
dim
,
args
.
is_causal
,
args
.
tune
)
maint/scripts/docker_build_all.sh
deleted
100755 → 0
View file @
8f4628e0
./maint/scripts/docker_local_distribute.sh 2>&1 |
tee
docker_local_distribute.log
./maint/scripts/docker_pypi_distribute.sh 2>&1 |
tee
docker_pypi_distribute.log
Prev
1
2
3
4
5
6
7
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment