Unverified Commit b78ec99a authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

test: reorganize and prune GPU Memory Service tests (#7828)


Signed-off-by: default avatarSchwinn Saereesitthipitak <schwinns@nvidia.com>
parent 11bb8498
This diff is collapsed.
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import logging
import time
from contextlib import ExitStack
from typing import Callable
import pytest
from gpu_memory_service.server.fsm import ServerState
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import DynamoFrontendProcess, ManagedProcess
from ..harness.gms import GMSServerProcess
from ..harness.runtime import (
MIN_EXPECTED_MEMORY_RETURN_FRACTION,
get_gpu_memory_used,
send_completion,
)
from ..harness.sglang import SGLangWithGMSProcess
from ..harness.vllm import VLLMWithGMSProcess
pytestmark = [pytest.mark.nightly]
# Event flow under test:
# 1. Weights are published once as a committed layout.
# 2. KV cache starts as a live RW layout build.
# 3. Sleep keeps weights committed but aborts and clears the KV layout.
# 4. Wake reconnects weights as RO to the same committed layout.
# 5. Wake recreates KV cache in a fresh RW layout after the old one was cleared.
logger = logging.getLogger(__name__)
def _run_sleep_wake_test(
request,
ports: dict,
make_engine: Callable[[], ManagedProcess],
) -> None:
with ExitStack() as stack:
weights_gms = stack.enter_context(
GMSServerProcess(request, device=0, tag="weights")
)
kv_cache_gms = stack.enter_context(
GMSServerProcess(request, device=0, tag="kv_cache")
)
stack.enter_context(
DynamoFrontendProcess(request, frontend_port=ports["frontend"])
)
with make_engine() as engine:
result = send_completion(ports["frontend"])
logger.info("Initial inference result: %s", result)
assert result["choices"]
# Before sleep, weights must already be published and visible to RO
# readers while KV cache remains a live RW layout owned by the engine.
deadline = time.monotonic() + 30.0
while True:
weights_before_sleep = weights_gms.get_runtime_state()
kv_before_sleep = kv_cache_gms.get_runtime_state()
if (
weights_before_sleep.state == ServerState.RO
and weights_before_sleep.allocation_count > 0
and weights_before_sleep.memory_layout_hash
and kv_before_sleep.state == ServerState.RW
and kv_before_sleep.allocation_count > 0
):
break
if time.monotonic() > deadline:
raise TimeoutError("initial GMS state did not stabilize")
time.sleep(0.1)
mem_before = get_gpu_memory_used()
logger.info("Memory before sleep: %.0f MB", mem_before / (1 << 20))
sleep_result = engine.sleep()
assert sleep_result["status"] == "ok"
mem_after_sleep = get_gpu_memory_used()
released_bytes = mem_before - mem_after_sleep
logger.info("Memory after sleep: %.0f MB", mem_after_sleep / (1 << 20))
assert mem_after_sleep < mem_before, "Sleep should reduce memory"
assert released_bytes > 0
# Sleep preserves the committed weights layout but aborts and clears the
# mutable KV-cache layout, which is what should release GPU memory.
deadline = time.monotonic() + 30.0
while True:
weights_after_sleep = weights_gms.get_runtime_state()
kv_after_sleep = kv_cache_gms.get_runtime_state()
if (
weights_after_sleep.state == ServerState.COMMITTED
and weights_after_sleep.allocation_count
== weights_before_sleep.allocation_count
and weights_after_sleep.memory_layout_hash
== weights_before_sleep.memory_layout_hash
and kv_after_sleep.state == ServerState.EMPTY
and kv_after_sleep.allocation_count == 0
):
break
if time.monotonic() > deadline:
raise TimeoutError(
"sleep did not drive GMS into the expected state"
)
time.sleep(0.1)
# Weights are immutable across sleep/wake, so their event history should
# still be the original publish: connect once, commit once.
weights_events = weights_gms.get_event_history().events
assert [event.kind for event in weights_events] == [
"rw_connected",
"committed",
]
# KV cache is different: sleep must abort the old RW layout and clear its
# server-owned allocations before wake can start a new RW layout.
kv_events = kv_cache_gms.get_event_history().events
assert [event.kind for event in kv_events] == [
"rw_connected",
"rw_aborted",
"allocations_cleared",
]
assert kv_events[-1].allocation_count > 0
wake_result = engine.wake()
assert wake_result["status"] == "ok"
mem_after_wake = get_gpu_memory_used()
reacquired_bytes = mem_after_wake - mem_after_sleep
logger.info("Memory after wake: %.0f MB", mem_after_wake / (1 << 20))
assert mem_after_wake > mem_after_sleep, "Wake should reacquire memory"
assert (
reacquired_bytes
) >= released_bytes * MIN_EXPECTED_MEMORY_RETURN_FRACTION
# Wake reconnects weights as RO to the same committed layout, but KV cache
# must come back as a fresh RW layout with new allocations.
deadline = time.monotonic() + 30.0
while True:
weights_after_wake = weights_gms.get_runtime_state()
kv_after_wake = kv_cache_gms.get_runtime_state()
if (
weights_after_wake.state == ServerState.RO
and weights_after_wake.allocation_count
== weights_before_sleep.allocation_count
and weights_after_wake.memory_layout_hash
== weights_before_sleep.memory_layout_hash
and kv_after_wake.state == ServerState.RW
and kv_after_wake.allocation_count > 0
):
break
if time.monotonic() > deadline:
raise TimeoutError("wake did not restore the expected GMS state")
time.sleep(0.1)
weights_events_after_wake = weights_gms.get_event_history().events
assert [event.kind for event in weights_events_after_wake] == [
"rw_connected",
"committed",
]
# The wake history should therefore extend the old KV sequence with one
# new RW connect after the previous layout was fully cleared.
kv_events_after_wake = kv_cache_gms.get_event_history().events
assert [event.kind for event in kv_events_after_wake] == [
"rw_connected",
"rw_aborted",
"allocations_cleared",
"rw_connected",
]
assert kv_events_after_wake[2].allocation_count > 0
result = send_completion(ports["frontend"], "Goodbye")
logger.info("Post-wake inference result: %s", result)
assert result["choices"]
logger.info(
"Memory freed: %.0f MB", (mem_before - mem_after_sleep) / (1 << 20)
)
@pytest.mark.vllm
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(300)
def test_gms_basic_sleep_wake_vllm(
request,
runtime_services_dynamic_ports,
gms_ports,
predownload_models,
):
ports = gms_ports
_run_sleep_wake_test(
request,
ports,
make_engine=lambda: VLLMWithGMSProcess(
request,
"engine",
ports["shadow_system"],
ports["shadow_kv_event"],
ports["shadow_nixl"],
ports["frontend"],
),
)
@pytest.mark.skip(reason="Nightly CI failure: https://linear.app/nvidia/issue/DYN-2567")
@pytest.mark.sglang
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(300)
def test_gms_basic_sleep_wake_sglang(
request,
runtime_services_dynamic_ports,
gms_ports,
predownload_models,
):
ports = gms_ports
_run_sleep_wake_test(
request,
ports,
make_engine=lambda: SGLangWithGMSProcess(
request,
"engine",
ports["shadow_system"],
ports["shadow_sglang"],
ports["frontend"],
),
)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from contextlib import contextmanager
from types import SimpleNamespace
import pytest
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.gpu_0,
pytest.mark.vllm,
]
class _FakeManager:
def __init__(self, *, is_unmapped: bool = False) -> None:
self.is_unmapped = is_unmapped
self.calls: list[object] = []
def unmap_all_vas(self) -> None:
self.calls.append("unmap_all_vas")
self.is_unmapped = True
def abort(self) -> None:
self.calls.append("abort")
def connect(self, lock_type, timeout_ms=None) -> None:
self.calls.append(("connect", lock_type.value))
self.is_unmapped = False
def reallocate_all_handles(self, *, tag: str) -> None:
self.calls.append(("reallocate_all_handles", tag))
def remap_all_vas(self) -> None:
self.calls.append("remap_all_vas")
self.is_unmapped = False
def test_initialize_from_config_uses_kv_cache_gms_tag(monkeypatch):
import gpu_memory_service.integrations.vllm.worker as worker_module
import vllm.distributed.kv_transfer as kv_transfer
from gpu_memory_service.integrations.vllm.worker import GMSWorker
create_calls: list[tuple[object, ...]] = []
pool_calls: list[tuple[str, str]] = []
kv_transfer_calls: list[object] = []
kv_init_calls: list[object] = []
@contextmanager
def fake_use_mem_pool(tag, device):
pool_calls.append((tag, str(device)))
yield
def fake_get_or_create(socket_path, device, mode, *, tag, timeout_ms=None):
create_calls.append((socket_path, device, mode.value, tag, timeout_ms))
return object()
monkeypatch.setattr(worker_module, "gms_use_mem_pool", fake_use_mem_pool)
monkeypatch.setattr(
worker_module,
"get_or_create_gms_client_memory_manager",
fake_get_or_create,
)
monkeypatch.setattr(
worker_module,
"get_socket_path",
lambda device, tag: f"/tmp/{tag}-{device}.sock",
)
monkeypatch.setattr(
kv_transfer,
"ensure_kv_transfer_initialized",
lambda vllm_config, kv_cache_config: kv_transfer_calls.append(kv_cache_config),
)
worker = object.__new__(GMSWorker)
worker.local_rank = 3
worker.vllm_config = SimpleNamespace(
model_config=SimpleNamespace(enable_sleep_mode=True)
)
worker.model_runner = SimpleNamespace(
initialize_kv_cache=lambda kv_cache_config: kv_init_calls.append(
kv_cache_config
)
)
worker.initialize_from_config("kv-config")
assert create_calls == [("/tmp/kv_cache-3.sock", 3, "rw", "kv_cache", None)]
assert pool_calls == [("kv_cache", "cuda:3")]
assert kv_transfer_calls == ["kv-config"]
assert kv_init_calls == ["kv-config"]
def test_sleep_level_2_unmaps_weights_and_kv_cache(monkeypatch):
import gpu_memory_service.integrations.vllm.worker as worker_module
from gpu_memory_service.integrations.vllm.worker import GMSWorker
weights = _FakeManager()
kv_cache = _FakeManager()
monkeypatch.setattr(
worker_module,
"get_gms_client_memory_manager",
lambda tag: weights if tag == "weights" else kv_cache,
)
monkeypatch.setattr(
worker_module.torch.cuda,
"mem_get_info",
lambda: (2 << 30, 8 << 30),
)
worker = object.__new__(GMSWorker)
worker.sleep(level=2)
assert weights.calls == ["unmap_all_vas", "abort"]
assert kv_cache.calls == ["unmap_all_vas", "abort"]
def test_wake_up_remaps_weights_and_reallocates_kv_cache(monkeypatch):
import gpu_memory_service.integrations.vllm.worker as worker_module
from gpu_memory_service.integrations.vllm.worker import GMSWorker
weights = _FakeManager(is_unmapped=True)
kv_cache = _FakeManager(is_unmapped=True)
fp8_calls: list[str] = []
monkeypatch.setattr(
worker_module,
"get_gms_client_memory_manager",
lambda tag: weights if tag == "weights" else kv_cache,
)
worker = object.__new__(GMSWorker)
worker.local_rank = 0
worker.cache_config = SimpleNamespace(cache_dtype="fp8_e4m3")
worker.model_runner = SimpleNamespace(
kv_caches={"layer_0": True},
init_fp8_kv_scales=lambda: fp8_calls.append("fp8"),
)
worker.wake_up(["weights", "kv_cache"])
assert weights.calls == [
("connect", "ro"),
"remap_all_vas",
]
assert kv_cache.calls == [
("connect", "rw"),
("reallocate_all_handles", "kv_cache"),
"remap_all_vas",
]
assert fp8_calls == ["fp8"]
def test_maybe_get_memory_pool_context_routes_tags(monkeypatch):
import gpu_memory_service.integrations.vllm.worker as worker_module
from gpu_memory_service.integrations.vllm.worker import GMSWorker, Worker
kv_cache_context = object()
super_calls: list[str] = []
mem_pool_calls: list[tuple[str, str]] = []
def fake_use_mem_pool(tag, device):
mem_pool_calls.append((tag, str(device)))
return kv_cache_context
def fake_super_context(self, tag):
del self
super_calls.append(tag)
return f"super:{tag}"
monkeypatch.setattr(worker_module, "gms_use_mem_pool", fake_use_mem_pool)
monkeypatch.setattr(Worker, "_maybe_get_memory_pool_context", fake_super_context)
worker = object.__new__(GMSWorker)
worker.local_rank = 2
weights_context = worker._maybe_get_memory_pool_context("weights")
with weights_context:
pass
assert mem_pool_calls == []
assert super_calls == []
assert worker._maybe_get_memory_pool_context("kv_cache") is kv_cache_context
assert mem_pool_calls == [("kv_cache", "cuda:2")]
assert super_calls == []
assert worker._maybe_get_memory_pool_context("other") == "super:other"
assert super_calls == ["other"]
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Socket-level GMS helpers for the cross-component test suite."""
from __future__ import annotations
import asyncio
import os
import socket
import threading
import time
from typing import TYPE_CHECKING
from gpu_memory_service.common.locks import RequestedLockType
from gpu_memory_service.common.utils import get_socket_path
if TYPE_CHECKING:
from gpu_memory_service.common.protocol.messages import (
GetEventHistoryResponse,
GetRuntimeStateResponse,
ListAllocationsResponse,
)
_SERVER_START_TIMEOUT_SECONDS = 30.0
_SERVER_STOP_TIMEOUT_SECONDS = 5.0
_POLL_INTERVAL_SECONDS = 0.1
def _request_gms(
socket_path: str,
request,
response_type,
*,
lock_type: RequestedLockType | None = None,
timeout_ms: int | None = None,
):
"""Send one raw request over a Unix socket, with optional lock handshake."""
from gpu_memory_service.common.protocol.messages import (
ErrorResponse,
HandshakeRequest,
HandshakeResponse,
)
from gpu_memory_service.common.protocol.wire import (
recv_message_sync,
send_message_sync,
)
recv_buffer = bytearray()
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
sock.connect(socket_path)
if lock_type is not None:
send_message_sync(
sock,
HandshakeRequest(lock_type=lock_type, timeout_ms=timeout_ms),
)
handshake, fd, recv_buffer = recv_message_sync(sock, recv_buffer)
if fd >= 0:
os.close(fd)
raise RuntimeError("GMS handshake returned an unexpected FD")
if isinstance(handshake, ErrorResponse):
raise RuntimeError(f"GMS handshake error: {handshake.error}")
if not isinstance(handshake, HandshakeResponse):
raise RuntimeError(
f"Unexpected handshake response: {type(handshake).__name__}"
)
if not handshake.success:
raise TimeoutError("Timeout waiting for GMS lock")
send_message_sync(sock, request)
response, fd, recv_buffer = recv_message_sync(sock, recv_buffer)
if isinstance(response, ErrorResponse):
raise RuntimeError(f"GMS request error: {response.error}")
if not isinstance(response, response_type):
raise RuntimeError(f"Unexpected response type: {type(response).__name__}")
if fd >= 0:
os.close(fd)
raise RuntimeError(
f"GMS request {type(request).__name__} returned an unexpected FD"
)
return response
finally:
sock.close()
def list_allocations(socket_path: str) -> ListAllocationsResponse:
from gpu_memory_service.common.protocol.messages import (
ListAllocationsRequest,
ListAllocationsResponse,
)
return _request_gms(
socket_path,
ListAllocationsRequest(),
ListAllocationsResponse,
lock_type=RequestedLockType.RO,
)
class GMSServer:
"""In-process GMS server wrapper."""
def __init__(self, device: int, tag: str = "weights"):
from gpu_memory_service.server.rpc import GMSRPCServer
self.socket_path = get_socket_path(device, tag)
self.server = GMSRPCServer(self.socket_path, device=device)
self._loop: asyncio.AbstractEventLoop | None = None
self._task: asyncio.Task[None] | None = None
self._thread = threading.Thread(target=self._run, daemon=True)
self._exception: BaseException | None = None
def _run(self) -> None:
loop = asyncio.new_event_loop()
self._loop = loop
asyncio.set_event_loop(loop)
self._task = loop.create_task(self.server.serve())
try:
loop.run_until_complete(self._task)
except asyncio.CancelledError:
pass
except BaseException as exc:
self._exception = exc
finally:
pending = [task for task in asyncio.all_tasks(loop) if not task.done()]
for task in pending:
task.cancel()
if pending:
loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True)
)
loop.close()
def __enter__(self):
if os.path.exists(self.socket_path):
try:
self.get_runtime_state()
except OSError:
if os.path.exists(self.socket_path):
os.unlink(self.socket_path)
else:
raise RuntimeError(f"GMS already active at {self.socket_path}")
self._thread.start()
deadline = time.monotonic() + _SERVER_START_TIMEOUT_SECONDS
last_probe_error: OSError | None = None
while True:
if self._exception is not None:
raise self._exception
if self.server._server is not None and os.path.exists(self.socket_path):
try:
self.get_runtime_state()
return self
except OSError as exc:
last_probe_error = exc
if time.monotonic() > deadline:
timeout_error = TimeoutError(
f"GMS socket did not appear at {self.socket_path}"
)
if last_probe_error is not None:
raise timeout_error from last_probe_error
raise timeout_error
time.sleep(_POLL_INTERVAL_SECONDS)
def __exit__(self, exc_type, exc_val, exc_tb):
if self._loop is not None:
def cancel() -> None:
if self.server._server is not None:
self.server._server.close()
if self._task is not None:
self._task.cancel()
self._loop.call_soon_threadsafe(cancel)
self._thread.join(timeout=_SERVER_STOP_TIMEOUT_SECONDS)
if self._thread.is_alive():
raise RuntimeError(
f"GMS server thread failed to stop for {self.socket_path}"
)
if self._exception is not None:
raise self._exception
if os.path.exists(self.socket_path):
os.unlink(self.socket_path)
def get_runtime_state(self) -> GetRuntimeStateResponse:
from gpu_memory_service.common.protocol.messages import (
GetRuntimeStateRequest,
GetRuntimeStateResponse,
)
return _request_gms(
self.socket_path,
GetRuntimeStateRequest(),
GetRuntimeStateResponse,
)
def get_event_history(self) -> GetEventHistoryResponse:
from gpu_memory_service.common.protocol.messages import (
GetEventHistoryRequest,
GetEventHistoryResponse,
)
return _request_gms(
self.socket_path,
GetEventHistoryRequest(),
GetEventHistoryResponse,
)
This diff is collapsed.
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import logging
import time
import requests
from gpu_memory_service.server.fsm import ServerState
from tests.gpu_memory_service.common.runtime import get_gpu_memory_used
from tests.utils.client import send_request
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.payloads import CompletionPayload
logger = logging.getLogger(__name__)
MIN_EXPECTED_MEMORY_RESTORE_FRACTION = 0.9
def assert_completion_ok(
frontend_port: int,
prompt: str,
*,
failure_message: str,
success_message: str,
retry_timeout: float = 0.0,
retry_interval: float = 1.0,
):
completion = CompletionPayload(
body={
"model": FAULT_TOLERANCE_MODEL_NAME,
"prompt": prompt,
"max_tokens": 20,
},
expected_response=[],
expected_log=[],
timeout=120,
port=frontend_port,
)
deadline = time.monotonic() + retry_timeout
while True:
response = send_request(
url=completion.url(),
payload=completion.body,
timeout=completion.timeout,
method=completion.method,
)
try:
completion.process_response(response)
result = response.json()
if not isinstance(result, dict) or not result.get("choices"):
raise AssertionError(failure_message)
logger.info("%s: %s", success_message, result)
return
except (AssertionError, KeyError, requests.RequestException, ValueError):
if time.monotonic() >= deadline:
raise
time.sleep(retry_interval)
def assert_memory_restored_after_quiesce(
label: str,
quiesced_memory: int,
active_memory: int,
released_bytes: int,
*,
min_fraction: float = MIN_EXPECTED_MEMORY_RESTORE_FRACTION,
) -> None:
restored_bytes = active_memory - quiesced_memory
logger.info(
"%s: %.2f GiB (restored %.0f MB)",
label,
active_memory / (1 << 30),
restored_bytes / (1 << 20),
)
assert active_memory > quiesced_memory
assert restored_bytes >= released_bytes * min_fraction
def quiesce_engine(
weights_gms,
kv_cache_gms,
engine,
*,
quiesce_label: str,
expected_weights_hash: str | None = None,
):
weights_state, _ = wait_for_active_layout(
weights_gms,
kv_cache_gms,
expected_weights_hash=expected_weights_hash,
)
memory_before_quiesce = get_gpu_memory_used()
assert engine.quiesce()["status"] == "ok"
memory_after_quiesce = get_gpu_memory_used()
released_bytes = memory_before_quiesce - memory_after_quiesce
logger.info(
"%s: %.2f -> %.2f GiB (freed %.0f MB)",
quiesce_label,
memory_before_quiesce / (1 << 30),
memory_after_quiesce / (1 << 30),
released_bytes / (1 << 20),
)
assert memory_after_quiesce < memory_before_quiesce
assert released_bytes > 0
wait_for_quiesced_layout(weights_gms, kv_cache_gms, weights_state)
return weights_state, released_bytes, memory_after_quiesce
def wait_for_active_layout(
weights_gms,
kv_cache_gms,
*,
expected_weights_hash: str | None = None,
min_weight_ro_sessions: int = 0,
timeout: float = 30.0,
):
deadline = time.monotonic() + timeout
while True:
weights_state = weights_gms.get_runtime_state()
kv_state = kv_cache_gms.get_runtime_state()
if (
weights_state.state == ServerState.RO
and weights_state.ro_session_count >= min_weight_ro_sessions
and weights_state.allocation_count > 0
and weights_state.memory_layout_hash
and kv_state.state == ServerState.RW
and kv_state.allocation_count > 0
):
if (
expected_weights_hash is None
or weights_state.memory_layout_hash == expected_weights_hash
):
return weights_state, kv_state
if time.monotonic() > deadline:
raise TimeoutError("GMS state did not reach the active layout")
time.sleep(0.1)
def wait_for_quiesced_layout(
weights_gms,
kv_cache_gms,
weights_state_before_quiesce,
*,
require_no_ro_sessions: bool = False,
timeout: float = 30.0,
):
deadline = time.monotonic() + timeout
while True:
weights_after_quiesce = weights_gms.get_runtime_state()
kv_after_quiesce = kv_cache_gms.get_runtime_state()
if (
weights_after_quiesce.state == ServerState.COMMITTED
and weights_after_quiesce.allocation_count
== weights_state_before_quiesce.allocation_count
and weights_after_quiesce.memory_layout_hash
== weights_state_before_quiesce.memory_layout_hash
and kv_after_quiesce.state == ServerState.EMPTY
and kv_after_quiesce.allocation_count == 0
):
if (
not require_no_ro_sessions
or weights_after_quiesce.ro_session_count == 0
):
return weights_after_quiesce, kv_after_quiesce
if time.monotonic() > deadline:
raise TimeoutError("GMS state did not reach the quiesced layout")
time.sleep(0.1)
def wait_for_resumed_layout(
weights_gms,
kv_cache_gms,
weights_state_before_quiesce,
*,
min_weight_ro_sessions: int = 0,
timeout: float = 30.0,
):
deadline = time.monotonic() + timeout
while True:
weights_after_resume = weights_gms.get_runtime_state()
kv_after_resume = kv_cache_gms.get_runtime_state()
if (
weights_after_resume.state == ServerState.RO
and weights_after_resume.ro_session_count >= min_weight_ro_sessions
and weights_after_resume.allocation_count
== weights_state_before_quiesce.allocation_count
and weights_after_resume.memory_layout_hash
== weights_state_before_quiesce.memory_layout_hash
and kv_after_resume.state == ServerState.RW
and kv_after_resume.allocation_count > 0
):
return weights_after_resume, kv_after_resume
if time.monotonic() > deadline:
raise TimeoutError("GMS state did not reach the resumed layout")
time.sleep(0.1)
def assert_weights_published_once(events) -> None:
assert [event.kind for event in events] == ["rw_connected", "committed"]
def assert_kv_history(
events,
*,
cleared_layouts: int,
suffix: list[str] | None = None,
) -> None:
expected_kinds = [
"rw_connected",
"rw_aborted",
"allocations_cleared",
] * cleared_layouts
if suffix is not None:
expected_kinds.extend(suffix)
assert [event.kind for event in events] == expected_kinds
clear_counts = [
event.allocation_count
for event in events
if event.kind == "allocations_cleared"
]
assert len(clear_counts) >= cleared_layouts
assert all(count > 0 for count in clear_counts[:cleared_layouts])
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment