Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
239 additions
and
296 deletions
+239
-296
testing/python/transform/test_tilelang_transform_warp_specialized.py
...hon/transform/test_tilelang_transform_warp_specialized.py
+37
-35
testing/python/utils/test_compress_utils.py
testing/python/utils/test_compress_utils.py
+1
-1
testing/python/webgpu/test_webgpu_codegen.py
testing/python/webgpu/test_webgpu_codegen.py
+3
-4
tilelang/__init__.py
tilelang/__init__.py
+2
-0
tilelang/analysis/fragment_loop_checker.py
tilelang/analysis/fragment_loop_checker.py
+9
-9
tilelang/analysis/layout_visual.py
tilelang/analysis/layout_visual.py
+1
-5
tilelang/analysis/nested_loop_checker.py
tilelang/analysis/nested_loop_checker.py
+8
-15
tilelang/autotuner/capture.py
tilelang/autotuner/capture.py
+1
-2
tilelang/autotuner/param.py
tilelang/autotuner/param.py
+33
-40
tilelang/autotuner/tuner.py
tilelang/autotuner/tuner.py
+77
-94
tilelang/cache/__init__.py
tilelang/cache/__init__.py
+9
-6
tilelang/cache/kernel_cache.py
tilelang/cache/kernel_cache.py
+19
-35
tilelang/carver/__init__.py
tilelang/carver/__init__.py
+1
-0
tilelang/carver/analysis.py
tilelang/carver/analysis.py
+10
-12
tilelang/carver/arch/__init__.py
tilelang/carver/arch/__init__.py
+14
-14
tilelang/carver/arch/arch_base.py
tilelang/carver/arch/arch_base.py
+2
-6
tilelang/carver/arch/cdna.py
tilelang/carver/arch/cdna.py
+2
-3
tilelang/carver/arch/cpu.py
tilelang/carver/arch/cpu.py
+2
-3
tilelang/carver/arch/cuda.py
tilelang/carver/arch/cuda.py
+7
-10
tilelang/carver/arch/driver/cuda_driver.py
tilelang/carver/arch/driver/cuda_driver.py
+1
-2
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
testing/python/transform/test_tilelang_transform_warp_specialized.py
View file @
29051439
...
@@ -32,7 +32,6 @@ block_K = 32
...
@@ -32,7 +32,6 @@ block_K = 32
def
test_warp_specialized
():
def
test_warp_specialized
():
@
T
.
prim_func
@
T
.
prim_func
def
before
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
)):
def
before
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
)):
bx
=
T
.
launch_thread
(
"blockIdx.x"
,
8
)
bx
=
T
.
launch_thread
(
"blockIdx.x"
,
8
)
...
@@ -47,25 +46,27 @@ def test_warp_specialized():
...
@@ -47,25 +46,27 @@ def test_warp_specialized():
for
k
in
T
.
serial
(
16
,
annotations
=
{
"num_stages"
:
T
.
int32
(
3
)}):
for
k
in
T
.
serial
(
16
,
annotations
=
{
"num_stages"
:
T
.
int32
(
3
)}):
if
v
==
0
:
if
v
==
0
:
T
.
tma_load
(
T
.
tma_load
(
T
.
create_tma_descriptor
(
6
,
2
,
A
.
data
,
512
,
512
,
2
,
1024
,
32
,
64
,
1
,
1
,
0
,
2
,
T
.
create_tma_descriptor
(
6
,
2
,
A
.
data
,
512
,
512
,
2
,
1024
,
32
,
64
,
1
,
1
,
0
,
2
,
2
,
0
),
2
,
0
),
0
,
0
,
T
.
tvm_access_ptr
(
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
k
*
32
,
k
*
32
,
by
*
64
)
by
*
64
,
)
if
v
==
0
:
if
v
==
0
:
T
.
tma_load
(
T
.
tma_load
(
T
.
create_tma_descriptor
(
6
,
2
,
B
.
data
,
512
,
512
,
2
,
1024
,
64
,
32
,
1
,
1
,
0
,
3
,
T
.
create_tma_descriptor
(
6
,
2
,
B
.
data
,
512
,
512
,
2
,
1024
,
64
,
32
,
1
,
1
,
0
,
3
,
2
,
0
),
2
,
0
),
0
,
0
,
T
.
tvm_access_ptr
(
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
bx
*
64
,
bx
*
64
,
k
*
32
)
k
*
32
,
)
T
.
call_extern
(
T
.
call_extern
(
"handle"
,
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>"
,
"handle"
,
T
.
tvm_access_ptr
(
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>"
,
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
type_annotation
(
"float
16
"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float
32
"
),
C_local
.
data
,
0
,
32
,
3
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
)
)
)
@
T
.
prim_func
@
T
.
prim_func
def
after
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
)):
def
after
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
)):
...
@@ -85,34 +86,35 @@ def test_warp_specialized():
...
@@ -85,34 +86,35 @@ def test_warp_specialized():
T
.
mbarrier_expect_tx
(
T
.
get_mbarrier
(
k
%
3
),
4096
)
T
.
mbarrier_expect_tx
(
T
.
get_mbarrier
(
k
%
3
),
4096
)
if
v
-
128
==
0
:
if
v
-
128
==
0
:
T
.
tma_load
(
T
.
tma_load
(
T
.
create_tma_descriptor
(
6
,
2
,
A
.
data
,
512
,
512
,
2
,
1024
,
32
,
64
,
1
,
1
,
0
,
2
,
T
.
create_tma_descriptor
(
6
,
2
,
A
.
data
,
512
,
512
,
2
,
1024
,
32
,
64
,
1
,
1
,
0
,
2
,
2
,
0
),
2
,
0
),
T
.
get_mbarrier
(
k
%
3
),
T
.
get_mbarrier
(
k
%
3
),
T
.
tvm_access_ptr
(
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
k
*
32
,
k
*
32
,
by
*
64
)
by
*
64
,
)
if
v
-
128
==
0
:
if
v
-
128
==
0
:
T
.
mbarrier_expect_tx
(
T
.
get_mbarrier
(
k
%
3
),
4096
)
T
.
mbarrier_expect_tx
(
T
.
get_mbarrier
(
k
%
3
),
4096
)
if
v
-
128
==
0
:
if
v
-
128
==
0
:
T
.
tma_load
(
T
.
tma_load
(
T
.
create_tma_descriptor
(
6
,
2
,
B
.
data
,
512
,
512
,
2
,
1024
,
64
,
32
,
1
,
1
,
0
,
3
,
T
.
create_tma_descriptor
(
6
,
2
,
B
.
data
,
512
,
512
,
2
,
1024
,
64
,
32
,
1
,
1
,
0
,
3
,
2
,
0
),
2
,
0
),
T
.
get_mbarrier
(
k
%
3
),
T
.
get_mbarrier
(
k
%
3
),
T
.
tvm_access_ptr
(
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
bx
*
64
,
bx
*
64
,
k
*
32
)
k
*
32
,
)
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_arrive_barrier"
,
[
T
.
get_mbarrier
(
k
%
3
)]))
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_arrive_barrier"
,
[
T
.
get_mbarrier
(
k
%
3
)]))
else
:
else
:
T
.
set_max_nreg
(
240
,
1
)
T
.
set_max_nreg
(
240
,
1
)
for
k
in
range
(
16
):
for
k
in
range
(
16
):
T
.
mbarrier_wait_parity
(
T
.
get_mbarrier
(
k
%
3
),
k
//
3
%
2
)
T
.
mbarrier_wait_parity
(
T
.
get_mbarrier
(
k
%
3
),
k
//
3
%
2
)
T
.
call_extern
(
T
.
call_extern
(
"handle"
,
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>"
,
"handle"
,
T
.
tvm_access_ptr
(
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>"
,
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
))
)
T
.
evaluate
(
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_arrive_barrier"
,
[
T
.
get_mbarrier
(
k
%
3
+
3
)]))
tir
.
Call
(
"handle"
,
"tir.ptx_arrive_barrier"
,
[
T
.
get_mbarrier
(
k
%
3
+
3
)]))
_check
(
before
,
after
)
_check
(
before
,
after
)
...
...
testing/python/utils/test_compress_utils.py
View file @
29051439
...
@@ -6,7 +6,7 @@ from tilelang.utils.sparse import compress_sm90, randn_semi_sparse
...
@@ -6,7 +6,7 @@ from tilelang.utils.sparse import compress_sm90, randn_semi_sparse
def
_test_compress_sm90
(
M
,
K
,
block_k
,
dtype
):
def
_test_compress_sm90
(
M
,
K
,
block_k
,
dtype
):
A
=
randn_semi_sparse
(
M
,
K
,
dtype
=
dtype
,
device
=
'
cuda
'
)
A
=
randn_semi_sparse
(
M
,
K
,
dtype
=
dtype
,
device
=
"
cuda
"
)
A_sparse
,
E
=
compress_sm90
(
A
,
block_k
,
False
)
A_sparse
,
E
=
compress_sm90
(
A
,
block_k
,
False
)
...
...
testing/python/webgpu/test_webgpu_codegen.py
View file @
29051439
...
@@ -5,12 +5,11 @@ import tilelang.language as T
...
@@ -5,12 +5,11 @@ import tilelang.language as T
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
# Initialize Kernel Context
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
...
...
tilelang/__init__.py
View file @
29051439
...
@@ -23,6 +23,7 @@ def _compute_version() -> str:
...
@@ -23,6 +23,7 @@ def _compute_version() -> str:
if
version_file
.
is_file
():
if
version_file
.
is_file
():
try
:
try
:
from
version_provider
import
dynamic_metadata
# type: ignore
from
version_provider
import
dynamic_metadata
# type: ignore
return
dynamic_metadata
(
"version"
)
return
dynamic_metadata
(
"version"
)
except
Exception
:
except
Exception
:
# Fall back to the raw VERSION file if provider isn't available.
# Fall back to the raw VERSION file if provider isn't available.
...
@@ -33,6 +34,7 @@ def _compute_version() -> str:
...
@@ -33,6 +34,7 @@ def _compute_version() -> str:
try
:
try
:
from
importlib.metadata
import
version
as
_dist_version
# py3.8+
from
importlib.metadata
import
version
as
_dist_version
# py3.8+
return
_dist_version
(
"tilelang"
)
return
_dist_version
(
"tilelang"
)
except
Exception
as
exc
:
except
Exception
as
exc
:
warnings
.
warn
(
warnings
.
warn
(
...
...
tilelang/analysis/fragment_loop_checker.py
View file @
29051439
from
__future__
import
annotations
from
__future__
import
annotations
from
tvm
import
tir
from
tvm
import
tir
from
tvm.tir
import
(
PyStmtExprVisitor
,
BufferStore
,
For
,
Var
,
PrimFunc
,
BufferLoad
,
IntImm
)
from
tvm.tir
import
PyStmtExprVisitor
,
BufferStore
,
For
,
Var
,
PrimFunc
,
BufferLoad
,
IntImm
from
tvm.tir.transform
import
prim_func_pass
from
tvm.tir.transform
import
prim_func_pass
from
tvm.tir.stmt_functor
import
post_order_visit
from
tvm.tir.stmt_functor
import
post_order_visit
...
@@ -22,14 +22,14 @@ class _LoopVarUseAnalyzer(PyStmtExprVisitor):
...
@@ -22,14 +22,14 @@ class _LoopVarUseAnalyzer(PyStmtExprVisitor):
def
collect_local_buffer_accesses
(
statement
)
->
list
[
BufferLoad
|
BufferStore
]:
def
collect_local_buffer_accesses
(
statement
)
->
list
[
BufferLoad
|
BufferStore
]:
"""
"""
Collect local buffer accesses in the loop body.
Collect local buffer accesses in the loop body.
Args:
Args:
statement: The TIR statement to analyze
statement: The TIR statement to analyze
Returns:
Returns:
Tuple of buffer accesses in the loop body.
Tuple of buffer accesses in the loop body.
"""
"""
buffer_accesses
=
[]
buffer_accesses
=
[]
...
@@ -44,7 +44,6 @@ def collect_local_buffer_accesses(statement) -> list[BufferLoad | BufferStore]:
...
@@ -44,7 +44,6 @@ def collect_local_buffer_accesses(statement) -> list[BufferLoad | BufferStore]:
@
tir
.
functor
.
visitor
@
tir
.
functor
.
visitor
class
_FragmentLoopCheckVisitor
(
PyStmtExprVisitor
):
class
_FragmentLoopCheckVisitor
(
PyStmtExprVisitor
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -75,7 +74,8 @@ class _FragmentLoopCheckVisitor(PyStmtExprVisitor):
...
@@ -75,7 +74,8 @@ class _FragmentLoopCheckVisitor(PyStmtExprVisitor):
raise
ValueError
(
raise
ValueError
(
"[Tilelang Semantic Check] "
"[Tilelang Semantic Check] "
f
"Loop variable
{
loop
.
loop_var
}
in a T.Parallel loop with symbolic range (min=
{
loop
.
min
}
, extent=
{
loop
.
extent
}
) is used to index "
f
"Loop variable
{
loop
.
loop_var
}
in a T.Parallel loop with symbolic range (min=
{
loop
.
min
}
, extent=
{
loop
.
extent
}
) is used to index "
"a local/fragment buffer, which is not allowed in Tilelang."
)
"a local/fragment buffer, which is not allowed in Tilelang."
)
return
return
...
...
tilelang/analysis/layout_visual.py
View file @
29051439
...
@@ -23,10 +23,7 @@ def print_fragment_format(layout: T.Fragment) -> str:
...
@@ -23,10 +23,7 @@ def print_fragment_format(layout: T.Fragment) -> str:
if
isinstance
(
layout
,
T
.
Fragment
):
if
isinstance
(
layout
,
T
.
Fragment
):
input_shape
=
layout
.
get_input_shape
()
input_shape
=
layout
.
get_input_shape
()
output_shape
=
layout
.
get_output_shape
()
output_shape
=
layout
.
get_output_shape
()
lines
=
[
lines
=
[
f
" Shape:
{
input_shape
}
->
{
output_shape
}
"
,
f
" Thread:
{
layout
.
forward_thread
}
"
,
f
" Index:
{
layout
.
forward_index
}
"
]
f
" Shape:
{
input_shape
}
->
{
output_shape
}
"
,
f
" Thread:
{
layout
.
forward_thread
}
"
,
f
" Index:
{
layout
.
forward_index
}
"
]
print
(
"
\n
"
.
join
(
lines
))
print
(
"
\n
"
.
join
(
lines
))
else
:
else
:
raise
ValueError
(
f
"Expected T.Fragment, but got
{
type
(
layout
).
__name__
}
"
)
raise
ValueError
(
f
"Expected T.Fragment, but got
{
type
(
layout
).
__name__
}
"
)
...
@@ -82,7 +79,6 @@ class _LayoutVisualVisitor(PyStmtExprVisitor):
...
@@ -82,7 +79,6 @@ class _LayoutVisualVisitor(PyStmtExprVisitor):
def
LayoutVisual
(
formats
:
str
=
""
):
def
LayoutVisual
(
formats
:
str
=
""
):
def
pass_fn
(
func
:
tir
.
PrimFunc
,
mod
,
ctx
):
def
pass_fn
(
func
:
tir
.
PrimFunc
,
mod
,
ctx
):
_LayoutVisualVisitor
(
formats
=
formats
).
visit_stmt
(
func
.
body
)
_LayoutVisualVisitor
(
formats
=
formats
).
visit_stmt
(
func
.
body
)
return
func
return
func
...
...
tilelang/analysis/nested_loop_checker.py
View file @
29051439
...
@@ -11,10 +11,7 @@ from tvm.tir.transform import prim_func_pass
...
@@ -11,10 +11,7 @@ from tvm.tir.transform import prim_func_pass
def
is_pipelined_for
(
op
:
For
)
->
bool
:
def
is_pipelined_for
(
op
:
For
)
->
bool
:
"""Check if a for loop is pipelined."""
"""Check if a for loop is pipelined."""
anno_keys
=
[
anno_keys
=
[
"num_stages"
,
"tl_pipeline_order"
,
"tl_pipeline_stage"
,
"tl_pipeline_sync"
,
"tl_pipeline_group"
]
"num_stages"
,
"tl_pipeline_order"
,
"tl_pipeline_stage"
,
"tl_pipeline_sync"
,
"tl_pipeline_group"
]
return
any
(
key
in
op
.
annotations
for
key
in
anno_keys
)
return
any
(
key
in
op
.
annotations
for
key
in
anno_keys
)
...
@@ -26,7 +23,6 @@ def is_tile_op(op: Call) -> bool:
...
@@ -26,7 +23,6 @@ def is_tile_op(op: Call) -> bool:
@
tir
.
functor
.
visitor
@
tir
.
functor
.
visitor
class
_NestedLoopCheckVisitor
(
PyStmtExprVisitor
):
class
_NestedLoopCheckVisitor
(
PyStmtExprVisitor
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
in_parallel_context
=
False
self
.
in_parallel_context
=
False
...
@@ -42,27 +38,24 @@ class _NestedLoopCheckVisitor(PyStmtExprVisitor):
...
@@ -42,27 +38,24 @@ class _NestedLoopCheckVisitor(PyStmtExprVisitor):
# Otherwise
# Otherwise
if
self
.
in_parallel_context
:
if
self
.
in_parallel_context
:
raise
ValueError
(
"[Tilelang Semantic Check] "
raise
ValueError
(
"[Tilelang Semantic Check] Nested parallel loops are not allowed. Please check your loop structure."
)
"Nested parallel loops are not allowed. "
"Please check your loop structure."
)
self
.
in_parallel_context
=
True
self
.
in_parallel_context
=
True
super
().
visit_for_
(
op
)
super
().
visit_for_
(
op
)
self
.
in_parallel_context
=
False
self
.
in_parallel_context
=
False
return
return
elif
is_pipelined_for
(
op
):
elif
is_pipelined_for
(
op
):
if
self
.
in_parallel_context
:
if
self
.
in_parallel_context
:
raise
ValueError
(
"[Tilelang Semantic Check] "
raise
ValueError
(
"
Pipelined loop cannot be nested inside a parallel loop. "
"[Tilelang Semantic Check]
Pipelined loop cannot be nested inside a parallel loop.
Please check your loop structure.
"
"Please check your loop structure."
)
)
super
().
visit_for_
(
op
)
super
().
visit_for_
(
op
)
def
visit_call_
(
self
,
op
:
Call
)
->
None
:
def
visit_call_
(
self
,
op
:
Call
)
->
None
:
if
self
.
in_parallel_context
and
is_tile_op
(
op
):
if
self
.
in_parallel_context
and
is_tile_op
(
op
):
raise
ValueError
(
"[Tilelang Semantic Check] "
raise
ValueError
(
"Only elementwise operations are allowed inside a parallel loop. "
\
f
'[Tilelang Semantic Check] Only elementwise operations are allowed inside a parallel loop. Got a tile-op "
{
op
.
op
}
".'
f
"Got a tile-op
\"
{
op
.
op
}
\"
."
)
)
def
NestedLoopChecker
():
def
NestedLoopChecker
():
...
...
tilelang/autotuner/capture.py
View file @
29051439
...
@@ -85,8 +85,7 @@ def _get_current_stack() -> CaptureStack:
...
@@ -85,8 +85,7 @@ def _get_current_stack() -> CaptureStack:
class
AutotuneInputsCapture
:
class
AutotuneInputsCapture
:
__slots__
=
"tensors"
__slots__
=
(
"tensors"
)
def
__init__
(
self
,
tensors
:
list
[
Any
]):
def
__init__
(
self
,
tensors
:
list
[
Any
]):
self
.
tensors
=
tensors
self
.
tensors
=
tensors
...
...
tilelang/autotuner/param.py
View file @
29051439
"""The auto-tune parameters.
"""The auto-tune parameters.
"""
"""
from
__future__
import
annotations
from
__future__
import
annotations
import
tilelang
import
tilelang
...
@@ -50,7 +50,7 @@ class CompileArgs:
...
@@ -50,7 +50,7 @@ class CompileArgs:
out_idx
:
list
[
int
]
|
int
|
None
=
None
out_idx
:
list
[
int
]
|
int
|
None
=
None
execution_backend
:
Literal
[
"auto"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
=
"auto"
execution_backend
:
Literal
[
"auto"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
=
"auto"
target
:
Literal
[
'
auto
'
,
'
cuda
'
,
'
hip
'
]
=
'
auto
'
target
:
Literal
[
"
auto
"
,
"
cuda
"
,
"
hip
"
]
=
"
auto
"
target_host
:
str
|
Target
=
None
target_host
:
str
|
Target
=
None
verbose
:
bool
=
False
verbose
:
bool
=
False
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
...
@@ -62,24 +62,20 @@ class CompileArgs:
...
@@ -62,24 +62,20 @@ class CompileArgs:
target
=
self
.
target
,
target
=
self
.
target
,
target_host
=
self
.
target_host
,
target_host
=
self
.
target_host
,
verbose
=
self
.
verbose
,
verbose
=
self
.
verbose
,
pass_configs
=
self
.
pass_configs
)
pass_configs
=
self
.
pass_configs
,
)
def
__hash__
(
self
):
def
__hash__
(
self
):
data
=
{
data
=
{
"execution_backend"
:
"execution_backend"
:
self
.
execution_backend
,
self
.
execution_backend
,
"target"
:
str
(
self
.
target
),
"target"
:
"target_host"
:
str
(
self
.
target_host
)
if
self
.
target_host
else
None
,
str
(
self
.
target
),
"verbose"
:
self
.
verbose
,
"target_host"
:
"pass_configs"
:
json
.
dumps
(
self
.
pass_configs
,
sort_keys
=
True
)
if
self
.
pass_configs
else
None
,
str
(
self
.
target_host
)
if
self
.
target_host
else
None
,
"verbose"
:
self
.
verbose
,
"pass_configs"
:
json
.
dumps
(
self
.
pass_configs
,
sort_keys
=
True
)
if
self
.
pass_configs
else
None
,
}
}
hash_obj
=
hashlib
.
sha256
(
json
.
dumps
(
data
,
sort_keys
=
True
).
encode
(
'
utf-8
'
))
hash_obj
=
hashlib
.
sha256
(
json
.
dumps
(
data
,
sort_keys
=
True
).
encode
(
"
utf-8
"
))
return
int
.
from_bytes
(
hash_obj
.
digest
(),
byteorder
=
'
big
'
)
return
int
.
from_bytes
(
hash_obj
.
digest
(),
byteorder
=
"
big
"
)
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
...
@@ -104,6 +100,7 @@ class ProfileArgs:
...
@@ -104,6 +100,7 @@ class ProfileArgs:
manual_check_prog: Callable = None
manual_check_prog: Callable = None
cache_input_tensors: bool = True
cache_input_tensors: bool = True
"""
"""
warmup
:
int
=
25
warmup
:
int
=
25
rep
:
int
=
100
rep
:
int
=
100
timeout
:
int
=
30
timeout
:
int
=
30
...
@@ -127,8 +124,8 @@ class ProfileArgs:
...
@@ -127,8 +124,8 @@ class ProfileArgs:
"atol"
:
self
.
atol
,
"atol"
:
self
.
atol
,
"max_mismatched_ratio"
:
self
.
max_mismatched_ratio
,
"max_mismatched_ratio"
:
self
.
max_mismatched_ratio
,
}
}
hash_obj
=
hashlib
.
sha256
(
json
.
dumps
(
data
,
sort_keys
=
True
).
encode
(
'
utf-8
'
))
hash_obj
=
hashlib
.
sha256
(
json
.
dumps
(
data
,
sort_keys
=
True
).
encode
(
"
utf-8
"
))
return
int
.
from_bytes
(
hash_obj
.
digest
(),
byteorder
=
'
big
'
)
return
int
.
from_bytes
(
hash_obj
.
digest
(),
byteorder
=
"
big
"
)
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
...
@@ -143,6 +140,7 @@ class AutotuneResult:
...
@@ -143,6 +140,7 @@ class AutotuneResult:
func: Optimized function.
func: Optimized function.
kernel: Compiled kernel function.
kernel: Compiled kernel function.
"""
"""
latency
:
float
|
None
=
None
latency
:
float
|
None
=
None
config
:
dict
|
None
=
None
config
:
dict
|
None
=
None
ref_latency
:
float
|
None
=
None
ref_latency
:
float
|
None
=
None
...
@@ -199,8 +197,7 @@ class AutotuneResult:
...
@@ -199,8 +197,7 @@ class AutotuneResult:
if
verbose
:
if
verbose
:
logger
.
debug
(
f
"Saving kernel source code to file:
{
device_kernel_path
}
"
)
logger
.
debug
(
f
"Saving kernel source code to file:
{
device_kernel_path
}
"
)
if
kernel
.
kernel_source
is
not
None
:
if
kernel
.
kernel_source
is
not
None
:
self
.
_safe_write_file
(
device_kernel_path
,
"w"
,
self
.
_safe_write_file
(
device_kernel_path
,
"w"
,
lambda
f
:
f
.
write
(
kernel
.
kernel_source
))
lambda
f
:
f
.
write
(
kernel
.
kernel_source
))
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"Error saving kernel source code to disk:
{
e
}
"
)
logger
.
error
(
f
"Error saving kernel source code to disk:
{
e
}
"
)
...
@@ -211,11 +208,9 @@ class AutotuneResult:
...
@@ -211,11 +208,9 @@ class AutotuneResult:
logger
.
debug
(
f
"Saving wrapped kernel source code to file:
{
host_kernel_path
}
"
)
logger
.
debug
(
f
"Saving wrapped kernel source code to file:
{
host_kernel_path
}
"
)
# Match kernel_cache behavior: use host source for tvm_ffi, otherwise wrapped kernel
# Match kernel_cache behavior: use host source for tvm_ffi, otherwise wrapped kernel
if
kernel
.
execution_backend
==
"tvm_ffi"
:
if
kernel
.
execution_backend
==
"tvm_ffi"
:
self
.
_safe_write_file
(
host_kernel_path
,
"w"
,
self
.
_safe_write_file
(
host_kernel_path
,
"w"
,
lambda
f
:
f
.
write
(
kernel
.
adapter
.
get_host_source
()))
lambda
f
:
f
.
write
(
kernel
.
adapter
.
get_host_source
()))
else
:
else
:
self
.
_safe_write_file
(
host_kernel_path
,
"w"
,
self
.
_safe_write_file
(
host_kernel_path
,
"w"
,
lambda
f
:
f
.
write
(
kernel
.
adapter
.
get_kernel_source
()))
lambda
f
:
f
.
write
(
kernel
.
adapter
.
get_kernel_source
()))
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"Error saving wrapped kernel source code to disk:
{
e
}
"
)
logger
.
error
(
f
"Error saving wrapped kernel source code to disk:
{
e
}
"
)
...
@@ -237,12 +232,10 @@ class AutotuneResult:
...
@@ -237,12 +232,10 @@ class AutotuneResult:
py_src_path
=
src_lib_path
.
replace
(
".cubin"
,
".py"
)
py_src_path
=
src_lib_path
.
replace
(
".cubin"
,
".py"
)
if
verbose
:
if
verbose
:
logger
.
debug
(
f
"Saving kernel nvrtc python code to file:
{
kernel_py_path
}
"
)
logger
.
debug
(
f
"Saving kernel nvrtc python code to file:
{
kernel_py_path
}
"
)
self
.
_safe_write_file
(
kernel_py_path
,
"wb"
,
self
.
_safe_write_file
(
kernel_py_path
,
"wb"
,
lambda
f
:
f
.
write
(
self
.
_load_binary
(
py_src_path
)))
lambda
f
:
f
.
write
(
self
.
_load_binary
(
py_src_path
)))
if
verbose
:
if
verbose
:
logger
.
debug
(
f
"Saving kernel library to file:
{
kernel_lib_path
}
"
)
logger
.
debug
(
f
"Saving kernel library to file:
{
kernel_lib_path
}
"
)
self
.
_safe_write_file
(
kernel_lib_path
,
"wb"
,
self
.
_safe_write_file
(
kernel_lib_path
,
"wb"
,
lambda
f
:
f
.
write
(
self
.
_load_binary
(
src_lib_path
)))
lambda
f
:
f
.
write
(
self
.
_load_binary
(
src_lib_path
)))
elif
kernel
.
execution_backend
==
"tvm_ffi"
:
elif
kernel
.
execution_backend
==
"tvm_ffi"
:
executable
=
kernel
.
adapter
.
executable
executable
=
kernel
.
adapter
.
executable
if
verbose
:
if
verbose
:
...
@@ -252,8 +245,7 @@ class AutotuneResult:
...
@@ -252,8 +245,7 @@ class AutotuneResult:
src_lib_path
=
kernel
.
adapter
.
libpath
src_lib_path
=
kernel
.
adapter
.
libpath
if
verbose
:
if
verbose
:
logger
.
debug
(
f
"Saving kernel library to file:
{
kernel_lib_path
}
"
)
logger
.
debug
(
f
"Saving kernel library to file:
{
kernel_lib_path
}
"
)
self
.
_safe_write_file
(
kernel_lib_path
,
"wb"
,
self
.
_safe_write_file
(
kernel_lib_path
,
"wb"
,
lambda
f
:
f
.
write
(
self
.
_load_binary
(
src_lib_path
)))
lambda
f
:
f
.
write
(
self
.
_load_binary
(
src_lib_path
)))
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"Error saving kernel library to disk:
{
e
}
"
)
logger
.
error
(
f
"Error saving kernel library to disk:
{
e
}
"
)
...
@@ -370,14 +362,12 @@ class AutotuneResult:
...
@@ -370,14 +362,12 @@ class AutotuneResult:
# save best config (atomic)
# save best config (atomic)
if
verbose
:
if
verbose
:
logger
.
debug
(
f
"Saving best config to file:
{
path
/
BEST_CONFIG_PATH
}
"
)
logger
.
debug
(
f
"Saving best config to file:
{
path
/
BEST_CONFIG_PATH
}
"
)
self
.
_safe_write_file
(
self
.
_safe_write_file
(
str
(
path
/
BEST_CONFIG_PATH
),
"w"
,
lambda
f
:
json
.
dump
(
self
.
config
,
f
))
str
(
path
/
BEST_CONFIG_PATH
),
"w"
,
lambda
f
:
json
.
dump
(
self
.
config
,
f
))
# save function (atomic)
# save function (atomic)
if
verbose
:
if
verbose
:
logger
.
debug
(
f
"Saving function to file:
{
path
/
FUNCTION_PATH
}
"
)
logger
.
debug
(
f
"Saving function to file:
{
path
/
FUNCTION_PATH
}
"
)
self
.
_safe_write_file
(
self
.
_safe_write_file
(
str
(
path
/
FUNCTION_PATH
),
"wb"
,
lambda
f
:
cloudpickle
.
dump
(
self
.
func
,
f
))
str
(
path
/
FUNCTION_PATH
),
"wb"
,
lambda
f
:
cloudpickle
.
dump
(
self
.
func
,
f
))
# save ref latency (atomic)
# save ref latency (atomic)
if
verbose
:
if
verbose
:
...
@@ -385,10 +375,13 @@ class AutotuneResult:
...
@@ -385,10 +375,13 @@ class AutotuneResult:
self
.
_safe_write_file
(
self
.
_safe_write_file
(
str
(
path
/
LATENCY_PATH
),
str
(
path
/
LATENCY_PATH
),
"w"
,
"w"
,
lambda
f
:
json
.
dump
({
lambda
f
:
json
.
dump
(
"latency"
:
self
.
latency
,
{
"ref_latency"
:
self
.
ref_latency
,
"latency"
:
self
.
latency
,
},
f
),
"ref_latency"
:
self
.
ref_latency
,
},
f
,
),
)
)
# save kernel
# save kernel
...
@@ -403,8 +396,8 @@ class AutotuneResult:
...
@@ -403,8 +396,8 @@ class AutotuneResult:
# Normalize target and resolve execution backend for loading
# Normalize target and resolve execution backend for loading
from
tilelang.utils.target
import
determine_target
as
_determine_target
from
tilelang.utils.target
import
determine_target
as
_determine_target
from
tilelang.jit.execution_backend
import
resolve_execution_backend
from
tilelang.jit.execution_backend
import
resolve_execution_backend
norm_target
=
Target
(
_determine_target
(
compile_args
.
target
))
if
isinstance
(
compile_args
.
target
,
str
)
else
compile_args
.
target
norm_target
=
Target
(
_determine_target
(
compile_args
.
target
))
if
isinstance
(
compile_args
.
target
,
str
)
else
compile_args
.
target
requested_backend
=
compile_args
.
execution_backend
requested_backend
=
compile_args
.
execution_backend
resolved_backend
=
resolve_execution_backend
(
requested_backend
,
norm_target
)
resolved_backend
=
resolve_execution_backend
(
requested_backend
,
norm_target
)
# load best config
# load best config
...
...
tilelang/autotuner/tuner.py
View file @
29051439
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
This module provides functionality for auto-tuning tilelang programs, including JIT compilation
This module provides functionality for auto-tuning tilelang programs, including JIT compilation
and performance optimization through configuration search.
and performance optimization through configuration search.
"""
"""
from
__future__
import
annotations
from
__future__
import
annotations
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
...
@@ -14,7 +15,8 @@ from tvm.tir import PrimFunc, Var
...
@@ -14,7 +15,8 @@ from tvm.tir import PrimFunc, Var
from
tvm.target
import
Target
from
tvm.target
import
Target
import
inspect
import
inspect
from
functools
import
partial
from
functools
import
partial
from
typing
import
(
Callable
,
Generic
,
Literal
,
Any
,
TypeVar
)
from
typing
import
Callable
,
Generic
,
Literal
,
Any
,
TypeVar
# Python 3.9 compatibility for ParamSpec
# Python 3.9 compatibility for ParamSpec
try
:
try
:
from
typing
import
ParamSpec
from
typing
import
ParamSpec
...
@@ -74,8 +76,8 @@ def _init_logger_handlers():
...
@@ -74,8 +76,8 @@ def _init_logger_handlers():
global
_logger_handlers_initialized
global
_logger_handlers_initialized
if
_logger_handlers_initialized
:
if
_logger_handlers_initialized
:
return
return
formatter
=
logging
.
Formatter
(
'
%(asctime)s %(levelname)s:%(message)s
'
)
formatter
=
logging
.
Formatter
(
"
%(asctime)s %(levelname)s:%(message)s
"
)
file_handler
=
logging
.
FileHandler
(
'
autotuner.log
'
,
mode
=
'w'
)
file_handler
=
logging
.
FileHandler
(
"
autotuner.log
"
,
mode
=
"w"
)
file_handler
.
setLevel
(
logging
.
DEBUG
)
file_handler
.
setLevel
(
logging
.
DEBUG
)
file_handler
.
setFormatter
(
formatter
)
file_handler
.
setFormatter
(
formatter
)
console_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
console_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
...
@@ -87,8 +89,7 @@ def _init_logger_handlers():
...
@@ -87,8 +89,7 @@ def _init_logger_handlers():
def
get_available_cpu_count
()
->
int
:
def
get_available_cpu_count
()
->
int
:
"""Gets the number of CPU cores available to the current process.
"""Gets the number of CPU cores available to the current process."""
"""
try
:
try
:
cpu_count
=
len
(
os
.
sched_getaffinity
(
0
))
cpu_count
=
len
(
os
.
sched_getaffinity
(
0
))
except
AttributeError
:
except
AttributeError
:
...
@@ -107,6 +108,7 @@ class AutoTuner:
...
@@ -107,6 +108,7 @@ class AutoTuner:
fn: The function to be auto-tuned.
fn: The function to be auto-tuned.
configs: List of configurations to try during auto-tuning.
configs: List of configurations to try during auto-tuning.
"""
"""
compile_args
=
CompileArgs
()
compile_args
=
CompileArgs
()
profile_args
=
ProfileArgs
()
profile_args
=
ProfileArgs
()
...
@@ -137,14 +139,15 @@ class AutoTuner:
...
@@ -137,14 +139,15 @@ class AutoTuner:
"""
"""
return
cls
(
kernel
,
configs
)
return
cls
(
kernel
,
configs
)
def
set_compile_args
(
self
,
def
set_compile_args
(
out_idx
:
list
[
int
]
|
int
|
None
=
None
,
self
,
target
:
Literal
[
'auto'
,
'cuda'
,
'hip'
,
'metal'
]
=
'auto'
,
out_idx
:
list
[
int
]
|
int
|
None
=
None
,
execution_backend
:
Literal
[
"auto"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
target
:
Literal
[
"auto"
,
"cuda"
,
"hip"
,
"metal"
]
=
"auto"
,
"torch"
]
=
"auto"
,
execution_backend
:
Literal
[
"auto"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
=
"auto"
,
target_host
:
str
|
Target
=
None
,
target_host
:
str
|
Target
=
None
,
verbose
:
bool
=
False
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
):
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
):
"""Set compilation arguments for the auto-tuner.
"""Set compilation arguments for the auto-tuner.
Args:
Args:
...
@@ -161,6 +164,7 @@ class AutoTuner:
...
@@ -161,6 +164,7 @@ class AutoTuner:
# Normalize target to a concrete TVM Target and resolve execution backend
# Normalize target to a concrete TVM Target and resolve execution backend
t
=
Target
(
determine_target
(
target
))
t
=
Target
(
determine_target
(
target
))
from
tilelang.jit.execution_backend
import
resolve_execution_backend
from
tilelang.jit.execution_backend
import
resolve_execution_backend
resolved_backend
=
resolve_execution_backend
(
execution_backend
,
t
)
resolved_backend
=
resolve_execution_backend
(
execution_backend
,
t
)
self
.
compile_args
=
CompileArgs
(
self
.
compile_args
=
CompileArgs
(
...
@@ -169,23 +173,26 @@ class AutoTuner:
...
@@ -169,23 +173,26 @@ class AutoTuner:
execution_backend
=
resolved_backend
,
execution_backend
=
resolved_backend
,
target_host
=
target_host
,
target_host
=
target_host
,
verbose
=
verbose
,
verbose
=
verbose
,
pass_configs
=
pass_configs
)
pass_configs
=
pass_configs
,
)
return
self
return
self
def
set_profile_args
(
self
,
def
set_profile_args
(
warmup
:
int
=
25
,
self
,
rep
:
int
=
100
,
warmup
:
int
=
25
,
timeout
:
int
=
30
,
rep
:
int
=
100
,
supply_type
:
tilelang
.
TensorSupplyType
=
tilelang
.
TensorSupplyType
.
Auto
,
timeout
:
int
=
30
,
ref_prog
:
Callable
=
None
,
supply_type
:
tilelang
.
TensorSupplyType
=
tilelang
.
TensorSupplyType
.
Auto
,
supply_prog
:
Callable
=
None
,
ref_prog
:
Callable
=
None
,
rtol
:
float
=
1e-2
,
supply_prog
:
Callable
=
None
,
atol
:
float
=
1e-2
,
rtol
:
float
=
1e-2
,
max_mismatched_ratio
:
float
=
0.01
,
atol
:
float
=
1e-2
,
skip_check
:
bool
=
False
,
max_mismatched_ratio
:
float
=
0.01
,
manual_check_prog
:
Callable
=
None
,
skip_check
:
bool
=
False
,
cache_input_tensors
:
bool
=
False
):
manual_check_prog
:
Callable
=
None
,
cache_input_tensors
:
bool
=
False
,
):
"""Set profiling arguments for the auto-tuner.
"""Set profiling arguments for the auto-tuner.
Args:
Args:
...
@@ -209,9 +216,7 @@ class AutoTuner:
...
@@ -209,9 +216,7 @@ class AutoTuner:
# the `supply_prog` will be ignored and the `get_autotune_inputs` will be used instead.
# the `supply_prog` will be ignored and the `get_autotune_inputs` will be used instead.
if
get_autotune_inputs
()
is
not
None
:
if
get_autotune_inputs
()
is
not
None
:
if
supply_prog
is
not
None
:
if
supply_prog
is
not
None
:
logger
.
warning
(
logger
.
warning
(
"`supply_prog` will be ignored as this program is under `with set_autotune_inputs` context."
)
"`supply_prog` will be ignored as this program is under `with set_autotune_inputs` context."
)
supply_prog
=
lambda
_
:
get_autotune_inputs
()
# noqa: E731
supply_prog
=
lambda
_
:
get_autotune_inputs
()
# noqa: E731
self
.
profile_args
=
ProfileArgs
(
self
.
profile_args
=
ProfileArgs
(
...
@@ -226,13 +231,13 @@ class AutoTuner:
...
@@ -226,13 +231,13 @@ class AutoTuner:
cache_input_tensors
=
cache_input_tensors
,
cache_input_tensors
=
cache_input_tensors
,
warmup
=
warmup
,
warmup
=
warmup
,
rep
=
rep
,
rep
=
rep
,
timeout
=
timeout
)
timeout
=
timeout
,
)
# If a custom `supply_prog` is provided, the profiler's `supply_type` setting
# If a custom `supply_prog` is provided, the profiler's `supply_type` setting
# becomes ineffective. The custom supply program will be used instead.
# becomes ineffective. The custom supply program will be used instead.
if
supply_prog
is
not
None
and
supply_type
!=
tilelang
.
TensorSupplyType
.
Auto
:
if
supply_prog
is
not
None
and
supply_type
!=
tilelang
.
TensorSupplyType
.
Auto
:
logger
.
warning
(
"Ignoring `supply_type` passed to `set_profile_args` because "
logger
.
warning
(
"Ignoring `supply_type` passed to `set_profile_args` because `supply_prog` is not None."
)
"`supply_prog` is not None."
)
return
self
return
self
...
@@ -241,10 +246,8 @@ class AutoTuner:
...
@@ -241,10 +246,8 @@ class AutoTuner:
self
.
_kernel_parameters
=
k_parameters
self
.
_kernel_parameters
=
k_parameters
self
.
_function_parameters
=
f_parameters
self
.
_function_parameters
=
f_parameters
def
generate_cache_key
(
self
,
parameters
:
dict
[
str
,
Any
],
def
generate_cache_key
(
self
,
parameters
:
dict
[
str
,
Any
],
extra_parameters
:
dict
[
str
,
Any
])
->
AutotuneResult
|
None
:
extra_parameters
:
dict
[
str
,
Any
])
->
AutotuneResult
|
None
:
"""Generate a cache key for the auto-tuning process."""
"""Generate a cache key for the auto-tuning process.
"""
def
_normalize_param
(
value
):
def
_normalize_param
(
value
):
if
isinstance
(
value
,
Var
):
if
isinstance
(
value
,
Var
):
...
@@ -315,8 +318,9 @@ class AutoTuner:
...
@@ -315,8 +318,9 @@ class AutoTuner:
if
var_name
in
parameters
:
if
var_name
in
parameters
:
continue
continue
# Cell content must be serializable
# Cell content must be serializable
assert
isinstance
(
cell
.
cell_contents
,
(
int
,
float
,
str
,
bool
,
type
(
None
))),
\
assert
isinstance
(
cell
.
cell_contents
,
(
int
,
float
,
str
,
bool
,
type
(
None
))),
(
f
"Cell contents
{
cell
.
cell_contents
}
is not serializable:
{
type
(
cell
.
cell_contents
)
}
"
f
"Cell contents
{
cell
.
cell_contents
}
is not serializable:
{
type
(
cell
.
cell_contents
)
}
"
)
extra_parameters
[
var_name
]
=
cell
.
cell_contents
extra_parameters
[
var_name
]
=
cell
.
cell_contents
if
isinstance
(
self
.
configs
,
Callable
):
if
isinstance
(
self
.
configs
,
Callable
):
...
@@ -328,8 +332,10 @@ class AutoTuner:
...
@@ -328,8 +332,10 @@ class AutoTuner:
if
env
.
is_cache_enabled
()
and
not
env
.
is_autotune_cache_disabled
():
if
env
.
is_cache_enabled
()
and
not
env
.
is_autotune_cache_disabled
():
# First check in-memory cache
# First check in-memory cache
if
key
in
self
.
_memory_cache
:
if
key
in
self
.
_memory_cache
:
logger
.
warning
(
"Found kernel in memory cache. For better performance,"
\
logger
.
warning
(
" consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel."
)
"Found kernel in memory cache. For better performance,"
" consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel."
)
return
self
.
_memory_cache
[
key
]
return
self
.
_memory_cache
[
key
]
# Then check disk cache
# Then check disk cache
...
@@ -369,7 +375,6 @@ class AutoTuner:
...
@@ -369,7 +375,6 @@ class AutoTuner:
# This encapsulates the logic of using either a custom supply program (`supply_prog`)
# This encapsulates the logic of using either a custom supply program (`supply_prog`)
# or the default profiler input generation (`profiler._get_inputs`).
# or the default profiler input generation (`profiler._get_inputs`).
def
get_input_tensors_supply
(
with_output
:
bool
):
def
get_input_tensors_supply
(
with_output
:
bool
):
def
func
():
def
func
():
if
supply_prog
is
not
None
:
if
supply_prog
is
not
None
:
return
supply_prog
(
profiler
.
_get_params
(
with_output
=
with_output
))
return
supply_prog
(
profiler
.
_get_params
(
with_output
=
with_output
))
...
@@ -387,8 +392,7 @@ class AutoTuner:
...
@@ -387,8 +392,7 @@ class AutoTuner:
self
.
jit_input_tensors
=
jit_input_tensors_supply
()
self
.
jit_input_tensors
=
jit_input_tensors_supply
()
else
:
else
:
# check if the cached tensors are compatible with the current configuration
# check if the cached tensors are compatible with the current configuration
assert
len
(
params
)
==
len
(
assert
len
(
params
)
==
len
(
self
.
jit_input_tensors
),
"len(params) != len(self.jit_input_tensors)"
self
.
jit_input_tensors
),
"len(params) != len(self.jit_input_tensors)"
for
p
,
c
in
zip
(
params
,
self
.
jit_input_tensors
):
for
p
,
c
in
zip
(
params
,
self
.
jit_input_tensors
):
if
not
isinstance
(
c
,
torch
.
Tensor
):
if
not
isinstance
(
c
,
torch
.
Tensor
):
# skip non-tensor inputs checking
# skip non-tensor inputs checking
...
@@ -397,8 +401,8 @@ class AutoTuner:
...
@@ -397,8 +401,8 @@ class AutoTuner:
# Check tensor compatibility using generator expression
# Check tensor compatibility using generator expression
def
shape_equal
(
a
,
b
):
def
shape_equal
(
a
,
b
):
return
all
(
return
all
(
a_dim
==
b_dim
or
isinstance
(
a_dim
,
Var
)
or
isinstance
(
b_dim
,
Var
)
a_dim
==
b_dim
or
isinstance
(
a_dim
,
Var
)
or
isinstance
(
b_dim
,
Var
)
for
a_dim
,
b_dim
in
zip
(
a
.
shape
,
b
.
shape
)
for
a_dim
,
b_dim
in
zip
(
a
.
shape
,
b
.
shape
)
)
)
if
p
.
dtype
!=
c
.
dtype
or
not
shape_equal
(
p
,
c
):
if
p
.
dtype
!=
c
.
dtype
or
not
shape_equal
(
p
,
c
):
logger
.
warning
(
logger
.
warning
(
...
@@ -409,7 +413,8 @@ class AutoTuner:
...
@@ -409,7 +413,8 @@ class AutoTuner:
"To ensure fresh, compatible inputs are generated for every trial "
"To ensure fresh, compatible inputs are generated for every trial "
"you can disable caching by setting:
\n
"
"you can disable caching by setting:
\n
"
" `cache_input_tensors=False`
\n
"
" `cache_input_tensors=False`
\n
"
"within your `.set_compile_args(...)` call.
\n
"
)
"within your `.set_compile_args(...)` call.
\n
"
)
# otherwise, regenerate the input tensors for safety
# otherwise, regenerate the input tensors for safety
self
.
jit_input_tensors
=
jit_input_tensors_supply
()
self
.
jit_input_tensors
=
jit_input_tensors_supply
()
break
break
...
@@ -418,24 +423,16 @@ class AutoTuner:
...
@@ -418,24 +423,16 @@ class AutoTuner:
if
(
not
skip_check
)
and
(
ref_prog
is
not
None
):
if
(
not
skip_check
)
and
(
ref_prog
is
not
None
):
if
manual_check_prog
is
not
None
:
if
manual_check_prog
is
not
None
:
profiler
.
manual_assert_close
(
profiler
.
manual_assert_close
(
ref_prog
,
input_tensors
=
self
.
jit_input_tensors
,
manual_check_prog
=
manual_check_prog
)
ref_prog
,
input_tensors
=
self
.
jit_input_tensors
,
manual_check_prog
=
manual_check_prog
)
else
:
else
:
profiler
.
assert_allclose
(
profiler
.
assert_allclose
(
ref_prog
,
ref_prog
,
input_tensors
=
self
.
jit_input_tensors
,
rtol
=
rtol
,
atol
=
atol
,
max_mismatched_ratio
=
max_mismatched_ratio
input_tensors
=
self
.
jit_input_tensors
,
)
rtol
=
rtol
,
latency
=
profiler
.
do_bench
(
warmup
=
warmup
,
rep
=
rep
,
input_tensors
=
self
.
jit_input_tensors
)
atol
=
atol
,
max_mismatched_ratio
=
max_mismatched_ratio
)
latency
=
profiler
.
do_bench
(
warmup
=
warmup
,
rep
=
rep
,
input_tensors
=
self
.
jit_input_tensors
)
if
self
.
ref_latency_cache
is
None
and
ref_prog
is
not
None
:
if
self
.
ref_latency_cache
is
None
and
ref_prog
is
not
None
:
self
.
ref_input_tensors
=
ref_input_tensors_supply
()
self
.
ref_input_tensors
=
ref_input_tensors_supply
()
self
.
ref_latency_cache
=
profiler
.
do_bench
(
self
.
ref_latency_cache
=
profiler
.
do_bench
(
ref_prog
,
n_warmup
=
warmup
,
n_repeat
=
rep
,
input_tensors
=
self
.
ref_input_tensors
)
ref_prog
,
n_warmup
=
warmup
,
n_repeat
=
rep
,
input_tensors
=
self
.
ref_input_tensors
)
return
latency
,
self
.
ref_latency_cache
return
latency
,
self
.
ref_latency_cache
...
@@ -469,17 +466,14 @@ class AutoTuner:
...
@@ -469,17 +466,14 @@ class AutoTuner:
# Check if all tunable arguments have been tuned by comparing config keys with key_kwargs_tuple
# Check if all tunable arguments have been tuned by comparing config keys with key_kwargs_tuple
if
any
(
key
in
top_config
for
key
,
_
in
key_kwargs_tuple
)
or
any
(
if
any
(
key
in
top_config
for
key
,
_
in
key_kwargs_tuple
)
or
any
(
check_tunable_argument_value
(
key
,
self
.
_function_parameters
,
key_args_tuple
)
check_tunable_argument_value
(
key
,
self
.
_function_parameters
,
key_args_tuple
)
for
key
in
tunable_arguments
for
key
in
tunable_arguments
):
):
logger
.
warning
(
logger
.
warning
(
f
"Tunable parameters
{
tunable_arguments
}
already provided during auto-tuning. Skipping compilation and using direct JIT"
f
"Tunable parameters
{
tunable_arguments
}
already provided during auto-tuning. Skipping compilation and using direct JIT"
)
)
# compile the kernel with the provided parameters
# compile the kernel with the provided parameters
jit_kernel
=
self
.
jit_compile
()
jit_kernel
=
self
.
jit_compile
()
autotuner_result
=
AutotuneResult
(
autotuner_result
=
AutotuneResult
(
libcode
=
jit_kernel
.
get_kernel_source
(),
func
=
jit_kernel
.
prim_func
,
kernel
=
jit_kernel
)
libcode
=
jit_kernel
.
get_kernel_source
(),
func
=
jit_kernel
.
prim_func
,
kernel
=
jit_kernel
)
self
.
_memory_cache
[
key
]
=
autotuner_result
self
.
_memory_cache
[
key
]
=
autotuner_result
return
autotuner_result
return
autotuner_result
# get the cpu count
# get the cpu count
...
@@ -489,9 +483,7 @@ class AutoTuner:
...
@@ -489,9 +483,7 @@ class AutoTuner:
max_cpu_count
=
int
(
env
.
TILELANG_AUTO_TUNING_MAX_CPU_COUNT
)
max_cpu_count
=
int
(
env
.
TILELANG_AUTO_TUNING_MAX_CPU_COUNT
)
if
cpu_counts
>
0
:
if
cpu_counts
>
0
:
num_workers
=
min
(
cpu_counts
,
available_cpu_count
)
num_workers
=
min
(
cpu_counts
,
available_cpu_count
)
logger
.
info
(
logger
.
info
(
f
"Auto-tuning with
{
cpu_counts
}
CPU counts,
{
available_cpu_count
}
CPUs available,
{
num_workers
}
CPUs will be used"
)
f
"Auto-tuning with
{
cpu_counts
}
CPU counts,
{
available_cpu_count
}
CPUs available,
{
num_workers
}
CPUs will be used"
)
else
:
else
:
num_workers
=
max
(
1
,
int
(
available_cpu_count
*
cpu_utilizations
))
num_workers
=
max
(
1
,
int
(
available_cpu_count
*
cpu_utilizations
))
logger
.
info
(
logger
.
info
(
...
@@ -509,7 +501,6 @@ class AutoTuner:
...
@@ -509,7 +501,6 @@ class AutoTuner:
future_to_index
=
{}
future_to_index
=
{}
def
cuda_device_wrapper
(
func
,
device
):
def
cuda_device_wrapper
(
func
,
device
):
def
inner
(
**
config_arg
):
def
inner
(
**
config_arg
):
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
return
func
(
**
config_arg
)
return
func
(
**
config_arg
)
...
@@ -532,18 +523,14 @@ class AutoTuner:
...
@@ -532,18 +523,14 @@ class AutoTuner:
future_to_index
[
future
]
=
i
future_to_index
[
future
]
=
i
results_with_configs
=
[]
results_with_configs
=
[]
for
future
in
tqdm
(
for
future
in
tqdm
(
concurrent
.
futures
.
as_completed
(
futures
),
total
=
len
(
futures
),
desc
=
"Compiling configurations"
):
concurrent
.
futures
.
as_completed
(
futures
),
total
=
len
(
futures
),
desc
=
"Compiling configurations"
):
idx
=
future_to_index
[
future
]
idx
=
future_to_index
[
future
]
config
=
config_args
[
idx
]
config
=
config_args
[
idx
]
try
:
try
:
result
=
future
.
result
()
result
=
future
.
result
()
results_with_configs
.
append
((
result
,
config
))
results_with_configs
.
append
((
result
,
config
))
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
debug
(
logger
.
debug
(
f
"Compilation failed for config
{
config
}
at index
{
idx
}
with error:
{
e
}
"
)
f
"Compilation failed for config
{
config
}
at index
{
idx
}
with error:
{
e
}
"
)
continue
continue
ref_latency
=
None
ref_latency
=
None
...
@@ -556,14 +543,10 @@ class AutoTuner:
...
@@ -556,14 +543,10 @@ class AutoTuner:
# latency, ref_latency = target_fn(jit_kernel)
# latency, ref_latency = target_fn(jit_kernel)
latency
,
ref_latency
=
run_with_timeout
(
target_fn
,
timeout
,
jit_kernel
)
latency
,
ref_latency
=
run_with_timeout
(
target_fn
,
timeout
,
jit_kernel
)
except
TimeoutException
:
except
TimeoutException
:
logger
.
warning
(
logger
.
warning
(
f
"A timeout occurred while testing config
{
config
}
, checkout autotuner.log for more details"
)
f
"A timeout occurred while testing config
{
config
}
, checkout autotuner.log for more details"
)
continue
continue
except
Exception
:
except
Exception
:
logger
.
warning
(
logger
.
warning
(
f
"An error occurred while testing config
{
config
}
, checkout autotuner.log for more details"
)
f
"An error occurred while testing config
{
config
}
, checkout autotuner.log for more details"
)
logger
.
debug
(
f
"Error:
{
traceback
.
format_exc
()
}
"
)
logger
.
debug
(
f
"Error:
{
traceback
.
format_exc
()
}
"
)
continue
continue
...
@@ -578,8 +561,7 @@ class AutoTuner:
...
@@ -578,8 +561,7 @@ class AutoTuner:
pool
.
shutdown
()
pool
.
shutdown
()
if
best_kernel
is
None
:
if
best_kernel
is
None
:
error_msg
=
(
"Auto-tuning failed: No configuration successfully "
error_msg
=
"Auto-tuning failed: No configuration successfully compiled and passed benchmarking/validation."
"compiled and passed benchmarking/validation."
)
logger
.
error
(
error_msg
)
logger
.
error
(
error_msg
)
raise
RuntimeError
(
error_msg
)
raise
RuntimeError
(
error_msg
)
...
@@ -595,7 +577,8 @@ class AutoTuner:
...
@@ -595,7 +577,8 @@ class AutoTuner:
ref_latency
=
ref_latency
,
ref_latency
=
ref_latency
,
libcode
=
best_kernel
.
get_kernel_source
(),
libcode
=
best_kernel
.
get_kernel_source
(),
func
=
best_kernel
.
prim_func
,
func
=
best_kernel
.
prim_func
,
kernel
=
best_kernel
)
kernel
=
best_kernel
,
)
if
self
.
compile_args
.
execution_backend
in
(
"torch"
):
if
self
.
compile_args
.
execution_backend
in
(
"torch"
):
logger
.
warning
(
"DLPack backend does not support cache saving to disk."
)
logger
.
warning
(
"DLPack backend does not support cache saving to disk."
)
...
@@ -617,8 +600,8 @@ class AutoTuner:
...
@@ -617,8 +600,8 @@ class AutoTuner:
return
self
.
run
()
return
self
.
run
()
_P
=
ParamSpec
(
'
_P
'
)
_P
=
ParamSpec
(
"
_P
"
)
_T
=
TypeVar
(
'
_T
'
)
_T
=
TypeVar
(
"
_T
"
)
@
dataclass
@
dataclass
...
@@ -643,8 +626,9 @@ class AutoTuneImpl(Generic[_P, _T]):
...
@@ -643,8 +626,9 @@ class AutoTuneImpl(Generic[_P, _T]):
self
.
_tuner_cache
=
{}
self
.
_tuner_cache
=
{}
def
get_tunner
(
self
):
def
get_tunner
(
self
):
autotuner
=
AutoTuner
(
autotuner
=
(
self
.
jit_impl
.
func
,
configs
=
self
.
configs
).
set_profile_args
(
AutoTuner
(
self
.
jit_impl
.
func
,
configs
=
self
.
configs
)
.
set_profile_args
(
supply_type
=
self
.
supply_type
,
supply_type
=
self
.
supply_type
,
ref_prog
=
self
.
ref_prog
,
ref_prog
=
self
.
ref_prog
,
supply_prog
=
self
.
supply_prog
,
supply_prog
=
self
.
supply_prog
,
...
@@ -654,7 +638,8 @@ class AutoTuneImpl(Generic[_P, _T]):
...
@@ -654,7 +638,8 @@ class AutoTuneImpl(Generic[_P, _T]):
skip_check
=
self
.
skip_check
,
skip_check
=
self
.
skip_check
,
manual_check_prog
=
self
.
manual_check_prog
,
manual_check_prog
=
self
.
manual_check_prog
,
cache_input_tensors
=
self
.
cache_input_tensors
,
cache_input_tensors
=
self
.
cache_input_tensors
,
).
set_compile_args
(
)
.
set_compile_args
(
out_idx
=
self
.
jit_impl
.
out_idx
,
out_idx
=
self
.
jit_impl
.
out_idx
,
execution_backend
=
self
.
jit_impl
.
execution_backend
,
execution_backend
=
self
.
jit_impl
.
execution_backend
,
target
=
self
.
jit_impl
.
target
,
target
=
self
.
jit_impl
.
target
,
...
@@ -662,6 +647,7 @@ class AutoTuneImpl(Generic[_P, _T]):
...
@@ -662,6 +647,7 @@ class AutoTuneImpl(Generic[_P, _T]):
verbose
=
self
.
jit_impl
.
verbose
,
verbose
=
self
.
jit_impl
.
verbose
,
pass_configs
=
self
.
jit_impl
.
pass_configs
,
pass_configs
=
self
.
jit_impl
.
pass_configs
,
)
)
)
autotuner
.
run
=
partial
(
autotuner
.
run
,
self
.
warmup
,
self
.
rep
,
self
.
timeout
)
autotuner
.
run
=
partial
(
autotuner
.
run
,
self
.
warmup
,
self
.
rep
,
self
.
timeout
)
return
autotuner
return
autotuner
...
@@ -753,16 +739,13 @@ def autotune( # This is the new public interface
...
@@ -753,16 +739,13 @@ def autotune( # This is the new public interface
if
callable
(
func
):
if
callable
(
func
):
# Case 1: Used as @autotune (func_or_out_idx is the function, others are defaults)
# Case 1: Used as @autotune (func_or_out_idx is the function, others are defaults)
# This is a placeholder for a real auto tuner implementation
# This is a placeholder for a real auto tuner implementation
raise
ValueError
(
raise
ValueError
(
"Use tilelang.autotune to decorate func without arguments is not supported yet."
)
"Use tilelang.autotune to decorate func without arguments is not supported yet."
)
elif
isinstance
(
func
,
PrimFunc
):
elif
isinstance
(
func
,
PrimFunc
):
raise
ValueError
(
"Use tilelang.jit to decorate prim_func is not supported yet."
)
raise
ValueError
(
"Use tilelang.jit to decorate prim_func is not supported yet."
)
else
:
else
:
def
decorator
(
impl
):
def
decorator
(
impl
):
assert
isinstance
(
assert
isinstance
(
impl
,
JITImpl
),
"The @autotune decorator can only be applied to @tilelang.jit decorated instances."
impl
,
JITImpl
),
"The @autotune decorator can only be applied to @tilelang.jit decorated instances."
return
AutoTuneImpl
(
return
AutoTuneImpl
(
jit_impl
=
impl
,
jit_impl
=
impl
,
configs
=
configs
,
configs
=
configs
,
...
...
tilelang/cache/__init__.py
View file @
29051439
"""The cache utils with class and database persistence - Init file"""
"""The cache utils with class and database persistence - Init file"""
from
__future__
import
annotations
from
__future__
import
annotations
from
typing
import
Literal
from
typing
import
Literal
...
@@ -18,8 +19,7 @@ def cached(
...
@@ -18,8 +19,7 @@ def cached(
*
args
,
*
args
,
target
:
str
|
Target
=
"auto"
,
target
:
str
|
Target
=
"auto"
,
target_host
:
str
|
Target
=
None
,
target_host
:
str
|
Target
=
None
,
execution_backend
:
Literal
[
"auto"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
execution_backend
:
Literal
[
"auto"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
|
None
=
"auto"
,
|
None
=
"auto"
,
verbose
:
bool
|
None
=
False
,
verbose
:
bool
|
None
=
False
,
pass_configs
:
dict
|
None
=
None
,
pass_configs
:
dict
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
,
...
@@ -36,7 +36,8 @@ def cached(
...
@@ -36,7 +36,8 @@ def cached(
execution_backend
=
execution_backend
,
execution_backend
=
execution_backend
,
verbose
=
verbose
,
verbose
=
verbose
,
pass_configs
=
pass_configs
,
pass_configs
=
pass_configs
,
compile_flags
=
compile_flags
)
compile_flags
=
compile_flags
,
)
def
clear_cache
():
def
clear_cache
():
...
@@ -47,9 +48,11 @@ def clear_cache():
...
@@ -47,9 +48,11 @@ def clear_cache():
RuntimeError: Always raised to warn users to clear the cache manually.
RuntimeError: Always raised to warn users to clear the cache manually.
"""
"""
cache_dir
=
env
.
TILELANG_CACHE_DIR
cache_dir
=
env
.
TILELANG_CACHE_DIR
raise
RuntimeError
(
"tilelang.clear_cache() is disabled because deleting the cache directory "
raise
RuntimeError
(
"is dangerous. If you accept the risk, remove it manually with "
"tilelang.clear_cache() is disabled because deleting the cache directory "
f
"`rm -rf '
{
cache_dir
}
'`."
)
"is dangerous. If you accept the risk, remove it manually with "
f
"`rm -rf '
{
cache_dir
}
'`."
)
if
env
.
TILELANG_CLEAR_CACHE
.
lower
()
in
(
"1"
,
"true"
,
"yes"
,
"on"
):
if
env
.
TILELANG_CLEAR_CACHE
.
lower
()
in
(
"1"
,
"true"
,
"yes"
,
"on"
):
...
...
tilelang/cache/kernel_cache.py
View file @
29051439
"""The cache utils with class and database persistence - KernelCache Class"""
"""The cache utils with class and database persistence - KernelCache Class"""
from
__future__
import
annotations
from
__future__
import
annotations
import
json
import
json
...
@@ -97,9 +98,7 @@ class KernelCache:
...
@@ -97,9 +98,7 @@ class KernelCache:
"version"
:
__version__
,
"version"
:
__version__
,
"func"
:
sha256
(
func_binary
).
hexdigest
(),
# Use SHA256 to generate hash key
"func"
:
sha256
(
func_binary
).
hexdigest
(),
# Use SHA256 to generate hash key
"out_idx"
:
(
tuple
(
out_idx
)
if
isinstance
(
out_idx
,
(
list
,
tuple
))
else
[
out_idx
]),
"out_idx"
:
(
tuple
(
out_idx
)
if
isinstance
(
out_idx
,
(
list
,
tuple
))
else
[
out_idx
]),
"args_repr"
:
tuple
(
"args_repr"
:
tuple
(
repr
(
arg
)
for
arg
in
args
),
# Use repr to serialize arguments, may need more robust serialization
repr
(
arg
)
for
arg
in
args
),
# Use repr to serialize arguments, may need more robust serialization
"target"
:
str
(
target
),
"target"
:
str
(
target
),
"target_host"
:
str
(
target_host
)
if
target_host
else
None
,
"target_host"
:
str
(
target_host
)
if
target_host
else
None
,
"execution_backend"
:
execution_backend
,
"execution_backend"
:
execution_backend
,
...
@@ -118,8 +117,7 @@ class KernelCache:
...
@@ -118,8 +117,7 @@ class KernelCache:
*
args
,
*
args
,
target
:
str
|
Target
=
"auto"
,
target
:
str
|
Target
=
"auto"
,
target_host
:
str
|
Target
=
None
,
target_host
:
str
|
Target
=
None
,
execution_backend
:
Literal
[
"auto"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
execution_backend
:
Literal
[
"auto"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
=
"auto"
,
"torch"
]
=
"auto"
,
verbose
:
bool
=
False
,
verbose
:
bool
=
False
,
pass_configs
:
dict
=
None
,
pass_configs
:
dict
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
,
...
@@ -140,6 +138,7 @@ class KernelCache:
...
@@ -140,6 +138,7 @@ class KernelCache:
# Normalize target and resolve execution backend before proceeding
# Normalize target and resolve execution backend before proceeding
from
tilelang.utils.target
import
determine_target
as
_determine_target
from
tilelang.utils.target
import
determine_target
as
_determine_target
from
tilelang.jit.execution_backend
import
resolve_execution_backend
,
allowed_backends_for_target
from
tilelang.jit.execution_backend
import
resolve_execution_backend
,
allowed_backends_for_target
norm_target
=
Target
(
_determine_target
(
target
))
if
isinstance
(
target
,
str
)
else
target
norm_target
=
Target
(
_determine_target
(
target
))
if
isinstance
(
target
,
str
)
else
target
requested_backend
=
execution_backend
requested_backend
=
execution_backend
execution_backend
=
resolve_execution_backend
(
requested_backend
,
norm_target
)
execution_backend
=
resolve_execution_backend
(
requested_backend
,
norm_target
)
...
@@ -180,21 +179,21 @@ class KernelCache:
...
@@ -180,21 +179,21 @@ class KernelCache:
with
self
.
_lock
:
with
self
.
_lock
:
# First check in-memory cache
# First check in-memory cache
if
key
in
self
.
_memory_cache
:
if
key
in
self
.
_memory_cache
:
self
.
logger
.
warning
(
"Found kernel in memory cache. For better performance,"
\
self
.
logger
.
warning
(
" consider using `@tilelang.jit` instead of direct kernel caching."
)
"Found kernel in memory cache. For better performance, consider using `@tilelang.jit` instead of direct kernel caching."
)
return
self
.
_memory_cache
[
key
]
return
self
.
_memory_cache
[
key
]
if
verbose
:
if
verbose
:
self
.
logger
.
debug
(
f
"Checking disk cache for kernel
{
func
.
attrs
[
'global_symbol'
]
}
"
)
self
.
logger
.
debug
(
f
"Checking disk cache for kernel
{
func
.
attrs
[
'global_symbol'
]
}
"
)
# Then check disk cache
# Then check disk cache
kernel
=
self
.
_load_kernel_from_disk
(
key
,
norm_target
,
target_host
,
out_idx
,
kernel
=
self
.
_load_kernel_from_disk
(
execution_backend
,
pass_configs
,
compile_flags
,
key
,
norm_target
,
target_host
,
out_idx
,
execution_backend
,
pass_configs
,
compile_flags
,
func
,
verbose
func
,
verbose
)
)
if
kernel
is
not
None
:
if
kernel
is
not
None
:
if
verbose
:
if
verbose
:
self
.
logger
.
debug
(
self
.
logger
.
debug
(
f
"Found kernel in disk cache for
{
func
.
attrs
[
'global_symbol'
]
}
"
)
f
"Found kernel in disk cache for
{
func
.
attrs
[
'global_symbol'
]
}
"
)
# Populate memory cache with disk result
# Populate memory cache with disk result
self
.
_memory_cache
[
key
]
=
kernel
self
.
_memory_cache
[
key
]
=
kernel
return
kernel
return
kernel
...
@@ -262,11 +261,7 @@ class KernelCache:
...
@@ -262,11 +261,7 @@ class KernelCache:
executable
.
export_library
(
temp_path
)
executable
.
export_library
(
temp_path
)
os
.
replace
(
temp_path
,
path
)
os
.
replace
(
temp_path
,
path
)
def
_save_kernel_to_disk
(
self
,
def
_save_kernel_to_disk
(
self
,
key
:
str
,
kernel
:
JITKernel
,
func
:
Callable
=
None
,
verbose
:
bool
=
False
):
key
:
str
,
kernel
:
JITKernel
,
func
:
Callable
=
None
,
verbose
:
bool
=
False
):
"""
"""
Persists a compiled kernel to disk cache.
Persists a compiled kernel to disk cache.
...
@@ -292,8 +287,7 @@ class KernelCache:
...
@@ -292,8 +287,7 @@ class KernelCache:
if
verbose
:
if
verbose
:
self
.
logger
.
debug
(
f
"Saving kernel source code to file:
{
device_kernel_path
}
"
)
self
.
logger
.
debug
(
f
"Saving kernel source code to file:
{
device_kernel_path
}
"
)
if
kernel
.
kernel_source
is
not
None
:
if
kernel
.
kernel_source
is
not
None
:
KernelCache
.
_safe_write_file
(
device_kernel_path
,
"w"
,
KernelCache
.
_safe_write_file
(
device_kernel_path
,
"w"
,
lambda
file
:
file
.
write
(
kernel
.
kernel_source
))
lambda
file
:
file
.
write
(
kernel
.
kernel_source
))
except
Exception
as
e
:
except
Exception
as
e
:
self
.
logger
.
error
(
f
"Error saving kernel source code to disk:
{
e
}
"
)
self
.
logger
.
error
(
f
"Error saving kernel source code to disk:
{
e
}
"
)
...
@@ -303,13 +297,9 @@ class KernelCache:
...
@@ -303,13 +297,9 @@ class KernelCache:
if
verbose
:
if
verbose
:
self
.
logger
.
debug
(
f
"Saving wrapped kernel source code to file:
{
host_kernel_path
}
"
)
self
.
logger
.
debug
(
f
"Saving wrapped kernel source code to file:
{
host_kernel_path
}
"
)
if
self
.
execution_backend
==
"tvm_ffi"
:
if
self
.
execution_backend
==
"tvm_ffi"
:
KernelCache
.
_safe_write_file
(
KernelCache
.
_safe_write_file
(
host_kernel_path
,
"w"
,
lambda
file
:
file
.
write
(
kernel
.
adapter
.
get_host_source
()))
host_kernel_path
,
"w"
,
lambda
file
:
file
.
write
(
kernel
.
adapter
.
get_host_source
()))
else
:
else
:
KernelCache
.
_safe_write_file
(
KernelCache
.
_safe_write_file
(
host_kernel_path
,
"w"
,
lambda
file
:
file
.
write
(
kernel
.
adapter
.
get_kernel_source
()))
host_kernel_path
,
"w"
,
lambda
file
:
file
.
write
(
kernel
.
adapter
.
get_kernel_source
()))
except
Exception
as
e
:
except
Exception
as
e
:
self
.
logger
.
error
(
f
"Error saving host kernel source code to disk:
{
e
}
"
)
self
.
logger
.
error
(
f
"Error saving host kernel source code to disk:
{
e
}
"
)
...
@@ -332,9 +322,7 @@ class KernelCache:
...
@@ -332,9 +322,7 @@ class KernelCache:
src_lib_path
=
src_lib_path
.
replace
(
".cubin"
,
".py"
)
src_lib_path
=
src_lib_path
.
replace
(
".cubin"
,
".py"
)
if
verbose
:
if
verbose
:
self
.
logger
.
debug
(
f
"Saving kernel nvrtc python code to file:
{
kernel_py_path
}
"
)
self
.
logger
.
debug
(
f
"Saving kernel nvrtc python code to file:
{
kernel_py_path
}
"
)
KernelCache
.
_safe_write_file
(
KernelCache
.
_safe_write_file
(
kernel_py_path
,
"wb"
,
lambda
file
:
file
.
write
(
KernelCache
.
_load_binary
(
src_lib_path
)))
kernel_py_path
,
"wb"
,
lambda
file
:
file
.
write
(
KernelCache
.
_load_binary
(
src_lib_path
)))
elif
self
.
execution_backend
==
"tvm_ffi"
:
elif
self
.
execution_backend
==
"tvm_ffi"
:
executable
=
kernel
.
adapter
.
executable
executable
=
kernel
.
adapter
.
executable
if
verbose
:
if
verbose
:
...
@@ -344,9 +332,7 @@ class KernelCache:
...
@@ -344,9 +332,7 @@ class KernelCache:
src_lib_path
=
kernel
.
adapter
.
libpath
src_lib_path
=
kernel
.
adapter
.
libpath
if
verbose
:
if
verbose
:
self
.
logger
.
debug
(
f
"Saving kernel library to file:
{
kernel_lib_path
}
"
)
self
.
logger
.
debug
(
f
"Saving kernel library to file:
{
kernel_lib_path
}
"
)
KernelCache
.
_safe_write_file
(
KernelCache
.
_safe_write_file
(
kernel_lib_path
,
"wb"
,
lambda
file
:
file
.
write
(
KernelCache
.
_load_binary
(
src_lib_path
)))
kernel_lib_path
,
"wb"
,
lambda
file
:
file
.
write
(
KernelCache
.
_load_binary
(
src_lib_path
)))
except
Exception
as
e
:
except
Exception
as
e
:
self
.
logger
.
error
(
f
"Error saving kernel library to disk:
{
e
}
"
)
self
.
logger
.
error
(
f
"Error saving kernel library to disk:
{
e
}
"
)
...
@@ -356,8 +342,7 @@ class KernelCache:
...
@@ -356,8 +342,7 @@ class KernelCache:
params_path
=
os
.
path
.
join
(
cache_path
,
PARAMS_PATH
)
params_path
=
os
.
path
.
join
(
cache_path
,
PARAMS_PATH
)
if
verbose
:
if
verbose
:
self
.
logger
.
debug
(
f
"Saving kernel parameters to disk:
{
params_path
}
"
)
self
.
logger
.
debug
(
f
"Saving kernel parameters to disk:
{
params_path
}
"
)
KernelCache
.
_safe_write_file
(
params_path
,
"wb"
,
KernelCache
.
_safe_write_file
(
params_path
,
"wb"
,
lambda
file
:
cloudpickle
.
dump
(
kernel
.
params
,
file
))
lambda
file
:
cloudpickle
.
dump
(
kernel
.
params
,
file
))
except
Exception
as
e
:
except
Exception
as
e
:
self
.
logger
.
error
(
f
"Error saving kernel parameters to disk:
{
e
}
"
)
self
.
logger
.
error
(
f
"Error saving kernel parameters to disk:
{
e
}
"
)
...
@@ -417,8 +402,7 @@ class KernelCache:
...
@@ -417,8 +402,7 @@ class KernelCache:
self
.
logger
.
error
(
f
"Error loading kernel source code from disk:
{
e
}
"
)
self
.
logger
.
error
(
f
"Error loading kernel source code from disk:
{
e
}
"
)
try
:
try
:
if
verbose
:
if
verbose
:
self
.
logger
.
debug
(
self
.
logger
.
debug
(
f
"Loading wrapped kernel source code from file:
{
host_kernel_path
}
"
)
f
"Loading wrapped kernel source code from file:
{
host_kernel_path
}
"
)
with
open
(
host_kernel_path
)
as
f
:
with
open
(
host_kernel_path
)
as
f
:
host_kernel_source
=
f
.
read
()
host_kernel_source
=
f
.
read
()
except
Exception
as
e
:
except
Exception
as
e
:
...
...
tilelang/carver/__init__.py
View file @
29051439
"""Base infra"""
"""Base infra"""
from
.analysis
import
(
from
.analysis
import
(
BlockInfo
,
# noqa: F401
BlockInfo
,
# noqa: F401
IterInfo
,
# noqa: F401
IterInfo
,
# noqa: F401
...
...
tilelang/carver/analysis.py
View file @
29051439
"""Analysis on TIR blocks, loops and functions."""
"""Analysis on TIR blocks, loops and functions."""
from
__future__
import
annotations
from
__future__
import
annotations
from
typing_extensions
import
Literal
from
typing_extensions
import
Literal
...
@@ -144,11 +145,13 @@ def normalize_prim_func(sch: tir.Schedule) -> list[BlockInfo] | None:
...
@@ -144,11 +145,13 @@ def normalize_prim_func(sch: tir.Schedule) -> list[BlockInfo] | None:
var
=
iter
.
var
,
var
=
iter
.
var
,
dom
=
iter
.
dom
,
dom
=
iter
.
dom
,
loop_rv
=
loop
,
loop_rv
=
loop
,
)
for
loop
,
iter
in
zip
(
loops
,
iters
)
)
for
loop
,
iter
in
zip
(
loops
,
iters
)
],
],
block_rv
=
block
,
block_rv
=
block
,
reduction_block
=
is_reduction
,
reduction_block
=
is_reduction
,
))
)
)
return
blocks
return
blocks
...
@@ -188,8 +191,7 @@ def get_max_shared_memory_per_block(target: Target) -> int:
...
@@ -188,8 +191,7 @@ def get_max_shared_memory_per_block(target: Target) -> int:
_assert_gpu_target
(
target
)
_assert_gpu_target
(
target
)
max_shared_memory_per_block
=
target
.
attrs
.
get
(
"max_shared_memory_per_block"
,
None
)
max_shared_memory_per_block
=
target
.
attrs
.
get
(
"max_shared_memory_per_block"
,
None
)
if
max_shared_memory_per_block
is
None
:
if
max_shared_memory_per_block
is
None
:
raise
ValueError
(
raise
ValueError
(
f
"Cannot find `max_shared_memory_per_block` in
{
target
}
, please specify it manually"
)
f
"Cannot find `max_shared_memory_per_block` in
{
target
}
, please specify it manually"
)
return
int
(
max_shared_memory_per_block
)
return
int
(
max_shared_memory_per_block
)
...
@@ -197,13 +199,11 @@ def get_root_block(sch: Schedule, func_name: str = "main") -> BlockRV:
...
@@ -197,13 +199,11 @@ def get_root_block(sch: Schedule, func_name: str = "main") -> BlockRV:
try
:
try
:
block
=
sch
.
mod
[
func_name
].
body
.
block
block
=
sch
.
mod
[
func_name
].
body
.
block
except
Exception
:
except
Exception
:
raise
ValueError
(
f
"The function body is expected to be the root block, but got:
\n
"
raise
ValueError
(
f
"The function body is expected to be the root block, but got:
\n
{
sch
.
mod
[
func_name
].
body
}
"
)
from
None
f
"
{
sch
.
mod
[
func_name
].
body
}
"
)
from
None
return
sch
.
get_block
(
block
.
name_hint
)
return
sch
.
get_block
(
block
.
name_hint
)
def
collect_block_iter_vars_used_in_access_region
(
block
:
tir
.
Block
,
def
collect_block_iter_vars_used_in_access_region
(
block
:
tir
.
Block
,
region
:
list
[
ir
.
Range
])
->
set
[
tir
.
Var
]:
region
:
list
[
ir
.
Range
])
->
set
[
tir
.
Var
]:
"""Collect the block iter variables used in the access region of a buffer region."""
"""Collect the block iter variables used in the access region of a buffer region."""
tir_vars
=
set
()
tir_vars
=
set
()
for
expr
in
region
:
for
expr
in
region
:
...
@@ -251,15 +251,13 @@ def is_broadcast_epilogue(
...
@@ -251,15 +251,13 @@ def is_broadcast_epilogue(
for
buffer_region
in
sch
.
get
(
epilogue
).
reads
:
for
buffer_region
in
sch
.
get
(
epilogue
).
reads
:
if
buffer_region
.
buffer
not
in
write_buffers
:
if
buffer_region
.
buffer
not
in
write_buffers
:
continue
continue
tir_vars
=
collect_block_iter_vars_used_in_access_region
(
tir_vars
=
collect_block_iter_vars_used_in_access_region
(
sch
.
get
(
epilogue
),
buffer_region
.
region
)
sch
.
get
(
epilogue
),
buffer_region
.
region
)
if
len
(
tir_vars
)
<
len
(
epilogue_iters
):
if
len
(
tir_vars
)
<
len
(
epilogue_iters
):
return
True
return
True
return
False
return
False
def
get_reduction_blocks
(
sch
:
tir
.
Schedule
,
def
get_reduction_blocks
(
sch
:
tir
.
Schedule
,
blocks
:
list
[
tir
.
schedule
.
BlockRV
])
->
list
[
tir
.
schedule
.
BlockRV
]:
blocks
:
list
[
tir
.
schedule
.
BlockRV
])
->
list
[
tir
.
schedule
.
BlockRV
]:
# Get the main computation block
# Get the main computation block
def
is_reduction
(
block
:
BlockRV
)
->
bool
:
def
is_reduction
(
block
:
BlockRV
)
->
bool
:
block_stmt
=
sch
.
get
(
block
)
block_stmt
=
sch
.
get
(
block
)
...
...
tilelang/carver/arch/__init__.py
View file @
29051439
...
@@ -39,18 +39,18 @@ def auto_infer_current_arch() -> TileDevice:
...
@@ -39,18 +39,18 @@ def auto_infer_current_arch() -> TileDevice:
__all__
=
[
__all__
=
[
'
is_cpu_arch
'
,
"
is_cpu_arch
"
,
'
is_cuda_arch
'
,
"
is_cuda_arch
"
,
'
is_volta_arch
'
,
"
is_volta_arch
"
,
'
is_ampere_arch
'
,
"
is_ampere_arch
"
,
'
is_ada_arch
'
,
"
is_ada_arch
"
,
'
is_hopper_arch
'
,
"
is_hopper_arch
"
,
'
is_tensorcore_supported_precision
'
,
"
is_tensorcore_supported_precision
"
,
'
has_mma_support
'
,
"
has_mma_support
"
,
'
is_cdna_arch
'
,
"
is_cdna_arch
"
,
'
is_metal_arch
'
,
"
is_metal_arch
"
,
'
CUDA
'
,
"
CUDA
"
,
'
CDNA
'
,
"
CDNA
"
,
'
METAL
'
,
"
METAL
"
,
'
CPU
'
,
"
CPU
"
,
]
]
tilelang/carver/arch/arch_base.py
View file @
29051439
...
@@ -7,9 +7,7 @@ class TileDevice:
...
@@ -7,9 +7,7 @@ class TileDevice:
self
.
reg_cap
:
int
=
0
# Register capacity: The amount of register memory available
self
.
reg_cap
:
int
=
0
# Register capacity: The amount of register memory available
self
.
smem_cap
:
int
=
0
# Shared memory capacity: The amount of shared memory available
self
.
smem_cap
:
int
=
0
# Shared memory capacity: The amount of shared memory available
self
.
compute_max_core
:
int
=
0
# The maximum number of computing cores
self
.
compute_max_core
:
int
=
0
# The maximum number of computing cores
self
.
warp_size
:
int
=
(
self
.
warp_size
:
int
=
0
# The size of a warp, a group of threads that execute instructions in lockstep
0
# The size of a warp, a group of threads that execute instructions in lockstep
)
self
.
sm_partition
:
int
=
0
# The number of streaming multiprocessor partitions
self
.
sm_partition
:
int
=
0
# The number of streaming multiprocessor partitions
self
.
transaction_size
:
list
[
int
]
=
[
self
.
transaction_size
:
list
[
int
]
=
[
0
,
0
,
...
@@ -21,9 +19,7 @@ class TileDevice:
...
@@ -21,9 +19,7 @@ class TileDevice:
0
,
0
,
]
# Bandwidth specifications, possibly including peak and sustained rates
]
# Bandwidth specifications, possibly including peak and sustained rates
self
.
platform
:
str
=
"unknown"
# The platform or manufacturer of the device
self
.
platform
:
str
=
"unknown"
# The platform or manufacturer of the device
self
.
compute_capability
:
str
=
(
self
.
compute_capability
:
str
=
"unknown"
# The compute capability, indicating the feature set and performance level
"unknown"
# The compute capability, indicating the feature set and performance level
)
self
.
l2_cache_size_bytes
:
int
=
0
self
.
l2_cache_size_bytes
:
int
=
0
# the number of transaction size in bytes
# the number of transaction size in bytes
self
.
transaction_size
:
list
[
int
]
=
[
0
,
0
]
# in bytes
self
.
transaction_size
:
list
[
int
]
=
[
0
,
0
]
# in bytes
...
...
tilelang/carver/arch/cdna.py
View file @
29051439
...
@@ -9,7 +9,6 @@ def is_cdna_arch(arch: TileDevice) -> bool:
...
@@ -9,7 +9,6 @@ def is_cdna_arch(arch: TileDevice) -> bool:
class
CDNA
(
TileDevice
):
class
CDNA
(
TileDevice
):
def
__init__
(
self
,
target
:
Target
|
str
):
def
__init__
(
self
,
target
:
Target
|
str
):
if
isinstance
(
target
,
str
):
if
isinstance
(
target
,
str
):
target
=
tvm
.
target
.
Target
(
target
)
target
=
tvm
.
target
.
Target
(
target
)
...
@@ -33,6 +32,6 @@ class CDNA(TileDevice):
...
@@ -33,6 +32,6 @@ class CDNA(TileDevice):
__all__
=
[
__all__
=
[
'
is_cdna_arch
'
,
"
is_cdna_arch
"
,
'
CDNA
'
,
"
CDNA
"
,
]
]
tilelang/carver/arch/cpu.py
View file @
29051439
...
@@ -10,7 +10,6 @@ def is_cpu_arch(arch: TileDevice) -> bool:
...
@@ -10,7 +10,6 @@ def is_cpu_arch(arch: TileDevice) -> bool:
# For LLVM Backend, we do not provide the detailed information of the CPU
# For LLVM Backend, we do not provide the detailed information of the CPU
# As the LLVM backend do not required tuning, just maintain the consistency
# As the LLVM backend do not required tuning, just maintain the consistency
class
CPU
(
TileDevice
):
class
CPU
(
TileDevice
):
def
__init__
(
self
,
target
:
Target
):
def
__init__
(
self
,
target
:
Target
):
self
.
target
=
target
self
.
target
=
target
device
=
tvm
.
runtime
.
cpu
(
0
)
device
=
tvm
.
runtime
.
cpu
(
0
)
...
@@ -21,6 +20,6 @@ class CPU(TileDevice):
...
@@ -21,6 +20,6 @@ class CPU(TileDevice):
__all__
=
[
__all__
=
[
'
is_cpu_arch
'
,
"
is_cpu_arch
"
,
'
CPU
'
,
"
CPU
"
,
]
]
tilelang/carver/arch/cuda.py
View file @
29051439
...
@@ -78,7 +78,6 @@ hopper_tensorcore_supported = ada_tensorcore_supported
...
@@ -78,7 +78,6 @@ hopper_tensorcore_supported = ada_tensorcore_supported
# instead of assuming both a and b share the same dtype.
# instead of assuming both a and b share the same dtype.
# As the tensorcore may supports float8_e4m3 * float8_e5m2
# As the tensorcore may supports float8_e4m3 * float8_e5m2
def
is_tensorcore_supported_precision
(
in_dtype
:
str
,
accum_dtype
:
str
,
arch
:
TileDevice
)
->
bool
:
def
is_tensorcore_supported_precision
(
in_dtype
:
str
,
accum_dtype
:
str
,
arch
:
TileDevice
)
->
bool
:
if
is_volta_arch
(
arch
):
if
is_volta_arch
(
arch
):
return
(
in_dtype
,
accum_dtype
)
in
volta_tensorcore_supported
return
(
in_dtype
,
accum_dtype
)
in
volta_tensorcore_supported
elif
is_ampere_arch
(
arch
):
elif
is_ampere_arch
(
arch
):
...
@@ -92,7 +91,6 @@ def is_tensorcore_supported_precision(in_dtype: str, accum_dtype: str, arch: Til
...
@@ -92,7 +91,6 @@ def is_tensorcore_supported_precision(in_dtype: str, accum_dtype: str, arch: Til
class
TensorInstruction
:
class
TensorInstruction
:
def
__init__
(
def
__init__
(
self
,
self
,
name
:
str
,
name
:
str
,
...
@@ -104,7 +102,6 @@ class TensorInstruction:
...
@@ -104,7 +102,6 @@ class TensorInstruction:
class
CUDA
(
TileDevice
):
class
CUDA
(
TileDevice
):
def
__init__
(
self
,
target
:
Target
|
str
):
def
__init__
(
self
,
target
:
Target
|
str
):
if
isinstance
(
target
,
str
):
if
isinstance
(
target
,
str
):
target
=
tvm
.
target
.
Target
(
target
)
target
=
tvm
.
target
.
Target
(
target
)
...
@@ -148,12 +145,12 @@ class CUDA(TileDevice):
...
@@ -148,12 +145,12 @@ class CUDA(TileDevice):
__all__
=
[
__all__
=
[
'
is_cuda_arch
'
,
"
is_cuda_arch
"
,
'
is_volta_arch
'
,
"
is_volta_arch
"
,
'
is_ampere_arch
'
,
"
is_ampere_arch
"
,
'
is_ada_arch
'
,
"
is_ada_arch
"
,
'
is_hopper_arch
'
,
"
is_hopper_arch
"
,
'
is_tensorcore_supported_precision
'
,
"
is_tensorcore_supported_precision
"
,
'
has_mma_support
'
,
"
has_mma_support
"
,
"CUDA"
,
"CUDA"
,
]
]
tilelang/carver/arch/driver/cuda_driver.py
View file @
29051439
...
@@ -83,8 +83,7 @@ def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes")
...
@@ -83,8 +83,7 @@ def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes")
Get the maximum dynamic shared memory size in bytes, kilobytes, or megabytes.
Get the maximum dynamic shared memory size in bytes, kilobytes, or megabytes.
"""
"""
assert
format
in
[
"bytes"
,
"kb"
,
"mb"
],
"Invalid format. Must be one of: bytes, kb, mb"
assert
format
in
[
"bytes"
,
"kb"
,
"mb"
],
"Invalid format. Must be one of: bytes, kb, mb"
shared_mem
=
get_device_attribute
(
shared_mem
=
get_device_attribute
(
cudaDeviceAttrNames
.
cudaDevAttrMaxSharedMemoryPerMultiprocessor
,
device_id
)
cudaDeviceAttrNames
.
cudaDevAttrMaxSharedMemoryPerMultiprocessor
,
device_id
)
if
format
==
"bytes"
:
if
format
==
"bytes"
:
return
shared_mem
return
shared_mem
elif
format
==
"kb"
:
elif
format
==
"kb"
:
...
...
Prev
1
…
13
14
15
16
17
18
19
20
21
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment