Unverified Commit 91d5ef54 authored by Cunxiao Ni's avatar Cunxiao Ni Committed by GitHub
Browse files

[Profiler] Adds CUPTI profiler support (#936)



* [Profiler]Adds CUPTI profiler support

* format

* rafactor cupti profiler

* format

* rafactor

* rafactor

* fix lint

* fix lint

* refactor

* add profiler tests

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent ac8c9afc
......@@ -51,6 +51,12 @@ def main():
print("CUDA Source:")
print(kernel.get_kernel_source())
# benchmark
profiler = kernel.get_profiler()
latency = profiler.do_bench(backend="cupti")
# latency = profiler.do_bench()
print(f"tilelang Latency: {latency}ms")
if __name__ == "__main__":
main()
import tilelang
import tilelang.language as T
@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def gemm(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
return gemm
def test_profiler():
kernel = matmul(1024, 1024, 1024, 128, 128, 32)
import torch
a = torch.randn(1024, 1024).cuda().half()
b = torch.randn(1024, 1024).cuda().half()
c = kernel(a, b)
ref_c = a @ b
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
# benchmark
profiler = kernel.get_profiler()
# use cupti backend
cupti_latency = profiler.do_bench(backend="cupti")
# use event backend
event_latency = profiler.do_bench(backend="event")
print(f"cupti Latency: {cupti_latency}ms")
print(f"event Latency: {event_latency}ms")
if __name__ == "__main__":
tilelang.testing.main()
......@@ -175,7 +175,7 @@ from cython_wrapper import CythonKernelWrapper
class CythonKernelAdapter(BaseKernelAdapter):
"""Adapter class that converts TVM/TIR functions to callable CUDA kernels using ctypes.
"""Adapter class that converts TVM/TIR functions to callable CUDA kernels using cython.
This adapter handles:
1. Converting TIR functions to compiled CUDA libraries
......
"""The profiler and convert to torch utils"""
from typing import List, Optional, Callable, Any
from typing import List, Optional, Callable, Any, Literal
from functools import partial
import torch
from contextlib import suppress
......@@ -223,6 +223,9 @@ class Profiler:
n_warmup: int = 1,
n_repeat: int = 1,
input_tensors: List[torch.Tensor] = None,
backend: Literal["event", "cupti"] = "event",
quantiles: Optional[List[float]] = None,
return_mode: Literal["min", "max", "mean", "median"] = "mean",
) -> float:
"""Benchmarks the execution time of a given function.
......@@ -251,6 +254,9 @@ class Profiler:
rep=rep,
_n_warmup=n_warmup,
_n_repeat=n_repeat,
quantiles=quantiles,
backend=backend,
return_mode=return_mode,
)
elif profiler == "tvm":
assert func is not None, "func should not be None"
......
"""The profiler and convert to torch utils"""
"""Profiler and benchmarking utilities for PyTorch functions."""
import torch
import os
import sys
from typing import Callable, List, Literal, Optional, Union
import torch
class suppress_stdout_stderr:
"""Context manager to suppress stdout and stderr output.
Source: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/testing/bench.py
"""
def __enter__(self):
# Open null device files
self.outnull_file = open(os.devnull, 'w')
self.errnull_file = open(os.devnull, 'w')
# Save original file descriptors
self.old_stdout_fileno_undup = sys.stdout.fileno()
self.old_stderr_fileno_undup = sys.stderr.fileno()
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
# Save original stdout/stderr objects
self.old_stdout = sys.stdout
self.old_stderr = sys.stderr
# Redirect file descriptors and streams to null device
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
sys.stdout = self.outnull_file
sys.stderr = self.errnull_file
return self
def __exit__(self, *_):
# Restore original stdout/stderr objects
sys.stdout = self.old_stdout
sys.stderr = self.old_stderr
# Restore original file descriptors
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
# Close duplicated file descriptors
os.close(self.old_stdout_fileno)
os.close(self.old_stderr_fileno)
# Close null device files
self.outnull_file.close()
self.errnull_file.close()
def do_bench(
fn: Callable,
......@@ -10,46 +60,47 @@ def do_bench(
rep: float = 100,
_n_warmup: int = 0,
_n_repeat: int = 0,
grad_to_none: Optional[List[torch.Tensor]] = None,
quantiles: Optional[List[float]] = None,
fast_flush: bool = True,
backend: Literal["event", "cupti"] = "event",
return_mode: Literal["min", "max", "mean", "median"] = "mean",
) -> Union[float, List[float]]:
"""Benchmarks the runtime of a PyTorch function.
"""Benchmark the runtime of a PyTorch function with L2 cache management.
This function handles:
- L2 cache flushing between runs for consistent timing
- Automatic warmup and repeat count calculation
- Optional gradient clearing for backward passes
- Multiple measurement modes (mean, median, min, max)
This function provides accurate GPU kernel timing by:
- Clearing L2 cache between runs for consistent measurements
- Auto-calculating warmup and repeat counts based on kernel runtime
- Supporting multiple profiling backends (CUDA events or CUPTI)
- Offering flexible result aggregation (mean/median/min/max/quantiles)
Args:
fn: Function to benchmark
warmup: Target warmup time in milliseconds
rep: Target number of repetitions
_n_warmup: Override for number of warmup iterations
_n_repeat: Override for number of timing iterations
grad_to_none: Tensors whose gradients should be cleared between runs
quantiles: Optional performance percentiles to compute
fast_flush: Whether to use faster L2 cache flushing
return_mode: How to aggregate timing results ("mean", "median", "min", "max")
warmup: Target warmup time in milliseconds (default: 25)
rep: Target total benchmark time in milliseconds (default: 100)
_n_warmup: Manual override for warmup iterations (default: 0 = auto)
_n_repeat: Manual override for benchmark iterations (default: 0 = auto)
quantiles: Performance percentiles to compute (e.g., [0.5, 0.95])
fast_flush: Use faster L2 cache flush with int32 vs int8 (default: True)
backend: Profiler backend - "event" (CUDA events) or "cupti" (default: "event")
return_mode: Result aggregation method - "mean", "median", "min", or "max"
Returns:
float: Aggregated runtime in milliseconds
Runtime in milliseconds (float) or list of quantile values if quantiles specified
"""
assert return_mode in ["min", "max", "mean", "median"]
assert return_mode in ["min", "max", "mean", "median"], \
f"Invalid return_mode: {return_mode}"
# Initial function call and synchronization
fn()
torch.cuda.synchronize()
# We maintain a buffer of 256 MB that we clear
# before each kernel call to make sure that the L2
# doesn't contain any input data before the run
if fast_flush:
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")
else:
cache = torch.empty(int(256e6), dtype=torch.int8, device="cuda")
# Create L2 cache flush buffer (256 MB)
# Fast flush uses int32 (4 bytes), regular uses int8 (1 byte)
cache_size = int(256e6 // 4) if fast_flush else int(256e6)
cache_dtype = torch.int if fast_flush else torch.int8
cache = torch.empty(cache_size, dtype=cache_dtype, device="cuda")
# Estimate the runtime of the function
# Estimate kernel runtime with 5 iterations
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
......@@ -60,41 +111,87 @@ def do_bench(
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5
# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))
if _n_warmup > 0:
n_warmup = _n_warmup
if _n_repeat > 0:
n_repeat = _n_repeat
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
# Warm-up
# Calculate warmup and repeat counts (minimum 1 iteration each)
n_warmup = _n_warmup if _n_warmup > 0 else max(1, int(warmup / estimate_ms))
n_repeat = _n_repeat if _n_repeat > 0 else max(1, int(rep / estimate_ms))
# Warmup phase
for _ in range(n_warmup):
fn()
# Benchmark
# Benchmarking phase
if backend == "event":
return _bench_with_cuda_events(fn, cache, n_repeat, quantiles, return_mode)
elif backend == "cupti":
return _bench_with_cupti(fn, cache, n_repeat)
else:
raise ValueError(f"Unknown profiler backend: {backend}")
def _bench_with_cuda_events(
fn: Callable,
cache: torch.Tensor,
n_repeat: int,
quantiles: Optional[List[float]],
return_mode: str,
) -> Union[float, List[float]]:
"""Benchmark using CUDA events for timing."""
# Create timing events
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(n_repeat)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(n_repeat)]
# Run benchmark iterations
for i in range(n_repeat):
# we don't want `fn` to accumulate gradient values
# if it contains a backward pass. So we clear the
# provided gradients
if grad_to_none is not None:
for x in grad_to_none:
x.grad = None
# we clear the L2 cache before each run
cache.zero_()
# record time of `fn`
start_event[i].record()
cache.zero_() # Clear L2 cache
start_events[i].record()
fn()
end_event[i].record()
# Record clocks
end_events[i].record()
# Synchronize and collect timings
torch.cuda.synchronize()
times = torch.tensor(
[s.elapsed_time(e) for s, e in zip(start_event, end_event)],
[s.elapsed_time(e) for s, e in zip(start_events, end_events)],
dtype=torch.float,
)
# Return quantiles if requested
if quantiles is not None:
ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
if len(ret) == 1:
ret = ret[0]
return ret
quantile_values = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
return quantile_values[0] if len(quantile_values) == 1 else quantile_values
# Return aggregated result
return getattr(torch, return_mode)(times).item()
def _bench_with_cupti(
fn: Callable,
cache: torch.Tensor,
n_repeat: int,
) -> float:
"""Benchmark using CUPTI profiler for detailed kernel timing."""
with suppress_stdout_stderr():
schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1)
profiler = torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA],
schedule=schedule,
)
with profiler:
for _ in range(2):
for _ in range(n_repeat):
cache.zero_()
fn()
profiler.step()
# Calculate average kernel time, excluding cache-clearing overhead
total_cuda_time = 0.0
excluded_time = 0.0
excluded_kernels = "at::native::vectorized_elementwise"
for event in profiler.key_averages():
total_cuda_time += event.self_device_time_total
if excluded_kernels in event.key:
excluded_time += event.self_device_time_total
kernel_time_us = (total_cuda_time - excluded_time) / n_repeat
return kernel_time_us * 1e-3 # Convert microseconds to milliseconds
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment