Unverified Commit ce2ef42f authored by Andreas Karatzas's avatar Andreas Karatzas Committed by GitHub
Browse files

[CI] Stabilize test_cpu_offloading by waiting for async offload before cache reset (#37335)


Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
parent 8b632575
...@@ -22,6 +22,17 @@ if current_platform.is_cuda(): ...@@ -22,6 +22,17 @@ if current_platform.is_cuda():
elif current_platform.is_rocm(): elif current_platform.is_rocm():
ATTN_BACKENDS = ["TRITON_ATTN"] ATTN_BACKENDS = ["TRITON_ATTN"]
# Maximum time (seconds) to wait for the async CPU offload transfer
# to complete before giving up.
_RESET_CACHE_TIMEOUT = 30 if current_platform.is_rocm() else 10
# ZMQ poll timeout (ms) for the first event.
_FIRST_EVENT_POLL_MS = 10_000 if current_platform.is_rocm() else 1000
# Hard ceiling (seconds) on how long get_new_cpu_stored_events may loop,
# to prevent hangs if non-CPU events keep arriving indefinitely.
_EVENT_DRAIN_TIMEOUT = 60
class MockSubscriber: class MockSubscriber:
"""Helper class to receive and verify published events""" """Helper class to receive and verify published events"""
...@@ -47,9 +58,10 @@ class MockSubscriber: ...@@ -47,9 +58,10 @@ class MockSubscriber:
poller = zmq.Poller() poller = zmq.Poller()
poller.register(self.sub, zmq.POLLIN) poller.register(self.sub, zmq.POLLIN)
timeout = 1000 # 1 second poll_ms = _FIRST_EVENT_POLL_MS
while True: deadline = time.monotonic() + _EVENT_DRAIN_TIMEOUT
events = dict(poller.poll(timeout)) while time.monotonic() < deadline:
events = dict(poller.poll(poll_ms))
if events.get(self.sub) != zmq.POLLIN: if events.get(self.sub) != zmq.POLLIN:
return cpu_stored_events return cpu_stored_events
...@@ -63,13 +75,32 @@ class MockSubscriber: ...@@ -63,13 +75,32 @@ class MockSubscriber:
for event in event_batch.events: for event in event_batch.events:
if isinstance(event, BlockStored) and event.medium == "CPU": if isinstance(event, BlockStored) and event.medium == "CPU":
cpu_stored_events.append(event) cpu_stored_events.append(event)
timeout = 100 poll_ms = 100
return cpu_stored_events
def close(self): def close(self):
"""Clean up resources""" """Clean up resources"""
self.sub.close() self.sub.close()
def _wait_for_prefix_cache_reset(llm: LLM) -> None:
"""Wait for async offload transfers to finish so prefix cache can reset.
The GPU-to-CPU offload runs on a CUDA stream asynchronously. While blocks
are still held by the offload worker, ``reset_prefix_cache`` returns
``False``. Retry with a short sleep until it succeeds or we time out.
"""
deadline = time.monotonic() + _RESET_CACHE_TIMEOUT
while not llm.reset_prefix_cache():
if time.monotonic() > deadline:
raise TimeoutError(
"reset_prefix_cache did not succeed within "
f"{_RESET_CACHE_TIMEOUT}s - async offload may be stuck"
)
time.sleep(0.1)
def _latency_test(llm: LLM, subscriber: MockSubscriber): def _latency_test(llm: LLM, subscriber: MockSubscriber):
sampling_params = SamplingParams(max_tokens=1) sampling_params = SamplingParams(max_tokens=1)
...@@ -95,10 +126,16 @@ def _latency_test(llm: LLM, subscriber: MockSubscriber): ...@@ -95,10 +126,16 @@ def _latency_test(llm: LLM, subscriber: MockSubscriber):
gpu_hit_time = time.time() - start_time gpu_hit_time = time.time() - start_time
total_gpu_hit_time += gpu_hit_time total_gpu_hit_time += gpu_hit_time
# reset prefix cache to avoid GPU hit. # Wait for the async CPU offload to finish, then reset prefix cache
llm.reset_prefix_cache() # so the next generate() must reload from CPU rather than GPU.
_wait_for_prefix_cache_reset(llm)
assert subscriber.get_new_cpu_stored_events() # Verify CPU stored events arrived (offload is done before we
# attempt to load from CPU).
assert subscriber.get_new_cpu_stored_events(), (
f"No CPU stored events received on iteration {i}; "
"async offload may not have completed in time"
)
# run generation again - this should trigger loading from CPU # run generation again - this should trigger loading from CPU
start_time = time.time() start_time = time.time()
...@@ -185,6 +222,8 @@ def test_cpu_offloading(cpu_block_size: int, attn_backend: str) -> None: ...@@ -185,6 +222,8 @@ def test_cpu_offloading(cpu_block_size: int, attn_backend: str) -> None:
kv_events_config=kv_events_config, kv_events_config=kv_events_config,
kv_transfer_config=kv_transfer_config, kv_transfer_config=kv_transfer_config,
attention_config={"backend": attn_backend}, attention_config={"backend": attn_backend},
# ROCm: batch size 1 to reduce variability
**({"max_num_seqs": 1} if current_platform.is_rocm() else {}),
) )
events_endpoint = events_endpoint.replace("*", "127.0.0.1") events_endpoint = events_endpoint.replace("*", "127.0.0.1")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment