Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
247 additions
and
340 deletions
+247
-340
testing/python/language/test_tilelang_language_annot.py
testing/python/language/test_tilelang_language_annot.py
+19
-16
testing/python/language/test_tilelang_language_annotate_safe_value.py
...on/language/test_tilelang_language_annotate_safe_value.py
+4
-10
testing/python/language/test_tilelang_language_any_of.py
testing/python/language/test_tilelang_language_any_of.py
+20
-22
testing/python/language/test_tilelang_language_assume.py
testing/python/language/test_tilelang_language_assume.py
+11
-14
testing/python/language/test_tilelang_language_atomic_add.py
testing/python/language/test_tilelang_language_atomic_add.py
+14
-31
testing/python/language/test_tilelang_language_ceildiv.py
testing/python/language/test_tilelang_language_ceildiv.py
+0
-2
testing/python/language/test_tilelang_language_chain_equal.py
...ing/python/language/test_tilelang_language_chain_equal.py
+5
-5
testing/python/language/test_tilelang_language_clamp.py
testing/python/language/test_tilelang_language_clamp.py
+4
-4
testing/python/language/test_tilelang_language_clear.py
testing/python/language/test_tilelang_language_clear.py
+5
-6
testing/python/language/test_tilelang_language_composable_index.py
...ython/language/test_tilelang_language_composable_index.py
+4
-4
testing/python/language/test_tilelang_language_copy.py
testing/python/language/test_tilelang_language_copy.py
+18
-36
testing/python/language/test_tilelang_language_cumsum.py
testing/python/language/test_tilelang_language_cumsum.py
+17
-14
testing/python/language/test_tilelang_language_frontend_v2.py
...ing/python/language/test_tilelang_language_frontend_v2.py
+39
-48
testing/python/language/test_tilelang_language_get_warp_info.py
...g/python/language/test_tilelang_language_get_warp_info.py
+0
-5
testing/python/language/test_tilelang_language_if_range.py
testing/python/language/test_tilelang_language_if_range.py
+5
-4
testing/python/language/test_tilelang_language_infinity.py
testing/python/language/test_tilelang_language_infinity.py
+1
-2
testing/python/language/test_tilelang_language_intrinsics_codegen.py
...hon/language/test_tilelang_language_intrinsics_codegen.py
+2
-2
testing/python/language/test_tilelang_language_lazy_jit.py
testing/python/language/test_tilelang_language_lazy_jit.py
+62
-69
testing/python/language/test_tilelang_language_let.py
testing/python/language/test_tilelang_language_let.py
+0
-1
testing/python/language/test_tilelang_language_mask_op.py
testing/python/language/test_tilelang_language_mask_op.py
+17
-45
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
testing/python/language/test_tilelang_language_annot.py
View file @
29051439
...
...
@@ -5,13 +5,14 @@ import torch
def
test_tensor_annot_mul
():
@
tilelang
.
jit
def
example_tensor_annot
():
n
=
T
.
symbolic
(
'n'
)
n
=
T
.
symbolic
(
"n"
)
@
T
.
prim_func
def
kernel
(
A
:
T
.
Tensor
((
n
*
4
,),
T
.
int32
),):
def
kernel
(
A
:
T
.
Tensor
((
n
*
4
,),
T
.
int32
),
):
with
T
.
Kernel
(
1
)
as
_
:
for
i
in
range
(
n
*
4
):
A
[
i
]
=
0
...
...
@@ -19,20 +20,21 @@ def test_tensor_annot_mul():
return
kernel
ker
=
example_tensor_annot
()
A
=
torch
.
arange
(
16
,
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
A
=
torch
.
arange
(
16
,
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
ker
(
A
)
expected
=
torch
.
zeros
(
16
,
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
expected
=
torch
.
zeros
(
16
,
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
assert
torch
.
equal
(
A
,
expected
)
def
test_tensor_annot_add
():
@
tilelang
.
jit
def
example_tensor_annot
():
n
=
T
.
symbolic
(
'n'
)
n
=
T
.
symbolic
(
"n"
)
@
T
.
prim_func
def
kernel
(
A
:
T
.
Tensor
((
n
+
1
,),
T
.
int32
),):
def
kernel
(
A
:
T
.
Tensor
((
n
+
1
,),
T
.
int32
),
):
with
T
.
Kernel
(
1
)
as
_
:
for
i
in
range
(
n
+
1
):
A
[
i
]
=
0
...
...
@@ -40,20 +42,21 @@ def test_tensor_annot_add():
return
kernel
ker
=
example_tensor_annot
()
A
=
torch
.
arange
(
16
,
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
A
=
torch
.
arange
(
16
,
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
ker
(
A
)
expected
=
torch
.
zeros
(
16
,
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
expected
=
torch
.
zeros
(
16
,
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
assert
torch
.
equal
(
A
,
expected
)
def
test_tensor_annot_mul_add
():
@
tilelang
.
jit
def
example_tensor_annot
():
n
=
T
.
symbolic
(
'n'
)
n
=
T
.
symbolic
(
"n"
)
@
T
.
prim_func
def
kernel
(
A
:
T
.
Tensor
((
n
*
3
+
1
,),
T
.
int32
),):
def
kernel
(
A
:
T
.
Tensor
((
n
*
3
+
1
,),
T
.
int32
),
):
with
T
.
Kernel
(
1
)
as
_
:
for
i
in
range
(
n
*
3
+
1
):
A
[
i
]
=
0
...
...
@@ -61,11 +64,11 @@ def test_tensor_annot_mul_add():
return
kernel
ker
=
example_tensor_annot
()
A
=
torch
.
arange
(
16
,
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
A
=
torch
.
arange
(
16
,
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
ker
(
A
)
expected
=
torch
.
zeros
(
16
,
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
expected
=
torch
.
zeros
(
16
,
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
assert
torch
.
equal
(
A
,
expected
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
tilelang
.
testing
.
main
()
testing/python/language/test_tilelang_language_annotate_safe_value.py
View file @
29051439
...
...
@@ -7,11 +7,10 @@ import torch
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def
tilelang_copy
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
,
pad_value
=
0
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
...
...
@@ -30,13 +29,8 @@ def tilelang_copy(M, N, block_M, block_N, dtype="float16", pad_value=0):
def
run_tilelang_copy
(
M
=
1024
,
N
=
1024
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
,
pad_value
=
0
):
program
=
tilelang_copy
(
M
,
N
,
block_M
,
block_N
,
dtype
,
pad_value
=
pad_value
)
kernel
=
tilelang
.
compile
(
program
,
out_idx
=
[
1
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_warp_specialized"
:
True
,
"tl.disable_tma_lower"
:
True
})
program
,
out_idx
=
[
1
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_warp_specialized"
:
True
,
"tl.disable_tma_lower"
:
True
}
)
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
b
=
kernel
(
a
)
ref_b
=
torch
.
zeros_like
(
a
)
...
...
testing/python/language/test_tilelang_language_any_of.py
View file @
29051439
...
...
@@ -13,11 +13,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K):
accu
=
torch
.
zeros
((
block_M
,
block_N
),
dtype
=
torch
.
float32
,
device
=
A
.
device
)
for
k
in
range
(
K
//
block_K
):
if
torch
.
any
(
BlockMask
[
i
,
j
,
k
]):
accu
+=
A
[
i
*
block_M
:(
i
+
1
)
*
block_M
,
k
*
block_K
:(
k
+
1
)
*
block_K
].
to
(
torch
.
float32
)
@
B
[
k
*
block_K
:(
k
+
1
)
*
block_K
,
j
*
block_N
:(
j
+
1
)
*
block_N
].
to
(
torch
.
float32
)
ref_c
[
i
*
block_M
:(
i
+
1
)
*
block_M
,
j
*
block_N
:(
j
+
1
)
*
block_N
]
=
(
accu
.
to
(
torch
.
float16
))
accu
+=
A
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
k
*
block_K
:
(
k
+
1
)
*
block_K
].
to
(
torch
.
float32
)
@
B
[
k
*
block_K
:
(
k
+
1
)
*
block_K
,
j
*
block_N
:
(
j
+
1
)
*
block_N
].
to
(
torch
.
float32
)
ref_c
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
j
*
block_N
:
(
j
+
1
)
*
block_N
]
=
accu
.
to
(
torch
.
float16
)
return
ref_c
...
...
@@ -35,15 +34,14 @@ def blocksparse_matmul_global(
dtype
=
"float16"
,
accum_dtype
=
"float"
,
):
block_mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
,
condition_dim
)
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
...
@@ -80,15 +78,14 @@ def blocksparse_matmul_shared(
dtype
=
"float16"
,
accum_dtype
=
"float"
,
):
block_mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
,
condition_dim
)
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
...
@@ -130,15 +127,14 @@ def blocksparse_matmul_local(
dtype
=
"float16"
,
accum_dtype
=
"float"
,
):
block_mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
,
condition_dim
)
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
...
@@ -237,7 +233,8 @@ def run_block_sparse_matmul_shared(M=1024, N=1024, K=1024, sparsity=0.5, conditi
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
# Create block mask with desired sparsity
mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
)
block_mask
=
torch
.
rand
(
mask_shape
).
cuda
()
>
sparsity
...
...
@@ -284,7 +281,8 @@ def run_block_sparse_matmul_local(M=1024, N=1024, K=1024, sparsity=0.5, conditio
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
# Create block mask with desired sparsity
mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
)
block_mask
=
torch
.
rand
(
mask_shape
).
cuda
()
>
sparsity
...
...
testing/python/language/test_tilelang_language_assume.py
View file @
29051439
...
...
@@ -4,10 +4,9 @@ import tilelang.testing
def
test_assume_remove_boundary_check
():
@
tilelang
.
jit
def
kernel_with_assume
():
N
=
T
.
dynamic
(
'N'
)
N
=
T
.
dynamic
(
"N"
)
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
N
,),
"float32"
),
l
:
T
.
int32
,
r
:
T
.
int32
):
...
...
@@ -21,20 +20,19 @@ def test_assume_remove_boundary_check():
jit_kernel
=
kernel_with_assume
()
source
=
jit_kernel
.
get_kernel_source
()
assert
(
"if ("
not
in
source
)
assert
"if ("
not
in
source
def
test_assume_enable_vectorization
():
@
tilelang
.
jit
def
kernel_vectorize
(
M
):
N
=
T
.
dynamic
(
'N'
)
N
=
T
.
dynamic
(
"N"
)
vectorize_size
=
4
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
"float32"
),
B
:
T
.
Tensor
((
M
,
N
),
"float32"
),
A
:
T
.
Tensor
((
M
,
N
),
"float32"
),
B
:
T
.
Tensor
((
M
,
N
),
"float32"
),
):
with
T
.
Kernel
(
1
,
threads
=
32
)
as
_
:
tid
=
T
.
get_thread_binding
()
...
...
@@ -55,16 +53,15 @@ def test_assume_enable_vectorization():
def
test_assume_complex_indexing
():
@
tilelang
.
jit
def
kernel_complex
():
M
=
T
.
dynamic
(
'M'
)
N
=
T
.
dynamic
(
'N'
)
M
=
T
.
dynamic
(
"M"
)
N
=
T
.
dynamic
(
"N"
)
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
"float32"
),
B
:
T
.
Tensor
((
M
,
N
),
"float32"
),
A
:
T
.
Tensor
((
M
,
N
),
"float32"
),
B
:
T
.
Tensor
((
M
,
N
),
"float32"
),
):
with
T
.
Kernel
(
1
,
threads
=
32
)
as
_
:
tid
=
T
.
get_thread_binding
()
...
...
@@ -82,8 +79,8 @@ def test_assume_complex_indexing():
jit_kernel
=
kernel_complex
()
source
=
jit_kernel
.
get_kernel_source
()
assert
(
"if ("
not
in
source
)
assert
"if ("
not
in
source
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
tilelang
.
testing
.
main
()
testing/python/language/test_tilelang_language_atomic_add.py
View file @
29051439
...
...
@@ -4,14 +4,12 @@ import tilelang.language as T
@
tilelang
.
jit
def
atomic_add_program
(
K
,
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
@
T
.
prim_func
def
atomic_add
(
A
:
T
.
Tensor
((
K
,
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
K
,
threads
=
32
)
as
(
bx
,
by
,
bz
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
T
.
copy
(
A
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
*
block_N
:(
by
+
1
)
*
block_N
],
A_shared
)
T
.
copy
(
A
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
*
block_N
:
(
by
+
1
)
*
block_N
],
A_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
T
.
atomic_add
(
B
[
bx
*
block_M
+
i
,
by
*
block_N
+
j
],
A_shared
[
i
,
j
])
...
...
@@ -39,14 +37,12 @@ def run_atomic_add(K, M, N, block_M, block_N, dtype="float32"):
@
tilelang
.
jit
def
tile_atomic_add_program
(
K
,
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
@
T
.
prim_func
def
atomic_add
(
A
:
T
.
Tensor
((
K
,
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
K
,
threads
=
32
)
as
(
bx
,
by
,
bz
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
T
.
copy
(
A
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
*
block_N
:(
by
+
1
)
*
block_N
],
A_shared
)
T
.
copy
(
A
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
*
block_N
:
(
by
+
1
)
*
block_N
],
A_shared
)
T
.
atomic_add
(
B
[
bx
*
block_M
,
by
*
block_N
],
A_shared
)
...
...
@@ -76,14 +72,12 @@ def run_tile_atomic_add(K, M, N, block_M, block_N, dtype="float32"):
@
tilelang
.
jit
def
atomic_max_program
(
K
,
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
@
T
.
prim_func
def
atomic_max
(
A
:
T
.
Tensor
((
K
,
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
K
,
threads
=
32
)
as
(
bx
,
by
,
bz
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
T
.
copy
(
A
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
*
block_N
:(
by
+
1
)
*
block_N
],
A_shared
)
T
.
copy
(
A
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
*
block_N
:
(
by
+
1
)
*
block_N
],
A_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
T
.
atomic_max
(
B
[
bx
*
block_M
+
i
,
by
*
block_N
+
j
],
A_shared
[
i
,
j
])
...
...
@@ -111,14 +105,12 @@ def run_atomic_max(K, M, N, block_M, block_N, dtype="float32"):
@
tilelang
.
jit
def
atomic_min_program
(
K
,
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
@
T
.
prim_func
def
atomic_min
(
A
:
T
.
Tensor
((
K
,
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
K
,
threads
=
32
)
as
(
bx
,
by
,
bz
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
T
.
copy
(
A
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
*
block_N
:(
by
+
1
)
*
block_N
],
A_shared
)
T
.
copy
(
A
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
*
block_N
:
(
by
+
1
)
*
block_N
],
A_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
T
.
atomic_min
(
B
[
bx
*
block_M
+
i
,
by
*
block_N
+
j
],
A_shared
[
i
,
j
])
...
...
@@ -137,7 +129,7 @@ def run_atomic_min(K, M, N, block_M, block_N, dtype="float32"):
B
[
i
,
j
]
=
min
(
B
[
i
,
j
],
A
[
k
,
i
,
j
])
A
=
torch
.
randn
(
K
,
M
,
N
,
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
B
=
torch
.
full
((
M
,
N
),
float
(
'
inf
'
),
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
B
=
torch
.
full
((
M
,
N
),
float
(
"
inf
"
),
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
ref_B
=
B
.
clone
()
ref_program
(
A
,
ref_B
)
kernel
(
A
,
B
)
...
...
@@ -146,7 +138,6 @@ def run_atomic_min(K, M, N, block_M, block_N, dtype="float32"):
@
tilelang
.
jit
def
atomic_load_store_program
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
@
T
.
prim_func
def
atomic_load_store
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
threads
=
32
)
as
(
bx
,
by
):
...
...
@@ -172,18 +163,15 @@ def run_atomic_load_store(M, N, block_M, block_N, dtype="float32"):
@
tilelang
.
jit
def
atomic_memory_order_program
(
K
,
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
@
T
.
prim_func
def
atomic_with_memory_order
(
A
:
T
.
Tensor
((
K
,
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
K
,
threads
=
32
)
as
(
bx
,
by
,
bz
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
T
.
copy
(
A
[
bz
,
bx
*
block_M
:(
bx
+
1
)
*
block_M
,
by
*
block_N
:(
by
+
1
)
*
block_N
],
A_shared
)
T
.
copy
(
A
[
bz
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
by
*
block_N
:
(
by
+
1
)
*
block_N
],
A_shared
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
T
.
atomic_add
(
B
[
bx
*
block_M
+
i
,
by
*
block_N
+
j
],
A_shared
[
i
,
j
],
memory_order
=
"relaxed"
)
T
.
atomic_add
(
B
[
bx
*
block_M
+
i
,
by
*
block_N
+
j
],
A_shared
[
i
,
j
],
memory_order
=
"relaxed"
)
return
atomic_with_memory_order
...
...
@@ -208,7 +196,6 @@ def run_atomic_memory_order(K, M, N, block_M, block_N, dtype="float32"):
@
tilelang
.
jit
def
atomic_addx2_program
(
M
,
N
,
block_M
,
block_N
):
@
T
.
prim_func
def
atomic_addx2
(
A
:
T
.
Tensor
((
M
,
N
),
"float16"
),
B
:
T
.
Tensor
((
M
,
N
),
"float16"
)):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
threads
=
32
)
as
(
bx
,
by
):
...
...
@@ -262,10 +249,10 @@ def test_atomic_addx2():
@
tilelang
.
jit
def
atomic_different_memory_orders_program
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
@
T
.
prim_func
def
atomic_different_orders
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
(
(
M
,
N
),
dtype
),
D
:
T
.
Tensor
((
M
,
N
),
dtype
)):
def
atomic_different_orders
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
D
:
T
.
Tensor
((
M
,
N
),
dtype
)
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
threads
=
32
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
idx_i
=
bx
*
block_M
+
i
...
...
@@ -286,18 +273,17 @@ def run_atomic_different_memory_orders(M, N, block_M, block_N, dtype="float32"):
A
=
torch
.
randn
(
M
,
N
,
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
B
=
torch
.
zeros
(
M
,
N
,
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
C
=
torch
.
zeros
(
M
,
N
,
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
D
=
torch
.
full
((
M
,
N
),
float
(
'
inf
'
),
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
D
=
torch
.
full
((
M
,
N
),
float
(
"
inf
"
),
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
kernel
(
A
,
B
,
C
,
D
)
torch
.
testing
.
assert_close
(
B
,
A
,
atol
=
1e-3
,
rtol
=
1e-3
)
torch
.
testing
.
assert_close
(
C
,
torch
.
maximum
(
torch
.
zeros_like
(
A
),
A
))
torch
.
testing
.
assert_close
(
D
,
torch
.
minimum
(
torch
.
full_like
(
A
,
float
(
'
inf
'
)),
A
))
torch
.
testing
.
assert_close
(
D
,
torch
.
minimum
(
torch
.
full_like
(
A
,
float
(
"
inf
"
)),
A
))
@
tilelang
.
jit
def
atomic_addx4_program
(
M
,
N
,
block_M
,
block_N
):
@
T
.
prim_func
def
atomic_addx4
(
A
:
T
.
Tensor
((
M
,
N
),
"float32"
),
B
:
T
.
Tensor
((
M
,
N
),
"float32"
)):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
threads
=
32
)
as
(
bx
,
by
):
...
...
@@ -330,17 +316,14 @@ def run_atomic_addx4(M, N, block_M, block_N):
@
tilelang
.
jit
def
atomic_return_prev_program
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float"
):
@
T
.
prim_func
def
atomic_with_return_prev
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
old_vals
:
T
.
Tensor
((
M
,
N
),
dtype
)):
def
atomic_with_return_prev
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
old_vals
:
T
.
Tensor
((
M
,
N
),
dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
T
.
ceildiv
(
N
,
block_N
),
threads
=
32
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
idx_i
=
bx
*
block_M
+
i
idx_j
=
by
*
block_N
+
j
if
idx_i
<
M
and
idx_j
<
N
:
old_vals
[
idx_i
,
idx_j
]
=
T
.
atomic_add
(
B
[
idx_i
,
idx_j
],
A
[
idx_i
,
idx_j
],
return_prev
=
True
)
old_vals
[
idx_i
,
idx_j
]
=
T
.
atomic_add
(
B
[
idx_i
,
idx_j
],
A
[
idx_i
,
idx_j
],
return_prev
=
True
)
return
atomic_with_return_prev
...
...
testing/python/language/test_tilelang_language_ceildiv.py
View file @
29051439
...
...
@@ -5,7 +5,6 @@ import torch
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
_ceildiv_kernel
(
a
:
int
,
b
:
int
):
@
T
.
prim_func
def
ceildiv_kernel
(
A
:
T
.
Tensor
((
1
,),
"int32"
)):
with
T
.
Kernel
(
1
,
threads
=
1
)
as
_
:
...
...
@@ -30,7 +29,6 @@ def test_ceildiv():
@
tilelang
.
jit
def
_ceildiv_kernel_dyn
(
b
:
int
):
@
T
.
prim_func
def
ceildiv_kernel
(
A
:
T
.
Tensor
((
1
,),
"int32"
),
a
:
T
.
int32
):
with
T
.
Kernel
(
1
,
threads
=
1
)
as
_
:
...
...
testing/python/language/test_tilelang_language_chain_equal.py
View file @
29051439
...
...
@@ -8,14 +8,14 @@ import torch
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
},)
},
)
def
chain_equal
(
N
,
block_size
,
dtype
=
"float32"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
C
:
T
.
Tensor
((
N
,),
dtype
),
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
C
:
T
.
Tensor
((
N
,),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_size
),
threads
=
block_size
)
as
bx
:
for
lane
in
T
.
Parallel
(
block_size
):
...
...
testing/python/language/test_tilelang_language_clamp.py
View file @
29051439
...
...
@@ -13,8 +13,8 @@ def clamp_within_bounds(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
block_N
)
as
bx
:
A_shared
=
T
.
alloc_shared
([
block_N
],
dtype
)
...
...
@@ -56,8 +56,8 @@ def clamp_value_range(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
1
,
N
),
dtype
),
B
:
T
.
Tensor
((
1
,
N
),
dtype
),
A
:
T
.
Tensor
((
1
,
N
),
dtype
),
B
:
T
.
Tensor
((
1
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
block_N
)
as
bx
:
# A_shared = T.alloc_shared([1, block_N], dtype=dtype)
...
...
testing/python/language/test_tilelang_language_clear.py
View file @
29051439
...
...
@@ -5,12 +5,11 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
...
...
@@ -42,10 +41,10 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
def
run_matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
program
=
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
,
accum_dtype
)
kernel
=
tilelang
.
compile
(
program
,
out_idx
=
[
2
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_tma_lower"
:
True
})
kernel
=
tilelang
.
compile
(
program
,
out_idx
=
[
2
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_tma_lower"
:
True
})
import
torch
from
tilelang.utils
import
map_torch_type
a
=
torch
.
randn
((
M
,
K
),
dtype
=
map_torch_type
(
dtype
)).
cuda
()
b
=
torch
.
randn
((
N
,
K
),
dtype
=
map_torch_type
(
dtype
)).
cuda
()
c
=
kernel
(
a
,
b
)
...
...
testing/python/language/test_tilelang_language_composable_index.py
View file @
29051439
...
...
@@ -7,11 +7,10 @@ import torch
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def
tilelang_composable_copy
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
*
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
*
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
...
...
@@ -35,7 +34,8 @@ def run_tilelang_composable_copy(M=1024, N=1024, block_M=128, block_N=128, dtype
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
b
=
kernel
(
a
)
torch
.
testing
.
assert_close
(
b
.
flatten
(),
a
.
flatten
(),
rtol
=
1e-2
,
atol
=
1e-2
)
...
...
testing/python/language/test_tilelang_language_copy.py
View file @
29051439
...
...
@@ -7,11 +7,10 @@ import tilelang.testing
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def
tilelang_copy
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
...
...
@@ -27,10 +26,8 @@ def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16")
program
,
out_idx
=
[
1
],
target
=
"cuda"
,
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
})
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
},
)
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
b
=
kernel
(
a
)
torch
.
testing
.
assert_close
(
b
,
a
,
rtol
=
1e-2
,
atol
=
1e-2
)
...
...
@@ -43,11 +40,10 @@ def test_tilelang_copy():
def
tilelang_copy_with_stride
(
M
,
N
,
NN
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
def
main
(
A
:
T
.
StridedTensor
((
M
,
N
),
(
NN
,
1
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
StridedTensor
((
M
,
N
),
(
NN
,
1
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
...
...
@@ -57,12 +53,7 @@ def tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype="float16"):
return
main
def
run_tilelang_copy_with_stride
(
M
=
1024
,
N
=
1024
,
NN
=
2048
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
):
def
run_tilelang_copy_with_stride
(
M
=
1024
,
N
=
1024
,
NN
=
2048
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
):
if
isinstance
(
NN
,
int
):
assert
NN
>
N
,
"NN must be greater than N"
program
=
tilelang_copy_with_stride
(
M
,
N
,
NN
,
block_M
,
block_N
,
dtype
)
...
...
@@ -73,7 +64,8 @@ def run_tilelang_copy_with_stride(M=1024,
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
})
},
)
if
isinstance
(
NN
,
T
.
Var
):
NN
=
N
*
2
a
=
torch
.
randn
(
M
,
NN
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
...
...
@@ -87,11 +79,10 @@ def test_tilelang_copy_with_stride():
def
tilelang_copy_bufferload
(
num_tokens
,
dtype
=
"float16"
):
@
T
.
prim_func
def
main
(
indices
:
T
.
Tensor
((
num_tokens
,),
"int32"
),
x
:
T
.
Tensor
((
num_tokens
,),
dtype
),
indices
:
T
.
Tensor
((
num_tokens
,),
"int32"
),
x
:
T
.
Tensor
((
num_tokens
,),
dtype
),
):
with
T
.
Kernel
(
num_tokens
,
threads
=
32
)
as
pid
:
idx
=
T
.
alloc_local
([
1
],
"int32"
)
...
...
@@ -107,10 +98,8 @@ def run_tilelang_copy_bufferload(num_tokens=128, dtype="float16"):
tilelang
.
compile
(
program
,
out_idx
=
[
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
})
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
},
)
def
test_tilelang_copy_bufferload
():
...
...
@@ -118,11 +107,10 @@ def test_tilelang_copy_bufferload():
def
tilelang_copy_buffer_load_with_parallel
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
...
...
@@ -132,20 +120,14 @@ def tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype="float
return
main
def
run_tilelang_copy_buffer_load_with_parallel
(
M
=
1024
,
N
=
1024
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
):
def
run_tilelang_copy_buffer_load_with_parallel
(
M
=
1024
,
N
=
1024
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
):
program
=
tilelang_copy_buffer_load_with_parallel
(
M
,
N
,
block_M
,
block_N
,
dtype
)
kernel
=
tilelang
.
compile
(
program
,
out_idx
=
[
1
],
target
=
"cuda"
,
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
})
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
},
)
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
b
=
kernel
(
a
)
torch
.
testing
.
assert_close
(
b
,
a
,
rtol
=
1e-2
,
atol
=
1e-2
)
...
...
testing/python/language/test_tilelang_language_cumsum.py
View file @
29051439
...
...
@@ -9,8 +9,8 @@ def cumsum_smem_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="float3
@
T
.
prim_func
def
cumsum
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
...
...
@@ -28,8 +28,8 @@ def cumsum_fragment_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="fl
@
T
.
prim_func
def
cumsum
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
...
...
@@ -57,13 +57,16 @@ def run_cumsum(M, N, block_M, block_N, dim=0, reverse=False, dtype="float32", sc
ref_b
=
torch
.
empty_like
(
A
)
for
i
in
range
(
M
//
block_M
):
for
j
in
range
(
N
//
block_N
):
ref_b
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
j
*
block_
N
:(
j
+
1
)
*
block_
N
]
=
A
[
i
*
block_
M
:(
i
+
1
)
*
block_
M
,
j
*
block_N
:(
j
+
1
)
*
block_N
].
cumsum
(
dim
=
dim
)
ref_b
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
j
*
block_N
:
(
j
+
1
)
*
block_N
]
=
A
[
i
*
block_
M
:
(
i
+
1
)
*
block_
M
,
j
*
block_
N
:
(
j
+
1
)
*
block_
N
].
cumsum
(
dim
=
dim
)
if
reverse
:
ref_b
[
i
*
block_M
:(
i
+
1
)
*
block_M
,
j
*
block_N
:(
j
+
1
)
*
block_N
]
=
A
[
i
*
block_M
:(
i
+
1
)
*
block_M
,
j
*
block_N
:(
j
+
1
)
*
block_N
].
flip
(
dims
=
[
dim
]).
cumsum
(
dim
=
dim
).
flip
(
dims
=
[
dim
])
ref_b
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
j
*
block_N
:
(
j
+
1
)
*
block_N
]
=
(
A
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
j
*
block_N
:
(
j
+
1
)
*
block_N
]
.
flip
(
dims
=
[
dim
])
.
cumsum
(
dim
=
dim
)
.
flip
(
dims
=
[
dim
])
)
return
ref_b
tilelang_res
=
jit_kernel
(
A
)
...
...
@@ -76,8 +79,8 @@ def cumsum_smem_test_1d(N, block_N, reverse=False, dtype="float32"):
@
T
.
prim_func
def
cumsum
(
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
block_N
)
as
bx
:
A_shared
=
T
.
alloc_shared
((
block_N
,),
dtype
)
...
...
@@ -94,8 +97,8 @@ def cumsum_fragment_test_1d(N, block_N, reverse=False, dtype="float32"):
@
T
.
prim_func
def
cumsum
(
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
block_N
)
as
bx
:
A_shared
=
T
.
alloc_shared
((
block_N
,),
dtype
)
...
...
testing/python/language/test_tilelang_language_frontend_v2.py
View file @
29051439
...
...
@@ -8,7 +8,6 @@ from tvm.tir.expr import IntImm, Var
def
test_argument
():
@
T
.
prim_func
def
test_argument
(
t_1
:
T
.
bool
,
...
...
@@ -41,6 +40,7 @@ def test_argument():
def
test_expr
():
from
tilelang.language.v2.dtypes
import
_all_dtypes
errors
=
[]
for
name
in
_all_dtypes
:
dtype
=
getattr
(
T
,
name
)
...
...
@@ -116,33 +116,32 @@ def test_expr():
def
test_dtype_str_repr
():
@
T
.
prim_func
def
test_str_repr
():
buf_1
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
bool
,
scope
=
'
shared
'
)
# noqa F841
buf_2
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
short
,
scope
=
'
shared
'
)
# noqa F841
buf_3
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int
,
scope
=
'
shared
'
)
# noqa F841
buf_4
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
long
,
scope
=
'
shared
'
)
# noqa F841
buf_5
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
half
,
scope
=
'
shared
'
)
# noqa F841
buf_6
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float
,
scope
=
'
shared
'
)
# noqa F841
buf_7
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
long
,
scope
=
'
shared
'
)
# noqa F841
buf_8
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int8
,
scope
=
'
shared
'
)
# noqa F841
buf_9
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int16
,
scope
=
'
shared
'
)
# noqa F841
buf_10
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int32
,
scope
=
'
shared
'
)
# noqa F841
buf_11
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int64
,
scope
=
'
shared
'
)
# noqa F841
buf_12
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
uint8
,
scope
=
'
shared
'
)
# noqa F841
buf_13
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
uint16
,
scope
=
'
shared
'
)
# noqa F841
buf_14
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
uint32
,
scope
=
'
shared
'
)
# noqa F841
buf_15
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
uint64
,
scope
=
'
shared
'
)
# noqa F841
buf_16
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e4m3fn
,
scope
=
'
shared
'
)
# noqa F841
buf_17
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e4m3fnuz
,
scope
=
'
shared
'
)
# noqa F841
buf_18
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e5m2
,
scope
=
'
shared
'
)
# noqa F841
buf_19
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e5m2fnuz
,
scope
=
'
shared
'
)
# noqa F841
buf_20
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e8m0fnu
,
scope
=
'
shared
'
)
# noqa F841
buf_21
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float16
,
scope
=
'
shared
'
)
# noqa F841
buf_22
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
bfloat16
,
scope
=
'
shared
'
)
# noqa F841
buf_23
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float32
,
scope
=
'
shared
'
)
# noqa F841
buf_24
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float64
,
scope
=
'
shared
'
)
# noqa F841
buf_1
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
bool
,
scope
=
"
shared
"
)
# noqa F841
buf_2
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
short
,
scope
=
"
shared
"
)
# noqa F841
buf_3
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int
,
scope
=
"
shared
"
)
# noqa F841
buf_4
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
long
,
scope
=
"
shared
"
)
# noqa F841
buf_5
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
half
,
scope
=
"
shared
"
)
# noqa F841
buf_6
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float
,
scope
=
"
shared
"
)
# noqa F841
buf_7
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
long
,
scope
=
"
shared
"
)
# noqa F841
buf_8
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int8
,
scope
=
"
shared
"
)
# noqa F841
buf_9
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int16
,
scope
=
"
shared
"
)
# noqa F841
buf_10
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int32
,
scope
=
"
shared
"
)
# noqa F841
buf_11
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int64
,
scope
=
"
shared
"
)
# noqa F841
buf_12
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
uint8
,
scope
=
"
shared
"
)
# noqa F841
buf_13
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
uint16
,
scope
=
"
shared
"
)
# noqa F841
buf_14
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
uint32
,
scope
=
"
shared
"
)
# noqa F841
buf_15
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
uint64
,
scope
=
"
shared
"
)
# noqa F841
buf_16
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e4m3fn
,
scope
=
"
shared
"
)
# noqa F841
buf_17
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e4m3fnuz
,
scope
=
"
shared
"
)
# noqa F841
buf_18
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e5m2
,
scope
=
"
shared
"
)
# noqa F841
buf_19
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e5m2fnuz
,
scope
=
"
shared
"
)
# noqa F841
buf_20
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e8m0fnu
,
scope
=
"
shared
"
)
# noqa F841
buf_21
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float16
,
scope
=
"
shared
"
)
# noqa F841
buf_22
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
bfloat16
,
scope
=
"
shared
"
)
# noqa F841
buf_23
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float32
,
scope
=
"
shared
"
)
# noqa F841
buf_24
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float64
,
scope
=
"
shared
"
)
# noqa F841
# not supported now
...
...
@@ -205,7 +204,6 @@ def test_dtype_str_repr():
def
test_var_assign
():
@
tilelang
.
jit
(
out_idx
=-
1
)
@
T
.
prim_func
def
test_var_assign
(
A
:
T
.
Tensor
((
2
,),
T
.
int32
)):
...
...
@@ -223,7 +221,6 @@ def test_var_assign():
def
test_marco_return
():
@
T
.
macro
def
macro_return_constant
():
return
0
...
...
@@ -258,11 +255,10 @@ def test_marco_return():
def
test_prim_func_generator
():
@
T
.
prim_func
(
generator
=
True
)
def
prim_func_gen
(
A
=
T
.
Tensor
((
128
,),
T
.
float32
),
# noqa: B008
B
=
T
.
Tensor
((
128
,),
T
.
float32
),
# noqa: B008
A
=
T
.
Tensor
((
128
,),
T
.
float32
),
# noqa: B008
B
=
T
.
Tensor
((
128
,),
T
.
float32
),
# noqa: B008
):
with
T
.
Kernel
(
128
)
as
(
tx
,):
T
.
copy
(
A
[
tx
],
B
[
tx
])
...
...
@@ -277,7 +273,6 @@ def test_prim_func_generator():
def
test_serial_for_with_step
():
@
tilelang
.
jit
(
out_idx
=-
1
)
@
T
.
prim_func
def
test_stepped_serial
(
A
:
T
.
Tensor
((
10
,),
T
.
int32
)):
...
...
@@ -291,7 +286,7 @@ def test_serial_for_with_step():
ker
=
test_stepped_serial
()
res
=
ker
()
ref
=
torch
.
tensor
([
1
,
2
,
1
,
2
,
1
,
2
,
1
,
2
,
1
,
2
],
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
ref
=
torch
.
tensor
([
1
,
2
,
1
,
2
,
1
,
2
,
1
,
2
,
1
,
2
],
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
assert
torch
.
all
(
res
==
ref
),
f
"Expected
{
ref
}
, but got
{
res
}
"
@
tilelang
.
jit
(
out_idx
=-
1
)
...
...
@@ -304,17 +299,16 @@ def test_serial_for_with_step():
ker
=
test_serial_step_neg
()
res
=
ker
()
ref
=
torch
.
tensor
([
10
,
9
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
],
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
ref
=
torch
.
tensor
([
10
,
9
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
],
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
assert
torch
.
all
(
res
==
ref
),
f
"Expected
{
ref
}
, but got
{
res
}
"
assert
isinstance
(
T
.
serial
(
1
,
10
,
1
),
IRBuilderFrame
)
assert
isinstance
(
T
.
serial
(
1
,
10
,
IntImm
(
'
int32
'
,
1
)),
IRBuilderFrame
)
assert
not
isinstance
(
T
.
serial
(
1
,
10
,
Var
(
'
tmp
'
,
'
int32
'
)),
IRBuilderFrame
)
assert
isinstance
(
T
.
serial
(
1
,
10
,
IntImm
(
"
int32
"
,
1
)),
IRBuilderFrame
)
assert
not
isinstance
(
T
.
serial
(
1
,
10
,
Var
(
"
tmp
"
,
"
int32
"
)),
IRBuilderFrame
)
assert
not
isinstance
(
T
.
serial
(
10
,
-
1
,
-
1
),
IRBuilderFrame
)
def
test_swap_logic
():
@
tilelang
.
jit
@
T
.
prim_func
def
swap_var
(
A
:
T
.
Tensor
[(
2
,),
T
.
float32
]):
...
...
@@ -344,7 +338,6 @@ def test_swap_logic():
def
test_while_loop
():
@
tilelang
.
jit
(
out_idx
=-
1
)
@
T
.
prim_func
def
test_while_loop
(
A
:
T
.
Tensor
((
1
,),
T
.
int32
)):
...
...
@@ -374,7 +367,7 @@ def test_var_macro():
x
=
T
.
alloc_var
(
T
.
int32
)
macro_with_var
(
x
)
assert
'
x[0] = 1
'
in
prim_call_macro
.
script
()
assert
"
x[0] = 1
"
in
prim_call_macro
.
script
()
finally
:
pass
...
...
@@ -406,7 +399,7 @@ def test_var_macro():
x
=
T
.
alloc_var
(
T
.
int32
)
macro_with_var
(
x
)
assert
'
x[0] = 1
'
in
prim_call_macro
.
script
()
assert
"
x[0] = 1
"
in
prim_call_macro
.
script
()
finally
:
pass
...
...
@@ -428,10 +421,8 @@ def test_var_macro():
def
test_frame_inside_macro
():
@
tilelang
.
jit
def
get_sample_kernel
():
@
T
.
macro
def
transform
(
x
):
return
x
+
1
...
...
@@ -442,7 +433,7 @@ def test_frame_inside_macro():
idx_out
:
T
.
Tensor
[(
32
,),
T
.
int32
],
):
with
T
.
Kernel
(
num_blocks
,
threads
=
32
)
as
block_idx
:
# noqa: F841
fragment
=
T
.
alloc_fragment
(
32
,
'
int32
'
)
fragment
=
T
.
alloc_fragment
(
32
,
"
int32
"
)
T
.
copy
(
idx_out
,
fragment
)
for
i
in
T
.
Parallel
(
32
):
...
...
@@ -467,10 +458,10 @@ def test_buffer_slice_step():
def
test_boolop
():
a
=
Var
(
'a'
,
'
int32
'
)
b
=
Var
(
'b'
,
'
int32
'
)
c
=
Var
(
'c'
,
'
int32
'
)
d
=
Var
(
'd'
,
'
int32
'
)
a
=
Var
(
"a"
,
"
int32
"
)
b
=
Var
(
"b"
,
"
int32
"
)
c
=
Var
(
"c"
,
"
int32
"
)
d
=
Var
(
"d"
,
"
int32
"
)
@
T
.
macro
def
cond
():
...
...
@@ -479,5 +470,5 @@ def test_boolop():
cond
()
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
tilelang
.
testing
.
main
()
testing/python/language/test_tilelang_language_get_warp_info.py
View file @
29051439
...
...
@@ -23,7 +23,6 @@ def _resolve_warps_per_group(warps_per_group: Optional[int]) -> int:
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
_get_laneid_kernel
(
num_threads
:
int
=
128
,
warp_size
:
Optional
[
int
]
=
None
):
@
T
.
prim_func
def
laneid_kernel
(
A
:
T
.
Tensor
((
num_threads
,),
"int32"
)):
with
T
.
Kernel
(
1
,
threads
=
num_threads
)
as
_
:
...
...
@@ -35,7 +34,6 @@ def _get_laneid_kernel(num_threads: int = 128, warp_size: Optional[int] = None):
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
_get_warp_idx_sync_kernel
(
num_threads
:
int
=
128
,
warp_size
:
Optional
[
int
]
=
None
):
@
T
.
prim_func
def
warp_idx_sync_kernel
(
A
:
T
.
Tensor
((
num_threads
,),
"int32"
)):
with
T
.
Kernel
(
1
,
threads
=
num_threads
)
as
_
:
...
...
@@ -47,7 +45,6 @@ def _get_warp_idx_sync_kernel(num_threads: int = 128, warp_size: Optional[int] =
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
_get_warp_idx_kernel
(
num_threads
:
int
=
128
,
warp_size
:
Optional
[
int
]
=
None
):
@
T
.
prim_func
def
warp_idx_kernel
(
A
:
T
.
Tensor
((
num_threads
,),
"int32"
)):
with
T
.
Kernel
(
1
,
threads
=
num_threads
)
as
_
:
...
...
@@ -63,7 +60,6 @@ def _get_warp_group_idx_kernel(
warp_size
:
Optional
[
int
]
=
None
,
warps_per_group
:
Optional
[
int
]
=
None
,
):
@
T
.
prim_func
def
warp_group_idx_kernel
(
A
:
T
.
Tensor
((
num_threads
,),
"int32"
)):
with
T
.
Kernel
(
1
,
threads
=
num_threads
)
as
_
:
...
...
@@ -75,7 +71,6 @@ def _get_warp_group_idx_kernel(
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
_shuffle_elect_kernel
(
num_threads
:
int
=
128
,
thread_extent
:
int
=
64
):
@
T
.
prim_func
def
shuffle_elect_kernel
(
A
:
T
.
Tensor
((
num_threads
,),
"int32"
)):
with
T
.
Kernel
(
1
,
threads
=
num_threads
)
as
_
:
...
...
testing/python/language/test_tilelang_language_if_range.py
View file @
29051439
...
...
@@ -4,13 +4,14 @@ import torch
import
tilelang.testing
@
tilelang
.
jit
(
out_idx
=
[
1
],)
@
tilelang
.
jit
(
out_idx
=
[
1
],
)
def
tilelang_if_range
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
...
...
testing/python/language/test_tilelang_language_infinity.py
View file @
29051439
...
...
@@ -5,7 +5,6 @@ import tilelang.language as T
@
tilelang
.
jit
(
out_idx
=-
1
)
def
get_inf_kernel
(
dtype
:
str
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
32
,),
dtype
)):
with
T
.
Kernel
(
1
,
threads
=
32
):
...
...
@@ -18,7 +17,7 @@ def _test_infinity(dtype: str):
kernel
=
get_inf_kernel
(
dtype
)
output
=
kernel
()
assert
torch
.
all
(
output
==
torch
.
inf
),
f
'
check failed for
{
dtype
=
}
'
assert
torch
.
all
(
output
==
torch
.
inf
),
f
"
check failed for
{
dtype
=
}
"
@
tilelang
.
testing
.
requires_cuda
...
...
testing/python/language/test_tilelang_language_intrinsics_codegen.py
View file @
29051439
...
...
@@ -9,8 +9,8 @@ def test_language_ldg_codegen():
@
T
.
prim_func
def
main
(
x
:
T
.
Tensor
((
N
,),
"float32"
),
y
:
T
.
Tensor
((
N
,),
"float32"
),
x
:
T
.
Tensor
((
N
,),
"float32"
),
y
:
T
.
Tensor
((
N
,),
"float32"
),
):
with
T
.
Kernel
(
N
,
threads
=
32
)
as
pid
:
# Explicitly request read-only cache load for x[pid]
...
...
testing/python/language/test_tilelang_language_lazy_jit.py
View file @
29051439
...
...
@@ -8,7 +8,6 @@ import torch
def
_gemm_impl
():
@
T
.
macro
def
gemm_impl
(
A
:
T
.
Tensor
[[
int
,
int
],
Any
],
...
...
@@ -37,7 +36,6 @@ def _gemm_impl():
def
test_jit2_gemm_annot
():
@
tilelang
.
lazy_jit
def
gemm
(
A
:
T
.
Tensor
[[
int
,
int
],
Any
],
...
...
@@ -54,24 +52,24 @@ def test_jit2_gemm_annot():
return
C
prod
=
product
([
T
.
float16
,
T
.
float32
],
[
T
.
float32
])
gemm
.
par_compile
([{
'A'
:
T
.
Tensor
((
1024
,
1024
),
dtype
=
in_dtype
),
'B'
:
T
.
Tensor
((
1024
,
1024
),
dtype
=
in_dtype
),
'out_dtype'
:
out_dtype
}
for
in_dtype
,
out_dtype
in
prod
])
gemm
.
par_compile
(
[
{
"A"
:
T
.
Tensor
((
1024
,
1024
),
dtype
=
in_dtype
),
"B"
:
T
.
Tensor
((
1024
,
1024
),
dtype
=
in_dtype
),
"out_dtype"
:
out_dtype
}
for
in_dtype
,
out_dtype
in
prod
]
)
for
in_dtype
,
out_dtype
in
prod
:
in_dtype
=
in_dtype
.
torch
()
out_dtype
=
out_dtype
.
torch
()
A
=
torch
.
randn
(
1024
,
1024
,
dtype
=
in_dtype
,
device
=
'
cuda
'
)
B
=
torch
.
randn
(
1024
,
1024
,
dtype
=
in_dtype
,
device
=
'
cuda
'
)
A
=
torch
.
randn
(
1024
,
1024
,
dtype
=
in_dtype
,
device
=
"
cuda
"
)
B
=
torch
.
randn
(
1024
,
1024
,
dtype
=
in_dtype
,
device
=
"
cuda
"
)
C_ref
=
out_dtype
(
A
@
B
)
C
=
gemm
(
A
,
B
)
torch
.
testing
.
assert_close
(
C
,
C_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
def
test_jit2_gemm_ptr
():
@
tilelang
.
lazy_jit
def
gemm_ptr
(
A
:
T
.
ptr
,
...
...
@@ -92,23 +90,19 @@ def test_jit2_gemm_ptr():
_gemm_impl
()(
A
,
B
,
C
,
out_dtype
,
block_M
,
block_N
,
block_K
)
prod
=
product
([
T
.
float16
,
T
.
float32
],
[
T
.
float32
])
gemm_ptr
.
par_compile
([{
'A'
:
T
.
ptr
(),
'B'
:
T
.
ptr
(),
'C'
:
T
.
ptr
(),
'M'
:
1024
,
'N'
:
1024
,
'K'
:
1024
,
'dtype'
:
in_dtype
,
'out_dtype'
:
out_dtype
}
for
in_dtype
,
out_dtype
in
prod
])
gemm_ptr
.
par_compile
(
[
{
"A"
:
T
.
ptr
(),
"B"
:
T
.
ptr
(),
"C"
:
T
.
ptr
(),
"M"
:
1024
,
"N"
:
1024
,
"K"
:
1024
,
"dtype"
:
in_dtype
,
"out_dtype"
:
out_dtype
}
for
in_dtype
,
out_dtype
in
prod
]
)
for
in_dtype
,
out_dtype
in
prod
:
in_dtype
=
in_dtype
.
torch
()
out_dtype
=
out_dtype
.
torch
()
A
=
torch
.
randn
(
1024
,
1024
,
dtype
=
in_dtype
,
device
=
'
cuda
'
)
B
=
torch
.
randn
(
1024
,
1024
,
dtype
=
in_dtype
,
device
=
'
cuda
'
)
A
=
torch
.
randn
(
1024
,
1024
,
dtype
=
in_dtype
,
device
=
"
cuda
"
)
B
=
torch
.
randn
(
1024
,
1024
,
dtype
=
in_dtype
,
device
=
"
cuda
"
)
C_ref
=
out_dtype
(
A
@
B
)
C
=
torch
.
empty
(
1024
,
1024
,
dtype
=
out_dtype
,
device
=
'
cuda
'
)
C
=
torch
.
empty
(
1024
,
1024
,
dtype
=
out_dtype
,
device
=
"
cuda
"
)
gemm_ptr
(
A
,
B
,
C
,
1024
,
1024
,
1024
,
in_dtype
,
out_dtype
)
torch
.
testing
.
assert_close
(
C
,
C_ref
,
atol
=
1e-2
,
rtol
=
1e-2
)
...
...
@@ -129,8 +123,7 @@ def test_jit2_annot():
AnnotTest
(
annot
=
T
.
Tensor
[[
int
,
int
],
T
.
float32
],
promote
=
False
,
match_ok
=
[
torch
.
randn
(
1
,
1
,
dtype
=
torch
.
float32
),
T
.
Tensor
((
1
,
1
),
dtype
=
T
.
float32
)],
match_ok
=
[
torch
.
randn
(
1
,
1
,
dtype
=
torch
.
float32
),
T
.
Tensor
((
1
,
1
),
dtype
=
T
.
float32
)],
match_ng
=
[
torch
.
randn
(
1
,
1
,
dtype
=
torch
.
float16
),
T
.
Tensor
(
1
,
dtype
=
T
.
float32
),
...
...
@@ -146,8 +139,8 @@ def test_jit2_annot():
T
.
Tensor
((
1
,),
dtype
=
T
.
float32
),
T
.
Tensor
((
1
,),
dtype
=
T
.
float16
),
],
match_ng
=
[
torch
.
randn
((
1
,
1
),
dtype
=
torch
.
float32
),
T
.
Tensor
((
1
,
1
),
dtype
=
T
.
float16
)]
),
match_ng
=
[
torch
.
randn
((
1
,
1
),
dtype
=
torch
.
float32
),
T
.
Tensor
((
1
,
1
),
dtype
=
T
.
float16
)],
),
AnnotTest
(
annot
=
T
.
Tensor
[[
int
,
1
],
Any
],
promote
=
False
,
...
...
@@ -157,8 +150,8 @@ def test_jit2_annot():
T
.
Tensor
((
12
,
1
),
T
.
float32
),
T
.
Tensor
((
12
,
1
),
T
.
float16
),
],
match_ng
=
[
torch
.
randn
(
12
,
12
,
dtype
=
torch
.
float32
),
T
.
Tensor
((
12
,
12
),
T
.
float32
)]
),
match_ng
=
[
torch
.
randn
(
12
,
12
,
dtype
=
torch
.
float32
),
T
.
Tensor
((
12
,
12
),
T
.
float32
)],
),
AnnotTest
(
annot
=
T
.
Tensor
[[
T
.
dyn
,
1
],
Any
],
promote
=
False
,
...
...
@@ -168,43 +161,39 @@ def test_jit2_annot():
T
.
Tensor
((
12
,
1
),
T
.
float32
),
T
.
Tensor
((
12
,
1
),
T
.
float16
),
],
match_ng
=
[
torch
.
randn
(
12
,
12
,
dtype
=
torch
.
float32
),
T
.
Tensor
((
12
,
12
),
T
.
float32
)]
),
match_ng
=
[
torch
.
randn
(
12
,
12
,
dtype
=
torch
.
float32
),
T
.
Tensor
((
12
,
12
),
T
.
float32
)],
),
AnnotTest
(
annot
=
T
.
Tensor
[[
1024
,
1024
],
T
.
float32
],
promote
=
True
,
),
AnnotTest
(
annot
=
T
.
dyn
[
int
,
'X'
],
promote
=
False
,
match_ok
=
[
1
,
2
,
3
,
4
]),
AnnotTest
(
annot
=
T
.
dyn
,
promote
=
False
,
match_ok
=
[
1
,
2
,
3
,
4
])
AnnotTest
(
annot
=
T
.
dyn
[
int
,
"X"
],
promote
=
False
,
match_ok
=
[
1
,
2
,
3
,
4
]),
AnnotTest
(
annot
=
T
.
dyn
,
promote
=
False
,
match_ok
=
[
1
,
2
,
3
,
4
])
,
]
for
test
in
tests
:
promote
=
test
.
annot
.
promote
()
promoted
=
promote
is
not
None
if
promoted
!=
test
.
promote
:
raise
AssertionError
(
f
'Promote mismatch for
{
test
.
annot
}
: expected
{
test
.
promote
}
, got
{
promoted
}
'
)
with
Builder
().
prim_func
(
'_test'
):
raise
AssertionError
(
f
"Promote mismatch for
{
test
.
annot
}
: expected
{
test
.
promote
}
, got
{
promoted
}
"
)
with
Builder
().
prim_func
(
"_test"
):
for
match_ok
in
test
.
match_ok
:
try
:
vt
=
ArgVarTable
()
test
.
annot
.
create_prim_func_arg
(
'
arg
'
,
match_ok
,
vt
)
test
.
annot
.
create_prim_func_arg
(
"
arg
"
,
match_ok
,
vt
)
except
Exception
as
e
:
traceback
.
print_exc
()
raise
AssertionError
(
f
'Match failed for
{
test
.
annot
}
with value
{
match_ok
}
:
{
e
}
'
)
from
e
raise
AssertionError
(
f
"Match failed for
{
test
.
annot
}
with value
{
match_ok
}
:
{
e
}
"
)
from
e
for
match_ng
in
test
.
match_ng
:
try
:
vt
=
ArgVarTable
()
test
.
annot
.
create_prim_func_arg
(
'arg'
,
match_ng
,
vt
)
raise
AssertionError
(
f
'Match unexpectedly succeeded for
{
test
.
annot
}
with value
{
match_ng
}
'
)
test
.
annot
.
create_prim_func_arg
(
"arg"
,
match_ng
,
vt
)
raise
AssertionError
(
f
"Match unexpectedly succeeded for
{
test
.
annot
}
with value
{
match_ng
}
"
)
except
Exception
:
pass
def
test_jit2_many_annot
():
@
T
.
macro
def
copy_impl
(
A
,
B
):
M
,
N
=
A
.
shape
...
...
@@ -213,8 +202,7 @@ def test_jit2_many_annot():
assert
N
==
N_
,
f
"N mismatch
{
N
}
{
N_
}
"
# assert tuple(A.shape) == tuple(B.shape), f"Invalid tensor shape: {A.shape}, {B.shape}"
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
128
),
T
.
ceildiv
(
N
,
128
),
threads
=
128
)
as
(
bx
,
by
):
T
.
copy
(
A
[
bx
*
128
:
bx
*
128
+
128
,
by
*
128
:
by
*
128
+
128
],
B
[
bx
*
128
:
bx
*
128
+
128
,
by
*
128
:
by
*
128
+
128
])
T
.
copy
(
A
[
bx
*
128
:
bx
*
128
+
128
,
by
*
128
:
by
*
128
+
128
],
B
[
bx
*
128
:
bx
*
128
+
128
,
by
*
128
:
by
*
128
+
128
])
@
tilelang
.
lazy_jit
def
copy1
(
...
...
@@ -259,20 +247,19 @@ def test_jit2_many_annot():
copy_impl
(
A
,
B
)
for
copy
in
[
copy1
,
copy2
,
copy3
,
copy4
]:
A
=
torch
.
randn
(
128
,
128
,
device
=
'
cuda
'
)
B
=
torch
.
empty
(
128
,
128
,
device
=
'
cuda
'
)
A
=
torch
.
randn
(
128
,
128
,
device
=
"
cuda
"
)
B
=
torch
.
empty
(
128
,
128
,
device
=
"
cuda
"
)
copy
(
A
,
B
)
assert
torch
.
equal
(
B
,
A
)
for
copy
in
[
copy5
,
copy6
]:
A
=
torch
.
randn
(
128
,
2
,
128
,
2
,
device
=
'
cuda
'
)
B
=
torch
.
randn
(
128
,
2
,
128
,
2
,
device
=
'
cuda
'
)
A
=
torch
.
randn
(
128
,
2
,
128
,
2
,
device
=
"
cuda
"
)
B
=
torch
.
randn
(
128
,
2
,
128
,
2
,
device
=
"
cuda
"
)
copy
(
A
[:,
0
,
:,
0
],
B
[:,
0
,
:,
0
])
assert
torch
.
equal
(
A
[:,
0
,
:,
0
],
B
[:,
0
,
:,
0
])
def
test_jit2_return
():
@
T
.
macro
def
copy_impl
(
A
):
M
,
N
=
A
.
shape
...
...
@@ -283,8 +270,7 @@ def test_jit2_return():
assert
N
==
N_
,
f
"N mismatch
{
N
}
{
N_
}
"
# assert tuple(A.shape) == tuple(B.shape), f"Invalid tensor shape: {A.shape}, {B.shape}"
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
128
),
T
.
ceildiv
(
N
,
128
),
threads
=
128
)
as
(
bx
,
by
):
T
.
copy
(
A
[
bx
*
128
:
bx
*
128
+
128
,
by
*
128
:
by
*
128
+
128
],
B
[
bx
*
128
:
bx
*
128
+
128
,
by
*
128
:
by
*
128
+
128
])
T
.
copy
(
A
[
bx
*
128
:
bx
*
128
+
128
,
by
*
128
:
by
*
128
+
128
],
B
[
bx
*
128
:
bx
*
128
+
128
,
by
*
128
:
by
*
128
+
128
])
return
B
@
tilelang
.
lazy_jit
...
...
@@ -292,41 +278,52 @@ def test_jit2_return():
return
copy_impl
(
A
)
@
tilelang
.
lazy_jit
def
copy1
(
A
:
T
.
Tensor
[[
int
,
int
],
T
.
float32
],):
def
copy1
(
A
:
T
.
Tensor
[[
int
,
int
],
T
.
float32
],
):
return
copy_impl
(
A
)
@
tilelang
.
lazy_jit
def
copy2
(
A
:
T
.
Tensor
[[
128
,
128
],
T
.
float32
],):
def
copy2
(
A
:
T
.
Tensor
[[
128
,
128
],
T
.
float32
],
):
return
copy_impl
(
A
)
@
tilelang
.
lazy_jit
def
copy3
(
A
:
T
.
Tensor
[[
int
,
128
],
T
.
float32
],):
def
copy3
(
A
:
T
.
Tensor
[[
int
,
128
],
T
.
float32
],
):
return
copy_impl
(
A
)
@
tilelang
.
lazy_jit
def
copy4
(
A
:
T
.
Tensor
[[
T
.
dyn
,
int
],
T
.
float32
],):
def
copy4
(
A
:
T
.
Tensor
[[
T
.
dyn
,
int
],
T
.
float32
],
):
return
copy_impl
(
A
)
@
tilelang
.
lazy_jit
def
copy5
(
A
:
T
.
StridedTensor
[[
int
,
int
],
[
int
,
int
],
T
.
float32
],):
def
copy5
(
A
:
T
.
StridedTensor
[[
int
,
int
],
[
int
,
int
],
T
.
float32
],
):
return
copy_impl
(
A
)
@
tilelang
.
lazy_jit
def
copy6
(
A
:
T
.
StridedTensor
[[
T
.
dyn
,
int
],
[
int
,
int
],
T
.
float32
],):
def
copy6
(
A
:
T
.
StridedTensor
[[
T
.
dyn
,
int
],
[
int
,
int
],
T
.
float32
],
):
return
copy_impl
(
A
)
for
copy
in
[
copy0
,
copy1
,
copy2
,
copy3
,
copy4
]:
A
=
torch
.
randn
(
128
,
128
,
device
=
'
cuda
'
)
A
=
torch
.
randn
(
128
,
128
,
device
=
"
cuda
"
)
B
=
copy
(
A
)
assert
torch
.
equal
(
B
,
A
)
for
copy
in
[
copy5
,
copy6
]:
A
=
torch
.
randn
(
128
,
2
,
128
,
2
,
device
=
'
cuda
'
)
A
=
torch
.
randn
(
128
,
2
,
128
,
2
,
device
=
"
cuda
"
)
B
=
copy
(
A
[:,
0
,
:,
0
])
assert
torch
.
equal
(
A
[:,
0
,
:,
0
],
B
)
def
test_jit2_deepseek_deepgemm
():
@
tilelang
.
lazy_jit
def
deep_gemm
(
A
:
T
.
Tensor
[[
int
,
int
],
T
.
float8_e4m3
],
...
...
@@ -351,13 +348,9 @@ def test_jit2_deepseek_deepgemm():
N
,
K
=
B
.
shape
C
=
T
.
empty
(
M
,
N
,
dtype
=
out_dtype
)
assert
out_dtype
in
[
T
.
bfloat16
,
T
.
float32
],
f
"Expect out_dtype to be one of [T.float16, T.float32], got
{
out_dtype
}
"
assert
scales_a
.
shape
==
[
M
,
T
.
ceildiv
(
K
,
group_size
)
],
f
"Expect scales_a shape to be f
{
[
M
,
T
.
ceildiv
(
K
,
group_size
)]
}
"
assert
scales_b
.
shape
==
[
N
,
T
.
ceildiv
(
K
,
group_size
)
],
f
"Expect scales_b shape to be f
{
[
N
,
T
.
ceildiv
(
K
,
group_size
)]
}
"
assert
out_dtype
in
[
T
.
bfloat16
,
T
.
float32
],
f
"Expect out_dtype to be one of [T.float16, T.float32], got
{
out_dtype
}
"
assert
scales_a
.
shape
==
[
M
,
T
.
ceildiv
(
K
,
group_size
)],
f
"Expect scales_a shape to be f
{
[
M
,
T
.
ceildiv
(
K
,
group_size
)]
}
"
assert
scales_b
.
shape
==
[
N
,
T
.
ceildiv
(
K
,
group_size
)],
f
"Expect scales_b shape to be f
{
[
N
,
T
.
ceildiv
(
K
,
group_size
)]
}
"
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
in_dtype
)
...
...
@@ -421,5 +414,5 @@ def test_jit2_deepseek_deepgemm():
# M, N, K = 1024, 1024, 8192
# A = torch.randn((M, K), dtype=torch.float8_e4m3fn, )
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
tilelang
.
testing
.
main
()
testing/python/language/test_tilelang_language_let.py
View file @
29051439
...
...
@@ -4,7 +4,6 @@ from tilelang import language as T
def
test_let_vectorize_load
():
@
T
.
prim_func
def
main
(
A_ptr
:
T
.
handle
):
A
=
T
.
match_buffer
(
A_ptr
,
(
16
,
16
),
dtype
=
"float32"
,
align
=
16
)
...
...
testing/python/language/test_tilelang_language_mask_op.py
View file @
29051439
...
...
@@ -6,11 +6,10 @@ import torch
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def
tilelang_copy_mask_parallel
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
...
...
@@ -30,13 +29,8 @@ def tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype="float16"):
def
run_tilelang_copy_mask_parallel
(
M
=
1024
,
N
=
1024
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
):
program
=
tilelang_copy_mask_parallel
(
M
,
N
,
block_M
,
block_N
,
dtype
)
kernel
=
tilelang
.
compile
(
program
,
out_idx
=
[
1
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_warp_specialized"
:
True
,
"tl.disable_tma_lower"
:
True
})
program
,
out_idx
=
[
1
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_warp_specialized"
:
True
,
"tl.disable_tma_lower"
:
True
}
)
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
b
=
kernel
(
a
)
torch
.
testing
.
assert_close
(
b
,
a
,
rtol
=
1e-2
,
atol
=
1e-2
)
...
...
@@ -49,11 +43,10 @@ def test_tilelang_copy_mask_parallel():
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def
tilelang_copy_mask_copy
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
...
...
@@ -72,13 +65,8 @@ def tilelang_copy_mask_copy(M, N, block_M, block_N, dtype="float16"):
def
run_tilelang_copy_mask_copy
(
M
=
1024
,
N
=
1024
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
):
program
=
tilelang_copy_mask_copy
(
M
,
N
,
block_M
,
block_N
,
dtype
)
kernel
=
tilelang
.
compile
(
program
,
out_idx
=
[
1
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_warp_specialized"
:
True
,
"tl.disable_tma_lower"
:
True
})
program
,
out_idx
=
[
1
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_warp_specialized"
:
True
,
"tl.disable_tma_lower"
:
True
}
)
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
b
=
kernel
(
a
)
torch
.
testing
.
assert_close
(
b
,
a
,
rtol
=
1e-2
,
atol
=
1e-2
)
...
...
@@ -91,11 +79,10 @@ def test_tilelang_copy_mask_copy():
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def
tilelang_copy_mask_parallel_range
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
...
...
@@ -112,20 +99,11 @@ def tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype="float16"):
return
main
def
run_tilelang_copy_mask_parallel_range
(
M
=
1024
,
N
=
1024
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
):
def
run_tilelang_copy_mask_parallel_range
(
M
=
1024
,
N
=
1024
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
):
program
=
tilelang_copy_mask_parallel_range
(
M
,
N
,
block_M
,
block_N
,
dtype
)
kernel
=
tilelang
.
compile
(
program
,
out_idx
=
[
1
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_warp_specialized"
:
True
,
"tl.disable_tma_lower"
:
True
})
program
,
out_idx
=
[
1
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_warp_specialized"
:
True
,
"tl.disable_tma_lower"
:
True
}
)
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
b
=
kernel
(
a
)
torch
.
testing
.
assert_close
(
b
,
a
,
rtol
=
1e-2
,
atol
=
1e-2
)
...
...
@@ -138,11 +116,10 @@ def test_tilelang_copy_mask_parallel_range():
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def
tilelang_copy_mask_copy_range
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
...
...
@@ -161,13 +138,8 @@ def tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype="float16"):
def
run_tilelang_copy_mask_copy_range
(
M
=
1024
,
N
=
1024
,
block_M
=
128
,
block_N
=
128
,
dtype
=
"float16"
):
program
=
tilelang_copy_mask_copy_range
(
M
,
N
,
block_M
,
block_N
,
dtype
)
kernel
=
tilelang
.
compile
(
program
,
out_idx
=
[
1
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_warp_specialized"
:
True
,
"tl.disable_tma_lower"
:
True
})
program
,
out_idx
=
[
1
],
target
=
"cuda"
,
pass_configs
=
{
"tl.disable_warp_specialized"
:
True
,
"tl.disable_tma_lower"
:
True
}
)
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
b
=
kernel
(
a
)
torch
.
testing
.
assert_close
(
b
,
a
,
rtol
=
1e-2
,
atol
=
1e-2
)
...
...
Prev
1
…
10
11
12
13
14
15
16
17
18
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment