Unverified Commit b4483090 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Autotune][Conv] optimize convolution examples to use autotune (#866)

parent 9cbbbbc6
...@@ -3,11 +3,6 @@ import argparse ...@@ -3,11 +3,6 @@ import argparse
import itertools import itertools
import tilelang import tilelang
import tilelang.language as T import tilelang.language as T
from tilelang.autotuner import AutoTuner
from tilelang.carver.template import ConvTemplate
from tilelang.carver.arch import CUDA
from tilelang.carver.arch import CDNA
from tilelang.carver.roller.rasterization import NoRasterization
def check_hopper(): def check_hopper():
...@@ -30,44 +25,7 @@ def ref_program(stride, padding, dilation): ...@@ -30,44 +25,7 @@ def ref_program(stride, padding, dilation):
return main return main
def get_configs(N, C, H, W, F, K, S, D, P, with_roller=False, topk=15): def get_configs():
if with_roller:
arch = CDNA("hip") if torch.version.hip is not None else CUDA("cuda")
carve_template = ConvTemplate(
N=N,
C=C,
H=H,
W=W,
F=F,
K=K,
S=S,
D=D,
P=P,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float",
).with_arch(arch)
func = carve_template.equivalent_function()
assert func is not None, "Function is None"
roller_hints = carve_template.recommend_hints(topk=topk)
if roller_hints is None:
raise ValueError("No Roller Hints Found for TensorCore Scheduling")
configs = []
for hint in roller_hints:
config = {}
block_m, block_n = hint.block
warp_m, warp_n = hint.warp
# block_rows, block_cols represents warp partitioning
block_rows, block_cols = block_m // warp_m, block_n // warp_n
config["block_M"] = block_m
config["block_N"] = block_n
config["block_K"] = hint.rstep[0]
config["num_stages"] = hint.pipeline_stage if hint.pipeline_stage > 1 else 0
config["thread_num"] = block_rows * block_cols * 32
config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization
configs.append(config)
else:
block_M = [64, 128, 256] block_M = [64, 128, 256]
block_N = [64, 128, 256] block_N = [64, 128, 256]
block_K = [32, 64] block_K = [32, 64]
...@@ -97,82 +55,6 @@ def get_configs(N, C, H, W, F, K, S, D, P, with_roller=False, topk=15): ...@@ -97,82 +55,6 @@ def get_configs(N, C, H, W, F, K, S, D, P, with_roller=False, topk=15):
return configs return configs
def get_best_config(N, C, H, W, F, K, S, D, P, ref_prog, with_roller=False):
@tilelang.jit(out_idx=[2])
def kernel(
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
enable_rasteration=None,
):
dtype = "float16"
accum_dtype = "float"
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
is_hopper = check_hopper()
@T.prim_func
def main(
data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype),
):
with T.Kernel(
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
threads=thread_num) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
out_shared = T.alloc_shared((block_M, block_N), dtype)
kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
T.annotate_layout({
out_shared: tilelang.layout.make_swizzled_layout(out_shared),
data_shared: tilelang.layout.make_swizzled_layout(data_shared),
kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared),
})
T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
if is_hopper:
T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P)
else:
for i, j in T.Parallel(block_M, block_K):
k = k_iter * block_K + j
m = by * block_M + i
access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P
access_w = m % OW * S + k // C % KW * D - P
in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and
(access_w < W))
data_shared[i, j] = T.if_then_else(
in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
T.gemm(data_shared, kernel_shared, out_local)
T.copy(out_local, out_shared)
T.copy(out_shared, out_flat[by * block_M, bx * block_N])
return main
autotuner = AutoTuner.from_kernel(
kernel=kernel, configs=get_configs(N, C, H, W, F, K, S, D, P,
with_roller)).set_compile_args(
out_idx=[2],
target="auto",
).set_profile_args(
supply_type=tilelang.TensorSupplyType.Integer,
ref_prog=ref_prog,
skip_check=False,
)
return autotuner.run(warmup=3, rep=20)
def get_heuristic_config() -> dict: def get_heuristic_config() -> dict:
# Get CUDA device properties # Get CUDA device properties
if not torch.cuda.is_available(): if not torch.cuda.is_available():
...@@ -210,6 +92,7 @@ def get_heuristic_config() -> dict: ...@@ -210,6 +92,7 @@ def get_heuristic_config() -> dict:
} }
@tilelang.autotune(configs=get_configs())
@tilelang.jit(out_idx=[2]) @tilelang.jit(out_idx=[2])
def convolution(N, def convolution(N,
C, C,
...@@ -252,10 +135,9 @@ def convolution(N, ...@@ -252,10 +135,9 @@ def convolution(N,
kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
if is_hopper:
T.annotate_layout({ T.annotate_layout({
out_shared: tilelang.layout.make_swizzled_layout(out_shared), out_shared: tilelang.layout.make_swizzled_layout(out_shared),
data_shared: tilelang.layout.make_swizzled_layout(data_shared),
kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared),
}) })
T.clear(out_local) T.clear(out_local)
...@@ -275,8 +157,11 @@ def convolution(N, ...@@ -275,8 +157,11 @@ def convolution(N,
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
T.gemm(data_shared, kernel_shared, out_local) T.gemm(data_shared, kernel_shared, out_local)
if is_hopper:
T.copy(out_local, out_shared) T.copy(out_local, out_shared)
T.copy(out_shared, out_flat[by * block_M, bx * block_N]) T.copy(out_shared, out_flat[by * block_M, bx * block_N])
else:
T.copy(out_local, out_flat[by * block_M, bx * block_N])
return main return main
...@@ -296,9 +181,7 @@ def main(n: int = 128, ...@@ -296,9 +181,7 @@ def main(n: int = 128,
ref_prog = ref_program(S, P, D) ref_prog = ref_program(S, P, D)
if use_autotune: if use_autotune:
result = get_best_config(N, C, H, W, F, K, S, D, P, ref_prog, with_roller) kernel = convolution(N, C, H, W, F, K, S, D, P)
print(result.config)
kernel = result.kernel
else: else:
config = get_heuristic_config() config = get_heuristic_config()
kernel = convolution(N, C, H, W, F, K, S, D, P, **config) kernel = convolution(N, C, H, W, F, K, S, D, P, **config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment