Unverified Commit 91a8c6af authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

fix: simplify GMS layout state and harden GPU-backed flows (#7006)


Signed-off-by: default avatarmohammedabdulwahhab <furkhan324@berkeley.edu>
Signed-off-by: default avatarHannah Zhang <hannahz@nvidia.com>
Co-authored-by: default avatarmohammedabdulwahhab <furkhan324@berkeley.edu>
Co-authored-by: default avatarhhzhang16 <54051230+hhzhang16@users.noreply.github.com>
parent dd7ceb4a
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import logging
import time
from contextlib import ExitStack
from typing import Callable
import pytest
from gpu_memory_service.common.types import ServerState
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import DynamoFrontendProcess, ManagedProcess
from ..harness.gms import GMSServerProcess
from ..harness.runtime import (
MIN_EXPECTED_MEMORY_RETURN_FRACTION,
get_gpu_memory_used,
send_completion,
)
from ..harness.sglang import SGLangWithGMSProcess
from ..harness.vllm import VLLMWithGMSProcess
pytestmark = [pytest.mark.nightly]
# Event flow under test:
# 1. Weights are published once as a committed layout.
# 2. KV cache starts as a live RW layout build.
# 3. Sleep keeps weights committed but aborts and clears the KV layout.
# 4. Wake reconnects weights as RO to the same committed layout.
# 5. Wake recreates KV cache in a fresh RW layout after the old one was cleared.
logger = logging.getLogger(__name__)
def _run_sleep_wake_test(
request,
ports: dict,
make_engine: Callable[[], ManagedProcess],
) -> None:
with ExitStack() as stack:
weights_gms = stack.enter_context(
GMSServerProcess(request, device=0, tag="weights")
)
kv_cache_gms = stack.enter_context(
GMSServerProcess(request, device=0, tag="kv_cache")
)
stack.enter_context(
DynamoFrontendProcess(request, frontend_port=ports["frontend"])
)
with make_engine() as engine:
result = send_completion(ports["frontend"])
logger.info("Initial inference result: %s", result)
assert result["choices"]
# Before sleep, weights must already be published and visible to RO
# readers while KV cache remains a live RW layout owned by the engine.
deadline = time.monotonic() + 30.0
while True:
weights_before_sleep = weights_gms.get_runtime_state()
kv_before_sleep = kv_cache_gms.get_runtime_state()
if (
weights_before_sleep.state == ServerState.RO
and weights_before_sleep.allocation_count > 0
and weights_before_sleep.memory_layout_hash
and kv_before_sleep.state == ServerState.RW
and kv_before_sleep.allocation_count > 0
):
break
if time.monotonic() > deadline:
raise TimeoutError("initial GMS state did not stabilize")
time.sleep(0.1)
mem_before = get_gpu_memory_used()
logger.info("Memory before sleep: %.0f MB", mem_before / (1 << 20))
sleep_result = engine.sleep()
assert sleep_result["status"] == "ok"
mem_after_sleep = get_gpu_memory_used()
released_bytes = mem_before - mem_after_sleep
logger.info("Memory after sleep: %.0f MB", mem_after_sleep / (1 << 20))
assert mem_after_sleep < mem_before, "Sleep should reduce memory"
assert released_bytes > 0
# Sleep preserves the committed weights layout but aborts and clears the
# mutable KV-cache layout, which is what should release GPU memory.
deadline = time.monotonic() + 30.0
while True:
weights_after_sleep = weights_gms.get_runtime_state()
kv_after_sleep = kv_cache_gms.get_runtime_state()
if (
weights_after_sleep.state == ServerState.COMMITTED
and weights_after_sleep.allocation_count
== weights_before_sleep.allocation_count
and weights_after_sleep.memory_layout_hash
== weights_before_sleep.memory_layout_hash
and kv_after_sleep.state == ServerState.EMPTY
and kv_after_sleep.allocation_count == 0
):
break
if time.monotonic() > deadline:
raise TimeoutError(
"sleep did not drive GMS into the expected state"
)
time.sleep(0.1)
# Weights are immutable across sleep/wake, so their event history should
# still be the original publish: connect once, commit once.
weights_events = weights_gms.get_event_history().events
assert [event.kind for event in weights_events] == [
"rw_connected",
"committed",
]
# KV cache is different: sleep must abort the old RW layout and clear its
# server-owned allocations before wake can start a new RW layout.
kv_events = kv_cache_gms.get_event_history().events
assert [event.kind for event in kv_events] == [
"rw_connected",
"rw_aborted",
"allocations_cleared",
]
assert kv_events[-1].allocation_count > 0
wake_result = engine.wake()
assert wake_result["status"] == "ok"
mem_after_wake = get_gpu_memory_used()
reacquired_bytes = mem_after_wake - mem_after_sleep
logger.info("Memory after wake: %.0f MB", mem_after_wake / (1 << 20))
assert mem_after_wake > mem_after_sleep, "Wake should reacquire memory"
assert (
reacquired_bytes
) >= released_bytes * MIN_EXPECTED_MEMORY_RETURN_FRACTION
# Wake reconnects weights as RO to the same committed layout, but KV cache
# must come back as a fresh RW layout with new allocations.
deadline = time.monotonic() + 30.0
while True:
weights_after_wake = weights_gms.get_runtime_state()
kv_after_wake = kv_cache_gms.get_runtime_state()
if (
weights_after_wake.state == ServerState.RO
and weights_after_wake.allocation_count
== weights_before_sleep.allocation_count
and weights_after_wake.memory_layout_hash
== weights_before_sleep.memory_layout_hash
and kv_after_wake.state == ServerState.RW
and kv_after_wake.allocation_count > 0
):
break
if time.monotonic() > deadline:
raise TimeoutError("wake did not restore the expected GMS state")
time.sleep(0.1)
weights_events_after_wake = weights_gms.get_event_history().events
assert [event.kind for event in weights_events_after_wake] == [
"rw_connected",
"committed",
]
# The wake history should therefore extend the old KV sequence with one
# new RW connect after the previous layout was fully cleared.
kv_events_after_wake = kv_cache_gms.get_event_history().events
assert [event.kind for event in kv_events_after_wake] == [
"rw_connected",
"rw_aborted",
"allocations_cleared",
"rw_connected",
]
assert kv_events_after_wake[2].allocation_count > 0
result = send_completion(ports["frontend"], "Goodbye")
logger.info("Post-wake inference result: %s", result)
assert result["choices"]
logger.info(
"Memory freed: %.0f MB", (mem_before - mem_after_sleep) / (1 << 20)
)
@pytest.mark.vllm
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(300)
def test_gms_basic_sleep_wake_vllm(
request,
runtime_services_dynamic_ports,
gms_ports,
predownload_models,
):
ports = gms_ports
_run_sleep_wake_test(
request,
ports,
make_engine=lambda: VLLMWithGMSProcess(
request,
"engine",
ports["shadow_system"],
ports["shadow_kv_event"],
ports["shadow_nixl"],
ports["frontend"],
),
)
@pytest.mark.sglang
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(300)
def test_gms_basic_sleep_wake_sglang(
request,
runtime_services_dynamic_ports,
gms_ports,
predownload_models,
):
ports = gms_ports
_run_sleep_wake_test(
request,
ports,
make_engine=lambda: SGLangWithGMSProcess(
request,
"engine",
ports["shadow_system"],
ports["shadow_sglang"],
ports["frontend"],
),
)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from typing import cast
import pytest
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
from gpu_memory_service.client.torch.module import (
materialize_module_from_gms,
register_module_tensors,
)
from gpu_memory_service.client.torch.tensor import _tensor_from_pointer
from gpu_memory_service.common.types import RequestedLockType
from tests.gms.harness.gms import GMSServerProcess
torch = pytest.importorskip("torch", reason="torch is required")
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.gpu_1,
]
if not torch.cuda.is_available():
pytest.skip(
"CUDA is required for torch GMS integration tests", allow_module_level=True
)
class _TinyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(8, 4, bias=False, device="cuda")
self.register_buffer(
"scale",
torch.linspace(0.5, 2.0, steps=4, device="cuda", dtype=torch.float32),
)
self.extra = torch.arange(1, 5, device="cuda", dtype=torch.float32)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = self.linear(x)
y = y + self.scale
y = y * self.extra
return torch.relu(y)
@pytest.fixture
def running_gms(request):
with GMSServerProcess(request, device=0, tag="weights") as server:
yield server.socket_path
def _make_gms_tensor(
manager: GMSClientMemoryManager,
tensor: torch.Tensor,
*,
tag: str,
) -> tuple[str, torch.Tensor]:
storage_bytes = tensor.untyped_storage().nbytes()
va = manager.create_mapping(size=storage_bytes, tag=tag)
allocation_id = manager.mappings[va].allocation_id
gms_tensor = _tensor_from_pointer(
va,
list(tensor.shape),
list(tensor.stride()),
tensor.dtype,
tensor.device.index or 0,
)
gms_tensor.copy_(tensor)
return allocation_id, gms_tensor
def _assert_exact_tensor_equal(expected: torch.Tensor, actual: torch.Tensor) -> None:
torch.testing.assert_close(expected, actual, rtol=0, atol=0)
def test_gms_tensor_matches_plain_torch_ops(running_gms):
socket_path = running_gms
baseline = torch.arange(64, device="cuda", dtype=torch.float32).reshape(8, 8)
rhs = torch.arange(32, device="cuda", dtype=torch.float32).reshape(8, 4)
writer = GMSClientMemoryManager(socket_path, device=0)
writer.connect(RequestedLockType.RW)
allocation_id, _writer_tensor = _make_gms_tensor(writer, baseline, tag="weights")
assert writer.commit()
reader = GMSClientMemoryManager(socket_path, device=0)
reader.connect(RequestedLockType.RO)
va = reader.create_mapping(allocation_id=allocation_id)
gms_tensor = _tensor_from_pointer(
va,
list(baseline.shape),
list(baseline.stride()),
baseline.dtype,
0,
)
_assert_exact_tensor_equal(
torch.relu((baseline + 3.0) @ rhs), torch.relu((gms_tensor + 3.0) @ rhs)
)
_assert_exact_tensor_equal(
baseline.transpose(0, 1).contiguous(), gms_tensor.transpose(0, 1).contiguous()
)
_assert_exact_tensor_equal(
baseline[:, 2:6].sum(dim=1), gms_tensor[:, 2:6].sum(dim=1)
)
_assert_exact_tensor_equal(
(baseline * 2.0 - 5.0).square(), (gms_tensor * 2.0 - 5.0).square()
)
reader.close()
def test_live_gms_tensor_survives_unmap_and_remap(running_gms):
socket_path = running_gms
baseline = torch.arange(64, device="cuda", dtype=torch.float32).reshape(8, 8)
writer = GMSClientMemoryManager(socket_path, device=0)
writer.connect(RequestedLockType.RW)
allocation_id, _ = _make_gms_tensor(writer, baseline, tag="weights")
assert writer.commit()
reader = GMSClientMemoryManager(socket_path, device=0)
reader.connect(RequestedLockType.RO)
va = reader.create_mapping(allocation_id=allocation_id)
gms_tensor = _tensor_from_pointer(
va,
list(baseline.shape),
list(baseline.stride()),
baseline.dtype,
0,
)
pointer_before = gms_tensor.data_ptr()
expected = torch.relu((baseline + 7.0).square())
reader.unmap_all_vas()
reader.abort()
reader.connect(RequestedLockType.RO)
reader.remap_all_vas()
assert gms_tensor.data_ptr() == pointer_before
_assert_exact_tensor_equal(expected, torch.relu((gms_tensor + 7.0).square()))
reader.close()
def test_materialized_module_from_gms_matches_plain_module_forward(running_gms):
socket_path = running_gms
torch.manual_seed(7)
baseline = _TinyModule().cuda()
gms_model = _TinyModule().cuda()
gms_model.load_state_dict(baseline.state_dict())
inputs = torch.randn(3, 8, device="cuda", dtype=torch.float32)
expected = baseline(inputs).detach().clone()
writer = GMSClientMemoryManager(socket_path, device=0)
writer.connect(RequestedLockType.RW)
baseline_weight = cast(torch.Tensor, baseline.linear.weight)
baseline_scale = cast(torch.Tensor, baseline.scale)
baseline_extra = cast(torch.Tensor, baseline.extra)
_, gms_weight = _make_gms_tensor(writer, baseline_weight, tag="weights")
gms_model.linear.weight = torch.nn.Parameter(
gms_weight, requires_grad=baseline_weight.requires_grad
)
_, gms_scale = _make_gms_tensor(writer, baseline_scale, tag="weights")
gms_model._buffers["scale"] = gms_scale
_, gms_extra = _make_gms_tensor(writer, baseline_extra, tag="weights")
gms_model.extra = gms_extra
register_module_tensors(writer, gms_model)
_assert_exact_tensor_equal(expected, gms_model(inputs))
assert writer.commit()
reader = GMSClientMemoryManager(socket_path, device=0)
reader.connect(RequestedLockType.RO)
materialized = _TinyModule().cuda()
materialize_module_from_gms(reader, materialized, device_index=0)
_assert_exact_tensor_equal(expected, materialized(inputs))
_assert_exact_tensor_equal(baseline_scale, cast(torch.Tensor, materialized.scale))
_assert_exact_tensor_equal(baseline_extra, cast(torch.Tensor, materialized.extra))
_assert_exact_tensor_equal(
baseline_weight,
cast(torch.Tensor, materialized.linear.weight),
)
reader.close()
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from contextlib import contextmanager
from types import SimpleNamespace
import pytest
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.gpu_0,
pytest.mark.vllm,
]
class _FakeManager:
def __init__(self, *, is_unmapped: bool = False) -> None:
self.is_unmapped = is_unmapped
self.calls: list[object] = []
def unmap_all_vas(self) -> None:
self.calls.append("unmap_all_vas")
self.is_unmapped = True
def abort(self) -> None:
self.calls.append("abort")
def connect(self, lock_type, timeout_ms=None) -> None:
self.calls.append(("connect", lock_type.value))
self.is_unmapped = False
def reallocate_all_handles(self, *, tag: str) -> None:
self.calls.append(("reallocate_all_handles", tag))
def remap_all_vas(self) -> None:
self.calls.append("remap_all_vas")
self.is_unmapped = False
def test_initialize_from_config_uses_kv_cache_gms_tag(monkeypatch):
import gpu_memory_service.integrations.vllm.worker as worker_module
import vllm.distributed.kv_transfer as kv_transfer
from gpu_memory_service.integrations.vllm.worker import GMSWorker
create_calls: list[tuple[object, ...]] = []
pool_calls: list[tuple[str, str]] = []
kv_transfer_calls: list[object] = []
kv_init_calls: list[object] = []
@contextmanager
def fake_use_mem_pool(tag, device):
pool_calls.append((tag, str(device)))
yield
def fake_get_or_create(socket_path, device, mode, *, tag, timeout_ms=None):
create_calls.append((socket_path, device, mode.value, tag, timeout_ms))
return object()
monkeypatch.setattr(worker_module, "gms_use_mem_pool", fake_use_mem_pool)
monkeypatch.setattr(
worker_module,
"get_or_create_gms_client_memory_manager",
fake_get_or_create,
)
monkeypatch.setattr(
worker_module,
"get_socket_path",
lambda device, tag: f"/tmp/{tag}-{device}.sock",
)
monkeypatch.setattr(
kv_transfer,
"ensure_kv_transfer_initialized",
lambda vllm_config, kv_cache_config: kv_transfer_calls.append(kv_cache_config),
)
worker = object.__new__(GMSWorker)
worker.local_rank = 3
worker.vllm_config = SimpleNamespace(
model_config=SimpleNamespace(enable_sleep_mode=True)
)
worker.model_runner = SimpleNamespace(
initialize_kv_cache=lambda kv_cache_config: kv_init_calls.append(
kv_cache_config
)
)
worker.initialize_from_config("kv-config")
assert create_calls == [("/tmp/kv_cache-3.sock", 3, "rw", "kv_cache", None)]
assert pool_calls == [("kv_cache", "cuda:3")]
assert kv_transfer_calls == ["kv-config"]
assert kv_init_calls == ["kv-config"]
def test_sleep_level_2_unmaps_weights_and_kv_cache(monkeypatch):
import gpu_memory_service.integrations.vllm.worker as worker_module
from gpu_memory_service.integrations.vllm.worker import GMSWorker
weights = _FakeManager()
kv_cache = _FakeManager()
monkeypatch.setattr(
worker_module,
"get_gms_client_memory_manager",
lambda tag: weights if tag == "weights" else kv_cache,
)
monkeypatch.setattr(
worker_module.torch.cuda,
"mem_get_info",
lambda: (2 << 30, 8 << 30),
)
worker = object.__new__(GMSWorker)
worker.sleep(level=2)
assert weights.calls == ["unmap_all_vas", "abort"]
assert kv_cache.calls == ["unmap_all_vas", "abort"]
def test_wake_up_remaps_weights_and_reallocates_kv_cache(monkeypatch):
import gpu_memory_service.integrations.vllm.worker as worker_module
from gpu_memory_service.integrations.vllm.worker import GMSWorker
weights = _FakeManager(is_unmapped=True)
kv_cache = _FakeManager(is_unmapped=True)
fp8_calls: list[str] = []
monkeypatch.setattr(
worker_module,
"get_gms_client_memory_manager",
lambda tag: weights if tag == "weights" else kv_cache,
)
worker = object.__new__(GMSWorker)
worker.local_rank = 0
worker.cache_config = SimpleNamespace(cache_dtype="fp8_e4m3")
worker.model_runner = SimpleNamespace(
kv_caches={"layer_0": True},
init_fp8_kv_scales=lambda: fp8_calls.append("fp8"),
)
worker.wake_up(["weights", "kv_cache"])
assert weights.calls == [
("connect", "ro"),
"remap_all_vas",
]
assert kv_cache.calls == [
("connect", "rw"),
("reallocate_all_handles", "kv_cache"),
"remap_all_vas",
]
assert fp8_calls == ["fp8"]
def test_maybe_get_memory_pool_context_routes_tags(monkeypatch):
import gpu_memory_service.integrations.vllm.worker as worker_module
from gpu_memory_service.integrations.vllm.worker import GMSWorker, Worker
kv_cache_context = object()
super_calls: list[str] = []
mem_pool_calls: list[tuple[str, str]] = []
def fake_use_mem_pool(tag, device):
mem_pool_calls.append((tag, str(device)))
return kv_cache_context
def fake_super_context(self, tag):
del self
super_calls.append(tag)
return f"super:{tag}"
monkeypatch.setattr(worker_module, "gms_use_mem_pool", fake_use_mem_pool)
monkeypatch.setattr(Worker, "_maybe_get_memory_pool_context", fake_super_context)
worker = object.__new__(GMSWorker)
worker.local_rank = 2
weights_context = worker._maybe_get_memory_pool_context("weights")
with weights_context:
pass
assert mem_pool_calls == []
assert super_calls == []
assert worker._maybe_get_memory_pool_context("kv_cache") is kv_cache_context
assert mem_pool_calls == [("kv_cache", "cuda:2")]
assert super_calls == []
assert worker._maybe_get_memory_pool_context("other") == "super:other"
assert super_calls == ["other"]
......@@ -120,10 +120,35 @@ STUB_MODULES = [
"botocore.exceptions",
"pynvml",
"gpu_memory_service",
"gpu_memory_service.client",
"gpu_memory_service.client.memory_manager",
"gpu_memory_service.client.rpc",
"gpu_memory_service.client.session",
"gpu_memory_service.client.torch",
"gpu_memory_service.client.torch.allocator",
"gpu_memory_service.client.torch.module",
"gpu_memory_service.client.torch.tensor",
"gpu_memory_service.common",
"gpu_memory_service.common.cuda_utils",
"gpu_memory_service.common.protocol",
"gpu_memory_service.common.protocol.messages",
"gpu_memory_service.common.protocol.wire",
"gpu_memory_service.common.types",
"gpu_memory_service.common.utils",
"gpu_memory_service.failover_lock",
"gpu_memory_service.failover_lock.flock",
"gpu_memory_service.integrations",
"gpu_memory_service.integrations.common",
"gpu_memory_service.integrations.common.utils",
"gpu_memory_service.integrations.sglang",
"gpu_memory_service.integrations.sglang.memory_saver",
"gpu_memory_service.integrations.vllm",
"gpu_memory_service.integrations.vllm.worker",
"gpu_memory_service.server",
"gpu_memory_service.server.allocations",
"gpu_memory_service.server.gms",
"gpu_memory_service.server.rpc",
"gpu_memory_service.server.session",
"prometheus_client",
"prometheus_client.parser",
"sklearn",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment