Unverified Commit 66a22096 authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

[Hardware] Replace `torch.cuda.synchronize()` api with `torch.accelerator.synchronize` (#36085)


Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
parent 0bfa229b
......@@ -109,16 +109,16 @@ def run_benchmark(
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
function_under_test()
torch.cuda.synchronize()
torch.accelerator.synchronize()
function_under_test = lambda: g.replay()
def run_cuda_benchmark(n_iters: int) -> float:
nonlocal key, value, key_cache, value_cache, slot_mapping
torch.cuda.synchronize()
torch.accelerator.synchronize()
start = time.perf_counter()
for _ in range(n_iters):
function_under_test()
torch.cuda.synchronize()
torch.accelerator.synchronize()
end = time.perf_counter()
return (end - start) / n_iters
......
......@@ -251,7 +251,7 @@ def benchmark(
kernel(
y, tokens_per_expert, num_parallel_tokens=num_parallel_tokens, group_size=G
)
torch.cuda.synchronize()
torch.accelerator.synchronize()
start_event = torch.Event(enable_timing=True)
end_event = torch.Event(enable_timing=True)
......@@ -259,7 +259,7 @@ def benchmark(
# Benchmark
latencies: list[float] = []
for _ in range(runs):
torch.cuda.synchronize()
torch.accelerator.synchronize()
start_event.record()
for i in range(iterations_per_run):
......
......@@ -126,7 +126,7 @@ def benchmark_decode(
)
def time_fn(fn, warmup=10, trials=20):
torch.cuda.synchronize()
torch.accelerator.synchronize()
start = torch.Event(enable_timing=True)
end = torch.Event(enable_timing=True)
times = []
......@@ -136,7 +136,7 @@ def benchmark_decode(
start.record()
fn()
end.record()
torch.cuda.synchronize()
torch.accelerator.synchronize()
times.append(start.elapsed_time(end)) # ms
return sum(times) / len(times), torch.std(torch.tensor(times))
......
......@@ -138,7 +138,7 @@ def benchmark_prefill(
)
def time_fn(fn, warmup=10, trials=20):
torch.cuda.synchronize()
torch.accelerator.synchronize()
start = torch.Event(enable_timing=True)
end = torch.Event(enable_timing=True)
times = []
......@@ -148,7 +148,7 @@ def benchmark_prefill(
start.record()
fn()
end.record()
torch.cuda.synchronize()
torch.accelerator.synchronize()
times.append(start.elapsed_time(end)) # ms
return sum(times) / len(times), torch.std(torch.tensor(times))
......
......@@ -177,18 +177,18 @@ def benchmark_config(
def run():
w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype)
torch.cuda.synchronize()
torch.accelerator.synchronize()
# JIT complication & warmup
for _ in range(5):
run()
torch.cuda.synchronize()
torch.accelerator.synchronize()
start_event = torch.Event(enable_timing=True)
end_event = torch.Event(enable_timing=True)
latencies: list[float] = []
for i in range(num_iters):
torch.cuda.synchronize()
torch.accelerator.synchronize()
start_event.record()
run()
end_event.record()
......
......@@ -35,7 +35,7 @@ def benchmark_shape(
B = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
# Reference result in BF16
torch.cuda.synchronize()
torch.accelerator.synchronize()
C_ref = A @ B.t()
# Pre-quantize B for all implementations
......@@ -121,14 +121,14 @@ def benchmark_shape(
# Warmup
for _ in range(warmup):
func()
torch.cuda.synchronize()
torch.accelerator.synchronize()
# Timing loop
torch.cuda.synchronize()
torch.accelerator.synchronize()
start = time.time()
for _ in range(repeat):
func()
torch.cuda.synchronize()
torch.accelerator.synchronize()
end = time.time()
# Calculate timing and TFLOPS
......
......@@ -50,7 +50,7 @@ V1 was not originally designed with async scheduling in mind, and support requir
## 3. Removing Async Barrier
A key requirement for async execution is that CPU operations remain non-blocking. Both explicit sync (for example, `torch.cuda.synchronize`) and implicit sync (for example, unpinned `.to("cuda")`) must be avoided.
A key requirement for async execution is that CPU operations remain non-blocking. Both explicit sync (for example, `torch.accelerator.synchronize`) and implicit sync (for example, unpinned `.to("cuda")`) must be avoided.
However, async execution can introduce race conditions when CPU and GPU concurrently touch the same memory.
......
......@@ -95,7 +95,7 @@ If GPU/CPU communication cannot be established, you can use the following Python
torch.cuda.set_device(local_rank)
data = torch.FloatTensor([1,] * 128).to("cuda")
dist.all_reduce(data, op=dist.ReduceOp.SUM)
torch.cuda.synchronize()
torch.accelerator.synchronize()
value = data.mean().item()
world_size = dist.get_world_size()
assert value == world_size, f"Expected {world_size}, got {value}"
......
......@@ -88,7 +88,7 @@ class RayTrainingActor:
# Zero out all the parameters.
for name, p in self.model.named_parameters():
p.data.zero_()
torch.cuda.synchronize()
torch.accelerator.synchronize()
# The argument for `get_device_uuid` is the index of the GPU in the
# list of visible devices.
from vllm.platforms import current_platform
......@@ -151,7 +151,7 @@ class RayTrainingActor:
p.data.view(-1).view(dtype=torch.uint8), non_blocking=True
)
offset += get_size(p)
torch.cuda.synchronize()
torch.accelerator.synchronize()
s.send_pyobj(named_tensors)
s.recv()
s.send_pyobj(None)
......
......@@ -120,7 +120,7 @@ class ColocateWorkerExtension:
process_weights_after_loading(
self.model_runner.model, self.model_config, self.device
)
torch.cuda.synchronize()
torch.accelerator.synchronize()
socket.send(b"")
break
if isinstance(payload, tuple):
......@@ -144,7 +144,7 @@ class ColocateWorkerExtension:
weights.append((item["name"], tensor))
self.model_runner.model.load_weights(weights=weights)
del weights
torch.cuda.synchronize()
torch.accelerator.synchronize()
socket.send(b"")
socket.close()
......
......@@ -100,7 +100,7 @@ def test_dynamic_shapes_compilation(
del model
gc.collect()
torch.accelerator.empty_cache()
torch.cuda.synchronize()
torch.accelerator.synchronize()
print("GPU memory cleared")
......
......@@ -32,7 +32,7 @@ pointers = CustomAllreduce.create_shared_buffer(buffer_size_in_bytes)
print(f"Rank {rank} has pointers {pointers}")
dist.barrier()
torch.cuda.synchronize()
torch.accelerator.synchronize()
if rank == 0:
# the first rank tries to write to all buffers
......@@ -41,7 +41,7 @@ if rank == 0:
lib.cudaMemset(pointer, byte_value, buffer_size_in_bytes)
dist.barrier()
torch.cuda.synchronize()
torch.accelerator.synchronize()
host_data = (ctypes.c_char * buffer_size_in_bytes)()
......@@ -59,6 +59,6 @@ for p in pointers:
print(f"Rank {rank} verified all buffers")
dist.barrier()
torch.cuda.synchronize()
torch.accelerator.synchronize()
CustomAllreduce.free_shared_buffer(pointers)
......@@ -48,7 +48,7 @@ def graph_allreduce(
data = torch.zeros(1)
data = data.to(device=device)
torch.distributed.all_reduce(data, group=group)
torch.cuda.synchronize()
torch.accelerator.synchronize()
del data
# we use the first group to communicate once
......@@ -68,7 +68,7 @@ def graph_allreduce(
inp2 = torch.randint(
1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device()
)
torch.cuda.synchronize()
torch.accelerator.synchronize()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=graph_capture_context.stream):
for i in range(num_communication):
......
......@@ -68,7 +68,7 @@ def worker_fn():
)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank)
tensor = pynccl_comm.all_reduce(tensor)
torch.cuda.synchronize()
torch.accelerator.synchronize()
assert torch.all(tensor == pynccl_comm.world_size).cpu().item()
......@@ -93,11 +93,11 @@ def multiple_allreduce_worker_fn():
if torch.distributed.get_rank() in [0, 1]:
tensor = pynccl_comm.all_reduce(tensor)
tensor = pynccl_comm.all_reduce(tensor)
torch.cuda.synchronize()
torch.accelerator.synchronize()
assert torch.all(tensor == 4).cpu().item()
else:
tensor = pynccl_comm.all_reduce(tensor)
torch.cuda.synchronize()
torch.accelerator.synchronize()
assert torch.all(tensor == 2).cpu().item()
......@@ -121,11 +121,11 @@ def multiple_allreduce_with_vllm_worker_fn():
if torch.distributed.get_rank() in [0, 1]:
tensor = tensor_model_parallel_all_reduce(tensor)
tensor = tensor_model_parallel_all_reduce(tensor)
torch.cuda.synchronize()
torch.accelerator.synchronize()
assert torch.all(tensor == 4).cpu().item()
else:
tensor = tensor_model_parallel_all_reduce(tensor)
torch.cuda.synchronize()
torch.accelerator.synchronize()
assert torch.all(tensor == 2).cpu().item()
......@@ -147,12 +147,12 @@ def worker_fn_with_cudagraph():
)
# run something in the default stream to initialize torch engine
a = torch.ones((4, 4), device=f"cuda:{pynccl_comm.rank}")
torch.cuda.synchronize()
torch.accelerator.synchronize()
with torch.cuda.graph(graph):
a_out = pynccl_comm.all_reduce(a)
torch.cuda.synchronize()
torch.accelerator.synchronize()
graph.replay()
torch.cuda.synchronize()
torch.accelerator.synchronize()
assert torch.all(a_out == pynccl_comm.world_size).cpu().item()
......@@ -180,7 +180,7 @@ def all_gather_worker_fn():
).to(device)
pynccl_comm.all_gather(result, tensor)
torch.cuda.synchronize()
torch.accelerator.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
......@@ -215,7 +215,7 @@ def all_gatherv_worker_fn():
).to(device)
pynccl_comm.all_gatherv(result, tensor, sizes=sizes)
torch.cuda.synchronize()
torch.accelerator.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
......@@ -255,7 +255,7 @@ def reduce_scatter_worker_fn():
).to(device)
pynccl_comm.reduce_scatter(result, tensor)
torch.cuda.synchronize()
torch.accelerator.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
......@@ -293,7 +293,7 @@ def reduce_scatterv_worker_fn():
expected = sum(tensor[start:end] for tensor in all_tensors).to(device)
pynccl_comm.reduce_scatterv(result, tensor, sizes=sizes)
torch.cuda.synchronize()
torch.accelerator.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
......@@ -325,7 +325,7 @@ def send_recv_worker_fn():
pynccl_comm.send(tensor, dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
else:
pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
torch.cuda.synchronize()
torch.accelerator.synchronize()
assert torch.all(tensor == 1).cpu().item()
......@@ -355,7 +355,7 @@ def multiple_send_recv_worker_fn():
pynccl_comm.send(tensor, dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
else:
pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
torch.cuda.synchronize()
torch.accelerator.synchronize()
if torch.distributed.get_rank() in [0, 2]:
assert torch.all(tensor == 1).cpu().item()
else:
......@@ -396,7 +396,7 @@ def broadcast_worker_fn():
pynccl_comm.broadcast(recv_tensors[i], src=i)
# the broadcast op might be launched in a different stream
# need to synchronize to make sure the tensor is ready
torch.cuda.synchronize()
torch.accelerator.synchronize()
assert torch.all(recv_tensors[i] == i).cpu().item()
......
......@@ -52,7 +52,7 @@ def graph_quickreduce(
data = torch.zeros(1)
data = data.to(device=device)
torch.distributed.all_reduce(data, group=group)
torch.cuda.synchronize()
torch.accelerator.synchronize()
del data
# we use the first group to communicate once
......@@ -71,7 +71,7 @@ def graph_quickreduce(
inp2 = torch.randint(
-23, 1, (sz,), dtype=dtype, device=torch.cuda.current_device()
)
torch.cuda.synchronize()
torch.accelerator.synchronize()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=graph_capture_context.stream):
for _ in range(num_communication):
......
......@@ -79,11 +79,11 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2):
data = torch.tensor([rank]).cuda()
pynccl1.all_reduce(data)
pg1.barrier()
torch.cuda.synchronize()
torch.accelerator.synchronize()
if rank <= 2:
pynccl2.all_reduce(data)
pg2.barrier()
torch.cuda.synchronize()
torch.accelerator.synchronize()
item = data[0].item()
print(f"rank: {rank}, item: {item}")
if rank == 3:
......
......@@ -251,7 +251,7 @@ def trainer_broadcast_tensor(
dtype = getattr(torch, tensor_dtype)
tensor_to_send = torch.ones(tensor_shape, dtype=dtype, device="cuda:0")
comm.broadcast(tensor_to_send, src=0, stream=torch.cuda.current_stream())
torch.cuda.synchronize()
torch.accelerator.synchronize()
return True
......@@ -309,7 +309,7 @@ def inference_receive_tensor(
shapes=[tensor_shape],
)
engine.receive_weights(update_info, noop_load_weights)
torch.cuda.synchronize()
torch.accelerator.synchronize()
# Verify we received the tensor
success = False
......@@ -630,7 +630,7 @@ class TrainerActor:
ipc_handle = reduce_tensor(self.tensor)
gpu_uuid = get_physical_gpu_id(0)
torch.cuda.synchronize()
torch.accelerator.synchronize()
self.ipc_handle_dict = {
"ipc_handle": ipc_handle,
......@@ -704,7 +704,7 @@ def inference_receive_ipc_tensor(
update_info = engine.parse_update_info(update_dict)
engine.receive_weights(update_info, noop_load_weights)
torch.cuda.synchronize()
torch.accelerator.synchronize()
# Verify we received the tensor
success = False
......
......@@ -165,7 +165,7 @@ def test_merge_attn_states(
suffix_lse_torch,
output_lse_torch,
)
torch.cuda.synchronize()
torch.accelerator.synchronize()
for _ in range(repeat_times):
start.record()
......@@ -178,7 +178,7 @@ def test_merge_attn_states(
output_lse_torch,
)
end.record()
torch.cuda.synchronize()
torch.accelerator.synchronize()
total_time_torch_kernel += start.elapsed_time(end)
avg_time_torch_kernel = total_time_torch_kernel / repeat_times
......@@ -200,7 +200,7 @@ def test_merge_attn_states(
suffix_lse,
output_lse_ref_triton,
)
torch.cuda.synchronize()
torch.accelerator.synchronize()
for _ in range(repeat_times):
start.record()
......@@ -213,7 +213,7 @@ def test_merge_attn_states(
output_lse_ref_triton,
)
end.record()
torch.cuda.synchronize()
torch.accelerator.synchronize()
total_time_triton_kernel += start.elapsed_time(end)
avg_time_triton_kernel = total_time_triton_kernel / repeat_times
......@@ -232,7 +232,7 @@ def test_merge_attn_states(
suffix_lse,
output_lse_cuda,
)
torch.cuda.synchronize()
torch.accelerator.synchronize()
for _ in range(repeat_times):
start.record()
......@@ -245,7 +245,7 @@ def test_merge_attn_states(
output_lse_cuda,
)
end.record()
torch.cuda.synchronize()
torch.accelerator.synchronize()
total_time_cuda_kernel += start.elapsed_time(end)
avg_time_cuda_kernel = total_time_cuda_kernel / repeat_times
......
......@@ -239,7 +239,7 @@ def test_contexted_kv_attention(
v_scale,
sliding_window=sliding_window,
)
torch.cuda.synchronize()
torch.accelerator.synchronize()
start_time = time.time()
op(
query,
......@@ -258,7 +258,7 @@ def test_contexted_kv_attention(
v_scale,
sliding_window=sliding_window,
)
torch.cuda.synchronize()
torch.accelerator.synchronize()
end_time = time.time()
print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms")
......@@ -298,7 +298,7 @@ def test_contexted_kv_attention(
dropout_p=0.0,
scale=scale,
)
torch.cuda.synchronize()
torch.accelerator.synchronize()
start_time = time.time()
output_ref = F.scaled_dot_product_attention(
query_sdpa,
......@@ -308,7 +308,7 @@ def test_contexted_kv_attention(
dropout_p=0.0,
scale=scale,
)
torch.cuda.synchronize()
torch.accelerator.synchronize()
end_time = time.time()
print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms")
......@@ -482,7 +482,7 @@ def test_contexted_kv_attention_alibi(
v_scale,
alibi_slopes=alibi_slopes,
)
torch.cuda.synchronize()
torch.accelerator.synchronize()
start_time = time.time()
op(
query,
......@@ -501,7 +501,7 @@ def test_contexted_kv_attention_alibi(
v_scale,
alibi_slopes=alibi_slopes,
)
torch.cuda.synchronize()
torch.accelerator.synchronize()
end_time = time.time()
print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms")
scale = float(1.0 / (head_size**0.5))
......@@ -517,7 +517,7 @@ def test_contexted_kv_attention_alibi(
output_ref = torch.empty_like(output)
torch.cuda.synchronize()
torch.accelerator.synchronize()
start_time = time.time()
query_start = 0
......@@ -572,7 +572,7 @@ def test_contexted_kv_attention_alibi(
query_start = query_end
key_start = key_end
torch.cuda.synchronize()
torch.accelerator.synchronize()
end_time = time.time()
print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms")
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
......
......@@ -127,7 +127,7 @@ def test_fused_rms_norm_quant(
out_quant, x_unfused.contiguous(), quant_scale_t
)
torch.cuda.synchronize()
torch.accelerator.synchronize()
torch.testing.assert_close(residual_fused, residual, atol=1e-2, rtol=1e-2)
opcheck(
torch.ops._C.fused_add_rms_norm_static_fp8_quant,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment