Unverified Commit f976f02b authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

refactor: rename sleep/wake endpoints for consistency (#5629)

parent 3d8c497e
......@@ -303,7 +303,7 @@ class BaseWorkerHandler(ABC):
logger.error(f"Failed to sleep engine: {e}")
return {"status": "error", "message": str(e)}
async def wake(self, body: dict) -> dict:
async def wake_up(self, body: dict) -> dict:
"""Wake the engine to restore GPU memory and re-register to discovery.
Args:
......@@ -331,7 +331,7 @@ class BaseWorkerHandler(ABC):
return {"status": "ok", "message": f"Engine woke (tags={tags})"}
except Exception as e:
logger.error(f"Failed to wake engine: {e}")
logger.error(f"Failed to wake up engine: {e}")
return {"status": "error", "message": str(e)}
@abstractmethod
......
......@@ -460,10 +460,10 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
setup_metrics_collection(config, generate_endpoint, logger)
# Register sleep/wake engine routes
# Register sleep/wake_up engine routes
runtime.register_engine_route("sleep", handler.sleep)
runtime.register_engine_route("wake", handler.wake)
logger.info("Registered engine routes: /engine/sleep, /engine/wake")
runtime.register_engine_route("wake_up", handler.wake_up)
logger.info("Registered engine routes: /engine/sleep, /engine/wake_up")
# Handle non-leader nodes - don't serve endpoints
if config.engine_args.data_parallel_rank:
......@@ -585,10 +585,10 @@ async def init(runtime: DistributedRuntime, config: Config):
setup_metrics_collection(config, generate_endpoint, logger)
# Register sleep/wake engine routes
# Register sleep/wake_up engine routes
runtime.register_engine_route("sleep", handler.sleep)
runtime.register_engine_route("wake", handler.wake)
logger.info("Registered engine routes: /engine/sleep, /engine/wake")
runtime.register_engine_route("wake_up", handler.wake_up)
logger.info("Registered engine routes: /engine/sleep, /engine/wake_up")
# Handle non-leader nodes - don't serve endpoints
if config.engine_args.data_parallel_rank:
......
......@@ -10,7 +10,7 @@ Key properties:
- The socket connection itself is the RW/RO lock.
- In write mode, the manager can allocate + map RW and then publish via commit().
- In read mode, the manager can import + map RO and hold the RO lock during inference.
- sleep()/wake() releases and reacquires the RO lock (and remaps allocations).
- unmap()/remap() releases and reacquires the RO lock (and remaps allocations).
This module uses cuda-python bindings for CUDA driver API calls:
- import FDs (cuMemImportFromShareableHandle)
......@@ -47,20 +47,20 @@ logger = logging.getLogger(__name__)
class StaleMemoryLayoutError(Exception):
"""Raised when memory layout was modified while sleeping.
"""Raised when memory layout was modified while unmapped.
This error indicates that a writer acquired the RW lock and changed the
allocation structure (different sizes, different tensor layouts) while this
reader was sleeping. The caller should re-import the model from scratch.
reader was unmapped. The caller should re-import the model from scratch.
IMPORTANT: This is a LAYOUT check, NOT a CONTENT check.
- Detected: Allocation sizes changed, tensors added/removed, metadata structure changed
- NOT detected: Weight values modified in-place
This design is intentional: sleep/wake enables use cases like RL training
This design is intentional: unmap/remap enables use cases like RL training
where another process can write to the same memory locations (e.g., updating
weights) while preserving the structure. As long as the layout (allocation
and metadata table hashes) remains identical, wake() succeeds.
and metadata table hashes) remains identical, remap() succeeds.
"""
pass
......@@ -106,7 +106,7 @@ class GMSClientMemoryManager:
Modes:
- mode=RequestedLockType.RW: acquire RW lock, allocate/map RW, mutate metadata, commit/publish.
- mode=RequestedLockType.RO: acquire RO lock (READY only), import/map RO, sleep/wake.
- mode=RequestedLockType.RO: acquire RO lock (READY only), import/map RO, unmap/remap.
- mode=RequestedLockType.RW_OR_RO: try RW if available, else wait for RO.
"""
......@@ -126,13 +126,13 @@ class GMSClientMemoryManager:
self._mappings: Dict[int, LocalMapping] = {} # va -> mapping
self._allocation_id_to_va: Dict[str, int] = {}
self._sleeping = False
self._unmapped = False
self._closed = False
self._preserved_allocation_ids: List[str] = []
self._published = False
self._mode: Optional[GrantedLockType] = None # Updated by _connect
# VA-stable sleep/wake state
# VA-stable unmap/remap state
self._va_preserved = False
self._last_memory_layout_hash: str = (
"" # Hash from server, saved on connect/commit
......@@ -157,10 +157,10 @@ class GMSClientMemoryManager:
self._client = GMSRPCClient(
self.socket_path, lock_type=lock_type, timeout_ms=timeout_ms
)
self._sleeping = False
self._unmapped = False
# Update mode based on granted lock type (may differ from requested for rw_or_ro)
self._mode = self._client.lock_type
# Save state hash for stale detection on wake (skip during wake itself)
# Save state hash for stale detection on remap (skip during remap itself)
if update_memory_layout_hash and self._client.committed:
self._last_memory_layout_hash = self._client.get_memory_layout_hash()
......@@ -181,8 +181,8 @@ class GMSClientMemoryManager:
return self._client is not None and self._client.is_connected
@property
def is_sleeping(self) -> bool:
return self._sleeping
def is_unmapped(self) -> bool:
return self._unmapped
@property
def mappings(self) -> Dict[int, LocalMapping]:
......@@ -366,9 +366,9 @@ class GMSClientMemoryManager:
"""
if self._closed:
raise RuntimeError("Memory manager is closed")
if self._sleeping:
if self._unmapped:
raise RuntimeError(
"Cannot switch_to_read() while sleeping; call wake() first"
"Cannot switch_to_read() while unmapped; call remap() first"
)
if self._client is not None:
if self.lock_type == GrantedLockType.RO:
......@@ -380,25 +380,25 @@ class GMSClientMemoryManager:
eff_timeout = timeout_ms if timeout_ms is not None else self._timeout_ms
self._connect(lock_type=RequestedLockType.RO, timeout_ms=eff_timeout)
# ==================== Sleep / wake (read mode) ====================
# ==================== Unmap / remap (read mode) ====================
def sleep(self) -> None:
def unmap(self) -> None:
"""Release RO lock and unmap local allocations (VA-stable).
VAs are preserved during sleep so tensor pointers remain stable.
On wake, allocations are remapped to the same VAs.
VAs are preserved during unmap so tensor pointers remain stable.
On remap, allocations are remapped to the same VAs.
"""
if self._closed:
raise RuntimeError("Memory manager is closed")
if self._sleeping:
if self._unmapped:
return
if self.lock_type != GrantedLockType.RO:
raise RuntimeError("sleep() requires RO mode")
raise RuntimeError("unmap() requires RO mode")
if torch.cuda.is_available():
torch.cuda.synchronize(self.device)
# Preserve allocation IDs for remapping on wake
# Preserve allocation IDs for remapping on remap
self._preserved_allocation_ids = list(self._allocation_id_to_va.keys())
# Unmap physical memory but keep VA reservations
......@@ -407,12 +407,12 @@ class GMSClientMemoryManager:
self._client_rpc.close()
self._client = None
self._sleeping = True
self._unmapped = True
def wake(self, timeout_ms: Optional[int] = None) -> bool:
def remap(self, timeout_ms: Optional[int] = None) -> bool:
"""Reacquire RO lock and remap preserved allocations (VA-stable).
Allocations are remapped to the same VAs they had before sleep,
Allocations are remapped to the same VAs they had before unmap,
ensuring tensor pointers remain valid.
Args:
......@@ -423,11 +423,11 @@ class GMSClientMemoryManager:
Raises:
TimeoutError: If timeout_ms expires waiting for RO lock.
StaleMemoryLayoutError: If weights were structurally changed while sleeping.
StaleMemoryLayoutError: If weights were structurally changed while unmapped.
"""
if self._closed:
raise RuntimeError("Memory manager is closed")
if not self._sleeping:
if not self._unmapped:
return True
if torch.cuda.is_available():
......@@ -440,14 +440,14 @@ class GMSClientMemoryManager:
update_memory_layout_hash=False,
)
# Check if memory layout changed while sleeping
# Check if memory layout changed while unmapped
current_hash = self._client_rpc.get_memory_layout_hash()
if (
self._last_memory_layout_hash
and current_hash != self._last_memory_layout_hash
):
raise StaleMemoryLayoutError(
f"State changed while sleeping: hash {self._last_memory_layout_hash[:16]}... -> {current_hash[:16]}..."
f"State changed while unmapped: hash {self._last_memory_layout_hash[:16]}... -> {current_hash[:16]}..."
)
# Remap to preserved VAs
......@@ -469,16 +469,16 @@ class GMSClientMemoryManager:
if failed_count > 0:
raise RuntimeError(
f"Wake failed: {failed_count} of {len(self._preserved_allocation_ids)} "
f"Remap failed: {failed_count} of {len(self._preserved_allocation_ids)} "
f"allocations could not be remapped"
)
logger.info(
f"[GPU Memory Service] Wake complete on device {self.device}: "
f"[GPU Memory Service] Remap complete on device {self.device}: "
f"remapped {remapped_count} allocations ({total_bytes / (1 << 30):.2f} GiB)"
)
self._sleeping = False
self._unmapped = False
self._va_preserved = False
return True
......@@ -499,7 +499,7 @@ class GMSClientMemoryManager:
self._client.close()
self._client = None
self._closed = True
self._sleeping = False
self._unmapped = False
self._va_preserved = False
self._preserved_allocation_ids.clear()
......@@ -515,8 +515,8 @@ class GMSClientMemoryManager:
def _client_rpc(self) -> GMSRPCClient:
"""Get connected client or raise. Use instead of _require_connected() + assert."""
if self._client is None:
if self._sleeping:
raise RuntimeError("Memory manager is sleeping")
if self._unmapped:
raise RuntimeError("Memory manager is unmapped")
raise RuntimeError("Memory manager is not connected")
return self._client
......@@ -530,10 +530,10 @@ class GMSClientMemoryManager:
self._allocation_id_to_va[m.allocation_id] = m.va
def _unmap_preserving_va(self) -> None:
"""Unmap physical memory but PRESERVE VA reservations for sleep/wake.
"""Unmap physical memory but PRESERVE VA reservations for unmap/remap.
This keeps the VA reservation intact so tensors maintain stable pointers.
On wake, we can remap to the same VAs.
On remap, we can remap to the same VAs.
"""
unmapped_count = 0
total_bytes = 0
......@@ -560,7 +560,7 @@ class GMSClientMemoryManager:
def _remap_preserved_va(self, allocation_id: str) -> int:
"""Remap an allocation to its preserved VA.
Requires the VA to already be reserved (from before sleep).
Requires the VA to already be reserved (from before unmap).
Validates allocation still exists and size matches.
Returns the VA.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment