Unverified Commit 66a22096 authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

[Hardware] Replace `torch.cuda.synchronize()` api with `torch.accelerator.synchronize` (#36085)


Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
parent 0bfa229b
...@@ -701,7 +701,7 @@ def _run_single_benchmark( ...@@ -701,7 +701,7 @@ def _run_single_benchmark(
# Warmup # Warmup
for _ in range(config.warmup_iters): for _ in range(config.warmup_iters):
forward_fn() forward_fn()
torch.cuda.synchronize() torch.accelerator.synchronize()
# Benchmark # Benchmark
times = [] times = []
...@@ -714,7 +714,7 @@ def _run_single_benchmark( ...@@ -714,7 +714,7 @@ def _run_single_benchmark(
forward_fn() forward_fn()
end.record() end.record()
torch.cuda.synchronize() torch.accelerator.synchronize()
elapsed_ms = start.elapsed_time(end) elapsed_ms = start.elapsed_time(end)
times.append(elapsed_ms / 1000.0 / config.num_layers) times.append(elapsed_ms / 1000.0 / config.num_layers)
......
...@@ -391,7 +391,7 @@ def _run_single_benchmark( ...@@ -391,7 +391,7 @@ def _run_single_benchmark(
attn_metadata, attn_metadata,
output=out, output=out,
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
# Benchmark # Benchmark
times = [] times = []
...@@ -412,7 +412,7 @@ def _run_single_benchmark( ...@@ -412,7 +412,7 @@ def _run_single_benchmark(
) )
end.record() end.record()
torch.cuda.synchronize() torch.accelerator.synchronize()
elapsed_ms = start.elapsed_time(end) elapsed_ms = start.elapsed_time(end)
times.append(elapsed_ms / 1000.0 / config.num_layers) # seconds per layer times.append(elapsed_ms / 1000.0 / config.num_layers) # seconds per layer
......
...@@ -94,7 +94,7 @@ def create_logits( ...@@ -94,7 +94,7 @@ def create_logits(
def measure_memory() -> tuple[int, int]: def measure_memory() -> tuple[int, int]:
"""Return (allocated, reserved) memory in bytes.""" """Return (allocated, reserved) memory in bytes."""
torch.cuda.synchronize() torch.accelerator.synchronize()
return torch.cuda.memory_allocated(), torch.cuda.max_memory_allocated() return torch.cuda.memory_allocated(), torch.cuda.max_memory_allocated()
...@@ -123,7 +123,7 @@ def benchmark_function( ...@@ -123,7 +123,7 @@ def benchmark_function(
for _ in range(warmup_iters): for _ in range(warmup_iters):
logits_copy = logits.clone() logits_copy = logits.clone()
func(logits_copy, k, p) func(logits_copy, k, p)
torch.cuda.synchronize() torch.accelerator.synchronize()
# Reset memory stats before benchmark # Reset memory stats before benchmark
reset_memory_stats() reset_memory_stats()
...@@ -140,7 +140,7 @@ def benchmark_function( ...@@ -140,7 +140,7 @@ def benchmark_function(
func(logits_copy, k, p) func(logits_copy, k, p)
end_events[i].record() end_events[i].record()
torch.cuda.synchronize() torch.accelerator.synchronize()
# Calculate timing # Calculate timing
times = [ times = [
......
...@@ -168,7 +168,7 @@ def bench_impl( ...@@ -168,7 +168,7 @@ def bench_impl(
# warmup # warmup
for kwargs in kwargs_list: for kwargs in kwargs_list:
impl_type.get_impl()(**kwargs) impl_type.get_impl()(**kwargs)
torch.cuda.synchronize() torch.accelerator.synchronize()
# Merge into a single kwargs and qualify arguments as ArgPool # Merge into a single kwargs and qualify arguments as ArgPool
kwargs = {k: ArgPool([]) for k in kwargs_list[0]} kwargs = {k: ArgPool([]) for k in kwargs_list[0]}
......
...@@ -171,7 +171,7 @@ def bench_run( ...@@ -171,7 +171,7 @@ def bench_run(
activation=MoEActivation.SILU, activation=MoEActivation.SILU,
global_num_experts=num_experts, global_num_experts=num_experts,
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
# Create CUDA graphs for Triton (match benchmark_moe.py pattern exactly) # Create CUDA graphs for Triton (match benchmark_moe.py pattern exactly)
triton_stream = torch.cuda.Stream() triton_stream = torch.cuda.Stream()
...@@ -187,14 +187,14 @@ def bench_run( ...@@ -187,14 +187,14 @@ def bench_run(
topk_ids, topk_ids,
quant_config=quant_config, quant_config=quant_config,
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
def bench_cuda_graph(graph, num_warmup=5, num_iters=100): def bench_cuda_graph(graph, num_warmup=5, num_iters=100):
"""Benchmark CUDA graph using events like benchmark_moe.py""" """Benchmark CUDA graph using events like benchmark_moe.py"""
# Warmup # Warmup
for _ in range(num_warmup): for _ in range(num_warmup):
graph.replay() graph.replay()
torch.cuda.synchronize() torch.accelerator.synchronize()
# Timing # Timing
start_event = torch.Event(enable_timing=True) start_event = torch.Event(enable_timing=True)
...@@ -202,7 +202,7 @@ def bench_run( ...@@ -202,7 +202,7 @@ def bench_run(
latencies = [] latencies = []
for _ in range(num_iters): for _ in range(num_iters):
torch.cuda.synchronize() torch.accelerator.synchronize()
start_event.record() start_event.record()
graph.replay() graph.replay()
end_event.record() end_event.record()
......
...@@ -307,7 +307,7 @@ def bench_run( ...@@ -307,7 +307,7 @@ def bench_run(
def replay_graph(graph, num_repeats): def replay_graph(graph, num_repeats):
for _ in range(num_repeats): for _ in range(num_repeats):
graph.replay() graph.replay()
torch.cuda.synchronize() torch.accelerator.synchronize()
cutlass_stream = torch.cuda.Stream() cutlass_stream = torch.cuda.Stream()
cutlass_graph = torch.cuda.CUDAGraph() cutlass_graph = torch.cuda.CUDAGraph()
...@@ -330,7 +330,7 @@ def bench_run( ...@@ -330,7 +330,7 @@ def bench_run(
e=num_experts, e=num_experts,
device=device, device=device,
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
triton_stream = torch.cuda.Stream() triton_stream = torch.cuda.Stream()
triton_graph = torch.cuda.CUDAGraph() triton_graph = torch.cuda.CUDAGraph()
...@@ -345,7 +345,7 @@ def bench_run( ...@@ -345,7 +345,7 @@ def bench_run(
w2_fp8scale, w2_fp8scale,
a_fp8_scale, a_fp8_scale,
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
min_run_time = 5 min_run_time = 5
num_warmup = 5 num_warmup = 5
......
...@@ -342,7 +342,7 @@ class CommunicatorBenchmark: ...@@ -342,7 +342,7 @@ class CommunicatorBenchmark:
if not should_use_fn(tensor): if not should_use_fn(tensor):
return None return None
torch.cuda.synchronize() torch.accelerator.synchronize()
stream = torch.cuda.Stream() stream = torch.cuda.Stream()
with torch.cuda.stream(stream): with torch.cuda.stream(stream):
graph_input = tensor.clone() graph_input = tensor.clone()
...@@ -360,17 +360,17 @@ class CommunicatorBenchmark: ...@@ -360,17 +360,17 @@ class CommunicatorBenchmark:
for _ in range(CUDA_GRAPH_CAPTURE_CYCLES): for _ in range(CUDA_GRAPH_CAPTURE_CYCLES):
allreduce_fn(graph_input) allreduce_fn(graph_input)
torch.cuda.synchronize() torch.accelerator.synchronize()
for _ in range(num_warmup): for _ in range(num_warmup):
graph.replay() graph.replay()
torch.cuda.synchronize() torch.accelerator.synchronize()
torch.cuda.synchronize() torch.accelerator.synchronize()
start_time = time.perf_counter() start_time = time.perf_counter()
for _ in range(num_trials): for _ in range(num_trials):
graph.replay() graph.replay()
torch.cuda.synchronize() torch.accelerator.synchronize()
end_time = time.perf_counter() end_time = time.perf_counter()
......
...@@ -385,7 +385,7 @@ def benchmark_operation( ...@@ -385,7 +385,7 @@ def benchmark_operation(
# Warmup before graph capture # Warmup before graph capture
for _ in range(warmup): for _ in range(warmup):
operation_func(*args, **kwargs) operation_func(*args, **kwargs)
torch.cuda.synchronize() torch.accelerator.synchronize()
# Create CUDA graph # Create CUDA graph
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
...@@ -398,19 +398,19 @@ def benchmark_operation( ...@@ -398,19 +398,19 @@ def benchmark_operation(
operation_func(*args, **kwargs) operation_func(*args, **kwargs)
# Graph warmup # Graph warmup
torch.cuda.synchronize() torch.accelerator.synchronize()
for _ in range(warmup): for _ in range(warmup):
graph.replay() graph.replay()
# Benchmark with CUDA graph # Benchmark with CUDA graph
torch.cuda.synchronize() torch.accelerator.synchronize()
start_time = time.perf_counter() start_time = time.perf_counter()
for _ in range(trials // num_op_per_cudagraph): for _ in range(trials // num_op_per_cudagraph):
# operation_func(*args, **kwargs) # operation_func(*args, **kwargs)
graph.replay() graph.replay()
torch.cuda.synchronize() torch.accelerator.synchronize()
end_time = time.perf_counter() end_time = time.perf_counter()
avg_time_ms = ((end_time - start_time) / trials) * 1000 avg_time_ms = ((end_time - start_time) / trials) * 1000
......
...@@ -224,7 +224,7 @@ def bench_run( ...@@ -224,7 +224,7 @@ def bench_run(
def replay_graph(graph, num_repeats): def replay_graph(graph, num_repeats):
for _ in range(num_repeats): for _ in range(num_repeats):
graph.replay() graph.replay()
torch.cuda.synchronize() torch.accelerator.synchronize()
cutlass_stream = torch.cuda.Stream() cutlass_stream = torch.cuda.Stream()
cutlass_graph = torch.cuda.CUDAGraph() cutlass_graph = torch.cuda.CUDAGraph()
...@@ -239,7 +239,7 @@ def bench_run( ...@@ -239,7 +239,7 @@ def bench_run(
topk_weights, topk_weights,
topk_ids, topk_ids,
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
triton_stream = torch.cuda.Stream() triton_stream = torch.cuda.Stream()
triton_graph = torch.cuda.CUDAGraph() triton_graph = torch.cuda.CUDAGraph()
...@@ -254,7 +254,7 @@ def bench_run( ...@@ -254,7 +254,7 @@ def bench_run(
w2_scale, w2_scale,
a_scale, a_scale,
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
min_run_time = 5 min_run_time = 5
num_warmup = 5 num_warmup = 5
......
...@@ -34,14 +34,14 @@ def main( ...@@ -34,14 +34,14 @@ def main(
residual = torch.randn_like(x) * scale if add_residual else None residual = torch.randn_like(x) * scale if add_residual else None
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
torch.cuda.synchronize() torch.accelerator.synchronize()
if profile: if profile:
torch.cuda.cudart().cudaProfilerStart() torch.cuda.cudart().cudaProfilerStart()
start_time = time.perf_counter() start_time = time.perf_counter()
for _ in range(num_iters): for _ in range(num_iters):
layer(x, residual) layer(x, residual)
torch.cuda.synchronize() torch.accelerator.synchronize()
end_time = time.perf_counter() end_time = time.perf_counter()
if profile: if profile:
......
...@@ -1035,7 +1035,7 @@ def bench_optype( ...@@ -1035,7 +1035,7 @@ def bench_optype(
# Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are set up # Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are set up
for kwargs in kwargs_list: for kwargs in kwargs_list:
op_type.bench_fn()(**kwargs) op_type.bench_fn()(**kwargs)
torch.cuda.synchronize() torch.accelerator.synchronize()
# Merge into a single kwargs and qualify arguments as ArgPool # Merge into a single kwargs and qualify arguments as ArgPool
kwargs = {k: ArgPool([]) for k in kwargs_list[0]} kwargs = {k: ArgPool([]) for k in kwargs_list[0]}
......
...@@ -47,13 +47,13 @@ def benchmark_method( ...@@ -47,13 +47,13 @@ def benchmark_method(
# Warmup # Warmup
for _ in range(num_warmup): for _ in range(num_warmup):
_ = method(k_nope, k_pe) _ = method(k_nope, k_pe)
torch.cuda.synchronize() torch.accelerator.synchronize()
# Benchmark # Benchmark
start = time.perf_counter() start = time.perf_counter()
for _ in range(num_iters): for _ in range(num_iters):
_ = method(k_nope, k_pe) _ = method(k_nope, k_pe)
torch.cuda.synchronize() torch.accelerator.synchronize()
end = time.perf_counter() end = time.perf_counter()
return (end - start) / num_iters * 1000 # Convert to ms return (end - start) / num_iters * 1000 # Convert to ms
......
...@@ -304,19 +304,19 @@ def benchmark_config( ...@@ -304,19 +304,19 @@ def benchmark_config(
# JIT compilation & warmup # JIT compilation & warmup
run() run()
torch.cuda.synchronize() torch.accelerator.synchronize()
# Capture 10 invocations with CUDA graph # Capture 10 invocations with CUDA graph
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph): with torch.cuda.graph(graph):
for _ in range(10): for _ in range(10):
run() run()
torch.cuda.synchronize() torch.accelerator.synchronize()
# Warmup # Warmup
for _ in range(5): for _ in range(5):
graph.replay() graph.replay()
torch.cuda.synchronize() torch.accelerator.synchronize()
start_event = torch.Event(enable_timing=True) start_event = torch.Event(enable_timing=True)
end_event = torch.Event(enable_timing=True) end_event = torch.Event(enable_timing=True)
...@@ -324,7 +324,7 @@ def benchmark_config( ...@@ -324,7 +324,7 @@ def benchmark_config(
latencies: list[float] = [] latencies: list[float] = []
for i in range(num_iters): for i in range(num_iters):
prepare(i) prepare(i)
torch.cuda.synchronize() torch.accelerator.synchronize()
start_event.record() start_event.record()
graph.replay() graph.replay()
......
...@@ -131,7 +131,7 @@ def benchmark_config( ...@@ -131,7 +131,7 @@ def benchmark_config(
topk_ids, topk_ids,
quant_config=quant_config, quant_config=quant_config,
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
# Benchmark # Benchmark
start = torch.cuda.Event(enable_timing=True) start = torch.cuda.Event(enable_timing=True)
...@@ -149,7 +149,7 @@ def benchmark_config( ...@@ -149,7 +149,7 @@ def benchmark_config(
quant_config=quant_config, quant_config=quant_config,
) )
end.record() end.record()
torch.cuda.synchronize() torch.accelerator.synchronize()
return start.elapsed_time(end) / num_iters * 1000 # ms -> us return start.elapsed_time(end) / num_iters * 1000 # ms -> us
......
...@@ -69,19 +69,19 @@ def benchmark_permute( ...@@ -69,19 +69,19 @@ def benchmark_permute(
# JIT compilation & warmup # JIT compilation & warmup
run() run()
torch.cuda.synchronize() torch.accelerator.synchronize()
# Capture 10 invocations with CUDA graph # Capture 10 invocations with CUDA graph
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph): with torch.cuda.graph(graph):
for _ in range(10): for _ in range(10):
run() run()
torch.cuda.synchronize() torch.accelerator.synchronize()
# Warmup # Warmup
for _ in range(5): for _ in range(5):
graph.replay() graph.replay()
torch.cuda.synchronize() torch.accelerator.synchronize()
start_event = torch.Event(enable_timing=True) start_event = torch.Event(enable_timing=True)
end_event = torch.Event(enable_timing=True) end_event = torch.Event(enable_timing=True)
...@@ -89,7 +89,7 @@ def benchmark_permute( ...@@ -89,7 +89,7 @@ def benchmark_permute(
latencies: list[float] = [] latencies: list[float] = []
for i in range(num_iters): for i in range(num_iters):
prepare(i) prepare(i)
torch.cuda.synchronize() torch.accelerator.synchronize()
start_event.record() start_event.record()
graph.replay() graph.replay()
...@@ -159,26 +159,26 @@ def benchmark_unpermute( ...@@ -159,26 +159,26 @@ def benchmark_unpermute(
# JIT compilation & warmup # JIT compilation & warmup
input = prepare() input = prepare()
run(input) run(input)
torch.cuda.synchronize() torch.accelerator.synchronize()
# Capture 10 invocations with CUDA graph # Capture 10 invocations with CUDA graph
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph): with torch.cuda.graph(graph):
for _ in range(10): for _ in range(10):
run(input) run(input)
torch.cuda.synchronize() torch.accelerator.synchronize()
# Warmup # Warmup
for _ in range(5): for _ in range(5):
graph.replay() graph.replay()
torch.cuda.synchronize() torch.accelerator.synchronize()
start_event = torch.Event(enable_timing=True) start_event = torch.Event(enable_timing=True)
end_event = torch.Event(enable_timing=True) end_event = torch.Event(enable_timing=True)
latencies: list[float] = [] latencies: list[float] = []
for i in range(num_iters): for i in range(num_iters):
torch.cuda.synchronize() torch.accelerator.synchronize()
start_event.record() start_event.record()
graph.replay() graph.replay()
end_event.record() end_event.record()
......
...@@ -135,14 +135,14 @@ def benchmark_mrope( ...@@ -135,14 +135,14 @@ def benchmark_mrope(
key.clone(), key.clone(),
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
# Time reference implementation # Time reference implementation
torch_times = [] torch_times = []
for _ in range(benchmark_iter): for _ in range(benchmark_iter):
query_clone = query.clone() query_clone = query.clone()
key_clone = key.clone() key_clone = key.clone()
torch.cuda.synchronize() torch.accelerator.synchronize()
start_time = time.time() start_time = time.time()
mrope_helper_class.forward_native( mrope_helper_class.forward_native(
...@@ -151,7 +151,7 @@ def benchmark_mrope( ...@@ -151,7 +151,7 @@ def benchmark_mrope(
key_clone, key_clone,
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
torch_times.append(time.time() - start_time) torch_times.append(time.time() - start_time)
# Time triton kernel implementation # Time triton kernel implementation
...@@ -159,14 +159,14 @@ def benchmark_mrope( ...@@ -159,14 +159,14 @@ def benchmark_mrope(
for _ in range(benchmark_iter): for _ in range(benchmark_iter):
query_clone = query.clone() query_clone = query.clone()
key_clone = key.clone() key_clone = key.clone()
torch.cuda.synchronize() torch.accelerator.synchronize()
start_time = time.time() start_time = time.time()
mrope_helper_class.forward_cuda( mrope_helper_class.forward_cuda(
positions, positions,
query_clone, query_clone,
key_clone, key_clone,
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
triton_times.append(time.time() - start_time) triton_times.append(time.time() - start_time)
# Calculate statistics # Calculate statistics
......
...@@ -103,7 +103,7 @@ def main( ...@@ -103,7 +103,7 @@ def main(
max_logits = torch.empty_like(exp_sums) max_logits = torch.empty_like(exp_sums)
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
torch.cuda.synchronize() torch.accelerator.synchronize()
if profile: if profile:
torch.cuda.cudart().cudaProfilerStart() torch.cuda.cudart().cudaProfilerStart()
start_time = time.perf_counter() start_time = time.perf_counter()
...@@ -173,7 +173,7 @@ def main( ...@@ -173,7 +173,7 @@ def main(
) )
else: else:
raise ValueError(f"Invalid version: {version}") raise ValueError(f"Invalid version: {version}")
torch.cuda.synchronize() torch.accelerator.synchronize()
end_time = time.perf_counter() end_time = time.perf_counter()
if profile: if profile:
......
...@@ -28,7 +28,7 @@ def _time_cuda( ...@@ -28,7 +28,7 @@ def _time_cuda(
# warmup # warmup
for _ in range(warmup_iters): for _ in range(warmup_iters):
fn() fn()
torch.cuda.synchronize() torch.accelerator.synchronize()
start = torch.Event(enable_timing=True) start = torch.Event(enable_timing=True)
end = torch.Event(enable_timing=True) end = torch.Event(enable_timing=True)
...@@ -37,7 +37,7 @@ def _time_cuda( ...@@ -37,7 +37,7 @@ def _time_cuda(
for _ in range(bench_iters): for _ in range(bench_iters):
fn() fn()
end.record() end.record()
torch.cuda.synchronize() torch.accelerator.synchronize()
return start.elapsed_time(end) / bench_iters # ms/iter return start.elapsed_time(end) / bench_iters # ms/iter
......
...@@ -29,7 +29,7 @@ def main( ...@@ -29,7 +29,7 @@ def main(
scale = torch.randn(1, 1, dtype=torch.float32) if static_scale else None scale = torch.randn(1, 1, dtype=torch.float32) if static_scale else None
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
torch.cuda.synchronize() torch.accelerator.synchronize()
if profile: if profile:
torch.cuda.cudart().cudaProfilerStart() torch.cuda.cudart().cudaProfilerStart()
start_time = time.perf_counter() start_time = time.perf_counter()
...@@ -39,7 +39,7 @@ def main( ...@@ -39,7 +39,7 @@ def main(
ops.scaled_int8_quant(x, scale) ops.scaled_int8_quant(x, scale)
else: else:
ops.scaled_fp8_quant(x, scale) ops.scaled_fp8_quant(x, scale)
torch.cuda.synchronize() torch.accelerator.synchronize()
end_time = time.perf_counter() end_time = time.perf_counter()
if profile: if profile:
......
...@@ -84,16 +84,16 @@ def run_benchmark( ...@@ -84,16 +84,16 @@ def run_benchmark(
g = torch.cuda.CUDAGraph() g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g): with torch.cuda.graph(g):
function_under_test() function_under_test()
torch.cuda.synchronize() torch.accelerator.synchronize()
function_under_test = lambda: g.replay() function_under_test = lambda: g.replay()
def run_cuda_benchmark(n_iters: int) -> float: def run_cuda_benchmark(n_iters: int) -> float:
nonlocal key, value, key_cache, value_cache, slot_mapping nonlocal key, value, key_cache, value_cache, slot_mapping
torch.cuda.synchronize() torch.accelerator.synchronize()
start = time.perf_counter() start = time.perf_counter()
for _ in range(n_iters): for _ in range(n_iters):
function_under_test() function_under_test()
torch.cuda.synchronize() torch.accelerator.synchronize()
end = time.perf_counter() end = time.perf_counter()
return (end - start) / n_iters return (end - start) / n_iters
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment