Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
467
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
247 additions
and
340 deletions
+247
-340
testing/python/language/test_tilelang_language_annot.py
testing/python/language/test_tilelang_language_annot.py
+19
-16
testing/python/language/test_tilelang_language_annotate_safe_value.py
...on/language/test_tilelang_language_annotate_safe_value.py
+4
-10
testing/python/language/test_tilelang_language_any_of.py
testing/python/language/test_tilelang_language_any_of.py
+20
-22
testing/python/language/test_tilelang_language_assume.py
testing/python/language/test_tilelang_language_assume.py
+11
-14
testing/python/language/test_tilelang_language_atomic_add.py
testing/python/language/test_tilelang_language_atomic_add.py
+14
-31
testing/python/language/test_tilelang_language_ceildiv.py
testing/python/language/test_tilelang_language_ceildiv.py
+0
-2
testing/python/language/test_tilelang_language_chain_equal.py
...ing/python/language/test_tilelang_language_chain_equal.py
+5
-5
testing/python/language/test_tilelang_language_clamp.py
testing/python/language/test_tilelang_language_clamp.py
+4
-4
testing/python/language/test_tilelang_language_clear.py
testing/python/language/test_tilelang_language_clear.py
+5
-6
testing/python/language/test_tilelang_language_composable_index.py
...ython/language/test_tilelang_language_composable_index.py
+4
-4
testing/python/language/test_tilelang_language_copy.py
testing/python/language/test_tilelang_language_copy.py
+18
-36
testing/python/language/test_tilelang_language_cumsum.py
testing/python/language/test_tilelang_language_cumsum.py
+17
-14
testing/python/language/test_tilelang_language_frontend_v2.py
...ing/python/language/test_tilelang_language_frontend_v2.py
+39
-48
testing/python/language/test_tilelang_language_get_warp_info.py
...g/python/language/test_tilelang_language_get_warp_info.py
+0
-5
testing/python/language/test_tilelang_language_if_range.py
testing/python/language/test_tilelang_language_if_range.py
+5
-4
testing/python/language/test_tilelang_language_infinity.py
testing/python/language/test_tilelang_language_infinity.py
+1
-2
testing/python/language/test_tilelang_language_intrinsics_codegen.py
...hon/language/test_tilelang_language_intrinsics_codegen.py
+2
-2
testing/python/language/test_tilelang_language_lazy_jit.py
testing/python/language/test_tilelang_language_lazy_jit.py
+62
-69
testing/python/language/test_tilelang_language_let.py
testing/python/language/test_tilelang_language_let.py
+0
-1
testing/python/language/test_tilelang_language_mask_op.py
testing/python/language/test_tilelang_language_mask_op.py
+17
-45
No files found.
testing/python/language/test_tilelang_language_annot.py
View file @
29051439
...
...
@@ -5,13 +5,14 @@ import torch
def
test_tensor_annot_mul
():
@
tilelang
.
jit
def
example_tensor_annot
():
n
=
T
.
symbolic
(
'n'
)
n
=
T
.
symbolic
(
"n"
)
@
T
.
prim_func
def
kernel
(
A
:
T
.
Tensor
((
n
*
4
,),
T
.
int32
),):
def
kernel
(
A
:
T
.
Tensor
((
n
*
4
,),
T
.
int32
),
):
with
T
.
Kernel
(
1
)
as
_
:
for
i
in
range
(
n
*
4
):
A
[
i
]
=
0
...
...
@@ -19,20 +20,21 @@ def test_tensor_annot_mul():
return
kernel
ker
=
example_tensor_annot
()
A
=
torch
.
arange
(
16
,
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
A
=
torch
.
arange
(
16
,
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
ker
(
A
)
expected
=
torch
.
zeros
(
16
,
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
expected
=
torch
.
zeros
(
16
,
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
assert
torch
.
equal
(
A
,
expected
)
def
test_tensor_annot_add
():
@
tilelang
.
jit
def
example_tensor_annot
():
n
=
T
.
symbolic
(
'n'
)
n
=
T
.
symbolic
(
"n"
)
@
T
.
prim_func
def
kernel
(
A
:
T
.
Tensor
((
n
+
1
,),
T
.
int32
),):
def
kernel
(
A
:
T
.
Tensor
((
n
+
1
,),
T
.
int32
),
):
with
T
.
Kernel
(
1
)
as
_
:
for
i
in
range
(
n
+
1
):
A
[
i
]
=
0
...
...
@@ -40,20 +42,21 @@ def test_tensor_annot_add():
return
kernel
ker
=
example_tensor_annot
()
A
=
torch
.
arange
(
16
,
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
A
=
torch
.
arange
(
16
,
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
ker
(
A
)
expected
=
torch
.
zeros
(
16
,
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
expected
=
torch
.
zeros
(
16
,
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
assert
torch
.
equal
(
A
,
expected
)
def
test_tensor_annot_mul_add
():
@
tilelang
.
jit
def
example_tensor_annot
():
n
=
T
.
symbolic
(
'n'
)
n
=
T
.
symbolic
(
"n"
)
@
T
.
prim_func
def
kernel
(
A
:
T
.
Tensor
((
n
*
3
+
1
,),
T
.
int32
),):
def
kernel
(
A
:
T
.
Tensor
((
n
*
3
+
1
,),
T
.
int32
),
):
with
T
.
Kernel
(
1
)
as
_
:
for
i
in
range
(
n
*
3
+
1
):
A
[
i
]
=
0
...
...
@@ -61,11 +64,11 @@ def test_tensor_annot_mul_add():
return
kernel
ker
=
example_tensor_annot
()
A
=
torch
.
arange
(
16
,
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
A
=
torch
.
arange
(
16
,
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
ker
(
A
)
expected
=
torch
.
zeros
(
16
,
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
expected
=
torch
.
zeros
(
16
,
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
assert
torch
.
equal
(
A
,
expected
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
tilelang
.
testing
.
main
()
testing/python/language/test_tilelang_language_annotate_safe_value.py
View file @
29051439
...
...
@@ -7,7 +7,6 @@ import torch
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def
tilelang_copy
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
,
pad_value
=
0
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
...
...
@@ -30,13 +29,8 @@ def tilelang_copy(M, N, block_M, block_N, dtype="float16", pad_value=0):
def
run_tilelang_copy
(
M
=
1024
,
N
=
1024
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
,
pad_value
=
0
):
program
=
tilelang_copy
(
M
,
N
,
block_M
,
block_N
,
dtype
,
pad_value
=
pad_value
)
kernel
=
tilelang
.
compile
(
program
,
out_idx
=
[
1
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_warp_specialized"
:
True
,
"tl.disable_tma_lower"
:
True
})
program
,
out_idx
=
[
1
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_warp_specialized"
:
True
,
"tl.disable_tma_lower"
:
True
}
)
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
b
=
kernel
(
a
)
ref_b
=
torch
.
zeros_like
(
a
)
...
...
testing/python/language/test_tilelang_language_any_of.py
View file @
29051439
...
...
@@ -13,11 +13,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K):
accu
=
torch
.
zeros
((
block_M
,
block_N
),
dtype
=
torch
.
float32
,
device
=
A
.
device
)
for
k
in
range
(
K
//
block_K
):
if
torch
.
any
(
BlockMask
[
i
,
j
,
k
]):
accu
+=
A
[
i
*
block_M
:(
i
+
1
)
*
block_M
,
k
*
block_K
:(
k
+
1
)
*
block_K
].
to
(
torch
.
float32
)
@
B
[
k
*
block_K
:(
k
+
1
)
*
block_K
,
j
*
block_N
:(
j
+
1
)
*
block_N
].
to
(
torch
.
float32
)
ref_c
[
i
*
block_M
:(
i
+
1
)
*
block_M
,
j
*
block_N
:(
j
+
1
)
*
block_N
]
=
(
accu
.
to
(
torch
.
float16
))
accu
+=
A
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
k
*
block_K
:
(
k
+
1
)
*
block_K
].
to
(
torch
.
float32
)
@
B
[
k
*
block_K
:
(
k
+
1
)
*
block_K
,
j
*
block_N
:
(
j
+
1
)
*
block_N
].
to
(
torch
.
float32
)
ref_c
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
j
*
block_N
:
(
j
+
1
)
*
block_N
]
=
accu
.
to
(
torch
.
float16
)
return
ref_c
...
...
@@ -35,7 +34,6 @@ def blocksparse_matmul_global(
dtype
=
"float16"
,
accum_dtype
=
"float"
,
):
block_mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
,
condition_dim
)
@
T
.
prim_func
...
...
@@ -80,7 +78,6 @@ def blocksparse_matmul_shared(
dtype
=
"float16"
,
accum_dtype
=
"float"
,
):
block_mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
,
condition_dim
)
@
T
.
prim_func
...
...
@@ -130,7 +127,6 @@ def blocksparse_matmul_local(
dtype
=
"float16"
,
accum_dtype
=
"float"
,
):
block_mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
,
condition_dim
)
@
T
.
prim_func
...
...
@@ -237,7 +233,8 @@ def run_block_sparse_matmul_shared(M=1024, N=1024, K=1024, sparsity=0.5, conditi
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
# Create block mask with desired sparsity
mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
)
block_mask
=
torch
.
rand
(
mask_shape
).
cuda
()
>
sparsity
...
...
@@ -284,7 +281,8 @@ def run_block_sparse_matmul_local(M=1024, N=1024, K=1024, sparsity=0.5, conditio
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
# Create block mask with desired sparsity
mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
)
block_mask
=
torch
.
rand
(
mask_shape
).
cuda
()
>
sparsity
...
...
testing/python/language/test_tilelang_language_assume.py
View file @
29051439
...
...
@@ -4,10 +4,9 @@ import tilelang.testing
def
test_assume_remove_boundary_check
():
@
tilelang
.
jit
def
kernel_with_assume
():
N
=
T
.
dynamic
(
'N'
)
N
=
T
.
dynamic
(
"N"
)
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
N
,),
"float32"
),
l
:
T
.
int32
,
r
:
T
.
int32
):
...
...
@@ -21,14 +20,13 @@ def test_assume_remove_boundary_check():
jit_kernel
=
kernel_with_assume
()
source
=
jit_kernel
.
get_kernel_source
()
assert
(
"if ("
not
in
source
)
assert
"if ("
not
in
source
def
test_assume_enable_vectorization
():
@
tilelang
.
jit
def
kernel_vectorize
(
M
):
N
=
T
.
dynamic
(
'N'
)
N
=
T
.
dynamic
(
"N"
)
vectorize_size
=
4
@
T
.
prim_func
...
...
@@ -55,11 +53,10 @@ def test_assume_enable_vectorization():
def
test_assume_complex_indexing
():
@
tilelang
.
jit
def
kernel_complex
():
M
=
T
.
dynamic
(
'M'
)
N
=
T
.
dynamic
(
'N'
)
M
=
T
.
dynamic
(
"M"
)
N
=
T
.
dynamic
(
"N"
)
@
T
.
prim_func
def
main
(
...
...
@@ -82,8 +79,8 @@ def test_assume_complex_indexing():
jit_kernel
=
kernel_complex
()
source
=
jit_kernel
.
get_kernel_source
()
assert
(
"if ("
not
in
source
)
assert
"if ("
not
in
source
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
tilelang
.
testing
.
main
()
testing/python/language/test_tilelang_language_atomic_add.py
View file @
29051439
...
...
@@ -4,14 +4,12 @@ import tilelang.language as T
@
tilelang
.
jit
def
atomic_add_program
(
K
,
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
@
T
.
prim_func
def
atomic_add
(
A
:
T
.
Tensor
((
K
,
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
K
,
threads
=
32
)
as
(
bx
,
by
,
bz
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
T
.
copy
(
A
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
*
block_N
:(
by
+
1
)
*
block_N
],
A_shared
)
T
.
copy
(
A
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
*
block_N
:
(
by
+
1
)
*
block_N
],
A_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
T
.
atomic_add
(
B
[
bx
*
block_M
+
i
,
by
*
block_N
+
j
],
A_shared
[
i
,
j
])
...
...
@@ -39,14 +37,12 @@ def run_atomic_add(K, M, N, block_M, block_N, dtype="float32"):
@
tilelang
.
jit
def
tile_atomic_add_program
(
K
,
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
@
T
.
prim_func
def
atomic_add
(
A
:
T
.
Tensor
((
K
,
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
K
,
threads
=
32
)
as
(
bx
,
by
,
bz
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
T
.
copy
(
A
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
*
block_N
:(
by
+
1
)
*
block_N
],
A_shared
)
T
.
copy
(
A
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
*
block_N
:
(
by
+
1
)
*
block_N
],
A_shared
)
T
.
atomic_add
(
B
[
bx
*
block_M
,
by
*
block_N
],
A_shared
)
...
...
@@ -76,14 +72,12 @@ def run_tile_atomic_add(K, M, N, block_M, block_N, dtype="float32"):
@
tilelang
.
jit
def
atomic_max_program
(
K
,
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
@
T
.
prim_func
def
atomic_max
(
A
:
T
.
Tensor
((
K
,
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
K
,
threads
=
32
)
as
(
bx
,
by
,
bz
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
T
.
copy
(
A
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
*
block_N
:(
by
+
1
)
*
block_N
],
A_shared
)
T
.
copy
(
A
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
*
block_N
:
(
by
+
1
)
*
block_N
],
A_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
T
.
atomic_max
(
B
[
bx
*
block_M
+
i
,
by
*
block_N
+
j
],
A_shared
[
i
,
j
])
...
...
@@ -111,14 +105,12 @@ def run_atomic_max(K, M, N, block_M, block_N, dtype="float32"):
@
tilelang
.
jit
def
atomic_min_program
(
K
,
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
@
T
.
prim_func
def
atomic_min
(
A
:
T
.
Tensor
((
K
,
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
K
,
threads
=
32
)
as
(
bx
,
by
,
bz
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
T
.
copy
(
A
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
*
block_N
:(
by
+
1
)
*
block_N
],
A_shared
)
T
.
copy
(
A
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
*
block_N
:
(
by
+
1
)
*
block_N
],
A_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
T
.
atomic_min
(
B
[
bx
*
block_M
+
i
,
by
*
block_N
+
j
],
A_shared
[
i
,
j
])
...
...
@@ -137,7 +129,7 @@ def run_atomic_min(K, M, N, block_M, block_N, dtype="float32"):
B
[
i
,
j
]
=
min
(
B
[
i
,
j
],
A
[
k
,
i
,
j
])
A
=
torch
.
randn
(
K
,
M
,
N
,
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
B
=
torch
.
full
((
M
,
N
),
float
(
'
inf
'
),
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
B
=
torch
.
full
((
M
,
N
),
float
(
"
inf
"
),
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
ref_B
=
B
.
clone
()
ref_program
(
A
,
ref_B
)
kernel
(
A
,
B
)
...
...
@@ -146,7 +138,6 @@ def run_atomic_min(K, M, N, block_M, block_N, dtype="float32"):
@
tilelang
.
jit
def
atomic_load_store_program
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
@
T
.
prim_func
def
atomic_load_store
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
threads
=
32
)
as
(
bx
,
by
):
...
...
@@ -172,18 +163,15 @@ def run_atomic_load_store(M, N, block_M, block_N, dtype="float32"):
@
tilelang
.
jit
def
atomic_memory_order_program
(
K
,
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
@
T
.
prim_func
def
atomic_with_memory_order
(
A
:
T
.
Tensor
((
K
,
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
K
,
threads
=
32
)
as
(
bx
,
by
,
bz
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
T
.
copy
(
A
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
*
block_N
:(
by
+
1
)
*
block_N
],
A_shared
)
T
.
copy
(
A
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
*
block_N
:
(
by
+
1
)
*
block_N
],
A_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
T
.
atomic_add
(
B
[
bx
*
block_M
+
i
,
by
*
block_N
+
j
],
A_shared
[
i
,
j
],
memory_order
=
"relaxed"
)
T
.
atomic_add
(
B
[
bx
*
block_M
+
i
,
by
*
block_N
+
j
],
A_shared
[
i
,
j
],
memory_order
=
"relaxed"
)
return
atomic_with_memory_order
...
...
@@ -208,7 +196,6 @@ def run_atomic_memory_order(K, M, N, block_M, block_N, dtype="float32"):
@
tilelang
.
jit
def
atomic_addx2_program
(
M
,
N
,
block_M
,
block_N
):
@
T
.
prim_func
def
atomic_addx2
(
A
:
T
.
Tensor
((
M
,
N
),
"float16"
),
B
:
T
.
Tensor
((
M
,
N
),
"float16"
)):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
threads
=
32
)
as
(
bx
,
by
):
...
...
@@ -262,10 +249,10 @@ def test_atomic_addx2():
@
tilelang
.
jit
def
atomic_different_memory_orders_program
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
@
T
.
prim_func
def
atomic_different_orders
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
(
(
M
,
N
),
dtype
),
D
:
T
.
Tensor
((
M
,
N
),
dtype
)):
def
atomic_different_orders
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
D
:
T
.
Tensor
((
M
,
N
),
dtype
)
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
threads
=
32
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
idx_i
=
bx
*
block_M
+
i
...
...
@@ -286,18 +273,17 @@ def run_atomic_different_memory_orders(M, N, block_M, block_N, dtype="float32"):
A
=
torch
.
randn
(
M
,
N
,
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
B
=
torch
.
zeros
(
M
,
N
,
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
C
=
torch
.
zeros
(
M
,
N
,
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
D
=
torch
.
full
((
M
,
N
),
float
(
'
inf
'
),
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
D
=
torch
.
full
((
M
,
N
),
float
(
"
inf
"
),
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
kernel
(
A
,
B
,
C
,
D
)
torch
.
testing
.
assert_close
(
B
,
A
,
atol
=
1e-3
,
rtol
=
1e-3
)
torch
.
testing
.
assert_close
(
C
,
torch
.
maximum
(
torch
.
zeros_like
(
A
),
A
))
torch
.
testing
.
assert_close
(
D
,
torch
.
minimum
(
torch
.
full_like
(
A
,
float
(
'
inf
'
)),
A
))
torch
.
testing
.
assert_close
(
D
,
torch
.
minimum
(
torch
.
full_like
(
A
,
float
(
"
inf
"
)),
A
))
@
tilelang
.
jit
def
atomic_addx4_program
(
M
,
N
,
block_M
,
block_N
):
@
T
.
prim_func
def
atomic_addx4
(
A
:
T
.
Tensor
((
M
,
N
),
"float32"
),
B
:
T
.
Tensor
((
M
,
N
),
"float32"
)):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
threads
=
32
)
as
(
bx
,
by
):
...
...
@@ -330,17 +316,14 @@ def run_atomic_addx4(M, N, block_M, block_N):
@
tilelang
.
jit
def
atomic_return_prev_program
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
@
T
.
prim_func
def
atomic_with_return_prev
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
old_vals
:
T
.
Tensor
((
M
,
N
),
dtype
)):
def
atomic_with_return_prev
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
old_vals
:
T
.
Tensor
((
M
,
N
),
dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
threads
=
32
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
idx_i
=
bx
*
block_M
+
i
idx_j
=
by
*
block_N
+
j
if
idx_i
<
M
and
idx_j
<
N
:
old_vals
[
idx_i
,
idx_j
]
=
T
.
atomic_add
(
B
[
idx_i
,
idx_j
],
A
[
idx_i
,
idx_j
],
return_prev
=
True
)
old_vals
[
idx_i
,
idx_j
]
=
T
.
atomic_add
(
B
[
idx_i
,
idx_j
],
A
[
idx_i
,
idx_j
],
return_prev
=
True
)
return
atomic_with_return_prev
...
...
testing/python/language/test_tilelang_language_ceildiv.py
View file @
29051439
...
...
@@ -5,7 +5,6 @@ import torch
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
_ceildiv_kernel
(
a
:
int
,
b
:
int
):
@
T
.
prim_func
def
ceildiv_kernel
(
A
:
T
.
Tensor
((
1
,),
"int32"
)):
with
T
.
Kernel
(
1
,
threads
=
1
)
as
_
:
...
...
@@ -30,7 +29,6 @@ def test_ceildiv():
@
tilelang
.
jit
def
_ceildiv_kernel_dyn
(
b
:
int
):
@
T
.
prim_func
def
ceildiv_kernel
(
A
:
T
.
Tensor
((
1
,),
"int32"
),
a
:
T
.
int32
):
with
T
.
Kernel
(
1
,
threads
=
1
)
as
_
:
...
...
testing/python/language/test_tilelang_language_chain_equal.py
View file @
29051439
...
...
@@ -8,9 +8,9 @@ import torch
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
},)
},
)
def
chain_equal
(
N
,
block_size
,
dtype
=
"float32"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
N
,),
dtype
),
...
...
testing/python/language/test_tilelang_language_clamp.py
View file @
29051439
testing/python/language/test_tilelang_language_clear.py
View file @
29051439
...
...
@@ -5,7 +5,6 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
...
@@ -42,10 +41,10 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
def
run_matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
program
=
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
,
accum_dtype
)
kernel
=
tilelang
.
compile
(
program
,
out_idx
=
[
2
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_tma_lower"
:
True
})
kernel
=
tilelang
.
compile
(
program
,
out_idx
=
[
2
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_tma_lower"
:
True
})
import
torch
from
tilelang.utils
import
map_torch_type
a
=
torch
.
randn
((
M
,
K
),
dtype
=
map_torch_type
(
dtype
)).
cuda
()
b
=
torch
.
randn
((
N
,
K
),
dtype
=
map_torch_type
(
dtype
)).
cuda
()
c
=
kernel
(
a
,
b
)
...
...
testing/python/language/test_tilelang_language_composable_index.py
View file @
29051439
...
...
@@ -7,7 +7,6 @@ import torch
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def
tilelang_composable_copy
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
...
...
@@ -35,7 +34,8 @@ def run_tilelang_composable_copy(M=1024, N=1024, block_M=128, block_N=128, dtype
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
b
=
kernel
(
a
)
torch
.
testing
.
assert_close
(
b
.
flatten
(),
a
.
flatten
(),
rtol
=
1e-2
,
atol
=
1e-2
)
...
...
testing/python/language/test_tilelang_language_copy.py
View file @
29051439
...
...
@@ -7,7 +7,6 @@ import tilelang.testing
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def
tilelang_copy
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
...
...
@@ -27,10 +26,8 @@ def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16")
program
,
out_idx
=
[
1
],
target
=
"cuda"
,
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
})
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
},
)
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
b
=
kernel
(
a
)
torch
.
testing
.
assert_close
(
b
,
a
,
rtol
=
1e-2
,
atol
=
1e-2
)
...
...
@@ -43,7 +40,6 @@ def test_tilelang_copy():
def
tilelang_copy_with_stride
(
M
,
N
,
NN
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
def
main
(
A
:
T
.
StridedTensor
((
M
,
N
),
(
NN
,
1
),
dtype
),
...
...
@@ -57,12 +53,7 @@ def tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype="float16"):
return
main
def
run_tilelang_copy_with_stride
(
M
=
1024
,
N
=
1024
,
NN
=
2048
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
):
def
run_tilelang_copy_with_stride
(
M
=
1024
,
N
=
1024
,
NN
=
2048
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
):
if
isinstance
(
NN
,
int
):
assert
NN
>
N
,
"NN must be greater than N"
program
=
tilelang_copy_with_stride
(
M
,
N
,
NN
,
block_M
,
block_N
,
dtype
)
...
...
@@ -73,7 +64,8 @@ def run_tilelang_copy_with_stride(M=1024,
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
})
},
)
if
isinstance
(
NN
,
T
.
Var
):
NN
=
N
*
2
a
=
torch
.
randn
(
M
,
NN
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
...
...
@@ -87,7 +79,6 @@ def test_tilelang_copy_with_stride():
def
tilelang_copy_bufferload
(
num_tokens
,
dtype
=
"float16"
):
@
T
.
prim_func
def
main
(
indices
:
T
.
Tensor
((
num_tokens
,),
"int32"
),
...
...
@@ -107,10 +98,8 @@ def run_tilelang_copy_bufferload(num_tokens=128, dtype="float16"):
tilelang
.
compile
(
program
,
out_idx
=
[
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
})
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
},
)
def
test_tilelang_copy_bufferload
():
...
...
@@ -118,7 +107,6 @@ def test_tilelang_copy_bufferload():
def
tilelang_copy_buffer_load_with_parallel
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
...
...
@@ -132,20 +120,14 @@ def tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype="float
return
main
def
run_tilelang_copy_buffer_load_with_parallel
(
M
=
1024
,
N
=
1024
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
):
def
run_tilelang_copy_buffer_load_with_parallel
(
M
=
1024
,
N
=
1024
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
):
program
=
tilelang_copy_buffer_load_with_parallel
(
M
,
N
,
block_M
,
block_N
,
dtype
)
kernel
=
tilelang
.
compile
(
program
,
out_idx
=
[
1
],
target
=
"cuda"
,
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
})
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
},
)
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
b
=
kernel
(
a
)
torch
.
testing
.
assert_close
(
b
,
a
,
rtol
=
1e-2
,
atol
=
1e-2
)
...
...
testing/python/language/test_tilelang_language_cumsum.py
View file @
29051439
...
...
@@ -57,13 +57,16 @@ def run_cumsum(M, N, block_M, block_N, dim=0, reverse=False, dtype="float32", sc
ref_b
=
torch
.
empty_like
(
A
)
for
i
in
range
(
M
//
block_M
):
for
j
in
range
(
N
//
block_N
):
ref_b
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
j
*
block_
N
:(
j
+
1
)
*
block_
N
]
=
A
[
i
*
block_
M
:(
i
+
1
)
*
block_
M
,
j
*
block_N
:(
j
+
1
)
*
block_N
].
cumsum
(
dim
=
dim
)
ref_b
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
j
*
block_N
:
(
j
+
1
)
*
block_N
]
=
A
[
i
*
block_
M
:
(
i
+
1
)
*
block_
M
,
j
*
block_
N
:
(
j
+
1
)
*
block_
N
].
cumsum
(
dim
=
dim
)
if
reverse
:
ref_b
[
i
*
block_M
:(
i
+
1
)
*
block_M
,
j
*
block_N
:(
j
+
1
)
*
block_N
]
=
A
[
i
*
block_M
:(
i
+
1
)
*
block_M
,
j
*
block_N
:(
j
+
1
)
*
block_N
].
flip
(
dims
=
[
dim
]).
cumsum
(
dim
=
dim
).
flip
(
dims
=
[
dim
])
ref_b
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
j
*
block_N
:
(
j
+
1
)
*
block_N
]
=
(
A
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
j
*
block_N
:
(
j
+
1
)
*
block_N
]
.
flip
(
dims
=
[
dim
])
.
cumsum
(
dim
=
dim
)
.
flip
(
dims
=
[
dim
])
)
return
ref_b
tilelang_res
=
jit_kernel
(
A
)
...
...
testing/python/language/test_tilelang_language_frontend_v2.py
View file @
29051439
...
...
@@ -8,7 +8,6 @@ from tvm.tir.expr import IntImm, Var
def
test_argument
():
@
T
.
prim_func
def
test_argument
(
t_1
:
T
.
bool
,
...
...
@@ -41,6 +40,7 @@ def test_argument():
def
test_expr
():
from
tilelang.language.v2.dtypes
import
_all_dtypes
errors
=
[]
for
name
in
_all_dtypes
:
dtype
=
getattr
(
T
,
name
)
...
...
@@ -116,33 +116,32 @@ def test_expr():
def
test_dtype_str_repr
():
@
T
.
prim_func
def
test_str_repr
():
buf_1
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
bool
,
scope
=
'
shared
'
)
# noqa F841
buf_2
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
short
,
scope
=
'
shared
'
)
# noqa F841
buf_3
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int
,
scope
=
'
shared
'
)
# noqa F841
buf_4
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
long
,
scope
=
'
shared
'
)
# noqa F841
buf_5
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
half
,
scope
=
'
shared
'
)
# noqa F841
buf_6
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float
,
scope
=
'
shared
'
)
# noqa F841
buf_7
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
long
,
scope
=
'
shared
'
)
# noqa F841
buf_8
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int8
,
scope
=
'
shared
'
)
# noqa F841
buf_9
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int16
,
scope
=
'
shared
'
)
# noqa F841
buf_10
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int32
,
scope
=
'
shared
'
)
# noqa F841
buf_11
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int64
,
scope
=
'
shared
'
)
# noqa F841
buf_12
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
uint8
,
scope
=
'
shared
'
)
# noqa F841
buf_13
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
uint16
,
scope
=
'
shared
'
)
# noqa F841
buf_14
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
uint32
,
scope
=
'
shared
'
)
# noqa F841
buf_15
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
uint64
,
scope
=
'
shared
'
)
# noqa F841
buf_16
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e4m3fn
,
scope
=
'
shared
'
)
# noqa F841
buf_17
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e4m3fnuz
,
scope
=
'
shared
'
)
# noqa F841
buf_18
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e5m2
,
scope
=
'
shared
'
)
# noqa F841
buf_19
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e5m2fnuz
,
scope
=
'
shared
'
)
# noqa F841
buf_20
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e8m0fnu
,
scope
=
'
shared
'
)
# noqa F841
buf_21
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float16
,
scope
=
'
shared
'
)
# noqa F841
buf_22
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
bfloat16
,
scope
=
'
shared
'
)
# noqa F841
buf_23
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float32
,
scope
=
'
shared
'
)
# noqa F841
buf_24
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float64
,
scope
=
'
shared
'
)
# noqa F841
buf_1
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
bool
,
scope
=
"
shared
"
)
# noqa F841
buf_2
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
short
,
scope
=
"
shared
"
)
# noqa F841
buf_3
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int
,
scope
=
"
shared
"
)
# noqa F841
buf_4
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
long
,
scope
=
"
shared
"
)
# noqa F841
buf_5
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
half
,
scope
=
"
shared
"
)
# noqa F841
buf_6
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float
,
scope
=
"
shared
"
)
# noqa F841
buf_7
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
long
,
scope
=
"
shared
"
)
# noqa F841
buf_8
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int8
,
scope
=
"
shared
"
)
# noqa F841
buf_9
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int16
,
scope
=
"
shared
"
)
# noqa F841
buf_10
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int32
,
scope
=
"
shared
"
)
# noqa F841
buf_11
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int64
,
scope
=
"
shared
"
)
# noqa F841
buf_12
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
uint8
,
scope
=
"
shared
"
)
# noqa F841
buf_13
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
uint16
,
scope
=
"
shared
"
)
# noqa F841
buf_14
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
uint32
,
scope
=
"
shared
"
)
# noqa F841
buf_15
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
uint64
,
scope
=
"
shared
"
)
# noqa F841
buf_16
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e4m3fn
,
scope
=
"
shared
"
)
# noqa F841
buf_17
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e4m3fnuz
,
scope
=
"
shared
"
)
# noqa F841
buf_18
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e5m2
,
scope
=
"
shared
"
)
# noqa F841
buf_19
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e5m2fnuz
,
scope
=
"
shared
"
)
# noqa F841
buf_20
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e8m0fnu
,
scope
=
"
shared
"
)
# noqa F841
buf_21
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float16
,
scope
=
"
shared
"
)
# noqa F841
buf_22
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
bfloat16
,
scope
=
"
shared
"
)
# noqa F841
buf_23
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float32
,
scope
=
"
shared
"
)
# noqa F841
buf_24
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float64
,
scope
=
"
shared
"
)
# noqa F841
# not supported now
...
...
@@ -205,7 +204,6 @@ def test_dtype_str_repr():
def
test_var_assign
():
@
tilelang
.
jit
(
out_idx
=-
1
)
@
T
.
prim_func
def
test_var_assign
(
A
:
T
.
Tensor
((
2
,),
T
.
int32
)):
...
...
@@ -223,7 +221,6 @@ def test_var_assign():
def
test_marco_return
():
@
T
.
macro
def
macro_return_constant
():
return
0
...
...
@@ -258,7 +255,6 @@ def test_marco_return():
def
test_prim_func_generator
():
@
T
.
prim_func
(
generator
=
True
)
def
prim_func_gen
(
A
=
T
.
Tensor
((
128
,),
T
.
float32
),
# noqa: B008
...
...
@@ -277,7 +273,6 @@ def test_prim_func_generator():
def
test_serial_for_with_step
():
@
tilelang
.
jit
(
out_idx
=-
1
)
@
T
.
prim_func
def
test_stepped_serial
(
A
:
T
.
Tensor
((
10
,),
T
.
int32
)):
...
...
@@ -291,7 +286,7 @@ def test_serial_for_with_step():
ker
=
test_stepped_serial
()
res
=
ker
()
ref
=
torch
.
tensor
([
1
,
2
,
1
,
2
,
1
,
2
,
1
,
2
,
1
,
2
],
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
ref
=
torch
.
tensor
([
1
,
2
,
1
,
2
,
1
,
2
,
1
,
2
,
1
,
2
],
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
assert
torch
.
all
(
res
==
ref
),
f
"Expected
{
ref
}
, but got
{
res
}
"
@
tilelang
.
jit
(
out_idx
=-
1
)
...
...
@@ -304,17 +299,16 @@ def test_serial_for_with_step():
ker
=
test_serial_step_neg
()
res
=
ker
()
ref
=
torch
.
tensor
([
10
,
9
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
],
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
ref
=
torch
.
tensor
([
10
,
9
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
],
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
assert
torch
.
all
(
res
==
ref
),
f
"Expected
{
ref
}
, but got
{
res
}
"
assert
isinstance
(
T
.
serial
(
1
,
10
,
1
),
IRBuilderFrame
)
assert
isinstance
(
T
.
serial
(
1
,
10
,
IntImm
(
'
int32
'
,
1
)),
IRBuilderFrame
)
assert
not
isinstance
(
T
.
serial
(
1
,
10
,
Var
(
'
tmp
'
,
'
int32
'
)),
IRBuilderFrame
)
assert
isinstance
(
T
.
serial
(
1
,
10
,
IntImm
(
"
int32
"
,
1
)),
IRBuilderFrame
)
assert
not
isinstance
(
T
.
serial
(
1
,
10
,
Var
(
"
tmp
"
,
"
int32
"
)),
IRBuilderFrame
)
assert
not
isinstance
(
T
.
serial
(
10
,
-
1
,
-
1
),
IRBuilderFrame
)
def
test_swap_logic
():
@
tilelang
.
jit
@
T
.
prim_func
def
swap_var
(
A
:
T
.
Tensor
[(
2
,),
T
.
float32
]):
...
...
@@ -344,7 +338,6 @@ def test_swap_logic():
def
test_while_loop
():
@
tilelang
.
jit
(
out_idx
=-
1
)
@
T
.
prim_func
def
test_while_loop
(
A
:
T
.
Tensor
((
1
,),
T
.
int32
)):
...
...
@@ -374,7 +367,7 @@ def test_var_macro():
x
=
T
.
alloc_var
(
T
.
int32
)
macro_with_var
(
x
)
assert
'
x[0] = 1
'
in
prim_call_macro
.
script
()
assert
"
x[0] = 1
"
in
prim_call_macro
.
script
()
finally
:
pass
...
...
@@ -406,7 +399,7 @@ def test_var_macro():
x
=
T
.
alloc_var
(
T
.
int32
)
macro_with_var
(
x
)
assert
'
x[0] = 1
'
in
prim_call_macro
.
script
()
assert
"
x[0] = 1
"
in
prim_call_macro
.
script
()
finally
:
pass
...
...
@@ -428,10 +421,8 @@ def test_var_macro():
def
test_frame_inside_macro
():
@
tilelang
.
jit
def
get_sample_kernel
():
@
T
.
macro
def
transform
(
x
):
return
x
+
1
...
...
@@ -442,7 +433,7 @@ def test_frame_inside_macro():
idx_out
:
T
.
Tensor
[(
32
,),
T
.
int32
],
):
with
T
.
Kernel
(
num_blocks
,
threads
=
32
)
as
block_idx
:
# noqa: F841
fragment
=
T
.
alloc_fragment
(
32
,
'
int32
'
)
fragment
=
T
.
alloc_fragment
(
32
,
"
int32
"
)
T
.
copy
(
idx_out
,
fragment
)
for
i
in
T
.
Parallel
(
32
):
...
...
@@ -467,10 +458,10 @@ def test_buffer_slice_step():
def
test_boolop
():
a
=
Var
(
'a'
,
'
int32
'
)
b
=
Var
(
'b'
,
'
int32
'
)
c
=
Var
(
'c'
,
'
int32
'
)
d
=
Var
(
'd'
,
'
int32
'
)
a
=
Var
(
"a"
,
"
int32
"
)
b
=
Var
(
"b"
,
"
int32
"
)
c
=
Var
(
"c"
,
"
int32
"
)
d
=
Var
(
"d"
,
"
int32
"
)
@
T
.
macro
def
cond
():
...
...
@@ -479,5 +470,5 @@ def test_boolop():
cond
()
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
tilelang
.
testing
.
main
()
testing/python/language/test_tilelang_language_get_warp_info.py
View file @
29051439
...
...
@@ -23,7 +23,6 @@ def _resolve_warps_per_group(warps_per_group: Optional[int]) -> int:
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
_get_laneid_kernel
(
num_threads
:
int
=
128
,
warp_size
:
Optional
[
int
]
=
None
):
@
T
.
prim_func
def
laneid_kernel
(
A
:
T
.
Tensor
((
num_threads
,),
"int32"
)):
with
T
.
Kernel
(
1
,
threads
=
num_threads
)
as
_
:
...
...
@@ -35,7 +34,6 @@ def _get_laneid_kernel(num_threads: int = 128, warp_size: Optional[int] = None):
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
_get_warp_idx_sync_kernel
(
num_threads
:
int
=
128
,
warp_size
:
Optional
[
int
]
=
None
):
@
T
.
prim_func
def
warp_idx_sync_kernel
(
A
:
T
.
Tensor
((
num_threads
,),
"int32"
)):
with
T
.
Kernel
(
1
,
threads
=
num_threads
)
as
_
:
...
...
@@ -47,7 +45,6 @@ def _get_warp_idx_sync_kernel(num_threads: int = 128, warp_size: Optional[int] =
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
_get_warp_idx_kernel
(
num_threads
:
int
=
128
,
warp_size
:
Optional
[
int
]
=
None
):
@
T
.
prim_func
def
warp_idx_kernel
(
A
:
T
.
Tensor
((
num_threads
,),
"int32"
)):
with
T
.
Kernel
(
1
,
threads
=
num_threads
)
as
_
:
...
...
@@ -63,7 +60,6 @@ def _get_warp_group_idx_kernel(
warp_size
:
Optional
[
int
]
=
None
,
warps_per_group
:
Optional
[
int
]
=
None
,
):
@
T
.
prim_func
def
warp_group_idx_kernel
(
A
:
T
.
Tensor
((
num_threads
,),
"int32"
)):
with
T
.
Kernel
(
1
,
threads
=
num_threads
)
as
_
:
...
...
@@ -75,7 +71,6 @@ def _get_warp_group_idx_kernel(
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
_shuffle_elect_kernel
(
num_threads
:
int
=
128
,
thread_extent
:
int
=
64
):
@
T
.
prim_func
def
shuffle_elect_kernel
(
A
:
T
.
Tensor
((
num_threads
,),
"int32"
)):
with
T
.
Kernel
(
1
,
threads
=
num_threads
)
as
_
:
...
...
testing/python/language/test_tilelang_language_if_range.py
View file @
29051439
...
...
@@ -4,9 +4,10 @@ import torch
import
tilelang.testing
@
tilelang
.
jit
(
out_idx
=
[
1
],)
@
tilelang
.
jit
(
out_idx
=
[
1
],
)
def
tilelang_if_range
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
...
...
testing/python/language/test_tilelang_language_infinity.py
View file @
29051439
...
...
@@ -5,7 +5,6 @@ import tilelang.language as T
@
tilelang
.
jit
(
out_idx
=-
1
)
def
get_inf_kernel
(
dtype
:
str
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
32
,),
dtype
)):
with
T
.
Kernel
(
1
,
threads
=
32
):
...
...
@@ -18,7 +17,7 @@ def _test_infinity(dtype: str):
kernel
=
get_inf_kernel
(
dtype
)
output
=
kernel
()
assert
torch
.
all
(
output
==
torch
.
inf
),
f
'
check failed for
{
dtype
=
}
'
assert
torch
.
all
(
output
==
torch
.
inf
),
f
"
check failed for
{
dtype
=
}
"
@
tilelang
.
testing
.
requires_cuda
...
...
testing/python/language/test_tilelang_language_intrinsics_codegen.py
View file @
29051439
testing/python/language/test_tilelang_language_lazy_jit.py
View file @
29051439
...
...
@@ -8,7 +8,6 @@ import torch
def
_gemm_impl
():
@
T
.
macro
def
gemm_impl
(
A
:
T
.
Tensor
[[
int
,
int
],
Any
],
...
...
@@ -37,7 +36,6 @@ def _gemm_impl():
def
test_jit2_gemm_annot
():
@
tilelang
.
lazy_jit
def
gemm
(
A
:
T
.
Tensor
[[
int
,
int
],
Any
],
...
...
@@ -54,24 +52,24 @@ def test_jit2_gemm_annot():
return
C
prod
=
product
([
T
.
float16
,
T
.
float32
],
[
T
.
float32
])
gemm
.
par_compile
([{
'A'
:
T
.
Tensor
((
1024
,
1024
),
dtype
=
in_dtype
),
'B'
:
T
.
Tensor
((
1024
,
1024
),
dtype
=
in_dtype
),
'out_dtype'
:
out_dtype
}
for
in_dtype
,
out_dtype
in
prod
])
gemm
.
par_compile
(
[
{
"A"
:
T
.
Tensor
((
1024
,
1024
),
dtype
=
in_dtype
),
"B"
:
T
.
Tensor
((
1024
,
1024
),
dtype
=
in_dtype
),
"out_dtype"
:
out_dtype
}
for
in_dtype
,
out_dtype
in
prod
]
)
for
in_dtype
,
out_dtype
in
prod
:
in_dtype
=
in_dtype
.
torch
()
out_dtype
=
out_dtype
.
torch
()
A
=
torch
.
randn
(
1024
,
1024
,
dtype
=
in_dtype
,
device
=
'
cuda
'
)
B
=
torch
.
randn
(
1024
,
1024
,
dtype
=
in_dtype
,
device
=
'
cuda
'
)
A
=
torch
.
randn
(
1024
,
1024
,
dtype
=
in_dtype
,
device
=
"
cuda
"
)
B
=
torch
.
randn
(
1024
,
1024
,
dtype
=
in_dtype
,
device
=
"
cuda
"
)
C_ref
=
out_dtype
(
A
@
B
)
C
=
gemm
(
A
,
B
)
torch
.
testing
.
assert_close
(
C
,
C_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
def
test_jit2_gemm_ptr
():
@
tilelang
.
lazy_jit
def
gemm_ptr
(
A
:
T
.
ptr
,
...
...
@@ -92,23 +90,19 @@ def test_jit2_gemm_ptr():
_gemm_impl
()(
A
,
B
,
C
,
out_dtype
,
block_M
,
block_N
,
block_K
)
prod
=
product
([
T
.
float16
,
T
.
float32
],
[
T
.
float32
])
gemm_ptr
.
par_compile
([{
'A'
:
T
.
ptr
(),
'B'
:
T
.
ptr
(),
'C'
:
T
.
ptr
(),
'M'
:
1024
,
'N'
:
1024
,
'K'
:
1024
,
'dtype'
:
in_dtype
,
'out_dtype'
:
out_dtype
}
for
in_dtype
,
out_dtype
in
prod
])
gemm_ptr
.
par_compile
(
[
{
"A"
:
T
.
ptr
(),
"B"
:
T
.
ptr
(),
"C"
:
T
.
ptr
(),
"M"
:
1024
,
"N"
:
1024
,
"K"
:
1024
,
"dtype"
:
in_dtype
,
"out_dtype"
:
out_dtype
}
for
in_dtype
,
out_dtype
in
prod
]
)
for
in_dtype
,
out_dtype
in
prod
:
in_dtype
=
in_dtype
.
torch
()
out_dtype
=
out_dtype
.
torch
()
A
=
torch
.
randn
(
1024
,
1024
,
dtype
=
in_dtype
,
device
=
'
cuda
'
)
B
=
torch
.
randn
(
1024
,
1024
,
dtype
=
in_dtype
,
device
=
'
cuda
'
)
A
=
torch
.
randn
(
1024
,
1024
,
dtype
=
in_dtype
,
device
=
"
cuda
"
)
B
=
torch
.
randn
(
1024
,
1024
,
dtype
=
in_dtype
,
device
=
"
cuda
"
)
C_ref
=
out_dtype
(
A
@
B
)
C
=
torch
.
empty
(
1024
,
1024
,
dtype
=
out_dtype
,
device
=
'
cuda
'
)
C
=
torch
.
empty
(
1024
,
1024
,
dtype
=
out_dtype
,
device
=
"
cuda
"
)
gemm_ptr
(
A
,
B
,
C
,
1024
,
1024
,
1024
,
in_dtype
,
out_dtype
)
torch
.
testing
.
assert_close
(
C
,
C_ref
,
atol
=
1e-2
,
rtol
=
1e-2
)
...
...
@@ -129,8 +123,7 @@ def test_jit2_annot():
AnnotTest
(
annot
=
T
.
Tensor
[[
int
,
int
],
T
.
float32
],
promote
=
False
,
match_ok
=
[
torch
.
randn
(
1
,
1
,
dtype
=
torch
.
float32
),
T
.
Tensor
((
1
,
1
),
dtype
=
T
.
float32
)],
match_ok
=
[
torch
.
randn
(
1
,
1
,
dtype
=
torch
.
float32
),
T
.
Tensor
((
1
,
1
),
dtype
=
T
.
float32
)],
match_ng
=
[
torch
.
randn
(
1
,
1
,
dtype
=
torch
.
float16
),
T
.
Tensor
(
1
,
dtype
=
T
.
float32
),
...
...
@@ -146,8 +139,8 @@ def test_jit2_annot():
T
.
Tensor
((
1
,),
dtype
=
T
.
float32
),
T
.
Tensor
((
1
,),
dtype
=
T
.
float16
),
],
match_ng
=
[
torch
.
randn
((
1
,
1
),
dtype
=
torch
.
float32
),
T
.
Tensor
((
1
,
1
),
dtype
=
T
.
float16
)]
),
match_ng
=
[
torch
.
randn
((
1
,
1
),
dtype
=
torch
.
float32
),
T
.
Tensor
((
1
,
1
),
dtype
=
T
.
float16
)],
),
AnnotTest
(
annot
=
T
.
Tensor
[[
int
,
1
],
Any
],
promote
=
False
,
...
...
@@ -157,8 +150,8 @@ def test_jit2_annot():
T
.
Tensor
((
12
,
1
),
T
.
float32
),
T
.
Tensor
((
12
,
1
),
T
.
float16
),
],
match_ng
=
[
torch
.
randn
(
12
,
12
,
dtype
=
torch
.
float32
),
T
.
Tensor
((
12
,
12
),
T
.
float32
)]
),
match_ng
=
[
torch
.
randn
(
12
,
12
,
dtype
=
torch
.
float32
),
T
.
Tensor
((
12
,
12
),
T
.
float32
)],
),
AnnotTest
(
annot
=
T
.
Tensor
[[
T
.
dyn
,
1
],
Any
],
promote
=
False
,
...
...
@@ -168,43 +161,39 @@ def test_jit2_annot():
T
.
Tensor
((
12
,
1
),
T
.
float32
),
T
.
Tensor
((
12
,
1
),
T
.
float16
),
],
match_ng
=
[
torch
.
randn
(
12
,
12
,
dtype
=
torch
.
float32
),
T
.
Tensor
((
12
,
12
),
T
.
float32
)]
),
match_ng
=
[
torch
.
randn
(
12
,
12
,
dtype
=
torch
.
float32
),
T
.
Tensor
((
12
,
12
),
T
.
float32
)],
),
AnnotTest
(
annot
=
T
.
Tensor
[[
1024
,
1024
],
T
.
float32
],
promote
=
True
,
),
AnnotTest
(
annot
=
T
.
dyn
[
int
,
'X'
],
promote
=
False
,
match_ok
=
[
1
,
2
,
3
,
4
]),
AnnotTest
(
annot
=
T
.
dyn
,
promote
=
False
,
match_ok
=
[
1
,
2
,
3
,
4
])
AnnotTest
(
annot
=
T
.
dyn
[
int
,
"X"
],
promote
=
False
,
match_ok
=
[
1
,
2
,
3
,
4
]),
AnnotTest
(
annot
=
T
.
dyn
,
promote
=
False
,
match_ok
=
[
1
,
2
,
3
,
4
])
,
]
for
test
in
tests
:
promote
=
test
.
annot
.
promote
()
promoted
=
promote
is
not
None
if
promoted
!=
test
.
promote
:
raise
AssertionError
(
f
'Promote mismatch for
{
test
.
annot
}
: expected
{
test
.
promote
}
, got
{
promoted
}
'
)
with
Builder
().
prim_func
(
'_test'
):
raise
AssertionError
(
f
"Promote mismatch for
{
test
.
annot
}
: expected
{
test
.
promote
}
, got
{
promoted
}
"
)
with
Builder
().
prim_func
(
"_test"
):
for
match_ok
in
test
.
match_ok
:
try
:
vt
=
ArgVarTable
()
test
.
annot
.
create_prim_func_arg
(
'
arg
'
,
match_ok
,
vt
)
test
.
annot
.
create_prim_func_arg
(
"
arg
"
,
match_ok
,
vt
)
except
Exception
as
e
:
traceback
.
print_exc
()
raise
AssertionError
(
f
'Match failed for
{
test
.
annot
}
with value
{
match_ok
}
:
{
e
}
'
)
from
e
raise
AssertionError
(
f
"Match failed for
{
test
.
annot
}
with value
{
match_ok
}
:
{
e
}
"
)
from
e
for
match_ng
in
test
.
match_ng
:
try
:
vt
=
ArgVarTable
()
test
.
annot
.
create_prim_func_arg
(
'arg'
,
match_ng
,
vt
)
raise
AssertionError
(
f
'Match unexpectedly succeeded for
{
test
.
annot
}
with value
{
match_ng
}
'
)
test
.
annot
.
create_prim_func_arg
(
"arg"
,
match_ng
,
vt
)
raise
AssertionError
(
f
"Match unexpectedly succeeded for
{
test
.
annot
}
with value
{
match_ng
}
"
)
except
Exception
:
pass
def
test_jit2_many_annot
():
@
T
.
macro
def
copy_impl
(
A
,
B
):
M
,
N
=
A
.
shape
...
...
@@ -213,8 +202,7 @@ def test_jit2_many_annot():
assert
N
==
N_
,
f
"N mismatch
{
N
}
{
N_
}
"
# assert tuple(A.shape) == tuple(B.shape), f"Invalid tensor shape: {A.shape}, {B.shape}"
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
128
),
T
.
ceildiv
(
N
,
128
),
threads
=
128
)
as
(
bx
,
by
):
T
.
copy
(
A
[
bx
*
128
:
bx
*
128
+
128
,
by
*
128
:
by
*
128
+
128
],
B
[
bx
*
128
:
bx
*
128
+
128
,
by
*
128
:
by
*
128
+
128
])
T
.
copy
(
A
[
bx
*
128
:
bx
*
128
+
128
,
by
*
128
:
by
*
128
+
128
],
B
[
bx
*
128
:
bx
*
128
+
128
,
by
*
128
:
by
*
128
+
128
])
@
tilelang
.
lazy_jit
def
copy1
(
...
...
@@ -259,20 +247,19 @@ def test_jit2_many_annot():
copy_impl
(
A
,
B
)
for
copy
in
[
copy1
,
copy2
,
copy3
,
copy4
]:
A
=
torch
.
randn
(
128
,
128
,
device
=
'
cuda
'
)
B
=
torch
.
empty
(
128
,
128
,
device
=
'
cuda
'
)
A
=
torch
.
randn
(
128
,
128
,
device
=
"
cuda
"
)
B
=
torch
.
empty
(
128
,
128
,
device
=
"
cuda
"
)
copy
(
A
,
B
)
assert
torch
.
equal
(
B
,
A
)
for
copy
in
[
copy5
,
copy6
]:
A
=
torch
.
randn
(
128
,
2
,
128
,
2
,
device
=
'
cuda
'
)
B
=
torch
.
randn
(
128
,
2
,
128
,
2
,
device
=
'
cuda
'
)
A
=
torch
.
randn
(
128
,
2
,
128
,
2
,
device
=
"
cuda
"
)
B
=
torch
.
randn
(
128
,
2
,
128
,
2
,
device
=
"
cuda
"
)
copy
(
A
[:,
0
,
:,
0
],
B
[:,
0
,
:,
0
])
assert
torch
.
equal
(
A
[:,
0
,
:,
0
],
B
[:,
0
,
:,
0
])
def
test_jit2_return
():
@
T
.
macro
def
copy_impl
(
A
):
M
,
N
=
A
.
shape
...
...
@@ -283,8 +270,7 @@ def test_jit2_return():
assert
N
==
N_
,
f
"N mismatch
{
N
}
{
N_
}
"
# assert tuple(A.shape) == tuple(B.shape), f"Invalid tensor shape: {A.shape}, {B.shape}"
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
128
),
T
.
ceildiv
(
N
,
128
),
threads
=
128
)
as
(
bx
,
by
):
T
.
copy
(
A
[
bx
*
128
:
bx
*
128
+
128
,
by
*
128
:
by
*
128
+
128
],
B
[
bx
*
128
:
bx
*
128
+
128
,
by
*
128
:
by
*
128
+
128
])
T
.
copy
(
A
[
bx
*
128
:
bx
*
128
+
128
,
by
*
128
:
by
*
128
+
128
],
B
[
bx
*
128
:
bx
*
128
+
128
,
by
*
128
:
by
*
128
+
128
])
return
B
@
tilelang
.
lazy_jit
...
...
@@ -292,41 +278,52 @@ def test_jit2_return():
return
copy_impl
(
A
)
@
tilelang
.
lazy_jit
def
copy1
(
A
:
T
.
Tensor
[[
int
,
int
],
T
.
float32
],):
def
copy1
(
A
:
T
.
Tensor
[[
int
,
int
],
T
.
float32
],
):
return
copy_impl
(
A
)
@
tilelang
.
lazy_jit
def
copy2
(
A
:
T
.
Tensor
[[
128
,
128
],
T
.
float32
],):
def
copy2
(
A
:
T
.
Tensor
[[
128
,
128
],
T
.
float32
],
):
return
copy_impl
(
A
)
@
tilelang
.
lazy_jit
def
copy3
(
A
:
T
.
Tensor
[[
int
,
128
],
T
.
float32
],):
def
copy3
(
A
:
T
.
Tensor
[[
int
,
128
],
T
.
float32
],
):
return
copy_impl
(
A
)
@
tilelang
.
lazy_jit
def
copy4
(
A
:
T
.
Tensor
[[
T
.
dyn
,
int
],
T
.
float32
],):
def
copy4
(
A
:
T
.
Tensor
[[
T
.
dyn
,
int
],
T
.
float32
],
):
return
copy_impl
(
A
)
@
tilelang
.
lazy_jit
def
copy5
(
A
:
T
.
StridedTensor
[[
int
,
int
],
[
int
,
int
],
T
.
float32
],):
def
copy5
(
A
:
T
.
StridedTensor
[[
int
,
int
],
[
int
,
int
],
T
.
float32
],
):
return
copy_impl
(
A
)
@
tilelang
.
lazy_jit
def
copy6
(
A
:
T
.
StridedTensor
[[
T
.
dyn
,
int
],
[
int
,
int
],
T
.
float32
],):
def
copy6
(
A
:
T
.
StridedTensor
[[
T
.
dyn
,
int
],
[
int
,
int
],
T
.
float32
],
):
return
copy_impl
(
A
)
for
copy
in
[
copy0
,
copy1
,
copy2
,
copy3
,
copy4
]:
A
=
torch
.
randn
(
128
,
128
,
device
=
'
cuda
'
)
A
=
torch
.
randn
(
128
,
128
,
device
=
"
cuda
"
)
B
=
copy
(
A
)
assert
torch
.
equal
(
B
,
A
)
for
copy
in
[
copy5
,
copy6
]:
A
=
torch
.
randn
(
128
,
2
,
128
,
2
,
device
=
'
cuda
'
)
A
=
torch
.
randn
(
128
,
2
,
128
,
2
,
device
=
"
cuda
"
)
B
=
copy
(
A
[:,
0
,
:,
0
])
assert
torch
.
equal
(
A
[:,
0
,
:,
0
],
B
)
def
test_jit2_deepseek_deepgemm
():
@
tilelang
.
lazy_jit
def
deep_gemm
(
A
:
T
.
Tensor
[[
int
,
int
],
T
.
float8_e4m3
],
...
...
@@ -351,13 +348,9 @@ def test_jit2_deepseek_deepgemm():
N
,
K
=
B
.
shape
C
=
T
.
empty
(
M
,
N
,
dtype
=
out_dtype
)
assert
out_dtype
in
[
T
.
bfloat16
,
T
.
float32
],
f
"Expect out_dtype to be one of [T.float16, T.float32], got
{
out_dtype
}
"
assert
scales_a
.
shape
==
[
M
,
T
.
ceildiv
(
K
,
group_size
)
],
f
"Expect scales_a shape to be f
{
[
M
,
T
.
ceildiv
(
K
,
group_size
)]
}
"
assert
scales_b
.
shape
==
[
N
,
T
.
ceildiv
(
K
,
group_size
)
],
f
"Expect scales_b shape to be f
{
[
N
,
T
.
ceildiv
(
K
,
group_size
)]
}
"
assert
out_dtype
in
[
T
.
bfloat16
,
T
.
float32
],
f
"Expect out_dtype to be one of [T.float16, T.float32], got
{
out_dtype
}
"
assert
scales_a
.
shape
==
[
M
,
T
.
ceildiv
(
K
,
group_size
)],
f
"Expect scales_a shape to be f
{
[
M
,
T
.
ceildiv
(
K
,
group_size
)]
}
"
assert
scales_b
.
shape
==
[
N
,
T
.
ceildiv
(
K
,
group_size
)],
f
"Expect scales_b shape to be f
{
[
N
,
T
.
ceildiv
(
K
,
group_size
)]
}
"
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
in_dtype
)
...
...
@@ -421,5 +414,5 @@ def test_jit2_deepseek_deepgemm():
# M, N, K = 1024, 1024, 8192
# A = torch.randn((M, K), dtype=torch.float8_e4m3fn, )
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
tilelang
.
testing
.
main
()
testing/python/language/test_tilelang_language_let.py
View file @
29051439
...
...
@@ -4,7 +4,6 @@ from tilelang import language as T
def
test_let_vectorize_load
():
@
T
.
prim_func
def
main
(
A_ptr
:
T
.
handle
):
A
=
T
.
match_buffer
(
A_ptr
,
(
16
,
16
),
dtype
=
"float32"
,
align
=
16
)
...
...
testing/python/language/test_tilelang_language_mask_op.py
View file @
29051439
...
...
@@ -6,7 +6,6 @@ import torch
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def
tilelang_copy_mask_parallel
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
...
...
@@ -30,13 +29,8 @@ def tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype="float16"):
def
run_tilelang_copy_mask_parallel
(
M
=
1024
,
N
=
1024
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
):
program
=
tilelang_copy_mask_parallel
(
M
,
N
,
block_M
,
block_N
,
dtype
)
kernel
=
tilelang
.
compile
(
program
,
out_idx
=
[
1
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_warp_specialized"
:
True
,
"tl.disable_tma_lower"
:
True
})
program
,
out_idx
=
[
1
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_warp_specialized"
:
True
,
"tl.disable_tma_lower"
:
True
}
)
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
b
=
kernel
(
a
)
torch
.
testing
.
assert_close
(
b
,
a
,
rtol
=
1e-2
,
atol
=
1e-2
)
...
...
@@ -49,7 +43,6 @@ def test_tilelang_copy_mask_parallel():
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def
tilelang_copy_mask_copy
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
...
...
@@ -72,13 +65,8 @@ def tilelang_copy_mask_copy(M, N, block_M, block_N, dtype="float16"):
def
run_tilelang_copy_mask_copy
(
M
=
1024
,
N
=
1024
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
):
program
=
tilelang_copy_mask_copy
(
M
,
N
,
block_M
,
block_N
,
dtype
)
kernel
=
tilelang
.
compile
(
program
,
out_idx
=
[
1
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_warp_specialized"
:
True
,
"tl.disable_tma_lower"
:
True
})
program
,
out_idx
=
[
1
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_warp_specialized"
:
True
,
"tl.disable_tma_lower"
:
True
}
)
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
b
=
kernel
(
a
)
torch
.
testing
.
assert_close
(
b
,
a
,
rtol
=
1e-2
,
atol
=
1e-2
)
...
...
@@ -91,7 +79,6 @@ def test_tilelang_copy_mask_copy():
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def
tilelang_copy_mask_parallel_range
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
...
...
@@ -112,20 +99,11 @@ def tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype="float16"):
return
main
def
run_tilelang_copy_mask_parallel_range
(
M
=
1024
,
N
=
1024
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
):
def
run_tilelang_copy_mask_parallel_range
(
M
=
1024
,
N
=
1024
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
):
program
=
tilelang_copy_mask_parallel_range
(
M
,
N
,
block_M
,
block_N
,
dtype
)
kernel
=
tilelang
.
compile
(
program
,
out_idx
=
[
1
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_warp_specialized"
:
True
,
"tl.disable_tma_lower"
:
True
})
program
,
out_idx
=
[
1
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_warp_specialized"
:
True
,
"tl.disable_tma_lower"
:
True
}
)
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
b
=
kernel
(
a
)
torch
.
testing
.
assert_close
(
b
,
a
,
rtol
=
1e-2
,
atol
=
1e-2
)
...
...
@@ -138,7 +116,6 @@ def test_tilelang_copy_mask_parallel_range():
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def
tilelang_copy_mask_copy_range
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
...
...
@@ -161,13 +138,8 @@ def tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype="float16"):
def
run_tilelang_copy_mask_copy_range
(
M
=
1024
,
N
=
1024
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
):
program
=
tilelang_copy_mask_copy_range
(
M
,
N
,
block_M
,
block_N
,
dtype
)
kernel
=
tilelang
.
compile
(
program
,
out_idx
=
[
1
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_warp_specialized"
:
True
,
"tl.disable_tma_lower"
:
True
})
program
,
out_idx
=
[
1
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_warp_specialized"
:
True
,
"tl.disable_tma_lower"
:
True
}
)
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
b
=
kernel
(
a
)
torch
.
testing
.
assert_close
(
b
,
a
,
rtol
=
1e-2
,
atol
=
1e-2
)
...
...
Prev
1
…
10
11
12
13
14
15
16
17
18
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment