Commit 1873dc00 authored by yyttt6's avatar yyttt6 Committed by LeiWang1999
Browse files

[Example] Add autotune to conv example (#301)



* add autotune to example_gemm.py

* add autotune to conv

* still coding ...

* version 0

* version 0

* version 0

* refactor autotune

* refactor autotune

* add autotune to conv example

* add conv template to carver

* add conv template to carver

* add conv template to carver

* Update num_stages configuration values

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent d3bf4fe1
......@@ -4,7 +4,9 @@ from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from functools import partial
from tilelang.carver.template import ConvTemplate
from tilelang.carver.arch import CUDA
from tilelang.carver.roller.rasterization import NoRasterization
def check_hopper():
......@@ -16,11 +18,50 @@ def check_hopper():
return False
def get_configs():
def get_configs(N, C, H, W, F, K, S, D, P, with_roller=False):
if with_roller:
arch = CUDA("cuda")
topk = 10
carve_template = ConvTemplate(
N=N,
C=C,
H=H,
W=W,
F=F,
K=K,
S=S,
D=D,
P=P,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float",
).with_arch(arch)
func = carve_template.equivalent_function()
assert func is not None, "Function is None"
roller_hints = carve_template.recommend_hints(topk=topk)
if roller_hints is None:
raise ValueError("No Roller Hints Found for TensorCore Scheduling")
configs = []
for hint in roller_hints:
config = {}
block_m, block_n = hint.block
warp_m, warp_n = hint.warp
block_rows, block_cols = block_m // warp_m, block_n // warp_n
config["block_M"] = block_m
config["block_N"] = block_n
config["block_K"] = hint.rstep[0]
config["num_stages"] = hint.pipeline_stage
config["thread_num"] = block_rows * block_cols * 32
config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization
configs.append(config)
for config in configs:
print(config)
else:
block_M = [64, 128, 256]
block_N = [64, 128, 256]
block_K = [32, 64]
num_stages = [1, 2, 3, 4]
num_stages = [0, 1, 2, 3]
threads = [128, 256]
_configs = list(itertools.product(block_M, block_N, block_K, num_stages, threads))
......@@ -29,12 +70,24 @@ def get_configs():
'block_N': c[1],
'block_K': c[2],
'num_stages': c[3],
'threads': c[4]
'thread_num': c[4]
} for c in _configs]
return configs
def convolution(N, C, H, W, F, K, S, D, P, tune=False):
def ref_program(stride, padding, dilation):
def main(A, B):
A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W
B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W
C = torch.conv2d(A, B, stride=stride, padding=padding, dilation=dilation)
C = C.permute(0, 2, 3, 1) # N, C, H, W -> N, H, W, C
return C
return main
def get_best_config(N, C, H, W, F, K, S, D, P, with_roller):
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
......@@ -43,7 +96,13 @@ def convolution(N, C, H, W, F, K, S, D, P, tune=False):
accum_dtype = "float"
is_hopper = check_hopper()
def kernel_func(block_M, block_N, block_K, num_stages, threads):
def kernel(
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
):
@T.prim_func
def main(
......@@ -53,7 +112,7 @@ def convolution(N, C, H, W, F, K, S, D, P, tune=False):
):
with T.Kernel(
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
threads=threads) as (bx, by):
threads=thread_num) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
......@@ -90,32 +149,85 @@ def convolution(N, C, H, W, F, K, S, D, P, tune=False):
return main
if tune:
autotuner = AutoTuner.from_kernel(
kernel=kernel, configs=get_configs(N, C, H, W, F, K, S, D, P,
with_roller)).set_compile_args(
out_idx=[2],
supply_type=tilelang.TensorSupplyType.Integer,
ref_prog=ref_program(S, P, D),
skip_check=False,
target="auto",
)
return autotuner.run(warmup=10, rep=10)
@autotune(
configs=get_configs(),
keys=["block_M", "block_N", "block_K", "num_stages", "threads"],
warmup=10,
rep=10)
@jit(out_idx=[2], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None)
def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, block_K, num_stages, threads)
return kernel()
else:
def convolution(N,
C,
H,
W,
F,
K,
S,
D,
P,
block_M,
block_N,
block_K,
num_stages,
threads,
dtype="float16",
accum_dtype="float"):
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
dtype = "float16"
accum_dtype = "float"
is_hopper = check_hopper()
def kernel(block_M, block_N, block_K, num_stages, threads):
return kernel_func(block_M, block_N, block_K, num_stages, threads)
@T.prim_func
def main(
data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype),
):
with T.Kernel(
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
threads=threads) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
out_shared = T.alloc_shared((block_M, block_N), dtype)
return kernel
kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
T.annotate_layout({
out_shared: tilelang.layout.make_swizzled_layout(out_shared),
data_shared: tilelang.layout.make_swizzled_layout(data_shared),
kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared),
})
def ref_program(A, B, stride, padding, dilation):
A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W
B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W
C = torch.conv2d(A, B, stride=stride, padding=padding, dilation=dilation)
C = C.permute(0, 2, 3, 1) # N, C, H, W -> N, H, W, C
return C
T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
if is_hopper:
T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P)
else:
for i, j in T.Parallel(block_M, block_K):
k = k_iter * block_K + j
m = by * block_M + i
access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P
access_w = m % OW * S + k // C % KW * D - P
in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and
(access_w < W))
data_shared[i, j] = T.if_then_else(
in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
T.gemm(data_shared, kernel_shared, out_local)
T.copy(out_local, out_shared)
T.copy(out_shared, out_flat[by * block_M, bx * block_N])
return main
if __name__ == "__main__":
......@@ -129,31 +241,31 @@ if __name__ == "__main__":
parser.add_argument('--s', type=int, default=1, help='s')
parser.add_argument('--d', type=int, default=1, help='d')
parser.add_argument('--p', type=int, default=1, help='p')
parser.add_argument('--tune', action='store_true', help='tune configs')
parser.add_argument(
"--use_autotune",
action="store_true",
default=True,
help="Whether to use autotune for matmul configs")
parser.add_argument(
"--with_roller",
action="store_true",
default=True,
help="Whether to enable BitBLAS roller for search space")
args = parser.parse_args()
N, C, H, W, F, K, S, D, P = args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
total_flops = 2 * N * C * OH * OW * F * K * K
if (not args.tune):
program = convolution(
N, C, H, W, F, K, S, D, P, tune=args.tune)(
block_M=256, block_N=128, block_K=64, num_stages=4, threads=256)
ref_program = partial(ref_program, stride=S, padding=P, dilation=D)
kernel = tilelang.compile(program, out_idx=[2])
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
a = torch.randn(N, H, W, C).cuda().half()
b = torch.randn(K, K, C, F).cuda().half()
use_autotune = args.use_autotune
with_roller = args.with_roller
if use_autotune:
result = get_best_config(N, C, H, W, F, K, S, D, P, with_roller)
print(f"best latency {result.latency}")
kernel = result.kernel
else:
best_latency, best_config, ref_latency = convolution(
N, C, H, W, F, K, S, D, P, tune=args.tune)
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
kernel = tilelang.compile(
convolution(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256), out_idx=[2])
out_c = kernel(a, b)
ref_c = ref_program(S, P, D)(a, b)
torch.testing.assert_close(out_c, ref_c, rtol=1e-2, atol=1e-2)
......@@ -188,7 +188,6 @@ if __name__ == "__main__":
M, N, K = args.m, args.n, args.k
a = torch.randn(M, K).cuda().half()
b = torch.randn(N, K).cuda().half()
configs = []
use_autotune = args.use_autotune
with_roller = args.with_roller
if use_autotune:
......
......@@ -6,3 +6,4 @@ from .gemv import GEMVTemplate # noqa: F401
from .elementwise import ElementwiseTemplate # noqa: F401
from .general_reduce import GeneralReductionTemplate # noqa: F401
from .flashattention import FlashAttentionTemplate # noqa: F401
from .conv import ConvTemplate # noqa: F401
\ No newline at end of file
from dataclasses import dataclass
from .base import BaseTemplate
from tvm import te, tir
from ..roller import Hint
from typing import List
from ..utils import get_roller_hints_from_func
@dataclass
class ConvTemplate(BaseTemplate):
"""
A template for convolution (Conv).
This class defines the computation for a matrix-matrix convolution
with configurable parameters such as transposition, data types, and bias addition.
Attributes:
N (int): The number of input samples processed simultaneously in a batch.
C (int): The number of input feature maps.
H (int): The height of the input feature maps.
W (int): The width of the input feature maps.
F (int): The number of filters (kernels) applied, determining output depth.
K (int): The spatial dimensions of each convolutional filter.
S (int): The step size by which the kernel slides across the input.
D (int): The spacing between kernel elements, controlling receptive field expansion.
P (int): The number of pixels added to input borders to control output spatial dimensions.
in_dtype (str): Data type of input matrices.
out_dtype (str): Data type of output matrix.
accum_dtype (str): Data type used for accumulation.
with_bias (bool): Whether to add a bias term.
"""
# Operation-related configuration parameters
N: int # The number of input samples processed simultaneously in a batch.
C: int # The number of input feature maps.
H: int # The height of the input feature maps.
W: int # The width of the input feature maps.
F: int # The number of filters (kernels) applied, determining output depth.
K: int # The spatial dimensions of each convolutional filter.
S: int # The step size by which the kernel slides across the input.
D: int # The spacing between kernel elements, controlling receptive field expansion.
P: int # The number of pixels added to input borders to control output spatial dimensions.
in_dtype: str = "float16" # Data type of input matrices
out_dtype: str = "float16" # Data type of output matrix
accum_dtype: str = "float16" # Data type for accumulation
with_bias: bool = False # Whether to add a bias term
def get_hardware_aware_configs(self, arch=None, topk=10) -> List[Hint]:
"""
Retrieves optimized hardware-aware configurations.
Args:
arch (TileDevice, optional): The target hardware architecture.
topk (int, optional): Number of top configurations to consider.
Returns:
List[Hint]: A list of optimization hints for hardware acceleration.
"""
roller_hints = get_roller_hints_from_func(self._func, arch=arch, topk=topk, allow_gemv=True)
return roller_hints
def initialize_function(self) -> None:
"""
Defines and initializes the convolution computation.
This method sets up placeholders for input matrices, computes
the convolution using TVM's compute API,
and optionally applies bias and type casting.
Raises:
AssertionError: If N, C, H, W, F, K, S, D, P are not positive integers.
"""
N, C, H, W, F, K, S, D, P = self.N, self.C, self.H, self.W, self.F, self.K, self.S, self.D, self.P
assert (isinstance(N, int) and isinstance(C, int) and isinstance(H, int) and
isinstance(W, int) and isinstance(F, int) and isinstance(K, int) and
isinstance(S, int) and isinstance(D, int) and
isinstance(P, int)), "Only Support Integer Params"
assert (N > 0 and C > 0 and H > 0 and W > 0 and F > 0 and K > 0 and S > 0 and D > 0 and
P > 0), "Params should be positive"
# Load configuration parameters
in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype
with_bias = self.with_bias
# Calculate kernel dimensions and output dimensions
KH, KW = K, K
OH = (H + 2 * P - D * (KH - 1) - 1) // S + 1
OW = (W + 2 * P - D * (KW - 1) - 1) // S + 1
# Define tensor shapes
input_shape = (N, H, W, C) # NHWC format input tensor
weight_shape = (KH, KW, C, F) # HWCF format weight tensor
output_shape = (N, OH, OW, F) # NHWC format output tensor
bias_shape = (F,) # Bias vector shape
# Create TVM placeholders for input tensors
A = te.placeholder(input_shape, name="A", dtype=in_dtype) # Input tensor
B = te.placeholder(weight_shape, name="B", dtype=in_dtype) # Weight tensor
Bias = te.placeholder(bias_shape, name="Bias", dtype=accum_dtype) # Bias vector
# Define reduction axes for convolution
kh = te.reduce_axis((0, KH), name="kh")
kw = te.reduce_axis((0, KW), name="kw")
c = te.reduce_axis((0, C), name="c")
def _compute_conv(n, h, w, f):
"""
Compute function for convolution.
Args:
n (int): Batch index.
h (int): Output height index.
w (int): Output width index.
f (int): Output channel index.
Returns:
Computed value for output[n, h, w, f] as a sum over reduction axes.
"""
# Calculate input positions considering stride and dilation
h_in = h * S - P + kh * D
w_in = w * S - P + kw * D
# Check if the input position is within bounds (implicit padding with 0)
return te.sum(
te.if_then_else(
te.all(h_in >= 0, h_in < H, w_in >= 0, w_in < W),
A[n, h_in, w_in, c].astype(accum_dtype) * B[kh, kw, c, f].astype(accum_dtype),
tir.const(0, accum_dtype)),
axis=[kh, kw, c])
# Compute convolution result
C = te.compute(
output_shape,
fcompute=_compute_conv,
name="C",
)
# Optionally apply bias addition
if with_bias:
C = te.compute(
output_shape,
lambda n, h, w, f: C[n, h, w, f] + Bias[f],
name="Bias",
)
# Optionally cast the output to a different type
if out_dtype != accum_dtype:
C = te.compute(
output_shape,
lambda n, h, w, f: C[n, h, w, f].astype(out_dtype),
name="D",
)
# Set function arguments (including bias if used)
args = [A, B, Bias, C] if self.with_bias else [A, B, C]
self.set_function(te.create_prim_func(args))
def params_as_dict(self):
"""
Returns the template parameters as a dictionary.
Returns:
dict: Dictionary containing template parameter values.
"""
return {
"N": self.N,
"C": self.C,
"H": self.H,
"W": self.W,
"F": self.F,
"K": self.K,
"S": self.S,
"D": self.D,
"P": self.P,
"in_dtype": self.in_dtype,
"out_dtype": self.out_dtype,
"accum_dtype": self.accum_dtype,
"with_bias": self.with_bias,
}
@property
def class_attributes(self):
"""
Returns the class attributes in dictionary form.
Returns:
dict: Dictionary of class attributes.
"""
return self.params_as_dict()
def __repr__(self) -> str:
"""
Returns a string representation of the class instance.
Returns:
str: A formatted string representation of the class.
"""
cls_name = self.__class__.__name__
fields = self.class_attributes
field_str = ", ".join(f"{key}={value!r}" for key, value in fields.items())
return f"{cls_name}({field_str})"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment