Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
667632cc
Unverified
Commit
667632cc
authored
Dec 22, 2025
by
guchaoyang
Committed by
GitHub
Dec 22, 2025
Browse files
Merge branch 'main' into dcu
parents
d6dd2ddf
a874e4e8
Changes
313
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
452 additions
and
411 deletions
+452
-411
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
...s/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
+17
-24
examples/topk/example_topk.py
examples/topk/example_topk.py
+8
-12
examples/visual_layout_inference/visual_layout_inference.py
examples/visual_layout_inference/visual_layout_inference.py
+61
-0
examples/warp_specialize/example_warp_specialize_flashmla.py
examples/warp_specialize/example_warp_specialize_flashmla.py
+55
-93
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py
...ialize/example_warp_specialize_gemm_barrierpipe_stage2.py
+6
-13
examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py
..._specialize/example_warp_specialize_gemm_copy_0_gemm_1.py
+4
-12
examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py
..._specialize/example_warp_specialize_gemm_copy_1_gemm_0.py
+4
-12
examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py
..._specialize/example_warp_specialize_gemm_copy_gemm_0_1.py
+8
-14
examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py
...pecialize/example_warp_specialize_gemm_softpipe_stage2.py
+1
-2
format.sh
format.sh
+1
-1
maint/gemm_v2/correctness_evaluation.py
maint/gemm_v2/correctness_evaluation.py
+93
-89
maint/gemm_v2/correctness_evaluation_sm70.py
maint/gemm_v2/correctness_evaluation_sm70.py
+29
-29
maint/gemm_v2/correctness_evaluation_tcgen05.py
maint/gemm_v2/correctness_evaluation_tcgen05.py
+32
-40
maint/gemm_v2/latency.py
maint/gemm_v2/latency.py
+4
-5
maint/gemm_v2/latency_gemm.py
maint/gemm_v2/latency_gemm.py
+4
-5
maint/gemm_v2/latency_mha_fwd_bhsd.py
maint/gemm_v2/latency_mha_fwd_bhsd.py
+42
-60
maint/host_checks/01_num_args_mismatch.py
maint/host_checks/01_num_args_mismatch.py
+22
-0
maint/host_checks/02_pointer_type_error.py
maint/host_checks/02_pointer_type_error.py
+23
-0
maint/host_checks/03_ndim_mismatch.py
maint/host_checks/03_ndim_mismatch.py
+19
-0
maint/host_checks/04_dtype_mismatch.py
maint/host_checks/04_dtype_mismatch.py
+19
-0
No files found.
Too many changes to show.
To preserve performance only
313 of 313+
files are displayed.
Plain diff
Email patch
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
View file @
667632cc
import
torch
import
tilelang
from
tilelang.utils.sparse
import
compress_sm90
from
tilelang.layout
import
make_metadata_layout
from
tilelang.layout
import
make_cutlass_metadata_layout
from
tilelang
import
language
as
T
import
tilelang.testing
...
...
@@ -24,32 +25,24 @@ def matmul_sp(
A_shared_shape
=
(
block_M
,
block_K
//
2
)
B_shared_shape
=
(
block_K
,
block_N
)
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
8
),
'
uint8
'
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
8
),
"
uint8
"
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
8
),
'
uint8
'
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
8
),
"
uint8
"
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
annotate_layout
({
E
:
make_metadata_layout
(
E
,
mma_dtype
=
"float16"
,
arch
=
"9.0"
,
backend
=
"cutlass"
,
block_k
=
block_K
),
E_shared
:
make_metadata_layout
(
E_shared
,
mma_dtype
=
"float16"
,
arch
=
"9.0"
,
backend
=
"cutlass"
,
block_k
=
block_K
),
})
T
.
annotate_layout
(
{
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
T
.
float16
,
arch
=
"9.0"
,
block_k
=
block_K
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
T
.
float16
,
arch
=
"9.0"
,
block_k
=
block_K
),
}
)
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
8
],
E_shared
)
...
...
@@ -61,7 +54,7 @@ def matmul_sp(
return
main
def
generate_2_to_4_sparse_tensor
(
shape
,
dtype
=
torch
.
float32
,
device
=
'
cpu
'
):
def
generate_2_to_4_sparse_tensor
(
shape
,
dtype
=
torch
.
float32
,
device
=
"
cpu
"
):
if
shape
[
-
1
]
%
4
!=
0
:
raise
ValueError
(
"Last dimension must be divisible by 4 for 2:4 sparsity."
)
...
...
@@ -106,9 +99,9 @@ def run_gemm_sp(
num_threads
,
)
A
=
generate_2_to_4_sparse_tensor
((
M
,
K
),
dtype
=
torch
.
float16
,
device
=
'
cuda
'
)
A
=
generate_2_to_4_sparse_tensor
((
M
,
K
),
dtype
=
torch
.
float16
,
device
=
"
cuda
"
)
A_sparse
,
E
=
compress_sm90
(
A
,
block_k
=
block_K
,
transposed
=
False
)
B
=
torch
.
randn
((
K
,
N
),
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
B
=
torch
.
randn
((
K
,
N
),
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
C_sp
=
kernel
(
A_sparse
,
E
,
B
).
half
()
C
=
torch
.
matmul
(
A
,
B
)
...
...
@@ -117,7 +110,7 @@ def run_gemm_sp(
def
main
():
run_gemm_sp
(
512
,
1024
,
768
,
"
float16
"
,
"
float16
"
,
"
float32
"
,
128
,
128
,
128
,
2
,
128
)
run_gemm_sp
(
512
,
1024
,
768
,
T
.
float16
,
T
.
float16
,
T
.
float32
,
128
,
128
,
128
,
2
,
128
)
if
__name__
==
"__main__"
:
...
...
examples/topk/example_topk.py
View file @
667632cc
...
...
@@ -22,19 +22,19 @@ def tl_topk(
blk_m
,
threads
=
128
,
):
dtype
=
"
float32
"
dtype
=
T
.
float32
@
T
.
prim_func
def
topk_kernel
(
logits
:
T
.
Tensor
([
M
,
N
],
dtype
),
topk_gates
:
T
.
Tensor
([
M
,
topk
],
dtype
),
topk_indices
:
T
.
Tensor
([
M
,
topk
],
"
int32
"
),
logits
:
T
.
Tensor
([
M
,
N
],
dtype
),
topk_gates
:
T
.
Tensor
([
M
,
topk
],
dtype
),
topk_indices
:
T
.
Tensor
([
M
,
topk
],
T
.
int32
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
blk_m
),
threads
=
threads
)
as
bx
:
logits_frag
=
T
.
alloc_fragment
([
blk_m
,
N
],
dtype
=
dtype
)
max_val
=
T
.
alloc_fragment
([
blk_m
],
dtype
=
dtype
)
expand_max_idx
=
T
.
alloc_fragment
([
blk_m
,
N
],
"
int32
"
)
max_idx
=
T
.
alloc_fragment
([
blk_m
],
"
int32
"
)
expand_max_idx
=
T
.
alloc_fragment
([
blk_m
,
N
],
T
.
int32
)
max_idx
=
T
.
alloc_fragment
([
blk_m
],
T
.
int32
)
T
.
copy
(
logits
[
bx
*
blk_m
,
0
],
logits_frag
)
...
...
@@ -43,15 +43,12 @@ def tl_topk(
T
.
reduce_max
(
logits_frag
,
max_val
,
dim
=
1
,
clear
=
True
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
N
):
expand_max_idx
[
i
,
j
]
=
T
.
if_then_else
(
max_val
[
i
]
==
logits_frag
[
i
,
j
],
j
,
expand_max_idx
[
i
,
j
])
expand_max_idx
[
i
,
j
]
=
T
.
if_then_else
(
max_val
[
i
]
==
logits_frag
[
i
,
j
],
j
,
expand_max_idx
[
i
,
j
])
T
.
reduce_max
(
expand_max_idx
,
max_idx
,
dim
=
1
,
clear
=
True
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
N
):
logits_frag
[
i
,
j
]
=
T
.
if_then_else
(
max_val
[
i
]
==
logits_frag
[
i
,
j
],
-
10000.0
,
logits_frag
[
i
,
j
])
logits_frag
[
i
,
j
]
=
T
.
if_then_else
(
max_val
[
i
]
==
logits_frag
[
i
,
j
],
-
10000.0
,
logits_frag
[
i
,
j
])
for
i
in
T
.
Parallel
(
blk_m
):
topk_gates
[
bx
*
blk_m
+
i
,
k
]
=
max_val
[
i
]
...
...
@@ -61,7 +58,6 @@ def tl_topk(
def
ref_program
(
logits
,
top_k
):
top_k_gates
,
top_k_indices
=
logits
.
topk
(
top_k
,
dim
=
1
)
return
top_k_gates
,
top_k_indices
.
to
(
torch
.
int32
)
...
...
examples/visual_layout_inference/visual_layout_inference.py
0 → 100644
View file @
667632cc
import
tilelang
import
tilelang.language
as
T
# use pass_configs to enable layout visualization
@
tilelang
.
jit
(
out_idx
=
[
-
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_LAYOUT_VISUALIZATION_ENABLE
:
True
,
tilelang
.
PassConfigKey
.
TL_LAYOUT_VISUALIZATION_FORMATS
:
"svg"
,
},
)
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
T
.
float16
,
accum_dtype
=
T
.
float32
):
@
T
.
prim_func
def
gemm
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
3
):
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
T
.
copy
(
B
[
k
*
block_K
,
bx
*
block_N
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
)
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
return
gemm
def
main
():
kernel
=
matmul
(
128
,
128
,
128
,
32
,
32
,
32
)
import
torch
a
=
torch
.
randn
(
128
,
128
).
cuda
().
half
()
b
=
torch
.
randn
(
128
,
128
).
cuda
().
half
()
c
=
kernel
(
a
,
b
)
ref_c
=
a
@
b
torch
.
testing
.
assert_close
(
c
,
ref_c
,
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
"All check passed."
)
# print the layout visualization result and save figures to ./tmp.
"""
C_local inferenced layout:
Shape: [32, 32] -> [8]
Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2
Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2]
"""
if
__name__
==
"__main__"
:
main
()
examples/warp_specialize/example_warp_specialize_flashmla.py
View file @
667632cc
...
...
@@ -9,9 +9,9 @@ import argparse
@
tilelang
.
jit
(
out_idx
=
[
6
])
def
flashattn
(
batch
,
heads
,
kv_head_num
,
seqlen_kv
,
dim
,
pe_dim
,
block_N
,
block_H
,
num_split
):
scale
=
(
1.0
/
(
dim
+
pe_dim
))
**
0.5
*
1.44269504
# log2(e)
dtype
=
"
float16
"
accum_dtype
=
"
float
"
scale
=
(
1.0
/
(
dim
+
pe_dim
))
**
0.5
*
1.44269504
# log2(e)
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
kv_group_num
=
heads
//
kv_head_num
VALID_BLOCK_H
=
min
(
block_H
,
kv_group_num
)
assert
kv_head_num
==
1
,
"kv_head_num must be 1"
...
...
@@ -19,11 +19,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@
T
.
macro
def
flash_attn
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
with
T
.
Kernel
(
heads
//
min
(
block_H
,
kv_group_num
),
batch
,
threads
=
256
)
as
(
hid
,
bid
):
# smem_sQ
...
...
@@ -81,10 +81,12 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
cur_kv_head
=
hid
//
(
kv_group_num
//
block_H
)
T
.
annotate_layout
({
O_shared_l
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared_l
),
O_shared_r
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared_r
),
})
T
.
annotate_layout
(
{
O_shared_l
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared_l
),
O_shared_r
:
tilelang
.
layout
.
make_swizzled_layout
(
O_shared_r
),
}
)
# barriers_Q
q_shared_ready_barrier
=
T
.
alloc_barrier
(
arrive_count
=
256
)
...
...
@@ -108,9 +110,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
tx
=
T
.
get_thread_binding
()
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:
h_dim
],
Q_shared_l
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
h_dim
:],
Q_shared_r
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:
h_dim
],
Q_shared_l
)
T
.
copy
(
Q
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
h_dim
:],
Q_shared_r
)
T
.
copy
(
Q_pe
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:],
Q_pe_shared
)
T
.
barrier_arrive
(
q_shared_ready_barrier
)
T
.
barrier_wait
(
q_shared_ready_barrier
,
0
)
...
...
@@ -123,25 +125,18 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
fill
(
acc_o_l
,
0
)
T
.
fill
(
logsum_0
,
0
)
T
.
copy
(
KV
[
bid
,
block_N
:
2
*
block_N
,
cur_kv_head
,
:
h_dim
],
KV_shared_1_l
)
T
.
copy
(
KV
[
bid
,
block_N
:
2
*
block_N
,
cur_kv_head
,
:
h_dim
],
KV_shared_1_l
)
T
.
barrier_arrive
(
kv_shared_1_l_is_ready
)
T
.
copy
(
KV
[
bid
,
block_N
:
2
*
block_N
,
cur_kv_head
,
h_dim
:],
KV_shared_1_r
)
T
.
copy
(
KV
[
bid
,
block_N
:
2
*
block_N
,
cur_kv_head
,
h_dim
:],
KV_shared_1_r
)
T
.
barrier_arrive
(
kv_shared_1_r_is_ready
)
T
.
copy
(
K_pe
[
bid
,
block_N
:
2
*
block_N
,
cur_kv_head
,
:],
K_pe_shared_1
)
T
.
copy
(
K_pe
[
bid
,
block_N
:
2
*
block_N
,
cur_kv_head
,
:],
K_pe_shared_1
)
T
.
barrier_arrive
(
kv_shared_1_pe_is_ready
)
for
k
in
T
.
serial
(
loop_range
):
T
.
barrier_wait
(
kv_shared_0_l_is_ready
,
k
%
2
)
T
.
gemm
(
Q_shared_l
,
KV_shared_0_l
,
acc_s_0
,
transpose_B
=
True
,
clear_accum
=
True
,
wg_wait
=-
1
)
T
.
gemm
(
Q_shared_l
,
KV_shared_0_l
,
acc_s_0
,
transpose_B
=
True
,
clear_accum
=
True
,
wg_wait
=-
1
)
T
.
barrier_wait
(
kv_shared_0_r_is_ready
,
k
%
2
)
T
.
gemm
(
Q_shared_r
,
KV_shared_0_r
,
acc_s_0
,
transpose_B
=
True
,
wg_wait
=-
1
)
...
...
@@ -161,8 +156,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s_0
[
i
,
j
]
=
T
.
exp2
(
acc_s_0
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
in
T
.
Parallel
(
block_H
):
scores_scale_0
[
i
]
=
T
.
exp2
(
scores_max_prev_0
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale_0
[
i
]
=
T
.
exp2
(
scores_max_prev_0
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s_0
,
scores_sum_0
,
dim
=
1
)
...
...
@@ -182,9 +176,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
barrier_wait
(
scale_1_ready_barrier
,
k
%
2
)
if
k
<
loop_range
-
1
:
T
.
copy
(
KV
[
bid
,
(
2
*
k
+
2
)
*
block_N
:(
2
*
k
+
3
)
*
block_N
,
cur_kv_head
,
:
h_dim
],
KV_shared_0_l
)
T
.
copy
(
KV
[
bid
,
(
2
*
k
+
2
)
*
block_N
:
(
2
*
k
+
3
)
*
block_N
,
cur_kv_head
,
:
h_dim
],
KV_shared_0_l
)
T
.
barrier_arrive
(
kv_shared_0_l_is_ready
)
# Step 11.
...
...
@@ -204,15 +196,10 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
gemm
(
SP1_shared
,
KV_shared_1_l
,
acc_o_l
)
if
k
<
loop_range
-
1
:
T
.
copy
(
KV
[
bid
,
(
2
*
k
+
3
)
*
block_N
:(
2
*
k
+
4
)
*
block_N
,
cur_kv_head
,
:
h_dim
],
KV_shared_1_l
)
T
.
copy
(
KV
[
bid
,
(
2
*
k
+
3
)
*
block_N
:
(
2
*
k
+
4
)
*
block_N
,
cur_kv_head
,
:
h_dim
],
KV_shared_1_l
)
T
.
barrier_arrive
(
kv_shared_1_l_is_ready
)
T
.
copy
(
K_pe
[
bid
,
(
2
*
k
+
3
)
*
block_N
:(
2
*
k
+
4
)
*
block_N
,
cur_kv_head
,
:],
K_pe_shared_1
)
T
.
copy
(
K_pe
[
bid
,
(
2
*
k
+
3
)
*
block_N
:
(
2
*
k
+
4
)
*
block_N
,
cur_kv_head
,
:],
K_pe_shared_1
)
T
.
barrier_arrive
(
kv_shared_1_pe_is_ready
)
T
.
copy
(
logsum_0
,
logsum
)
...
...
@@ -221,8 +208,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for
i
,
j
in
T
.
Parallel
(
block_H
,
h_dim
):
acc_o_l
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o_l
,
O_shared_l
)
T
.
copy
(
O_shared_l
,
Output
[
bid
,
hid
*
VALID_BLOCK_H
:(
hid
+
1
)
*
VALID_BLOCK_H
,
:
h_dim
])
T
.
copy
(
O_shared_l
,
Output
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
:
h_dim
])
else
:
T
.
copy
(
Q_pe_shared
,
Q_pe_local_1
)
...
...
@@ -237,16 +223,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
barrier_arrive
(
kv_shared_0_pe_is_ready
)
for
k
in
T
.
serial
(
loop_range
):
# Step 2.
T
.
barrier_wait
(
kv_shared_1_l_is_ready
,
k
%
2
)
T
.
gemm
(
Q_shared_l
,
KV_shared_1_l
,
acc_s_1
,
transpose_B
=
True
,
clear_accum
=
True
,
wg_wait
=-
1
)
T
.
gemm
(
Q_shared_l
,
KV_shared_1_l
,
acc_s_1
,
transpose_B
=
True
,
clear_accum
=
True
,
wg_wait
=-
1
)
T
.
barrier_wait
(
kv_shared_1_r_is_ready
,
k
%
2
)
T
.
gemm
(
Q_shared_r
,
KV_shared_1_r
,
acc_s_1
,
transpose_B
=
True
,
wg_wait
=-
1
)
...
...
@@ -265,8 +244,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
copy
(
scores_max_1
,
scores_max
)
for
i
in
T
.
Parallel
(
block_H
):
scores_scale_1
[
i
]
=
T
.
exp2
(
scores_max_prev_1
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale_1
[
i
]
=
T
.
exp2
(
scores_max_prev_1
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
# Step 8.
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
...
...
@@ -279,8 +257,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
acc_o_r
[
i
,
j
]
=
acc_o_r
[
i
,
j
]
*
(
scores_scale_0
[
i
]
*
scores_scale_1
[
i
])
for
i
in
T
.
Parallel
(
block_H
):
logsum_1
[
i
]
=
logsum_1
[
i
]
*
scores_scale_1
[
i
]
*
scores_scale_0
[
i
]
+
scores_sum_1
[
i
]
logsum_1
[
i
]
=
logsum_1
[
i
]
*
scores_scale_1
[
i
]
*
scores_scale_0
[
i
]
+
scores_sum_1
[
i
]
T
.
barrier_arrive
(
scale_1_ready_barrier
)
...
...
@@ -291,9 +268,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
barrier_arrive
(
s_shared_ready_barrier
)
if
k
<
loop_range
-
1
:
T
.
copy
(
KV
[
bid
,
(
2
*
k
+
3
)
*
block_N
:(
2
*
k
+
4
)
*
block_N
,
cur_kv_head
,
h_dim
:],
KV_shared_1_r
)
T
.
copy
(
KV
[
bid
,
(
2
*
k
+
3
)
*
block_N
:
(
2
*
k
+
4
)
*
block_N
,
cur_kv_head
,
h_dim
:],
KV_shared_1_r
)
T
.
barrier_arrive
(
kv_shared_1_r_is_ready
)
T
.
barrier_wait
(
p0_1_1_ready_barrier
,
k
%
2
)
...
...
@@ -301,15 +276,10 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T
.
gemm
(
SP0_shared
,
KV_shared_0_r
,
acc_o_r
)
if
k
<
loop_range
-
1
:
T
.
copy
(
KV
[
bid
,
(
2
*
k
+
2
)
*
block_N
:(
2
*
k
+
3
)
*
block_N
,
cur_kv_head
,
h_dim
:],
KV_shared_0_r
)
T
.
copy
(
KV
[
bid
,
(
2
*
k
+
2
)
*
block_N
:
(
2
*
k
+
3
)
*
block_N
,
cur_kv_head
,
h_dim
:],
KV_shared_0_r
)
T
.
barrier_arrive
(
kv_shared_0_r_is_ready
)
T
.
copy
(
K_pe
[
bid
,
(
2
*
k
+
2
)
*
block_N
:(
2
*
k
+
3
)
*
block_N
,
cur_kv_head
,
:],
K_pe_shared_0
)
T
.
copy
(
K_pe
[
bid
,
(
2
*
k
+
2
)
*
block_N
:
(
2
*
k
+
3
)
*
block_N
,
cur_kv_head
,
:],
K_pe_shared_0
)
T
.
barrier_arrive
(
kv_shared_0_pe_is_ready
)
T
.
barrier_wait
(
lse_0_ready_barrier
,
0
)
...
...
@@ -319,18 +289,17 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for
i
,
j
in
T
.
Parallel
(
block_H
,
h_dim
):
acc_o_r
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o_r
,
O_shared_r
)
T
.
copy
(
O_shared_r
,
Output
[
bid
,
hid
*
VALID_BLOCK_H
:(
hid
+
1
)
*
VALID_BLOCK_H
,
h_dim
:])
T
.
copy
(
O_shared_r
,
Output
[
bid
,
hid
*
VALID_BLOCK_H
:
(
hid
+
1
)
*
VALID_BLOCK_H
,
h_dim
:])
@
T
.
prim_func
def
main_no_split
(
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
Q_pe
:
T
.
Tensor
([
batch
,
heads
,
pe_dim
],
dtype
),
KV
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
dim
],
dtype
),
K_pe
:
T
.
Tensor
([
batch
,
seqlen_kv
,
kv_head_num
,
pe_dim
],
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
dtype
),
Output_partial
:
T
.
Tensor
([
batch
,
heads
,
num_split
,
dim
],
dtype
),
Output
:
T
.
Tensor
([
batch
,
heads
,
dim
],
dtype
),
):
flash_attn
(
Q
,
Q_pe
,
KV
,
K_pe
,
Output
)
...
...
@@ -352,31 +321,24 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
dim
=
q
.
shape
[
-
1
]
pe_dim
=
q_pe
.
shape
[
-
1
]
num_head_groups
=
q
.
shape
[
1
]
//
kv
.
shape
[
2
]
scale
=
(
dim
+
pe_dim
)
**
0.5
q
=
rearrange
(
q
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
scale
=
(
dim
+
pe_dim
)
**
0.5
q
=
rearrange
(
q
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, dim]
q_pe
=
rearrange
(
q_pe
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
q_pe
=
rearrange
(
q_pe
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, groups, pe_dim]
kv
=
rearrange
(
kv
,
'
b n h d -> b h n d
'
)
# [batch_size, groups, seqlen_kv, dim]
kv
=
rearrange
(
kv
,
"
b n h d -> b h n d
"
)
# [batch_size, groups, seqlen_kv, dim]
k_pe
=
rearrange
(
k_pe
,
'
b n h d -> b h n d
'
)
# [batch_size, num_head_groups, groups, pe_dim]
k_pe
=
rearrange
(
k_pe
,
"
b n h d -> b h n d
"
)
# [batch_size, num_head_groups, groups, pe_dim]
query
=
torch
.
concat
([
q
,
q_pe
],
dim
=-
1
)
key
=
torch
.
concat
([
kv
,
k_pe
],
dim
=-
1
)
scores
=
einsum
(
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, groups, seqlen_kv]
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, groups, seqlen_kv]
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, groups, seqlen_kv]
out
=
einsum
(
attention
,
kv
,
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
out
=
einsum
(
attention
,
kv
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, groups, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
return
out
...
...
@@ -399,12 +361,12 @@ def main(batch=1, heads=64, kv_heads=1, kv_ctx=1024, dim=512, pe_dim=64):
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
1
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
128
,
help
=
'
q heads number
'
)
parser
.
add_argument
(
'
--kv_heads
'
,
type
=
int
,
default
=
1
,
help
=
'
kv heads number
'
)
parser
.
add_argument
(
'
--kv_ctx
'
,
type
=
int
,
default
=
8192
,
help
=
'
kv context length
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
512
,
help
=
'
head dim
'
)
parser
.
add_argument
(
'
--pe_dim
'
,
type
=
int
,
default
=
64
,
help
=
'
pe head dim
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
1
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
128
,
help
=
"
q heads number
"
)
parser
.
add_argument
(
"
--kv_heads
"
,
type
=
int
,
default
=
1
,
help
=
"
kv heads number
"
)
parser
.
add_argument
(
"
--kv_ctx
"
,
type
=
int
,
default
=
8192
,
help
=
"
kv context length
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
512
,
help
=
"
head dim
"
)
parser
.
add_argument
(
"
--pe_dim
"
,
type
=
int
,
default
=
64
,
help
=
"
pe head dim
"
)
args
=
parser
.
parse_args
()
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
=
args
.
batch
,
args
.
heads
,
args
.
kv_heads
,
args
.
kv_ctx
,
args
.
dim
,
args
.
pe_dim
main
(
batch
,
heads
,
kv_heads
,
kv_ctx
,
dim
,
pe_dim
)
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py
View file @
667632cc
...
...
@@ -7,8 +7,7 @@ tilelang.disable_cache()
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
@
tilelang
.
jit
(
out_idx
=
[
2
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
T
.
float16
,
accum_dtype
=
T
.
float32
):
num_stages
=
2
mbarrier_list
=
[
128
,
128
]
*
num_stages
...
...
@@ -32,19 +31,13 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
for
ko
in
range
(
T
.
ceildiv
(
K
,
block_K
)):
with
T
.
ws
(
1
):
T
.
mbarrier_wait_parity
(
mbarrier
=
ko
%
num_stages
+
num_stages
,
parity
=
((
ko
//
num_stages
)
%
num_stages
)
^
1
)
T
.
copy
(
A
[
by
*
block_M
:(
by
+
1
)
*
block_M
,
ko
*
block_K
:(
ko
+
1
)
*
block_K
],
A_shared
[
ko
%
num_stages
,
:,
:])
T
.
copy
(
B
[
ko
*
block_K
:(
ko
+
1
)
*
block_K
,
bx
*
block_N
:(
bx
+
1
)
*
block_N
],
B_shared
[
ko
%
num_stages
,
:,
:])
T
.
mbarrier_wait_parity
(
mbarrier
=
ko
%
num_stages
+
num_stages
,
parity
=
((
ko
//
num_stages
)
%
num_stages
)
^
1
)
T
.
copy
(
A
[
by
*
block_M
:
(
by
+
1
)
*
block_M
,
ko
*
block_K
:
(
ko
+
1
)
*
block_K
],
A_shared
[
ko
%
num_stages
,
:,
:])
T
.
copy
(
B
[
ko
*
block_K
:
(
ko
+
1
)
*
block_K
,
bx
*
block_N
:
(
bx
+
1
)
*
block_N
],
B_shared
[
ko
%
num_stages
,
:,
:])
T
.
mbarrier_arrive
(
mbarrier
=
ko
%
num_stages
)
with
T
.
ws
(
0
):
T
.
mbarrier_wait_parity
(
mbarrier
=
ko
%
num_stages
,
parity
=
(
ko
//
num_stages
)
%
num_stages
)
T
.
gemm
(
A_shared
[
ko
%
num_stages
,
:,
:],
B_shared
[
ko
%
num_stages
,
:,
:],
C_local
)
T
.
mbarrier_wait_parity
(
mbarrier
=
ko
%
num_stages
,
parity
=
(
ko
//
num_stages
)
%
num_stages
)
T
.
gemm
(
A_shared
[
ko
%
num_stages
,
:,
:],
B_shared
[
ko
%
num_stages
,
:,
:],
C_local
)
T
.
mbarrier_arrive
(
mbarrier
=
ko
%
num_stages
+
num_stages
)
with
T
.
ws
(
0
):
...
...
examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py
View file @
667632cc
...
...
@@ -5,20 +5,12 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
@
tilelang
.
jit
(
out_idx
=
[
2
])
def
matmul_warp_specialize_copy_0_gemm_1
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul_warp_specialize_copy_0_gemm_1
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
T
.
float16
,
accum_dtype
=
T
.
float32
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
...
...
examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py
View file @
667632cc
...
...
@@ -5,20 +5,12 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
@
tilelang
.
jit
(
out_idx
=
[
2
])
def
matmul_warp_specialize_copy_1_gemm_0
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul_warp_specialize_copy_1_gemm_0
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
T
.
float16
,
accum_dtype
=
T
.
float32
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
...
...
examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py
View file @
667632cc
...
...
@@ -5,26 +5,20 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
@
tilelang
.
jit
(
out_idx
=
[
2
],
pass_configs
=
{
out_idx
=
[
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
})
def
matmul_warp_specialize_copy_1_gemm_0
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
},
)
def
matmul_warp_specialize_copy_1_gemm_0
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
T
.
float16
,
accum_dtype
=
T
.
float32
):
warp_group_num
=
2
threads
=
128
*
warp_group_num
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
...
...
examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py
View file @
667632cc
...
...
@@ -5,8 +5,7 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
@
tilelang
.
jit
(
out_idx
=
[
2
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
T
.
float16
,
accum_dtype
=
T
.
float32
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
[(
M
,
K
),
dtype
],
...
...
format.sh
View file @
667632cc
...
...
@@ -9,7 +9,7 @@
# bash format.sh --all
#
#
#
YAPF
+ Clang formatter (if installed). This script formats all changed files from the last mergebase.
#
Ruff (format)
+ Clang formatter (if installed). This script formats all changed files from the last mergebase.
# You are encouraged to run this locally before pushing changes for review.
# Cause the script to exit if a single command fails
...
...
maint/gemm_v2/correctness_evaluation.py
View file @
667632cc
...
...
@@ -2,6 +2,8 @@
import
pytest
from
tilelang
import
tvm
as
tvm
import
tilelang.testing
from
tilelang
import
language
as
T
import
torch
def
matmul
(
...
...
@@ -24,13 +26,11 @@ def matmul(
A_shared_shape
=
(
block_K
,
block_M
)
if
trans_A
else
(
block_M
,
block_K
)
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
...
...
@@ -66,20 +66,19 @@ def _compile_and_check(
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
# tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False,
})
},
)
print
(
kernel
.
get_kernel_source
())
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
def
ref_program
(
A
,
B
):
import
torch
if
trans_A
:
A
=
A
.
T
if
trans_B
:
B
=
B
.
T
if
in_dtype
==
"
float32
"
:
if
in_dtype
==
T
.
float32
:
A
=
(
A
.
view
(
torch
.
int32
)
-
0x1000
).
view
(
torch
.
float32
)
B
=
(
B
.
view
(
torch
.
int32
)
-
0x1000
).
view
(
torch
.
float32
)
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
...
...
@@ -147,13 +146,11 @@ def matmul_rs(
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
A_frag_shape
=
A_shared_shape
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
...
...
@@ -234,13 +231,11 @@ def matmul_sr(
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
B_frag_shape
=
B_shared_shape
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
...
...
@@ -322,13 +317,11 @@ def matmul_rr(
A_frag_shape
=
A_shared_shape
B_frag_shape
=
B_shared_shape
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
...
...
@@ -394,37 +387,48 @@ M_VALUES = [64, 128, 256]
N_VALUES
=
[
16
,
32
,
64
,
128
,
256
,
512
]
K_VALUES
=
[
16
,
32
,
64
,
128
]
K_VALUES_8Bit
=
[
32
,
64
,
128
]
FALSE_TRUE_CASES
=
([
pytest
.
param
(
k
,
"float16"
,
"float16"
,
"float16"
,
id
=
f
"K
{
k
}
-float16-float16-float16"
,
)
for
k
in
K_VALUES
]
+
[
pytest
.
param
(
k
,
"int8"
,
"int32"
,
"int32"
,
id
=
"K32-int8-int32-int32"
,
)
for
k
in
K_VALUES_8Bit
]
+
[
pytest
.
param
(
k
,
"float8_e5m2"
,
"float32"
,
"float32"
,
id
=
"K32-float8_e5m2-float32-float32"
,
)
for
k
in
K_VALUES_8Bit
]
+
[
pytest
.
param
(
k
,
"float8_e4m3"
,
"float32"
,
"float32"
,
id
=
"K32-float8_e4m3-float32-float32"
,
)
for
k
in
K_VALUES_8Bit
])
FALSE_TRUE_CASES
=
(
[
pytest
.
param
(
k
,
T
.
float16
,
T
.
float16
,
T
.
float16
,
id
=
f
"K
{
k
}
-float16-float16-float16"
,
)
for
k
in
K_VALUES
]
+
[
pytest
.
param
(
k
,
T
.
int8
,
T
.
int32
,
T
.
int32
,
id
=
"K32-int8-int32-int32"
,
)
for
k
in
K_VALUES_8Bit
]
+
[
pytest
.
param
(
k
,
T
.
float8_e5m2
,
T
.
float32
,
T
.
float32
,
id
=
"K32-float8_e5m2-float32-float32"
,
)
for
k
in
K_VALUES_8Bit
]
+
[
pytest
.
param
(
k
,
T
.
float8_e4m3fn
,
T
.
float32
,
T
.
float32
,
id
=
"K32-float8_e4m3-float32-float32"
,
)
for
k
in
K_VALUES_8Bit
]
)
def
_ensure_torch_dtypes
(
*
dtype_names
):
...
...
@@ -440,15 +444,15 @@ def run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
def
run_gemm_rs_false_false
(
m
,
n
,
k
):
run_gemm_rs
(
m
,
n
,
k
*
3
,
False
,
False
,
"
float16
"
,
"
float16
"
,
"
float16
"
,
m
,
n
,
k
)
run_gemm_rs
(
m
,
n
,
k
*
3
,
False
,
False
,
T
.
float16
,
T
.
float16
,
T
.
float16
,
m
,
n
,
k
)
def
run_gemm_rs_true_false
(
m
,
n
,
k
):
run_gemm_rs
(
m
,
n
,
k
*
3
,
True
,
False
,
"
float16
"
,
"
float16
"
,
"
float16
"
,
m
,
n
,
k
)
run_gemm_rs
(
m
,
n
,
k
*
3
,
True
,
False
,
T
.
float16
,
T
.
float16
,
T
.
float16
,
m
,
n
,
k
)
def
run_gemm_rs_true_true
(
m
,
n
,
k
):
run_gemm_rs
(
m
,
n
,
k
*
3
,
True
,
True
,
"
float16
"
,
"
float16
"
,
"
float16
"
,
m
,
n
,
k
)
run_gemm_rs
(
m
,
n
,
k
*
3
,
True
,
True
,
T
.
float16
,
T
.
float16
,
T
.
float16
,
m
,
n
,
k
)
def
run_gemm_sr_false_true
(
m
,
n
,
k
,
in_dtype
,
out_dtype
,
accum_dtype
):
...
...
@@ -456,15 +460,15 @@ def run_gemm_sr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
def
run_gemm_sr_false_false
(
m
,
n
,
k
):
run_gemm_sr
(
m
,
n
,
k
*
3
,
False
,
False
,
"
float16
"
,
"
float16
"
,
"
float16
"
,
m
,
n
,
k
)
run_gemm_sr
(
m
,
n
,
k
*
3
,
False
,
False
,
T
.
float16
,
T
.
float16
,
T
.
float16
,
m
,
n
,
k
)
def
run_gemm_sr_true_false
(
m
,
n
,
k
):
run_gemm_sr
(
m
,
n
,
k
*
3
,
True
,
False
,
"
float16
"
,
"
float16
"
,
"
float16
"
,
m
,
n
,
k
)
run_gemm_sr
(
m
,
n
,
k
*
3
,
True
,
False
,
T
.
float16
,
T
.
float16
,
T
.
float16
,
m
,
n
,
k
)
def
run_gemm_sr_true_true
(
m
,
n
,
k
):
run_gemm_sr
(
m
,
n
,
k
*
3
,
True
,
True
,
"
float16
"
,
"
float16
"
,
"
float16
"
,
m
,
n
,
k
)
run_gemm_sr
(
m
,
n
,
k
*
3
,
True
,
True
,
T
.
float16
,
T
.
float16
,
T
.
float16
,
m
,
n
,
k
)
def
run_gemm_rr_false_true
(
m
,
n
,
k
,
in_dtype
,
out_dtype
,
accum_dtype
):
...
...
@@ -472,15 +476,15 @@ def run_gemm_rr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
def
run_gemm_rr_false_false
(
m
,
n
,
k
):
run_gemm_rr
(
m
,
n
,
k
*
3
,
False
,
False
,
"
float16
"
,
"
float16
"
,
"
float16
"
,
m
,
n
,
k
)
run_gemm_rr
(
m
,
n
,
k
*
3
,
False
,
False
,
T
.
float16
,
T
.
float16
,
T
.
float16
,
m
,
n
,
k
)
def
run_gemm_rr_true_false
(
m
,
n
,
k
):
run_gemm_rr
(
m
,
n
,
k
*
3
,
True
,
False
,
"
float16
"
,
"
float16
"
,
"
float16
"
,
m
,
n
,
k
)
run_gemm_rr
(
m
,
n
,
k
*
3
,
True
,
False
,
T
.
float16
,
T
.
float16
,
T
.
float16
,
m
,
n
,
k
)
def
run_gemm_rr_true_true
(
m
,
n
,
k
):
run_gemm_rr
(
m
,
n
,
k
*
3
,
True
,
True
,
"
float16
"
,
"
float16
"
,
"
float16
"
,
m
,
n
,
k
)
run_gemm_rr
(
m
,
n
,
k
*
3
,
True
,
True
,
T
.
float16
,
T
.
float16
,
T
.
float16
,
m
,
n
,
k
)
TRANS_CASES
=
[
...
...
@@ -536,9 +540,9 @@ def test_gemm_false_false(m, n, k):
k
*
3
,
False
,
False
,
"
float16
"
,
"
float16
"
,
"
float16
"
,
T
.
float16
,
T
.
float16
,
T
.
float16
,
m
,
n
,
k
,
...
...
@@ -555,9 +559,9 @@ def test_gemm_true_false(m, n, k):
k
*
3
,
True
,
False
,
"
float16
"
,
"
float16
"
,
"
float16
"
,
T
.
float16
,
T
.
float16
,
T
.
float16
,
m
,
n
,
k
,
...
...
@@ -574,9 +578,9 @@ def test_gemm_true_true(m, n, k):
k
*
3
,
True
,
True
,
"
float16
"
,
"
float16
"
,
"
float16
"
,
T
.
float16
,
T
.
float16
,
T
.
float16
,
m
,
n
,
k
,
...
...
@@ -595,7 +599,7 @@ def test_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k"
,
K_VALUES
,
ids
=
lambda
v
:
f
"K
{
v
}
"
)
def
test_gemm_rs_false_false
(
m
,
n
,
k
):
_ensure_torch_dtypes
(
"
float16
"
)
_ensure_torch_dtypes
(
T
.
float16
)
run_gemm_rs_false_false
(
m
,
n
,
k
)
...
...
@@ -603,7 +607,7 @@ def test_gemm_rs_false_false(m, n, k):
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k"
,
K_VALUES
,
ids
=
lambda
v
:
f
"K
{
v
}
"
)
def
test_gemm_rs_true_false
(
m
,
n
,
k
):
_ensure_torch_dtypes
(
"
float16
"
)
_ensure_torch_dtypes
(
T
.
float16
)
run_gemm_rs_true_false
(
m
,
n
,
k
)
...
...
@@ -611,7 +615,7 @@ def test_gemm_rs_true_false(m, n, k):
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k"
,
K_VALUES
,
ids
=
lambda
v
:
f
"K
{
v
}
"
)
def
test_gemm_rs_true_true
(
m
,
n
,
k
):
_ensure_torch_dtypes
(
"
float16
"
)
_ensure_torch_dtypes
(
T
.
float16
)
run_gemm_rs_true_true
(
m
,
n
,
k
)
...
...
@@ -627,7 +631,7 @@ def test_gemm_sr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k"
,
K_VALUES
,
ids
=
lambda
v
:
f
"K
{
v
}
"
)
def
test_gemm_sr_false_false
(
m
,
n
,
k
):
_ensure_torch_dtypes
(
"
float16
"
)
_ensure_torch_dtypes
(
T
.
float16
)
run_gemm_sr_false_false
(
m
,
n
,
k
)
...
...
@@ -635,7 +639,7 @@ def test_gemm_sr_false_false(m, n, k):
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k"
,
K_VALUES
,
ids
=
lambda
v
:
f
"K
{
v
}
"
)
def
test_gemm_sr_true_false
(
m
,
n
,
k
):
_ensure_torch_dtypes
(
"
float16
"
)
_ensure_torch_dtypes
(
T
.
float16
)
run_gemm_sr_true_false
(
m
,
n
,
k
)
...
...
@@ -643,7 +647,7 @@ def test_gemm_sr_true_false(m, n, k):
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k"
,
K_VALUES
,
ids
=
lambda
v
:
f
"K
{
v
}
"
)
def
test_gemm_sr_true_true
(
m
,
n
,
k
):
_ensure_torch_dtypes
(
"
float16
"
)
_ensure_torch_dtypes
(
T
.
float16
)
run_gemm_sr_true_true
(
m
,
n
,
k
)
...
...
@@ -659,7 +663,7 @@ def test_gemm_rr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k"
,
K_VALUES
,
ids
=
lambda
v
:
f
"K
{
v
}
"
)
def
test_gemm_rr_false_false
(
m
,
n
,
k
):
_ensure_torch_dtypes
(
"
float16
"
)
_ensure_torch_dtypes
(
T
.
float16
)
run_gemm_rr_false_false
(
m
,
n
,
k
)
...
...
@@ -667,7 +671,7 @@ def test_gemm_rr_false_false(m, n, k):
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k"
,
K_VALUES
,
ids
=
lambda
v
:
f
"K
{
v
}
"
)
def
test_gemm_rr_true_false
(
m
,
n
,
k
):
_ensure_torch_dtypes
(
"
float16
"
)
_ensure_torch_dtypes
(
T
.
float16
)
run_gemm_rr_true_false
(
m
,
n
,
k
)
...
...
@@ -675,7 +679,7 @@ def test_gemm_rr_true_false(m, n, k):
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k"
,
K_VALUES
,
ids
=
lambda
v
:
f
"K
{
v
}
"
)
def
test_gemm_rr_true_true
(
m
,
n
,
k
):
_ensure_torch_dtypes
(
"
float16
"
)
_ensure_torch_dtypes
(
T
.
float16
)
run_gemm_rr_true_true
(
m
,
n
,
k
)
...
...
@@ -687,7 +691,7 @@ if __name__ == "__main__":
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True,
"
float16
"
,
"
float16
"
,
"
float16
"
, m, n, k, 2, 128)
# run_gemm(m, n, k * 3, False, True,
T.
float16,
T.
float16,
T.
float16, m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# # Test Pass
...
...
@@ -695,7 +699,7 @@ if __name__ == "__main__":
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# print(f"======================= Test {m} {n} {k} False False =============================")
# run_gemm(m, n, k * 3, False, False,
"
float16
"
,
"
float16
"
,
"
float16
"
, m, n, k, 2, 128)
# run_gemm(m, n, k * 3, False, False,
T.
float16,
T.
float16,
T.
float16, m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# # Test Pass
...
...
@@ -703,7 +707,7 @@ if __name__ == "__main__":
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# print(f"======================= Test {m} {n} {k} True False =============================")
# run_gemm(m, n, k * 3, True, False,
"
float16
"
,
"
float16
"
,
"
float16
"
, m, n, k, 2, 128)
# run_gemm(m, n, k * 3, True, False,
T.
float16,
T.
float16,
T.
float16, m, n, k, 2, 128)
# print(f"Test {m}, {n} {k} Pass")
# print(f"Test {n} Pass")
...
...
@@ -712,7 +716,7 @@ if __name__ == "__main__":
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# print(f"======================= Test {m} {n} {k} True True =============================")
# run_gemm(m, n, k * 3, True, True,
"
float16
"
,
"
float16
"
,
"
float16
"
, m, n, k, 2, 128)
# run_gemm(m, n, k * 3, True, True,
T.
float16,
T.
float16,
T.
float16, m, n, k, 2, 128)
# print(f"Test {m}, {n} {k} Pass")
# print(f"Test {n} Pass")
...
...
@@ -721,15 +725,15 @@ if __name__ == "__main__":
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm_rs(m, n, k * 3, False, True,
"
float16
"
,
"
float16
"
,
"
float16
"
, m, n, k, 2, 128)
# run_gemm_rs(m, n, k * 3, False, True,
T.
float16,
T.
float16,
T.
float16, m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# run_gemm_rs(64, n, k, False, False,
"
float16
"
,
"
float16
"
,
"
float16
"
, 64, n, k, 0, 256)
# run_gemm_rs(64, n, k, False, False,
T.
float16,
T.
float16,
T.
float16, 64, n, k, 0, 256)
# print(f"Test {64} {n} {k} Pass")
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# run_gemm(64, n, k, False, False,
"
float16
"
,
"
float16
"
,
"
float16
"
, 64, n, k, 0, 256)
# run_gemm(64, n, k, False, False,
T.
float16,
T.
float16,
T.
float16, 64, n, k, 0, 256)
# print(f"Test {64} {n} {k} Pass")
maint/gemm_v2/correctness_evaluation_sm70.py
View file @
667632cc
...
...
@@ -2,6 +2,7 @@
import
pytest
from
tilelang
import
tvm
as
tvm
import
tilelang.testing
from
tilelang
import
language
as
T
def
matmul
(
...
...
@@ -24,13 +25,11 @@ def matmul(
A_shared_shape
=
(
block_K
,
block_M
)
if
trans_A
else
(
block_M
,
block_K
)
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
...
...
@@ -67,7 +66,8 @@ def _compile_and_check(
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
# tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False,
})
},
)
print
(
kernel
.
get_kernel_source
())
...
...
@@ -80,7 +80,7 @@ def _compile_and_check(
A
=
A
.
T
if
trans_B
:
B
=
B
.
T
if
in_dtype
==
"
float32
"
:
if
in_dtype
==
T
.
float32
:
A
=
(
A
.
view
(
torch
.
int32
)
-
0x1000
).
view
(
torch
.
float32
)
B
=
(
B
.
view
(
torch
.
int32
)
-
0x1000
).
view
(
torch
.
float32
)
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
...
...
@@ -146,13 +146,11 @@ def matmul_rs(
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
A_frag_shape
=
A_shared_shape
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared.dyn"
)
...
...
@@ -213,23 +211,25 @@ def run_gemm_rs(
M_VALUES
=
[
64
,
128
]
N_VALUES
=
[
32
,
64
,
128
]
K_VALUES
=
[
16
,
32
,
64
]
FALSE_TRUE_CASES
=
(
[
FALSE_TRUE_CASES
=
[
pytest
.
param
(
k
,
"
float16
"
,
"
float16
"
,
"
float16
"
,
T
.
float16
,
T
.
float16
,
T
.
float16
,
id
=
f
"K
{
k
}
-float16-float16-float16"
,
)
for
k
in
K_VALUES
)
for
k
in
K_VALUES
]
+
[
pytest
.
param
(
k
,
"
float16
"
,
"
float16
"
,
"
float32
"
,
T
.
float16
,
T
.
float16
,
T
.
float32
,
id
=
f
"K
{
k
}
-float16-float16-float32"
,
)
for
k
in
K_VALUES
])
)
for
k
in
K_VALUES
]
def
_ensure_torch_dtypes
(
*
dtype_names
):
...
...
@@ -245,7 +245,7 @@ def run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
def
run_gemm_rs_false_false
(
m
,
n
,
k
):
run_gemm_rs
(
m
,
n
,
k
*
3
,
False
,
False
,
"
float16
"
,
"
float16
"
,
"
float16
"
,
m
,
n
,
k
,
2
,
128
)
run_gemm_rs
(
m
,
n
,
k
*
3
,
False
,
False
,
T
.
float16
,
T
.
float16
,
T
.
float16
,
m
,
n
,
k
,
2
,
128
)
TRANS_CASES
=
[
...
...
@@ -303,9 +303,9 @@ def test_gemm_false_false(m, n, k):
k
*
3
,
False
,
False
,
"
float16
"
,
"
float16
"
,
"
float16
"
,
T
.
float16
,
T
.
float16
,
T
.
float16
,
m
,
n
,
k
,
...
...
@@ -326,7 +326,7 @@ def test_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
@
pytest
.
mark
.
parametrize
(
"n"
,
N_VALUES
,
ids
=
lambda
v
:
f
"N
{
v
}
"
)
@
pytest
.
mark
.
parametrize
(
"k"
,
K_VALUES
,
ids
=
lambda
v
:
f
"K
{
v
}
"
)
def
test_gemm_rs_false_false
(
m
,
n
,
k
):
_ensure_torch_dtypes
(
"
float16
"
)
_ensure_torch_dtypes
(
T
.
float16
)
run_gemm_rs_false_false
(
m
,
n
,
k
)
...
...
@@ -338,7 +338,7 @@ if __name__ == "__main__":
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64]:
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True,
"
float16
"
,
"
float16
"
,
"
float16
"
, m, n, k, 2, 128)
# run_gemm(m, n, k * 3, False, True,
T.
float16,
T.
float16,
T.
float16, m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# # Test Pass
...
...
@@ -346,5 +346,5 @@ if __name__ == "__main__":
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64]:
# print(f"======================= Test {m} {n} {k} False False =============================")
# run_gemm(m, n, k * 3, False, False,
"
float16
"
,
"
float16
"
,
"
float16
"
, m, n, k, 2, 128)
# run_gemm(m, n, k * 3, False, False,
T.
float16,
T.
float16,
T.
float16, m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
maint/gemm_v2/correctness_evaluation_tcgen05.py
View file @
667632cc
...
...
@@ -27,9 +27,9 @@ def matmul(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
...
@@ -42,15 +42,7 @@ def matmul(
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_tmem
,
trans_A
,
trans_B
,
mbar
=
mbar
,
wg_wait
=-
1
,
clear_accum
=
k
==
0
)
T
.
gemm
(
A_shared
,
B_shared
,
C_tmem
,
trans_A
,
trans_B
,
mbar
=
mbar
,
wg_wait
=-
1
,
clear_accum
=
k
==
0
)
T
.
mbarrier_wait_parity
(
mbar
,
k
%
2
)
T
.
copy
(
C_tmem
,
C_local
)
...
...
@@ -74,7 +66,8 @@ def _compile_and_check(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
print
(
kernel
.
get_kernel_source
())
...
...
@@ -87,7 +80,7 @@ def _compile_and_check(
A
=
A
.
T
if
trans_B
:
B
=
B
.
T
if
in_dtype
==
"
float32
"
:
if
in_dtype
==
T
.
float32
:
A
=
(
A
.
view
(
torch
.
int32
)
-
0x1000
).
view
(
torch
.
float32
)
B
=
(
B
.
view
(
torch
.
int32
)
-
0x1000
).
view
(
torch
.
float32
)
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
...
...
@@ -138,23 +131,25 @@ M_VALUES = [32, 64, 128, 256]
N_VALUES
=
[
64
,
128
,
256
,
512
]
K_VALUES
=
[
16
,
32
,
64
,
128
]
K_VALUES_8Bit
=
[
32
,
64
,
128
]
FALSE_TRUE_CASES
=
(
[
FALSE_TRUE_CASES
=
[
pytest
.
param
(
k
,
"
float16
"
,
"
float32
"
,
"
float32
"
,
T
.
float16
,
T
.
float32
,
T
.
float32
,
id
=
f
"K
{
k
}
-float16-float-float"
,
)
for
k
in
K_VALUES
)
for
k
in
K_VALUES
]
+
[
pytest
.
param
(
k
,
"
float8_e5m2
"
,
"
float32
"
,
"
float32
"
,
T
.
float8_e5m2
,
T
.
float32
,
T
.
float32
,
id
=
"K32-float8_e5m2-float32-float32"
,
)
for
k
in
K_VALUES_8Bit
])
)
for
k
in
K_VALUES_8Bit
]
TRANS_CASES
=
[
pytest
.
param
(
False
,
True
,
id
=
"nt"
),
...
...
@@ -191,7 +186,7 @@ def test_gemm_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
if
__name__
==
"__main__"
:
#
tilelang.testing.main()
tilelang
.
testing
.
main
()
# # Test Pass
# for m in [32, 64, 128, 256]:
...
...
@@ -200,27 +195,24 @@ if __name__ == "__main__":
# if m in [32, 64] and (n not in [64, 128, 256]):
# continue
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True,
"
float16
"
,
"
float
"
,
"
float
"
, m, n, k, 2, 128)
# run_gemm(m, n, k * 3, False, True,
T.
float16,
T.
float,
T.
float, m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# # Test Pass
# for m in [32, 64, 128, 256]:
# for n in [
16,
32, 64, 128]:
# for k in [32, 64, 128]:
# for n in [32, 64, 128]:
# for k in [
16,
32, 64, 128]:
# if m in [32, 64] and (n not in [64, 128, 256]):
# continue
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True,
"
float
8_e5m2", "
float
"
,
"
float
"
, m, n, k, 2,
128
)
# run_gemm(m, n, k * 3, False, True,
T.
float
16, T.
float,
T.
float, m, n, k, 2,
256
)
# print(f"Test {m} {n} {k} Pass")
tilelang
.
disable_cache
()
run_gemm
(
32
,
512
,
16
,
False
,
True
,
"float16"
,
"float32"
,
"float32"
,
32
,
512
,
16
,
0
,
128
)
run_gemm
(
32
,
512
,
32
,
False
,
True
,
"float16"
,
"float32"
,
"float32"
,
32
,
512
,
32
,
0
,
128
)
run_gemm
(
32
,
512
,
64
,
False
,
True
,
"float16"
,
"float32"
,
"float32"
,
32
,
512
,
64
,
0
,
128
)
run_gemm
(
64
,
512
,
16
,
False
,
True
,
"float16"
,
"float32"
,
"float32"
,
64
,
512
,
16
,
0
,
128
)
run_gemm
(
64
,
512
,
16
,
False
,
True
,
"float16"
,
"float32"
,
"float32"
,
32
,
512
,
16
,
0
,
128
)
run_gemm
(
128
,
512
,
16
,
False
,
True
,
"float16"
,
"float32"
,
"float32"
,
128
,
512
,
16
,
0
,
128
)
# run_gemm(64, 512, 32, False, True, "float16", "float32", "float32", 64, 512, 32, 0, 128)
# run_gemm(64, 512, 64, False, True, "float16", "float32", "float32", 64, 512, 64, 0, 128)
# run_gemm(128, 512, 16, False, True, "float16", "float32", "float32", 128, 512, 16, 0, 128)
# # Test Pass
# for m in [32, 64, 128, 256]:
# for n in [16, 32, 64, 128]:
# for k in [32, 64, 128]:
# if m in [32, 64] and (n not in [64, 128, 256]):
# continue
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True, T.float8_e5m2, T.float, T.float, m, n, k, 2, 128)
maint/gemm_v2/latency.py
View file @
667632cc
...
...
@@ -13,13 +13,12 @@ use_v2 = args.use_v2
# target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time
@
tilelang
.
jit
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
T
.
float16
,
accum_dtype
=
T
.
float32
):
@
T
.
prim_func
def
matmul_relu_kernel
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
...
...
maint/gemm_v2/latency_gemm.py
View file @
667632cc
...
...
@@ -13,13 +13,12 @@ use_v2 = args.use_v2
# target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time
@
tilelang
.
jit
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
T
.
float16
,
accum_dtype
=
T
.
float32
):
@
T
.
prim_func
def
matmul_relu_kernel
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
...
...
maint/gemm_v2/latency_mha_fwd_bhsd.py
View file @
667632cc
...
...
@@ -8,13 +8,13 @@ import argparse
from
functools
import
partial
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--batch
'
,
type
=
int
,
default
=
128
,
help
=
'
batch size
'
)
parser
.
add_argument
(
'
--heads
'
,
type
=
int
,
default
=
16
,
help
=
'
heads
'
)
parser
.
add_argument
(
'
--seq_q
'
,
type
=
int
,
default
=
1024
,
help
=
'
query sequence length
'
)
parser
.
add_argument
(
'
--seq_kv
'
,
type
=
int
,
default
=
1024
,
help
=
'
key/value sequence length
'
)
parser
.
add_argument
(
'
--dim
'
,
type
=
int
,
default
=
256
,
help
=
'
dim
'
)
parser
.
add_argument
(
'
--is_causal
'
,
action
=
'
store_true
'
,
help
=
'
causal
'
)
parser
.
add_argument
(
'
--tune
'
,
action
=
'
store_true
'
,
help
=
'
tune configs
'
)
parser
.
add_argument
(
"
--batch
"
,
type
=
int
,
default
=
128
,
help
=
"
batch size
"
)
parser
.
add_argument
(
"
--heads
"
,
type
=
int
,
default
=
16
,
help
=
"
heads
"
)
parser
.
add_argument
(
"
--seq_q
"
,
type
=
int
,
default
=
1024
,
help
=
"
query sequence length
"
)
parser
.
add_argument
(
"
--seq_kv
"
,
type
=
int
,
default
=
1024
,
help
=
"
key/value sequence length
"
)
parser
.
add_argument
(
"
--dim
"
,
type
=
int
,
default
=
256
,
help
=
"
dim
"
)
parser
.
add_argument
(
"
--is_causal
"
,
action
=
"
store_true
"
,
help
=
"
causal
"
)
parser
.
add_argument
(
"
--tune
"
,
action
=
"
store_true
"
,
help
=
"
tune configs
"
)
parser
.
add_argument
(
"--use_v2"
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
...
...
@@ -29,24 +29,17 @@ def get_configs():
@
autotune
(
configs
=
get_configs
(),
warmup
=
10
,
rep
=
10
)
@
tilelang
.
jit
(
out_idx
=
[
3
],
pass_configs
=
{
out_idx
=
[
3
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
0
,
threads
=
128
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
},
)
def
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
0
,
threads
=
128
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
q_shape
=
[
batch
,
heads
,
seq_q
,
dim
]
kv_shape
=
[
batch
,
heads
,
seq_kv
,
dim
]
dtype
=
"
float16
"
accum_dtype
=
"
float
"
dtype
=
T
.
float16
accum_dtype
=
T
.
float
32
past_len
=
seq_kv
-
seq_q
assert
past_len
>=
0
,
"seq_kv must be greater than or equal to seq_q"
...
...
@@ -62,7 +55,7 @@ def flashattn(batch,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
q_idx
=
bx
*
block_M
+
i
+
past_len
...
...
@@ -85,7 +78,7 @@ def flashattn(batch,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
# T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if
use_v2
:
T
.
gemm_v2
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
...
@@ -94,13 +87,13 @@ def flashattn(batch,
@
T
.
macro
def
Softmax
(
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -125,18 +118,18 @@ def flashattn(batch,
@
T
.
macro
def
Rescale
(
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
@
T
.
prim_func
def
main
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
K
:
T
.
Tensor
(
kv_shape
,
dtype
),
V
:
T
.
Tensor
(
kv_shape
,
dtype
),
Output
:
T
.
Tensor
(
q_shape
,
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_q
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
...
...
@@ -152,43 +145,42 @@ def flashattn(batch,
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_
M
+
past_len
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_kv
,
block_N
)
)
T
.
min
(
T
.
ceildiv
(
seq_kv
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
+
past_len
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_kv
,
block_
N
)
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
return
main
def
ref_program
(
Q
,
K
,
V
,
is_causal
):
dim
=
Q
.
size
(
-
1
)
scores
=
torch
.
einsum
(
'
bhqd,bhkd->bhqk
'
,
Q
,
K
)
scores
=
torch
.
einsum
(
"
bhqd,bhkd->bhqk
"
,
Q
,
K
)
scores
=
scores
/
torch
.
sqrt
(
torch
.
tensor
(
dim
,
dtype
=
scores
.
dtype
))
if
is_causal
:
seq_q
=
Q
.
size
(
2
)
seq_kv
=
K
.
size
(
2
)
mask
=
torch
.
tril
(
torch
.
ones
(
seq_q
,
seq_kv
,
device
=
scores
.
device
),
seq_kv
-
seq_q
)
mask
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
mask
==
0
,
float
(
"
-inf
"
))
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
einsum
(
'
bhqk,bhkd->bhqd
'
,
attention_weights
,
V
)
output
=
torch
.
einsum
(
"
bhqk,bhkd->bhqd
"
,
attention_weights
,
V
)
return
output
...
...
@@ -206,18 +198,8 @@ def main(
if
is_causal
:
total_flops
*=
0.5
if
(
not
tune
):
kernel
=
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
0
,
threads
=
128
)
if
not
tune
:
kernel
=
flashattn
(
batch
,
heads
,
seq_q
,
seq_kv
,
dim
,
is_causal
,
block_M
=
64
,
block_N
=
64
,
num_stages
=
0
,
threads
=
128
)
print
(
kernel
.
get_kernel_source
())
ref_program_processed
=
partial
(
ref_program
,
is_causal
=
is_causal
)
...
...
maint/host_checks/01_num_args_mismatch.py
0 → 100644
View file @
667632cc
"""Reproduce: Argument count mismatch.
Note: The adapter-level wrapper expects only inputs (A, B) because C is marked as output.
Calling with the wrong number of inputs raises a ValueError before host entry.
"""
import
torch
from
common
import
build_matmul_kernel
def
main
():
M
=
N
=
K
=
256
fn
=
build_matmul_kernel
(
M
,
N
,
K
,
target
=
"cuda"
)
a
=
torch
.
empty
((
M
,
K
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
# Missing b
# Expected: ValueError with message about expected vs. actual inputs
fn
(
a
)
if
__name__
==
"__main__"
:
main
()
maint/host_checks/02_pointer_type_error.py
0 → 100644
View file @
667632cc
"""Reproduce: Pointer-type argument expected but scalar provided.
We pass an integer for A; wrapper forwards it to the host where a pointer is expected.
Expected: error like "Expect buffer A_handle to be pointer or tensor" (exact name depends on kernel param).
"""
import
torch
from
common
import
build_matmul_kernel
def
main
():
M
=
N
=
K
=
256
fn
=
build_matmul_kernel
(
M
,
N
,
K
,
target
=
"cuda"
)
# Wrong type for A (int instead of tensor)
a
=
1
b
=
torch
.
empty
((
K
,
N
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
fn
(
a
,
b
)
if
__name__
==
"__main__"
:
main
()
maint/host_checks/03_ndim_mismatch.py
0 → 100644
View file @
667632cc
"""Reproduce: ndim (rank) mismatch for A."""
import
torch
from
common
import
build_matmul_kernel
def
main
():
M
=
N
=
K
=
128
fn
=
build_matmul_kernel
(
M
,
N
,
K
,
target
=
"cuda"
)
# A has rank 3 instead of 2
a
=
torch
.
empty
((
M
,
K
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
b
=
torch
.
empty
((
K
,
N
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
fn
(
a
,
b
)
if
__name__
==
"__main__"
:
main
()
maint/host_checks/04_dtype_mismatch.py
0 → 100644
View file @
667632cc
"""Reproduce: dtype mismatch for A (float32 vs expected float16)."""
import
torch
from
common
import
build_matmul_kernel
def
main
():
M
=
N
=
K
=
128
fn
=
build_matmul_kernel
(
M
,
N
,
K
,
target
=
"cuda"
)
print
(
fn
.
get_host_source
())
a
=
torch
.
empty
((
M
,
K
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
# should be float16
b
=
torch
.
empty
((
K
,
N
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
fn
(
a
,
b
)
if
__name__
==
"__main__"
:
main
()
Prev
1
…
8
9
10
11
12
13
14
15
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment