Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
247 additions
and
340 deletions
+247
-340
testing/python/language/test_tilelang_language_annot.py
testing/python/language/test_tilelang_language_annot.py
+19
-16
testing/python/language/test_tilelang_language_annotate_safe_value.py
...on/language/test_tilelang_language_annotate_safe_value.py
+4
-10
testing/python/language/test_tilelang_language_any_of.py
testing/python/language/test_tilelang_language_any_of.py
+20
-22
testing/python/language/test_tilelang_language_assume.py
testing/python/language/test_tilelang_language_assume.py
+11
-14
testing/python/language/test_tilelang_language_atomic_add.py
testing/python/language/test_tilelang_language_atomic_add.py
+14
-31
testing/python/language/test_tilelang_language_ceildiv.py
testing/python/language/test_tilelang_language_ceildiv.py
+0
-2
testing/python/language/test_tilelang_language_chain_equal.py
...ing/python/language/test_tilelang_language_chain_equal.py
+5
-5
testing/python/language/test_tilelang_language_clamp.py
testing/python/language/test_tilelang_language_clamp.py
+4
-4
testing/python/language/test_tilelang_language_clear.py
testing/python/language/test_tilelang_language_clear.py
+5
-6
testing/python/language/test_tilelang_language_composable_index.py
...ython/language/test_tilelang_language_composable_index.py
+4
-4
testing/python/language/test_tilelang_language_copy.py
testing/python/language/test_tilelang_language_copy.py
+18
-36
testing/python/language/test_tilelang_language_cumsum.py
testing/python/language/test_tilelang_language_cumsum.py
+17
-14
testing/python/language/test_tilelang_language_frontend_v2.py
...ing/python/language/test_tilelang_language_frontend_v2.py
+39
-48
testing/python/language/test_tilelang_language_get_warp_info.py
...g/python/language/test_tilelang_language_get_warp_info.py
+0
-5
testing/python/language/test_tilelang_language_if_range.py
testing/python/language/test_tilelang_language_if_range.py
+5
-4
testing/python/language/test_tilelang_language_infinity.py
testing/python/language/test_tilelang_language_infinity.py
+1
-2
testing/python/language/test_tilelang_language_intrinsics_codegen.py
...hon/language/test_tilelang_language_intrinsics_codegen.py
+2
-2
testing/python/language/test_tilelang_language_lazy_jit.py
testing/python/language/test_tilelang_language_lazy_jit.py
+62
-69
testing/python/language/test_tilelang_language_let.py
testing/python/language/test_tilelang_language_let.py
+0
-1
testing/python/language/test_tilelang_language_mask_op.py
testing/python/language/test_tilelang_language_mask_op.py
+17
-45
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
testing/python/language/test_tilelang_language_annot.py
View file @
29051439
...
@@ -5,13 +5,14 @@ import torch
...
@@ -5,13 +5,14 @@ import torch
def
test_tensor_annot_mul
():
def
test_tensor_annot_mul
():
@
tilelang
.
jit
@
tilelang
.
jit
def
example_tensor_annot
():
def
example_tensor_annot
():
n
=
T
.
symbolic
(
'n'
)
n
=
T
.
symbolic
(
"n"
)
@
T
.
prim_func
@
T
.
prim_func
def
kernel
(
A
:
T
.
Tensor
((
n
*
4
,),
T
.
int32
),):
def
kernel
(
A
:
T
.
Tensor
((
n
*
4
,),
T
.
int32
),
):
with
T
.
Kernel
(
1
)
as
_
:
with
T
.
Kernel
(
1
)
as
_
:
for
i
in
range
(
n
*
4
):
for
i
in
range
(
n
*
4
):
A
[
i
]
=
0
A
[
i
]
=
0
...
@@ -19,20 +20,21 @@ def test_tensor_annot_mul():
...
@@ -19,20 +20,21 @@ def test_tensor_annot_mul():
return
kernel
return
kernel
ker
=
example_tensor_annot
()
ker
=
example_tensor_annot
()
A
=
torch
.
arange
(
16
,
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
A
=
torch
.
arange
(
16
,
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
ker
(
A
)
ker
(
A
)
expected
=
torch
.
zeros
(
16
,
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
expected
=
torch
.
zeros
(
16
,
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
assert
torch
.
equal
(
A
,
expected
)
assert
torch
.
equal
(
A
,
expected
)
def
test_tensor_annot_add
():
def
test_tensor_annot_add
():
@
tilelang
.
jit
@
tilelang
.
jit
def
example_tensor_annot
():
def
example_tensor_annot
():
n
=
T
.
symbolic
(
'n'
)
n
=
T
.
symbolic
(
"n"
)
@
T
.
prim_func
@
T
.
prim_func
def
kernel
(
A
:
T
.
Tensor
((
n
+
1
,),
T
.
int32
),):
def
kernel
(
A
:
T
.
Tensor
((
n
+
1
,),
T
.
int32
),
):
with
T
.
Kernel
(
1
)
as
_
:
with
T
.
Kernel
(
1
)
as
_
:
for
i
in
range
(
n
+
1
):
for
i
in
range
(
n
+
1
):
A
[
i
]
=
0
A
[
i
]
=
0
...
@@ -40,20 +42,21 @@ def test_tensor_annot_add():
...
@@ -40,20 +42,21 @@ def test_tensor_annot_add():
return
kernel
return
kernel
ker
=
example_tensor_annot
()
ker
=
example_tensor_annot
()
A
=
torch
.
arange
(
16
,
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
A
=
torch
.
arange
(
16
,
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
ker
(
A
)
ker
(
A
)
expected
=
torch
.
zeros
(
16
,
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
expected
=
torch
.
zeros
(
16
,
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
assert
torch
.
equal
(
A
,
expected
)
assert
torch
.
equal
(
A
,
expected
)
def
test_tensor_annot_mul_add
():
def
test_tensor_annot_mul_add
():
@
tilelang
.
jit
@
tilelang
.
jit
def
example_tensor_annot
():
def
example_tensor_annot
():
n
=
T
.
symbolic
(
'n'
)
n
=
T
.
symbolic
(
"n"
)
@
T
.
prim_func
@
T
.
prim_func
def
kernel
(
A
:
T
.
Tensor
((
n
*
3
+
1
,),
T
.
int32
),):
def
kernel
(
A
:
T
.
Tensor
((
n
*
3
+
1
,),
T
.
int32
),
):
with
T
.
Kernel
(
1
)
as
_
:
with
T
.
Kernel
(
1
)
as
_
:
for
i
in
range
(
n
*
3
+
1
):
for
i
in
range
(
n
*
3
+
1
):
A
[
i
]
=
0
A
[
i
]
=
0
...
@@ -61,11 +64,11 @@ def test_tensor_annot_mul_add():
...
@@ -61,11 +64,11 @@ def test_tensor_annot_mul_add():
return
kernel
return
kernel
ker
=
example_tensor_annot
()
ker
=
example_tensor_annot
()
A
=
torch
.
arange
(
16
,
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
A
=
torch
.
arange
(
16
,
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
ker
(
A
)
ker
(
A
)
expected
=
torch
.
zeros
(
16
,
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
expected
=
torch
.
zeros
(
16
,
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
assert
torch
.
equal
(
A
,
expected
)
assert
torch
.
equal
(
A
,
expected
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
tilelang
.
testing
.
main
()
tilelang
.
testing
.
main
()
testing/python/language/test_tilelang_language_annotate_safe_value.py
View file @
29051439
...
@@ -7,11 +7,10 @@ import torch
...
@@ -7,11 +7,10 @@ import torch
# add decorator @tilelang.jit if you want to return a torch function
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
# @tilelang.jit
def
tilelang_copy
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
,
pad_value
=
0
):
def
tilelang_copy
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
,
pad_value
=
0
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
# Initialize Kernel Context
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
...
@@ -30,13 +29,8 @@ def tilelang_copy(M, N, block_M, block_N, dtype="float16", pad_value=0):
...
@@ -30,13 +29,8 @@ def tilelang_copy(M, N, block_M, block_N, dtype="float16", pad_value=0):
def
run_tilelang_copy
(
M
=
1024
,
N
=
1024
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
,
pad_value
=
0
):
def
run_tilelang_copy
(
M
=
1024
,
N
=
1024
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
,
pad_value
=
0
):
program
=
tilelang_copy
(
M
,
N
,
block_M
,
block_N
,
dtype
,
pad_value
=
pad_value
)
program
=
tilelang_copy
(
M
,
N
,
block_M
,
block_N
,
dtype
,
pad_value
=
pad_value
)
kernel
=
tilelang
.
compile
(
kernel
=
tilelang
.
compile
(
program
,
program
,
out_idx
=
[
1
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_warp_specialized"
:
True
,
"tl.disable_tma_lower"
:
True
}
out_idx
=
[
1
],
)
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_warp_specialized"
:
True
,
"tl.disable_tma_lower"
:
True
})
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
b
=
kernel
(
a
)
b
=
kernel
(
a
)
ref_b
=
torch
.
zeros_like
(
a
)
ref_b
=
torch
.
zeros_like
(
a
)
...
...
testing/python/language/test_tilelang_language_any_of.py
View file @
29051439
...
@@ -13,11 +13,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K):
...
@@ -13,11 +13,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K):
accu
=
torch
.
zeros
((
block_M
,
block_N
),
dtype
=
torch
.
float32
,
device
=
A
.
device
)
accu
=
torch
.
zeros
((
block_M
,
block_N
),
dtype
=
torch
.
float32
,
device
=
A
.
device
)
for
k
in
range
(
K
//
block_K
):
for
k
in
range
(
K
//
block_K
):
if
torch
.
any
(
BlockMask
[
i
,
j
,
k
]):
if
torch
.
any
(
BlockMask
[
i
,
j
,
k
]):
accu
+=
A
[
i
*
block_M
:(
i
+
1
)
*
block_M
,
k
*
block_K
:(
k
+
1
)
*
block_K
].
to
(
accu
+=
A
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
k
*
block_K
:
(
k
+
1
)
*
block_K
].
to
(
torch
.
float32
)
@
B
[
torch
.
float32
)
@
B
[
k
*
block_K
:(
k
+
1
)
*
block_K
,
k
*
block_K
:
(
k
+
1
)
*
block_K
,
j
*
block_N
:
(
j
+
1
)
*
block_N
j
*
block_N
:(
j
+
1
)
*
block_N
].
to
(
torch
.
float32
)
].
to
(
torch
.
float32
)
ref_c
[
i
*
block_M
:(
i
+
1
)
*
block_M
,
j
*
block_N
:(
j
+
1
)
*
block_N
]
=
(
ref_c
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
j
*
block_N
:
(
j
+
1
)
*
block_N
]
=
accu
.
to
(
torch
.
float16
)
accu
.
to
(
torch
.
float16
))
return
ref_c
return
ref_c
...
@@ -35,15 +34,14 @@ def blocksparse_matmul_global(
...
@@ -35,15 +34,14 @@ def blocksparse_matmul_global(
dtype
=
"float16"
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
accum_dtype
=
"float"
,
):
):
block_mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
,
condition_dim
)
block_mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
,
condition_dim
)
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
@@ -80,15 +78,14 @@ def blocksparse_matmul_shared(
...
@@ -80,15 +78,14 @@ def blocksparse_matmul_shared(
dtype
=
"float16"
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
accum_dtype
=
"float"
,
):
):
block_mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
,
condition_dim
)
block_mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
,
condition_dim
)
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
@@ -130,15 +127,14 @@ def blocksparse_matmul_local(
...
@@ -130,15 +127,14 @@ def blocksparse_matmul_local(
dtype
=
"float16"
,
dtype
=
"float16"
,
accum_dtype
=
"float"
,
accum_dtype
=
"float"
,
):
):
block_mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
,
condition_dim
)
block_mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
,
condition_dim
)
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
@@ -237,7 +233,8 @@ def run_block_sparse_matmul_shared(M=1024, N=1024, K=1024, sparsity=0.5, conditi
...
@@ -237,7 +233,8 @@ def run_block_sparse_matmul_shared(M=1024, N=1024, K=1024, sparsity=0.5, conditi
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
# Create block mask with desired sparsity
# Create block mask with desired sparsity
mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
)
mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
)
block_mask
=
torch
.
rand
(
mask_shape
).
cuda
()
>
sparsity
block_mask
=
torch
.
rand
(
mask_shape
).
cuda
()
>
sparsity
...
@@ -284,7 +281,8 @@ def run_block_sparse_matmul_local(M=1024, N=1024, K=1024, sparsity=0.5, conditio
...
@@ -284,7 +281,8 @@ def run_block_sparse_matmul_local(M=1024, N=1024, K=1024, sparsity=0.5, conditio
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
# Create block mask with desired sparsity
# Create block mask with desired sparsity
mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
)
mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
)
block_mask
=
torch
.
rand
(
mask_shape
).
cuda
()
>
sparsity
block_mask
=
torch
.
rand
(
mask_shape
).
cuda
()
>
sparsity
...
...
testing/python/language/test_tilelang_language_assume.py
View file @
29051439
...
@@ -4,10 +4,9 @@ import tilelang.testing
...
@@ -4,10 +4,9 @@ import tilelang.testing
def
test_assume_remove_boundary_check
():
def
test_assume_remove_boundary_check
():
@
tilelang
.
jit
@
tilelang
.
jit
def
kernel_with_assume
():
def
kernel_with_assume
():
N
=
T
.
dynamic
(
'N'
)
N
=
T
.
dynamic
(
"N"
)
@
T
.
prim_func
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
N
,),
"float32"
),
l
:
T
.
int32
,
r
:
T
.
int32
):
def
main
(
A
:
T
.
Tensor
((
N
,),
"float32"
),
l
:
T
.
int32
,
r
:
T
.
int32
):
...
@@ -21,20 +20,19 @@ def test_assume_remove_boundary_check():
...
@@ -21,20 +20,19 @@ def test_assume_remove_boundary_check():
jit_kernel
=
kernel_with_assume
()
jit_kernel
=
kernel_with_assume
()
source
=
jit_kernel
.
get_kernel_source
()
source
=
jit_kernel
.
get_kernel_source
()
assert
(
"if ("
not
in
source
)
assert
"if ("
not
in
source
def
test_assume_enable_vectorization
():
def
test_assume_enable_vectorization
():
@
tilelang
.
jit
@
tilelang
.
jit
def
kernel_vectorize
(
M
):
def
kernel_vectorize
(
M
):
N
=
T
.
dynamic
(
'N'
)
N
=
T
.
dynamic
(
"N"
)
vectorize_size
=
4
vectorize_size
=
4
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
"float32"
),
A
:
T
.
Tensor
((
M
,
N
),
"float32"
),
B
:
T
.
Tensor
((
M
,
N
),
"float32"
),
B
:
T
.
Tensor
((
M
,
N
),
"float32"
),
):
):
with
T
.
Kernel
(
1
,
threads
=
32
)
as
_
:
with
T
.
Kernel
(
1
,
threads
=
32
)
as
_
:
tid
=
T
.
get_thread_binding
()
tid
=
T
.
get_thread_binding
()
...
@@ -55,16 +53,15 @@ def test_assume_enable_vectorization():
...
@@ -55,16 +53,15 @@ def test_assume_enable_vectorization():
def
test_assume_complex_indexing
():
def
test_assume_complex_indexing
():
@
tilelang
.
jit
@
tilelang
.
jit
def
kernel_complex
():
def
kernel_complex
():
M
=
T
.
dynamic
(
'M'
)
M
=
T
.
dynamic
(
"M"
)
N
=
T
.
dynamic
(
'N'
)
N
=
T
.
dynamic
(
"N"
)
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
"float32"
),
A
:
T
.
Tensor
((
M
,
N
),
"float32"
),
B
:
T
.
Tensor
((
M
,
N
),
"float32"
),
B
:
T
.
Tensor
((
M
,
N
),
"float32"
),
):
):
with
T
.
Kernel
(
1
,
threads
=
32
)
as
_
:
with
T
.
Kernel
(
1
,
threads
=
32
)
as
_
:
tid
=
T
.
get_thread_binding
()
tid
=
T
.
get_thread_binding
()
...
@@ -82,8 +79,8 @@ def test_assume_complex_indexing():
...
@@ -82,8 +79,8 @@ def test_assume_complex_indexing():
jit_kernel
=
kernel_complex
()
jit_kernel
=
kernel_complex
()
source
=
jit_kernel
.
get_kernel_source
()
source
=
jit_kernel
.
get_kernel_source
()
assert
(
"if ("
not
in
source
)
assert
"if ("
not
in
source
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
tilelang
.
testing
.
main
()
tilelang
.
testing
.
main
()
testing/python/language/test_tilelang_language_atomic_add.py
View file @
29051439
...
@@ -4,14 +4,12 @@ import tilelang.language as T
...
@@ -4,14 +4,12 @@ import tilelang.language as T
@
tilelang
.
jit
@
tilelang
.
jit
def
atomic_add_program
(
K
,
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
def
atomic_add_program
(
K
,
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
atomic_add
(
A
:
T
.
Tensor
((
K
,
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
)):
def
atomic_add
(
A
:
T
.
Tensor
((
K
,
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
K
,
threads
=
32
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
K
,
threads
=
32
)
as
(
bx
,
by
,
bz
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
T
.
copy
(
A
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
*
block_N
:(
by
+
1
)
*
block_N
],
T
.
copy
(
A
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
*
block_N
:
(
by
+
1
)
*
block_N
],
A_shared
)
A_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
T
.
atomic_add
(
B
[
bx
*
block_M
+
i
,
by
*
block_N
+
j
],
A_shared
[
i
,
j
])
T
.
atomic_add
(
B
[
bx
*
block_M
+
i
,
by
*
block_N
+
j
],
A_shared
[
i
,
j
])
...
@@ -39,14 +37,12 @@ def run_atomic_add(K, M, N, block_M, block_N, dtype="float32"):
...
@@ -39,14 +37,12 @@ def run_atomic_add(K, M, N, block_M, block_N, dtype="float32"):
@
tilelang
.
jit
@
tilelang
.
jit
def
tile_atomic_add_program
(
K
,
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
def
tile_atomic_add_program
(
K
,
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
atomic_add
(
A
:
T
.
Tensor
((
K
,
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
)):
def
atomic_add
(
A
:
T
.
Tensor
((
K
,
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
K
,
threads
=
32
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
K
,
threads
=
32
)
as
(
bx
,
by
,
bz
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
T
.
copy
(
A
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
*
block_N
:(
by
+
1
)
*
block_N
],
T
.
copy
(
A
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
*
block_N
:
(
by
+
1
)
*
block_N
],
A_shared
)
A_shared
)
T
.
atomic_add
(
B
[
bx
*
block_M
,
by
*
block_N
],
A_shared
)
T
.
atomic_add
(
B
[
bx
*
block_M
,
by
*
block_N
],
A_shared
)
...
@@ -76,14 +72,12 @@ def run_tile_atomic_add(K, M, N, block_M, block_N, dtype="float32"):
...
@@ -76,14 +72,12 @@ def run_tile_atomic_add(K, M, N, block_M, block_N, dtype="float32"):
@
tilelang
.
jit
@
tilelang
.
jit
def
atomic_max_program
(
K
,
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
def
atomic_max_program
(
K
,
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
atomic_max
(
A
:
T
.
Tensor
((
K
,
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
)):
def
atomic_max
(
A
:
T
.
Tensor
((
K
,
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
K
,
threads
=
32
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
K
,
threads
=
32
)
as
(
bx
,
by
,
bz
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
T
.
copy
(
A
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
*
block_N
:(
by
+
1
)
*
block_N
],
T
.
copy
(
A
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
*
block_N
:
(
by
+
1
)
*
block_N
],
A_shared
)
A_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
T
.
atomic_max
(
B
[
bx
*
block_M
+
i
,
by
*
block_N
+
j
],
A_shared
[
i
,
j
])
T
.
atomic_max
(
B
[
bx
*
block_M
+
i
,
by
*
block_N
+
j
],
A_shared
[
i
,
j
])
...
@@ -111,14 +105,12 @@ def run_atomic_max(K, M, N, block_M, block_N, dtype="float32"):
...
@@ -111,14 +105,12 @@ def run_atomic_max(K, M, N, block_M, block_N, dtype="float32"):
@
tilelang
.
jit
@
tilelang
.
jit
def
atomic_min_program
(
K
,
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
def
atomic_min_program
(
K
,
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
atomic_min
(
A
:
T
.
Tensor
((
K
,
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
)):
def
atomic_min
(
A
:
T
.
Tensor
((
K
,
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
K
,
threads
=
32
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
K
,
threads
=
32
)
as
(
bx
,
by
,
bz
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
T
.
copy
(
A
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
*
block_N
:(
by
+
1
)
*
block_N
],
T
.
copy
(
A
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
*
block_N
:
(
by
+
1
)
*
block_N
],
A_shared
)
A_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
T
.
atomic_min
(
B
[
bx
*
block_M
+
i
,
by
*
block_N
+
j
],
A_shared
[
i
,
j
])
T
.
atomic_min
(
B
[
bx
*
block_M
+
i
,
by
*
block_N
+
j
],
A_shared
[
i
,
j
])
...
@@ -137,7 +129,7 @@ def run_atomic_min(K, M, N, block_M, block_N, dtype="float32"):
...
@@ -137,7 +129,7 @@ def run_atomic_min(K, M, N, block_M, block_N, dtype="float32"):
B
[
i
,
j
]
=
min
(
B
[
i
,
j
],
A
[
k
,
i
,
j
])
B
[
i
,
j
]
=
min
(
B
[
i
,
j
],
A
[
k
,
i
,
j
])
A
=
torch
.
randn
(
K
,
M
,
N
,
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
A
=
torch
.
randn
(
K
,
M
,
N
,
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
B
=
torch
.
full
((
M
,
N
),
float
(
'
inf
'
),
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
B
=
torch
.
full
((
M
,
N
),
float
(
"
inf
"
),
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
ref_B
=
B
.
clone
()
ref_B
=
B
.
clone
()
ref_program
(
A
,
ref_B
)
ref_program
(
A
,
ref_B
)
kernel
(
A
,
B
)
kernel
(
A
,
B
)
...
@@ -146,7 +138,6 @@ def run_atomic_min(K, M, N, block_M, block_N, dtype="float32"):
...
@@ -146,7 +138,6 @@ def run_atomic_min(K, M, N, block_M, block_N, dtype="float32"):
@
tilelang
.
jit
@
tilelang
.
jit
def
atomic_load_store_program
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
def
atomic_load_store_program
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
atomic_load_store
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
)):
def
atomic_load_store
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
threads
=
32
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
threads
=
32
)
as
(
bx
,
by
):
...
@@ -172,18 +163,15 @@ def run_atomic_load_store(M, N, block_M, block_N, dtype="float32"):
...
@@ -172,18 +163,15 @@ def run_atomic_load_store(M, N, block_M, block_N, dtype="float32"):
@
tilelang
.
jit
@
tilelang
.
jit
def
atomic_memory_order_program
(
K
,
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
def
atomic_memory_order_program
(
K
,
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
atomic_with_memory_order
(
A
:
T
.
Tensor
((
K
,
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
)):
def
atomic_with_memory_order
(
A
:
T
.
Tensor
((
K
,
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
K
,
threads
=
32
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
K
,
threads
=
32
)
as
(
bx
,
by
,
bz
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
T
.
copy
(
A
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
*
block_N
:(
by
+
1
)
*
block_N
],
T
.
copy
(
A
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
*
block_N
:
(
by
+
1
)
*
block_N
],
A_shared
)
A_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
T
.
atomic_add
(
T
.
atomic_add
(
B
[
bx
*
block_M
+
i
,
by
*
block_N
+
j
],
A_shared
[
i
,
j
],
memory_order
=
"relaxed"
)
B
[
bx
*
block_M
+
i
,
by
*
block_N
+
j
],
A_shared
[
i
,
j
],
memory_order
=
"relaxed"
)
return
atomic_with_memory_order
return
atomic_with_memory_order
...
@@ -208,7 +196,6 @@ def run_atomic_memory_order(K, M, N, block_M, block_N, dtype="float32"):
...
@@ -208,7 +196,6 @@ def run_atomic_memory_order(K, M, N, block_M, block_N, dtype="float32"):
@
tilelang
.
jit
@
tilelang
.
jit
def
atomic_addx2_program
(
M
,
N
,
block_M
,
block_N
):
def
atomic_addx2_program
(
M
,
N
,
block_M
,
block_N
):
@
T
.
prim_func
@
T
.
prim_func
def
atomic_addx2
(
A
:
T
.
Tensor
((
M
,
N
),
"float16"
),
B
:
T
.
Tensor
((
M
,
N
),
"float16"
)):
def
atomic_addx2
(
A
:
T
.
Tensor
((
M
,
N
),
"float16"
),
B
:
T
.
Tensor
((
M
,
N
),
"float16"
)):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
threads
=
32
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
threads
=
32
)
as
(
bx
,
by
):
...
@@ -262,10 +249,10 @@ def test_atomic_addx2():
...
@@ -262,10 +249,10 @@ def test_atomic_addx2():
@
tilelang
.
jit
@
tilelang
.
jit
def
atomic_different_memory_orders_program
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
def
atomic_different_memory_orders_program
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
atomic_different_orders
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
(
def
atomic_different_orders
(
(
M
,
N
),
dtype
),
D
:
T
.
Tensor
((
M
,
N
),
dtype
)):
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
D
:
T
.
Tensor
((
M
,
N
),
dtype
)
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
threads
=
32
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
threads
=
32
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
idx_i
=
bx
*
block_M
+
i
idx_i
=
bx
*
block_M
+
i
...
@@ -286,18 +273,17 @@ def run_atomic_different_memory_orders(M, N, block_M, block_N, dtype="float32"):
...
@@ -286,18 +273,17 @@ def run_atomic_different_memory_orders(M, N, block_M, block_N, dtype="float32"):
A
=
torch
.
randn
(
M
,
N
,
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
A
=
torch
.
randn
(
M
,
N
,
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
B
=
torch
.
zeros
(
M
,
N
,
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
B
=
torch
.
zeros
(
M
,
N
,
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
C
=
torch
.
zeros
(
M
,
N
,
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
C
=
torch
.
zeros
(
M
,
N
,
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
D
=
torch
.
full
((
M
,
N
),
float
(
'
inf
'
),
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
D
=
torch
.
full
((
M
,
N
),
float
(
"
inf
"
),
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
kernel
(
A
,
B
,
C
,
D
)
kernel
(
A
,
B
,
C
,
D
)
torch
.
testing
.
assert_close
(
B
,
A
,
atol
=
1e-3
,
rtol
=
1e-3
)
torch
.
testing
.
assert_close
(
B
,
A
,
atol
=
1e-3
,
rtol
=
1e-3
)
torch
.
testing
.
assert_close
(
C
,
torch
.
maximum
(
torch
.
zeros_like
(
A
),
A
))
torch
.
testing
.
assert_close
(
C
,
torch
.
maximum
(
torch
.
zeros_like
(
A
),
A
))
torch
.
testing
.
assert_close
(
D
,
torch
.
minimum
(
torch
.
full_like
(
A
,
float
(
'
inf
'
)),
A
))
torch
.
testing
.
assert_close
(
D
,
torch
.
minimum
(
torch
.
full_like
(
A
,
float
(
"
inf
"
)),
A
))
@
tilelang
.
jit
@
tilelang
.
jit
def
atomic_addx4_program
(
M
,
N
,
block_M
,
block_N
):
def
atomic_addx4_program
(
M
,
N
,
block_M
,
block_N
):
@
T
.
prim_func
@
T
.
prim_func
def
atomic_addx4
(
A
:
T
.
Tensor
((
M
,
N
),
"float32"
),
B
:
T
.
Tensor
((
M
,
N
),
"float32"
)):
def
atomic_addx4
(
A
:
T
.
Tensor
((
M
,
N
),
"float32"
),
B
:
T
.
Tensor
((
M
,
N
),
"float32"
)):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
threads
=
32
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
threads
=
32
)
as
(
bx
,
by
):
...
@@ -330,17 +316,14 @@ def run_atomic_addx4(M, N, block_M, block_N):
...
@@ -330,17 +316,14 @@ def run_atomic_addx4(M, N, block_M, block_N):
@
tilelang
.
jit
@
tilelang
.
jit
def
atomic_return_prev_program
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
def
atomic_return_prev_program
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
atomic_with_return_prev
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
def
atomic_with_return_prev
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
old_vals
:
T
.
Tensor
((
M
,
N
),
dtype
)):
old_vals
:
T
.
Tensor
((
M
,
N
),
dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
threads
=
32
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
threads
=
32
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
idx_i
=
bx
*
block_M
+
i
idx_i
=
bx
*
block_M
+
i
idx_j
=
by
*
block_N
+
j
idx_j
=
by
*
block_N
+
j
if
idx_i
<
M
and
idx_j
<
N
:
if
idx_i
<
M
and
idx_j
<
N
:
old_vals
[
idx_i
,
idx_j
]
=
T
.
atomic_add
(
old_vals
[
idx_i
,
idx_j
]
=
T
.
atomic_add
(
B
[
idx_i
,
idx_j
],
A
[
idx_i
,
idx_j
],
return_prev
=
True
)
B
[
idx_i
,
idx_j
],
A
[
idx_i
,
idx_j
],
return_prev
=
True
)
return
atomic_with_return_prev
return
atomic_with_return_prev
...
...
testing/python/language/test_tilelang_language_ceildiv.py
View file @
29051439
...
@@ -5,7 +5,6 @@ import torch
...
@@ -5,7 +5,6 @@ import torch
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
_ceildiv_kernel
(
a
:
int
,
b
:
int
):
def
_ceildiv_kernel
(
a
:
int
,
b
:
int
):
@
T
.
prim_func
@
T
.
prim_func
def
ceildiv_kernel
(
A
:
T
.
Tensor
((
1
,),
"int32"
)):
def
ceildiv_kernel
(
A
:
T
.
Tensor
((
1
,),
"int32"
)):
with
T
.
Kernel
(
1
,
threads
=
1
)
as
_
:
with
T
.
Kernel
(
1
,
threads
=
1
)
as
_
:
...
@@ -30,7 +29,6 @@ def test_ceildiv():
...
@@ -30,7 +29,6 @@ def test_ceildiv():
@
tilelang
.
jit
@
tilelang
.
jit
def
_ceildiv_kernel_dyn
(
b
:
int
):
def
_ceildiv_kernel_dyn
(
b
:
int
):
@
T
.
prim_func
@
T
.
prim_func
def
ceildiv_kernel
(
A
:
T
.
Tensor
((
1
,),
"int32"
),
a
:
T
.
int32
):
def
ceildiv_kernel
(
A
:
T
.
Tensor
((
1
,),
"int32"
),
a
:
T
.
int32
):
with
T
.
Kernel
(
1
,
threads
=
1
)
as
_
:
with
T
.
Kernel
(
1
,
threads
=
1
)
as
_
:
...
...
testing/python/language/test_tilelang_language_chain_equal.py
View file @
29051439
...
@@ -8,14 +8,14 @@ import torch
...
@@ -8,14 +8,14 @@ import torch
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
},)
},
)
def
chain_equal
(
N
,
block_size
,
dtype
=
"float32"
):
def
chain_equal
(
N
,
block_size
,
dtype
=
"float32"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
N
,),
dtype
),
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
C
:
T
.
Tensor
((
N
,),
dtype
),
C
:
T
.
Tensor
((
N
,),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_size
),
threads
=
block_size
)
as
bx
:
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_size
),
threads
=
block_size
)
as
bx
:
for
lane
in
T
.
Parallel
(
block_size
):
for
lane
in
T
.
Parallel
(
block_size
):
...
...
testing/python/language/test_tilelang_language_clamp.py
View file @
29051439
...
@@ -13,8 +13,8 @@ def clamp_within_bounds(
...
@@ -13,8 +13,8 @@ def clamp_within_bounds(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
N
,),
dtype
),
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
block_N
)
as
bx
:
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
block_N
)
as
bx
:
A_shared
=
T
.
alloc_shared
([
block_N
],
dtype
)
A_shared
=
T
.
alloc_shared
([
block_N
],
dtype
)
...
@@ -56,8 +56,8 @@ def clamp_value_range(
...
@@ -56,8 +56,8 @@ def clamp_value_range(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
1
,
N
),
dtype
),
A
:
T
.
Tensor
((
1
,
N
),
dtype
),
B
:
T
.
Tensor
((
1
,
N
),
dtype
),
B
:
T
.
Tensor
((
1
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
block_N
)
as
bx
:
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
block_N
)
as
bx
:
# A_shared = T.alloc_shared([1, block_N], dtype=dtype)
# A_shared = T.alloc_shared([1, block_N], dtype=dtype)
...
...
testing/python/language/test_tilelang_language_clear.py
View file @
29051439
...
@@ -5,12 +5,11 @@ import tilelang.language as T
...
@@ -5,12 +5,11 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
# @tilelang.jit
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
# Initialize Kernel Context
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
...
@@ -42,10 +41,10 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
...
@@ -42,10 +41,10 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
def
run_matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
run_matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
program
=
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
,
accum_dtype
)
program
=
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
,
accum_dtype
)
kernel
=
tilelang
.
compile
(
kernel
=
tilelang
.
compile
(
program
,
out_idx
=
[
2
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_tma_lower"
:
True
})
program
,
out_idx
=
[
2
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_tma_lower"
:
True
})
import
torch
import
torch
from
tilelang.utils
import
map_torch_type
from
tilelang.utils
import
map_torch_type
a
=
torch
.
randn
((
M
,
K
),
dtype
=
map_torch_type
(
dtype
)).
cuda
()
a
=
torch
.
randn
((
M
,
K
),
dtype
=
map_torch_type
(
dtype
)).
cuda
()
b
=
torch
.
randn
((
N
,
K
),
dtype
=
map_torch_type
(
dtype
)).
cuda
()
b
=
torch
.
randn
((
N
,
K
),
dtype
=
map_torch_type
(
dtype
)).
cuda
()
c
=
kernel
(
a
,
b
)
c
=
kernel
(
a
,
b
)
...
...
testing/python/language/test_tilelang_language_composable_index.py
View file @
29051439
...
@@ -7,11 +7,10 @@ import torch
...
@@ -7,11 +7,10 @@ import torch
# add decorator @tilelang.jit if you want to return a torch function
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
# @tilelang.jit
def
tilelang_composable_copy
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
def
tilelang_composable_copy
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
*
N
),
dtype
),
B
:
T
.
Tensor
((
M
*
N
),
dtype
),
):
):
# Initialize Kernel Context
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
...
@@ -35,7 +34,8 @@ def run_tilelang_composable_copy(M=1024, N=1024, block_M=128, block_N=128, dtype
...
@@ -35,7 +34,8 @@ def run_tilelang_composable_copy(M=1024, N=1024, block_M=128, block_N=128, dtype
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
b
=
kernel
(
a
)
b
=
kernel
(
a
)
torch
.
testing
.
assert_close
(
b
.
flatten
(),
a
.
flatten
(),
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
b
.
flatten
(),
a
.
flatten
(),
rtol
=
1e-2
,
atol
=
1e-2
)
...
...
testing/python/language/test_tilelang_language_copy.py
View file @
29051439
...
@@ -7,11 +7,10 @@ import tilelang.testing
...
@@ -7,11 +7,10 @@ import tilelang.testing
# add decorator @tilelang.jit if you want to return a torch function
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
# @tilelang.jit
def
tilelang_copy
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
def
tilelang_copy
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
# Initialize Kernel Context
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
...
@@ -27,10 +26,8 @@ def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16")
...
@@ -27,10 +26,8 @@ def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16")
program
,
program
,
out_idx
=
[
1
],
out_idx
=
[
1
],
target
=
"cuda"
,
target
=
"cuda"
,
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
},
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
)
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
})
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
b
=
kernel
(
a
)
b
=
kernel
(
a
)
torch
.
testing
.
assert_close
(
b
,
a
,
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
b
,
a
,
rtol
=
1e-2
,
atol
=
1e-2
)
...
@@ -43,11 +40,10 @@ def test_tilelang_copy():
...
@@ -43,11 +40,10 @@ def test_tilelang_copy():
def
tilelang_copy_with_stride
(
M
,
N
,
NN
,
block_M
,
block_N
,
dtype
=
"float16"
):
def
tilelang_copy_with_stride
(
M
,
N
,
NN
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
StridedTensor
((
M
,
N
),
(
NN
,
1
),
dtype
),
A
:
T
.
StridedTensor
((
M
,
N
),
(
NN
,
1
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
# Initialize Kernel Context
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
...
@@ -57,12 +53,7 @@ def tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype="float16"):
...
@@ -57,12 +53,7 @@ def tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype="float16"):
return
main
return
main
def
run_tilelang_copy_with_stride
(
M
=
1024
,
def
run_tilelang_copy_with_stride
(
M
=
1024
,
N
=
1024
,
NN
=
2048
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
):
N
=
1024
,
NN
=
2048
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
):
if
isinstance
(
NN
,
int
):
if
isinstance
(
NN
,
int
):
assert
NN
>
N
,
"NN must be greater than N"
assert
NN
>
N
,
"NN must be greater than N"
program
=
tilelang_copy_with_stride
(
M
,
N
,
NN
,
block_M
,
block_N
,
dtype
)
program
=
tilelang_copy_with_stride
(
M
,
N
,
NN
,
block_M
,
block_N
,
dtype
)
...
@@ -73,7 +64,8 @@ def run_tilelang_copy_with_stride(M=1024,
...
@@ -73,7 +64,8 @@ def run_tilelang_copy_with_stride(M=1024,
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
})
},
)
if
isinstance
(
NN
,
T
.
Var
):
if
isinstance
(
NN
,
T
.
Var
):
NN
=
N
*
2
NN
=
N
*
2
a
=
torch
.
randn
(
M
,
NN
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
a
=
torch
.
randn
(
M
,
NN
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
...
@@ -87,11 +79,10 @@ def test_tilelang_copy_with_stride():
...
@@ -87,11 +79,10 @@ def test_tilelang_copy_with_stride():
def
tilelang_copy_bufferload
(
num_tokens
,
dtype
=
"float16"
):
def
tilelang_copy_bufferload
(
num_tokens
,
dtype
=
"float16"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
indices
:
T
.
Tensor
((
num_tokens
,),
"int32"
),
indices
:
T
.
Tensor
((
num_tokens
,),
"int32"
),
x
:
T
.
Tensor
((
num_tokens
,),
dtype
),
x
:
T
.
Tensor
((
num_tokens
,),
dtype
),
):
):
with
T
.
Kernel
(
num_tokens
,
threads
=
32
)
as
pid
:
with
T
.
Kernel
(
num_tokens
,
threads
=
32
)
as
pid
:
idx
=
T
.
alloc_local
([
1
],
"int32"
)
idx
=
T
.
alloc_local
([
1
],
"int32"
)
...
@@ -107,10 +98,8 @@ def run_tilelang_copy_bufferload(num_tokens=128, dtype="float16"):
...
@@ -107,10 +98,8 @@ def run_tilelang_copy_bufferload(num_tokens=128, dtype="float16"):
tilelang
.
compile
(
tilelang
.
compile
(
program
,
program
,
out_idx
=
[
1
],
out_idx
=
[
1
],
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
},
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
)
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
})
def
test_tilelang_copy_bufferload
():
def
test_tilelang_copy_bufferload
():
...
@@ -118,11 +107,10 @@ def test_tilelang_copy_bufferload():
...
@@ -118,11 +107,10 @@ def test_tilelang_copy_bufferload():
def
tilelang_copy_buffer_load_with_parallel
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
def
tilelang_copy_buffer_load_with_parallel
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
# Initialize Kernel Context
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
...
@@ -132,20 +120,14 @@ def tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype="float
...
@@ -132,20 +120,14 @@ def tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype="float
return
main
return
main
def
run_tilelang_copy_buffer_load_with_parallel
(
M
=
1024
,
def
run_tilelang_copy_buffer_load_with_parallel
(
M
=
1024
,
N
=
1024
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
):
N
=
1024
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
):
program
=
tilelang_copy_buffer_load_with_parallel
(
M
,
N
,
block_M
,
block_N
,
dtype
)
program
=
tilelang_copy_buffer_load_with_parallel
(
M
,
N
,
block_M
,
block_N
,
dtype
)
kernel
=
tilelang
.
compile
(
kernel
=
tilelang
.
compile
(
program
,
program
,
out_idx
=
[
1
],
out_idx
=
[
1
],
target
=
"cuda"
,
target
=
"cuda"
,
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
},
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
)
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
})
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
b
=
kernel
(
a
)
b
=
kernel
(
a
)
torch
.
testing
.
assert_close
(
b
,
a
,
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
b
,
a
,
rtol
=
1e-2
,
atol
=
1e-2
)
...
...
testing/python/language/test_tilelang_language_cumsum.py
View file @
29051439
...
@@ -9,8 +9,8 @@ def cumsum_smem_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="float3
...
@@ -9,8 +9,8 @@ def cumsum_smem_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="float3
@
T
.
prim_func
@
T
.
prim_func
def
cumsum
(
def
cumsum
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
# Initialize Kernel Context
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
...
@@ -28,8 +28,8 @@ def cumsum_fragment_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="fl
...
@@ -28,8 +28,8 @@ def cumsum_fragment_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="fl
@
T
.
prim_func
@
T
.
prim_func
def
cumsum
(
def
cumsum
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
# Initialize Kernel Context
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
...
@@ -57,13 +57,16 @@ def run_cumsum(M, N, block_M, block_N, dim=0, reverse=False, dtype="float32", sc
...
@@ -57,13 +57,16 @@ def run_cumsum(M, N, block_M, block_N, dim=0, reverse=False, dtype="float32", sc
ref_b
=
torch
.
empty_like
(
A
)
ref_b
=
torch
.
empty_like
(
A
)
for
i
in
range
(
M
//
block_M
):
for
i
in
range
(
M
//
block_M
):
for
j
in
range
(
N
//
block_N
):
for
j
in
range
(
N
//
block_N
):
ref_b
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
ref_b
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
j
*
block_N
:
(
j
+
1
)
*
block_N
]
=
A
[
j
*
block_
N
:(
j
+
1
)
*
block_
N
]
=
A
[
i
*
block_
M
:(
i
+
1
)
*
block_
M
,
j
*
i
*
block_
M
:
(
i
+
1
)
*
block_
M
,
j
*
block_
N
:
(
j
+
1
)
*
block_
N
block_N
:(
j
+
1
)
*
block_N
].
cumsum
(
dim
=
dim
)
].
cumsum
(
dim
=
dim
)
if
reverse
:
if
reverse
:
ref_b
[
i
*
block_M
:(
i
+
1
)
*
block_M
,
j
*
block_N
:(
j
+
1
)
*
ref_b
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
j
*
block_N
:
(
j
+
1
)
*
block_N
]
=
(
block_N
]
=
A
[
i
*
block_M
:(
i
+
1
)
*
block_M
,
j
*
block_N
:(
j
+
1
)
*
A
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
j
*
block_N
:
(
j
+
1
)
*
block_N
]
block_N
].
flip
(
dims
=
[
dim
]).
cumsum
(
dim
=
dim
).
flip
(
dims
=
[
dim
])
.
flip
(
dims
=
[
dim
])
.
cumsum
(
dim
=
dim
)
.
flip
(
dims
=
[
dim
])
)
return
ref_b
return
ref_b
tilelang_res
=
jit_kernel
(
A
)
tilelang_res
=
jit_kernel
(
A
)
...
@@ -76,8 +79,8 @@ def cumsum_smem_test_1d(N, block_N, reverse=False, dtype="float32"):
...
@@ -76,8 +79,8 @@ def cumsum_smem_test_1d(N, block_N, reverse=False, dtype="float32"):
@
T
.
prim_func
@
T
.
prim_func
def
cumsum
(
def
cumsum
(
A
:
T
.
Tensor
((
N
,),
dtype
),
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
block_N
)
as
bx
:
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
block_N
)
as
bx
:
A_shared
=
T
.
alloc_shared
((
block_N
,),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_N
,),
dtype
)
...
@@ -94,8 +97,8 @@ def cumsum_fragment_test_1d(N, block_N, reverse=False, dtype="float32"):
...
@@ -94,8 +97,8 @@ def cumsum_fragment_test_1d(N, block_N, reverse=False, dtype="float32"):
@
T
.
prim_func
@
T
.
prim_func
def
cumsum
(
def
cumsum
(
A
:
T
.
Tensor
((
N
,),
dtype
),
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
block_N
)
as
bx
:
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
block_N
)
as
bx
:
A_shared
=
T
.
alloc_shared
((
block_N
,),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_N
,),
dtype
)
...
...
testing/python/language/test_tilelang_language_frontend_v2.py
View file @
29051439
...
@@ -8,7 +8,6 @@ from tvm.tir.expr import IntImm, Var
...
@@ -8,7 +8,6 @@ from tvm.tir.expr import IntImm, Var
def
test_argument
():
def
test_argument
():
@
T
.
prim_func
@
T
.
prim_func
def
test_argument
(
def
test_argument
(
t_1
:
T
.
bool
,
t_1
:
T
.
bool
,
...
@@ -41,6 +40,7 @@ def test_argument():
...
@@ -41,6 +40,7 @@ def test_argument():
def
test_expr
():
def
test_expr
():
from
tilelang.language.v2.dtypes
import
_all_dtypes
from
tilelang.language.v2.dtypes
import
_all_dtypes
errors
=
[]
errors
=
[]
for
name
in
_all_dtypes
:
for
name
in
_all_dtypes
:
dtype
=
getattr
(
T
,
name
)
dtype
=
getattr
(
T
,
name
)
...
@@ -116,33 +116,32 @@ def test_expr():
...
@@ -116,33 +116,32 @@ def test_expr():
def
test_dtype_str_repr
():
def
test_dtype_str_repr
():
@
T
.
prim_func
@
T
.
prim_func
def
test_str_repr
():
def
test_str_repr
():
buf_1
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
bool
,
scope
=
'
shared
'
)
# noqa F841
buf_1
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
bool
,
scope
=
"
shared
"
)
# noqa F841
buf_2
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
short
,
scope
=
'
shared
'
)
# noqa F841
buf_2
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
short
,
scope
=
"
shared
"
)
# noqa F841
buf_3
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int
,
scope
=
'
shared
'
)
# noqa F841
buf_3
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int
,
scope
=
"
shared
"
)
# noqa F841
buf_4
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
long
,
scope
=
'
shared
'
)
# noqa F841
buf_4
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
long
,
scope
=
"
shared
"
)
# noqa F841
buf_5
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
half
,
scope
=
'
shared
'
)
# noqa F841
buf_5
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
half
,
scope
=
"
shared
"
)
# noqa F841
buf_6
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float
,
scope
=
'
shared
'
)
# noqa F841
buf_6
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float
,
scope
=
"
shared
"
)
# noqa F841
buf_7
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
long
,
scope
=
'
shared
'
)
# noqa F841
buf_7
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
long
,
scope
=
"
shared
"
)
# noqa F841
buf_8
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int8
,
scope
=
'
shared
'
)
# noqa F841
buf_8
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int8
,
scope
=
"
shared
"
)
# noqa F841
buf_9
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int16
,
scope
=
'
shared
'
)
# noqa F841
buf_9
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int16
,
scope
=
"
shared
"
)
# noqa F841
buf_10
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int32
,
scope
=
'
shared
'
)
# noqa F841
buf_10
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int32
,
scope
=
"
shared
"
)
# noqa F841
buf_11
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int64
,
scope
=
'
shared
'
)
# noqa F841
buf_11
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int64
,
scope
=
"
shared
"
)
# noqa F841
buf_12
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
uint8
,
scope
=
'
shared
'
)
# noqa F841
buf_12
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
uint8
,
scope
=
"
shared
"
)
# noqa F841
buf_13
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
uint16
,
scope
=
'
shared
'
)
# noqa F841
buf_13
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
uint16
,
scope
=
"
shared
"
)
# noqa F841
buf_14
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
uint32
,
scope
=
'
shared
'
)
# noqa F841
buf_14
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
uint32
,
scope
=
"
shared
"
)
# noqa F841
buf_15
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
uint64
,
scope
=
'
shared
'
)
# noqa F841
buf_15
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
uint64
,
scope
=
"
shared
"
)
# noqa F841
buf_16
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e4m3fn
,
scope
=
'
shared
'
)
# noqa F841
buf_16
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e4m3fn
,
scope
=
"
shared
"
)
# noqa F841
buf_17
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e4m3fnuz
,
scope
=
'
shared
'
)
# noqa F841
buf_17
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e4m3fnuz
,
scope
=
"
shared
"
)
# noqa F841
buf_18
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e5m2
,
scope
=
'
shared
'
)
# noqa F841
buf_18
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e5m2
,
scope
=
"
shared
"
)
# noqa F841
buf_19
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e5m2fnuz
,
scope
=
'
shared
'
)
# noqa F841
buf_19
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e5m2fnuz
,
scope
=
"
shared
"
)
# noqa F841
buf_20
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e8m0fnu
,
scope
=
'
shared
'
)
# noqa F841
buf_20
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e8m0fnu
,
scope
=
"
shared
"
)
# noqa F841
buf_21
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float16
,
scope
=
'
shared
'
)
# noqa F841
buf_21
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float16
,
scope
=
"
shared
"
)
# noqa F841
buf_22
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
bfloat16
,
scope
=
'
shared
'
)
# noqa F841
buf_22
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
bfloat16
,
scope
=
"
shared
"
)
# noqa F841
buf_23
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float32
,
scope
=
'
shared
'
)
# noqa F841
buf_23
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float32
,
scope
=
"
shared
"
)
# noqa F841
buf_24
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float64
,
scope
=
'
shared
'
)
# noqa F841
buf_24
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float64
,
scope
=
"
shared
"
)
# noqa F841
# not supported now
# not supported now
...
@@ -205,7 +204,6 @@ def test_dtype_str_repr():
...
@@ -205,7 +204,6 @@ def test_dtype_str_repr():
def
test_var_assign
():
def
test_var_assign
():
@
tilelang
.
jit
(
out_idx
=-
1
)
@
tilelang
.
jit
(
out_idx
=-
1
)
@
T
.
prim_func
@
T
.
prim_func
def
test_var_assign
(
A
:
T
.
Tensor
((
2
,),
T
.
int32
)):
def
test_var_assign
(
A
:
T
.
Tensor
((
2
,),
T
.
int32
)):
...
@@ -223,7 +221,6 @@ def test_var_assign():
...
@@ -223,7 +221,6 @@ def test_var_assign():
def
test_marco_return
():
def
test_marco_return
():
@
T
.
macro
@
T
.
macro
def
macro_return_constant
():
def
macro_return_constant
():
return
0
return
0
...
@@ -258,11 +255,10 @@ def test_marco_return():
...
@@ -258,11 +255,10 @@ def test_marco_return():
def
test_prim_func_generator
():
def
test_prim_func_generator
():
@
T
.
prim_func
(
generator
=
True
)
@
T
.
prim_func
(
generator
=
True
)
def
prim_func_gen
(
def
prim_func_gen
(
A
=
T
.
Tensor
((
128
,),
T
.
float32
),
# noqa: B008
A
=
T
.
Tensor
((
128
,),
T
.
float32
),
# noqa: B008
B
=
T
.
Tensor
((
128
,),
T
.
float32
),
# noqa: B008
B
=
T
.
Tensor
((
128
,),
T
.
float32
),
# noqa: B008
):
):
with
T
.
Kernel
(
128
)
as
(
tx
,):
with
T
.
Kernel
(
128
)
as
(
tx
,):
T
.
copy
(
A
[
tx
],
B
[
tx
])
T
.
copy
(
A
[
tx
],
B
[
tx
])
...
@@ -277,7 +273,6 @@ def test_prim_func_generator():
...
@@ -277,7 +273,6 @@ def test_prim_func_generator():
def
test_serial_for_with_step
():
def
test_serial_for_with_step
():
@
tilelang
.
jit
(
out_idx
=-
1
)
@
tilelang
.
jit
(
out_idx
=-
1
)
@
T
.
prim_func
@
T
.
prim_func
def
test_stepped_serial
(
A
:
T
.
Tensor
((
10
,),
T
.
int32
)):
def
test_stepped_serial
(
A
:
T
.
Tensor
((
10
,),
T
.
int32
)):
...
@@ -291,7 +286,7 @@ def test_serial_for_with_step():
...
@@ -291,7 +286,7 @@ def test_serial_for_with_step():
ker
=
test_stepped_serial
()
ker
=
test_stepped_serial
()
res
=
ker
()
res
=
ker
()
ref
=
torch
.
tensor
([
1
,
2
,
1
,
2
,
1
,
2
,
1
,
2
,
1
,
2
],
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
ref
=
torch
.
tensor
([
1
,
2
,
1
,
2
,
1
,
2
,
1
,
2
,
1
,
2
],
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
assert
torch
.
all
(
res
==
ref
),
f
"Expected
{
ref
}
, but got
{
res
}
"
assert
torch
.
all
(
res
==
ref
),
f
"Expected
{
ref
}
, but got
{
res
}
"
@
tilelang
.
jit
(
out_idx
=-
1
)
@
tilelang
.
jit
(
out_idx
=-
1
)
...
@@ -304,17 +299,16 @@ def test_serial_for_with_step():
...
@@ -304,17 +299,16 @@ def test_serial_for_with_step():
ker
=
test_serial_step_neg
()
ker
=
test_serial_step_neg
()
res
=
ker
()
res
=
ker
()
ref
=
torch
.
tensor
([
10
,
9
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
],
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
ref
=
torch
.
tensor
([
10
,
9
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
],
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
assert
torch
.
all
(
res
==
ref
),
f
"Expected
{
ref
}
, but got
{
res
}
"
assert
torch
.
all
(
res
==
ref
),
f
"Expected
{
ref
}
, but got
{
res
}
"
assert
isinstance
(
T
.
serial
(
1
,
10
,
1
),
IRBuilderFrame
)
assert
isinstance
(
T
.
serial
(
1
,
10
,
1
),
IRBuilderFrame
)
assert
isinstance
(
T
.
serial
(
1
,
10
,
IntImm
(
'
int32
'
,
1
)),
IRBuilderFrame
)
assert
isinstance
(
T
.
serial
(
1
,
10
,
IntImm
(
"
int32
"
,
1
)),
IRBuilderFrame
)
assert
not
isinstance
(
T
.
serial
(
1
,
10
,
Var
(
'
tmp
'
,
'
int32
'
)),
IRBuilderFrame
)
assert
not
isinstance
(
T
.
serial
(
1
,
10
,
Var
(
"
tmp
"
,
"
int32
"
)),
IRBuilderFrame
)
assert
not
isinstance
(
T
.
serial
(
10
,
-
1
,
-
1
),
IRBuilderFrame
)
assert
not
isinstance
(
T
.
serial
(
10
,
-
1
,
-
1
),
IRBuilderFrame
)
def
test_swap_logic
():
def
test_swap_logic
():
@
tilelang
.
jit
@
tilelang
.
jit
@
T
.
prim_func
@
T
.
prim_func
def
swap_var
(
A
:
T
.
Tensor
[(
2
,),
T
.
float32
]):
def
swap_var
(
A
:
T
.
Tensor
[(
2
,),
T
.
float32
]):
...
@@ -344,7 +338,6 @@ def test_swap_logic():
...
@@ -344,7 +338,6 @@ def test_swap_logic():
def
test_while_loop
():
def
test_while_loop
():
@
tilelang
.
jit
(
out_idx
=-
1
)
@
tilelang
.
jit
(
out_idx
=-
1
)
@
T
.
prim_func
@
T
.
prim_func
def
test_while_loop
(
A
:
T
.
Tensor
((
1
,),
T
.
int32
)):
def
test_while_loop
(
A
:
T
.
Tensor
((
1
,),
T
.
int32
)):
...
@@ -374,7 +367,7 @@ def test_var_macro():
...
@@ -374,7 +367,7 @@ def test_var_macro():
x
=
T
.
alloc_var
(
T
.
int32
)
x
=
T
.
alloc_var
(
T
.
int32
)
macro_with_var
(
x
)
macro_with_var
(
x
)
assert
'
x[0] = 1
'
in
prim_call_macro
.
script
()
assert
"
x[0] = 1
"
in
prim_call_macro
.
script
()
finally
:
finally
:
pass
pass
...
@@ -406,7 +399,7 @@ def test_var_macro():
...
@@ -406,7 +399,7 @@ def test_var_macro():
x
=
T
.
alloc_var
(
T
.
int32
)
x
=
T
.
alloc_var
(
T
.
int32
)
macro_with_var
(
x
)
macro_with_var
(
x
)
assert
'
x[0] = 1
'
in
prim_call_macro
.
script
()
assert
"
x[0] = 1
"
in
prim_call_macro
.
script
()
finally
:
finally
:
pass
pass
...
@@ -428,10 +421,8 @@ def test_var_macro():
...
@@ -428,10 +421,8 @@ def test_var_macro():
def
test_frame_inside_macro
():
def
test_frame_inside_macro
():
@
tilelang
.
jit
@
tilelang
.
jit
def
get_sample_kernel
():
def
get_sample_kernel
():
@
T
.
macro
@
T
.
macro
def
transform
(
x
):
def
transform
(
x
):
return
x
+
1
return
x
+
1
...
@@ -442,7 +433,7 @@ def test_frame_inside_macro():
...
@@ -442,7 +433,7 @@ def test_frame_inside_macro():
idx_out
:
T
.
Tensor
[(
32
,),
T
.
int32
],
idx_out
:
T
.
Tensor
[(
32
,),
T
.
int32
],
):
):
with
T
.
Kernel
(
num_blocks
,
threads
=
32
)
as
block_idx
:
# noqa: F841
with
T
.
Kernel
(
num_blocks
,
threads
=
32
)
as
block_idx
:
# noqa: F841
fragment
=
T
.
alloc_fragment
(
32
,
'
int32
'
)
fragment
=
T
.
alloc_fragment
(
32
,
"
int32
"
)
T
.
copy
(
idx_out
,
fragment
)
T
.
copy
(
idx_out
,
fragment
)
for
i
in
T
.
Parallel
(
32
):
for
i
in
T
.
Parallel
(
32
):
...
@@ -467,10 +458,10 @@ def test_buffer_slice_step():
...
@@ -467,10 +458,10 @@ def test_buffer_slice_step():
def
test_boolop
():
def
test_boolop
():
a
=
Var
(
'a'
,
'
int32
'
)
a
=
Var
(
"a"
,
"
int32
"
)
b
=
Var
(
'b'
,
'
int32
'
)
b
=
Var
(
"b"
,
"
int32
"
)
c
=
Var
(
'c'
,
'
int32
'
)
c
=
Var
(
"c"
,
"
int32
"
)
d
=
Var
(
'd'
,
'
int32
'
)
d
=
Var
(
"d"
,
"
int32
"
)
@
T
.
macro
@
T
.
macro
def
cond
():
def
cond
():
...
@@ -479,5 +470,5 @@ def test_boolop():
...
@@ -479,5 +470,5 @@ def test_boolop():
cond
()
cond
()
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
tilelang
.
testing
.
main
()
tilelang
.
testing
.
main
()
testing/python/language/test_tilelang_language_get_warp_info.py
View file @
29051439
...
@@ -23,7 +23,6 @@ def _resolve_warps_per_group(warps_per_group: Optional[int]) -> int:
...
@@ -23,7 +23,6 @@ def _resolve_warps_per_group(warps_per_group: Optional[int]) -> int:
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
_get_laneid_kernel
(
num_threads
:
int
=
128
,
warp_size
:
Optional
[
int
]
=
None
):
def
_get_laneid_kernel
(
num_threads
:
int
=
128
,
warp_size
:
Optional
[
int
]
=
None
):
@
T
.
prim_func
@
T
.
prim_func
def
laneid_kernel
(
A
:
T
.
Tensor
((
num_threads
,),
"int32"
)):
def
laneid_kernel
(
A
:
T
.
Tensor
((
num_threads
,),
"int32"
)):
with
T
.
Kernel
(
1
,
threads
=
num_threads
)
as
_
:
with
T
.
Kernel
(
1
,
threads
=
num_threads
)
as
_
:
...
@@ -35,7 +34,6 @@ def _get_laneid_kernel(num_threads: int = 128, warp_size: Optional[int] = None):
...
@@ -35,7 +34,6 @@ def _get_laneid_kernel(num_threads: int = 128, warp_size: Optional[int] = None):
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
_get_warp_idx_sync_kernel
(
num_threads
:
int
=
128
,
warp_size
:
Optional
[
int
]
=
None
):
def
_get_warp_idx_sync_kernel
(
num_threads
:
int
=
128
,
warp_size
:
Optional
[
int
]
=
None
):
@
T
.
prim_func
@
T
.
prim_func
def
warp_idx_sync_kernel
(
A
:
T
.
Tensor
((
num_threads
,),
"int32"
)):
def
warp_idx_sync_kernel
(
A
:
T
.
Tensor
((
num_threads
,),
"int32"
)):
with
T
.
Kernel
(
1
,
threads
=
num_threads
)
as
_
:
with
T
.
Kernel
(
1
,
threads
=
num_threads
)
as
_
:
...
@@ -47,7 +45,6 @@ def _get_warp_idx_sync_kernel(num_threads: int = 128, warp_size: Optional[int] =
...
@@ -47,7 +45,6 @@ def _get_warp_idx_sync_kernel(num_threads: int = 128, warp_size: Optional[int] =
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
_get_warp_idx_kernel
(
num_threads
:
int
=
128
,
warp_size
:
Optional
[
int
]
=
None
):
def
_get_warp_idx_kernel
(
num_threads
:
int
=
128
,
warp_size
:
Optional
[
int
]
=
None
):
@
T
.
prim_func
@
T
.
prim_func
def
warp_idx_kernel
(
A
:
T
.
Tensor
((
num_threads
,),
"int32"
)):
def
warp_idx_kernel
(
A
:
T
.
Tensor
((
num_threads
,),
"int32"
)):
with
T
.
Kernel
(
1
,
threads
=
num_threads
)
as
_
:
with
T
.
Kernel
(
1
,
threads
=
num_threads
)
as
_
:
...
@@ -63,7 +60,6 @@ def _get_warp_group_idx_kernel(
...
@@ -63,7 +60,6 @@ def _get_warp_group_idx_kernel(
warp_size
:
Optional
[
int
]
=
None
,
warp_size
:
Optional
[
int
]
=
None
,
warps_per_group
:
Optional
[
int
]
=
None
,
warps_per_group
:
Optional
[
int
]
=
None
,
):
):
@
T
.
prim_func
@
T
.
prim_func
def
warp_group_idx_kernel
(
A
:
T
.
Tensor
((
num_threads
,),
"int32"
)):
def
warp_group_idx_kernel
(
A
:
T
.
Tensor
((
num_threads
,),
"int32"
)):
with
T
.
Kernel
(
1
,
threads
=
num_threads
)
as
_
:
with
T
.
Kernel
(
1
,
threads
=
num_threads
)
as
_
:
...
@@ -75,7 +71,6 @@ def _get_warp_group_idx_kernel(
...
@@ -75,7 +71,6 @@ def _get_warp_group_idx_kernel(
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
_shuffle_elect_kernel
(
num_threads
:
int
=
128
,
thread_extent
:
int
=
64
):
def
_shuffle_elect_kernel
(
num_threads
:
int
=
128
,
thread_extent
:
int
=
64
):
@
T
.
prim_func
@
T
.
prim_func
def
shuffle_elect_kernel
(
A
:
T
.
Tensor
((
num_threads
,),
"int32"
)):
def
shuffle_elect_kernel
(
A
:
T
.
Tensor
((
num_threads
,),
"int32"
)):
with
T
.
Kernel
(
1
,
threads
=
num_threads
)
as
_
:
with
T
.
Kernel
(
1
,
threads
=
num_threads
)
as
_
:
...
...
testing/python/language/test_tilelang_language_if_range.py
View file @
29051439
...
@@ -4,13 +4,14 @@ import torch
...
@@ -4,13 +4,14 @@ import torch
import
tilelang.testing
import
tilelang.testing
@
tilelang
.
jit
(
out_idx
=
[
1
],)
@
tilelang
.
jit
(
out_idx
=
[
1
],
)
def
tilelang_if_range
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
def
tilelang_if_range
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
# Initialize Kernel Context
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
...
...
testing/python/language/test_tilelang_language_infinity.py
View file @
29051439
...
@@ -5,7 +5,6 @@ import tilelang.language as T
...
@@ -5,7 +5,6 @@ import tilelang.language as T
@
tilelang
.
jit
(
out_idx
=-
1
)
@
tilelang
.
jit
(
out_idx
=-
1
)
def
get_inf_kernel
(
dtype
:
str
):
def
get_inf_kernel
(
dtype
:
str
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
32
,),
dtype
)):
def
main
(
A
:
T
.
Tensor
((
32
,),
dtype
)):
with
T
.
Kernel
(
1
,
threads
=
32
):
with
T
.
Kernel
(
1
,
threads
=
32
):
...
@@ -18,7 +17,7 @@ def _test_infinity(dtype: str):
...
@@ -18,7 +17,7 @@ def _test_infinity(dtype: str):
kernel
=
get_inf_kernel
(
dtype
)
kernel
=
get_inf_kernel
(
dtype
)
output
=
kernel
()
output
=
kernel
()
assert
torch
.
all
(
output
==
torch
.
inf
),
f
'
check failed for
{
dtype
=
}
'
assert
torch
.
all
(
output
==
torch
.
inf
),
f
"
check failed for
{
dtype
=
}
"
@
tilelang
.
testing
.
requires_cuda
@
tilelang
.
testing
.
requires_cuda
...
...
testing/python/language/test_tilelang_language_intrinsics_codegen.py
View file @
29051439
...
@@ -9,8 +9,8 @@ def test_language_ldg_codegen():
...
@@ -9,8 +9,8 @@ def test_language_ldg_codegen():
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
x
:
T
.
Tensor
((
N
,),
"float32"
),
x
:
T
.
Tensor
((
N
,),
"float32"
),
y
:
T
.
Tensor
((
N
,),
"float32"
),
y
:
T
.
Tensor
((
N
,),
"float32"
),
):
):
with
T
.
Kernel
(
N
,
threads
=
32
)
as
pid
:
with
T
.
Kernel
(
N
,
threads
=
32
)
as
pid
:
# Explicitly request read-only cache load for x[pid]
# Explicitly request read-only cache load for x[pid]
...
...
testing/python/language/test_tilelang_language_lazy_jit.py
View file @
29051439
...
@@ -8,7 +8,6 @@ import torch
...
@@ -8,7 +8,6 @@ import torch
def
_gemm_impl
():
def
_gemm_impl
():
@
T
.
macro
@
T
.
macro
def
gemm_impl
(
def
gemm_impl
(
A
:
T
.
Tensor
[[
int
,
int
],
Any
],
A
:
T
.
Tensor
[[
int
,
int
],
Any
],
...
@@ -37,7 +36,6 @@ def _gemm_impl():
...
@@ -37,7 +36,6 @@ def _gemm_impl():
def
test_jit2_gemm_annot
():
def
test_jit2_gemm_annot
():
@
tilelang
.
lazy_jit
@
tilelang
.
lazy_jit
def
gemm
(
def
gemm
(
A
:
T
.
Tensor
[[
int
,
int
],
Any
],
A
:
T
.
Tensor
[[
int
,
int
],
Any
],
...
@@ -54,24 +52,24 @@ def test_jit2_gemm_annot():
...
@@ -54,24 +52,24 @@ def test_jit2_gemm_annot():
return
C
return
C
prod
=
product
([
T
.
float16
,
T
.
float32
],
[
T
.
float32
])
prod
=
product
([
T
.
float16
,
T
.
float32
],
[
T
.
float32
])
gemm
.
par_compile
([{
gemm
.
par_compile
(
'A'
:
T
.
Tensor
((
1024
,
1024
),
dtype
=
in_dtype
),
[
'B'
:
T
.
Tensor
((
1024
,
1024
),
dtype
=
in_dtype
),
{
"A"
:
T
.
Tensor
((
1024
,
1024
),
dtype
=
in_dtype
),
"B"
:
T
.
Tensor
((
1024
,
1024
),
dtype
=
in_dtype
),
"out_dtype"
:
out_dtype
}
'out_dtype'
:
out_dtype
for
in_dtype
,
out_dtype
in
prod
}
for
in_dtype
,
out_dtype
in
prod
])
]
)
for
in_dtype
,
out_dtype
in
prod
:
for
in_dtype
,
out_dtype
in
prod
:
in_dtype
=
in_dtype
.
torch
()
in_dtype
=
in_dtype
.
torch
()
out_dtype
=
out_dtype
.
torch
()
out_dtype
=
out_dtype
.
torch
()
A
=
torch
.
randn
(
1024
,
1024
,
dtype
=
in_dtype
,
device
=
'
cuda
'
)
A
=
torch
.
randn
(
1024
,
1024
,
dtype
=
in_dtype
,
device
=
"
cuda
"
)
B
=
torch
.
randn
(
1024
,
1024
,
dtype
=
in_dtype
,
device
=
'
cuda
'
)
B
=
torch
.
randn
(
1024
,
1024
,
dtype
=
in_dtype
,
device
=
"
cuda
"
)
C_ref
=
out_dtype
(
A
@
B
)
C_ref
=
out_dtype
(
A
@
B
)
C
=
gemm
(
A
,
B
)
C
=
gemm
(
A
,
B
)
torch
.
testing
.
assert_close
(
C
,
C_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
C
,
C_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
def
test_jit2_gemm_ptr
():
def
test_jit2_gemm_ptr
():
@
tilelang
.
lazy_jit
@
tilelang
.
lazy_jit
def
gemm_ptr
(
def
gemm_ptr
(
A
:
T
.
ptr
,
A
:
T
.
ptr
,
...
@@ -92,23 +90,19 @@ def test_jit2_gemm_ptr():
...
@@ -92,23 +90,19 @@ def test_jit2_gemm_ptr():
_gemm_impl
()(
A
,
B
,
C
,
out_dtype
,
block_M
,
block_N
,
block_K
)
_gemm_impl
()(
A
,
B
,
C
,
out_dtype
,
block_M
,
block_N
,
block_K
)
prod
=
product
([
T
.
float16
,
T
.
float32
],
[
T
.
float32
])
prod
=
product
([
T
.
float16
,
T
.
float32
],
[
T
.
float32
])
gemm_ptr
.
par_compile
([{
gemm_ptr
.
par_compile
(
'A'
:
T
.
ptr
(),
[
'B'
:
T
.
ptr
(),
{
"A"
:
T
.
ptr
(),
"B"
:
T
.
ptr
(),
"C"
:
T
.
ptr
(),
"M"
:
1024
,
"N"
:
1024
,
"K"
:
1024
,
"dtype"
:
in_dtype
,
"out_dtype"
:
out_dtype
}
'C'
:
T
.
ptr
(),
for
in_dtype
,
out_dtype
in
prod
'M'
:
1024
,
]
'N'
:
1024
,
)
'K'
:
1024
,
'dtype'
:
in_dtype
,
'out_dtype'
:
out_dtype
}
for
in_dtype
,
out_dtype
in
prod
])
for
in_dtype
,
out_dtype
in
prod
:
for
in_dtype
,
out_dtype
in
prod
:
in_dtype
=
in_dtype
.
torch
()
in_dtype
=
in_dtype
.
torch
()
out_dtype
=
out_dtype
.
torch
()
out_dtype
=
out_dtype
.
torch
()
A
=
torch
.
randn
(
1024
,
1024
,
dtype
=
in_dtype
,
device
=
'
cuda
'
)
A
=
torch
.
randn
(
1024
,
1024
,
dtype
=
in_dtype
,
device
=
"
cuda
"
)
B
=
torch
.
randn
(
1024
,
1024
,
dtype
=
in_dtype
,
device
=
'
cuda
'
)
B
=
torch
.
randn
(
1024
,
1024
,
dtype
=
in_dtype
,
device
=
"
cuda
"
)
C_ref
=
out_dtype
(
A
@
B
)
C_ref
=
out_dtype
(
A
@
B
)
C
=
torch
.
empty
(
1024
,
1024
,
dtype
=
out_dtype
,
device
=
'
cuda
'
)
C
=
torch
.
empty
(
1024
,
1024
,
dtype
=
out_dtype
,
device
=
"
cuda
"
)
gemm_ptr
(
A
,
B
,
C
,
1024
,
1024
,
1024
,
in_dtype
,
out_dtype
)
gemm_ptr
(
A
,
B
,
C
,
1024
,
1024
,
1024
,
in_dtype
,
out_dtype
)
torch
.
testing
.
assert_close
(
C
,
C_ref
,
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
C
,
C_ref
,
atol
=
1e-2
,
rtol
=
1e-2
)
...
@@ -129,8 +123,7 @@ def test_jit2_annot():
...
@@ -129,8 +123,7 @@ def test_jit2_annot():
AnnotTest
(
AnnotTest
(
annot
=
T
.
Tensor
[[
int
,
int
],
T
.
float32
],
annot
=
T
.
Tensor
[[
int
,
int
],
T
.
float32
],
promote
=
False
,
promote
=
False
,
match_ok
=
[
torch
.
randn
(
1
,
1
,
dtype
=
torch
.
float32
),
match_ok
=
[
torch
.
randn
(
1
,
1
,
dtype
=
torch
.
float32
),
T
.
Tensor
((
1
,
1
),
dtype
=
T
.
float32
)],
T
.
Tensor
((
1
,
1
),
dtype
=
T
.
float32
)],
match_ng
=
[
match_ng
=
[
torch
.
randn
(
1
,
1
,
dtype
=
torch
.
float16
),
torch
.
randn
(
1
,
1
,
dtype
=
torch
.
float16
),
T
.
Tensor
(
1
,
dtype
=
T
.
float32
),
T
.
Tensor
(
1
,
dtype
=
T
.
float32
),
...
@@ -146,8 +139,8 @@ def test_jit2_annot():
...
@@ -146,8 +139,8 @@ def test_jit2_annot():
T
.
Tensor
((
1
,),
dtype
=
T
.
float32
),
T
.
Tensor
((
1
,),
dtype
=
T
.
float32
),
T
.
Tensor
((
1
,),
dtype
=
T
.
float16
),
T
.
Tensor
((
1
,),
dtype
=
T
.
float16
),
],
],
match_ng
=
[
torch
.
randn
((
1
,
1
),
dtype
=
torch
.
float32
),
match_ng
=
[
torch
.
randn
((
1
,
1
),
dtype
=
torch
.
float32
),
T
.
Tensor
((
1
,
1
),
dtype
=
T
.
float16
)],
T
.
Tensor
((
1
,
1
),
dtype
=
T
.
float16
)]
),
),
AnnotTest
(
AnnotTest
(
annot
=
T
.
Tensor
[[
int
,
1
],
Any
],
annot
=
T
.
Tensor
[[
int
,
1
],
Any
],
promote
=
False
,
promote
=
False
,
...
@@ -157,8 +150,8 @@ def test_jit2_annot():
...
@@ -157,8 +150,8 @@ def test_jit2_annot():
T
.
Tensor
((
12
,
1
),
T
.
float32
),
T
.
Tensor
((
12
,
1
),
T
.
float32
),
T
.
Tensor
((
12
,
1
),
T
.
float16
),
T
.
Tensor
((
12
,
1
),
T
.
float16
),
],
],
match_ng
=
[
torch
.
randn
(
12
,
12
,
dtype
=
torch
.
float32
),
match_ng
=
[
torch
.
randn
(
12
,
12
,
dtype
=
torch
.
float32
),
T
.
Tensor
((
12
,
12
),
T
.
float32
)],
T
.
Tensor
((
12
,
12
),
T
.
float32
)]
),
),
AnnotTest
(
AnnotTest
(
annot
=
T
.
Tensor
[[
T
.
dyn
,
1
],
Any
],
annot
=
T
.
Tensor
[[
T
.
dyn
,
1
],
Any
],
promote
=
False
,
promote
=
False
,
...
@@ -168,43 +161,39 @@ def test_jit2_annot():
...
@@ -168,43 +161,39 @@ def test_jit2_annot():
T
.
Tensor
((
12
,
1
),
T
.
float32
),
T
.
Tensor
((
12
,
1
),
T
.
float32
),
T
.
Tensor
((
12
,
1
),
T
.
float16
),
T
.
Tensor
((
12
,
1
),
T
.
float16
),
],
],
match_ng
=
[
torch
.
randn
(
12
,
12
,
dtype
=
torch
.
float32
),
match_ng
=
[
torch
.
randn
(
12
,
12
,
dtype
=
torch
.
float32
),
T
.
Tensor
((
12
,
12
),
T
.
float32
)],
T
.
Tensor
((
12
,
12
),
T
.
float32
)]
),
),
AnnotTest
(
AnnotTest
(
annot
=
T
.
Tensor
[[
1024
,
1024
],
T
.
float32
],
annot
=
T
.
Tensor
[[
1024
,
1024
],
T
.
float32
],
promote
=
True
,
promote
=
True
,
),
),
AnnotTest
(
annot
=
T
.
dyn
[
int
,
'X'
],
promote
=
False
,
match_ok
=
[
1
,
2
,
3
,
4
]),
AnnotTest
(
annot
=
T
.
dyn
[
int
,
"X"
],
promote
=
False
,
match_ok
=
[
1
,
2
,
3
,
4
]),
AnnotTest
(
annot
=
T
.
dyn
,
promote
=
False
,
match_ok
=
[
1
,
2
,
3
,
4
])
AnnotTest
(
annot
=
T
.
dyn
,
promote
=
False
,
match_ok
=
[
1
,
2
,
3
,
4
])
,
]
]
for
test
in
tests
:
for
test
in
tests
:
promote
=
test
.
annot
.
promote
()
promote
=
test
.
annot
.
promote
()
promoted
=
promote
is
not
None
promoted
=
promote
is
not
None
if
promoted
!=
test
.
promote
:
if
promoted
!=
test
.
promote
:
raise
AssertionError
(
raise
AssertionError
(
f
"Promote mismatch for
{
test
.
annot
}
: expected
{
test
.
promote
}
, got
{
promoted
}
"
)
f
'Promote mismatch for
{
test
.
annot
}
: expected
{
test
.
promote
}
, got
{
promoted
}
'
)
with
Builder
().
prim_func
(
"_test"
):
with
Builder
().
prim_func
(
'_test'
):
for
match_ok
in
test
.
match_ok
:
for
match_ok
in
test
.
match_ok
:
try
:
try
:
vt
=
ArgVarTable
()
vt
=
ArgVarTable
()
test
.
annot
.
create_prim_func_arg
(
'
arg
'
,
match_ok
,
vt
)
test
.
annot
.
create_prim_func_arg
(
"
arg
"
,
match_ok
,
vt
)
except
Exception
as
e
:
except
Exception
as
e
:
traceback
.
print_exc
()
traceback
.
print_exc
()
raise
AssertionError
(
raise
AssertionError
(
f
"Match failed for
{
test
.
annot
}
with value
{
match_ok
}
:
{
e
}
"
)
from
e
f
'Match failed for
{
test
.
annot
}
with value
{
match_ok
}
:
{
e
}
'
)
from
e
for
match_ng
in
test
.
match_ng
:
for
match_ng
in
test
.
match_ng
:
try
:
try
:
vt
=
ArgVarTable
()
vt
=
ArgVarTable
()
test
.
annot
.
create_prim_func_arg
(
'arg'
,
match_ng
,
vt
)
test
.
annot
.
create_prim_func_arg
(
"arg"
,
match_ng
,
vt
)
raise
AssertionError
(
raise
AssertionError
(
f
"Match unexpectedly succeeded for
{
test
.
annot
}
with value
{
match_ng
}
"
)
f
'Match unexpectedly succeeded for
{
test
.
annot
}
with value
{
match_ng
}
'
)
except
Exception
:
except
Exception
:
pass
pass
def
test_jit2_many_annot
():
def
test_jit2_many_annot
():
@
T
.
macro
@
T
.
macro
def
copy_impl
(
A
,
B
):
def
copy_impl
(
A
,
B
):
M
,
N
=
A
.
shape
M
,
N
=
A
.
shape
...
@@ -213,8 +202,7 @@ def test_jit2_many_annot():
...
@@ -213,8 +202,7 @@ def test_jit2_many_annot():
assert
N
==
N_
,
f
"N mismatch
{
N
}
{
N_
}
"
assert
N
==
N_
,
f
"N mismatch
{
N
}
{
N_
}
"
# assert tuple(A.shape) == tuple(B.shape), f"Invalid tensor shape: {A.shape}, {B.shape}"
# assert tuple(A.shape) == tuple(B.shape), f"Invalid tensor shape: {A.shape}, {B.shape}"
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
128
),
T
.
ceildiv
(
N
,
128
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
128
),
T
.
ceildiv
(
N
,
128
),
threads
=
128
)
as
(
bx
,
by
):
T
.
copy
(
A
[
bx
*
128
:
bx
*
128
+
128
,
by
*
128
:
by
*
128
+
128
],
B
[
bx
*
128
:
bx
*
128
+
128
,
T
.
copy
(
A
[
bx
*
128
:
bx
*
128
+
128
,
by
*
128
:
by
*
128
+
128
],
B
[
bx
*
128
:
bx
*
128
+
128
,
by
*
128
:
by
*
128
+
128
])
by
*
128
:
by
*
128
+
128
])
@
tilelang
.
lazy_jit
@
tilelang
.
lazy_jit
def
copy1
(
def
copy1
(
...
@@ -259,20 +247,19 @@ def test_jit2_many_annot():
...
@@ -259,20 +247,19 @@ def test_jit2_many_annot():
copy_impl
(
A
,
B
)
copy_impl
(
A
,
B
)
for
copy
in
[
copy1
,
copy2
,
copy3
,
copy4
]:
for
copy
in
[
copy1
,
copy2
,
copy3
,
copy4
]:
A
=
torch
.
randn
(
128
,
128
,
device
=
'
cuda
'
)
A
=
torch
.
randn
(
128
,
128
,
device
=
"
cuda
"
)
B
=
torch
.
empty
(
128
,
128
,
device
=
'
cuda
'
)
B
=
torch
.
empty
(
128
,
128
,
device
=
"
cuda
"
)
copy
(
A
,
B
)
copy
(
A
,
B
)
assert
torch
.
equal
(
B
,
A
)
assert
torch
.
equal
(
B
,
A
)
for
copy
in
[
copy5
,
copy6
]:
for
copy
in
[
copy5
,
copy6
]:
A
=
torch
.
randn
(
128
,
2
,
128
,
2
,
device
=
'
cuda
'
)
A
=
torch
.
randn
(
128
,
2
,
128
,
2
,
device
=
"
cuda
"
)
B
=
torch
.
randn
(
128
,
2
,
128
,
2
,
device
=
'
cuda
'
)
B
=
torch
.
randn
(
128
,
2
,
128
,
2
,
device
=
"
cuda
"
)
copy
(
A
[:,
0
,
:,
0
],
B
[:,
0
,
:,
0
])
copy
(
A
[:,
0
,
:,
0
],
B
[:,
0
,
:,
0
])
assert
torch
.
equal
(
A
[:,
0
,
:,
0
],
B
[:,
0
,
:,
0
])
assert
torch
.
equal
(
A
[:,
0
,
:,
0
],
B
[:,
0
,
:,
0
])
def
test_jit2_return
():
def
test_jit2_return
():
@
T
.
macro
@
T
.
macro
def
copy_impl
(
A
):
def
copy_impl
(
A
):
M
,
N
=
A
.
shape
M
,
N
=
A
.
shape
...
@@ -283,8 +270,7 @@ def test_jit2_return():
...
@@ -283,8 +270,7 @@ def test_jit2_return():
assert
N
==
N_
,
f
"N mismatch
{
N
}
{
N_
}
"
assert
N
==
N_
,
f
"N mismatch
{
N
}
{
N_
}
"
# assert tuple(A.shape) == tuple(B.shape), f"Invalid tensor shape: {A.shape}, {B.shape}"
# assert tuple(A.shape) == tuple(B.shape), f"Invalid tensor shape: {A.shape}, {B.shape}"
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
128
),
T
.
ceildiv
(
N
,
128
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
128
),
T
.
ceildiv
(
N
,
128
),
threads
=
128
)
as
(
bx
,
by
):
T
.
copy
(
A
[
bx
*
128
:
bx
*
128
+
128
,
by
*
128
:
by
*
128
+
128
],
B
[
bx
*
128
:
bx
*
128
+
128
,
T
.
copy
(
A
[
bx
*
128
:
bx
*
128
+
128
,
by
*
128
:
by
*
128
+
128
],
B
[
bx
*
128
:
bx
*
128
+
128
,
by
*
128
:
by
*
128
+
128
])
by
*
128
:
by
*
128
+
128
])
return
B
return
B
@
tilelang
.
lazy_jit
@
tilelang
.
lazy_jit
...
@@ -292,41 +278,52 @@ def test_jit2_return():
...
@@ -292,41 +278,52 @@ def test_jit2_return():
return
copy_impl
(
A
)
return
copy_impl
(
A
)
@
tilelang
.
lazy_jit
@
tilelang
.
lazy_jit
def
copy1
(
A
:
T
.
Tensor
[[
int
,
int
],
T
.
float32
],):
def
copy1
(
A
:
T
.
Tensor
[[
int
,
int
],
T
.
float32
],
):
return
copy_impl
(
A
)
return
copy_impl
(
A
)
@
tilelang
.
lazy_jit
@
tilelang
.
lazy_jit
def
copy2
(
A
:
T
.
Tensor
[[
128
,
128
],
T
.
float32
],):
def
copy2
(
A
:
T
.
Tensor
[[
128
,
128
],
T
.
float32
],
):
return
copy_impl
(
A
)
return
copy_impl
(
A
)
@
tilelang
.
lazy_jit
@
tilelang
.
lazy_jit
def
copy3
(
A
:
T
.
Tensor
[[
int
,
128
],
T
.
float32
],):
def
copy3
(
A
:
T
.
Tensor
[[
int
,
128
],
T
.
float32
],
):
return
copy_impl
(
A
)
return
copy_impl
(
A
)
@
tilelang
.
lazy_jit
@
tilelang
.
lazy_jit
def
copy4
(
A
:
T
.
Tensor
[[
T
.
dyn
,
int
],
T
.
float32
],):
def
copy4
(
A
:
T
.
Tensor
[[
T
.
dyn
,
int
],
T
.
float32
],
):
return
copy_impl
(
A
)
return
copy_impl
(
A
)
@
tilelang
.
lazy_jit
@
tilelang
.
lazy_jit
def
copy5
(
A
:
T
.
StridedTensor
[[
int
,
int
],
[
int
,
int
],
T
.
float32
],):
def
copy5
(
A
:
T
.
StridedTensor
[[
int
,
int
],
[
int
,
int
],
T
.
float32
],
):
return
copy_impl
(
A
)
return
copy_impl
(
A
)
@
tilelang
.
lazy_jit
@
tilelang
.
lazy_jit
def
copy6
(
A
:
T
.
StridedTensor
[[
T
.
dyn
,
int
],
[
int
,
int
],
T
.
float32
],):
def
copy6
(
A
:
T
.
StridedTensor
[[
T
.
dyn
,
int
],
[
int
,
int
],
T
.
float32
],
):
return
copy_impl
(
A
)
return
copy_impl
(
A
)
for
copy
in
[
copy0
,
copy1
,
copy2
,
copy3
,
copy4
]:
for
copy
in
[
copy0
,
copy1
,
copy2
,
copy3
,
copy4
]:
A
=
torch
.
randn
(
128
,
128
,
device
=
'
cuda
'
)
A
=
torch
.
randn
(
128
,
128
,
device
=
"
cuda
"
)
B
=
copy
(
A
)
B
=
copy
(
A
)
assert
torch
.
equal
(
B
,
A
)
assert
torch
.
equal
(
B
,
A
)
for
copy
in
[
copy5
,
copy6
]:
for
copy
in
[
copy5
,
copy6
]:
A
=
torch
.
randn
(
128
,
2
,
128
,
2
,
device
=
'
cuda
'
)
A
=
torch
.
randn
(
128
,
2
,
128
,
2
,
device
=
"
cuda
"
)
B
=
copy
(
A
[:,
0
,
:,
0
])
B
=
copy
(
A
[:,
0
,
:,
0
])
assert
torch
.
equal
(
A
[:,
0
,
:,
0
],
B
)
assert
torch
.
equal
(
A
[:,
0
,
:,
0
],
B
)
def
test_jit2_deepseek_deepgemm
():
def
test_jit2_deepseek_deepgemm
():
@
tilelang
.
lazy_jit
@
tilelang
.
lazy_jit
def
deep_gemm
(
def
deep_gemm
(
A
:
T
.
Tensor
[[
int
,
int
],
T
.
float8_e4m3
],
A
:
T
.
Tensor
[[
int
,
int
],
T
.
float8_e4m3
],
...
@@ -351,13 +348,9 @@ def test_jit2_deepseek_deepgemm():
...
@@ -351,13 +348,9 @@ def test_jit2_deepseek_deepgemm():
N
,
K
=
B
.
shape
N
,
K
=
B
.
shape
C
=
T
.
empty
(
M
,
N
,
dtype
=
out_dtype
)
C
=
T
.
empty
(
M
,
N
,
dtype
=
out_dtype
)
assert
out_dtype
in
[
assert
out_dtype
in
[
T
.
bfloat16
,
T
.
float32
],
f
"Expect out_dtype to be one of [T.float16, T.float32], got
{
out_dtype
}
"
T
.
bfloat16
,
T
.
float32
assert
scales_a
.
shape
==
[
M
,
T
.
ceildiv
(
K
,
group_size
)],
f
"Expect scales_a shape to be f
{
[
M
,
T
.
ceildiv
(
K
,
group_size
)]
}
"
],
f
"Expect out_dtype to be one of [T.float16, T.float32], got
{
out_dtype
}
"
assert
scales_b
.
shape
==
[
N
,
T
.
ceildiv
(
K
,
group_size
)],
f
"Expect scales_b shape to be f
{
[
N
,
T
.
ceildiv
(
K
,
group_size
)]
}
"
assert
scales_a
.
shape
==
[
M
,
T
.
ceildiv
(
K
,
group_size
)
],
f
"Expect scales_a shape to be f
{
[
M
,
T
.
ceildiv
(
K
,
group_size
)]
}
"
assert
scales_b
.
shape
==
[
N
,
T
.
ceildiv
(
K
,
group_size
)
],
f
"Expect scales_b shape to be f
{
[
N
,
T
.
ceildiv
(
K
,
group_size
)]
}
"
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
in_dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
in_dtype
)
...
@@ -421,5 +414,5 @@ def test_jit2_deepseek_deepgemm():
...
@@ -421,5 +414,5 @@ def test_jit2_deepseek_deepgemm():
# M, N, K = 1024, 1024, 8192
# M, N, K = 1024, 1024, 8192
# A = torch.randn((M, K), dtype=torch.float8_e4m3fn, )
# A = torch.randn((M, K), dtype=torch.float8_e4m3fn, )
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
tilelang
.
testing
.
main
()
tilelang
.
testing
.
main
()
testing/python/language/test_tilelang_language_let.py
View file @
29051439
...
@@ -4,7 +4,6 @@ from tilelang import language as T
...
@@ -4,7 +4,6 @@ from tilelang import language as T
def
test_let_vectorize_load
():
def
test_let_vectorize_load
():
@
T
.
prim_func
@
T
.
prim_func
def
main
(
A_ptr
:
T
.
handle
):
def
main
(
A_ptr
:
T
.
handle
):
A
=
T
.
match_buffer
(
A_ptr
,
(
16
,
16
),
dtype
=
"float32"
,
align
=
16
)
A
=
T
.
match_buffer
(
A_ptr
,
(
16
,
16
),
dtype
=
"float32"
,
align
=
16
)
...
...
testing/python/language/test_tilelang_language_mask_op.py
View file @
29051439
...
@@ -6,11 +6,10 @@ import torch
...
@@ -6,11 +6,10 @@ import torch
# add decorator @tilelang.jit if you want to return a torch function
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
# @tilelang.jit
def
tilelang_copy_mask_parallel
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
def
tilelang_copy_mask_parallel
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
# Initialize Kernel Context
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
...
@@ -30,13 +29,8 @@ def tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype="float16"):
...
@@ -30,13 +29,8 @@ def tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype="float16"):
def
run_tilelang_copy_mask_parallel
(
M
=
1024
,
N
=
1024
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
):
def
run_tilelang_copy_mask_parallel
(
M
=
1024
,
N
=
1024
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
):
program
=
tilelang_copy_mask_parallel
(
M
,
N
,
block_M
,
block_N
,
dtype
)
program
=
tilelang_copy_mask_parallel
(
M
,
N
,
block_M
,
block_N
,
dtype
)
kernel
=
tilelang
.
compile
(
kernel
=
tilelang
.
compile
(
program
,
program
,
out_idx
=
[
1
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_warp_specialized"
:
True
,
"tl.disable_tma_lower"
:
True
}
out_idx
=
[
1
],
)
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_warp_specialized"
:
True
,
"tl.disable_tma_lower"
:
True
})
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
b
=
kernel
(
a
)
b
=
kernel
(
a
)
torch
.
testing
.
assert_close
(
b
,
a
,
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
b
,
a
,
rtol
=
1e-2
,
atol
=
1e-2
)
...
@@ -49,11 +43,10 @@ def test_tilelang_copy_mask_parallel():
...
@@ -49,11 +43,10 @@ def test_tilelang_copy_mask_parallel():
# add decorator @tilelang.jit if you want to return a torch function
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
# @tilelang.jit
def
tilelang_copy_mask_copy
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
def
tilelang_copy_mask_copy
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
# Initialize Kernel Context
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
...
@@ -72,13 +65,8 @@ def tilelang_copy_mask_copy(M, N, block_M, block_N, dtype="float16"):
...
@@ -72,13 +65,8 @@ def tilelang_copy_mask_copy(M, N, block_M, block_N, dtype="float16"):
def
run_tilelang_copy_mask_copy
(
M
=
1024
,
N
=
1024
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
):
def
run_tilelang_copy_mask_copy
(
M
=
1024
,
N
=
1024
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
):
program
=
tilelang_copy_mask_copy
(
M
,
N
,
block_M
,
block_N
,
dtype
)
program
=
tilelang_copy_mask_copy
(
M
,
N
,
block_M
,
block_N
,
dtype
)
kernel
=
tilelang
.
compile
(
kernel
=
tilelang
.
compile
(
program
,
program
,
out_idx
=
[
1
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_warp_specialized"
:
True
,
"tl.disable_tma_lower"
:
True
}
out_idx
=
[
1
],
)
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_warp_specialized"
:
True
,
"tl.disable_tma_lower"
:
True
})
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
b
=
kernel
(
a
)
b
=
kernel
(
a
)
torch
.
testing
.
assert_close
(
b
,
a
,
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
b
,
a
,
rtol
=
1e-2
,
atol
=
1e-2
)
...
@@ -91,11 +79,10 @@ def test_tilelang_copy_mask_copy():
...
@@ -91,11 +79,10 @@ def test_tilelang_copy_mask_copy():
# add decorator @tilelang.jit if you want to return a torch function
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
# @tilelang.jit
def
tilelang_copy_mask_parallel_range
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
def
tilelang_copy_mask_parallel_range
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
# Initialize Kernel Context
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
...
@@ -112,20 +99,11 @@ def tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype="float16"):
...
@@ -112,20 +99,11 @@ def tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype="float16"):
return
main
return
main
def
run_tilelang_copy_mask_parallel_range
(
M
=
1024
,
def
run_tilelang_copy_mask_parallel_range
(
M
=
1024
,
N
=
1024
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
):
N
=
1024
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
):
program
=
tilelang_copy_mask_parallel_range
(
M
,
N
,
block_M
,
block_N
,
dtype
)
program
=
tilelang_copy_mask_parallel_range
(
M
,
N
,
block_M
,
block_N
,
dtype
)
kernel
=
tilelang
.
compile
(
kernel
=
tilelang
.
compile
(
program
,
program
,
out_idx
=
[
1
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_warp_specialized"
:
True
,
"tl.disable_tma_lower"
:
True
}
out_idx
=
[
1
],
)
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_warp_specialized"
:
True
,
"tl.disable_tma_lower"
:
True
})
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
b
=
kernel
(
a
)
b
=
kernel
(
a
)
torch
.
testing
.
assert_close
(
b
,
a
,
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
b
,
a
,
rtol
=
1e-2
,
atol
=
1e-2
)
...
@@ -138,11 +116,10 @@ def test_tilelang_copy_mask_parallel_range():
...
@@ -138,11 +116,10 @@ def test_tilelang_copy_mask_parallel_range():
# add decorator @tilelang.jit if you want to return a torch function
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
# @tilelang.jit
def
tilelang_copy_mask_copy_range
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
def
tilelang_copy_mask_copy_range
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
# Initialize Kernel Context
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
...
@@ -161,13 +138,8 @@ def tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype="float16"):
...
@@ -161,13 +138,8 @@ def tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype="float16"):
def
run_tilelang_copy_mask_copy_range
(
M
=
1024
,
N
=
1024
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
):
def
run_tilelang_copy_mask_copy_range
(
M
=
1024
,
N
=
1024
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
):
program
=
tilelang_copy_mask_copy_range
(
M
,
N
,
block_M
,
block_N
,
dtype
)
program
=
tilelang_copy_mask_copy_range
(
M
,
N
,
block_M
,
block_N
,
dtype
)
kernel
=
tilelang
.
compile
(
kernel
=
tilelang
.
compile
(
program
,
program
,
out_idx
=
[
1
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_warp_specialized"
:
True
,
"tl.disable_tma_lower"
:
True
}
out_idx
=
[
1
],
)
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_warp_specialized"
:
True
,
"tl.disable_tma_lower"
:
True
})
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
b
=
kernel
(
a
)
b
=
kernel
(
a
)
torch
.
testing
.
assert_close
(
b
,
a
,
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
b
,
a
,
rtol
=
1e-2
,
atol
=
1e-2
)
...
...
Prev
1
…
10
11
12
13
14
15
16
17
18
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment