Unverified Commit 2a2d5d27 authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

Replace `torch.cuda.Event` with `torch.Event` for better hardware compatibility (#26985)


Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
parent c3e29786
...@@ -255,8 +255,8 @@ def bench_run( ...@@ -255,8 +255,8 @@ def bench_run(
torch.cuda.synchronize() torch.cuda.synchronize()
# Timing # Timing
start_event = torch.cuda.Event(enable_timing=True) start_event = torch.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) end_event = torch.Event(enable_timing=True)
latencies = [] latencies = []
for _ in range(num_iters): for _ in range(num_iters):
......
...@@ -185,8 +185,8 @@ def benchmark_config( ...@@ -185,8 +185,8 @@ def benchmark_config(
graph.replay() graph.replay()
torch.cuda.synchronize() torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True) start_event = torch.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) end_event = torch.Event(enable_timing=True)
latencies: list[float] = [] latencies: list[float] = []
for i in range(num_iters): for i in range(num_iters):
......
...@@ -105,8 +105,8 @@ def benchmark_permute( ...@@ -105,8 +105,8 @@ def benchmark_permute(
graph.replay() graph.replay()
torch.cuda.synchronize() torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True) start_event = torch.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) end_event = torch.Event(enable_timing=True)
latencies: list[float] = [] latencies: list[float] = []
for i in range(num_iters): for i in range(num_iters):
...@@ -241,8 +241,8 @@ def benchmark_unpermute( ...@@ -241,8 +241,8 @@ def benchmark_unpermute(
graph.replay() graph.replay()
torch.cuda.synchronize() torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True) start_event = torch.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) end_event = torch.Event(enable_timing=True)
latencies: list[float] = [] latencies: list[float] = []
for i in range(num_iters): for i in range(num_iters):
......
...@@ -30,8 +30,8 @@ def _time_cuda( ...@@ -30,8 +30,8 @@ def _time_cuda(
fn() fn()
torch.cuda.synchronize() torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True) start = torch.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True) end = torch.Event(enable_timing=True)
start.record() start.record()
for _ in range(bench_iters): for _ in range(bench_iters):
......
...@@ -253,8 +253,8 @@ def benchmark( ...@@ -253,8 +253,8 @@ def benchmark(
) )
torch.cuda.synchronize() torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True) start_event = torch.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) end_event = torch.Event(enable_timing=True)
# Benchmark # Benchmark
latencies: list[float] = [] latencies: list[float] = []
......
...@@ -127,8 +127,8 @@ def benchmark_decode( ...@@ -127,8 +127,8 @@ def benchmark_decode(
def time_fn(fn, warmup=10, trials=20): def time_fn(fn, warmup=10, trials=20):
torch.cuda.synchronize() torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True) start = torch.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True) end = torch.Event(enable_timing=True)
times = [] times = []
for i in range(warmup): for i in range(warmup):
fn() fn()
......
...@@ -139,8 +139,8 @@ def benchmark_prefill( ...@@ -139,8 +139,8 @@ def benchmark_prefill(
def time_fn(fn, warmup=10, trials=20): def time_fn(fn, warmup=10, trials=20):
torch.cuda.synchronize() torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True) start = torch.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True) end = torch.Event(enable_timing=True)
times = [] times = []
for i in range(warmup): for i in range(warmup):
fn() fn()
......
...@@ -183,8 +183,8 @@ def benchmark_config( ...@@ -183,8 +183,8 @@ def benchmark_config(
run() run()
torch.cuda.synchronize() torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True) start_event = torch.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) end_event = torch.Event(enable_timing=True)
latencies: list[float] = [] latencies: list[float] = []
for i in range(num_iters): for i in range(num_iters):
......
...@@ -150,8 +150,8 @@ def test_merge_attn_states( ...@@ -150,8 +150,8 @@ def test_merge_attn_states(
output_torch = output.clone() output_torch = output.clone()
output_lse_torch = output_lse.clone() output_lse_torch = output_lse.clone()
total_time_torch_kernel = 0 total_time_torch_kernel = 0
start = torch.cuda.Event(enable_timing=True) start = torch.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True) end = torch.Event(enable_timing=True)
# 0. Run the Torch kernel # 0. Run the Torch kernel
prefix_lse_torch = prefix_lse.clone() prefix_lse_torch = prefix_lse.clone()
...@@ -188,8 +188,8 @@ def test_merge_attn_states( ...@@ -188,8 +188,8 @@ def test_merge_attn_states(
output_lse_ref_triton = output_lse.clone() output_lse_ref_triton = output_lse.clone()
total_time_triton_kernel = 0 total_time_triton_kernel = 0
start = torch.cuda.Event(enable_timing=True) start = torch.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True) end = torch.Event(enable_timing=True)
for _ in range(warmup_times): for _ in range(warmup_times):
merge_attn_states_triton( merge_attn_states_triton(
......
...@@ -68,9 +68,9 @@ class CpuGpuOffloadingHandler(OffloadingHandler): ...@@ -68,9 +68,9 @@ class CpuGpuOffloadingHandler(OffloadingHandler):
self.h2d_stream = torch.cuda.Stream() self.h2d_stream = torch.cuda.Stream()
# job_id -> transfer cuda event # job_id -> transfer cuda event
self.transfer_events: dict[int, torch.cuda.Event] = {} self.transfer_events: dict[int, torch.Event] = {}
# list of cuda events available for re-use # list of cuda events available for re-use
self.events_pool: list[torch.cuda.Event] = [] self.events_pool: list[torch.Event] = []
pin_memory = is_pin_memory_available() pin_memory = is_pin_memory_available()
...@@ -153,7 +153,7 @@ class CpuGpuOffloadingHandler(OffloadingHandler): ...@@ -153,7 +153,7 @@ class CpuGpuOffloadingHandler(OffloadingHandler):
) )
src_to_dst_tensor = torch.from_numpy(src_to_dst) src_to_dst_tensor = torch.from_numpy(src_to_dst)
event = self.events_pool.pop() if self.events_pool else torch.cuda.Event() event = self.events_pool.pop() if self.events_pool else torch.Event()
with torch.cuda.stream(stream): with torch.cuda.stream(stream):
for src_tensor, dst_tensor, kv_dim in zip( for src_tensor, dst_tensor, kv_dim in zip(
src_tensors, dst_tensors, self.kv_dim_before_num_blocks src_tensors, dst_tensors, self.kv_dim_before_num_blocks
......
...@@ -96,14 +96,14 @@ def _torch_cuda_wrapper(): ...@@ -96,14 +96,14 @@ def _torch_cuda_wrapper():
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
pass pass
cuda_event = torch.cuda.Event cuda_event = torch.Event
cuda_stream = torch.cuda.Stream cuda_stream = torch.cuda.Stream
try: try:
torch.cuda.Event = _EventPlaceholder torch.Event = _EventPlaceholder
torch.cuda.Stream = _StreamPlaceholder torch.cuda.Stream = _StreamPlaceholder
yield yield
finally: finally:
torch.cuda.Event = cuda_event torch.Event = cuda_event
torch.cuda.Stream = cuda_stream torch.cuda.Stream = cuda_stream
......
...@@ -265,7 +265,7 @@ class InputBatch: ...@@ -265,7 +265,7 @@ class InputBatch:
# ids from prior step, if required by current sampling params # ids from prior step, if required by current sampling params
# (e.g. penalties). # (e.g. penalties).
self.sampled_token_ids_cpu: torch.Tensor | None = None self.sampled_token_ids_cpu: torch.Tensor | None = None
self.async_copy_ready_event: torch.cuda.Event | None = None self.async_copy_ready_event: torch.Event | None = None
@property @property
def req_ids(self) -> list[str]: def req_ids(self) -> list[str]:
...@@ -891,7 +891,7 @@ class InputBatch: ...@@ -891,7 +891,7 @@ class InputBatch:
def set_async_sampled_token_ids( def set_async_sampled_token_ids(
self, self,
sampled_token_ids_cpu: torch.Tensor, sampled_token_ids_cpu: torch.Tensor,
async_copy_ready_event: torch.cuda.Event, async_copy_ready_event: torch.Event,
) -> None: ) -> None:
""" """
In async scheduling case, store ref to sampled_token_ids_cpu In async scheduling case, store ref to sampled_token_ids_cpu
......
...@@ -185,7 +185,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): ...@@ -185,7 +185,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
self._invalid_req_indices = invalid_req_indices self._invalid_req_indices = invalid_req_indices
# Event on the copy stream so we can synchronize the non-blocking copy. # Event on the copy stream so we can synchronize the non-blocking copy.
self.async_copy_ready_event = torch.cuda.Event() self.async_copy_ready_event = torch.Event()
# Keep a reference to the device tensor to avoid it being # Keep a reference to the device tensor to avoid it being
# deallocated until we finish copying it to the host. # deallocated until we finish copying it to the host.
...@@ -435,10 +435,10 @@ class GPUModelRunner( ...@@ -435,10 +435,10 @@ class GPUModelRunner(
self.async_output_copy_stream: torch.cuda.Stream | None = None self.async_output_copy_stream: torch.cuda.Stream | None = None
# cuda event to synchronize use of reused CPU tensors between steps # cuda event to synchronize use of reused CPU tensors between steps
# when async scheduling is enabled. # when async scheduling is enabled.
self.prepare_inputs_event: torch.cuda.Event | None = None self.prepare_inputs_event: torch.Event | None = None
if self.use_async_scheduling: if self.use_async_scheduling:
self.async_output_copy_stream = torch.cuda.Stream() self.async_output_copy_stream = torch.cuda.Stream()
self.prepare_inputs_event = torch.cuda.Event() self.prepare_inputs_event = torch.Event()
# self.cudagraph_batch_sizes sorts in ascending order. # self.cudagraph_batch_sizes sorts in ascending order.
if ( if (
...@@ -549,7 +549,7 @@ class GPUModelRunner( ...@@ -549,7 +549,7 @@ class GPUModelRunner(
# Cached outputs. # Cached outputs.
self._draft_token_ids: list[list[int]] | torch.Tensor | None = None self._draft_token_ids: list[list[int]] | torch.Tensor | None = None
self.transfer_event = torch.cuda.Event() self.transfer_event = torch.Event()
self.sampled_token_ids_pinned_cpu = torch.empty( self.sampled_token_ids_pinned_cpu = torch.empty(
(self.max_num_reqs, 1), (self.max_num_reqs, 1),
dtype=torch.int64, dtype=torch.int64,
...@@ -559,10 +559,10 @@ class GPUModelRunner( ...@@ -559,10 +559,10 @@ class GPUModelRunner(
# Pre-allocated tensor for copying valid sampled token counts to CPU, # Pre-allocated tensor for copying valid sampled token counts to CPU,
# with dedicated stream for overlapping and event for coordination. # with dedicated stream for overlapping and event for coordination.
self.valid_sampled_token_count_event: torch.cuda.Event | None = None self.valid_sampled_token_count_event: torch.Event | None = None
self.valid_sampled_token_count_copy_stream: torch.cuda.Stream | None = None self.valid_sampled_token_count_copy_stream: torch.cuda.Stream | None = None
if self.use_async_scheduling and self.num_spec_tokens: if self.use_async_scheduling and self.num_spec_tokens:
self.valid_sampled_token_count_event = torch.cuda.Event() self.valid_sampled_token_count_event = torch.Event()
self.valid_sampled_token_count_copy_stream = torch.cuda.Stream() self.valid_sampled_token_count_copy_stream = torch.cuda.Stream()
self.valid_sampled_token_count_cpu = torch.empty( self.valid_sampled_token_count_cpu = torch.empty(
self.max_num_reqs, self.max_num_reqs,
......
...@@ -27,8 +27,8 @@ class UBatchContext: ...@@ -27,8 +27,8 @@ class UBatchContext:
ready_barrier: threading.Barrier, ready_barrier: threading.Barrier,
cpu_wait_event: threading.Event, cpu_wait_event: threading.Event,
cpu_signal_event: threading.Event, cpu_signal_event: threading.Event,
gpu_comm_done_event: torch.cuda.Event, gpu_comm_done_event: torch.Event,
gpu_compute_done_event: torch.cuda.Event, gpu_compute_done_event: torch.Event,
schedule: str = "default", schedule: str = "default",
): ):
self.id = id self.id = id
...@@ -207,8 +207,8 @@ def make_ubatch_contexts( ...@@ -207,8 +207,8 @@ def make_ubatch_contexts(
Create a context manager for micro-batching synchronization. Create a context manager for micro-batching synchronization.
""" """
cpu_events = [threading.Event() for _ in range(num_micro_batches)] cpu_events = [threading.Event() for _ in range(num_micro_batches)]
gpu_comm_done_events = [torch.cuda.Event() for _ in range(num_micro_batches)] gpu_comm_done_events = [torch.Event() for _ in range(num_micro_batches)]
gpu_compute_done_events = [torch.cuda.Event() for _ in range(num_micro_batches)] gpu_compute_done_events = [torch.Event() for _ in range(num_micro_batches)]
assert len(forward_contexts) == 2 assert len(forward_contexts) == 2
......
...@@ -37,19 +37,12 @@ class XPUModelRunner(GPUModelRunner): ...@@ -37,19 +37,12 @@ class XPUModelRunner(GPUModelRunner):
@contextmanager @contextmanager
def _torch_cuda_wrapper(): def _torch_cuda_wrapper():
class _EventPlaceholder:
def __init__(self, *args, **kwargs) -> None:
self.record = lambda: None
self.synchronize = lambda: None
try: try:
# replace cuda APIs with xpu APIs, this should work by default # replace cuda APIs with xpu APIs, this should work by default
torch.cuda.Event = torch.xpu.Event
torch.cuda.Stream = torch.xpu.Stream torch.cuda.Stream = torch.xpu.Stream
torch.cuda.default_stream = torch.xpu.current_stream torch.cuda.default_stream = torch.xpu.current_stream
torch.cuda.current_stream = torch.xpu.current_stream torch.cuda.current_stream = torch.xpu.current_stream
torch.cuda.stream = torch.xpu.stream torch.cuda.stream = torch.xpu.stream
yield yield
finally: finally:
# if anything goes wrong, just patch it with a placeholder pass
torch.cuda.Event = _EventPlaceholder
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment