Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
541e1685
Commit
541e1685
authored
Mar 25, 2025
by
yyttt6
Committed by
LeiWang1999
Mar 25, 2025
Browse files
[Refactor] Enhance Autotune (#266)
* add autotune to example_gemm.py * format init.py
parent
8ad53855
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
369 additions
and
106 deletions
+369
-106
examples/gemm/example_gemm.py
examples/gemm/example_gemm.py
+15
-30
testing/python/autotune/test_tilelang_autotune.py
testing/python/autotune/test_tilelang_autotune.py
+12
-22
testing/python/autotune/test_tilelang_autotune_decorator.py
testing/python/autotune/test_tilelang_autotune_decorator.py
+269
-0
tilelang/autotuner/__init__.py
tilelang/autotuner/__init__.py
+73
-54
No files found.
examples/gemm/example_gemm.py
View file @
541e1685
...
@@ -3,7 +3,7 @@ import torch
...
@@ -3,7 +3,7 @@ import torch
import
itertools
import
itertools
import
tilelang
as
tl
import
tilelang
as
tl
import
tilelang.language
as
T
import
tilelang.language
as
T
from
tilelang.autotuner
import
a
uto
t
une
,
jit
from
tilelang.autotuner
import
A
uto
T
une
r
from
tilelang.carver.template
import
MatmulTemplate
from
tilelang.carver.template
import
MatmulTemplate
from
tilelang.carver.arch
import
CUDA
from
tilelang.carver.arch
import
CUDA
from
tilelang.carver.roller.rasterization
import
NoRasterization
from
tilelang.carver.roller.rasterization
import
NoRasterization
...
@@ -79,26 +79,6 @@ def get_configs(M, N, K, with_roller=False):
...
@@ -79,26 +79,6 @@ def get_configs(M, N, K, with_roller=False):
def
get_best_config
(
M
,
N
,
K
,
with_roller
=
False
):
def
get_best_config
(
M
,
N
,
K
,
with_roller
=
False
):
@
autotune
(
configs
=
get_configs
(
M
,
N
,
K
,
with_roller
),
keys
=
[
"block_M"
,
"block_N"
,
"block_K"
,
"num_stages"
,
"thread_num"
,
"enable_rasteration"
,
],
warmup
=
3
,
rep
=
20
,
)
@
jit
(
out_idx
=
[
-
1
],
supply_type
=
tl
.
TensorSupplyType
.
Integer
,
ref_prog
=
ref_program
,
skip_check
=
False
,
target
=
"auto"
,
)
def
kernel
(
def
kernel
(
block_M
=
None
,
block_M
=
None
,
block_N
=
None
,
block_N
=
None
,
...
@@ -138,7 +118,15 @@ def get_best_config(M, N, K, with_roller=False):
...
@@ -138,7 +118,15 @@ def get_best_config(M, N, K, with_roller=False):
return
main
return
main
return
kernel
()
autotuner
=
AutoTuner
.
from_kernel
(
kernel
=
kernel
,
configs
=
get_configs
(
M
,
N
,
K
,
with_roller
)).
set_compile_args
(
out_idx
=
[
-
1
],
supply_type
=
tl
.
TensorSupplyType
.
Integer
,
ref_prog
=
ref_program
,
skip_check
=
False
,
target
=
"auto"
,
)
return
autotuner
.
run
(
warmup
=
3
,
rep
=
20
)
def
matmul
(
M
,
def
matmul
(
M
,
...
@@ -200,19 +188,16 @@ if __name__ == "__main__":
...
@@ -200,19 +188,16 @@ if __name__ == "__main__":
M
,
N
,
K
=
args
.
m
,
args
.
n
,
args
.
k
M
,
N
,
K
=
args
.
m
,
args
.
n
,
args
.
k
a
=
torch
.
randn
(
M
,
K
).
cuda
().
half
()
a
=
torch
.
randn
(
M
,
K
).
cuda
().
half
()
b
=
torch
.
randn
(
N
,
K
).
cuda
().
half
()
b
=
torch
.
randn
(
N
,
K
).
cuda
().
half
()
c
=
torch
.
zeros
(
M
,
N
).
cuda
().
half
()
configs
=
[]
configs
=
[]
use_autotune
=
args
.
use_autotune
use_autotune
=
args
.
use_autotune
with_roller
=
args
.
with_roller
with_roller
=
args
.
with_roller
if
use_autotune
:
if
use_autotune
:
best_latency
,
best_config
,
ref_latency
=
get_best_config
(
M
,
N
,
K
,
with_roller
)
result
=
get_best_config
(
M
,
N
,
K
,
with_roller
)
func
=
matmul
(
M
,
N
,
K
,
*
best_config
)
print
(
f
"best latency
{
result
.
latency
}
"
)
kernel
=
result
.
kernel
else
:
else
:
func
=
matmul
(
M
,
N
,
K
,
128
,
128
,
32
,
3
,
128
,
True
)
kernel
=
tl
.
compile
(
matmul
(
M
,
N
,
K
,
128
,
128
,
32
,
3
,
128
,
True
)
,
out_idx
=-
1
)
# print(func)
kernel
=
tl
.
compile
(
func
,
out_idx
=-
1
)
out_c
=
kernel
(
a
,
b
)
out_c
=
kernel
(
a
,
b
)
ref_c
=
a
@
b
.
T
+
c
ref_c
=
ref_program
(
a
,
b
)
torch
.
testing
.
assert_close
(
out_c
,
ref_c
,
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
out_c
,
ref_c
,
rtol
=
1e-2
,
atol
=
1e-2
)
# print(kernel.get_kernel_source())
testing/python/autotune/test_tilelang_autotune.py
View file @
541e1685
...
@@ -4,7 +4,7 @@ import logging
...
@@ -4,7 +4,7 @@ import logging
import
tilelang
as
tl
import
tilelang
as
tl
import
tilelang.testing
import
tilelang.testing
import
tilelang.language
as
T
import
tilelang.language
as
T
from
tilelang.autotuner
import
a
uto
t
une
,
jit
from
tilelang.autotuner
import
A
uto
T
une
r
# Configure logger
# Configure logger
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -151,26 +151,6 @@ def matmul(M, N, K, with_roller):
...
@@ -151,26 +151,6 @@ def matmul(M, N, K, with_roller):
# - The "tvm" profiler backend
# - The "tvm" profiler backend
# - HIP as the compilation target (modify as needed for your hardware)
# - HIP as the compilation target (modify as needed for your hardware)
@
autotune
(
configs
=
get_configs
(
M
,
N
,
K
,
with_roller
),
keys
=
[
"block_M"
,
"block_N"
,
"block_K"
,
"num_stages"
,
"thread_num"
,
"enable_rasteration"
,
],
warmup
=
3
,
rep
=
5
,
)
@
jit
(
out_idx
=
[
2
],
supply_type
=
tl
.
TensorSupplyType
.
Integer
,
ref_prog
=
ref_program
,
skip_check
=
True
,
target
=
"auto"
,
)
def
kernel
(
def
kernel
(
block_M
=
None
,
block_M
=
None
,
block_N
=
None
,
block_N
=
None
,
...
@@ -268,14 +248,24 @@ def matmul(M, N, K, with_roller):
...
@@ -268,14 +248,24 @@ def matmul(M, N, K, with_roller):
return
main
return
main
return
kernel
()
autotuner
=
AutoTuner
.
from_kernel
(
kernel
=
kernel
,
configs
=
get_configs
(
M
,
N
,
K
,
with_roller
)).
set_compile_args
(
out_idx
=
[
-
1
],
supply_type
=
tl
.
TensorSupplyType
.
Integer
,
ref_prog
=
ref_program
,
skip_check
=
False
,
target
=
"auto"
,
)
return
autotuner
.
run
(
warmup
=
3
,
rep
=
20
)
def
test_autotune_get_configs
():
def
test_autotune_get_configs
():
get_configs
(
8192
,
8192
,
8192
,
with_roller
=
True
)
get_configs
(
8192
,
8192
,
8192
,
with_roller
=
False
)
get_configs
(
8192
,
8192
,
8192
,
with_roller
=
False
)
def
test_autotune_matmul
():
def
test_autotune_matmul
():
matmul
(
8192
,
8192
,
8192
,
with_roller
=
True
)
matmul
(
8192
,
8192
,
8192
,
with_roller
=
False
)
matmul
(
8192
,
8192
,
8192
,
with_roller
=
False
)
...
...
testing/python/autotune/test_tilelang_autotune_decorator.py
0 → 100644
View file @
541e1685
import
itertools
import
logging
import
tilelang
as
tl
import
tilelang.testing
import
tilelang.language
as
T
from
tilelang.autotuner
import
jit
,
autotune
# Configure logger
logger
=
logging
.
getLogger
(
__name__
)
logger
.
setLevel
(
logging
.
DEBUG
)
def
ref_program
(
A
,
B
):
"""
A reference matrix multiplication program, used to compare performance.
Parameters
----------
A : numpy.ndarray
The matrix with shape (M, K).
B : numpy.ndarray
The matrix with shape (N, K).
Returns
-------
np.ndarray
The result of A @ B.T, shape (M, N).
"""
return
A
@
B
.
T
def
get_configs
(
M
,
N
,
K
,
with_roller
=
False
):
"""
Generate a list of configuration dictionaries that will be used for tuning.
Parameters
----------
with_roller : bool
Whether to enable bitblas roller to deduce search spaces
Returns
-------
list of dict
Each configuration dict includes various block sizes, pipeline stages,
thread numbers, and other parameters to explore during autotuning.
"""
if
with_roller
:
from
tilelang.carver.template
import
MatmulTemplate
from
tilelang.carver.arch
import
CUDA
from
tilelang.carver.roller.rasterization
import
NoRasterization
arch
=
CUDA
(
"cuda"
)
topk
=
20
# Simple TIR Compute Expression
carve_template
=
MatmulTemplate
(
M
=
M
,
N
=
N
,
K
=
K
,
in_dtype
=
"float16"
,
out_dtype
=
"float16"
,
accum_dtype
=
"float16"
,
).
with_arch
(
arch
)
func
=
carve_template
.
equivalent_function
()
assert
func
is
not
None
,
"Function is None"
roller_hints
=
carve_template
.
recommend_hints
(
topk
=
topk
)
if
roller_hints
is
None
:
raise
ValueError
(
"No Roller Hints Found for TensorCore Scheduling"
)
configs
=
[]
for
hint
in
roller_hints
:
config
=
{}
block_m
,
block_n
=
hint
.
block
warp_m
,
warp_n
=
hint
.
warp
config
[
"block_M"
]
=
block_m
config
[
"block_N"
]
=
block_n
config
[
"block_K"
]
=
hint
.
rstep
[
0
]
config
[
"num_stages"
]
=
0
config
[
"thread_num"
]
=
(
block_m
*
block_n
)
//
(
warp_m
*
warp_n
)
*
32
config
[
"enable_rasteration"
]
=
hint
.
rasterization_plan
is
not
NoRasterization
configs
.
append
(
config
)
for
config
in
configs
:
print
(
config
)
else
:
block_M
=
[
64
]
block_N
=
[
64
]
block_K
=
[
32
]
num_stages
=
[
0
,
1
]
thread_num
=
[
128
]
enable_rasterization
=
[
False
]
_configs
=
list
(
itertools
.
product
(
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasterization
,
))
configs
=
[
{
"block_M"
:
c
[
0
],
"block_N"
:
c
[
1
],
"block_K"
:
c
[
2
],
"num_stages"
:
c
[
3
],
"thread_num"
:
c
[
4
],
"enable_rasteration"
:
c
[
5
],
# keep param name for backward-compat
}
for
c
in
_configs
]
return
configs
def
matmul
(
M
,
N
,
K
,
with_roller
):
"""
Create an autotuned matrix multiplication kernel for matrices of shape:
- A: (M, K)
- B: (N, K)
- C: (M, N)
Parameters
----------
M : int
The dimension M of the matrix multiplication.
N : int
The dimension N of the matrix multiplication.
K : int
The dimension K of the matrix multiplication.
Returns
-------
(best_latency, best_config, ref_latency)
best_latency : float
The best latency found among the tuned configurations.
best_config : dict
The parameter configuration that yielded best_latency.
ref_latency : float
The baseline latency of the reference program (for computing speedup).
"""
@
autotune
(
configs
=
get_configs
(
M
,
N
,
K
,
with_roller
),
warmup
=
3
,
rep
=
20
,
)
@
jit
(
out_idx
=
[
-
1
],
supply_type
=
tl
.
TensorSupplyType
.
Integer
,
ref_prog
=
ref_program
,
skip_check
=
False
,
target
=
"auto"
,
)
def
kernel
(
block_M
=
None
,
block_N
=
None
,
block_K
=
None
,
num_stages
=
None
,
thread_num
=
None
,
enable_rasteration
=
None
,
):
"""
The actual kernel to compute C = A @ B^T.
Parameters
----------
block_M : int
Block size in M dimension.
block_N : int
Block size in N dimension.
block_K : int
Block size in K dimension.
num_stages : int
Number of pipelined stages (for asynchronous load).
thread_num : int
Number of threads to use per block.
enable_rasteration : bool
Whether to enable rasterization (swizzling) optimization.
k_pack : int
K dimension packing factor to improve memory coalescing.
Returns
-------
Function
A TVM Tensor Language function (T.prim_func) that computes matmul.
"""
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
dtype
=
"float16"
accum_dtype
=
"float"
@
T
.
prim_func
def
main
(
A
:
T
.
Buffer
((
M
,
K
),
dtype
),
B
:
T
.
Buffer
((
N
,
K
),
dtype
),
C
:
T
.
Buffer
((
M
,
N
),
dtype
),
):
"""
The compiled TVM function for block-level matrix multiplication.
- We divide the entire (M, N) domain into blocks of shape
(block_M, block_N).
- Each block has its own allocated shared memory for sub-blocks
of A and B.
- The partial results go into C_local, and then we copy them back
to global memory C.
"""
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
dtype
)
# Allocate a local fragment for intermediate accumulation
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
# Enable (or disable) swizzling optimization
T
.
use_swizzle
(
panel_size
=
10
,
enable
=
enable_rasteration
)
# Clear out the accumulation buffer
T
.
clear
(
C_local
)
# Loop over sub-blocks in K dimension, pipelined by num_stages
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
# Load a sub-block of A from global memory into A_shared
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
,
)
# Load a sub-block of B from global memory into B_shared
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
,
)
# Perform a partial matrix multiplication:
# C_local += A_shared @ B_shared^T
T
.
gemm
(
A_shared
,
B_shared
,
C_local
,
transpose_B
=
True
,
)
# Write back the results from C_local to the global memory C
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
return
main
return
kernel
()
def
test_autotune_get_configs
():
get_configs
(
8192
,
8192
,
8192
,
with_roller
=
True
)
get_configs
(
8192
,
8192
,
8192
,
with_roller
=
False
)
def
test_autotune_matmul
():
matmul
(
8192
,
8192
,
8192
,
with_roller
=
True
)
matmul
(
8192
,
8192
,
8192
,
with_roller
=
False
)
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
tilelang/autotuner/__init__.py
View file @
541e1685
...
@@ -3,14 +3,13 @@
...
@@ -3,14 +3,13 @@
import
tilelang
import
tilelang
from
tilelang
import
tvm
as
tvm
from
tilelang
import
tvm
as
tvm
import
inspect
import
inspect
from
functools
import
wraps
from
functools
import
wraps
,
partial
from
typing
import
Any
,
Callable
,
List
,
Literal
from
typing
import
Callable
,
List
,
Literal
,
Any
from
tqdm
import
tqdm
from
tqdm
import
tqdm
import
logging
import
logging
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
import
concurrent.futures
import
concurrent.futures
import
os
import
os
from
functools
import
partial
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -34,40 +33,65 @@ class JITContext:
...
@@ -34,40 +33,65 @@ class JITContext:
target
:
Literal
[
'cuda'
,
'hip'
]
target
:
Literal
[
'cuda'
,
'hip'
]
class
Autotuner
:
@
dataclass
(
frozen
=
True
)
class
AutotuneResult
:
latency
:
float
config
:
dict
ref_latency
:
float
libcode
:
str
func
:
Callable
kernel
:
Callable
def
__init__
(
class
AutoTuner
:
self
,
fn
:
Callable
,
def
__init__
(
self
,
fn
:
Callable
,
configs
):
configs
:
Any
,
keys
:
List
[
str
],
warmup
:
int
=
25
,
rep
:
int
=
100
,
timeout
:
int
=
30
,
):
self
.
fn
=
fn
self
.
fn
=
fn
self
.
configs
=
configs
self
.
configs
=
configs
self
.
keys
=
keys
self
.
warmup
=
warmup
self
.
rep
=
rep
self
.
timeout
=
timeout
# Precompute cached variables
self
.
ref_latency_cache
=
None
self
.
ref_latency_cache
=
None
self
.
jit_input_tensors
=
None
self
.
jit_input_tensors
=
None
self
.
ref_input_tensors
=
None
self
.
ref_input_tensors
=
None
def
jit_compile
(
self
,
config_arg
)
->
JITContext
:
@
classmethod
jit_context
=
self
.
fn
(
*
config_arg
)
def
from_kernel
(
cls
,
kernel
:
Callable
,
configs
):
return
cls
(
kernel
,
configs
)
def
set_compile_args
(
self
,
out_idx
:
List
[
int
],
supply_type
:
tilelang
.
TensorSupplyType
=
tilelang
.
TensorSupplyType
.
Normal
,
ref_prog
:
Callable
=
None
,
rtol
:
float
=
1e-2
,
atol
:
float
=
1e-2
,
max_mismatched_ratio
:
float
=
0.01
,
skip_check
:
bool
=
False
,
target
:
Literal
[
'auto'
,
'cuda'
,
'hip'
]
=
'auto'
):
def
_compile
(
*
config_arg
):
kernel
=
tilelang
.
compile
(
self
.
fn
(
*
config_arg
),
out_idx
=
out_idx
,
target
=
target
)
profiler
=
kernel
.
get_profiler
()
jit_context
=
JITContext
(
out_idx
=
out_idx
,
supply_type
=
supply_type
,
ref_prog
=
ref_prog
,
rtol
=
rtol
,
atol
=
atol
,
max_mismatched_ratio
=
max_mismatched_ratio
,
skip_check
=
skip_check
,
profiler
=
profiler
,
target
=
target
)
return
jit_context
return
jit_context
def
run
(
self
,
*
args
:
Any
,
**
kwds
:
Any
)
->
Any
:
self
.
jit_compile
=
_compile
return
self
def
run
(
self
,
warmup
:
int
=
25
,
rep
:
int
=
100
,
timeout
:
int
=
100
):
sig
=
inspect
.
signature
(
self
.
fn
)
sig
=
inspect
.
signature
(
self
.
fn
)
bound_args
=
sig
.
bind
(
*
args
,
**
kwds
)
keys
=
list
(
sig
.
parameters
.
keys
())
bound_args
=
sig
.
bind
()
bound_args
.
apply_defaults
()
bound_args
.
apply_defaults
()
best_latency
=
1e8
best_latency
=
1e8
best_config
=
None
best_config
=
None
best_jit_context
=
None
def
target_fn
(
jit_context
):
def
target_fn
(
jit_context
):
# Unpack the context
# Unpack the context
...
@@ -87,49 +111,38 @@ class Autotuner:
...
@@ -87,49 +111,38 @@ class Autotuner:
ref_prog
,
rtol
=
rtol
,
atol
=
atol
,
max_mismatched_ratio
=
max_mismatched_ratio
)
ref_prog
,
rtol
=
rtol
,
atol
=
atol
,
max_mismatched_ratio
=
max_mismatched_ratio
)
latency
=
profiler
.
do_bench
(
latency
=
profiler
.
do_bench
(
profiler
.
func
,
profiler
.
func
,
n_warmup
=
warmup
,
n_repeat
=
rep
,
input_tensors
=
self
.
jit_input_tensors
)
n_warmup
=
self
.
warmup
,
n_repeat
=
self
.
rep
,
input_tensors
=
self
.
jit_input_tensors
)
if
self
.
ref_latency_cache
is
None
and
ref_prog
is
not
None
:
if
self
.
ref_latency_cache
is
None
and
ref_prog
is
not
None
:
self
.
ref_input_tensors
=
profiler
.
_get_inputs
(
self
.
ref_input_tensors
=
profiler
.
_get_inputs
(
with_output
=
False
)
if
self
.
ref_input_tensors
is
None
else
self
.
ref_input_tensors
with_output
=
False
)
if
self
.
ref_input_tensors
is
None
else
self
.
ref_input_tensors
self
.
ref_latency_cache
=
profiler
.
do_bench
(
self
.
ref_latency_cache
=
profiler
.
do_bench
(
ref_prog
,
ref_prog
,
n_warmup
=
warmup
,
n_repeat
=
rep
,
input_tensors
=
self
.
ref_input_tensors
)
n_warmup
=
self
.
warmup
,
n_repeat
=
self
.
rep
,
input_tensors
=
self
.
ref_input_tensors
)
return
latency
,
self
.
ref_latency_cache
return
latency
,
self
.
ref_latency_cache
# Parallel compilation
config_args
=
[]
config_args
=
[]
for
config
in
self
.
configs
:
for
config
in
self
.
configs
:
new_args
=
[]
new_args
=
[]
for
name
,
value
in
bound_args
.
arguments
.
items
():
for
name
,
value
in
bound_args
.
arguments
.
items
():
if
name
not
in
self
.
keys
:
if
name
not
in
keys
:
new_args
.
append
(
value
)
new_args
.
append
(
value
)
else
:
else
:
new_args
.
append
(
config
[
name
])
new_args
.
append
(
config
[
name
])
new_args
=
tuple
(
new_args
)
new_args
=
tuple
(
new_args
)
config_args
.
append
(
new_args
)
config_args
.
append
(
new_args
)
worker
=
partial
(
self
.
jit_compile
,
**
kwds
)
# 90% utilization
num_workers
=
max
(
1
,
int
(
os
.
cpu_count
()
*
0.9
))
num_workers
=
max
(
1
,
int
(
os
.
cpu_count
()
*
0.9
))
pool
=
concurrent
.
futures
.
ThreadPoolExecutor
(
max_workers
=
num_workers
)
pool
=
concurrent
.
futures
.
ThreadPoolExecutor
(
max_workers
=
num_workers
)
# Submit all compilation jobs
futures
=
[]
futures
=
[]
future_to_index
=
{}
# Track which future corresponds to which config
future_to_index
=
{}
for
i
,
config_arg
in
enumerate
(
config_args
):
for
i
,
config_arg
in
enumerate
(
config_args
):
future
=
pool
.
submit
(
worker
,
config_arg
)
future
=
pool
.
submit
(
self
.
jit_compile
,
*
config_arg
,
)
futures
.
append
(
future
)
futures
.
append
(
future
)
future_to_index
[
future
]
=
i
future_to_index
[
future
]
=
i
# Process results with error handling
results_with_configs
=
[]
results_with_configs
=
[]
for
future
in
tqdm
(
for
future
in
tqdm
(
concurrent
.
futures
.
as_completed
(
futures
),
concurrent
.
futures
.
as_completed
(
futures
),
...
@@ -164,28 +177,34 @@ class Autotuner:
...
@@ -164,28 +177,34 @@ class Autotuner:
if
latency
<
best_latency
:
if
latency
<
best_latency
:
best_latency
=
latency
best_latency
=
latency
best_config
=
config
best_config
=
config
best_jit_context
=
jit_context
progress_bar
.
set_postfix
({
"best_latency"
:
best_latency
})
progress_bar
.
set_postfix
({
"best_latency"
:
best_latency
})
tqdm
.
write
(
f
"Tuned Latency
{
latency
}
with config
{
config
}
at index
{
i
}
"
)
tqdm
.
write
(
f
"Tuned Latency
{
latency
}
with config
{
config
}
at index
{
i
}
"
)
pool
.
shutdown
()
pool
.
shutdown
()
return
best_latency
,
best_config
,
ref_latency
return
AutotuneResult
(
latency
=
best_latency
,
config
=
best_config
,
ref_latency
=
ref_latency
,
libcode
=
best_jit_context
.
profiler
.
func
.
lib_code
,
func
=
self
.
fn
(
*
best_config
),
kernel
=
best_jit_context
.
profiler
.
func
)
def
__call__
(
self
,
*
args
:
Any
,
**
kwds
:
Any
)
->
Any
:
def
__call__
(
self
)
->
Any
:
return
self
.
run
(
*
args
,
**
kwds
)
return
self
.
run
()
def
autotune
(
configs
:
Any
,
def
autotune
(
configs
:
Any
,
warmup
:
int
=
25
,
rep
:
int
=
100
,
timeout
:
int
=
100
)
->
Callable
:
keys
:
List
[
str
],
warmup
:
int
=
25
,
rep
:
int
=
100
,
timeout
:
int
=
100
)
->
Callable
:
"""
"""
Decorator for tilelang program
Decorator for tilelang program
"""
"""
def
decorator
(
fn
:
Callable
)
->
Autotuner
:
def
decorator
(
fn
:
Callable
)
->
AutoTuner
:
return
Autotuner
(
fn
,
configs
=
configs
,
keys
=
keys
,
warmup
=
warmup
,
rep
=
rep
,
timeout
=
timeout
)
autotuner
=
AutoTuner
(
fn
,
configs
=
configs
)
autotuner
.
jit_compile
=
fn
autotuner
.
run
=
partial
(
autotuner
.
run
,
warmup
,
rep
,
timeout
)
return
autotuner
return
decorator
return
decorator
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment