"examples/vscode:/vscode.git/clone" did not exist on "43877a620bf629d3625c870ef787e590101e0518"
Unverified Commit a2fbda3e authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

feat(sglang): integration with GPU Memory Service (#5664)


Signed-off-by: default avatarSchwinn Saereesitthipitak <17022745+galletas1712@users.noreply.github.com>
parent 2c36a588
......@@ -367,6 +367,13 @@ async def parse_args(args: list[str]) -> Config:
bootstrap_port = _reserve_disaggregation_bootstrap_port()
ServerArgs.add_cli_args(parser)
# Add "gms" to --load-format choices so it passes argparse validation.
# The actual loader class is set in main.py when load_format == "gms".
for action in parser._actions:
if getattr(action, "dest", None) == "load_format" and action.choices:
action.choices = list(action.choices) + ["gms"]
break
# Handle config file if present
temp_config_file = None # Track temp file for cleanup
if "--config" in args:
......
......@@ -70,6 +70,12 @@ async def worker():
config = await parse_args(sys.argv[1:])
dump_config(config.dynamo_args.dump_config_to, config)
# Setup GPU Memory Service if --load-format gms is used
if config.server_args.load_format == "gms":
from gpu_memory_service.integrations.sglang import setup_gms
config.server_args.load_format = setup_gms(config.server_args)
loop = asyncio.get_running_loop()
# Set DYN_EVENT_PLANE environment variable based on config
......
......@@ -315,7 +315,7 @@ def setup_vllm_engine(config, stat_logger=None):
os.environ["VLLM_LORA_MODULES_LOADING_TIMEOUT"] = "600"
if engine_args.load_format == "gms":
engine_args.worker_cls = "gpu_memory_service.vllm_integration.worker.GMSWorker"
engine_args.worker_cls = "gpu_memory_service.integrations.vllm.worker.GMSWorker"
# Load default sampling params from `generation_config.json`
default_sampling_params = (
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Common utilities shared across GMS integrations."""
from gpu_memory_service.integrations.common.patches import patch_empty_cache
from gpu_memory_service.integrations.common.utils import (
finalize_gms_write,
setup_meta_tensor_workaround,
)
__all__ = [
"patch_empty_cache",
"setup_meta_tensor_workaround",
"finalize_gms_write",
]
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Common patches shared across GMS integrations."""
from __future__ import annotations
import logging
import torch
from gpu_memory_service import get_gms_client_memory_manager
logger = logging.getLogger(__name__)
_empty_cache_patched = False
def patch_empty_cache() -> None:
"""Patch torch.cuda.empty_cache to prevent segfaults with VMM allocations.
When weights are allocated through our VMM-based pluggable allocator, calling
torch.cuda.empty_cache() causes segfaults because the native caching allocator
tries to release blocks that were allocated through VMM APIs.
This patch is idempotent - calling it multiple times has no effect.
"""
global _empty_cache_patched
if _empty_cache_patched:
return
_original_empty_cache = torch.cuda.empty_cache
def safe_empty_cache() -> None:
manager = get_gms_client_memory_manager()
if manager is not None and len(manager.mappings) > 0:
logger.debug(
"[GMS] Skipping torch.cuda.empty_cache() - %d VMM allocations active",
len(manager.mappings),
)
return
_original_empty_cache()
torch.cuda.empty_cache = safe_empty_cache
_empty_cache_patched = True
logger.info("[GMS] Patched torch.cuda.empty_cache")
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Common utilities shared across GMS integrations."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
import torch
if TYPE_CHECKING:
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
logger = logging.getLogger(__name__)
def setup_meta_tensor_workaround() -> None:
"""Enable workaround for meta tensor operations like torch.nonzero()."""
try:
import torch.fx.experimental._config as fx_config
fx_config.meta_nonzero_assume_all_nonzero = True
except (ImportError, AttributeError):
pass
def finalize_gms_write(
allocator: "GMSClientMemoryManager", model: torch.nn.Module
) -> int:
"""Finalize GMS write mode: register tensors, commit, switch to read.
This is typically called when the (writing) model loader finishes, and
is ready to commit the weights so that other engines can import these
weights and read them.
Args:
allocator: The GMS client memory manager in write mode.
model: The loaded model with weights to register.
Returns:
Total bytes committed.
Raises:
RuntimeError: If commit fails.
"""
from gpu_memory_service.client.torch.module import register_module_tensors
register_module_tensors(allocator, model)
total_bytes = allocator.total_bytes
# Wait for all writes to weights (from caller) to complete before mode switch
torch.cuda.synchronize()
if not allocator.commit():
raise RuntimeError("GMS commit failed")
allocator.switch_to_read()
logger.info(
"[GMS] Committed %.2f GiB, switched to read mode with %d mappings",
total_bytes / (1 << 30),
len(allocator._mappings),
)
return int(total_bytes)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GPU Memory Service integration for SGLang.
Usage:
from gpu_memory_service.integrations.sglang import setup_gms
if server_args.load_format == "gms":
server_args.load_format = setup_gms(server_args)
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Type
if TYPE_CHECKING:
from gpu_memory_service.integrations.sglang.model_loader import GMSModelLoader
logger = logging.getLogger(__name__)
def setup_gms(server_args) -> Type["GMSModelLoader"]:
"""Setup GPU Memory Service for SGLang.
Validates config and returns the GMSModelLoader class.
Patches are applied automatically when GMSModelLoader is imported.
Args:
server_args: SGLang ServerArgs instance.
Returns:
GMSModelLoader class to use as load_format.
Raises:
ValueError: If incompatible options are enabled.
"""
# Validate config - GMS provides its own VA-stable unmap/remap for weights
if getattr(server_args, "enable_weights_cpu_backup", False):
raise ValueError(
"Cannot use --enable-weights-cpu-backup with --load-format gms."
)
if getattr(server_args, "enable_draft_weights_cpu_backup", False):
raise ValueError(
"Cannot use --enable-draft-weights-cpu-backup with --load-format gms."
)
# Import triggers patches at module level
from gpu_memory_service.integrations.sglang.model_loader import GMSModelLoader
logger.info("[GMS] Using GMSModelLoader...")
return GMSModelLoader
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Hybrid torch_memory_saver implementation for GPU Memory Service.
This module provides a hybrid implementation that combines:
1. GPU Memory Service allocator for "weights" tag (VA-stable unmap/remap, shared)
2. Torch mempool mode for other tags like "kv_cache" (CPU backup, per-instance)
The impl uses RW_OR_RO mode to connect to GMS:
- First process gets RW lock and loads weights from disk
- Subsequent processes get RO lock and import weights from metadata
"""
from __future__ import annotations
import logging
from contextlib import contextmanager
from typing import TYPE_CHECKING, Optional
import torch
if TYPE_CHECKING:
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
from torch.cuda.memory import MemPool
from torch_memory_saver.entrypoint import _TorchMemorySaverImpl
logger = logging.getLogger(__name__)
def get_gms_memory_saver_impl() -> Optional["GMSMemorySaverImpl"]:
"""Get the GMS memory saver impl from the torch_memory_saver singleton."""
try:
import torch_memory_saver
return torch_memory_saver.torch_memory_saver.gms_impl
except (ImportError, AttributeError):
return None
class GMSMemorySaverImpl:
"""Hybrid implementation: GMS for weights, torch mempool for KV cache.
Routes operations based on tag:
- "weights" or "model_weights": Handled by GMS allocator (VA-stable)
- Other tags (e.g., "kv_cache"): Delegated to torch mempool mode
"""
def __init__(
self,
torch_impl: "_TorchMemorySaverImpl",
socket_path: str,
device_index: int,
):
self._torch_impl = torch_impl
self._socket_path = socket_path
self._device_index = device_index
self._disabled = False
self._imported_weights_bytes: int = 0
# Initialize allocator with auto mode
self._allocator: Optional["GMSClientMemoryManager"]
self._mem_pool: Optional["MemPool"]
self._mode: str
self._allocator, self._mem_pool, self._mode = self._init_allocator()
logger.info(
"[GMS] Initialized: weights=%s mode (device=%d, socket=%s)",
self._mode.upper(),
device_index,
socket_path,
)
def _init_allocator(
self,
) -> tuple[Optional["GMSClientMemoryManager"], Optional["MemPool"], str]:
"""Create allocator with automatic mode selection."""
from gpu_memory_service import get_or_create_gms_client_memory_manager
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
allocator, mem_pool = get_or_create_gms_client_memory_manager(
self._socket_path,
self._device_index,
mode=RequestedLockType.RW_OR_RO,
tag="weights",
)
granted_mode = allocator.mode
if granted_mode == GrantedLockType.RW:
allocator.clear_all()
actual_mode = "write"
else:
actual_mode = "read"
logger.info(
"[GMS] Initialized in AUTO mode, granted=%s (device=%d)",
actual_mode.upper(),
self._device_index,
)
return (
allocator,
mem_pool if granted_mode == GrantedLockType.RW else None,
actual_mode,
)
def _is_weights_tag(self, tag: Optional[str]) -> bool:
return tag in ("weights", "model_weights")
def get_mode(self) -> str:
return self._mode
def get_allocator(self) -> Optional["GMSClientMemoryManager"]:
return self._allocator
@contextmanager
def region(self, tag: str, enable_cpu_backup: bool):
"""Mark allocation region with tag."""
if not self._is_weights_tag(tag):
with self._torch_impl.region(tag=tag, enable_cpu_backup=enable_cpu_backup):
yield
return
if self._mode == "read":
yield
return
if self._mem_pool is None:
raise RuntimeError("GMS mempool is None in WRITE mode")
target_device = torch.device("cuda", self._device_index)
with torch.cuda.use_mem_pool(self._mem_pool, device=target_device):
yield
def pause(self, tag: Optional[str] = None) -> None:
if self._disabled:
return
if tag is None or self._is_weights_tag(tag):
self._pause_weights()
if tag is None or not self._is_weights_tag(tag):
self._torch_impl.pause(tag=tag)
def resume(self, tag: Optional[str] = None) -> None:
if self._disabled:
return
if tag is None or self._is_weights_tag(tag):
self._resume_weights()
if tag is None or not self._is_weights_tag(tag):
self._torch_impl.resume(tag=tag)
def _pause_weights(self) -> None:
if self._allocator is None:
return
if self._allocator.is_unmapped:
return
logger.info("[GMS] Unmapping weights (VA-stable)")
self._allocator.unmap()
def _resume_weights(self) -> None:
if self._allocator is None:
return
if not self._allocator.is_unmapped:
return
logger.info("[GMS] Remapping weights (VA-stable)")
self._allocator.remap()
torch.cuda.synchronize()
def finalize_write_mode(self, model: torch.nn.Module) -> None:
"""Finalize write mode: register tensors, commit, and switch to read."""
if self._mode != "write":
return
if self._allocator is None:
raise RuntimeError("Allocator is None in WRITE mode")
from gpu_memory_service.integrations.common.utils import finalize_gms_write
self._imported_weights_bytes = finalize_gms_write(self._allocator, model)
self._mode = "read"
def set_imported_weights_bytes(self, bytes_count: int) -> None:
self._imported_weights_bytes = bytes_count
def get_imported_weights_bytes(self) -> int:
return self._imported_weights_bytes
def disable(self) -> None:
self._disabled = True
def enable(self) -> None:
self._disabled = False
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""SGLang model loader for GPU Memory Service integration.
Provides a model loader that loads weights via GMS for cross-process sharing.
The loader uses RW_OR_RO mode: first process loads from disk (RW), subsequent
processes import from GMS metadata (RO).
Usage:
Set --load-format gms when launching SGLang.
"""
from __future__ import annotations
import logging
from dataclasses import replace
import torch
from gpu_memory_service.integrations.common import patch_empty_cache
from gpu_memory_service.integrations.common.utils import setup_meta_tensor_workaround
from gpu_memory_service.integrations.sglang.patches import (
patch_model_runner,
patch_torch_memory_saver,
)
logger = logging.getLogger(__name__)
# Apply patches at module import time.
# This module is only imported when load_format="gms" is used.
patch_empty_cache()
patch_torch_memory_saver()
patch_model_runner()
logger.info("[GMS] Applied patches")
class GMSModelLoader:
"""SGLang model loader that loads/imports weights via GPU Memory Service."""
def __init__(self, load_config):
self.load_config = load_config
self._default_loader = None
def _get_default_loader(self):
if self._default_loader is None:
from sglang.srt.model_loader.loader import DefaultModelLoader
config = replace(self.load_config, load_format="auto")
self._default_loader = DefaultModelLoader(config)
return self._default_loader
def load_model(
self,
*,
model_config,
device_config,
) -> torch.nn.Module:
"""Load or import model weights."""
from gpu_memory_service.integrations.sglang.memory_saver import (
get_gms_memory_saver_impl,
)
impl = get_gms_memory_saver_impl()
if impl is None:
raise RuntimeError(
"GMS impl not initialized. "
"Ensure torch_memory_saver patch was applied before model loading."
)
mode = impl.get_mode()
logger.info("[GMS] Loading model in %s mode", mode.upper())
if mode == "read":
return self._load_import_only(model_config, device_config, impl)
else:
return self._load_write_mode(model_config, device_config, impl)
def _load_write_mode(self, model_config, device_config, impl) -> torch.nn.Module:
"""Load model from disk and register with GMS (WRITE mode)."""
default_loader = self._get_default_loader()
model = default_loader.load_model(
model_config=model_config,
device_config=device_config,
)
impl.finalize_write_mode(model)
return model
def _load_import_only(self, model_config, device_config, impl) -> torch.nn.Module:
"""Import model weights from GMS metadata (READ mode)."""
from gpu_memory_service.client.torch.module import materialize_module_from_gms
allocator = impl.get_allocator()
if allocator is None:
raise RuntimeError("GMS allocator is None in READ mode")
device_index = torch.cuda.current_device()
model = self._create_meta_model(model_config, device_config)
materialize_module_from_gms(allocator, model, device_index=device_index)
impl.set_imported_weights_bytes(allocator.total_bytes)
logger.info(
"[GMS] READ mode: imported %.2f GiB from metadata",
allocator.total_bytes / (1 << 30),
)
return model.eval()
def _create_meta_model(self, model_config, device_config) -> torch.nn.Module:
"""Create model on meta device for import-only mode."""
from sglang.srt.model_loader import get_model
setup_meta_tensor_workaround()
original_device = torch.cuda.current_device()
meta_device = torch.device("meta")
with meta_device:
model = get_model(
model_config=model_config,
load_config=replace(self.load_config, load_format="dummy"),
device_config=device_config,
)
torch.cuda.set_device(original_device)
try:
from sglang.srt.model_loader.utils import (
process_model_weights_after_loading,
)
process_model_weights_after_loading(model, model_config)
except Exception as e:
logger.debug("[GMS] Post-processing on meta tensors: %s", e)
return model
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""SGLang-specific patches for GPU Memory Service integration.
- patch_torch_memory_saver: Routes to GMS hybrid implementation
- patch_model_runner: Fixes memory accounting with pre-loaded weights
"""
from __future__ import annotations
import logging
from typing import Optional
import torch
from gpu_memory_service.common.utils import get_socket_path
logger = logging.getLogger(__name__)
_torch_memory_saver_patched = False
_model_runner_patched = False
def patch_torch_memory_saver() -> None:
"""Patch torch_memory_saver to use GPU Memory Service implementation.
This function is idempotent - calling it multiple times has no effect.
This patch is only applied when GMSModelLoader is imported (load_format="gms").
"""
global _torch_memory_saver_patched
if _torch_memory_saver_patched:
return
try:
import torch_memory_saver.entrypoint as entrypoint_module
except ImportError:
logger.debug("[GMS] torch_memory_saver not installed, skipping patch")
return
# Store reference to original method
original_ensure_initialized = entrypoint_module.TorchMemorySaver._ensure_initialized
def patched_ensure_initialized(self):
"""Patched _ensure_initialized that uses GPU Memory Service implementation."""
# Check if already initialized
if self._impl is not None:
logger.debug("[GMS] TorchMemorySaver already initialized, skipping")
return
# Check hook_mode - use GMS for None or explicit "gms"
hook_mode = self._impl_ctor_kwargs.get("hook_mode")
logger.info(f"[GMS] TorchMemorySaver initializing with hook_mode={hook_mode}")
if hook_mode is None or hook_mode == "gms":
# Use our GPU Memory Service implementation
from gpu_memory_service.integrations.sglang.memory_saver import (
GMSMemorySaverImpl,
)
from torch_memory_saver.entrypoint import _TorchMemorySaverImpl
# Get device from torch.cuda.current_device() (already set by SGLang)
device_index = torch.cuda.current_device()
# Resolve socket path from env or default
socket_path = get_socket_path(device_index)
# Create underlying torch impl for non-weights tags (KV cache etc.)
torch_impl = _TorchMemorySaverImpl(hook_mode="torch")
# Create GPU Memory Service impl
gms_impl = GMSMemorySaverImpl(
torch_impl=torch_impl,
socket_path=socket_path,
device_index=device_index,
)
# Set _impl directly (accessible via gms_impl property)
self._impl = gms_impl
logger.info(
"[GMS] Using GMS mode (device=%d, socket=%s, mode=%s)",
device_index,
socket_path,
gms_impl.get_mode(),
)
del self._impl_ctor_kwargs
else:
# Fall back to original implementation
logger.info("[GMS] Using default torch_memory_saver hook mode")
original_ensure_initialized(self)
entrypoint_module.TorchMemorySaver._ensure_initialized = patched_ensure_initialized
# Add property to access GMS impl directly from the singleton
from gpu_memory_service.integrations.sglang.memory_saver import GMSMemorySaverImpl
@property
def gms_impl(self) -> Optional[GMSMemorySaverImpl]:
"""Get the GMS impl if installed, None otherwise."""
if isinstance(self._impl, GMSMemorySaverImpl):
return self._impl
return None
entrypoint_module.TorchMemorySaver.gms_impl = gms_impl
_torch_memory_saver_patched = True
logger.debug("[GMS] Patched torch_memory_saver")
def patch_model_runner() -> None:
"""Patch SGLang's ModelRunner to fix memory accounting with pre-loaded weights.
When weights are pre-loaded via GMS (import-only mode), SGLang's min_per_gpu_memory
captured before loading is lower than device total. This causes under-reservation
of overhead memory in KV cache calculation.
"""
global _model_runner_patched
if _model_runner_patched:
return
try:
from sglang.srt.model_executor.model_runner import ModelRunner
except ImportError:
logger.warning("[GMS] Could not import ModelRunner, skipping patch")
return
if hasattr(ModelRunner, "_gms_patched"):
return
original_init_memory_pool = ModelRunner.init_memory_pool
def patched_init_memory_pool(self, *args, **kwargs):
"""Patched init_memory_pool that uses device total for overhead calculation."""
from gpu_memory_service.integrations.sglang.memory_saver import (
get_gms_memory_saver_impl,
)
impl = get_gms_memory_saver_impl()
if impl is not None and impl.get_imported_weights_bytes() > 0:
total_memory = torch.cuda.get_device_properties(
torch.cuda.current_device()
).total_memory
if hasattr(self, "min_per_gpu_memory"):
old_value = self.min_per_gpu_memory
self.min_per_gpu_memory = total_memory
logger.info(
"[GMS] Adjusted min_per_gpu_memory: %.2f GiB -> %.2f GiB",
old_value / (1 << 30),
total_memory / (1 << 30),
)
return original_init_memory_pool(self, *args, **kwargs)
ModelRunner.init_memory_pool = patched_init_memory_pool
ModelRunner._gms_patched = True
_model_runner_patched = True
logger.info("[GMS] Patched ModelRunner.init_memory_pool")
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GPU Memory Service integration for vLLM.
Usage:
Set --load-format gms --worker-cls gpu_memory_service.integrations.vllm.worker:GMSWorker
"""
......@@ -16,12 +16,13 @@ from typing import TYPE_CHECKING
import torch
from gpu_memory_service import get_or_create_gms_client_memory_manager
from gpu_memory_service.client.torch.module import (
materialize_module_from_gms,
register_module_tensors,
)
from gpu_memory_service.client.torch.module import materialize_module_from_gms
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
from gpu_memory_service.common.utils import get_socket_path
from gpu_memory_service.integrations.common.utils import (
finalize_gms_write,
setup_meta_tensor_workaround,
)
if TYPE_CHECKING:
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
......@@ -146,22 +147,11 @@ def _load_write_mode(
process_weights_after_loading(model, model_config, target_device)
torch.cuda.empty_cache()
# Update GMS metadata store with model tensors
register_module_tensors(gms_client, model)
_last_imported_weights_bytes = gms_client.total_bytes
# Ensure all writes to GPU memory are finished before we unmap
torch.cuda.synchronize()
if not gms_client.commit():
raise RuntimeError("Allocation Server commit failed")
gms_client.switch_to_read()
_last_imported_weights_bytes = finalize_gms_write(gms_client, model)
logger.info(
"[GMS] Write mode: published %.2f GiB (%d mappings)",
"[GMS] Write mode: published %.2f GiB",
_last_imported_weights_bytes / (1 << 30),
len(gms_client._mappings),
)
return model.eval()
......@@ -174,16 +164,9 @@ def _create_meta_model(vllm_config, model_config) -> torch.nn.Module:
)
from vllm.utils.torch_utils import set_default_torch_dtype
setup_meta_tensor_workaround()
meta_device = torch.device("meta")
# Enable meta tensor workaround for torch.nonzero() etc.
try:
import torch.fx.experimental._config as fx_config
fx_config.meta_nonzero_assume_all_nonzero = True
except (ImportError, AttributeError):
pass
with set_default_torch_dtype(model_config.dtype):
with meta_device:
model = initialize_model(vllm_config=vllm_config, model_config=model_config)
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Utility patches for GPU Memory Service vLLM integration.
"""vLLM-specific patches for GPU Memory Service integration.
This module contains non-Worker patches that are applied when the GMSWorker
This module contains vLLM-specific patches that are applied when the GMSWorker
module is imported:
- torch.cuda.empty_cache patch (prevents segfaults with VMM allocations)
- MemorySnapshot.measure patch (adjusts free memory for read mode)
Note: The torch.cuda.empty_cache patch is in integrations/common/patches.py
"""
from __future__ import annotations
import logging
import torch
from gpu_memory_service import get_gms_client_memory_manager
from gpu_memory_service.common.types import GrantedLockType
logger = logging.getLogger(__name__)
_empty_cache_patched = False
_memory_snapshot_patched = False
def patch_empty_cache() -> None:
"""Patch torch.cuda.empty_cache to prevent segfaults with VMM allocations.
Must be called at module import time before any empty_cache calls.
"""
global _empty_cache_patched
if _empty_cache_patched:
return
_original_empty_cache = torch.cuda.empty_cache
def safe_empty_cache() -> None:
"""Safe replacement for torch.cuda.empty_cache that skips when VMM allocations exist.
When weights are allocated through our VMM-based pluggable allocator, calling
torch.cuda.empty_cache() causes segfaults because the native caching allocator
tries to release blocks that were allocated through VMM APIs.
"""
manager = get_gms_client_memory_manager()
if manager is not None and len(manager.mappings) > 0:
return
_original_empty_cache()
torch.cuda.empty_cache = safe_empty_cache
_empty_cache_patched = True
logger.info("[GMS Patch] Patched torch.cuda.empty_cache")
def patch_memory_snapshot() -> None:
"""Patch MemorySnapshot.measure to add committed bytes to free_memory."""
global _memory_snapshot_patched
......
......@@ -7,7 +7,7 @@ This module provides a custom Worker class that properly integrates with
GPU Memory Service for VA-stable weight sharing and unmap/remap functionality.
Usage:
Set --worker-cls=gpu_memory_service.vllm_integration.worker:GMSWorker
Set --worker-cls=gpu_memory_service.integrations.vllm.worker:GMSWorker
"""
from __future__ import annotations
......@@ -23,23 +23,14 @@ from gpu_memory_service import (
)
from gpu_memory_service.common.types import RequestedLockType
from gpu_memory_service.common.utils import get_socket_path
from gpu_memory_service.integrations.common import patch_empty_cache
from gpu_memory_service.integrations.vllm.model_loader import register_gms_loader
from gpu_memory_service.integrations.vllm.patches import patch_memory_snapshot
logger = logging.getLogger(__name__)
# Trigger model loader registration and utility patches on import
from gpu_memory_service.vllm_integration.model_loader import ( # noqa: E402
register_gms_loader,
)
from gpu_memory_service.vllm_integration.patches import ( # noqa: E402
patch_empty_cache,
patch_memory_snapshot,
)
# Register model loader
register_gms_loader()
# Apply utility patches
patch_empty_cache()
patch_memory_snapshot()
......@@ -86,7 +77,7 @@ class GMSWorker(Worker):
# Correct memory accounting for GMS-imported weights
try:
from gpu_memory_service.vllm_integration.model_loader import (
from gpu_memory_service.integrations.vllm.model_loader import (
get_imported_weights_bytes,
)
......
......@@ -72,7 +72,10 @@ setup(
"gpu_memory_service.client",
"gpu_memory_service.client.torch",
"gpu_memory_service.client.torch.extensions",
"gpu_memory_service.vllm_integration",
"gpu_memory_service.integrations",
"gpu_memory_service.integrations.common",
"gpu_memory_service.integrations.sglang",
"gpu_memory_service.integrations.vllm",
],
package_dir={
"gpu_memory_service": ".",
......@@ -83,7 +86,10 @@ setup(
"gpu_memory_service.client": "client",
"gpu_memory_service.client.torch": "client/torch",
"gpu_memory_service.client.torch.extensions": "client/torch/extensions",
"gpu_memory_service.vllm_integration": "vllm_integration",
"gpu_memory_service.integrations": "integrations",
"gpu_memory_service.integrations.common": "integrations/common",
"gpu_memory_service.integrations.sglang": "integrations/sglang",
"gpu_memory_service.integrations.vllm": "integrations/vllm",
},
package_data={
"gpu_memory_service.client.torch.extensions": ["*.cpp"],
......
......@@ -22,12 +22,17 @@ def gms_ports():
- frontend: Frontend HTTP port
- shadow_system: System port for shadow/primary engine
- primary_system: System port for primary engine (failover test only)
- shadow_kv_event: KV event port for shadow engine
- primary_kv_event: KV event port for primary engine
- shadow_nixl: NIXL side channel port for shadow engine
- primary_nixl: NIXL side channel port for primary engine
- shadow_kv_event: KV event port for shadow engine (vLLM)
- primary_kv_event: KV event port for primary engine (vLLM)
- shadow_nixl: NIXL side channel port for shadow engine (vLLM)
- primary_nixl: NIXL side channel port for primary engine (vLLM)
- shadow_sglang: SGLang HTTP port for shadow engine
- primary_sglang: SGLang HTTP port for primary engine
"""
ports = [allocate_port(p) for p in [8200, 8100, 8101, 20080, 20081, 20096, 20097]]
ports = [
allocate_port(p)
for p in [8200, 8100, 8101, 20080, 20081, 20096, 20097, 30000, 30001]
]
yield {
"frontend": ports[0],
"shadow_system": ports[1],
......@@ -36,5 +41,7 @@ def gms_ports():
"primary_kv_event": ports[4],
"shadow_nixl": ports[5],
"primary_nixl": ports[6],
"shadow_sglang": ports[7],
"primary_sglang": ports[8],
}
deallocate_ports(ports)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
GPU Memory Service Shadow Engine Failover Test for SGLang.
Tests the shadow engine failover scenario where a sleeping shadow engine can
wake up and take over when the primary engine fails.
"""
import logging
import pytest
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import DynamoFrontendProcess
from .utils.common import GMSServerProcess, get_gpu_memory_used, send_completion
from .utils.sglang import SGLangWithGMSProcess
logger = logging.getLogger(__name__)
@pytest.mark.sglang
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.fault_tolerance
@pytest.mark.nightly
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(600)
def test_gms_shadow_engine_failover(
request, runtime_services, gms_ports, predownload_models
):
"""Test shadow engine failover with GPU Memory Service.
1. Start shadow engine and put it to sleep
2. Start primary engine and serve inference
3. Kill primary engine
4. Wake shadow engine and verify it handles inference
"""
ports = gms_ports
with GMSServerProcess(request, device=0):
with DynamoFrontendProcess(request, frontend_port=ports["frontend"]):
# Start shadow engine
with SGLangWithGMSProcess(
request,
"shadow",
ports["shadow_system"],
ports["shadow_sglang"],
ports["frontend"],
) as shadow:
# Verify shadow works
result = send_completion(ports["frontend"])
logger.info(f"Shadow inference result: {result}")
assert result["choices"]
logger.info("Shadow inference OK")
# Sleep shadow (release memory occupation)
mem_before = get_gpu_memory_used()
sleep_result = shadow.sleep()
assert sleep_result["status"] == "ok"
mem_after_sleep = get_gpu_memory_used()
logger.info(
f"Shadow sleep freed {(mem_before - mem_after_sleep) / (1 << 20):.0f} MB"
)
assert mem_after_sleep < mem_before
# Start primary engine
with SGLangWithGMSProcess(
request,
"primary",
ports["primary_system"],
ports["primary_sglang"],
ports["frontend"],
):
result = send_completion(ports["frontend"], "Primary test")
logger.info(f"Primary inference result: {result}")
assert result["choices"]
logger.info("Primary inference OK")
# Primary is dead (exited context manager)
# Wake shadow (resume memory occupation)
wake_result = shadow.wake()
assert wake_result["status"] == "ok"
# Verify shadow handles failover
result = send_completion(ports["frontend"], "After failover")
logger.info(f"Failover inference result: {result}")
assert result["choices"]
logger.info("Shadow handles failover OK")
for i in range(3):
result = send_completion(ports["frontend"], f"Verify {i}")
logger.info(f"Verification {i} result: {result}")
assert result["choices"]
logger.info("All verification passed")
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
GPU Memory Service Basic Sleep/Wake Test for SGLang.
Tests the basic sleep/wake cycle of a single SGLang engine using the GPU Memory
Service for VA-stable weight offloading.
"""
import logging
import pytest
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import DynamoFrontendProcess
from .utils.common import GMSServerProcess, get_gpu_memory_used, send_completion
from .utils.sglang import SGLangWithGMSProcess
logger = logging.getLogger(__name__)
@pytest.mark.sglang
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.fault_tolerance
@pytest.mark.nightly
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(300)
def test_gms_basic_sleep_wake(request, runtime_services, gms_ports, predownload_models):
"""Test basic sleep/wake with GPU Memory Service.
1. Start GMS server and SGLang engine with GMS integration
2. Run initial inference to verify engine works
3. Put engine to sleep and verify GPU memory is freed
4. Wake engine and verify inference still works
"""
ports = gms_ports
with GMSServerProcess(request, device=0):
with DynamoFrontendProcess(request, frontend_port=ports["frontend"]):
with SGLangWithGMSProcess(
request,
"engine",
ports["shadow_system"],
ports["shadow_sglang"],
ports["frontend"],
) as engine:
# Initial inference
result = send_completion(ports["frontend"])
logger.info(f"Initial inference result: {result}")
assert result["choices"]
mem_before = get_gpu_memory_used()
logger.info(f"Memory before sleep: {mem_before / (1 << 20):.0f} MB")
# Sleep (release memory occupation)
sleep_result = engine.sleep()
assert sleep_result["status"] == "ok"
mem_after_sleep = get_gpu_memory_used()
logger.info(f"Memory after sleep: {mem_after_sleep / (1 << 20):.0f} MB")
assert mem_after_sleep < mem_before, "Sleep should reduce memory"
# Wake (resume memory occupation)
wake_result = engine.wake()
assert wake_result["status"] == "ok"
# Inference after wake
result = send_completion(ports["frontend"], "Goodbye")
logger.info(f"Post-wake inference result: {result}")
assert result["choices"]
logger.info(
f"Memory freed: {(mem_before - mem_after_sleep) / (1 << 20):.0f} MB"
)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""SGLang-specific utilities for GPU Memory Service tests."""
import logging
import os
import shutil
import requests
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import ManagedProcess
from tests.utils.payloads import check_health_generate, check_models_api
logger = logging.getLogger(__name__)
class SGLangWithGMSProcess(ManagedProcess):
"""SGLang engine with GPU Memory Service integration."""
def __init__(
self,
request,
engine_id: str,
system_port: int,
sglang_port: int,
frontend_port: int,
):
self.engine_id = engine_id
self.system_port = system_port
log_dir = f"{request.node.name}_{engine_id}"
shutil.rmtree(log_dir, ignore_errors=True)
super().__init__(
command=[
"python3",
"-m",
"dynamo.sglang",
"--model-path",
FAULT_TOLERANCE_MODEL_NAME,
"--load-format",
"gms",
"--enable-memory-saver",
"--mem-fraction-static",
"0.8",
"--port",
str(sglang_port),
],
env={
**os.environ,
"DYN_LOG": "debug",
"DYN_SYSTEM_PORT": str(system_port),
},
health_check_urls=[
(f"http://localhost:{system_port}/health", self._is_ready),
(f"http://localhost:{frontend_port}/v1/models", check_models_api),
(f"http://localhost:{frontend_port}/health", check_health_generate),
],
timeout=300,
display_output=True,
terminate_existing=False,
stragglers=[],
log_dir=log_dir,
)
def _is_ready(self, response) -> bool:
try:
return response.json().get("status") == "ready"
except ValueError:
return False
def sleep(self) -> dict:
"""Put the engine to sleep, offloading weights from GPU memory."""
r = requests.post(
f"http://localhost:{self.system_port}/engine/release_memory_occupation",
json={},
timeout=30,
)
r.raise_for_status()
logger.info(f"{self.engine_id} release_memory_occupation: {r.json()}")
return r.json()
def wake(self) -> dict:
"""Wake the engine, reloading weights to GPU memory."""
r = requests.post(
f"http://localhost:{self.system_port}/engine/resume_memory_occupation",
json={},
timeout=30,
)
r.raise_for_status()
logger.info(f"{self.engine_id} resume_memory_occupation: {r.json()}")
return r.json()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment