Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
541e1685
Commit
541e1685
authored
Mar 25, 2025
by
yyttt6
Committed by
LeiWang1999
Mar 25, 2025
Browse files
[Refactor] Enhance Autotune (#266)
* add autotune to example_gemm.py * format init.py
parent
8ad53855
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
369 additions
and
106 deletions
+369
-106
examples/gemm/example_gemm.py
examples/gemm/example_gemm.py
+15
-30
testing/python/autotune/test_tilelang_autotune.py
testing/python/autotune/test_tilelang_autotune.py
+12
-22
testing/python/autotune/test_tilelang_autotune_decorator.py
testing/python/autotune/test_tilelang_autotune_decorator.py
+269
-0
tilelang/autotuner/__init__.py
tilelang/autotuner/__init__.py
+73
-54
No files found.
examples/gemm/example_gemm.py
View file @
541e1685
...
...
@@ -3,7 +3,7 @@ import torch
import
itertools
import
tilelang
as
tl
import
tilelang.language
as
T
from
tilelang.autotuner
import
a
uto
t
une
,
jit
from
tilelang.autotuner
import
A
uto
T
une
r
from
tilelang.carver.template
import
MatmulTemplate
from
tilelang.carver.arch
import
CUDA
from
tilelang.carver.roller.rasterization
import
NoRasterization
...
...
@@ -79,26 +79,6 @@ def get_configs(M, N, K, with_roller=False):
def
get_best_config
(
M
,
N
,
K
,
with_roller
=
False
):
@
autotune
(
configs
=
get_configs
(
M
,
N
,
K
,
with_roller
),
keys
=
[
"block_M"
,
"block_N"
,
"block_K"
,
"num_stages"
,
"thread_num"
,
"enable_rasteration"
,
],
warmup
=
3
,
rep
=
20
,
)
@
jit
(
out_idx
=
[
-
1
],
supply_type
=
tl
.
TensorSupplyType
.
Integer
,
ref_prog
=
ref_program
,
skip_check
=
False
,
target
=
"auto"
,
)
def
kernel
(
block_M
=
None
,
block_N
=
None
,
...
...
@@ -138,7 +118,15 @@ def get_best_config(M, N, K, with_roller=False):
return
main
return
kernel
()
autotuner
=
AutoTuner
.
from_kernel
(
kernel
=
kernel
,
configs
=
get_configs
(
M
,
N
,
K
,
with_roller
)).
set_compile_args
(
out_idx
=
[
-
1
],
supply_type
=
tl
.
TensorSupplyType
.
Integer
,
ref_prog
=
ref_program
,
skip_check
=
False
,
target
=
"auto"
,
)
return
autotuner
.
run
(
warmup
=
3
,
rep
=
20
)
def
matmul
(
M
,
...
...
@@ -200,19 +188,16 @@ if __name__ == "__main__":
M
,
N
,
K
=
args
.
m
,
args
.
n
,
args
.
k
a
=
torch
.
randn
(
M
,
K
).
cuda
().
half
()
b
=
torch
.
randn
(
N
,
K
).
cuda
().
half
()
c
=
torch
.
zeros
(
M
,
N
).
cuda
().
half
()
configs
=
[]
use_autotune
=
args
.
use_autotune
with_roller
=
args
.
with_roller
if
use_autotune
:
best_latency
,
best_config
,
ref_latency
=
get_best_config
(
M
,
N
,
K
,
with_roller
)
func
=
matmul
(
M
,
N
,
K
,
*
best_config
)
result
=
get_best_config
(
M
,
N
,
K
,
with_roller
)
print
(
f
"best latency
{
result
.
latency
}
"
)
kernel
=
result
.
kernel
else
:
func
=
matmul
(
M
,
N
,
K
,
128
,
128
,
32
,
3
,
128
,
True
)
kernel
=
tl
.
compile
(
matmul
(
M
,
N
,
K
,
128
,
128
,
32
,
3
,
128
,
True
)
,
out_idx
=-
1
)
# print(func)
kernel
=
tl
.
compile
(
func
,
out_idx
=-
1
)
out_c
=
kernel
(
a
,
b
)
ref_c
=
a
@
b
.
T
+
c
ref_c
=
ref_program
(
a
,
b
)
torch
.
testing
.
assert_close
(
out_c
,
ref_c
,
rtol
=
1e-2
,
atol
=
1e-2
)
# print(kernel.get_kernel_source())
testing/python/autotune/test_tilelang_autotune.py
View file @
541e1685
...
...
@@ -4,7 +4,7 @@ import logging
import
tilelang
as
tl
import
tilelang.testing
import
tilelang.language
as
T
from
tilelang.autotuner
import
a
uto
t
une
,
jit
from
tilelang.autotuner
import
A
uto
T
une
r
# Configure logger
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -151,26 +151,6 @@ def matmul(M, N, K, with_roller):
# - The "tvm" profiler backend
# - HIP as the compilation target (modify as needed for your hardware)
@
autotune
(
configs
=
get_configs
(
M
,
N
,
K
,
with_roller
),
keys
=
[
"block_M"
,
"block_N"
,
"block_K"
,
"num_stages"
,
"thread_num"
,
"enable_rasteration"
,
],
warmup
=
3
,
rep
=
5
,
)
@
jit
(
out_idx
=
[
2
],
supply_type
=
tl
.
TensorSupplyType
.
Integer
,
ref_prog
=
ref_program
,
skip_check
=
True
,
target
=
"auto"
,
)
def
kernel
(
block_M
=
None
,
block_N
=
None
,
...
...
@@ -268,14 +248,24 @@ def matmul(M, N, K, with_roller):
return
main
return
kernel
()
autotuner
=
AutoTuner
.
from_kernel
(
kernel
=
kernel
,
configs
=
get_configs
(
M
,
N
,
K
,
with_roller
)).
set_compile_args
(
out_idx
=
[
-
1
],
supply_type
=
tl
.
TensorSupplyType
.
Integer
,
ref_prog
=
ref_program
,
skip_check
=
False
,
target
=
"auto"
,
)
return
autotuner
.
run
(
warmup
=
3
,
rep
=
20
)
def
test_autotune_get_configs
():
get_configs
(
8192
,
8192
,
8192
,
with_roller
=
True
)
get_configs
(
8192
,
8192
,
8192
,
with_roller
=
False
)
def
test_autotune_matmul
():
matmul
(
8192
,
8192
,
8192
,
with_roller
=
True
)
matmul
(
8192
,
8192
,
8192
,
with_roller
=
False
)
...
...
testing/python/autotune/test_tilelang_autotune_decorator.py
0 → 100644
View file @
541e1685
import
itertools
import
logging
import
tilelang
as
tl
import
tilelang.testing
import
tilelang.language
as
T
from
tilelang.autotuner
import
jit
,
autotune
# Configure logger
logger
=
logging
.
getLogger
(
__name__
)
logger
.
setLevel
(
logging
.
DEBUG
)
def
ref_program
(
A
,
B
):
"""
A reference matrix multiplication program, used to compare performance.
Parameters
----------
A : numpy.ndarray
The matrix with shape (M, K).
B : numpy.ndarray
The matrix with shape (N, K).
Returns
-------
np.ndarray
The result of A @ B.T, shape (M, N).
"""
return
A
@
B
.
T
def
get_configs
(
M
,
N
,
K
,
with_roller
=
False
):
"""
Generate a list of configuration dictionaries that will be used for tuning.
Parameters
----------
with_roller : bool
Whether to enable bitblas roller to deduce search spaces
Returns
-------
list of dict
Each configuration dict includes various block sizes, pipeline stages,
thread numbers, and other parameters to explore during autotuning.
"""
if
with_roller
:
from
tilelang.carver.template
import
MatmulTemplate
from
tilelang.carver.arch
import
CUDA
from
tilelang.carver.roller.rasterization
import
NoRasterization
arch
=
CUDA
(
"cuda"
)
topk
=
20
# Simple TIR Compute Expression
carve_template
=
MatmulTemplate
(
M
=
M
,
N
=
N
,
K
=
K
,
in_dtype
=
"float16"
,
out_dtype
=
"float16"
,
accum_dtype
=
"float16"
,
).
with_arch
(
arch
)
func
=
carve_template
.
equivalent_function
()
assert
func
is
not
None
,
"Function is None"
roller_hints
=
carve_template
.
recommend_hints
(
topk
=
topk
)
if
roller_hints
is
None
:
raise
ValueError
(
"No Roller Hints Found for TensorCore Scheduling"
)
configs
=
[]
for
hint
in
roller_hints
:
config
=
{}
block_m
,
block_n
=
hint
.
block
warp_m
,
warp_n
=
hint
.
warp
config
[
"block_M"
]
=
block_m
config
[
"block_N"
]
=
block_n
config
[
"block_K"
]
=
hint
.
rstep
[
0
]
config
[
"num_stages"
]
=
0
config
[
"thread_num"
]
=
(
block_m
*
block_n
)
//
(
warp_m
*
warp_n
)
*
32
config
[
"enable_rasteration"
]
=
hint
.
rasterization_plan
is
not
NoRasterization
configs
.
append
(
config
)
for
config
in
configs
:
print
(
config
)
else
:
block_M
=
[
64
]
block_N
=
[
64
]
block_K
=
[
32
]
num_stages
=
[
0
,
1
]
thread_num
=
[
128
]
enable_rasterization
=
[
False
]
_configs
=
list
(
itertools
.
product
(
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasterization
,
))
configs
=
[
{
"block_M"
:
c
[
0
],
"block_N"
:
c
[
1
],
"block_K"
:
c
[
2
],
"num_stages"
:
c
[
3
],
"thread_num"
:
c
[
4
],
"enable_rasteration"
:
c
[
5
],
# keep param name for backward-compat
}
for
c
in
_configs
]
return
configs
def
matmul
(
M
,
N
,
K
,
with_roller
):
"""
Create an autotuned matrix multiplication kernel for matrices of shape:
- A: (M, K)
- B: (N, K)
- C: (M, N)
Parameters
----------
M : int
The dimension M of the matrix multiplication.
N : int
The dimension N of the matrix multiplication.
K : int
The dimension K of the matrix multiplication.
Returns
-------
(best_latency, best_config, ref_latency)
best_latency : float
The best latency found among the tuned configurations.
best_config : dict
The parameter configuration that yielded best_latency.
ref_latency : float
The baseline latency of the reference program (for computing speedup).
"""
@
autotune
(
configs
=
get_configs
(
M
,
N
,
K
,
with_roller
),
warmup
=
3
,
rep
=
20
,
)
@
jit
(
out_idx
=
[
-
1
],
supply_type
=
tl
.
TensorSupplyType
.
Integer
,
ref_prog
=
ref_program
,
skip_check
=
False
,
target
=
"auto"
,
)
def
kernel
(
block_M
=
None
,
block_N
=
None
,
block_K
=
None
,
num_stages
=
None
,
thread_num
=
None
,
enable_rasteration
=
None
,
):
"""
The actual kernel to compute C = A @ B^T.
Parameters
----------
block_M : int
Block size in M dimension.
block_N : int
Block size in N dimension.
block_K : int
Block size in K dimension.
num_stages : int
Number of pipelined stages (for asynchronous load).
thread_num : int
Number of threads to use per block.
enable_rasteration : bool
Whether to enable rasterization (swizzling) optimization.
k_pack : int
K dimension packing factor to improve memory coalescing.
Returns
-------
Function
A TVM Tensor Language function (T.prim_func) that computes matmul.
"""
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
dtype
=
"float16"
accum_dtype
=
"float"
@
T
.
prim_func
def
main
(
A
:
T
.
Buffer
((
M
,
K
),
dtype
),
B
:
T
.
Buffer
((
N
,
K
),
dtype
),
C
:
T
.
Buffer
((
M
,
N
),
dtype
),
):
"""
The compiled TVM function for block-level matrix multiplication.
- We divide the entire (M, N) domain into blocks of shape
(block_M, block_N).
- Each block has its own allocated shared memory for sub-blocks
of A and B.
- The partial results go into C_local, and then we copy them back
to global memory C.
"""
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
dtype
)
# Allocate a local fragment for intermediate accumulation
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
# Enable (or disable) swizzling optimization
T
.
use_swizzle
(
panel_size
=
10
,
enable
=
enable_rasteration
)
# Clear out the accumulation buffer
T
.
clear
(
C_local
)
# Loop over sub-blocks in K dimension, pipelined by num_stages
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
# Load a sub-block of A from global memory into A_shared
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
,
)
# Load a sub-block of B from global memory into B_shared
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
,
)
# Perform a partial matrix multiplication:
# C_local += A_shared @ B_shared^T
T
.
gemm
(
A_shared
,
B_shared
,
C_local
,
transpose_B
=
True
,
)
# Write back the results from C_local to the global memory C
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
return
main
return
kernel
()
def
test_autotune_get_configs
():
get_configs
(
8192
,
8192
,
8192
,
with_roller
=
True
)
get_configs
(
8192
,
8192
,
8192
,
with_roller
=
False
)
def
test_autotune_matmul
():
matmul
(
8192
,
8192
,
8192
,
with_roller
=
True
)
matmul
(
8192
,
8192
,
8192
,
with_roller
=
False
)
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
tilelang/autotuner/__init__.py
View file @
541e1685
...
...
@@ -3,14 +3,13 @@
import
tilelang
from
tilelang
import
tvm
as
tvm
import
inspect
from
functools
import
wraps
from
typing
import
Any
,
Callable
,
List
,
Literal
from
functools
import
wraps
,
partial
from
typing
import
Callable
,
List
,
Literal
,
Any
from
tqdm
import
tqdm
import
logging
from
dataclasses
import
dataclass
import
concurrent.futures
import
os
from
functools
import
partial
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -34,40 +33,65 @@ class JITContext:
target
:
Literal
[
'cuda'
,
'hip'
]
class
Autotuner
:
@
dataclass
(
frozen
=
True
)
class
AutotuneResult
:
latency
:
float
config
:
dict
ref_latency
:
float
libcode
:
str
func
:
Callable
kernel
:
Callable
class
AutoTuner
:
def
__init__
(
self
,
fn
:
Callable
,
configs
:
Any
,
keys
:
List
[
str
],
warmup
:
int
=
25
,
rep
:
int
=
100
,
timeout
:
int
=
30
,
):
def
__init__
(
self
,
fn
:
Callable
,
configs
):
self
.
fn
=
fn
self
.
configs
=
configs
self
.
keys
=
keys
self
.
warmup
=
warmup
self
.
rep
=
rep
self
.
timeout
=
timeout
# Precompute cached variables
self
.
ref_latency_cache
=
None
self
.
jit_input_tensors
=
None
self
.
ref_input_tensors
=
None
def
jit_compile
(
self
,
config_arg
)
->
JITContext
:
jit_context
=
self
.
fn
(
*
config_arg
)
return
jit_context
@
classmethod
def
from_kernel
(
cls
,
kernel
:
Callable
,
configs
):
return
cls
(
kernel
,
configs
)
def
set_compile_args
(
self
,
out_idx
:
List
[
int
],
supply_type
:
tilelang
.
TensorSupplyType
=
tilelang
.
TensorSupplyType
.
Normal
,
ref_prog
:
Callable
=
None
,
rtol
:
float
=
1e-2
,
atol
:
float
=
1e-2
,
max_mismatched_ratio
:
float
=
0.01
,
skip_check
:
bool
=
False
,
target
:
Literal
[
'auto'
,
'cuda'
,
'hip'
]
=
'auto'
):
def
_compile
(
*
config_arg
):
kernel
=
tilelang
.
compile
(
self
.
fn
(
*
config_arg
),
out_idx
=
out_idx
,
target
=
target
)
profiler
=
kernel
.
get_profiler
()
jit_context
=
JITContext
(
out_idx
=
out_idx
,
supply_type
=
supply_type
,
ref_prog
=
ref_prog
,
rtol
=
rtol
,
atol
=
atol
,
max_mismatched_ratio
=
max_mismatched_ratio
,
skip_check
=
skip_check
,
profiler
=
profiler
,
target
=
target
)
return
jit_context
self
.
jit_compile
=
_compile
return
self
def
run
(
self
,
*
ar
gs
:
Any
,
**
kwds
:
Any
)
->
Any
:
def
run
(
self
,
w
ar
mup
:
int
=
25
,
rep
:
int
=
100
,
timeout
:
int
=
100
)
:
sig
=
inspect
.
signature
(
self
.
fn
)
bound_args
=
sig
.
bind
(
*
args
,
**
kwds
)
keys
=
list
(
sig
.
parameters
.
keys
())
bound_args
=
sig
.
bind
()
bound_args
.
apply_defaults
()
best_latency
=
1e8
best_config
=
None
best_jit_context
=
None
def
target_fn
(
jit_context
):
# Unpack the context
...
...
@@ -87,49 +111,38 @@ class Autotuner:
ref_prog
,
rtol
=
rtol
,
atol
=
atol
,
max_mismatched_ratio
=
max_mismatched_ratio
)
latency
=
profiler
.
do_bench
(
profiler
.
func
,
n_warmup
=
self
.
warmup
,
n_repeat
=
self
.
rep
,
input_tensors
=
self
.
jit_input_tensors
)
profiler
.
func
,
n_warmup
=
warmup
,
n_repeat
=
rep
,
input_tensors
=
self
.
jit_input_tensors
)
if
self
.
ref_latency_cache
is
None
and
ref_prog
is
not
None
:
self
.
ref_input_tensors
=
profiler
.
_get_inputs
(
with_output
=
False
)
if
self
.
ref_input_tensors
is
None
else
self
.
ref_input_tensors
self
.
ref_latency_cache
=
profiler
.
do_bench
(
ref_prog
,
n_warmup
=
self
.
warmup
,
n_repeat
=
self
.
rep
,
input_tensors
=
self
.
ref_input_tensors
)
ref_prog
,
n_warmup
=
warmup
,
n_repeat
=
rep
,
input_tensors
=
self
.
ref_input_tensors
)
return
latency
,
self
.
ref_latency_cache
# Parallel compilation
config_args
=
[]
for
config
in
self
.
configs
:
new_args
=
[]
for
name
,
value
in
bound_args
.
arguments
.
items
():
if
name
not
in
self
.
keys
:
if
name
not
in
keys
:
new_args
.
append
(
value
)
else
:
new_args
.
append
(
config
[
name
])
new_args
=
tuple
(
new_args
)
config_args
.
append
(
new_args
)
worker
=
partial
(
self
.
jit_compile
,
**
kwds
)
# 90% utilization
num_workers
=
max
(
1
,
int
(
os
.
cpu_count
()
*
0.9
))
pool
=
concurrent
.
futures
.
ThreadPoolExecutor
(
max_workers
=
num_workers
)
# Submit all compilation jobs
futures
=
[]
future_to_index
=
{}
# Track which future corresponds to which config
future_to_index
=
{}
for
i
,
config_arg
in
enumerate
(
config_args
):
future
=
pool
.
submit
(
worker
,
config_arg
)
future
=
pool
.
submit
(
self
.
jit_compile
,
*
config_arg
,
)
futures
.
append
(
future
)
future_to_index
[
future
]
=
i
# Process results with error handling
results_with_configs
=
[]
for
future
in
tqdm
(
concurrent
.
futures
.
as_completed
(
futures
),
...
...
@@ -164,28 +177,34 @@ class Autotuner:
if
latency
<
best_latency
:
best_latency
=
latency
best_config
=
config
best_jit_context
=
jit_context
progress_bar
.
set_postfix
({
"best_latency"
:
best_latency
})
tqdm
.
write
(
f
"Tuned Latency
{
latency
}
with config
{
config
}
at index
{
i
}
"
)
pool
.
shutdown
()
return
best_latency
,
best_config
,
ref_latency
return
AutotuneResult
(
latency
=
best_latency
,
config
=
best_config
,
ref_latency
=
ref_latency
,
libcode
=
best_jit_context
.
profiler
.
func
.
lib_code
,
func
=
self
.
fn
(
*
best_config
),
kernel
=
best_jit_context
.
profiler
.
func
)
def
__call__
(
self
,
*
args
:
Any
,
**
kwds
:
Any
)
->
Any
:
return
self
.
run
(
*
args
,
**
kwds
)
def
__call__
(
self
)
->
Any
:
return
self
.
run
()
def
autotune
(
configs
:
Any
,
keys
:
List
[
str
],
warmup
:
int
=
25
,
rep
:
int
=
100
,
timeout
:
int
=
100
)
->
Callable
:
def
autotune
(
configs
:
Any
,
warmup
:
int
=
25
,
rep
:
int
=
100
,
timeout
:
int
=
100
)
->
Callable
:
"""
Decorator for tilelang program
"""
def
decorator
(
fn
:
Callable
)
->
Autotuner
:
return
Autotuner
(
fn
,
configs
=
configs
,
keys
=
keys
,
warmup
=
warmup
,
rep
=
rep
,
timeout
=
timeout
)
def
decorator
(
fn
:
Callable
)
->
AutoTuner
:
autotuner
=
AutoTuner
(
fn
,
configs
=
configs
)
autotuner
.
jit_compile
=
fn
autotuner
.
run
=
partial
(
autotuner
.
run
,
warmup
,
rep
,
timeout
)
return
autotuner
return
decorator
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment