Unverified Commit b78ec99a authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

test: reorganize and prune GPU Memory Service tests (#7828)


Signed-off-by: default avatarSchwinn Saereesitthipitak <schwinns@nvidia.com>
parent 11bb8498
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import logging
import os
import signal
import time
from concurrent.futures import ThreadPoolExecutor
from contextlib import ExitStack
from typing import Callable
import pytest
from gpu_memory_service.server.fsm import ServerState
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import DynamoFrontendProcess, ManagedProcess
from ..harness.gms import ThreadedGMSServer
from ..harness.runtime import (
MIN_EXPECTED_MEMORY_RETURN_FRACTION,
get_gpu_memory_used,
send_completion,
)
from ..harness.sglang import SGLangWithGMSProcess
from ..harness.vllm import VLLMWithGMSProcess
pytestmark = [pytest.mark.nightly]
# Event flow under test:
# 1. Shadow A starts with committed weights and a live RW KV layout, then sleeps.
# 2. Shadow B starts from the same committed weights layout, then sleeps as well.
# 3. Primary wakes and owns the next RW KV layout.
# 4. Shadow A wakes after a forced primary disconnect and enters a new RW layout.
# 5. Shadow A blocks on allocation_oom until the still-alive primary is killed.
# 6. After primary death, the old KV layout clears and Shadow A finishes wake.
logger = logging.getLogger(__name__)
def _kill_process_group(process: ManagedProcess) -> None:
pid = process.get_pid()
if pid is None:
logger.warning("kill process group: no PID available")
return
try:
os.killpg(os.getpgid(pid), signal.SIGKILL)
except ProcessLookupError:
logger.warning("kill process group: process %d already dead", pid)
return
try:
os.waitpid(pid, 0)
except ChildProcessError:
pass
def _is_process_alive(process: ManagedProcess) -> bool:
pid = process.get_pid()
if pid is None:
return False
try:
os.kill(pid, 0)
except ProcessLookupError:
return False
return True
def _assert_weights_published_once(events) -> None:
assert [event.kind for event in events] == ["rw_connected", "committed"]
def _assert_cleared_rw_layout_prefix(events, cleared_layouts: int) -> None:
expected_prefix = [
"rw_connected",
"rw_aborted",
"allocations_cleared",
] * cleared_layouts
assert [event.kind for event in events[: len(expected_prefix)]] == expected_prefix
clear_counts = [
event.allocation_count
for event in events
if event.kind == "allocations_cleared"
]
assert len(clear_counts) >= cleared_layouts
assert all(count > 0 for count in clear_counts[:cleared_layouts])
def _sleep_shadow(
frontend_port: int,
weights_gms: ThreadedGMSServer,
kv_cache_gms: ThreadedGMSServer,
shadow: ManagedProcess,
expected_weights_hash: str | None = None,
) -> tuple[str, int, int]:
result = send_completion(frontend_port)
assert result["choices"], "Shadow inference failed"
logger.info("Shadow inference OK: %s", result)
deadline = time.monotonic() + 30.0
while True:
weights_state = weights_gms.get_runtime_state()
kv_state = kv_cache_gms.get_runtime_state()
if (
weights_state.state == ServerState.RO
and weights_state.allocation_count > 0
and weights_state.memory_layout_hash
and kv_state.state == ServerState.RW
and kv_state.allocation_count > 0
):
break
if time.monotonic() > deadline:
raise TimeoutError("shadow startup did not stabilize GMS state")
time.sleep(0.1)
if expected_weights_hash is not None:
assert weights_state.memory_layout_hash == expected_weights_hash
shadow_memory_before_sleep = get_gpu_memory_used()
assert shadow.sleep()["status"] == "ok"
shadow_memory_after_sleep = get_gpu_memory_used()
shadow_released_bytes = shadow_memory_before_sleep - shadow_memory_after_sleep
logger.info(
"Shadow sleep: %.2f -> %.2f GiB (freed %.0f MB)",
shadow_memory_before_sleep / (1 << 30),
shadow_memory_after_sleep / (1 << 30),
shadow_released_bytes / (1 << 20),
)
assert shadow_memory_after_sleep < shadow_memory_before_sleep
assert shadow_released_bytes > 0
deadline = time.monotonic() + 30.0
while True:
weights_after_sleep = weights_gms.get_runtime_state()
kv_after_sleep = kv_cache_gms.get_runtime_state()
if (
weights_after_sleep.state == ServerState.COMMITTED
and weights_after_sleep.allocation_count == weights_state.allocation_count
and weights_after_sleep.memory_layout_hash
== weights_state.memory_layout_hash
and kv_after_sleep.state == ServerState.EMPTY
and kv_after_sleep.allocation_count == 0
):
break
if time.monotonic() > deadline:
raise TimeoutError("shadow sleep did not clear GMS state")
time.sleep(0.1)
return (
weights_state.memory_layout_hash,
shadow_released_bytes,
shadow_memory_after_sleep,
)
def _run_shadow_failover_test(
request,
ports: dict,
make_shadow_a: Callable[[], ManagedProcess],
make_shadow_b: Callable[[], ManagedProcess],
make_primary: Callable[[], ManagedProcess],
) -> None:
frontend_port = ports["frontend"]
with ExitStack() as stack:
weights_gms = stack.enter_context(ThreadedGMSServer(device=0, tag="weights"))
kv_cache_gms = stack.enter_context(ThreadedGMSServer(device=0, tag="kv_cache"))
stack.enter_context(
DynamoFrontendProcess(
request,
frontend_port=frontend_port,
display_name="frontend",
)
)
with make_shadow_a() as shadow_a:
(
weights_hash,
shadow_a_released_bytes,
_shadow_a_memory_after_sleep,
) = _sleep_shadow(frontend_port, weights_gms, kv_cache_gms, shadow_a)
with make_shadow_b() as shadow_b:
(
sleeping_weights_hash,
_shadow_b_released_bytes,
sleeping_memory_after_sleep,
) = _sleep_shadow(
frontend_port,
weights_gms,
kv_cache_gms,
shadow_b,
expected_weights_hash=weights_hash,
)
assert sleeping_weights_hash == weights_hash
weights_events_after_shadow_sleep = (
weights_gms.get_event_history().events
)
_assert_weights_published_once(weights_events_after_shadow_sleep)
kv_events_after_shadow_sleep = kv_cache_gms.get_event_history().events
_assert_cleared_rw_layout_prefix(kv_events_after_shadow_sleep, 2)
with make_primary() as primary:
result = send_completion(frontend_port, "Primary test")
assert result["choices"], "Primary inference failed"
logger.info("Primary inference OK: %s", result)
primary_memory_in_use = get_gpu_memory_used()
logger.info(
"Primary active memory: %.2f GiB",
primary_memory_in_use / (1 << 30),
)
assert primary_memory_in_use > sleeping_memory_after_sleep
assert (
(primary_memory_in_use - sleeping_memory_after_sleep)
>= shadow_a_released_bytes * MIN_EXPECTED_MEMORY_RETURN_FRACTION
)
deadline = time.monotonic() + 30.0
while True:
weights_with_primary = weights_gms.get_runtime_state()
kv_with_primary = kv_cache_gms.get_runtime_state()
if (
weights_with_primary.state == ServerState.RO
and weights_with_primary.ro_session_count >= 1
and weights_with_primary.allocation_count > 0
and weights_with_primary.memory_layout_hash == weights_hash
and kv_with_primary.state == ServerState.RW
and kv_with_primary.allocation_count > 0
):
break
if time.monotonic() > deadline:
raise TimeoutError(
"primary did not acquire KV cache GMS state"
)
time.sleep(0.1)
expected_kv_kinds_before_disconnect = [
"rw_connected",
"rw_aborted",
"allocations_cleared",
"rw_connected",
"rw_aborted",
"allocations_cleared",
"rw_connected",
]
assert [
event.kind for event in kv_cache_gms.get_event_history().events
] == expected_kv_kinds_before_disconnect
with ThreadPoolExecutor(max_workers=1) as executor:
# Shadow A wakes while Shadow B remains asleep. After we
# force-disconnect the primary from GMS, Shadow A should enter
# a new RW layout but block on real CUDA OOM until the primary dies.
wake_future = executor.submit(shadow_a.wake, 180)
deadline = time.monotonic() + 10.0
while time.monotonic() < deadline:
if wake_future.done():
break
time.sleep(0.2)
assert not wake_future.done(), (
"Shadow wake completed before the primary died; "
"KV cache RW handoff did not block as expected"
)
kv_while_blocked = kv_cache_gms.get_runtime_state()
assert kv_while_blocked.state == ServerState.RW
assert kv_while_blocked.allocation_count > 0
kv_cache_gms.disconnect_rw_session()
expected_kv_kinds_while_blocked = (
expected_kv_kinds_before_disconnect
+ [
"rw_aborted",
"allocations_cleared",
"rw_connected",
"allocation_oom",
]
)
blocked_allocation_count: int | None = None
deadline = time.monotonic() + 30.0
while time.monotonic() < deadline:
kv_after_forced_disconnect = (
kv_cache_gms.get_runtime_state()
)
kv_events_after_forced_disconnect = (
kv_cache_gms.get_event_history().events
)
if (
kv_after_forced_disconnect.state == ServerState.RW
and [
event.kind
for event in kv_events_after_forced_disconnect
]
== expected_kv_kinds_while_blocked
and not wake_future.done()
):
blocked_allocation_count = (
kv_after_forced_disconnect.allocation_count
)
if (
blocked_allocation_count
< kv_while_blocked.allocation_count
and blocked_allocation_count
== kv_events_after_forced_disconnect[
-1
].allocation_count
):
break
time.sleep(0.2)
else:
raise TimeoutError(
"shadow never entered a new KV-cache layout blocked on allocation"
)
assert blocked_allocation_count is not None
linger_deadline = time.monotonic() + 3.0
while time.monotonic() < linger_deadline:
kv_while_lingering = kv_cache_gms.get_runtime_state()
kv_events_while_lingering = (
kv_cache_gms.get_event_history().events
)
assert kv_while_lingering.state == ServerState.RW
assert (
kv_while_lingering.allocation_count
== blocked_allocation_count
)
assert [
event.kind for event in kv_events_while_lingering
] == expected_kv_kinds_while_blocked
assert _is_process_alive(
primary
), "primary died before the linger window completed"
assert (
not wake_future.done()
), "shadow wake completed while the primary was still alive"
time.sleep(0.2)
primary_memory_before_kill = get_gpu_memory_used()
_kill_process_group(primary)
primary_memory_after_kill = get_gpu_memory_used()
logger.info(
"Primary kill snapshot: %.2f -> %.2f GiB",
primary_memory_before_kill / (1 << 30),
primary_memory_after_kill / (1 << 30),
)
deadline = time.monotonic() + 30.0
while time.monotonic() < deadline:
kv_after_primary_kill = kv_cache_gms.get_runtime_state()
if (
kv_after_primary_kill.state == ServerState.RW
and kv_after_primary_kill.allocation_count > 0
):
break
time.sleep(0.2)
else:
raise TimeoutError(
"shadow did not reacquire KV cache after failover"
)
wake_result = wake_future.result(timeout=180)
assert wake_result["status"] == "ok"
shadow_memory_after_wake = get_gpu_memory_used()
shadow_reacquired_bytes = (
shadow_memory_after_wake - sleeping_memory_after_sleep
)
logger.info(
"Shadow wake memory: %.2f GiB (reacquired %.0f MB)",
shadow_memory_after_wake / (1 << 30),
shadow_reacquired_bytes / (1 << 20),
)
assert shadow_memory_after_wake > sleeping_memory_after_sleep
assert (
shadow_reacquired_bytes
) >= shadow_a_released_bytes * MIN_EXPECTED_MEMORY_RETURN_FRACTION
# Once the primary is gone, the failover shadow should finish wake
# with the same committed weights layout and a new live RW KV-cache layout.
deadline = time.monotonic() + 30.0
while True:
weights_after_wake = weights_gms.get_runtime_state()
kv_after_wake = kv_cache_gms.get_runtime_state()
if (
weights_after_wake.state == ServerState.RO
and weights_after_wake.ro_session_count >= 1
and weights_after_wake.allocation_count > 0
and weights_after_wake.memory_layout_hash
== weights_with_primary.memory_layout_hash
and kv_after_wake.state == ServerState.RW
and kv_after_wake.allocation_count > 0
):
break
if time.monotonic() > deadline:
raise TimeoutError(
"shadow wake did not restore the expected GMS state"
)
time.sleep(0.1)
# The final KV history should show the full handoff:
# shadow A slept -> shadow B slept -> primary layout ->
# primary abort/clear -> shadow A reconnects -> shadow A sees OOM.
weights_events_after_wake = weights_gms.get_event_history().events
_assert_weights_published_once(weights_events_after_wake)
kv_events_after_wake = kv_cache_gms.get_event_history().events
_assert_cleared_rw_layout_prefix(kv_events_after_wake, 3)
assert [event.kind for event in kv_events_after_wake] == [
"rw_connected",
"rw_aborted",
"allocations_cleared",
"rw_connected",
"rw_aborted",
"allocations_cleared",
"rw_connected",
"rw_aborted",
"allocations_cleared",
"rw_connected",
"allocation_oom",
]
result = send_completion(frontend_port, "Post failover")
assert result["choices"], "Shadow inference after failover failed"
logger.info("Shadow inference after failover OK: %s", result)
@pytest.mark.vllm
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(600)
def test_gms_shadow_engine_failover_vllm(
request, runtime_services_dynamic_ports, gms_ports, predownload_models
):
ports = gms_ports
_run_shadow_failover_test(
request,
ports,
make_shadow_a=lambda: VLLMWithGMSProcess(
request,
"shadow-a",
ports["shadow_system"],
ports["shadow_kv_event"],
ports["shadow_nixl"],
ports["frontend"],
),
make_shadow_b=lambda: VLLMWithGMSProcess(
request,
"shadow-b",
ports["shadow2_system"],
ports["shadow2_kv_event"],
ports["shadow2_nixl"],
ports["frontend"],
),
make_primary=lambda: VLLMWithGMSProcess(
request,
"primary",
ports["primary_system"],
ports["primary_kv_event"],
ports["primary_nixl"],
ports["frontend"],
),
)
@pytest.mark.skip(reason="Nightly CI failure: https://linear.app/nvidia/issue/DYN-2567")
@pytest.mark.sglang
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(600)
def test_gms_shadow_engine_failover_sglang(
request, runtime_services_dynamic_ports, gms_ports, predownload_models
):
ports = gms_ports
_run_shadow_failover_test(
request,
ports,
make_shadow_a=lambda: SGLangWithGMSProcess(
request,
"shadow-a",
ports["shadow_system"],
ports["shadow_sglang"],
ports["frontend"],
),
make_shadow_b=lambda: SGLangWithGMSProcess(
request,
"shadow-b",
ports["shadow2_system"],
ports["shadow2_sglang"],
ports["frontend"],
),
make_primary=lambda: SGLangWithGMSProcess(
request,
"primary",
ports["primary_system"],
ports["primary_sglang"],
ports["frontend"],
),
)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import logging
import time
from contextlib import ExitStack
from typing import Callable
import pytest
from gpu_memory_service.server.fsm import ServerState
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import DynamoFrontendProcess, ManagedProcess
from ..harness.gms import GMSServerProcess
from ..harness.runtime import (
MIN_EXPECTED_MEMORY_RETURN_FRACTION,
get_gpu_memory_used,
send_completion,
)
from ..harness.sglang import SGLangWithGMSProcess
from ..harness.vllm import VLLMWithGMSProcess
pytestmark = [pytest.mark.nightly]
# Event flow under test:
# 1. Weights are published once as a committed layout.
# 2. KV cache starts as a live RW layout build.
# 3. Sleep keeps weights committed but aborts and clears the KV layout.
# 4. Wake reconnects weights as RO to the same committed layout.
# 5. Wake recreates KV cache in a fresh RW layout after the old one was cleared.
logger = logging.getLogger(__name__)
def _run_sleep_wake_test(
request,
ports: dict,
make_engine: Callable[[], ManagedProcess],
) -> None:
with ExitStack() as stack:
weights_gms = stack.enter_context(
GMSServerProcess(request, device=0, tag="weights")
)
kv_cache_gms = stack.enter_context(
GMSServerProcess(request, device=0, tag="kv_cache")
)
stack.enter_context(
DynamoFrontendProcess(request, frontend_port=ports["frontend"])
)
with make_engine() as engine:
result = send_completion(ports["frontend"])
logger.info("Initial inference result: %s", result)
assert result["choices"]
# Before sleep, weights must already be published and visible to RO
# readers while KV cache remains a live RW layout owned by the engine.
deadline = time.monotonic() + 30.0
while True:
weights_before_sleep = weights_gms.get_runtime_state()
kv_before_sleep = kv_cache_gms.get_runtime_state()
if (
weights_before_sleep.state == ServerState.RO
and weights_before_sleep.allocation_count > 0
and weights_before_sleep.memory_layout_hash
and kv_before_sleep.state == ServerState.RW
and kv_before_sleep.allocation_count > 0
):
break
if time.monotonic() > deadline:
raise TimeoutError("initial GMS state did not stabilize")
time.sleep(0.1)
mem_before = get_gpu_memory_used()
logger.info("Memory before sleep: %.0f MB", mem_before / (1 << 20))
sleep_result = engine.sleep()
assert sleep_result["status"] == "ok"
mem_after_sleep = get_gpu_memory_used()
released_bytes = mem_before - mem_after_sleep
logger.info("Memory after sleep: %.0f MB", mem_after_sleep / (1 << 20))
assert mem_after_sleep < mem_before, "Sleep should reduce memory"
assert released_bytes > 0
# Sleep preserves the committed weights layout but aborts and clears the
# mutable KV-cache layout, which is what should release GPU memory.
deadline = time.monotonic() + 30.0
while True:
weights_after_sleep = weights_gms.get_runtime_state()
kv_after_sleep = kv_cache_gms.get_runtime_state()
if (
weights_after_sleep.state == ServerState.COMMITTED
and weights_after_sleep.allocation_count
== weights_before_sleep.allocation_count
and weights_after_sleep.memory_layout_hash
== weights_before_sleep.memory_layout_hash
and kv_after_sleep.state == ServerState.EMPTY
and kv_after_sleep.allocation_count == 0
):
break
if time.monotonic() > deadline:
raise TimeoutError(
"sleep did not drive GMS into the expected state"
)
time.sleep(0.1)
# Weights are immutable across sleep/wake, so their event history should
# still be the original publish: connect once, commit once.
weights_events = weights_gms.get_event_history().events
assert [event.kind for event in weights_events] == [
"rw_connected",
"committed",
]
# KV cache is different: sleep must abort the old RW layout and clear its
# server-owned allocations before wake can start a new RW layout.
kv_events = kv_cache_gms.get_event_history().events
assert [event.kind for event in kv_events] == [
"rw_connected",
"rw_aborted",
"allocations_cleared",
]
assert kv_events[-1].allocation_count > 0
wake_result = engine.wake()
assert wake_result["status"] == "ok"
mem_after_wake = get_gpu_memory_used()
reacquired_bytes = mem_after_wake - mem_after_sleep
logger.info("Memory after wake: %.0f MB", mem_after_wake / (1 << 20))
assert mem_after_wake > mem_after_sleep, "Wake should reacquire memory"
assert (
reacquired_bytes
) >= released_bytes * MIN_EXPECTED_MEMORY_RETURN_FRACTION
# Wake reconnects weights as RO to the same committed layout, but KV cache
# must come back as a fresh RW layout with new allocations.
deadline = time.monotonic() + 30.0
while True:
weights_after_wake = weights_gms.get_runtime_state()
kv_after_wake = kv_cache_gms.get_runtime_state()
if (
weights_after_wake.state == ServerState.RO
and weights_after_wake.allocation_count
== weights_before_sleep.allocation_count
and weights_after_wake.memory_layout_hash
== weights_before_sleep.memory_layout_hash
and kv_after_wake.state == ServerState.RW
and kv_after_wake.allocation_count > 0
):
break
if time.monotonic() > deadline:
raise TimeoutError("wake did not restore the expected GMS state")
time.sleep(0.1)
weights_events_after_wake = weights_gms.get_event_history().events
assert [event.kind for event in weights_events_after_wake] == [
"rw_connected",
"committed",
]
# The wake history should therefore extend the old KV sequence with one
# new RW connect after the previous layout was fully cleared.
kv_events_after_wake = kv_cache_gms.get_event_history().events
assert [event.kind for event in kv_events_after_wake] == [
"rw_connected",
"rw_aborted",
"allocations_cleared",
"rw_connected",
]
assert kv_events_after_wake[2].allocation_count > 0
result = send_completion(ports["frontend"], "Goodbye")
logger.info("Post-wake inference result: %s", result)
assert result["choices"]
logger.info(
"Memory freed: %.0f MB", (mem_before - mem_after_sleep) / (1 << 20)
)
@pytest.mark.vllm
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(300)
def test_gms_basic_sleep_wake_vllm(
request,
runtime_services_dynamic_ports,
gms_ports,
predownload_models,
):
ports = gms_ports
_run_sleep_wake_test(
request,
ports,
make_engine=lambda: VLLMWithGMSProcess(
request,
"engine",
ports["shadow_system"],
ports["shadow_kv_event"],
ports["shadow_nixl"],
ports["frontend"],
),
)
@pytest.mark.skip(reason="Nightly CI failure: https://linear.app/nvidia/issue/DYN-2567")
@pytest.mark.sglang
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(300)
def test_gms_basic_sleep_wake_sglang(
request,
runtime_services_dynamic_ports,
gms_ports,
predownload_models,
):
ports = gms_ports
_run_sleep_wake_test(
request,
ports,
make_engine=lambda: SGLangWithGMSProcess(
request,
"engine",
ports["shadow_system"],
ports["shadow_sglang"],
ports["frontend"],
),
)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from contextlib import contextmanager
from types import SimpleNamespace
import pytest
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.gpu_0,
pytest.mark.vllm,
]
class _FakeManager:
def __init__(self, *, is_unmapped: bool = False) -> None:
self.is_unmapped = is_unmapped
self.calls: list[object] = []
def unmap_all_vas(self) -> None:
self.calls.append("unmap_all_vas")
self.is_unmapped = True
def abort(self) -> None:
self.calls.append("abort")
def connect(self, lock_type, timeout_ms=None) -> None:
self.calls.append(("connect", lock_type.value))
self.is_unmapped = False
def reallocate_all_handles(self, *, tag: str) -> None:
self.calls.append(("reallocate_all_handles", tag))
def remap_all_vas(self) -> None:
self.calls.append("remap_all_vas")
self.is_unmapped = False
def test_initialize_from_config_uses_kv_cache_gms_tag(monkeypatch):
import gpu_memory_service.integrations.vllm.worker as worker_module
import vllm.distributed.kv_transfer as kv_transfer
from gpu_memory_service.integrations.vllm.worker import GMSWorker
create_calls: list[tuple[object, ...]] = []
pool_calls: list[tuple[str, str]] = []
kv_transfer_calls: list[object] = []
kv_init_calls: list[object] = []
@contextmanager
def fake_use_mem_pool(tag, device):
pool_calls.append((tag, str(device)))
yield
def fake_get_or_create(socket_path, device, mode, *, tag, timeout_ms=None):
create_calls.append((socket_path, device, mode.value, tag, timeout_ms))
return object()
monkeypatch.setattr(worker_module, "gms_use_mem_pool", fake_use_mem_pool)
monkeypatch.setattr(
worker_module,
"get_or_create_gms_client_memory_manager",
fake_get_or_create,
)
monkeypatch.setattr(
worker_module,
"get_socket_path",
lambda device, tag: f"/tmp/{tag}-{device}.sock",
)
monkeypatch.setattr(
kv_transfer,
"ensure_kv_transfer_initialized",
lambda vllm_config, kv_cache_config: kv_transfer_calls.append(kv_cache_config),
)
worker = object.__new__(GMSWorker)
worker.local_rank = 3
worker.vllm_config = SimpleNamespace(
model_config=SimpleNamespace(enable_sleep_mode=True)
)
worker.model_runner = SimpleNamespace(
initialize_kv_cache=lambda kv_cache_config: kv_init_calls.append(
kv_cache_config
)
)
worker.initialize_from_config("kv-config")
assert create_calls == [("/tmp/kv_cache-3.sock", 3, "rw", "kv_cache", None)]
assert pool_calls == [("kv_cache", "cuda:3")]
assert kv_transfer_calls == ["kv-config"]
assert kv_init_calls == ["kv-config"]
def test_sleep_level_2_unmaps_weights_and_kv_cache(monkeypatch):
import gpu_memory_service.integrations.vllm.worker as worker_module
from gpu_memory_service.integrations.vllm.worker import GMSWorker
weights = _FakeManager()
kv_cache = _FakeManager()
monkeypatch.setattr(
worker_module,
"get_gms_client_memory_manager",
lambda tag: weights if tag == "weights" else kv_cache,
)
monkeypatch.setattr(
worker_module.torch.cuda,
"mem_get_info",
lambda: (2 << 30, 8 << 30),
)
worker = object.__new__(GMSWorker)
worker.sleep(level=2)
assert weights.calls == ["unmap_all_vas", "abort"]
assert kv_cache.calls == ["unmap_all_vas", "abort"]
def test_wake_up_remaps_weights_and_reallocates_kv_cache(monkeypatch):
import gpu_memory_service.integrations.vllm.worker as worker_module
from gpu_memory_service.integrations.vllm.worker import GMSWorker
weights = _FakeManager(is_unmapped=True)
kv_cache = _FakeManager(is_unmapped=True)
fp8_calls: list[str] = []
monkeypatch.setattr(
worker_module,
"get_gms_client_memory_manager",
lambda tag: weights if tag == "weights" else kv_cache,
)
worker = object.__new__(GMSWorker)
worker.local_rank = 0
worker.cache_config = SimpleNamespace(cache_dtype="fp8_e4m3")
worker.model_runner = SimpleNamespace(
kv_caches={"layer_0": True},
init_fp8_kv_scales=lambda: fp8_calls.append("fp8"),
)
worker.wake_up(["weights", "kv_cache"])
assert weights.calls == [
("connect", "ro"),
"remap_all_vas",
]
assert kv_cache.calls == [
("connect", "rw"),
("reallocate_all_handles", "kv_cache"),
"remap_all_vas",
]
assert fp8_calls == ["fp8"]
def test_maybe_get_memory_pool_context_routes_tags(monkeypatch):
import gpu_memory_service.integrations.vllm.worker as worker_module
from gpu_memory_service.integrations.vllm.worker import GMSWorker, Worker
kv_cache_context = object()
super_calls: list[str] = []
mem_pool_calls: list[tuple[str, str]] = []
def fake_use_mem_pool(tag, device):
mem_pool_calls.append((tag, str(device)))
return kv_cache_context
def fake_super_context(self, tag):
del self
super_calls.append(tag)
return f"super:{tag}"
monkeypatch.setattr(worker_module, "gms_use_mem_pool", fake_use_mem_pool)
monkeypatch.setattr(Worker, "_maybe_get_memory_pool_context", fake_super_context)
worker = object.__new__(GMSWorker)
worker.local_rank = 2
weights_context = worker._maybe_get_memory_pool_context("weights")
with weights_context:
pass
assert mem_pool_calls == []
assert super_calls == []
assert worker._maybe_get_memory_pool_context("kv_cache") is kv_cache_context
assert mem_pool_calls == [("kv_cache", "cuda:2")]
assert super_calls == []
assert worker._maybe_get_memory_pool_context("other") == "super:other"
assert super_calls == ["other"]
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Socket-level GMS helpers for the cross-component test suite."""
from __future__ import annotations
import asyncio
import os
import socket
import threading
import time
from typing import TYPE_CHECKING
from gpu_memory_service.common.locks import RequestedLockType
from gpu_memory_service.common.utils import get_socket_path
if TYPE_CHECKING:
from gpu_memory_service.common.protocol.messages import (
GetEventHistoryResponse,
GetRuntimeStateResponse,
ListAllocationsResponse,
)
_SERVER_START_TIMEOUT_SECONDS = 30.0
_SERVER_STOP_TIMEOUT_SECONDS = 5.0
_POLL_INTERVAL_SECONDS = 0.1
def _request_gms(
socket_path: str,
request,
response_type,
*,
lock_type: RequestedLockType | None = None,
timeout_ms: int | None = None,
):
"""Send one raw request over a Unix socket, with optional lock handshake."""
from gpu_memory_service.common.protocol.messages import (
ErrorResponse,
HandshakeRequest,
HandshakeResponse,
)
from gpu_memory_service.common.protocol.wire import (
recv_message_sync,
send_message_sync,
)
recv_buffer = bytearray()
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
sock.connect(socket_path)
if lock_type is not None:
send_message_sync(
sock,
HandshakeRequest(lock_type=lock_type, timeout_ms=timeout_ms),
)
handshake, fd, recv_buffer = recv_message_sync(sock, recv_buffer)
if fd >= 0:
os.close(fd)
raise RuntimeError("GMS handshake returned an unexpected FD")
if isinstance(handshake, ErrorResponse):
raise RuntimeError(f"GMS handshake error: {handshake.error}")
if not isinstance(handshake, HandshakeResponse):
raise RuntimeError(
f"Unexpected handshake response: {type(handshake).__name__}"
)
if not handshake.success:
raise TimeoutError("Timeout waiting for GMS lock")
send_message_sync(sock, request)
response, fd, recv_buffer = recv_message_sync(sock, recv_buffer)
if isinstance(response, ErrorResponse):
raise RuntimeError(f"GMS request error: {response.error}")
if not isinstance(response, response_type):
raise RuntimeError(f"Unexpected response type: {type(response).__name__}")
if fd >= 0:
os.close(fd)
raise RuntimeError(
f"GMS request {type(request).__name__} returned an unexpected FD"
)
return response
finally:
sock.close()
def list_allocations(socket_path: str) -> ListAllocationsResponse:
from gpu_memory_service.common.protocol.messages import (
ListAllocationsRequest,
ListAllocationsResponse,
)
return _request_gms(
socket_path,
ListAllocationsRequest(),
ListAllocationsResponse,
lock_type=RequestedLockType.RO,
)
class GMSServer:
"""In-process GMS server wrapper."""
def __init__(self, device: int, tag: str = "weights"):
from gpu_memory_service.server.rpc import GMSRPCServer
self.socket_path = get_socket_path(device, tag)
self.server = GMSRPCServer(self.socket_path, device=device)
self._loop: asyncio.AbstractEventLoop | None = None
self._task: asyncio.Task[None] | None = None
self._thread = threading.Thread(target=self._run, daemon=True)
self._exception: BaseException | None = None
def _run(self) -> None:
loop = asyncio.new_event_loop()
self._loop = loop
asyncio.set_event_loop(loop)
self._task = loop.create_task(self.server.serve())
try:
loop.run_until_complete(self._task)
except asyncio.CancelledError:
pass
except BaseException as exc:
self._exception = exc
finally:
pending = [task for task in asyncio.all_tasks(loop) if not task.done()]
for task in pending:
task.cancel()
if pending:
loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True)
)
loop.close()
def __enter__(self):
if os.path.exists(self.socket_path):
try:
self.get_runtime_state()
except OSError:
if os.path.exists(self.socket_path):
os.unlink(self.socket_path)
else:
raise RuntimeError(f"GMS already active at {self.socket_path}")
self._thread.start()
deadline = time.monotonic() + _SERVER_START_TIMEOUT_SECONDS
last_probe_error: OSError | None = None
while True:
if self._exception is not None:
raise self._exception
if self.server._server is not None and os.path.exists(self.socket_path):
try:
self.get_runtime_state()
return self
except OSError as exc:
last_probe_error = exc
if time.monotonic() > deadline:
timeout_error = TimeoutError(
f"GMS socket did not appear at {self.socket_path}"
)
if last_probe_error is not None:
raise timeout_error from last_probe_error
raise timeout_error
time.sleep(_POLL_INTERVAL_SECONDS)
def __exit__(self, exc_type, exc_val, exc_tb):
if self._loop is not None:
def cancel() -> None:
if self.server._server is not None:
self.server._server.close()
if self._task is not None:
self._task.cancel()
self._loop.call_soon_threadsafe(cancel)
self._thread.join(timeout=_SERVER_STOP_TIMEOUT_SECONDS)
if self._thread.is_alive():
raise RuntimeError(
f"GMS server thread failed to stop for {self.socket_path}"
)
if self._exception is not None:
raise self._exception
if os.path.exists(self.socket_path):
os.unlink(self.socket_path)
def get_runtime_state(self) -> GetRuntimeStateResponse:
from gpu_memory_service.common.protocol.messages import (
GetRuntimeStateRequest,
GetRuntimeStateResponse,
)
return _request_gms(
self.socket_path,
GetRuntimeStateRequest(),
GetRuntimeStateResponse,
)
def get_event_history(self) -> GetEventHistoryResponse:
from gpu_memory_service.common.protocol.messages import (
GetEventHistoryRequest,
GetEventHistoryResponse,
)
return _request_gms(
self.socket_path,
GetEventHistoryRequest(),
GetEventHistoryResponse,
)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Shared process orchestration for the cross-component GMS scenarios."""
from __future__ import annotations
import json
import logging
import os
import sys
from abc import ABC, abstractmethod
from contextlib import ExitStack
import pynvml
import requests
from tests.gpu_memory_service.common.gms import GMSServer
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME, DefaultPort
from tests.utils.engine_process import EngineProcess
from tests.utils.managed_process import DynamoFrontendProcess
from tests.utils.payloads import check_health_generate, check_models_api
from tests.utils.port_utils import allocate_ports, deallocate_ports
logger = logging.getLogger(__name__)
def get_gpu_memory_used(device: int = 0) -> int:
pynvml.nvmlInit()
try:
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
return pynvml.nvmlDeviceGetMemoryInfo(handle).used
finally:
pynvml.nvmlShutdown()
class GMSProcessManager:
"""Start the shared GMS daemons and frontend for one test scenario."""
def __init__(
self,
request,
engine_cls,
*,
read_only_weights: bool = False,
):
self._request = request
self._engine_cls = engine_cls
self._read_only_weights = read_only_weights
self._stack: ExitStack | None = None
self.frontend_port: int | None = None
self.weights_gms = None
self.kv_cache_gms = None
self._engine_ids: set[str] = set()
self.engines: dict[str, GMSEngineProcess] = {}
def __enter__(self):
stack = ExitStack()
try:
self.weights_gms = stack.enter_context(GMSServer(device=0, tag="weights"))
self.kv_cache_gms = stack.enter_context(GMSServer(device=0, tag="kv_cache"))
frontend = stack.enter_context(
DynamoFrontendProcess(
self._request,
frontend_port=0,
display_name="frontend",
)
)
except Exception:
stack.close()
raise
self._stack = stack
self.frontend_port = frontend.frontend_port
return self
def __exit__(self, exc_type, exc_val, exc_tb):
stack = self._stack
self._stack = None
self.frontend_port = None
self.weights_gms = None
self.kv_cache_gms = None
self._engine_ids.clear()
self.engines.clear()
if stack is None:
return False
return stack.__exit__(exc_type, exc_val, exc_tb)
def create_engine(
self,
engine_id: str,
*,
read_only_weights: bool | None = None,
):
if self._stack is None or self.frontend_port is None:
raise RuntimeError(
"GMSProcessManager must be entered before creating engines"
)
if engine_id in self._engine_ids:
raise ValueError(f"engine {engine_id!r} already requested")
if read_only_weights is None:
read_only_weights = self._read_only_weights
engine = self._engine_cls(
self._request,
self.frontend_port,
engine_id=engine_id,
read_only_weights=read_only_weights,
)
self._engine_ids.add(engine_id)
return engine
def start_engine(
self,
engine_id: str,
*,
read_only_weights: bool | None = None,
):
if self._stack is None:
raise RuntimeError(
"GMSProcessManager must be entered before starting engines"
)
engine = self._stack.enter_context(
self.create_engine(engine_id, read_only_weights=read_only_weights)
)
self.engines[engine_id] = engine
return engine
class GMSEngineProcess(EngineProcess, ABC):
"""Backend process wrapper with a common quiesce/resume surface."""
quiesce_route: str
resume_route: str
def __init__(
self,
request,
engine_id: str,
system_port: int,
frontend_port: int,
reserved_ports: list[int],
*,
read_only_weights: bool = False,
):
self.engine_id = engine_id
self.system_port = system_port
self._reserved_ports = reserved_ports
self.read_only_weights = read_only_weights
super().__init__(
command=self.command(),
env={
**os.environ,
"DYN_LOG": "debug",
"DYN_SYSTEM_PORT": str(system_port),
**self.env_updates(),
},
health_check_urls=[
(f"http://localhost:{system_port}/health", self._is_ready),
(f"http://localhost:{frontend_port}/v1/models", check_models_api),
(f"http://localhost:{frontend_port}/health", check_health_generate),
],
timeout=300,
display_output=True,
terminate_all_matching_process_names=False,
stragglers=[],
log_dir=f"{request.node.name}_{engine_id}",
display_name=engine_id,
)
@abstractmethod
def command(self) -> list[str]:
raise NotImplementedError
def env_updates(self) -> dict[str, str]:
return {}
def model_loader_extra_config(self) -> str | None:
if not self.read_only_weights:
return None
return json.dumps({"gms_read_only": True})
@abstractmethod
def quiesce_payload(self) -> dict:
raise NotImplementedError
def resume_payload(self) -> dict:
return {}
def _is_ready(self, response) -> bool:
try:
return response.json().get("status") == "ready"
except ValueError:
return False
def _request_engine(
self,
route: str,
payload: dict,
timeout: int,
action: str,
) -> dict:
response = requests.post(
f"http://localhost:{self.system_port}/engine/{route}",
json=payload,
timeout=timeout,
)
response.raise_for_status()
result = response.json()
logger.info("%s %s: %s", self.engine_id, action, result)
return result
def quiesce(self) -> dict:
return self._request_engine(
self.quiesce_route,
self.quiesce_payload(),
30,
"quiesce",
)
def resume(self, timeout: int = 30) -> dict:
return self._request_engine(
self.resume_route,
self.resume_payload(),
timeout,
"resume",
)
def __exit__(self, exc_type, exc_val, exc_tb):
try:
return super().__exit__(exc_type, exc_val, exc_tb)
finally:
deallocate_ports(self._reserved_ports)
class VLLMWithGMSProcess(GMSEngineProcess):
quiesce_route = "sleep"
resume_route = "wake_up"
def __init__(
self,
request,
frontend_port: int,
*,
engine_id: str,
read_only_weights: bool = False,
):
reserved_ports = allocate_ports(3, DefaultPort.SYSTEM1.value)
self.kv_event_port = reserved_ports[1]
self.nixl_port = reserved_ports[2]
try:
super().__init__(
request,
engine_id,
reserved_ports[0],
frontend_port,
reserved_ports,
read_only_weights=read_only_weights,
)
except Exception:
deallocate_ports(reserved_ports)
raise
def env_updates(self) -> dict[str, str]:
return {"VLLM_NIXL_SIDE_CHANNEL_PORT": str(self.nixl_port)}
def command(self) -> list[str]:
kv_events_cfg = json.dumps(
{
"publisher": "zmq",
"topic": "kv-events",
"endpoint": f"tcp://*:{self.kv_event_port}",
"enable_kv_cache_events": True,
}
)
command = [
sys.executable,
"-m",
"dynamo.vllm",
"--model",
FAULT_TOLERANCE_MODEL_NAME,
"--load-format",
"gms",
"--enforce-eager",
"--enable-sleep-mode",
"--max-num-seqs",
"1",
"--gpu-memory-utilization",
"0.9",
"--kv-events-config",
kv_events_cfg,
]
extra_config = self.model_loader_extra_config()
if extra_config is not None:
command.extend(
[
"--model-loader-extra-config",
extra_config,
]
)
return command
def quiesce_payload(self) -> dict:
return {"level": 2}
class SGLangWithGMSProcess(GMSEngineProcess):
quiesce_route = "release_memory_occupation"
resume_route = "resume_memory_occupation"
def __init__(
self,
request,
frontend_port: int,
*,
engine_id: str,
read_only_weights: bool = False,
):
reserved_ports = allocate_ports(2, DefaultPort.SYSTEM1.value)
self.serve_port = reserved_ports[1]
try:
super().__init__(
request,
engine_id,
reserved_ports[0],
frontend_port,
reserved_ports,
read_only_weights=read_only_weights,
)
except Exception:
deallocate_ports(reserved_ports)
raise
def command(self) -> list[str]:
command = [
sys.executable,
"-m",
"dynamo.sglang",
"--model-path",
FAULT_TOLERANCE_MODEL_NAME,
"--load-format",
"gms",
"--enable-memory-saver",
"--disable-cuda-graph",
"--mem-fraction-static",
"0.9",
"--port",
str(self.serve_port),
]
extra_config = self.model_loader_extra_config()
if extra_config is not None:
command.extend(
[
"--model-loader-extra-config",
extra_config,
]
)
return command
def env_updates(self) -> dict[str, str]:
return {"NVCC_PREPEND_FLAGS": "-ccbin /usr/bin/g++"}
def quiesce_payload(self) -> dict:
return {}
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import logging
import time
import requests
from gpu_memory_service.server.fsm import ServerState
from tests.gpu_memory_service.common.runtime import get_gpu_memory_used
from tests.utils.client import send_request
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.payloads import CompletionPayload
logger = logging.getLogger(__name__)
MIN_EXPECTED_MEMORY_RESTORE_FRACTION = 0.9
def assert_completion_ok(
frontend_port: int,
prompt: str,
*,
failure_message: str,
success_message: str,
retry_timeout: float = 0.0,
retry_interval: float = 1.0,
):
completion = CompletionPayload(
body={
"model": FAULT_TOLERANCE_MODEL_NAME,
"prompt": prompt,
"max_tokens": 20,
},
expected_response=[],
expected_log=[],
timeout=120,
port=frontend_port,
)
deadline = time.monotonic() + retry_timeout
while True:
response = send_request(
url=completion.url(),
payload=completion.body,
timeout=completion.timeout,
method=completion.method,
)
try:
completion.process_response(response)
result = response.json()
if not isinstance(result, dict) or not result.get("choices"):
raise AssertionError(failure_message)
logger.info("%s: %s", success_message, result)
return
except (AssertionError, KeyError, requests.RequestException, ValueError):
if time.monotonic() >= deadline:
raise
time.sleep(retry_interval)
def assert_memory_restored_after_quiesce(
label: str,
quiesced_memory: int,
active_memory: int,
released_bytes: int,
*,
min_fraction: float = MIN_EXPECTED_MEMORY_RESTORE_FRACTION,
) -> None:
restored_bytes = active_memory - quiesced_memory
logger.info(
"%s: %.2f GiB (restored %.0f MB)",
label,
active_memory / (1 << 30),
restored_bytes / (1 << 20),
)
assert active_memory > quiesced_memory
assert restored_bytes >= released_bytes * min_fraction
def quiesce_engine(
weights_gms,
kv_cache_gms,
engine,
*,
quiesce_label: str,
expected_weights_hash: str | None = None,
):
weights_state, _ = wait_for_active_layout(
weights_gms,
kv_cache_gms,
expected_weights_hash=expected_weights_hash,
)
memory_before_quiesce = get_gpu_memory_used()
assert engine.quiesce()["status"] == "ok"
memory_after_quiesce = get_gpu_memory_used()
released_bytes = memory_before_quiesce - memory_after_quiesce
logger.info(
"%s: %.2f -> %.2f GiB (freed %.0f MB)",
quiesce_label,
memory_before_quiesce / (1 << 30),
memory_after_quiesce / (1 << 30),
released_bytes / (1 << 20),
)
assert memory_after_quiesce < memory_before_quiesce
assert released_bytes > 0
wait_for_quiesced_layout(weights_gms, kv_cache_gms, weights_state)
return weights_state, released_bytes, memory_after_quiesce
def wait_for_active_layout(
weights_gms,
kv_cache_gms,
*,
expected_weights_hash: str | None = None,
min_weight_ro_sessions: int = 0,
timeout: float = 30.0,
):
deadline = time.monotonic() + timeout
while True:
weights_state = weights_gms.get_runtime_state()
kv_state = kv_cache_gms.get_runtime_state()
if (
weights_state.state == ServerState.RO
and weights_state.ro_session_count >= min_weight_ro_sessions
and weights_state.allocation_count > 0
and weights_state.memory_layout_hash
and kv_state.state == ServerState.RW
and kv_state.allocation_count > 0
):
if (
expected_weights_hash is None
or weights_state.memory_layout_hash == expected_weights_hash
):
return weights_state, kv_state
if time.monotonic() > deadline:
raise TimeoutError("GMS state did not reach the active layout")
time.sleep(0.1)
def wait_for_quiesced_layout(
weights_gms,
kv_cache_gms,
weights_state_before_quiesce,
*,
require_no_ro_sessions: bool = False,
timeout: float = 30.0,
):
deadline = time.monotonic() + timeout
while True:
weights_after_quiesce = weights_gms.get_runtime_state()
kv_after_quiesce = kv_cache_gms.get_runtime_state()
if (
weights_after_quiesce.state == ServerState.COMMITTED
and weights_after_quiesce.allocation_count
== weights_state_before_quiesce.allocation_count
and weights_after_quiesce.memory_layout_hash
== weights_state_before_quiesce.memory_layout_hash
and kv_after_quiesce.state == ServerState.EMPTY
and kv_after_quiesce.allocation_count == 0
):
if (
not require_no_ro_sessions
or weights_after_quiesce.ro_session_count == 0
):
return weights_after_quiesce, kv_after_quiesce
if time.monotonic() > deadline:
raise TimeoutError("GMS state did not reach the quiesced layout")
time.sleep(0.1)
def wait_for_resumed_layout(
weights_gms,
kv_cache_gms,
weights_state_before_quiesce,
*,
min_weight_ro_sessions: int = 0,
timeout: float = 30.0,
):
deadline = time.monotonic() + timeout
while True:
weights_after_resume = weights_gms.get_runtime_state()
kv_after_resume = kv_cache_gms.get_runtime_state()
if (
weights_after_resume.state == ServerState.RO
and weights_after_resume.ro_session_count >= min_weight_ro_sessions
and weights_after_resume.allocation_count
== weights_state_before_quiesce.allocation_count
and weights_after_resume.memory_layout_hash
== weights_state_before_quiesce.memory_layout_hash
and kv_after_resume.state == ServerState.RW
and kv_after_resume.allocation_count > 0
):
return weights_after_resume, kv_after_resume
if time.monotonic() > deadline:
raise TimeoutError("GMS state did not reach the resumed layout")
time.sleep(0.1)
def assert_weights_published_once(events) -> None:
assert [event.kind for event in events] == ["rw_connected", "committed"]
def assert_kv_history(
events,
*,
cleared_layouts: int,
suffix: list[str] | None = None,
) -> None:
expected_kinds = [
"rw_connected",
"rw_aborted",
"allocations_cleared",
] * cleared_layouts
if suffix is not None:
expected_kinds.extend(suffix)
assert [event.kind for event in events] == expected_kinds
clear_counts = [
event.allocation_count
for event in events
if event.kind == "allocations_cleared"
]
assert len(clear_counts) >= cleared_layouts
assert all(count > 0 for count in clear_counts[:cleared_layouts])
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import logging
import pytest
from tests.gpu_memory_service.common.runtime import (
GMSProcessManager,
SGLangWithGMSProcess,
VLLMWithGMSProcess,
get_gpu_memory_used,
)
from tests.gpu_memory_service.flow_assertions import (
assert_completion_ok,
assert_kv_history,
assert_memory_restored_after_quiesce,
assert_weights_published_once,
quiesce_engine,
wait_for_resumed_layout,
)
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
pytestmark = [pytest.mark.nightly, pytest.mark.fault_tolerance]
# Event flow under test:
# 1. Weights are published once as a committed layout.
# 2. KV cache starts as a live RW layout build.
# 3. Quiesce keeps weights committed but aborts and clears the KV layout.
# 4. Resume reconnects weights as RO to the same committed layout.
# 5. Resume recreates KV cache in a fresh RW layout after the old one was cleared.
logger = logging.getLogger(__name__)
def _run_quiesce_resume_test(
request,
engine_cls,
) -> None:
with GMSProcessManager(request, engine_cls) as manager:
frontend_port = manager.frontend_port
weights_gms = manager.weights_gms
kv_cache_gms = manager.kv_cache_gms
engine = manager.start_engine("engine")
assert_completion_ok(
frontend_port,
"Hello",
failure_message="Initial inference failed",
success_message="Initial inference result",
)
# Before quiesce, weights must already be published and visible to RO
# readers while KV cache remains a live RW layout owned by the engine.
weights_before_quiesce, released_bytes, mem_after_quiesce = quiesce_engine(
weights_gms,
kv_cache_gms,
engine,
quiesce_label="Engine quiesce",
)
# Weights are immutable across quiesce/resume, so their event history should
# still be the original publish: connect once, commit once.
weights_events = weights_gms.get_event_history().events
assert_weights_published_once(weights_events)
# KV cache is different: quiesce must abort the old RW layout and clear
# its server-owned allocations before resume can start a new RW layout.
kv_events = kv_cache_gms.get_event_history().events
assert_kv_history(kv_events, cleared_layouts=1)
assert kv_events[-1].allocation_count > 0
resume_result = engine.resume()
assert resume_result["status"] == "ok"
mem_after_resume = get_gpu_memory_used()
assert_memory_restored_after_quiesce(
"Memory after resume",
mem_after_quiesce,
mem_after_resume,
released_bytes,
)
# Resume reconnects weights as RO to the same committed layout, but KV cache
# must come back as a fresh RW layout with new allocations.
wait_for_resumed_layout(
weights_gms,
kv_cache_gms,
weights_before_quiesce,
)
weights_events_after_resume = weights_gms.get_event_history().events
assert_weights_published_once(weights_events_after_resume)
# The resume history should therefore extend the old KV sequence with one
# new RW connect after the previous layout was fully cleared.
kv_events_after_resume = kv_cache_gms.get_event_history().events
assert_kv_history(
kv_events_after_resume,
cleared_layouts=1,
suffix=["rw_connected"],
)
assert kv_events_after_resume[2].allocation_count > 0
assert_completion_ok(
frontend_port,
"Goodbye",
failure_message="Post-resume inference failed",
success_message="Post-resume inference result",
)
logger.info("Memory freed: %.0f MB", released_bytes / (1 << 20))
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(300)
@pytest.mark.vllm
def test_gms_basic_quiesce_resume_vllm(
request,
runtime_services_dynamic_ports,
predownload_models,
):
_run_quiesce_resume_test(request, VLLMWithGMSProcess)
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(300)
@pytest.mark.sglang
def test_gms_basic_quiesce_resume_sglang(
request,
runtime_services_dynamic_ports,
predownload_models,
):
_run_quiesce_resume_test(request, SGLangWithGMSProcess)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import logging
import os
import signal
import time
from concurrent.futures import ThreadPoolExecutor
import pytest
from gpu_memory_service.server.fsm import ServerState
from tests.gpu_memory_service.common.runtime import (
GMSProcessManager,
SGLangWithGMSProcess,
VLLMWithGMSProcess,
get_gpu_memory_used,
)
from tests.gpu_memory_service.flow_assertions import (
assert_completion_ok,
assert_kv_history,
assert_memory_restored_after_quiesce,
assert_weights_published_once,
quiesce_engine,
wait_for_active_layout,
wait_for_resumed_layout,
)
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import ManagedProcess
pytestmark = [pytest.mark.nightly, pytest.mark.fault_tolerance]
# Event flow under test:
# 1. Shadow A starts as the initial weights publisher, then quiesces.
# 2. Shadow B starts in read-only mode from the committed weights layout, then quiesces.
# 3. Primary starts in read-only mode and owns the next RW KV layout.
# 4. Shadow A tries to resume while primary still owns the KV-cache RW layout.
# 5. Primary is SIGKILLed; the old KV session clears before its GPU memory is reclaimed.
# 6. Shadow A enters a new RW KV layout, hits allocation_oom, then finishes resume.
logger = logging.getLogger(__name__)
def _kill_process_group(process: ManagedProcess) -> None:
pid = process.get_pid()
if pid is None:
logger.warning("kill process group: no PID available")
return
memory_before_kill = get_gpu_memory_used()
try:
os.killpg(os.getpgid(pid), signal.SIGKILL)
except ProcessLookupError:
logger.warning("kill process group: process %d already dead", pid)
return
try:
os.waitpid(pid, 0)
except ChildProcessError:
pass
memory_after_kill = get_gpu_memory_used()
logger.info(
"Primary kill snapshot: %.2f -> %.2f GiB",
memory_before_kill / (1 << 30),
memory_after_kill / (1 << 30),
)
def _start_primary(
manager,
frontend_port: int,
weights_gms,
kv_cache_gms,
*,
weights_hash: str,
quiesced_memory_after_shadow_b: int,
shadow_a_released_bytes: int,
):
primary = manager.start_engine("primary", read_only_weights=True)
assert_completion_ok(
frontend_port,
"Primary test",
failure_message="Primary inference failed",
success_message="Primary inference OK",
)
primary_memory_in_use = get_gpu_memory_used()
assert_memory_restored_after_quiesce(
"Primary active memory",
quiesced_memory_after_shadow_b,
primary_memory_in_use,
shadow_a_released_bytes,
)
weights_with_primary, _ = wait_for_active_layout(
weights_gms,
kv_cache_gms,
expected_weights_hash=weights_hash,
min_weight_ro_sessions=1,
)
assert_kv_history(
kv_cache_gms.get_event_history().events,
cleared_layouts=2,
suffix=["rw_connected"],
)
return primary, weights_with_primary
def _wait_for_blocked_resume_layout(
kv_cache_gms,
resume_future,
previous_allocation_count: int,
expected_kinds: list[str],
) -> int:
deadline = time.monotonic() + 30.0
while time.monotonic() < deadline:
kv_runtime_state = kv_cache_gms.get_runtime_state()
kv_events = kv_cache_gms.get_event_history().events
if (
kv_runtime_state.state == ServerState.RW
and [event.kind for event in kv_events] == expected_kinds
and not resume_future.done()
):
blocked_allocation_count = kv_runtime_state.allocation_count
if (
blocked_allocation_count < previous_allocation_count
and blocked_allocation_count == kv_events[-1].allocation_count
):
return blocked_allocation_count
time.sleep(0.2)
raise TimeoutError(
"shadow never entered a new KV-cache layout blocked on allocation"
)
def _resume_shadow_after_primary_failover(
shadow: ManagedProcess,
kv_cache_gms,
primary: ManagedProcess,
):
expected_kv_kinds_while_blocked = [
"rw_connected",
"rw_aborted",
"allocations_cleared",
] * 3 + ["rw_connected", "allocation_oom"]
with ThreadPoolExecutor(max_workers=1) as executor:
resume_future = executor.submit(shadow.resume, 180)
deadline = time.monotonic() + 10.0
while time.monotonic() < deadline:
if resume_future.done():
break
time.sleep(0.2)
assert not resume_future.done(), (
"Shadow resume completed before the primary died; "
"KV cache RW handoff did not block as expected"
)
kv_with_primary = kv_cache_gms.get_runtime_state()
assert kv_with_primary.state == ServerState.RW
assert kv_with_primary.allocation_count > 0
_kill_process_group(primary)
_wait_for_blocked_resume_layout(
kv_cache_gms,
resume_future,
kv_with_primary.allocation_count,
expected_kv_kinds_while_blocked,
)
deadline = time.monotonic() + 30.0
while time.monotonic() < deadline:
kv_after_primary_kill = kv_cache_gms.get_runtime_state()
if (
kv_after_primary_kill.state == ServerState.RW
and kv_after_primary_kill.allocation_count > 0
):
break
time.sleep(0.2)
else:
raise TimeoutError("shadow did not reacquire KV cache after failover")
return resume_future.result(timeout=180)
def _run_shadow_failover_test(
request,
engine_cls,
) -> None:
with GMSProcessManager(request, engine_cls) as manager:
frontend_port = manager.frontend_port
weights_gms = manager.weights_gms
kv_cache_gms = manager.kv_cache_gms
shadow_a = manager.start_engine(
"shadow-a",
)
assert_completion_ok(
frontend_port,
"Hello",
failure_message="Shadow inference failed",
success_message="Shadow inference OK",
)
(
weights_state_after_shadow_a,
shadow_a_released_bytes,
_,
) = quiesce_engine(
weights_gms,
kv_cache_gms,
shadow_a,
quiesce_label="Shadow quiesce",
)
weights_hash = weights_state_after_shadow_a.memory_layout_hash
shadow_b = manager.start_engine(
"shadow-b",
read_only_weights=True,
)
assert_completion_ok(
frontend_port,
"Hello",
failure_message="Shadow inference failed",
success_message="Shadow inference OK",
)
(
weights_state_after_shadow_b,
_,
quiesced_memory_after_shadow_b,
) = quiesce_engine(
weights_gms,
kv_cache_gms,
shadow_b,
quiesce_label="Shadow quiesce",
expected_weights_hash=weights_hash,
)
assert weights_state_after_shadow_b.memory_layout_hash == weights_hash
weights_events_after_shadow_quiesce = weights_gms.get_event_history().events
assert_weights_published_once(weights_events_after_shadow_quiesce)
kv_events_after_shadow_quiesce = kv_cache_gms.get_event_history().events
assert_kv_history(kv_events_after_shadow_quiesce, cleared_layouts=2)
primary, weights_with_primary = _start_primary(
manager,
frontend_port,
weights_gms,
kv_cache_gms,
weights_hash=weights_hash,
quiesced_memory_after_shadow_b=quiesced_memory_after_shadow_b,
shadow_a_released_bytes=shadow_a_released_bytes,
)
resume_result = _resume_shadow_after_primary_failover(
shadow_a,
kv_cache_gms,
primary,
)
assert resume_result["status"] == "ok"
shadow_memory_after_resume = get_gpu_memory_used()
assert_memory_restored_after_quiesce(
"Shadow resume memory",
quiesced_memory_after_shadow_b,
shadow_memory_after_resume,
shadow_a_released_bytes,
)
# Once the primary is gone, the failover shadow should finish resume
# with the same committed weights layout and a new live RW KV-cache layout.
wait_for_resumed_layout(
weights_gms,
kv_cache_gms,
weights_with_primary,
min_weight_ro_sessions=1,
)
# The final KV history should show the full handoff:
# shadow A quiesced -> shadow B quiesced -> primary layout ->
# primary abort/clear -> shadow A reconnects -> shadow A sees OOM.
weights_events_after_resume = weights_gms.get_event_history().events
assert_weights_published_once(weights_events_after_resume)
kv_events_after_resume = kv_cache_gms.get_event_history().events
assert_kv_history(
kv_events_after_resume,
cleared_layouts=3,
suffix=["rw_connected", "allocation_oom"],
)
assert_completion_ok(
frontend_port,
"Post failover",
failure_message="Shadow inference after failover failed",
success_message="Shadow inference after failover OK",
retry_timeout=30.0,
)
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(600)
@pytest.mark.vllm
def test_gms_shadow_engine_failover_vllm(
request, runtime_services_dynamic_ports, predownload_models
):
_run_shadow_failover_test(request, VLLMWithGMSProcess)
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(600)
@pytest.mark.sglang
def test_gms_shadow_engine_failover_sglang(
request, runtime_services_dynamic_ports, predownload_models
):
_run_shadow_failover_test(request, SGLangWithGMSProcess)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment