Unverified Commit a55b2433 authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

fix: remove unnecessary cuda synchronize calls in GMS adapters (#6362)

parent c5c6a551
...@@ -314,7 +314,8 @@ class BaseWorkerHandler(ABC): ...@@ -314,7 +314,8 @@ class BaseWorkerHandler(ABC):
Order of operations: Order of operations:
1. Unregister from discovery - stop accepting new requests 1. Unregister from discovery - stop accepting new requests
2. Sleep engine - safe now that no new requests will arrive 2. Abort and drain in-flight requests
3. Sleep engine - safe now that GPU is quiesced
""" """
level = body.get("level", 1) level = body.get("level", 1)
try: try:
...@@ -329,7 +330,11 @@ class BaseWorkerHandler(ABC): ...@@ -329,7 +330,11 @@ class BaseWorkerHandler(ABC):
f"[Sleep] Failed to unregister endpoint from discovery: {unreg_err}" f"[Sleep] Failed to unregister endpoint from discovery: {unreg_err}"
) )
# Step 2: Now safe to sleep - no new requests will be routed here # Step 2: Abort in-flight requests and wait for them to drain so the
# GPU is fully quiesced before unmapping memory.
await self.engine_client.pause_generation()
# Step 3: Now safe to sleep - no in-flight GPU work
await self.engine_client.sleep(level) await self.engine_client.sleep(level)
return {"status": "ok", "message": f"Engine slept (level={level})"} return {"status": "ok", "message": f"Engine slept (level={level})"}
...@@ -352,7 +357,10 @@ class BaseWorkerHandler(ABC): ...@@ -352,7 +357,10 @@ class BaseWorkerHandler(ABC):
# Step 1: Wake engine first - must be ready before accepting requests # Step 1: Wake engine first - must be ready before accepting requests
await self.engine_client.wake_up(tags) await self.engine_client.wake_up(tags)
# Step 2: Re-register endpoint instance to discovery so frontend can route to us again # Step 2: Resume generation so new requests can be processed
await self.engine_client.resume_generation()
# Step 3: Re-register endpoint instance to discovery so frontend can route to us again
try: try:
await self.generate_endpoint.register_endpoint_instance() await self.generate_endpoint.register_endpoint_instance()
logger.info( logger.info(
......
...@@ -403,11 +403,6 @@ class GMSClientMemoryManager: ...@@ -403,11 +403,6 @@ class GMSClientMemoryManager:
self._unmap_preserving_va() self._unmap_preserving_va()
self._va_preserved = True self._va_preserved = True
# Ensure all CUDA VMM unmap operations complete before releasing the lock.
# This prevents race conditions where remap() may be called before
# physical memory is fully released.
synchronize()
self._client_rpc.close() self._client_rpc.close()
self._client = None self._client = None
self._unmapped = True self._unmapped = True
......
...@@ -136,8 +136,6 @@ class GMSMemorySaverImpl: ...@@ -136,8 +136,6 @@ class GMSMemorySaverImpl:
self._pause_weights() self._pause_weights()
if tag is None or not self._is_weights_tag(tag): if tag is None or not self._is_weights_tag(tag):
self._torch_impl.pause(tag=tag) self._torch_impl.pause(tag=tag)
# Ensure KV cache unmap operations complete before returning.
torch.cuda.synchronize()
def resume(self, tag: Optional[str] = None) -> None: def resume(self, tag: Optional[str] = None) -> None:
if self._disabled: if self._disabled:
...@@ -146,9 +144,6 @@ class GMSMemorySaverImpl: ...@@ -146,9 +144,6 @@ class GMSMemorySaverImpl:
self._resume_weights() self._resume_weights()
if tag is None or not self._is_weights_tag(tag): if tag is None or not self._is_weights_tag(tag):
self._torch_impl.resume(tag=tag) self._torch_impl.resume(tag=tag)
# Ensure KV cache mappings are complete before returning.
# Without this sync, inference may start before mappings are ready.
torch.cuda.synchronize()
def _pause_weights(self) -> None: def _pause_weights(self) -> None:
if self._allocator is None: if self._allocator is None:
...@@ -157,10 +152,6 @@ class GMSMemorySaverImpl: ...@@ -157,10 +152,6 @@ class GMSMemorySaverImpl:
return return
logger.info("[GMS] Unmapping weights (VA-stable)") logger.info("[GMS] Unmapping weights (VA-stable)")
self._allocator.unmap() self._allocator.unmap()
# Ensure all CUDA VMM unmap operations complete before returning.
# Without this sync, resume() may race with pending unmaps, causing OOM
# when it tries to allocate new memory while old memory is still mapped.
torch.cuda.synchronize()
def _resume_weights(self) -> None: def _resume_weights(self) -> None:
if self._allocator is None: if self._allocator is None:
...@@ -169,7 +160,6 @@ class GMSMemorySaverImpl: ...@@ -169,7 +160,6 @@ class GMSMemorySaverImpl:
return return
logger.info("[GMS] Remapping weights (VA-stable)") logger.info("[GMS] Remapping weights (VA-stable)")
self._allocator.remap() self._allocator.remap()
torch.cuda.synchronize()
def finalize_write_mode(self, model: torch.nn.Module) -> None: def finalize_write_mode(self, model: torch.nn.Module) -> None:
"""Finalize write mode: register tensors, commit, and switch to read.""" """Finalize write mode: register tensors, commit, and switch to read."""
......
...@@ -119,11 +119,6 @@ class GMSWorker(Worker): ...@@ -119,11 +119,6 @@ class GMSWorker(Worker):
allocator = CuMemAllocator.get_instance() allocator = CuMemAllocator.get_instance()
allocator.sleep(offload_tags=tuple()) allocator.sleep(offload_tags=tuple())
# Ensure all CUDA VMM unmap operations complete before returning.
# Without this sync, wake_up() may race with pending unmaps, causing OOM
# when it tries to allocate new memory while old memory is still mapped.
torch.cuda.synchronize()
free_bytes_after, total = torch.cuda.mem_get_info() free_bytes_after, total = torch.cuda.mem_get_info()
freed_bytes = free_bytes_after - free_bytes_before freed_bytes = free_bytes_after - free_bytes_before
used_bytes = total - free_bytes_after used_bytes = total - free_bytes_after
...@@ -145,16 +140,11 @@ class GMSWorker(Worker): ...@@ -145,16 +140,11 @@ class GMSWorker(Worker):
assert manager is not None, "GMS client is not initialized" assert manager is not None, "GMS client is not initialized"
assert manager.is_unmapped, "GMS weights are not unmapped" assert manager.is_unmapped, "GMS weights are not unmapped"
manager.remap() manager.remap()
torch.cuda.synchronize()
if "kv_cache" in tags: if "kv_cache" in tags:
allocator = CuMemAllocator.get_instance() allocator = CuMemAllocator.get_instance()
allocator.wake_up(tags=["kv_cache"]) allocator.wake_up(tags=["kv_cache"])
# Ensure KV cache mappings are complete before returning.
# Without this sync, inference may start before mappings are ready.
torch.cuda.synchronize()
# Reinitialize FP8 KV scales if needed # Reinitialize FP8 KV scales if needed
if self.cache_config.cache_dtype.startswith("fp8") and hasattr( if self.cache_config.cache_dtype.startswith("fp8") and hasattr(
self.model_runner, "init_fp8_kv_scales" self.model_runner, "init_fp8_kv_scales"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment