Commit c0378aa9 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Feat] Enhance CUDA Property Handling (#322)



* [Enhancement] Introduce CUDA driver module and refactor CUDA device handling

- Added a new `cuda_driver` module to encapsulate CUDA device properties and functionalities.
- Updated `CUDA` class in `cuda.py` to utilize the new driver for fetching device name and shared memory capabilities.
- Introduced `get_device_name` and `get_shared_memory_per_block` functions in the `cuda_driver` for improved device property management.
- This refactor enhances code organization and maintainability while improving the handling of CUDA device attributes.

* [Refactor] Clean up whitespace in CUDA-related files

- Removed unnecessary blank lines in `cuda.py`, `__init__.py`, and `cuda_driver.py` to improve code readability and maintainability.
- This change enhances the overall organization of the codebase without altering functionality.

* [Benchmark] Add FP8 Matrix Multiplication Benchmark Script

- Introduced a new benchmark script for FP8 matrix multiplication in `benchmark/matmul_fp8/benchmark_matmul.py`.
- The script includes functions for reference matrix multiplication, configuration generation for autotuning, and an autotuned kernel for performance measurement.
- Added command-line argument parsing for matrix dimensions and the option to enable BitBLAS roller for search space exploration.
- The benchmark computes and prints the best latency and performance metrics, enhancing the benchmarking capabilities for FP8 operations.

* lint fix

---------
Co-authored-by: default avatarLeiWang1999 <wyatuestc@gmail.com>
parent d2f59cfa
import argparse
import itertools
import logging
import tilelang as tl
import tilelang.language as T
from tilelang.autotuner import autotune, jit
# Configure logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
def ref_program(A, B):
"""
A reference matrix multiplication program, used to compare performance.
Parameters
----------
A : numpy.ndarray
The matrix with shape (M, K).
B : numpy.ndarray
The matrix with shape (N, K).
Returns
-------
np.ndarray
The result of A @ B.T, shape (M, N).
"""
return A.float() @ B.T.float()
def get_configs(M, N, K, with_roller=False):
"""
Generate a list of configuration dictionaries that will be used for tuning.
Parameters
----------
with_roller : bool
Whether to enable bitblas roller to deduce search spaces
Returns
-------
list of dict
Each configuration dict includes various block sizes, pipeline stages,
thread numbers, and other parameters to explore during autotuning.
"""
if with_roller:
from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA
from tilelang.carver.roller.rasterization import NoRasterization
arch = CUDA("cuda")
topk = 10
carve_template = MatmulTemplate(
M=M,
N=N,
K=K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float",
).with_arch(arch)
func = carve_template.equivalent_function()
assert func is not None, "Function is None"
roller_hints = carve_template.recommend_hints(topk=topk)
if roller_hints is None:
raise ValueError("No Roller Hints Found for TensorCore Scheduling")
configs = []
for hint in roller_hints:
config = {}
block_m, block_n = hint.block
warp_m, warp_n = hint.warp
# block_rows, block_cols represents warp partitioning
block_rows, block_cols = block_m // warp_m, block_n // warp_n
config["block_M"] = block_m
config["block_N"] = block_n
config["block_K"] = hint.rstep[0]
config["num_stages"] = hint.pipeline_stage
config["thread_num"] = block_rows * block_cols * 32
config["policy"] = T.GemmWarpPolicy.from_warp_partition(block_rows, block_cols)
config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization
configs.append(config)
for config in configs:
print(config)
else:
block_M = [64, 128, 256]
block_N = [64, 128, 256]
block_K = [64, 128]
num_stages = [0, 1, 2, 3]
thread_num = [128, 256]
policy = [T.GemmWarpPolicy.Square]
enable_rasterization = [True, False]
_configs = list(
itertools.product(
block_M,
block_N,
block_K,
num_stages,
thread_num,
policy,
enable_rasterization,
))
configs = [
{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"policy": c[5],
"enable_rasteration": c[6], # keep param name for backward-compat
} for c in _configs
]
return configs
def matmul(M, N, K, with_roller):
"""
Create an autotuned matrix multiplication kernel for matrices of shape:
- A: (M, K)
- B: (N, K)
- C: (M, N)
Parameters
----------
M : int
The dimension M of the matrix multiplication.
N : int
The dimension N of the matrix multiplication.
K : int
The dimension K of the matrix multiplication.
Returns
-------
(best_latency, best_config, ref_latency)
best_latency : float
The best latency found among the tuned configurations.
best_config : dict
The parameter configuration that yielded best_latency.
ref_latency : float
The baseline latency of the reference program (for computing speedup).
"""
# Decorate the kernel with autotune & jit, specifying:
# - Tuning config list
# - Profiling keys
# - Warmup and repetition counts for better measurement
# - A reference program for correctness verification
# - The "tvm" profiler backend
# - HIP as the compilation target (modify as needed for your hardware)
@autotune(
configs=get_configs(M, N, K, with_roller),
warmup=3,
rep=20,
)
@jit(
out_idx=[2],
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=True,
target="auto",
)
def kernel(
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
policy=None,
enable_rasteration=None,
):
"""
The actual kernel to compute C = A @ B^T.
Parameters
----------
block_M : int
Block size in M dimension.
block_N : int
Block size in N dimension.
block_K : int
Block size in K dimension.
num_stages : int
Number of pipelined stages (for asynchronous load).
thread_num : int
Number of threads to use per block.
enable_rasteration : bool
Whether to enable rasterization (swizzling) optimization.
k_pack : int
K dimension packing factor to improve memory coalescing.
Returns
-------
Function
A TVM Tensor Language function (T.prim_func) that computes matmul.
"""
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
dtype = "e4m3_float8"
accum_dtype = "float"
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
"""
The compiled TVM function for block-level matrix multiplication.
- We divide the entire (M, N) domain into blocks of shape
(block_M, block_N).
- Each block has its own allocated shared memory for sub-blocks
of A and B.
- The partial results go into C_local, and then we copy them back
to global memory C.
"""
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K), dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
B_shared = T.alloc_shared((block_N, block_K), dtype)
# Allocate a local fragment for intermediate accumulation
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Allocate a shared memory for C sub-block of shape (block_M, block_N)
C_shared = T.alloc_shared((block_M, block_N), dtype)
# Enable (or disable) swizzling optimization
T.use_swizzle(panel_size=10, enable=enable_rasteration)
# Clear out the accumulation buffer
T.clear(C_local)
# Loop over sub-blocks in K dimension, pipelined by num_stages
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
# Load a sub-block of A from global memory into A_shared
T.copy(A[by * block_M, k * block_K], A_shared)
# Load a sub-block of B from global memory into B_shared
T.copy(B[bx * block_N, k * block_K], B_shared)
# Perform a partial matrix multiplication:
# C_local += A_shared @ B_shared^T
T.gemm(
A_shared,
B_shared,
C_local,
transpose_B=True,
policy=policy,
)
# Write back the results from C_local to the global memory C
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
return kernel()
if __name__ == "__main__":
# Parse command-line arguments for matrix dimensions
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument(
"--with_roller",
action="store_true",
help="Whether to enable BitBLAS roller for search space",
)
args = parser.parse_args()
M, N, K = args.m, args.n, args.k
with_roller = args.with_roller
# Compute total floating-point operations to measure throughput
total_flops = 2 * M * N * K
# matmul(...) returns (best_latency, best_config, ref_latency)
best_result = matmul(M, N, K, with_roller)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
# Print out the benchmark results
print(f"Best latency (s): {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9:.3f}")
print(f"Best config: {best_config}")
print(f"Reference TFlops: {total_flops / ref_latency * 1e-9:.3f}")
......@@ -96,3 +96,4 @@ my_func = kernel(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256)
cuda_device = CUDA("cuda")
result = Analyzer.analysis(my_func, cuda_device)
print(result)
print(f"Analyzed FLOPs: {result.total_flops}")
......@@ -45,6 +45,9 @@ def kernel(
my_func = kernel(128, 128, 32, 3, 128, True)
cuda_device = CUDA("cuda")
result = Analyzer.analysis(my_func, cuda_device)
print(result)
print(f"Analyzed FLOPs: {result.total_flops}")
print(f"Expected FLOPs: {2 * M * N * K}")
......@@ -129,6 +129,43 @@ def get_best_config(M, N, K, with_roller=False):
return autotuner.run(warmup=3, rep=20)
def get_heuristic_config() -> dict:
# Get CUDA device properties
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
device = torch.cuda.current_device()
sm_major, sm_minor = torch.cuda.get_device_capability(device)
sm_version = sm_major * 10 + sm_minor
print(f"CUDA device capability: {sm_version}")
if sm_version in {80}:
return {
"block_M": 128,
"block_N": 256,
"block_K": 32,
"num_stages": 2,
"thread_num": 128,
"enable_rasteration": True
}
elif sm_version in {90}:
return {
"block_M": 128,
"block_N": 256,
"block_K": 64,
"num_stages": 3,
"thread_num": 256,
"enable_rasteration": True
}
else:
return {
"block_M": 128,
"block_N": 256,
"block_K": 32,
"num_stages": 0,
"thread_num": 128,
"enable_rasteration": True
}
def matmul(M,
N,
K,
......@@ -171,13 +208,13 @@ def matmul(M,
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument("--m", type=int, default=1024, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K")
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument(
"--use_autotune",
action="store_true",
default=True,
default=False,
help="Whether to use autotune for matmul configs")
parser.add_argument(
"--with_roller",
......@@ -192,11 +229,21 @@ if __name__ == "__main__":
with_roller = args.with_roller
if use_autotune:
result = get_best_config(M, N, K, with_roller)
print(f"best latency {result.latency}")
kernel = result.kernel
else:
kernel = tl.compile(matmul(M, N, K, 128, 128, 32, 3, 128, True), out_idx=-1)
config = get_heuristic_config()
kernel = tl.compile(matmul(M, N, K, **config), out_idx=-1)
out_c = kernel(a, b)
ref_c = ref_program(a, b)
torch.testing.assert_close(out_c, ref_c, rtol=1e-2, atol=1e-2)
# benchmark
profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto)
tilelang_latency = profiler.do_bench()
ref_latency = profiler.do_bench(ref_program)
print(f"TileLang latency: {tilelang_latency}")
print(f"Ref latency: {ref_latency}")
print(f"TileLang TFlops: {2 * M * N * K / tilelang_latency * 1e-9}")
print(f"Ref TFlops: {2 * M * N * K / ref_latency * 1e-9}")
......@@ -40,7 +40,7 @@ def _init_logger():
logger = logging.getLogger(__name__)
handler = TqdmLoggingHandler()
formatter = logging.Formatter(
fmt="%(asctime)s [TileLang:%(levelname)s]: %(message)s",
fmt="%(asctime)s [TileLang:%(name)s:%(levelname)s]: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
handler.setFormatter(formatter)
......
......@@ -92,7 +92,7 @@ def get_cuda_device_properties(device_id: int = 0) -> Optional[cudaDeviceProp]:
if ret == 0:
return prop
else:
return None
raise RuntimeError(f"cudaGetDeviceProperties failed with error {ret}")
def get_device_name(device_id: int = 0) -> Optional[str]:
......@@ -112,9 +112,9 @@ def get_shared_memory_per_block(device_id: int = 0, format: str = "bytes") -> Op
if format == "bytes":
return shared_mem
elif format == "kb":
return shared_mem // 1024 # 使用整除
return shared_mem // 1024
elif format == "mb":
return shared_mem // (1024 * 1024) # 使用整除
return shared_mem // (1024 * 1024)
else:
raise RuntimeError("Invalid format. Must be one of: bytes, kb, mb")
else:
......@@ -144,7 +144,9 @@ def get_device_attribute(attr: int, device_id: int = 0) -> int:
def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes") -> Optional[int]:
"""获取设备支持的最大动态共享内存大小"""
"""
Get the maximum dynamic shared memory size in bytes, kilobytes, or megabytes.
"""
assert format in ["bytes", "kb", "mb"], "Invalid format. Must be one of: bytes, kb, mb"
prop = get_cuda_device_properties(device_id)
if prop:
......@@ -153,9 +155,9 @@ def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes")
if format == "bytes":
return shared_mem
elif format == "kb":
return shared_mem // 1024 # 使用整除
return shared_mem // 1024
elif format == "mb":
return shared_mem // (1024 * 1024) # 使用整除
return shared_mem // (1024 * 1024)
else:
raise RuntimeError("Invalid format. Must be one of: bytes, kb, mb")
else:
......
......@@ -2,11 +2,15 @@ import numpy as np
from dataclasses import dataclass
from tilelang import tvm
from tvm.tir.stmt_functor import ir_transform
import logging
from typing import Optional
# Configuration for different hardware architectures.
# Each entry contains: (cores per SM, default clock (GHz), FLOPs per cycle, max SM count)
ARCH_CONFIGS = {"80": (128, 1.41, 2, 108), "86": (128, 1.70, 2, 84), "89": (128, 2.52, 2, 128)}
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class AnalysisResult:
......@@ -22,8 +26,8 @@ class AnalysisResult:
total_flops: int
total_global_bytes: int
estimated_time: float
tflops: float
bandwidth_GBps: float
expected_tflops: float
expected_bandwidth_GBps: float
class Analyzer:
......@@ -164,7 +168,7 @@ class Analyzer:
AnalysisResult: The calculated performance metrics.
"""
def get_peak_tflops(device) -> float:
def get_peak_tflops(device) -> Optional[float]:
"""
Get the peak TFLOPS for the target device.
Args:
......@@ -174,7 +178,10 @@ class Analyzer:
"""
arch_key = device.compute_capability[:2]
if arch_key not in ARCH_CONFIGS:
raise ValueError(f"Unsupported compute capability: {device.compute_capability}")
logger.info(
f"Unsupported compute capability: {device.compute_capability}, theoretical peak tflops will be None"
)
return None
cores_per_sm, default_clock, flops_per_cycle, compute_max_core = ARCH_CONFIGS[arch_key]
total_cores = compute_max_core * cores_per_sm
......@@ -187,16 +194,16 @@ class Analyzer:
# Estimate memory and compute times
mem_time = self.total_global_bytes / (bandwidth_GBps * 1e9)
compute_time = self.total_flops / (peak_tflops * 1e12)
estimated_time = max(mem_time, compute_time) # Use the larger of the two times
compute_time = self.total_flops / (peak_tflops * 1e12) if peak_tflops else None
estimated_time = max(mem_time, compute_time) if peak_tflops else mem_time
# Return the analysis results
return AnalysisResult(
total_flops=self.total_flops,
total_global_bytes=self.total_global_bytes,
estimated_time=float(estimated_time),
tflops=float(self.total_flops / estimated_time / 1e12),
bandwidth_GBps=bandwidth_GBps)
estimated_time=estimated_time,
expected_tflops=peak_tflops,
expected_bandwidth_GBps=bandwidth_GBps)
@classmethod
def analysis(cls, fn, device):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment