Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
467
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
239 additions
and
296 deletions
+239
-296
testing/python/transform/test_tilelang_transform_warp_specialized.py
...hon/transform/test_tilelang_transform_warp_specialized.py
+37
-35
testing/python/utils/test_compress_utils.py
testing/python/utils/test_compress_utils.py
+1
-1
testing/python/webgpu/test_webgpu_codegen.py
testing/python/webgpu/test_webgpu_codegen.py
+3
-4
tilelang/__init__.py
tilelang/__init__.py
+2
-0
tilelang/analysis/fragment_loop_checker.py
tilelang/analysis/fragment_loop_checker.py
+9
-9
tilelang/analysis/layout_visual.py
tilelang/analysis/layout_visual.py
+1
-5
tilelang/analysis/nested_loop_checker.py
tilelang/analysis/nested_loop_checker.py
+8
-15
tilelang/autotuner/capture.py
tilelang/autotuner/capture.py
+1
-2
tilelang/autotuner/param.py
tilelang/autotuner/param.py
+33
-40
tilelang/autotuner/tuner.py
tilelang/autotuner/tuner.py
+77
-94
tilelang/cache/__init__.py
tilelang/cache/__init__.py
+9
-6
tilelang/cache/kernel_cache.py
tilelang/cache/kernel_cache.py
+19
-35
tilelang/carver/__init__.py
tilelang/carver/__init__.py
+1
-0
tilelang/carver/analysis.py
tilelang/carver/analysis.py
+10
-12
tilelang/carver/arch/__init__.py
tilelang/carver/arch/__init__.py
+14
-14
tilelang/carver/arch/arch_base.py
tilelang/carver/arch/arch_base.py
+2
-6
tilelang/carver/arch/cdna.py
tilelang/carver/arch/cdna.py
+2
-3
tilelang/carver/arch/cpu.py
tilelang/carver/arch/cpu.py
+2
-3
tilelang/carver/arch/cuda.py
tilelang/carver/arch/cuda.py
+7
-10
tilelang/carver/arch/driver/cuda_driver.py
tilelang/carver/arch/driver/cuda_driver.py
+1
-2
No files found.
testing/python/transform/test_tilelang_transform_warp_specialized.py
View file @
29051439
...
...
@@ -32,7 +32,6 @@ block_K = 32
def
test_warp_specialized
():
@
T
.
prim_func
def
before
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
)):
bx
=
T
.
launch_thread
(
"blockIdx.x"
,
8
)
...
...
@@ -47,25 +46,27 @@ def test_warp_specialized():
for
k
in
T
.
serial
(
16
,
annotations
=
{
"num_stages"
:
T
.
int32
(
3
)}):
if
v
==
0
:
T
.
tma_load
(
T
.
create_tma_descriptor
(
6
,
2
,
A
.
data
,
512
,
512
,
2
,
1024
,
32
,
64
,
1
,
1
,
0
,
2
,
2
,
0
),
0
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
k
*
32
,
by
*
64
)
T
.
create_tma_descriptor
(
6
,
2
,
A
.
data
,
512
,
512
,
2
,
1024
,
32
,
64
,
1
,
1
,
0
,
2
,
2
,
0
),
0
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
k
*
32
,
by
*
64
,
)
if
v
==
0
:
T
.
tma_load
(
T
.
create_tma_descriptor
(
6
,
2
,
B
.
data
,
512
,
512
,
2
,
1024
,
64
,
32
,
1
,
1
,
0
,
3
,
2
,
0
),
0
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
bx
*
64
,
k
*
32
)
T
.
create_tma_descriptor
(
6
,
2
,
B
.
data
,
512
,
512
,
2
,
1024
,
64
,
32
,
1
,
1
,
0
,
3
,
2
,
0
),
0
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
bx
*
64
,
k
*
32
,
)
T
.
call_extern
(
"handle"
,
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>"
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float
16
"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
)
)
"handle"
,
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>"
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float
32
"
),
C_local
.
data
,
0
,
32
,
3
),
)
@
T
.
prim_func
def
after
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
)):
...
...
@@ -85,34 +86,35 @@ def test_warp_specialized():
T
.
mbarrier_expect_tx
(
T
.
get_mbarrier
(
k
%
3
),
4096
)
if
v
-
128
==
0
:
T
.
tma_load
(
T
.
create_tma_descriptor
(
6
,
2
,
A
.
data
,
512
,
512
,
2
,
1024
,
32
,
64
,
1
,
1
,
0
,
2
,
2
,
0
),
T
.
get_mbarrier
(
k
%
3
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
k
*
32
,
by
*
64
)
T
.
create_tma_descriptor
(
6
,
2
,
A
.
data
,
512
,
512
,
2
,
1024
,
32
,
64
,
1
,
1
,
0
,
2
,
2
,
0
),
T
.
get_mbarrier
(
k
%
3
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
k
*
32
,
by
*
64
,
)
if
v
-
128
==
0
:
T
.
mbarrier_expect_tx
(
T
.
get_mbarrier
(
k
%
3
),
4096
)
if
v
-
128
==
0
:
T
.
tma_load
(
T
.
create_tma_descriptor
(
6
,
2
,
B
.
data
,
512
,
512
,
2
,
1024
,
64
,
32
,
1
,
1
,
0
,
3
,
2
,
0
),
T
.
get_mbarrier
(
k
%
3
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
bx
*
64
,
k
*
32
)
T
.
create_tma_descriptor
(
6
,
2
,
B
.
data
,
512
,
512
,
2
,
1024
,
64
,
32
,
1
,
1
,
0
,
3
,
2
,
0
),
T
.
get_mbarrier
(
k
%
3
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
bx
*
64
,
k
*
32
,
)
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_arrive_barrier"
,
[
T
.
get_mbarrier
(
k
%
3
)]))
else
:
T
.
set_max_nreg
(
240
,
1
)
for
k
in
range
(
16
):
T
.
mbarrier_wait_parity
(
T
.
get_mbarrier
(
k
%
3
),
k
//
3
%
2
)
T
.
call_extern
(
"handle"
,
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>"
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
))
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_arrive_barrier"
,
[
T
.
get_mbarrier
(
k
%
3
+
3
)]))
"handle"
,
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>"
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
),
)
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_arrive_barrier"
,
[
T
.
get_mbarrier
(
k
%
3
+
3
)]))
_check
(
before
,
after
)
...
...
testing/python/utils/test_compress_utils.py
View file @
29051439
...
...
@@ -6,7 +6,7 @@ from tilelang.utils.sparse import compress_sm90, randn_semi_sparse
def
_test_compress_sm90
(
M
,
K
,
block_k
,
dtype
):
A
=
randn_semi_sparse
(
M
,
K
,
dtype
=
dtype
,
device
=
'
cuda
'
)
A
=
randn_semi_sparse
(
M
,
K
,
dtype
=
dtype
,
device
=
"
cuda
"
)
A_sparse
,
E
=
compress_sm90
(
A
,
block_k
,
False
)
...
...
testing/python/webgpu/test_webgpu_codegen.py
View file @
29051439
...
...
@@ -5,7 +5,6 @@ import tilelang.language as T
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
...
tilelang/__init__.py
View file @
29051439
...
...
@@ -23,6 +23,7 @@ def _compute_version() -> str:
if
version_file
.
is_file
():
try
:
from
version_provider
import
dynamic_metadata
# type: ignore
return
dynamic_metadata
(
"version"
)
except
Exception
:
# Fall back to the raw VERSION file if provider isn't available.
...
...
@@ -33,6 +34,7 @@ def _compute_version() -> str:
try
:
from
importlib.metadata
import
version
as
_dist_version
# py3.8+
return
_dist_version
(
"tilelang"
)
except
Exception
as
exc
:
warnings
.
warn
(
...
...
tilelang/analysis/fragment_loop_checker.py
View file @
29051439
from
__future__
import
annotations
from
tvm
import
tir
from
tvm.tir
import
(
PyStmtExprVisitor
,
BufferStore
,
For
,
Var
,
PrimFunc
,
BufferLoad
,
IntImm
)
from
tvm.tir
import
PyStmtExprVisitor
,
BufferStore
,
For
,
Var
,
PrimFunc
,
BufferLoad
,
IntImm
from
tvm.tir.transform
import
prim_func_pass
from
tvm.tir.stmt_functor
import
post_order_visit
...
...
@@ -44,7 +44,6 @@ def collect_local_buffer_accesses(statement) -> list[BufferLoad | BufferStore]:
@
tir
.
functor
.
visitor
class
_FragmentLoopCheckVisitor
(
PyStmtExprVisitor
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
...
...
@@ -75,7 +74,8 @@ class _FragmentLoopCheckVisitor(PyStmtExprVisitor):
raise
ValueError
(
"[Tilelang Semantic Check] "
f
"Loop variable
{
loop
.
loop_var
}
in a T.Parallel loop with symbolic range (min=
{
loop
.
min
}
, extent=
{
loop
.
extent
}
) is used to index "
"a local/fragment buffer, which is not allowed in Tilelang."
)
"a local/fragment buffer, which is not allowed in Tilelang."
)
return
...
...
tilelang/analysis/layout_visual.py
View file @
29051439
...
...
@@ -23,10 +23,7 @@ def print_fragment_format(layout: T.Fragment) -> str:
if
isinstance
(
layout
,
T
.
Fragment
):
input_shape
=
layout
.
get_input_shape
()
output_shape
=
layout
.
get_output_shape
()
lines
=
[
f
" Shape:
{
input_shape
}
->
{
output_shape
}
"
,
f
" Thread:
{
layout
.
forward_thread
}
"
,
f
" Index:
{
layout
.
forward_index
}
"
]
lines
=
[
f
" Shape:
{
input_shape
}
->
{
output_shape
}
"
,
f
" Thread:
{
layout
.
forward_thread
}
"
,
f
" Index:
{
layout
.
forward_index
}
"
]
print
(
"
\n
"
.
join
(
lines
))
else
:
raise
ValueError
(
f
"Expected T.Fragment, but got
{
type
(
layout
).
__name__
}
"
)
...
...
@@ -82,7 +79,6 @@ class _LayoutVisualVisitor(PyStmtExprVisitor):
def
LayoutVisual
(
formats
:
str
=
""
):
def
pass_fn
(
func
:
tir
.
PrimFunc
,
mod
,
ctx
):
_LayoutVisualVisitor
(
formats
=
formats
).
visit_stmt
(
func
.
body
)
return
func
...
...
tilelang/analysis/nested_loop_checker.py
View file @
29051439
...
...
@@ -11,10 +11,7 @@ from tvm.tir.transform import prim_func_pass
def
is_pipelined_for
(
op
:
For
)
->
bool
:
"""Check if a for loop is pipelined."""
anno_keys
=
[
"num_stages"
,
"tl_pipeline_order"
,
"tl_pipeline_stage"
,
"tl_pipeline_sync"
,
"tl_pipeline_group"
]
anno_keys
=
[
"num_stages"
,
"tl_pipeline_order"
,
"tl_pipeline_stage"
,
"tl_pipeline_sync"
,
"tl_pipeline_group"
]
return
any
(
key
in
op
.
annotations
for
key
in
anno_keys
)
...
...
@@ -26,7 +23,6 @@ def is_tile_op(op: Call) -> bool:
@
tir
.
functor
.
visitor
class
_NestedLoopCheckVisitor
(
PyStmtExprVisitor
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
in_parallel_context
=
False
...
...
@@ -42,26 +38,23 @@ class _NestedLoopCheckVisitor(PyStmtExprVisitor):
# Otherwise
if
self
.
in_parallel_context
:
raise
ValueError
(
"[Tilelang Semantic Check] "
"Nested parallel loops are not allowed. "
"Please check your loop structure."
)
raise
ValueError
(
"[Tilelang Semantic Check] Nested parallel loops are not allowed. Please check your loop structure."
)
self
.
in_parallel_context
=
True
super
().
visit_for_
(
op
)
self
.
in_parallel_context
=
False
return
elif
is_pipelined_for
(
op
):
if
self
.
in_parallel_context
:
raise
ValueError
(
"[Tilelang Semantic Check] "
"
Pipelined loop cannot be nested inside a parallel loop. "
"Please check your loop structure."
)
raise
ValueError
(
"[Tilelang Semantic Check]
Pipelined loop cannot be nested inside a parallel loop.
Please check your loop structure.
"
)
super
().
visit_for_
(
op
)
def
visit_call_
(
self
,
op
:
Call
)
->
None
:
if
self
.
in_parallel_context
and
is_tile_op
(
op
):
raise
ValueError
(
"[Tilelang Semantic Check] "
"Only elementwise operations are allowed inside a parallel loop. "
\
f
"Got a tile-op
\"
{
op
.
op
}
\"
."
raise
ValueError
(
f
'[Tilelang Semantic Check] Only elementwise operations are allowed inside a parallel loop. Got a tile-op "
{
op
.
op
}
".'
)
...
...
tilelang/autotuner/capture.py
View file @
29051439
...
...
@@ -85,8 +85,7 @@ def _get_current_stack() -> CaptureStack:
class
AutotuneInputsCapture
:
__slots__
=
(
"tensors"
)
__slots__
=
"tensors"
def
__init__
(
self
,
tensors
:
list
[
Any
]):
self
.
tensors
=
tensors
...
...
tilelang/autotuner/param.py
View file @
29051439
"""The auto-tune parameters.
"""
"""The auto-tune parameters.
"""
from
__future__
import
annotations
import
tilelang
...
...
@@ -50,7 +50,7 @@ class CompileArgs:
out_idx
:
list
[
int
]
|
int
|
None
=
None
execution_backend
:
Literal
[
"auto"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
=
"auto"
target
:
Literal
[
'
auto
'
,
'
cuda
'
,
'
hip
'
]
=
'
auto
'
target
:
Literal
[
"
auto
"
,
"
cuda
"
,
"
hip
"
]
=
"
auto
"
target_host
:
str
|
Target
=
None
verbose
:
bool
=
False
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
...
...
@@ -62,24 +62,20 @@ class CompileArgs:
target
=
self
.
target
,
target_host
=
self
.
target_host
,
verbose
=
self
.
verbose
,
pass_configs
=
self
.
pass_configs
)
pass_configs
=
self
.
pass_configs
,
)
def
__hash__
(
self
):
data
=
{
"execution_backend"
:
self
.
execution_backend
,
"target"
:
str
(
self
.
target
),
"target_host"
:
str
(
self
.
target_host
)
if
self
.
target_host
else
None
,
"verbose"
:
self
.
verbose
,
"pass_configs"
:
json
.
dumps
(
self
.
pass_configs
,
sort_keys
=
True
)
if
self
.
pass_configs
else
None
,
"execution_backend"
:
self
.
execution_backend
,
"target"
:
str
(
self
.
target
),
"target_host"
:
str
(
self
.
target_host
)
if
self
.
target_host
else
None
,
"verbose"
:
self
.
verbose
,
"pass_configs"
:
json
.
dumps
(
self
.
pass_configs
,
sort_keys
=
True
)
if
self
.
pass_configs
else
None
,
}
hash_obj
=
hashlib
.
sha256
(
json
.
dumps
(
data
,
sort_keys
=
True
).
encode
(
'
utf-8
'
))
return
int
.
from_bytes
(
hash_obj
.
digest
(),
byteorder
=
'
big
'
)
hash_obj
=
hashlib
.
sha256
(
json
.
dumps
(
data
,
sort_keys
=
True
).
encode
(
"
utf-8
"
))
return
int
.
from_bytes
(
hash_obj
.
digest
(),
byteorder
=
"
big
"
)
@
dataclass
(
frozen
=
True
)
...
...
@@ -104,6 +100,7 @@ class ProfileArgs:
manual_check_prog: Callable = None
cache_input_tensors: bool = True
"""
warmup
:
int
=
25
rep
:
int
=
100
timeout
:
int
=
30
...
...
@@ -127,8 +124,8 @@ class ProfileArgs:
"atol"
:
self
.
atol
,
"max_mismatched_ratio"
:
self
.
max_mismatched_ratio
,
}
hash_obj
=
hashlib
.
sha256
(
json
.
dumps
(
data
,
sort_keys
=
True
).
encode
(
'
utf-8
'
))
return
int
.
from_bytes
(
hash_obj
.
digest
(),
byteorder
=
'
big
'
)
hash_obj
=
hashlib
.
sha256
(
json
.
dumps
(
data
,
sort_keys
=
True
).
encode
(
"
utf-8
"
))
return
int
.
from_bytes
(
hash_obj
.
digest
(),
byteorder
=
"
big
"
)
@
dataclass
(
frozen
=
True
)
...
...
@@ -143,6 +140,7 @@ class AutotuneResult:
func: Optimized function.
kernel: Compiled kernel function.
"""
latency
:
float
|
None
=
None
config
:
dict
|
None
=
None
ref_latency
:
float
|
None
=
None
...
...
@@ -199,8 +197,7 @@ class AutotuneResult:
if
verbose
:
logger
.
debug
(
f
"Saving kernel source code to file:
{
device_kernel_path
}
"
)
if
kernel
.
kernel_source
is
not
None
:
self
.
_safe_write_file
(
device_kernel_path
,
"w"
,
lambda
f
:
f
.
write
(
kernel
.
kernel_source
))
self
.
_safe_write_file
(
device_kernel_path
,
"w"
,
lambda
f
:
f
.
write
(
kernel
.
kernel_source
))
except
Exception
as
e
:
logger
.
error
(
f
"Error saving kernel source code to disk:
{
e
}
"
)
...
...
@@ -211,11 +208,9 @@ class AutotuneResult:
logger
.
debug
(
f
"Saving wrapped kernel source code to file:
{
host_kernel_path
}
"
)
# Match kernel_cache behavior: use host source for tvm_ffi, otherwise wrapped kernel
if
kernel
.
execution_backend
==
"tvm_ffi"
:
self
.
_safe_write_file
(
host_kernel_path
,
"w"
,
lambda
f
:
f
.
write
(
kernel
.
adapter
.
get_host_source
()))
self
.
_safe_write_file
(
host_kernel_path
,
"w"
,
lambda
f
:
f
.
write
(
kernel
.
adapter
.
get_host_source
()))
else
:
self
.
_safe_write_file
(
host_kernel_path
,
"w"
,
lambda
f
:
f
.
write
(
kernel
.
adapter
.
get_kernel_source
()))
self
.
_safe_write_file
(
host_kernel_path
,
"w"
,
lambda
f
:
f
.
write
(
kernel
.
adapter
.
get_kernel_source
()))
except
Exception
as
e
:
logger
.
error
(
f
"Error saving wrapped kernel source code to disk:
{
e
}
"
)
...
...
@@ -237,12 +232,10 @@ class AutotuneResult:
py_src_path
=
src_lib_path
.
replace
(
".cubin"
,
".py"
)
if
verbose
:
logger
.
debug
(
f
"Saving kernel nvrtc python code to file:
{
kernel_py_path
}
"
)
self
.
_safe_write_file
(
kernel_py_path
,
"wb"
,
lambda
f
:
f
.
write
(
self
.
_load_binary
(
py_src_path
)))
self
.
_safe_write_file
(
kernel_py_path
,
"wb"
,
lambda
f
:
f
.
write
(
self
.
_load_binary
(
py_src_path
)))
if
verbose
:
logger
.
debug
(
f
"Saving kernel library to file:
{
kernel_lib_path
}
"
)
self
.
_safe_write_file
(
kernel_lib_path
,
"wb"
,
lambda
f
:
f
.
write
(
self
.
_load_binary
(
src_lib_path
)))
self
.
_safe_write_file
(
kernel_lib_path
,
"wb"
,
lambda
f
:
f
.
write
(
self
.
_load_binary
(
src_lib_path
)))
elif
kernel
.
execution_backend
==
"tvm_ffi"
:
executable
=
kernel
.
adapter
.
executable
if
verbose
:
...
...
@@ -252,8 +245,7 @@ class AutotuneResult:
src_lib_path
=
kernel
.
adapter
.
libpath
if
verbose
:
logger
.
debug
(
f
"Saving kernel library to file:
{
kernel_lib_path
}
"
)
self
.
_safe_write_file
(
kernel_lib_path
,
"wb"
,
lambda
f
:
f
.
write
(
self
.
_load_binary
(
src_lib_path
)))
self
.
_safe_write_file
(
kernel_lib_path
,
"wb"
,
lambda
f
:
f
.
write
(
self
.
_load_binary
(
src_lib_path
)))
except
Exception
as
e
:
logger
.
error
(
f
"Error saving kernel library to disk:
{
e
}
"
)
...
...
@@ -370,14 +362,12 @@ class AutotuneResult:
# save best config (atomic)
if
verbose
:
logger
.
debug
(
f
"Saving best config to file:
{
path
/
BEST_CONFIG_PATH
}
"
)
self
.
_safe_write_file
(
str
(
path
/
BEST_CONFIG_PATH
),
"w"
,
lambda
f
:
json
.
dump
(
self
.
config
,
f
))
self
.
_safe_write_file
(
str
(
path
/
BEST_CONFIG_PATH
),
"w"
,
lambda
f
:
json
.
dump
(
self
.
config
,
f
))
# save function (atomic)
if
verbose
:
logger
.
debug
(
f
"Saving function to file:
{
path
/
FUNCTION_PATH
}
"
)
self
.
_safe_write_file
(
str
(
path
/
FUNCTION_PATH
),
"wb"
,
lambda
f
:
cloudpickle
.
dump
(
self
.
func
,
f
))
self
.
_safe_write_file
(
str
(
path
/
FUNCTION_PATH
),
"wb"
,
lambda
f
:
cloudpickle
.
dump
(
self
.
func
,
f
))
# save ref latency (atomic)
if
verbose
:
...
...
@@ -385,10 +375,13 @@ class AutotuneResult:
self
.
_safe_write_file
(
str
(
path
/
LATENCY_PATH
),
"w"
,
lambda
f
:
json
.
dump
({
lambda
f
:
json
.
dump
(
{
"latency"
:
self
.
latency
,
"ref_latency"
:
self
.
ref_latency
,
},
f
),
},
f
,
),
)
# save kernel
...
...
@@ -403,8 +396,8 @@ class AutotuneResult:
# Normalize target and resolve execution backend for loading
from
tilelang.utils.target
import
determine_target
as
_determine_target
from
tilelang.jit.execution_backend
import
resolve_execution_backend
norm_target
=
Target
(
_determine_target
(
compile_args
.
target
))
if
isinstance
(
compile_args
.
target
,
str
)
else
compile_args
.
target
norm_target
=
Target
(
_determine_target
(
compile_args
.
target
))
if
isinstance
(
compile_args
.
target
,
str
)
else
compile_args
.
target
requested_backend
=
compile_args
.
execution_backend
resolved_backend
=
resolve_execution_backend
(
requested_backend
,
norm_target
)
# load best config
...
...
tilelang/autotuner/tuner.py
View file @
29051439
...
...
@@ -3,6 +3,7 @@
This module provides functionality for auto-tuning tilelang programs, including JIT compilation
and performance optimization through configuration search.
"""
from
__future__
import
annotations
from
dataclasses
import
dataclass
...
...
@@ -14,7 +15,8 @@ from tvm.tir import PrimFunc, Var
from
tvm.target
import
Target
import
inspect
from
functools
import
partial
from
typing
import
(
Callable
,
Generic
,
Literal
,
Any
,
TypeVar
)
from
typing
import
Callable
,
Generic
,
Literal
,
Any
,
TypeVar
# Python 3.9 compatibility for ParamSpec
try
:
from
typing
import
ParamSpec
...
...
@@ -74,8 +76,8 @@ def _init_logger_handlers():
global
_logger_handlers_initialized
if
_logger_handlers_initialized
:
return
formatter
=
logging
.
Formatter
(
'
%(asctime)s %(levelname)s:%(message)s
'
)
file_handler
=
logging
.
FileHandler
(
'
autotuner.log
'
,
mode
=
'w'
)
formatter
=
logging
.
Formatter
(
"
%(asctime)s %(levelname)s:%(message)s
"
)
file_handler
=
logging
.
FileHandler
(
"
autotuner.log
"
,
mode
=
"w"
)
file_handler
.
setLevel
(
logging
.
DEBUG
)
file_handler
.
setFormatter
(
formatter
)
console_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
...
...
@@ -87,8 +89,7 @@ def _init_logger_handlers():
def
get_available_cpu_count
()
->
int
:
"""Gets the number of CPU cores available to the current process.
"""
"""Gets the number of CPU cores available to the current process."""
try
:
cpu_count
=
len
(
os
.
sched_getaffinity
(
0
))
except
AttributeError
:
...
...
@@ -107,6 +108,7 @@ class AutoTuner:
fn: The function to be auto-tuned.
configs: List of configurations to try during auto-tuning.
"""
compile_args
=
CompileArgs
()
profile_args
=
ProfileArgs
()
...
...
@@ -137,14 +139,15 @@ class AutoTuner:
"""
return
cls
(
kernel
,
configs
)
def
set_compile_args
(
self
,
def
set_compile_args
(
self
,
out_idx
:
list
[
int
]
|
int
|
None
=
None
,
target
:
Literal
[
'auto'
,
'cuda'
,
'hip'
,
'metal'
]
=
'auto'
,
execution_backend
:
Literal
[
"auto"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
=
"auto"
,
target
:
Literal
[
"auto"
,
"cuda"
,
"hip"
,
"metal"
]
=
"auto"
,
execution_backend
:
Literal
[
"auto"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
=
"auto"
,
target_host
:
str
|
Target
=
None
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
):
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
):
"""Set compilation arguments for the auto-tuner.
Args:
...
...
@@ -161,6 +164,7 @@ class AutoTuner:
# Normalize target to a concrete TVM Target and resolve execution backend
t
=
Target
(
determine_target
(
target
))
from
tilelang.jit.execution_backend
import
resolve_execution_backend
resolved_backend
=
resolve_execution_backend
(
execution_backend
,
t
)
self
.
compile_args
=
CompileArgs
(
...
...
@@ -169,11 +173,13 @@ class AutoTuner:
execution_backend
=
resolved_backend
,
target_host
=
target_host
,
verbose
=
verbose
,
pass_configs
=
pass_configs
)
pass_configs
=
pass_configs
,
)
return
self
def
set_profile_args
(
self
,
def
set_profile_args
(
self
,
warmup
:
int
=
25
,
rep
:
int
=
100
,
timeout
:
int
=
30
,
...
...
@@ -185,7 +191,8 @@ class AutoTuner:
max_mismatched_ratio
:
float
=
0.01
,
skip_check
:
bool
=
False
,
manual_check_prog
:
Callable
=
None
,
cache_input_tensors
:
bool
=
False
):
cache_input_tensors
:
bool
=
False
,
):
"""Set profiling arguments for the auto-tuner.
Args:
...
...
@@ -209,9 +216,7 @@ class AutoTuner:
# the `supply_prog` will be ignored and the `get_autotune_inputs` will be used instead.
if
get_autotune_inputs
()
is
not
None
:
if
supply_prog
is
not
None
:
logger
.
warning
(
"`supply_prog` will be ignored as this program is under `with set_autotune_inputs` context."
)
logger
.
warning
(
"`supply_prog` will be ignored as this program is under `with set_autotune_inputs` context."
)
supply_prog
=
lambda
_
:
get_autotune_inputs
()
# noqa: E731
self
.
profile_args
=
ProfileArgs
(
...
...
@@ -226,13 +231,13 @@ class AutoTuner:
cache_input_tensors
=
cache_input_tensors
,
warmup
=
warmup
,
rep
=
rep
,
timeout
=
timeout
)
timeout
=
timeout
,
)
# If a custom `supply_prog` is provided, the profiler's `supply_type` setting
# becomes ineffective. The custom supply program will be used instead.
if
supply_prog
is
not
None
and
supply_type
!=
tilelang
.
TensorSupplyType
.
Auto
:
logger
.
warning
(
"Ignoring `supply_type` passed to `set_profile_args` because "
"`supply_prog` is not None."
)
logger
.
warning
(
"Ignoring `supply_type` passed to `set_profile_args` because `supply_prog` is not None."
)
return
self
...
...
@@ -241,10 +246,8 @@ class AutoTuner:
self
.
_kernel_parameters
=
k_parameters
self
.
_function_parameters
=
f_parameters
def
generate_cache_key
(
self
,
parameters
:
dict
[
str
,
Any
],
extra_parameters
:
dict
[
str
,
Any
])
->
AutotuneResult
|
None
:
"""Generate a cache key for the auto-tuning process.
"""
def
generate_cache_key
(
self
,
parameters
:
dict
[
str
,
Any
],
extra_parameters
:
dict
[
str
,
Any
])
->
AutotuneResult
|
None
:
"""Generate a cache key for the auto-tuning process."""
def
_normalize_param
(
value
):
if
isinstance
(
value
,
Var
):
...
...
@@ -315,8 +318,9 @@ class AutoTuner:
if
var_name
in
parameters
:
continue
# Cell content must be serializable
assert
isinstance
(
cell
.
cell_contents
,
(
int
,
float
,
str
,
bool
,
type
(
None
))),
\
assert
isinstance
(
cell
.
cell_contents
,
(
int
,
float
,
str
,
bool
,
type
(
None
))),
(
f
"Cell contents
{
cell
.
cell_contents
}
is not serializable:
{
type
(
cell
.
cell_contents
)
}
"
)
extra_parameters
[
var_name
]
=
cell
.
cell_contents
if
isinstance
(
self
.
configs
,
Callable
):
...
...
@@ -328,8 +332,10 @@ class AutoTuner:
if
env
.
is_cache_enabled
()
and
not
env
.
is_autotune_cache_disabled
():
# First check in-memory cache
if
key
in
self
.
_memory_cache
:
logger
.
warning
(
"Found kernel in memory cache. For better performance,"
\
" consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel."
)
logger
.
warning
(
"Found kernel in memory cache. For better performance,"
" consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel."
)
return
self
.
_memory_cache
[
key
]
# Then check disk cache
...
...
@@ -369,7 +375,6 @@ class AutoTuner:
# This encapsulates the logic of using either a custom supply program (`supply_prog`)
# or the default profiler input generation (`profiler._get_inputs`).
def
get_input_tensors_supply
(
with_output
:
bool
):
def
func
():
if
supply_prog
is
not
None
:
return
supply_prog
(
profiler
.
_get_params
(
with_output
=
with_output
))
...
...
@@ -387,8 +392,7 @@ class AutoTuner:
self
.
jit_input_tensors
=
jit_input_tensors_supply
()
else
:
# check if the cached tensors are compatible with the current configuration
assert
len
(
params
)
==
len
(
self
.
jit_input_tensors
),
"len(params) != len(self.jit_input_tensors)"
assert
len
(
params
)
==
len
(
self
.
jit_input_tensors
),
"len(params) != len(self.jit_input_tensors)"
for
p
,
c
in
zip
(
params
,
self
.
jit_input_tensors
):
if
not
isinstance
(
c
,
torch
.
Tensor
):
# skip non-tensor inputs checking
...
...
@@ -397,8 +401,8 @@ class AutoTuner:
# Check tensor compatibility using generator expression
def
shape_equal
(
a
,
b
):
return
all
(
a_dim
==
b_dim
or
isinstance
(
a_dim
,
Var
)
or
isinstance
(
b_dim
,
Var
)
for
a_dim
,
b_dim
in
zip
(
a
.
shape
,
b
.
shape
)
)
a_dim
==
b_dim
or
isinstance
(
a_dim
,
Var
)
or
isinstance
(
b_dim
,
Var
)
for
a_dim
,
b_dim
in
zip
(
a
.
shape
,
b
.
shape
)
)
if
p
.
dtype
!=
c
.
dtype
or
not
shape_equal
(
p
,
c
):
logger
.
warning
(
...
...
@@ -409,7 +413,8 @@ class AutoTuner:
"To ensure fresh, compatible inputs are generated for every trial "
"you can disable caching by setting:
\n
"
" `cache_input_tensors=False`
\n
"
"within your `.set_compile_args(...)` call.
\n
"
)
"within your `.set_compile_args(...)` call.
\n
"
)
# otherwise, regenerate the input tensors for safety
self
.
jit_input_tensors
=
jit_input_tensors_supply
()
break
...
...
@@ -418,24 +423,16 @@ class AutoTuner:
if
(
not
skip_check
)
and
(
ref_prog
is
not
None
):
if
manual_check_prog
is
not
None
:
profiler
.
manual_assert_close
(
ref_prog
,
input_tensors
=
self
.
jit_input_tensors
,
manual_check_prog
=
manual_check_prog
)
profiler
.
manual_assert_close
(
ref_prog
,
input_tensors
=
self
.
jit_input_tensors
,
manual_check_prog
=
manual_check_prog
)
else
:
profiler
.
assert_allclose
(
ref_prog
,
input_tensors
=
self
.
jit_input_tensors
,
rtol
=
rtol
,
atol
=
atol
,
max_mismatched_ratio
=
max_mismatched_ratio
)
latency
=
profiler
.
do_bench
(
warmup
=
warmup
,
rep
=
rep
,
input_tensors
=
self
.
jit_input_tensors
)
ref_prog
,
input_tensors
=
self
.
jit_input_tensors
,
rtol
=
rtol
,
atol
=
atol
,
max_mismatched_ratio
=
max_mismatched_ratio
)
latency
=
profiler
.
do_bench
(
warmup
=
warmup
,
rep
=
rep
,
input_tensors
=
self
.
jit_input_tensors
)
if
self
.
ref_latency_cache
is
None
and
ref_prog
is
not
None
:
self
.
ref_input_tensors
=
ref_input_tensors_supply
()
self
.
ref_latency_cache
=
profiler
.
do_bench
(
ref_prog
,
n_warmup
=
warmup
,
n_repeat
=
rep
,
input_tensors
=
self
.
ref_input_tensors
)
self
.
ref_latency_cache
=
profiler
.
do_bench
(
ref_prog
,
n_warmup
=
warmup
,
n_repeat
=
rep
,
input_tensors
=
self
.
ref_input_tensors
)
return
latency
,
self
.
ref_latency_cache
...
...
@@ -469,17 +466,14 @@ class AutoTuner:
# Check if all tunable arguments have been tuned by comparing config keys with key_kwargs_tuple
if
any
(
key
in
top_config
for
key
,
_
in
key_kwargs_tuple
)
or
any
(
check_tunable_argument_value
(
key
,
self
.
_function_parameters
,
key_args_tuple
)
for
key
in
tunable_arguments
):
check_tunable_argument_value
(
key
,
self
.
_function_parameters
,
key_args_tuple
)
for
key
in
tunable_arguments
):
logger
.
warning
(
f
"Tunable parameters
{
tunable_arguments
}
already provided during auto-tuning. Skipping compilation and using direct JIT"
)
# compile the kernel with the provided parameters
jit_kernel
=
self
.
jit_compile
()
autotuner_result
=
AutotuneResult
(
libcode
=
jit_kernel
.
get_kernel_source
(),
func
=
jit_kernel
.
prim_func
,
kernel
=
jit_kernel
)
autotuner_result
=
AutotuneResult
(
libcode
=
jit_kernel
.
get_kernel_source
(),
func
=
jit_kernel
.
prim_func
,
kernel
=
jit_kernel
)
self
.
_memory_cache
[
key
]
=
autotuner_result
return
autotuner_result
# get the cpu count
...
...
@@ -489,9 +483,7 @@ class AutoTuner:
max_cpu_count
=
int
(
env
.
TILELANG_AUTO_TUNING_MAX_CPU_COUNT
)
if
cpu_counts
>
0
:
num_workers
=
min
(
cpu_counts
,
available_cpu_count
)
logger
.
info
(
f
"Auto-tuning with
{
cpu_counts
}
CPU counts,
{
available_cpu_count
}
CPUs available,
{
num_workers
}
CPUs will be used"
)
logger
.
info
(
f
"Auto-tuning with
{
cpu_counts
}
CPU counts,
{
available_cpu_count
}
CPUs available,
{
num_workers
}
CPUs will be used"
)
else
:
num_workers
=
max
(
1
,
int
(
available_cpu_count
*
cpu_utilizations
))
logger
.
info
(
...
...
@@ -509,7 +501,6 @@ class AutoTuner:
future_to_index
=
{}
def
cuda_device_wrapper
(
func
,
device
):
def
inner
(
**
config_arg
):
torch
.
cuda
.
set_device
(
device
)
return
func
(
**
config_arg
)
...
...
@@ -532,18 +523,14 @@ class AutoTuner:
future_to_index
[
future
]
=
i
results_with_configs
=
[]
for
future
in
tqdm
(
concurrent
.
futures
.
as_completed
(
futures
),
total
=
len
(
futures
),
desc
=
"Compiling configurations"
):
for
future
in
tqdm
(
concurrent
.
futures
.
as_completed
(
futures
),
total
=
len
(
futures
),
desc
=
"Compiling configurations"
):
idx
=
future_to_index
[
future
]
config
=
config_args
[
idx
]
try
:
result
=
future
.
result
()
results_with_configs
.
append
((
result
,
config
))
except
Exception
as
e
:
logger
.
debug
(
f
"Compilation failed for config
{
config
}
at index
{
idx
}
with error:
{
e
}
"
)
logger
.
debug
(
f
"Compilation failed for config
{
config
}
at index
{
idx
}
with error:
{
e
}
"
)
continue
ref_latency
=
None
...
...
@@ -556,14 +543,10 @@ class AutoTuner:
# latency, ref_latency = target_fn(jit_kernel)
latency
,
ref_latency
=
run_with_timeout
(
target_fn
,
timeout
,
jit_kernel
)
except
TimeoutException
:
logger
.
warning
(
f
"A timeout occurred while testing config
{
config
}
, checkout autotuner.log for more details"
)
logger
.
warning
(
f
"A timeout occurred while testing config
{
config
}
, checkout autotuner.log for more details"
)
continue
except
Exception
:
logger
.
warning
(
f
"An error occurred while testing config
{
config
}
, checkout autotuner.log for more details"
)
logger
.
warning
(
f
"An error occurred while testing config
{
config
}
, checkout autotuner.log for more details"
)
logger
.
debug
(
f
"Error:
{
traceback
.
format_exc
()
}
"
)
continue
...
...
@@ -578,8 +561,7 @@ class AutoTuner:
pool
.
shutdown
()
if
best_kernel
is
None
:
error_msg
=
(
"Auto-tuning failed: No configuration successfully "
"compiled and passed benchmarking/validation."
)
error_msg
=
"Auto-tuning failed: No configuration successfully compiled and passed benchmarking/validation."
logger
.
error
(
error_msg
)
raise
RuntimeError
(
error_msg
)
...
...
@@ -595,7 +577,8 @@ class AutoTuner:
ref_latency
=
ref_latency
,
libcode
=
best_kernel
.
get_kernel_source
(),
func
=
best_kernel
.
prim_func
,
kernel
=
best_kernel
)
kernel
=
best_kernel
,
)
if
self
.
compile_args
.
execution_backend
in
(
"torch"
):
logger
.
warning
(
"DLPack backend does not support cache saving to disk."
)
...
...
@@ -617,8 +600,8 @@ class AutoTuner:
return
self
.
run
()
_P
=
ParamSpec
(
'
_P
'
)
_T
=
TypeVar
(
'
_T
'
)
_P
=
ParamSpec
(
"
_P
"
)
_T
=
TypeVar
(
"
_T
"
)
@
dataclass
...
...
@@ -643,8 +626,9 @@ class AutoTuneImpl(Generic[_P, _T]):
self
.
_tuner_cache
=
{}
def
get_tunner
(
self
):
autotuner
=
AutoTuner
(
self
.
jit_impl
.
func
,
configs
=
self
.
configs
).
set_profile_args
(
autotuner
=
(
AutoTuner
(
self
.
jit_impl
.
func
,
configs
=
self
.
configs
)
.
set_profile_args
(
supply_type
=
self
.
supply_type
,
ref_prog
=
self
.
ref_prog
,
supply_prog
=
self
.
supply_prog
,
...
...
@@ -654,7 +638,8 @@ class AutoTuneImpl(Generic[_P, _T]):
skip_check
=
self
.
skip_check
,
manual_check_prog
=
self
.
manual_check_prog
,
cache_input_tensors
=
self
.
cache_input_tensors
,
).
set_compile_args
(
)
.
set_compile_args
(
out_idx
=
self
.
jit_impl
.
out_idx
,
execution_backend
=
self
.
jit_impl
.
execution_backend
,
target
=
self
.
jit_impl
.
target
,
...
...
@@ -662,6 +647,7 @@ class AutoTuneImpl(Generic[_P, _T]):
verbose
=
self
.
jit_impl
.
verbose
,
pass_configs
=
self
.
jit_impl
.
pass_configs
,
)
)
autotuner
.
run
=
partial
(
autotuner
.
run
,
self
.
warmup
,
self
.
rep
,
self
.
timeout
)
return
autotuner
...
...
@@ -753,16 +739,13 @@ def autotune( # This is the new public interface
if
callable
(
func
):
# Case 1: Used as @autotune (func_or_out_idx is the function, others are defaults)
# This is a placeholder for a real auto tuner implementation
raise
ValueError
(
"Use tilelang.autotune to decorate func without arguments is not supported yet."
)
raise
ValueError
(
"Use tilelang.autotune to decorate func without arguments is not supported yet."
)
elif
isinstance
(
func
,
PrimFunc
):
raise
ValueError
(
"Use tilelang.jit to decorate prim_func is not supported yet."
)
else
:
def
decorator
(
impl
):
assert
isinstance
(
impl
,
JITImpl
),
"The @autotune decorator can only be applied to @tilelang.jit decorated instances."
assert
isinstance
(
impl
,
JITImpl
),
"The @autotune decorator can only be applied to @tilelang.jit decorated instances."
return
AutoTuneImpl
(
jit_impl
=
impl
,
configs
=
configs
,
...
...
tilelang/cache/__init__.py
View file @
29051439
"""The cache utils with class and database persistence - Init file"""
from
__future__
import
annotations
from
typing
import
Literal
...
...
@@ -18,8 +19,7 @@ def cached(
*
args
,
target
:
str
|
Target
=
"auto"
,
target_host
:
str
|
Target
=
None
,
execution_backend
:
Literal
[
"auto"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
|
None
=
"auto"
,
execution_backend
:
Literal
[
"auto"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
|
None
=
"auto"
,
verbose
:
bool
|
None
=
False
,
pass_configs
:
dict
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
,
...
...
@@ -36,7 +36,8 @@ def cached(
execution_backend
=
execution_backend
,
verbose
=
verbose
,
pass_configs
=
pass_configs
,
compile_flags
=
compile_flags
)
compile_flags
=
compile_flags
,
)
def
clear_cache
():
...
...
@@ -47,9 +48,11 @@ def clear_cache():
RuntimeError: Always raised to warn users to clear the cache manually.
"""
cache_dir
=
env
.
TILELANG_CACHE_DIR
raise
RuntimeError
(
"tilelang.clear_cache() is disabled because deleting the cache directory "
raise
RuntimeError
(
"tilelang.clear_cache() is disabled because deleting the cache directory "
"is dangerous. If you accept the risk, remove it manually with "
f
"`rm -rf '
{
cache_dir
}
'`."
)
f
"`rm -rf '
{
cache_dir
}
'`."
)
if
env
.
TILELANG_CLEAR_CACHE
.
lower
()
in
(
"1"
,
"true"
,
"yes"
,
"on"
):
...
...
tilelang/cache/kernel_cache.py
View file @
29051439
"""The cache utils with class and database persistence - KernelCache Class"""
from
__future__
import
annotations
import
json
...
...
@@ -97,9 +98,7 @@ class KernelCache:
"version"
:
__version__
,
"func"
:
sha256
(
func_binary
).
hexdigest
(),
# Use SHA256 to generate hash key
"out_idx"
:
(
tuple
(
out_idx
)
if
isinstance
(
out_idx
,
(
list
,
tuple
))
else
[
out_idx
]),
"args_repr"
:
tuple
(
repr
(
arg
)
for
arg
in
args
),
# Use repr to serialize arguments, may need more robust serialization
"args_repr"
:
tuple
(
repr
(
arg
)
for
arg
in
args
),
# Use repr to serialize arguments, may need more robust serialization
"target"
:
str
(
target
),
"target_host"
:
str
(
target_host
)
if
target_host
else
None
,
"execution_backend"
:
execution_backend
,
...
...
@@ -118,8 +117,7 @@ class KernelCache:
*
args
,
target
:
str
|
Target
=
"auto"
,
target_host
:
str
|
Target
=
None
,
execution_backend
:
Literal
[
"auto"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
=
"auto"
,
execution_backend
:
Literal
[
"auto"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
=
"auto"
,
verbose
:
bool
=
False
,
pass_configs
:
dict
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
,
...
...
@@ -140,6 +138,7 @@ class KernelCache:
# Normalize target and resolve execution backend before proceeding
from
tilelang.utils.target
import
determine_target
as
_determine_target
from
tilelang.jit.execution_backend
import
resolve_execution_backend
,
allowed_backends_for_target
norm_target
=
Target
(
_determine_target
(
target
))
if
isinstance
(
target
,
str
)
else
target
requested_backend
=
execution_backend
execution_backend
=
resolve_execution_backend
(
requested_backend
,
norm_target
)
...
...
@@ -180,21 +179,21 @@ class KernelCache:
with
self
.
_lock
:
# First check in-memory cache
if
key
in
self
.
_memory_cache
:
self
.
logger
.
warning
(
"Found kernel in memory cache. For better performance,"
\
" consider using `@tilelang.jit` instead of direct kernel caching."
)
self
.
logger
.
warning
(
"Found kernel in memory cache. For better performance, consider using `@tilelang.jit` instead of direct kernel caching."
)
return
self
.
_memory_cache
[
key
]
if
verbose
:
self
.
logger
.
debug
(
f
"Checking disk cache for kernel
{
func
.
attrs
[
'global_symbol'
]
}
"
)
# Then check disk cache
kernel
=
self
.
_load_kernel_from_disk
(
key
,
norm_target
,
target_host
,
out_idx
,
execution_backend
,
pass_configs
,
compile_flags
,
func
,
verbose
)
kernel
=
self
.
_load_kernel_from_disk
(
key
,
norm_target
,
target_host
,
out_idx
,
execution_backend
,
pass_configs
,
compile_flags
,
func
,
verbose
)
if
kernel
is
not
None
:
if
verbose
:
self
.
logger
.
debug
(
f
"Found kernel in disk cache for
{
func
.
attrs
[
'global_symbol'
]
}
"
)
self
.
logger
.
debug
(
f
"Found kernel in disk cache for
{
func
.
attrs
[
'global_symbol'
]
}
"
)
# Populate memory cache with disk result
self
.
_memory_cache
[
key
]
=
kernel
return
kernel
...
...
@@ -262,11 +261,7 @@ class KernelCache:
executable
.
export_library
(
temp_path
)
os
.
replace
(
temp_path
,
path
)
def
_save_kernel_to_disk
(
self
,
key
:
str
,
kernel
:
JITKernel
,
func
:
Callable
=
None
,
verbose
:
bool
=
False
):
def
_save_kernel_to_disk
(
self
,
key
:
str
,
kernel
:
JITKernel
,
func
:
Callable
=
None
,
verbose
:
bool
=
False
):
"""
Persists a compiled kernel to disk cache.
...
...
@@ -292,8 +287,7 @@ class KernelCache:
if
verbose
:
self
.
logger
.
debug
(
f
"Saving kernel source code to file:
{
device_kernel_path
}
"
)
if
kernel
.
kernel_source
is
not
None
:
KernelCache
.
_safe_write_file
(
device_kernel_path
,
"w"
,
lambda
file
:
file
.
write
(
kernel
.
kernel_source
))
KernelCache
.
_safe_write_file
(
device_kernel_path
,
"w"
,
lambda
file
:
file
.
write
(
kernel
.
kernel_source
))
except
Exception
as
e
:
self
.
logger
.
error
(
f
"Error saving kernel source code to disk:
{
e
}
"
)
...
...
@@ -303,13 +297,9 @@ class KernelCache:
if
verbose
:
self
.
logger
.
debug
(
f
"Saving wrapped kernel source code to file:
{
host_kernel_path
}
"
)
if
self
.
execution_backend
==
"tvm_ffi"
:
KernelCache
.
_safe_write_file
(
host_kernel_path
,
"w"
,
lambda
file
:
file
.
write
(
kernel
.
adapter
.
get_host_source
()))
KernelCache
.
_safe_write_file
(
host_kernel_path
,
"w"
,
lambda
file
:
file
.
write
(
kernel
.
adapter
.
get_host_source
()))
else
:
KernelCache
.
_safe_write_file
(
host_kernel_path
,
"w"
,
lambda
file
:
file
.
write
(
kernel
.
adapter
.
get_kernel_source
()))
KernelCache
.
_safe_write_file
(
host_kernel_path
,
"w"
,
lambda
file
:
file
.
write
(
kernel
.
adapter
.
get_kernel_source
()))
except
Exception
as
e
:
self
.
logger
.
error
(
f
"Error saving host kernel source code to disk:
{
e
}
"
)
...
...
@@ -332,9 +322,7 @@ class KernelCache:
src_lib_path
=
src_lib_path
.
replace
(
".cubin"
,
".py"
)
if
verbose
:
self
.
logger
.
debug
(
f
"Saving kernel nvrtc python code to file:
{
kernel_py_path
}
"
)
KernelCache
.
_safe_write_file
(
kernel_py_path
,
"wb"
,
lambda
file
:
file
.
write
(
KernelCache
.
_load_binary
(
src_lib_path
)))
KernelCache
.
_safe_write_file
(
kernel_py_path
,
"wb"
,
lambda
file
:
file
.
write
(
KernelCache
.
_load_binary
(
src_lib_path
)))
elif
self
.
execution_backend
==
"tvm_ffi"
:
executable
=
kernel
.
adapter
.
executable
if
verbose
:
...
...
@@ -344,9 +332,7 @@ class KernelCache:
src_lib_path
=
kernel
.
adapter
.
libpath
if
verbose
:
self
.
logger
.
debug
(
f
"Saving kernel library to file:
{
kernel_lib_path
}
"
)
KernelCache
.
_safe_write_file
(
kernel_lib_path
,
"wb"
,
lambda
file
:
file
.
write
(
KernelCache
.
_load_binary
(
src_lib_path
)))
KernelCache
.
_safe_write_file
(
kernel_lib_path
,
"wb"
,
lambda
file
:
file
.
write
(
KernelCache
.
_load_binary
(
src_lib_path
)))
except
Exception
as
e
:
self
.
logger
.
error
(
f
"Error saving kernel library to disk:
{
e
}
"
)
...
...
@@ -356,8 +342,7 @@ class KernelCache:
params_path
=
os
.
path
.
join
(
cache_path
,
PARAMS_PATH
)
if
verbose
:
self
.
logger
.
debug
(
f
"Saving kernel parameters to disk:
{
params_path
}
"
)
KernelCache
.
_safe_write_file
(
params_path
,
"wb"
,
lambda
file
:
cloudpickle
.
dump
(
kernel
.
params
,
file
))
KernelCache
.
_safe_write_file
(
params_path
,
"wb"
,
lambda
file
:
cloudpickle
.
dump
(
kernel
.
params
,
file
))
except
Exception
as
e
:
self
.
logger
.
error
(
f
"Error saving kernel parameters to disk:
{
e
}
"
)
...
...
@@ -417,8 +402,7 @@ class KernelCache:
self
.
logger
.
error
(
f
"Error loading kernel source code from disk:
{
e
}
"
)
try
:
if
verbose
:
self
.
logger
.
debug
(
f
"Loading wrapped kernel source code from file:
{
host_kernel_path
}
"
)
self
.
logger
.
debug
(
f
"Loading wrapped kernel source code from file:
{
host_kernel_path
}
"
)
with
open
(
host_kernel_path
)
as
f
:
host_kernel_source
=
f
.
read
()
except
Exception
as
e
:
...
...
tilelang/carver/__init__.py
View file @
29051439
"""Base infra"""
from
.analysis
import
(
BlockInfo
,
# noqa: F401
IterInfo
,
# noqa: F401
...
...
tilelang/carver/analysis.py
View file @
29051439
"""Analysis on TIR blocks, loops and functions."""
from
__future__
import
annotations
from
typing_extensions
import
Literal
...
...
@@ -144,11 +145,13 @@ def normalize_prim_func(sch: tir.Schedule) -> list[BlockInfo] | None:
var
=
iter
.
var
,
dom
=
iter
.
dom
,
loop_rv
=
loop
,
)
for
loop
,
iter
in
zip
(
loops
,
iters
)
)
for
loop
,
iter
in
zip
(
loops
,
iters
)
],
block_rv
=
block
,
reduction_block
=
is_reduction
,
))
)
)
return
blocks
...
...
@@ -188,8 +191,7 @@ def get_max_shared_memory_per_block(target: Target) -> int:
_assert_gpu_target
(
target
)
max_shared_memory_per_block
=
target
.
attrs
.
get
(
"max_shared_memory_per_block"
,
None
)
if
max_shared_memory_per_block
is
None
:
raise
ValueError
(
f
"Cannot find `max_shared_memory_per_block` in
{
target
}
, please specify it manually"
)
raise
ValueError
(
f
"Cannot find `max_shared_memory_per_block` in
{
target
}
, please specify it manually"
)
return
int
(
max_shared_memory_per_block
)
...
...
@@ -197,13 +199,11 @@ def get_root_block(sch: Schedule, func_name: str = "main") -> BlockRV:
try
:
block
=
sch
.
mod
[
func_name
].
body
.
block
except
Exception
:
raise
ValueError
(
f
"The function body is expected to be the root block, but got:
\n
"
f
"
{
sch
.
mod
[
func_name
].
body
}
"
)
from
None
raise
ValueError
(
f
"The function body is expected to be the root block, but got:
\n
{
sch
.
mod
[
func_name
].
body
}
"
)
from
None
return
sch
.
get_block
(
block
.
name_hint
)
def
collect_block_iter_vars_used_in_access_region
(
block
:
tir
.
Block
,
region
:
list
[
ir
.
Range
])
->
set
[
tir
.
Var
]:
def
collect_block_iter_vars_used_in_access_region
(
block
:
tir
.
Block
,
region
:
list
[
ir
.
Range
])
->
set
[
tir
.
Var
]:
"""Collect the block iter variables used in the access region of a buffer region."""
tir_vars
=
set
()
for
expr
in
region
:
...
...
@@ -251,15 +251,13 @@ def is_broadcast_epilogue(
for
buffer_region
in
sch
.
get
(
epilogue
).
reads
:
if
buffer_region
.
buffer
not
in
write_buffers
:
continue
tir_vars
=
collect_block_iter_vars_used_in_access_region
(
sch
.
get
(
epilogue
),
buffer_region
.
region
)
tir_vars
=
collect_block_iter_vars_used_in_access_region
(
sch
.
get
(
epilogue
),
buffer_region
.
region
)
if
len
(
tir_vars
)
<
len
(
epilogue_iters
):
return
True
return
False
def
get_reduction_blocks
(
sch
:
tir
.
Schedule
,
blocks
:
list
[
tir
.
schedule
.
BlockRV
])
->
list
[
tir
.
schedule
.
BlockRV
]:
def
get_reduction_blocks
(
sch
:
tir
.
Schedule
,
blocks
:
list
[
tir
.
schedule
.
BlockRV
])
->
list
[
tir
.
schedule
.
BlockRV
]:
# Get the main computation block
def
is_reduction
(
block
:
BlockRV
)
->
bool
:
block_stmt
=
sch
.
get
(
block
)
...
...
tilelang/carver/arch/__init__.py
View file @
29051439
...
...
@@ -39,18 +39,18 @@ def auto_infer_current_arch() -> TileDevice:
__all__
=
[
'
is_cpu_arch
'
,
'
is_cuda_arch
'
,
'
is_volta_arch
'
,
'
is_ampere_arch
'
,
'
is_ada_arch
'
,
'
is_hopper_arch
'
,
'
is_tensorcore_supported_precision
'
,
'
has_mma_support
'
,
'
is_cdna_arch
'
,
'
is_metal_arch
'
,
'
CUDA
'
,
'
CDNA
'
,
'
METAL
'
,
'
CPU
'
,
"
is_cpu_arch
"
,
"
is_cuda_arch
"
,
"
is_volta_arch
"
,
"
is_ampere_arch
"
,
"
is_ada_arch
"
,
"
is_hopper_arch
"
,
"
is_tensorcore_supported_precision
"
,
"
has_mma_support
"
,
"
is_cdna_arch
"
,
"
is_metal_arch
"
,
"
CUDA
"
,
"
CDNA
"
,
"
METAL
"
,
"
CPU
"
,
]
tilelang/carver/arch/arch_base.py
View file @
29051439
...
...
@@ -7,9 +7,7 @@ class TileDevice:
self
.
reg_cap
:
int
=
0
# Register capacity: The amount of register memory available
self
.
smem_cap
:
int
=
0
# Shared memory capacity: The amount of shared memory available
self
.
compute_max_core
:
int
=
0
# The maximum number of computing cores
self
.
warp_size
:
int
=
(
0
# The size of a warp, a group of threads that execute instructions in lockstep
)
self
.
warp_size
:
int
=
0
# The size of a warp, a group of threads that execute instructions in lockstep
self
.
sm_partition
:
int
=
0
# The number of streaming multiprocessor partitions
self
.
transaction_size
:
list
[
int
]
=
[
0
,
...
...
@@ -21,9 +19,7 @@ class TileDevice:
0
,
]
# Bandwidth specifications, possibly including peak and sustained rates
self
.
platform
:
str
=
"unknown"
# The platform or manufacturer of the device
self
.
compute_capability
:
str
=
(
"unknown"
# The compute capability, indicating the feature set and performance level
)
self
.
compute_capability
:
str
=
"unknown"
# The compute capability, indicating the feature set and performance level
self
.
l2_cache_size_bytes
:
int
=
0
# the number of transaction size in bytes
self
.
transaction_size
:
list
[
int
]
=
[
0
,
0
]
# in bytes
...
...
tilelang/carver/arch/cdna.py
View file @
29051439
...
...
@@ -9,7 +9,6 @@ def is_cdna_arch(arch: TileDevice) -> bool:
class
CDNA
(
TileDevice
):
def
__init__
(
self
,
target
:
Target
|
str
):
if
isinstance
(
target
,
str
):
target
=
tvm
.
target
.
Target
(
target
)
...
...
@@ -33,6 +32,6 @@ class CDNA(TileDevice):
__all__
=
[
'
is_cdna_arch
'
,
'
CDNA
'
,
"
is_cdna_arch
"
,
"
CDNA
"
,
]
tilelang/carver/arch/cpu.py
View file @
29051439
...
...
@@ -10,7 +10,6 @@ def is_cpu_arch(arch: TileDevice) -> bool:
# For LLVM Backend, we do not provide the detailed information of the CPU
# As the LLVM backend do not required tuning, just maintain the consistency
class
CPU
(
TileDevice
):
def
__init__
(
self
,
target
:
Target
):
self
.
target
=
target
device
=
tvm
.
runtime
.
cpu
(
0
)
...
...
@@ -21,6 +20,6 @@ class CPU(TileDevice):
__all__
=
[
'
is_cpu_arch
'
,
'
CPU
'
,
"
is_cpu_arch
"
,
"
CPU
"
,
]
tilelang/carver/arch/cuda.py
View file @
29051439
...
...
@@ -78,7 +78,6 @@ hopper_tensorcore_supported = ada_tensorcore_supported
# instead of assuming both a and b share the same dtype.
# As the tensorcore may supports float8_e4m3 * float8_e5m2
def
is_tensorcore_supported_precision
(
in_dtype
:
str
,
accum_dtype
:
str
,
arch
:
TileDevice
)
->
bool
:
if
is_volta_arch
(
arch
):
return
(
in_dtype
,
accum_dtype
)
in
volta_tensorcore_supported
elif
is_ampere_arch
(
arch
):
...
...
@@ -92,7 +91,6 @@ def is_tensorcore_supported_precision(in_dtype: str, accum_dtype: str, arch: Til
class
TensorInstruction
:
def
__init__
(
self
,
name
:
str
,
...
...
@@ -104,7 +102,6 @@ class TensorInstruction:
class
CUDA
(
TileDevice
):
def
__init__
(
self
,
target
:
Target
|
str
):
if
isinstance
(
target
,
str
):
target
=
tvm
.
target
.
Target
(
target
)
...
...
@@ -148,12 +145,12 @@ class CUDA(TileDevice):
__all__
=
[
'
is_cuda_arch
'
,
'
is_volta_arch
'
,
'
is_ampere_arch
'
,
'
is_ada_arch
'
,
'
is_hopper_arch
'
,
'
is_tensorcore_supported_precision
'
,
'
has_mma_support
'
,
"
is_cuda_arch
"
,
"
is_volta_arch
"
,
"
is_ampere_arch
"
,
"
is_ada_arch
"
,
"
is_hopper_arch
"
,
"
is_tensorcore_supported_precision
"
,
"
has_mma_support
"
,
"CUDA"
,
]
tilelang/carver/arch/driver/cuda_driver.py
View file @
29051439
...
...
@@ -83,8 +83,7 @@ def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes")
Get the maximum dynamic shared memory size in bytes, kilobytes, or megabytes.
"""
assert
format
in
[
"bytes"
,
"kb"
,
"mb"
],
"Invalid format. Must be one of: bytes, kb, mb"
shared_mem
=
get_device_attribute
(
cudaDeviceAttrNames
.
cudaDevAttrMaxSharedMemoryPerMultiprocessor
,
device_id
)
shared_mem
=
get_device_attribute
(
cudaDeviceAttrNames
.
cudaDevAttrMaxSharedMemoryPerMultiprocessor
,
device_id
)
if
format
==
"bytes"
:
return
shared_mem
elif
format
==
"kb"
:
...
...
Prev
1
…
13
14
15
16
17
18
19
20
21
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment