Unverified Commit 39a6a240 authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

refactor: simplify GPU Memory Service integrations and module boundaries (#7875)

parent 02666f04
......@@ -50,7 +50,7 @@ jobs:
dynamo-status-check:
runs-on: ubuntu-latest
needs: [changed-files, build, rust-checks, mypy, test-parallel, test-sequential]
needs: [changed-files, build, rust-checks, mypy, test-parallel, test-sequential, test-generic-gpu]
if: always()
steps:
- name: "Check all dependent jobs"
......@@ -307,3 +307,35 @@ jobs:
hf_token: ${{ secrets.HF_TOKEN }}
parallel_mode: 'none'
dind_as_sidecar: 'false'
test-generic-gpu:
needs: [changed-files, build, mypy]
if: needs.changed-files.outputs.core == 'true'
runs-on: prod-tester-amd-gpu-v1
name: Pytest (GPU)
timeout-minutes: 30
env:
IMAGE_TAG: ${{ secrets.AWS_ACCOUNT_ID }}.dkr.ecr.${{ secrets.AWS_DEFAULT_REGION }}.amazonaws.com/ai-dynamo/dynamo:${{ needs.build.outputs.test_tag_suffix }}
steps:
- name: Checkout repository
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
- name: Docker Login
uses: ./.github/actions/docker-login
with:
aws_default_region: ${{ secrets.AWS_DEFAULT_REGION }}
aws_account_id: ${{ secrets.AWS_ACCOUNT_ID }}
- name: Pull test image
run: |
source ./.github/scripts/retry_docker.sh
retry_pull ${{ env.IMAGE_TAG }}
- name: Run pytest (gpu)
uses: ./.github/actions/pytest
with:
image_tag: ${{ env.IMAGE_TAG }}
pytest_marks: "pre_merge and none and gpu_1"
framework: dynamo
test_type: "pre_merge_gpu"
platform_arch: amd64
hf_token: ${{ secrets.HF_TOKEN }}
parallel_mode: 'none'
dind_as_sidecar: 'true'
......@@ -34,7 +34,6 @@ from dynamo.sglang.publisher import DynamoSglangPublisher
class SGLangEngineQuiesceController:
def __init__(self, engine: sgl.Engine):
self._engine = engine
self._quiesced_tags: Optional[list[str]] = None
self._is_quiesced = False
@property
......@@ -55,7 +54,6 @@ class SGLangEngineQuiesceController:
ReleaseMemoryOccupationReqInput(tags=tags),
None,
)
self._quiesced_tags = None if tags is None else list(tags)
self._is_quiesced = True
return True
......@@ -68,9 +66,8 @@ class SGLangEngineQuiesceController:
ResumeMemoryOccupationReqInput,
)
request_tags = self._quiesced_tags if tags is None else list(tags)
await self._engine.tokenizer_manager.resume_memory_occupation(
ResumeMemoryOccupationReqInput(tags=request_tags),
ResumeMemoryOccupationReqInput(tags=tags),
None,
)
await self._engine.tokenizer_manager.continue_generation(
......@@ -79,7 +76,6 @@ class SGLangEngineQuiesceController:
return True
def mark_resumed(self) -> None:
self._quiesced_tags = None
self._is_quiesced = False
......
......@@ -6,11 +6,19 @@ from __future__ import annotations
from contextlib import contextmanager
import pytest
from gpu_memory_service.common.types import RequestedLockType
from gpu_memory_service.integrations.sglang.memory_saver import GMSMemorySaverImpl
pytest.importorskip("gpu_memory_service", reason="gpu_memory_service is required")
torch = pytest.importorskip("torch", reason="torch is required")
import gpu_memory_service.integrations.sglang.memory_saver as gms_memory_saver # noqa: E402
from gpu_memory_service.common.locks import ( # noqa: E402
GrantedLockType,
RequestedLockType,
)
from gpu_memory_service.integrations.sglang.memory_saver import ( # noqa: E402
GMSMemorySaverImpl,
)
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
......@@ -20,8 +28,14 @@ pytestmark = [
class _FakeManager:
def __init__(self, *, is_unmapped: bool = False):
def __init__(
self,
*,
is_unmapped: bool = False,
granted_lock_type: GrantedLockType | None = None,
):
self.is_unmapped = is_unmapped
self.granted_lock_type = granted_lock_type
self.calls: list[object] = []
def unmap_all_vas(self) -> None:
......@@ -30,9 +44,11 @@ class _FakeManager:
def abort(self) -> None:
self.calls.append("abort")
self.granted_lock_type = None
def connect(self, lock_type) -> None:
self.calls.append(("connect", lock_type))
self.granted_lock_type = GrantedLockType(lock_type.value)
self.is_unmapped = False
def reallocate_all_handles(self, *, tag: str) -> None:
......@@ -43,28 +59,21 @@ class _FakeManager:
self.is_unmapped = False
class _FakeTorchImpl:
def __init__(self):
self.region_calls: list[tuple[str, bool]] = []
self.pause_calls: list[object] = []
self.resume_calls: list[object] = []
@contextmanager
def region(self, tag: str, enable_cpu_backup: bool):
self.region_calls.append((tag, enable_cpu_backup))
yield
def pause(self, tag=None) -> None:
self.pause_calls.append(tag)
def resume(self, tag=None) -> None:
self.resume_calls.append(tag)
@pytest.fixture
def build_impl(monkeypatch, tmp_path):
monkeypatch.setattr(
gms_memory_saver,
"get_socket_path",
lambda device_index, tag: str(tmp_path / f"gms-test-{device_index}-{tag}.sock"),
)
def test_region_routes_weights_and_kv_cache_to_gms(monkeypatch):
weights = _FakeManager()
kv_cache = _FakeManager()
torch_impl = _FakeTorchImpl()
def build(
*,
weights_lock: GrantedLockType = GrantedLockType.RW,
kv_cache_lock: GrantedLockType = GrantedLockType.RW,
):
weights = _FakeManager(granted_lock_type=weights_lock)
kv_cache = _FakeManager(granted_lock_type=kv_cache_lock)
pool_calls: list[tuple[str, torch.device]] = []
@contextmanager
......@@ -73,43 +82,58 @@ def test_region_routes_weights_and_kv_cache_to_gms(monkeypatch):
yield
monkeypatch.setattr(
GMSMemorySaverImpl,
"_init_allocators",
lambda self: (weights, kv_cache, "write"),
gms_memory_saver,
"get_or_create_gms_client_memory_manager",
lambda socket_path, device, mode, tag: {
"weights": weights,
"kv_cache": kv_cache,
}[tag],
)
monkeypatch.setattr(
"gpu_memory_service.integrations.sglang.memory_saver.gms_use_mem_pool",
fake_use_mem_pool,
monkeypatch.setattr(gms_memory_saver, "gms_use_mem_pool", fake_use_mem_pool)
return (
GMSMemorySaverImpl(device_index=0, mode=None),
weights,
kv_cache,
pool_calls,
)
impl = GMSMemorySaverImpl(torch_impl=torch_impl, device_index=2, mode=None)
return build
@pytest.mark.parametrize(
("tag", "weights_lock", "expected_pool_calls"),
[
("weights", GrantedLockType.RW, [("weights", torch.device("cuda", 0))]),
("weights", GrantedLockType.RO, []),
("kv_cache", GrantedLockType.RW, [("kv_cache", torch.device("cuda", 0))]),
("cuda_graph", GrantedLockType.RW, []),
],
)
def test_region_uses_gms_pool_only_for_rw_managed_tags(
build_impl,
tag,
weights_lock,
expected_pool_calls,
):
impl, _, _, pool_calls = build_impl(
weights_lock=weights_lock,
kv_cache_lock=GrantedLockType.RW,
)
with impl.region("weights", enable_cpu_backup=False):
with impl.region(tag, enable_cpu_backup=False):
pass
with impl.region("kv_cache", enable_cpu_backup=False):
pass
with impl.region("cuda_graph", enable_cpu_backup=True):
pass
assert pool_calls == [
("weights", torch.device("cuda", 2)),
("kv_cache", torch.device("cuda", 2)),
]
assert torch_impl.region_calls == [("cuda_graph", True)]
assert pool_calls == expected_pool_calls
def test_pause_resume_routes_kv_cache_to_gms(monkeypatch):
weights = _FakeManager()
kv_cache = _FakeManager()
torch_impl = _FakeTorchImpl()
monkeypatch.setattr(
GMSMemorySaverImpl,
"_init_allocators",
lambda self: (weights, kv_cache, "read"),
def test_pause_resume_routes_only_managed_tags(build_impl):
impl, weights, kv_cache, _ = build_impl(
weights_lock=GrantedLockType.RO,
kv_cache_lock=GrantedLockType.RW,
)
impl = GMSMemorySaverImpl(torch_impl=torch_impl, device_index=0, mode=None)
impl.pause("model_weights")
impl.resume("anything_else")
impl.pause()
impl.resume()
......@@ -127,62 +151,13 @@ def test_pause_resume_routes_kv_cache_to_gms(monkeypatch):
("reallocate_all_handles", "kv_cache"),
"remap_all_vas",
]
assert torch_impl.pause_calls == [None]
assert torch_impl.resume_calls == [None]
def test_region_treats_model_weights_as_weights(monkeypatch):
weights = _FakeManager()
kv_cache = _FakeManager()
torch_impl = _FakeTorchImpl()
pool_calls: list[tuple[str, torch.device]] = []
@contextmanager
def fake_use_mem_pool(tag: str, device: torch.device):
pool_calls.append((tag, device))
yield
monkeypatch.setattr(
GMSMemorySaverImpl,
"_init_allocators",
lambda self: (weights, kv_cache, "write"),
)
monkeypatch.setattr(
"gpu_memory_service.integrations.sglang.memory_saver.gms_use_mem_pool",
fake_use_mem_pool,
)
impl = GMSMemorySaverImpl(torch_impl=torch_impl, device_index=1, mode=None)
@pytest.mark.parametrize("tag", ["weights", "kv_cache"])
def test_region_requires_rw_allocator(build_impl, tag):
impl, _, _, _ = build_impl()
impl.allocators[tag].abort()
with impl.region("model_weights", enable_cpu_backup=False):
with pytest.raises(RuntimeError, match=rf"requires '{tag}' to be RW"):
with impl.region(tag, enable_cpu_backup=False):
pass
assert pool_calls == [("weights", torch.device("cuda", 1))]
assert torch_impl.region_calls == []
def test_pause_resume_model_weights_only_routes_weights(monkeypatch):
weights = _FakeManager()
kv_cache = _FakeManager()
torch_impl = _FakeTorchImpl()
monkeypatch.setattr(
GMSMemorySaverImpl,
"_init_allocators",
lambda self: (weights, kv_cache, "read"),
)
impl = GMSMemorySaverImpl(torch_impl=torch_impl, device_index=0, mode=None)
impl.pause("model_weights")
impl.resume("model_weights")
assert weights.calls == [
"unmap_all_vas",
"abort",
("connect", RequestedLockType.RO),
"remap_all_vas",
]
assert kv_cache.calls == []
assert torch_impl.pause_calls == []
assert torch_impl.resume_calls == []
......@@ -45,7 +45,8 @@ class _TestWorkerHandler(BaseWorkerHandler):
yield {}
def _make_handler() -> _TestWorkerHandler:
@pytest.fixture
def handler():
handler = _TestWorkerHandler.__new__(_TestWorkerHandler)
handler.engine = SimpleNamespace(
tokenizer_manager=SimpleNamespace(
......@@ -65,22 +66,7 @@ def _make_handler() -> _TestWorkerHandler:
@pytest.mark.asyncio
async def test_resume_before_release_is_noop():
handler = _make_handler()
result = await handler.resume_memory_occupation({})
assert result["status"] == "ok"
assert result["message"] == "Memory already resumed"
handler.engine.tokenizer_manager.resume_memory_occupation.assert_not_awaited()
handler.engine.tokenizer_manager.continue_generation.assert_not_awaited()
handler.generate_endpoint.register_endpoint_instance.assert_not_awaited()
@pytest.mark.asyncio
async def test_release_and_resume_are_idempotent():
handler = _make_handler()
async def test_release_and_resume_are_idempotent(handler):
first_release = await handler.release_memory_occupation({})
second_release = await handler.release_memory_occupation({})
......@@ -113,11 +99,9 @@ async def test_release_and_resume_are_idempotent():
@pytest.mark.asyncio
async def test_release_and_resume_use_explicit_request_tags():
handler = _make_handler()
await handler.release_memory_occupation({"tags": ["weights"]})
resume_result = await handler.resume_memory_occupation({"tags": ["weights"]})
async def test_memory_occupation_handlers_forward_tags_exactly(handler):
await handler.release_memory_occupation({"tags": []})
resume_result = await handler.resume_memory_occupation({"tags": []})
assert resume_result["status"] == "ok"
release_req = (
......@@ -126,74 +110,59 @@ async def test_release_and_resume_use_explicit_request_tags():
resume_req = (
handler.engine.tokenizer_manager.resume_memory_occupation.await_args.args[0]
)
assert release_req.tags == ["weights"]
assert resume_req.tags == ["weights"]
handler.engine.tokenizer_manager.continue_generation.assert_awaited_once()
handler.generate_endpoint.register_endpoint_instance.assert_awaited_once()
assert release_req.tags == []
assert resume_req.tags == []
@pytest.mark.asyncio
async def test_resume_reuses_release_tags_when_request_omits_them():
handler = _make_handler()
handler.engine.tokenizer_manager.pause_generation.reset_mock()
handler.engine.tokenizer_manager.release_memory_occupation.reset_mock()
handler.engine.tokenizer_manager.resume_memory_occupation.reset_mock()
handler.engine.tokenizer_manager.continue_generation.reset_mock()
handler.generate_endpoint.unregister_endpoint_instance.reset_mock()
handler.generate_endpoint.register_endpoint_instance.reset_mock()
await handler.release_memory_occupation({"tags": ["weights"]})
resume_result = await handler.resume_memory_occupation({})
assert resume_result["status"] == "ok"
release_req = (
handler.engine.tokenizer_manager.release_memory_occupation.await_args.args[0]
)
resume_req = (
handler.engine.tokenizer_manager.resume_memory_occupation.await_args.args[0]
)
assert resume_req.tags == ["weights"]
assert release_req.tags == ["weights"]
assert resume_req.tags is None
handler.engine.tokenizer_manager.continue_generation.assert_awaited_once()
handler.generate_endpoint.register_endpoint_instance.assert_awaited_once()
@pytest.mark.asyncio
async def test_resume_with_no_sleeping_state_is_noop():
handler = _make_handler()
result = await handler.resume_memory_occupation({})
assert result["status"] == "ok"
assert result["message"] == "Memory already resumed"
handler.engine.tokenizer_manager.resume_memory_occupation.assert_not_awaited()
handler.engine.tokenizer_manager.continue_generation.assert_not_awaited()
handler.generate_endpoint.register_endpoint_instance.assert_not_awaited()
@pytest.mark.asyncio
async def test_release_returns_error_when_worker_has_no_tokenizer_manager():
handler = _make_handler()
handler.engine = None
handler._quiesce_controller = None
result = await handler.release_memory_occupation({})
assert result == {
"status": "error",
"message": "memory control not supported on this worker",
}
handler.generate_endpoint.unregister_endpoint_instance.assert_not_awaited()
@pytest.mark.asyncio
async def test_resume_returns_error_when_worker_has_no_tokenizer_manager():
handler = _make_handler()
@pytest.mark.parametrize(
("method_name", "endpoint_method"),
[
("release_memory_occupation", "unregister_endpoint_instance"),
("resume_memory_occupation", "register_endpoint_instance"),
],
)
async def test_memory_control_returns_error_without_quiesce_controller(
handler,
method_name,
endpoint_method,
):
handler.engine = None
handler._quiesce_controller = None
result = await handler.resume_memory_occupation({})
result = await getattr(handler, method_name)({})
assert result == {
"status": "error",
"message": "memory control not supported on this worker",
}
handler.generate_endpoint.register_endpoint_instance.assert_not_awaited()
getattr(handler.generate_endpoint, endpoint_method).assert_not_awaited()
@pytest.mark.asyncio
async def test_resume_keeps_quiesced_state_when_register_fails():
handler = _make_handler()
async def test_resume_keeps_quiesced_state_when_register_fails(handler):
await handler.release_memory_occupation({})
handler.generate_endpoint.register_endpoint_instance = AsyncMock(
side_effect=RuntimeError("discovery write timeout")
......
......@@ -34,7 +34,7 @@ dynamo:
nixl_libfabric_ref: v2.3.0
enable_kvbm: "true"
enable_media_ffmpeg: "false"
enable_gpu_memory_service: "false"
enable_gpu_memory_service: "true"
ffmpeg_version: "7.1"
efa_version: 1.45.1
......
......@@ -565,9 +565,9 @@ python -m dynamo.sglang \
```
The integration patches `torch_memory_saver` to route both weight and KV-cache operations through GMS:
- Weights (`"weights"` / `"model_weights"` tags) use the `weights` GMS tag
- Weights (`"weights"`) use the `weights` GMS tag
- KV cache (`"kv_cache"`) uses a separate RW-only `kv_cache` GMS tag
- Other tags still use the default torch mempool implementation
- Other tags are not supported in GMS mode
- The `--enable-memory-saver` flag is required to activate the memory saver pathway
### Shadow Engine Failover (Sleep / Wake)
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GPU Memory Service - out-of-process GPU memory manager.
The GPU Memory Service decouples ownership of GPU memory from the processes
that use it, enabling zero-copy sharing and data survival across process crashes.
Package structure:
- common/: Shared types and protocol (used by both server and client)
- server/: Allocation server daemon (no CUDA context required)
- client/: Client library for memory management
- client/torch/: PyTorch integration (allocator, tensor, module, extensions)
Primary client API:
from gpu_memory_service import (
GMSClientMemoryManager,
get_or_create_gms_client_memory_manager,
get_gms_client_memory_manager,
)
Server API:
from gpu_memory_service.server import GMSRPCServer
"""
# Primary client exports
from gpu_memory_service.client.memory_manager import (
GMSClientMemoryManager,
StaleMemoryLayoutError,
)
# PyTorch integration (GMS client memory manager)
from gpu_memory_service.client.torch.allocator import (
get_gms_client_memory_manager,
get_gms_client_memory_managers,
get_or_create_gms_client_memory_manager,
gms_use_mem_pool,
)
__all__ = [
# Client
"GMSClientMemoryManager",
"StaleMemoryLayoutError",
# GMS client memory manager
"get_or_create_gms_client_memory_manager",
"get_gms_client_memory_manager",
"get_gms_client_memory_managers",
"gms_use_mem_pool",
]
......@@ -15,7 +15,7 @@ import asyncio
import logging
import uvloop
from gpu_memory_service.server import GMSRPCServer
from gpu_memory_service.server.rpc import GMSRPCServer
from .args import parse_args
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GPU Memory Service client library.
This module provides the client-side components for interacting with the
GPU Memory Service:
- GMSClientMemoryManager: Manages local VA mappings of remote GPU memory
For PyTorch integration (MemPool, tensor utilities), see gpu_memory_service.client.torch.
"""
from gpu_memory_service.client.memory_manager import (
GMSClientMemoryManager,
StaleMemoryLayoutError,
)
__all__ = [
"GMSClientMemoryManager",
"StaleMemoryLayoutError",
]
......@@ -49,8 +49,8 @@ from gpu_memory_service.common.cuda_utils import (
cumem_set_access,
cumem_unmap,
)
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
from gpu_memory_service.common.protocol.messages import GetAllocationResponse
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
logger = logging.getLogger(__name__)
......
......@@ -14,13 +14,13 @@ import os
import socket
from typing import Optional, Tuple, Type, TypeVar
from gpu_memory_service.common.locks import RequestedLockType
from gpu_memory_service.common.protocol.messages import (
ErrorResponse,
HandshakeRequest,
HandshakeResponse,
)
from gpu_memory_service.common.protocol.wire import recv_message_sync, send_message_sync
from gpu_memory_service.common.types import RequestedLockType
T = TypeVar("T")
......
......@@ -9,6 +9,7 @@ import logging
from typing import List, Optional, Tuple
from gpu_memory_service.client.rpc import _GMSRPCTransport
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
from gpu_memory_service.common.protocol.messages import (
AllocateRequest,
AllocateResponse,
......@@ -38,7 +39,6 @@ from gpu_memory_service.common.protocol.messages import (
MetadataPutRequest,
MetadataPutResponse,
)
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
logger = logging.getLogger(__name__)
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""PyTorch integration for GPU Memory Service.
This module provides PyTorch-specific functionality:
- Memory manager singleton management
- Tensor utilities (metadata, registration, materialization)
- C++ extension for CUDAPluggableAllocator
"""
from gpu_memory_service.client.torch.allocator import (
get_gms_client_memory_manager,
get_gms_client_memory_managers,
get_or_create_gms_client_memory_manager,
gms_use_mem_pool,
)
from gpu_memory_service.client.torch.module import (
materialize_module_from_gms,
register_module_tensors,
)
__all__ = [
# GMS client memory manager
"get_or_create_gms_client_memory_manager",
"get_gms_client_memory_manager",
"get_gms_client_memory_managers",
"gms_use_mem_pool",
# Tensor operations (public API)
"register_module_tensors",
"materialize_module_from_gms",
]
......@@ -11,7 +11,7 @@ from contextvars import ContextVar
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Iterator, Optional
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
if TYPE_CHECKING:
import torch
......
......@@ -8,10 +8,22 @@ from __future__ import annotations
import atexit
import os
from cuda.bindings import driver as cuda
from gpu_memory_service.common.types import GrantedLockType
from gpu_memory_service.common.locks import GrantedLockType
from gpu_memory_service.common.utils import fail
try:
from cuda.bindings import driver as cuda
except ImportError:
# Keep import-time collection working in CPU-only environments and let the
# first real CUDA call fail with a targeted message instead.
class _MissingCuda:
def __getattr__(self, name):
raise RuntimeError(
"cuda-python is required for GPU Memory Service CUDA operations"
)
cuda = _MissingCuda()
_primary_contexts: dict[int, object] = {}
_primary_context_release_registered = False
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from enum import Enum
class RequestedLockType(str, Enum):
RW = "rw"
RO = "ro"
RW_OR_RO = "rw_or_ro"
class GrantedLockType(str, Enum):
RW = "rw"
RO = "ro"
......@@ -3,25 +3,10 @@
"""Message types for GPU Memory Service RPC protocol."""
from enum import Enum
from typing import List, Optional, Union
import msgspec
class RequestedLockType(str, Enum):
"""Lock type requested by client."""
RW = "rw"
RO = "ro"
RW_OR_RO = "rw_or_ro"
class GrantedLockType(str, Enum):
"""Lock type actually granted by server."""
RW = "rw"
RO = "ro"
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
class HandshakeRequest(msgspec.Struct, tag="handshake_request"):
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Shared types for GPU Memory Service."""
from dataclasses import dataclass
from enum import Enum, auto
from gpu_memory_service.common.protocol.messages import (
AllocateRequest,
CommitRequest,
ExportAllocationRequest,
FreeAllocationRequest,
GetAllocationRequest,
GetAllocationStateRequest,
GetLockStateRequest,
GetStateHashRequest,
GrantedLockType,
ListAllocationsRequest,
MetadataDeleteRequest,
MetadataGetRequest,
MetadataListRequest,
MetadataPutRequest,
RequestedLockType,
)
# Re-export lock types for convenience
__all__ = [
"GrantedLockType",
"RequestedLockType",
"ServerState",
"StateEvent",
"StateSnapshot",
"derive_state",
"RW_REQUIRED",
"RO_ALLOWED",
"RW_ALLOWED",
]
class ServerState(str, Enum):
"""Server state - derived from actual connections."""
EMPTY = "EMPTY"
RW = "RW"
COMMITTED = "COMMITTED"
RO = "RO"
class StateEvent(Enum):
"""Events that trigger state transitions."""
RW_CONNECT = auto()
RW_COMMIT = auto()
RW_ABORT = auto()
RO_CONNECT = auto()
RO_DISCONNECT = auto()
@dataclass
class StateSnapshot:
"""Current server state snapshot."""
state: ServerState
has_rw: bool
ro_count: int
waiting_writers: int
committed: bool
@property
def is_ready(self) -> bool:
"""Ready = committed and no RW connection."""
return self.committed and not self.has_rw
def derive_state(has_rw: bool, ro_count: int, committed: bool) -> ServerState:
"""Derive server state from connection info."""
if has_rw:
return ServerState.RW
if ro_count > 0:
return ServerState.RO
if committed:
return ServerState.COMMITTED
return ServerState.EMPTY
# Permission sets: which message types require which connection mode
RW_REQUIRED: frozenset[type] = frozenset(
{
AllocateRequest,
FreeAllocationRequest,
MetadataPutRequest,
MetadataDeleteRequest,
CommitRequest,
}
)
RO_ALLOWED: frozenset[type] = frozenset(
{
ExportAllocationRequest,
GetAllocationRequest,
ListAllocationsRequest,
MetadataGetRequest,
MetadataListRequest,
GetLockStateRequest,
GetAllocationStateRequest,
GetStateHashRequest,
}
)
RW_ALLOWED: frozenset[type] = RW_REQUIRED | RO_ALLOWED
......@@ -5,11 +5,13 @@
from gpu_memory_service.integrations.common.patches import patch_empty_cache
from gpu_memory_service.integrations.common.utils import (
GMS_TAGS,
finalize_gms_write,
setup_meta_tensor_workaround,
)
__all__ = [
"GMS_TAGS",
"patch_empty_cache",
"setup_meta_tensor_workaround",
"finalize_gms_write",
......
......@@ -10,11 +10,14 @@ from dataclasses import replace
from typing import TYPE_CHECKING
import torch
from gpu_memory_service.client.torch.module import register_module_tensors
from gpu_memory_service.common.locks import RequestedLockType
if TYPE_CHECKING:
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
logger = logging.getLogger(__name__)
GMS_TAGS = ("weights", "kv_cache")
def get_gms_lock_mode(extra_config: dict):
......@@ -22,8 +25,6 @@ def get_gms_lock_mode(extra_config: dict):
Returns RO if gms_read_only=True, otherwise RW_OR_RO (default).
"""
from gpu_memory_service.common.types import RequestedLockType
if extra_config.get("gms_read_only", False):
logger.info("[GMS] gms_read_only=True, forcing RO mode")
return RequestedLockType.RO
......@@ -68,9 +69,6 @@ def finalize_gms_write(
Returns:
Total bytes committed.
"""
from gpu_memory_service.client.torch.module import register_module_tensors
from gpu_memory_service.common.types import RequestedLockType
register_module_tensors(allocator, model)
total_bytes = allocator.total_bytes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment