Unverified Commit b78ec99a authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

test: reorganize and prune GPU Memory Service tests (#7828)


Signed-off-by: default avatarSchwinn Saereesitthipitak <schwinns@nvidia.com>
parent 11bb8498
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Torch integration coverage for GMS-backed tensors and modules.
This module exercises tensor remap after unmap/remap cycles and module
materialization from committed GMS-backed weights.
"""
from __future__ import annotations
import asyncio
import os
import threading
import time
from typing import cast
import pytest
......@@ -13,18 +23,21 @@ from gpu_memory_service.client.torch.module import (
)
from gpu_memory_service.client.torch.tensor import _tensor_from_pointer
from gpu_memory_service.common.locks import RequestedLockType
from tests.gms.harness.gms import GMSServerProcess
from gpu_memory_service.server.rpc import GMSRPCServer
torch = pytest.importorskip("torch", reason="torch is required")
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.integration,
pytest.mark.none,
pytest.mark.gpu_1,
]
_SERVER_START_TIMEOUT_SECONDS = 5.0
_SERVER_STOP_TIMEOUT_SECONDS = 5.0
_POLL_INTERVAL_SECONDS = 0.01
if not torch.cuda.is_available():
pytest.skip(
......@@ -50,9 +63,70 @@ class _TinyModule(torch.nn.Module):
@pytest.fixture
def running_gms(request):
with GMSServerProcess(request, device=0, tag="weights") as server:
yield server.socket_path
def running_gms(tmp_path):
socket_path = str(tmp_path / "gms.sock")
server = GMSRPCServer(socket_path, device=0)
loop: asyncio.AbstractEventLoop | None = None
task: asyncio.Task[None] | None = None
thread_error: BaseException | None = None
def run() -> None:
nonlocal loop, task, thread_error
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
task = loop.create_task(server.serve())
try:
loop.run_until_complete(task)
except asyncio.CancelledError:
pass
except BaseException as exc:
thread_error = exc
finally:
pending = [
pending_task
for pending_task in asyncio.all_tasks(loop)
if not pending_task.done()
]
for pending_task in pending:
pending_task.cancel()
if pending:
loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True)
)
loop.close()
thread = threading.Thread(target=run, daemon=True)
thread.start()
deadline = time.monotonic() + _SERVER_START_TIMEOUT_SECONDS
while True:
if thread_error is not None:
raise thread_error
if server._server is not None and os.path.exists(socket_path):
break
if time.monotonic() > deadline:
raise TimeoutError(f"GMS socket did not appear at {socket_path}")
time.sleep(_POLL_INTERVAL_SECONDS)
try:
yield socket_path
finally:
if loop is not None:
def cancel() -> None:
if server._server is not None:
server._server.close()
if task is not None:
task.cancel()
loop.call_soon_threadsafe(cancel)
thread.join(timeout=_SERVER_STOP_TIMEOUT_SECONDS)
if thread.is_alive():
raise RuntimeError(f"GMS server thread failed to stop for {socket_path}")
if thread_error is not None:
raise thread_error
if os.path.exists(socket_path):
os.unlink(socket_path)
def _make_gms_tensor(
......@@ -86,8 +160,10 @@ def test_gms_tensor_matches_plain_torch_ops(running_gms):
writer = GMSClientMemoryManager(socket_path, device=0)
writer.connect(RequestedLockType.RW)
allocation_id, _writer_tensor = _make_gms_tensor(writer, baseline, tag="weights")
allocation_id, writer_tensor = _make_gms_tensor(writer, baseline, tag="weights")
assert writer.commit()
del writer_tensor
writer.close()
reader = GMSClientMemoryManager(socket_path, device=0)
reader.connect(RequestedLockType.RO)
......@@ -122,8 +198,10 @@ def test_live_gms_tensor_survives_unmap_and_remap(running_gms):
writer = GMSClientMemoryManager(socket_path, device=0)
writer.connect(RequestedLockType.RW)
allocation_id, _ = _make_gms_tensor(writer, baseline, tag="weights")
allocation_id, writer_tensor = _make_gms_tensor(writer, baseline, tag="weights")
del writer_tensor
assert writer.commit()
writer.close()
reader = GMSClientMemoryManager(socket_path, device=0)
reader.connect(RequestedLockType.RO)
......@@ -177,6 +255,11 @@ def test_materialized_module_from_gms_matches_plain_module_forward(running_gms):
register_module_tensors(writer, gms_model)
_assert_exact_tensor_equal(expected, gms_model(inputs))
assert writer.commit()
del gms_weight
del gms_scale
del gms_extra
del gms_model
writer.close()
reader = GMSClientMemoryManager(socket_path, device=0)
reader.connect(RequestedLockType.RO)
......
......@@ -55,6 +55,7 @@ dynamo/
├── tests/ # End-to-end and cross-component tests
│ ├── serve/ # Serve E2E tests (vllm, sglang, trtllm)
│ ├── kvbm_integration/ # KVBM integration tests
│ ├── gpu_memory_service/ # GPU Memory Service E2E tests
│ ├── fault_tolerance/ # Fault tolerance, migration, cancellation
│ ├── deploy/ # Deployment tests
│ ├── frontend/ # Frontend HTTP/gRPC tests
......@@ -83,20 +84,21 @@ dynamo/
**Python tests** (`pytest`):
| Type | Description | Location |
|-------------------|------------------------------------------|----------------------------------------------|
| Unit | Single function/class, isolated | `components/src/dynamo/<component>/tests/` |
| Integration | Interactions between modules/services | `components/src/dynamo/<component>/tests/` |
| End-to-End | User workflows, CLI, API | `tests/serve/`, `tests/deploy/`, etc. |
| KVBM Integration | KV block manager integration | `tests/kvbm_integration/` |
| Router | Router E2E with backends | `tests/router/` |
| Planner | Planner unit + integration tests | `components/src/dynamo/planner/tests/` |
| Frontend | Frontend HTTP/gRPC tests | `tests/frontend/` |
| Profiler | Profiler unit + integration tests | `components/src/dynamo/profiler/tests/` |
| Global Planner | Global planner unit tests | `components/src/dynamo/global_planner/tests/`|
| Fault Tolerance | Chaos, migration, cancellation | `tests/fault_tolerance/` |
| Deployment | Deployment validation | `tests/deploy/` |
| Benchmark | Performance/load | `benchmarks/` |
| Type | Description | Location |
|--------------------|---------------------------------------|-----------------------------------------------|
| Unit | Single function/class, isolated | `components/src/dynamo/<component>/tests/` |
| Integration | Interactions between modules/services | `components/src/dynamo/<component>/tests/` |
| End-to-End | User workflows, CLI, API | `tests/serve/`, `tests/deploy/`, etc. |
| KVBM Integration | KV block manager integration | `tests/kvbm_integration/` |
| GPU Memory Service | GPU Memory Service E2E | `tests/gpu_memory_service/` |
| Router | Router E2E with backends | `tests/router/` |
| Planner | Planner unit + integration tests | `components/src/dynamo/planner/tests/` |
| Frontend | Frontend HTTP/gRPC tests | `tests/frontend/` |
| Profiler | Profiler unit + integration tests | `components/src/dynamo/profiler/tests/` |
| Global Planner | Global planner unit tests | `components/src/dynamo/global_planner/tests/` |
| Fault Tolerance | Chaos, migration, cancellation | `tests/fault_tolerance/` |
| Deployment | Deployment validation | `tests/deploy/` |
| Benchmark | Performance/load | `benchmarks/` |
---
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import pytest
from gpu_memory_service.client import memory_manager as memory_manager_module
from gpu_memory_service.client.memory_manager import (
GMSClientMemoryManager,
LocalMapping,
)
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.none,
pytest.mark.gpu_1,
]
class _FakeSession:
def __init__(self):
self.lock_type = GrantedLockType.RW
self.committed = False
self.closed = False
@property
def is_connected(self) -> bool:
return not self.closed
def get_memory_layout_hash(self) -> str:
return ""
def commit(self) -> bool:
self.closed = True
return True
def close(self) -> None:
self.closed = True
class _TrackingSession:
def __init__(self) -> None:
self.closed = False
def close(self) -> None:
self.closed = True
class _FailingCommitSession:
is_connected = True
def commit(self) -> bool:
raise ConnectionError("commit failed after local unmap")
class _SuccessfulCommitSession:
is_connected = True
def commit(self) -> bool:
return True
class _CloseFailingSession:
def close(self) -> None:
raise ConnectionError("close failed")
def _make_mapping(
allocation_id: str,
va: int,
*,
handle: int,
tag: str = "weights",
layout_slot: int = 0,
) -> LocalMapping:
return LocalMapping(
allocation_id=allocation_id,
va=va,
size=4096,
aligned_size=4096,
handle=handle,
tag=tag,
layout_slot=layout_slot,
)
@pytest.fixture
def manager(monkeypatch):
monkeypatch.setattr(
memory_manager_module, "cuda_set_current_device", lambda _device: None
)
monkeypatch.setattr(
memory_manager_module, "cumem_get_allocation_granularity", lambda _device: 65536
)
monkeypatch.setattr(memory_manager_module, "cuda_synchronize", lambda: None)
return GMSClientMemoryManager("/tmp/gms-test.sock", device=0)
def _make_manager(
monkeypatch,
*,
client: object | None,
lock_type: GrantedLockType | None,
mappings: list[LocalMapping] | None = None,
unmapped: bool = False,
va_preserved: bool = False,
layout_hash: str = "",
) -> GMSClientMemoryManager:
monkeypatch.setattr(
memory_manager_module, "cuda_set_current_device", lambda _device: None
)
manager = object.__new__(GMSClientMemoryManager)
manager.socket_path = "/tmp/gms-test.sock"
manager.device = 0
manager._client = client
manager._granted_lock_type = lock_type
manager._mappings = {}
manager._inverse_mapping = {}
manager._unmapped = unmapped
manager._va_preserved = va_preserved
manager._last_memory_layout_hash = layout_hash
manager.granularity = 4096
for mapping in mappings or []:
manager._track_mapping(mapping)
return manager
def test_commit_clears_client_lock_state(manager):
manager._client = _FakeSession()
manager._granted_lock_type = GrantedLockType.RW
assert manager.commit()
assert manager.granted_lock_type is None
assert not manager.is_connected
assert manager.is_unmapped
def test_abort_clears_client_lock_state(manager):
manager._client = _FakeSession()
manager._granted_lock_type = GrantedLockType.RW
manager.abort()
assert manager.granted_lock_type is None
assert not manager.is_connected
def test_connect_rejects_double_connect(monkeypatch):
manager = _make_manager(
monkeypatch,
client=object(),
lock_type=GrantedLockType.RO,
)
with pytest.raises(RuntimeError, match="already connected"):
manager.connect(RequestedLockType.RO)
def test_abort_drops_session_with_live_mappings(monkeypatch):
session = _TrackingSession()
manager = _make_manager(
monkeypatch,
client=session,
lock_type=GrantedLockType.RO,
mappings=[_make_mapping("alloc-1", 0x1000, handle=1234)],
)
manager.abort()
assert session.closed
assert manager.granted_lock_type is None
assert not manager.is_connected
assert manager.mappings[0x1000].handle == 1234
def test_commit_failure_after_local_unmap_keeps_preserved_unmapped_state(monkeypatch):
manager = _make_manager(
monkeypatch,
client=_FailingCommitSession(),
lock_type=GrantedLockType.RW,
mappings=[
_make_mapping("alloc_1", 0x1000, handle=11),
_make_mapping("alloc_2", 0x2000, handle=0),
],
)
unmapped_vas: list[int] = []
monkeypatch.setattr(memory_manager_module, "cuda_synchronize", lambda: None)
def fake_unmap_va(self, va: int) -> None:
unmapped_vas.append(va)
self._mappings[va] = self._mappings[va].with_handle(0)
monkeypatch.setattr(GMSClientMemoryManager, "unmap_va", fake_unmap_va)
with pytest.raises(ConnectionError, match="commit failed after local unmap"):
manager.commit()
assert unmapped_vas == [0x1000]
assert manager._mappings[0x1000].handle == 0
assert manager._mappings[0x2000].handle == 0
assert manager._va_preserved
assert manager._unmapped
assert manager._client is not None
def test_successful_commit_clears_rw_mode_before_local_cleanup(monkeypatch):
manager = _make_manager(
monkeypatch,
client=_SuccessfulCommitSession(),
lock_type=GrantedLockType.RW,
mappings=[_make_mapping("alloc_1", 0x1000, handle=11)],
)
free_handle_calls: list[str] = []
monkeypatch.setattr(memory_manager_module, "cuda_synchronize", lambda: None)
def fake_unmap_va(self, va: int) -> None:
self._mappings[va] = self._mappings[va].with_handle(0)
def fake_free_va(self, va: int) -> None:
mapping = self._mappings.pop(va)
self._inverse_mapping.pop(mapping.allocation_id, None)
def fake_free_handle(self, allocation_id: str) -> bool:
free_handle_calls.append(allocation_id)
return True
monkeypatch.setattr(GMSClientMemoryManager, "unmap_va", fake_unmap_va)
monkeypatch.setattr(GMSClientMemoryManager, "free_va", fake_free_va)
monkeypatch.setattr(GMSClientMemoryManager, "free_handle", fake_free_handle)
assert manager.commit()
assert manager._client is None
assert manager._granted_lock_type is None
assert manager._unmapped
assert manager._va_preserved
manager.destroy_mapping(0x1000)
assert free_handle_calls == []
def test_destroy_mapping_keeps_local_state_when_server_free_fails(monkeypatch):
mapping = _make_mapping("alloc_1", 0x1000, handle=77)
manager = _make_manager(
monkeypatch,
client=None,
lock_type=GrantedLockType.RW,
mappings=[mapping],
)
monkeypatch.setattr(
manager,
"free_handle",
lambda allocation_id: (_ for _ in ()).throw(RuntimeError("server free failed")),
)
with pytest.raises(RuntimeError, match="server free failed"):
manager.destroy_mapping(mapping.va)
assert manager.mappings[mapping.va] == mapping
def test_disconnect_clears_local_state_even_if_close_fails(monkeypatch):
manager = _make_manager(
monkeypatch,
client=_CloseFailingSession(),
lock_type=GrantedLockType.RO,
)
with pytest.raises(ConnectionError, match="close failed"):
manager.abort()
assert manager._client is None
assert manager._granted_lock_type is None
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import pytest
from gpu_memory_service.client.rpc import _GMSRPCTransport
from gpu_memory_service.client.session import _GMSClientSession
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
from gpu_memory_service.common.protocol.messages import (
CommitResponse,
HandshakeResponse,
)
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.none,
pytest.mark.gpu_0,
]
def _patch_handshake(
monkeypatch,
*,
response: HandshakeResponse | None = None,
error: Exception | None = None,
) -> dict[str, bool]:
closed = {"value": False}
monkeypatch.setattr(_GMSRPCTransport, "connect", lambda self: None)
if error is None:
monkeypatch.setattr(
_GMSRPCTransport,
"handshake",
lambda self, lock_type, timeout_ms: response,
)
else:
monkeypatch.setattr(
_GMSRPCTransport,
"handshake",
lambda self, lock_type, timeout_ms: (_ for _ in ()).throw(error),
)
monkeypatch.setattr(
_GMSRPCTransport,
"close",
lambda self: closed.__setitem__("value", True),
)
return closed
def test_client_session_timeout_closes_transport(monkeypatch):
closed = _patch_handshake(
monkeypatch,
response=HandshakeResponse(success=False, committed=False),
)
with pytest.raises(TimeoutError, match="Timeout waiting for lock"):
_GMSClientSession("/tmp/gms-test.sock", RequestedLockType.RO, 1000)
assert closed["value"]
def test_client_session_handshake_failure_closes_transport(monkeypatch):
closed = _patch_handshake(monkeypatch, error=RuntimeError("handshake failed"))
with pytest.raises(RuntimeError, match="handshake failed"):
_GMSClientSession("/tmp/gms-test.sock", RequestedLockType.RO, 1000)
assert closed["value"]
def test_client_session_records_granted_lock_and_committed(monkeypatch):
_patch_handshake(
monkeypatch,
response=HandshakeResponse(
success=True,
committed=True,
granted_lock_type=GrantedLockType.RO,
),
)
session = _GMSClientSession("/tmp/gms-test.sock", RequestedLockType.RW_OR_RO, None)
assert session.committed
assert session.lock_type == GrantedLockType.RO
assert session.is_ready()
def test_client_session_requires_granted_lock_type(monkeypatch):
closed = _patch_handshake(
monkeypatch,
response=HandshakeResponse(
success=True,
committed=False,
granted_lock_type=None,
),
)
with pytest.raises(RuntimeError, match="granted_lock_type"):
_GMSClientSession("/tmp/gms-test.sock", RequestedLockType.RW_OR_RO, None)
assert closed["value"]
def test_client_session_commit_marks_committed_and_closes_transport(monkeypatch):
closed = _patch_handshake(
monkeypatch,
response=HandshakeResponse(
success=True,
committed=False,
granted_lock_type=GrantedLockType.RW,
),
)
monkeypatch.setattr(
_GMSRPCTransport,
"request",
lambda self, request, response_type: CommitResponse(success=True),
)
session = _GMSClientSession("/tmp/gms-test.sock", RequestedLockType.RW, None)
assert not session.committed
assert session.commit()
assert session.committed
assert closed["value"]
def test_client_session_commit_tolerates_close_failure_after_success(monkeypatch):
monkeypatch.setattr(_GMSRPCTransport, "connect", lambda self: None)
monkeypatch.setattr(
_GMSRPCTransport,
"handshake",
lambda self, lock_type, timeout_ms: HandshakeResponse(
success=True,
committed=False,
granted_lock_type=GrantedLockType.RW,
),
)
monkeypatch.setattr(
_GMSRPCTransport,
"request",
lambda self, request, response_type: CommitResponse(success=True),
)
monkeypatch.setattr(
_GMSRPCTransport,
"close",
lambda self: (_ for _ in ()).throw(ConnectionError("close failed")),
)
session = _GMSClientSession("/tmp/gms-test.sock", RequestedLockType.RW, None)
assert session.commit()
assert session.committed
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import pytest
from gpu_memory_service.client.rpc import _GMSRPCTransport
from gpu_memory_service.common.protocol import wire
from gpu_memory_service.common.protocol.messages import (
CommitResponse,
ErrorResponse,
HandshakeResponse,
)
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.none,
pytest.mark.gpu_0,
]
class _DummySocket:
def __init__(self) -> None:
self.closed = False
def close(self) -> None:
self.closed = True
def test_transport_failure_closes_socket_and_marks_disconnected(monkeypatch):
transport = _GMSRPCTransport("/tmp/gms-test.sock")
transport._socket = _DummySocket()
monkeypatch.setattr(
"gpu_memory_service.client.rpc.send_message_sync",
lambda sock, request: None,
)
monkeypatch.setattr(
"gpu_memory_service.client.rpc.recv_message_sync",
lambda sock, buffer: (_ for _ in ()).throw(BrokenPipeError("boom")),
)
with pytest.raises(ConnectionError, match="failed: boom"):
transport.request(CommitResponse(success=True), HandshakeResponse)
assert not transport.is_connected
assert transport._socket is None
def test_request_with_fd_closes_fd_on_unexpected_response_type(monkeypatch):
transport = _GMSRPCTransport("/tmp/gms-test.sock")
closed_fds: list[int] = []
monkeypatch.setattr(
transport,
"_send_recv",
lambda request, error_prefix=None: (CommitResponse(success=True), 37),
)
monkeypatch.setattr("gpu_memory_service.client.rpc.os.close", closed_fds.append)
with pytest.raises(RuntimeError, match="unexpected response type"):
transport.request_with_fd(
CommitResponse(success=True),
HandshakeResponse,
)
assert closed_fds == [37]
def test_request_closes_fd_on_error_response(monkeypatch):
transport = _GMSRPCTransport("/tmp/gms-test.sock")
transport._socket = _DummySocket()
closed_fds: list[int] = []
monkeypatch.setattr(
"gpu_memory_service.client.rpc.send_message_sync",
lambda sock, request: None,
)
monkeypatch.setattr(
"gpu_memory_service.client.rpc.recv_message_sync",
lambda sock, buffer: (ErrorResponse(error="boom"), 41, bytearray()),
)
monkeypatch.setattr("gpu_memory_service.client.rpc.os.close", closed_fds.append)
with pytest.raises(RuntimeError, match="error: boom"):
transport.request(CommitResponse(success=True), HandshakeResponse)
assert closed_fds == [41]
def test_request_closes_fd_on_unexpected_success_fd(monkeypatch):
transport = _GMSRPCTransport("/tmp/gms-test.sock")
closed_fds: list[int] = []
monkeypatch.setattr(
transport,
"request_with_fd",
lambda request, response_type: (CommitResponse(success=True), 43),
)
monkeypatch.setattr("gpu_memory_service.client.rpc.os.close", closed_fds.append)
with pytest.raises(RuntimeError, match="unexpected FD"):
transport.request(CommitResponse(success=True), CommitResponse)
assert closed_fds == [43]
def test_recv_message_sync_closes_fd_on_decode_failure(monkeypatch):
closed_fds: list[int] = []
monkeypatch.setattr(
wire.socket,
"recv_fds",
lambda sock, size, maxfds: (b"\x00\x00\x00\x01x", [53], 0, None),
)
monkeypatch.setattr(
wire,
"decode_message",
lambda payload: (_ for _ in ()).throw(ValueError("bad frame")),
)
monkeypatch.setattr(
"gpu_memory_service.common.protocol.wire.os.close",
closed_fds.append,
)
with pytest.raises(ValueError, match="bad frame"):
wire.recv_message_sync(object(), bytearray())
assert closed_fds == [53]
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from types import SimpleNamespace
import pytest
from tests.gms.harness.gms import GMSServerProcess
from tests.utils.managed_process import ManagedProcess
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.none,
pytest.mark.gpu_1,
]
@pytest.fixture
def request_stub():
return SimpleNamespace(node=SimpleNamespace(name="gms-harness"))
def test_server_process_refuses_foreign_live_socket(monkeypatch, request_stub):
server = GMSServerProcess(request_stub, device=0, tag="weights")
monkeypatch.setattr("tests.gms.harness.gms.os.path.exists", lambda path: True)
monkeypatch.setattr("tests.gms.harness.gms._socket_has_live_gms", lambda path: True)
with pytest.raises(RuntimeError, match="already active"):
server.__enter__()
def test_server_process_unlinks_only_stale_socket_on_exit(monkeypatch, request_stub):
server = GMSServerProcess(request_stub, device=0, tag="weights")
unlinked: list[str] = []
monkeypatch.setattr(
ManagedProcess,
"__exit__",
lambda self, exc_type, exc_val, exc_tb: False,
)
monkeypatch.setattr("tests.gms.harness.gms.os.path.exists", lambda path: True)
monkeypatch.setattr("tests.gms.harness.gms.os.unlink", unlinked.append)
monkeypatch.setattr("tests.gms.harness.gms._socket_has_live_gms", lambda path: True)
server.__exit__(None, None, None)
assert unlinked == []
monkeypatch.setattr(
"tests.gms.harness.gms._socket_has_live_gms", lambda path: False
)
server.__exit__(None, None, None)
assert unlinked == [server.socket_path]
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Targeted GMS fault-tolerance unit tests."""
from __future__ import annotations
import asyncio
import os
import socket
import subprocess
import sys
import time
from dataclasses import dataclass
import pytest
from gpu_memory_service.common import cuda_utils
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
from gpu_memory_service.common.protocol.messages import (
CommitRequest,
CommitResponse,
GetEventHistoryRequest,
GetLockStateRequest,
GetLockStateResponse,
GetRuntimeStateRequest,
HandshakeRequest,
)
from gpu_memory_service.server.allocations import GMSAllocationManager
from gpu_memory_service.server.fsm import ServerState, StateEvent
from gpu_memory_service.server.gms import GMS
from gpu_memory_service.server.rpc import GMSRPCServer, _is_connection_alive
from gpu_memory_service.server.session import (
Connection,
GMSSessionManager,
OperationNotAllowed,
)
# Skip entire module if cuda.bindings is not installed
pytest.importorskip("cuda.bindings", reason="cuda.bindings is required")
from cuda.bindings import driver as cuda # noqa: E402
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.none,
pytest.mark.gpu_1,
]
def test_cumem_create_tolerate_oom_returns_handle_on_success(monkeypatch):
monkeypatch.setattr(
cuda_utils.cuda,
"cuMemCreate",
lambda size, prop, flags: (cuda.CUresult.CUDA_SUCCESS, 1234),
)
allocated, handle = cuda_utils.cumem_create_tolerate_oom(4096, 0)
assert allocated
assert handle == 1234
def test_cumem_create_tolerate_oom_returns_false_on_oom(monkeypatch):
monkeypatch.setattr(
cuda_utils.cuda,
"cuMemCreate",
lambda size, prop, flags: (cuda.CUresult.CUDA_ERROR_OUT_OF_MEMORY, 0),
)
allocated, handle = cuda_utils.cumem_create_tolerate_oom(4096, 0)
assert not allocated
assert handle == 0
def test_cumem_export_to_shareable_handle_returns_fd(monkeypatch):
monkeypatch.setattr(
cuda_utils.cuda,
"cuMemExportToShareableHandle",
lambda handle, handle_type, flags: (cuda.CUresult.CUDA_SUCCESS, 77),
)
fd = cuda_utils.cumem_export_to_shareable_handle(1234)
assert fd == 77
def test_non_oom_cuda_error_exits_process() -> None:
script = """
from cuda.bindings import driver as cuda
from gpu_memory_service.common import cuda_utils
cuda_utils.cuda_check_result(cuda.CUresult.CUDA_ERROR_INVALID_VALUE, "synthetic")
"""
result = subprocess.run([sys.executable, "-c", script], check=False)
assert result.returncode == 1
class _DummyReader:
def at_eof(self) -> bool:
return False
def exception(self):
return None
class _DummyWriter:
def __init__(self, sock: socket.socket | None = None) -> None:
self.closed = False
self._socket = sock
def close(self) -> None:
self.closed = True
async def wait_closed(self) -> None:
return None
def is_closing(self) -> bool:
return self.closed
def get_extra_info(self, _name: str):
return self._socket
def test_is_connection_alive_detects_dead_peer() -> None:
local_sock, peer_sock = socket.socketpair()
local_sock.setblocking(False)
try:
conn = Connection(
reader=_DummyReader(),
writer=_DummyWriter(local_sock),
mode=GrantedLockType.RW,
session_id="session_1",
recv_buffer=bytearray(),
)
assert _is_connection_alive(conn)
peer_sock.close()
deadline = time.monotonic() + 1.0
while _is_connection_alive(conn):
if time.monotonic() > deadline:
raise TimeoutError("peer disconnect was not detected")
time.sleep(0.01)
finally:
peer_sock.close()
local_sock.close()
def _make_connection(
mode: GrantedLockType,
session_id: str,
) -> tuple[Connection, _DummyWriter]:
writer = _DummyWriter()
return (
Connection(
reader=_DummyReader(),
writer=writer,
mode=mode,
session_id=session_id,
recv_buffer=bytearray(),
),
writer,
)
def _make_allocation_manager() -> GMSAllocationManager:
manager = object.__new__(GMSAllocationManager)
manager._device = 0
manager._allocations = {}
manager._next_layout_slot = 0
manager._granularity = 1
manager._allocation_retry_interval = 0.0
manager._allocation_retry_timeout = None
return manager
@dataclass
class _FakeHandler:
has_committed_layout: bool = False
rw_connect_calls: int = 0
rw_abort_calls: int = 0
commit_calls: int = 0
def on_rw_connect(self) -> None:
self.rw_connect_calls += 1
self.has_committed_layout = False
def on_rw_abort(self) -> None:
self.rw_abort_calls += 1
def on_commit(self) -> None:
self.commit_calls += 1
self.has_committed_layout = True
def handle_get_lock_state(
self,
has_rw: bool,
ro_count: int,
waiting_writers: int,
committed: bool,
) -> GetLockStateResponse:
return GetLockStateResponse(
state=(
"RW"
if has_rw
else "RO"
if ro_count
else "COMMITTED"
if committed
else "EMPTY"
),
has_rw_session=has_rw,
ro_session_count=ro_count,
waiting_writers=waiting_writers,
committed=committed,
is_ready=committed and not has_rw,
)
class _FakeGMS:
def __init__(self, handler: _FakeHandler | None = None):
self.handler = handler or _FakeHandler()
self._sessions = GMSSessionManager()
if self.handler.has_committed_layout:
self._sessions._locking._committed = True
@property
def committed(self) -> bool:
return self._sessions.snapshot().committed
async def acquire_lock(self, mode, timeout_ms, session_id):
return await self._sessions.acquire_lock(mode, timeout_ms, session_id)
async def cancel_connect(self, session_id, mode):
await self._sessions.cancel_connect(session_id, mode)
def on_connect(self, conn: Connection) -> None:
if conn.mode == GrantedLockType.RW:
self.handler.on_rw_connect()
self._sessions.on_connect(conn)
async def cleanup_connection(self, conn: Connection | None) -> None:
event = self._sessions.begin_cleanup(conn)
if event == StateEvent.RW_ABORT:
self.handler.on_rw_abort()
await self._sessions.finish_cleanup(conn)
async def handle_request(self, conn: Connection, msg, _is_connected):
if isinstance(msg, GetLockStateRequest):
snapshot = self._sessions.snapshot()
return (
self.handler.handle_get_lock_state(
snapshot.has_rw_session,
snapshot.ro_session_count,
snapshot.waiting_writers,
snapshot.committed,
),
-1,
False,
)
if isinstance(msg, CommitRequest):
self.handler.on_commit()
self._sessions.on_commit(conn)
return CommitResponse(success=True), -1, True
raise AssertionError(f"Unexpected request type in test: {type(msg)}")
def _make_server(handler: _FakeHandler | None = None) -> GMSRPCServer:
server = object.__new__(GMSRPCServer)
server.socket_path = "/tmp/gms-test.sock"
server.device = 0
server._gms = _FakeGMS(handler)
server._server = None
return server
@pytest.mark.asyncio
async def test_handshake_success_send_failure_cleans_up_rw_state(monkeypatch):
server = _make_server()
reader = _DummyReader()
writer = _DummyWriter()
async def fake_recv_message(_reader, _buffer):
return HandshakeRequest(lock_type=RequestedLockType.RW), -1, bytearray()
async def fake_acquire_lock(_mode, _timeout_ms, _session_id):
server._gms._sessions._reserved_rw_session_id = _session_id
return GrantedLockType.RW
async def fake_send_message(_writer, _msg, _fd=-1):
raise BrokenPipeError("handshake reply failed")
monkeypatch.setattr("gpu_memory_service.server.rpc.recv_message", fake_recv_message)
monkeypatch.setattr(server._gms, "acquire_lock", fake_acquire_lock)
monkeypatch.setattr("gpu_memory_service.server.rpc.send_message", fake_send_message)
conn = await server._do_handshake(reader, writer, "session_1")
assert conn is None
assert server._gms._sessions._locking.rw_conn is None
assert server._gms._sessions.state == ServerState.EMPTY
assert server._gms.handler.rw_connect_calls == 1
assert server._gms.handler.rw_abort_calls == 1
assert writer.closed
@pytest.mark.asyncio
async def test_rw_lock_is_reserved_until_connect():
sessions = GMSSessionManager()
first = await sessions.acquire_lock(
RequestedLockType.RW,
timeout_ms=50,
session_id="session_1",
)
second = await sessions.acquire_lock(
RequestedLockType.RW,
timeout_ms=50,
session_id="session_2",
)
assert first == GrantedLockType.RW
assert second is None
await sessions.cancel_connect("session_1", GrantedLockType.RW)
@pytest.mark.asyncio
async def test_reader_cannot_acquire_while_rw_is_reserved_before_connect():
sessions = GMSSessionManager()
sessions._locking._committed = True
granted = await sessions.acquire_lock(
RequestedLockType.RW,
timeout_ms=50,
session_id="writer_1",
)
assert granted == GrantedLockType.RW
reader = await sessions.acquire_lock(
RequestedLockType.RO,
timeout_ms=50,
session_id="reader_1",
)
assert reader is None
await sessions.cancel_connect("writer_1", GrantedLockType.RW)
@pytest.mark.asyncio
async def test_reader_waiter_wakes_when_waiting_writer_times_out():
sessions = GMSSessionManager()
sessions._locking._committed = True
existing_reader = Connection(
reader=_DummyReader(),
writer=_DummyWriter(),
mode=GrantedLockType.RO,
session_id="reader_1",
recv_buffer=bytearray(),
)
sessions.on_connect(existing_reader)
writer_task = asyncio.create_task(
sessions.acquire_lock(
RequestedLockType.RW,
timeout_ms=50,
session_id="writer_1",
)
)
await asyncio.sleep(0)
reader_task = asyncio.create_task(
sessions.acquire_lock(
RequestedLockType.RO,
timeout_ms=200,
session_id="reader_2",
)
)
assert await writer_task is None
assert await reader_task == GrantedLockType.RO
event = sessions.begin_cleanup(existing_reader)
assert event == StateEvent.RO_DISCONNECT
await sessions.finish_cleanup(existing_reader)
@pytest.mark.asyncio
async def test_rw_or_ro_waiter_becomes_rw_after_writer_abort():
sessions = GMSSessionManager()
writer_mode = await sessions.acquire_lock(
RequestedLockType.RW,
timeout_ms=50,
session_id="writer_1",
)
assert writer_mode == GrantedLockType.RW
writer = Connection(
reader=_DummyReader(),
writer=_DummyWriter(),
mode=GrantedLockType.RW,
session_id="writer_1",
recv_buffer=bytearray(),
)
sessions.on_connect(writer)
waiter = asyncio.create_task(
sessions.acquire_lock(
RequestedLockType.RW_OR_RO,
timeout_ms=200,
session_id="waiter_1",
)
)
await asyncio.sleep(0)
assert not waiter.done()
event = sessions.begin_cleanup(writer)
assert event == StateEvent.RW_ABORT
await sessions.finish_cleanup(writer)
assert await waiter == GrantedLockType.RW
await sessions.cancel_connect("waiter_1", GrantedLockType.RW)
@pytest.mark.asyncio
async def test_gms_clears_aborted_rw_layout_before_waking_waiters():
gms = object.__new__(GMS)
cleanup_order: list[str] = []
conn, _ = _make_connection(GrantedLockType.RW, "session_1")
gms._events = []
def begin_cleanup(self, cleanup_conn):
cleanup_order.append("begin_cleanup")
return StateEvent.RW_ABORT
async def finish_cleanup(self, cleanup_conn):
cleanup_order.append("finish_cleanup")
def clear_layout_state():
cleanup_order.append("clear_layout_state")
return 3
gms._sessions = type(
"_Sessions",
(),
{
"begin_cleanup": begin_cleanup,
"finish_cleanup": finish_cleanup,
},
)()
gms._clear_layout_state = clear_layout_state
await gms.cleanup_connection(conn)
assert cleanup_order == [
"begin_cleanup",
"clear_layout_state",
"finish_cleanup",
]
@pytest.mark.asyncio
async def test_request_response_send_failure_disconnects_without_error_response(
monkeypatch,
):
handler = _FakeHandler(has_committed_layout=True)
server = _make_server(handler)
server._gms._sessions._locking._committed = True
conn, writer = _make_connection(GrantedLockType.RO, "session_2")
server._gms._sessions._locking.transition(StateEvent.RO_CONNECT, conn)
recv_calls = 0
sent_messages: list[object] = []
async def fake_recv_message(_reader, _buffer):
nonlocal recv_calls
recv_calls += 1
return GetLockStateRequest(), -1, bytearray()
async def fake_send_message(_writer, msg, _fd=-1):
sent_messages.append(msg)
raise BrokenPipeError("response send failed")
monkeypatch.setattr("gpu_memory_service.server.rpc.recv_message", fake_recv_message)
monkeypatch.setattr("gpu_memory_service.server.rpc.send_message", fake_send_message)
await server._request_loop(conn)
await server._gms.cleanup_connection(conn)
assert recv_calls == 1
assert len(sent_messages) == 1
assert isinstance(sent_messages[0], GetLockStateResponse)
assert conn not in server._gms._sessions._locking.ro_conns
assert server._gms._sessions.state == ServerState.COMMITTED
assert writer.closed
@pytest.mark.asyncio
async def test_post_commit_response_send_failure_stays_committed(monkeypatch):
handler = _FakeHandler()
server = _make_server(handler)
conn, writer = _make_connection(GrantedLockType.RW, "session_3")
server._gms._sessions._locking.transition(StateEvent.RW_CONNECT, conn)
recv_calls = 0
sent_messages: list[object] = []
async def fake_recv_message(_reader, _buffer):
nonlocal recv_calls
recv_calls += 1
return CommitRequest(), -1, bytearray()
async def fake_send_message(_writer, msg, _fd=-1):
sent_messages.append(msg)
raise BrokenPipeError("commit reply failed")
monkeypatch.setattr("gpu_memory_service.server.rpc.recv_message", fake_recv_message)
monkeypatch.setattr("gpu_memory_service.server.rpc.send_message", fake_send_message)
await server._request_loop(conn)
await server._gms.cleanup_connection(conn)
assert recv_calls == 1
assert len(sent_messages) == 1
assert handler.commit_calls == 1
assert server._gms._sessions._locking.rw_conn is None
assert server._gms._sessions.snapshot().committed
assert server._gms._sessions.state == ServerState.COMMITTED
assert writer.closed
@pytest.mark.asyncio
async def test_runtime_state_handshake_send_failure_does_not_fail_server(monkeypatch):
server = _make_server()
reader = _DummyReader()
writer = _DummyWriter()
async def fake_recv_message(_reader, _buffer):
return GetRuntimeStateRequest(), -1, bytearray()
async def fake_send_message(_writer, _msg, _fd=-1):
raise BrokenPipeError("runtime-state send failed")
monkeypatch.setattr("gpu_memory_service.server.rpc.recv_message", fake_recv_message)
monkeypatch.setattr("gpu_memory_service.server.rpc.send_message", fake_send_message)
conn = await server._do_handshake(reader, writer, "session_diag")
assert conn is None
assert server._gms._sessions.state == ServerState.EMPTY
assert writer.closed
@pytest.mark.asyncio
async def test_runtime_state_request_is_rejected_on_live_session():
gms = GMS()
conn, _ = _make_connection(GrantedLockType.RW, "session_4")
gms._sessions._reserved_rw_session_id = conn.session_id
gms.on_connect(conn)
with pytest.raises(OperationNotAllowed):
await gms.handle_request(
conn,
GetRuntimeStateRequest(),
lambda: True,
)
@pytest.mark.asyncio
async def test_event_history_request_is_rejected_on_live_session():
gms = GMS()
conn, _ = _make_connection(GrantedLockType.RW, "session_5")
gms._sessions._reserved_rw_session_id = conn.session_id
gms.on_connect(conn)
with pytest.raises(OperationNotAllowed):
await gms.handle_request(
conn,
GetEventHistoryRequest(),
lambda: True,
)
@pytest.mark.asyncio
async def test_server_refuses_to_bind_over_live_socket(monkeypatch, tmp_path):
socket_path = str(tmp_path / "gms.sock")
listener = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
listener.bind(socket_path)
listener.listen(1)
monkeypatch.setattr(
"gpu_memory_service.server.allocations.cuda_ensure_initialized",
lambda: None,
)
monkeypatch.setattr(
"gpu_memory_service.server.allocations.cumem_get_allocation_granularity",
lambda device: 4096,
)
server = GMSRPCServer(socket_path, device=0)
try:
with pytest.raises(RuntimeError, match="already running"):
await asyncio.wait_for(server.serve(), timeout=0.1)
finally:
listener.close()
if os.path.exists(socket_path):
os.unlink(socket_path)
@pytest.mark.asyncio
async def test_allocate_rejects_non_positive_size_before_cuda():
manager = _make_allocation_manager()
with pytest.raises(ValueError, match="size must be > 0"):
await manager.allocate(0, tag="weights", is_connected=None)
@pytest.mark.asyncio
async def test_allocate_aborts_retry_when_writer_disconnects(monkeypatch):
manager = _make_allocation_manager()
checks = 0
def is_connected() -> bool:
nonlocal checks
checks += 1
return checks < 2
monkeypatch.setattr(
"gpu_memory_service.server.allocations.cumem_create_tolerate_oom",
lambda size, device: (False, 0),
)
with pytest.raises(
ConnectionAbortedError, match="RW client disconnected during allocation retry"
):
await manager.allocate(1, tag="weights", is_connected=is_connected)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import sys
import types
import gpu_memory_service.integrations.sglang.patches as sglang_patches
import pytest
torch = pytest.importorskip("torch", reason="torch is required")
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.gpu_0,
pytest.mark.sglang,
]
def test_patch_model_runner_rewrites_total_gpu_memory(monkeypatch):
fake_sglang = types.ModuleType("sglang")
fake_srt = types.ModuleType("sglang.srt")
fake_executor = types.ModuleType("sglang.srt.model_executor")
fake_model_runner = types.ModuleType("sglang.srt.model_executor.model_runner")
class FakeModelRunner:
def init_memory_pool(self, total_gpu_memory):
self.seen_total_gpu_memory = total_gpu_memory
return total_gpu_memory
fake_model_runner.ModelRunner = FakeModelRunner
fake_sglang.srt = fake_srt
fake_srt.model_executor = fake_executor
fake_executor.model_runner = fake_model_runner
fake_memory_saver = types.ModuleType(
"gpu_memory_service.integrations.sglang.memory_saver"
)
class FakeImpl:
imported_weights_bytes = 8 << 30
fake_memory_saver.get_gms_memory_saver_impl = lambda: FakeImpl()
monkeypatch.setitem(sys.modules, "sglang", fake_sglang)
monkeypatch.setitem(sys.modules, "sglang.srt", fake_srt)
monkeypatch.setitem(sys.modules, "sglang.srt.model_executor", fake_executor)
monkeypatch.setitem(
sys.modules,
"sglang.srt.model_executor.model_runner",
fake_model_runner,
)
monkeypatch.setitem(
sys.modules,
"gpu_memory_service.integrations.sglang.memory_saver",
fake_memory_saver,
)
monkeypatch.setattr(
sglang_patches,
"get_gms_memory_saver_impl",
lambda: FakeImpl(),
)
monkeypatch.setattr(sglang_patches, "_model_runner_patched", False)
monkeypatch.delattr(FakeModelRunner, "_gms_patched", raising=False)
monkeypatch.setattr(
sglang_patches.torch.cuda,
"current_device",
lambda: 0,
)
monkeypatch.setattr(
sglang_patches.torch.cuda,
"get_device_properties",
lambda device: types.SimpleNamespace(total_memory=80 * (1 << 30)),
)
sglang_patches.patch_model_runner()
runner = FakeModelRunner()
assert runner.init_memory_pool(40.0) == pytest.approx(80.0)
assert runner.seen_total_gpu_memory == pytest.approx(80.0)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import pytest
from gpu_memory_service.client.torch import allocator as allocator_module
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.none,
pytest.mark.gpu_0,
]
class _FakeManager:
def __init__(self, socket_path: str, *, device: int):
self.socket_path = socket_path
self.device = device
self._connected = False
self._granted_lock_type: GrantedLockType | None = None
self._mappings: dict[int, object] = {}
self._unmapped = False
@property
def granted_lock_type(self) -> GrantedLockType | None:
return self._granted_lock_type
@property
def is_connected(self) -> bool:
return self._connected
@property
def mappings(self) -> dict[int, object]:
return self._mappings
@property
def is_unmapped(self) -> bool:
return self._unmapped
def connect(self, mode: RequestedLockType, timeout_ms: int | None = None) -> None:
del timeout_ms
self._connected = True
if mode == RequestedLockType.RW:
self._granted_lock_type = GrantedLockType.RW
return
self._granted_lock_type = GrantedLockType.RO
@property
def _client_rpc(self):
return self
def get_lock_state(self) -> object:
return object()
@pytest.fixture(autouse=True)
def clear_tag_states():
allocator_module._tag_states.clear()
yield
allocator_module._tag_states.clear()
@pytest.fixture
def fake_allocator(monkeypatch):
monkeypatch.setattr(
"gpu_memory_service.client.memory_manager.GMSClientMemoryManager",
_FakeManager,
)
monkeypatch.setattr(
allocator_module,
"_ensure_callbacks_initialized",
lambda: None,
)
monkeypatch.setattr(allocator_module, "_create_mem_pool", lambda: object())
def test_tag_registry_rejects_socket_or_device_mismatch(fake_allocator):
manager = allocator_module.get_or_create_gms_client_memory_manager(
"/tmp/weights.sock",
0,
RequestedLockType.RO,
tag="weights",
)
with pytest.raises(RuntimeError, match="initialized for"):
allocator_module.get_or_create_gms_client_memory_manager(
"/tmp/other.sock",
0,
RequestedLockType.RO,
tag="weights",
)
with pytest.raises(RuntimeError, match="initialized for"):
allocator_module.get_or_create_gms_client_memory_manager(
"/tmp/weights.sock",
1,
RequestedLockType.RO,
tag="weights",
)
assert manager.is_connected
def test_tag_registry_recreates_disconnected_empty_manager(fake_allocator):
first = allocator_module.get_or_create_gms_client_memory_manager(
"/tmp/weights.sock",
0,
RequestedLockType.RO,
tag="weights",
)
first._connected = False
first._granted_lock_type = None
second = allocator_module.get_or_create_gms_client_memory_manager(
"/tmp/weights.sock",
0,
RequestedLockType.RO,
tag="weights",
)
assert second is not first
assert second.is_connected
assert second.granted_lock_type == GrantedLockType.RO
def test_tag_registry_rejects_disconnected_manager_with_preserved_state(
fake_allocator,
):
manager = allocator_module.get_or_create_gms_client_memory_manager(
"/tmp/weights.sock",
0,
RequestedLockType.RO,
tag="weights",
)
manager._connected = False
manager._mappings[0x1000] = object()
with pytest.raises(RuntimeError, match="preserved state"):
allocator_module.get_or_create_gms_client_memory_manager(
"/tmp/weights.sock",
0,
RequestedLockType.RO,
tag="weights",
)
def test_close_evicts_manager_from_tag_registry(fake_allocator):
manager = allocator_module.get_or_create_gms_client_memory_manager(
"/tmp/weights.sock",
0,
RequestedLockType.RO,
tag="weights",
)
allocator_module.evict_gms_client_memory_manager(manager)
assert allocator_module.get_gms_client_memory_manager("weights") is None
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Pytest configuration for GPU Memory Service tests."""
import pytest
from tests.utils.port_utils import allocate_port, deallocate_ports # noqa: E402
@pytest.fixture
def gms_ports():
"""Allocate ports for GMS tests.
Returns a dict with ports for:
- frontend: Frontend HTTP port
- shadow_system: System port for the first shadow engine
- shadow2_system: System port for the second shadow engine
- primary_system: System port for primary engine (failover test only)
- shadow_kv_event: KV event port for the first shadow engine (vLLM)
- shadow2_kv_event: KV event port for the second shadow engine (vLLM)
- primary_kv_event: KV event port for primary engine (vLLM)
- shadow_nixl: NIXL side channel port for the first shadow engine (vLLM)
- shadow2_nixl: NIXL side channel port for the second shadow engine (vLLM)
- primary_nixl: NIXL side channel port for primary engine (vLLM)
- shadow_sglang: SGLang HTTP port for the first shadow engine
- shadow2_sglang: SGLang HTTP port for the second shadow engine
- primary_sglang: SGLang HTTP port for primary engine
"""
ports = [
allocate_port(p)
for p in [
8200,
8100,
8101,
8102,
20080,
20081,
20082,
20096,
20097,
20098,
30000,
30001,
30002,
]
]
yield {
"frontend": ports[0],
"shadow_system": ports[1],
"primary_system": ports[2],
"shadow2_system": ports[3],
"shadow_kv_event": ports[4],
"primary_kv_event": ports[5],
"shadow2_kv_event": ports[6],
"shadow_nixl": ports[7],
"primary_nixl": ports[8],
"shadow2_nixl": ports[9],
"shadow_sglang": ports[10],
"primary_sglang": ports[11],
"shadow2_sglang": ports[12],
}
deallocate_ports(ports)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import argparse
import os
import socket
import subprocess
from contextlib import contextmanager
import torch
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
from gpu_memory_service.client.torch.allocator import (
get_or_create_gms_client_memory_manager,
gms_use_mem_pool,
)
from gpu_memory_service.client.torch.module import register_module_tensors
from gpu_memory_service.common.locks import RequestedLockType
from gpu_memory_service.common.utils import get_socket_path
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from .runtime import DYNAMO_BIN, REPO_ROOT
SGLANG_BIN = REPO_ROOT / "dynamo-sglang" / "bin"
def get_external_weight_writer_command(backend: str) -> list[str]:
return [
"python",
"-m",
"tests.gms.harness.external_weight_writer",
"--backend",
backend,
]
def get_external_weight_writer_env(backend: str) -> dict[str, str]:
if backend == "sglang":
return {
**os.environ,
"PATH": f"/usr/local/cuda/bin:{SGLANG_BIN}:{os.environ.get('PATH', '')}",
"CC": "/usr/bin/gcc",
"CXX": "/usr/bin/g++",
"PYTHONPATH": str(REPO_ROOT),
}
return {
**os.environ,
"PATH": f"{DYNAMO_BIN}:{os.environ.get('PATH', '')}",
"PYTHONPATH": str(REPO_ROOT),
}
def run_external_weight_writer(backend: str) -> None:
command = get_external_weight_writer_command(backend)
subprocess.run(
command,
cwd=REPO_ROOT,
env=get_external_weight_writer_env(backend),
check=True,
)
def _get_writer_manager(device: int, tag: str) -> GMSClientMemoryManager:
return get_or_create_gms_client_memory_manager(
get_socket_path(device, tag),
device,
RequestedLockType.RW,
tag=tag,
)
def _publish_model(manager: GMSClientMemoryManager, model: torch.nn.Module) -> None:
register_module_tensors(manager, model)
torch.cuda.synchronize()
manager.commit()
manager.close()
@contextmanager
def _vllm_single_rank_distributed(device: int):
from vllm.distributed.parallel_state import (
destroy_distributed_environment,
destroy_model_parallel,
ensure_model_parallel_initialized,
init_distributed_environment,
)
probe = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
probe.bind(("127.0.0.1", 0))
host, port = probe.getsockname()
probe.close()
init_distributed_environment(
world_size=1,
rank=0,
local_rank=device,
distributed_init_method=f"tcp://{host}:{port}",
backend="gloo",
)
ensure_model_parallel_initialized(1, 1)
try:
yield
finally:
destroy_model_parallel()
destroy_distributed_environment()
@contextmanager
def _sglang_single_rank_distributed(device: int):
from sglang.srt.distributed.parallel_state import (
destroy_distributed_environment,
destroy_model_parallel,
init_distributed_environment,
initialize_model_parallel,
)
probe = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
probe.bind(("127.0.0.1", 0))
host, port = probe.getsockname()
probe.close()
init_distributed_environment(
world_size=1,
rank=0,
local_rank=device,
distributed_init_method=f"tcp://{host}:{port}",
backend="gloo",
)
initialize_model_parallel(1, 1, 1, backend="gloo")
try:
yield
finally:
destroy_model_parallel()
destroy_distributed_environment()
def _publish_vllm_dummy_weights(device: int, tag: str) -> None:
from vllm.config import (
DeviceConfig,
LoadConfig,
ModelConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader
from vllm.model_executor.model_loader.utils import (
initialize_model,
process_weights_after_loading,
)
from vllm.utils.torch_utils import set_default_torch_dtype
torch.cuda.set_device(device)
model_config = ModelConfig(model=FAULT_TOLERANCE_MODEL_NAME, enforce_eager=True)
load_config = LoadConfig(load_format="dummy")
device_config = DeviceConfig(device="cuda")
vllm_config = VllmConfig(
model_config=model_config,
device_config=device_config,
load_config=load_config,
)
target_device = torch.device("cuda", device)
manager = _get_writer_manager(device, tag)
with set_current_vllm_config(vllm_config):
with _vllm_single_rank_distributed(device):
with set_default_torch_dtype(model_config.dtype):
with gms_use_mem_pool(tag, target_device):
with target_device:
model = initialize_model(
vllm_config=vllm_config,
model_config=model_config,
)
DummyModelLoader(load_config).load_weights(model, model_config)
process_weights_after_loading(model, model_config, target_device)
_publish_model(manager, model)
def _publish_sglang_dummy_weights(device: int, tag: str) -> None:
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.layers.dp_attention import initialize_dp_attention
from sglang.srt.model_loader.loader import LoadConfig, get_model_loader
from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler
torch.cuda.set_device(device)
model_config = ModelConfig(FAULT_TOLERANCE_MODEL_NAME)
device_config = DeviceConfig(device="cuda", gpu_id=device)
load_config = LoadConfig(load_format="dummy")
loader = get_model_loader(load_config, model_config)
manager = _get_writer_manager(device, tag)
server_args = ServerArgs(
model_path=FAULT_TOLERANCE_MODEL_NAME,
load_format="dummy",
device="cuda",
)
set_global_server_args_for_scheduler(server_args)
with _sglang_single_rank_distributed(device):
initialize_dp_attention(server_args, model_config)
with gms_use_mem_pool(tag, torch.device("cuda", device)):
model = loader.load_model(
model_config=model_config,
device_config=device_config,
)
_publish_model(manager, model)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--backend", choices=("vllm", "sglang"), required=True)
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--tag", default="weights")
args = parser.parse_args()
if args.backend == "vllm":
_publish_vllm_dummy_weights(args.device, args.tag)
return
_publish_sglang_dummy_weights(args.device, args.tag)
if __name__ == "__main__":
main()
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import asyncio
import os
import shutil
import threading
import time
from concurrent.futures import TimeoutError as FutureTimeoutError
from gpu_memory_service.client.rpc import _GMSRPCTransport
from gpu_memory_service.common.protocol.messages import (
GetEventHistoryRequest,
GetEventHistoryResponse,
GetRuntimeStateRequest,
GetRuntimeStateResponse,
)
from gpu_memory_service.common.utils import get_socket_path
from gpu_memory_service.server.rpc import GMSRPCServer
from tests.utils.managed_process import ManagedProcess
from .runtime import DYNAMO_BIN
def _socket_has_live_gms(socket_path: str) -> bool:
if not os.path.exists(socket_path):
return False
try:
with _GMSRPCTransport(socket_path) as transport:
transport.connect()
transport.request(
GetRuntimeStateRequest(),
GetRuntimeStateResponse,
)
except Exception:
return False
return True
def _prepare_socket_path_for_launch(socket_path: str) -> None:
if not os.path.exists(socket_path):
return
if _socket_has_live_gms(socket_path):
raise RuntimeError(f"GMS already active at {socket_path}")
os.unlink(socket_path)
class GMSServerProcess(ManagedProcess):
def __init__(self, request, device: int, tag: str = "weights"):
self.device = device
self.tag = tag
self.socket_path = get_socket_path(device, tag)
log_dir = f"{request.node.name}_gms_{tag}_{device}"
shutil.rmtree(log_dir, ignore_errors=True)
super().__init__(
command=[
"python",
"-m",
"gpu_memory_service",
"--device",
str(device),
"--tag",
tag,
],
env={
**os.environ,
"PATH": f"{DYNAMO_BIN}:{os.environ.get('PATH', '')}",
"DYN_LOG": "debug",
},
timeout=60,
display_output=True,
terminate_all_matching_process_names=False,
log_dir=log_dir,
display_name=f"gms_{tag}",
health_check_funcs=[self._runtime_state_ready],
)
def __enter__(self):
_prepare_socket_path_for_launch(self.socket_path)
return super().__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
try:
return super().__exit__(exc_type, exc_val, exc_tb)
finally:
if os.path.exists(self.socket_path) and not _socket_has_live_gms(
self.socket_path
):
os.unlink(self.socket_path)
def _socket_has_live_gms(self) -> bool:
return _socket_has_live_gms(self.socket_path)
def _prepare_socket_path_for_launch(self) -> None:
_prepare_socket_path_for_launch(self.socket_path)
def _runtime_state_ready(self, timeout: float = 30) -> bool:
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
if not os.path.exists(self.socket_path):
time.sleep(0.1)
continue
try:
self.get_runtime_state()
return True
except Exception:
time.sleep(0.1)
continue
time.sleep(0.1)
return False
def get_runtime_state(self) -> GetRuntimeStateResponse:
with _GMSRPCTransport(self.socket_path) as transport:
transport.connect()
return transport.request(
GetRuntimeStateRequest(),
GetRuntimeStateResponse,
)
def get_event_history(self) -> GetEventHistoryResponse:
with _GMSRPCTransport(self.socket_path) as transport:
transport.connect()
return transport.request(
GetEventHistoryRequest(),
GetEventHistoryResponse,
)
class ServerThread:
def __init__(self, server, socket_path: str):
self.server = server
self.socket_path = socket_path
self._loop: asyncio.AbstractEventLoop | None = None
self._task: asyncio.Task[None] | None = None
self._thread = threading.Thread(target=self._run, daemon=True)
self._exception: BaseException | None = None
def _run(self) -> None:
loop = asyncio.new_event_loop()
self._loop = loop
asyncio.set_event_loop(loop)
self._task = loop.create_task(self.server.serve())
try:
loop.run_until_complete(self._task)
except asyncio.CancelledError:
pass
except BaseException as exc:
self._exception = exc
finally:
pending = [task for task in asyncio.all_tasks(loop) if not task.done()]
for task in pending:
task.cancel()
if pending:
loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True)
)
loop.close()
def start(self) -> None:
self._thread.start()
deadline = time.monotonic() + 5.0
while True:
if self._exception is not None:
raise self._exception
if self.server._server is not None and os.path.exists(self.socket_path):
try:
with _GMSRPCTransport(self.socket_path) as transport:
transport.connect()
transport.request(
GetRuntimeStateRequest(),
GetRuntimeStateResponse,
)
return
except Exception:
pass
if time.monotonic() > deadline:
raise TimeoutError(f"GMS socket did not appear at {self.socket_path}")
time.sleep(0.01)
def stop(self) -> None:
if self._loop is not None:
def cancel() -> None:
if self.server._server is not None:
self.server._server.close()
if self._task is not None:
self._task.cancel()
self._loop.call_soon_threadsafe(cancel)
self._thread.join(timeout=5)
if self._exception is not None:
raise self._exception
if os.path.exists(self.socket_path):
os.unlink(self.socket_path)
def disconnect_rw_session(self, timeout: float = 5.0) -> None:
if self._loop is None:
raise RuntimeError("GMS server thread is not running")
future = asyncio.run_coroutine_threadsafe(
self._disconnect_rw_session(), self._loop
)
try:
future.result(timeout=timeout)
except FutureTimeoutError as exc:
raise TimeoutError("Timed out disconnecting RW session") from exc
async def _disconnect_rw_session(self) -> None:
conn = self.server._gms._sessions._locking.rw_conn
if conn is None:
raise RuntimeError("No active RW session to disconnect")
await self.server._gms.cleanup_connection(conn)
class ThreadedGMSServer:
def __init__(self, device: int, tag: str = "weights"):
self.device = device
self.tag = tag
self.socket_path = get_socket_path(device, tag)
self.server = GMSRPCServer(self.socket_path, device)
self._thread = ServerThread(self.server, self.socket_path)
def __enter__(self) -> "ThreadedGMSServer":
_prepare_socket_path_for_launch(self.socket_path)
self._thread.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self._thread.stop()
def get_runtime_state(self) -> GetRuntimeStateResponse:
with _GMSRPCTransport(self.socket_path) as transport:
transport.connect()
return transport.request(
GetRuntimeStateRequest(),
GetRuntimeStateResponse,
)
def get_event_history(self) -> GetEventHistoryResponse:
with _GMSRPCTransport(self.socket_path) as transport:
transport.connect()
return transport.request(
GetEventHistoryRequest(),
GetEventHistoryResponse,
)
def disconnect_rw_session(self) -> None:
self._thread.disconnect_rw_session()
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import logging
import time
from pathlib import Path
import requests
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
logger = logging.getLogger(__name__)
REPO_ROOT = Path(__file__).resolve().parents[3]
DYNAMO_BIN = REPO_ROOT / "dynamo" / "bin"
MIN_EXPECTED_MEMORY_RETURN_FRACTION = 0.6
def get_gpu_memory_used(device: int = 0) -> int:
import pynvml
pynvml.nvmlInit()
try:
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
return pynvml.nvmlDeviceGetMemoryInfo(handle).used
finally:
pynvml.nvmlShutdown()
def send_completion(
port: int,
prompt: str = "Hello",
max_retries: int = 3,
retry_delay: float = 1.0,
) -> dict:
last_error = None
for attempt in range(max_retries):
try:
response = requests.post(
f"http://localhost:{port}/v1/completions",
json={
"model": FAULT_TOLERANCE_MODEL_NAME,
"prompt": prompt,
"max_tokens": 20,
},
timeout=120,
)
response.raise_for_status()
result = response.json()
assert result.get("choices"), "No choices in response"
if attempt > 0:
logger.info("send_completion succeeded after %d attempts", attempt + 1)
return result
except (requests.exceptions.RequestException, AssertionError) as exc:
last_error = exc
if attempt < max_retries - 1:
logger.debug(
"send_completion attempt %d/%d failed: %s",
attempt + 1,
max_retries,
exc,
)
time.sleep(retry_delay)
raise last_error # type: ignore[misc]
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""SGLang-specific utilities for GPU Memory Service tests."""
import logging
import os
import shutil
import requests
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import ManagedProcess
from tests.utils.payloads import check_health_generate, check_models_api
from .runtime import REPO_ROOT
logger = logging.getLogger(__name__)
SGLANG_BIN = REPO_ROOT / "dynamo-sglang" / "bin"
class SGLangWithGMSProcess(ManagedProcess):
"""SGLang engine with GPU Memory Service integration."""
def __init__(
self,
request,
engine_id: str,
system_port: int,
sglang_port: int,
frontend_port: int,
*,
read_only_weights: bool = False,
):
self.engine_id = engine_id
self.system_port = system_port
log_dir = f"{request.node.name}_{engine_id}"
shutil.rmtree(log_dir, ignore_errors=True)
command = [
"python",
"-m",
"dynamo.sglang",
"--model-path",
FAULT_TOLERANCE_MODEL_NAME,
"--load-format",
"gms",
"--enable-memory-saver",
"--mem-fraction-static",
"0.9",
"--port",
str(sglang_port),
]
if read_only_weights:
command.extend(
[
"--model-loader-extra-config",
'{"gms_read_only": true}',
]
)
super().__init__(
command=command,
env={
**os.environ,
"PATH": f"/usr/local/cuda/bin:{SGLANG_BIN}:{os.environ.get('PATH', '')}",
"CC": "/usr/bin/gcc",
"CXX": "/usr/bin/g++",
"DYN_LOG": "debug",
"DYN_SYSTEM_PORT": str(system_port),
},
health_check_urls=[
(f"http://localhost:{system_port}/health", self._is_ready),
(f"http://localhost:{frontend_port}/v1/models", check_models_api),
(f"http://localhost:{frontend_port}/health", check_health_generate),
],
timeout=300,
display_output=True,
terminate_all_matching_process_names=False,
stragglers=[],
log_dir=log_dir,
display_name=engine_id,
)
def _is_ready(self, response) -> bool:
try:
return response.json().get("status") == "ready"
except ValueError:
return False
def sleep(self) -> dict:
"""Put the engine to sleep, offloading weights and KV cache."""
r = requests.post(
f"http://localhost:{self.system_port}/engine/release_memory_occupation",
json={"tags": ["weights", "kv_cache"]},
timeout=30,
)
r.raise_for_status()
logger.info(f"{self.engine_id} release_memory_occupation: {r.json()}")
return r.json()
def wake(self, timeout: int = 30) -> dict:
"""Wake the engine, restoring weights and KV cache."""
r = requests.post(
f"http://localhost:{self.system_port}/engine/resume_memory_occupation",
json={"tags": ["weights", "kv_cache"]},
timeout=timeout,
)
r.raise_for_status()
logger.info(f"{self.engine_id} resume_memory_occupation: {r.json()}")
return r.json()
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""vLLM-specific utilities for GPU Memory Service tests."""
import json
import logging
import os
import shutil
import requests
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import ManagedProcess
from tests.utils.payloads import check_health_generate, check_models_api
from .runtime import DYNAMO_BIN
logger = logging.getLogger(__name__)
class VLLMWithGMSProcess(ManagedProcess):
"""vLLM engine with GPU Memory Service integration."""
def __init__(
self,
request,
engine_id: str,
system_port: int,
kv_event_port: int,
nixl_port: int,
frontend_port: int,
*,
read_only_weights: bool = False,
):
self.engine_id = engine_id
self.system_port = system_port
log_dir = f"{request.node.name}_{engine_id}"
shutil.rmtree(log_dir, ignore_errors=True)
kv_events_cfg = json.dumps(
{
"publisher": "zmq",
"topic": "kv-events",
"endpoint": f"tcp://*:{kv_event_port}",
"enable_kv_cache_events": True,
}
)
command = [
"python",
"-m",
"dynamo.vllm",
"--model",
FAULT_TOLERANCE_MODEL_NAME,
"--load-format",
"gms",
"--enforce-eager",
"--enable-sleep-mode",
"--gpu-memory-utilization",
"0.9",
"--kv-events-config",
kv_events_cfg,
]
if read_only_weights:
command.extend(
[
"--model-loader-extra-config",
json.dumps({"gms_read_only": True}),
]
)
super().__init__(
command=command,
env={
**os.environ,
"PATH": f"{DYNAMO_BIN}:{os.environ.get('PATH', '')}",
"DYN_LOG": "debug",
"DYN_SYSTEM_PORT": str(system_port),
"VLLM_NIXL_SIDE_CHANNEL_PORT": str(nixl_port),
},
health_check_urls=[
(f"http://localhost:{system_port}/health", self._is_ready),
(f"http://localhost:{frontend_port}/v1/models", check_models_api),
(f"http://localhost:{frontend_port}/health", check_health_generate),
],
timeout=300,
display_output=True,
terminate_all_matching_process_names=False,
stragglers=[],
log_dir=log_dir,
display_name=engine_id,
)
def _is_ready(self, response) -> bool:
try:
return response.json().get("status") == "ready"
except ValueError:
return False
def sleep(self) -> dict:
"""Put the engine to sleep, offloading weights and KV cache."""
r = requests.post(
f"http://localhost:{self.system_port}/engine/sleep",
json={"level": 2},
timeout=30,
)
r.raise_for_status()
logger.info(f"{self.engine_id} sleep: {r.json()}")
return r.json()
def wake(self, timeout: int = 30) -> dict:
"""Wake the engine, restoring weights and KV cache."""
r = requests.post(
f"http://localhost:{self.system_port}/engine/wake_up",
json={"tags": ["weights", "kv_cache"]},
timeout=timeout,
)
r.raise_for_status()
logger.info(f"{self.engine_id} wake: {r.json()}")
return r.json()
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import pytest
pytest.importorskip("gpu_memory_service", reason="gpu_memory_service is required")
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import time
from concurrent.futures import ThreadPoolExecutor
from contextlib import ExitStack
from typing import Callable, Protocol
import pytest
from gpu_memory_service.client.session import _GMSClientSession
from gpu_memory_service.common.locks import RequestedLockType
from gpu_memory_service.server.fsm import ServerState
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import DynamoFrontendProcess
from ..harness.gms import GMSServerProcess
from ..harness.runtime import send_completion
from ..harness.sglang import SGLangWithGMSProcess
from ..harness.vllm import VLLMWithGMSProcess
# guard: external_weight_writer imports torch at module level
pytest.importorskip("torch", reason="torch is required")
from ..harness.external_weight_writer import run_external_weight_writer # noqa: E402
pytestmark = [pytest.mark.nightly]
# Event flow under test:
# 1. The engine starts in read-only mode and waits for a committed weights layout.
# 2. An external writer acquires RW on the weights GMS, loads dummy weights, commits, and exits.
# 3. The engine comes online with those committed weights while owning its own KV-cache RW layout.
# 4. The engine sleeps, preserving its weight VAs but dropping the KV-cache layout.
# 5. A second external writer acquires RW on the weights GMS, creates a fresh committed layout with
# different allocation IDs but the same structural layout, and exits.
# 6. The engine wakes, remaps the preserved weight VAs into the new committed layout, recreates its
# KV cache in a new RW layout, and serves inference without a stale-layout error.
class _SleepWakeEngine(Protocol):
def __enter__(self):
...
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
...
def sleep(self) -> dict:
...
def wake(self) -> dict:
...
def _list_committed_weight_allocations(
socket_path: str,
) -> list[tuple[int, str, int, int, str]]:
with _GMSClientSession(
socket_path, RequestedLockType.RO, timeout_ms=None
) as reader:
return [
(
int(info.layout_slot),
str(info.allocation_id),
int(info.size),
int(info.aligned_size),
str(info.tag),
)
for info in reader.list_allocations()
]
def _run_external_weight_mgr_test(
request,
ports: dict,
backend: str,
make_engine: Callable[[], _SleepWakeEngine],
) -> None:
with ExitStack() as stack:
weights_gms = stack.enter_context(
GMSServerProcess(request, device=0, tag="weights")
)
kv_cache_gms = stack.enter_context(
GMSServerProcess(request, device=0, tag="kv_cache")
)
stack.enter_context(
DynamoFrontendProcess(request, frontend_port=ports["frontend"])
)
engine = make_engine()
stack.callback(engine.__exit__, None, None, None)
with ThreadPoolExecutor(max_workers=1) as executor:
start_future = executor.submit(engine.__enter__)
try:
# The read-only engine must stall until some external writer
# publishes the first committed weights layout.
time.sleep(2.0)
assert (
not start_future.done()
), "read-only engine should still be waiting for committed weights"
assert weights_gms.get_runtime_state().state == ServerState.EMPTY
assert kv_cache_gms.get_runtime_state().state == ServerState.EMPTY
# Publish the first weights layout out-of-process, then let the
# engine finish importing those committed weights.
run_external_weight_writer(backend)
start_future.result(timeout=300)
first_weights_state = weights_gms.get_runtime_state()
first_kv_state = kv_cache_gms.get_runtime_state()
first_allocations = _list_committed_weight_allocations(
weights_gms.socket_path
)
assert first_weights_state.state == ServerState.RO
assert first_weights_state.allocation_count > 0
assert first_weights_state.memory_layout_hash
assert first_kv_state.state == ServerState.RW
assert first_allocations
result = send_completion(ports["frontend"])
assert result["choices"]
# Sleep preserves the engine's weight VAs but tears down the KV
# cache layout so the process can wake against a later weights layout.
assert engine.sleep()["status"] == "ok"
deadline = time.monotonic() + 30.0
while True:
weights_after_sleep = weights_gms.get_runtime_state()
kv_after_sleep = kv_cache_gms.get_runtime_state()
if (
weights_after_sleep.state == ServerState.COMMITTED
and weights_after_sleep.memory_layout_hash
== first_weights_state.memory_layout_hash
and weights_after_sleep.ro_session_count == 0
and kv_after_sleep.state == ServerState.EMPTY
and kv_after_sleep.allocation_count == 0
):
break
if time.monotonic() > deadline:
raise TimeoutError(
"engine sleep did not settle the expected GMS state"
)
time.sleep(0.1)
# Publish a second committed weights layout with the same logical
# layout but fresh allocation IDs.
run_external_weight_writer(backend)
deadline = time.monotonic() + 30.0
while True:
second_weights_state = weights_gms.get_runtime_state()
second_kv_state = kv_cache_gms.get_runtime_state()
if (
second_weights_state.state == ServerState.COMMITTED
and second_weights_state.memory_layout_hash
== first_weights_state.memory_layout_hash
and second_weights_state.allocation_count
== first_weights_state.allocation_count
and second_kv_state.state == ServerState.EMPTY
and second_kv_state.allocation_count == 0
):
break
if time.monotonic() > deadline:
raise TimeoutError(
"external writer did not publish a new committed weights layout"
)
time.sleep(0.1)
# The second publish must reuse the same layout slots and sizes so
# RO remap can bind the preserved VAs to new allocations.
second_allocations = _list_committed_weight_allocations(
weights_gms.socket_path
)
assert len(second_allocations) == len(first_allocations)
assert [item[0] for item in second_allocations] == [
item[0] for item in first_allocations
]
assert [item[2:] for item in second_allocations] == [
item[2:] for item in first_allocations
]
assert [item[1] for item in second_allocations] != [
item[1] for item in first_allocations
]
# The weights GMS should show the expected publish progression:
# first publish, old layout cleanup, second publish.
weights_events = weights_gms.get_event_history().events
assert [event.kind for event in weights_events] == [
"rw_connected",
"committed",
"allocations_cleared",
"rw_connected",
"committed",
]
assert (
weights_events[2].allocation_count
== first_weights_state.allocation_count
)
# Wake should remap the preserved RO weight VAs into the new
# committed layout and recreate KV cache in a new RW layout.
assert engine.wake()["status"] == "ok"
deadline = time.monotonic() + 30.0
while True:
weights_after_wake = weights_gms.get_runtime_state()
kv_after_wake = kv_cache_gms.get_runtime_state()
if (
weights_after_wake.state == ServerState.RO
and weights_after_wake.memory_layout_hash
== first_weights_state.memory_layout_hash
and kv_after_wake.state == ServerState.RW
and kv_after_wake.allocation_count > 0
):
break
if time.monotonic() > deadline:
raise TimeoutError(
"engine wake did not restore the expected GMS state"
)
time.sleep(0.1)
# A normal inference after wake proves the remapped weights and
# recreated KV cache are usable end to end.
result = send_completion(ports["frontend"], "updated weights")
assert result["choices"]
finally:
if start_future.done():
try:
start_future.result()
except Exception:
pass
@pytest.mark.vllm
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(600)
def test_external_weight_mgr_vllm(
request,
runtime_services_dynamic_ports,
gms_ports,
predownload_models,
):
ports = gms_ports
_run_external_weight_mgr_test(
request,
ports,
"vllm",
make_engine=lambda: VLLMWithGMSProcess(
request,
"engine",
ports["shadow_system"],
ports["shadow_kv_event"],
ports["shadow_nixl"],
ports["frontend"],
read_only_weights=True,
),
)
@pytest.mark.skip(reason="Nightly CI failure: https://linear.app/nvidia/issue/DYN-2567")
@pytest.mark.sglang
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(600)
def test_external_weight_mgr_sglang(
request,
runtime_services_dynamic_ports,
gms_ports,
predownload_models,
):
ports = gms_ports
_run_external_weight_mgr_test(
request,
ports,
"sglang",
make_engine=lambda: SGLangWithGMSProcess(
request,
"engine",
ports["shadow_system"],
ports["shadow_sglang"],
ports["frontend"],
read_only_weights=True,
),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment