"vscode:/vscode.git/clone" did not exist on "a98406d46cab99f03001fea896cef76201d7cff7"
Unverified Commit 69fdc9dd authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

fix: CUDA synchronization to prevent sleep/wake race conditions (#5759)


Signed-off-by: default avatarSchwinn Saereesitthipitak <schwinns@nvidia.com>
parent eef5e645
...@@ -109,3 +109,24 @@ def release_handle(handle: int) -> None: ...@@ -109,3 +109,24 @@ def release_handle(handle: int) -> None:
""" """
(result,) = cuda.cuMemRelease(handle) (result,) = cuda.cuMemRelease(handle)
check_cuda_result(result, "cuMemRelease") check_cuda_result(result, "cuMemRelease")
def synchronize() -> None:
"""Synchronize the current CUDA context.
Blocks until all preceding commands in the current context have completed.
"""
(result,) = cuda.cuCtxSynchronize()
check_cuda_result(result, "cuCtxSynchronize")
def set_current_device(device: int) -> None:
"""Set the current CUDA device by activating its primary context.
Args:
device: CUDA device index.
"""
result, ctx = cuda.cuDevicePrimaryCtxRetain(device)
check_cuda_result(result, "cuDevicePrimaryCtxRetain")
(result,) = cuda.cuCtxSetCurrent(ctx)
check_cuda_result(result, "cuCtxSetCurrent")
...@@ -25,7 +25,6 @@ import logging ...@@ -25,7 +25,6 @@ import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional from typing import Dict, List, Optional
import torch
from cuda.bindings import driver as cuda from cuda.bindings import driver as cuda
from gpu_memory_service.client.cuda_vmm_utils import ( from gpu_memory_service.client.cuda_vmm_utils import (
free_va, free_va,
...@@ -34,6 +33,8 @@ from gpu_memory_service.client.cuda_vmm_utils import ( ...@@ -34,6 +33,8 @@ from gpu_memory_service.client.cuda_vmm_utils import (
release_handle, release_handle,
reserve_va, reserve_va,
set_access, set_access,
set_current_device,
synchronize,
unmap, unmap,
) )
from gpu_memory_service.client.rpc import GMSRPCClient from gpu_memory_service.client.rpc import GMSRPCClient
...@@ -138,9 +139,8 @@ class GMSClientMemoryManager: ...@@ -138,9 +139,8 @@ class GMSClientMemoryManager:
"" # Hash from server, saved on connect/commit "" # Hash from server, saved on connect/commit
) )
# Ensure torch is on the right device for subsequent CUDA operations. # Set the current CUDA device for subsequent operations.
if torch.cuda.is_available(): set_current_device(self.device)
torch.cuda.set_device(self.device)
# Cache granularity for VA alignment # Cache granularity for VA alignment
self.granularity = get_allocation_granularity(device) self.granularity = get_allocation_granularity(device)
...@@ -342,8 +342,7 @@ class GMSClientMemoryManager: ...@@ -342,8 +342,7 @@ class GMSClientMemoryManager:
""" """
self._require_rw() self._require_rw()
if torch.cuda.is_available(): synchronize()
torch.cuda.synchronize(self.device)
# After publishing, prevent further writes locally. # After publishing, prevent further writes locally.
for va, m in list(self._mappings.items()): for va, m in list(self._mappings.items()):
...@@ -395,8 +394,7 @@ class GMSClientMemoryManager: ...@@ -395,8 +394,7 @@ class GMSClientMemoryManager:
if self.lock_type != GrantedLockType.RO: if self.lock_type != GrantedLockType.RO:
raise RuntimeError("unmap() requires RO mode") raise RuntimeError("unmap() requires RO mode")
if torch.cuda.is_available(): synchronize()
torch.cuda.synchronize(self.device)
# Preserve allocation IDs for remapping on remap # Preserve allocation IDs for remapping on remap
self._preserved_allocation_ids = list(self._allocation_id_to_va.keys()) self._preserved_allocation_ids = list(self._allocation_id_to_va.keys())
...@@ -405,6 +403,11 @@ class GMSClientMemoryManager: ...@@ -405,6 +403,11 @@ class GMSClientMemoryManager:
self._unmap_preserving_va() self._unmap_preserving_va()
self._va_preserved = True self._va_preserved = True
# Ensure all CUDA VMM unmap operations complete before releasing the lock.
# This prevents race conditions where remap() may be called before
# physical memory is fully released.
synchronize()
self._client_rpc.close() self._client_rpc.close()
self._client = None self._client = None
self._unmapped = True self._unmapped = True
...@@ -430,8 +433,7 @@ class GMSClientMemoryManager: ...@@ -430,8 +433,7 @@ class GMSClientMemoryManager:
if not self._unmapped: if not self._unmapped:
return True return True
if torch.cuda.is_available(): set_current_device(self.device)
torch.cuda.set_device(self.device)
eff_timeout = timeout_ms if timeout_ms is not None else self._timeout_ms eff_timeout = timeout_ms if timeout_ms is not None else self._timeout_ms
self._connect( self._connect(
...@@ -489,8 +491,7 @@ class GMSClientMemoryManager: ...@@ -489,8 +491,7 @@ class GMSClientMemoryManager:
return return
# Ensure kernels are done before tearing down mappings. # Ensure kernels are done before tearing down mappings.
if torch.cuda.is_available(): synchronize()
torch.cuda.synchronize(self.device)
# Release all mappings including preserved VA reservations # Release all mappings including preserved VA reservations
self._unmap_all() self._unmap_all()
...@@ -566,8 +567,7 @@ class GMSClientMemoryManager: ...@@ -566,8 +567,7 @@ class GMSClientMemoryManager:
Returns the VA. Returns the VA.
Raises StaleMemoryLayoutError if allocation is missing or size changed. Raises StaleMemoryLayoutError if allocation is missing or size changed.
""" """
if torch.cuda.is_available(): set_current_device(self.device)
torch.cuda.set_device(self.device)
va = self._allocation_id_to_va.get(allocation_id) va = self._allocation_id_to_va.get(allocation_id)
if va is None: if va is None:
...@@ -608,7 +608,7 @@ class GMSClientMemoryManager: ...@@ -608,7 +608,7 @@ class GMSClientMemoryManager:
set_access(va, mapping.aligned_size, self.device, current_access) set_access(va, mapping.aligned_size, self.device, current_access)
# Synchronize to ensure mapping is complete before any access # Synchronize to ensure mapping is complete before any access
cuda.cuCtxSynchronize() synchronize()
# Validate the pointer is accessible (this is what Triton checks) # Validate the pointer is accessible (this is what Triton checks)
result, _dev_ptr = cuda.cuPointerGetAttribute( result, _dev_ptr = cuda.cuPointerGetAttribute(
......
...@@ -136,6 +136,8 @@ class GMSMemorySaverImpl: ...@@ -136,6 +136,8 @@ class GMSMemorySaverImpl:
self._pause_weights() self._pause_weights()
if tag is None or not self._is_weights_tag(tag): if tag is None or not self._is_weights_tag(tag):
self._torch_impl.pause(tag=tag) self._torch_impl.pause(tag=tag)
# Ensure KV cache unmap operations complete before returning.
torch.cuda.synchronize()
def resume(self, tag: Optional[str] = None) -> None: def resume(self, tag: Optional[str] = None) -> None:
if self._disabled: if self._disabled:
...@@ -144,6 +146,9 @@ class GMSMemorySaverImpl: ...@@ -144,6 +146,9 @@ class GMSMemorySaverImpl:
self._resume_weights() self._resume_weights()
if tag is None or not self._is_weights_tag(tag): if tag is None or not self._is_weights_tag(tag):
self._torch_impl.resume(tag=tag) self._torch_impl.resume(tag=tag)
# Ensure KV cache mappings are complete before returning.
# Without this sync, inference may start before mappings are ready.
torch.cuda.synchronize()
def _pause_weights(self) -> None: def _pause_weights(self) -> None:
if self._allocator is None: if self._allocator is None:
...@@ -152,6 +157,10 @@ class GMSMemorySaverImpl: ...@@ -152,6 +157,10 @@ class GMSMemorySaverImpl:
return return
logger.info("[GMS] Unmapping weights (VA-stable)") logger.info("[GMS] Unmapping weights (VA-stable)")
self._allocator.unmap() self._allocator.unmap()
# Ensure all CUDA VMM unmap operations complete before returning.
# Without this sync, resume() may race with pending unmaps, causing OOM
# when it tries to allocate new memory while old memory is still mapped.
torch.cuda.synchronize()
def _resume_weights(self) -> None: def _resume_weights(self) -> None:
if self._allocator is None: if self._allocator is None:
......
...@@ -119,6 +119,11 @@ class GMSWorker(Worker): ...@@ -119,6 +119,11 @@ class GMSWorker(Worker):
allocator = CuMemAllocator.get_instance() allocator = CuMemAllocator.get_instance()
allocator.sleep(offload_tags=tuple()) allocator.sleep(offload_tags=tuple())
# Ensure all CUDA VMM unmap operations complete before returning.
# Without this sync, wake_up() may race with pending unmaps, causing OOM
# when it tries to allocate new memory while old memory is still mapped.
torch.cuda.synchronize()
free_bytes_after, total = torch.cuda.mem_get_info() free_bytes_after, total = torch.cuda.mem_get_info()
freed_bytes = free_bytes_after - free_bytes_before freed_bytes = free_bytes_after - free_bytes_before
used_bytes = total - free_bytes_after used_bytes = total - free_bytes_after
...@@ -146,6 +151,10 @@ class GMSWorker(Worker): ...@@ -146,6 +151,10 @@ class GMSWorker(Worker):
allocator = CuMemAllocator.get_instance() allocator = CuMemAllocator.get_instance()
allocator.wake_up(tags=["kv_cache"]) allocator.wake_up(tags=["kv_cache"])
# Ensure KV cache mappings are complete before returning.
# Without this sync, inference may start before mappings are ready.
torch.cuda.synchronize()
# Reinitialize FP8 KV scales if needed # Reinitialize FP8 KV scales if needed
if self.cache_config.cache_dtype.startswith("fp8") and hasattr( if self.cache_config.cache_dtype.startswith("fp8") and hasattr(
self.model_runner, "init_fp8_kv_scales" self.model_runner, "init_fp8_kv_scales"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment