Commit 3c53297b authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Carver] Enhance Carver Adaptation for MatMul Benchmarking (#153)

* [Refactor] Consolidate GemmWarpPolicy Enum and Add Utility Method

- Move GemmWarpPolicy from copy.py and gemm.py to primitives/gemm/base.py
- Implement from_warp_partition class method to determine warp policy
- Add docstring with examples for policy determination
- Remove duplicate GemmWarpPolicy class definitions

* [Enhancement] Add TensorCore Intrinsic Matrix Multiplication Benchmarks

- Implement two new matrix multiplication benchmark scripts:
  1. `benchmark_matmul_intrinsic.py`: Uses TensorCore intrinsics with advanced configuration
  2. `benchmark_matmul.py`: Provides a more generic matrix multiplication benchmark

- Add support for roller-based configuration generation in both benchmarks
- Enhance MMA macro generator to handle 2D and 4D output buffer layouts
- Implement flexible autotuning configurations with multiple parameters
- Support different data types and accumulation modes
- Add command-line arguments for matrix dimensions and roller configuration

* lint fix

* Fix roller hints generation in get_roller_hints_from_func

- Simplify roller hints generation logic
- Ensure policy-based configuration is always emitted when a policy is available
- Remove redundant None check for roller hints

* Add shared memory for matrix multiplication in benchmark and quickstart examples

- Modify benchmark_matmul.py and quickstart.py to include C_shared allocation
- Change accumulation dtype from float16 to float in benchmark_matmul.py
- Update matrix multiplication kernels to use shared memory for result storage
- Enable CUDA kernel source printing in quickstart example
parent e945dae2
......@@ -50,7 +50,7 @@ def get_configs(M, N, K, with_roller=False):
from tilelang.carver.arch import CUDA
from tilelang.carver.roller.rasterization import NoRasterization
arch = CUDA("cuda")
topk = 20
topk = 10
carve_template = MatmulTemplate(
M=M,
......@@ -58,7 +58,7 @@ def get_configs(M, N, K, with_roller=False):
K=K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
accum_dtype="float",
).with_arch(arch)
func = carve_template.equivalent_function()
......@@ -74,11 +74,14 @@ def get_configs(M, N, K, with_roller=False):
config = {}
block_m, block_n = hint.block
warp_m, warp_n = hint.warp
# block_rows, block_cols represents warp partitioning
block_rows, block_cols = block_m // warp_m, block_n // warp_n
config["block_M"] = block_m
config["block_N"] = block_n
config["block_K"] = hint.rstep[0]
config["num_stages"] = 0
config["thread_num"] = (block_m * block_n) // (warp_m * warp_n) * 32
config["num_stages"] = hint.pipeline_stage
config["thread_num"] = block_rows * block_cols * 32
config["policy"] = T.GemmWarpPolicy.from_warp_partition(block_rows, block_cols)
config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization
configs.append(config)
for config in configs:
......@@ -90,8 +93,8 @@ def get_configs(M, N, K, with_roller=False):
block_K = [32, 64]
num_stages = [0, 1, 2, 3]
thread_num = [128, 256]
policy = [T.GemmWarpPolicy.Square]
enable_rasterization = [True, False]
_configs = list(
itertools.product(
block_M,
......@@ -99,6 +102,7 @@ def get_configs(M, N, K, with_roller=False):
block_K,
num_stages,
thread_num,
policy,
enable_rasterization,
))
......@@ -109,7 +113,8 @@ def get_configs(M, N, K, with_roller=False):
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"enable_rasteration": c[5], # keep param name for backward-compat
"policy": c[5],
"enable_rasteration": c[6], # keep param name for backward-compat
} for c in _configs
]
return configs
......@@ -158,6 +163,7 @@ def matmul(M, N, K, with_roller):
"block_K",
"num_stages",
"thread_num",
"policy",
"enable_rasteration",
],
warmup=3,
......@@ -177,6 +183,7 @@ def matmul(M, N, K, with_roller):
block_K=None,
num_stages=None,
thread_num=None,
policy=None,
enable_rasteration=None,
):
"""
......@@ -236,6 +243,8 @@ def matmul(M, N, K, with_roller):
B_shared = T.alloc_shared((block_N, block_K), dtype)
# Allocate a local fragment for intermediate accumulation
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Allocate a shared memory for C sub-block of shape (block_M, block_N)
C_shared = T.alloc_shared((block_M, block_N), dtype)
# Enable (or disable) swizzling optimization
T.use_swizzle(panel_size=10, enable=enable_rasteration)
......@@ -246,15 +255,9 @@ def matmul(M, N, K, with_roller):
# Loop over sub-blocks in K dimension, pipelined by num_stages
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
# Load a sub-block of A from global memory into A_shared
T.copy(
A[by * block_M, k * block_K],
A_shared,
)
T.copy(A[by * block_M, k * block_K], A_shared)
# Load a sub-block of B from global memory into B_shared
T.copy(
B[bx * block_N, k * block_K],
B_shared,
)
T.copy(B[bx * block_N, k * block_K], B_shared)
# Perform a partial matrix multiplication:
# C_local += A_shared @ B_shared^T
T.gemm(
......@@ -262,9 +265,11 @@ def matmul(M, N, K, with_roller):
B_shared,
C_local,
transpose_B=True,
policy=policy,
)
# Write back the results from C_local to the global memory C
T.copy(C_local, C[by * block_M, bx * block_N])
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
......@@ -274,9 +279,9 @@ def matmul(M, N, K, with_roller):
if __name__ == "__main__":
# Parse command-line arguments for matrix dimensions
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument("--m", type=int, default=8192, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=8192, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=8192, help="Matrix dimension K")
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument(
"--with_roller",
action="store_true",
......@@ -285,7 +290,8 @@ if __name__ == "__main__":
args = parser.parse_args()
M, N, K = args.m, args.n, args.k
with_roller = args.with_roller
# with_roller = args.with_roller
with_roller = True
# Compute total floating-point operations to measure throughput
total_flops = 2 * M * N * K
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
import logging
import torch
import torch.backends
from tilelang import tvm as tvm
from tvm import DataType
import tilelang as tl
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func
from tilelang.autotuner import autotune, jit
import itertools
# Configure logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
def make_swizzle_layout(shared_buf):
dtype = shared_buf.dtype
shape = shared_buf.shape
can_swizzle = shape[-1] * DataType(dtype).bits == 512
if not can_swizzle:
return T.Layout(shape, lambda *args: args)
def transform_func(i, j):
new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype)
return [new_warp_i, new_warp_j]
return T.Layout(shape, transform_func)
@simplify_prim_func
def tl_matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
block_row_warps=1,
block_col_warps=1,
warp_row_tiles=16,
warp_col_tiles=16,
chunk=32,
stage=2,
enable_rasteration=False,
):
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32":
micro_size_k = 32
# This is a debug config
# chunk = 32 if in_dtype == "float16" else 64
shared_scope = "shared.dyn"
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk
A_shape = (M, K)
B_shape = (N, K)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (
block_M,
block_N,
)
warp_size = 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
# Improve L2 Cache
T.use_swizzle(panel_size=10, enable=enable_rasteration)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
# Load B into shared memory
for j, k in T.Parallel(block_N, block_K):
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(A_local, A_shared, ki)
# Load B into fragment
mma_emitter.ldmatrix_b(B_local, B_shared, ki)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)
# Perform STMatrix
mma_emitter.stmatrix(C_local, C_shared)
# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[i, j]
return main
def ref_program(A, B):
"""Reference matrix multiplication program."""
return torch.matmul(A, B.T)
def get_configs(M, N, K, with_roller=False):
"""
Generate a list of configuration dictionaries that will be used for tuning.
Parameters
----------
with_roller : bool
Whether to enable bitblas roller to deduce search spaces
Returns
-------
list of dict
Each configuration dict includes various block sizes, pipeline stages,
thread numbers, and other parameters to explore during autotuning.
"""
if with_roller:
from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA
from tilelang.carver.roller.rasterization import NoRasterization
arch = CUDA("cuda")
topk = 10
carve_template = MatmulTemplate(
M=M,
N=N,
K=K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
).with_arch(arch)
func = carve_template.equivalent_function()
assert func is not None, "Function is None"
roller_hints = carve_template.recommend_hints(topk=topk)
if roller_hints is None:
raise ValueError("No Roller Hints Found for TensorCore Scheduling")
configs = []
for hint in roller_hints:
config = {}
block_m, block_n = hint.block
warp_m, warp_n = hint.warp
config["block_row_warps"] = block_m // warp_m
config["block_col_warps"] = block_n // warp_n
config["warp_row_tiles"] = warp_m
config["warp_col_tiles"] = warp_n
config["chunk"] = hint.rstep[0]
config["stage"] = hint.pipeline_stage
config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization
configs.append(config)
for config in configs:
print(config)
else:
block_rows_warps = [1, 2, 4]
block_col_warps = [1, 2, 4]
warp_row_tiles = [16, 32, 64, 128]
warp_col_tiles = [16, 32, 64, 128]
chunk = [32, 64, 128, 256]
stage = [0, 2]
enable_rasteration = [True, False]
_configs = list(
itertools.product(block_rows_warps, block_col_warps, warp_row_tiles, warp_col_tiles,
chunk, stage, enable_rasteration))
configs = [{
"block_row_warps": c[0],
"block_col_warps": c[1],
"warp_row_tiles": c[2],
"warp_col_tiles": c[3],
"chunk": c[4],
"stage": c[5],
"enable_rasteration": c[6],
} for c in _configs]
return configs
def matmul(M,
N,
K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
with_roller=False):
"""Create an autotuned tensor core matrix multiplication kernel."""
@autotune(
configs=get_configs(M, N, K, with_roller),
keys=[
"block_row_warps",
"block_col_warps",
"warp_row_tiles",
"warp_col_tiles",
"chunk",
"enable_rasteration",
"stage",
],
warmup=3,
rep=5,
)
@jit(
out_idx=[2],
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=True,
profiler="auto",
target="auto",
)
def kernel(
block_row_warps=None,
block_col_warps=None,
warp_row_tiles=None,
warp_col_tiles=None,
chunk=None,
stage=None,
enable_rasteration=None,
):
return tl_matmul(
M,
N,
K,
in_dtype=in_dtype,
out_dtype=out_dtype,
accum_dtype=accum_dtype,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
stage=stage,
enable_rasteration=enable_rasteration,
)
return kernel()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Autotuned TensorCore MatMul Benchmark")
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument(
"--with_roller",
type=bool,
default=False,
help="Whether to use roller to deduce search spaces")
parser.add_argument(
"--dtype", type=str, default="float16", choices=["float16", "int8"], help="Input data type")
args = parser.parse_args()
M, N, K = args.m, args.n, args.k
in_dtype = args.dtype
out_dtype = "float32" if in_dtype == "int8" else "float16"
accum_dtype = "float32" if in_dtype == "int8" else "float16"
with_roller = args.with_roller
with_roller = True
# Compute total floating-point operations
total_flops = 2 * M * N * K
# Run autotuning
best_latency, best_config, ref_latency = matmul(M, N, K, in_dtype, out_dtype, accum_dtype,
with_roller)
# Print benchmark results
print(f"Best latency (s): {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9:.3f}")
print(f"Best config: {best_config}")
print(f"Reference TFlops: {total_flops / ref_latency * 1e-9:.3f}")
......@@ -68,9 +68,7 @@ def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule],
tags = None
if tags and tensorized_func:
policy = TensorCorePolicy.from_prim_func(func=tensorized_func, arch=arch, tags=tags)
roller_hints = policy.emit_config(topk)
else:
roller_hints = None
roller_hints = policy.emit_config(topk)
return roller_hints
......
......@@ -318,6 +318,8 @@ class TensorCoreIntrinEmitter(object):
BLOCK_M = block_row_warps * warp_rows
BLOCK_N = block_col_warps * warp_cols
M_DIM, N_DIM = self.M_DIM, self.N_DIM
C_buf_dims = len(C_buf.shape)
assert C_buf_dims in {2, 4}, "C_buf should be 2D or 4D"
current_frame = T.KernelLaunchFrame.Current()
thread_binding = current_frame.get_thread_binding()
......@@ -334,9 +336,15 @@ class TensorCoreIntrinEmitter(object):
for local_id_i in T.vectorized(2):
local_id = local_id_o * 2 + local_id_i
row, col = T.meta_var(mma_store_index_map(tx, local_id))
C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row,
col] = C_local_buf[i * (warp_cols * local_size_out) +
j * local_size_out + local_id]
if C_buf_dims == 2:
C_buf[(warp_m * warp_rows + i) * M_DIM + row,
(warp_n * warp_cols + j) * N_DIM +
col] = C_local_buf[i * (warp_cols * local_size_out) +
j * local_size_out + local_id]
else:
C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row,
col] = C_local_buf[i * (warp_cols * local_size_out) +
j * local_size_out + local_id]
@T.macro
def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding):
......
......@@ -91,9 +91,3 @@ def c2d_im2col(
dilation,
pad,
)
class GemmWarpPolicy:
Square = 0
FullRow = 1
FullCol = 2
......@@ -2,15 +2,10 @@
# Licensed under the MIT License.
"""The language interface for tl programs."""
from tilelang.primitives.gemm.base import GemmWarpPolicy
from tvm import tir
class GemmWarpPolicy:
Square = 0
FullRow = 1
FullCol = 2
def gemm(
A: tir.Buffer,
B: tir.Buffer,
......
......@@ -126,6 +126,33 @@ class GemmWarpPolicy(IntEnum):
return m_warp, n_warp
@classmethod
def from_warp_partition(cls, m_warp: int, n_warp: int) -> 'GemmWarpPolicy':
"""
Determine the warp policy based on the given warp partitioning.
Args:
m_warp (int): Number of warps in the row dimension
n_warp (int): Number of warps in the column dimension
Returns:
GemmWarpPolicy: The corresponding warp policy
Examples:
>>> GemmWarpPolicy.from_block_row_cols(4, 1) # All warps in rows
GemmWarpPolicy.FullRow
>>> GemmWarpPolicy.from_block_row_cols(1, 4) # All warps in columns
GemmWarpPolicy.FullCol
>>> GemmWarpPolicy.from_block_row_cols(2, 2) # Balanced distribution
GemmWarpPolicy.Square
"""
if n_warp == 1 and m_warp > 1:
return cls.FullRow
elif m_warp == 1 and n_warp > 1:
return cls.FullCol
else:
return cls.Square
@dataclass
class GemmBaseParams:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment