Commit 853898a7 authored by yyttt6's avatar yyttt6 Committed by LeiWang1999
Browse files

[Tools] Summarize TFLOPS Information from a tilelang program (#321)

* refactor autotune

* refactor autotune

* refactor autotune

* refactor autotune

* format init.py

* add tutorial for autotune

* merge

* merge

* format analyzer

* add readme for analyzer

* format

* [Tools] Summarize TFLOPS Information from a tilelang program

* Summarize TFLOPS Information from a tilelang program
parent 4b705eb2
# TVM IR Performance Analyzer
A performance analysis toolkit for TVM IR modules, Provides hardware-aware performance metrics including FLOPs, memory bandwidth utilization, and execution time estimation.
## Features
-**Operation Analysis**: Supports arbitrary operations expressed in TVM IR (including GEMM and convolution)
-**Memory Traffic Calculation**: Tracks global memory transfers
-**Architecture-aware Metrics**: Pre-configured with NVIDIA GPU architectures (Ampere, Ada Lovelace)
-**Performance Estimation**: Predicts execution time using roofline model
-**TVM Integration**: Works with TVM IRModule and PrimFunc
## Quick Start
### GEMM Analysis Example
```python
import tilelang.language as T
from tilelang.tools import Analyzer
from tilelang.carver.arch import CUDA
M = N = K = 1024
def kernel(block_M=128, block_N=128, block_K=32, num_stages=3, thread_num=128):
@T.prim_func
def main(A: T.Tensor((M, K), "float16"),
B: T.Tensor((N, K), "float16"),
C: T.Tensor((M, N), "float")):
# ... (kernel definition)
return main
cuda_device = CUDA("cuda")
result = Analyzer.analysis(kernel(), cuda_device)
print(result)
```
### Convolution Analysis Example
```python
import tilelang.language as T
from tilelang.tools import Analyzer
from tilelang.carver.arch import CUDA
def kernel(N=64, C=256, H=512, W=512, F=512, K=3, block_M=64, block_N=128):
@T.prim_func
def main(data: T.Tensor((N, H, W, C), "float16"),
kernel: T.Tensor((K, K, C, F), "float16"),
out: T.Tensor((N, (H-K+1), (W-K+1), F), "float")):
# ... (convolution kernel definition)
return main
cuda_device = CUDA("cuda")
result = Analyzer.analysis(kernel(), cuda_device)
print(result)
```
## API Documentation
### `AnalysisResult` Class
```python
@dataclass(frozen=True)
class AnalysisResult:
total_flops: int # Total floating-point operations
total_global_bytes: int # Global memory traffic in bytes
estimated_time: float # Predicted execution time (seconds)
tflops: float # Achieved TFLOPS
bandwidth_GBps: float # Memory bandwidth utilization
```
### `Analyzer` Class Methods
#### `analysis(fn, device)`
* ​Parameters:
* fn: TVM IRModule or PrimFunc
* device: Device configuration object
* Returns: AnalysisResult
#### Supported Architectures
```python
# Extendable to custom hardware via: "compute_capability": (cores_per_SM, clock_GHz, flops_per_cycle, max_SM_count)
ARCH_CONFIGS = {
"80": (128, 1.41, 2, 108), # A100
"86": (128, 1.70, 2, 84), # RTX 3080
"89": (128, 2.52, 2, 128) # RTX 4090
}
```
## Implementation Details
### Performance Model
Uses roofline model with two constraints:
1.**Compute Bound**: `Time = Total FLOPs / (SM Count × Cores/SM × Clock × FLOPs/Cycle)`
2.**Memory Bound**: `Time = Memory Bytes / (Bandwidth × Utilization)`
### IR Analysis Pass
1.**Traversal**: Walks through TVM IR using `ir_transform`
2.**Operation Detection**:
- Counts FLOPs for all compute operations
- Calculates memory traffic for all memory operations
3.**Loop Handling**:
- Tracks nested loops for operation scaling
- Accounts for block/grid dimensions
## Key Metrics Calculation
| Metric | Formula |
|-------------------------|-----------------------------------------|
| FLOPs per GEMM | `2 × M × N × K` |
| Memory Traffic per Copy | `elements × dtype_size × loop_product` |
| Achieved TFLOPS | `total_flops / estimated_time / 1e12` |
| Memory Bandwidth | `total_global_bytes / estimated_time` |
## Limitations
1. Requires memory operations to be properly annotated in the IR
2. Assumes perfect memory coalescing and no bank conflicts
## Supported Operations
Any operation expressed in TVM IR
import tilelang.language as T
from tilelang.tools import Analyzer
from tilelang.carver.arch import CUDA
from tilelang.layout import make_swizzled_layout
N = 64
C = 256
H = 512
W = 512
F = 512
K = 3
S = 1
D = 1
P = 1
def check_hopper():
# if not torch.cuda.is_available():
# return None
# props = torch.cuda.get_device_properties(0)
# compute_capability = props.major, props.minor
# return compute_capability == (9, 0)
return False
def kernel(N,
C,
H,
W,
F,
K,
S,
D,
P,
block_M,
block_N,
block_K,
num_stages,
threads,
dtype="float16",
accum_dtype="float"):
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
dtype = "float16"
accum_dtype = "float"
is_hopper = check_hopper()
@T.prim_func
def main(
data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype),
):
with T.Kernel(
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
threads=threads) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
out_shared = T.alloc_shared((block_M, block_N), dtype)
kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
T.annotate_layout({
out_shared: make_swizzled_layout(out_shared),
data_shared: make_swizzled_layout(data_shared),
kernel_shared: make_swizzled_layout(kernel_shared),
})
T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
if is_hopper:
T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P)
else:
for i, j in T.Parallel(block_M, block_K):
k = k_iter * block_K + j
m = by * block_M + i
access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P
access_w = m % OW * S + k // C % KW * D - P
in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and
(access_w < W))
data_shared[i, j] = T.if_then_else(
in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
T.gemm(data_shared, kernel_shared, out_local)
T.copy(out_local, out_shared)
T.copy(out_shared, out_flat[by * block_M, bx * block_N])
return main
my_func = kernel(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256)
cuda_device = CUDA("cuda")
result = Analyzer.analysis(my_func, cuda_device)
print(result)
import tilelang.language as T
from tilelang.tools import Analyzer
from tilelang.carver.arch import CUDA
M = N = K = 1024
def kernel(
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
enable_rasteration=None,
):
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype)
T.use_swizzle(panel_size=10, enable=enable_rasteration)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_shared,
B_shared,
C_local,
transpose_B=True,
)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
my_func = kernel(128, 128, 32, 3, 128, True)
cuda_device = CUDA("cuda")
result = Analyzer.analysis(my_func, cuda_device)
print(result)
import numpy as np
from dataclasses import dataclass
from tilelang import tvm
from tvm.tir.stmt_functor import ir_transform
# Configuration for different hardware architectures.
# Each entry contains: (cores per SM, default clock (GHz), FLOPs per cycle, max SM count)
ARCH_CONFIGS = {"80": (128, 1.41, 2, 108), "86": (128, 1.70, 2, 84), "89": (128, 2.52, 2, 128)}
@dataclass(frozen=True)
class AnalysisResult:
"""
A data class to store the results of the analysis.
Attributes:
total_flops: Total floating-point operations.
total_global_bytes: Total bytes transferred to/from global memory.
estimated_time: Estimated execution time (seconds).
tflops: Achieved TFLOPS (trillions of FLOPs per second).
bandwidth_GBps: Achieved memory bandwidth in GB/s.
"""
total_flops: int
total_global_bytes: int
estimated_time: float
tflops: float
bandwidth_GBps: float
class Analyzer:
"""
A class to analyze the performance of a TVM IR module.
It calculates metrics such as FLOPs, memory bandwidth, and estimated execution time.
"""
def __init__(self, fn, device):
"""
Initialize the Analyzer.
Args:
fn: A TVM IRModule or PrimFunc to analyze.
device: The target device information.
"""
if isinstance(fn, tvm.tir.function.PrimFunc):
self.fn = tvm.IRModule({"main": fn})
else:
self.fn = fn
self.device = device
self.total_flops = 0 # Total floating-point operations
self.total_global_bytes = 0 # Total global memory bytes
self.block_counts = {"blockIdx.x": 1, "blockIdx.y": 1} # Block dimensions
self.loop_stack = [] # Stack to track nested loops
self.global_buffers = set() # Set of global memory buffers
def _analyze_copy(self, call):
"""
Analyze memory copy operations (e.g., tl.copy).
Args:
call: A TVM Call node representing the copy operation.
"""
src_buffer = call.args[0].args[0].buffer
dst_buffer = call.args[1].args[0].buffer
# Determine if the source or destination is a global buffer
if src_buffer in self.global_buffers:
buffer_region = call.args[0]
elif dst_buffer in self.global_buffers:
buffer_region = call.args[1]
else:
return
# Calculate the number of elements being copied
elements = 1
for r in range(2, len(buffer_region.args)):
elements *= buffer_region.args[r]
dtype_size = np.dtype(buffer_region.args[0].buffer.dtype).itemsize # Size of the data type
bytes_transferred = elements * dtype_size # Total bytes transferred
# Account for loop and block dimensions
loop_product = 1
for extent in self.loop_stack:
loop_product *= extent.value if hasattr(extent, 'value') else extent
total_blocks = self.block_counts["blockIdx.x"] * self.block_counts["blockIdx.y"]
total_bytes = bytes_transferred * loop_product * total_blocks
self.total_global_bytes += total_bytes
def _analyze_gemm(self, call):
"""
Analyze matrix multiplication (GEMM) operations (e.g., tl.gemm).
Args:
call: A TVM Call node representing the GEMM operation.
"""
M = call.args[5].value
N = call.args[6].value
K = call.args[7].value
flops_per_call = 2 * M * N * K # FLOPs for one GEMM operation
# Account for loop and block dimensions
loop_product = 1
for extent in self.loop_stack:
loop_product *= extent.value if hasattr(extent, 'value') else extent
total_blocks = self.block_counts["blockIdx.x"] * self.block_counts["blockIdx.y"]
self.total_flops += flops_per_call * loop_product * total_blocks
def ir_pass(self):
"""
Traverse and transform the IR module to extract performance-related information.
Returns:
self: The Analyzer instance.
"""
def _ftransform(f, mod, ctx):
# Initialize the set of global buffers
self.global_buffers = set(f.buffer_map.values())
def _pre_visit(stmt):
"""
Pre-visit callback for IR nodes.
Args:
stmt: The current IR node being visited.
"""
if isinstance(stmt, tvm.tir.AttrStmt):
# Handle thread extent attributes
if stmt.attr_key == "thread_extent":
iter_var = stmt.node
thread_tag = iter_var.thread_tag
if thread_tag in self.block_counts:
extent = stmt.value.value if hasattr(stmt.value,
'value') else stmt.value
self.block_counts[thread_tag] = extent
elif isinstance(stmt, tvm.tir.For):
# Push loop extent onto the stack
self.loop_stack.append(stmt.extent)
elif isinstance(stmt, tvm.tir.Evaluate):
# Handle Evaluate nodes containing calls
value = stmt.value
if isinstance(value, tvm.tir.Call):
if value.op.name == "tl.copy":
self._analyze_copy(value)
elif value.op.name == "tl.gemm":
self._analyze_gemm(value)
return None
def _post_visit(stmt):
"""
Post-visit callback for IR nodes.
Args:
stmt: The current IR node being visited.
"""
if isinstance(stmt, tvm.tir.For) and self.loop_stack:
self.loop_stack.pop()
return None
# Use IR transformation to traverse and modify the function body
new_body = ir_transform(f.body, _pre_visit, _post_visit)
return f.with_body(new_body)
# Apply the custom PrimFunc pass
tvm.tir.transform.prim_func_pass(_ftransform, opt_level=0)(self.fn)
return self
def calculate(self) -> AnalysisResult:
"""
Calculate performance metrics based on the analysis.
Returns:
AnalysisResult: The calculated performance metrics.
"""
def get_peak_tflops(device) -> float:
"""
Get the peak TFLOPS for the target device.
Args:
device: The target device information.
Returns:
float: The peak TFLOPS.
"""
arch_key = device.compute_capability[:2]
if arch_key not in ARCH_CONFIGS:
raise ValueError(f"Unsupported compute capability: {device.compute_capability}")
cores_per_sm, default_clock, flops_per_cycle, compute_max_core = ARCH_CONFIGS[arch_key]
total_cores = compute_max_core * cores_per_sm
tflops = (total_cores * default_clock * flops_per_cycle) / 1e3
return round(tflops, 1)
# Calculate memory bandwidth and peak TFLOPS
bandwidth_GBps = self.device.bandwidth[1] / 1000
peak_tflops = get_peak_tflops(self.device)
# Estimate memory and compute times
mem_time = self.total_global_bytes / (bandwidth_GBps * 1e9)
compute_time = self.total_flops / (peak_tflops * 1e12)
estimated_time = max(mem_time, compute_time) # Use the larger of the two times
# Return the analysis results
return AnalysisResult(
total_flops=self.total_flops,
total_global_bytes=self.total_global_bytes,
estimated_time=float(estimated_time),
tflops=float(self.total_flops / estimated_time / 1e12),
bandwidth_GBps=bandwidth_GBps)
@classmethod
def analysis(cls, fn, device):
"""
Perform a full analysis of the given IR module or PrimFunc.
Args:
fn: A TVM IRModule or PrimFunc to analyze.
device: The target device information.
Returns:
AnalysisResult: The calculated performance metrics.
"""
return cls(fn, device).ir_pass().calculate()
from .plot_layout import plot_layout # noqa: F401 from .plot_layout import plot_layout # noqa: F401
from .Analyzer import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment