"...ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "dea0b201d6679a013f1238839591b0806130c29c"
Unverified Commit 39a6a240 authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

refactor: simplify GPU Memory Service integrations and module boundaries (#7875)

parent 02666f04
...@@ -50,7 +50,7 @@ jobs: ...@@ -50,7 +50,7 @@ jobs:
dynamo-status-check: dynamo-status-check:
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: [changed-files, build, rust-checks, mypy, test-parallel, test-sequential] needs: [changed-files, build, rust-checks, mypy, test-parallel, test-sequential, test-generic-gpu]
if: always() if: always()
steps: steps:
- name: "Check all dependent jobs" - name: "Check all dependent jobs"
...@@ -307,3 +307,35 @@ jobs: ...@@ -307,3 +307,35 @@ jobs:
hf_token: ${{ secrets.HF_TOKEN }} hf_token: ${{ secrets.HF_TOKEN }}
parallel_mode: 'none' parallel_mode: 'none'
dind_as_sidecar: 'false' dind_as_sidecar: 'false'
test-generic-gpu:
needs: [changed-files, build, mypy]
if: needs.changed-files.outputs.core == 'true'
runs-on: prod-tester-amd-gpu-v1
name: Pytest (GPU)
timeout-minutes: 30
env:
IMAGE_TAG: ${{ secrets.AWS_ACCOUNT_ID }}.dkr.ecr.${{ secrets.AWS_DEFAULT_REGION }}.amazonaws.com/ai-dynamo/dynamo:${{ needs.build.outputs.test_tag_suffix }}
steps:
- name: Checkout repository
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
- name: Docker Login
uses: ./.github/actions/docker-login
with:
aws_default_region: ${{ secrets.AWS_DEFAULT_REGION }}
aws_account_id: ${{ secrets.AWS_ACCOUNT_ID }}
- name: Pull test image
run: |
source ./.github/scripts/retry_docker.sh
retry_pull ${{ env.IMAGE_TAG }}
- name: Run pytest (gpu)
uses: ./.github/actions/pytest
with:
image_tag: ${{ env.IMAGE_TAG }}
pytest_marks: "pre_merge and none and gpu_1"
framework: dynamo
test_type: "pre_merge_gpu"
platform_arch: amd64
hf_token: ${{ secrets.HF_TOKEN }}
parallel_mode: 'none'
dind_as_sidecar: 'true'
...@@ -34,7 +34,6 @@ from dynamo.sglang.publisher import DynamoSglangPublisher ...@@ -34,7 +34,6 @@ from dynamo.sglang.publisher import DynamoSglangPublisher
class SGLangEngineQuiesceController: class SGLangEngineQuiesceController:
def __init__(self, engine: sgl.Engine): def __init__(self, engine: sgl.Engine):
self._engine = engine self._engine = engine
self._quiesced_tags: Optional[list[str]] = None
self._is_quiesced = False self._is_quiesced = False
@property @property
...@@ -55,7 +54,6 @@ class SGLangEngineQuiesceController: ...@@ -55,7 +54,6 @@ class SGLangEngineQuiesceController:
ReleaseMemoryOccupationReqInput(tags=tags), ReleaseMemoryOccupationReqInput(tags=tags),
None, None,
) )
self._quiesced_tags = None if tags is None else list(tags)
self._is_quiesced = True self._is_quiesced = True
return True return True
...@@ -68,9 +66,8 @@ class SGLangEngineQuiesceController: ...@@ -68,9 +66,8 @@ class SGLangEngineQuiesceController:
ResumeMemoryOccupationReqInput, ResumeMemoryOccupationReqInput,
) )
request_tags = self._quiesced_tags if tags is None else list(tags)
await self._engine.tokenizer_manager.resume_memory_occupation( await self._engine.tokenizer_manager.resume_memory_occupation(
ResumeMemoryOccupationReqInput(tags=request_tags), ResumeMemoryOccupationReqInput(tags=tags),
None, None,
) )
await self._engine.tokenizer_manager.continue_generation( await self._engine.tokenizer_manager.continue_generation(
...@@ -79,7 +76,6 @@ class SGLangEngineQuiesceController: ...@@ -79,7 +76,6 @@ class SGLangEngineQuiesceController:
return True return True
def mark_resumed(self) -> None: def mark_resumed(self) -> None:
self._quiesced_tags = None
self._is_quiesced = False self._is_quiesced = False
......
...@@ -6,11 +6,19 @@ from __future__ import annotations ...@@ -6,11 +6,19 @@ from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
import pytest import pytest
from gpu_memory_service.common.types import RequestedLockType
from gpu_memory_service.integrations.sglang.memory_saver import GMSMemorySaverImpl
pytest.importorskip("gpu_memory_service", reason="gpu_memory_service is required")
torch = pytest.importorskip("torch", reason="torch is required") torch = pytest.importorskip("torch", reason="torch is required")
import gpu_memory_service.integrations.sglang.memory_saver as gms_memory_saver # noqa: E402
from gpu_memory_service.common.locks import ( # noqa: E402
GrantedLockType,
RequestedLockType,
)
from gpu_memory_service.integrations.sglang.memory_saver import ( # noqa: E402
GMSMemorySaverImpl,
)
pytestmark = [ pytestmark = [
pytest.mark.pre_merge, pytest.mark.pre_merge,
pytest.mark.unit, pytest.mark.unit,
...@@ -20,8 +28,14 @@ pytestmark = [ ...@@ -20,8 +28,14 @@ pytestmark = [
class _FakeManager: class _FakeManager:
def __init__(self, *, is_unmapped: bool = False): def __init__(
self,
*,
is_unmapped: bool = False,
granted_lock_type: GrantedLockType | None = None,
):
self.is_unmapped = is_unmapped self.is_unmapped = is_unmapped
self.granted_lock_type = granted_lock_type
self.calls: list[object] = [] self.calls: list[object] = []
def unmap_all_vas(self) -> None: def unmap_all_vas(self) -> None:
...@@ -30,9 +44,11 @@ class _FakeManager: ...@@ -30,9 +44,11 @@ class _FakeManager:
def abort(self) -> None: def abort(self) -> None:
self.calls.append("abort") self.calls.append("abort")
self.granted_lock_type = None
def connect(self, lock_type) -> None: def connect(self, lock_type) -> None:
self.calls.append(("connect", lock_type)) self.calls.append(("connect", lock_type))
self.granted_lock_type = GrantedLockType(lock_type.value)
self.is_unmapped = False self.is_unmapped = False
def reallocate_all_handles(self, *, tag: str) -> None: def reallocate_all_handles(self, *, tag: str) -> None:
...@@ -43,73 +59,81 @@ class _FakeManager: ...@@ -43,73 +59,81 @@ class _FakeManager:
self.is_unmapped = False self.is_unmapped = False
class _FakeTorchImpl: @pytest.fixture
def __init__(self): def build_impl(monkeypatch, tmp_path):
self.region_calls: list[tuple[str, bool]] = []
self.pause_calls: list[object] = []
self.resume_calls: list[object] = []
@contextmanager
def region(self, tag: str, enable_cpu_backup: bool):
self.region_calls.append((tag, enable_cpu_backup))
yield
def pause(self, tag=None) -> None:
self.pause_calls.append(tag)
def resume(self, tag=None) -> None:
self.resume_calls.append(tag)
def test_region_routes_weights_and_kv_cache_to_gms(monkeypatch):
weights = _FakeManager()
kv_cache = _FakeManager()
torch_impl = _FakeTorchImpl()
pool_calls: list[tuple[str, torch.device]] = []
@contextmanager
def fake_use_mem_pool(tag: str, device: torch.device):
pool_calls.append((tag, device))
yield
monkeypatch.setattr( monkeypatch.setattr(
GMSMemorySaverImpl, gms_memory_saver,
"_init_allocators", "get_socket_path",
lambda self: (weights, kv_cache, "write"), lambda device_index, tag: str(tmp_path / f"gms-test-{device_index}-{tag}.sock"),
)
monkeypatch.setattr(
"gpu_memory_service.integrations.sglang.memory_saver.gms_use_mem_pool",
fake_use_mem_pool,
) )
impl = GMSMemorySaverImpl(torch_impl=torch_impl, device_index=2, mode=None) def build(
*,
weights_lock: GrantedLockType = GrantedLockType.RW,
kv_cache_lock: GrantedLockType = GrantedLockType.RW,
):
weights = _FakeManager(granted_lock_type=weights_lock)
kv_cache = _FakeManager(granted_lock_type=kv_cache_lock)
pool_calls: list[tuple[str, torch.device]] = []
@contextmanager
def fake_use_mem_pool(tag: str, device: torch.device):
pool_calls.append((tag, device))
yield
monkeypatch.setattr(
gms_memory_saver,
"get_or_create_gms_client_memory_manager",
lambda socket_path, device, mode, tag: {
"weights": weights,
"kv_cache": kv_cache,
}[tag],
)
monkeypatch.setattr(gms_memory_saver, "gms_use_mem_pool", fake_use_mem_pool)
return (
GMSMemorySaverImpl(device_index=0, mode=None),
weights,
kv_cache,
pool_calls,
)
return build
@pytest.mark.parametrize(
("tag", "weights_lock", "expected_pool_calls"),
[
("weights", GrantedLockType.RW, [("weights", torch.device("cuda", 0))]),
("weights", GrantedLockType.RO, []),
("kv_cache", GrantedLockType.RW, [("kv_cache", torch.device("cuda", 0))]),
("cuda_graph", GrantedLockType.RW, []),
],
)
def test_region_uses_gms_pool_only_for_rw_managed_tags(
build_impl,
tag,
weights_lock,
expected_pool_calls,
):
impl, _, _, pool_calls = build_impl(
weights_lock=weights_lock,
kv_cache_lock=GrantedLockType.RW,
)
with impl.region("weights", enable_cpu_backup=False): with impl.region(tag, enable_cpu_backup=False):
pass
with impl.region("kv_cache", enable_cpu_backup=False):
pass
with impl.region("cuda_graph", enable_cpu_backup=True):
pass pass
assert pool_calls == [ assert pool_calls == expected_pool_calls
("weights", torch.device("cuda", 2)),
("kv_cache", torch.device("cuda", 2)),
]
assert torch_impl.region_calls == [("cuda_graph", True)]
def test_pause_resume_routes_kv_cache_to_gms(monkeypatch): def test_pause_resume_routes_only_managed_tags(build_impl):
weights = _FakeManager() impl, weights, kv_cache, _ = build_impl(
kv_cache = _FakeManager() weights_lock=GrantedLockType.RO,
torch_impl = _FakeTorchImpl() kv_cache_lock=GrantedLockType.RW,
monkeypatch.setattr(
GMSMemorySaverImpl,
"_init_allocators",
lambda self: (weights, kv_cache, "read"),
) )
impl = GMSMemorySaverImpl(torch_impl=torch_impl, device_index=0, mode=None) impl.pause("model_weights")
impl.resume("anything_else")
impl.pause() impl.pause()
impl.resume() impl.resume()
...@@ -127,62 +151,13 @@ def test_pause_resume_routes_kv_cache_to_gms(monkeypatch): ...@@ -127,62 +151,13 @@ def test_pause_resume_routes_kv_cache_to_gms(monkeypatch):
("reallocate_all_handles", "kv_cache"), ("reallocate_all_handles", "kv_cache"),
"remap_all_vas", "remap_all_vas",
] ]
assert torch_impl.pause_calls == [None]
assert torch_impl.resume_calls == [None]
def test_region_treats_model_weights_as_weights(monkeypatch):
weights = _FakeManager()
kv_cache = _FakeManager()
torch_impl = _FakeTorchImpl()
pool_calls: list[tuple[str, torch.device]] = []
@contextmanager @pytest.mark.parametrize("tag", ["weights", "kv_cache"])
def fake_use_mem_pool(tag: str, device: torch.device): def test_region_requires_rw_allocator(build_impl, tag):
pool_calls.append((tag, device)) impl, _, _, _ = build_impl()
yield impl.allocators[tag].abort()
monkeypatch.setattr( with pytest.raises(RuntimeError, match=rf"requires '{tag}' to be RW"):
GMSMemorySaverImpl, with impl.region(tag, enable_cpu_backup=False):
"_init_allocators", pass
lambda self: (weights, kv_cache, "write"),
)
monkeypatch.setattr(
"gpu_memory_service.integrations.sglang.memory_saver.gms_use_mem_pool",
fake_use_mem_pool,
)
impl = GMSMemorySaverImpl(torch_impl=torch_impl, device_index=1, mode=None)
with impl.region("model_weights", enable_cpu_backup=False):
pass
assert pool_calls == [("weights", torch.device("cuda", 1))]
assert torch_impl.region_calls == []
def test_pause_resume_model_weights_only_routes_weights(monkeypatch):
weights = _FakeManager()
kv_cache = _FakeManager()
torch_impl = _FakeTorchImpl()
monkeypatch.setattr(
GMSMemorySaverImpl,
"_init_allocators",
lambda self: (weights, kv_cache, "read"),
)
impl = GMSMemorySaverImpl(torch_impl=torch_impl, device_index=0, mode=None)
impl.pause("model_weights")
impl.resume("model_weights")
assert weights.calls == [
"unmap_all_vas",
"abort",
("connect", RequestedLockType.RO),
"remap_all_vas",
]
assert kv_cache.calls == []
assert torch_impl.pause_calls == []
assert torch_impl.resume_calls == []
...@@ -45,7 +45,8 @@ class _TestWorkerHandler(BaseWorkerHandler): ...@@ -45,7 +45,8 @@ class _TestWorkerHandler(BaseWorkerHandler):
yield {} yield {}
def _make_handler() -> _TestWorkerHandler: @pytest.fixture
def handler():
handler = _TestWorkerHandler.__new__(_TestWorkerHandler) handler = _TestWorkerHandler.__new__(_TestWorkerHandler)
handler.engine = SimpleNamespace( handler.engine = SimpleNamespace(
tokenizer_manager=SimpleNamespace( tokenizer_manager=SimpleNamespace(
...@@ -65,22 +66,7 @@ def _make_handler() -> _TestWorkerHandler: ...@@ -65,22 +66,7 @@ def _make_handler() -> _TestWorkerHandler:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_resume_before_release_is_noop(): async def test_release_and_resume_are_idempotent(handler):
handler = _make_handler()
result = await handler.resume_memory_occupation({})
assert result["status"] == "ok"
assert result["message"] == "Memory already resumed"
handler.engine.tokenizer_manager.resume_memory_occupation.assert_not_awaited()
handler.engine.tokenizer_manager.continue_generation.assert_not_awaited()
handler.generate_endpoint.register_endpoint_instance.assert_not_awaited()
@pytest.mark.asyncio
async def test_release_and_resume_are_idempotent():
handler = _make_handler()
first_release = await handler.release_memory_occupation({}) first_release = await handler.release_memory_occupation({})
second_release = await handler.release_memory_occupation({}) second_release = await handler.release_memory_occupation({})
...@@ -113,11 +99,9 @@ async def test_release_and_resume_are_idempotent(): ...@@ -113,11 +99,9 @@ async def test_release_and_resume_are_idempotent():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_release_and_resume_use_explicit_request_tags(): async def test_memory_occupation_handlers_forward_tags_exactly(handler):
handler = _make_handler() await handler.release_memory_occupation({"tags": []})
resume_result = await handler.resume_memory_occupation({"tags": []})
await handler.release_memory_occupation({"tags": ["weights"]})
resume_result = await handler.resume_memory_occupation({"tags": ["weights"]})
assert resume_result["status"] == "ok" assert resume_result["status"] == "ok"
release_req = ( release_req = (
...@@ -126,74 +110,59 @@ async def test_release_and_resume_use_explicit_request_tags(): ...@@ -126,74 +110,59 @@ async def test_release_and_resume_use_explicit_request_tags():
resume_req = ( resume_req = (
handler.engine.tokenizer_manager.resume_memory_occupation.await_args.args[0] handler.engine.tokenizer_manager.resume_memory_occupation.await_args.args[0]
) )
assert release_req.tags == ["weights"] assert release_req.tags == []
assert resume_req.tags == ["weights"] assert resume_req.tags == []
handler.engine.tokenizer_manager.continue_generation.assert_awaited_once()
handler.generate_endpoint.register_endpoint_instance.assert_awaited_once()
handler.engine.tokenizer_manager.pause_generation.reset_mock()
@pytest.mark.asyncio handler.engine.tokenizer_manager.release_memory_occupation.reset_mock()
async def test_resume_reuses_release_tags_when_request_omits_them(): handler.engine.tokenizer_manager.resume_memory_occupation.reset_mock()
handler = _make_handler() handler.engine.tokenizer_manager.continue_generation.reset_mock()
handler.generate_endpoint.unregister_endpoint_instance.reset_mock()
handler.generate_endpoint.register_endpoint_instance.reset_mock()
await handler.release_memory_occupation({"tags": ["weights"]}) await handler.release_memory_occupation({"tags": ["weights"]})
resume_result = await handler.resume_memory_occupation({}) resume_result = await handler.resume_memory_occupation({})
assert resume_result["status"] == "ok" assert resume_result["status"] == "ok"
release_req = (
handler.engine.tokenizer_manager.release_memory_occupation.await_args.args[0]
)
resume_req = ( resume_req = (
handler.engine.tokenizer_manager.resume_memory_occupation.await_args.args[0] handler.engine.tokenizer_manager.resume_memory_occupation.await_args.args[0]
) )
assert resume_req.tags == ["weights"] assert release_req.tags == ["weights"]
assert resume_req.tags is None
handler.engine.tokenizer_manager.continue_generation.assert_awaited_once() handler.engine.tokenizer_manager.continue_generation.assert_awaited_once()
handler.generate_endpoint.register_endpoint_instance.assert_awaited_once() handler.generate_endpoint.register_endpoint_instance.assert_awaited_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_resume_with_no_sleeping_state_is_noop(): @pytest.mark.parametrize(
handler = _make_handler() ("method_name", "endpoint_method"),
[
result = await handler.resume_memory_occupation({}) ("release_memory_occupation", "unregister_endpoint_instance"),
("resume_memory_occupation", "register_endpoint_instance"),
assert result["status"] == "ok" ],
assert result["message"] == "Memory already resumed" )
handler.engine.tokenizer_manager.resume_memory_occupation.assert_not_awaited() async def test_memory_control_returns_error_without_quiesce_controller(
handler.engine.tokenizer_manager.continue_generation.assert_not_awaited() handler,
handler.generate_endpoint.register_endpoint_instance.assert_not_awaited() method_name,
endpoint_method,
):
@pytest.mark.asyncio
async def test_release_returns_error_when_worker_has_no_tokenizer_manager():
handler = _make_handler()
handler.engine = None
handler._quiesce_controller = None
result = await handler.release_memory_occupation({})
assert result == {
"status": "error",
"message": "memory control not supported on this worker",
}
handler.generate_endpoint.unregister_endpoint_instance.assert_not_awaited()
@pytest.mark.asyncio
async def test_resume_returns_error_when_worker_has_no_tokenizer_manager():
handler = _make_handler()
handler.engine = None handler.engine = None
handler._quiesce_controller = None handler._quiesce_controller = None
result = await handler.resume_memory_occupation({}) result = await getattr(handler, method_name)({})
assert result == { assert result == {
"status": "error", "status": "error",
"message": "memory control not supported on this worker", "message": "memory control not supported on this worker",
} }
handler.generate_endpoint.register_endpoint_instance.assert_not_awaited() getattr(handler.generate_endpoint, endpoint_method).assert_not_awaited()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_resume_keeps_quiesced_state_when_register_fails(): async def test_resume_keeps_quiesced_state_when_register_fails(handler):
handler = _make_handler()
await handler.release_memory_occupation({}) await handler.release_memory_occupation({})
handler.generate_endpoint.register_endpoint_instance = AsyncMock( handler.generate_endpoint.register_endpoint_instance = AsyncMock(
side_effect=RuntimeError("discovery write timeout") side_effect=RuntimeError("discovery write timeout")
......
...@@ -34,7 +34,7 @@ dynamo: ...@@ -34,7 +34,7 @@ dynamo:
nixl_libfabric_ref: v2.3.0 nixl_libfabric_ref: v2.3.0
enable_kvbm: "true" enable_kvbm: "true"
enable_media_ffmpeg: "false" enable_media_ffmpeg: "false"
enable_gpu_memory_service: "false" enable_gpu_memory_service: "true"
ffmpeg_version: "7.1" ffmpeg_version: "7.1"
efa_version: 1.45.1 efa_version: 1.45.1
......
...@@ -565,9 +565,9 @@ python -m dynamo.sglang \ ...@@ -565,9 +565,9 @@ python -m dynamo.sglang \
``` ```
The integration patches `torch_memory_saver` to route both weight and KV-cache operations through GMS: The integration patches `torch_memory_saver` to route both weight and KV-cache operations through GMS:
- Weights (`"weights"` / `"model_weights"` tags) use the `weights` GMS tag - Weights (`"weights"`) use the `weights` GMS tag
- KV cache (`"kv_cache"`) uses a separate RW-only `kv_cache` GMS tag - KV cache (`"kv_cache"`) uses a separate RW-only `kv_cache` GMS tag
- Other tags still use the default torch mempool implementation - Other tags are not supported in GMS mode
- The `--enable-memory-saver` flag is required to activate the memory saver pathway - The `--enable-memory-saver` flag is required to activate the memory saver pathway
### Shadow Engine Failover (Sleep / Wake) ### Shadow Engine Failover (Sleep / Wake)
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""GPU Memory Service - out-of-process GPU memory manager.
The GPU Memory Service decouples ownership of GPU memory from the processes
that use it, enabling zero-copy sharing and data survival across process crashes.
Package structure:
- common/: Shared types and protocol (used by both server and client)
- server/: Allocation server daemon (no CUDA context required)
- client/: Client library for memory management
- client/torch/: PyTorch integration (allocator, tensor, module, extensions)
Primary client API:
from gpu_memory_service import (
GMSClientMemoryManager,
get_or_create_gms_client_memory_manager,
get_gms_client_memory_manager,
)
Server API:
from gpu_memory_service.server import GMSRPCServer
"""
# Primary client exports
from gpu_memory_service.client.memory_manager import (
GMSClientMemoryManager,
StaleMemoryLayoutError,
)
# PyTorch integration (GMS client memory manager)
from gpu_memory_service.client.torch.allocator import (
get_gms_client_memory_manager,
get_gms_client_memory_managers,
get_or_create_gms_client_memory_manager,
gms_use_mem_pool,
)
__all__ = [
# Client
"GMSClientMemoryManager",
"StaleMemoryLayoutError",
# GMS client memory manager
"get_or_create_gms_client_memory_manager",
"get_gms_client_memory_manager",
"get_gms_client_memory_managers",
"gms_use_mem_pool",
]
...@@ -15,7 +15,7 @@ import asyncio ...@@ -15,7 +15,7 @@ import asyncio
import logging import logging
import uvloop import uvloop
from gpu_memory_service.server import GMSRPCServer from gpu_memory_service.server.rpc import GMSRPCServer
from .args import parse_args from .args import parse_args
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""GPU Memory Service client library.
This module provides the client-side components for interacting with the
GPU Memory Service:
- GMSClientMemoryManager: Manages local VA mappings of remote GPU memory
For PyTorch integration (MemPool, tensor utilities), see gpu_memory_service.client.torch.
"""
from gpu_memory_service.client.memory_manager import (
GMSClientMemoryManager,
StaleMemoryLayoutError,
)
__all__ = [
"GMSClientMemoryManager",
"StaleMemoryLayoutError",
]
...@@ -49,8 +49,8 @@ from gpu_memory_service.common.cuda_utils import ( ...@@ -49,8 +49,8 @@ from gpu_memory_service.common.cuda_utils import (
cumem_set_access, cumem_set_access,
cumem_unmap, cumem_unmap,
) )
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
from gpu_memory_service.common.protocol.messages import GetAllocationResponse from gpu_memory_service.common.protocol.messages import GetAllocationResponse
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -14,13 +14,13 @@ import os ...@@ -14,13 +14,13 @@ import os
import socket import socket
from typing import Optional, Tuple, Type, TypeVar from typing import Optional, Tuple, Type, TypeVar
from gpu_memory_service.common.locks import RequestedLockType
from gpu_memory_service.common.protocol.messages import ( from gpu_memory_service.common.protocol.messages import (
ErrorResponse, ErrorResponse,
HandshakeRequest, HandshakeRequest,
HandshakeResponse, HandshakeResponse,
) )
from gpu_memory_service.common.protocol.wire import recv_message_sync, send_message_sync from gpu_memory_service.common.protocol.wire import recv_message_sync, send_message_sync
from gpu_memory_service.common.types import RequestedLockType
T = TypeVar("T") T = TypeVar("T")
......
...@@ -9,6 +9,7 @@ import logging ...@@ -9,6 +9,7 @@ import logging
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from gpu_memory_service.client.rpc import _GMSRPCTransport from gpu_memory_service.client.rpc import _GMSRPCTransport
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
from gpu_memory_service.common.protocol.messages import ( from gpu_memory_service.common.protocol.messages import (
AllocateRequest, AllocateRequest,
AllocateResponse, AllocateResponse,
...@@ -38,7 +39,6 @@ from gpu_memory_service.common.protocol.messages import ( ...@@ -38,7 +39,6 @@ from gpu_memory_service.common.protocol.messages import (
MetadataPutRequest, MetadataPutRequest,
MetadataPutResponse, MetadataPutResponse,
) )
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""PyTorch integration for GPU Memory Service.
This module provides PyTorch-specific functionality:
- Memory manager singleton management
- Tensor utilities (metadata, registration, materialization)
- C++ extension for CUDAPluggableAllocator
"""
from gpu_memory_service.client.torch.allocator import (
get_gms_client_memory_manager,
get_gms_client_memory_managers,
get_or_create_gms_client_memory_manager,
gms_use_mem_pool,
)
from gpu_memory_service.client.torch.module import (
materialize_module_from_gms,
register_module_tensors,
)
__all__ = [
# GMS client memory manager
"get_or_create_gms_client_memory_manager",
"get_gms_client_memory_manager",
"get_gms_client_memory_managers",
"gms_use_mem_pool",
# Tensor operations (public API)
"register_module_tensors",
"materialize_module_from_gms",
]
...@@ -11,7 +11,7 @@ from contextvars import ContextVar ...@@ -11,7 +11,7 @@ from contextvars import ContextVar
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Iterator, Optional from typing import TYPE_CHECKING, Any, Iterator, Optional
from gpu_memory_service.common.types import GrantedLockType, RequestedLockType from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
if TYPE_CHECKING: if TYPE_CHECKING:
import torch import torch
......
...@@ -8,10 +8,22 @@ from __future__ import annotations ...@@ -8,10 +8,22 @@ from __future__ import annotations
import atexit import atexit
import os import os
from cuda.bindings import driver as cuda from gpu_memory_service.common.locks import GrantedLockType
from gpu_memory_service.common.types import GrantedLockType
from gpu_memory_service.common.utils import fail from gpu_memory_service.common.utils import fail
try:
from cuda.bindings import driver as cuda
except ImportError:
# Keep import-time collection working in CPU-only environments and let the
# first real CUDA call fail with a targeted message instead.
class _MissingCuda:
def __getattr__(self, name):
raise RuntimeError(
"cuda-python is required for GPU Memory Service CUDA operations"
)
cuda = _MissingCuda()
_primary_contexts: dict[int, object] = {} _primary_contexts: dict[int, object] = {}
_primary_context_release_registered = False _primary_context_release_registered = False
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from enum import Enum
class RequestedLockType(str, Enum):
RW = "rw"
RO = "ro"
RW_OR_RO = "rw_or_ro"
class GrantedLockType(str, Enum):
RW = "rw"
RO = "ro"
...@@ -3,25 +3,10 @@ ...@@ -3,25 +3,10 @@
"""Message types for GPU Memory Service RPC protocol.""" """Message types for GPU Memory Service RPC protocol."""
from enum import Enum
from typing import List, Optional, Union from typing import List, Optional, Union
import msgspec import msgspec
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
class RequestedLockType(str, Enum):
"""Lock type requested by client."""
RW = "rw"
RO = "ro"
RW_OR_RO = "rw_or_ro"
class GrantedLockType(str, Enum):
"""Lock type actually granted by server."""
RW = "rw"
RO = "ro"
class HandshakeRequest(msgspec.Struct, tag="handshake_request"): class HandshakeRequest(msgspec.Struct, tag="handshake_request"):
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Shared types for GPU Memory Service."""
from dataclasses import dataclass
from enum import Enum, auto
from gpu_memory_service.common.protocol.messages import (
AllocateRequest,
CommitRequest,
ExportAllocationRequest,
FreeAllocationRequest,
GetAllocationRequest,
GetAllocationStateRequest,
GetLockStateRequest,
GetStateHashRequest,
GrantedLockType,
ListAllocationsRequest,
MetadataDeleteRequest,
MetadataGetRequest,
MetadataListRequest,
MetadataPutRequest,
RequestedLockType,
)
# Re-export lock types for convenience
__all__ = [
"GrantedLockType",
"RequestedLockType",
"ServerState",
"StateEvent",
"StateSnapshot",
"derive_state",
"RW_REQUIRED",
"RO_ALLOWED",
"RW_ALLOWED",
]
class ServerState(str, Enum):
"""Server state - derived from actual connections."""
EMPTY = "EMPTY"
RW = "RW"
COMMITTED = "COMMITTED"
RO = "RO"
class StateEvent(Enum):
"""Events that trigger state transitions."""
RW_CONNECT = auto()
RW_COMMIT = auto()
RW_ABORT = auto()
RO_CONNECT = auto()
RO_DISCONNECT = auto()
@dataclass
class StateSnapshot:
"""Current server state snapshot."""
state: ServerState
has_rw: bool
ro_count: int
waiting_writers: int
committed: bool
@property
def is_ready(self) -> bool:
"""Ready = committed and no RW connection."""
return self.committed and not self.has_rw
def derive_state(has_rw: bool, ro_count: int, committed: bool) -> ServerState:
"""Derive server state from connection info."""
if has_rw:
return ServerState.RW
if ro_count > 0:
return ServerState.RO
if committed:
return ServerState.COMMITTED
return ServerState.EMPTY
# Permission sets: which message types require which connection mode
RW_REQUIRED: frozenset[type] = frozenset(
{
AllocateRequest,
FreeAllocationRequest,
MetadataPutRequest,
MetadataDeleteRequest,
CommitRequest,
}
)
RO_ALLOWED: frozenset[type] = frozenset(
{
ExportAllocationRequest,
GetAllocationRequest,
ListAllocationsRequest,
MetadataGetRequest,
MetadataListRequest,
GetLockStateRequest,
GetAllocationStateRequest,
GetStateHashRequest,
}
)
RW_ALLOWED: frozenset[type] = RW_REQUIRED | RO_ALLOWED
...@@ -5,11 +5,13 @@ ...@@ -5,11 +5,13 @@
from gpu_memory_service.integrations.common.patches import patch_empty_cache from gpu_memory_service.integrations.common.patches import patch_empty_cache
from gpu_memory_service.integrations.common.utils import ( from gpu_memory_service.integrations.common.utils import (
GMS_TAGS,
finalize_gms_write, finalize_gms_write,
setup_meta_tensor_workaround, setup_meta_tensor_workaround,
) )
__all__ = [ __all__ = [
"GMS_TAGS",
"patch_empty_cache", "patch_empty_cache",
"setup_meta_tensor_workaround", "setup_meta_tensor_workaround",
"finalize_gms_write", "finalize_gms_write",
......
...@@ -10,11 +10,14 @@ from dataclasses import replace ...@@ -10,11 +10,14 @@ from dataclasses import replace
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
from gpu_memory_service.client.torch.module import register_module_tensors
from gpu_memory_service.common.locks import RequestedLockType
if TYPE_CHECKING: if TYPE_CHECKING:
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
GMS_TAGS = ("weights", "kv_cache")
def get_gms_lock_mode(extra_config: dict): def get_gms_lock_mode(extra_config: dict):
...@@ -22,8 +25,6 @@ def get_gms_lock_mode(extra_config: dict): ...@@ -22,8 +25,6 @@ def get_gms_lock_mode(extra_config: dict):
Returns RO if gms_read_only=True, otherwise RW_OR_RO (default). Returns RO if gms_read_only=True, otherwise RW_OR_RO (default).
""" """
from gpu_memory_service.common.types import RequestedLockType
if extra_config.get("gms_read_only", False): if extra_config.get("gms_read_only", False):
logger.info("[GMS] gms_read_only=True, forcing RO mode") logger.info("[GMS] gms_read_only=True, forcing RO mode")
return RequestedLockType.RO return RequestedLockType.RO
...@@ -68,9 +69,6 @@ def finalize_gms_write( ...@@ -68,9 +69,6 @@ def finalize_gms_write(
Returns: Returns:
Total bytes committed. Total bytes committed.
""" """
from gpu_memory_service.client.torch.module import register_module_tensors
from gpu_memory_service.common.types import RequestedLockType
register_module_tensors(allocator, model) register_module_tensors(allocator, model)
total_bytes = allocator.total_bytes total_bytes = allocator.total_bytes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment