Commit 3c53297b authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Carver] Enhance Carver Adaptation for MatMul Benchmarking (#153)

* [Refactor] Consolidate GemmWarpPolicy Enum and Add Utility Method

- Move GemmWarpPolicy from copy.py and gemm.py to primitives/gemm/base.py
- Implement from_warp_partition class method to determine warp policy
- Add docstring with examples for policy determination
- Remove duplicate GemmWarpPolicy class definitions

* [Enhancement] Add TensorCore Intrinsic Matrix Multiplication Benchmarks

- Implement two new matrix multiplication benchmark scripts:
  1. `benchmark_matmul_intrinsic.py`: Uses TensorCore intrinsics with advanced configuration
  2. `benchmark_matmul.py`: Provides a more generic matrix multiplication benchmark

- Add support for roller-based configuration generation in both benchmarks
- Enhance MMA macro generator to handle 2D and 4D output buffer layouts
- Implement flexible autotuning configurations with multiple parameters
- Support different data types and accumulation modes
- Add command-line arguments for matrix dimensions and roller configuration

* lint fix

* Fix roller hints generation in get_roller_hints_from_func

- Simplify roller hints generation logic
- Ensure policy-based configuration is always emitted when a policy is available
- Remove redundant None check for roller hints

* Add shared memory for matrix multiplication in benchmark and quickstart examples

- Modify benchmark_matmul.py and quickstart.py to include C_shared allocation
- Change accumulation dtype from float16 to float in benchmark_matmul.py
- Update matrix multiplication kernels to use shared memory for result storage
- Enable CUDA kernel source printing in quickstart example
parent e945dae2
...@@ -50,7 +50,7 @@ def get_configs(M, N, K, with_roller=False): ...@@ -50,7 +50,7 @@ def get_configs(M, N, K, with_roller=False):
from tilelang.carver.arch import CUDA from tilelang.carver.arch import CUDA
from tilelang.carver.roller.rasterization import NoRasterization from tilelang.carver.roller.rasterization import NoRasterization
arch = CUDA("cuda") arch = CUDA("cuda")
topk = 20 topk = 10
carve_template = MatmulTemplate( carve_template = MatmulTemplate(
M=M, M=M,
...@@ -58,7 +58,7 @@ def get_configs(M, N, K, with_roller=False): ...@@ -58,7 +58,7 @@ def get_configs(M, N, K, with_roller=False):
K=K, K=K,
in_dtype="float16", in_dtype="float16",
out_dtype="float16", out_dtype="float16",
accum_dtype="float16", accum_dtype="float",
).with_arch(arch) ).with_arch(arch)
func = carve_template.equivalent_function() func = carve_template.equivalent_function()
...@@ -74,11 +74,14 @@ def get_configs(M, N, K, with_roller=False): ...@@ -74,11 +74,14 @@ def get_configs(M, N, K, with_roller=False):
config = {} config = {}
block_m, block_n = hint.block block_m, block_n = hint.block
warp_m, warp_n = hint.warp warp_m, warp_n = hint.warp
# block_rows, block_cols represents warp partitioning
block_rows, block_cols = block_m // warp_m, block_n // warp_n
config["block_M"] = block_m config["block_M"] = block_m
config["block_N"] = block_n config["block_N"] = block_n
config["block_K"] = hint.rstep[0] config["block_K"] = hint.rstep[0]
config["num_stages"] = 0 config["num_stages"] = hint.pipeline_stage
config["thread_num"] = (block_m * block_n) // (warp_m * warp_n) * 32 config["thread_num"] = block_rows * block_cols * 32
config["policy"] = T.GemmWarpPolicy.from_warp_partition(block_rows, block_cols)
config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization
configs.append(config) configs.append(config)
for config in configs: for config in configs:
...@@ -90,8 +93,8 @@ def get_configs(M, N, K, with_roller=False): ...@@ -90,8 +93,8 @@ def get_configs(M, N, K, with_roller=False):
block_K = [32, 64] block_K = [32, 64]
num_stages = [0, 1, 2, 3] num_stages = [0, 1, 2, 3]
thread_num = [128, 256] thread_num = [128, 256]
policy = [T.GemmWarpPolicy.Square]
enable_rasterization = [True, False] enable_rasterization = [True, False]
_configs = list( _configs = list(
itertools.product( itertools.product(
block_M, block_M,
...@@ -99,6 +102,7 @@ def get_configs(M, N, K, with_roller=False): ...@@ -99,6 +102,7 @@ def get_configs(M, N, K, with_roller=False):
block_K, block_K,
num_stages, num_stages,
thread_num, thread_num,
policy,
enable_rasterization, enable_rasterization,
)) ))
...@@ -109,7 +113,8 @@ def get_configs(M, N, K, with_roller=False): ...@@ -109,7 +113,8 @@ def get_configs(M, N, K, with_roller=False):
"block_K": c[2], "block_K": c[2],
"num_stages": c[3], "num_stages": c[3],
"thread_num": c[4], "thread_num": c[4],
"enable_rasteration": c[5], # keep param name for backward-compat "policy": c[5],
"enable_rasteration": c[6], # keep param name for backward-compat
} for c in _configs } for c in _configs
] ]
return configs return configs
...@@ -158,6 +163,7 @@ def matmul(M, N, K, with_roller): ...@@ -158,6 +163,7 @@ def matmul(M, N, K, with_roller):
"block_K", "block_K",
"num_stages", "num_stages",
"thread_num", "thread_num",
"policy",
"enable_rasteration", "enable_rasteration",
], ],
warmup=3, warmup=3,
...@@ -177,6 +183,7 @@ def matmul(M, N, K, with_roller): ...@@ -177,6 +183,7 @@ def matmul(M, N, K, with_roller):
block_K=None, block_K=None,
num_stages=None, num_stages=None,
thread_num=None, thread_num=None,
policy=None,
enable_rasteration=None, enable_rasteration=None,
): ):
""" """
...@@ -236,6 +243,8 @@ def matmul(M, N, K, with_roller): ...@@ -236,6 +243,8 @@ def matmul(M, N, K, with_roller):
B_shared = T.alloc_shared((block_N, block_K), dtype) B_shared = T.alloc_shared((block_N, block_K), dtype)
# Allocate a local fragment for intermediate accumulation # Allocate a local fragment for intermediate accumulation
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Allocate a shared memory for C sub-block of shape (block_M, block_N)
C_shared = T.alloc_shared((block_M, block_N), dtype)
# Enable (or disable) swizzling optimization # Enable (or disable) swizzling optimization
T.use_swizzle(panel_size=10, enable=enable_rasteration) T.use_swizzle(panel_size=10, enable=enable_rasteration)
...@@ -246,15 +255,9 @@ def matmul(M, N, K, with_roller): ...@@ -246,15 +255,9 @@ def matmul(M, N, K, with_roller):
# Loop over sub-blocks in K dimension, pipelined by num_stages # Loop over sub-blocks in K dimension, pipelined by num_stages
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
# Load a sub-block of A from global memory into A_shared # Load a sub-block of A from global memory into A_shared
T.copy( T.copy(A[by * block_M, k * block_K], A_shared)
A[by * block_M, k * block_K],
A_shared,
)
# Load a sub-block of B from global memory into B_shared # Load a sub-block of B from global memory into B_shared
T.copy( T.copy(B[bx * block_N, k * block_K], B_shared)
B[bx * block_N, k * block_K],
B_shared,
)
# Perform a partial matrix multiplication: # Perform a partial matrix multiplication:
# C_local += A_shared @ B_shared^T # C_local += A_shared @ B_shared^T
T.gemm( T.gemm(
...@@ -262,9 +265,11 @@ def matmul(M, N, K, with_roller): ...@@ -262,9 +265,11 @@ def matmul(M, N, K, with_roller):
B_shared, B_shared,
C_local, C_local,
transpose_B=True, transpose_B=True,
policy=policy,
) )
# Write back the results from C_local to the global memory C # Write back the results from C_local to the global memory C
T.copy(C_local, C[by * block_M, bx * block_N]) T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main return main
...@@ -274,9 +279,9 @@ def matmul(M, N, K, with_roller): ...@@ -274,9 +279,9 @@ def matmul(M, N, K, with_roller):
if __name__ == "__main__": if __name__ == "__main__":
# Parse command-line arguments for matrix dimensions # Parse command-line arguments for matrix dimensions
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument("--m", type=int, default=8192, help="Matrix dimension M") parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=8192, help="Matrix dimension N") parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=8192, help="Matrix dimension K") parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument( parser.add_argument(
"--with_roller", "--with_roller",
action="store_true", action="store_true",
...@@ -285,7 +290,8 @@ if __name__ == "__main__": ...@@ -285,7 +290,8 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
M, N, K = args.m, args.n, args.k M, N, K = args.m, args.n, args.k
with_roller = args.with_roller # with_roller = args.with_roller
with_roller = True
# Compute total floating-point operations to measure throughput # Compute total floating-point operations to measure throughput
total_flops = 2 * M * N * K total_flops = 2 * M * N * K
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
import logging
import torch
import torch.backends
from tilelang import tvm as tvm
from tvm import DataType
import tilelang as tl
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func
from tilelang.autotuner import autotune, jit
import itertools
# Configure logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
def make_swizzle_layout(shared_buf):
dtype = shared_buf.dtype
shape = shared_buf.shape
can_swizzle = shape[-1] * DataType(dtype).bits == 512
if not can_swizzle:
return T.Layout(shape, lambda *args: args)
def transform_func(i, j):
new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype)
return [new_warp_i, new_warp_j]
return T.Layout(shape, transform_func)
@simplify_prim_func
def tl_matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
block_row_warps=1,
block_col_warps=1,
warp_row_tiles=16,
warp_col_tiles=16,
chunk=32,
stage=2,
enable_rasteration=False,
):
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32":
micro_size_k = 32
# This is a debug config
# chunk = 32 if in_dtype == "float16" else 64
shared_scope = "shared.dyn"
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk
A_shape = (M, K)
B_shape = (N, K)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (
block_M,
block_N,
)
warp_size = 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
# Improve L2 Cache
T.use_swizzle(panel_size=10, enable=enable_rasteration)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
# Load B into shared memory
for j, k in T.Parallel(block_N, block_K):
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(A_local, A_shared, ki)
# Load B into fragment
mma_emitter.ldmatrix_b(B_local, B_shared, ki)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)
# Perform STMatrix
mma_emitter.stmatrix(C_local, C_shared)
# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[i, j]
return main
def ref_program(A, B):
"""Reference matrix multiplication program."""
return torch.matmul(A, B.T)
def get_configs(M, N, K, with_roller=False):
"""
Generate a list of configuration dictionaries that will be used for tuning.
Parameters
----------
with_roller : bool
Whether to enable bitblas roller to deduce search spaces
Returns
-------
list of dict
Each configuration dict includes various block sizes, pipeline stages,
thread numbers, and other parameters to explore during autotuning.
"""
if with_roller:
from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA
from tilelang.carver.roller.rasterization import NoRasterization
arch = CUDA("cuda")
topk = 10
carve_template = MatmulTemplate(
M=M,
N=N,
K=K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
).with_arch(arch)
func = carve_template.equivalent_function()
assert func is not None, "Function is None"
roller_hints = carve_template.recommend_hints(topk=topk)
if roller_hints is None:
raise ValueError("No Roller Hints Found for TensorCore Scheduling")
configs = []
for hint in roller_hints:
config = {}
block_m, block_n = hint.block
warp_m, warp_n = hint.warp
config["block_row_warps"] = block_m // warp_m
config["block_col_warps"] = block_n // warp_n
config["warp_row_tiles"] = warp_m
config["warp_col_tiles"] = warp_n
config["chunk"] = hint.rstep[0]
config["stage"] = hint.pipeline_stage
config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization
configs.append(config)
for config in configs:
print(config)
else:
block_rows_warps = [1, 2, 4]
block_col_warps = [1, 2, 4]
warp_row_tiles = [16, 32, 64, 128]
warp_col_tiles = [16, 32, 64, 128]
chunk = [32, 64, 128, 256]
stage = [0, 2]
enable_rasteration = [True, False]
_configs = list(
itertools.product(block_rows_warps, block_col_warps, warp_row_tiles, warp_col_tiles,
chunk, stage, enable_rasteration))
configs = [{
"block_row_warps": c[0],
"block_col_warps": c[1],
"warp_row_tiles": c[2],
"warp_col_tiles": c[3],
"chunk": c[4],
"stage": c[5],
"enable_rasteration": c[6],
} for c in _configs]
return configs
def matmul(M,
N,
K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
with_roller=False):
"""Create an autotuned tensor core matrix multiplication kernel."""
@autotune(
configs=get_configs(M, N, K, with_roller),
keys=[
"block_row_warps",
"block_col_warps",
"warp_row_tiles",
"warp_col_tiles",
"chunk",
"enable_rasteration",
"stage",
],
warmup=3,
rep=5,
)
@jit(
out_idx=[2],
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=True,
profiler="auto",
target="auto",
)
def kernel(
block_row_warps=None,
block_col_warps=None,
warp_row_tiles=None,
warp_col_tiles=None,
chunk=None,
stage=None,
enable_rasteration=None,
):
return tl_matmul(
M,
N,
K,
in_dtype=in_dtype,
out_dtype=out_dtype,
accum_dtype=accum_dtype,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
stage=stage,
enable_rasteration=enable_rasteration,
)
return kernel()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Autotuned TensorCore MatMul Benchmark")
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument(
"--with_roller",
type=bool,
default=False,
help="Whether to use roller to deduce search spaces")
parser.add_argument(
"--dtype", type=str, default="float16", choices=["float16", "int8"], help="Input data type")
args = parser.parse_args()
M, N, K = args.m, args.n, args.k
in_dtype = args.dtype
out_dtype = "float32" if in_dtype == "int8" else "float16"
accum_dtype = "float32" if in_dtype == "int8" else "float16"
with_roller = args.with_roller
with_roller = True
# Compute total floating-point operations
total_flops = 2 * M * N * K
# Run autotuning
best_latency, best_config, ref_latency = matmul(M, N, K, in_dtype, out_dtype, accum_dtype,
with_roller)
# Print benchmark results
print(f"Best latency (s): {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9:.3f}")
print(f"Best config: {best_config}")
print(f"Reference TFlops: {total_flops / ref_latency * 1e-9:.3f}")
...@@ -68,9 +68,7 @@ def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule], ...@@ -68,9 +68,7 @@ def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule],
tags = None tags = None
if tags and tensorized_func: if tags and tensorized_func:
policy = TensorCorePolicy.from_prim_func(func=tensorized_func, arch=arch, tags=tags) policy = TensorCorePolicy.from_prim_func(func=tensorized_func, arch=arch, tags=tags)
roller_hints = policy.emit_config(topk) roller_hints = policy.emit_config(topk)
else:
roller_hints = None
return roller_hints return roller_hints
......
...@@ -318,6 +318,8 @@ class TensorCoreIntrinEmitter(object): ...@@ -318,6 +318,8 @@ class TensorCoreIntrinEmitter(object):
BLOCK_M = block_row_warps * warp_rows BLOCK_M = block_row_warps * warp_rows
BLOCK_N = block_col_warps * warp_cols BLOCK_N = block_col_warps * warp_cols
M_DIM, N_DIM = self.M_DIM, self.N_DIM M_DIM, N_DIM = self.M_DIM, self.N_DIM
C_buf_dims = len(C_buf.shape)
assert C_buf_dims in {2, 4}, "C_buf should be 2D or 4D"
current_frame = T.KernelLaunchFrame.Current() current_frame = T.KernelLaunchFrame.Current()
thread_binding = current_frame.get_thread_binding() thread_binding = current_frame.get_thread_binding()
...@@ -334,9 +336,15 @@ class TensorCoreIntrinEmitter(object): ...@@ -334,9 +336,15 @@ class TensorCoreIntrinEmitter(object):
for local_id_i in T.vectorized(2): for local_id_i in T.vectorized(2):
local_id = local_id_o * 2 + local_id_i local_id = local_id_o * 2 + local_id_i
row, col = T.meta_var(mma_store_index_map(tx, local_id)) row, col = T.meta_var(mma_store_index_map(tx, local_id))
C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row, if C_buf_dims == 2:
col] = C_local_buf[i * (warp_cols * local_size_out) + C_buf[(warp_m * warp_rows + i) * M_DIM + row,
j * local_size_out + local_id] (warp_n * warp_cols + j) * N_DIM +
col] = C_local_buf[i * (warp_cols * local_size_out) +
j * local_size_out + local_id]
else:
C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row,
col] = C_local_buf[i * (warp_cols * local_size_out) +
j * local_size_out + local_id]
@T.macro @T.macro
def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding): def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding):
......
...@@ -91,9 +91,3 @@ def c2d_im2col( ...@@ -91,9 +91,3 @@ def c2d_im2col(
dilation, dilation,
pad, pad,
) )
class GemmWarpPolicy:
Square = 0
FullRow = 1
FullCol = 2
...@@ -2,15 +2,10 @@ ...@@ -2,15 +2,10 @@
# Licensed under the MIT License. # Licensed under the MIT License.
"""The language interface for tl programs.""" """The language interface for tl programs."""
from tilelang.primitives.gemm.base import GemmWarpPolicy
from tvm import tir from tvm import tir
class GemmWarpPolicy:
Square = 0
FullRow = 1
FullCol = 2
def gemm( def gemm(
A: tir.Buffer, A: tir.Buffer,
B: tir.Buffer, B: tir.Buffer,
......
...@@ -126,6 +126,33 @@ class GemmWarpPolicy(IntEnum): ...@@ -126,6 +126,33 @@ class GemmWarpPolicy(IntEnum):
return m_warp, n_warp return m_warp, n_warp
@classmethod
def from_warp_partition(cls, m_warp: int, n_warp: int) -> 'GemmWarpPolicy':
"""
Determine the warp policy based on the given warp partitioning.
Args:
m_warp (int): Number of warps in the row dimension
n_warp (int): Number of warps in the column dimension
Returns:
GemmWarpPolicy: The corresponding warp policy
Examples:
>>> GemmWarpPolicy.from_block_row_cols(4, 1) # All warps in rows
GemmWarpPolicy.FullRow
>>> GemmWarpPolicy.from_block_row_cols(1, 4) # All warps in columns
GemmWarpPolicy.FullCol
>>> GemmWarpPolicy.from_block_row_cols(2, 2) # Balanced distribution
GemmWarpPolicy.Square
"""
if n_warp == 1 and m_warp > 1:
return cls.FullRow
elif m_warp == 1 and n_warp > 1:
return cls.FullCol
else:
return cls.Square
@dataclass @dataclass
class GemmBaseParams: class GemmBaseParams:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment