Unverified Commit ccd12d1c authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

feat: vLLM integrations for GPU Memory Service (#5615)

parent 89e135b9
......@@ -304,6 +304,10 @@ def setup_vllm_engine(config, stat_logger=None):
os.environ["VLLM_ALLOW_RUNTIME_LORA_UPDATING"] = "True"
if "VLLM_LORA_MODULES_LOADING_TIMEOUT" not in os.environ:
os.environ["VLLM_LORA_MODULES_LOADING_TIMEOUT"] = "600"
if engine_args.load_format == "gms":
engine_args.worker_cls = "gpu_memory_service.vllm_integration.worker.GMSWorker"
# Load default sampling params from `generation_config.json`
default_sampling_params = (
engine_args.create_model_config().get_diff_sampling_param()
......
......@@ -7,6 +7,8 @@ import argparse
import logging
from dataclasses import dataclass
from gpu_memory_service.common.utils import get_socket_path
logger = logging.getLogger(__name__)
......@@ -14,7 +16,6 @@ logger = logging.getLogger(__name__)
class Config:
"""Configuration for GPU Memory Service server."""
# GPU Memory Service specific
device: int
socket_path: str
verbose: bool
......@@ -26,7 +27,6 @@ def parse_args() -> Config:
description="GPU Memory Service allocation server."
)
# GPU Memory Service specific arguments
parser.add_argument(
"--device",
type=int,
......@@ -37,8 +37,7 @@ def parse_args() -> Config:
"--socket-path",
type=str,
default=None,
help="Path for Unix domain socket. Default: /tmp/gpu_memory_service_{device}.sock. "
"Supports {device} placeholder for multi-GPU setups.",
help="Path for Unix domain socket. Default uses GPU UUID for stability.",
)
parser.add_argument(
"--verbose",
......@@ -49,18 +48,11 @@ def parse_args() -> Config:
args = parser.parse_args()
# Generate default socket path if not provided
socket_path = args.socket_path
if socket_path is None:
socket_path = f"/tmp/gpu_memory_service_{args.device}.sock"
else:
# Expand {device} placeholder
socket_path = socket_path.format(device=args.device)
# Use UUID-based socket path by default (stable across CUDA_VISIBLE_DEVICES)
socket_path = args.socket_path or get_socket_path(args.device)
config = Config(
return Config(
device=args.device,
socket_path=socket_path,
verbose=args.verbose,
)
return config
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Shared utilities for GPU Memory Service."""
import pynvml
def get_socket_path(device: int) -> str:
"""Get GMS socket path for the given CUDA device.
The socket path is based on GPU UUID, making it stable across different
CUDA_VISIBLE_DEVICES configurations.
Args:
device: CUDA device index.
Returns:
Socket path (e.g., "/tmp/gms_GPU-12345678-1234-1234-1234-123456789abc.sock").
"""
pynvml.nvmlInit()
try:
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
uuid = pynvml.nvmlDeviceGetUUID(handle)
finally:
pynvml.nvmlShutdown()
return f"/tmp/gms_{uuid}.sock"
......@@ -71,6 +71,7 @@ setup(
"gpu_memory_service.client",
"gpu_memory_service.client.torch",
"gpu_memory_service.client.torch.extensions",
"gpu_memory_service.vllm_integration",
],
package_dir={
"gpu_memory_service": ".",
......@@ -80,6 +81,7 @@ setup(
"gpu_memory_service.client": "client",
"gpu_memory_service.client.torch": "client/torch",
"gpu_memory_service.client.torch.extensions": "client/torch/extensions",
"gpu_memory_service.vllm_integration": "vllm_integration",
},
package_data={
"gpu_memory_service.client.torch.extensions": ["*.cpp"],
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""vLLM model loader for GPU Memory Service integration.
Provides a model loader that loads weights via GMS for cross-process sharing.
The loader uses RW_OR_RO mode: first process loads from disk (RW), subsequent
processes import from GMS metadata (RO).
"""
from __future__ import annotations
import logging
from dataclasses import replace
from typing import TYPE_CHECKING
import torch
from gpu_memory_service import get_or_create_gms_client_memory_manager
from gpu_memory_service.client.torch.module import (
materialize_module_from_gms,
register_module_tensors,
)
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
from gpu_memory_service.common.utils import get_socket_path
if TYPE_CHECKING:
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
logger = logging.getLogger(__name__)
# Track imported weights for memory accounting
_last_imported_weights_bytes: int = 0
def get_imported_weights_bytes() -> int:
"""Return bytes of weights imported in the last load_model call."""
return _last_imported_weights_bytes
def register_gms_loader(load_format: str = "gms") -> None:
"""Register the GMS model loader with vLLM's loader registry."""
from vllm.model_executor.model_loader import register_model_loader
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
@register_model_loader(load_format)
class GMSModelLoader(BaseModelLoader):
"""vLLM model loader that loads weights via GPU Memory Service."""
def __init__(self, load_config):
super().__init__(load_config)
self.default_loader = DefaultModelLoader(
replace(load_config, load_format="auto")
)
def download_model(self, model_config) -> None:
self.default_loader.download_model(model_config)
def load_weights(self, model: torch.nn.Module, model_config) -> None:
self.default_loader.load_weights(model, model_config)
def load_model(self, vllm_config, model_config) -> torch.nn.Module:
device = torch.cuda.current_device()
gms_client, pool = get_or_create_gms_client_memory_manager(
get_socket_path(device),
device,
mode=RequestedLockType.RW_OR_RO,
tag="weights",
)
if gms_client.mode == GrantedLockType.RO:
return _load_read_mode(gms_client, vllm_config, model_config, device)
else:
return _load_write_mode(
gms_client,
pool,
vllm_config,
model_config,
self.default_loader,
torch.device("cuda", device),
)
# =============================================================================
# Helper functions
# =============================================================================
def _load_read_mode(
gms_client: "GMSClientMemoryManager",
vllm_config,
model_config,
device_index: int,
) -> torch.nn.Module:
"""Load model by importing weights from GMS (RO mode)."""
global _last_imported_weights_bytes
try:
model = _create_meta_model(vllm_config, model_config)
materialize_module_from_gms(gms_client, model, device_index=device_index)
_last_imported_weights_bytes = gms_client.total_bytes
logger.info(
"[GMS] Read mode: imported %.2f GiB",
_last_imported_weights_bytes / (1 << 30),
)
return model.eval()
except Exception:
gms_client.close()
raise
def _load_write_mode(
gms_client: "GMSClientMemoryManager",
pool,
vllm_config,
model_config,
default_loader,
target_device: torch.device,
) -> torch.nn.Module:
"""Load model from disk and publish weights to GMS (RW mode).
Initializes model using GMS memory pool, loads weights from disk,
registers tensors with GMS, and commits for cross-process sharing.
"""
global _last_imported_weights_bytes
from torch.cuda.memory import use_mem_pool
from vllm.model_executor.model_loader.utils import (
initialize_model,
process_weights_after_loading,
)
from vllm.utils.torch_utils import set_default_torch_dtype
gms_client.clear_all()
# Allocate model tensors using GMS memory pool
with set_default_torch_dtype(model_config.dtype):
with use_mem_pool(pool, device=target_device):
with target_device:
model = initialize_model(
vllm_config=vllm_config, model_config=model_config
)
default_loader.load_weights(model, model_config)
process_weights_after_loading(model, model_config, target_device)
torch.cuda.empty_cache()
# Update GMS metadata store with model tensors
register_module_tensors(gms_client, model)
_last_imported_weights_bytes = gms_client.total_bytes
# Ensure all writes to GPU memory are finished before we unmap
torch.cuda.synchronize()
if not gms_client.commit():
raise RuntimeError("Allocation Server commit failed")
gms_client.switch_to_read()
logger.info(
"[GMS] Write mode: published %.2f GiB (%d mappings)",
_last_imported_weights_bytes / (1 << 30),
len(gms_client._mappings),
)
return model.eval()
def _create_meta_model(vllm_config, model_config) -> torch.nn.Module:
"""Create model on meta device for RO mode materialization."""
from vllm.model_executor.model_loader.utils import (
initialize_model,
process_weights_after_loading,
)
from vllm.utils.torch_utils import set_default_torch_dtype
meta_device = torch.device("meta")
# Enable meta tensor workaround for torch.nonzero() etc.
try:
import torch.fx.experimental._config as fx_config
fx_config.meta_nonzero_assume_all_nonzero = True
except (ImportError, AttributeError):
pass
with set_default_torch_dtype(model_config.dtype):
with meta_device:
model = initialize_model(vllm_config=vllm_config, model_config=model_config)
try:
process_weights_after_loading(model, model_config, meta_device)
except Exception as e:
logger.debug("[GMS] Post-processing on meta tensors: %s", e)
return model
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Utility patches for GPU Memory Service vLLM integration.
This module contains non-Worker patches that are applied when the GMSWorker
module is imported:
- torch.cuda.empty_cache patch (prevents segfaults with VMM allocations)
- MemorySnapshot.measure patch (adjusts free memory for read mode)
"""
from __future__ import annotations
import logging
import torch
from gpu_memory_service import get_gms_client_memory_manager
from gpu_memory_service.common.types import GrantedLockType
logger = logging.getLogger(__name__)
_empty_cache_patched = False
_memory_snapshot_patched = False
def patch_empty_cache() -> None:
"""Patch torch.cuda.empty_cache to prevent segfaults with VMM allocations.
Must be called at module import time before any empty_cache calls.
"""
global _empty_cache_patched
if _empty_cache_patched:
return
_original_empty_cache = torch.cuda.empty_cache
def safe_empty_cache() -> None:
"""Safe replacement for torch.cuda.empty_cache that skips when VMM allocations exist.
When weights are allocated through our VMM-based pluggable allocator, calling
torch.cuda.empty_cache() causes segfaults because the native caching allocator
tries to release blocks that were allocated through VMM APIs.
"""
manager = get_gms_client_memory_manager()
if manager is not None and len(manager.mappings) > 0:
return
_original_empty_cache()
torch.cuda.empty_cache = safe_empty_cache
_empty_cache_patched = True
logger.info("[GMS Patch] Patched torch.cuda.empty_cache")
def patch_memory_snapshot() -> None:
"""Patch MemorySnapshot.measure to add committed bytes to free_memory."""
global _memory_snapshot_patched
if _memory_snapshot_patched:
return
try:
from vllm.utils.mem_utils import MemorySnapshot
except ImportError:
logger.debug("[GMS Patch] MemorySnapshot not available")
return
original_measure = MemorySnapshot.measure
def patched_measure(self):
original_measure(self)
manager = get_gms_client_memory_manager()
assert manager is not None, "GMS client is not initialized"
if manager.mode == GrantedLockType.RO:
allocations = manager.list_allocations()
committed_bytes = sum(alloc.get("aligned_size", 0) for alloc in allocations)
else:
# NOTE: by design, we want to assume we have the whole GPU when writing
# weights for the first time, so we don't make an adjustment.
committed_bytes = 0
logger.info("[GMS] RW mode - skipping committed memory adjustment")
original_free = self.free_memory
self.free_memory += committed_bytes
if committed_bytes > 0:
logger.info(
"[GMS Patch] Adjusted free_memory: %.2f GiB + %.2f GiB = %.2f GiB",
original_free / (1 << 30),
committed_bytes / (1 << 30),
self.free_memory / (1 << 30),
)
MemorySnapshot.measure = patched_measure
_memory_snapshot_patched = True
logger.info("[GMS Patch] Patched MemorySnapshot.measure")
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GPU Memory Service Worker subclass for vLLM integration.
This module provides a custom Worker class that properly integrates with
GPU Memory Service for VA-stable weight sharing and unmap/remap functionality.
Usage:
Set --worker-cls=gpu_memory_service.vllm_integration.worker:GMSWorker
"""
from __future__ import annotations
import logging
from contextlib import nullcontext
from typing import List, Optional
import torch
from gpu_memory_service import (
get_gms_client_memory_manager,
get_or_create_gms_client_memory_manager,
)
from gpu_memory_service.common.types import RequestedLockType
from gpu_memory_service.common.utils import get_socket_path
logger = logging.getLogger(__name__)
# Trigger model loader registration and utility patches on import
from gpu_memory_service.vllm_integration.model_loader import ( # noqa: E402
register_gms_loader,
)
from gpu_memory_service.vllm_integration.patches import ( # noqa: E402
patch_empty_cache,
patch_memory_snapshot,
)
# Register model loader
register_gms_loader()
# Apply utility patches
patch_empty_cache()
patch_memory_snapshot()
logger.info(
"[GMS] Worker module loaded - model loader registered, utility patches applied"
)
# Import Worker after patches are applied
from vllm.v1.worker.gpu_worker import Worker # noqa: E402
class GMSWorker(Worker):
"""vLLM Worker subclass with GMS integration."""
def init_device(self) -> None:
"""Initialize device with early GMS connection.
We set CUDA device and establish GMS connection BEFORE calling super()
so that MemorySnapshot.measure can query committed bytes.
"""
from vllm.platforms import current_platform
# Set CUDA device first (vLLM provides self.local_rank)
device = self.local_rank
current_platform.set_device(torch.device(f"cuda:{device}"))
# Establish GMS connection (so MemorySnapshot can query committed bytes)
socket_path = get_socket_path(device)
get_or_create_gms_client_memory_manager(
socket_path, device, mode=RequestedLockType.RW_OR_RO, tag="weights"
)
# Parent will set device again (harmless) and do memory checks
super().init_device()
def load_model(self, *args, **kwargs) -> None:
"""Load model with corrected memory accounting.
After the parent loads the model, we correct the model_memory_usage
to reflect the actual bytes imported from GMS (not the delta measured
by vLLM's memory tracking).
"""
super().load_model(*args, **kwargs)
# Correct memory accounting for GMS-imported weights
try:
from gpu_memory_service.vllm_integration.model_loader import (
get_imported_weights_bytes,
)
imported_bytes = int(get_imported_weights_bytes())
if (
imported_bytes > 0
and hasattr(self, "model_runner")
and self.model_runner is not None
):
old_usage = getattr(self.model_runner, "model_memory_usage", 0)
self.model_runner.model_memory_usage = imported_bytes
logger.info(
"[GMS] Corrected model_memory_usage: %.2f GiB -> %.2f GiB",
old_usage / (1 << 30),
imported_bytes / (1 << 30),
)
except Exception as e:
logger.debug("[GMS] Could not correct memory accounting: %s", e)
def sleep(self, level: int = 1) -> None:
"""
vLLM sleep implementation with GMS integration.
NOTE: `level` is a no-op here: weights are only unmapped (but remain in GPU memory).
NOTE: We do NOT call super().sleep() because it tries to copy GPU buffers to CPU,
which segfaults on already-unmapped GMS memory.
"""
from vllm.device_allocator.cumem import CuMemAllocator
free_bytes_before = torch.cuda.mem_get_info()[0]
# Unmap GMS weights (VA-stable unmap, no CPU backup needed)
manager = get_gms_client_memory_manager()
assert manager is not None, "GMS client is not initialized"
assert not manager.is_unmapped, "GMS weights are already unmapped"
manager.unmap()
# Sleep KV cache via CuMemAllocator (discard, no CPU backup)
allocator = CuMemAllocator.get_instance()
allocator.sleep(offload_tags=tuple())
free_bytes_after, total = torch.cuda.mem_get_info()
freed_bytes = free_bytes_after - free_bytes_before
used_bytes = total - free_bytes_after
logger.info(
"Sleep freed %.2f GiB, %.2f GiB still in use.",
freed_bytes / (1 << 30),
used_bytes / (1 << 30),
)
def wake_up(self, tags: Optional[List[str]] = None) -> None:
"""vLLM wake implementation with GMS integration."""
from vllm.device_allocator.cumem import CuMemAllocator
if tags is None:
tags = ["weights", "kv_cache"]
if "weights" in tags:
manager = get_gms_client_memory_manager()
assert manager is not None, "GMS client is not initialized"
assert manager.is_unmapped, "GMS weights are not unmapped"
manager.remap()
torch.cuda.synchronize()
if "kv_cache" in tags:
allocator = CuMemAllocator.get_instance()
allocator.wake_up(tags=["kv_cache"])
# Reinitialize FP8 KV scales if needed
if self.cache_config.cache_dtype.startswith("fp8") and hasattr(
self.model_runner, "init_fp8_kv_scales"
):
self.model_runner.init_fp8_kv_scales()
def _maybe_get_memory_pool_context(self, tag: str):
"""Skip CuMemAllocator for weights when using GMS.
GMS manages its own memory pool for weights, so we don't want
vLLM's CuMemAllocator to interfere.
"""
if tag == "weights":
logger.debug("[GMS] Skipping CuMemAllocator for weights")
return nullcontext()
return super()._maybe_get_memory_pool_context(tag)
......@@ -108,6 +108,10 @@ STUB_MODULES = [
"botocore",
"botocore.client",
"botocore.exceptions",
"pynvml",
"gpu_memory_service",
"gpu_memory_service.common",
"gpu_memory_service.common.utils",
]
# Project paths for local imports
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Pytest configuration for GPU Memory Service tests."""
import pytest
# Skip collection entirely if gpu_memory_service is not installed
try:
import gpu_memory_service # noqa: F401
except ImportError:
collect_ignore_glob = ["test_*.py"]
from tests.utils.port_utils import allocate_port, deallocate_ports
@pytest.fixture
def gms_ports():
"""Allocate ports for GMS tests.
Returns a dict with ports for:
- frontend: Frontend HTTP port
- shadow_system: System port for shadow/primary engine
- primary_system: System port for primary engine (failover test only)
- shadow_kv_event: KV event port for shadow engine
- primary_kv_event: KV event port for primary engine
- shadow_nixl: NIXL side channel port for shadow engine
- primary_nixl: NIXL side channel port for primary engine
"""
ports = [allocate_port(p) for p in [8200, 8100, 8101, 20080, 20081, 20096, 20097]]
yield {
"frontend": ports[0],
"shadow_system": ports[1],
"primary_system": ports[2],
"shadow_kv_event": ports[3],
"primary_kv_event": ports[4],
"shadow_nixl": ports[5],
"primary_nixl": ports[6],
}
deallocate_ports(ports)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
GPU Memory Service Shadow Engine Failover Test for vLLM.
Tests the shadow engine failover scenario where a sleeping shadow engine can
wake up and take over when the primary engine fails.
"""
import logging
import pytest
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import DynamoFrontendProcess
from .utils.common import GMSServerProcess, get_gpu_memory_used, send_completion
from .utils.vllm import VLLMWithGMSProcess
logger = logging.getLogger(__name__)
@pytest.mark.vllm
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.fault_tolerance
@pytest.mark.nightly
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(600)
def test_gms_shadow_engine_failover(
request, runtime_services, gms_ports, predownload_models
):
"""Test shadow engine failover with GPU Memory Service.
1. Start shadow engine and put it to sleep
2. Start primary engine and serve inference
3. Kill primary engine
4. Wake shadow engine and verify it handles inference
"""
ports = gms_ports
with GMSServerProcess(request, device=0):
with DynamoFrontendProcess(request, frontend_port=ports["frontend"]):
# Start shadow engine
with VLLMWithGMSProcess(
request,
"shadow",
ports["shadow_system"],
ports["shadow_kv_event"],
ports["shadow_nixl"],
ports["frontend"],
) as shadow:
# Verify shadow works
result = send_completion(ports["frontend"])
logger.info(f"Shadow inference result: {result}")
assert result["choices"]
logger.info("Shadow inference OK")
# Sleep shadow
mem_before = get_gpu_memory_used()
sleep_result = shadow.sleep()
assert sleep_result["status"] == "ok"
mem_after_sleep = get_gpu_memory_used()
logger.info(
f"Shadow sleep freed {(mem_before - mem_after_sleep) / (1 << 20):.0f} MB"
)
assert mem_after_sleep < mem_before
# Start primary engine
with VLLMWithGMSProcess(
request,
"primary",
ports["primary_system"],
ports["primary_kv_event"],
ports["primary_nixl"],
ports["frontend"],
):
result = send_completion(ports["frontend"], "Primary test")
logger.info(f"Primary inference result: {result}")
assert result["choices"]
logger.info("Primary inference OK")
# Primary is dead (exited context manager)
# Wake shadow
wake_result = shadow.wake()
assert wake_result["status"] == "ok"
# Verify shadow handles failover
result = send_completion(ports["frontend"], "After failover")
logger.info(f"Failover inference result: {result}")
assert result["choices"]
logger.info("Shadow handles failover OK")
for i in range(3):
result = send_completion(ports["frontend"], f"Verify {i}")
logger.info(f"Verification {i} result: {result}")
assert result["choices"]
logger.info("All verification passed")
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
GPU Memory Service Basic Sleep/Wake Test for vLLM.
Tests the basic sleep/wake cycle of a single vLLM engine using the GPU Memory
Service for VA-stable weight offloading.
"""
import logging
import pytest
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import DynamoFrontendProcess
from .utils.common import GMSServerProcess, get_gpu_memory_used, send_completion
from .utils.vllm import VLLMWithGMSProcess
logger = logging.getLogger(__name__)
@pytest.mark.vllm
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.fault_tolerance
@pytest.mark.nightly
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(300)
def test_gms_basic_sleep_wake(request, runtime_services, gms_ports, predownload_models):
"""Test basic sleep/wake with GPU Memory Service.
1. Start GMS server and vLLM engine with GMS integration
2. Run initial inference to verify engine works
3. Put engine to sleep and verify GPU memory is freed
4. Wake engine and verify inference still works
"""
ports = gms_ports
with GMSServerProcess(request, device=0):
with DynamoFrontendProcess(request, frontend_port=ports["frontend"]):
with VLLMWithGMSProcess(
request,
"engine",
ports["shadow_system"],
ports["shadow_kv_event"],
ports["shadow_nixl"],
ports["frontend"],
) as engine:
# Initial inference
result = send_completion(ports["frontend"])
logger.info(f"Initial inference result: {result}")
assert result["choices"]
mem_before = get_gpu_memory_used()
logger.info(f"Memory before sleep: {mem_before / (1 << 20):.0f} MB")
# Sleep
sleep_result = engine.sleep()
assert sleep_result["status"] == "ok"
mem_after_sleep = get_gpu_memory_used()
logger.info(f"Memory after sleep: {mem_after_sleep / (1 << 20):.0f} MB")
assert mem_after_sleep < mem_before, "Sleep should reduce memory"
# Wake
wake_result = engine.wake()
assert wake_result["status"] == "ok"
# Inference after wake
result = send_completion(ports["frontend"], "Goodbye")
logger.info(f"Post-wake inference result: {result}")
assert result["choices"]
logger.info(
f"Memory freed: {(mem_before - mem_after_sleep) / (1 << 20):.0f} MB"
)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Shared utilities for GPU Memory Service tests.
This module provides process managers and helper functions that are
backend-agnostic and can be used by vLLM, SGLang, or other backends.
"""
import logging
import os
import shutil
import time
import pynvml
import requests
from gpu_memory_service.common.utils import get_socket_path
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import ManagedProcess
logger = logging.getLogger(__name__)
def get_gpu_memory_used(device: int = 0) -> int:
"""Get GPU memory usage in bytes for the specified device."""
pynvml.nvmlInit()
try:
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
return pynvml.nvmlDeviceGetMemoryInfo(handle).used
finally:
pynvml.nvmlShutdown()
def send_completion(
port: int, prompt: str = "Hello", max_retries: int = 3, retry_delay: float = 1.0
) -> dict:
"""Send a completion request to the frontend.
Includes retry logic to handle transient failures from stale routing
(e.g., after failover when etcd still has dead instance entries).
Args:
port: The frontend HTTP port.
prompt: The prompt to send.
max_retries: Max retries for transient failures.
retry_delay: Delay between retries in seconds.
"""
last_error = None
for attempt in range(max_retries):
try:
r = requests.post(
f"http://localhost:{port}/v1/completions",
json={
"model": FAULT_TOLERANCE_MODEL_NAME,
"prompt": prompt,
"max_tokens": 20,
},
timeout=120,
)
r.raise_for_status()
result = r.json()
assert result.get("choices"), "No choices in response"
if attempt > 0:
logger.info(f"send_completion succeeded after {attempt + 1} attempts")
return result
except (requests.exceptions.RequestException, AssertionError) as e:
last_error = e
if attempt < max_retries - 1:
logger.debug(
f"send_completion attempt {attempt + 1}/{max_retries} failed: {e}"
)
time.sleep(retry_delay)
raise last_error # type: ignore
class GMSServerProcess(ManagedProcess):
"""
Manages GMS server lifecycle for tests. Starts server, waits for socket, cleans up on exit.
Runs only for the specified GPU device.
"""
def __init__(self, request, device: int):
self.device = device
self.socket_path = get_socket_path(device)
if os.path.exists(self.socket_path):
os.unlink(self.socket_path)
log_dir = f"{request.node.name}_gms_{device}"
shutil.rmtree(log_dir, ignore_errors=True)
super().__init__(
command=["python3", "-m", "gpu_memory_service", "--device", str(device)],
env={**os.environ, "DYN_LOG": "debug"},
timeout=60,
display_output=True,
terminate_existing=False,
log_dir=log_dir,
health_check_funcs=[self._socket_ready],
)
def __exit__(self, exc_type, exc_val, exc_tb):
try:
return super().__exit__(exc_type, exc_val, exc_tb)
finally:
if os.path.exists(self.socket_path):
os.unlink(self.socket_path)
def _socket_ready(self, timeout: float = 30) -> bool:
start = time.time()
while time.time() - start < timeout:
if os.path.exists(self.socket_path):
return True
time.sleep(0.1)
return False
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""vLLM-specific utilities for GPU Memory Service tests."""
import logging
import os
import shutil
import requests
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import ManagedProcess
from tests.utils.payloads import check_health_generate, check_models_api
logger = logging.getLogger(__name__)
class VLLMWithGMSProcess(ManagedProcess):
"""vLLM engine with GPU Memory Service integration."""
def __init__(
self,
request,
engine_id: str,
system_port: int,
kv_event_port: int,
nixl_port: int,
frontend_port: int,
):
self.engine_id = engine_id
self.system_port = system_port
log_dir = f"{request.node.name}_{engine_id}"
shutil.rmtree(log_dir, ignore_errors=True)
super().__init__(
command=[
"python3",
"-m",
"dynamo.vllm",
"--model",
FAULT_TOLERANCE_MODEL_NAME,
"--load-format",
"gms",
"--enable-sleep-mode",
"--gpu-memory-utilization",
"0.8",
],
env={
**os.environ,
"DYN_LOG": "debug",
"DYN_SYSTEM_PORT": str(system_port),
"DYN_VLLM_KV_EVENT_PORT": str(kv_event_port),
"VLLM_NIXL_SIDE_CHANNEL_PORT": str(nixl_port),
},
health_check_urls=[
(f"http://localhost:{system_port}/health", self._is_ready),
(f"http://localhost:{frontend_port}/v1/models", check_models_api),
(f"http://localhost:{frontend_port}/health", check_health_generate),
],
timeout=300,
display_output=True,
terminate_existing=False,
stragglers=[],
log_dir=log_dir,
)
def _is_ready(self, response) -> bool:
try:
return response.json().get("status") == "ready"
except ValueError:
return False
def sleep(self) -> dict:
"""Put the engine to sleep, offloading weights from GPU memory."""
r = requests.post(
f"http://localhost:{self.system_port}/engine/sleep",
json={"level": 1},
timeout=30,
)
r.raise_for_status()
logger.info(f"{self.engine_id} sleep: {r.json()}")
return r.json()
def wake(self) -> dict:
"""Wake the engine, reloading weights to GPU memory."""
r = requests.post(
f"http://localhost:{self.system_port}/engine/wake_up", json={}, timeout=30
)
r.raise_for_status()
logger.info(f"{self.engine_id} wake: {r.json()}")
return r.json()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment