Unverified Commit 91a8c6af authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

fix: simplify GMS layout state and harden GPU-backed flows (#7006)


Signed-off-by: default avatarmohammedabdulwahhab <furkhan324@berkeley.edu>
Signed-off-by: default avatarHannah Zhang <hannahz@nvidia.com>
Co-authored-by: default avatarmohammedabdulwahhab <furkhan324@berkeley.edu>
Co-authored-by: default avatarhhzhang16 <54051230+hhzhang16@users.noreply.github.com>
parent dd7ceb4a
...@@ -20,7 +20,6 @@ pytestmark = [ ...@@ -20,7 +20,6 @@ pytestmark = [
pytest.mark.pre_merge, pytest.mark.pre_merge,
pytest.mark.unit, pytest.mark.unit,
pytest.mark.gpu_0, pytest.mark.gpu_0,
pytest.mark.fault_tolerance,
] ]
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import pytest
from gpu_memory_service.client import memory_manager as memory_manager_module
from gpu_memory_service.client.memory_manager import (
GMSClientMemoryManager,
LocalMapping,
)
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.gpu_0,
]
class _FakeSession:
def __init__(self):
self.lock_type = GrantedLockType.RW
self.committed = False
self.closed = False
@property
def is_connected(self) -> bool:
return not self.closed
def get_memory_layout_hash(self) -> str:
return ""
def commit(self) -> bool:
self.closed = True
return True
def close(self) -> None:
self.closed = True
class _TrackingSession:
def __init__(self) -> None:
self.closed = False
def close(self) -> None:
self.closed = True
class _FailingCommitSession:
is_connected = True
def commit(self) -> bool:
raise ConnectionError("commit failed after local unmap")
class _SuccessfulCommitSession:
is_connected = True
def commit(self) -> bool:
return True
class _CloseFailingSession:
def close(self) -> None:
raise ConnectionError("close failed")
def _make_mapping(
allocation_id: str,
va: int,
*,
handle: int,
tag: str = "weights",
layout_slot: int = 0,
) -> LocalMapping:
return LocalMapping(
allocation_id=allocation_id,
va=va,
size=4096,
aligned_size=4096,
handle=handle,
tag=tag,
layout_slot=layout_slot,
)
@pytest.fixture
def manager(monkeypatch):
monkeypatch.setattr(
memory_manager_module, "cuda_set_current_device", lambda _device: None
)
monkeypatch.setattr(
memory_manager_module, "cumem_get_allocation_granularity", lambda _device: 65536
)
monkeypatch.setattr(memory_manager_module, "cuda_synchronize", lambda: None)
return GMSClientMemoryManager("/tmp/gms-test.sock", device=0)
def _make_manager(
monkeypatch,
*,
client: object | None,
lock_type: GrantedLockType | None,
mappings: list[LocalMapping] | None = None,
unmapped: bool = False,
va_preserved: bool = False,
layout_hash: str = "",
) -> GMSClientMemoryManager:
monkeypatch.setattr(
memory_manager_module, "cuda_set_current_device", lambda _device: None
)
manager = object.__new__(GMSClientMemoryManager)
manager.socket_path = "/tmp/gms-test.sock"
manager.device = 0
manager._client = client
manager._granted_lock_type = lock_type
manager._mappings = {}
manager._inverse_mapping = {}
manager._unmapped = unmapped
manager._va_preserved = va_preserved
manager._last_memory_layout_hash = layout_hash
manager.granularity = 4096
for mapping in mappings or []:
manager._track_mapping(mapping)
return manager
def test_commit_clears_client_lock_state(manager):
manager._client = _FakeSession()
manager._granted_lock_type = GrantedLockType.RW
assert manager.commit()
assert manager.granted_lock_type is None
assert not manager.is_connected
assert manager.is_unmapped
def test_abort_clears_client_lock_state(manager):
manager._client = _FakeSession()
manager._granted_lock_type = GrantedLockType.RW
manager.abort()
assert manager.granted_lock_type is None
assert not manager.is_connected
def test_connect_rejects_double_connect(monkeypatch):
manager = _make_manager(
monkeypatch,
client=object(),
lock_type=GrantedLockType.RO,
)
with pytest.raises(RuntimeError, match="already connected"):
manager.connect(RequestedLockType.RO)
def test_abort_drops_session_with_live_mappings(monkeypatch):
session = _TrackingSession()
manager = _make_manager(
monkeypatch,
client=session,
lock_type=GrantedLockType.RO,
mappings=[_make_mapping("alloc-1", 0x1000, handle=1234)],
)
manager.abort()
assert session.closed
assert manager.granted_lock_type is None
assert not manager.is_connected
assert manager.mappings[0x1000].handle == 1234
def test_commit_failure_after_local_unmap_keeps_preserved_unmapped_state(monkeypatch):
manager = _make_manager(
monkeypatch,
client=_FailingCommitSession(),
lock_type=GrantedLockType.RW,
mappings=[
_make_mapping("alloc_1", 0x1000, handle=11),
_make_mapping("alloc_2", 0x2000, handle=0),
],
)
unmapped_vas: list[int] = []
monkeypatch.setattr(memory_manager_module, "cuda_synchronize", lambda: None)
def fake_unmap_va(self, va: int) -> None:
unmapped_vas.append(va)
self._mappings[va] = self._mappings[va].with_handle(0)
monkeypatch.setattr(GMSClientMemoryManager, "unmap_va", fake_unmap_va)
with pytest.raises(ConnectionError, match="commit failed after local unmap"):
manager.commit()
assert unmapped_vas == [0x1000]
assert manager._mappings[0x1000].handle == 0
assert manager._mappings[0x2000].handle == 0
assert manager._va_preserved
assert manager._unmapped
assert manager._client is not None
def test_successful_commit_clears_rw_mode_before_local_cleanup(monkeypatch):
manager = _make_manager(
monkeypatch,
client=_SuccessfulCommitSession(),
lock_type=GrantedLockType.RW,
mappings=[_make_mapping("alloc_1", 0x1000, handle=11)],
)
free_handle_calls: list[str] = []
monkeypatch.setattr(memory_manager_module, "cuda_synchronize", lambda: None)
def fake_unmap_va(self, va: int) -> None:
self._mappings[va] = self._mappings[va].with_handle(0)
def fake_free_va(self, va: int) -> None:
mapping = self._mappings.pop(va)
self._inverse_mapping.pop(mapping.allocation_id, None)
def fake_free_handle(self, allocation_id: str) -> bool:
free_handle_calls.append(allocation_id)
return True
monkeypatch.setattr(GMSClientMemoryManager, "unmap_va", fake_unmap_va)
monkeypatch.setattr(GMSClientMemoryManager, "free_va", fake_free_va)
monkeypatch.setattr(GMSClientMemoryManager, "free_handle", fake_free_handle)
assert manager.commit()
assert manager._client is None
assert manager._granted_lock_type is None
assert manager._unmapped
assert manager._va_preserved
manager.destroy_mapping(0x1000)
assert free_handle_calls == []
def test_destroy_mapping_keeps_local_state_when_server_free_fails(monkeypatch):
mapping = _make_mapping("alloc_1", 0x1000, handle=77)
manager = _make_manager(
monkeypatch,
client=None,
lock_type=GrantedLockType.RW,
mappings=[mapping],
)
monkeypatch.setattr(
manager,
"free_handle",
lambda allocation_id: (_ for _ in ()).throw(RuntimeError("server free failed")),
)
with pytest.raises(RuntimeError, match="server free failed"):
manager.destroy_mapping(mapping.va)
assert manager.mappings[mapping.va] == mapping
def test_disconnect_clears_local_state_even_if_close_fails(monkeypatch):
manager = _make_manager(
monkeypatch,
client=_CloseFailingSession(),
lock_type=GrantedLockType.RO,
)
with pytest.raises(ConnectionError, match="close failed"):
manager.abort()
assert manager._client is None
assert manager._granted_lock_type is None
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import pytest
from gpu_memory_service.client.rpc import _GMSRPCTransport
from gpu_memory_service.client.session import _GMSClientSession
from gpu_memory_service.common.protocol.messages import (
CommitResponse,
HandshakeResponse,
)
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.gpu_0,
]
def _patch_handshake(
monkeypatch,
*,
response: HandshakeResponse | None = None,
error: Exception | None = None,
) -> dict[str, bool]:
closed = {"value": False}
monkeypatch.setattr(_GMSRPCTransport, "connect", lambda self: None)
if error is None:
monkeypatch.setattr(
_GMSRPCTransport,
"handshake",
lambda self, lock_type, timeout_ms: response,
)
else:
monkeypatch.setattr(
_GMSRPCTransport,
"handshake",
lambda self, lock_type, timeout_ms: (_ for _ in ()).throw(error),
)
monkeypatch.setattr(
_GMSRPCTransport,
"close",
lambda self: closed.__setitem__("value", True),
)
return closed
def test_client_session_timeout_closes_transport(monkeypatch):
closed = _patch_handshake(
monkeypatch,
response=HandshakeResponse(success=False, committed=False),
)
with pytest.raises(TimeoutError, match="Timeout waiting for lock"):
_GMSClientSession("/tmp/gms-test.sock", RequestedLockType.RO, 1000)
assert closed["value"]
def test_client_session_handshake_failure_closes_transport(monkeypatch):
closed = _patch_handshake(monkeypatch, error=RuntimeError("handshake failed"))
with pytest.raises(RuntimeError, match="handshake failed"):
_GMSClientSession("/tmp/gms-test.sock", RequestedLockType.RO, 1000)
assert closed["value"]
def test_client_session_records_granted_lock_and_committed(monkeypatch):
_patch_handshake(
monkeypatch,
response=HandshakeResponse(
success=True,
committed=True,
granted_lock_type=GrantedLockType.RO,
),
)
session = _GMSClientSession("/tmp/gms-test.sock", RequestedLockType.RW_OR_RO, None)
assert session.committed
assert session.lock_type == GrantedLockType.RO
assert session.is_ready()
def test_client_session_requires_granted_lock_type(monkeypatch):
closed = _patch_handshake(
monkeypatch,
response=HandshakeResponse(
success=True,
committed=False,
granted_lock_type=None,
),
)
with pytest.raises(RuntimeError, match="granted_lock_type"):
_GMSClientSession("/tmp/gms-test.sock", RequestedLockType.RW_OR_RO, None)
assert closed["value"]
def test_client_session_commit_marks_committed_and_closes_transport(monkeypatch):
closed = _patch_handshake(
monkeypatch,
response=HandshakeResponse(
success=True,
committed=False,
granted_lock_type=GrantedLockType.RW,
),
)
monkeypatch.setattr(
_GMSRPCTransport,
"request",
lambda self, request, response_type: CommitResponse(success=True),
)
session = _GMSClientSession("/tmp/gms-test.sock", RequestedLockType.RW, None)
assert not session.committed
assert session.commit()
assert session.committed
assert closed["value"]
def test_client_session_commit_tolerates_close_failure_after_success(monkeypatch):
monkeypatch.setattr(_GMSRPCTransport, "connect", lambda self: None)
monkeypatch.setattr(
_GMSRPCTransport,
"handshake",
lambda self, lock_type, timeout_ms: HandshakeResponse(
success=True,
committed=False,
granted_lock_type=GrantedLockType.RW,
),
)
monkeypatch.setattr(
_GMSRPCTransport,
"request",
lambda self, request, response_type: CommitResponse(success=True),
)
monkeypatch.setattr(
_GMSRPCTransport,
"close",
lambda self: (_ for _ in ()).throw(ConnectionError("close failed")),
)
session = _GMSClientSession("/tmp/gms-test.sock", RequestedLockType.RW, None)
assert session.commit()
assert session.committed
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import pytest
from gpu_memory_service.client.rpc import _GMSRPCTransport
from gpu_memory_service.common.protocol import wire
from gpu_memory_service.common.protocol.messages import (
CommitResponse,
ErrorResponse,
HandshakeResponse,
)
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.gpu_0,
]
class _DummySocket:
def __init__(self) -> None:
self.closed = False
def close(self) -> None:
self.closed = True
def test_transport_failure_closes_socket_and_marks_disconnected(monkeypatch):
transport = _GMSRPCTransport("/tmp/gms-test.sock")
transport._socket = _DummySocket()
monkeypatch.setattr(
"gpu_memory_service.client.rpc.send_message_sync",
lambda sock, request: None,
)
monkeypatch.setattr(
"gpu_memory_service.client.rpc.recv_message_sync",
lambda sock, buffer: (_ for _ in ()).throw(BrokenPipeError("boom")),
)
with pytest.raises(ConnectionError, match="failed: boom"):
transport.request(CommitResponse(success=True), HandshakeResponse)
assert not transport.is_connected
assert transport._socket is None
def test_request_with_fd_closes_fd_on_unexpected_response_type(monkeypatch):
transport = _GMSRPCTransport("/tmp/gms-test.sock")
closed_fds: list[int] = []
monkeypatch.setattr(
transport,
"_send_recv",
lambda request, error_prefix=None: (CommitResponse(success=True), 37),
)
monkeypatch.setattr("gpu_memory_service.client.rpc.os.close", closed_fds.append)
with pytest.raises(RuntimeError, match="unexpected response type"):
transport.request_with_fd(
CommitResponse(success=True),
HandshakeResponse,
)
assert closed_fds == [37]
def test_request_closes_fd_on_error_response(monkeypatch):
transport = _GMSRPCTransport("/tmp/gms-test.sock")
transport._socket = _DummySocket()
closed_fds: list[int] = []
monkeypatch.setattr(
"gpu_memory_service.client.rpc.send_message_sync",
lambda sock, request: None,
)
monkeypatch.setattr(
"gpu_memory_service.client.rpc.recv_message_sync",
lambda sock, buffer: (ErrorResponse(error="boom"), 41, bytearray()),
)
monkeypatch.setattr("gpu_memory_service.client.rpc.os.close", closed_fds.append)
with pytest.raises(RuntimeError, match="error: boom"):
transport.request(CommitResponse(success=True), HandshakeResponse)
assert closed_fds == [41]
def test_request_closes_fd_on_unexpected_success_fd(monkeypatch):
transport = _GMSRPCTransport("/tmp/gms-test.sock")
closed_fds: list[int] = []
monkeypatch.setattr(
transport,
"request_with_fd",
lambda request, response_type: (CommitResponse(success=True), 43),
)
monkeypatch.setattr("gpu_memory_service.client.rpc.os.close", closed_fds.append)
with pytest.raises(RuntimeError, match="unexpected FD"):
transport.request(CommitResponse(success=True), CommitResponse)
assert closed_fds == [43]
def test_recv_message_sync_closes_fd_on_decode_failure(monkeypatch):
closed_fds: list[int] = []
monkeypatch.setattr(
wire.socket,
"recv_fds",
lambda sock, size, maxfds: (b"\x00\x00\x00\x01x", [53], 0, None),
)
monkeypatch.setattr(
wire,
"decode_message",
lambda payload: (_ for _ in ()).throw(ValueError("bad frame")),
)
monkeypatch.setattr(
"gpu_memory_service.common.protocol.wire.os.close",
closed_fds.append,
)
with pytest.raises(ValueError, match="bad frame"):
wire.recv_message_sync(object(), bytearray())
assert closed_fds == [53]
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from types import SimpleNamespace
import pytest
from tests.gms.harness.gms import GMSServerProcess
from tests.utils.managed_process import ManagedProcess
pytestmark = [pytest.mark.pre_merge, pytest.mark.unit, pytest.mark.gpu_0]
@pytest.fixture
def request_stub():
return SimpleNamespace(node=SimpleNamespace(name="gms-harness"))
def test_server_process_refuses_foreign_live_socket(monkeypatch, request_stub):
server = GMSServerProcess(request_stub, device=0, tag="weights")
monkeypatch.setattr("tests.gms.harness.gms.os.path.exists", lambda path: True)
monkeypatch.setattr("tests.gms.harness.gms._socket_has_live_gms", lambda path: True)
with pytest.raises(RuntimeError, match="already active"):
server.__enter__()
def test_server_process_unlinks_only_stale_socket_on_exit(monkeypatch, request_stub):
server = GMSServerProcess(request_stub, device=0, tag="weights")
unlinked: list[str] = []
monkeypatch.setattr(
ManagedProcess,
"__exit__",
lambda self, exc_type, exc_val, exc_tb: False,
)
monkeypatch.setattr("tests.gms.harness.gms.os.path.exists", lambda path: True)
monkeypatch.setattr("tests.gms.harness.gms.os.unlink", unlinked.append)
monkeypatch.setattr("tests.gms.harness.gms._socket_has_live_gms", lambda path: True)
server.__exit__(None, None, None)
assert unlinked == []
monkeypatch.setattr(
"tests.gms.harness.gms._socket_has_live_gms", lambda path: False
)
server.__exit__(None, None, None)
assert unlinked == [server.socket_path]
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import asyncio
import itertools
import os
import signal
import socket
import subprocess
import sys
import textwrap
import threading
import time
import pynvml
import pytest
from gpu_memory_service.client import memory_manager as client_memory_manager
from gpu_memory_service.client.memory_manager import (
GMSClientMemoryManager,
StaleMemoryLayoutError,
)
from gpu_memory_service.client.rpc import _GMSRPCTransport
from gpu_memory_service.client.session import _GMSClientSession
from gpu_memory_service.common import cuda_utils
from gpu_memory_service.common.protocol.messages import (
GetEventHistoryRequest,
GetEventHistoryResponse,
GetRuntimeStateRequest,
GetRuntimeStateResponse,
)
from gpu_memory_service.common.types import (
GrantedLockType,
RequestedLockType,
ServerState,
)
from gpu_memory_service.server import allocations as server_allocations
from gpu_memory_service.server.allocations import GMSAllocationManager
from gpu_memory_service.server.rpc import GMSRPCServer
from tests.gms.harness.gms import ServerThread
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.gpu_0,
]
def _gpu_memory_free_bytes(device: int = 0) -> int:
pynvml.nvmlInit()
try:
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
return int(pynvml.nvmlDeviceGetMemoryInfo(handle).free)
finally:
pynvml.nvmlShutdown()
def _drop_connection(session: _GMSClientSession) -> None:
# Use a raw transport break here, not abort(), because these tests need to
# simulate an unexpected socket loss while a request is still in flight.
sock = session._transport._socket
assert sock is not None
try:
sock.shutdown(socket.SHUT_RDWR)
except OSError:
pass
sock.close()
session._transport._socket = None
def _wait_for_server_state(
server: GMSRPCServer,
expected: ServerState,
timeout: float = 2.0,
) -> None:
deadline = time.monotonic() + timeout
while server.state != expected:
if time.monotonic() > deadline:
raise TimeoutError(f"server did not reach {expected.name}")
time.sleep(0.01)
def _wait_for_waiting_writers(
server: GMSRPCServer,
expected: int,
timeout: float = 2.0,
) -> None:
deadline = time.monotonic() + timeout
while server._gms._sessions.snapshot().waiting_writers != expected:
if time.monotonic() > deadline:
raise TimeoutError(f"waiting writers did not reach {expected}")
time.sleep(0.01)
def _wait_for_ro_session_count(
server: GMSRPCServer,
expected: int,
timeout: float = 2.0,
) -> None:
deadline = time.monotonic() + timeout
while server._gms._sessions.snapshot().ro_session_count != expected:
if time.monotonic() > deadline:
raise TimeoutError(f"RO session count did not reach {expected}")
time.sleep(0.01)
@pytest.fixture
def real_gms(monkeypatch, tmp_path):
server_handles = itertools.count(1000)
client_handles = itertools.count(10000)
next_va = itertools.count(0x100000, 0x10000)
monkeypatch.setattr(server_allocations, "cuda_ensure_initialized", lambda: None)
monkeypatch.setattr(
server_allocations,
"cumem_get_allocation_granularity",
lambda device: 4096,
)
monkeypatch.setattr(
server_allocations,
"cumem_create_tolerate_oom",
lambda size, device: (True, next(server_handles)),
)
monkeypatch.setattr(server_allocations, "cumem_release", lambda handle: None)
def export_fd(handle: int) -> int:
read_fd, write_fd = os.pipe()
os.close(write_fd)
return read_fd
monkeypatch.setattr(
server_allocations, "cumem_export_to_shareable_handle", export_fd
)
monkeypatch.setattr(
client_memory_manager, "cuda_set_current_device", lambda device: None
)
monkeypatch.setattr(
client_memory_manager,
"cumem_get_allocation_granularity",
lambda device: 4096,
)
monkeypatch.setattr(client_memory_manager, "cuda_synchronize", lambda: None)
monkeypatch.setattr(
client_memory_manager,
"cumem_address_reserve",
lambda size, granularity: next(next_va),
)
monkeypatch.setattr(
client_memory_manager,
"cumem_address_free",
lambda va, size: None,
)
monkeypatch.setattr(
client_memory_manager, "cumem_map", lambda va, size, handle: None
)
monkeypatch.setattr(
client_memory_manager,
"cumem_set_access",
lambda va, size, device, mode: None,
)
monkeypatch.setattr(client_memory_manager, "cumem_unmap", lambda va, size: None)
monkeypatch.setattr(client_memory_manager, "cumem_release", lambda handle: None)
monkeypatch.setattr(client_memory_manager, "cuda_validate_pointer", lambda va: True)
def import_fd(fd: int) -> int:
os.close(fd)
return next(client_handles)
monkeypatch.setattr(
client_memory_manager,
"cumem_import_from_shareable_handle_close_fd",
import_fd,
)
socket_path = str(tmp_path / "gms.sock")
server = GMSRPCServer(socket_path, device=0, allocation_retry_interval=0.01)
thread = ServerThread(server, socket_path)
thread.start()
try:
yield server, socket_path
finally:
thread.stop()
def test_rw_commit_publishes_allocations_metadata_and_layout_hash(real_gms):
server, socket_path = real_gms
writer = GMSClientMemoryManager(socket_path, device=0)
writer.connect(RequestedLockType.RW)
va = writer.create_mapping(size=4096, tag="weights")
allocation_id = writer.mappings[va].allocation_id
writer.metadata_put("tensor.0", allocation_id, 0, b"weights")
assert writer.commit()
reader = _GMSClientSession(socket_path, RequestedLockType.RO, None)
try:
assert reader.lock_type == GrantedLockType.RO
assert reader.committed
assert len(reader.list_allocations()) == 1
assert reader.metadata_get("tensor.0") == (allocation_id, 0, b"weights")
assert reader.get_memory_layout_hash()
finally:
reader.close()
assert writer.is_unmapped
assert not writer.is_connected
_wait_for_server_state(server, ServerState.COMMITTED)
def test_rw_disconnect_aborts_layout_and_next_writer_starts_clean(real_gms):
server, socket_path = real_gms
writer = _GMSClientSession(socket_path, RequestedLockType.RW, None)
allocation_id, _ = writer.allocate(4096, "weights")
writer.metadata_put("stale", allocation_id, 0, b"value")
_drop_connection(writer)
_wait_for_server_state(server, ServerState.EMPTY)
next_writer = _GMSClientSession(socket_path, RequestedLockType.RW, None)
try:
assert next_writer.list_allocations() == []
assert next_writer.metadata_list() == []
finally:
next_writer.close()
def test_rw_or_ro_grants_rw_from_empty_and_ro_from_committed(real_gms):
server, socket_path = real_gms
session = _GMSClientSession(socket_path, RequestedLockType.RW_OR_RO, 100)
assert session.lock_type == GrantedLockType.RW
session.commit()
_wait_for_server_state(server, ServerState.COMMITTED)
session = _GMSClientSession(socket_path, RequestedLockType.RW_OR_RO, 100)
try:
assert session.lock_type == GrantedLockType.RO
assert session.committed
finally:
session.close()
def test_runtime_state_and_event_history_are_side_effect_free(real_gms):
server, socket_path = real_gms
writer = GMSClientMemoryManager(socket_path, device=0)
writer.connect(RequestedLockType.RW)
writer.create_mapping(size=4096, tag="weights")
assert writer.commit()
assert server._gms._sessions.snapshot().ro_session_count == 0
with _GMSRPCTransport(socket_path) as transport:
transport.connect()
state = transport.request(
GetRuntimeStateRequest(),
GetRuntimeStateResponse,
)
with _GMSRPCTransport(socket_path) as transport:
transport.connect()
history = transport.request(
GetEventHistoryRequest(),
GetEventHistoryResponse,
)
assert state.state == ServerState.COMMITTED.name
assert state.committed
assert state.is_ready
assert state.ro_session_count == 0
assert state.waiting_writers == 0
assert state.allocation_count == 1
assert state.memory_layout_hash
assert [event.kind for event in history.events] == ["rw_connected", "committed"]
assert server._gms._sessions.snapshot().ro_session_count == 0
def test_committed_layout_is_replaced_when_new_writer_connects(real_gms):
server, socket_path = real_gms
first_writer = GMSClientMemoryManager(socket_path, device=0)
first_writer.connect(RequestedLockType.RW)
first_writer.create_mapping(size=4096, tag="weights")
assert first_writer.commit()
_wait_for_server_state(server, ServerState.COMMITTED)
assert server._gms.allocation_count == 1
second_writer = _GMSClientSession(socket_path, RequestedLockType.RW, None)
try:
assert second_writer.lock_type == GrantedLockType.RW
assert second_writer.list_allocations() == []
assert second_writer.metadata_list() == []
assert server._gms.allocation_count == 0
assert server.state == ServerState.RW
assert not server._gms.committed
finally:
second_writer.close()
def test_reader_mapping_disconnect_then_next_writer_clears_old_layout(real_gms):
server, socket_path = real_gms
writer = GMSClientMemoryManager(socket_path, device=0)
writer.connect(RequestedLockType.RW)
va = writer.create_mapping(size=4096, tag="weights")
allocation_id = writer.mappings[va].allocation_id
assert writer.commit()
reader = GMSClientMemoryManager(socket_path, device=0)
reader.connect(RequestedLockType.RO)
imported_va = reader.create_mapping(allocation_id=allocation_id)
assert reader.mappings[imported_va].handle != 0
next_writer_result: dict[str, object] = {}
def open_writer() -> None:
try:
next_writer_result["session"] = _GMSClientSession(
socket_path,
RequestedLockType.RW,
500,
)
except Exception as exc:
next_writer_result["error"] = exc
thread = threading.Thread(target=open_writer)
thread.start()
_wait_for_waiting_writers(server, 1)
assert thread.is_alive()
assert server.state == ServerState.RO
assert server._gms.allocation_count == 1
reader.unmap_all_vas()
reader.abort()
thread.join(timeout=2)
next_writer = next_writer_result.get("session")
assert isinstance(next_writer, _GMSClientSession)
try:
assert next_writer.lock_type == GrantedLockType.RW
assert next_writer.list_allocations() == []
assert server._gms.allocation_count == 0
assert server.state == ServerState.RW
assert not server._gms.committed
finally:
next_writer.close()
def test_waiting_writer_blocks_new_readers_until_last_reader_disconnects(real_gms):
server, socket_path = real_gms
writer = _GMSClientSession(socket_path, RequestedLockType.RW, None)
writer.commit()
reader = _GMSClientSession(socket_path, RequestedLockType.RO, None)
writer_result: dict[str, object] = {}
def open_writer() -> None:
try:
writer_result["session"] = _GMSClientSession(
socket_path,
RequestedLockType.RW,
500,
)
except Exception as exc:
writer_result["error"] = exc
thread = threading.Thread(target=open_writer)
thread.start()
_wait_for_waiting_writers(server, 1)
with pytest.raises(TimeoutError, match="Timeout waiting for lock"):
_GMSClientSession(socket_path, RequestedLockType.RO, 100)
reader.close()
thread.join(timeout=2)
waiting_writer = writer_result.get("session")
assert isinstance(waiting_writer, _GMSClientSession)
try:
assert waiting_writer.lock_type == GrantedLockType.RW
finally:
waiting_writer.close()
def test_rw_or_ro_times_out_while_writer_waits_behind_reader(real_gms):
server, socket_path = real_gms
writer = _GMSClientSession(socket_path, RequestedLockType.RW, None)
writer.commit()
reader = _GMSClientSession(socket_path, RequestedLockType.RO, None)
waiting_writer: dict[str, object] = {}
def block_writer() -> None:
try:
waiting_writer["session"] = _GMSClientSession(
socket_path,
RequestedLockType.RW,
500,
)
except Exception as exc:
waiting_writer["error"] = exc
thread = threading.Thread(target=block_writer)
thread.start()
_wait_for_waiting_writers(server, 1)
with pytest.raises(TimeoutError, match="Timeout waiting for lock"):
_GMSClientSession(socket_path, RequestedLockType.RW_OR_RO, 100)
reader.close()
thread.join(timeout=2)
granted_writer = waiting_writer.get("session")
assert isinstance(granted_writer, _GMSClientSession)
granted_writer.close()
def test_reader_can_acquire_after_waiting_writer_times_out(real_gms):
server, socket_path = real_gms
writer = _GMSClientSession(socket_path, RequestedLockType.RW, None)
writer.commit()
reader = _GMSClientSession(socket_path, RequestedLockType.RO, None)
writer_result: dict[str, BaseException | None] = {"error": None}
def timeout_writer() -> None:
try:
_GMSClientSession(socket_path, RequestedLockType.RW, 100)
except BaseException as exc:
writer_result["error"] = exc
thread = threading.Thread(target=timeout_writer)
thread.start()
_wait_for_waiting_writers(server, 1)
thread.join(timeout=2)
assert isinstance(writer_result["error"], TimeoutError)
_wait_for_waiting_writers(server, 0)
second_reader = _GMSClientSession(socket_path, RequestedLockType.RO, 200)
try:
assert second_reader.lock_type == GrantedLockType.RO
finally:
second_reader.close()
reader.close()
def test_multiple_readers_hold_committed_state_until_last_disconnect(real_gms):
server, socket_path = real_gms
writer = _GMSClientSession(socket_path, RequestedLockType.RW, None)
writer.commit()
reader_a = _GMSClientSession(socket_path, RequestedLockType.RO, None)
reader_b = _GMSClientSession(socket_path, RequestedLockType.RO, None)
_wait_for_server_state(server, ServerState.RO)
assert server._gms._sessions.snapshot().ro_session_count == 2
reader_a.close()
_wait_for_ro_session_count(server, 1)
assert server.state == ServerState.RO
reader_b.close()
_wait_for_server_state(server, ServerState.COMMITTED)
def test_ro_session_rejects_rw_only_requests(real_gms):
_, socket_path = real_gms
writer = _GMSClientSession(socket_path, RequestedLockType.RW, None)
writer.commit()
reader = _GMSClientSession(socket_path, RequestedLockType.RO, None)
try:
with pytest.raises(RuntimeError, match="not allowed for RO session"):
reader.allocate(4096, "weights")
with pytest.raises(RuntimeError, match="not allowed for RO session"):
reader.commit()
finally:
reader.close()
def test_lock_and_allocation_state_requests_reflect_real_server_state(real_gms):
_, socket_path = real_gms
writer = _GMSClientSession(socket_path, RequestedLockType.RW, None)
allocation_id, _ = writer.allocate(4096, "weights")
lock_state = writer.get_lock_state()
allocation_state = writer.get_allocation_state()
assert lock_state.state == ServerState.RW.name
assert lock_state.has_rw_session
assert lock_state.ro_session_count == 0
assert allocation_state.allocation_count == 1
writer.metadata_put("tensor.0", allocation_id, 0, b"x")
writer.commit()
reader = _GMSClientSession(socket_path, RequestedLockType.RO, None)
try:
lock_state = reader.get_lock_state()
allocation_state = reader.get_allocation_state()
assert lock_state.state == ServerState.RO.name
assert not lock_state.has_rw_session
assert lock_state.ro_session_count == 1
assert allocation_state.allocation_count == 1
finally:
reader.close()
def test_invalid_metadata_offset_is_rejected_without_mutating_state(real_gms):
_, socket_path = real_gms
writer = _GMSClientSession(socket_path, RequestedLockType.RW, None)
try:
allocation_id, _ = writer.allocate(4096, "weights")
with pytest.raises(RuntimeError, match="out of range"):
writer.metadata_put("tensor.bad", allocation_id, 4096, b"x")
assert writer.metadata_list() == []
finally:
writer.close()
def test_destroy_mapping_frees_allocation_and_metadata(real_gms):
_, socket_path = real_gms
writer = GMSClientMemoryManager(socket_path, device=0)
writer.connect(RequestedLockType.RW)
va = writer.create_mapping(size=4096, tag="weights")
allocation_id = writer.mappings[va].allocation_id
writer.metadata_put("tensor.0", allocation_id, 0, b"payload")
writer.destroy_mapping(va)
assert writer.list_handles() == []
assert writer.metadata_list() == []
writer.abort()
def test_remap_all_vas_succeeds_when_committed_layout_is_unchanged(real_gms):
_, socket_path = real_gms
writer = GMSClientMemoryManager(socket_path, device=0)
writer.connect(RequestedLockType.RW)
va = writer.create_mapping(size=4096, tag="weights")
allocation_id = writer.mappings[va].allocation_id
assert writer.commit()
reader = GMSClientMemoryManager(socket_path, device=0)
reader.connect(RequestedLockType.RO)
imported_va = reader.create_mapping(allocation_id=allocation_id)
imported_mapping = reader.mappings[imported_va]
reader.unmap_all_vas()
reader.abort()
reader.connect(RequestedLockType.RO)
reader.remap_all_vas()
assert reader.mappings[imported_va].handle != 0
assert reader.mappings[imported_va].allocation_id == imported_mapping.allocation_id
reader.close()
def test_remap_all_vas_rejects_stale_layout_after_new_layout_commit(real_gms):
_, socket_path = real_gms
writer = GMSClientMemoryManager(socket_path, device=0)
writer.connect(RequestedLockType.RW)
va = writer.create_mapping(size=4096, tag="weights")
allocation_id = writer.mappings[va].allocation_id
assert writer.commit()
reader = GMSClientMemoryManager(socket_path, device=0)
reader.connect(RequestedLockType.RO)
reader.create_mapping(allocation_id=allocation_id)
reader.unmap_all_vas()
reader.abort()
next_writer = GMSClientMemoryManager(socket_path, device=0)
next_writer.connect(RequestedLockType.RW)
next_writer.create_mapping(size=8192, tag="weights")
assert next_writer.commit()
reader.connect(RequestedLockType.RO)
with pytest.raises(StaleMemoryLayoutError, match="Layout changed"):
reader.remap_all_vas()
reader.abort()
def test_remap_all_vas_accepts_new_layout_with_same_structural_layout(real_gms):
_, socket_path = real_gms
first_writer = GMSClientMemoryManager(socket_path, device=0)
first_writer.connect(RequestedLockType.RW)
va = first_writer.create_mapping(size=4096, tag="weights")
first_allocation_id = first_writer.mappings[va].allocation_id
first_writer.metadata_put("tensor.0", first_allocation_id, 0, b"shape")
assert first_writer.commit()
reader = GMSClientMemoryManager(socket_path, device=0)
reader.connect(RequestedLockType.RO)
imported_va = reader.create_mapping(allocation_id=first_allocation_id)
reader.unmap_all_vas()
reader.abort()
second_writer = GMSClientMemoryManager(socket_path, device=0)
second_writer.connect(RequestedLockType.RW)
second_va = second_writer.create_mapping(size=4096, tag="weights")
second_allocation_id = second_writer.mappings[second_va].allocation_id
assert second_allocation_id != first_allocation_id
second_writer.metadata_put("tensor.0", second_allocation_id, 0, b"shape")
assert second_writer.commit()
reader.connect(RequestedLockType.RO)
reader.remap_all_vas()
assert reader.mappings[imported_va].va == imported_va
assert reader.mappings[imported_va].allocation_id == second_allocation_id
assert reader.metadata_get("tensor.0") == (second_allocation_id, 0, b"shape")
reader.close()
def test_reallocate_all_handles_reuses_preserved_vas_in_new_layout(real_gms):
server, socket_path = real_gms
manager = GMSClientMemoryManager(socket_path, device=0)
manager.connect(RequestedLockType.RW)
va = manager.create_mapping(size=4096, tag="weights")
old_allocation_id = manager.mappings[va].allocation_id
assert manager.commit()
_wait_for_server_state(server, ServerState.COMMITTED)
manager.connect(RequestedLockType.RW)
manager.reallocate_all_handles(tag="weights")
assert manager.mappings[va].allocation_id != old_allocation_id
assert manager.mappings[va].handle == 0
manager.remap_all_vas()
assert manager.mappings[va].va == va
assert manager.mappings[va].handle != 0
manager.close()
_wait_for_server_state(server, ServerState.EMPTY)
def test_same_process_republish_remaps_against_new_committed_hash(real_gms):
_, socket_path = real_gms
manager = GMSClientMemoryManager(socket_path, device=0)
manager.connect(RequestedLockType.RW)
va = manager.create_mapping(size=4096, tag="weights")
first_allocation_id = manager.mappings[va].allocation_id
manager.metadata_put("tensor", first_allocation_id, 0, b"publish-1")
assert manager.commit()
manager.connect(RequestedLockType.RO)
manager.remap_all_vas()
manager.unmap_all_vas()
manager.abort()
manager.connect(RequestedLockType.RW)
manager.reallocate_all_handles(tag="weights")
second_allocation_id = manager.mappings[va].allocation_id
assert second_allocation_id != first_allocation_id
manager.remap_all_vas()
manager.metadata_put("tensor", second_allocation_id, 0, b"publish-2")
assert manager.commit()
manager.connect(RequestedLockType.RO)
manager.remap_all_vas()
assert manager.mappings[va].va == va
assert manager.mappings[va].allocation_id == second_allocation_id
assert manager.metadata_get("tensor") == (second_allocation_id, 0, b"publish-2")
manager.close()
def test_disconnect_during_allocation_retry_aborts_writer_and_unblocks_next_writer(
real_gms,
monkeypatch,
):
server, socket_path = real_gms
oom_attempts = 0
allow_allocation = False
def always_oom(size: int, device: int) -> tuple[bool, int]:
nonlocal oom_attempts
nonlocal allow_allocation
if allow_allocation:
return True, 4242
oom_attempts += 1
return False, 0
monkeypatch.setattr(
"gpu_memory_service.server.allocations.cumem_create_tolerate_oom",
always_oom,
)
writer = _GMSClientSession(socket_path, RequestedLockType.RW, None)
result: dict[str, BaseException] = {}
def allocate() -> None:
try:
writer.allocate(4096, "weights")
except BaseException as exc:
result["error"] = exc
thread = threading.Thread(target=allocate)
thread.start()
deadline = time.monotonic() + 2.0
while oom_attempts == 0:
if time.monotonic() > deadline:
raise TimeoutError("allocation retry never reached CUDA OOM")
time.sleep(0.01)
_drop_connection(writer)
thread.join(timeout=2)
_wait_for_server_state(server, ServerState.EMPTY)
allow_allocation = True
next_writer = _GMSClientSession(socket_path, RequestedLockType.RW, 200)
try:
assert next_writer.lock_type == GrantedLockType.RW
allocation_id, aligned_size = next_writer.allocate(4096, "weights")
assert allocation_id
assert aligned_size == 4096
finally:
next_writer.close()
assert isinstance(result.get("error"), ConnectionError)
@pytest.mark.asyncio
@pytest.mark.timeout(180)
async def test_new_layout_large_allocation_waits_for_dead_writer_process(
tmp_path,
monkeypatch,
):
free_before = _gpu_memory_free_bytes()
size = int(free_before * 0.9)
assert size > 0
oom_failures = 0
def count_oom(size: int, device: int) -> tuple[bool, int]:
nonlocal oom_failures
allocated, handle = cuda_utils.cumem_create_tolerate_oom(size, device)
if not allocated:
oom_failures += 1
return allocated, handle
monkeypatch.setattr(
"gpu_memory_service.server.allocations.cumem_create_tolerate_oom",
count_oom,
)
allocations = GMSAllocationManager(
device=0,
allocation_retry_interval=0.1,
allocation_retry_timeout=120.0,
)
holder = None
allocation_task = None
try:
first = await allocations.allocate(
size=size,
tag="weights",
is_connected=lambda: True,
)
assert first.layout_slot == 0
free_after_first = _gpu_memory_free_bytes()
assert free_after_first < free_before - (size // 2)
exported_fd = allocations.export_allocation(first.allocation_id)
holder_ready = tmp_path / "holder.ready"
holder_log = tmp_path / "holder.log"
holder_script = tmp_path / "hold_import.py"
holder_script.write_text(
textwrap.dedent(
"""
import sys
import time
from pathlib import Path
fd = int(sys.argv[1])
Path(sys.argv[2]).write_text(str(fd))
while True:
time.sleep(1.0)
"""
),
encoding="utf-8",
)
with holder_log.open("w", encoding="utf-8") as log_file:
holder = subprocess.Popen(
[
sys.executable,
str(holder_script),
str(exported_fd),
str(holder_ready),
],
pass_fds=[exported_fd],
stdout=log_file,
stderr=subprocess.STDOUT,
start_new_session=True,
)
os.close(exported_fd)
deadline = time.monotonic() + 30.0
while not holder_ready.exists():
assert holder.poll() is None, holder_log.read_text(encoding="utf-8")
assert time.monotonic() < deadline, holder_log.read_text(encoding="utf-8")
await asyncio.sleep(0.1)
allocations.clear_all()
assert allocations.allocation_count == 0
allocation_task = asyncio.create_task(
allocations.allocate(
size=size,
tag="weights",
is_connected=lambda: True,
)
)
deadline = time.monotonic() + 30.0
while oom_failures == 0:
assert holder.poll() is None, holder_log.read_text(encoding="utf-8")
assert not allocation_task.done()
assert time.monotonic() < deadline
await asyncio.sleep(0.1)
assert oom_failures > 0
assert not allocation_task.done()
os.killpg(os.getpgid(holder.pid), signal.SIGKILL)
holder.wait(timeout=30.0)
second = await asyncio.wait_for(allocation_task, timeout=120.0)
assert second.layout_slot == 0
assert allocations.allocation_count == 1
allocations.clear_all()
assert allocations.allocation_count == 0
deadline = time.monotonic() + 30.0
while _gpu_memory_free_bytes() < free_before - (1 << 30):
assert time.monotonic() < deadline
await asyncio.sleep(0.1)
finally:
if allocation_task is not None and not allocation_task.done():
allocation_task.cancel()
try:
await allocation_task
except asyncio.CancelledError:
pass
if allocations.allocation_count > 0:
allocations.clear_all()
if holder is not None and holder.poll() is None:
os.killpg(os.getpgid(holder.pid), signal.SIGKILL)
holder.wait(timeout=30.0)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Targeted GMS fault-tolerance unit tests."""
from __future__ import annotations
import asyncio
import os
import socket
import subprocess
import sys
import time
from dataclasses import dataclass
import pytest
from gpu_memory_service.common import cuda_utils
from gpu_memory_service.common.protocol.messages import (
CommitRequest,
CommitResponse,
GetEventHistoryRequest,
GetLockStateRequest,
GetLockStateResponse,
GetRuntimeStateRequest,
HandshakeRequest,
)
from gpu_memory_service.common.types import (
GrantedLockType,
RequestedLockType,
ServerState,
StateEvent,
)
from gpu_memory_service.server.allocations import GMSAllocationManager
from gpu_memory_service.server.gms import GMS
from gpu_memory_service.server.rpc import GMSRPCServer, _is_connection_alive
from gpu_memory_service.server.session import (
Connection,
GMSSessionManager,
OperationNotAllowed,
)
# Skip entire module if cuda.bindings is not installed
pytest.importorskip("cuda.bindings", reason="cuda.bindings is required")
from cuda.bindings import driver as cuda # noqa: E402
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.gpu_0,
]
def test_cumem_create_tolerate_oom_returns_handle_on_success(monkeypatch):
monkeypatch.setattr(
cuda_utils.cuda,
"cuMemCreate",
lambda size, prop, flags: (cuda.CUresult.CUDA_SUCCESS, 1234),
)
allocated, handle = cuda_utils.cumem_create_tolerate_oom(4096, 0)
assert allocated
assert handle == 1234
def test_cumem_create_tolerate_oom_returns_false_on_oom(monkeypatch):
monkeypatch.setattr(
cuda_utils.cuda,
"cuMemCreate",
lambda size, prop, flags: (cuda.CUresult.CUDA_ERROR_OUT_OF_MEMORY, 0),
)
allocated, handle = cuda_utils.cumem_create_tolerate_oom(4096, 0)
assert not allocated
assert handle == 0
def test_cumem_export_to_shareable_handle_returns_fd(monkeypatch):
monkeypatch.setattr(
cuda_utils.cuda,
"cuMemExportToShareableHandle",
lambda handle, handle_type, flags: (cuda.CUresult.CUDA_SUCCESS, 77),
)
fd = cuda_utils.cumem_export_to_shareable_handle(1234)
assert fd == 77
def test_non_oom_cuda_error_exits_process() -> None:
script = """
from cuda.bindings import driver as cuda
from gpu_memory_service.common import cuda_utils
cuda_utils.cuda_check_result(cuda.CUresult.CUDA_ERROR_INVALID_VALUE, "synthetic")
"""
result = subprocess.run([sys.executable, "-c", script], check=False)
assert result.returncode == 1
class _DummyReader:
def at_eof(self) -> bool:
return False
def exception(self):
return None
class _DummyWriter:
def __init__(self, sock: socket.socket | None = None) -> None:
self.closed = False
self._socket = sock
def close(self) -> None:
self.closed = True
async def wait_closed(self) -> None:
return None
def is_closing(self) -> bool:
return self.closed
def get_extra_info(self, _name: str):
return self._socket
def test_is_connection_alive_detects_dead_peer() -> None:
local_sock, peer_sock = socket.socketpair()
local_sock.setblocking(False)
try:
conn = Connection(
reader=_DummyReader(),
writer=_DummyWriter(local_sock),
mode=GrantedLockType.RW,
session_id="session_1",
recv_buffer=bytearray(),
)
assert _is_connection_alive(conn)
peer_sock.close()
deadline = time.monotonic() + 1.0
while _is_connection_alive(conn):
if time.monotonic() > deadline:
raise TimeoutError("peer disconnect was not detected")
time.sleep(0.01)
finally:
peer_sock.close()
local_sock.close()
def _make_connection(
mode: GrantedLockType,
session_id: str,
) -> tuple[Connection, _DummyWriter]:
writer = _DummyWriter()
return (
Connection(
reader=_DummyReader(),
writer=writer,
mode=mode,
session_id=session_id,
recv_buffer=bytearray(),
),
writer,
)
def _make_allocation_manager() -> GMSAllocationManager:
manager = object.__new__(GMSAllocationManager)
manager._device = 0
manager._allocations = {}
manager._next_layout_slot = 0
manager._granularity = 1
manager._allocation_retry_interval = 0.0
manager._allocation_retry_timeout = None
return manager
@dataclass
class _FakeHandler:
has_committed_layout: bool = False
rw_connect_calls: int = 0
rw_abort_calls: int = 0
commit_calls: int = 0
def on_rw_connect(self) -> None:
self.rw_connect_calls += 1
self.has_committed_layout = False
def on_rw_abort(self) -> None:
self.rw_abort_calls += 1
def on_commit(self) -> None:
self.commit_calls += 1
self.has_committed_layout = True
def handle_get_lock_state(
self,
has_rw: bool,
ro_count: int,
waiting_writers: int,
committed: bool,
) -> GetLockStateResponse:
return GetLockStateResponse(
state=(
"RW"
if has_rw
else "RO"
if ro_count
else "COMMITTED"
if committed
else "EMPTY"
),
has_rw_session=has_rw,
ro_session_count=ro_count,
waiting_writers=waiting_writers,
committed=committed,
is_ready=committed and not has_rw,
)
class _FakeGMS:
def __init__(self, handler: _FakeHandler | None = None):
self.handler = handler or _FakeHandler()
self._sessions = GMSSessionManager()
if self.handler.has_committed_layout:
self._sessions._locking._committed = True
@property
def committed(self) -> bool:
return self._sessions.snapshot().committed
async def acquire_lock(self, mode, timeout_ms, session_id):
return await self._sessions.acquire_lock(mode, timeout_ms, session_id)
async def cancel_connect(self, session_id, mode):
await self._sessions.cancel_connect(session_id, mode)
def on_connect(self, conn: Connection) -> None:
if conn.mode == GrantedLockType.RW:
self.handler.on_rw_connect()
self._sessions.on_connect(conn)
async def cleanup_connection(self, conn: Connection | None) -> None:
event = self._sessions.begin_cleanup(conn)
if event == StateEvent.RW_ABORT:
self.handler.on_rw_abort()
await self._sessions.finish_cleanup(conn)
async def handle_request(self, conn: Connection, msg, _is_connected):
if isinstance(msg, GetLockStateRequest):
snapshot = self._sessions.snapshot()
return (
self.handler.handle_get_lock_state(
snapshot.has_rw_session,
snapshot.ro_session_count,
snapshot.waiting_writers,
snapshot.committed,
),
-1,
False,
)
if isinstance(msg, CommitRequest):
self.handler.on_commit()
self._sessions.on_commit(conn)
return CommitResponse(success=True), -1, True
raise AssertionError(f"Unexpected request type in test: {type(msg)}")
def _make_server(handler: _FakeHandler | None = None) -> GMSRPCServer:
server = object.__new__(GMSRPCServer)
server.socket_path = "/tmp/gms-test.sock"
server.device = 0
server._gms = _FakeGMS(handler)
server._server = None
return server
@pytest.mark.asyncio
async def test_handshake_success_send_failure_cleans_up_rw_state(monkeypatch):
server = _make_server()
reader = _DummyReader()
writer = _DummyWriter()
async def fake_recv_message(_reader, _buffer):
return HandshakeRequest(lock_type=RequestedLockType.RW), -1, bytearray()
async def fake_acquire_lock(_mode, _timeout_ms, _session_id):
server._gms._sessions._reserved_rw_session_id = _session_id
return GrantedLockType.RW
async def fake_send_message(_writer, _msg, _fd=-1):
raise BrokenPipeError("handshake reply failed")
monkeypatch.setattr("gpu_memory_service.server.rpc.recv_message", fake_recv_message)
monkeypatch.setattr(server._gms, "acquire_lock", fake_acquire_lock)
monkeypatch.setattr("gpu_memory_service.server.rpc.send_message", fake_send_message)
conn = await server._do_handshake(reader, writer, "session_1")
assert conn is None
assert server._gms._sessions._locking.rw_conn is None
assert server._gms._sessions.state == ServerState.EMPTY
assert server._gms.handler.rw_connect_calls == 1
assert server._gms.handler.rw_abort_calls == 1
assert writer.closed
@pytest.mark.asyncio
async def test_rw_lock_is_reserved_until_connect():
sessions = GMSSessionManager()
first = await sessions.acquire_lock(
RequestedLockType.RW,
timeout_ms=50,
session_id="session_1",
)
second = await sessions.acquire_lock(
RequestedLockType.RW,
timeout_ms=50,
session_id="session_2",
)
assert first == GrantedLockType.RW
assert second is None
await sessions.cancel_connect("session_1", GrantedLockType.RW)
@pytest.mark.asyncio
async def test_reader_cannot_acquire_while_rw_is_reserved_before_connect():
sessions = GMSSessionManager()
sessions._locking._committed = True
granted = await sessions.acquire_lock(
RequestedLockType.RW,
timeout_ms=50,
session_id="writer_1",
)
assert granted == GrantedLockType.RW
reader = await sessions.acquire_lock(
RequestedLockType.RO,
timeout_ms=50,
session_id="reader_1",
)
assert reader is None
await sessions.cancel_connect("writer_1", GrantedLockType.RW)
@pytest.mark.asyncio
async def test_reader_waiter_wakes_when_waiting_writer_times_out():
sessions = GMSSessionManager()
sessions._locking._committed = True
existing_reader = Connection(
reader=_DummyReader(),
writer=_DummyWriter(),
mode=GrantedLockType.RO,
session_id="reader_1",
recv_buffer=bytearray(),
)
sessions.on_connect(existing_reader)
writer_task = asyncio.create_task(
sessions.acquire_lock(
RequestedLockType.RW,
timeout_ms=50,
session_id="writer_1",
)
)
await asyncio.sleep(0)
reader_task = asyncio.create_task(
sessions.acquire_lock(
RequestedLockType.RO,
timeout_ms=200,
session_id="reader_2",
)
)
assert await writer_task is None
assert await reader_task == GrantedLockType.RO
event = sessions.begin_cleanup(existing_reader)
assert event == StateEvent.RO_DISCONNECT
await sessions.finish_cleanup(existing_reader)
@pytest.mark.asyncio
async def test_rw_or_ro_waiter_becomes_rw_after_writer_abort():
sessions = GMSSessionManager()
writer_mode = await sessions.acquire_lock(
RequestedLockType.RW,
timeout_ms=50,
session_id="writer_1",
)
assert writer_mode == GrantedLockType.RW
writer = Connection(
reader=_DummyReader(),
writer=_DummyWriter(),
mode=GrantedLockType.RW,
session_id="writer_1",
recv_buffer=bytearray(),
)
sessions.on_connect(writer)
waiter = asyncio.create_task(
sessions.acquire_lock(
RequestedLockType.RW_OR_RO,
timeout_ms=200,
session_id="waiter_1",
)
)
await asyncio.sleep(0)
assert not waiter.done()
event = sessions.begin_cleanup(writer)
assert event == StateEvent.RW_ABORT
await sessions.finish_cleanup(writer)
assert await waiter == GrantedLockType.RW
await sessions.cancel_connect("waiter_1", GrantedLockType.RW)
@pytest.mark.asyncio
async def test_gms_clears_aborted_rw_layout_before_waking_waiters():
gms = object.__new__(GMS)
cleanup_order: list[str] = []
conn, _ = _make_connection(GrantedLockType.RW, "session_1")
gms._events = []
def begin_cleanup(self, cleanup_conn):
cleanup_order.append("begin_cleanup")
return StateEvent.RW_ABORT
async def finish_cleanup(self, cleanup_conn):
cleanup_order.append("finish_cleanup")
def clear_layout_state():
cleanup_order.append("clear_layout_state")
return 3
gms._sessions = type(
"_Sessions",
(),
{
"begin_cleanup": begin_cleanup,
"finish_cleanup": finish_cleanup,
},
)()
gms._clear_layout_state = clear_layout_state
await gms.cleanup_connection(conn)
assert cleanup_order == [
"begin_cleanup",
"clear_layout_state",
"finish_cleanup",
]
@pytest.mark.asyncio
async def test_request_response_send_failure_disconnects_without_error_response(
monkeypatch,
):
handler = _FakeHandler(has_committed_layout=True)
server = _make_server(handler)
server._gms._sessions._locking._committed = True
conn, writer = _make_connection(GrantedLockType.RO, "session_2")
server._gms._sessions._locking.transition(StateEvent.RO_CONNECT, conn)
recv_calls = 0
sent_messages: list[object] = []
async def fake_recv_message(_reader, _buffer):
nonlocal recv_calls
recv_calls += 1
return GetLockStateRequest(), -1, bytearray()
async def fake_send_message(_writer, msg, _fd=-1):
sent_messages.append(msg)
raise BrokenPipeError("response send failed")
monkeypatch.setattr("gpu_memory_service.server.rpc.recv_message", fake_recv_message)
monkeypatch.setattr("gpu_memory_service.server.rpc.send_message", fake_send_message)
await server._request_loop(conn)
await server._gms.cleanup_connection(conn)
assert recv_calls == 1
assert len(sent_messages) == 1
assert isinstance(sent_messages[0], GetLockStateResponse)
assert conn not in server._gms._sessions._locking.ro_conns
assert server._gms._sessions.state == ServerState.COMMITTED
assert writer.closed
@pytest.mark.asyncio
async def test_post_commit_response_send_failure_stays_committed(monkeypatch):
handler = _FakeHandler()
server = _make_server(handler)
conn, writer = _make_connection(GrantedLockType.RW, "session_3")
server._gms._sessions._locking.transition(StateEvent.RW_CONNECT, conn)
recv_calls = 0
sent_messages: list[object] = []
async def fake_recv_message(_reader, _buffer):
nonlocal recv_calls
recv_calls += 1
return CommitRequest(), -1, bytearray()
async def fake_send_message(_writer, msg, _fd=-1):
sent_messages.append(msg)
raise BrokenPipeError("commit reply failed")
monkeypatch.setattr("gpu_memory_service.server.rpc.recv_message", fake_recv_message)
monkeypatch.setattr("gpu_memory_service.server.rpc.send_message", fake_send_message)
await server._request_loop(conn)
await server._gms.cleanup_connection(conn)
assert recv_calls == 1
assert len(sent_messages) == 1
assert handler.commit_calls == 1
assert server._gms._sessions._locking.rw_conn is None
assert server._gms._sessions.snapshot().committed
assert server._gms._sessions.state == ServerState.COMMITTED
assert writer.closed
@pytest.mark.asyncio
async def test_runtime_state_handshake_send_failure_does_not_fail_server(monkeypatch):
server = _make_server()
reader = _DummyReader()
writer = _DummyWriter()
async def fake_recv_message(_reader, _buffer):
return GetRuntimeStateRequest(), -1, bytearray()
async def fake_send_message(_writer, _msg, _fd=-1):
raise BrokenPipeError("runtime-state send failed")
monkeypatch.setattr("gpu_memory_service.server.rpc.recv_message", fake_recv_message)
monkeypatch.setattr("gpu_memory_service.server.rpc.send_message", fake_send_message)
conn = await server._do_handshake(reader, writer, "session_diag")
assert conn is None
assert server._gms._sessions.state == ServerState.EMPTY
assert writer.closed
@pytest.mark.asyncio
async def test_runtime_state_request_is_rejected_on_live_session():
gms = GMS()
conn, _ = _make_connection(GrantedLockType.RW, "session_4")
gms._sessions._reserved_rw_session_id = conn.session_id
gms.on_connect(conn)
with pytest.raises(OperationNotAllowed):
await gms.handle_request(
conn,
GetRuntimeStateRequest(),
lambda: True,
)
@pytest.mark.asyncio
async def test_event_history_request_is_rejected_on_live_session():
gms = GMS()
conn, _ = _make_connection(GrantedLockType.RW, "session_5")
gms._sessions._reserved_rw_session_id = conn.session_id
gms.on_connect(conn)
with pytest.raises(OperationNotAllowed):
await gms.handle_request(
conn,
GetEventHistoryRequest(),
lambda: True,
)
@pytest.mark.asyncio
async def test_server_refuses_to_bind_over_live_socket(monkeypatch, tmp_path):
socket_path = str(tmp_path / "gms.sock")
listener = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
listener.bind(socket_path)
listener.listen(1)
monkeypatch.setattr(
"gpu_memory_service.server.allocations.cuda_ensure_initialized",
lambda: None,
)
monkeypatch.setattr(
"gpu_memory_service.server.allocations.cumem_get_allocation_granularity",
lambda device: 4096,
)
server = GMSRPCServer(socket_path, device=0)
try:
with pytest.raises(RuntimeError, match="already running"):
await asyncio.wait_for(server.serve(), timeout=0.1)
finally:
listener.close()
if os.path.exists(socket_path):
os.unlink(socket_path)
@pytest.mark.asyncio
async def test_allocate_rejects_non_positive_size_before_cuda():
manager = _make_allocation_manager()
with pytest.raises(ValueError, match="size must be > 0"):
await manager.allocate(0, tag="weights", is_connected=None)
@pytest.mark.asyncio
async def test_allocate_aborts_retry_when_writer_disconnects(monkeypatch):
manager = _make_allocation_manager()
checks = 0
def is_connected() -> bool:
nonlocal checks
checks += 1
return checks < 2
monkeypatch.setattr(
"gpu_memory_service.server.allocations.cumem_create_tolerate_oom",
lambda size, device: (False, 0),
)
with pytest.raises(
ConnectionAbortedError, match="RW client disconnected during allocation retry"
):
await manager.allocate(1, tag="weights", is_connected=is_connected)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import sys
import types
import pytest
from gpu_memory_service.integrations.sglang import patches as sglang_patches
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.gpu_0,
]
def test_patch_model_runner_rewrites_total_gpu_memory(monkeypatch):
fake_sglang = types.ModuleType("sglang")
fake_srt = types.ModuleType("sglang.srt")
fake_executor = types.ModuleType("sglang.srt.model_executor")
fake_model_runner = types.ModuleType("sglang.srt.model_executor.model_runner")
class FakeModelRunner:
def init_memory_pool(self, total_gpu_memory):
self.seen_total_gpu_memory = total_gpu_memory
return total_gpu_memory
fake_model_runner.ModelRunner = FakeModelRunner
fake_sglang.srt = fake_srt
fake_srt.model_executor = fake_executor
fake_executor.model_runner = fake_model_runner
fake_memory_saver = types.ModuleType(
"gpu_memory_service.integrations.sglang.memory_saver"
)
class FakeImpl:
def get_imported_weights_bytes(self):
return 8 << 30
fake_memory_saver.get_gms_memory_saver_impl = lambda: FakeImpl()
monkeypatch.setitem(sys.modules, "sglang", fake_sglang)
monkeypatch.setitem(sys.modules, "sglang.srt", fake_srt)
monkeypatch.setitem(sys.modules, "sglang.srt.model_executor", fake_executor)
monkeypatch.setitem(
sys.modules,
"sglang.srt.model_executor.model_runner",
fake_model_runner,
)
monkeypatch.setitem(
sys.modules,
"gpu_memory_service.integrations.sglang.memory_saver",
fake_memory_saver,
)
monkeypatch.setattr(sglang_patches, "_model_runner_patched", False)
monkeypatch.delattr(FakeModelRunner, "_gms_patched", raising=False)
monkeypatch.setattr(
sglang_patches.torch.cuda,
"current_device",
lambda: 0,
)
monkeypatch.setattr(
sglang_patches.torch.cuda,
"get_device_properties",
lambda device: types.SimpleNamespace(total_memory=80 * (1 << 30)),
)
sglang_patches.patch_model_runner()
runner = FakeModelRunner()
assert runner.init_memory_pool(40.0) == pytest.approx(80.0)
assert runner.seen_total_gpu_memory == pytest.approx(80.0)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import pytest
from gpu_memory_service.client.torch import allocator as allocator_module
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.gpu_0,
]
class _FakeManager:
def __init__(self, socket_path: str, *, device: int):
self.socket_path = socket_path
self.device = device
self._connected = False
self._granted_lock_type: GrantedLockType | None = None
self._mappings: dict[int, object] = {}
self._unmapped = False
@property
def granted_lock_type(self) -> GrantedLockType | None:
return self._granted_lock_type
@property
def is_connected(self) -> bool:
return self._connected
@property
def mappings(self) -> dict[int, object]:
return self._mappings
@property
def is_unmapped(self) -> bool:
return self._unmapped
def connect(self, mode: RequestedLockType, timeout_ms: int | None = None) -> None:
del timeout_ms
self._connected = True
if mode == RequestedLockType.RW:
self._granted_lock_type = GrantedLockType.RW
return
self._granted_lock_type = GrantedLockType.RO
@property
def _client_rpc(self):
return self
def get_lock_state(self) -> object:
return object()
@pytest.fixture(autouse=True)
def clear_tag_states():
allocator_module._tag_states.clear()
yield
allocator_module._tag_states.clear()
@pytest.fixture
def fake_allocator(monkeypatch):
monkeypatch.setattr(
"gpu_memory_service.client.memory_manager.GMSClientMemoryManager",
_FakeManager,
)
monkeypatch.setattr(
allocator_module,
"_ensure_callbacks_initialized",
lambda: None,
)
monkeypatch.setattr(allocator_module, "_create_mem_pool", lambda: object())
def test_tag_registry_rejects_socket_or_device_mismatch(fake_allocator):
manager = allocator_module.get_or_create_gms_client_memory_manager(
"/tmp/weights.sock",
0,
RequestedLockType.RO,
tag="weights",
)
with pytest.raises(RuntimeError, match="initialized for"):
allocator_module.get_or_create_gms_client_memory_manager(
"/tmp/other.sock",
0,
RequestedLockType.RO,
tag="weights",
)
with pytest.raises(RuntimeError, match="initialized for"):
allocator_module.get_or_create_gms_client_memory_manager(
"/tmp/weights.sock",
1,
RequestedLockType.RO,
tag="weights",
)
assert manager.is_connected
def test_tag_registry_recreates_disconnected_empty_manager(fake_allocator):
first = allocator_module.get_or_create_gms_client_memory_manager(
"/tmp/weights.sock",
0,
RequestedLockType.RO,
tag="weights",
)
first._connected = False
first._granted_lock_type = None
second = allocator_module.get_or_create_gms_client_memory_manager(
"/tmp/weights.sock",
0,
RequestedLockType.RO,
tag="weights",
)
assert second is not first
assert second.is_connected
assert second.granted_lock_type == GrantedLockType.RO
def test_tag_registry_rejects_disconnected_manager_with_preserved_state(
fake_allocator,
):
manager = allocator_module.get_or_create_gms_client_memory_manager(
"/tmp/weights.sock",
0,
RequestedLockType.RO,
tag="weights",
)
manager._connected = False
manager._mappings[0x1000] = object()
with pytest.raises(RuntimeError, match="preserved state"):
allocator_module.get_or_create_gms_client_memory_manager(
"/tmp/weights.sock",
0,
RequestedLockType.RO,
tag="weights",
)
def test_close_evicts_manager_from_tag_registry(fake_allocator):
manager = allocator_module.get_or_create_gms_client_memory_manager(
"/tmp/weights.sock",
0,
RequestedLockType.RO,
tag="weights",
)
allocator_module.evict_gms_client_memory_manager(manager)
assert allocator_module.get_gms_client_memory_manager("weights") is None
...@@ -5,11 +5,14 @@ ...@@ -5,11 +5,14 @@
import pytest import pytest
# Skip collection entirely if gpu_memory_service is not installed # Skip collection entirely if gpu_memory_service is not installed.
# This package lives under nested common/ and integration/ subdirectories, so
# we ignore those directories directly instead of only matching test files next
# to this conftest.
try: try:
import gpu_memory_service # noqa: F401 import gpu_memory_service # noqa: F401
except ImportError: except ImportError:
collect_ignore_glob = ["test_*.py"] collect_ignore = ["common", "integration"]
from tests.utils.port_utils import allocate_port, deallocate_ports from tests.utils.port_utils import allocate_port, deallocate_ports
...@@ -20,28 +23,50 @@ def gms_ports(): ...@@ -20,28 +23,50 @@ def gms_ports():
Returns a dict with ports for: Returns a dict with ports for:
- frontend: Frontend HTTP port - frontend: Frontend HTTP port
- shadow_system: System port for shadow/primary engine - shadow_system: System port for the first shadow engine
- shadow2_system: System port for the second shadow engine
- primary_system: System port for primary engine (failover test only) - primary_system: System port for primary engine (failover test only)
- shadow_kv_event: KV event port for shadow engine (vLLM) - shadow_kv_event: KV event port for the first shadow engine (vLLM)
- shadow2_kv_event: KV event port for the second shadow engine (vLLM)
- primary_kv_event: KV event port for primary engine (vLLM) - primary_kv_event: KV event port for primary engine (vLLM)
- shadow_nixl: NIXL side channel port for shadow engine (vLLM) - shadow_nixl: NIXL side channel port for the first shadow engine (vLLM)
- shadow2_nixl: NIXL side channel port for the second shadow engine (vLLM)
- primary_nixl: NIXL side channel port for primary engine (vLLM) - primary_nixl: NIXL side channel port for primary engine (vLLM)
- shadow_sglang: SGLang HTTP port for shadow engine - shadow_sglang: SGLang HTTP port for the first shadow engine
- shadow2_sglang: SGLang HTTP port for the second shadow engine
- primary_sglang: SGLang HTTP port for primary engine - primary_sglang: SGLang HTTP port for primary engine
""" """
ports = [ ports = [
allocate_port(p) allocate_port(p)
for p in [8200, 8100, 8101, 20080, 20081, 20096, 20097, 30000, 30001] for p in [
8200,
8100,
8101,
8102,
20080,
20081,
20082,
20096,
20097,
20098,
30000,
30001,
30002,
]
] ]
yield { yield {
"frontend": ports[0], "frontend": ports[0],
"shadow_system": ports[1], "shadow_system": ports[1],
"primary_system": ports[2], "primary_system": ports[2],
"shadow_kv_event": ports[3], "shadow2_system": ports[3],
"primary_kv_event": ports[4], "shadow_kv_event": ports[4],
"shadow_nixl": ports[5], "primary_kv_event": ports[5],
"primary_nixl": ports[6], "shadow2_kv_event": ports[6],
"shadow_sglang": ports[7], "shadow_nixl": ports[7],
"primary_sglang": ports[8], "primary_nixl": ports[8],
"shadow2_nixl": ports[9],
"shadow_sglang": ports[10],
"primary_sglang": ports[11],
"shadow2_sglang": ports[12],
} }
deallocate_ports(ports) deallocate_ports(ports)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import argparse
import os
import socket
import subprocess
from contextlib import contextmanager
import torch
from gpu_memory_service import get_or_create_gms_client_memory_manager
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
from gpu_memory_service.client.torch.allocator import gms_use_mem_pool
from gpu_memory_service.client.torch.module import register_module_tensors
from gpu_memory_service.common.types import RequestedLockType
from gpu_memory_service.common.utils import get_socket_path
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from .runtime import DYNAMO_BIN, REPO_ROOT
SGLANG_BIN = REPO_ROOT / "dynamo-sglang" / "bin"
def get_external_weight_writer_command(backend: str) -> list[str]:
return [
"python",
"-m",
"tests.gms.harness.external_weight_writer",
"--backend",
backend,
]
def get_external_weight_writer_env(backend: str) -> dict[str, str]:
if backend == "sglang":
return {
**os.environ,
"PATH": f"/usr/local/cuda/bin:{SGLANG_BIN}:{os.environ.get('PATH', '')}",
"CC": "/usr/bin/gcc",
"CXX": "/usr/bin/g++",
"PYTHONPATH": str(REPO_ROOT),
}
return {
**os.environ,
"PATH": f"{DYNAMO_BIN}:{os.environ.get('PATH', '')}",
"PYTHONPATH": str(REPO_ROOT),
}
def run_external_weight_writer(backend: str) -> None:
command = get_external_weight_writer_command(backend)
subprocess.run(
command,
cwd=REPO_ROOT,
env=get_external_weight_writer_env(backend),
check=True,
)
def _get_writer_manager(device: int, tag: str) -> GMSClientMemoryManager:
return get_or_create_gms_client_memory_manager(
get_socket_path(device, tag),
device,
RequestedLockType.RW,
tag=tag,
)
def _publish_model(manager: GMSClientMemoryManager, model: torch.nn.Module) -> None:
register_module_tensors(manager, model)
torch.cuda.synchronize()
manager.commit()
manager.close()
@contextmanager
def _vllm_single_rank_distributed(device: int):
from vllm.distributed.parallel_state import (
destroy_distributed_environment,
destroy_model_parallel,
ensure_model_parallel_initialized,
init_distributed_environment,
)
probe = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
probe.bind(("127.0.0.1", 0))
host, port = probe.getsockname()
probe.close()
init_distributed_environment(
world_size=1,
rank=0,
local_rank=device,
distributed_init_method=f"tcp://{host}:{port}",
backend="gloo",
)
ensure_model_parallel_initialized(1, 1)
try:
yield
finally:
destroy_model_parallel()
destroy_distributed_environment()
@contextmanager
def _sglang_single_rank_distributed(device: int):
from sglang.srt.distributed.parallel_state import (
destroy_distributed_environment,
destroy_model_parallel,
init_distributed_environment,
initialize_model_parallel,
)
probe = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
probe.bind(("127.0.0.1", 0))
host, port = probe.getsockname()
probe.close()
init_distributed_environment(
world_size=1,
rank=0,
local_rank=device,
distributed_init_method=f"tcp://{host}:{port}",
backend="gloo",
)
initialize_model_parallel(1, 1, 1, backend="gloo")
try:
yield
finally:
destroy_model_parallel()
destroy_distributed_environment()
def _publish_vllm_dummy_weights(device: int, tag: str) -> None:
from vllm.config import (
DeviceConfig,
LoadConfig,
ModelConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader
from vllm.model_executor.model_loader.utils import (
initialize_model,
process_weights_after_loading,
)
from vllm.utils.torch_utils import set_default_torch_dtype
torch.cuda.set_device(device)
model_config = ModelConfig(model=FAULT_TOLERANCE_MODEL_NAME, enforce_eager=True)
load_config = LoadConfig(load_format="dummy")
device_config = DeviceConfig(device="cuda")
vllm_config = VllmConfig(
model_config=model_config,
device_config=device_config,
load_config=load_config,
)
target_device = torch.device("cuda", device)
manager = _get_writer_manager(device, tag)
with set_current_vllm_config(vllm_config):
with _vllm_single_rank_distributed(device):
with set_default_torch_dtype(model_config.dtype):
with gms_use_mem_pool(tag, target_device):
with target_device:
model = initialize_model(
vllm_config=vllm_config,
model_config=model_config,
)
DummyModelLoader(load_config).load_weights(model, model_config)
process_weights_after_loading(model, model_config, target_device)
_publish_model(manager, model)
def _publish_sglang_dummy_weights(device: int, tag: str) -> None:
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.layers.dp_attention import initialize_dp_attention
from sglang.srt.model_loader.loader import LoadConfig, get_model_loader
from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler
torch.cuda.set_device(device)
model_config = ModelConfig(FAULT_TOLERANCE_MODEL_NAME)
device_config = DeviceConfig(device="cuda", gpu_id=device)
load_config = LoadConfig(load_format="dummy")
loader = get_model_loader(load_config, model_config)
manager = _get_writer_manager(device, tag)
server_args = ServerArgs(
model_path=FAULT_TOLERANCE_MODEL_NAME,
load_format="dummy",
device="cuda",
)
set_global_server_args_for_scheduler(server_args)
with _sglang_single_rank_distributed(device):
initialize_dp_attention(server_args, model_config)
with gms_use_mem_pool(tag, torch.device("cuda", device)):
model = loader.load_model(
model_config=model_config,
device_config=device_config,
)
_publish_model(manager, model)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--backend", choices=("vllm", "sglang"), required=True)
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--tag", default="weights")
args = parser.parse_args()
if args.backend == "vllm":
_publish_vllm_dummy_weights(args.device, args.tag)
return
_publish_sglang_dummy_weights(args.device, args.tag)
if __name__ == "__main__":
main()
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import asyncio
import os
import shutil
import threading
import time
from concurrent.futures import TimeoutError as FutureTimeoutError
from gpu_memory_service.client.rpc import _GMSRPCTransport
from gpu_memory_service.common.protocol.messages import (
GetEventHistoryRequest,
GetEventHistoryResponse,
GetRuntimeStateRequest,
GetRuntimeStateResponse,
)
from gpu_memory_service.common.utils import get_socket_path
from gpu_memory_service.server.rpc import GMSRPCServer
from tests.utils.managed_process import ManagedProcess
from .runtime import DYNAMO_BIN
def _socket_has_live_gms(socket_path: str) -> bool:
if not os.path.exists(socket_path):
return False
try:
with _GMSRPCTransport(socket_path) as transport:
transport.connect()
transport.request(
GetRuntimeStateRequest(),
GetRuntimeStateResponse,
)
except Exception:
return False
return True
def _prepare_socket_path_for_launch(socket_path: str) -> None:
if not os.path.exists(socket_path):
return
if _socket_has_live_gms(socket_path):
raise RuntimeError(f"GMS already active at {socket_path}")
os.unlink(socket_path)
class GMSServerProcess(ManagedProcess):
def __init__(self, request, device: int, tag: str = "weights"):
self.device = device
self.tag = tag
self.socket_path = get_socket_path(device, tag)
log_dir = f"{request.node.name}_gms_{tag}_{device}"
shutil.rmtree(log_dir, ignore_errors=True)
super().__init__(
command=[
"python",
"-m",
"gpu_memory_service",
"--device",
str(device),
"--tag",
tag,
],
env={
**os.environ,
"PATH": f"{DYNAMO_BIN}:{os.environ.get('PATH', '')}",
"DYN_LOG": "debug",
},
timeout=60,
display_output=True,
terminate_all_matching_process_names=False,
log_dir=log_dir,
display_name=f"gms_{tag}",
health_check_funcs=[self._runtime_state_ready],
)
def __enter__(self):
_prepare_socket_path_for_launch(self.socket_path)
return super().__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
try:
return super().__exit__(exc_type, exc_val, exc_tb)
finally:
if os.path.exists(self.socket_path) and not _socket_has_live_gms(
self.socket_path
):
os.unlink(self.socket_path)
def _socket_has_live_gms(self) -> bool:
return _socket_has_live_gms(self.socket_path)
def _prepare_socket_path_for_launch(self) -> None:
_prepare_socket_path_for_launch(self.socket_path)
def _runtime_state_ready(self, timeout: float = 30) -> bool:
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
if not os.path.exists(self.socket_path):
time.sleep(0.1)
continue
try:
self.get_runtime_state()
return True
except Exception:
time.sleep(0.1)
continue
time.sleep(0.1)
return False
def get_runtime_state(self) -> GetRuntimeStateResponse:
with _GMSRPCTransport(self.socket_path) as transport:
transport.connect()
return transport.request(
GetRuntimeStateRequest(),
GetRuntimeStateResponse,
)
def get_event_history(self) -> GetEventHistoryResponse:
with _GMSRPCTransport(self.socket_path) as transport:
transport.connect()
return transport.request(
GetEventHistoryRequest(),
GetEventHistoryResponse,
)
class ServerThread:
def __init__(self, server, socket_path: str):
self.server = server
self.socket_path = socket_path
self._loop: asyncio.AbstractEventLoop | None = None
self._task: asyncio.Task[None] | None = None
self._thread = threading.Thread(target=self._run, daemon=True)
self._exception: BaseException | None = None
def _run(self) -> None:
loop = asyncio.new_event_loop()
self._loop = loop
asyncio.set_event_loop(loop)
self._task = loop.create_task(self.server.serve())
try:
loop.run_until_complete(self._task)
except asyncio.CancelledError:
pass
except BaseException as exc:
self._exception = exc
finally:
pending = [task for task in asyncio.all_tasks(loop) if not task.done()]
for task in pending:
task.cancel()
if pending:
loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True)
)
loop.close()
def start(self) -> None:
self._thread.start()
deadline = time.monotonic() + 5.0
while True:
if self._exception is not None:
raise self._exception
if self.server._server is not None and os.path.exists(self.socket_path):
try:
with _GMSRPCTransport(self.socket_path) as transport:
transport.connect()
transport.request(
GetRuntimeStateRequest(),
GetRuntimeStateResponse,
)
return
except Exception:
pass
if time.monotonic() > deadline:
raise TimeoutError(f"GMS socket did not appear at {self.socket_path}")
time.sleep(0.01)
def stop(self) -> None:
if self._loop is not None:
def cancel() -> None:
if self.server._server is not None:
self.server._server.close()
if self._task is not None:
self._task.cancel()
self._loop.call_soon_threadsafe(cancel)
self._thread.join(timeout=5)
if self._exception is not None:
raise self._exception
if os.path.exists(self.socket_path):
os.unlink(self.socket_path)
def disconnect_rw_session(self, timeout: float = 5.0) -> None:
if self._loop is None:
raise RuntimeError("GMS server thread is not running")
future = asyncio.run_coroutine_threadsafe(
self._disconnect_rw_session(), self._loop
)
try:
future.result(timeout=timeout)
except FutureTimeoutError as exc:
raise TimeoutError("Timed out disconnecting RW session") from exc
async def _disconnect_rw_session(self) -> None:
conn = self.server._gms._sessions._locking.rw_conn
if conn is None:
raise RuntimeError("No active RW session to disconnect")
await self.server._gms.cleanup_connection(conn)
class ThreadedGMSServer:
def __init__(self, device: int, tag: str = "weights"):
self.device = device
self.tag = tag
self.socket_path = get_socket_path(device, tag)
self.server = GMSRPCServer(self.socket_path, device)
self._thread = ServerThread(self.server, self.socket_path)
def __enter__(self) -> "ThreadedGMSServer":
_prepare_socket_path_for_launch(self.socket_path)
self._thread.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self._thread.stop()
def get_runtime_state(self) -> GetRuntimeStateResponse:
with _GMSRPCTransport(self.socket_path) as transport:
transport.connect()
return transport.request(
GetRuntimeStateRequest(),
GetRuntimeStateResponse,
)
def get_event_history(self) -> GetEventHistoryResponse:
with _GMSRPCTransport(self.socket_path) as transport:
transport.connect()
return transport.request(
GetEventHistoryRequest(),
GetEventHistoryResponse,
)
def disconnect_rw_session(self) -> None:
self._thread.disconnect_rw_session()
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import logging
import time
from pathlib import Path
import pynvml
import requests
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
logger = logging.getLogger(__name__)
REPO_ROOT = Path(__file__).resolve().parents[3]
DYNAMO_BIN = REPO_ROOT / "dynamo" / "bin"
MIN_EXPECTED_MEMORY_RETURN_FRACTION = 0.6
def get_gpu_memory_used(device: int = 0) -> int:
pynvml.nvmlInit()
try:
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
return pynvml.nvmlDeviceGetMemoryInfo(handle).used
finally:
pynvml.nvmlShutdown()
def send_completion(
port: int,
prompt: str = "Hello",
max_retries: int = 3,
retry_delay: float = 1.0,
) -> dict:
last_error = None
for attempt in range(max_retries):
try:
response = requests.post(
f"http://localhost:{port}/v1/completions",
json={
"model": FAULT_TOLERANCE_MODEL_NAME,
"prompt": prompt,
"max_tokens": 20,
},
timeout=120,
)
response.raise_for_status()
result = response.json()
assert result.get("choices"), "No choices in response"
if attempt > 0:
logger.info("send_completion succeeded after %d attempts", attempt + 1)
return result
except (requests.exceptions.RequestException, AssertionError) as exc:
last_error = exc
if attempt < max_retries - 1:
logger.debug(
"send_completion attempt %d/%d failed: %s",
attempt + 1,
max_retries,
exc,
)
time.sleep(retry_delay)
raise last_error # type: ignore[misc]
...@@ -13,7 +13,10 @@ from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME ...@@ -13,7 +13,10 @@ from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import ManagedProcess from tests.utils.managed_process import ManagedProcess
from tests.utils.payloads import check_health_generate, check_models_api from tests.utils.payloads import check_health_generate, check_models_api
from .runtime import REPO_ROOT
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SGLANG_BIN = REPO_ROOT / "dynamo-sglang" / "bin"
class SGLangWithGMSProcess(ManagedProcess): class SGLangWithGMSProcess(ManagedProcess):
...@@ -26,6 +29,8 @@ class SGLangWithGMSProcess(ManagedProcess): ...@@ -26,6 +29,8 @@ class SGLangWithGMSProcess(ManagedProcess):
system_port: int, system_port: int,
sglang_port: int, sglang_port: int,
frontend_port: int, frontend_port: int,
*,
read_only_weights: bool = False,
): ):
self.engine_id = engine_id self.engine_id = engine_id
self.system_port = system_port self.system_port = system_port
...@@ -33,23 +38,34 @@ class SGLangWithGMSProcess(ManagedProcess): ...@@ -33,23 +38,34 @@ class SGLangWithGMSProcess(ManagedProcess):
log_dir = f"{request.node.name}_{engine_id}" log_dir = f"{request.node.name}_{engine_id}"
shutil.rmtree(log_dir, ignore_errors=True) shutil.rmtree(log_dir, ignore_errors=True)
command = [
"python",
"-m",
"dynamo.sglang",
"--model-path",
FAULT_TOLERANCE_MODEL_NAME,
"--load-format",
"gms",
"--enable-memory-saver",
"--mem-fraction-static",
"0.9",
"--port",
str(sglang_port),
]
if read_only_weights:
command.extend(
[
"--model-loader-extra-config",
'{"gms_read_only": true}',
]
)
super().__init__( super().__init__(
command=[ command=command,
"python3",
"-m",
"dynamo.sglang",
"--model-path",
FAULT_TOLERANCE_MODEL_NAME,
"--load-format",
"gms",
"--enable-memory-saver",
"--mem-fraction-static",
"0.9",
"--port",
str(sglang_port),
],
env={ env={
**os.environ, **os.environ,
"PATH": f"/usr/local/cuda/bin:{SGLANG_BIN}:{os.environ.get('PATH', '')}",
"CC": "/usr/bin/gcc",
"CXX": "/usr/bin/g++",
"DYN_LOG": "debug", "DYN_LOG": "debug",
"DYN_SYSTEM_PORT": str(system_port), "DYN_SYSTEM_PORT": str(system_port),
}, },
...@@ -63,6 +79,7 @@ class SGLangWithGMSProcess(ManagedProcess): ...@@ -63,6 +79,7 @@ class SGLangWithGMSProcess(ManagedProcess):
terminate_all_matching_process_names=False, terminate_all_matching_process_names=False,
stragglers=[], stragglers=[],
log_dir=log_dir, log_dir=log_dir,
display_name=engine_id,
) )
def _is_ready(self, response) -> bool: def _is_ready(self, response) -> bool:
...@@ -72,22 +89,22 @@ class SGLangWithGMSProcess(ManagedProcess): ...@@ -72,22 +89,22 @@ class SGLangWithGMSProcess(ManagedProcess):
return False return False
def sleep(self) -> dict: def sleep(self) -> dict:
"""Put the engine to sleep, offloading weights from GPU memory.""" """Put the engine to sleep, offloading weights and KV cache."""
r = requests.post( r = requests.post(
f"http://localhost:{self.system_port}/engine/release_memory_occupation", f"http://localhost:{self.system_port}/engine/release_memory_occupation",
json={}, json={"tags": ["weights", "kv_cache"]},
timeout=30, timeout=30,
) )
r.raise_for_status() r.raise_for_status()
logger.info(f"{self.engine_id} release_memory_occupation: {r.json()}") logger.info(f"{self.engine_id} release_memory_occupation: {r.json()}")
return r.json() return r.json()
def wake(self) -> dict: def wake(self, timeout: int = 30) -> dict:
"""Wake the engine, reloading weights to GPU memory.""" """Wake the engine, restoring weights and KV cache."""
r = requests.post( r = requests.post(
f"http://localhost:{self.system_port}/engine/resume_memory_occupation", f"http://localhost:{self.system_port}/engine/resume_memory_occupation",
json={}, json={"tags": ["weights", "kv_cache"]},
timeout=30, timeout=timeout,
) )
r.raise_for_status() r.raise_for_status()
logger.info(f"{self.engine_id} resume_memory_occupation: {r.json()}") logger.info(f"{self.engine_id} resume_memory_occupation: {r.json()}")
......
...@@ -14,6 +14,8 @@ from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME ...@@ -14,6 +14,8 @@ from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import ManagedProcess from tests.utils.managed_process import ManagedProcess
from tests.utils.payloads import check_health_generate, check_models_api from tests.utils.payloads import check_health_generate, check_models_api
from .runtime import DYNAMO_BIN
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -28,6 +30,8 @@ class VLLMWithGMSProcess(ManagedProcess): ...@@ -28,6 +30,8 @@ class VLLMWithGMSProcess(ManagedProcess):
kv_event_port: int, kv_event_port: int,
nixl_port: int, nixl_port: int,
frontend_port: int, frontend_port: int,
*,
read_only_weights: bool = False,
): ):
self.engine_id = engine_id self.engine_id = engine_id
self.system_port = system_port self.system_port = system_port
...@@ -43,23 +47,33 @@ class VLLMWithGMSProcess(ManagedProcess): ...@@ -43,23 +47,33 @@ class VLLMWithGMSProcess(ManagedProcess):
"enable_kv_cache_events": True, "enable_kv_cache_events": True,
} }
) )
command = [
"python",
"-m",
"dynamo.vllm",
"--model",
FAULT_TOLERANCE_MODEL_NAME,
"--load-format",
"gms",
"--enforce-eager",
"--enable-sleep-mode",
"--gpu-memory-utilization",
"0.9",
"--kv-events-config",
kv_events_cfg,
]
if read_only_weights:
command.extend(
[
"--model-loader-extra-config",
json.dumps({"gms_read_only": True}),
]
)
super().__init__( super().__init__(
command=[ command=command,
"python3",
"-m",
"dynamo.vllm",
"--model",
FAULT_TOLERANCE_MODEL_NAME,
"--load-format",
"gms",
"--enable-sleep-mode",
"--gpu-memory-utilization",
"0.9",
"--kv-events-config",
kv_events_cfg,
],
env={ env={
**os.environ, **os.environ,
"PATH": f"{DYNAMO_BIN}:{os.environ.get('PATH', '')}",
"DYN_LOG": "debug", "DYN_LOG": "debug",
"DYN_SYSTEM_PORT": str(system_port), "DYN_SYSTEM_PORT": str(system_port),
"VLLM_NIXL_SIDE_CHANNEL_PORT": str(nixl_port), "VLLM_NIXL_SIDE_CHANNEL_PORT": str(nixl_port),
...@@ -74,6 +88,7 @@ class VLLMWithGMSProcess(ManagedProcess): ...@@ -74,6 +88,7 @@ class VLLMWithGMSProcess(ManagedProcess):
terminate_all_matching_process_names=False, terminate_all_matching_process_names=False,
stragglers=[], stragglers=[],
log_dir=log_dir, log_dir=log_dir,
display_name=engine_id,
) )
def _is_ready(self, response) -> bool: def _is_ready(self, response) -> bool:
...@@ -83,20 +98,22 @@ class VLLMWithGMSProcess(ManagedProcess): ...@@ -83,20 +98,22 @@ class VLLMWithGMSProcess(ManagedProcess):
return False return False
def sleep(self) -> dict: def sleep(self) -> dict:
"""Put the engine to sleep, offloading weights from GPU memory.""" """Put the engine to sleep, offloading weights and KV cache."""
r = requests.post( r = requests.post(
f"http://localhost:{self.system_port}/engine/sleep", f"http://localhost:{self.system_port}/engine/sleep",
json={"level": 1}, json={"level": 2},
timeout=30, timeout=30,
) )
r.raise_for_status() r.raise_for_status()
logger.info(f"{self.engine_id} sleep: {r.json()}") logger.info(f"{self.engine_id} sleep: {r.json()}")
return r.json() return r.json()
def wake(self) -> dict: def wake(self, timeout: int = 30) -> dict:
"""Wake the engine, reloading weights to GPU memory.""" """Wake the engine, restoring weights and KV cache."""
r = requests.post( r = requests.post(
f"http://localhost:{self.system_port}/engine/wake_up", json={}, timeout=30 f"http://localhost:{self.system_port}/engine/wake_up",
json={"tags": ["weights", "kv_cache"]},
timeout=timeout,
) )
r.raise_for_status() r.raise_for_status()
logger.info(f"{self.engine_id} wake: {r.json()}") logger.info(f"{self.engine_id} wake: {r.json()}")
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import time
from concurrent.futures import ThreadPoolExecutor
from contextlib import ExitStack
from typing import Callable, Protocol
import pytest
from gpu_memory_service.client.session import _GMSClientSession
from gpu_memory_service.common.types import RequestedLockType, ServerState
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import DynamoFrontendProcess
from ..harness.gms import GMSServerProcess
from ..harness.runtime import send_completion
from ..harness.sglang import SGLangWithGMSProcess
from ..harness.vllm import VLLMWithGMSProcess
# guard: external_weight_writer imports torch at module level
pytest.importorskip("torch", reason="torch is required")
from ..harness.external_weight_writer import run_external_weight_writer # noqa: E402
pytestmark = [pytest.mark.nightly]
# Event flow under test:
# 1. The engine starts in read-only mode and waits for a committed weights layout.
# 2. An external writer acquires RW on the weights GMS, loads dummy weights, commits, and exits.
# 3. The engine comes online with those committed weights while owning its own KV-cache RW layout.
# 4. The engine sleeps, preserving its weight VAs but dropping the KV-cache layout.
# 5. A second external writer acquires RW on the weights GMS, creates a fresh committed layout with
# different allocation IDs but the same structural layout, and exits.
# 6. The engine wakes, remaps the preserved weight VAs into the new committed layout, recreates its
# KV cache in a new RW layout, and serves inference without a stale-layout error.
class _SleepWakeEngine(Protocol):
def __enter__(self):
...
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
...
def sleep(self) -> dict:
...
def wake(self) -> dict:
...
def _list_committed_weight_allocations(
socket_path: str,
) -> list[tuple[int, str, int, int, str]]:
with _GMSClientSession(
socket_path, RequestedLockType.RO, timeout_ms=None
) as reader:
return [
(
int(info.layout_slot),
str(info.allocation_id),
int(info.size),
int(info.aligned_size),
str(info.tag),
)
for info in reader.list_allocations()
]
def _run_external_weight_mgr_test(
request,
ports: dict,
backend: str,
make_engine: Callable[[], _SleepWakeEngine],
) -> None:
with ExitStack() as stack:
weights_gms = stack.enter_context(
GMSServerProcess(request, device=0, tag="weights")
)
kv_cache_gms = stack.enter_context(
GMSServerProcess(request, device=0, tag="kv_cache")
)
stack.enter_context(
DynamoFrontendProcess(request, frontend_port=ports["frontend"])
)
engine = make_engine()
stack.callback(engine.__exit__, None, None, None)
with ThreadPoolExecutor(max_workers=1) as executor:
start_future = executor.submit(engine.__enter__)
try:
# The read-only engine must stall until some external writer
# publishes the first committed weights layout.
time.sleep(2.0)
assert (
not start_future.done()
), "read-only engine should still be waiting for committed weights"
assert weights_gms.get_runtime_state().state == ServerState.EMPTY
assert kv_cache_gms.get_runtime_state().state == ServerState.EMPTY
# Publish the first weights layout out-of-process, then let the
# engine finish importing those committed weights.
run_external_weight_writer(backend)
start_future.result(timeout=300)
first_weights_state = weights_gms.get_runtime_state()
first_kv_state = kv_cache_gms.get_runtime_state()
first_allocations = _list_committed_weight_allocations(
weights_gms.socket_path
)
assert first_weights_state.state == ServerState.RO
assert first_weights_state.allocation_count > 0
assert first_weights_state.memory_layout_hash
assert first_kv_state.state == ServerState.RW
assert first_allocations
result = send_completion(ports["frontend"])
assert result["choices"]
# Sleep preserves the engine's weight VAs but tears down the KV
# cache layout so the process can wake against a later weights layout.
assert engine.sleep()["status"] == "ok"
deadline = time.monotonic() + 30.0
while True:
weights_after_sleep = weights_gms.get_runtime_state()
kv_after_sleep = kv_cache_gms.get_runtime_state()
if (
weights_after_sleep.state == ServerState.COMMITTED
and weights_after_sleep.memory_layout_hash
== first_weights_state.memory_layout_hash
and weights_after_sleep.ro_session_count == 0
and kv_after_sleep.state == ServerState.EMPTY
and kv_after_sleep.allocation_count == 0
):
break
if time.monotonic() > deadline:
raise TimeoutError(
"engine sleep did not settle the expected GMS state"
)
time.sleep(0.1)
# Publish a second committed weights layout with the same logical
# layout but fresh allocation IDs.
run_external_weight_writer(backend)
deadline = time.monotonic() + 30.0
while True:
second_weights_state = weights_gms.get_runtime_state()
second_kv_state = kv_cache_gms.get_runtime_state()
if (
second_weights_state.state == ServerState.COMMITTED
and second_weights_state.memory_layout_hash
== first_weights_state.memory_layout_hash
and second_weights_state.allocation_count
== first_weights_state.allocation_count
and second_kv_state.state == ServerState.EMPTY
and second_kv_state.allocation_count == 0
):
break
if time.monotonic() > deadline:
raise TimeoutError(
"external writer did not publish a new committed weights layout"
)
time.sleep(0.1)
# The second publish must reuse the same layout slots and sizes so
# RO remap can bind the preserved VAs to new allocations.
second_allocations = _list_committed_weight_allocations(
weights_gms.socket_path
)
assert len(second_allocations) == len(first_allocations)
assert [item[0] for item in second_allocations] == [
item[0] for item in first_allocations
]
assert [item[2:] for item in second_allocations] == [
item[2:] for item in first_allocations
]
assert [item[1] for item in second_allocations] != [
item[1] for item in first_allocations
]
# The weights GMS should show the expected publish progression:
# first publish, old layout cleanup, second publish.
weights_events = weights_gms.get_event_history().events
assert [event.kind for event in weights_events] == [
"rw_connected",
"committed",
"allocations_cleared",
"rw_connected",
"committed",
]
assert (
weights_events[2].allocation_count
== first_weights_state.allocation_count
)
# Wake should remap the preserved RO weight VAs into the new
# committed layout and recreate KV cache in a new RW layout.
assert engine.wake()["status"] == "ok"
deadline = time.monotonic() + 30.0
while True:
weights_after_wake = weights_gms.get_runtime_state()
kv_after_wake = kv_cache_gms.get_runtime_state()
if (
weights_after_wake.state == ServerState.RO
and weights_after_wake.memory_layout_hash
== first_weights_state.memory_layout_hash
and kv_after_wake.state == ServerState.RW
and kv_after_wake.allocation_count > 0
):
break
if time.monotonic() > deadline:
raise TimeoutError(
"engine wake did not restore the expected GMS state"
)
time.sleep(0.1)
# A normal inference after wake proves the remapped weights and
# recreated KV cache are usable end to end.
result = send_completion(ports["frontend"], "updated weights")
assert result["choices"]
finally:
if start_future.done():
try:
start_future.result()
except Exception:
pass
@pytest.mark.vllm
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(600)
def test_external_weight_mgr_vllm(
request,
runtime_services_dynamic_ports,
gms_ports,
predownload_models,
):
ports = gms_ports
_run_external_weight_mgr_test(
request,
ports,
"vllm",
make_engine=lambda: VLLMWithGMSProcess(
request,
"engine",
ports["shadow_system"],
ports["shadow_kv_event"],
ports["shadow_nixl"],
ports["frontend"],
read_only_weights=True,
),
)
@pytest.mark.sglang
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(600)
def test_external_weight_mgr_sglang(
request,
runtime_services_dynamic_ports,
gms_ports,
predownload_models,
):
ports = gms_ports
_run_external_weight_mgr_test(
request,
ports,
"sglang",
make_engine=lambda: SGLangWithGMSProcess(
request,
"engine",
ports["shadow_system"],
ports["shadow_sglang"],
ports["frontend"],
read_only_weights=True,
),
)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from contextlib import contextmanager
import pytest
from gpu_memory_service.common.types import RequestedLockType
from gpu_memory_service.integrations.sglang.memory_saver import GMSMemorySaverImpl
torch = pytest.importorskip("torch", reason="torch is required")
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.gpu_0,
pytest.mark.sglang,
]
class _FakeManager:
def __init__(self, *, is_unmapped: bool = False):
self.is_unmapped = is_unmapped
self.calls: list[object] = []
def unmap_all_vas(self) -> None:
self.calls.append("unmap_all_vas")
self.is_unmapped = True
def abort(self) -> None:
self.calls.append("abort")
def connect(self, lock_type) -> None:
self.calls.append(("connect", lock_type))
self.is_unmapped = False
def reallocate_all_handles(self, *, tag: str) -> None:
self.calls.append(("reallocate_all_handles", tag))
def remap_all_vas(self) -> None:
self.calls.append("remap_all_vas")
self.is_unmapped = False
class _FakeTorchImpl:
def __init__(self):
self.region_calls: list[tuple[str, bool]] = []
self.pause_calls: list[object] = []
self.resume_calls: list[object] = []
@contextmanager
def region(self, tag: str, enable_cpu_backup: bool):
self.region_calls.append((tag, enable_cpu_backup))
yield
def pause(self, tag=None) -> None:
self.pause_calls.append(tag)
def resume(self, tag=None) -> None:
self.resume_calls.append(tag)
def test_region_routes_weights_and_kv_cache_to_gms(monkeypatch):
weights = _FakeManager()
kv_cache = _FakeManager()
torch_impl = _FakeTorchImpl()
pool_calls: list[tuple[str, torch.device]] = []
@contextmanager
def fake_use_mem_pool(tag: str, device: torch.device):
pool_calls.append((tag, device))
yield
monkeypatch.setattr(
GMSMemorySaverImpl,
"_init_allocators",
lambda self: (weights, kv_cache, "write"),
)
monkeypatch.setattr(
"gpu_memory_service.integrations.sglang.memory_saver.gms_use_mem_pool",
fake_use_mem_pool,
)
impl = GMSMemorySaverImpl(torch_impl=torch_impl, device_index=2, mode=None)
with impl.region("weights", enable_cpu_backup=False):
pass
with impl.region("kv_cache", enable_cpu_backup=False):
pass
with impl.region("cuda_graph", enable_cpu_backup=True):
pass
assert pool_calls == [
("weights", torch.device("cuda", 2)),
("kv_cache", torch.device("cuda", 2)),
]
assert torch_impl.region_calls == [("cuda_graph", True)]
def test_pause_resume_routes_kv_cache_to_gms(monkeypatch):
weights = _FakeManager()
kv_cache = _FakeManager()
torch_impl = _FakeTorchImpl()
monkeypatch.setattr(
GMSMemorySaverImpl,
"_init_allocators",
lambda self: (weights, kv_cache, "read"),
)
impl = GMSMemorySaverImpl(torch_impl=torch_impl, device_index=0, mode=None)
impl.pause()
impl.resume()
assert weights.calls == [
"unmap_all_vas",
"abort",
("connect", RequestedLockType.RO),
"remap_all_vas",
]
assert kv_cache.calls == [
"unmap_all_vas",
"abort",
("connect", RequestedLockType.RW),
("reallocate_all_handles", "kv_cache"),
"remap_all_vas",
]
assert torch_impl.pause_calls == [None]
assert torch_impl.resume_calls == [None]
def test_region_treats_model_weights_as_weights(monkeypatch):
weights = _FakeManager()
kv_cache = _FakeManager()
torch_impl = _FakeTorchImpl()
pool_calls: list[tuple[str, torch.device]] = []
@contextmanager
def fake_use_mem_pool(tag: str, device: torch.device):
pool_calls.append((tag, device))
yield
monkeypatch.setattr(
GMSMemorySaverImpl,
"_init_allocators",
lambda self: (weights, kv_cache, "write"),
)
monkeypatch.setattr(
"gpu_memory_service.integrations.sglang.memory_saver.gms_use_mem_pool",
fake_use_mem_pool,
)
impl = GMSMemorySaverImpl(torch_impl=torch_impl, device_index=1, mode=None)
with impl.region("model_weights", enable_cpu_backup=False):
pass
assert pool_calls == [("weights", torch.device("cuda", 1))]
assert torch_impl.region_calls == []
def test_pause_resume_model_weights_only_routes_weights(monkeypatch):
weights = _FakeManager()
kv_cache = _FakeManager()
torch_impl = _FakeTorchImpl()
monkeypatch.setattr(
GMSMemorySaverImpl,
"_init_allocators",
lambda self: (weights, kv_cache, "read"),
)
impl = GMSMemorySaverImpl(torch_impl=torch_impl, device_index=0, mode=None)
impl.pause("model_weights")
impl.resume("model_weights")
assert weights.calls == [
"unmap_all_vas",
"abort",
("connect", RequestedLockType.RO),
"remap_all_vas",
]
assert kv_cache.calls == []
assert torch_impl.pause_calls == []
assert torch_impl.resume_calls == []
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import logging
import os
import signal
import time
from concurrent.futures import ThreadPoolExecutor
from contextlib import ExitStack
from typing import Callable
import pytest
from gpu_memory_service.common.types import ServerState
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import DynamoFrontendProcess, ManagedProcess
from ..harness.gms import ThreadedGMSServer
from ..harness.runtime import (
MIN_EXPECTED_MEMORY_RETURN_FRACTION,
get_gpu_memory_used,
send_completion,
)
from ..harness.sglang import SGLangWithGMSProcess
from ..harness.vllm import VLLMWithGMSProcess
pytestmark = [pytest.mark.nightly]
# Event flow under test:
# 1. Shadow A starts with committed weights and a live RW KV layout, then sleeps.
# 2. Shadow B starts from the same committed weights layout, then sleeps as well.
# 3. Primary wakes and owns the next RW KV layout.
# 4. Shadow A wakes after a forced primary disconnect and enters a new RW layout.
# 5. Shadow A blocks on allocation_oom until the still-alive primary is killed.
# 6. After primary death, the old KV layout clears and Shadow A finishes wake.
logger = logging.getLogger(__name__)
def _kill_process_group(process: ManagedProcess) -> None:
pid = process.get_pid()
if pid is None:
logger.warning("kill process group: no PID available")
return
try:
os.killpg(os.getpgid(pid), signal.SIGKILL)
except ProcessLookupError:
logger.warning("kill process group: process %d already dead", pid)
return
try:
os.waitpid(pid, 0)
except ChildProcessError:
pass
def _is_process_alive(process: ManagedProcess) -> bool:
pid = process.get_pid()
if pid is None:
return False
try:
os.kill(pid, 0)
except ProcessLookupError:
return False
return True
def _assert_weights_published_once(events) -> None:
assert [event.kind for event in events] == ["rw_connected", "committed"]
def _assert_cleared_rw_layout_prefix(events, cleared_layouts: int) -> None:
expected_prefix = [
"rw_connected",
"rw_aborted",
"allocations_cleared",
] * cleared_layouts
assert [event.kind for event in events[: len(expected_prefix)]] == expected_prefix
clear_counts = [
event.allocation_count
for event in events
if event.kind == "allocations_cleared"
]
assert len(clear_counts) >= cleared_layouts
assert all(count > 0 for count in clear_counts[:cleared_layouts])
def _sleep_shadow(
frontend_port: int,
weights_gms: ThreadedGMSServer,
kv_cache_gms: ThreadedGMSServer,
shadow: ManagedProcess,
expected_weights_hash: str | None = None,
) -> tuple[str, int, int]:
result = send_completion(frontend_port)
assert result["choices"], "Shadow inference failed"
logger.info("Shadow inference OK: %s", result)
deadline = time.monotonic() + 30.0
while True:
weights_state = weights_gms.get_runtime_state()
kv_state = kv_cache_gms.get_runtime_state()
if (
weights_state.state == ServerState.RO
and weights_state.allocation_count > 0
and weights_state.memory_layout_hash
and kv_state.state == ServerState.RW
and kv_state.allocation_count > 0
):
break
if time.monotonic() > deadline:
raise TimeoutError("shadow startup did not stabilize GMS state")
time.sleep(0.1)
if expected_weights_hash is not None:
assert weights_state.memory_layout_hash == expected_weights_hash
shadow_memory_before_sleep = get_gpu_memory_used()
assert shadow.sleep()["status"] == "ok"
shadow_memory_after_sleep = get_gpu_memory_used()
shadow_released_bytes = shadow_memory_before_sleep - shadow_memory_after_sleep
logger.info(
"Shadow sleep: %.2f -> %.2f GiB (freed %.0f MB)",
shadow_memory_before_sleep / (1 << 30),
shadow_memory_after_sleep / (1 << 30),
shadow_released_bytes / (1 << 20),
)
assert shadow_memory_after_sleep < shadow_memory_before_sleep
assert shadow_released_bytes > 0
deadline = time.monotonic() + 30.0
while True:
weights_after_sleep = weights_gms.get_runtime_state()
kv_after_sleep = kv_cache_gms.get_runtime_state()
if (
weights_after_sleep.state == ServerState.COMMITTED
and weights_after_sleep.allocation_count == weights_state.allocation_count
and weights_after_sleep.memory_layout_hash
== weights_state.memory_layout_hash
and kv_after_sleep.state == ServerState.EMPTY
and kv_after_sleep.allocation_count == 0
):
break
if time.monotonic() > deadline:
raise TimeoutError("shadow sleep did not clear GMS state")
time.sleep(0.1)
return (
weights_state.memory_layout_hash,
shadow_released_bytes,
shadow_memory_after_sleep,
)
def _run_shadow_failover_test(
request,
ports: dict,
make_shadow_a: Callable[[], ManagedProcess],
make_shadow_b: Callable[[], ManagedProcess],
make_primary: Callable[[], ManagedProcess],
) -> None:
frontend_port = ports["frontend"]
with ExitStack() as stack:
weights_gms = stack.enter_context(ThreadedGMSServer(device=0, tag="weights"))
kv_cache_gms = stack.enter_context(ThreadedGMSServer(device=0, tag="kv_cache"))
stack.enter_context(
DynamoFrontendProcess(
request,
frontend_port=frontend_port,
display_name="frontend",
)
)
with make_shadow_a() as shadow_a:
(
weights_hash,
shadow_a_released_bytes,
_shadow_a_memory_after_sleep,
) = _sleep_shadow(frontend_port, weights_gms, kv_cache_gms, shadow_a)
with make_shadow_b() as shadow_b:
(
sleeping_weights_hash,
_shadow_b_released_bytes,
sleeping_memory_after_sleep,
) = _sleep_shadow(
frontend_port,
weights_gms,
kv_cache_gms,
shadow_b,
expected_weights_hash=weights_hash,
)
assert sleeping_weights_hash == weights_hash
weights_events_after_shadow_sleep = (
weights_gms.get_event_history().events
)
_assert_weights_published_once(weights_events_after_shadow_sleep)
kv_events_after_shadow_sleep = kv_cache_gms.get_event_history().events
_assert_cleared_rw_layout_prefix(kv_events_after_shadow_sleep, 2)
with make_primary() as primary:
result = send_completion(frontend_port, "Primary test")
assert result["choices"], "Primary inference failed"
logger.info("Primary inference OK: %s", result)
primary_memory_in_use = get_gpu_memory_used()
logger.info(
"Primary active memory: %.2f GiB",
primary_memory_in_use / (1 << 30),
)
assert primary_memory_in_use > sleeping_memory_after_sleep
assert (
(primary_memory_in_use - sleeping_memory_after_sleep)
>= shadow_a_released_bytes * MIN_EXPECTED_MEMORY_RETURN_FRACTION
)
deadline = time.monotonic() + 30.0
while True:
weights_with_primary = weights_gms.get_runtime_state()
kv_with_primary = kv_cache_gms.get_runtime_state()
if (
weights_with_primary.state == ServerState.RO
and weights_with_primary.ro_session_count >= 1
and weights_with_primary.allocation_count > 0
and weights_with_primary.memory_layout_hash == weights_hash
and kv_with_primary.state == ServerState.RW
and kv_with_primary.allocation_count > 0
):
break
if time.monotonic() > deadline:
raise TimeoutError(
"primary did not acquire KV cache GMS state"
)
time.sleep(0.1)
expected_kv_kinds_before_disconnect = [
"rw_connected",
"rw_aborted",
"allocations_cleared",
"rw_connected",
"rw_aborted",
"allocations_cleared",
"rw_connected",
]
assert [
event.kind for event in kv_cache_gms.get_event_history().events
] == expected_kv_kinds_before_disconnect
with ThreadPoolExecutor(max_workers=1) as executor:
# Shadow A wakes while Shadow B remains asleep. After we
# force-disconnect the primary from GMS, Shadow A should enter
# a new RW layout but block on real CUDA OOM until the primary dies.
wake_future = executor.submit(shadow_a.wake, 180)
deadline = time.monotonic() + 10.0
while time.monotonic() < deadline:
if wake_future.done():
break
time.sleep(0.2)
assert not wake_future.done(), (
"Shadow wake completed before the primary died; "
"KV cache RW handoff did not block as expected"
)
kv_while_blocked = kv_cache_gms.get_runtime_state()
assert kv_while_blocked.state == ServerState.RW
assert kv_while_blocked.allocation_count > 0
kv_cache_gms.disconnect_rw_session()
expected_kv_kinds_while_blocked = (
expected_kv_kinds_before_disconnect
+ [
"rw_aborted",
"allocations_cleared",
"rw_connected",
"allocation_oom",
]
)
blocked_allocation_count: int | None = None
deadline = time.monotonic() + 30.0
while time.monotonic() < deadline:
kv_after_forced_disconnect = (
kv_cache_gms.get_runtime_state()
)
kv_events_after_forced_disconnect = (
kv_cache_gms.get_event_history().events
)
if (
kv_after_forced_disconnect.state == ServerState.RW
and [
event.kind
for event in kv_events_after_forced_disconnect
]
== expected_kv_kinds_while_blocked
and not wake_future.done()
):
blocked_allocation_count = (
kv_after_forced_disconnect.allocation_count
)
if (
blocked_allocation_count
< kv_while_blocked.allocation_count
and blocked_allocation_count
== kv_events_after_forced_disconnect[
-1
].allocation_count
):
break
time.sleep(0.2)
else:
raise TimeoutError(
"shadow never entered a new KV-cache layout blocked on allocation"
)
assert blocked_allocation_count is not None
linger_deadline = time.monotonic() + 3.0
while time.monotonic() < linger_deadline:
kv_while_lingering = kv_cache_gms.get_runtime_state()
kv_events_while_lingering = (
kv_cache_gms.get_event_history().events
)
assert kv_while_lingering.state == ServerState.RW
assert (
kv_while_lingering.allocation_count
== blocked_allocation_count
)
assert [
event.kind for event in kv_events_while_lingering
] == expected_kv_kinds_while_blocked
assert _is_process_alive(
primary
), "primary died before the linger window completed"
assert (
not wake_future.done()
), "shadow wake completed while the primary was still alive"
time.sleep(0.2)
primary_memory_before_kill = get_gpu_memory_used()
_kill_process_group(primary)
primary_memory_after_kill = get_gpu_memory_used()
logger.info(
"Primary kill snapshot: %.2f -> %.2f GiB",
primary_memory_before_kill / (1 << 30),
primary_memory_after_kill / (1 << 30),
)
deadline = time.monotonic() + 30.0
while time.monotonic() < deadline:
kv_after_primary_kill = kv_cache_gms.get_runtime_state()
if (
kv_after_primary_kill.state == ServerState.RW
and kv_after_primary_kill.allocation_count > 0
):
break
time.sleep(0.2)
else:
raise TimeoutError(
"shadow did not reacquire KV cache after failover"
)
wake_result = wake_future.result(timeout=180)
assert wake_result["status"] == "ok"
shadow_memory_after_wake = get_gpu_memory_used()
shadow_reacquired_bytes = (
shadow_memory_after_wake - sleeping_memory_after_sleep
)
logger.info(
"Shadow wake memory: %.2f GiB (reacquired %.0f MB)",
shadow_memory_after_wake / (1 << 30),
shadow_reacquired_bytes / (1 << 20),
)
assert shadow_memory_after_wake > sleeping_memory_after_sleep
assert (
shadow_reacquired_bytes
) >= shadow_a_released_bytes * MIN_EXPECTED_MEMORY_RETURN_FRACTION
# Once the primary is gone, the failover shadow should finish wake
# with the same committed weights layout and a new live RW KV-cache layout.
deadline = time.monotonic() + 30.0
while True:
weights_after_wake = weights_gms.get_runtime_state()
kv_after_wake = kv_cache_gms.get_runtime_state()
if (
weights_after_wake.state == ServerState.RO
and weights_after_wake.ro_session_count >= 1
and weights_after_wake.allocation_count > 0
and weights_after_wake.memory_layout_hash
== weights_with_primary.memory_layout_hash
and kv_after_wake.state == ServerState.RW
and kv_after_wake.allocation_count > 0
):
break
if time.monotonic() > deadline:
raise TimeoutError(
"shadow wake did not restore the expected GMS state"
)
time.sleep(0.1)
# The final KV history should show the full handoff:
# shadow A slept -> shadow B slept -> primary layout ->
# primary abort/clear -> shadow A reconnects -> shadow A sees OOM.
weights_events_after_wake = weights_gms.get_event_history().events
_assert_weights_published_once(weights_events_after_wake)
kv_events_after_wake = kv_cache_gms.get_event_history().events
_assert_cleared_rw_layout_prefix(kv_events_after_wake, 3)
assert [event.kind for event in kv_events_after_wake] == [
"rw_connected",
"rw_aborted",
"allocations_cleared",
"rw_connected",
"rw_aborted",
"allocations_cleared",
"rw_connected",
"rw_aborted",
"allocations_cleared",
"rw_connected",
"allocation_oom",
]
result = send_completion(frontend_port, "Post failover")
assert result["choices"], "Shadow inference after failover failed"
logger.info("Shadow inference after failover OK: %s", result)
@pytest.mark.vllm
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(600)
def test_gms_shadow_engine_failover_vllm(
request, runtime_services_dynamic_ports, gms_ports, predownload_models
):
ports = gms_ports
_run_shadow_failover_test(
request,
ports,
make_shadow_a=lambda: VLLMWithGMSProcess(
request,
"shadow-a",
ports["shadow_system"],
ports["shadow_kv_event"],
ports["shadow_nixl"],
ports["frontend"],
),
make_shadow_b=lambda: VLLMWithGMSProcess(
request,
"shadow-b",
ports["shadow2_system"],
ports["shadow2_kv_event"],
ports["shadow2_nixl"],
ports["frontend"],
),
make_primary=lambda: VLLMWithGMSProcess(
request,
"primary",
ports["primary_system"],
ports["primary_kv_event"],
ports["primary_nixl"],
ports["frontend"],
),
)
@pytest.mark.sglang
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(600)
def test_gms_shadow_engine_failover_sglang(
request, runtime_services_dynamic_ports, gms_ports, predownload_models
):
ports = gms_ports
_run_shadow_failover_test(
request,
ports,
make_shadow_a=lambda: SGLangWithGMSProcess(
request,
"shadow-a",
ports["shadow_system"],
ports["shadow_sglang"],
ports["frontend"],
),
make_shadow_b=lambda: SGLangWithGMSProcess(
request,
"shadow-b",
ports["shadow2_system"],
ports["shadow2_sglang"],
ports["frontend"],
),
make_primary=lambda: SGLangWithGMSProcess(
request,
"primary",
ports["primary_system"],
ports["primary_sglang"],
ports["frontend"],
),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment