Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
287 additions
and
466 deletions
+287
-466
testing/python/jit/test_tilelang_jit_nullptr.py
testing/python/jit/test_tilelang_jit_nullptr.py
+6
-17
testing/python/jit/test_tilelang_jit_nvrtc.py
testing/python/jit/test_tilelang_jit_nvrtc.py
+47
-126
testing/python/jit/test_tilelang_jit_parcompile.py
testing/python/jit/test_tilelang_jit_parcompile.py
+6
-6
testing/python/jit/test_tilelang_jit_tvm_ffi.py
testing/python/jit/test_tilelang_jit_tvm_ffi.py
+54
-127
testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py
testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py
+11
-11
testing/python/kernel/test_tilelang_kernel_element_wise_add.py
...ng/python/kernel/test_tilelang_kernel_element_wise_add.py
+4
-4
testing/python/kernel/test_tilelang_kernel_fp8_gemm.py
testing/python/kernel/test_tilelang_kernel_fp8_gemm.py
+4
-6
testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py
testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py
+11
-11
testing/python/kernel/test_tilelang_kernel_fp8_gemv_simt.py
testing/python/kernel/test_tilelang_kernel_fp8_gemv_simt.py
+17
-19
testing/python/kernel/test_tilelang_kernel_gemm.py
testing/python/kernel/test_tilelang_kernel_gemm.py
+11
-11
testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py
.../python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py
+11
-11
testing/python/kernel/test_tilelang_kernel_gemm_simt.py
testing/python/kernel/test_tilelang_kernel_gemm_simt.py
+10
-17
testing/python/kernel/test_tilelang_kernel_gemm_with_stride.py
...ng/python/kernel/test_tilelang_kernel_gemm_with_stride.py
+5
-5
testing/python/kernel/test_tilelang_kernel_gemv_simt.py
testing/python/kernel/test_tilelang_kernel_gemv_simt.py
+17
-19
testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py
testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py
+25
-26
testing/python/language/test_tilelang_capture.py
testing/python/language/test_tilelang_capture.py
+6
-5
testing/python/language/test_tilelang_intimm.py
testing/python/language/test_tilelang_intimm.py
+11
-11
testing/python/language/test_tilelang_language_alias.py
testing/python/language/test_tilelang_language_alias.py
+3
-4
testing/python/language/test_tilelang_language_all_of.py
testing/python/language/test_tilelang_language_all_of.py
+20
-22
testing/python/language/test_tilelang_language_alloc.py
testing/python/language/test_tilelang_language_alloc.py
+8
-8
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
testing/python/jit/test_tilelang_jit_nullptr.py
View file @
29051439
...
@@ -7,22 +7,13 @@ from tilelang.utils import map_torch_type
...
@@ -7,22 +7,13 @@ from tilelang.utils import map_torch_type
@
tl
.
jit
@
tl
.
jit
def
tensor_null_test
(
M
,
def
tensor_null_test
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
with_bias
=
False
):
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
with_bias
=
False
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
accum_dtype
),
Bias
:
T
.
Tensor
((
N
),
accum_dtype
),
Bias
:
T
.
Tensor
((
N
),
accum_dtype
),
):
):
# Initialize Kernel Context
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
...
@@ -48,12 +39,10 @@ def tensor_null_test(M,
...
@@ -48,12 +39,10 @@ def tensor_null_test(M,
def
run_test
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
run_test
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
a
=
torch
.
randn
(
M
,
K
,
device
=
"cuda"
,
dtype
=
map_torch_type
(
dtype
))
a
=
torch
.
randn
(
M
,
K
,
device
=
"cuda"
,
dtype
=
map_torch_type
(
dtype
))
b
=
torch
.
randn
(
N
,
K
,
device
=
"cuda"
,
dtype
=
map_torch_type
(
dtype
))
b
=
torch
.
randn
(
N
,
K
,
device
=
"cuda"
,
dtype
=
map_torch_type
(
dtype
))
c
=
torch
.
zeros
(
M
,
N
,
device
=
"cuda"
,
dtype
=
map_torch_type
(
accum_dtype
))
c
=
torch
.
zeros
(
M
,
N
,
device
=
"cuda"
,
dtype
=
map_torch_type
(
accum_dtype
))
kernel
=
tensor_null_test
(
kernel
=
tensor_null_test
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
,
accum_dtype
,
with_bias
=
False
)
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
,
accum_dtype
,
with_bias
=
False
)
kernel
(
a
,
b
,
c
,
None
)
kernel
(
a
,
b
,
c
,
None
)
...
...
testing/python/jit/test_tilelang_jit_nvrtc.py
View file @
29051439
...
@@ -28,9 +28,9 @@ def matmul(
...
@@ -28,9 +28,9 @@ def matmul(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
@@ -136,9 +136,9 @@ def matmu_jit_kernel(
...
@@ -136,9 +136,9 @@ def matmu_jit_kernel(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
@@ -206,6 +206,7 @@ def run_gemm_jit_kernel(
...
@@ -206,6 +206,7 @@ def run_gemm_jit_kernel(
def
ref_program
(
A
,
B
):
def
ref_program
(
A
,
B
):
import
torch
import
torch
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
C
.
to
(
out_dtype
)
C
=
C
.
to
(
out_dtype
)
return
C
return
C
...
@@ -233,19 +234,9 @@ def test_gemm_jit_kernel():
...
@@ -233,19 +234,9 @@ def test_gemm_jit_kernel():
)
)
def
run_nvrtc_kernel_do_bench
(
M
,
def
run_nvrtc_kernel_do_bench
(
N
,
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
128
K
,
):
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
128
):
program
=
matmul
(
program
=
matmul
(
M
,
M
,
N
,
N
,
...
@@ -278,23 +269,12 @@ def run_nvrtc_kernel_do_bench(M,
...
@@ -278,23 +269,12 @@ def run_nvrtc_kernel_do_bench(M,
def
test_nvrtc_kernel_do_bench
():
def
test_nvrtc_kernel_do_bench
():
run_nvrtc_kernel_do_bench
(
512
,
1024
,
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
run_nvrtc_kernel_do_bench
(
512
,
1024
,
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
256
,
32
,
2
)
def
run_nvrtc_kernel_multi_stream
(
def
run_nvrtc_kernel_multi_stream
(
M
,
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
128
N
,
):
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
128
):
program
=
matmul
(
program
=
matmul
(
M
,
M
,
N
,
N
,
...
@@ -331,23 +311,12 @@ def run_nvrtc_kernel_multi_stream(M,
...
@@ -331,23 +311,12 @@ def run_nvrtc_kernel_multi_stream(M,
def
test_nvrtc_kernel_multi_stream
():
def
test_nvrtc_kernel_multi_stream
():
run_nvrtc_kernel_multi_stream
(
512
,
1024
,
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
run_nvrtc_kernel_multi_stream
(
512
,
1024
,
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
128
,
256
,
32
,
2
)
def
run_nvrtc_dynamic_shape
(
def
run_nvrtc_dynamic_shape
(
M
,
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
128
N
,
):
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
128
):
program
=
matmul
(
program
=
matmul
(
M
,
M
,
N
,
N
,
...
@@ -387,21 +356,15 @@ def run_nvrtc_dynamic_shape(M,
...
@@ -387,21 +356,15 @@ def run_nvrtc_dynamic_shape(M,
matmul_kernel
(
tensor_a
,
tensor_b
,
tensor_c
)
matmul_kernel
(
tensor_a
,
tensor_b
,
tensor_c
)
tensor_ref_c
=
torch
.
matmul
(
tensor_a
.
to
(
torch
.
float
),
tensor_b
.
to
(
torch
.
float
)).
to
(
out_dtype
)
tensor_ref_c
=
torch
.
matmul
(
tensor_a
.
to
(
torch
.
float
),
tensor_b
.
to
(
torch
.
float
)).
to
(
out_dtype
)
tilelang
.
testing
.
torch_assert_close
(
tilelang
.
testing
.
torch_assert_close
(
tensor_c
,
tensor_ref_c
,
atol
=
1e-2
,
rtol
=
1e-2
,
max_mismatched_ratio
=
0.05
)
tensor_c
,
tensor_ref_c
,
atol
=
1e-2
,
rtol
=
1e-2
,
max_mismatched_ratio
=
0.05
)
def
test_nvrtc_dynamic_shape
():
def
test_nvrtc_dynamic_shape
():
run_nvrtc_dynamic_shape
(
run_nvrtc_dynamic_shape
(
T
.
dynamic
(
"m"
),
1024
,
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
T
.
dynamic
(
"m"
),
1024
,
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
run_nvrtc_dynamic_shape
(
run_nvrtc_dynamic_shape
(
T
.
dynamic
(
"m"
),
T
.
dynamic
(
"n"
),
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
T
.
dynamic
(
"m"
),
T
.
dynamic
(
"n"
),
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
run_nvrtc_dynamic_shape
(
run_nvrtc_dynamic_shape
(
T
.
dynamic
(
"m"
),
T
.
dynamic
(
"n"
),
T
.
dynamic
(
"k"
),
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
T
.
dynamic
(
"m"
),
T
.
dynamic
(
"n"
),
T
.
dynamic
(
"k"
),
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
def
check_hopper
():
def
check_hopper
():
...
@@ -412,35 +375,18 @@ def check_hopper():
...
@@ -412,35 +375,18 @@ def check_hopper():
return
compute_capability
==
(
9
,
0
)
return
compute_capability
==
(
9
,
0
)
def
convolution_im2col
(
N
,
def
convolution_im2col
(
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
KH
,
KW
=
K
,
K
KH
,
KW
=
K
,
K
OH
=
(
H
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
OH
=
(
H
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
OW
=
(
W
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
OW
=
(
W
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
data
:
T
.
Tensor
((
N
,
H
,
W
,
C
),
dtype
),
data
:
T
.
Tensor
((
N
,
H
,
W
,
C
),
dtype
),
kernel
:
T
.
Tensor
((
KH
,
KW
,
C
,
F
),
dtype
),
kernel
:
T
.
Tensor
((
KH
,
KW
,
C
,
F
),
dtype
),
out
:
T
.
Tensor
((
N
,
OH
,
OW
,
F
),
dtype
),
out
:
T
.
Tensor
((
N
,
OH
,
OW
,
F
),
dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
F
,
block_N
),
T
.
ceildiv
(
N
*
OH
*
OW
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
T
.
ceildiv
(
F
,
block_N
),
T
.
ceildiv
(
N
*
OH
*
OW
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
data_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
data_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
kernel_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
kernel_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
out_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
out_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
@@ -449,11 +395,13 @@ def convolution_im2col(N,
...
@@ -449,11 +395,13 @@ def convolution_im2col(N,
kernel_flat
=
T
.
Tensor
((
KH
*
KW
*
C
,
F
),
dtype
,
kernel
.
data
)
kernel_flat
=
T
.
Tensor
((
KH
*
KW
*
C
,
F
),
dtype
,
kernel
.
data
)
out_flat
=
T
.
Tensor
((
N
*
OH
*
OW
,
F
),
dtype
,
out
.
data
)
out_flat
=
T
.
Tensor
((
N
*
OH
*
OW
,
F
),
dtype
,
out
.
data
)
T
.
annotate_layout
({
T
.
annotate_layout
(
out_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
out_shared
),
{
data_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
data_shared
),
out_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
out_shared
),
kernel_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
kernel_shared
),
data_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
data_shared
),
})
kernel_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
kernel_shared
),
}
)
T
.
clear
(
out_local
)
T
.
clear
(
out_local
)
for
k_iter
in
T
.
Pipelined
(
T
.
ceildiv
(
KH
*
KW
*
C
,
block_K
),
num_stages
=
num_stages
):
for
k_iter
in
T
.
Pipelined
(
T
.
ceildiv
(
KH
*
KW
*
C
,
block_K
),
num_stages
=
num_stages
):
...
@@ -467,23 +415,9 @@ def convolution_im2col(N,
...
@@ -467,23 +415,9 @@ def convolution_im2col(N,
return
main
return
main
def
run_nvrtc_im2col_tma_desc
(
N
,
def
run_nvrtc_im2col_tma_desc
(
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
256
):
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
256
):
"""Test im2col TMA descriptor functionality in NVRTC backend."""
"""Test im2col TMA descriptor functionality in NVRTC backend."""
program
=
convolution_im2col
(
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
program
=
convolution_im2col
(
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
num_threads
)
num_threads
)
conv_kernel
=
tilelang
.
compile
(
program
,
out_idx
=-
1
,
execution_backend
=
"nvrtc"
)
conv_kernel
=
tilelang
.
compile
(
program
,
out_idx
=-
1
,
execution_backend
=
"nvrtc"
)
...
@@ -501,32 +435,20 @@ def run_nvrtc_im2col_tma_desc(N,
...
@@ -501,32 +435,20 @@ def run_nvrtc_im2col_tma_desc(N,
return
C
return
C
ref_c
=
ref_program
(
a
,
b
)
ref_c
=
ref_program
(
a
,
b
)
tilelang
.
testing
.
torch_assert_close
(
tilelang
.
testing
.
torch_assert_close
(
out_c
,
ref_c
,
atol
=
1e-2
,
rtol
=
1e-2
,
max_mismatched_ratio
=
0.05
)
out_c
,
ref_c
,
atol
=
1e-2
,
rtol
=
1e-2
,
max_mismatched_ratio
=
0.05
)
def
test_nvrtc_im2col_tma_desc
():
def
test_nvrtc_im2col_tma_desc
():
"""Test im2col TMA descriptor with NVRTC backend."""
"""Test im2col TMA descriptor with NVRTC backend."""
if
not
check_hopper
():
if
not
check_hopper
():
import
pytest
import
pytest
pytest
.
skip
(
"Test requires Hopper GPU (compute capability 9.0)"
)
pytest
.
skip
(
"Test requires Hopper GPU (compute capability 9.0)"
)
# Small test case for im2col TMA descriptor
# Small test case for im2col TMA descriptor
run_nvrtc_im2col_tma_desc
(
run_nvrtc_im2col_tma_desc
(
N
=
4
,
N
=
4
,
C
=
64
,
H
=
32
,
W
=
32
,
F
=
64
,
K
=
3
,
S
=
1
,
D
=
1
,
P
=
1
,
block_M
=
64
,
block_N
=
128
,
block_K
=
32
,
num_stages
=
3
,
num_threads
=
256
C
=
64
,
)
H
=
32
,
W
=
32
,
F
=
64
,
K
=
3
,
S
=
1
,
D
=
1
,
P
=
1
,
block_M
=
64
,
block_N
=
128
,
block_K
=
32
,
num_stages
=
3
,
num_threads
=
256
)
def
test_nvrtc_l2_persistent_map
():
def
test_nvrtc_l2_persistent_map
():
...
@@ -543,12 +465,11 @@ def test_nvrtc_l2_persistent_map():
...
@@ -543,12 +465,11 @@ def test_nvrtc_l2_persistent_map():
block_size
=
256
,
block_size
=
256
,
dtype
=
"float32"
,
dtype
=
"float32"
,
):
):
@
T
.
prim_func
@
T
.
prim_func
def
kernel
(
def
kernel
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
M
*
N
//
block_size
,
threads
=
block_size
)
as
bx
:
with
T
.
Kernel
(
M
*
N
//
block_size
,
threads
=
block_size
)
as
bx
:
# Annotate L2 persistent cache for buffer B
# Annotate L2 persistent cache for buffer B
...
...
testing/python/jit/test_tilelang_jit_parcompile.py
View file @
29051439
...
@@ -16,9 +16,9 @@ def matmul_kernel_jit(
...
@@ -16,9 +16,9 @@ def matmul_kernel_jit(
block_K
,
block_K
,
trans_A
=
False
,
trans_A
=
False
,
trans_B
=
True
,
trans_B
=
True
,
in_dtype
=
'
float16
'
,
in_dtype
=
"
float16
"
,
out_dtype
=
'
float32
'
,
out_dtype
=
"
float32
"
,
accum_dtype
=
'
float32
'
,
accum_dtype
=
"
float32
"
,
num_stages
=
2
,
num_stages
=
2
,
threads
=
128
,
threads
=
128
,
):
):
...
@@ -31,9 +31,9 @@ def matmul_kernel_jit(
...
@@ -31,9 +31,9 @@ def matmul_kernel_jit(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
...
testing/python/jit/test_tilelang_jit_tvm_ffi.py
View file @
29051439
...
@@ -28,9 +28,9 @@ def matmul(
...
@@ -28,9 +28,9 @@ def matmul(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
@@ -74,9 +74,9 @@ def matmu_jit_kernel(
...
@@ -74,9 +74,9 @@ def matmu_jit_kernel(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
@@ -144,6 +144,7 @@ def run_gemm_jit_kernel(
...
@@ -144,6 +144,7 @@ def run_gemm_jit_kernel(
def
ref_program
(
A
,
B
):
def
ref_program
(
A
,
B
):
import
torch
import
torch
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
C
.
to
(
out_dtype
)
C
=
C
.
to
(
out_dtype
)
return
C
return
C
...
@@ -171,19 +172,9 @@ def test_gemm_jit_kernel():
...
@@ -171,19 +172,9 @@ def test_gemm_jit_kernel():
)
)
def
run_tvm_ffi_kernel_do_bench
(
M
,
def
run_tvm_ffi_kernel_do_bench
(
N
,
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
128
K
,
):
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
128
):
program
=
matmul
(
program
=
matmul
(
M
,
M
,
N
,
N
,
...
@@ -216,23 +207,12 @@ def run_tvm_ffi_kernel_do_bench(M,
...
@@ -216,23 +207,12 @@ def run_tvm_ffi_kernel_do_bench(M,
def
test_tvm_ffi_kernel_do_bench
():
def
test_tvm_ffi_kernel_do_bench
():
run_tvm_ffi_kernel_do_bench
(
512
,
1024
,
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
run_tvm_ffi_kernel_do_bench
(
512
,
1024
,
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
256
,
32
,
2
)
def
run_tvm_ffi_kernel_multi_stream
(
def
run_tvm_ffi_kernel_multi_stream
(
M
,
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
128
N
,
):
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
128
):
program
=
matmul
(
program
=
matmul
(
M
,
M
,
N
,
N
,
...
@@ -269,23 +249,12 @@ def run_tvm_ffi_kernel_multi_stream(M,
...
@@ -269,23 +249,12 @@ def run_tvm_ffi_kernel_multi_stream(M,
def
test_tvm_ffi_kernel_multi_stream
():
def
test_tvm_ffi_kernel_multi_stream
():
run_tvm_ffi_kernel_multi_stream
(
512
,
1024
,
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
run_tvm_ffi_kernel_multi_stream
(
512
,
1024
,
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
128
,
256
,
32
,
2
)
def
run_tvm_ffi_dynamic_shape
(
def
run_tvm_ffi_dynamic_shape
(
M
,
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
128
N
,
):
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
128
):
program
=
matmul
(
program
=
matmul
(
M
,
M
,
N
,
N
,
...
@@ -325,21 +294,17 @@ def run_tvm_ffi_dynamic_shape(M,
...
@@ -325,21 +294,17 @@ def run_tvm_ffi_dynamic_shape(M,
matmul_kernel
(
tensor_a
,
tensor_b
,
tensor_c
)
matmul_kernel
(
tensor_a
,
tensor_b
,
tensor_c
)
tensor_ref_c
=
torch
.
matmul
(
tensor_a
.
to
(
torch
.
float
),
tensor_b
.
to
(
torch
.
float
)).
to
(
out_dtype
)
tensor_ref_c
=
torch
.
matmul
(
tensor_a
.
to
(
torch
.
float
),
tensor_b
.
to
(
torch
.
float
)).
to
(
out_dtype
)
tilelang
.
testing
.
torch_assert_close
(
tilelang
.
testing
.
torch_assert_close
(
tensor_c
,
tensor_ref_c
,
atol
=
1e-2
,
rtol
=
1e-2
,
max_mismatched_ratio
=
0.05
)
tensor_c
,
tensor_ref_c
,
atol
=
1e-2
,
rtol
=
1e-2
,
max_mismatched_ratio
=
0.05
)
def
test_tvm_ffi_dynamic_shape
():
def
test_tvm_ffi_dynamic_shape
():
run_tvm_ffi_dynamic_shape
(
run_tvm_ffi_dynamic_shape
(
T
.
dynamic
(
"m"
),
1024
,
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
T
.
dynamic
(
"m"
),
1024
,
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
run_tvm_ffi_dynamic_shape
(
run_tvm_ffi_dynamic_shape
(
T
.
dynamic
(
"m"
),
T
.
dynamic
(
"n"
),
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
T
.
dynamic
(
"m"
),
T
.
dynamic
(
"n"
),
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
run_tvm_ffi_dynamic_shape
(
run_tvm_ffi_dynamic_shape
(
T
.
dynamic
(
"m"
),
T
.
dynamic
(
"n"
),
T
.
dynamic
(
"k"
),
False
,
False
,
"float16"
,
"float16"
,
T
.
dynamic
(
"m"
),
T
.
dynamic
(
"n"
),
T
.
dynamic
(
"k"
),
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
"float16"
,
128
,
256
,
32
,
2
)
)
def
check_hopper
():
def
check_hopper
():
...
@@ -350,35 +315,18 @@ def check_hopper():
...
@@ -350,35 +315,18 @@ def check_hopper():
return
compute_capability
==
(
9
,
0
)
return
compute_capability
==
(
9
,
0
)
def
convolution_im2col
(
N
,
def
convolution_im2col
(
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
KH
,
KW
=
K
,
K
KH
,
KW
=
K
,
K
OH
=
(
H
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
OH
=
(
H
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
OW
=
(
W
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
OW
=
(
W
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
data
:
T
.
Tensor
((
N
,
H
,
W
,
C
),
dtype
),
data
:
T
.
Tensor
((
N
,
H
,
W
,
C
),
dtype
),
kernel
:
T
.
Tensor
((
KH
,
KW
,
C
,
F
),
dtype
),
kernel
:
T
.
Tensor
((
KH
,
KW
,
C
,
F
),
dtype
),
out
:
T
.
Tensor
((
N
,
OH
,
OW
,
F
),
dtype
),
out
:
T
.
Tensor
((
N
,
OH
,
OW
,
F
),
dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
F
,
block_N
),
T
.
ceildiv
(
N
*
OH
*
OW
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
T
.
ceildiv
(
F
,
block_N
),
T
.
ceildiv
(
N
*
OH
*
OW
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
data_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
data_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
kernel_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
kernel_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
out_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
out_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
@@ -387,11 +335,13 @@ def convolution_im2col(N,
...
@@ -387,11 +335,13 @@ def convolution_im2col(N,
kernel_flat
=
T
.
Tensor
((
KH
*
KW
*
C
,
F
),
dtype
,
kernel
.
data
)
kernel_flat
=
T
.
Tensor
((
KH
*
KW
*
C
,
F
),
dtype
,
kernel
.
data
)
out_flat
=
T
.
Tensor
((
N
*
OH
*
OW
,
F
),
dtype
,
out
.
data
)
out_flat
=
T
.
Tensor
((
N
*
OH
*
OW
,
F
),
dtype
,
out
.
data
)
T
.
annotate_layout
({
T
.
annotate_layout
(
out_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
out_shared
),
{
data_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
data_shared
),
out_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
out_shared
),
kernel_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
kernel_shared
),
data_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
data_shared
),
})
kernel_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
kernel_shared
),
}
)
T
.
clear
(
out_local
)
T
.
clear
(
out_local
)
for
k_iter
in
T
.
Pipelined
(
T
.
ceildiv
(
KH
*
KW
*
C
,
block_K
),
num_stages
=
num_stages
):
for
k_iter
in
T
.
Pipelined
(
T
.
ceildiv
(
KH
*
KW
*
C
,
block_K
),
num_stages
=
num_stages
):
...
@@ -405,23 +355,9 @@ def convolution_im2col(N,
...
@@ -405,23 +355,9 @@ def convolution_im2col(N,
return
main
return
main
def
run_tvm_ffi_im2col_tma_desc
(
N
,
def
run_tvm_ffi_im2col_tma_desc
(
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
256
):
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
256
):
"""Test im2col TMA descriptor functionality in tvm_ffi backend."""
"""Test im2col TMA descriptor functionality in tvm_ffi backend."""
program
=
convolution_im2col
(
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
program
=
convolution_im2col
(
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
num_threads
)
num_threads
)
conv_kernel
=
tilelang
.
compile
(
program
,
out_idx
=-
1
,
execution_backend
=
"tvm_ffi"
)
conv_kernel
=
tilelang
.
compile
(
program
,
out_idx
=-
1
,
execution_backend
=
"tvm_ffi"
)
...
@@ -439,32 +375,20 @@ def run_tvm_ffi_im2col_tma_desc(N,
...
@@ -439,32 +375,20 @@ def run_tvm_ffi_im2col_tma_desc(N,
return
C
return
C
ref_c
=
ref_program
(
a
,
b
)
ref_c
=
ref_program
(
a
,
b
)
tilelang
.
testing
.
torch_assert_close
(
tilelang
.
testing
.
torch_assert_close
(
out_c
,
ref_c
,
atol
=
1e-2
,
rtol
=
1e-2
,
max_mismatched_ratio
=
0.05
)
out_c
,
ref_c
,
atol
=
1e-2
,
rtol
=
1e-2
,
max_mismatched_ratio
=
0.05
)
def
test_tvm_ffi_im2col_tma_desc
():
def
test_tvm_ffi_im2col_tma_desc
():
"""Test im2col TMA descriptor with tvm_ffi backend."""
"""Test im2col TMA descriptor with tvm_ffi backend."""
if
not
check_hopper
():
if
not
check_hopper
():
import
pytest
import
pytest
pytest
.
skip
(
"Test requires Hopper GPU (compute capability 9.0)"
)
pytest
.
skip
(
"Test requires Hopper GPU (compute capability 9.0)"
)
# Small test case for im2col TMA descriptor
# Small test case for im2col TMA descriptor
run_tvm_ffi_im2col_tma_desc
(
run_tvm_ffi_im2col_tma_desc
(
N
=
4
,
N
=
4
,
C
=
64
,
H
=
32
,
W
=
32
,
F
=
64
,
K
=
3
,
S
=
1
,
D
=
1
,
P
=
1
,
block_M
=
64
,
block_N
=
128
,
block_K
=
32
,
num_stages
=
3
,
num_threads
=
256
C
=
64
,
)
H
=
32
,
W
=
32
,
F
=
64
,
K
=
3
,
S
=
1
,
D
=
1
,
P
=
1
,
block_M
=
64
,
block_N
=
128
,
block_K
=
32
,
num_stages
=
3
,
num_threads
=
256
)
def
test_tvm_ffi_l2_persistent_map
():
def
test_tvm_ffi_l2_persistent_map
():
...
@@ -481,12 +405,11 @@ def test_tvm_ffi_l2_persistent_map():
...
@@ -481,12 +405,11 @@ def test_tvm_ffi_l2_persistent_map():
block_size
=
256
,
block_size
=
256
,
dtype
=
"float32"
,
dtype
=
"float32"
,
):
):
@
T
.
prim_func
@
T
.
prim_func
def
kernel
(
def
kernel
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
M
*
N
//
block_size
,
threads
=
block_size
)
as
bx
:
with
T
.
Kernel
(
M
*
N
//
block_size
,
threads
=
block_size
)
as
bx
:
# Annotate L2 persistent cache for buffer B
# Annotate L2 persistent cache for buffer B
...
@@ -506,8 +429,12 @@ def test_tvm_ffi_l2_persistent_map():
...
@@ -506,8 +429,12 @@ def test_tvm_ffi_l2_persistent_map():
kernel
=
elementwise_add_with_l2_cache
(
M
,
N
)
kernel
=
elementwise_add_with_l2_cache
(
M
,
N
)
source
=
kernel
.
get_host_source
()
source
=
kernel
.
get_host_source
()
assert
"__tvm_cuda_stream_set_access_policy_window_packed"
in
source
,
"Expected __tvm_cuda_stream_set_access_policy_window_packed in the kernel source"
assert
"__tvm_cuda_stream_set_access_policy_window_packed"
in
source
,
(
assert
"__tvm_cuda_stream_reset_access_policy_window_packed"
in
source
,
"Expected __tvm_cuda_stream_reset_access_policy_window_packed in the kernel source"
"Expected __tvm_cuda_stream_set_access_policy_window_packed in the kernel source"
)
assert
"__tvm_cuda_stream_reset_access_policy_window_packed"
in
source
,
(
"Expected __tvm_cuda_stream_reset_access_policy_window_packed in the kernel source"
)
# Create test tensors
# Create test tensors
a
=
torch
.
randn
(
M
,
N
,
dtype
=
torch
.
float32
).
cuda
()
a
=
torch
.
randn
(
M
,
N
,
dtype
=
torch
.
float32
).
cuda
()
...
...
testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py
View file @
29051439
...
@@ -6,7 +6,8 @@ from tvm import DataType
...
@@ -6,7 +6,8 @@ from tvm import DataType
import
tilelang.language
as
T
import
tilelang.language
as
T
from
tilelang.intrinsics
import
get_swizzle_layout
from
tilelang.intrinsics
import
get_swizzle_layout
from
tilelang.intrinsics.mma_macro_generator
import
(
from
tilelang.intrinsics.mma_macro_generator
import
(
TensorCoreIntrinEmitter
,)
TensorCoreIntrinEmitter
,
)
from
tilelang.transform
import
simplify_prim_func
from
tilelang.transform
import
simplify_prim_func
from
tilelang.utils.tensor
import
map_torch_type
from
tilelang.utils.tensor
import
map_torch_type
...
@@ -111,12 +112,11 @@ def tl_matmul(
...
@@ -111,12 +112,11 @@ def tl_matmul(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
...
@@ -124,10 +124,12 @@ def tl_matmul(
...
@@ -124,10 +124,12 @@ def tl_matmul(
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
A_shared
:
make_swizzle_layout
(
A_shared
),
{
B_shared
:
make_swizzle_layout
(
B_shared
),
A_shared
:
make_swizzle_layout
(
A_shared
),
})
B_shared
:
make_swizzle_layout
(
B_shared
),
}
)
# Improve L2 Cache
# Improve L2 Cache
T
.
use_swizzle
(
panel_size
=
10
)
T
.
use_swizzle
(
panel_size
=
10
)
...
@@ -135,7 +137,6 @@ def tl_matmul(
...
@@ -135,7 +137,6 @@ def tl_matmul(
T
.
clear
(
C_local
)
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
stage
):
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
stage
):
# Load A into shared memory
# Load A into shared memory
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
...
@@ -145,7 +146,6 @@ def tl_matmul(
...
@@ -145,7 +146,6 @@ def tl_matmul(
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
block_K
+
k
]
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
block_K
+
k
]
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load A into fragment
# Load A into fragment
mma_emitter
.
ldmatrix_a
(
mma_emitter
.
ldmatrix_a
(
A_local
,
A_local
,
...
...
testing/python/kernel/test_tilelang_kernel_element_wise_add.py
View file @
29051439
...
@@ -16,15 +16,15 @@ def elementwise_add(
...
@@ -16,15 +16,15 @@ def elementwise_add(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
in_dtype
),
A
:
T
.
Tensor
((
M
,
N
),
in_dtype
),
B
:
T
.
Tensor
((
M
,
N
),
in_dtype
),
B
:
T
.
Tensor
((
M
,
N
),
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
start_x
=
bx
*
block_N
start_x
=
bx
*
block_N
start_y
=
by
*
block_M
start_y
=
by
*
block_M
for
(
local_y
,
local_x
)
in
T
.
Parallel
(
block_M
,
block_N
):
for
local_y
,
local_x
in
T
.
Parallel
(
block_M
,
block_N
):
y
=
start_y
+
local_y
y
=
start_y
+
local_y
x
=
start_x
+
local_x
x
=
start_x
+
local_x
...
...
testing/python/kernel/test_tilelang_kernel_fp8_gemm.py
View file @
29051439
...
@@ -12,12 +12,11 @@ def calc_diff(x, y):
...
@@ -12,12 +12,11 @@ def calc_diff(x, y):
def
matmul_nt
(
M
,
N
,
K
,
bM
,
bN
,
bK
,
in_dtype
,
out_dtype
,
accum_dtype
):
def
matmul_nt
(
M
,
N
,
K
,
bM
,
bN
,
bK
,
in_dtype
,
out_dtype
,
accum_dtype
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
in_dtype
),
A
:
T
.
Tensor
((
M
,
K
),
in_dtype
),
B
:
T
.
Tensor
((
N
,
K
),
in_dtype
),
B
:
T
.
Tensor
((
N
,
K
),
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
bN
),
T
.
ceildiv
(
M
,
bM
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
bN
),
T
.
ceildiv
(
M
,
bM
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
bM
,
bK
),
in_dtype
)
A_shared
=
T
.
alloc_shared
((
bM
,
bK
),
in_dtype
)
...
@@ -44,8 +43,7 @@ def assert_matmul_correctness(M, N, K, block_M, block_N, block_K, in_dtype, out_
...
@@ -44,8 +43,7 @@ def assert_matmul_correctness(M, N, K, block_M, block_N, block_K, in_dtype, out_
C
=
kernel
(
A
,
B
)
C
=
kernel
(
A
,
B
)
ref_c
=
torch
.
matmul
(
A
.
to
(
map_torch_type
(
accum_dtype
)),
ref_c
=
torch
.
matmul
(
A
.
to
(
map_torch_type
(
accum_dtype
)),
B
.
T
.
to
(
map_torch_type
(
accum_dtype
))).
to
(
map_torch_type
(
out_dtype
))
B
.
T
.
to
(
map_torch_type
(
accum_dtype
))).
to
(
map_torch_type
(
out_dtype
))
print
(
C
)
print
(
C
)
print
(
ref_c
)
print
(
ref_c
)
diff
=
calc_diff
(
C
,
ref_c
)
diff
=
calc_diff
(
C
,
ref_c
)
...
...
testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py
View file @
29051439
...
@@ -6,7 +6,8 @@ from tvm import DataType
...
@@ -6,7 +6,8 @@ from tvm import DataType
import
tilelang.language
as
T
import
tilelang.language
as
T
from
tilelang.intrinsics
import
get_swizzle_layout
from
tilelang.intrinsics
import
get_swizzle_layout
from
tilelang.intrinsics.mma_macro_generator
import
(
from
tilelang.intrinsics.mma_macro_generator
import
(
TensorCoreIntrinEmitter
,)
TensorCoreIntrinEmitter
,
)
from
tilelang.transform
import
simplify_prim_func
from
tilelang.transform
import
simplify_prim_func
from
tilelang.utils.tensor
import
map_torch_type
from
tilelang.utils.tensor
import
map_torch_type
...
@@ -110,12 +111,11 @@ def tl_matmul(
...
@@ -110,12 +111,11 @@ def tl_matmul(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
...
@@ -123,10 +123,12 @@ def tl_matmul(
...
@@ -123,10 +123,12 @@ def tl_matmul(
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
A_shared
:
make_swizzle_layout
(
A_shared
),
{
B_shared
:
make_swizzle_layout
(
B_shared
),
A_shared
:
make_swizzle_layout
(
A_shared
),
})
B_shared
:
make_swizzle_layout
(
B_shared
),
}
)
# Improve L2 Cache
# Improve L2 Cache
T
.
use_swizzle
(
panel_size
=
10
)
T
.
use_swizzle
(
panel_size
=
10
)
...
@@ -134,7 +136,6 @@ def tl_matmul(
...
@@ -134,7 +136,6 @@ def tl_matmul(
T
.
clear
(
C_local
)
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
stage
):
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
stage
):
# Load A into shared memory
# Load A into shared memory
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
...
@@ -144,7 +145,6 @@ def tl_matmul(
...
@@ -144,7 +145,6 @@ def tl_matmul(
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
block_K
+
k
]
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
block_K
+
k
]
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load A into fragment
# Load A into fragment
mma_emitter
.
ldmatrix_a
(
mma_emitter
.
ldmatrix_a
(
A_local
,
A_local
,
...
...
testing/python/kernel/test_tilelang_kernel_fp8_gemv_simt.py
View file @
29051439
...
@@ -27,8 +27,8 @@ def gemv_simt(
...
@@ -27,8 +27,8 @@ def gemv_simt(
):
):
assert
n_partition
is
not
None
,
"n_partition must be provided"
assert
n_partition
is
not
None
,
"n_partition must be provided"
assert
reduce_thread
is
not
None
,
(
assert
reduce_thread
is
not
None
,
(
"reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMV"
"reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMV
sch_outer_reduction_with_config is not implemented
"
"sch_outer_reduction_with_config is not implemented"
)
)
assert
isinstance
(
N
,
int
)
and
isinstance
(
K
,
int
),
"Do not support dynamic N and K Currently"
assert
isinstance
(
N
,
int
)
and
isinstance
(
K
,
int
),
"Do not support dynamic N and K Currently"
...
@@ -50,16 +50,15 @@ def gemv_simt(
...
@@ -50,16 +50,15 @@ def gemv_simt(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
Bias
:
T
.
Tensor
(
Bias_shape
,
out_dtype
),
Bias
:
T
.
Tensor
(
Bias_shape
,
out_dtype
),
C
:
T
.
Tensor
(
C_shape
,
out_dtype
),
C
:
T
.
Tensor
(
C_shape
,
out_dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
n_partition
),
M
,
threads
=
(
reduce_thread
,
n_partition
))
as
(
T
.
ceildiv
(
N
,
n_partition
),
M
,
threads
=
(
reduce_thread
,
n_partition
))
as
(
bx
,
bx
,
by
,
by
,
):
):
A_local
=
T
.
alloc_local
((
micro_size_k
,),
in_dtype
)
A_local
=
T
.
alloc_local
((
micro_size_k
,),
in_dtype
)
B_local
=
T
.
alloc_local
((
micro_size_k
,),
in_dtype
)
B_local
=
T
.
alloc_local
((
micro_size_k
,),
in_dtype
)
accum_res
=
T
.
alloc_local
((
1
,),
accum_dtype
)
accum_res
=
T
.
alloc_local
((
1
,),
accum_dtype
)
...
@@ -88,13 +87,12 @@ def gemv_simt(
...
@@ -88,13 +87,12 @@ def gemv_simt(
)
)
else
:
else
:
for
ki
in
T
.
serial
(
micro_size_k
):
for
ki
in
T
.
serial
(
micro_size_k
):
accum_res
[
0
]
+=
A_local
[
ki
].
astype
(
accum_dtype
)
*
B_local
[
ki
].
astype
(
accum_res
[
0
]
+=
A_local
[
ki
].
astype
(
accum_dtype
)
*
B_local
[
ki
].
astype
(
accum_dtype
)
accum_dtype
)
with
T
.
attr
(
with
T
.
attr
(
T
.
comm_reducer
(
lambda
x
,
y
:
x
+
y
,
[
T
.
Cast
(
accum_dtype
,
0
)]),
T
.
comm_reducer
(
lambda
x
,
y
:
x
+
y
,
[
T
.
Cast
(
accum_dtype
,
0
)]),
"reduce_scope"
,
"reduce_scope"
,
T
.
reinterpret
(
T
.
uint64
(
0
),
dtype
=
"handle"
),
T
.
reinterpret
(
T
.
uint64
(
0
),
dtype
=
"handle"
),
):
):
T
.
evaluate
(
T
.
evaluate
(
T
.
tvm_thread_allreduce
(
T
.
tvm_thread_allreduce
(
...
@@ -104,11 +102,11 @@ def gemv_simt(
...
@@ -104,11 +102,11 @@ def gemv_simt(
reduced_accum_res
[
0
],
reduced_accum_res
[
0
],
kr
,
kr
,
dtype
=
"handle"
,
dtype
=
"handle"
,
))
)
)
if
kr
==
0
:
if
kr
==
0
:
if
with_bias
:
if
with_bias
:
C
[
by
,
C
[
by
,
bx
*
n_partition
+
ni
]
=
reduced_accum_res
[
0
]
+
Bias
[
bx
*
n_partition
+
ni
]
bx
*
n_partition
+
ni
]
=
reduced_accum_res
[
0
]
+
Bias
[
bx
*
n_partition
+
ni
]
else
:
else
:
C
[
by
,
bx
*
n_partition
+
ni
]
=
reduced_accum_res
[
0
]
C
[
by
,
bx
*
n_partition
+
ni
]
=
reduced_accum_res
[
0
]
...
...
testing/python/kernel/test_tilelang_kernel_gemm.py
View file @
29051439
...
@@ -26,9 +26,9 @@ def matmul(
...
@@ -26,9 +26,9 @@ def matmul(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
@@ -95,8 +95,8 @@ def run_gemm(
...
@@ -95,8 +95,8 @@ def run_gemm(
if
in_dtype
==
"float32"
:
if
in_dtype
==
"float32"
:
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas
# float32 automatically, -0x1000 meas
A
=
(
(
A
.
view
(
torch
.
int32
)
-
0x1000
)
)
.
view
(
torch
.
float32
)
A
=
(
A
.
view
(
torch
.
int32
)
-
0x1000
).
view
(
torch
.
float32
)
B
=
(
(
B
.
view
(
torch
.
int32
)
-
0x1000
)
)
.
view
(
torch
.
float32
)
B
=
(
B
.
view
(
torch
.
int32
)
-
0x1000
).
view
(
torch
.
float32
)
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
C
.
to
(
torch
.
__getattribute__
(
out_dtype
))
C
=
C
.
to
(
torch
.
__getattribute__
(
out_dtype
))
return
C
return
C
...
@@ -321,9 +321,9 @@ def matmul_sr(
...
@@ -321,9 +321,9 @@ def matmul_sr(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
@@ -441,9 +441,9 @@ def matmul_rs(
...
@@ -441,9 +441,9 @@ def matmul_rs(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared"
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
"shared"
)
...
...
testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py
View file @
29051439
...
@@ -6,7 +6,8 @@ from tvm import DataType
...
@@ -6,7 +6,8 @@ from tvm import DataType
import
tilelang.language
as
T
import
tilelang.language
as
T
from
tilelang.intrinsics
import
get_swizzle_layout
from
tilelang.intrinsics
import
get_swizzle_layout
from
tilelang.intrinsics.mma_macro_generator
import
(
from
tilelang.intrinsics.mma_macro_generator
import
(
TensorCoreIntrinEmitter
,)
TensorCoreIntrinEmitter
,
)
from
tilelang.transform
import
simplify_prim_func
from
tilelang.transform
import
simplify_prim_func
from
tilelang.utils.tensor
import
map_torch_type
from
tilelang.utils.tensor
import
map_torch_type
...
@@ -111,12 +112,11 @@ def tl_matmul(
...
@@ -111,12 +112,11 @@ def tl_matmul(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
...
@@ -124,10 +124,12 @@ def tl_matmul(
...
@@ -124,10 +124,12 @@ def tl_matmul(
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
A_shared
:
make_swizzle_layout
(
A_shared
),
{
B_shared
:
make_swizzle_layout
(
B_shared
),
A_shared
:
make_swizzle_layout
(
A_shared
),
})
B_shared
:
make_swizzle_layout
(
B_shared
),
}
)
# Improve L2 Cache
# Improve L2 Cache
T
.
use_swizzle
(
panel_size
=
10
)
T
.
use_swizzle
(
panel_size
=
10
)
...
@@ -135,7 +137,6 @@ def tl_matmul(
...
@@ -135,7 +137,6 @@ def tl_matmul(
T
.
clear
(
C_local
)
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
stage
):
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
stage
):
# Load A into shared memory
# Load A into shared memory
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
...
@@ -145,7 +146,6 @@ def tl_matmul(
...
@@ -145,7 +146,6 @@ def tl_matmul(
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
block_K
+
k
]
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
block_K
+
k
]
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load A into fragment
# Load A into fragment
mma_emitter
.
ldmatrix_a
(
mma_emitter
.
ldmatrix_a
(
A_local
,
A_local
,
...
...
testing/python/kernel/test_tilelang_kernel_gemm_simt.py
View file @
29051439
...
@@ -76,12 +76,11 @@ def tl_matmul_simt(
...
@@ -76,12 +76,11 @@ def tl_matmul_simt(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
(
C_shape
,
out_dtype
),
C
:
T
.
Tensor
(
C_shape
,
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
...
@@ -97,7 +96,6 @@ def tl_matmul_simt(
...
@@ -97,7 +96,6 @@ def tl_matmul_simt(
T
.
clear
(
C_local
)
T
.
clear
(
C_local
)
for
ko
in
T
.
serial
(
K
//
block_K
):
for
ko
in
T
.
serial
(
K
//
block_K
):
# Load A into shared memory
# Load A into shared memory
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
...
@@ -109,29 +107,24 @@ def tl_matmul_simt(
...
@@ -109,29 +107,24 @@ def tl_matmul_simt(
for
ki
in
T
.
serial
((
block_K
//
micro_size_k
)):
for
ki
in
T
.
serial
((
block_K
//
micro_size_k
)):
for
i
in
T
.
serial
(
local_size_a
):
for
i
in
T
.
serial
(
local_size_a
):
for
mk
in
T
.
vectorized
(
micro_size_k
):
for
mk
in
T
.
vectorized
(
micro_size_k
):
A_local
[
i
,
mk
]
=
A_shared
[
warp_m
*
local_size_a
+
i
,
A_local
[
i
,
mk
]
=
A_shared
[
warp_m
*
local_size_a
+
i
,
ki
*
micro_size_k
+
mk
]
ki
*
micro_size_k
+
mk
]
for
i
in
T
.
serial
(
local_size_b
):
for
i
in
T
.
serial
(
local_size_b
):
for
mk
in
T
.
vectorized
(
micro_size_k
):
for
mk
in
T
.
vectorized
(
micro_size_k
):
B_local
[
i
,
mk
]
=
B_shared
[
warp_n
*
local_size_b
+
i
,
B_local
[
i
,
mk
]
=
B_shared
[
warp_n
*
local_size_b
+
i
,
ki
*
micro_size_k
+
mk
]
ki
*
micro_size_k
+
mk
]
for
i
,
j
in
T
.
grid
(
local_size_a
,
local_size_b
):
for
i
,
j
in
T
.
grid
(
local_size_a
,
local_size_b
):
for
mk
in
T
.
serial
(
micro_size_k
//
dp4a_size
):
for
mk
in
T
.
serial
(
micro_size_k
//
dp4a_size
):
if
use_dp4a
:
if
use_dp4a
:
T
.
dp4a
(
A_local
[
i
,
mk
*
dp4a_size
],
B_local
[
j
,
mk
*
dp4a_size
],
T
.
dp4a
(
A_local
[
i
,
mk
*
dp4a_size
],
B_local
[
j
,
mk
*
dp4a_size
],
C_local
[
i
*
local_size_b
+
j
])
C_local
[
i
*
local_size_b
+
j
])
else
:
else
:
for
dp4a_idx
in
T
.
serial
(
dp4a_size
):
for
dp4a_idx
in
T
.
serial
(
dp4a_size
):
C_local
[
i
*
local_size_b
+
C_local
[
i
*
local_size_b
+
j
]
+=
(
j
]
+=
A_local
[
i
,
mk
*
dp4a_size
+
A_local
[
i
,
mk
*
dp4a_size
+
dp4a_idx
]
*
B_local
[
j
,
mk
*
dp4a_size
+
dp4a_idx
]
dp4a_idx
]
*
B_local
[
j
,
mk
*
dp4a_size
+
)
dp4a_idx
]
for
i
,
j
in
T
.
grid
(
local_size_a
,
local_size_b
):
for
i
,
j
in
T
.
grid
(
local_size_a
,
local_size_b
):
C
[
by
*
block_M
+
warp_m
*
local_size_a
+
i
,
C
[
by
*
block_M
+
warp_m
*
local_size_a
+
i
,
bx
*
block_N
+
warp_n
*
local_size_b
+
j
]
=
C_local
[
i
*
local_size_b
+
j
]
bx
*
block_N
+
warp_n
*
local_size_b
+
j
]
=
C_local
[
i
*
local_size_b
+
j
]
return
main
return
main
...
...
testing/python/kernel/test_tilelang_kernel_gemm_with_stride.py
View file @
29051439
...
@@ -5,12 +5,11 @@ import torch
...
@@ -5,12 +5,11 @@ import torch
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
# Initialize Kernel Context
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
...
@@ -59,7 +58,8 @@ def run_gemm_with_stride_ss(M: int, N: int, K: int, block_M: int, block_N: int,
...
@@ -59,7 +58,8 @@ def run_gemm_with_stride_ss(M: int, N: int, K: int, block_M: int, block_N: int,
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
# Create random input tensors on the GPU
# Create random input tensors on the GPU
a
=
torch
.
randn
(
M
,
K
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
a
=
torch
.
randn
(
M
,
K
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
b
=
torch
.
randn
(
K
,
N
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
b
=
torch
.
randn
(
K
,
N
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
...
...
testing/python/kernel/test_tilelang_kernel_gemv_simt.py
View file @
29051439
...
@@ -27,8 +27,8 @@ def gemv_simt(
...
@@ -27,8 +27,8 @@ def gemv_simt(
):
):
assert
n_partition
is
not
None
,
"n_partition must be provided"
assert
n_partition
is
not
None
,
"n_partition must be provided"
assert
reduce_thread
is
not
None
,
(
assert
reduce_thread
is
not
None
,
(
"reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMV"
"reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMV
sch_outer_reduction_with_config is not implemented
"
"sch_outer_reduction_with_config is not implemented"
)
)
assert
isinstance
(
N
,
int
)
and
isinstance
(
K
,
int
),
"Do not support dynamic N and K Currently"
assert
isinstance
(
N
,
int
)
and
isinstance
(
K
,
int
),
"Do not support dynamic N and K Currently"
...
@@ -50,16 +50,15 @@ def gemv_simt(
...
@@ -50,16 +50,15 @@ def gemv_simt(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
Bias
:
T
.
Tensor
(
Bias_shape
,
out_dtype
),
Bias
:
T
.
Tensor
(
Bias_shape
,
out_dtype
),
C
:
T
.
Tensor
(
C_shape
,
out_dtype
),
C
:
T
.
Tensor
(
C_shape
,
out_dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
n_partition
),
M
,
threads
=
(
reduce_thread
,
n_partition
))
as
(
T
.
ceildiv
(
N
,
n_partition
),
M
,
threads
=
(
reduce_thread
,
n_partition
))
as
(
bx
,
bx
,
by
,
by
,
):
):
A_local
=
T
.
alloc_local
((
micro_size_k
,),
in_dtype
)
A_local
=
T
.
alloc_local
((
micro_size_k
,),
in_dtype
)
B_local
=
T
.
alloc_local
((
micro_size_k
,),
in_dtype
)
B_local
=
T
.
alloc_local
((
micro_size_k
,),
in_dtype
)
accum_res
=
T
.
alloc_local
((
1
,),
accum_dtype
)
accum_res
=
T
.
alloc_local
((
1
,),
accum_dtype
)
...
@@ -88,13 +87,12 @@ def gemv_simt(
...
@@ -88,13 +87,12 @@ def gemv_simt(
)
)
else
:
else
:
for
ki
in
T
.
serial
(
micro_size_k
):
for
ki
in
T
.
serial
(
micro_size_k
):
accum_res
[
0
]
+=
A_local
[
ki
].
astype
(
accum_dtype
)
*
B_local
[
ki
].
astype
(
accum_res
[
0
]
+=
A_local
[
ki
].
astype
(
accum_dtype
)
*
B_local
[
ki
].
astype
(
accum_dtype
)
accum_dtype
)
with
T
.
attr
(
with
T
.
attr
(
T
.
comm_reducer
(
lambda
x
,
y
:
x
+
y
,
[
T
.
Cast
(
accum_dtype
,
0
)]),
T
.
comm_reducer
(
lambda
x
,
y
:
x
+
y
,
[
T
.
Cast
(
accum_dtype
,
0
)]),
"reduce_scope"
,
"reduce_scope"
,
T
.
reinterpret
(
T
.
uint64
(
0
),
dtype
=
"handle"
),
T
.
reinterpret
(
T
.
uint64
(
0
),
dtype
=
"handle"
),
):
):
T
.
evaluate
(
T
.
evaluate
(
T
.
tvm_thread_allreduce
(
T
.
tvm_thread_allreduce
(
...
@@ -104,11 +102,11 @@ def gemv_simt(
...
@@ -104,11 +102,11 @@ def gemv_simt(
reduced_accum_res
[
0
],
reduced_accum_res
[
0
],
kr
,
kr
,
dtype
=
"handle"
,
dtype
=
"handle"
,
))
)
)
if
kr
==
0
:
if
kr
==
0
:
if
with_bias
:
if
with_bias
:
C
[
by
,
C
[
by
,
bx
*
n_partition
+
ni
]
=
reduced_accum_res
[
0
]
+
Bias
[
bx
*
n_partition
+
ni
]
bx
*
n_partition
+
ni
]
=
reduced_accum_res
[
0
]
+
Bias
[
bx
*
n_partition
+
ni
]
else
:
else
:
C
[
by
,
bx
*
n_partition
+
ni
]
=
reduced_accum_res
[
0
]
C
[
by
,
bx
*
n_partition
+
ni
]
=
reduced_accum_res
[
0
]
...
...
testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py
View file @
29051439
...
@@ -4,7 +4,8 @@ from tilelang import tvm as tvm
...
@@ -4,7 +4,8 @@ from tilelang import tvm as tvm
import
tilelang.testing
import
tilelang.testing
import
tilelang.language
as
T
import
tilelang.language
as
T
from
tilelang.intrinsics
import
(
from
tilelang.intrinsics
import
(
make_mma_swizzle_layout
as
make_swizzle_layout
,)
make_mma_swizzle_layout
as
make_swizzle_layout
,
)
from
tilelang.intrinsics.mma_macro_generator
import
(
from
tilelang.intrinsics.mma_macro_generator
import
(
INT4TensorCoreIntrinEmitter
,
INT4TensorCoreIntrinEmitter
,
...
@@ -91,12 +92,11 @@ def tl_matmul(
...
@@ -91,12 +92,11 @@ def tl_matmul(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
...
@@ -104,10 +104,12 @@ def tl_matmul(
...
@@ -104,10 +104,12 @@ def tl_matmul(
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
A_shared
:
make_swizzle_layout
(
A_shared
),
{
B_shared
:
make_swizzle_layout
(
B_shared
),
A_shared
:
make_swizzle_layout
(
A_shared
),
})
B_shared
:
make_swizzle_layout
(
B_shared
),
}
)
# Improve L2 Cache
# Improve L2 Cache
T
.
use_swizzle
(
panel_size
=
10
)
T
.
use_swizzle
(
panel_size
=
10
)
...
@@ -115,7 +117,6 @@ def tl_matmul(
...
@@ -115,7 +117,6 @@ def tl_matmul(
T
.
clear
(
C_local
)
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
stage
):
for
ko
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
stage
):
# Load A into shared memory
# Load A into shared memory
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
...
@@ -125,7 +126,6 @@ def tl_matmul(
...
@@ -125,7 +126,6 @@ def tl_matmul(
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
block_K
+
k
]
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
block_K
+
k
]
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load A into fragment
# Load A into fragment
mma_emitter
.
ldmatrix_a
(
mma_emitter
.
ldmatrix_a
(
A_local
,
A_local
,
...
@@ -168,7 +168,8 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
...
@@ -168,7 +168,8 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
out_idx
=
[
2
],
out_idx
=
[
2
],
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONS
:
True
,
tilelang
.
PassConfigKey
.
TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONS
:
True
,
})
},
)
print
(
kernel
.
get_kernel_source
())
print
(
kernel
.
get_kernel_source
())
profiler
=
kernel
.
get_profiler
()
profiler
=
kernel
.
get_profiler
()
...
@@ -285,12 +286,11 @@ def tl_matmul_weight_only_transform(
...
@@ -285,12 +286,11 @@ def tl_matmul_weight_only_transform(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
...
@@ -298,10 +298,12 @@ def tl_matmul_weight_only_transform(
...
@@ -298,10 +298,12 @@ def tl_matmul_weight_only_transform(
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
A_shared
:
make_swizzle_layout
(
A_shared
),
{
B_shared
:
make_swizzle_layout
(
B_shared
),
A_shared
:
make_swizzle_layout
(
A_shared
),
})
B_shared
:
make_swizzle_layout
(
B_shared
),
}
)
# Improve L2 Cache
# Improve L2 Cache
T
.
use_swizzle
(
panel_size
=
10
)
T
.
use_swizzle
(
panel_size
=
10
)
...
@@ -309,19 +311,15 @@ def tl_matmul_weight_only_transform(
...
@@ -309,19 +311,15 @@ def tl_matmul_weight_only_transform(
T
.
clear
(
C_local
)
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
stage
):
for
ko
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
stage
):
# Load A into shared memory
# Load A into shared memory
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
# Load B into shared memory
# Load B into shared memory
for
j
,
k
,
jj
,
kk
in
T
.
Parallel
(
block_N
//
micro_size_y
,
block_K
//
micro_size_k
,
for
j
,
k
,
jj
,
kk
in
T
.
Parallel
(
block_N
//
micro_size_y
,
block_K
//
micro_size_k
,
micro_size_y
,
micro_size_k
):
micro_size_y
,
micro_size_k
):
B_shared
[
j
,
k
,
jj
,
kk
]
=
B
[
bx
*
(
block_N
//
micro_size_y
)
+
j
,
ko
*
(
block_K
//
micro_size_k
)
+
k
,
jj
,
kk
]
B_shared
[
j
,
k
,
jj
,
kk
]
=
B
[
bx
*
(
block_N
//
micro_size_y
)
+
j
,
ko
*
(
block_K
//
micro_size_k
)
+
k
,
jj
,
kk
]
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load A into fragment
# Load A into fragment
mma_emitter
.
ldmatrix_a
(
mma_emitter
.
ldmatrix_a
(
A_local
,
A_local
,
...
@@ -359,6 +357,7 @@ def tl_matmul_weight_only_transform(
...
@@ -359,6 +357,7 @@ def tl_matmul_weight_only_transform(
def
assert_tl_matmul_weight_only_transform_correctness
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
):
def
assert_tl_matmul_weight_only_transform_correctness
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
):
import
bitblas
import
bitblas
matmul
=
tl_matmul_weight_only_transform
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
)
matmul
=
tl_matmul_weight_only_transform
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
)
kernel
=
tilelang
.
compile
(
matmul
,
out_idx
=
[
2
])
kernel
=
tilelang
.
compile
(
matmul
,
out_idx
=
[
2
])
profiler
=
kernel
.
get_profiler
()
profiler
=
kernel
.
get_profiler
()
...
...
testing/python/language/test_tilelang_capture.py
View file @
29051439
...
@@ -6,16 +6,17 @@ import gc
...
@@ -6,16 +6,17 @@ import gc
def
test_tilelang_capture
():
def
test_tilelang_capture
():
@
tilelang
.
jit
(
@
tilelang
.
jit
(
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
},)
},
)
def
get_dummy_kernel
():
def
get_dummy_kernel
():
@
T
.
prim_func
@
T
.
prim_func
def
dummy_kernel
(
a
:
T
.
Tensor
[(
1
,),
T
.
float32
],):
def
dummy_kernel
(
a
:
T
.
Tensor
[(
1
,),
T
.
float32
],
):
with
T
.
Kernel
(
1
)
as
_
:
with
T
.
Kernel
(
1
)
as
_
:
a
[
0
]
=
1
a
[
0
]
=
1
...
@@ -36,5 +37,5 @@ def test_tilelang_capture():
...
@@ -36,5 +37,5 @@ def test_tilelang_capture():
# objgraph.show_backrefs([a_upgrade], max_depth=5)
# objgraph.show_backrefs([a_upgrade], max_depth=5)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
tilelang
.
testing
.
main
()
tilelang
.
testing
.
main
()
testing/python/language/test_tilelang_intimm.py
View file @
29051439
...
@@ -4,25 +4,25 @@ import tilelang.language as T
...
@@ -4,25 +4,25 @@ import tilelang.language as T
def
test_tilelang_intimm
():
def
test_tilelang_intimm
():
T
.
int32
(
0x7
fffffff
)
T
.
int32
(
0x7
FFFFFFF
)
T
.
int32
(
-
0x7
fffffff
-
1
)
T
.
int32
(
-
0x7
FFFFFFF
-
1
)
T
.
uint32
(
0x
ffffffff
)
T
.
uint32
(
0x
FFFFFFFF
)
T
.
int64
(
0x7
fffffffffffffff
)
T
.
int64
(
0x7
FFFFFFFFFFFFFFF
)
T
.
int64
(
-
0x7
fffffffffffffff
-
1
)
T
.
int64
(
-
0x7
FFFFFFFFFFFFFFF
-
1
)
T
.
uint64
(
0x
ffffffffffffffff
)
T
.
uint64
(
0x
FFFFFFFFFFFFFFFF
)
a
=
T
.
int32
()
a
=
T
.
int32
()
a
&
0x7
fffffff
a
&
0x7
FFFFFFF
a
=
T
.
uint32
()
a
=
T
.
uint32
()
a
&
0x
ffffffff
a
&
0x
FFFFFFFF
a
=
T
.
int64
()
a
=
T
.
int64
()
a
&
0x7
fffffffffffffff
a
&
0x7
FFFFFFFFFFFFFFF
a
=
T
.
uint64
()
a
=
T
.
uint64
()
a
&
T
.
uint64
(
0x
ffffffffffffffff
)
a
&
T
.
uint64
(
0x
FFFFFFFFFFFFFFFF
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
tilelang
.
testing
.
main
()
tilelang
.
testing
.
main
()
testing/python/language/test_tilelang_language_alias.py
View file @
29051439
...
@@ -5,12 +5,11 @@ import tilelang.language as T
...
@@ -5,12 +5,11 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
# @tilelang.jit
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
# Initialize Kernel Context
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
...
...
testing/python/language/test_tilelang_language_all_of.py
View file @
29051439
...
@@ -13,11 +13,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K):
...
@@ -13,11 +13,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K):
accu
=
torch
.
zeros
((
block_M
,
block_N
),
dtype
=
torch
.
float32
,
device
=
A
.
device
)
accu
=
torch
.
zeros
((
block_M
,
block_N
),
dtype
=
torch
.
float32
,
device
=
A
.
device
)
for
k
in
range
(
K
//
block_K
):
for
k
in
range
(
K
//
block_K
):
if
torch
.
all
(
BlockMask
[
i
,
j
,
k
]):
if
torch
.
all
(
BlockMask
[
i
,
j
,
k
]):
accu
+=
A
[
i
*
block_M
:(
i
+
1
)
*
block_M
,
k
*
block_K
:(
k
+
1
)
*
block_K
].
to
(
accu
+=
A
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
k
*
block_K
:
(
k
+
1
)
*
block_K
].
to
(
torch
.
float32
)
@
B
[
torch
.
float32
)
@
B
[
k
*
block_K
:(
k
+
1
)
*
block_K
,
k
*
block_K
:
(
k
+
1
)
*
block_K
,
j
*
block_N
:
(
j
+
1
)
*
block_N
j
*
block_N
:(
j
+
1
)
*
block_N
].
to
(
torch
.
float32
)
].
to
(
torch
.
float32
)
ref_c
[
i
*
block_M
:(
i
+
1
)
*
block_M
,
j
*
block_N
:(
j
+
1
)
*
block_N
]
=
(
ref_c
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
j
*
block_N
:
(
j
+
1
)
*
block_N
]
=
accu
.
to
(
torch
.
float16
)
accu
.
to
(
torch
.
float16
))
return
ref_c
return
ref_c
...
@@ -35,15 +34,14 @@ def blocksparse_matmul_global(
...
@@ -35,15 +34,14 @@ def blocksparse_matmul_global(
dtype
=
"float16"
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
accum_dtype
=
"float"
,
):
):
block_mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
,
condition_dim
)
block_mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
,
condition_dim
)
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
@@ -80,15 +78,14 @@ def blocksparse_matmul_shared(
...
@@ -80,15 +78,14 @@ def blocksparse_matmul_shared(
dtype
=
"float16"
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
accum_dtype
=
"float"
,
):
):
block_mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
,
condition_dim
)
block_mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
,
condition_dim
)
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
@@ -130,15 +127,14 @@ def blocksparse_matmul_local(
...
@@ -130,15 +127,14 @@ def blocksparse_matmul_local(
dtype
=
"float16"
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
accum_dtype
=
"float"
,
):
):
block_mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
,
condition_dim
)
block_mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
,
condition_dim
)
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
@@ -237,7 +233,8 @@ def run_block_sparse_matmul_shared(M=1024, N=1024, K=1024, sparsity=0.5, conditi
...
@@ -237,7 +233,8 @@ def run_block_sparse_matmul_shared(M=1024, N=1024, K=1024, sparsity=0.5, conditi
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
# Create block mask with desired sparsity
# Create block mask with desired sparsity
mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
)
mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
)
block_mask
=
torch
.
rand
(
mask_shape
).
cuda
()
>
sparsity
block_mask
=
torch
.
rand
(
mask_shape
).
cuda
()
>
sparsity
...
@@ -284,7 +281,8 @@ def run_block_sparse_matmul_local(M=1024, N=1024, K=1024, sparsity=0.5, conditio
...
@@ -284,7 +281,8 @@ def run_block_sparse_matmul_local(M=1024, N=1024, K=1024, sparsity=0.5, conditio
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
# Create block mask with desired sparsity
# Create block mask with desired sparsity
mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
)
mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
)
block_mask
=
torch
.
rand
(
mask_shape
).
cuda
()
>
sparsity
block_mask
=
torch
.
rand
(
mask_shape
).
cuda
()
>
sparsity
...
...
testing/python/language/test_tilelang_language_alloc.py
View file @
29051439
...
@@ -10,8 +10,8 @@ def alloc_var(
...
@@ -10,8 +10,8 @@ def alloc_var(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
N
,),
dtype
),
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
block_N
)
as
bx
:
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
block_N
)
as
bx
:
A_shared
=
T
.
alloc_shared
([
block_N
],
dtype
)
A_shared
=
T
.
alloc_shared
([
block_N
],
dtype
)
...
@@ -50,8 +50,8 @@ def alloc_var_add(
...
@@ -50,8 +50,8 @@ def alloc_var_add(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
N
,),
dtype
),
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
block_N
)
as
bx
:
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
block_N
)
as
bx
:
A_shared
=
T
.
alloc_shared
([
block_N
],
dtype
)
A_shared
=
T
.
alloc_shared
([
block_N
],
dtype
)
...
@@ -91,8 +91,8 @@ def alloc_var_with_initializer(
...
@@ -91,8 +91,8 @@ def alloc_var_with_initializer(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
N
,),
dtype
),
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
block_N
)
as
bx
:
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
block_N
)
as
bx
:
tmp
=
T
.
alloc_var
(
dtype
,
init_value
)
tmp
=
T
.
alloc_var
(
dtype
,
init_value
)
...
@@ -129,8 +129,8 @@ def alloc_multi_vars_with_initializer(
...
@@ -129,8 +129,8 @@ def alloc_multi_vars_with_initializer(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
N
,),
dtype
),
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
block_N
)
as
bx
:
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
block_N
)
as
bx
:
tmp0
=
T
.
alloc_var
(
dtype
,
1
)
tmp0
=
T
.
alloc_var
(
dtype
,
1
)
...
...
Prev
1
…
9
10
11
12
13
14
15
16
17
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment