"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "067e77f9a87c3466fce41c8fe8710fddc69ec26c"
Unverified Commit 69fdc9dd authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

fix: CUDA synchronization to prevent sleep/wake race conditions (#5759)


Signed-off-by: default avatarSchwinn Saereesitthipitak <schwinns@nvidia.com>
parent eef5e645
......@@ -109,3 +109,24 @@ def release_handle(handle: int) -> None:
"""
(result,) = cuda.cuMemRelease(handle)
check_cuda_result(result, "cuMemRelease")
def synchronize() -> None:
"""Synchronize the current CUDA context.
Blocks until all preceding commands in the current context have completed.
"""
(result,) = cuda.cuCtxSynchronize()
check_cuda_result(result, "cuCtxSynchronize")
def set_current_device(device: int) -> None:
"""Set the current CUDA device by activating its primary context.
Args:
device: CUDA device index.
"""
result, ctx = cuda.cuDevicePrimaryCtxRetain(device)
check_cuda_result(result, "cuDevicePrimaryCtxRetain")
(result,) = cuda.cuCtxSetCurrent(ctx)
check_cuda_result(result, "cuCtxSetCurrent")
......@@ -25,7 +25,6 @@ import logging
from dataclasses import dataclass
from typing import Dict, List, Optional
import torch
from cuda.bindings import driver as cuda
from gpu_memory_service.client.cuda_vmm_utils import (
free_va,
......@@ -34,6 +33,8 @@ from gpu_memory_service.client.cuda_vmm_utils import (
release_handle,
reserve_va,
set_access,
set_current_device,
synchronize,
unmap,
)
from gpu_memory_service.client.rpc import GMSRPCClient
......@@ -138,9 +139,8 @@ class GMSClientMemoryManager:
"" # Hash from server, saved on connect/commit
)
# Ensure torch is on the right device for subsequent CUDA operations.
if torch.cuda.is_available():
torch.cuda.set_device(self.device)
# Set the current CUDA device for subsequent operations.
set_current_device(self.device)
# Cache granularity for VA alignment
self.granularity = get_allocation_granularity(device)
......@@ -342,8 +342,7 @@ class GMSClientMemoryManager:
"""
self._require_rw()
if torch.cuda.is_available():
torch.cuda.synchronize(self.device)
synchronize()
# After publishing, prevent further writes locally.
for va, m in list(self._mappings.items()):
......@@ -395,8 +394,7 @@ class GMSClientMemoryManager:
if self.lock_type != GrantedLockType.RO:
raise RuntimeError("unmap() requires RO mode")
if torch.cuda.is_available():
torch.cuda.synchronize(self.device)
synchronize()
# Preserve allocation IDs for remapping on remap
self._preserved_allocation_ids = list(self._allocation_id_to_va.keys())
......@@ -405,6 +403,11 @@ class GMSClientMemoryManager:
self._unmap_preserving_va()
self._va_preserved = True
# Ensure all CUDA VMM unmap operations complete before releasing the lock.
# This prevents race conditions where remap() may be called before
# physical memory is fully released.
synchronize()
self._client_rpc.close()
self._client = None
self._unmapped = True
......@@ -430,8 +433,7 @@ class GMSClientMemoryManager:
if not self._unmapped:
return True
if torch.cuda.is_available():
torch.cuda.set_device(self.device)
set_current_device(self.device)
eff_timeout = timeout_ms if timeout_ms is not None else self._timeout_ms
self._connect(
......@@ -489,8 +491,7 @@ class GMSClientMemoryManager:
return
# Ensure kernels are done before tearing down mappings.
if torch.cuda.is_available():
torch.cuda.synchronize(self.device)
synchronize()
# Release all mappings including preserved VA reservations
self._unmap_all()
......@@ -566,8 +567,7 @@ class GMSClientMemoryManager:
Returns the VA.
Raises StaleMemoryLayoutError if allocation is missing or size changed.
"""
if torch.cuda.is_available():
torch.cuda.set_device(self.device)
set_current_device(self.device)
va = self._allocation_id_to_va.get(allocation_id)
if va is None:
......@@ -608,7 +608,7 @@ class GMSClientMemoryManager:
set_access(va, mapping.aligned_size, self.device, current_access)
# Synchronize to ensure mapping is complete before any access
cuda.cuCtxSynchronize()
synchronize()
# Validate the pointer is accessible (this is what Triton checks)
result, _dev_ptr = cuda.cuPointerGetAttribute(
......
......@@ -136,6 +136,8 @@ class GMSMemorySaverImpl:
self._pause_weights()
if tag is None or not self._is_weights_tag(tag):
self._torch_impl.pause(tag=tag)
# Ensure KV cache unmap operations complete before returning.
torch.cuda.synchronize()
def resume(self, tag: Optional[str] = None) -> None:
if self._disabled:
......@@ -144,6 +146,9 @@ class GMSMemorySaverImpl:
self._resume_weights()
if tag is None or not self._is_weights_tag(tag):
self._torch_impl.resume(tag=tag)
# Ensure KV cache mappings are complete before returning.
# Without this sync, inference may start before mappings are ready.
torch.cuda.synchronize()
def _pause_weights(self) -> None:
if self._allocator is None:
......@@ -152,6 +157,10 @@ class GMSMemorySaverImpl:
return
logger.info("[GMS] Unmapping weights (VA-stable)")
self._allocator.unmap()
# Ensure all CUDA VMM unmap operations complete before returning.
# Without this sync, resume() may race with pending unmaps, causing OOM
# when it tries to allocate new memory while old memory is still mapped.
torch.cuda.synchronize()
def _resume_weights(self) -> None:
if self._allocator is None:
......
......@@ -119,6 +119,11 @@ class GMSWorker(Worker):
allocator = CuMemAllocator.get_instance()
allocator.sleep(offload_tags=tuple())
# Ensure all CUDA VMM unmap operations complete before returning.
# Without this sync, wake_up() may race with pending unmaps, causing OOM
# when it tries to allocate new memory while old memory is still mapped.
torch.cuda.synchronize()
free_bytes_after, total = torch.cuda.mem_get_info()
freed_bytes = free_bytes_after - free_bytes_before
used_bytes = total - free_bytes_after
......@@ -146,6 +151,10 @@ class GMSWorker(Worker):
allocator = CuMemAllocator.get_instance()
allocator.wake_up(tags=["kv_cache"])
# Ensure KV cache mappings are complete before returning.
# Without this sync, inference may start before mappings are ready.
torch.cuda.synchronize()
# Reinitialize FP8 KV scales if needed
if self.cache_config.cache_dtype.startswith("fp8") and hasattr(
self.model_runner, "init_fp8_kv_scales"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment