Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
bbbf4207
Unverified
Commit
bbbf4207
authored
Nov 14, 2025
by
guchaoyang
Committed by
GitHub
Nov 14, 2025
Browse files
Merge branch 'main' into dcu
parents
8f4628e0
5eb30a4f
Changes
286
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
842 additions
and
43 deletions
+842
-43
testing/python/language/test_tilelang_language_chain_equal.py
...ing/python/language/test_tilelang_language_chain_equal.py
+0
-0
testing/python/language/test_tilelang_language_frontend_v2.py
...ing/python/language/test_tilelang_language_frontend_v2.py
+364
-0
testing/python/language/test_tilelang_language_get_warp_info.py
...g/python/language/test_tilelang_language_get_warp_info.py
+0
-1
testing/python/language/test_tilelang_language_infinity.py
testing/python/language/test_tilelang_language_infinity.py
+33
-0
testing/python/language/test_tilelang_language_let.py
testing/python/language/test_tilelang_language_let.py
+23
-0
testing/python/language/test_tilelang_language_negative_index.py
.../python/language/test_tilelang_language_negative_index.py
+60
-0
testing/python/language/test_tilelang_language_reduce.py
testing/python/language/test_tilelang_language_reduce.py
+0
-2
testing/python/language/test_tilelang_language_reshape.py
testing/python/language/test_tilelang_language_reshape.py
+133
-0
testing/python/language/test_tilelang_language_tma_1d.py
testing/python/language/test_tilelang_language_tma_1d.py
+16
-11
testing/python/language/test_tilelang_language_var_init.py
testing/python/language/test_tilelang_language_var_init.py
+32
-0
testing/python/language/test_tilelang_language_vectorized_cast.py
...python/language/test_tilelang_language_vectorized_cast.py
+29
-3
testing/python/layout/test_tilelang_layout_fused_replicate.py
...ing/python/layout/test_tilelang_layout_fused_replicate.py
+63
-0
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py
+1
-0
testing/python/transform/test_tilelang_transform_inject_fence_proxy.py
...n/transform/test_tilelang_transform_inject_fence_proxy.py
+2
-2
testing/python/transform/test_tilelang_transform_layout_inference.py
...hon/transform/test_tilelang_transform_layout_inference.py
+8
-6
testing/python/transform/test_tilelang_transform_lower_tile_op.py
...python/transform/test_tilelang_transform_lower_tile_op.py
+8
-6
testing/python/transform/test_tilelang_transform_multi_version_buffer.py
...transform/test_tilelang_transform_multi_version_buffer.py
+2
-2
testing/python/transform/test_tilelang_transform_thread_sync.py
...g/python/transform/test_tilelang_transform_thread_sync.py
+36
-0
tilelang/__init__.py
tilelang/__init__.py
+30
-8
tilelang/_ffi_api.py
tilelang/_ffi_api.py
+2
-2
No files found.
testing/python/language/test_tilelang_lagua
n
ge_chain_equal.py
→
testing/python/language/test_tilelang_la
n
guage_chain_equal.py
View file @
bbbf4207
File moved
testing/python/language/test_tilelang_language_frontend_v2.py
0 → 100644
View file @
bbbf4207
import
tilelang
import
tilelang.language
as
T
import
torch
import
tilelang.testing
import
tvm
from
tvm.script.ir_builder.base
import
IRBuilderFrame
from
tvm.tir.expr
import
IntImm
,
Var
def
test_argument
():
@
T
.
prim_func
def
test_argument
(
t_1
:
T
.
bool
,
t_2
:
T
.
short
,
t_3
:
T
.
int
,
t_4
:
T
.
long
,
t_5
:
T
.
half
,
t_6
:
T
.
float
,
t_7
:
T
.
long
,
t_8
:
T
.
int8
,
t_9
:
T
.
int16
,
t_10
:
T
.
int32
,
t_11
:
T
.
int64
,
t_12
:
T
.
uint8
,
t_13
:
T
.
uint16
,
t_14
:
T
.
uint32
,
t_15
:
T
.
uint64
,
t_16
:
T
.
float8_e4m3fn
,
t_17
:
T
.
float8_e4m3fnuz
,
t_18
:
T
.
float8_e5m2
,
t_19
:
T
.
float8_e5m2fnuz
,
t_20
:
T
.
float8_e8m0fnu
,
t_21
:
T
.
float16
,
t_22
:
T
.
bfloat16
,
t_23
:
T
.
float32
,
t_24
:
T
.
float64
,
):
pass
def
test_expr
():
from
tilelang.language.v2.dtypes
import
_all_dtypes
errors
=
[]
for
name
in
_all_dtypes
:
dtype
=
getattr
(
T
,
name
)
assert
isinstance
(
dtype
,
tvm
.
DataType
),
f
"
{
dtype
}
is not tvm.DataType"
try
:
dtype
(
1.0
)
dtype
()
except
TypeError
:
pass
except
Exception
:
errors
.
append
(
name
)
assert
not
errors
# def test_var_decl_sugar():
# @T.prim_func
# def test_var_decl_sugar():
# with T.Kernel(128, 128) as (bx, by):
# var_1: T.bool = 1.0
# var_2: T.short = 1.0
# var_3: T.int = 1.0
# var_4: T.long = 1.0
# var_5: T.half = 1.0
# var_6: T.float = 1.0
# var_7: T.long = 1.0
# var_8: T.int8 = 1.0
# var_9: T.int16 = 1.0
# var_10: T.int32 = 1.0
# var_11: T.int64 = 1.0
# var_12: T.uint8 = 1.0
# var_13: T.uint16 = 1.0
# var_14: T.uint32 = 1.0
# var_15: T.uint64 = 1.0
# var_16: T.float8_e4m3fn = 1.0
# var_17: T.float8_e4m3fnuz = 1.0
# var_18: T.float8_e5m2 = 1.0
# var_19: T.float8_e5m2fnuz = 1.0
# var_20: T.float8_e8m0fnu = 1.0
# var_21: T.float16 = 1.0
# var_22: T.bfloat16 = 1.0
# var_23: T.float32 = 1.0
# var_24: T.float64 = 1.0
# var_1: T.bool = var_1
# var_2: T.short = var_2
# var_3: T.int = var_3
# var_4: T.long = var_4
# var_5: T.half = var_5
# var_6: T.float = var_6
# var_7: T.long = var_7
# var_8: T.int8 = var_8
# var_9: T.int16 = var_9
# var_10: T.int32 = var_10
# var_11: T.int64 = var_11
# var_12: T.uint8 = var_12
# var_13: T.uint16 = var_13
# var_14: T.uint32 = var_14
# var_15: T.uint64 = var_15
# var_16: T.float8_e4m3fn = var_16
# var_17: T.float8_e4m3fnuz = var_17
# var_18: T.float8_e5m2 = var_18
# var_19: T.float8_e5m2fnuz = var_19
# var_20: T.float8_e8m0fnu = var_20
# var_21: T.float16 = var_21
# var_22: T.bfloat16 = var_22
# var_23: T.float32 = var_23
# var_24: T.float64 = var_24
# s = test_var_decl_sugar.script()
# for i in range(1, 25):
# assert f'var_{i}_1' in s
# assert 'tl.local_var_init' in s
def
test_dtype_str_repr
():
@
T
.
prim_func
def
test_str_repr
():
buf_1
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
bool
,
scope
=
'shared'
)
# noqa F841
buf_2
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
short
,
scope
=
'shared'
)
# noqa F841
buf_3
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int
,
scope
=
'shared'
)
# noqa F841
buf_4
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
long
,
scope
=
'shared'
)
# noqa F841
buf_5
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
half
,
scope
=
'shared'
)
# noqa F841
buf_6
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float
,
scope
=
'shared'
)
# noqa F841
buf_7
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
long
,
scope
=
'shared'
)
# noqa F841
buf_8
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int8
,
scope
=
'shared'
)
# noqa F841
buf_9
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int16
,
scope
=
'shared'
)
# noqa F841
buf_10
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int32
,
scope
=
'shared'
)
# noqa F841
buf_11
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
int64
,
scope
=
'shared'
)
# noqa F841
buf_12
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
uint8
,
scope
=
'shared'
)
# noqa F841
buf_13
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
uint16
,
scope
=
'shared'
)
# noqa F841
buf_14
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
uint32
,
scope
=
'shared'
)
# noqa F841
buf_15
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
uint64
,
scope
=
'shared'
)
# noqa F841
buf_16
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e4m3fn
,
scope
=
'shared'
)
# noqa F841
buf_17
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e4m3fnuz
,
scope
=
'shared'
)
# noqa F841
buf_18
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e5m2
,
scope
=
'shared'
)
# noqa F841
buf_19
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e5m2fnuz
,
scope
=
'shared'
)
# noqa F841
buf_20
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float8_e8m0fnu
,
scope
=
'shared'
)
# noqa F841
buf_21
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float16
,
scope
=
'shared'
)
# noqa F841
buf_22
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
bfloat16
,
scope
=
'shared'
)
# noqa F841
buf_23
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float32
,
scope
=
'shared'
)
# noqa F841
buf_24
=
T
.
alloc_buffer
((
1
,),
dtype
=
T
.
float64
,
scope
=
'shared'
)
# noqa F841
def
test_torch_eq
():
dtypes
=
[
T
.
bool
,
T
.
short
,
T
.
int
,
T
.
long
,
T
.
half
,
T
.
float
,
T
.
long
,
T
.
int8
,
T
.
int16
,
T
.
int32
,
T
.
int64
,
T
.
uint8
,
T
.
uint16
,
T
.
uint32
,
T
.
uint64
,
T
.
float8_e4m3fn
,
T
.
float8_e4m3fnuz
,
T
.
float8_e5m2
,
T
.
float8_e5m2fnuz
,
T
.
float8_e8m0fnu
,
T
.
float16
,
T
.
bfloat16
,
T
.
float32
,
T
.
float64
,
]
torch_dtypes
=
[
torch
.
bool
,
torch
.
short
,
torch
.
int
,
torch
.
long
,
torch
.
half
,
torch
.
float
,
torch
.
long
,
torch
.
int8
,
torch
.
int16
,
torch
.
int32
,
torch
.
int64
,
torch
.
uint8
,
torch
.
uint16
,
torch
.
uint32
,
torch
.
uint64
,
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
torch
.
float8_e5m2
,
torch
.
float8_e5m2fnuz
,
torch
.
float8_e8m0fnu
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
,
torch
.
float64
,
]
for
a
,
b
in
zip
(
dtypes
,
torch_dtypes
):
assert
a
==
b
,
f
"
{
a
}
and
{
b
}
are not equal"
assert
T
.
dtype
(
b
)
==
a
,
"dtype conversion error"
def
test_var_assign
():
@
tilelang
.
jit
(
out_idx
=-
1
)
@
T
.
prim_func
def
test_var_assign
(
A
:
T
.
Tensor
((
2
,),
T
.
int32
)):
with
T
.
Kernel
(
1
)
as
_
:
a
:
T
.
int32
=
1
b
:
T
.
int32
=
a
a
=
2
d
:
T
.
int32
=
a
A
[
0
]
=
b
A
[
1
]
=
d
res
=
test_var_assign
()()
assert
res
[
0
]
==
1
assert
res
[
1
]
==
2
def
test_marco_return
():
@
T
.
macro
def
macro_return_constant
():
return
0
@
T
.
macro
def
macro_return_frame
(
x
):
return
T
.
alloc_var
(
T
.
float32
,
init
=
x
)
@
T
.
macro
def
macro_return_expr
(
x
):
y
=
x
+
1.0
return
y
@
T
.
macro
def
macro_apply_func
(
x
,
fn
):
return
fn
(
x
)
def
check
(
x
,
ty
):
assert
isinstance
(
x
,
ty
)
@
T
.
prim_func
def
test_macro_return
():
with
T
.
Kernel
(
1
)
as
_
:
a
=
macro_return_constant
()
b
=
macro_return_frame
(
3.0
)
c
=
macro_return_expr
(
4.0
)
d
=
macro_apply_func
(
5.0
,
lambda
x
:
x
*
2.0
)
check
(
a
,
(
int
,
float
,
T
.
PrimExpr
))
check
(
b
,
T
.
PrimExpr
)
check
(
c
,
T
.
PrimExpr
)
check
(
d
,
T
.
PrimExpr
)
def
test_prim_func_generator
():
@
T
.
prim_func
(
generator
=
True
)
def
prim_func_gen
(
A
=
T
.
Tensor
((
128
,),
T
.
float32
),
# noqa: B008
B
=
T
.
Tensor
((
128
,),
T
.
float32
),
# noqa: B008
):
with
T
.
Kernel
(
128
)
as
(
tx
,):
T
.
copy
(
A
[
tx
],
B
[
tx
])
prim_func_gen
()
@
T
.
prim_func
def
foo
()
->
T
.
Tensor
((
128
,),
T
.
float32
):
pass
assert
isinstance
(
foo
,
T
.
PrimFunc
)
def
test_serial_for_with_step
():
@
tilelang
.
jit
(
out_idx
=-
1
)
@
T
.
prim_func
def
test_stepped_serial
(
A
:
T
.
Tensor
((
10
,),
T
.
int32
)):
with
T
.
Kernel
(
1
)
as
_
:
for
i
in
range
(
0
,
10
,
2
):
T
.
device_assert
(
0
<=
i
<
10
and
i
%
2
==
0
,
"i out of range"
)
A
[
i
]
=
1.0
for
i
in
range
(
1
,
10
,
2
):
T
.
device_assert
(
1
<=
i
<
10
and
i
%
2
==
1
,
"i out of range"
)
A
[
i
]
=
2.0
ker
=
test_stepped_serial
()
res
=
ker
()
ref
=
torch
.
tensor
([
1
,
2
,
1
,
2
,
1
,
2
,
1
,
2
,
1
,
2
],
dtype
=
torch
.
int32
,
device
=
'cuda'
)
assert
torch
.
all
(
res
==
ref
),
f
"Expected
{
ref
}
, but got
{
res
}
"
@
tilelang
.
jit
(
out_idx
=-
1
)
@
T
.
prim_func
def
test_serial_step_neg
(
A
:
T
.
Tensor
((
10
,),
T
.
int32
)):
with
T
.
Kernel
(
1
)
as
_
:
for
i
in
range
(
10
,
0
,
-
1
):
T
.
device_assert
(
0
<
i
<=
10
,
"i out of range"
)
A
[
10
-
i
]
=
i
ker
=
test_serial_step_neg
()
res
=
ker
()
ref
=
torch
.
tensor
([
10
,
9
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
],
dtype
=
torch
.
int32
,
device
=
'cuda'
)
assert
torch
.
all
(
res
==
ref
),
f
"Expected
{
ref
}
, but got
{
res
}
"
assert
isinstance
(
T
.
serial
(
1
,
10
,
1
),
IRBuilderFrame
)
assert
isinstance
(
T
.
serial
(
1
,
10
,
IntImm
(
'int32'
,
1
)),
IRBuilderFrame
)
assert
not
isinstance
(
T
.
serial
(
1
,
10
,
Var
(
'tmp'
,
'int32'
)),
IRBuilderFrame
)
assert
not
isinstance
(
T
.
serial
(
10
,
-
1
,
-
1
),
IRBuilderFrame
)
def
test_swap_logic
():
@
tilelang
.
jit
@
T
.
prim_func
def
swap_var
(
A
:
T
.
Tensor
[(
2
,),
T
.
float32
]):
with
T
.
Kernel
(
1
,
threads
=
1
)
as
_
:
a
=
T
.
alloc_var
(
T
.
float32
,
A
[
0
])
b
=
T
.
alloc_var
(
T
.
float32
,
A
[
1
])
a
,
b
=
b
,
a
A
[
0
],
A
[
1
]
=
a
,
b
@
tilelang
.
jit
@
T
.
prim_func
def
swap_idx
(
A
:
T
.
Tensor
[(
2
,),
T
.
float32
]):
with
T
.
Kernel
(
1
,
threads
=
1
)
as
_
:
A
[
0
],
A
[
1
]
=
A
[
1
],
A
[
0
]
k_swap_var
=
swap_var
()
data
=
torch
.
tensor
([
1.0
,
2.0
],
dtype
=
torch
.
float32
).
cuda
()
k_swap_var
(
data
)
ref
=
torch
.
tensor
([
2.0
,
1.0
],
dtype
=
torch
.
float32
).
cuda
()
torch
.
testing
.
assert_close
(
data
,
ref
)
k_swap_idx
=
swap_idx
()
data
=
torch
.
tensor
([
1.0
,
2.0
],
dtype
=
torch
.
float32
).
cuda
()
k_swap_idx
(
data
)
ref
=
torch
.
tensor
([
2.0
,
1.0
],
dtype
=
torch
.
float32
).
cuda
()
torch
.
testing
.
assert_close
(
data
,
ref
)
def
test_while_loop
():
@
tilelang
.
jit
(
out_idx
=-
1
)
@
T
.
prim_func
def
test_while_loop
(
A
:
T
.
Tensor
((
1
,),
T
.
int32
)):
with
T
.
Kernel
(
1
)
as
_
:
i
=
T
.
alloc_var
(
T
.
int32
,
0
)
sum
=
T
.
alloc_var
(
T
.
int32
)
while
i
<
10
:
sum
+=
i
i
+=
1
A
[
0
]
=
sum
ker
=
test_while_loop
()
A
=
ker
()
assert
A
[
0
].
item
()
==
sum
(
range
(
10
)),
f
"Expected
{
sum
(
range
(
10
))
}
, but got
{
A
[
0
].
item
()
}
"
if
__name__
==
'__main__'
:
tilelang
.
testing
.
main
()
testing/python/language/test_tilelang_language_get_warp_info.py
View file @
bbbf4207
...
@@ -209,4 +209,3 @@ def test_shuffle_elect_block_leader():
...
@@ -209,4 +209,3 @@ def test_shuffle_elect_block_leader():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
tilelang
.
testing
.
main
()
# run_get_lane_id()
testing/python/language/test_tilelang_language_infinity.py
0 → 100644
View file @
bbbf4207
import
torch
import
tilelang
import
tilelang.language
as
T
@
tilelang
.
jit
(
out_idx
=-
1
)
def
get_inf_kernel
(
dtype
:
str
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
32
,),
dtype
)):
with
T
.
Kernel
(
1
,
threads
=
32
):
T
.
fill
(
A
,
T
.
infinity
(
dtype
))
return
main
def
_test_infinity
(
dtype
:
str
):
kernel
=
get_inf_kernel
(
dtype
)
output
=
kernel
()
assert
torch
.
all
(
output
==
torch
.
inf
),
f
'check failed for
{
dtype
=
}
'
@
tilelang
.
testing
.
requires_cuda
def
test_infinity
():
_test_infinity
(
"float16"
)
_test_infinity
(
"bfloat16"
)
_test_infinity
(
"float32"
)
_test_infinity
(
"float64"
)
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
testing/python/language/test_tilelang_language_let.py
0 → 100644
View file @
bbbf4207
import
tilelang.testing
from
tilelang
import
tvm
as
tvm
from
tilelang
import
language
as
T
def
test_let_vectorize_load
():
@
T
.
prim_func
def
main
(
A_ptr
:
T
.
handle
):
A
=
T
.
match_buffer
(
A_ptr
,
(
16
,
16
),
dtype
=
"float32"
,
align
=
16
)
for
_blockIdx
in
T
.
thread_binding
(
1
,
thread
=
"blockIdx.x"
):
for
_threadIdx
in
T
.
thread_binding
(
128
,
thread
=
"threadIdx.x"
):
b
=
A
[
0
,
0
:
4
]
A
[
0
,
4
:
8
]
=
b
mod
=
tvm
.
IRModule
({
"main"
:
main
})
mod
=
tvm
.
compile
(
mod
,
target
=
"cuda"
)
assert
"float4 b"
in
mod
.
mod
.
imports
[
0
].
inspect_source
()
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
testing/python/language/test_tilelang_language_negative_index.py
0 → 100644
View file @
bbbf4207
from
tilelang
import
tvm
import
tilelang
as
tl
import
tilelang.testing
from
tvm.script
import
tir
as
T
@
T
.
prim_func
def
negative_index_before
(
A
:
T
.
Buffer
((
16
,),
"float32"
),
B
:
T
.
Buffer
((
16
,),
"float32"
)):
T
.
func_attr
({
"tir.noalias"
:
True
})
B
[
0
]
=
A
[
T
.
int32
(
-
1
)]
@
T
.
prim_func
def
negative_index_expected
(
A
:
T
.
Buffer
((
16
,),
"float32"
),
B
:
T
.
Buffer
((
16
,),
"float32"
)):
T
.
func_attr
({
"tir.noalias"
:
True
})
B
[
0
]
=
A
[
T
.
int32
(
15
)]
@
T
.
prim_func
def
negative_index_loop_before
(
A
:
T
.
Buffer
((
16
,),
"float32"
),
B
:
T
.
Buffer
((
4
,),
"float32"
)):
T
.
func_attr
({
"tir.noalias"
:
True
})
for
i
in
T
.
serial
(
4
):
B
[
i
]
=
A
[
-
i
-
1
]
@
T
.
prim_func
def
negative_index_loop_expected
(
A
:
T
.
Buffer
((
16
,),
"float32"
),
B
:
T
.
Buffer
((
4
,),
"float32"
)):
T
.
func_attr
({
"tir.noalias"
:
True
})
for
i
in
T
.
serial
(
4
):
B
[
i
]
=
A
[
15
-
i
]
@
T
.
prim_func
def
negative_index_symbolic_before
(
shift
:
T
.
int32
,
A
:
T
.
Buffer
((
16
,),
"float32"
),
B
:
T
.
Buffer
((
16
,),
"float32"
)):
T
.
func_attr
({
"tir.noalias"
:
True
})
for
i
in
T
.
serial
(
16
):
B
[
i
]
=
A
[
shift
+
i
]
def
test_legalize_negative_index_scalar
():
mod
=
tvm
.
IRModule
({
"main"
:
negative_index_before
})
transformed
=
tl
.
transform
.
LegalizeNegativeIndex
()(
mod
)
tvm
.
ir
.
assert_structural_equal
(
transformed
[
"main"
].
body
,
negative_index_expected
.
body
)
def
test_legalize_negative_index_affine_expr
():
mod
=
tvm
.
IRModule
({
"main"
:
negative_index_loop_before
})
transformed
=
tl
.
transform
.
LegalizeNegativeIndex
()(
mod
)
tvm
.
ir
.
assert_structural_equal
(
transformed
[
"main"
].
body
,
negative_index_loop_expected
.
body
)
def
test_legalize_negative_index_symbolic_passthrough
():
mod
=
tvm
.
IRModule
({
"main"
:
negative_index_symbolic_before
})
transformed
=
tl
.
transform
.
LegalizeNegativeIndex
()(
mod
)
tvm
.
ir
.
assert_structural_equal
(
transformed
[
"main"
].
body
,
negative_index_symbolic_before
.
body
)
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
testing/python/language/test_tilelang_language_reduce.py
View file @
bbbf4207
...
@@ -116,7 +116,6 @@ def test_reduce_sum():
...
@@ -116,7 +116,6 @@ def test_reduce_sum():
def
test_reduce_sum_shared
():
def
test_reduce_sum_shared
():
run_reduce_sum
(
64
,
64
,
mode
=
"ss"
)
run_reduce_sum
(
64
,
64
,
mode
=
"ss"
)
run_reduce_sum
(
32
,
96
,
mode
=
"ss"
)
def
test_reduce_max
():
def
test_reduce_max
():
...
@@ -127,7 +126,6 @@ def test_reduce_max():
...
@@ -127,7 +126,6 @@ def test_reduce_max():
def
test_reduce_max_shared
():
def
test_reduce_max_shared
():
run_shared_reduce
(
reduce_max_ss
,
lambda
A
:
A
.
max
(
dim
=
1
).
values
,
64
,
64
,
"float32"
)
run_shared_reduce
(
reduce_max_ss
,
lambda
A
:
A
.
max
(
dim
=
1
).
values
,
64
,
64
,
"float32"
)
run_shared_reduce
(
reduce_max_ss
,
lambda
A
:
A
.
max
(
dim
=
1
).
values
,
96
,
48
,
"float32"
)
def
test_reduce_min_shared
():
def
test_reduce_min_shared
():
...
...
testing/python/language/test_tilelang_language_reshape.py
View file @
bbbf4207
from
tilelang
import
tvm
as
tvm
from
tilelang
import
tvm
as
tvm
import
tilelang.testing
import
tilelang.testing
import
tilelang
as
tl
import
tilelang
as
tl
import
torch
def
reshape_test
(
N
,
M
,
dtype
):
def
reshape_test
(
N
,
M
,
dtype
):
...
@@ -129,5 +130,137 @@ def test_reshape_smem_2d_2_1d():
...
@@ -129,5 +130,137 @@ def test_reshape_smem_2d_2_1d():
run_reshape_smem_2d_2_1d
(
2048
,
64
,
"float16"
)
run_reshape_smem_2d_2_1d
(
2048
,
64
,
"float16"
)
def
reshape_fragment_test
(
N
,
M
,
dtype
):
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
N
//
M
,
M
),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
):
with
T
.
Kernel
(
1
,
threads
=
32
)
as
_
:
A_shared
=
T
.
alloc_shared
((
N
//
M
,
M
),
dtype
,
scope
=
"shared"
)
A_local
=
T
.
alloc_fragment
((
N
//
M
,
M
),
dtype
)
B_shared
=
T
.
alloc_shared
((
N
,),
dtype
,
scope
=
"shared"
)
T
.
copy
(
A
,
A_shared
)
T
.
copy
(
A_shared
,
A_local
)
A_local_reshape
=
T
.
reshape
(
A_local
,
[
N
])
T
.
copy
(
A_local_reshape
,
B_shared
)
T
.
copy
(
B_shared
,
B
)
return
main
def
run_reshape_fragment
(
N
,
M
,
dtype
):
program
=
reshape_fragment_test
(
N
,
M
,
dtype
)
jit_kernel
=
tl
.
compile
(
program
,
out_idx
=-
1
,
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
profiler
=
jit_kernel
.
get_profiler
()
def
ref_program
(
A
):
return
A
.
reshape
(
N
)
profiler
.
assert_allclose
(
ref_program
,
atol
=
1e-2
,
rtol
=
1e-2
)
def
test_reshape_fragment
():
run_reshape_fragment
(
1024
,
32
,
"float32"
)
run_reshape_fragment
(
2048
,
64
,
"float16"
)
def
reshape_layout_transform_shared
(
N
,
M
,
dtype
):
import
tilelang.language
as
T
from
tilelang.intrinsics.mma_layout
import
make_mma_swizzle_layout
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
N
//
M
,
M
),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
):
with
T
.
Kernel
(
1
,
threads
=
32
)
as
_
:
A_shared
=
T
.
alloc_shared
((
N
//
M
,
M
),
dtype
,
scope
=
"shared"
)
T
.
annotate_layout
({
A_shared
:
make_mma_swizzle_layout
(
A_shared
),
})
T
.
copy
(
A
,
A_shared
)
A_shared_reshape
=
T
.
reshape
(
A_shared
,
[
N
])
T
.
copy
(
A_shared_reshape
,
B
)
return
main
def
run_reshape_layout_transform_shared
(
N
,
M
,
dtype
):
program
=
reshape_layout_transform_shared
(
N
,
M
,
dtype
)
jit_kernel
=
tl
.
compile
(
program
,
out_idx
=-
1
,
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
profiler
=
jit_kernel
.
get_profiler
()
def
ref_program
(
A
):
return
A
.
reshape
(
N
)
profiler
.
assert_allclose
(
ref_program
,
atol
=
1e-2
,
rtol
=
1e-2
)
def
test_reshape_layout_transform_shared
():
run_reshape_layout_transform_shared
(
1024
,
32
,
"float32"
)
run_reshape_layout_transform_shared
(
2048
,
64
,
"float16"
)
def
reduce_after_reshape_test
(
N
,
M
,
dtype
):
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
//
M
,),
dtype
),
):
with
T
.
Kernel
(
1
,
threads
=
32
)
as
_
:
A_shared
=
T
.
alloc_shared
((
N
,),
dtype
,
scope
=
"shared"
)
A_local
=
T
.
alloc_fragment
((
N
,),
dtype
)
B_local
=
T
.
alloc_fragment
((
N
//
M
,),
dtype
)
T
.
copy
(
A
,
A_shared
)
T
.
copy
(
A_shared
,
A_local
)
A_local_reshape
=
T
.
reshape
(
A_local
,
[
N
//
M
,
M
])
T
.
reduce_max
(
A_local_reshape
,
B_local
,
dim
=
1
)
T
.
copy
(
B_local
,
B
)
return
main
def
run_reduce_after_reshape
(
N
,
M
,
dtype
):
program
=
reduce_after_reshape_test
(
N
,
M
,
dtype
)
jit_kernel
=
tl
.
compile
(
program
,
out_idx
=-
1
,
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
profiler
=
jit_kernel
.
get_profiler
()
def
ref_program
(
A
):
return
torch
.
max
(
A
.
reshape
(
N
//
M
,
M
),
dim
=
1
).
values
profiler
.
assert_allclose
(
ref_program
,
atol
=
1e-2
,
rtol
=
1e-2
)
def
test_reduce_after_reshape
():
run_reduce_after_reshape
(
1024
,
32
,
"float32"
)
run_reduce_after_reshape
(
2048
,
64
,
"float16"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
tilelang
.
testing
.
main
()
examples/elementwise/example_elementwise_add
_tma_1d.py
→
testing/python/language/test_tilelang_language
_tma_1d.py
View file @
bbbf4207
import
argparse
import
torch
import
tilelang
import
tilelang
import
tilelang.language
as
T
import
tilelang.language
as
T
import
torch
def
ref_program
(
x
,
y
):
def
ref_program
(
x
,
y
):
...
@@ -30,23 +29,29 @@ def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
...
@@ -30,23 +29,29 @@ def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
return
elem_add
return
elem_add
def
main
():
def
run_elementwise_add
(
M
,
N
):
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--m"
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
128
)
args
,
_
=
parser
.
parse_known_args
()
M
,
N
=
args
.
m
,
args
.
n
a
=
torch
.
randn
(
M
,
N
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
a
=
torch
.
randn
(
M
,
N
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
b
=
torch
.
randn
(
M
,
N
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
b
=
torch
.
randn
(
M
,
N
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
# Default config
# Default config
config
=
{
"block_M"
:
128
,
"block_N"
:
128
,
"threads"
:
128
}
block_M
,
block_N
=
128
,
128
config
=
{
"block_M"
:
block_M
,
"block_N"
:
block_N
,
"threads"
:
128
}
kernel
=
elementwise_add
(
M
,
N
,
**
config
,
in_dtype
=
"float32"
,
out_dtype
=
"float32"
)
kernel
=
elementwise_add
(
M
,
N
,
**
config
,
in_dtype
=
"float32"
,
out_dtype
=
"float32"
)
out
=
kernel
(
a
,
b
)
out
=
kernel
(
a
,
b
)
torch
.
testing
.
assert_close
(
out
,
ref_program
(
a
,
b
),
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
out
,
ref_program
(
a
,
b
),
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
"All passed!"
)
code
=
kernel
.
get_kernel_source
()
if
block_N
==
N
:
assert
"tma_load"
in
code
and
"CUtensorMap"
not
in
code
else
:
assert
"tma_load"
in
code
and
"CUtensorMap"
in
code
def
main
():
run_elementwise_add
(
128
,
128
)
run_elementwise_add
(
256
,
128
)
run_elementwise_add
(
256
,
256
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
testing/python/language/test_tilelang_language_var_init.py
0 → 100644
View file @
bbbf4207
import
tilelang
import
tilelang.language
as
T
import
tilelang.testing
def
test_var_assign
()
->
None
:
@
tilelang
.
jit
(
out_idx
=-
1
)
def
jit_kernel
():
@
T
.
prim_func
def
test_var_assign
(
A
:
T
.
Tensor
((
2
,),
'int32'
)):
with
T
.
Kernel
(
1
)
as
_
:
a
=
T
.
alloc_var
(
'int32'
,
init
=
1
)
b
=
T
.
alloc_var
(
'int32'
,
init
=
a
)
# b gets value of a
a
=
2
d
=
T
.
alloc_var
(
'int32'
,
init
=
a
)
# c gets new value of a
A
[
0
]
=
b
A
[
1
]
=
d
print
(
test_var_assign
)
return
test_var_assign
kernel
=
jit_kernel
()
print
(
kernel
.
get_kernel_source
())
res
=
kernel
()
assert
res
[
0
]
==
1
assert
res
[
1
]
==
2
if
__name__
==
'__main__'
:
tilelang
.
testing
.
main
()
testing/python/language/test_tilelang_language_vectorized_cast.py
View file @
bbbf4207
...
@@ -17,8 +17,8 @@ def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
...
@@ -17,8 +17,8 @@ def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
[(
M
),
dtype_A
],
# noqa: F821
A
:
T
.
Tensor
[(
M
,
),
dtype_A
],
# noqa: F821
B
:
T
.
Tensor
[(
M
),
dtype_B
],
# noqa: F821
B
:
T
.
Tensor
[(
M
,
),
dtype_B
],
# noqa: F821
):
):
with
T
.
Kernel
(
1
,
threads
=
128
):
with
T
.
Kernel
(
1
,
threads
=
128
):
T
.
copy
(
A
,
B
)
T
.
copy
(
A
,
B
)
...
@@ -26,6 +26,27 @@ def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
...
@@ -26,6 +26,27 @@ def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
return
main
return
main
@
tilelang
.
jit
def
parallel_vectorized_cast_kernel
(
M
:
int
,
dtype_A
:
str
,
dtype_B
:
str
):
assert
M
%
256
==
0
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
[(
M
,),
dtype_A
],
# noqa: F821
B
:
T
.
Tensor
[(
M
,),
dtype_B
],
# noqa: F821
):
with
T
.
Kernel
(
1
,
threads
=
128
):
A_local
=
T
.
alloc_fragment
((
M
,),
dtype_A
)
B_local
=
T
.
alloc_fragment
((
M
,),
dtype_B
)
T
.
copy
(
A
,
A_local
)
for
i
in
T
.
Parallel
(
M
):
B_local
[
i
]
=
A_local
[
i
]
T
.
copy
(
B_local
,
B
)
return
main
def
run_vectorized_cast
(
src_dtype_str
:
str
,
dst_dtype_str
:
str
,
check_str
:
str
,
lanes
:
int
=
2
):
def
run_vectorized_cast
(
src_dtype_str
:
str
,
dst_dtype_str
:
str
,
check_str
:
str
,
lanes
:
int
=
2
):
"""Run the vectorized cast kernel and check the correctness.
"""Run the vectorized cast kernel and check the correctness.
Args:
Args:
...
@@ -37,17 +58,22 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str,
...
@@ -37,17 +58,22 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str,
M
=
128
*
lanes
M
=
128
*
lanes
kernel
=
vectorized_cast_kernel
(
M
,
src_dtype_str
,
dst_dtype_str
)
kernel
=
vectorized_cast_kernel
(
M
,
src_dtype_str
,
dst_dtype_str
)
kernel_parallel
=
parallel_vectorized_cast_kernel
(
M
,
src_dtype_str
,
dst_dtype_str
)
A
=
torch
.
randn
(
M
,
dtype
=
str2dtype
[
src_dtype_str
]).
cuda
()
A
=
torch
.
randn
(
M
,
dtype
=
str2dtype
[
src_dtype_str
]).
cuda
()
B
=
torch
.
zeros
(
M
,
dtype
=
str2dtype
[
dst_dtype_str
]).
cuda
()
B
=
torch
.
zeros
(
M
,
dtype
=
str2dtype
[
dst_dtype_str
]).
cuda
()
C
=
torch
.
zeros
(
M
,
dtype
=
str2dtype
[
dst_dtype_str
]).
cuda
()
kernel
(
A
,
B
)
kernel
(
A
,
B
)
kernel_parallel
(
A
,
C
)
torch
.
testing
.
assert_close
(
A
.
to
(
str2dtype
[
dst_dtype_str
]),
B
)
torch
.
testing
.
assert_close
(
A
.
to
(
str2dtype
[
dst_dtype_str
]),
B
)
torch
.
testing
.
assert_close
(
A
.
to
(
str2dtype
[
dst_dtype_str
]),
C
)
code
=
kernel
.
get_kernel_source
()
code
=
kernel
.
get_kernel_source
()
code_parallel
=
kernel_parallel
.
get_kernel_source
()
assert
check_str
in
code
,
\
assert
check_str
in
code
and
check_str
in
code_parallel
,
\
f
"Cast
{
src_dtype_str
}
to
{
dst_dtype_str
}
with
{
lanes
=
}
is not vectorized!"
f
"Cast
{
src_dtype_str
}
to
{
dst_dtype_str
}
with
{
lanes
=
}
is not vectorized!"
...
...
testing/python/layout/test_tilelang_layout_fused_replicate.py
0 → 100644
View file @
bbbf4207
import
pytest
import
torch
import
tilelang
import
tilelang.testing
import
tilelang.language
as
T
tilelang
.
testing
.
set_random_seed
()
VEC_SIZE
=
32
@
tilelang
.
jit
def
fused_index_kernel
(
B
:
int
,
M
:
int
,
N
:
int
,
BLOCK_MN
:
int
,
BLOCK_K
:
int
):
@
T
.
prim_func
def
main
(
a
:
T
.
Buffer
((
B
,
M
,
N
),
"bfloat16"
),
a_out
:
T
.
Buffer
((
B
,
M
,
N
),
"float32"
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
BLOCK_MN
),
T
.
ceildiv
(
N
,
BLOCK_K
),
B
,
threads
=
128
,
)
as
(
pid_m
,
pid_n
,
pid_b
):
a_fp32_local
=
T
.
alloc_fragment
((
BLOCK_MN
*
BLOCK_K
//
VEC_SIZE
,
VEC_SIZE
),
"float32"
)
offs_m
=
pid_m
*
BLOCK_MN
offs_n
=
pid_n
*
BLOCK_K
for
i
,
j
in
T
.
Parallel
(
BLOCK_MN
,
BLOCK_K
):
idx
=
i
*
BLOCK_K
+
j
a_out
[
pid_b
,
offs_m
+
i
,
offs_n
+
j
]
=
a_fp32_local
[
idx
//
VEC_SIZE
,
idx
%
VEC_SIZE
]
return
main
def
_require_cuda_tensor
(
shape
,
dtype
):
if
not
torch
.
cuda
.
is_available
():
pytest
.
skip
(
"CUDA not available"
)
try
:
return
torch
.
randn
(
*
shape
,
device
=
"cuda"
,
dtype
=
dtype
)
except
RuntimeError
as
err
:
pytest
.
skip
(
f
"CUDA runtime unavailable:
{
err
}
"
)
def
test_layout_infer_compiles_and_runs
():
B
,
M
,
N
=
1
,
32
,
64
BLOCK_MN
,
BLOCK_K
=
32
,
64
kernel
=
fused_index_kernel
(
B
,
M
,
N
,
BLOCK_MN
,
BLOCK_K
)
a
=
_require_cuda_tensor
((
B
,
M
,
N
),
torch
.
bfloat16
)
a_out
=
torch
.
empty
((
B
,
M
,
N
),
dtype
=
torch
.
float32
,
device
=
a
.
device
)
# Ensure kernel compiles and executes without layout inversion failure
kernel
(
a
,
a_out
)
assert
a_out
.
shape
==
a
.
shape
assert
a_out
.
dtype
==
torch
.
float32
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py
View file @
bbbf4207
...
@@ -397,6 +397,7 @@ def test_gemm_sr():
...
@@ -397,6 +397,7 @@ def test_gemm_sr():
run_gemm_sr
(
128
,
128
,
128
,
True
,
True
,
"float8_e5m2"
,
"float8_e5m2"
,
"float32"
,
128
,
128
,
32
,
2
)
run_gemm_sr
(
128
,
128
,
128
,
True
,
True
,
"float8_e5m2"
,
"float8_e5m2"
,
"float32"
,
128
,
128
,
32
,
2
)
# float32 tests
# float32 tests
# TODO(lei): fix in future
run_gemm_sr
(
128
,
128
,
128
,
False
,
False
,
"float"
,
"float"
,
"float32"
,
128
,
128
,
32
,
2
)
run_gemm_sr
(
128
,
128
,
128
,
False
,
False
,
"float"
,
"float"
,
"float32"
,
128
,
128
,
32
,
2
)
run_gemm_sr
(
128
,
128
,
128
,
False
,
True
,
"float"
,
"float"
,
"float32"
,
128
,
128
,
32
,
2
)
run_gemm_sr
(
128
,
128
,
128
,
False
,
True
,
"float"
,
"float"
,
"float32"
,
128
,
128
,
32
,
2
)
run_gemm_sr
(
128
,
128
,
128
,
True
,
False
,
"float"
,
"float"
,
"float32"
,
128
,
128
,
32
,
2
)
run_gemm_sr
(
128
,
128
,
128
,
True
,
False
,
"float"
,
"float"
,
"float32"
,
128
,
128
,
32
,
2
)
...
...
testing/python/transform/test_tilelang_transform_inject_fence_proxy.py
View file @
bbbf4207
...
@@ -159,8 +159,8 @@ def test_wgmma_marked_async():
...
@@ -159,8 +159,8 @@ def test_wgmma_marked_async():
def
before
():
def
before
():
with
T
.
Kernel
(
1
):
with
T
.
Kernel
(
1
):
A_shared
=
T
.
decl_buffer
((
1
,),
"float16"
,
scope
=
"shared"
)
A_shared
=
T
.
decl_buffer
((
1
,),
"float16"
,
scope
=
"shared"
)
desc_a
=
T
.
decl_buffer
((
1
,),
"uint64"
,
scope
=
"local.descriptor"
)
desc_a
=
T
.
decl_buffer
((
1
,),
"uint64"
,
scope
=
"local.descriptor
.wgmma
"
)
desc_b
=
T
.
decl_buffer
((
1
,),
"uint64"
,
scope
=
"local.descriptor"
)
desc_b
=
T
.
decl_buffer
((
1
,),
"uint64"
,
scope
=
"local.descriptor
.wgmma
"
)
C_local
=
T
.
decl_buffer
((
32
,),
"float16"
,
scope
=
"local"
)
C_local
=
T
.
decl_buffer
((
32
,),
"float16"
,
scope
=
"local"
)
A_shared
[
0
]
=
T
.
float16
(
0
)
A_shared
[
0
]
=
T
.
float16
(
0
)
T
.
warpgroup_arrive
()
T
.
warpgroup_arrive
()
...
...
testing/python/transform/test_tilelang_transform_layout_inference.py
View file @
bbbf4207
...
@@ -15,8 +15,7 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
...
@@ -15,8 +15,7 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
N
=
tvm
.
te
.
var
(
"n"
)
N
=
tvm
.
te
.
var
(
"n"
)
K
=
tvm
.
te
.
var
(
"k"
)
K
=
tvm
.
te
.
var
(
"k"
)
@
tvm
.
script
.
ir
.
ir_module
def
before
():
class
Before
:
@
T
.
prim_func
@
T
.
prim_func
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),):
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),):
...
@@ -38,8 +37,9 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
...
@@ -38,8 +37,9 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
],
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
],
T
.
float16
(
0
))
T
.
float16
(
0
))
@
tvm
.
script
.
ir
.
ir_module
return
tvm
.
IRModule
({
'main'
:
main
})
class
After
:
def
after
():
@
T
.
prim_func
@
T
.
prim_func
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),):
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),):
...
@@ -77,11 +77,13 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
...
@@ -77,11 +77,13 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
],
T
.
float16
(
0
))
(
block_N
//
vec_load_b
)
+
vec
],
T
.
float16
(
0
))
return
tvm
.
IRModule
({
'main'
:
main
})
with
tvm
.
target
.
Target
(
auto_target
):
with
tvm
.
target
.
Target
(
auto_target
):
mod
=
tvm
.
tir
.
transform
.
BindTarget
(
auto_target
)(
B
efore
)
mod
=
tvm
.
tir
.
transform
.
BindTarget
(
auto_target
)(
b
efore
()
)
mod
=
tl
.
transform
.
LayoutInference
()(
mod
)
mod
=
tl
.
transform
.
LayoutInference
()(
mod
)
mod
=
tvm
.
tir
.
transform
.
Simplify
()(
mod
)
mod
=
tvm
.
tir
.
transform
.
Simplify
()(
mod
)
ref_mod
=
tvm
.
tir
.
transform
.
BindTarget
(
auto_target
)(
A
fter
)
ref_mod
=
tvm
.
tir
.
transform
.
BindTarget
(
auto_target
)(
a
fter
()
)
ref_mod
=
tvm
.
tir
.
transform
.
Simplify
()(
ref_mod
)
ref_mod
=
tvm
.
tir
.
transform
.
Simplify
()(
ref_mod
)
# Note(tzj): The structures are equal except one more "for" loop after the LayoutInference pass
# Note(tzj): The structures are equal except one more "for" loop after the LayoutInference pass
# This loop is "for vec in T.parallel(1)",
# This loop is "for vec in T.parallel(1)",
...
...
testing/python/transform/test_tilelang_transform_lower_tile_op.py
View file @
bbbf4207
...
@@ -15,8 +15,7 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
...
@@ -15,8 +15,7 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
N
=
tvm
.
te
.
var
(
"n"
)
N
=
tvm
.
te
.
var
(
"n"
)
K
=
tvm
.
te
.
var
(
"k"
)
K
=
tvm
.
te
.
var
(
"k"
)
@
tvm
.
script
.
ir
.
ir_module
def
before
():
class
Before
:
@
T
.
prim_func
@
T
.
prim_func
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),):
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),):
...
@@ -25,8 +24,9 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
...
@@ -25,8 +24,9 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
3
):
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
3
):
T
.
copy
(
B
[
k
*
block_K
,
bx
*
block_N
],
B_shared
)
T
.
copy
(
B
[
k
*
block_K
,
bx
*
block_N
],
B_shared
)
@
tvm
.
script
.
ir
.
ir_module
return
tvm
.
IRModule
({
'main'
:
main
})
class
After
:
def
after
():
@
T
.
prim_func
@
T
.
prim_func
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),):
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),):
...
@@ -64,11 +64,13 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
...
@@ -64,11 +64,13 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
],
T
.
float16
(
0
))
(
block_N
//
vec_load_b
)
+
vec
],
T
.
float16
(
0
))
return
tvm
.
IRModule
({
'main'
:
main
})
with
tvm
.
transform
.
PassContext
():
with
tvm
.
transform
.
PassContext
():
mod
=
tvm
.
tir
.
transform
.
BindTarget
(
auto_target
)(
B
efore
)
mod
=
tvm
.
tir
.
transform
.
BindTarget
(
auto_target
)(
b
efore
()
)
mod
=
tl
.
transform
.
LowerTileOp
()(
mod
)
mod
=
tl
.
transform
.
LowerTileOp
()(
mod
)
mod
=
tvm
.
tir
.
transform
.
Simplify
()(
mod
)
mod
=
tvm
.
tir
.
transform
.
Simplify
()(
mod
)
ref_mod
=
tvm
.
tir
.
transform
.
BindTarget
(
auto_target
)(
A
fter
)
ref_mod
=
tvm
.
tir
.
transform
.
BindTarget
(
auto_target
)(
a
fter
()
)
ref_mod
=
tvm
.
tir
.
transform
.
Simplify
()(
ref_mod
)
ref_mod
=
tvm
.
tir
.
transform
.
Simplify
()(
ref_mod
)
# Note(tzj): The structures are equal except the argument in "T.reads" function.
# Note(tzj): The structures are equal except the argument in "T.reads" function.
# The difference is just between the first index and the indices range, which is totally equivalent
# The difference is just between the first index and the indices range, which is totally equivalent
...
...
testing/python/transform/test_tilelang_transform_multi_version_buffer.py
View file @
bbbf4207
...
@@ -113,7 +113,7 @@ def test_multi_version_buffer_with_let():
...
@@ -113,7 +113,7 @@ def test_multi_version_buffer_with_let():
shared
=
T
.
alloc_buffer
((
8
,),
"float32"
,
scope
=
"shared.dyn"
)
shared
=
T
.
alloc_buffer
((
8
,),
"float32"
,
scope
=
"shared.dyn"
)
accum
=
T
.
alloc_buffer
((
8
,),
"float32"
,
scope
=
"local"
)
accum
=
T
.
alloc_buffer
((
8
,),
"float32"
,
scope
=
"local"
)
for
k
in
T
.
serial
(
4
,
annotations
=
{
"num_stages"
:
T
.
int32
(
2
)}):
for
k
in
T
.
serial
(
4
,
annotations
=
{
"num_stages"
:
T
.
int32
(
2
)}):
value
:
T
.
float32
=
scales
[
k
]
value
=
scales
[
k
]
for
i
in
T
.
serial
(
8
):
for
i
in
T
.
serial
(
8
):
shared
[
i
]
=
value
shared
[
i
]
=
value
for
i
in
T
.
serial
(
8
):
for
i
in
T
.
serial
(
8
):
...
@@ -125,7 +125,7 @@ def test_multi_version_buffer_with_let():
...
@@ -125,7 +125,7 @@ def test_multi_version_buffer_with_let():
shared
=
T
.
alloc_buffer
((
2
,
8
),
"float32"
,
scope
=
"shared.dyn"
)
shared
=
T
.
alloc_buffer
((
2
,
8
),
"float32"
,
scope
=
"shared.dyn"
)
accum
=
T
.
alloc_buffer
((
8
,),
"float32"
,
scope
=
"local"
)
accum
=
T
.
alloc_buffer
((
8
,),
"float32"
,
scope
=
"local"
)
for
k
in
T
.
serial
(
4
,
annotations
=
{
"num_stages"
:
T
.
int32
(
2
)}):
for
k
in
T
.
serial
(
4
,
annotations
=
{
"num_stages"
:
T
.
int32
(
2
)}):
value
:
T
.
float32
=
scales
[
k
]
value
=
scales
[
k
]
for
i
in
T
.
serial
(
8
):
for
i
in
T
.
serial
(
8
):
shared
[
k
%
2
,
i
]
=
value
shared
[
k
%
2
,
i
]
=
value
for
i
in
T
.
serial
(
8
):
for
i
in
T
.
serial
(
8
):
...
...
testing/python/transform/test_tilelang_transform_thread_sync.py
View file @
bbbf4207
...
@@ -188,5 +188,41 @@ def test_sync_let_stmt():
...
@@ -188,5 +188,41 @@ def test_sync_let_stmt():
tvm
.
ir
.
assert_structural_equal
(
mod
[
"main"
],
expected
)
tvm
.
ir
.
assert_structural_equal
(
mod
[
"main"
],
expected
)
@
tilelang
.
testing
.
requires_cuda
def
test_sync_shared_dyn_stmatrix_loop_hoist
():
@
T
.
prim_func
def
func
():
buf_dyn_shmem
=
T
.
alloc_buffer
((
98304
,),
"uint8"
,
scope
=
"shared.dyn"
)
tx
=
T
.
launch_thread
(
"threadIdx.x"
,
384
)
for
i
in
T
.
unroll
(
8
):
off
=
(
i
//
4
*
8192
+
tx
//
32
*
1024
+
tx
%
16
*
64
+
(
tx
%
8
//
4
+
i
%
4
//
2
)
%
2
*
32
+
(
tx
%
4
//
2
+
i
%
2
)
%
2
*
16
+
(
tx
%
32
//
16
+
tx
%
2
)
%
2
*
8
)
T
.
evaluate
(
T
.
call_intrin
(
"handle"
,
tvm
.
tir
.
op
.
Op
.
get
(
"tl.ptx_stmatrix"
),
T
.
int32
(
0
),
T
.
int32
(
4
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"uint8"
),
buf_dyn_shmem
.
data
,
off
,
98304
-
off
,
2
,
),
T
.
int32
(
2
),
))
mod
=
tvm
.
IRModule
({
"main"
:
func
})
mod
=
tilelang
.
transform
.
ThreadSync
(
"shared.dyn"
)(
mod
)
s
=
str
(
mod
)
assert
'T.tvm_storage_sync("shared.dyn")'
in
s
# Ensure the sync appears before the unrolled loop
assert
s
.
index
(
'T.tvm_storage_sync("shared.dyn")'
)
<
s
.
index
(
"for i in T.unroll(8)"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
tilelang
.
testing
.
main
()
tilelang/__init__.py
View file @
bbbf4207
...
@@ -4,24 +4,46 @@ import ctypes
...
@@ -4,24 +4,46 @@ import ctypes
import
logging
import
logging
import
warnings
import
warnings
from
tqdm
import
tqdm
from
pathlib
import
Path
from
tqdm.auto
import
tqdm
from
importlib.metadata
import
PackageNotFoundError
,
version
try
:
def
_compute_version
()
->
str
:
__version__
=
version
(
'tilelang'
)
"""Return the package version without being polluted by unrelated installs.
except
PackageNotFoundError
:
Preference order:
1) If running from a source checkout (VERSION file present at repo root),
use the dynamic version from version_provider (falls back to plain VERSION).
2) Otherwise, use importlib.metadata for the installed distribution.
3) As a last resort, return a dev sentinel.
"""
try
:
try
:
from
version_provider
import
dynamic_metadata
repo_root
=
Path
(
__file__
).
resolve
().
parent
.
parent
version_file
=
repo_root
/
"VERSION"
if
version_file
.
is_file
():
try
:
from
version_provider
import
dynamic_metadata
# type: ignore
return
dynamic_metadata
(
"version"
)
except
Exception
:
# Fall back to the raw VERSION file if provider isn't available.
return
version_file
.
read_text
().
strip
()
except
Exception
:
# If any of the above fails, fall through to installed metadata.
pass
__version__
=
dynamic_metadata
(
'version'
)
try
:
from
importlib.metadata
import
version
as
_dist_version
# py3.8+
return
_dist_version
(
"tilelang"
)
except
Exception
as
exc
:
except
Exception
as
exc
:
warnings
.
warn
(
warnings
.
warn
(
f
"tilelang version metadata unavailable (
{
exc
!
r
}
); using development version."
,
f
"tilelang version metadata unavailable (
{
exc
!
r
}
); using development version."
,
RuntimeWarning
,
RuntimeWarning
,
stacklevel
=
2
,
stacklevel
=
2
,
)
)
__version__
=
"0.0.dev0"
return
"0.0.dev0"
__version__
=
_compute_version
()
class
TqdmLoggingHandler
(
logging
.
Handler
):
class
TqdmLoggingHandler
(
logging
.
Handler
):
...
...
tilelang/_ffi_api.py
View file @
bbbf4207
"""FFI APIs for tilelang"""
"""FFI APIs for tilelang"""
import
tvm
.
ffi
import
tvm
_
ffi
# TVM_REGISTER_GLOBAL("tl.name").set_body_typed(func);
# TVM_REGISTER_GLOBAL("tl.name").set_body_typed(func);
tvm
.
ffi
.
_
init_api
(
"tl"
,
__name__
)
# pylint: disable=protected-access
tvm
_
ffi
.
init_
ffi_
api
(
"tl"
,
__name__
)
Prev
1
…
7
8
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment