Unverified Commit a55b2433 authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

fix: remove unnecessary cuda synchronize calls in GMS adapters (#6362)

parent c5c6a551
......@@ -314,7 +314,8 @@ class BaseWorkerHandler(ABC):
Order of operations:
1. Unregister from discovery - stop accepting new requests
2. Sleep engine - safe now that no new requests will arrive
2. Abort and drain in-flight requests
3. Sleep engine - safe now that GPU is quiesced
"""
level = body.get("level", 1)
try:
......@@ -329,7 +330,11 @@ class BaseWorkerHandler(ABC):
f"[Sleep] Failed to unregister endpoint from discovery: {unreg_err}"
)
# Step 2: Now safe to sleep - no new requests will be routed here
# Step 2: Abort in-flight requests and wait for them to drain so the
# GPU is fully quiesced before unmapping memory.
await self.engine_client.pause_generation()
# Step 3: Now safe to sleep - no in-flight GPU work
await self.engine_client.sleep(level)
return {"status": "ok", "message": f"Engine slept (level={level})"}
......@@ -352,7 +357,10 @@ class BaseWorkerHandler(ABC):
# Step 1: Wake engine first - must be ready before accepting requests
await self.engine_client.wake_up(tags)
# Step 2: Re-register endpoint instance to discovery so frontend can route to us again
# Step 2: Resume generation so new requests can be processed
await self.engine_client.resume_generation()
# Step 3: Re-register endpoint instance to discovery so frontend can route to us again
try:
await self.generate_endpoint.register_endpoint_instance()
logger.info(
......
......@@ -403,11 +403,6 @@ class GMSClientMemoryManager:
self._unmap_preserving_va()
self._va_preserved = True
# Ensure all CUDA VMM unmap operations complete before releasing the lock.
# This prevents race conditions where remap() may be called before
# physical memory is fully released.
synchronize()
self._client_rpc.close()
self._client = None
self._unmapped = True
......
......@@ -136,8 +136,6 @@ class GMSMemorySaverImpl:
self._pause_weights()
if tag is None or not self._is_weights_tag(tag):
self._torch_impl.pause(tag=tag)
# Ensure KV cache unmap operations complete before returning.
torch.cuda.synchronize()
def resume(self, tag: Optional[str] = None) -> None:
if self._disabled:
......@@ -146,9 +144,6 @@ class GMSMemorySaverImpl:
self._resume_weights()
if tag is None or not self._is_weights_tag(tag):
self._torch_impl.resume(tag=tag)
# Ensure KV cache mappings are complete before returning.
# Without this sync, inference may start before mappings are ready.
torch.cuda.synchronize()
def _pause_weights(self) -> None:
if self._allocator is None:
......@@ -157,10 +152,6 @@ class GMSMemorySaverImpl:
return
logger.info("[GMS] Unmapping weights (VA-stable)")
self._allocator.unmap()
# Ensure all CUDA VMM unmap operations complete before returning.
# Without this sync, resume() may race with pending unmaps, causing OOM
# when it tries to allocate new memory while old memory is still mapped.
torch.cuda.synchronize()
def _resume_weights(self) -> None:
if self._allocator is None:
......@@ -169,7 +160,6 @@ class GMSMemorySaverImpl:
return
logger.info("[GMS] Remapping weights (VA-stable)")
self._allocator.remap()
torch.cuda.synchronize()
def finalize_write_mode(self, model: torch.nn.Module) -> None:
"""Finalize write mode: register tensors, commit, and switch to read."""
......
......@@ -119,11 +119,6 @@ class GMSWorker(Worker):
allocator = CuMemAllocator.get_instance()
allocator.sleep(offload_tags=tuple())
# Ensure all CUDA VMM unmap operations complete before returning.
# Without this sync, wake_up() may race with pending unmaps, causing OOM
# when it tries to allocate new memory while old memory is still mapped.
torch.cuda.synchronize()
free_bytes_after, total = torch.cuda.mem_get_info()
freed_bytes = free_bytes_after - free_bytes_before
used_bytes = total - free_bytes_after
......@@ -145,16 +140,11 @@ class GMSWorker(Worker):
assert manager is not None, "GMS client is not initialized"
assert manager.is_unmapped, "GMS weights are not unmapped"
manager.remap()
torch.cuda.synchronize()
if "kv_cache" in tags:
allocator = CuMemAllocator.get_instance()
allocator.wake_up(tags=["kv_cache"])
# Ensure KV cache mappings are complete before returning.
# Without this sync, inference may start before mappings are ready.
torch.cuda.synchronize()
# Reinitialize FP8 KV scales if needed
if self.cache_config.cache_dtype.startswith("fp8") and hasattr(
self.model_runner, "init_fp8_kv_scales"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment