Unverified Commit b78ec99a authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

test: reorganize and prune GPU Memory Service tests (#7828)


Signed-off-by: default avatarSchwinn Saereesitthipitak <schwinns@nvidia.com>
parent 11bb8498
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""GMS runtime flow coverage.
This module exercises lock handoff, committed-layout publication and remap,
reader/writer state transitions, and allocation retry behavior against one
in-process GMS server.
"""
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
...@@ -10,9 +17,9 @@ import signal ...@@ -10,9 +17,9 @@ import signal
import socket import socket
import subprocess import subprocess
import sys import sys
import textwrap
import threading import threading
import time import time
from concurrent.futures import TimeoutError as FutureTimeoutError
import pynvml import pynvml
import pytest import pytest
...@@ -23,7 +30,6 @@ from gpu_memory_service.client.memory_manager import ( ...@@ -23,7 +30,6 @@ from gpu_memory_service.client.memory_manager import (
) )
from gpu_memory_service.client.rpc import _GMSRPCTransport from gpu_memory_service.client.rpc import _GMSRPCTransport
from gpu_memory_service.client.session import _GMSClientSession from gpu_memory_service.client.session import _GMSClientSession
from gpu_memory_service.common import cuda_utils
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
from gpu_memory_service.common.protocol.messages import ( from gpu_memory_service.common.protocol.messages import (
GetEventHistoryRequest, GetEventHistoryRequest,
...@@ -36,15 +42,27 @@ from gpu_memory_service.server.allocations import GMSAllocationManager ...@@ -36,15 +42,27 @@ from gpu_memory_service.server.allocations import GMSAllocationManager
from gpu_memory_service.server.fsm import ServerState from gpu_memory_service.server.fsm import ServerState
from gpu_memory_service.server.rpc import GMSRPCServer from gpu_memory_service.server.rpc import GMSRPCServer
from tests.gms.harness.gms import ServerThread
pytestmark = [ pytestmark = [
pytest.mark.pre_merge, pytest.mark.pre_merge,
pytest.mark.unit, pytest.mark.integration,
pytest.mark.none, pytest.mark.none,
pytest.mark.gpu_1, pytest.mark.gpu_1,
] ]
_SOCKET_TEST_TIMEOUT_SECONDS = 60
_DEFAULT_WAIT_TIMEOUT_SECONDS = 2.0
_SERVER_START_TIMEOUT_SECONDS = 5.0
_SERVER_STOP_TIMEOUT_SECONDS = 5.0
_RW_DISCONNECT_TIMEOUT_SECONDS = 5.0
_BLOCKED_WRITER_JOIN_TIMEOUT_SECONDS = 2.0
_ALLOCATION_RETRY_INTERVAL_SECONDS = 0.1
_ALLOCATION_RETRY_TIMEOUT_SECONDS = 120.0
_EXPORT_HOLDER_READY_TIMEOUT_SECONDS = 30.0
_ALLOCATION_BLOCK_ASSERTION_SECONDS = 5.0
_GPU_MEMORY_RECOVERY_TIMEOUT_SECONDS = 30.0
_FAST_POLL_INTERVAL_SECONDS = 0.01
_SLOW_POLL_INTERVAL_SECONDS = 0.1
def _gpu_memory_free_bytes(device: int = 0) -> int: def _gpu_memory_free_bytes(device: int = 0) -> int:
pynvml.nvmlInit() pynvml.nvmlInit()
...@@ -71,41 +89,136 @@ def _drop_connection(session: _GMSClientSession) -> None: ...@@ -71,41 +89,136 @@ def _drop_connection(session: _GMSClientSession) -> None:
def _wait_for_server_state( def _wait_for_server_state(
server: GMSRPCServer, server: GMSRPCServer,
expected: ServerState, expected: ServerState,
timeout: float = 2.0, timeout: float = _DEFAULT_WAIT_TIMEOUT_SECONDS,
) -> None: ) -> None:
deadline = time.monotonic() + timeout deadline = time.monotonic() + timeout
while server.state != expected: while server.state != expected:
if time.monotonic() > deadline: if time.monotonic() > deadline:
raise TimeoutError(f"server did not reach {expected.name}") raise TimeoutError(f"server did not reach {expected.name}")
time.sleep(0.01) time.sleep(_FAST_POLL_INTERVAL_SECONDS)
def _wait_for_waiting_writers( def _wait_for_waiting_writers(
server: GMSRPCServer, server: GMSRPCServer,
expected: int, expected: int,
timeout: float = 2.0, timeout: float = _DEFAULT_WAIT_TIMEOUT_SECONDS,
) -> None: ) -> None:
deadline = time.monotonic() + timeout deadline = time.monotonic() + timeout
while server._gms._sessions.snapshot().waiting_writers != expected: while server._gms._sessions.snapshot().waiting_writers != expected:
if time.monotonic() > deadline: if time.monotonic() > deadline:
raise TimeoutError(f"waiting writers did not reach {expected}") raise TimeoutError(f"waiting writers did not reach {expected}")
time.sleep(0.01) time.sleep(_FAST_POLL_INTERVAL_SECONDS)
def _wait_for_ro_session_count( def _wait_for_ro_session_count(
server: GMSRPCServer, server: GMSRPCServer,
expected: int, expected: int,
timeout: float = 2.0, timeout: float = _DEFAULT_WAIT_TIMEOUT_SECONDS,
) -> None: ) -> None:
deadline = time.monotonic() + timeout deadline = time.monotonic() + timeout
while server._gms._sessions.snapshot().ro_session_count != expected: while server._gms._sessions.snapshot().ro_session_count != expected:
if time.monotonic() > deadline: if time.monotonic() > deadline:
raise TimeoutError(f"RO session count did not reach {expected}") raise TimeoutError(f"RO session count did not reach {expected}")
time.sleep(0.01) time.sleep(_FAST_POLL_INTERVAL_SECONDS)
class _WhiteBoxServerThread:
"""Threaded in-process server helper."""
def __init__(self, server, socket_path: str):
self.server = server
self.socket_path = socket_path
self._loop: asyncio.AbstractEventLoop | None = None
self._task: asyncio.Task[None] | None = None
self._thread = threading.Thread(target=self._run, daemon=True)
self._exception: BaseException | None = None
def _run(self) -> None:
loop = asyncio.new_event_loop()
self._loop = loop
asyncio.set_event_loop(loop)
self._task = loop.create_task(self.server.serve())
try:
loop.run_until_complete(self._task)
except asyncio.CancelledError:
pass
except BaseException as exc:
self._exception = exc
finally:
pending = [task for task in asyncio.all_tasks(loop) if not task.done()]
for task in pending:
task.cancel()
if pending:
loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True)
)
loop.close()
def start(self) -> None:
self._thread.start()
deadline = time.monotonic() + _SERVER_START_TIMEOUT_SECONDS
last_probe_error: Exception | None = None
while True:
if self._exception is not None:
raise self._exception
if self.server._server is not None and os.path.exists(self.socket_path):
try:
self.server._gms.get_runtime_state()
return
except Exception as exc:
last_probe_error = exc
if time.monotonic() > deadline:
timeout_error = TimeoutError(
f"GMS socket did not appear at {self.socket_path}"
)
if last_probe_error is not None:
raise timeout_error from last_probe_error
raise timeout_error
time.sleep(_FAST_POLL_INTERVAL_SECONDS)
def stop(self) -> None:
if self._loop is not None:
def cancel() -> None:
if self.server._server is not None:
self.server._server.close()
if self._task is not None:
self._task.cancel()
self._loop.call_soon_threadsafe(cancel)
self._thread.join(timeout=_SERVER_STOP_TIMEOUT_SECONDS)
if self._thread.is_alive():
raise RuntimeError(
f"GMS server thread failed to stop for {self.socket_path}"
)
if self._exception is not None:
raise self._exception
if os.path.exists(self.socket_path):
os.unlink(self.socket_path)
def disconnect_rw_session(
self,
timeout: float = _RW_DISCONNECT_TIMEOUT_SECONDS,
) -> None:
if self._loop is None:
raise RuntimeError("GMS server thread is not running")
future = asyncio.run_coroutine_threadsafe(
self._disconnect_rw_session(), self._loop
)
try:
future.result(timeout=timeout)
except FutureTimeoutError as exc:
raise TimeoutError("Timed out disconnecting RW session") from exc
async def _disconnect_rw_session(self) -> None:
conn = self.server._gms._sessions._locking.rw_conn
if conn is None:
raise RuntimeError("No active RW session to disconnect")
await self.server._gms.cleanup_connection(conn)
@pytest.fixture @pytest.fixture
def real_gms(monkeypatch, tmp_path): def running_gms(monkeypatch, tmp_path):
server_handles = itertools.count(1000) server_handles = itertools.count(1000)
client_handles = itertools.count(10000) client_handles = itertools.count(10000)
next_va = itertools.count(0x100000, 0x10000) next_va = itertools.count(0x100000, 0x10000)
...@@ -175,7 +288,7 @@ def real_gms(monkeypatch, tmp_path): ...@@ -175,7 +288,7 @@ def real_gms(monkeypatch, tmp_path):
socket_path = str(tmp_path / "gms.sock") socket_path = str(tmp_path / "gms.sock")
server = GMSRPCServer(socket_path, device=0, allocation_retry_interval=0.01) server = GMSRPCServer(socket_path, device=0, allocation_retry_interval=0.01)
thread = ServerThread(server, socket_path) thread = _WhiteBoxServerThread(server, socket_path)
thread.start() thread.start()
try: try:
yield server, socket_path yield server, socket_path
...@@ -183,33 +296,38 @@ def real_gms(monkeypatch, tmp_path): ...@@ -183,33 +296,38 @@ def real_gms(monkeypatch, tmp_path):
thread.stop() thread.stop()
def test_rw_commit_publishes_allocations_metadata_and_layout_hash(real_gms): @pytest.mark.timeout(_SOCKET_TEST_TIMEOUT_SECONDS)
server, socket_path = real_gms def test_rw_commit_publishes_allocations_metadata_and_layout_hash(running_gms):
server, socket_path = running_gms
writer = GMSClientMemoryManager(socket_path, device=0) writer = GMSClientMemoryManager(socket_path, device=0)
writer.connect(RequestedLockType.RW)
va = writer.create_mapping(size=4096, tag="weights")
allocation_id = writer.mappings[va].allocation_id
writer.metadata_put("tensor.0", allocation_id, 0, b"weights")
assert writer.commit()
reader = _GMSClientSession(socket_path, RequestedLockType.RO, None)
try: try:
assert reader.lock_type == GrantedLockType.RO writer.connect(RequestedLockType.RW)
assert reader.committed va = writer.create_mapping(size=4096, tag="weights")
assert len(reader.list_allocations()) == 1 allocation_id = writer.mappings[va].allocation_id
assert reader.metadata_get("tensor.0") == (allocation_id, 0, b"weights") writer.metadata_put("tensor.0", allocation_id, 0, b"weights")
assert reader.get_memory_layout_hash() assert writer.commit()
finally:
reader.close()
assert writer.is_unmapped reader = _GMSClientSession(socket_path, RequestedLockType.RO, None)
assert not writer.is_connected try:
_wait_for_server_state(server, ServerState.COMMITTED) assert reader.lock_type == GrantedLockType.RO
assert reader.committed
assert len(reader.list_allocations()) == 1
assert reader.metadata_get("tensor.0") == (allocation_id, 0, b"weights")
assert reader.get_memory_layout_hash()
finally:
reader.close()
assert writer.is_unmapped
assert not writer.is_connected
_wait_for_server_state(server, ServerState.COMMITTED)
finally:
writer.close()
def test_rw_disconnect_aborts_layout_and_next_writer_starts_clean(real_gms): @pytest.mark.timeout(_SOCKET_TEST_TIMEOUT_SECONDS)
server, socket_path = real_gms def test_rw_disconnect_aborts_layout_and_next_writer_starts_clean(running_gms):
server, socket_path = running_gms
writer = _GMSClientSession(socket_path, RequestedLockType.RW, None) writer = _GMSClientSession(socket_path, RequestedLockType.RW, None)
allocation_id, _ = writer.allocate(4096, "weights") allocation_id, _ = writer.allocate(4096, "weights")
...@@ -226,8 +344,9 @@ def test_rw_disconnect_aborts_layout_and_next_writer_starts_clean(real_gms): ...@@ -226,8 +344,9 @@ def test_rw_disconnect_aborts_layout_and_next_writer_starts_clean(real_gms):
next_writer.close() next_writer.close()
def test_rw_or_ro_grants_rw_from_empty_and_ro_from_committed(real_gms): @pytest.mark.timeout(_SOCKET_TEST_TIMEOUT_SECONDS)
server, socket_path = real_gms def test_rw_or_ro_grants_rw_from_empty_and_ro_from_committed(running_gms):
server, socket_path = running_gms
session = _GMSClientSession(socket_path, RequestedLockType.RW_OR_RO, 100) session = _GMSClientSession(socket_path, RequestedLockType.RW_OR_RO, 100)
assert session.lock_type == GrantedLockType.RW assert session.lock_type == GrantedLockType.RW
...@@ -243,116 +362,134 @@ def test_rw_or_ro_grants_rw_from_empty_and_ro_from_committed(real_gms): ...@@ -243,116 +362,134 @@ def test_rw_or_ro_grants_rw_from_empty_and_ro_from_committed(real_gms):
session.close() session.close()
def test_runtime_state_and_event_history_are_side_effect_free(real_gms): @pytest.mark.timeout(_SOCKET_TEST_TIMEOUT_SECONDS)
server, socket_path = real_gms def test_runtime_state_and_event_history_are_side_effect_free(running_gms):
server, socket_path = running_gms
writer = GMSClientMemoryManager(socket_path, device=0) writer = GMSClientMemoryManager(socket_path, device=0)
writer.connect(RequestedLockType.RW) try:
writer.create_mapping(size=4096, tag="weights") writer.connect(RequestedLockType.RW)
assert writer.commit() writer.create_mapping(size=4096, tag="weights")
assert writer.commit()
assert server._gms._sessions.snapshot().ro_session_count == 0 assert server._gms._sessions.snapshot().ro_session_count == 0
with _GMSRPCTransport(socket_path) as transport: with _GMSRPCTransport(socket_path) as transport:
transport.connect() transport.connect()
state = transport.request( state = transport.request(
GetRuntimeStateRequest(), GetRuntimeStateRequest(),
GetRuntimeStateResponse, GetRuntimeStateResponse,
) )
with _GMSRPCTransport(socket_path) as transport: with _GMSRPCTransport(socket_path) as transport:
transport.connect() transport.connect()
history = transport.request( history = transport.request(
GetEventHistoryRequest(), GetEventHistoryRequest(),
GetEventHistoryResponse, GetEventHistoryResponse,
) )
assert state.state == ServerState.COMMITTED.name assert state.state == ServerState.COMMITTED.name
assert state.committed assert state.committed
assert state.is_ready assert state.is_ready
assert state.ro_session_count == 0 assert state.ro_session_count == 0
assert state.waiting_writers == 0 assert state.waiting_writers == 0
assert state.allocation_count == 1 assert state.allocation_count == 1
assert state.memory_layout_hash assert state.memory_layout_hash
assert [event.kind for event in history.events] == ["rw_connected", "committed"] assert [event.kind for event in history.events] == ["rw_connected", "committed"]
assert server._gms._sessions.snapshot().ro_session_count == 0 assert server._gms._sessions.snapshot().ro_session_count == 0
finally:
writer.close()
def test_committed_layout_is_replaced_when_new_writer_connects(real_gms): @pytest.mark.timeout(_SOCKET_TEST_TIMEOUT_SECONDS)
server, socket_path = real_gms def test_committed_layout_is_replaced_when_new_writer_connects(running_gms):
server, socket_path = running_gms
first_writer = GMSClientMemoryManager(socket_path, device=0) first_writer = GMSClientMemoryManager(socket_path, device=0)
first_writer.connect(RequestedLockType.RW) try:
first_writer.create_mapping(size=4096, tag="weights") first_writer.connect(RequestedLockType.RW)
assert first_writer.commit() first_writer.create_mapping(size=4096, tag="weights")
assert first_writer.commit()
_wait_for_server_state(server, ServerState.COMMITTED) _wait_for_server_state(server, ServerState.COMMITTED)
assert server._gms.allocation_count == 1 assert server._gms.allocation_count == 1
second_writer = _GMSClientSession(socket_path, RequestedLockType.RW, None) second_writer = _GMSClientSession(socket_path, RequestedLockType.RW, None)
try: try:
assert second_writer.lock_type == GrantedLockType.RW assert second_writer.lock_type == GrantedLockType.RW
assert second_writer.list_allocations() == [] assert second_writer.list_allocations() == []
assert second_writer.metadata_list() == [] assert second_writer.metadata_list() == []
assert server._gms.allocation_count == 0 assert server._gms.allocation_count == 0
assert server.state == ServerState.RW assert server.state == ServerState.RW
assert not server._gms.committed assert not server._gms.committed
finally:
second_writer.close()
finally: finally:
second_writer.close() first_writer.close()
def test_reader_mapping_disconnect_then_next_writer_clears_old_layout(real_gms): @pytest.mark.timeout(_SOCKET_TEST_TIMEOUT_SECONDS)
server, socket_path = real_gms def test_reader_mapping_disconnect_then_next_writer_clears_old_layout(
running_gms,
):
server, socket_path = running_gms
writer = GMSClientMemoryManager(socket_path, device=0) writer = GMSClientMemoryManager(socket_path, device=0)
writer.connect(RequestedLockType.RW)
va = writer.create_mapping(size=4096, tag="weights")
allocation_id = writer.mappings[va].allocation_id
assert writer.commit()
reader = GMSClientMemoryManager(socket_path, device=0) reader = GMSClientMemoryManager(socket_path, device=0)
reader.connect(RequestedLockType.RO) try:
imported_va = reader.create_mapping(allocation_id=allocation_id) writer.connect(RequestedLockType.RW)
assert reader.mappings[imported_va].handle != 0 va = writer.create_mapping(size=4096, tag="weights")
allocation_id = writer.mappings[va].allocation_id
next_writer_result: dict[str, object] = {} assert writer.commit()
def open_writer() -> None:
try:
next_writer_result["session"] = _GMSClientSession(
socket_path,
RequestedLockType.RW,
500,
)
except Exception as exc:
next_writer_result["error"] = exc
thread = threading.Thread(target=open_writer)
thread.start()
_wait_for_waiting_writers(server, 1)
assert thread.is_alive() reader.connect(RequestedLockType.RO)
assert server.state == ServerState.RO imported_va = reader.create_mapping(allocation_id=allocation_id)
assert server._gms.allocation_count == 1 assert reader.mappings[imported_va].handle != 0
reader.unmap_all_vas() next_writer_result: dict[str, object] = {}
reader.abort()
thread.join(timeout=2)
next_writer = next_writer_result.get("session") def open_writer() -> None:
assert isinstance(next_writer, _GMSClientSession) try:
try: next_writer_result["session"] = _GMSClientSession(
assert next_writer.lock_type == GrantedLockType.RW socket_path,
assert next_writer.list_allocations() == [] RequestedLockType.RW,
assert server._gms.allocation_count == 0 500,
assert server.state == ServerState.RW )
assert not server._gms.committed except Exception as exc:
next_writer_result["error"] = exc
thread = threading.Thread(target=open_writer)
thread.start()
_wait_for_waiting_writers(server, 1)
assert thread.is_alive()
assert server.state == ServerState.RO
assert server._gms.allocation_count == 1
reader.unmap_all_vas()
reader.abort()
thread.join(timeout=_BLOCKED_WRITER_JOIN_TIMEOUT_SECONDS)
next_writer = next_writer_result.get("session")
assert isinstance(next_writer, _GMSClientSession)
try:
assert next_writer.lock_type == GrantedLockType.RW
assert next_writer.list_allocations() == []
assert server._gms.allocation_count == 0
assert server.state == ServerState.RW
assert not server._gms.committed
finally:
next_writer.close()
finally: finally:
next_writer.close() reader.close()
writer.close()
def test_waiting_writer_blocks_new_readers_until_last_reader_disconnects(real_gms): @pytest.mark.timeout(_SOCKET_TEST_TIMEOUT_SECONDS)
server, socket_path = real_gms def test_waiting_writer_blocks_new_readers_until_last_reader_disconnects(
running_gms,
):
server, socket_path = running_gms
writer = _GMSClientSession(socket_path, RequestedLockType.RW, None) writer = _GMSClientSession(socket_path, RequestedLockType.RW, None)
writer.commit() writer.commit()
...@@ -378,7 +515,7 @@ def test_waiting_writer_blocks_new_readers_until_last_reader_disconnects(real_gm ...@@ -378,7 +515,7 @@ def test_waiting_writer_blocks_new_readers_until_last_reader_disconnects(real_gm
_GMSClientSession(socket_path, RequestedLockType.RO, 100) _GMSClientSession(socket_path, RequestedLockType.RO, 100)
reader.close() reader.close()
thread.join(timeout=2) thread.join(timeout=_BLOCKED_WRITER_JOIN_TIMEOUT_SECONDS)
waiting_writer = writer_result.get("session") waiting_writer = writer_result.get("session")
assert isinstance(waiting_writer, _GMSClientSession) assert isinstance(waiting_writer, _GMSClientSession)
...@@ -388,8 +525,9 @@ def test_waiting_writer_blocks_new_readers_until_last_reader_disconnects(real_gm ...@@ -388,8 +525,9 @@ def test_waiting_writer_blocks_new_readers_until_last_reader_disconnects(real_gm
waiting_writer.close() waiting_writer.close()
def test_rw_or_ro_times_out_while_writer_waits_behind_reader(real_gms): @pytest.mark.timeout(_SOCKET_TEST_TIMEOUT_SECONDS)
server, socket_path = real_gms def test_rw_or_ro_times_out_while_writer_waits_behind_reader(running_gms):
server, socket_path = running_gms
writer = _GMSClientSession(socket_path, RequestedLockType.RW, None) writer = _GMSClientSession(socket_path, RequestedLockType.RW, None)
writer.commit() writer.commit()
...@@ -415,14 +553,16 @@ def test_rw_or_ro_times_out_while_writer_waits_behind_reader(real_gms): ...@@ -415,14 +553,16 @@ def test_rw_or_ro_times_out_while_writer_waits_behind_reader(real_gms):
_GMSClientSession(socket_path, RequestedLockType.RW_OR_RO, 100) _GMSClientSession(socket_path, RequestedLockType.RW_OR_RO, 100)
reader.close() reader.close()
thread.join(timeout=2) thread.join(timeout=_BLOCKED_WRITER_JOIN_TIMEOUT_SECONDS)
granted_writer = waiting_writer.get("session") granted_writer = waiting_writer.get("session")
assert isinstance(granted_writer, _GMSClientSession) assert isinstance(granted_writer, _GMSClientSession)
granted_writer.close() granted_writer.close()
def test_reader_can_acquire_after_waiting_writer_times_out(real_gms): @pytest.mark.timeout(_SOCKET_TEST_TIMEOUT_SECONDS)
server, socket_path = real_gms def test_reader_can_acquire_after_waiting_writer_times_out(running_gms):
server, socket_path = running_gms
writer = _GMSClientSession(socket_path, RequestedLockType.RW, None) writer = _GMSClientSession(socket_path, RequestedLockType.RW, None)
writer.commit() writer.commit()
...@@ -439,7 +579,7 @@ def test_reader_can_acquire_after_waiting_writer_times_out(real_gms): ...@@ -439,7 +579,7 @@ def test_reader_can_acquire_after_waiting_writer_times_out(real_gms):
thread = threading.Thread(target=timeout_writer) thread = threading.Thread(target=timeout_writer)
thread.start() thread.start()
_wait_for_waiting_writers(server, 1) _wait_for_waiting_writers(server, 1)
thread.join(timeout=2) thread.join(timeout=_BLOCKED_WRITER_JOIN_TIMEOUT_SECONDS)
assert isinstance(writer_result["error"], TimeoutError) assert isinstance(writer_result["error"], TimeoutError)
_wait_for_waiting_writers(server, 0) _wait_for_waiting_writers(server, 0)
...@@ -452,8 +592,11 @@ def test_reader_can_acquire_after_waiting_writer_times_out(real_gms): ...@@ -452,8 +592,11 @@ def test_reader_can_acquire_after_waiting_writer_times_out(real_gms):
reader.close() reader.close()
def test_multiple_readers_hold_committed_state_until_last_disconnect(real_gms): @pytest.mark.timeout(_SOCKET_TEST_TIMEOUT_SECONDS)
server, socket_path = real_gms def test_multiple_readers_hold_committed_state_until_last_disconnect(
running_gms,
):
server, socket_path = running_gms
writer = _GMSClientSession(socket_path, RequestedLockType.RW, None) writer = _GMSClientSession(socket_path, RequestedLockType.RW, None)
writer.commit() writer.commit()
...@@ -472,8 +615,9 @@ def test_multiple_readers_hold_committed_state_until_last_disconnect(real_gms): ...@@ -472,8 +615,9 @@ def test_multiple_readers_hold_committed_state_until_last_disconnect(real_gms):
_wait_for_server_state(server, ServerState.COMMITTED) _wait_for_server_state(server, ServerState.COMMITTED)
def test_ro_session_rejects_rw_only_requests(real_gms): @pytest.mark.timeout(_SOCKET_TEST_TIMEOUT_SECONDS)
_, socket_path = real_gms def test_ro_session_rejects_rw_only_requests(running_gms):
_, socket_path = running_gms
writer = _GMSClientSession(socket_path, RequestedLockType.RW, None) writer = _GMSClientSession(socket_path, RequestedLockType.RW, None)
writer.commit() writer.commit()
...@@ -488,8 +632,11 @@ def test_ro_session_rejects_rw_only_requests(real_gms): ...@@ -488,8 +632,11 @@ def test_ro_session_rejects_rw_only_requests(real_gms):
reader.close() reader.close()
def test_lock_and_allocation_state_requests_reflect_real_server_state(real_gms): @pytest.mark.timeout(_SOCKET_TEST_TIMEOUT_SECONDS)
_, socket_path = real_gms def test_lock_and_allocation_state_requests_reflect_real_server_state(
running_gms,
):
_, socket_path = running_gms
writer = _GMSClientSession(socket_path, RequestedLockType.RW, None) writer = _GMSClientSession(socket_path, RequestedLockType.RW, None)
allocation_id, _ = writer.allocate(4096, "weights") allocation_id, _ = writer.allocate(4096, "weights")
...@@ -517,120 +664,152 @@ def test_lock_and_allocation_state_requests_reflect_real_server_state(real_gms): ...@@ -517,120 +664,152 @@ def test_lock_and_allocation_state_requests_reflect_real_server_state(real_gms):
reader.close() reader.close()
def test_invalid_metadata_offset_is_rejected_without_mutating_state(real_gms): @pytest.mark.timeout(_SOCKET_TEST_TIMEOUT_SECONDS)
_, socket_path = real_gms def test_invalid_metadata_offset_is_rejected_without_mutating_state(
running_gms,
):
_, socket_path = running_gms
writer = _GMSClientSession(socket_path, RequestedLockType.RW, None) writer = _GMSClientSession(socket_path, RequestedLockType.RW, None)
try: try:
allocation_id, _ = writer.allocate(4096, "weights") allocation_id, aligned_size = writer.allocate(4096, "weights")
with pytest.raises(RuntimeError, match="out of range"): with pytest.raises(RuntimeError, match="out of range"):
writer.metadata_put("tensor.bad", allocation_id, 4096, b"x") writer.metadata_put("tensor.bad", allocation_id, aligned_size, b"x")
assert writer.metadata_list() == [] assert writer.metadata_list() == []
finally: finally:
writer.close() writer.close()
def test_destroy_mapping_frees_allocation_and_metadata(real_gms): @pytest.mark.timeout(_SOCKET_TEST_TIMEOUT_SECONDS)
_, socket_path = real_gms def test_destroy_mapping_frees_allocation_and_metadata(running_gms):
_, socket_path = running_gms
writer = GMSClientMemoryManager(socket_path, device=0) writer = GMSClientMemoryManager(socket_path, device=0)
writer.connect(RequestedLockType.RW) try:
va = writer.create_mapping(size=4096, tag="weights") writer.connect(RequestedLockType.RW)
allocation_id = writer.mappings[va].allocation_id va = writer.create_mapping(size=4096, tag="weights")
writer.metadata_put("tensor.0", allocation_id, 0, b"payload") allocation_id = writer.mappings[va].allocation_id
writer.metadata_put("tensor.0", allocation_id, 0, b"payload")
writer.destroy_mapping(va) writer.destroy_mapping(va)
assert writer.list_handles() == [] assert writer.list_handles() == []
assert writer.metadata_list() == [] assert writer.metadata_list() == []
writer.abort() finally:
writer.close()
def test_remap_all_vas_succeeds_when_committed_layout_is_unchanged(real_gms): @pytest.mark.timeout(_SOCKET_TEST_TIMEOUT_SECONDS)
_, socket_path = real_gms def test_remap_all_vas_succeeds_when_committed_layout_is_unchanged(
running_gms,
):
_, socket_path = running_gms
writer = GMSClientMemoryManager(socket_path, device=0) writer = GMSClientMemoryManager(socket_path, device=0)
writer.connect(RequestedLockType.RW)
va = writer.create_mapping(size=4096, tag="weights")
allocation_id = writer.mappings[va].allocation_id
assert writer.commit()
reader = GMSClientMemoryManager(socket_path, device=0) reader = GMSClientMemoryManager(socket_path, device=0)
reader.connect(RequestedLockType.RO) try:
imported_va = reader.create_mapping(allocation_id=allocation_id) writer.connect(RequestedLockType.RW)
imported_mapping = reader.mappings[imported_va] va = writer.create_mapping(size=4096, tag="weights")
reader.unmap_all_vas() allocation_id = writer.mappings[va].allocation_id
reader.abort() assert writer.commit()
reader.connect(RequestedLockType.RO) reader.connect(RequestedLockType.RO)
reader.remap_all_vas() imported_va = reader.create_mapping(allocation_id=allocation_id)
imported_mapping = reader.mappings[imported_va]
reader.unmap_all_vas()
reader.abort()
reader.connect(RequestedLockType.RO)
reader.remap_all_vas()
assert reader.mappings[imported_va].handle != 0 assert reader.mappings[imported_va].handle != 0
assert reader.mappings[imported_va].allocation_id == imported_mapping.allocation_id assert (
reader.close() reader.mappings[imported_va].allocation_id == imported_mapping.allocation_id
)
finally:
reader.close()
writer.close()
def test_remap_all_vas_rejects_stale_layout_after_new_layout_commit(real_gms): @pytest.mark.timeout(_SOCKET_TEST_TIMEOUT_SECONDS)
_, socket_path = real_gms def test_remap_all_vas_rejects_stale_layout_after_new_layout_commit(
running_gms,
):
_, socket_path = running_gms
writer = GMSClientMemoryManager(socket_path, device=0) writer = GMSClientMemoryManager(socket_path, device=0)
writer.connect(RequestedLockType.RW)
va = writer.create_mapping(size=4096, tag="weights")
allocation_id = writer.mappings[va].allocation_id
assert writer.commit()
reader = GMSClientMemoryManager(socket_path, device=0) reader = GMSClientMemoryManager(socket_path, device=0)
reader.connect(RequestedLockType.RO)
reader.create_mapping(allocation_id=allocation_id)
reader.unmap_all_vas()
reader.abort()
next_writer = GMSClientMemoryManager(socket_path, device=0) next_writer = GMSClientMemoryManager(socket_path, device=0)
next_writer.connect(RequestedLockType.RW) try:
next_writer.create_mapping(size=8192, tag="weights") writer.connect(RequestedLockType.RW)
assert next_writer.commit() va = writer.create_mapping(size=4096, tag="weights")
allocation_id = writer.mappings[va].allocation_id
reader.connect(RequestedLockType.RO) assert writer.commit()
with pytest.raises(StaleMemoryLayoutError, match="Layout changed"):
reader.remap_all_vas() reader.connect(RequestedLockType.RO)
reader.abort() reader.create_mapping(allocation_id=allocation_id)
reader.unmap_all_vas()
reader.abort()
next_writer.connect(RequestedLockType.RW)
next_writer.create_mapping(size=writer.granularity + 4096, tag="weights")
assert next_writer.commit()
reader.connect(RequestedLockType.RO)
with pytest.raises(StaleMemoryLayoutError, match="Layout changed"):
reader.remap_all_vas()
reader.abort()
finally:
next_writer.close()
reader.close()
writer.close()
def test_remap_all_vas_accepts_new_layout_with_same_structural_layout(real_gms): @pytest.mark.timeout(_SOCKET_TEST_TIMEOUT_SECONDS)
_, socket_path = real_gms def test_remap_all_vas_accepts_new_layout_with_same_structural_layout(
running_gms,
):
_, socket_path = running_gms
first_writer = GMSClientMemoryManager(socket_path, device=0) first_writer = GMSClientMemoryManager(socket_path, device=0)
first_writer.connect(RequestedLockType.RW)
va = first_writer.create_mapping(size=4096, tag="weights")
first_allocation_id = first_writer.mappings[va].allocation_id
first_writer.metadata_put("tensor.0", first_allocation_id, 0, b"shape")
assert first_writer.commit()
reader = GMSClientMemoryManager(socket_path, device=0) reader = GMSClientMemoryManager(socket_path, device=0)
reader.connect(RequestedLockType.RO)
imported_va = reader.create_mapping(allocation_id=first_allocation_id)
reader.unmap_all_vas()
reader.abort()
second_writer = GMSClientMemoryManager(socket_path, device=0) second_writer = GMSClientMemoryManager(socket_path, device=0)
second_writer.connect(RequestedLockType.RW) try:
second_va = second_writer.create_mapping(size=4096, tag="weights") first_writer.connect(RequestedLockType.RW)
second_allocation_id = second_writer.mappings[second_va].allocation_id va = first_writer.create_mapping(size=4096, tag="weights")
assert second_allocation_id != first_allocation_id first_allocation_id = first_writer.mappings[va].allocation_id
second_writer.metadata_put("tensor.0", second_allocation_id, 0, b"shape") first_writer.metadata_put("tensor.0", first_allocation_id, 0, b"shape")
assert second_writer.commit() assert first_writer.commit()
reader.connect(RequestedLockType.RO) reader.connect(RequestedLockType.RO)
reader.remap_all_vas() imported_va = reader.create_mapping(allocation_id=first_allocation_id)
reader.unmap_all_vas()
assert reader.mappings[imported_va].va == imported_va reader.abort()
assert reader.mappings[imported_va].allocation_id == second_allocation_id
assert reader.metadata_get("tensor.0") == (second_allocation_id, 0, b"shape") second_writer.connect(RequestedLockType.RW)
reader.close() second_va = second_writer.create_mapping(size=4096, tag="weights")
second_allocation_id = second_writer.mappings[second_va].allocation_id
assert second_allocation_id != first_allocation_id
second_writer.metadata_put("tensor.0", second_allocation_id, 0, b"shape")
assert second_writer.commit()
reader.connect(RequestedLockType.RO)
reader.remap_all_vas()
assert reader.mappings[imported_va].va == imported_va
assert reader.mappings[imported_va].allocation_id == second_allocation_id
assert reader.metadata_get("tensor.0") == (second_allocation_id, 0, b"shape")
finally:
second_writer.close()
reader.close()
first_writer.close()
def test_reallocate_all_handles_reuses_preserved_vas_in_new_layout(real_gms): @pytest.mark.timeout(_SOCKET_TEST_TIMEOUT_SECONDS)
server, socket_path = real_gms def test_reallocate_all_handles_reuses_preserved_vas_in_new_layout(
running_gms,
):
server, socket_path = running_gms
manager = GMSClientMemoryManager(socket_path, device=0) manager = GMSClientMemoryManager(socket_path, device=0)
manager.connect(RequestedLockType.RW) manager.connect(RequestedLockType.RW)
...@@ -653,95 +832,6 @@ def test_reallocate_all_handles_reuses_preserved_vas_in_new_layout(real_gms): ...@@ -653,95 +832,6 @@ def test_reallocate_all_handles_reuses_preserved_vas_in_new_layout(real_gms):
_wait_for_server_state(server, ServerState.EMPTY) _wait_for_server_state(server, ServerState.EMPTY)
def test_same_process_republish_remaps_against_new_committed_hash(real_gms):
_, socket_path = real_gms
manager = GMSClientMemoryManager(socket_path, device=0)
manager.connect(RequestedLockType.RW)
va = manager.create_mapping(size=4096, tag="weights")
first_allocation_id = manager.mappings[va].allocation_id
manager.metadata_put("tensor", first_allocation_id, 0, b"publish-1")
assert manager.commit()
manager.connect(RequestedLockType.RO)
manager.remap_all_vas()
manager.unmap_all_vas()
manager.abort()
manager.connect(RequestedLockType.RW)
manager.reallocate_all_handles(tag="weights")
second_allocation_id = manager.mappings[va].allocation_id
assert second_allocation_id != first_allocation_id
manager.remap_all_vas()
manager.metadata_put("tensor", second_allocation_id, 0, b"publish-2")
assert manager.commit()
manager.connect(RequestedLockType.RO)
manager.remap_all_vas()
assert manager.mappings[va].va == va
assert manager.mappings[va].allocation_id == second_allocation_id
assert manager.metadata_get("tensor") == (second_allocation_id, 0, b"publish-2")
manager.close()
def test_disconnect_during_allocation_retry_aborts_writer_and_unblocks_next_writer(
real_gms,
monkeypatch,
):
server, socket_path = real_gms
oom_attempts = 0
allow_allocation = False
def always_oom(size: int, device: int) -> tuple[bool, int]:
nonlocal oom_attempts
nonlocal allow_allocation
if allow_allocation:
return True, 4242
oom_attempts += 1
return False, 0
monkeypatch.setattr(
"gpu_memory_service.server.allocations.cumem_create_tolerate_oom",
always_oom,
)
writer = _GMSClientSession(socket_path, RequestedLockType.RW, None)
result: dict[str, BaseException] = {}
def allocate() -> None:
try:
writer.allocate(4096, "weights")
except BaseException as exc:
result["error"] = exc
thread = threading.Thread(target=allocate)
thread.start()
deadline = time.monotonic() + 2.0
while oom_attempts == 0:
if time.monotonic() > deadline:
raise TimeoutError("allocation retry never reached CUDA OOM")
time.sleep(0.01)
_drop_connection(writer)
thread.join(timeout=2)
_wait_for_server_state(server, ServerState.EMPTY)
allow_allocation = True
next_writer = _GMSClientSession(socket_path, RequestedLockType.RW, 200)
try:
assert next_writer.lock_type == GrantedLockType.RW
allocation_id, aligned_size = next_writer.allocate(4096, "weights")
assert allocation_id
assert aligned_size == 4096
finally:
next_writer.close()
assert isinstance(result.get("error"), ConnectionError)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_allocation_manager_caches_exported_fd(monkeypatch): async def test_allocation_manager_caches_exported_fd(monkeypatch):
export_calls = 0 export_calls = 0
...@@ -793,32 +883,31 @@ async def test_allocation_manager_caches_exported_fd(monkeypatch): ...@@ -793,32 +883,31 @@ async def test_allocation_manager_caches_exported_fd(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.timeout(180) @pytest.mark.timeout(180)
async def test_new_layout_large_allocation_waits_for_dead_writer_process( async def test_large_allocation_unblocks_after_export_fd_holder_dies(
tmp_path, tmp_path,
monkeypatch,
): ):
"""Allocation retries should unblock after the last exported FD holder dies."""
_HOLD_EXPORT_FD_PROGRAM = (
"import sys\n"
"import time\n"
"from pathlib import Path\n"
"\n"
"fd = int(sys.argv[1])\n"
"ready_path = Path(sys.argv[2])\n"
"ready_path.write_text(str(fd), encoding='utf-8')\n"
"while True:\n"
" time.sleep(1.0)\n"
)
free_before = _gpu_memory_free_bytes() free_before = _gpu_memory_free_bytes()
size = int(free_before * 0.9) size = int(free_before * 0.9)
assert size > 0 assert size > 0
oom_failures = 0
def count_oom(size: int, device: int) -> tuple[bool, int]:
nonlocal oom_failures
allocated, handle = cuda_utils.cumem_create_tolerate_oom(size, device)
if not allocated:
oom_failures += 1
return allocated, handle
monkeypatch.setattr(
"gpu_memory_service.server.allocations.cumem_create_tolerate_oom",
count_oom,
)
allocations = GMSAllocationManager( allocations = GMSAllocationManager(
device=0, device=0,
allocation_retry_interval=0.1, allocation_retry_interval=_ALLOCATION_RETRY_INTERVAL_SECONDS,
allocation_retry_timeout=120.0, allocation_retry_timeout=_ALLOCATION_RETRY_TIMEOUT_SECONDS,
) )
holder = None holder = None
allocation_task = None allocation_task = None
...@@ -837,29 +926,15 @@ async def test_new_layout_large_allocation_waits_for_dead_writer_process( ...@@ -837,29 +926,15 @@ async def test_new_layout_large_allocation_waits_for_dead_writer_process(
exported_fd = allocations.export_allocation(first.allocation_id) exported_fd = allocations.export_allocation(first.allocation_id)
holder_ready = tmp_path / "holder.ready" holder_ready = tmp_path / "holder.ready"
holder_log = tmp_path / "holder.log" holder_log = tmp_path / "holder.log"
holder_script = tmp_path / "hold_import.py" # Hold the exported FD in another process so `clear_all()` drops the
holder_script.write_text( # server's bookkeeping first, then the allocator retry path waits for
textwrap.dedent( # the last external reference to die before reusing GPU memory.
"""
import sys
import time
from pathlib import Path
fd = int(sys.argv[1])
Path(sys.argv[2]).write_text(str(fd))
while True:
time.sleep(1.0)
"""
),
encoding="utf-8",
)
with holder_log.open("w", encoding="utf-8") as log_file: with holder_log.open("w", encoding="utf-8") as log_file:
holder = subprocess.Popen( holder = subprocess.Popen(
[ [
sys.executable, sys.executable,
str(holder_script), "-c",
_HOLD_EXPORT_FD_PROGRAM,
str(exported_fd), str(exported_fd),
str(holder_ready), str(holder_ready),
], ],
...@@ -870,11 +945,11 @@ async def test_new_layout_large_allocation_waits_for_dead_writer_process( ...@@ -870,11 +945,11 @@ async def test_new_layout_large_allocation_waits_for_dead_writer_process(
) )
os.close(exported_fd) os.close(exported_fd)
deadline = time.monotonic() + 30.0 deadline = time.monotonic() + _EXPORT_HOLDER_READY_TIMEOUT_SECONDS
while not holder_ready.exists(): while not holder_ready.exists():
assert holder.poll() is None, holder_log.read_text(encoding="utf-8") assert holder.poll() is None, holder_log.read_text(encoding="utf-8")
assert time.monotonic() < deadline, holder_log.read_text(encoding="utf-8") assert time.monotonic() < deadline, holder_log.read_text(encoding="utf-8")
await asyncio.sleep(0.1) await asyncio.sleep(_SLOW_POLL_INTERVAL_SECONDS)
allocations.clear_all() allocations.clear_all()
assert allocations.allocation_count == 0 assert allocations.allocation_count == 0
...@@ -887,30 +962,31 @@ async def test_new_layout_large_allocation_waits_for_dead_writer_process( ...@@ -887,30 +962,31 @@ async def test_new_layout_large_allocation_waits_for_dead_writer_process(
) )
) )
deadline = time.monotonic() + 30.0 deadline = time.monotonic() + _ALLOCATION_BLOCK_ASSERTION_SECONDS
while oom_failures == 0: while time.monotonic() < deadline:
assert holder.poll() is None, holder_log.read_text(encoding="utf-8") assert holder.poll() is None, holder_log.read_text(encoding="utf-8")
assert not allocation_task.done() assert not allocation_task.done()
assert time.monotonic() < deadline await asyncio.sleep(_SLOW_POLL_INTERVAL_SECONDS)
await asyncio.sleep(0.1)
assert oom_failures > 0
assert not allocation_task.done() assert not allocation_task.done()
os.killpg(os.getpgid(holder.pid), signal.SIGKILL) os.killpg(os.getpgid(holder.pid), signal.SIGKILL)
holder.wait(timeout=30.0) holder.wait(timeout=_EXPORT_HOLDER_READY_TIMEOUT_SECONDS)
second = await asyncio.wait_for(allocation_task, timeout=120.0) second = await asyncio.wait_for(
allocation_task,
timeout=_ALLOCATION_RETRY_TIMEOUT_SECONDS,
)
assert second.layout_slot == 0 assert second.layout_slot == 0
assert allocations.allocation_count == 1 assert allocations.allocation_count == 1
allocations.clear_all() allocations.clear_all()
assert allocations.allocation_count == 0 assert allocations.allocation_count == 0
deadline = time.monotonic() + 30.0 deadline = time.monotonic() + _GPU_MEMORY_RECOVERY_TIMEOUT_SECONDS
while _gpu_memory_free_bytes() < free_before - (1 << 30): while _gpu_memory_free_bytes() < free_before - (1 << 30):
assert time.monotonic() < deadline assert time.monotonic() < deadline
await asyncio.sleep(0.1) await asyncio.sleep(_SLOW_POLL_INTERVAL_SECONDS)
finally: finally:
if allocation_task is not None and not allocation_task.done(): if allocation_task is not None and not allocation_task.done():
allocation_task.cancel() allocation_task.cancel()
...@@ -922,4 +998,4 @@ async def test_new_layout_large_allocation_waits_for_dead_writer_process( ...@@ -922,4 +998,4 @@ async def test_new_layout_large_allocation_waits_for_dead_writer_process(
allocations.clear_all() allocations.clear_all()
if holder is not None and holder.poll() is None: if holder is not None and holder.poll() is None:
os.killpg(os.getpgid(holder.pid), signal.SIGKILL) os.killpg(os.getpgid(holder.pid), signal.SIGKILL)
holder.wait(timeout=30.0) holder.wait(timeout=_EXPORT_HOLDER_READY_TIMEOUT_SECONDS)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Torch integration coverage for GMS-backed tensors and modules.
This module exercises tensor remap after unmap/remap cycles and module
materialization from committed GMS-backed weights.
"""
from __future__ import annotations from __future__ import annotations
import asyncio
import os
import threading
import time
from typing import cast from typing import cast
import pytest import pytest
...@@ -13,18 +23,21 @@ from gpu_memory_service.client.torch.module import ( ...@@ -13,18 +23,21 @@ from gpu_memory_service.client.torch.module import (
) )
from gpu_memory_service.client.torch.tensor import _tensor_from_pointer from gpu_memory_service.client.torch.tensor import _tensor_from_pointer
from gpu_memory_service.common.locks import RequestedLockType from gpu_memory_service.common.locks import RequestedLockType
from gpu_memory_service.server.rpc import GMSRPCServer
from tests.gms.harness.gms import GMSServerProcess
torch = pytest.importorskip("torch", reason="torch is required") torch = pytest.importorskip("torch", reason="torch is required")
pytestmark = [ pytestmark = [
pytest.mark.pre_merge, pytest.mark.pre_merge,
pytest.mark.unit, pytest.mark.integration,
pytest.mark.none, pytest.mark.none,
pytest.mark.gpu_1, pytest.mark.gpu_1,
] ]
_SERVER_START_TIMEOUT_SECONDS = 5.0
_SERVER_STOP_TIMEOUT_SECONDS = 5.0
_POLL_INTERVAL_SECONDS = 0.01
if not torch.cuda.is_available(): if not torch.cuda.is_available():
pytest.skip( pytest.skip(
...@@ -50,9 +63,70 @@ class _TinyModule(torch.nn.Module): ...@@ -50,9 +63,70 @@ class _TinyModule(torch.nn.Module):
@pytest.fixture @pytest.fixture
def running_gms(request): def running_gms(tmp_path):
with GMSServerProcess(request, device=0, tag="weights") as server: socket_path = str(tmp_path / "gms.sock")
yield server.socket_path server = GMSRPCServer(socket_path, device=0)
loop: asyncio.AbstractEventLoop | None = None
task: asyncio.Task[None] | None = None
thread_error: BaseException | None = None
def run() -> None:
nonlocal loop, task, thread_error
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
task = loop.create_task(server.serve())
try:
loop.run_until_complete(task)
except asyncio.CancelledError:
pass
except BaseException as exc:
thread_error = exc
finally:
pending = [
pending_task
for pending_task in asyncio.all_tasks(loop)
if not pending_task.done()
]
for pending_task in pending:
pending_task.cancel()
if pending:
loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True)
)
loop.close()
thread = threading.Thread(target=run, daemon=True)
thread.start()
deadline = time.monotonic() + _SERVER_START_TIMEOUT_SECONDS
while True:
if thread_error is not None:
raise thread_error
if server._server is not None and os.path.exists(socket_path):
break
if time.monotonic() > deadline:
raise TimeoutError(f"GMS socket did not appear at {socket_path}")
time.sleep(_POLL_INTERVAL_SECONDS)
try:
yield socket_path
finally:
if loop is not None:
def cancel() -> None:
if server._server is not None:
server._server.close()
if task is not None:
task.cancel()
loop.call_soon_threadsafe(cancel)
thread.join(timeout=_SERVER_STOP_TIMEOUT_SECONDS)
if thread.is_alive():
raise RuntimeError(f"GMS server thread failed to stop for {socket_path}")
if thread_error is not None:
raise thread_error
if os.path.exists(socket_path):
os.unlink(socket_path)
def _make_gms_tensor( def _make_gms_tensor(
...@@ -86,8 +160,10 @@ def test_gms_tensor_matches_plain_torch_ops(running_gms): ...@@ -86,8 +160,10 @@ def test_gms_tensor_matches_plain_torch_ops(running_gms):
writer = GMSClientMemoryManager(socket_path, device=0) writer = GMSClientMemoryManager(socket_path, device=0)
writer.connect(RequestedLockType.RW) writer.connect(RequestedLockType.RW)
allocation_id, _writer_tensor = _make_gms_tensor(writer, baseline, tag="weights") allocation_id, writer_tensor = _make_gms_tensor(writer, baseline, tag="weights")
assert writer.commit() assert writer.commit()
del writer_tensor
writer.close()
reader = GMSClientMemoryManager(socket_path, device=0) reader = GMSClientMemoryManager(socket_path, device=0)
reader.connect(RequestedLockType.RO) reader.connect(RequestedLockType.RO)
...@@ -122,8 +198,10 @@ def test_live_gms_tensor_survives_unmap_and_remap(running_gms): ...@@ -122,8 +198,10 @@ def test_live_gms_tensor_survives_unmap_and_remap(running_gms):
writer = GMSClientMemoryManager(socket_path, device=0) writer = GMSClientMemoryManager(socket_path, device=0)
writer.connect(RequestedLockType.RW) writer.connect(RequestedLockType.RW)
allocation_id, _ = _make_gms_tensor(writer, baseline, tag="weights") allocation_id, writer_tensor = _make_gms_tensor(writer, baseline, tag="weights")
del writer_tensor
assert writer.commit() assert writer.commit()
writer.close()
reader = GMSClientMemoryManager(socket_path, device=0) reader = GMSClientMemoryManager(socket_path, device=0)
reader.connect(RequestedLockType.RO) reader.connect(RequestedLockType.RO)
...@@ -177,6 +255,11 @@ def test_materialized_module_from_gms_matches_plain_module_forward(running_gms): ...@@ -177,6 +255,11 @@ def test_materialized_module_from_gms_matches_plain_module_forward(running_gms):
register_module_tensors(writer, gms_model) register_module_tensors(writer, gms_model)
_assert_exact_tensor_equal(expected, gms_model(inputs)) _assert_exact_tensor_equal(expected, gms_model(inputs))
assert writer.commit() assert writer.commit()
del gms_weight
del gms_scale
del gms_extra
del gms_model
writer.close()
reader = GMSClientMemoryManager(socket_path, device=0) reader = GMSClientMemoryManager(socket_path, device=0)
reader.connect(RequestedLockType.RO) reader.connect(RequestedLockType.RO)
......
...@@ -55,6 +55,7 @@ dynamo/ ...@@ -55,6 +55,7 @@ dynamo/
├── tests/ # End-to-end and cross-component tests ├── tests/ # End-to-end and cross-component tests
│ ├── serve/ # Serve E2E tests (vllm, sglang, trtllm) │ ├── serve/ # Serve E2E tests (vllm, sglang, trtllm)
│ ├── kvbm_integration/ # KVBM integration tests │ ├── kvbm_integration/ # KVBM integration tests
│ ├── gpu_memory_service/ # GPU Memory Service E2E tests
│ ├── fault_tolerance/ # Fault tolerance, migration, cancellation │ ├── fault_tolerance/ # Fault tolerance, migration, cancellation
│ ├── deploy/ # Deployment tests │ ├── deploy/ # Deployment tests
│ ├── frontend/ # Frontend HTTP/gRPC tests │ ├── frontend/ # Frontend HTTP/gRPC tests
...@@ -83,20 +84,21 @@ dynamo/ ...@@ -83,20 +84,21 @@ dynamo/
**Python tests** (`pytest`): **Python tests** (`pytest`):
| Type | Description | Location | | Type | Description | Location |
|-------------------|------------------------------------------|----------------------------------------------| |--------------------|---------------------------------------|-----------------------------------------------|
| Unit | Single function/class, isolated | `components/src/dynamo/<component>/tests/` | | Unit | Single function/class, isolated | `components/src/dynamo/<component>/tests/` |
| Integration | Interactions between modules/services | `components/src/dynamo/<component>/tests/` | | Integration | Interactions between modules/services | `components/src/dynamo/<component>/tests/` |
| End-to-End | User workflows, CLI, API | `tests/serve/`, `tests/deploy/`, etc. | | End-to-End | User workflows, CLI, API | `tests/serve/`, `tests/deploy/`, etc. |
| KVBM Integration | KV block manager integration | `tests/kvbm_integration/` | | KVBM Integration | KV block manager integration | `tests/kvbm_integration/` |
| Router | Router E2E with backends | `tests/router/` | | GPU Memory Service | GPU Memory Service E2E | `tests/gpu_memory_service/` |
| Planner | Planner unit + integration tests | `components/src/dynamo/planner/tests/` | | Router | Router E2E with backends | `tests/router/` |
| Frontend | Frontend HTTP/gRPC tests | `tests/frontend/` | | Planner | Planner unit + integration tests | `components/src/dynamo/planner/tests/` |
| Profiler | Profiler unit + integration tests | `components/src/dynamo/profiler/tests/` | | Frontend | Frontend HTTP/gRPC tests | `tests/frontend/` |
| Global Planner | Global planner unit tests | `components/src/dynamo/global_planner/tests/`| | Profiler | Profiler unit + integration tests | `components/src/dynamo/profiler/tests/` |
| Fault Tolerance | Chaos, migration, cancellation | `tests/fault_tolerance/` | | Global Planner | Global planner unit tests | `components/src/dynamo/global_planner/tests/` |
| Deployment | Deployment validation | `tests/deploy/` | | Fault Tolerance | Chaos, migration, cancellation | `tests/fault_tolerance/` |
| Benchmark | Performance/load | `benchmarks/` | | Deployment | Deployment validation | `tests/deploy/` |
| Benchmark | Performance/load | `benchmarks/` |
--- ---
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import pytest
from gpu_memory_service.client import memory_manager as memory_manager_module
from gpu_memory_service.client.memory_manager import (
GMSClientMemoryManager,
LocalMapping,
)
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.none,
pytest.mark.gpu_1,
]
class _FakeSession:
def __init__(self):
self.lock_type = GrantedLockType.RW
self.committed = False
self.closed = False
@property
def is_connected(self) -> bool:
return not self.closed
def get_memory_layout_hash(self) -> str:
return ""
def commit(self) -> bool:
self.closed = True
return True
def close(self) -> None:
self.closed = True
class _TrackingSession:
def __init__(self) -> None:
self.closed = False
def close(self) -> None:
self.closed = True
class _FailingCommitSession:
is_connected = True
def commit(self) -> bool:
raise ConnectionError("commit failed after local unmap")
class _SuccessfulCommitSession:
is_connected = True
def commit(self) -> bool:
return True
class _CloseFailingSession:
def close(self) -> None:
raise ConnectionError("close failed")
def _make_mapping(
allocation_id: str,
va: int,
*,
handle: int,
tag: str = "weights",
layout_slot: int = 0,
) -> LocalMapping:
return LocalMapping(
allocation_id=allocation_id,
va=va,
size=4096,
aligned_size=4096,
handle=handle,
tag=tag,
layout_slot=layout_slot,
)
@pytest.fixture
def manager(monkeypatch):
monkeypatch.setattr(
memory_manager_module, "cuda_set_current_device", lambda _device: None
)
monkeypatch.setattr(
memory_manager_module, "cumem_get_allocation_granularity", lambda _device: 65536
)
monkeypatch.setattr(memory_manager_module, "cuda_synchronize", lambda: None)
return GMSClientMemoryManager("/tmp/gms-test.sock", device=0)
def _make_manager(
monkeypatch,
*,
client: object | None,
lock_type: GrantedLockType | None,
mappings: list[LocalMapping] | None = None,
unmapped: bool = False,
va_preserved: bool = False,
layout_hash: str = "",
) -> GMSClientMemoryManager:
monkeypatch.setattr(
memory_manager_module, "cuda_set_current_device", lambda _device: None
)
manager = object.__new__(GMSClientMemoryManager)
manager.socket_path = "/tmp/gms-test.sock"
manager.device = 0
manager._client = client
manager._granted_lock_type = lock_type
manager._mappings = {}
manager._inverse_mapping = {}
manager._unmapped = unmapped
manager._va_preserved = va_preserved
manager._last_memory_layout_hash = layout_hash
manager.granularity = 4096
for mapping in mappings or []:
manager._track_mapping(mapping)
return manager
def test_commit_clears_client_lock_state(manager):
manager._client = _FakeSession()
manager._granted_lock_type = GrantedLockType.RW
assert manager.commit()
assert manager.granted_lock_type is None
assert not manager.is_connected
assert manager.is_unmapped
def test_abort_clears_client_lock_state(manager):
manager._client = _FakeSession()
manager._granted_lock_type = GrantedLockType.RW
manager.abort()
assert manager.granted_lock_type is None
assert not manager.is_connected
def test_connect_rejects_double_connect(monkeypatch):
manager = _make_manager(
monkeypatch,
client=object(),
lock_type=GrantedLockType.RO,
)
with pytest.raises(RuntimeError, match="already connected"):
manager.connect(RequestedLockType.RO)
def test_abort_drops_session_with_live_mappings(monkeypatch):
session = _TrackingSession()
manager = _make_manager(
monkeypatch,
client=session,
lock_type=GrantedLockType.RO,
mappings=[_make_mapping("alloc-1", 0x1000, handle=1234)],
)
manager.abort()
assert session.closed
assert manager.granted_lock_type is None
assert not manager.is_connected
assert manager.mappings[0x1000].handle == 1234
def test_commit_failure_after_local_unmap_keeps_preserved_unmapped_state(monkeypatch):
manager = _make_manager(
monkeypatch,
client=_FailingCommitSession(),
lock_type=GrantedLockType.RW,
mappings=[
_make_mapping("alloc_1", 0x1000, handle=11),
_make_mapping("alloc_2", 0x2000, handle=0),
],
)
unmapped_vas: list[int] = []
monkeypatch.setattr(memory_manager_module, "cuda_synchronize", lambda: None)
def fake_unmap_va(self, va: int) -> None:
unmapped_vas.append(va)
self._mappings[va] = self._mappings[va].with_handle(0)
monkeypatch.setattr(GMSClientMemoryManager, "unmap_va", fake_unmap_va)
with pytest.raises(ConnectionError, match="commit failed after local unmap"):
manager.commit()
assert unmapped_vas == [0x1000]
assert manager._mappings[0x1000].handle == 0
assert manager._mappings[0x2000].handle == 0
assert manager._va_preserved
assert manager._unmapped
assert manager._client is not None
def test_successful_commit_clears_rw_mode_before_local_cleanup(monkeypatch):
manager = _make_manager(
monkeypatch,
client=_SuccessfulCommitSession(),
lock_type=GrantedLockType.RW,
mappings=[_make_mapping("alloc_1", 0x1000, handle=11)],
)
free_handle_calls: list[str] = []
monkeypatch.setattr(memory_manager_module, "cuda_synchronize", lambda: None)
def fake_unmap_va(self, va: int) -> None:
self._mappings[va] = self._mappings[va].with_handle(0)
def fake_free_va(self, va: int) -> None:
mapping = self._mappings.pop(va)
self._inverse_mapping.pop(mapping.allocation_id, None)
def fake_free_handle(self, allocation_id: str) -> bool:
free_handle_calls.append(allocation_id)
return True
monkeypatch.setattr(GMSClientMemoryManager, "unmap_va", fake_unmap_va)
monkeypatch.setattr(GMSClientMemoryManager, "free_va", fake_free_va)
monkeypatch.setattr(GMSClientMemoryManager, "free_handle", fake_free_handle)
assert manager.commit()
assert manager._client is None
assert manager._granted_lock_type is None
assert manager._unmapped
assert manager._va_preserved
manager.destroy_mapping(0x1000)
assert free_handle_calls == []
def test_destroy_mapping_keeps_local_state_when_server_free_fails(monkeypatch):
mapping = _make_mapping("alloc_1", 0x1000, handle=77)
manager = _make_manager(
monkeypatch,
client=None,
lock_type=GrantedLockType.RW,
mappings=[mapping],
)
monkeypatch.setattr(
manager,
"free_handle",
lambda allocation_id: (_ for _ in ()).throw(RuntimeError("server free failed")),
)
with pytest.raises(RuntimeError, match="server free failed"):
manager.destroy_mapping(mapping.va)
assert manager.mappings[mapping.va] == mapping
def test_disconnect_clears_local_state_even_if_close_fails(monkeypatch):
manager = _make_manager(
monkeypatch,
client=_CloseFailingSession(),
lock_type=GrantedLockType.RO,
)
with pytest.raises(ConnectionError, match="close failed"):
manager.abort()
assert manager._client is None
assert manager._granted_lock_type is None
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import pytest
from gpu_memory_service.client.rpc import _GMSRPCTransport
from gpu_memory_service.client.session import _GMSClientSession
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
from gpu_memory_service.common.protocol.messages import (
CommitResponse,
HandshakeResponse,
)
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.none,
pytest.mark.gpu_0,
]
def _patch_handshake(
monkeypatch,
*,
response: HandshakeResponse | None = None,
error: Exception | None = None,
) -> dict[str, bool]:
closed = {"value": False}
monkeypatch.setattr(_GMSRPCTransport, "connect", lambda self: None)
if error is None:
monkeypatch.setattr(
_GMSRPCTransport,
"handshake",
lambda self, lock_type, timeout_ms: response,
)
else:
monkeypatch.setattr(
_GMSRPCTransport,
"handshake",
lambda self, lock_type, timeout_ms: (_ for _ in ()).throw(error),
)
monkeypatch.setattr(
_GMSRPCTransport,
"close",
lambda self: closed.__setitem__("value", True),
)
return closed
def test_client_session_timeout_closes_transport(monkeypatch):
closed = _patch_handshake(
monkeypatch,
response=HandshakeResponse(success=False, committed=False),
)
with pytest.raises(TimeoutError, match="Timeout waiting for lock"):
_GMSClientSession("/tmp/gms-test.sock", RequestedLockType.RO, 1000)
assert closed["value"]
def test_client_session_handshake_failure_closes_transport(monkeypatch):
closed = _patch_handshake(monkeypatch, error=RuntimeError("handshake failed"))
with pytest.raises(RuntimeError, match="handshake failed"):
_GMSClientSession("/tmp/gms-test.sock", RequestedLockType.RO, 1000)
assert closed["value"]
def test_client_session_records_granted_lock_and_committed(monkeypatch):
_patch_handshake(
monkeypatch,
response=HandshakeResponse(
success=True,
committed=True,
granted_lock_type=GrantedLockType.RO,
),
)
session = _GMSClientSession("/tmp/gms-test.sock", RequestedLockType.RW_OR_RO, None)
assert session.committed
assert session.lock_type == GrantedLockType.RO
assert session.is_ready()
def test_client_session_requires_granted_lock_type(monkeypatch):
closed = _patch_handshake(
monkeypatch,
response=HandshakeResponse(
success=True,
committed=False,
granted_lock_type=None,
),
)
with pytest.raises(RuntimeError, match="granted_lock_type"):
_GMSClientSession("/tmp/gms-test.sock", RequestedLockType.RW_OR_RO, None)
assert closed["value"]
def test_client_session_commit_marks_committed_and_closes_transport(monkeypatch):
closed = _patch_handshake(
monkeypatch,
response=HandshakeResponse(
success=True,
committed=False,
granted_lock_type=GrantedLockType.RW,
),
)
monkeypatch.setattr(
_GMSRPCTransport,
"request",
lambda self, request, response_type: CommitResponse(success=True),
)
session = _GMSClientSession("/tmp/gms-test.sock", RequestedLockType.RW, None)
assert not session.committed
assert session.commit()
assert session.committed
assert closed["value"]
def test_client_session_commit_tolerates_close_failure_after_success(monkeypatch):
monkeypatch.setattr(_GMSRPCTransport, "connect", lambda self: None)
monkeypatch.setattr(
_GMSRPCTransport,
"handshake",
lambda self, lock_type, timeout_ms: HandshakeResponse(
success=True,
committed=False,
granted_lock_type=GrantedLockType.RW,
),
)
monkeypatch.setattr(
_GMSRPCTransport,
"request",
lambda self, request, response_type: CommitResponse(success=True),
)
monkeypatch.setattr(
_GMSRPCTransport,
"close",
lambda self: (_ for _ in ()).throw(ConnectionError("close failed")),
)
session = _GMSClientSession("/tmp/gms-test.sock", RequestedLockType.RW, None)
assert session.commit()
assert session.committed
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import pytest
from gpu_memory_service.client.rpc import _GMSRPCTransport
from gpu_memory_service.common.protocol import wire
from gpu_memory_service.common.protocol.messages import (
CommitResponse,
ErrorResponse,
HandshakeResponse,
)
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.none,
pytest.mark.gpu_0,
]
class _DummySocket:
def __init__(self) -> None:
self.closed = False
def close(self) -> None:
self.closed = True
def test_transport_failure_closes_socket_and_marks_disconnected(monkeypatch):
transport = _GMSRPCTransport("/tmp/gms-test.sock")
transport._socket = _DummySocket()
monkeypatch.setattr(
"gpu_memory_service.client.rpc.send_message_sync",
lambda sock, request: None,
)
monkeypatch.setattr(
"gpu_memory_service.client.rpc.recv_message_sync",
lambda sock, buffer: (_ for _ in ()).throw(BrokenPipeError("boom")),
)
with pytest.raises(ConnectionError, match="failed: boom"):
transport.request(CommitResponse(success=True), HandshakeResponse)
assert not transport.is_connected
assert transport._socket is None
def test_request_with_fd_closes_fd_on_unexpected_response_type(monkeypatch):
transport = _GMSRPCTransport("/tmp/gms-test.sock")
closed_fds: list[int] = []
monkeypatch.setattr(
transport,
"_send_recv",
lambda request, error_prefix=None: (CommitResponse(success=True), 37),
)
monkeypatch.setattr("gpu_memory_service.client.rpc.os.close", closed_fds.append)
with pytest.raises(RuntimeError, match="unexpected response type"):
transport.request_with_fd(
CommitResponse(success=True),
HandshakeResponse,
)
assert closed_fds == [37]
def test_request_closes_fd_on_error_response(monkeypatch):
transport = _GMSRPCTransport("/tmp/gms-test.sock")
transport._socket = _DummySocket()
closed_fds: list[int] = []
monkeypatch.setattr(
"gpu_memory_service.client.rpc.send_message_sync",
lambda sock, request: None,
)
monkeypatch.setattr(
"gpu_memory_service.client.rpc.recv_message_sync",
lambda sock, buffer: (ErrorResponse(error="boom"), 41, bytearray()),
)
monkeypatch.setattr("gpu_memory_service.client.rpc.os.close", closed_fds.append)
with pytest.raises(RuntimeError, match="error: boom"):
transport.request(CommitResponse(success=True), HandshakeResponse)
assert closed_fds == [41]
def test_request_closes_fd_on_unexpected_success_fd(monkeypatch):
transport = _GMSRPCTransport("/tmp/gms-test.sock")
closed_fds: list[int] = []
monkeypatch.setattr(
transport,
"request_with_fd",
lambda request, response_type: (CommitResponse(success=True), 43),
)
monkeypatch.setattr("gpu_memory_service.client.rpc.os.close", closed_fds.append)
with pytest.raises(RuntimeError, match="unexpected FD"):
transport.request(CommitResponse(success=True), CommitResponse)
assert closed_fds == [43]
def test_recv_message_sync_closes_fd_on_decode_failure(monkeypatch):
closed_fds: list[int] = []
monkeypatch.setattr(
wire.socket,
"recv_fds",
lambda sock, size, maxfds: (b"\x00\x00\x00\x01x", [53], 0, None),
)
monkeypatch.setattr(
wire,
"decode_message",
lambda payload: (_ for _ in ()).throw(ValueError("bad frame")),
)
monkeypatch.setattr(
"gpu_memory_service.common.protocol.wire.os.close",
closed_fds.append,
)
with pytest.raises(ValueError, match="bad frame"):
wire.recv_message_sync(object(), bytearray())
assert closed_fds == [53]
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from types import SimpleNamespace
import pytest
from tests.gms.harness.gms import GMSServerProcess
from tests.utils.managed_process import ManagedProcess
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.none,
pytest.mark.gpu_1,
]
@pytest.fixture
def request_stub():
return SimpleNamespace(node=SimpleNamespace(name="gms-harness"))
def test_server_process_refuses_foreign_live_socket(monkeypatch, request_stub):
server = GMSServerProcess(request_stub, device=0, tag="weights")
monkeypatch.setattr("tests.gms.harness.gms.os.path.exists", lambda path: True)
monkeypatch.setattr("tests.gms.harness.gms._socket_has_live_gms", lambda path: True)
with pytest.raises(RuntimeError, match="already active"):
server.__enter__()
def test_server_process_unlinks_only_stale_socket_on_exit(monkeypatch, request_stub):
server = GMSServerProcess(request_stub, device=0, tag="weights")
unlinked: list[str] = []
monkeypatch.setattr(
ManagedProcess,
"__exit__",
lambda self, exc_type, exc_val, exc_tb: False,
)
monkeypatch.setattr("tests.gms.harness.gms.os.path.exists", lambda path: True)
monkeypatch.setattr("tests.gms.harness.gms.os.unlink", unlinked.append)
monkeypatch.setattr("tests.gms.harness.gms._socket_has_live_gms", lambda path: True)
server.__exit__(None, None, None)
assert unlinked == []
monkeypatch.setattr(
"tests.gms.harness.gms._socket_has_live_gms", lambda path: False
)
server.__exit__(None, None, None)
assert unlinked == [server.socket_path]
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Targeted GMS fault-tolerance unit tests."""
from __future__ import annotations
import asyncio
import os
import socket
import subprocess
import sys
import time
from dataclasses import dataclass
import pytest
from gpu_memory_service.common import cuda_utils
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
from gpu_memory_service.common.protocol.messages import (
CommitRequest,
CommitResponse,
GetEventHistoryRequest,
GetLockStateRequest,
GetLockStateResponse,
GetRuntimeStateRequest,
HandshakeRequest,
)
from gpu_memory_service.server.allocations import GMSAllocationManager
from gpu_memory_service.server.fsm import ServerState, StateEvent
from gpu_memory_service.server.gms import GMS
from gpu_memory_service.server.rpc import GMSRPCServer, _is_connection_alive
from gpu_memory_service.server.session import (
Connection,
GMSSessionManager,
OperationNotAllowed,
)
# Skip entire module if cuda.bindings is not installed
pytest.importorskip("cuda.bindings", reason="cuda.bindings is required")
from cuda.bindings import driver as cuda # noqa: E402
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.none,
pytest.mark.gpu_1,
]
def test_cumem_create_tolerate_oom_returns_handle_on_success(monkeypatch):
monkeypatch.setattr(
cuda_utils.cuda,
"cuMemCreate",
lambda size, prop, flags: (cuda.CUresult.CUDA_SUCCESS, 1234),
)
allocated, handle = cuda_utils.cumem_create_tolerate_oom(4096, 0)
assert allocated
assert handle == 1234
def test_cumem_create_tolerate_oom_returns_false_on_oom(monkeypatch):
monkeypatch.setattr(
cuda_utils.cuda,
"cuMemCreate",
lambda size, prop, flags: (cuda.CUresult.CUDA_ERROR_OUT_OF_MEMORY, 0),
)
allocated, handle = cuda_utils.cumem_create_tolerate_oom(4096, 0)
assert not allocated
assert handle == 0
def test_cumem_export_to_shareable_handle_returns_fd(monkeypatch):
monkeypatch.setattr(
cuda_utils.cuda,
"cuMemExportToShareableHandle",
lambda handle, handle_type, flags: (cuda.CUresult.CUDA_SUCCESS, 77),
)
fd = cuda_utils.cumem_export_to_shareable_handle(1234)
assert fd == 77
def test_non_oom_cuda_error_exits_process() -> None:
script = """
from cuda.bindings import driver as cuda
from gpu_memory_service.common import cuda_utils
cuda_utils.cuda_check_result(cuda.CUresult.CUDA_ERROR_INVALID_VALUE, "synthetic")
"""
result = subprocess.run([sys.executable, "-c", script], check=False)
assert result.returncode == 1
class _DummyReader:
def at_eof(self) -> bool:
return False
def exception(self):
return None
class _DummyWriter:
def __init__(self, sock: socket.socket | None = None) -> None:
self.closed = False
self._socket = sock
def close(self) -> None:
self.closed = True
async def wait_closed(self) -> None:
return None
def is_closing(self) -> bool:
return self.closed
def get_extra_info(self, _name: str):
return self._socket
def test_is_connection_alive_detects_dead_peer() -> None:
local_sock, peer_sock = socket.socketpair()
local_sock.setblocking(False)
try:
conn = Connection(
reader=_DummyReader(),
writer=_DummyWriter(local_sock),
mode=GrantedLockType.RW,
session_id="session_1",
recv_buffer=bytearray(),
)
assert _is_connection_alive(conn)
peer_sock.close()
deadline = time.monotonic() + 1.0
while _is_connection_alive(conn):
if time.monotonic() > deadline:
raise TimeoutError("peer disconnect was not detected")
time.sleep(0.01)
finally:
peer_sock.close()
local_sock.close()
def _make_connection(
mode: GrantedLockType,
session_id: str,
) -> tuple[Connection, _DummyWriter]:
writer = _DummyWriter()
return (
Connection(
reader=_DummyReader(),
writer=writer,
mode=mode,
session_id=session_id,
recv_buffer=bytearray(),
),
writer,
)
def _make_allocation_manager() -> GMSAllocationManager:
manager = object.__new__(GMSAllocationManager)
manager._device = 0
manager._allocations = {}
manager._next_layout_slot = 0
manager._granularity = 1
manager._allocation_retry_interval = 0.0
manager._allocation_retry_timeout = None
return manager
@dataclass
class _FakeHandler:
has_committed_layout: bool = False
rw_connect_calls: int = 0
rw_abort_calls: int = 0
commit_calls: int = 0
def on_rw_connect(self) -> None:
self.rw_connect_calls += 1
self.has_committed_layout = False
def on_rw_abort(self) -> None:
self.rw_abort_calls += 1
def on_commit(self) -> None:
self.commit_calls += 1
self.has_committed_layout = True
def handle_get_lock_state(
self,
has_rw: bool,
ro_count: int,
waiting_writers: int,
committed: bool,
) -> GetLockStateResponse:
return GetLockStateResponse(
state=(
"RW"
if has_rw
else "RO"
if ro_count
else "COMMITTED"
if committed
else "EMPTY"
),
has_rw_session=has_rw,
ro_session_count=ro_count,
waiting_writers=waiting_writers,
committed=committed,
is_ready=committed and not has_rw,
)
class _FakeGMS:
def __init__(self, handler: _FakeHandler | None = None):
self.handler = handler or _FakeHandler()
self._sessions = GMSSessionManager()
if self.handler.has_committed_layout:
self._sessions._locking._committed = True
@property
def committed(self) -> bool:
return self._sessions.snapshot().committed
async def acquire_lock(self, mode, timeout_ms, session_id):
return await self._sessions.acquire_lock(mode, timeout_ms, session_id)
async def cancel_connect(self, session_id, mode):
await self._sessions.cancel_connect(session_id, mode)
def on_connect(self, conn: Connection) -> None:
if conn.mode == GrantedLockType.RW:
self.handler.on_rw_connect()
self._sessions.on_connect(conn)
async def cleanup_connection(self, conn: Connection | None) -> None:
event = self._sessions.begin_cleanup(conn)
if event == StateEvent.RW_ABORT:
self.handler.on_rw_abort()
await self._sessions.finish_cleanup(conn)
async def handle_request(self, conn: Connection, msg, _is_connected):
if isinstance(msg, GetLockStateRequest):
snapshot = self._sessions.snapshot()
return (
self.handler.handle_get_lock_state(
snapshot.has_rw_session,
snapshot.ro_session_count,
snapshot.waiting_writers,
snapshot.committed,
),
-1,
False,
)
if isinstance(msg, CommitRequest):
self.handler.on_commit()
self._sessions.on_commit(conn)
return CommitResponse(success=True), -1, True
raise AssertionError(f"Unexpected request type in test: {type(msg)}")
def _make_server(handler: _FakeHandler | None = None) -> GMSRPCServer:
server = object.__new__(GMSRPCServer)
server.socket_path = "/tmp/gms-test.sock"
server.device = 0
server._gms = _FakeGMS(handler)
server._server = None
return server
@pytest.mark.asyncio
async def test_handshake_success_send_failure_cleans_up_rw_state(monkeypatch):
server = _make_server()
reader = _DummyReader()
writer = _DummyWriter()
async def fake_recv_message(_reader, _buffer):
return HandshakeRequest(lock_type=RequestedLockType.RW), -1, bytearray()
async def fake_acquire_lock(_mode, _timeout_ms, _session_id):
server._gms._sessions._reserved_rw_session_id = _session_id
return GrantedLockType.RW
async def fake_send_message(_writer, _msg, _fd=-1):
raise BrokenPipeError("handshake reply failed")
monkeypatch.setattr("gpu_memory_service.server.rpc.recv_message", fake_recv_message)
monkeypatch.setattr(server._gms, "acquire_lock", fake_acquire_lock)
monkeypatch.setattr("gpu_memory_service.server.rpc.send_message", fake_send_message)
conn = await server._do_handshake(reader, writer, "session_1")
assert conn is None
assert server._gms._sessions._locking.rw_conn is None
assert server._gms._sessions.state == ServerState.EMPTY
assert server._gms.handler.rw_connect_calls == 1
assert server._gms.handler.rw_abort_calls == 1
assert writer.closed
@pytest.mark.asyncio
async def test_rw_lock_is_reserved_until_connect():
sessions = GMSSessionManager()
first = await sessions.acquire_lock(
RequestedLockType.RW,
timeout_ms=50,
session_id="session_1",
)
second = await sessions.acquire_lock(
RequestedLockType.RW,
timeout_ms=50,
session_id="session_2",
)
assert first == GrantedLockType.RW
assert second is None
await sessions.cancel_connect("session_1", GrantedLockType.RW)
@pytest.mark.asyncio
async def test_reader_cannot_acquire_while_rw_is_reserved_before_connect():
sessions = GMSSessionManager()
sessions._locking._committed = True
granted = await sessions.acquire_lock(
RequestedLockType.RW,
timeout_ms=50,
session_id="writer_1",
)
assert granted == GrantedLockType.RW
reader = await sessions.acquire_lock(
RequestedLockType.RO,
timeout_ms=50,
session_id="reader_1",
)
assert reader is None
await sessions.cancel_connect("writer_1", GrantedLockType.RW)
@pytest.mark.asyncio
async def test_reader_waiter_wakes_when_waiting_writer_times_out():
sessions = GMSSessionManager()
sessions._locking._committed = True
existing_reader = Connection(
reader=_DummyReader(),
writer=_DummyWriter(),
mode=GrantedLockType.RO,
session_id="reader_1",
recv_buffer=bytearray(),
)
sessions.on_connect(existing_reader)
writer_task = asyncio.create_task(
sessions.acquire_lock(
RequestedLockType.RW,
timeout_ms=50,
session_id="writer_1",
)
)
await asyncio.sleep(0)
reader_task = asyncio.create_task(
sessions.acquire_lock(
RequestedLockType.RO,
timeout_ms=200,
session_id="reader_2",
)
)
assert await writer_task is None
assert await reader_task == GrantedLockType.RO
event = sessions.begin_cleanup(existing_reader)
assert event == StateEvent.RO_DISCONNECT
await sessions.finish_cleanup(existing_reader)
@pytest.mark.asyncio
async def test_rw_or_ro_waiter_becomes_rw_after_writer_abort():
sessions = GMSSessionManager()
writer_mode = await sessions.acquire_lock(
RequestedLockType.RW,
timeout_ms=50,
session_id="writer_1",
)
assert writer_mode == GrantedLockType.RW
writer = Connection(
reader=_DummyReader(),
writer=_DummyWriter(),
mode=GrantedLockType.RW,
session_id="writer_1",
recv_buffer=bytearray(),
)
sessions.on_connect(writer)
waiter = asyncio.create_task(
sessions.acquire_lock(
RequestedLockType.RW_OR_RO,
timeout_ms=200,
session_id="waiter_1",
)
)
await asyncio.sleep(0)
assert not waiter.done()
event = sessions.begin_cleanup(writer)
assert event == StateEvent.RW_ABORT
await sessions.finish_cleanup(writer)
assert await waiter == GrantedLockType.RW
await sessions.cancel_connect("waiter_1", GrantedLockType.RW)
@pytest.mark.asyncio
async def test_gms_clears_aborted_rw_layout_before_waking_waiters():
gms = object.__new__(GMS)
cleanup_order: list[str] = []
conn, _ = _make_connection(GrantedLockType.RW, "session_1")
gms._events = []
def begin_cleanup(self, cleanup_conn):
cleanup_order.append("begin_cleanup")
return StateEvent.RW_ABORT
async def finish_cleanup(self, cleanup_conn):
cleanup_order.append("finish_cleanup")
def clear_layout_state():
cleanup_order.append("clear_layout_state")
return 3
gms._sessions = type(
"_Sessions",
(),
{
"begin_cleanup": begin_cleanup,
"finish_cleanup": finish_cleanup,
},
)()
gms._clear_layout_state = clear_layout_state
await gms.cleanup_connection(conn)
assert cleanup_order == [
"begin_cleanup",
"clear_layout_state",
"finish_cleanup",
]
@pytest.mark.asyncio
async def test_request_response_send_failure_disconnects_without_error_response(
monkeypatch,
):
handler = _FakeHandler(has_committed_layout=True)
server = _make_server(handler)
server._gms._sessions._locking._committed = True
conn, writer = _make_connection(GrantedLockType.RO, "session_2")
server._gms._sessions._locking.transition(StateEvent.RO_CONNECT, conn)
recv_calls = 0
sent_messages: list[object] = []
async def fake_recv_message(_reader, _buffer):
nonlocal recv_calls
recv_calls += 1
return GetLockStateRequest(), -1, bytearray()
async def fake_send_message(_writer, msg, _fd=-1):
sent_messages.append(msg)
raise BrokenPipeError("response send failed")
monkeypatch.setattr("gpu_memory_service.server.rpc.recv_message", fake_recv_message)
monkeypatch.setattr("gpu_memory_service.server.rpc.send_message", fake_send_message)
await server._request_loop(conn)
await server._gms.cleanup_connection(conn)
assert recv_calls == 1
assert len(sent_messages) == 1
assert isinstance(sent_messages[0], GetLockStateResponse)
assert conn not in server._gms._sessions._locking.ro_conns
assert server._gms._sessions.state == ServerState.COMMITTED
assert writer.closed
@pytest.mark.asyncio
async def test_post_commit_response_send_failure_stays_committed(monkeypatch):
handler = _FakeHandler()
server = _make_server(handler)
conn, writer = _make_connection(GrantedLockType.RW, "session_3")
server._gms._sessions._locking.transition(StateEvent.RW_CONNECT, conn)
recv_calls = 0
sent_messages: list[object] = []
async def fake_recv_message(_reader, _buffer):
nonlocal recv_calls
recv_calls += 1
return CommitRequest(), -1, bytearray()
async def fake_send_message(_writer, msg, _fd=-1):
sent_messages.append(msg)
raise BrokenPipeError("commit reply failed")
monkeypatch.setattr("gpu_memory_service.server.rpc.recv_message", fake_recv_message)
monkeypatch.setattr("gpu_memory_service.server.rpc.send_message", fake_send_message)
await server._request_loop(conn)
await server._gms.cleanup_connection(conn)
assert recv_calls == 1
assert len(sent_messages) == 1
assert handler.commit_calls == 1
assert server._gms._sessions._locking.rw_conn is None
assert server._gms._sessions.snapshot().committed
assert server._gms._sessions.state == ServerState.COMMITTED
assert writer.closed
@pytest.mark.asyncio
async def test_runtime_state_handshake_send_failure_does_not_fail_server(monkeypatch):
server = _make_server()
reader = _DummyReader()
writer = _DummyWriter()
async def fake_recv_message(_reader, _buffer):
return GetRuntimeStateRequest(), -1, bytearray()
async def fake_send_message(_writer, _msg, _fd=-1):
raise BrokenPipeError("runtime-state send failed")
monkeypatch.setattr("gpu_memory_service.server.rpc.recv_message", fake_recv_message)
monkeypatch.setattr("gpu_memory_service.server.rpc.send_message", fake_send_message)
conn = await server._do_handshake(reader, writer, "session_diag")
assert conn is None
assert server._gms._sessions.state == ServerState.EMPTY
assert writer.closed
@pytest.mark.asyncio
async def test_runtime_state_request_is_rejected_on_live_session():
gms = GMS()
conn, _ = _make_connection(GrantedLockType.RW, "session_4")
gms._sessions._reserved_rw_session_id = conn.session_id
gms.on_connect(conn)
with pytest.raises(OperationNotAllowed):
await gms.handle_request(
conn,
GetRuntimeStateRequest(),
lambda: True,
)
@pytest.mark.asyncio
async def test_event_history_request_is_rejected_on_live_session():
gms = GMS()
conn, _ = _make_connection(GrantedLockType.RW, "session_5")
gms._sessions._reserved_rw_session_id = conn.session_id
gms.on_connect(conn)
with pytest.raises(OperationNotAllowed):
await gms.handle_request(
conn,
GetEventHistoryRequest(),
lambda: True,
)
@pytest.mark.asyncio
async def test_server_refuses_to_bind_over_live_socket(monkeypatch, tmp_path):
socket_path = str(tmp_path / "gms.sock")
listener = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
listener.bind(socket_path)
listener.listen(1)
monkeypatch.setattr(
"gpu_memory_service.server.allocations.cuda_ensure_initialized",
lambda: None,
)
monkeypatch.setattr(
"gpu_memory_service.server.allocations.cumem_get_allocation_granularity",
lambda device: 4096,
)
server = GMSRPCServer(socket_path, device=0)
try:
with pytest.raises(RuntimeError, match="already running"):
await asyncio.wait_for(server.serve(), timeout=0.1)
finally:
listener.close()
if os.path.exists(socket_path):
os.unlink(socket_path)
@pytest.mark.asyncio
async def test_allocate_rejects_non_positive_size_before_cuda():
manager = _make_allocation_manager()
with pytest.raises(ValueError, match="size must be > 0"):
await manager.allocate(0, tag="weights", is_connected=None)
@pytest.mark.asyncio
async def test_allocate_aborts_retry_when_writer_disconnects(monkeypatch):
manager = _make_allocation_manager()
checks = 0
def is_connected() -> bool:
nonlocal checks
checks += 1
return checks < 2
monkeypatch.setattr(
"gpu_memory_service.server.allocations.cumem_create_tolerate_oom",
lambda size, device: (False, 0),
)
with pytest.raises(
ConnectionAbortedError, match="RW client disconnected during allocation retry"
):
await manager.allocate(1, tag="weights", is_connected=is_connected)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import sys
import types
import gpu_memory_service.integrations.sglang.patches as sglang_patches
import pytest
torch = pytest.importorskip("torch", reason="torch is required")
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.gpu_0,
pytest.mark.sglang,
]
def test_patch_model_runner_rewrites_total_gpu_memory(monkeypatch):
fake_sglang = types.ModuleType("sglang")
fake_srt = types.ModuleType("sglang.srt")
fake_executor = types.ModuleType("sglang.srt.model_executor")
fake_model_runner = types.ModuleType("sglang.srt.model_executor.model_runner")
class FakeModelRunner:
def init_memory_pool(self, total_gpu_memory):
self.seen_total_gpu_memory = total_gpu_memory
return total_gpu_memory
fake_model_runner.ModelRunner = FakeModelRunner
fake_sglang.srt = fake_srt
fake_srt.model_executor = fake_executor
fake_executor.model_runner = fake_model_runner
fake_memory_saver = types.ModuleType(
"gpu_memory_service.integrations.sglang.memory_saver"
)
class FakeImpl:
imported_weights_bytes = 8 << 30
fake_memory_saver.get_gms_memory_saver_impl = lambda: FakeImpl()
monkeypatch.setitem(sys.modules, "sglang", fake_sglang)
monkeypatch.setitem(sys.modules, "sglang.srt", fake_srt)
monkeypatch.setitem(sys.modules, "sglang.srt.model_executor", fake_executor)
monkeypatch.setitem(
sys.modules,
"sglang.srt.model_executor.model_runner",
fake_model_runner,
)
monkeypatch.setitem(
sys.modules,
"gpu_memory_service.integrations.sglang.memory_saver",
fake_memory_saver,
)
monkeypatch.setattr(
sglang_patches,
"get_gms_memory_saver_impl",
lambda: FakeImpl(),
)
monkeypatch.setattr(sglang_patches, "_model_runner_patched", False)
monkeypatch.delattr(FakeModelRunner, "_gms_patched", raising=False)
monkeypatch.setattr(
sglang_patches.torch.cuda,
"current_device",
lambda: 0,
)
monkeypatch.setattr(
sglang_patches.torch.cuda,
"get_device_properties",
lambda device: types.SimpleNamespace(total_memory=80 * (1 << 30)),
)
sglang_patches.patch_model_runner()
runner = FakeModelRunner()
assert runner.init_memory_pool(40.0) == pytest.approx(80.0)
assert runner.seen_total_gpu_memory == pytest.approx(80.0)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import pytest
from gpu_memory_service.client.torch import allocator as allocator_module
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.none,
pytest.mark.gpu_0,
]
class _FakeManager:
def __init__(self, socket_path: str, *, device: int):
self.socket_path = socket_path
self.device = device
self._connected = False
self._granted_lock_type: GrantedLockType | None = None
self._mappings: dict[int, object] = {}
self._unmapped = False
@property
def granted_lock_type(self) -> GrantedLockType | None:
return self._granted_lock_type
@property
def is_connected(self) -> bool:
return self._connected
@property
def mappings(self) -> dict[int, object]:
return self._mappings
@property
def is_unmapped(self) -> bool:
return self._unmapped
def connect(self, mode: RequestedLockType, timeout_ms: int | None = None) -> None:
del timeout_ms
self._connected = True
if mode == RequestedLockType.RW:
self._granted_lock_type = GrantedLockType.RW
return
self._granted_lock_type = GrantedLockType.RO
@property
def _client_rpc(self):
return self
def get_lock_state(self) -> object:
return object()
@pytest.fixture(autouse=True)
def clear_tag_states():
allocator_module._tag_states.clear()
yield
allocator_module._tag_states.clear()
@pytest.fixture
def fake_allocator(monkeypatch):
monkeypatch.setattr(
"gpu_memory_service.client.memory_manager.GMSClientMemoryManager",
_FakeManager,
)
monkeypatch.setattr(
allocator_module,
"_ensure_callbacks_initialized",
lambda: None,
)
monkeypatch.setattr(allocator_module, "_create_mem_pool", lambda: object())
def test_tag_registry_rejects_socket_or_device_mismatch(fake_allocator):
manager = allocator_module.get_or_create_gms_client_memory_manager(
"/tmp/weights.sock",
0,
RequestedLockType.RO,
tag="weights",
)
with pytest.raises(RuntimeError, match="initialized for"):
allocator_module.get_or_create_gms_client_memory_manager(
"/tmp/other.sock",
0,
RequestedLockType.RO,
tag="weights",
)
with pytest.raises(RuntimeError, match="initialized for"):
allocator_module.get_or_create_gms_client_memory_manager(
"/tmp/weights.sock",
1,
RequestedLockType.RO,
tag="weights",
)
assert manager.is_connected
def test_tag_registry_recreates_disconnected_empty_manager(fake_allocator):
first = allocator_module.get_or_create_gms_client_memory_manager(
"/tmp/weights.sock",
0,
RequestedLockType.RO,
tag="weights",
)
first._connected = False
first._granted_lock_type = None
second = allocator_module.get_or_create_gms_client_memory_manager(
"/tmp/weights.sock",
0,
RequestedLockType.RO,
tag="weights",
)
assert second is not first
assert second.is_connected
assert second.granted_lock_type == GrantedLockType.RO
def test_tag_registry_rejects_disconnected_manager_with_preserved_state(
fake_allocator,
):
manager = allocator_module.get_or_create_gms_client_memory_manager(
"/tmp/weights.sock",
0,
RequestedLockType.RO,
tag="weights",
)
manager._connected = False
manager._mappings[0x1000] = object()
with pytest.raises(RuntimeError, match="preserved state"):
allocator_module.get_or_create_gms_client_memory_manager(
"/tmp/weights.sock",
0,
RequestedLockType.RO,
tag="weights",
)
def test_close_evicts_manager_from_tag_registry(fake_allocator):
manager = allocator_module.get_or_create_gms_client_memory_manager(
"/tmp/weights.sock",
0,
RequestedLockType.RO,
tag="weights",
)
allocator_module.evict_gms_client_memory_manager(manager)
assert allocator_module.get_gms_client_memory_manager("weights") is None
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Pytest configuration for GPU Memory Service tests."""
import pytest
from tests.utils.port_utils import allocate_port, deallocate_ports # noqa: E402
@pytest.fixture
def gms_ports():
"""Allocate ports for GMS tests.
Returns a dict with ports for:
- frontend: Frontend HTTP port
- shadow_system: System port for the first shadow engine
- shadow2_system: System port for the second shadow engine
- primary_system: System port for primary engine (failover test only)
- shadow_kv_event: KV event port for the first shadow engine (vLLM)
- shadow2_kv_event: KV event port for the second shadow engine (vLLM)
- primary_kv_event: KV event port for primary engine (vLLM)
- shadow_nixl: NIXL side channel port for the first shadow engine (vLLM)
- shadow2_nixl: NIXL side channel port for the second shadow engine (vLLM)
- primary_nixl: NIXL side channel port for primary engine (vLLM)
- shadow_sglang: SGLang HTTP port for the first shadow engine
- shadow2_sglang: SGLang HTTP port for the second shadow engine
- primary_sglang: SGLang HTTP port for primary engine
"""
ports = [
allocate_port(p)
for p in [
8200,
8100,
8101,
8102,
20080,
20081,
20082,
20096,
20097,
20098,
30000,
30001,
30002,
]
]
yield {
"frontend": ports[0],
"shadow_system": ports[1],
"primary_system": ports[2],
"shadow2_system": ports[3],
"shadow_kv_event": ports[4],
"primary_kv_event": ports[5],
"shadow2_kv_event": ports[6],
"shadow_nixl": ports[7],
"primary_nixl": ports[8],
"shadow2_nixl": ports[9],
"shadow_sglang": ports[10],
"primary_sglang": ports[11],
"shadow2_sglang": ports[12],
}
deallocate_ports(ports)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import argparse
import os
import socket
import subprocess
from contextlib import contextmanager
import torch
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
from gpu_memory_service.client.torch.allocator import (
get_or_create_gms_client_memory_manager,
gms_use_mem_pool,
)
from gpu_memory_service.client.torch.module import register_module_tensors
from gpu_memory_service.common.locks import RequestedLockType
from gpu_memory_service.common.utils import get_socket_path
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from .runtime import DYNAMO_BIN, REPO_ROOT
SGLANG_BIN = REPO_ROOT / "dynamo-sglang" / "bin"
def get_external_weight_writer_command(backend: str) -> list[str]:
return [
"python",
"-m",
"tests.gms.harness.external_weight_writer",
"--backend",
backend,
]
def get_external_weight_writer_env(backend: str) -> dict[str, str]:
if backend == "sglang":
return {
**os.environ,
"PATH": f"/usr/local/cuda/bin:{SGLANG_BIN}:{os.environ.get('PATH', '')}",
"CC": "/usr/bin/gcc",
"CXX": "/usr/bin/g++",
"PYTHONPATH": str(REPO_ROOT),
}
return {
**os.environ,
"PATH": f"{DYNAMO_BIN}:{os.environ.get('PATH', '')}",
"PYTHONPATH": str(REPO_ROOT),
}
def run_external_weight_writer(backend: str) -> None:
command = get_external_weight_writer_command(backend)
subprocess.run(
command,
cwd=REPO_ROOT,
env=get_external_weight_writer_env(backend),
check=True,
)
def _get_writer_manager(device: int, tag: str) -> GMSClientMemoryManager:
return get_or_create_gms_client_memory_manager(
get_socket_path(device, tag),
device,
RequestedLockType.RW,
tag=tag,
)
def _publish_model(manager: GMSClientMemoryManager, model: torch.nn.Module) -> None:
register_module_tensors(manager, model)
torch.cuda.synchronize()
manager.commit()
manager.close()
@contextmanager
def _vllm_single_rank_distributed(device: int):
from vllm.distributed.parallel_state import (
destroy_distributed_environment,
destroy_model_parallel,
ensure_model_parallel_initialized,
init_distributed_environment,
)
probe = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
probe.bind(("127.0.0.1", 0))
host, port = probe.getsockname()
probe.close()
init_distributed_environment(
world_size=1,
rank=0,
local_rank=device,
distributed_init_method=f"tcp://{host}:{port}",
backend="gloo",
)
ensure_model_parallel_initialized(1, 1)
try:
yield
finally:
destroy_model_parallel()
destroy_distributed_environment()
@contextmanager
def _sglang_single_rank_distributed(device: int):
from sglang.srt.distributed.parallel_state import (
destroy_distributed_environment,
destroy_model_parallel,
init_distributed_environment,
initialize_model_parallel,
)
probe = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
probe.bind(("127.0.0.1", 0))
host, port = probe.getsockname()
probe.close()
init_distributed_environment(
world_size=1,
rank=0,
local_rank=device,
distributed_init_method=f"tcp://{host}:{port}",
backend="gloo",
)
initialize_model_parallel(1, 1, 1, backend="gloo")
try:
yield
finally:
destroy_model_parallel()
destroy_distributed_environment()
def _publish_vllm_dummy_weights(device: int, tag: str) -> None:
from vllm.config import (
DeviceConfig,
LoadConfig,
ModelConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader
from vllm.model_executor.model_loader.utils import (
initialize_model,
process_weights_after_loading,
)
from vllm.utils.torch_utils import set_default_torch_dtype
torch.cuda.set_device(device)
model_config = ModelConfig(model=FAULT_TOLERANCE_MODEL_NAME, enforce_eager=True)
load_config = LoadConfig(load_format="dummy")
device_config = DeviceConfig(device="cuda")
vllm_config = VllmConfig(
model_config=model_config,
device_config=device_config,
load_config=load_config,
)
target_device = torch.device("cuda", device)
manager = _get_writer_manager(device, tag)
with set_current_vllm_config(vllm_config):
with _vllm_single_rank_distributed(device):
with set_default_torch_dtype(model_config.dtype):
with gms_use_mem_pool(tag, target_device):
with target_device:
model = initialize_model(
vllm_config=vllm_config,
model_config=model_config,
)
DummyModelLoader(load_config).load_weights(model, model_config)
process_weights_after_loading(model, model_config, target_device)
_publish_model(manager, model)
def _publish_sglang_dummy_weights(device: int, tag: str) -> None:
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.layers.dp_attention import initialize_dp_attention
from sglang.srt.model_loader.loader import LoadConfig, get_model_loader
from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler
torch.cuda.set_device(device)
model_config = ModelConfig(FAULT_TOLERANCE_MODEL_NAME)
device_config = DeviceConfig(device="cuda", gpu_id=device)
load_config = LoadConfig(load_format="dummy")
loader = get_model_loader(load_config, model_config)
manager = _get_writer_manager(device, tag)
server_args = ServerArgs(
model_path=FAULT_TOLERANCE_MODEL_NAME,
load_format="dummy",
device="cuda",
)
set_global_server_args_for_scheduler(server_args)
with _sglang_single_rank_distributed(device):
initialize_dp_attention(server_args, model_config)
with gms_use_mem_pool(tag, torch.device("cuda", device)):
model = loader.load_model(
model_config=model_config,
device_config=device_config,
)
_publish_model(manager, model)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--backend", choices=("vllm", "sglang"), required=True)
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--tag", default="weights")
args = parser.parse_args()
if args.backend == "vllm":
_publish_vllm_dummy_weights(args.device, args.tag)
return
_publish_sglang_dummy_weights(args.device, args.tag)
if __name__ == "__main__":
main()
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import asyncio
import os
import shutil
import threading
import time
from concurrent.futures import TimeoutError as FutureTimeoutError
from gpu_memory_service.client.rpc import _GMSRPCTransport
from gpu_memory_service.common.protocol.messages import (
GetEventHistoryRequest,
GetEventHistoryResponse,
GetRuntimeStateRequest,
GetRuntimeStateResponse,
)
from gpu_memory_service.common.utils import get_socket_path
from gpu_memory_service.server.rpc import GMSRPCServer
from tests.utils.managed_process import ManagedProcess
from .runtime import DYNAMO_BIN
def _socket_has_live_gms(socket_path: str) -> bool:
if not os.path.exists(socket_path):
return False
try:
with _GMSRPCTransport(socket_path) as transport:
transport.connect()
transport.request(
GetRuntimeStateRequest(),
GetRuntimeStateResponse,
)
except Exception:
return False
return True
def _prepare_socket_path_for_launch(socket_path: str) -> None:
if not os.path.exists(socket_path):
return
if _socket_has_live_gms(socket_path):
raise RuntimeError(f"GMS already active at {socket_path}")
os.unlink(socket_path)
class GMSServerProcess(ManagedProcess):
def __init__(self, request, device: int, tag: str = "weights"):
self.device = device
self.tag = tag
self.socket_path = get_socket_path(device, tag)
log_dir = f"{request.node.name}_gms_{tag}_{device}"
shutil.rmtree(log_dir, ignore_errors=True)
super().__init__(
command=[
"python",
"-m",
"gpu_memory_service",
"--device",
str(device),
"--tag",
tag,
],
env={
**os.environ,
"PATH": f"{DYNAMO_BIN}:{os.environ.get('PATH', '')}",
"DYN_LOG": "debug",
},
timeout=60,
display_output=True,
terminate_all_matching_process_names=False,
log_dir=log_dir,
display_name=f"gms_{tag}",
health_check_funcs=[self._runtime_state_ready],
)
def __enter__(self):
_prepare_socket_path_for_launch(self.socket_path)
return super().__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
try:
return super().__exit__(exc_type, exc_val, exc_tb)
finally:
if os.path.exists(self.socket_path) and not _socket_has_live_gms(
self.socket_path
):
os.unlink(self.socket_path)
def _socket_has_live_gms(self) -> bool:
return _socket_has_live_gms(self.socket_path)
def _prepare_socket_path_for_launch(self) -> None:
_prepare_socket_path_for_launch(self.socket_path)
def _runtime_state_ready(self, timeout: float = 30) -> bool:
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
if not os.path.exists(self.socket_path):
time.sleep(0.1)
continue
try:
self.get_runtime_state()
return True
except Exception:
time.sleep(0.1)
continue
time.sleep(0.1)
return False
def get_runtime_state(self) -> GetRuntimeStateResponse:
with _GMSRPCTransport(self.socket_path) as transport:
transport.connect()
return transport.request(
GetRuntimeStateRequest(),
GetRuntimeStateResponse,
)
def get_event_history(self) -> GetEventHistoryResponse:
with _GMSRPCTransport(self.socket_path) as transport:
transport.connect()
return transport.request(
GetEventHistoryRequest(),
GetEventHistoryResponse,
)
class ServerThread:
def __init__(self, server, socket_path: str):
self.server = server
self.socket_path = socket_path
self._loop: asyncio.AbstractEventLoop | None = None
self._task: asyncio.Task[None] | None = None
self._thread = threading.Thread(target=self._run, daemon=True)
self._exception: BaseException | None = None
def _run(self) -> None:
loop = asyncio.new_event_loop()
self._loop = loop
asyncio.set_event_loop(loop)
self._task = loop.create_task(self.server.serve())
try:
loop.run_until_complete(self._task)
except asyncio.CancelledError:
pass
except BaseException as exc:
self._exception = exc
finally:
pending = [task for task in asyncio.all_tasks(loop) if not task.done()]
for task in pending:
task.cancel()
if pending:
loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True)
)
loop.close()
def start(self) -> None:
self._thread.start()
deadline = time.monotonic() + 5.0
while True:
if self._exception is not None:
raise self._exception
if self.server._server is not None and os.path.exists(self.socket_path):
try:
with _GMSRPCTransport(self.socket_path) as transport:
transport.connect()
transport.request(
GetRuntimeStateRequest(),
GetRuntimeStateResponse,
)
return
except Exception:
pass
if time.monotonic() > deadline:
raise TimeoutError(f"GMS socket did not appear at {self.socket_path}")
time.sleep(0.01)
def stop(self) -> None:
if self._loop is not None:
def cancel() -> None:
if self.server._server is not None:
self.server._server.close()
if self._task is not None:
self._task.cancel()
self._loop.call_soon_threadsafe(cancel)
self._thread.join(timeout=5)
if self._exception is not None:
raise self._exception
if os.path.exists(self.socket_path):
os.unlink(self.socket_path)
def disconnect_rw_session(self, timeout: float = 5.0) -> None:
if self._loop is None:
raise RuntimeError("GMS server thread is not running")
future = asyncio.run_coroutine_threadsafe(
self._disconnect_rw_session(), self._loop
)
try:
future.result(timeout=timeout)
except FutureTimeoutError as exc:
raise TimeoutError("Timed out disconnecting RW session") from exc
async def _disconnect_rw_session(self) -> None:
conn = self.server._gms._sessions._locking.rw_conn
if conn is None:
raise RuntimeError("No active RW session to disconnect")
await self.server._gms.cleanup_connection(conn)
class ThreadedGMSServer:
def __init__(self, device: int, tag: str = "weights"):
self.device = device
self.tag = tag
self.socket_path = get_socket_path(device, tag)
self.server = GMSRPCServer(self.socket_path, device)
self._thread = ServerThread(self.server, self.socket_path)
def __enter__(self) -> "ThreadedGMSServer":
_prepare_socket_path_for_launch(self.socket_path)
self._thread.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self._thread.stop()
def get_runtime_state(self) -> GetRuntimeStateResponse:
with _GMSRPCTransport(self.socket_path) as transport:
transport.connect()
return transport.request(
GetRuntimeStateRequest(),
GetRuntimeStateResponse,
)
def get_event_history(self) -> GetEventHistoryResponse:
with _GMSRPCTransport(self.socket_path) as transport:
transport.connect()
return transport.request(
GetEventHistoryRequest(),
GetEventHistoryResponse,
)
def disconnect_rw_session(self) -> None:
self._thread.disconnect_rw_session()
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import logging
import time
from pathlib import Path
import requests
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
logger = logging.getLogger(__name__)
REPO_ROOT = Path(__file__).resolve().parents[3]
DYNAMO_BIN = REPO_ROOT / "dynamo" / "bin"
MIN_EXPECTED_MEMORY_RETURN_FRACTION = 0.6
def get_gpu_memory_used(device: int = 0) -> int:
import pynvml
pynvml.nvmlInit()
try:
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
return pynvml.nvmlDeviceGetMemoryInfo(handle).used
finally:
pynvml.nvmlShutdown()
def send_completion(
port: int,
prompt: str = "Hello",
max_retries: int = 3,
retry_delay: float = 1.0,
) -> dict:
last_error = None
for attempt in range(max_retries):
try:
response = requests.post(
f"http://localhost:{port}/v1/completions",
json={
"model": FAULT_TOLERANCE_MODEL_NAME,
"prompt": prompt,
"max_tokens": 20,
},
timeout=120,
)
response.raise_for_status()
result = response.json()
assert result.get("choices"), "No choices in response"
if attempt > 0:
logger.info("send_completion succeeded after %d attempts", attempt + 1)
return result
except (requests.exceptions.RequestException, AssertionError) as exc:
last_error = exc
if attempt < max_retries - 1:
logger.debug(
"send_completion attempt %d/%d failed: %s",
attempt + 1,
max_retries,
exc,
)
time.sleep(retry_delay)
raise last_error # type: ignore[misc]
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""SGLang-specific utilities for GPU Memory Service tests."""
import logging
import os
import shutil
import requests
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import ManagedProcess
from tests.utils.payloads import check_health_generate, check_models_api
from .runtime import REPO_ROOT
logger = logging.getLogger(__name__)
SGLANG_BIN = REPO_ROOT / "dynamo-sglang" / "bin"
class SGLangWithGMSProcess(ManagedProcess):
"""SGLang engine with GPU Memory Service integration."""
def __init__(
self,
request,
engine_id: str,
system_port: int,
sglang_port: int,
frontend_port: int,
*,
read_only_weights: bool = False,
):
self.engine_id = engine_id
self.system_port = system_port
log_dir = f"{request.node.name}_{engine_id}"
shutil.rmtree(log_dir, ignore_errors=True)
command = [
"python",
"-m",
"dynamo.sglang",
"--model-path",
FAULT_TOLERANCE_MODEL_NAME,
"--load-format",
"gms",
"--enable-memory-saver",
"--mem-fraction-static",
"0.9",
"--port",
str(sglang_port),
]
if read_only_weights:
command.extend(
[
"--model-loader-extra-config",
'{"gms_read_only": true}',
]
)
super().__init__(
command=command,
env={
**os.environ,
"PATH": f"/usr/local/cuda/bin:{SGLANG_BIN}:{os.environ.get('PATH', '')}",
"CC": "/usr/bin/gcc",
"CXX": "/usr/bin/g++",
"DYN_LOG": "debug",
"DYN_SYSTEM_PORT": str(system_port),
},
health_check_urls=[
(f"http://localhost:{system_port}/health", self._is_ready),
(f"http://localhost:{frontend_port}/v1/models", check_models_api),
(f"http://localhost:{frontend_port}/health", check_health_generate),
],
timeout=300,
display_output=True,
terminate_all_matching_process_names=False,
stragglers=[],
log_dir=log_dir,
display_name=engine_id,
)
def _is_ready(self, response) -> bool:
try:
return response.json().get("status") == "ready"
except ValueError:
return False
def sleep(self) -> dict:
"""Put the engine to sleep, offloading weights and KV cache."""
r = requests.post(
f"http://localhost:{self.system_port}/engine/release_memory_occupation",
json={"tags": ["weights", "kv_cache"]},
timeout=30,
)
r.raise_for_status()
logger.info(f"{self.engine_id} release_memory_occupation: {r.json()}")
return r.json()
def wake(self, timeout: int = 30) -> dict:
"""Wake the engine, restoring weights and KV cache."""
r = requests.post(
f"http://localhost:{self.system_port}/engine/resume_memory_occupation",
json={"tags": ["weights", "kv_cache"]},
timeout=timeout,
)
r.raise_for_status()
logger.info(f"{self.engine_id} resume_memory_occupation: {r.json()}")
return r.json()
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""vLLM-specific utilities for GPU Memory Service tests."""
import json
import logging
import os
import shutil
import requests
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import ManagedProcess
from tests.utils.payloads import check_health_generate, check_models_api
from .runtime import DYNAMO_BIN
logger = logging.getLogger(__name__)
class VLLMWithGMSProcess(ManagedProcess):
"""vLLM engine with GPU Memory Service integration."""
def __init__(
self,
request,
engine_id: str,
system_port: int,
kv_event_port: int,
nixl_port: int,
frontend_port: int,
*,
read_only_weights: bool = False,
):
self.engine_id = engine_id
self.system_port = system_port
log_dir = f"{request.node.name}_{engine_id}"
shutil.rmtree(log_dir, ignore_errors=True)
kv_events_cfg = json.dumps(
{
"publisher": "zmq",
"topic": "kv-events",
"endpoint": f"tcp://*:{kv_event_port}",
"enable_kv_cache_events": True,
}
)
command = [
"python",
"-m",
"dynamo.vllm",
"--model",
FAULT_TOLERANCE_MODEL_NAME,
"--load-format",
"gms",
"--enforce-eager",
"--enable-sleep-mode",
"--gpu-memory-utilization",
"0.9",
"--kv-events-config",
kv_events_cfg,
]
if read_only_weights:
command.extend(
[
"--model-loader-extra-config",
json.dumps({"gms_read_only": True}),
]
)
super().__init__(
command=command,
env={
**os.environ,
"PATH": f"{DYNAMO_BIN}:{os.environ.get('PATH', '')}",
"DYN_LOG": "debug",
"DYN_SYSTEM_PORT": str(system_port),
"VLLM_NIXL_SIDE_CHANNEL_PORT": str(nixl_port),
},
health_check_urls=[
(f"http://localhost:{system_port}/health", self._is_ready),
(f"http://localhost:{frontend_port}/v1/models", check_models_api),
(f"http://localhost:{frontend_port}/health", check_health_generate),
],
timeout=300,
display_output=True,
terminate_all_matching_process_names=False,
stragglers=[],
log_dir=log_dir,
display_name=engine_id,
)
def _is_ready(self, response) -> bool:
try:
return response.json().get("status") == "ready"
except ValueError:
return False
def sleep(self) -> dict:
"""Put the engine to sleep, offloading weights and KV cache."""
r = requests.post(
f"http://localhost:{self.system_port}/engine/sleep",
json={"level": 2},
timeout=30,
)
r.raise_for_status()
logger.info(f"{self.engine_id} sleep: {r.json()}")
return r.json()
def wake(self, timeout: int = 30) -> dict:
"""Wake the engine, restoring weights and KV cache."""
r = requests.post(
f"http://localhost:{self.system_port}/engine/wake_up",
json={"tags": ["weights", "kv_cache"]},
timeout=timeout,
)
r.raise_for_status()
logger.info(f"{self.engine_id} wake: {r.json()}")
return r.json()
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import pytest
pytest.importorskip("gpu_memory_service", reason="gpu_memory_service is required")
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import time
from concurrent.futures import ThreadPoolExecutor
from contextlib import ExitStack
from typing import Callable, Protocol
import pytest
from gpu_memory_service.client.session import _GMSClientSession
from gpu_memory_service.common.locks import RequestedLockType
from gpu_memory_service.server.fsm import ServerState
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import DynamoFrontendProcess
from ..harness.gms import GMSServerProcess
from ..harness.runtime import send_completion
from ..harness.sglang import SGLangWithGMSProcess
from ..harness.vllm import VLLMWithGMSProcess
# guard: external_weight_writer imports torch at module level
pytest.importorskip("torch", reason="torch is required")
from ..harness.external_weight_writer import run_external_weight_writer # noqa: E402
pytestmark = [pytest.mark.nightly]
# Event flow under test:
# 1. The engine starts in read-only mode and waits for a committed weights layout.
# 2. An external writer acquires RW on the weights GMS, loads dummy weights, commits, and exits.
# 3. The engine comes online with those committed weights while owning its own KV-cache RW layout.
# 4. The engine sleeps, preserving its weight VAs but dropping the KV-cache layout.
# 5. A second external writer acquires RW on the weights GMS, creates a fresh committed layout with
# different allocation IDs but the same structural layout, and exits.
# 6. The engine wakes, remaps the preserved weight VAs into the new committed layout, recreates its
# KV cache in a new RW layout, and serves inference without a stale-layout error.
class _SleepWakeEngine(Protocol):
def __enter__(self):
...
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
...
def sleep(self) -> dict:
...
def wake(self) -> dict:
...
def _list_committed_weight_allocations(
socket_path: str,
) -> list[tuple[int, str, int, int, str]]:
with _GMSClientSession(
socket_path, RequestedLockType.RO, timeout_ms=None
) as reader:
return [
(
int(info.layout_slot),
str(info.allocation_id),
int(info.size),
int(info.aligned_size),
str(info.tag),
)
for info in reader.list_allocations()
]
def _run_external_weight_mgr_test(
request,
ports: dict,
backend: str,
make_engine: Callable[[], _SleepWakeEngine],
) -> None:
with ExitStack() as stack:
weights_gms = stack.enter_context(
GMSServerProcess(request, device=0, tag="weights")
)
kv_cache_gms = stack.enter_context(
GMSServerProcess(request, device=0, tag="kv_cache")
)
stack.enter_context(
DynamoFrontendProcess(request, frontend_port=ports["frontend"])
)
engine = make_engine()
stack.callback(engine.__exit__, None, None, None)
with ThreadPoolExecutor(max_workers=1) as executor:
start_future = executor.submit(engine.__enter__)
try:
# The read-only engine must stall until some external writer
# publishes the first committed weights layout.
time.sleep(2.0)
assert (
not start_future.done()
), "read-only engine should still be waiting for committed weights"
assert weights_gms.get_runtime_state().state == ServerState.EMPTY
assert kv_cache_gms.get_runtime_state().state == ServerState.EMPTY
# Publish the first weights layout out-of-process, then let the
# engine finish importing those committed weights.
run_external_weight_writer(backend)
start_future.result(timeout=300)
first_weights_state = weights_gms.get_runtime_state()
first_kv_state = kv_cache_gms.get_runtime_state()
first_allocations = _list_committed_weight_allocations(
weights_gms.socket_path
)
assert first_weights_state.state == ServerState.RO
assert first_weights_state.allocation_count > 0
assert first_weights_state.memory_layout_hash
assert first_kv_state.state == ServerState.RW
assert first_allocations
result = send_completion(ports["frontend"])
assert result["choices"]
# Sleep preserves the engine's weight VAs but tears down the KV
# cache layout so the process can wake against a later weights layout.
assert engine.sleep()["status"] == "ok"
deadline = time.monotonic() + 30.0
while True:
weights_after_sleep = weights_gms.get_runtime_state()
kv_after_sleep = kv_cache_gms.get_runtime_state()
if (
weights_after_sleep.state == ServerState.COMMITTED
and weights_after_sleep.memory_layout_hash
== first_weights_state.memory_layout_hash
and weights_after_sleep.ro_session_count == 0
and kv_after_sleep.state == ServerState.EMPTY
and kv_after_sleep.allocation_count == 0
):
break
if time.monotonic() > deadline:
raise TimeoutError(
"engine sleep did not settle the expected GMS state"
)
time.sleep(0.1)
# Publish a second committed weights layout with the same logical
# layout but fresh allocation IDs.
run_external_weight_writer(backend)
deadline = time.monotonic() + 30.0
while True:
second_weights_state = weights_gms.get_runtime_state()
second_kv_state = kv_cache_gms.get_runtime_state()
if (
second_weights_state.state == ServerState.COMMITTED
and second_weights_state.memory_layout_hash
== first_weights_state.memory_layout_hash
and second_weights_state.allocation_count
== first_weights_state.allocation_count
and second_kv_state.state == ServerState.EMPTY
and second_kv_state.allocation_count == 0
):
break
if time.monotonic() > deadline:
raise TimeoutError(
"external writer did not publish a new committed weights layout"
)
time.sleep(0.1)
# The second publish must reuse the same layout slots and sizes so
# RO remap can bind the preserved VAs to new allocations.
second_allocations = _list_committed_weight_allocations(
weights_gms.socket_path
)
assert len(second_allocations) == len(first_allocations)
assert [item[0] for item in second_allocations] == [
item[0] for item in first_allocations
]
assert [item[2:] for item in second_allocations] == [
item[2:] for item in first_allocations
]
assert [item[1] for item in second_allocations] != [
item[1] for item in first_allocations
]
# The weights GMS should show the expected publish progression:
# first publish, old layout cleanup, second publish.
weights_events = weights_gms.get_event_history().events
assert [event.kind for event in weights_events] == [
"rw_connected",
"committed",
"allocations_cleared",
"rw_connected",
"committed",
]
assert (
weights_events[2].allocation_count
== first_weights_state.allocation_count
)
# Wake should remap the preserved RO weight VAs into the new
# committed layout and recreate KV cache in a new RW layout.
assert engine.wake()["status"] == "ok"
deadline = time.monotonic() + 30.0
while True:
weights_after_wake = weights_gms.get_runtime_state()
kv_after_wake = kv_cache_gms.get_runtime_state()
if (
weights_after_wake.state == ServerState.RO
and weights_after_wake.memory_layout_hash
== first_weights_state.memory_layout_hash
and kv_after_wake.state == ServerState.RW
and kv_after_wake.allocation_count > 0
):
break
if time.monotonic() > deadline:
raise TimeoutError(
"engine wake did not restore the expected GMS state"
)
time.sleep(0.1)
# A normal inference after wake proves the remapped weights and
# recreated KV cache are usable end to end.
result = send_completion(ports["frontend"], "updated weights")
assert result["choices"]
finally:
if start_future.done():
try:
start_future.result()
except Exception:
pass
@pytest.mark.vllm
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(600)
def test_external_weight_mgr_vllm(
request,
runtime_services_dynamic_ports,
gms_ports,
predownload_models,
):
ports = gms_ports
_run_external_weight_mgr_test(
request,
ports,
"vllm",
make_engine=lambda: VLLMWithGMSProcess(
request,
"engine",
ports["shadow_system"],
ports["shadow_kv_event"],
ports["shadow_nixl"],
ports["frontend"],
read_only_weights=True,
),
)
@pytest.mark.skip(reason="Nightly CI failure: https://linear.app/nvidia/issue/DYN-2567")
@pytest.mark.sglang
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(600)
def test_external_weight_mgr_sglang(
request,
runtime_services_dynamic_ports,
gms_ports,
predownload_models,
):
ports = gms_ports
_run_external_weight_mgr_test(
request,
ports,
"sglang",
make_engine=lambda: SGLangWithGMSProcess(
request,
"engine",
ports["shadow_system"],
ports["shadow_sglang"],
ports["frontend"],
read_only_weights=True,
),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment