Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
239 additions
and
296 deletions
+239
-296
testing/python/transform/test_tilelang_transform_warp_specialized.py
...hon/transform/test_tilelang_transform_warp_specialized.py
+37
-35
testing/python/utils/test_compress_utils.py
testing/python/utils/test_compress_utils.py
+1
-1
testing/python/webgpu/test_webgpu_codegen.py
testing/python/webgpu/test_webgpu_codegen.py
+3
-4
tilelang/__init__.py
tilelang/__init__.py
+2
-0
tilelang/analysis/fragment_loop_checker.py
tilelang/analysis/fragment_loop_checker.py
+9
-9
tilelang/analysis/layout_visual.py
tilelang/analysis/layout_visual.py
+1
-5
tilelang/analysis/nested_loop_checker.py
tilelang/analysis/nested_loop_checker.py
+8
-15
tilelang/autotuner/capture.py
tilelang/autotuner/capture.py
+1
-2
tilelang/autotuner/param.py
tilelang/autotuner/param.py
+33
-40
tilelang/autotuner/tuner.py
tilelang/autotuner/tuner.py
+77
-94
tilelang/cache/__init__.py
tilelang/cache/__init__.py
+9
-6
tilelang/cache/kernel_cache.py
tilelang/cache/kernel_cache.py
+19
-35
tilelang/carver/__init__.py
tilelang/carver/__init__.py
+1
-0
tilelang/carver/analysis.py
tilelang/carver/analysis.py
+10
-12
tilelang/carver/arch/__init__.py
tilelang/carver/arch/__init__.py
+14
-14
tilelang/carver/arch/arch_base.py
tilelang/carver/arch/arch_base.py
+2
-6
tilelang/carver/arch/cdna.py
tilelang/carver/arch/cdna.py
+2
-3
tilelang/carver/arch/cpu.py
tilelang/carver/arch/cpu.py
+2
-3
tilelang/carver/arch/cuda.py
tilelang/carver/arch/cuda.py
+7
-10
tilelang/carver/arch/driver/cuda_driver.py
tilelang/carver/arch/driver/cuda_driver.py
+1
-2
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
testing/python/transform/test_tilelang_transform_warp_specialized.py
View file @
29051439
...
...
@@ -32,7 +32,6 @@ block_K = 32
def
test_warp_specialized
():
@
T
.
prim_func
def
before
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
)):
bx
=
T
.
launch_thread
(
"blockIdx.x"
,
8
)
...
...
@@ -47,25 +46,27 @@ def test_warp_specialized():
for
k
in
T
.
serial
(
16
,
annotations
=
{
"num_stages"
:
T
.
int32
(
3
)}):
if
v
==
0
:
T
.
tma_load
(
T
.
create_tma_descriptor
(
6
,
2
,
A
.
data
,
512
,
512
,
2
,
1024
,
32
,
64
,
1
,
1
,
0
,
2
,
2
,
0
),
0
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
k
*
32
,
by
*
64
)
T
.
create_tma_descriptor
(
6
,
2
,
A
.
data
,
512
,
512
,
2
,
1024
,
32
,
64
,
1
,
1
,
0
,
2
,
2
,
0
),
0
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
k
*
32
,
by
*
64
,
)
if
v
==
0
:
T
.
tma_load
(
T
.
create_tma_descriptor
(
6
,
2
,
B
.
data
,
512
,
512
,
2
,
1024
,
64
,
32
,
1
,
1
,
0
,
3
,
2
,
0
),
0
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
bx
*
64
,
k
*
32
)
T
.
create_tma_descriptor
(
6
,
2
,
B
.
data
,
512
,
512
,
2
,
1024
,
64
,
32
,
1
,
1
,
0
,
3
,
2
,
0
),
0
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
bx
*
64
,
k
*
32
,
)
T
.
call_extern
(
"handle"
,
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>"
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float
16
"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
)
)
"handle"
,
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>"
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float
32
"
),
C_local
.
data
,
0
,
32
,
3
),
)
@
T
.
prim_func
def
after
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
)):
...
...
@@ -85,34 +86,35 @@ def test_warp_specialized():
T
.
mbarrier_expect_tx
(
T
.
get_mbarrier
(
k
%
3
),
4096
)
if
v
-
128
==
0
:
T
.
tma_load
(
T
.
create_tma_descriptor
(
6
,
2
,
A
.
data
,
512
,
512
,
2
,
1024
,
32
,
64
,
1
,
1
,
0
,
2
,
2
,
0
),
T
.
get_mbarrier
(
k
%
3
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
k
*
32
,
by
*
64
)
T
.
create_tma_descriptor
(
6
,
2
,
A
.
data
,
512
,
512
,
2
,
1024
,
32
,
64
,
1
,
1
,
0
,
2
,
2
,
0
),
T
.
get_mbarrier
(
k
%
3
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
k
*
32
,
by
*
64
,
)
if
v
-
128
==
0
:
T
.
mbarrier_expect_tx
(
T
.
get_mbarrier
(
k
%
3
),
4096
)
if
v
-
128
==
0
:
T
.
tma_load
(
T
.
create_tma_descriptor
(
6
,
2
,
B
.
data
,
512
,
512
,
2
,
1024
,
64
,
32
,
1
,
1
,
0
,
3
,
2
,
0
),
T
.
get_mbarrier
(
k
%
3
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
bx
*
64
,
k
*
32
)
T
.
create_tma_descriptor
(
6
,
2
,
B
.
data
,
512
,
512
,
2
,
1024
,
64
,
32
,
1
,
1
,
0
,
3
,
2
,
0
),
T
.
get_mbarrier
(
k
%
3
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
bx
*
64
,
k
*
32
,
)
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_arrive_barrier"
,
[
T
.
get_mbarrier
(
k
%
3
)]))
else
:
T
.
set_max_nreg
(
240
,
1
)
for
k
in
range
(
16
):
T
.
mbarrier_wait_parity
(
T
.
get_mbarrier
(
k
%
3
),
k
//
3
%
2
)
T
.
call_extern
(
"handle"
,
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>"
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
))
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_arrive_barrier"
,
[
T
.
get_mbarrier
(
k
%
3
+
3
)]))
"handle"
,
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>"
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
),
)
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_arrive_barrier"
,
[
T
.
get_mbarrier
(
k
%
3
+
3
)]))
_check
(
before
,
after
)
...
...
testing/python/utils/test_compress_utils.py
View file @
29051439
...
...
@@ -6,7 +6,7 @@ from tilelang.utils.sparse import compress_sm90, randn_semi_sparse
def
_test_compress_sm90
(
M
,
K
,
block_k
,
dtype
):
A
=
randn_semi_sparse
(
M
,
K
,
dtype
=
dtype
,
device
=
'
cuda
'
)
A
=
randn_semi_sparse
(
M
,
K
,
dtype
=
dtype
,
device
=
"
cuda
"
)
A_sparse
,
E
=
compress_sm90
(
A
,
block_k
,
False
)
...
...
testing/python/webgpu/test_webgpu_codegen.py
View file @
29051439
...
...
@@ -5,12 +5,11 @@ import tilelang.language as T
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
...
...
tilelang/__init__.py
View file @
29051439
...
...
@@ -23,6 +23,7 @@ def _compute_version() -> str:
if
version_file
.
is_file
():
try
:
from
version_provider
import
dynamic_metadata
# type: ignore
return
dynamic_metadata
(
"version"
)
except
Exception
:
# Fall back to the raw VERSION file if provider isn't available.
...
...
@@ -33,6 +34,7 @@ def _compute_version() -> str:
try
:
from
importlib.metadata
import
version
as
_dist_version
# py3.8+
return
_dist_version
(
"tilelang"
)
except
Exception
as
exc
:
warnings
.
warn
(
...
...
tilelang/analysis/fragment_loop_checker.py
View file @
29051439
from
__future__
import
annotations
from
tvm
import
tir
from
tvm.tir
import
(
PyStmtExprVisitor
,
BufferStore
,
For
,
Var
,
PrimFunc
,
BufferLoad
,
IntImm
)
from
tvm.tir
import
PyStmtExprVisitor
,
BufferStore
,
For
,
Var
,
PrimFunc
,
BufferLoad
,
IntImm
from
tvm.tir.transform
import
prim_func_pass
from
tvm.tir.stmt_functor
import
post_order_visit
...
...
@@ -22,14 +22,14 @@ class _LoopVarUseAnalyzer(PyStmtExprVisitor):
def
collect_local_buffer_accesses
(
statement
)
->
list
[
BufferLoad
|
BufferStore
]:
"""
Collect local buffer accesses in the loop body.
Collect local buffer accesses in the loop body.
Args:
statement: The TIR statement to analyze
Args:
statement: The TIR statement to analyze
Returns:
Tuple of buffer accesses in the loop body.
"""
Returns:
Tuple of buffer accesses in the loop body.
"""
buffer_accesses
=
[]
...
...
@@ -44,7 +44,6 @@ def collect_local_buffer_accesses(statement) -> list[BufferLoad | BufferStore]:
@
tir
.
functor
.
visitor
class
_FragmentLoopCheckVisitor
(
PyStmtExprVisitor
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
...
...
@@ -75,7 +74,8 @@ class _FragmentLoopCheckVisitor(PyStmtExprVisitor):
raise
ValueError
(
"[Tilelang Semantic Check] "
f
"Loop variable
{
loop
.
loop_var
}
in a T.Parallel loop with symbolic range (min=
{
loop
.
min
}
, extent=
{
loop
.
extent
}
) is used to index "
"a local/fragment buffer, which is not allowed in Tilelang."
)
"a local/fragment buffer, which is not allowed in Tilelang."
)
return
...
...
tilelang/analysis/layout_visual.py
View file @
29051439
...
...
@@ -23,10 +23,7 @@ def print_fragment_format(layout: T.Fragment) -> str:
if
isinstance
(
layout
,
T
.
Fragment
):
input_shape
=
layout
.
get_input_shape
()
output_shape
=
layout
.
get_output_shape
()
lines
=
[
f
" Shape:
{
input_shape
}
->
{
output_shape
}
"
,
f
" Thread:
{
layout
.
forward_thread
}
"
,
f
" Index:
{
layout
.
forward_index
}
"
]
lines
=
[
f
" Shape:
{
input_shape
}
->
{
output_shape
}
"
,
f
" Thread:
{
layout
.
forward_thread
}
"
,
f
" Index:
{
layout
.
forward_index
}
"
]
print
(
"
\n
"
.
join
(
lines
))
else
:
raise
ValueError
(
f
"Expected T.Fragment, but got
{
type
(
layout
).
__name__
}
"
)
...
...
@@ -82,7 +79,6 @@ class _LayoutVisualVisitor(PyStmtExprVisitor):
def
LayoutVisual
(
formats
:
str
=
""
):
def
pass_fn
(
func
:
tir
.
PrimFunc
,
mod
,
ctx
):
_LayoutVisualVisitor
(
formats
=
formats
).
visit_stmt
(
func
.
body
)
return
func
...
...
tilelang/analysis/nested_loop_checker.py
View file @
29051439
...
...
@@ -11,10 +11,7 @@ from tvm.tir.transform import prim_func_pass
def
is_pipelined_for
(
op
:
For
)
->
bool
:
"""Check if a for loop is pipelined."""
anno_keys
=
[
"num_stages"
,
"tl_pipeline_order"
,
"tl_pipeline_stage"
,
"tl_pipeline_sync"
,
"tl_pipeline_group"
]
anno_keys
=
[
"num_stages"
,
"tl_pipeline_order"
,
"tl_pipeline_stage"
,
"tl_pipeline_sync"
,
"tl_pipeline_group"
]
return
any
(
key
in
op
.
annotations
for
key
in
anno_keys
)
...
...
@@ -26,7 +23,6 @@ def is_tile_op(op: Call) -> bool:
@
tir
.
functor
.
visitor
class
_NestedLoopCheckVisitor
(
PyStmtExprVisitor
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
in_parallel_context
=
False
...
...
@@ -42,27 +38,24 @@ class _NestedLoopCheckVisitor(PyStmtExprVisitor):
# Otherwise
if
self
.
in_parallel_context
:
raise
ValueError
(
"[Tilelang Semantic Check] "
"Nested parallel loops are not allowed. "
"Please check your loop structure."
)
raise
ValueError
(
"[Tilelang Semantic Check] Nested parallel loops are not allowed. Please check your loop structure."
)
self
.
in_parallel_context
=
True
super
().
visit_for_
(
op
)
self
.
in_parallel_context
=
False
return
elif
is_pipelined_for
(
op
):
if
self
.
in_parallel_context
:
raise
ValueError
(
"[Tilelang Semantic Check] "
"
Pipelined loop cannot be nested inside a parallel loop. "
"Please check your loop structure."
)
raise
ValueError
(
"[Tilelang Semantic Check]
Pipelined loop cannot be nested inside a parallel loop.
Please check your loop structure.
"
)
super
().
visit_for_
(
op
)
def
visit_call_
(
self
,
op
:
Call
)
->
None
:
if
self
.
in_parallel_context
and
is_tile_op
(
op
):
raise
ValueError
(
"[Tilelang Semantic Check] "
"Only elementwise operations are allowed inside a parallel loop. "
\
f
"Got a tile-op
\"
{
op
.
op
}
\"
."
)
raise
ValueError
(
f
'[Tilelang Semantic Check] Only elementwise operations are allowed inside a parallel loop. Got a tile-op "
{
op
.
op
}
".'
)
def
NestedLoopChecker
():
...
...
tilelang/autotuner/capture.py
View file @
29051439
...
...
@@ -85,8 +85,7 @@ def _get_current_stack() -> CaptureStack:
class
AutotuneInputsCapture
:
__slots__
=
(
"tensors"
)
__slots__
=
"tensors"
def
__init__
(
self
,
tensors
:
list
[
Any
]):
self
.
tensors
=
tensors
...
...
tilelang/autotuner/param.py
View file @
29051439
"""The auto-tune parameters.
"""
"""The auto-tune parameters.
"""
from
__future__
import
annotations
import
tilelang
...
...
@@ -50,7 +50,7 @@ class CompileArgs:
out_idx
:
list
[
int
]
|
int
|
None
=
None
execution_backend
:
Literal
[
"auto"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
=
"auto"
target
:
Literal
[
'
auto
'
,
'
cuda
'
,
'
hip
'
]
=
'
auto
'
target
:
Literal
[
"
auto
"
,
"
cuda
"
,
"
hip
"
]
=
"
auto
"
target_host
:
str
|
Target
=
None
verbose
:
bool
=
False
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
...
...
@@ -62,24 +62,20 @@ class CompileArgs:
target
=
self
.
target
,
target_host
=
self
.
target_host
,
verbose
=
self
.
verbose
,
pass_configs
=
self
.
pass_configs
)
pass_configs
=
self
.
pass_configs
,
)
def
__hash__
(
self
):
data
=
{
"execution_backend"
:
self
.
execution_backend
,
"target"
:
str
(
self
.
target
),
"target_host"
:
str
(
self
.
target_host
)
if
self
.
target_host
else
None
,
"verbose"
:
self
.
verbose
,
"pass_configs"
:
json
.
dumps
(
self
.
pass_configs
,
sort_keys
=
True
)
if
self
.
pass_configs
else
None
,
"execution_backend"
:
self
.
execution_backend
,
"target"
:
str
(
self
.
target
),
"target_host"
:
str
(
self
.
target_host
)
if
self
.
target_host
else
None
,
"verbose"
:
self
.
verbose
,
"pass_configs"
:
json
.
dumps
(
self
.
pass_configs
,
sort_keys
=
True
)
if
self
.
pass_configs
else
None
,
}
hash_obj
=
hashlib
.
sha256
(
json
.
dumps
(
data
,
sort_keys
=
True
).
encode
(
'
utf-8
'
))
return
int
.
from_bytes
(
hash_obj
.
digest
(),
byteorder
=
'
big
'
)
hash_obj
=
hashlib
.
sha256
(
json
.
dumps
(
data
,
sort_keys
=
True
).
encode
(
"
utf-8
"
))
return
int
.
from_bytes
(
hash_obj
.
digest
(),
byteorder
=
"
big
"
)
@
dataclass
(
frozen
=
True
)
...
...
@@ -104,6 +100,7 @@ class ProfileArgs:
manual_check_prog: Callable = None
cache_input_tensors: bool = True
"""
warmup
:
int
=
25
rep
:
int
=
100
timeout
:
int
=
30
...
...
@@ -127,8 +124,8 @@ class ProfileArgs:
"atol"
:
self
.
atol
,
"max_mismatched_ratio"
:
self
.
max_mismatched_ratio
,
}
hash_obj
=
hashlib
.
sha256
(
json
.
dumps
(
data
,
sort_keys
=
True
).
encode
(
'
utf-8
'
))
return
int
.
from_bytes
(
hash_obj
.
digest
(),
byteorder
=
'
big
'
)
hash_obj
=
hashlib
.
sha256
(
json
.
dumps
(
data
,
sort_keys
=
True
).
encode
(
"
utf-8
"
))
return
int
.
from_bytes
(
hash_obj
.
digest
(),
byteorder
=
"
big
"
)
@
dataclass
(
frozen
=
True
)
...
...
@@ -143,6 +140,7 @@ class AutotuneResult:
func: Optimized function.
kernel: Compiled kernel function.
"""
latency
:
float
|
None
=
None
config
:
dict
|
None
=
None
ref_latency
:
float
|
None
=
None
...
...
@@ -199,8 +197,7 @@ class AutotuneResult:
if
verbose
:
logger
.
debug
(
f
"Saving kernel source code to file:
{
device_kernel_path
}
"
)
if
kernel
.
kernel_source
is
not
None
:
self
.
_safe_write_file
(
device_kernel_path
,
"w"
,
lambda
f
:
f
.
write
(
kernel
.
kernel_source
))
self
.
_safe_write_file
(
device_kernel_path
,
"w"
,
lambda
f
:
f
.
write
(
kernel
.
kernel_source
))
except
Exception
as
e
:
logger
.
error
(
f
"Error saving kernel source code to disk:
{
e
}
"
)
...
...
@@ -211,11 +208,9 @@ class AutotuneResult:
logger
.
debug
(
f
"Saving wrapped kernel source code to file:
{
host_kernel_path
}
"
)
# Match kernel_cache behavior: use host source for tvm_ffi, otherwise wrapped kernel
if
kernel
.
execution_backend
==
"tvm_ffi"
:
self
.
_safe_write_file
(
host_kernel_path
,
"w"
,
lambda
f
:
f
.
write
(
kernel
.
adapter
.
get_host_source
()))
self
.
_safe_write_file
(
host_kernel_path
,
"w"
,
lambda
f
:
f
.
write
(
kernel
.
adapter
.
get_host_source
()))
else
:
self
.
_safe_write_file
(
host_kernel_path
,
"w"
,
lambda
f
:
f
.
write
(
kernel
.
adapter
.
get_kernel_source
()))
self
.
_safe_write_file
(
host_kernel_path
,
"w"
,
lambda
f
:
f
.
write
(
kernel
.
adapter
.
get_kernel_source
()))
except
Exception
as
e
:
logger
.
error
(
f
"Error saving wrapped kernel source code to disk:
{
e
}
"
)
...
...
@@ -237,12 +232,10 @@ class AutotuneResult:
py_src_path
=
src_lib_path
.
replace
(
".cubin"
,
".py"
)
if
verbose
:
logger
.
debug
(
f
"Saving kernel nvrtc python code to file:
{
kernel_py_path
}
"
)
self
.
_safe_write_file
(
kernel_py_path
,
"wb"
,
lambda
f
:
f
.
write
(
self
.
_load_binary
(
py_src_path
)))
self
.
_safe_write_file
(
kernel_py_path
,
"wb"
,
lambda
f
:
f
.
write
(
self
.
_load_binary
(
py_src_path
)))
if
verbose
:
logger
.
debug
(
f
"Saving kernel library to file:
{
kernel_lib_path
}
"
)
self
.
_safe_write_file
(
kernel_lib_path
,
"wb"
,
lambda
f
:
f
.
write
(
self
.
_load_binary
(
src_lib_path
)))
self
.
_safe_write_file
(
kernel_lib_path
,
"wb"
,
lambda
f
:
f
.
write
(
self
.
_load_binary
(
src_lib_path
)))
elif
kernel
.
execution_backend
==
"tvm_ffi"
:
executable
=
kernel
.
adapter
.
executable
if
verbose
:
...
...
@@ -252,8 +245,7 @@ class AutotuneResult:
src_lib_path
=
kernel
.
adapter
.
libpath
if
verbose
:
logger
.
debug
(
f
"Saving kernel library to file:
{
kernel_lib_path
}
"
)
self
.
_safe_write_file
(
kernel_lib_path
,
"wb"
,
lambda
f
:
f
.
write
(
self
.
_load_binary
(
src_lib_path
)))
self
.
_safe_write_file
(
kernel_lib_path
,
"wb"
,
lambda
f
:
f
.
write
(
self
.
_load_binary
(
src_lib_path
)))
except
Exception
as
e
:
logger
.
error
(
f
"Error saving kernel library to disk:
{
e
}
"
)
...
...
@@ -370,14 +362,12 @@ class AutotuneResult:
# save best config (atomic)
if
verbose
:
logger
.
debug
(
f
"Saving best config to file:
{
path
/
BEST_CONFIG_PATH
}
"
)
self
.
_safe_write_file
(
str
(
path
/
BEST_CONFIG_PATH
),
"w"
,
lambda
f
:
json
.
dump
(
self
.
config
,
f
))
self
.
_safe_write_file
(
str
(
path
/
BEST_CONFIG_PATH
),
"w"
,
lambda
f
:
json
.
dump
(
self
.
config
,
f
))
# save function (atomic)
if
verbose
:
logger
.
debug
(
f
"Saving function to file:
{
path
/
FUNCTION_PATH
}
"
)
self
.
_safe_write_file
(
str
(
path
/
FUNCTION_PATH
),
"wb"
,
lambda
f
:
cloudpickle
.
dump
(
self
.
func
,
f
))
self
.
_safe_write_file
(
str
(
path
/
FUNCTION_PATH
),
"wb"
,
lambda
f
:
cloudpickle
.
dump
(
self
.
func
,
f
))
# save ref latency (atomic)
if
verbose
:
...
...
@@ -385,10 +375,13 @@ class AutotuneResult:
self
.
_safe_write_file
(
str
(
path
/
LATENCY_PATH
),
"w"
,
lambda
f
:
json
.
dump
({
"latency"
:
self
.
latency
,
"ref_latency"
:
self
.
ref_latency
,
},
f
),
lambda
f
:
json
.
dump
(
{
"latency"
:
self
.
latency
,
"ref_latency"
:
self
.
ref_latency
,
},
f
,
),
)
# save kernel
...
...
@@ -403,8 +396,8 @@ class AutotuneResult:
# Normalize target and resolve execution backend for loading
from
tilelang.utils.target
import
determine_target
as
_determine_target
from
tilelang.jit.execution_backend
import
resolve_execution_backend
norm_target
=
Target
(
_determine_target
(
compile_args
.
target
))
if
isinstance
(
compile_args
.
target
,
str
)
else
compile_args
.
target
norm_target
=
Target
(
_determine_target
(
compile_args
.
target
))
if
isinstance
(
compile_args
.
target
,
str
)
else
compile_args
.
target
requested_backend
=
compile_args
.
execution_backend
resolved_backend
=
resolve_execution_backend
(
requested_backend
,
norm_target
)
# load best config
...
...
tilelang/autotuner/tuner.py
View file @
29051439
...
...
@@ -3,6 +3,7 @@
This module provides functionality for auto-tuning tilelang programs, including JIT compilation
and performance optimization through configuration search.
"""
from
__future__
import
annotations
from
dataclasses
import
dataclass
...
...
@@ -14,7 +15,8 @@ from tvm.tir import PrimFunc, Var
from
tvm.target
import
Target
import
inspect
from
functools
import
partial
from
typing
import
(
Callable
,
Generic
,
Literal
,
Any
,
TypeVar
)
from
typing
import
Callable
,
Generic
,
Literal
,
Any
,
TypeVar
# Python 3.9 compatibility for ParamSpec
try
:
from
typing
import
ParamSpec
...
...
@@ -74,8 +76,8 @@ def _init_logger_handlers():
global
_logger_handlers_initialized
if
_logger_handlers_initialized
:
return
formatter
=
logging
.
Formatter
(
'
%(asctime)s %(levelname)s:%(message)s
'
)
file_handler
=
logging
.
FileHandler
(
'
autotuner.log
'
,
mode
=
'w'
)
formatter
=
logging
.
Formatter
(
"
%(asctime)s %(levelname)s:%(message)s
"
)
file_handler
=
logging
.
FileHandler
(
"
autotuner.log
"
,
mode
=
"w"
)
file_handler
.
setLevel
(
logging
.
DEBUG
)
file_handler
.
setFormatter
(
formatter
)
console_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
...
...
@@ -87,8 +89,7 @@ def _init_logger_handlers():
def
get_available_cpu_count
()
->
int
:
"""Gets the number of CPU cores available to the current process.
"""
"""Gets the number of CPU cores available to the current process."""
try
:
cpu_count
=
len
(
os
.
sched_getaffinity
(
0
))
except
AttributeError
:
...
...
@@ -107,6 +108,7 @@ class AutoTuner:
fn: The function to be auto-tuned.
configs: List of configurations to try during auto-tuning.
"""
compile_args
=
CompileArgs
()
profile_args
=
ProfileArgs
()
...
...
@@ -137,14 +139,15 @@ class AutoTuner:
"""
return
cls
(
kernel
,
configs
)
def
set_compile_args
(
self
,
out_idx
:
list
[
int
]
|
int
|
None
=
None
,
target
:
Literal
[
'auto'
,
'cuda'
,
'hip'
,
'metal'
]
=
'auto'
,
execution_backend
:
Literal
[
"auto"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
=
"auto"
,
target_host
:
str
|
Target
=
None
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
):
def
set_compile_args
(
self
,
out_idx
:
list
[
int
]
|
int
|
None
=
None
,
target
:
Literal
[
"auto"
,
"cuda"
,
"hip"
,
"metal"
]
=
"auto"
,
execution_backend
:
Literal
[
"auto"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
=
"auto"
,
target_host
:
str
|
Target
=
None
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
):
"""Set compilation arguments for the auto-tuner.
Args:
...
...
@@ -161,6 +164,7 @@ class AutoTuner:
# Normalize target to a concrete TVM Target and resolve execution backend
t
=
Target
(
determine_target
(
target
))
from
tilelang.jit.execution_backend
import
resolve_execution_backend
resolved_backend
=
resolve_execution_backend
(
execution_backend
,
t
)
self
.
compile_args
=
CompileArgs
(
...
...
@@ -169,23 +173,26 @@ class AutoTuner:
execution_backend
=
resolved_backend
,
target_host
=
target_host
,
verbose
=
verbose
,
pass_configs
=
pass_configs
)
pass_configs
=
pass_configs
,
)
return
self
def
set_profile_args
(
self
,
warmup
:
int
=
25
,
rep
:
int
=
100
,
timeout
:
int
=
30
,
supply_type
:
tilelang
.
TensorSupplyType
=
tilelang
.
TensorSupplyType
.
Auto
,
ref_prog
:
Callable
=
None
,
supply_prog
:
Callable
=
None
,
rtol
:
float
=
1e-2
,
atol
:
float
=
1e-2
,
max_mismatched_ratio
:
float
=
0.01
,
skip_check
:
bool
=
False
,
manual_check_prog
:
Callable
=
None
,
cache_input_tensors
:
bool
=
False
):
def
set_profile_args
(
self
,
warmup
:
int
=
25
,
rep
:
int
=
100
,
timeout
:
int
=
30
,
supply_type
:
tilelang
.
TensorSupplyType
=
tilelang
.
TensorSupplyType
.
Auto
,
ref_prog
:
Callable
=
None
,
supply_prog
:
Callable
=
None
,
rtol
:
float
=
1e-2
,
atol
:
float
=
1e-2
,
max_mismatched_ratio
:
float
=
0.01
,
skip_check
:
bool
=
False
,
manual_check_prog
:
Callable
=
None
,
cache_input_tensors
:
bool
=
False
,
):
"""Set profiling arguments for the auto-tuner.
Args:
...
...
@@ -209,9 +216,7 @@ class AutoTuner:
# the `supply_prog` will be ignored and the `get_autotune_inputs` will be used instead.
if
get_autotune_inputs
()
is
not
None
:
if
supply_prog
is
not
None
:
logger
.
warning
(
"`supply_prog` will be ignored as this program is under `with set_autotune_inputs` context."
)
logger
.
warning
(
"`supply_prog` will be ignored as this program is under `with set_autotune_inputs` context."
)
supply_prog
=
lambda
_
:
get_autotune_inputs
()
# noqa: E731
self
.
profile_args
=
ProfileArgs
(
...
...
@@ -226,13 +231,13 @@ class AutoTuner:
cache_input_tensors
=
cache_input_tensors
,
warmup
=
warmup
,
rep
=
rep
,
timeout
=
timeout
)
timeout
=
timeout
,
)
# If a custom `supply_prog` is provided, the profiler's `supply_type` setting
# becomes ineffective. The custom supply program will be used instead.
if
supply_prog
is
not
None
and
supply_type
!=
tilelang
.
TensorSupplyType
.
Auto
:
logger
.
warning
(
"Ignoring `supply_type` passed to `set_profile_args` because "
"`supply_prog` is not None."
)
logger
.
warning
(
"Ignoring `supply_type` passed to `set_profile_args` because `supply_prog` is not None."
)
return
self
...
...
@@ -241,10 +246,8 @@ class AutoTuner:
self
.
_kernel_parameters
=
k_parameters
self
.
_function_parameters
=
f_parameters
def
generate_cache_key
(
self
,
parameters
:
dict
[
str
,
Any
],
extra_parameters
:
dict
[
str
,
Any
])
->
AutotuneResult
|
None
:
"""Generate a cache key for the auto-tuning process.
"""
def
generate_cache_key
(
self
,
parameters
:
dict
[
str
,
Any
],
extra_parameters
:
dict
[
str
,
Any
])
->
AutotuneResult
|
None
:
"""Generate a cache key for the auto-tuning process."""
def
_normalize_param
(
value
):
if
isinstance
(
value
,
Var
):
...
...
@@ -315,8 +318,9 @@ class AutoTuner:
if
var_name
in
parameters
:
continue
# Cell content must be serializable
assert
isinstance
(
cell
.
cell_contents
,
(
int
,
float
,
str
,
bool
,
type
(
None
))),
\
assert
isinstance
(
cell
.
cell_contents
,
(
int
,
float
,
str
,
bool
,
type
(
None
))),
(
f
"Cell contents
{
cell
.
cell_contents
}
is not serializable:
{
type
(
cell
.
cell_contents
)
}
"
)
extra_parameters
[
var_name
]
=
cell
.
cell_contents
if
isinstance
(
self
.
configs
,
Callable
):
...
...
@@ -328,8 +332,10 @@ class AutoTuner:
if
env
.
is_cache_enabled
()
and
not
env
.
is_autotune_cache_disabled
():
# First check in-memory cache
if
key
in
self
.
_memory_cache
:
logger
.
warning
(
"Found kernel in memory cache. For better performance,"
\
" consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel."
)
logger
.
warning
(
"Found kernel in memory cache. For better performance,"
" consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel."
)
return
self
.
_memory_cache
[
key
]
# Then check disk cache
...
...
@@ -369,7 +375,6 @@ class AutoTuner:
# This encapsulates the logic of using either a custom supply program (`supply_prog`)
# or the default profiler input generation (`profiler._get_inputs`).
def
get_input_tensors_supply
(
with_output
:
bool
):
def
func
():
if
supply_prog
is
not
None
:
return
supply_prog
(
profiler
.
_get_params
(
with_output
=
with_output
))
...
...
@@ -387,8 +392,7 @@ class AutoTuner:
self
.
jit_input_tensors
=
jit_input_tensors_supply
()
else
:
# check if the cached tensors are compatible with the current configuration
assert
len
(
params
)
==
len
(
self
.
jit_input_tensors
),
"len(params) != len(self.jit_input_tensors)"
assert
len
(
params
)
==
len
(
self
.
jit_input_tensors
),
"len(params) != len(self.jit_input_tensors)"
for
p
,
c
in
zip
(
params
,
self
.
jit_input_tensors
):
if
not
isinstance
(
c
,
torch
.
Tensor
):
# skip non-tensor inputs checking
...
...
@@ -397,8 +401,8 @@ class AutoTuner:
# Check tensor compatibility using generator expression
def
shape_equal
(
a
,
b
):
return
all
(
a_dim
==
b_dim
or
isinstance
(
a_dim
,
Var
)
or
isinstance
(
b_dim
,
Var
)
for
a_dim
,
b_dim
in
zip
(
a
.
shape
,
b
.
shape
)
)
a_dim
==
b_dim
or
isinstance
(
a_dim
,
Var
)
or
isinstance
(
b_dim
,
Var
)
for
a_dim
,
b_dim
in
zip
(
a
.
shape
,
b
.
shape
)
)
if
p
.
dtype
!=
c
.
dtype
or
not
shape_equal
(
p
,
c
):
logger
.
warning
(
...
...
@@ -409,7 +413,8 @@ class AutoTuner:
"To ensure fresh, compatible inputs are generated for every trial "
"you can disable caching by setting:
\n
"
" `cache_input_tensors=False`
\n
"
"within your `.set_compile_args(...)` call.
\n
"
)
"within your `.set_compile_args(...)` call.
\n
"
)
# otherwise, regenerate the input tensors for safety
self
.
jit_input_tensors
=
jit_input_tensors_supply
()
break
...
...
@@ -418,24 +423,16 @@ class AutoTuner:
if
(
not
skip_check
)
and
(
ref_prog
is
not
None
):
if
manual_check_prog
is
not
None
:
profiler
.
manual_assert_close
(
ref_prog
,
input_tensors
=
self
.
jit_input_tensors
,
manual_check_prog
=
manual_check_prog
)
profiler
.
manual_assert_close
(
ref_prog
,
input_tensors
=
self
.
jit_input_tensors
,
manual_check_prog
=
manual_check_prog
)
else
:
profiler
.
assert_allclose
(
ref_prog
,
input_tensors
=
self
.
jit_input_tensors
,
rtol
=
rtol
,
atol
=
atol
,
max_mismatched_ratio
=
max_mismatched_ratio
)
latency
=
profiler
.
do_bench
(
warmup
=
warmup
,
rep
=
rep
,
input_tensors
=
self
.
jit_input_tensors
)
ref_prog
,
input_tensors
=
self
.
jit_input_tensors
,
rtol
=
rtol
,
atol
=
atol
,
max_mismatched_ratio
=
max_mismatched_ratio
)
latency
=
profiler
.
do_bench
(
warmup
=
warmup
,
rep
=
rep
,
input_tensors
=
self
.
jit_input_tensors
)
if
self
.
ref_latency_cache
is
None
and
ref_prog
is
not
None
:
self
.
ref_input_tensors
=
ref_input_tensors_supply
()
self
.
ref_latency_cache
=
profiler
.
do_bench
(
ref_prog
,
n_warmup
=
warmup
,
n_repeat
=
rep
,
input_tensors
=
self
.
ref_input_tensors
)
self
.
ref_latency_cache
=
profiler
.
do_bench
(
ref_prog
,
n_warmup
=
warmup
,
n_repeat
=
rep
,
input_tensors
=
self
.
ref_input_tensors
)
return
latency
,
self
.
ref_latency_cache
...
...
@@ -469,17 +466,14 @@ class AutoTuner:
# Check if all tunable arguments have been tuned by comparing config keys with key_kwargs_tuple
if
any
(
key
in
top_config
for
key
,
_
in
key_kwargs_tuple
)
or
any
(
check_tunable_argument_value
(
key
,
self
.
_function_parameters
,
key_args_tuple
)
for
key
in
tunable_arguments
):
check_tunable_argument_value
(
key
,
self
.
_function_parameters
,
key_args_tuple
)
for
key
in
tunable_arguments
):
logger
.
warning
(
f
"Tunable parameters
{
tunable_arguments
}
already provided during auto-tuning. Skipping compilation and using direct JIT"
)
# compile the kernel with the provided parameters
jit_kernel
=
self
.
jit_compile
()
autotuner_result
=
AutotuneResult
(
libcode
=
jit_kernel
.
get_kernel_source
(),
func
=
jit_kernel
.
prim_func
,
kernel
=
jit_kernel
)
autotuner_result
=
AutotuneResult
(
libcode
=
jit_kernel
.
get_kernel_source
(),
func
=
jit_kernel
.
prim_func
,
kernel
=
jit_kernel
)
self
.
_memory_cache
[
key
]
=
autotuner_result
return
autotuner_result
# get the cpu count
...
...
@@ -489,9 +483,7 @@ class AutoTuner:
max_cpu_count
=
int
(
env
.
TILELANG_AUTO_TUNING_MAX_CPU_COUNT
)
if
cpu_counts
>
0
:
num_workers
=
min
(
cpu_counts
,
available_cpu_count
)
logger
.
info
(
f
"Auto-tuning with
{
cpu_counts
}
CPU counts,
{
available_cpu_count
}
CPUs available,
{
num_workers
}
CPUs will be used"
)
logger
.
info
(
f
"Auto-tuning with
{
cpu_counts
}
CPU counts,
{
available_cpu_count
}
CPUs available,
{
num_workers
}
CPUs will be used"
)
else
:
num_workers
=
max
(
1
,
int
(
available_cpu_count
*
cpu_utilizations
))
logger
.
info
(
...
...
@@ -509,7 +501,6 @@ class AutoTuner:
future_to_index
=
{}
def
cuda_device_wrapper
(
func
,
device
):
def
inner
(
**
config_arg
):
torch
.
cuda
.
set_device
(
device
)
return
func
(
**
config_arg
)
...
...
@@ -532,18 +523,14 @@ class AutoTuner:
future_to_index
[
future
]
=
i
results_with_configs
=
[]
for
future
in
tqdm
(
concurrent
.
futures
.
as_completed
(
futures
),
total
=
len
(
futures
),
desc
=
"Compiling configurations"
):
for
future
in
tqdm
(
concurrent
.
futures
.
as_completed
(
futures
),
total
=
len
(
futures
),
desc
=
"Compiling configurations"
):
idx
=
future_to_index
[
future
]
config
=
config_args
[
idx
]
try
:
result
=
future
.
result
()
results_with_configs
.
append
((
result
,
config
))
except
Exception
as
e
:
logger
.
debug
(
f
"Compilation failed for config
{
config
}
at index
{
idx
}
with error:
{
e
}
"
)
logger
.
debug
(
f
"Compilation failed for config
{
config
}
at index
{
idx
}
with error:
{
e
}
"
)
continue
ref_latency
=
None
...
...
@@ -556,14 +543,10 @@ class AutoTuner:
# latency, ref_latency = target_fn(jit_kernel)
latency
,
ref_latency
=
run_with_timeout
(
target_fn
,
timeout
,
jit_kernel
)
except
TimeoutException
:
logger
.
warning
(
f
"A timeout occurred while testing config
{
config
}
, checkout autotuner.log for more details"
)
logger
.
warning
(
f
"A timeout occurred while testing config
{
config
}
, checkout autotuner.log for more details"
)
continue
except
Exception
:
logger
.
warning
(
f
"An error occurred while testing config
{
config
}
, checkout autotuner.log for more details"
)
logger
.
warning
(
f
"An error occurred while testing config
{
config
}
, checkout autotuner.log for more details"
)
logger
.
debug
(
f
"Error:
{
traceback
.
format_exc
()
}
"
)
continue
...
...
@@ -578,8 +561,7 @@ class AutoTuner:
pool
.
shutdown
()
if
best_kernel
is
None
:
error_msg
=
(
"Auto-tuning failed: No configuration successfully "
"compiled and passed benchmarking/validation."
)
error_msg
=
"Auto-tuning failed: No configuration successfully compiled and passed benchmarking/validation."
logger
.
error
(
error_msg
)
raise
RuntimeError
(
error_msg
)
...
...
@@ -595,7 +577,8 @@ class AutoTuner:
ref_latency
=
ref_latency
,
libcode
=
best_kernel
.
get_kernel_source
(),
func
=
best_kernel
.
prim_func
,
kernel
=
best_kernel
)
kernel
=
best_kernel
,
)
if
self
.
compile_args
.
execution_backend
in
(
"torch"
):
logger
.
warning
(
"DLPack backend does not support cache saving to disk."
)
...
...
@@ -617,8 +600,8 @@ class AutoTuner:
return
self
.
run
()
_P
=
ParamSpec
(
'
_P
'
)
_T
=
TypeVar
(
'
_T
'
)
_P
=
ParamSpec
(
"
_P
"
)
_T
=
TypeVar
(
"
_T
"
)
@
dataclass
...
...
@@ -643,8 +626,9 @@ class AutoTuneImpl(Generic[_P, _T]):
self
.
_tuner_cache
=
{}
def
get_tunner
(
self
):
autotuner
=
AutoTuner
(
self
.
jit_impl
.
func
,
configs
=
self
.
configs
).
set_profile_args
(
autotuner
=
(
AutoTuner
(
self
.
jit_impl
.
func
,
configs
=
self
.
configs
)
.
set_profile_args
(
supply_type
=
self
.
supply_type
,
ref_prog
=
self
.
ref_prog
,
supply_prog
=
self
.
supply_prog
,
...
...
@@ -654,7 +638,8 @@ class AutoTuneImpl(Generic[_P, _T]):
skip_check
=
self
.
skip_check
,
manual_check_prog
=
self
.
manual_check_prog
,
cache_input_tensors
=
self
.
cache_input_tensors
,
).
set_compile_args
(
)
.
set_compile_args
(
out_idx
=
self
.
jit_impl
.
out_idx
,
execution_backend
=
self
.
jit_impl
.
execution_backend
,
target
=
self
.
jit_impl
.
target
,
...
...
@@ -662,6 +647,7 @@ class AutoTuneImpl(Generic[_P, _T]):
verbose
=
self
.
jit_impl
.
verbose
,
pass_configs
=
self
.
jit_impl
.
pass_configs
,
)
)
autotuner
.
run
=
partial
(
autotuner
.
run
,
self
.
warmup
,
self
.
rep
,
self
.
timeout
)
return
autotuner
...
...
@@ -753,16 +739,13 @@ def autotune( # This is the new public interface
if
callable
(
func
):
# Case 1: Used as @autotune (func_or_out_idx is the function, others are defaults)
# This is a placeholder for a real auto tuner implementation
raise
ValueError
(
"Use tilelang.autotune to decorate func without arguments is not supported yet."
)
raise
ValueError
(
"Use tilelang.autotune to decorate func without arguments is not supported yet."
)
elif
isinstance
(
func
,
PrimFunc
):
raise
ValueError
(
"Use tilelang.jit to decorate prim_func is not supported yet."
)
else
:
def
decorator
(
impl
):
assert
isinstance
(
impl
,
JITImpl
),
"The @autotune decorator can only be applied to @tilelang.jit decorated instances."
assert
isinstance
(
impl
,
JITImpl
),
"The @autotune decorator can only be applied to @tilelang.jit decorated instances."
return
AutoTuneImpl
(
jit_impl
=
impl
,
configs
=
configs
,
...
...
tilelang/cache/__init__.py
View file @
29051439
"""The cache utils with class and database persistence - Init file"""
from
__future__
import
annotations
from
typing
import
Literal
...
...
@@ -18,8 +19,7 @@ def cached(
*
args
,
target
:
str
|
Target
=
"auto"
,
target_host
:
str
|
Target
=
None
,
execution_backend
:
Literal
[
"auto"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
|
None
=
"auto"
,
execution_backend
:
Literal
[
"auto"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
|
None
=
"auto"
,
verbose
:
bool
|
None
=
False
,
pass_configs
:
dict
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
,
...
...
@@ -36,7 +36,8 @@ def cached(
execution_backend
=
execution_backend
,
verbose
=
verbose
,
pass_configs
=
pass_configs
,
compile_flags
=
compile_flags
)
compile_flags
=
compile_flags
,
)
def
clear_cache
():
...
...
@@ -47,9 +48,11 @@ def clear_cache():
RuntimeError: Always raised to warn users to clear the cache manually.
"""
cache_dir
=
env
.
TILELANG_CACHE_DIR
raise
RuntimeError
(
"tilelang.clear_cache() is disabled because deleting the cache directory "
"is dangerous. If you accept the risk, remove it manually with "
f
"`rm -rf '
{
cache_dir
}
'`."
)
raise
RuntimeError
(
"tilelang.clear_cache() is disabled because deleting the cache directory "
"is dangerous. If you accept the risk, remove it manually with "
f
"`rm -rf '
{
cache_dir
}
'`."
)
if
env
.
TILELANG_CLEAR_CACHE
.
lower
()
in
(
"1"
,
"true"
,
"yes"
,
"on"
):
...
...
tilelang/cache/kernel_cache.py
View file @
29051439
"""The cache utils with class and database persistence - KernelCache Class"""
from
__future__
import
annotations
import
json
...
...
@@ -97,9 +98,7 @@ class KernelCache:
"version"
:
__version__
,
"func"
:
sha256
(
func_binary
).
hexdigest
(),
# Use SHA256 to generate hash key
"out_idx"
:
(
tuple
(
out_idx
)
if
isinstance
(
out_idx
,
(
list
,
tuple
))
else
[
out_idx
]),
"args_repr"
:
tuple
(
repr
(
arg
)
for
arg
in
args
),
# Use repr to serialize arguments, may need more robust serialization
"args_repr"
:
tuple
(
repr
(
arg
)
for
arg
in
args
),
# Use repr to serialize arguments, may need more robust serialization
"target"
:
str
(
target
),
"target_host"
:
str
(
target_host
)
if
target_host
else
None
,
"execution_backend"
:
execution_backend
,
...
...
@@ -118,8 +117,7 @@ class KernelCache:
*
args
,
target
:
str
|
Target
=
"auto"
,
target_host
:
str
|
Target
=
None
,
execution_backend
:
Literal
[
"auto"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
=
"auto"
,
execution_backend
:
Literal
[
"auto"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
=
"auto"
,
verbose
:
bool
=
False
,
pass_configs
:
dict
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
,
...
...
@@ -140,6 +138,7 @@ class KernelCache:
# Normalize target and resolve execution backend before proceeding
from
tilelang.utils.target
import
determine_target
as
_determine_target
from
tilelang.jit.execution_backend
import
resolve_execution_backend
,
allowed_backends_for_target
norm_target
=
Target
(
_determine_target
(
target
))
if
isinstance
(
target
,
str
)
else
target
requested_backend
=
execution_backend
execution_backend
=
resolve_execution_backend
(
requested_backend
,
norm_target
)
...
...
@@ -180,21 +179,21 @@ class KernelCache:
with
self
.
_lock
:
# First check in-memory cache
if
key
in
self
.
_memory_cache
:
self
.
logger
.
warning
(
"Found kernel in memory cache. For better performance,"
\
" consider using `@tilelang.jit` instead of direct kernel caching."
)
self
.
logger
.
warning
(
"Found kernel in memory cache. For better performance, consider using `@tilelang.jit` instead of direct kernel caching."
)
return
self
.
_memory_cache
[
key
]
if
verbose
:
self
.
logger
.
debug
(
f
"Checking disk cache for kernel
{
func
.
attrs
[
'global_symbol'
]
}
"
)
# Then check disk cache
kernel
=
self
.
_load_kernel_from_disk
(
key
,
norm_target
,
target_host
,
out_idx
,
execution_backend
,
pass_configs
,
compile_flags
,
func
,
verbose
)
kernel
=
self
.
_load_kernel_from_disk
(
key
,
norm_target
,
target_host
,
out_idx
,
execution_backend
,
pass_configs
,
compile_flags
,
func
,
verbose
)
if
kernel
is
not
None
:
if
verbose
:
self
.
logger
.
debug
(
f
"Found kernel in disk cache for
{
func
.
attrs
[
'global_symbol'
]
}
"
)
self
.
logger
.
debug
(
f
"Found kernel in disk cache for
{
func
.
attrs
[
'global_symbol'
]
}
"
)
# Populate memory cache with disk result
self
.
_memory_cache
[
key
]
=
kernel
return
kernel
...
...
@@ -262,11 +261,7 @@ class KernelCache:
executable
.
export_library
(
temp_path
)
os
.
replace
(
temp_path
,
path
)
def
_save_kernel_to_disk
(
self
,
key
:
str
,
kernel
:
JITKernel
,
func
:
Callable
=
None
,
verbose
:
bool
=
False
):
def
_save_kernel_to_disk
(
self
,
key
:
str
,
kernel
:
JITKernel
,
func
:
Callable
=
None
,
verbose
:
bool
=
False
):
"""
Persists a compiled kernel to disk cache.
...
...
@@ -292,8 +287,7 @@ class KernelCache:
if
verbose
:
self
.
logger
.
debug
(
f
"Saving kernel source code to file:
{
device_kernel_path
}
"
)
if
kernel
.
kernel_source
is
not
None
:
KernelCache
.
_safe_write_file
(
device_kernel_path
,
"w"
,
lambda
file
:
file
.
write
(
kernel
.
kernel_source
))
KernelCache
.
_safe_write_file
(
device_kernel_path
,
"w"
,
lambda
file
:
file
.
write
(
kernel
.
kernel_source
))
except
Exception
as
e
:
self
.
logger
.
error
(
f
"Error saving kernel source code to disk:
{
e
}
"
)
...
...
@@ -303,13 +297,9 @@ class KernelCache:
if
verbose
:
self
.
logger
.
debug
(
f
"Saving wrapped kernel source code to file:
{
host_kernel_path
}
"
)
if
self
.
execution_backend
==
"tvm_ffi"
:
KernelCache
.
_safe_write_file
(
host_kernel_path
,
"w"
,
lambda
file
:
file
.
write
(
kernel
.
adapter
.
get_host_source
()))
KernelCache
.
_safe_write_file
(
host_kernel_path
,
"w"
,
lambda
file
:
file
.
write
(
kernel
.
adapter
.
get_host_source
()))
else
:
KernelCache
.
_safe_write_file
(
host_kernel_path
,
"w"
,
lambda
file
:
file
.
write
(
kernel
.
adapter
.
get_kernel_source
()))
KernelCache
.
_safe_write_file
(
host_kernel_path
,
"w"
,
lambda
file
:
file
.
write
(
kernel
.
adapter
.
get_kernel_source
()))
except
Exception
as
e
:
self
.
logger
.
error
(
f
"Error saving host kernel source code to disk:
{
e
}
"
)
...
...
@@ -332,9 +322,7 @@ class KernelCache:
src_lib_path
=
src_lib_path
.
replace
(
".cubin"
,
".py"
)
if
verbose
:
self
.
logger
.
debug
(
f
"Saving kernel nvrtc python code to file:
{
kernel_py_path
}
"
)
KernelCache
.
_safe_write_file
(
kernel_py_path
,
"wb"
,
lambda
file
:
file
.
write
(
KernelCache
.
_load_binary
(
src_lib_path
)))
KernelCache
.
_safe_write_file
(
kernel_py_path
,
"wb"
,
lambda
file
:
file
.
write
(
KernelCache
.
_load_binary
(
src_lib_path
)))
elif
self
.
execution_backend
==
"tvm_ffi"
:
executable
=
kernel
.
adapter
.
executable
if
verbose
:
...
...
@@ -344,9 +332,7 @@ class KernelCache:
src_lib_path
=
kernel
.
adapter
.
libpath
if
verbose
:
self
.
logger
.
debug
(
f
"Saving kernel library to file:
{
kernel_lib_path
}
"
)
KernelCache
.
_safe_write_file
(
kernel_lib_path
,
"wb"
,
lambda
file
:
file
.
write
(
KernelCache
.
_load_binary
(
src_lib_path
)))
KernelCache
.
_safe_write_file
(
kernel_lib_path
,
"wb"
,
lambda
file
:
file
.
write
(
KernelCache
.
_load_binary
(
src_lib_path
)))
except
Exception
as
e
:
self
.
logger
.
error
(
f
"Error saving kernel library to disk:
{
e
}
"
)
...
...
@@ -356,8 +342,7 @@ class KernelCache:
params_path
=
os
.
path
.
join
(
cache_path
,
PARAMS_PATH
)
if
verbose
:
self
.
logger
.
debug
(
f
"Saving kernel parameters to disk:
{
params_path
}
"
)
KernelCache
.
_safe_write_file
(
params_path
,
"wb"
,
lambda
file
:
cloudpickle
.
dump
(
kernel
.
params
,
file
))
KernelCache
.
_safe_write_file
(
params_path
,
"wb"
,
lambda
file
:
cloudpickle
.
dump
(
kernel
.
params
,
file
))
except
Exception
as
e
:
self
.
logger
.
error
(
f
"Error saving kernel parameters to disk:
{
e
}
"
)
...
...
@@ -417,8 +402,7 @@ class KernelCache:
self
.
logger
.
error
(
f
"Error loading kernel source code from disk:
{
e
}
"
)
try
:
if
verbose
:
self
.
logger
.
debug
(
f
"Loading wrapped kernel source code from file:
{
host_kernel_path
}
"
)
self
.
logger
.
debug
(
f
"Loading wrapped kernel source code from file:
{
host_kernel_path
}
"
)
with
open
(
host_kernel_path
)
as
f
:
host_kernel_source
=
f
.
read
()
except
Exception
as
e
:
...
...
tilelang/carver/__init__.py
View file @
29051439
"""Base infra"""
from
.analysis
import
(
BlockInfo
,
# noqa: F401
IterInfo
,
# noqa: F401
...
...
tilelang/carver/analysis.py
View file @
29051439
"""Analysis on TIR blocks, loops and functions."""
from
__future__
import
annotations
from
typing_extensions
import
Literal
...
...
@@ -144,11 +145,13 @@ def normalize_prim_func(sch: tir.Schedule) -> list[BlockInfo] | None:
var
=
iter
.
var
,
dom
=
iter
.
dom
,
loop_rv
=
loop
,
)
for
loop
,
iter
in
zip
(
loops
,
iters
)
)
for
loop
,
iter
in
zip
(
loops
,
iters
)
],
block_rv
=
block
,
reduction_block
=
is_reduction
,
))
)
)
return
blocks
...
...
@@ -188,8 +191,7 @@ def get_max_shared_memory_per_block(target: Target) -> int:
_assert_gpu_target
(
target
)
max_shared_memory_per_block
=
target
.
attrs
.
get
(
"max_shared_memory_per_block"
,
None
)
if
max_shared_memory_per_block
is
None
:
raise
ValueError
(
f
"Cannot find `max_shared_memory_per_block` in
{
target
}
, please specify it manually"
)
raise
ValueError
(
f
"Cannot find `max_shared_memory_per_block` in
{
target
}
, please specify it manually"
)
return
int
(
max_shared_memory_per_block
)
...
...
@@ -197,13 +199,11 @@ def get_root_block(sch: Schedule, func_name: str = "main") -> BlockRV:
try
:
block
=
sch
.
mod
[
func_name
].
body
.
block
except
Exception
:
raise
ValueError
(
f
"The function body is expected to be the root block, but got:
\n
"
f
"
{
sch
.
mod
[
func_name
].
body
}
"
)
from
None
raise
ValueError
(
f
"The function body is expected to be the root block, but got:
\n
{
sch
.
mod
[
func_name
].
body
}
"
)
from
None
return
sch
.
get_block
(
block
.
name_hint
)
def
collect_block_iter_vars_used_in_access_region
(
block
:
tir
.
Block
,
region
:
list
[
ir
.
Range
])
->
set
[
tir
.
Var
]:
def
collect_block_iter_vars_used_in_access_region
(
block
:
tir
.
Block
,
region
:
list
[
ir
.
Range
])
->
set
[
tir
.
Var
]:
"""Collect the block iter variables used in the access region of a buffer region."""
tir_vars
=
set
()
for
expr
in
region
:
...
...
@@ -251,15 +251,13 @@ def is_broadcast_epilogue(
for
buffer_region
in
sch
.
get
(
epilogue
).
reads
:
if
buffer_region
.
buffer
not
in
write_buffers
:
continue
tir_vars
=
collect_block_iter_vars_used_in_access_region
(
sch
.
get
(
epilogue
),
buffer_region
.
region
)
tir_vars
=
collect_block_iter_vars_used_in_access_region
(
sch
.
get
(
epilogue
),
buffer_region
.
region
)
if
len
(
tir_vars
)
<
len
(
epilogue_iters
):
return
True
return
False
def
get_reduction_blocks
(
sch
:
tir
.
Schedule
,
blocks
:
list
[
tir
.
schedule
.
BlockRV
])
->
list
[
tir
.
schedule
.
BlockRV
]:
def
get_reduction_blocks
(
sch
:
tir
.
Schedule
,
blocks
:
list
[
tir
.
schedule
.
BlockRV
])
->
list
[
tir
.
schedule
.
BlockRV
]:
# Get the main computation block
def
is_reduction
(
block
:
BlockRV
)
->
bool
:
block_stmt
=
sch
.
get
(
block
)
...
...
tilelang/carver/arch/__init__.py
View file @
29051439
...
...
@@ -39,18 +39,18 @@ def auto_infer_current_arch() -> TileDevice:
__all__
=
[
'
is_cpu_arch
'
,
'
is_cuda_arch
'
,
'
is_volta_arch
'
,
'
is_ampere_arch
'
,
'
is_ada_arch
'
,
'
is_hopper_arch
'
,
'
is_tensorcore_supported_precision
'
,
'
has_mma_support
'
,
'
is_cdna_arch
'
,
'
is_metal_arch
'
,
'
CUDA
'
,
'
CDNA
'
,
'
METAL
'
,
'
CPU
'
,
"
is_cpu_arch
"
,
"
is_cuda_arch
"
,
"
is_volta_arch
"
,
"
is_ampere_arch
"
,
"
is_ada_arch
"
,
"
is_hopper_arch
"
,
"
is_tensorcore_supported_precision
"
,
"
has_mma_support
"
,
"
is_cdna_arch
"
,
"
is_metal_arch
"
,
"
CUDA
"
,
"
CDNA
"
,
"
METAL
"
,
"
CPU
"
,
]
tilelang/carver/arch/arch_base.py
View file @
29051439
...
...
@@ -7,9 +7,7 @@ class TileDevice:
self
.
reg_cap
:
int
=
0
# Register capacity: The amount of register memory available
self
.
smem_cap
:
int
=
0
# Shared memory capacity: The amount of shared memory available
self
.
compute_max_core
:
int
=
0
# The maximum number of computing cores
self
.
warp_size
:
int
=
(
0
# The size of a warp, a group of threads that execute instructions in lockstep
)
self
.
warp_size
:
int
=
0
# The size of a warp, a group of threads that execute instructions in lockstep
self
.
sm_partition
:
int
=
0
# The number of streaming multiprocessor partitions
self
.
transaction_size
:
list
[
int
]
=
[
0
,
...
...
@@ -21,9 +19,7 @@ class TileDevice:
0
,
]
# Bandwidth specifications, possibly including peak and sustained rates
self
.
platform
:
str
=
"unknown"
# The platform or manufacturer of the device
self
.
compute_capability
:
str
=
(
"unknown"
# The compute capability, indicating the feature set and performance level
)
self
.
compute_capability
:
str
=
"unknown"
# The compute capability, indicating the feature set and performance level
self
.
l2_cache_size_bytes
:
int
=
0
# the number of transaction size in bytes
self
.
transaction_size
:
list
[
int
]
=
[
0
,
0
]
# in bytes
...
...
tilelang/carver/arch/cdna.py
View file @
29051439
...
...
@@ -9,7 +9,6 @@ def is_cdna_arch(arch: TileDevice) -> bool:
class
CDNA
(
TileDevice
):
def
__init__
(
self
,
target
:
Target
|
str
):
if
isinstance
(
target
,
str
):
target
=
tvm
.
target
.
Target
(
target
)
...
...
@@ -33,6 +32,6 @@ class CDNA(TileDevice):
__all__
=
[
'
is_cdna_arch
'
,
'
CDNA
'
,
"
is_cdna_arch
"
,
"
CDNA
"
,
]
tilelang/carver/arch/cpu.py
View file @
29051439
...
...
@@ -10,7 +10,6 @@ def is_cpu_arch(arch: TileDevice) -> bool:
# For LLVM Backend, we do not provide the detailed information of the CPU
# As the LLVM backend do not required tuning, just maintain the consistency
class
CPU
(
TileDevice
):
def
__init__
(
self
,
target
:
Target
):
self
.
target
=
target
device
=
tvm
.
runtime
.
cpu
(
0
)
...
...
@@ -21,6 +20,6 @@ class CPU(TileDevice):
__all__
=
[
'
is_cpu_arch
'
,
'
CPU
'
,
"
is_cpu_arch
"
,
"
CPU
"
,
]
tilelang/carver/arch/cuda.py
View file @
29051439
...
...
@@ -78,7 +78,6 @@ hopper_tensorcore_supported = ada_tensorcore_supported
# instead of assuming both a and b share the same dtype.
# As the tensorcore may supports float8_e4m3 * float8_e5m2
def
is_tensorcore_supported_precision
(
in_dtype
:
str
,
accum_dtype
:
str
,
arch
:
TileDevice
)
->
bool
:
if
is_volta_arch
(
arch
):
return
(
in_dtype
,
accum_dtype
)
in
volta_tensorcore_supported
elif
is_ampere_arch
(
arch
):
...
...
@@ -92,7 +91,6 @@ def is_tensorcore_supported_precision(in_dtype: str, accum_dtype: str, arch: Til
class
TensorInstruction
:
def
__init__
(
self
,
name
:
str
,
...
...
@@ -104,7 +102,6 @@ class TensorInstruction:
class
CUDA
(
TileDevice
):
def
__init__
(
self
,
target
:
Target
|
str
):
if
isinstance
(
target
,
str
):
target
=
tvm
.
target
.
Target
(
target
)
...
...
@@ -148,12 +145,12 @@ class CUDA(TileDevice):
__all__
=
[
'
is_cuda_arch
'
,
'
is_volta_arch
'
,
'
is_ampere_arch
'
,
'
is_ada_arch
'
,
'
is_hopper_arch
'
,
'
is_tensorcore_supported_precision
'
,
'
has_mma_support
'
,
"
is_cuda_arch
"
,
"
is_volta_arch
"
,
"
is_ampere_arch
"
,
"
is_ada_arch
"
,
"
is_hopper_arch
"
,
"
is_tensorcore_supported_precision
"
,
"
has_mma_support
"
,
"CUDA"
,
]
tilelang/carver/arch/driver/cuda_driver.py
View file @
29051439
...
...
@@ -83,8 +83,7 @@ def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes")
Get the maximum dynamic shared memory size in bytes, kilobytes, or megabytes.
"""
assert
format
in
[
"bytes"
,
"kb"
,
"mb"
],
"Invalid format. Must be one of: bytes, kb, mb"
shared_mem
=
get_device_attribute
(
cudaDeviceAttrNames
.
cudaDevAttrMaxSharedMemoryPerMultiprocessor
,
device_id
)
shared_mem
=
get_device_attribute
(
cudaDeviceAttrNames
.
cudaDevAttrMaxSharedMemoryPerMultiprocessor
,
device_id
)
if
format
==
"bytes"
:
return
shared_mem
elif
format
==
"kb"
:
...
...
Prev
1
…
13
14
15
16
17
18
19
20
21
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment