Unverified Commit 66a22096 authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

[Hardware] Replace `torch.cuda.synchronize()` api with `torch.accelerator.synchronize` (#36085)


Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
parent 0bfa229b
...@@ -109,16 +109,16 @@ def run_benchmark( ...@@ -109,16 +109,16 @@ def run_benchmark(
g = torch.cuda.CUDAGraph() g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g): with torch.cuda.graph(g):
function_under_test() function_under_test()
torch.cuda.synchronize() torch.accelerator.synchronize()
function_under_test = lambda: g.replay() function_under_test = lambda: g.replay()
def run_cuda_benchmark(n_iters: int) -> float: def run_cuda_benchmark(n_iters: int) -> float:
nonlocal key, value, key_cache, value_cache, slot_mapping nonlocal key, value, key_cache, value_cache, slot_mapping
torch.cuda.synchronize() torch.accelerator.synchronize()
start = time.perf_counter() start = time.perf_counter()
for _ in range(n_iters): for _ in range(n_iters):
function_under_test() function_under_test()
torch.cuda.synchronize() torch.accelerator.synchronize()
end = time.perf_counter() end = time.perf_counter()
return (end - start) / n_iters return (end - start) / n_iters
......
...@@ -251,7 +251,7 @@ def benchmark( ...@@ -251,7 +251,7 @@ def benchmark(
kernel( kernel(
y, tokens_per_expert, num_parallel_tokens=num_parallel_tokens, group_size=G y, tokens_per_expert, num_parallel_tokens=num_parallel_tokens, group_size=G
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
start_event = torch.Event(enable_timing=True) start_event = torch.Event(enable_timing=True)
end_event = torch.Event(enable_timing=True) end_event = torch.Event(enable_timing=True)
...@@ -259,7 +259,7 @@ def benchmark( ...@@ -259,7 +259,7 @@ def benchmark(
# Benchmark # Benchmark
latencies: list[float] = [] latencies: list[float] = []
for _ in range(runs): for _ in range(runs):
torch.cuda.synchronize() torch.accelerator.synchronize()
start_event.record() start_event.record()
for i in range(iterations_per_run): for i in range(iterations_per_run):
......
...@@ -126,7 +126,7 @@ def benchmark_decode( ...@@ -126,7 +126,7 @@ def benchmark_decode(
) )
def time_fn(fn, warmup=10, trials=20): def time_fn(fn, warmup=10, trials=20):
torch.cuda.synchronize() torch.accelerator.synchronize()
start = torch.Event(enable_timing=True) start = torch.Event(enable_timing=True)
end = torch.Event(enable_timing=True) end = torch.Event(enable_timing=True)
times = [] times = []
...@@ -136,7 +136,7 @@ def benchmark_decode( ...@@ -136,7 +136,7 @@ def benchmark_decode(
start.record() start.record()
fn() fn()
end.record() end.record()
torch.cuda.synchronize() torch.accelerator.synchronize()
times.append(start.elapsed_time(end)) # ms times.append(start.elapsed_time(end)) # ms
return sum(times) / len(times), torch.std(torch.tensor(times)) return sum(times) / len(times), torch.std(torch.tensor(times))
......
...@@ -138,7 +138,7 @@ def benchmark_prefill( ...@@ -138,7 +138,7 @@ def benchmark_prefill(
) )
def time_fn(fn, warmup=10, trials=20): def time_fn(fn, warmup=10, trials=20):
torch.cuda.synchronize() torch.accelerator.synchronize()
start = torch.Event(enable_timing=True) start = torch.Event(enable_timing=True)
end = torch.Event(enable_timing=True) end = torch.Event(enable_timing=True)
times = [] times = []
...@@ -148,7 +148,7 @@ def benchmark_prefill( ...@@ -148,7 +148,7 @@ def benchmark_prefill(
start.record() start.record()
fn() fn()
end.record() end.record()
torch.cuda.synchronize() torch.accelerator.synchronize()
times.append(start.elapsed_time(end)) # ms times.append(start.elapsed_time(end)) # ms
return sum(times) / len(times), torch.std(torch.tensor(times)) return sum(times) / len(times), torch.std(torch.tensor(times))
......
...@@ -177,18 +177,18 @@ def benchmark_config( ...@@ -177,18 +177,18 @@ def benchmark_config(
def run(): def run():
w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype) w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype)
torch.cuda.synchronize() torch.accelerator.synchronize()
# JIT complication & warmup # JIT complication & warmup
for _ in range(5): for _ in range(5):
run() run()
torch.cuda.synchronize() torch.accelerator.synchronize()
start_event = torch.Event(enable_timing=True) start_event = torch.Event(enable_timing=True)
end_event = torch.Event(enable_timing=True) end_event = torch.Event(enable_timing=True)
latencies: list[float] = [] latencies: list[float] = []
for i in range(num_iters): for i in range(num_iters):
torch.cuda.synchronize() torch.accelerator.synchronize()
start_event.record() start_event.record()
run() run()
end_event.record() end_event.record()
......
...@@ -35,7 +35,7 @@ def benchmark_shape( ...@@ -35,7 +35,7 @@ def benchmark_shape(
B = torch.randn((n, k), device="cuda", dtype=torch.bfloat16) B = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
# Reference result in BF16 # Reference result in BF16
torch.cuda.synchronize() torch.accelerator.synchronize()
C_ref = A @ B.t() C_ref = A @ B.t()
# Pre-quantize B for all implementations # Pre-quantize B for all implementations
...@@ -121,14 +121,14 @@ def benchmark_shape( ...@@ -121,14 +121,14 @@ def benchmark_shape(
# Warmup # Warmup
for _ in range(warmup): for _ in range(warmup):
func() func()
torch.cuda.synchronize() torch.accelerator.synchronize()
# Timing loop # Timing loop
torch.cuda.synchronize() torch.accelerator.synchronize()
start = time.time() start = time.time()
for _ in range(repeat): for _ in range(repeat):
func() func()
torch.cuda.synchronize() torch.accelerator.synchronize()
end = time.time() end = time.time()
# Calculate timing and TFLOPS # Calculate timing and TFLOPS
......
...@@ -50,7 +50,7 @@ V1 was not originally designed with async scheduling in mind, and support requir ...@@ -50,7 +50,7 @@ V1 was not originally designed with async scheduling in mind, and support requir
## 3. Removing Async Barrier ## 3. Removing Async Barrier
A key requirement for async execution is that CPU operations remain non-blocking. Both explicit sync (for example, `torch.cuda.synchronize`) and implicit sync (for example, unpinned `.to("cuda")`) must be avoided. A key requirement for async execution is that CPU operations remain non-blocking. Both explicit sync (for example, `torch.accelerator.synchronize`) and implicit sync (for example, unpinned `.to("cuda")`) must be avoided.
However, async execution can introduce race conditions when CPU and GPU concurrently touch the same memory. However, async execution can introduce race conditions when CPU and GPU concurrently touch the same memory.
......
...@@ -95,7 +95,7 @@ If GPU/CPU communication cannot be established, you can use the following Python ...@@ -95,7 +95,7 @@ If GPU/CPU communication cannot be established, you can use the following Python
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
data = torch.FloatTensor([1,] * 128).to("cuda") data = torch.FloatTensor([1,] * 128).to("cuda")
dist.all_reduce(data, op=dist.ReduceOp.SUM) dist.all_reduce(data, op=dist.ReduceOp.SUM)
torch.cuda.synchronize() torch.accelerator.synchronize()
value = data.mean().item() value = data.mean().item()
world_size = dist.get_world_size() world_size = dist.get_world_size()
assert value == world_size, f"Expected {world_size}, got {value}" assert value == world_size, f"Expected {world_size}, got {value}"
......
...@@ -88,7 +88,7 @@ class RayTrainingActor: ...@@ -88,7 +88,7 @@ class RayTrainingActor:
# Zero out all the parameters. # Zero out all the parameters.
for name, p in self.model.named_parameters(): for name, p in self.model.named_parameters():
p.data.zero_() p.data.zero_()
torch.cuda.synchronize() torch.accelerator.synchronize()
# The argument for `get_device_uuid` is the index of the GPU in the # The argument for `get_device_uuid` is the index of the GPU in the
# list of visible devices. # list of visible devices.
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -151,7 +151,7 @@ class RayTrainingActor: ...@@ -151,7 +151,7 @@ class RayTrainingActor:
p.data.view(-1).view(dtype=torch.uint8), non_blocking=True p.data.view(-1).view(dtype=torch.uint8), non_blocking=True
) )
offset += get_size(p) offset += get_size(p)
torch.cuda.synchronize() torch.accelerator.synchronize()
s.send_pyobj(named_tensors) s.send_pyobj(named_tensors)
s.recv() s.recv()
s.send_pyobj(None) s.send_pyobj(None)
......
...@@ -120,7 +120,7 @@ class ColocateWorkerExtension: ...@@ -120,7 +120,7 @@ class ColocateWorkerExtension:
process_weights_after_loading( process_weights_after_loading(
self.model_runner.model, self.model_config, self.device self.model_runner.model, self.model_config, self.device
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
socket.send(b"") socket.send(b"")
break break
if isinstance(payload, tuple): if isinstance(payload, tuple):
...@@ -144,7 +144,7 @@ class ColocateWorkerExtension: ...@@ -144,7 +144,7 @@ class ColocateWorkerExtension:
weights.append((item["name"], tensor)) weights.append((item["name"], tensor))
self.model_runner.model.load_weights(weights=weights) self.model_runner.model.load_weights(weights=weights)
del weights del weights
torch.cuda.synchronize() torch.accelerator.synchronize()
socket.send(b"") socket.send(b"")
socket.close() socket.close()
......
...@@ -100,7 +100,7 @@ def test_dynamic_shapes_compilation( ...@@ -100,7 +100,7 @@ def test_dynamic_shapes_compilation(
del model del model
gc.collect() gc.collect()
torch.accelerator.empty_cache() torch.accelerator.empty_cache()
torch.cuda.synchronize() torch.accelerator.synchronize()
print("GPU memory cleared") print("GPU memory cleared")
......
...@@ -32,7 +32,7 @@ pointers = CustomAllreduce.create_shared_buffer(buffer_size_in_bytes) ...@@ -32,7 +32,7 @@ pointers = CustomAllreduce.create_shared_buffer(buffer_size_in_bytes)
print(f"Rank {rank} has pointers {pointers}") print(f"Rank {rank} has pointers {pointers}")
dist.barrier() dist.barrier()
torch.cuda.synchronize() torch.accelerator.synchronize()
if rank == 0: if rank == 0:
# the first rank tries to write to all buffers # the first rank tries to write to all buffers
...@@ -41,7 +41,7 @@ if rank == 0: ...@@ -41,7 +41,7 @@ if rank == 0:
lib.cudaMemset(pointer, byte_value, buffer_size_in_bytes) lib.cudaMemset(pointer, byte_value, buffer_size_in_bytes)
dist.barrier() dist.barrier()
torch.cuda.synchronize() torch.accelerator.synchronize()
host_data = (ctypes.c_char * buffer_size_in_bytes)() host_data = (ctypes.c_char * buffer_size_in_bytes)()
...@@ -59,6 +59,6 @@ for p in pointers: ...@@ -59,6 +59,6 @@ for p in pointers:
print(f"Rank {rank} verified all buffers") print(f"Rank {rank} verified all buffers")
dist.barrier() dist.barrier()
torch.cuda.synchronize() torch.accelerator.synchronize()
CustomAllreduce.free_shared_buffer(pointers) CustomAllreduce.free_shared_buffer(pointers)
...@@ -48,7 +48,7 @@ def graph_allreduce( ...@@ -48,7 +48,7 @@ def graph_allreduce(
data = torch.zeros(1) data = torch.zeros(1)
data = data.to(device=device) data = data.to(device=device)
torch.distributed.all_reduce(data, group=group) torch.distributed.all_reduce(data, group=group)
torch.cuda.synchronize() torch.accelerator.synchronize()
del data del data
# we use the first group to communicate once # we use the first group to communicate once
...@@ -68,7 +68,7 @@ def graph_allreduce( ...@@ -68,7 +68,7 @@ def graph_allreduce(
inp2 = torch.randint( inp2 = torch.randint(
1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device() 1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device()
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=graph_capture_context.stream): with torch.cuda.graph(graph, stream=graph_capture_context.stream):
for i in range(num_communication): for i in range(num_communication):
......
...@@ -68,7 +68,7 @@ def worker_fn(): ...@@ -68,7 +68,7 @@ def worker_fn():
) )
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank)
tensor = pynccl_comm.all_reduce(tensor) tensor = pynccl_comm.all_reduce(tensor)
torch.cuda.synchronize() torch.accelerator.synchronize()
assert torch.all(tensor == pynccl_comm.world_size).cpu().item() assert torch.all(tensor == pynccl_comm.world_size).cpu().item()
...@@ -93,11 +93,11 @@ def multiple_allreduce_worker_fn(): ...@@ -93,11 +93,11 @@ def multiple_allreduce_worker_fn():
if torch.distributed.get_rank() in [0, 1]: if torch.distributed.get_rank() in [0, 1]:
tensor = pynccl_comm.all_reduce(tensor) tensor = pynccl_comm.all_reduce(tensor)
tensor = pynccl_comm.all_reduce(tensor) tensor = pynccl_comm.all_reduce(tensor)
torch.cuda.synchronize() torch.accelerator.synchronize()
assert torch.all(tensor == 4).cpu().item() assert torch.all(tensor == 4).cpu().item()
else: else:
tensor = pynccl_comm.all_reduce(tensor) tensor = pynccl_comm.all_reduce(tensor)
torch.cuda.synchronize() torch.accelerator.synchronize()
assert torch.all(tensor == 2).cpu().item() assert torch.all(tensor == 2).cpu().item()
...@@ -121,11 +121,11 @@ def multiple_allreduce_with_vllm_worker_fn(): ...@@ -121,11 +121,11 @@ def multiple_allreduce_with_vllm_worker_fn():
if torch.distributed.get_rank() in [0, 1]: if torch.distributed.get_rank() in [0, 1]:
tensor = tensor_model_parallel_all_reduce(tensor) tensor = tensor_model_parallel_all_reduce(tensor)
tensor = tensor_model_parallel_all_reduce(tensor) tensor = tensor_model_parallel_all_reduce(tensor)
torch.cuda.synchronize() torch.accelerator.synchronize()
assert torch.all(tensor == 4).cpu().item() assert torch.all(tensor == 4).cpu().item()
else: else:
tensor = tensor_model_parallel_all_reduce(tensor) tensor = tensor_model_parallel_all_reduce(tensor)
torch.cuda.synchronize() torch.accelerator.synchronize()
assert torch.all(tensor == 2).cpu().item() assert torch.all(tensor == 2).cpu().item()
...@@ -147,12 +147,12 @@ def worker_fn_with_cudagraph(): ...@@ -147,12 +147,12 @@ def worker_fn_with_cudagraph():
) )
# run something in the default stream to initialize torch engine # run something in the default stream to initialize torch engine
a = torch.ones((4, 4), device=f"cuda:{pynccl_comm.rank}") a = torch.ones((4, 4), device=f"cuda:{pynccl_comm.rank}")
torch.cuda.synchronize() torch.accelerator.synchronize()
with torch.cuda.graph(graph): with torch.cuda.graph(graph):
a_out = pynccl_comm.all_reduce(a) a_out = pynccl_comm.all_reduce(a)
torch.cuda.synchronize() torch.accelerator.synchronize()
graph.replay() graph.replay()
torch.cuda.synchronize() torch.accelerator.synchronize()
assert torch.all(a_out == pynccl_comm.world_size).cpu().item() assert torch.all(a_out == pynccl_comm.world_size).cpu().item()
...@@ -180,7 +180,7 @@ def all_gather_worker_fn(): ...@@ -180,7 +180,7 @@ def all_gather_worker_fn():
).to(device) ).to(device)
pynccl_comm.all_gather(result, tensor) pynccl_comm.all_gather(result, tensor)
torch.cuda.synchronize() torch.accelerator.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
...@@ -215,7 +215,7 @@ def all_gatherv_worker_fn(): ...@@ -215,7 +215,7 @@ def all_gatherv_worker_fn():
).to(device) ).to(device)
pynccl_comm.all_gatherv(result, tensor, sizes=sizes) pynccl_comm.all_gatherv(result, tensor, sizes=sizes)
torch.cuda.synchronize() torch.accelerator.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
...@@ -255,7 +255,7 @@ def reduce_scatter_worker_fn(): ...@@ -255,7 +255,7 @@ def reduce_scatter_worker_fn():
).to(device) ).to(device)
pynccl_comm.reduce_scatter(result, tensor) pynccl_comm.reduce_scatter(result, tensor)
torch.cuda.synchronize() torch.accelerator.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
...@@ -293,7 +293,7 @@ def reduce_scatterv_worker_fn(): ...@@ -293,7 +293,7 @@ def reduce_scatterv_worker_fn():
expected = sum(tensor[start:end] for tensor in all_tensors).to(device) expected = sum(tensor[start:end] for tensor in all_tensors).to(device)
pynccl_comm.reduce_scatterv(result, tensor, sizes=sizes) pynccl_comm.reduce_scatterv(result, tensor, sizes=sizes)
torch.cuda.synchronize() torch.accelerator.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
...@@ -325,7 +325,7 @@ def send_recv_worker_fn(): ...@@ -325,7 +325,7 @@ def send_recv_worker_fn():
pynccl_comm.send(tensor, dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size) pynccl_comm.send(tensor, dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
else: else:
pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size) pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
torch.cuda.synchronize() torch.accelerator.synchronize()
assert torch.all(tensor == 1).cpu().item() assert torch.all(tensor == 1).cpu().item()
...@@ -355,7 +355,7 @@ def multiple_send_recv_worker_fn(): ...@@ -355,7 +355,7 @@ def multiple_send_recv_worker_fn():
pynccl_comm.send(tensor, dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size) pynccl_comm.send(tensor, dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
else: else:
pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size) pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
torch.cuda.synchronize() torch.accelerator.synchronize()
if torch.distributed.get_rank() in [0, 2]: if torch.distributed.get_rank() in [0, 2]:
assert torch.all(tensor == 1).cpu().item() assert torch.all(tensor == 1).cpu().item()
else: else:
...@@ -396,7 +396,7 @@ def broadcast_worker_fn(): ...@@ -396,7 +396,7 @@ def broadcast_worker_fn():
pynccl_comm.broadcast(recv_tensors[i], src=i) pynccl_comm.broadcast(recv_tensors[i], src=i)
# the broadcast op might be launched in a different stream # the broadcast op might be launched in a different stream
# need to synchronize to make sure the tensor is ready # need to synchronize to make sure the tensor is ready
torch.cuda.synchronize() torch.accelerator.synchronize()
assert torch.all(recv_tensors[i] == i).cpu().item() assert torch.all(recv_tensors[i] == i).cpu().item()
......
...@@ -52,7 +52,7 @@ def graph_quickreduce( ...@@ -52,7 +52,7 @@ def graph_quickreduce(
data = torch.zeros(1) data = torch.zeros(1)
data = data.to(device=device) data = data.to(device=device)
torch.distributed.all_reduce(data, group=group) torch.distributed.all_reduce(data, group=group)
torch.cuda.synchronize() torch.accelerator.synchronize()
del data del data
# we use the first group to communicate once # we use the first group to communicate once
...@@ -71,7 +71,7 @@ def graph_quickreduce( ...@@ -71,7 +71,7 @@ def graph_quickreduce(
inp2 = torch.randint( inp2 = torch.randint(
-23, 1, (sz,), dtype=dtype, device=torch.cuda.current_device() -23, 1, (sz,), dtype=dtype, device=torch.cuda.current_device()
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=graph_capture_context.stream): with torch.cuda.graph(graph, stream=graph_capture_context.stream):
for _ in range(num_communication): for _ in range(num_communication):
......
...@@ -79,11 +79,11 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2): ...@@ -79,11 +79,11 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2):
data = torch.tensor([rank]).cuda() data = torch.tensor([rank]).cuda()
pynccl1.all_reduce(data) pynccl1.all_reduce(data)
pg1.barrier() pg1.barrier()
torch.cuda.synchronize() torch.accelerator.synchronize()
if rank <= 2: if rank <= 2:
pynccl2.all_reduce(data) pynccl2.all_reduce(data)
pg2.barrier() pg2.barrier()
torch.cuda.synchronize() torch.accelerator.synchronize()
item = data[0].item() item = data[0].item()
print(f"rank: {rank}, item: {item}") print(f"rank: {rank}, item: {item}")
if rank == 3: if rank == 3:
......
...@@ -251,7 +251,7 @@ def trainer_broadcast_tensor( ...@@ -251,7 +251,7 @@ def trainer_broadcast_tensor(
dtype = getattr(torch, tensor_dtype) dtype = getattr(torch, tensor_dtype)
tensor_to_send = torch.ones(tensor_shape, dtype=dtype, device="cuda:0") tensor_to_send = torch.ones(tensor_shape, dtype=dtype, device="cuda:0")
comm.broadcast(tensor_to_send, src=0, stream=torch.cuda.current_stream()) comm.broadcast(tensor_to_send, src=0, stream=torch.cuda.current_stream())
torch.cuda.synchronize() torch.accelerator.synchronize()
return True return True
...@@ -309,7 +309,7 @@ def inference_receive_tensor( ...@@ -309,7 +309,7 @@ def inference_receive_tensor(
shapes=[tensor_shape], shapes=[tensor_shape],
) )
engine.receive_weights(update_info, noop_load_weights) engine.receive_weights(update_info, noop_load_weights)
torch.cuda.synchronize() torch.accelerator.synchronize()
# Verify we received the tensor # Verify we received the tensor
success = False success = False
...@@ -630,7 +630,7 @@ class TrainerActor: ...@@ -630,7 +630,7 @@ class TrainerActor:
ipc_handle = reduce_tensor(self.tensor) ipc_handle = reduce_tensor(self.tensor)
gpu_uuid = get_physical_gpu_id(0) gpu_uuid = get_physical_gpu_id(0)
torch.cuda.synchronize() torch.accelerator.synchronize()
self.ipc_handle_dict = { self.ipc_handle_dict = {
"ipc_handle": ipc_handle, "ipc_handle": ipc_handle,
...@@ -704,7 +704,7 @@ def inference_receive_ipc_tensor( ...@@ -704,7 +704,7 @@ def inference_receive_ipc_tensor(
update_info = engine.parse_update_info(update_dict) update_info = engine.parse_update_info(update_dict)
engine.receive_weights(update_info, noop_load_weights) engine.receive_weights(update_info, noop_load_weights)
torch.cuda.synchronize() torch.accelerator.synchronize()
# Verify we received the tensor # Verify we received the tensor
success = False success = False
......
...@@ -165,7 +165,7 @@ def test_merge_attn_states( ...@@ -165,7 +165,7 @@ def test_merge_attn_states(
suffix_lse_torch, suffix_lse_torch,
output_lse_torch, output_lse_torch,
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
for _ in range(repeat_times): for _ in range(repeat_times):
start.record() start.record()
...@@ -178,7 +178,7 @@ def test_merge_attn_states( ...@@ -178,7 +178,7 @@ def test_merge_attn_states(
output_lse_torch, output_lse_torch,
) )
end.record() end.record()
torch.cuda.synchronize() torch.accelerator.synchronize()
total_time_torch_kernel += start.elapsed_time(end) total_time_torch_kernel += start.elapsed_time(end)
avg_time_torch_kernel = total_time_torch_kernel / repeat_times avg_time_torch_kernel = total_time_torch_kernel / repeat_times
...@@ -200,7 +200,7 @@ def test_merge_attn_states( ...@@ -200,7 +200,7 @@ def test_merge_attn_states(
suffix_lse, suffix_lse,
output_lse_ref_triton, output_lse_ref_triton,
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
for _ in range(repeat_times): for _ in range(repeat_times):
start.record() start.record()
...@@ -213,7 +213,7 @@ def test_merge_attn_states( ...@@ -213,7 +213,7 @@ def test_merge_attn_states(
output_lse_ref_triton, output_lse_ref_triton,
) )
end.record() end.record()
torch.cuda.synchronize() torch.accelerator.synchronize()
total_time_triton_kernel += start.elapsed_time(end) total_time_triton_kernel += start.elapsed_time(end)
avg_time_triton_kernel = total_time_triton_kernel / repeat_times avg_time_triton_kernel = total_time_triton_kernel / repeat_times
...@@ -232,7 +232,7 @@ def test_merge_attn_states( ...@@ -232,7 +232,7 @@ def test_merge_attn_states(
suffix_lse, suffix_lse,
output_lse_cuda, output_lse_cuda,
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
for _ in range(repeat_times): for _ in range(repeat_times):
start.record() start.record()
...@@ -245,7 +245,7 @@ def test_merge_attn_states( ...@@ -245,7 +245,7 @@ def test_merge_attn_states(
output_lse_cuda, output_lse_cuda,
) )
end.record() end.record()
torch.cuda.synchronize() torch.accelerator.synchronize()
total_time_cuda_kernel += start.elapsed_time(end) total_time_cuda_kernel += start.elapsed_time(end)
avg_time_cuda_kernel = total_time_cuda_kernel / repeat_times avg_time_cuda_kernel = total_time_cuda_kernel / repeat_times
......
...@@ -239,7 +239,7 @@ def test_contexted_kv_attention( ...@@ -239,7 +239,7 @@ def test_contexted_kv_attention(
v_scale, v_scale,
sliding_window=sliding_window, sliding_window=sliding_window,
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
start_time = time.time() start_time = time.time()
op( op(
query, query,
...@@ -258,7 +258,7 @@ def test_contexted_kv_attention( ...@@ -258,7 +258,7 @@ def test_contexted_kv_attention(
v_scale, v_scale,
sliding_window=sliding_window, sliding_window=sliding_window,
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
end_time = time.time() end_time = time.time()
print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms") print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms")
...@@ -298,7 +298,7 @@ def test_contexted_kv_attention( ...@@ -298,7 +298,7 @@ def test_contexted_kv_attention(
dropout_p=0.0, dropout_p=0.0,
scale=scale, scale=scale,
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
start_time = time.time() start_time = time.time()
output_ref = F.scaled_dot_product_attention( output_ref = F.scaled_dot_product_attention(
query_sdpa, query_sdpa,
...@@ -308,7 +308,7 @@ def test_contexted_kv_attention( ...@@ -308,7 +308,7 @@ def test_contexted_kv_attention(
dropout_p=0.0, dropout_p=0.0,
scale=scale, scale=scale,
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
end_time = time.time() end_time = time.time()
print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms") print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms")
...@@ -482,7 +482,7 @@ def test_contexted_kv_attention_alibi( ...@@ -482,7 +482,7 @@ def test_contexted_kv_attention_alibi(
v_scale, v_scale,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
start_time = time.time() start_time = time.time()
op( op(
query, query,
...@@ -501,7 +501,7 @@ def test_contexted_kv_attention_alibi( ...@@ -501,7 +501,7 @@ def test_contexted_kv_attention_alibi(
v_scale, v_scale,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
end_time = time.time() end_time = time.time()
print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms") print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms")
scale = float(1.0 / (head_size**0.5)) scale = float(1.0 / (head_size**0.5))
...@@ -517,7 +517,7 @@ def test_contexted_kv_attention_alibi( ...@@ -517,7 +517,7 @@ def test_contexted_kv_attention_alibi(
output_ref = torch.empty_like(output) output_ref = torch.empty_like(output)
torch.cuda.synchronize() torch.accelerator.synchronize()
start_time = time.time() start_time = time.time()
query_start = 0 query_start = 0
...@@ -572,7 +572,7 @@ def test_contexted_kv_attention_alibi( ...@@ -572,7 +572,7 @@ def test_contexted_kv_attention_alibi(
query_start = query_end query_start = query_end
key_start = key_end key_start = key_end
torch.cuda.synchronize() torch.accelerator.synchronize()
end_time = time.time() end_time = time.time()
print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms") print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms")
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6 atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
......
...@@ -127,7 +127,7 @@ def test_fused_rms_norm_quant( ...@@ -127,7 +127,7 @@ def test_fused_rms_norm_quant(
out_quant, x_unfused.contiguous(), quant_scale_t out_quant, x_unfused.contiguous(), quant_scale_t
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
torch.testing.assert_close(residual_fused, residual, atol=1e-2, rtol=1e-2) torch.testing.assert_close(residual_fused, residual, atol=1e-2, rtol=1e-2)
opcheck( opcheck(
torch.ops._C.fused_add_rms_norm_static_fp8_quant, torch.ops._C.fused_add_rms_norm_static_fp8_quant,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment