Unverified Commit b4483090 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Autotune][Conv] optimize convolution examples to use autotune (#866)

parent 9cbbbbc6
...@@ -3,11 +3,6 @@ import argparse ...@@ -3,11 +3,6 @@ import argparse
import itertools import itertools
import tilelang import tilelang
import tilelang.language as T import tilelang.language as T
from tilelang.autotuner import AutoTuner
from tilelang.carver.template import ConvTemplate
from tilelang.carver.arch import CUDA
from tilelang.carver.arch import CDNA
from tilelang.carver.roller.rasterization import NoRasterization
def check_hopper(): def check_hopper():
...@@ -30,149 +25,36 @@ def ref_program(stride, padding, dilation): ...@@ -30,149 +25,36 @@ def ref_program(stride, padding, dilation):
return main return main
def get_configs(N, C, H, W, F, K, S, D, P, with_roller=False, topk=15): def get_configs():
if with_roller: block_M = [64, 128, 256]
arch = CDNA("hip") if torch.version.hip is not None else CUDA("cuda") block_N = [64, 128, 256]
carve_template = ConvTemplate( block_K = [32, 64]
N=N, num_stages = [0, 1, 2, 3]
C=C, thread_num = [128, 256]
H=H, enable_rasterization = [True, False]
W=W, _configs = list(
F=F, itertools.product(
K=K, block_M,
S=S, block_N,
D=D, block_K,
P=P, num_stages,
in_dtype="float16", thread_num,
out_dtype="float16", enable_rasterization,
accum_dtype="float", ))
).with_arch(arch)
configs = [
func = carve_template.equivalent_function() {
assert func is not None, "Function is None" "block_M": c[0],
roller_hints = carve_template.recommend_hints(topk=topk) "block_N": c[1],
if roller_hints is None: "block_K": c[2],
raise ValueError("No Roller Hints Found for TensorCore Scheduling") "num_stages": c[3],
configs = [] "thread_num": c[4],
for hint in roller_hints: "enable_rasteration": c[5], # keep param name for backward-compat
config = {} } for c in _configs
block_m, block_n = hint.block ]
warp_m, warp_n = hint.warp
# block_rows, block_cols represents warp partitioning
block_rows, block_cols = block_m // warp_m, block_n // warp_n
config["block_M"] = block_m
config["block_N"] = block_n
config["block_K"] = hint.rstep[0]
config["num_stages"] = hint.pipeline_stage if hint.pipeline_stage > 1 else 0
config["thread_num"] = block_rows * block_cols * 32
config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization
configs.append(config)
else:
block_M = [64, 128, 256]
block_N = [64, 128, 256]
block_K = [32, 64]
num_stages = [0, 1, 2, 3]
thread_num = [128, 256]
enable_rasterization = [True, False]
_configs = list(
itertools.product(
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasterization,
))
configs = [
{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"enable_rasteration": c[5], # keep param name for backward-compat
} for c in _configs
]
return configs return configs
def get_best_config(N, C, H, W, F, K, S, D, P, ref_prog, with_roller=False):
@tilelang.jit(out_idx=[2])
def kernel(
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
enable_rasteration=None,
):
dtype = "float16"
accum_dtype = "float"
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
is_hopper = check_hopper()
@T.prim_func
def main(
data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype),
):
with T.Kernel(
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
threads=thread_num) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
out_shared = T.alloc_shared((block_M, block_N), dtype)
kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
T.annotate_layout({
out_shared: tilelang.layout.make_swizzled_layout(out_shared),
data_shared: tilelang.layout.make_swizzled_layout(data_shared),
kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared),
})
T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
if is_hopper:
T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P)
else:
for i, j in T.Parallel(block_M, block_K):
k = k_iter * block_K + j
m = by * block_M + i
access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P
access_w = m % OW * S + k // C % KW * D - P
in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and
(access_w < W))
data_shared[i, j] = T.if_then_else(
in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
T.gemm(data_shared, kernel_shared, out_local)
T.copy(out_local, out_shared)
T.copy(out_shared, out_flat[by * block_M, bx * block_N])
return main
autotuner = AutoTuner.from_kernel(
kernel=kernel, configs=get_configs(N, C, H, W, F, K, S, D, P,
with_roller)).set_compile_args(
out_idx=[2],
target="auto",
).set_profile_args(
supply_type=tilelang.TensorSupplyType.Integer,
ref_prog=ref_prog,
skip_check=False,
)
return autotuner.run(warmup=3, rep=20)
def get_heuristic_config() -> dict: def get_heuristic_config() -> dict:
# Get CUDA device properties # Get CUDA device properties
if not torch.cuda.is_available(): if not torch.cuda.is_available():
...@@ -210,6 +92,7 @@ def get_heuristic_config() -> dict: ...@@ -210,6 +92,7 @@ def get_heuristic_config() -> dict:
} }
@tilelang.autotune(configs=get_configs())
@tilelang.jit(out_idx=[2]) @tilelang.jit(out_idx=[2])
def convolution(N, def convolution(N,
C, C,
...@@ -252,11 +135,10 @@ def convolution(N, ...@@ -252,11 +135,10 @@ def convolution(N,
kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
T.annotate_layout({ if is_hopper:
out_shared: tilelang.layout.make_swizzled_layout(out_shared), T.annotate_layout({
data_shared: tilelang.layout.make_swizzled_layout(data_shared), out_shared: tilelang.layout.make_swizzled_layout(out_shared),
kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared), })
})
T.clear(out_local) T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
...@@ -275,8 +157,11 @@ def convolution(N, ...@@ -275,8 +157,11 @@ def convolution(N,
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
T.gemm(data_shared, kernel_shared, out_local) T.gemm(data_shared, kernel_shared, out_local)
T.copy(out_local, out_shared) if is_hopper:
T.copy(out_shared, out_flat[by * block_M, bx * block_N]) T.copy(out_local, out_shared)
T.copy(out_shared, out_flat[by * block_M, bx * block_N])
else:
T.copy(out_local, out_flat[by * block_M, bx * block_N])
return main return main
...@@ -296,9 +181,7 @@ def main(n: int = 128, ...@@ -296,9 +181,7 @@ def main(n: int = 128,
ref_prog = ref_program(S, P, D) ref_prog = ref_program(S, P, D)
if use_autotune: if use_autotune:
result = get_best_config(N, C, H, W, F, K, S, D, P, ref_prog, with_roller) kernel = convolution(N, C, H, W, F, K, S, D, P)
print(result.config)
kernel = result.kernel
else: else:
config = get_heuristic_config() config = get_heuristic_config()
kernel = convolution(N, C, H, W, F, K, S, D, P, **config) kernel = convolution(N, C, H, W, F, K, S, D, P, **config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment