Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
39a6a240
Unverified
Commit
39a6a240
authored
Apr 09, 2026
by
Schwinn Saereesitthipitak
Committed by
GitHub
Apr 09, 2026
Browse files
refactor: simplify GPU Memory Service integrations and module boundaries (#7875)
parent
02666f04
Changes
51
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
427 additions
and
452 deletions
+427
-452
lib/gpu_memory_service/integrations/sglang/memory_saver.py
lib/gpu_memory_service/integrations/sglang/memory_saver.py
+128
-151
lib/gpu_memory_service/integrations/sglang/model_loader.py
lib/gpu_memory_service/integrations/sglang/model_loader.py
+11
-15
lib/gpu_memory_service/integrations/sglang/patches.py
lib/gpu_memory_service/integrations/sglang/patches.py
+12
-22
lib/gpu_memory_service/integrations/vllm/model_loader.py
lib/gpu_memory_service/integrations/vllm/model_loader.py
+5
-3
lib/gpu_memory_service/integrations/vllm/patches.py
lib/gpu_memory_service/integrations/vllm/patches.py
+2
-2
lib/gpu_memory_service/integrations/vllm/worker.py
lib/gpu_memory_service/integrations/vllm/worker.py
+6
-6
lib/gpu_memory_service/server/__init__.py
lib/gpu_memory_service/server/__init__.py
+0
-39
lib/gpu_memory_service/server/fsm.py
lib/gpu_memory_service/server/fsm.py
+171
-0
lib/gpu_memory_service/server/gms.py
lib/gpu_memory_service/server/gms.py
+3
-7
lib/gpu_memory_service/server/rpc.py
lib/gpu_memory_service/server/rpc.py
+2
-1
lib/gpu_memory_service/server/session.py
lib/gpu_memory_service/server/session.py
+56
-187
pyproject.toml
pyproject.toml
+1
-0
tests/gms/common/__init__.py
tests/gms/common/__init__.py
+6
-0
tests/gms/common/test_failover_lock.py
tests/gms/common/test_failover_lock.py
+4
-2
tests/gms/common/test_gms_client_memory_manager.py
tests/gms/common/test_gms_client_memory_manager.py
+3
-2
tests/gms/common/test_gms_client_session.py
tests/gms/common/test_gms_client_session.py
+2
-1
tests/gms/common/test_gms_client_transport.py
tests/gms/common/test_gms_client_transport.py
+1
-0
tests/gms/common/test_gms_harness.py
tests/gms/common/test_gms_harness.py
+6
-1
tests/gms/common/test_gms_runtime_flows.py
tests/gms/common/test_gms_runtime_flows.py
+4
-6
tests/gms/common/test_gms_server_transport_failures.py
tests/gms/common/test_gms_server_transport_failures.py
+4
-7
No files found.
lib/gpu_memory_service/integrations/sglang/memory_saver.py
View file @
39a6a240
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
"""
Hybrid
torch_memory_saver implementation for GPU Memory Service.
"""torch_memory_saver implementation for GPU Memory Service.
This module uses:
SGLang with GMS owns exactly two memory classes:
1. GPU Memory Service for "weights" (shared RO/RW publish flow)
1. "weights" via the shared RO/RW publish flow
2. GPU Memory Service for "kv_cache" (RW-only failover flow)
2. "kv_cache" via the RW failover flow
3. torch_memory_saver for any remaining tags
Unsupported release/resume tags stay no-ops with a warning so the generic
SGLang memory-control API can still pass broader tag sets without reintroducing
the old torch-memory-saver fallback. `cuda_graph` is a hard error because the
pauseable CUDA-graph path depends on the LD_PRELOAD torch allocator hooks that
GMS intentionally does not use.
"""
"""
from
__future__
import
annotations
from
__future__
import
annotations
import
logging
import
logging
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
Optional
import
torch
import
torch
from
gpu_memory_service
import
get_or_create_gms_client_memory_manager
from
gpu_memory_service.client.torch.allocator
import
(
from
gpu_memory_service.client.torch.allocator
import
gms_use_mem_pool
get_or_create_gms_client_memory_manager
,
from
gpu_memory_service.common.types
import
GrantedLockType
,
RequestedLockType
gms_use_mem_pool
,
)
from
gpu_memory_service.common.locks
import
GrantedLockType
,
RequestedLockType
from
gpu_memory_service.common.utils
import
get_socket_path
from
gpu_memory_service.common.utils
import
get_socket_path
from
gpu_memory_service.integrations.common.utils
import
GMS_TAGS
,
finalize_gms_write
if
TYPE_CHECKING
:
from
gpu_memory_service.client.memory_manager
import
GMSClientMemoryManager
from
torch_memory_saver.entrypoint
import
_TorchMemorySaverImpl
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
# Published weights must come back RO, while KV cache always resumes in a fresh
# RW epoch so the restored engine can rebuild mutable cache state.
_TAG_LOCK_TYPES
=
{
"weights"
:
RequestedLockType
.
RO
,
"kv_cache"
:
RequestedLockType
.
RW
}
def
_pause_resume_tags
(
tag
:
Optional
[
str
])
->
tuple
[
str
,
...]:
if
tag
is
None
:
return
GMS_TAGS
if
tag
in
_TAG_LOCK_TYPES
:
return
(
tag
,)
logger
.
warning
(
"[GMS] Ignoring unsupported torch_memory_saver tag %r; supported tags are %s"
,
tag
,
list
(
GMS_TAGS
),
)
return
()
def
get_gms_memory_saver_impl
()
->
Optional
[
"GMSMemorySaverImpl"
]:
def
get_gms_memory_saver_impl
()
->
Optional
[
"GMSMemorySaverImpl"
]:
"""Get the GMS memory saver impl from the torch_memory_saver singleton."""
"""Get the GMS memory saver impl from the torch_memory_saver singleton."""
...
@@ -39,170 +60,126 @@ def get_gms_memory_saver_impl() -> Optional["GMSMemorySaverImpl"]:
...
@@ -39,170 +60,126 @@ def get_gms_memory_saver_impl() -> Optional["GMSMemorySaverImpl"]:
class
GMSMemorySaverImpl
:
class
GMSMemorySaverImpl
:
"""
Hybrid implementation: GMS for weights and KV cache
."""
"""
SGLang memory saver implementation backed only by GMS
."""
def
__init__
(
def
__init__
(
self
,
self
,
torch_impl
:
"_TorchMemorySaverImpl"
,
device_index
:
int
,
device_index
:
int
,
mode
=
None
,
mode
=
None
,
):
):
self
.
_torch_impl
=
torch_impl
self
.
_device
=
torch
.
device
(
"cuda"
,
device_index
)
self
.
_device_index
=
device_index
self
.
imported_weights_bytes
=
0
self
.
_requested_mode
=
mode
requested_mode
=
mode
or
RequestedLockType
.
RW_OR_RO
self
.
_disabled
=
False
self
.
allocators
=
{
self
.
_imported_weights_bytes
:
int
=
0
tag
:
get_or_create_gms_client_memory_manager
(
get_socket_path
(
device_index
,
tag
),
self
.
_weights_allocator
:
Optional
[
"GMSClientMemoryManager"
]
self
.
_kv_cache_allocator
:
"GMSClientMemoryManager"
self
.
_mode
:
str
(
self
.
_weights_allocator
,
self
.
_kv_cache_allocator
,
self
.
_mode
,
)
=
self
.
_init_allocators
()
logger
.
info
(
"[GMS] Initialized weights=%s mode, kv_cache=RW (device=%d)"
,
self
.
_mode
.
upper
(),
device_index
,
device_index
,
# weights follow the configured publish/import mode; kv_cache is
# always mutable and therefore always needs an RW session.
mode
=
requested_mode
if
tag
==
"weights"
else
RequestedLockType
.
RW
,
tag
=
tag
,
)
)
for
tag
in
GMS_TAGS
}
def
_init_allocators
(
self
,
)
->
tuple
[
Optional
[
"GMSClientMemoryManager"
],
"GMSClientMemoryManager"
,
str
,]:
"""Create allocator with mode from config (default: RW_OR_RO)."""
mode
=
self
.
_requested_mode
or
RequestedLockType
.
RW_OR_RO
weights_allocator
=
get_or_create_gms_client_memory_manager
(
get_socket_path
(
self
.
_device_index
,
"weights"
),
self
.
_device_index
,
mode
=
mode
,
tag
=
"weights"
,
)
kv_cache_allocator
=
get_or_create_gms_client_memory_manager
(
get_socket_path
(
self
.
_device_index
,
"kv_cache"
),
self
.
_device_index
,
mode
=
RequestedLockType
.
RW
,
tag
=
"kv_cache"
,
)
granted_mode
=
weights_allocator
.
granted_lock_type
if
granted_mode
==
GrantedLockType
.
RW
:
actual_mode
=
"write"
else
:
actual_mode
=
"read"
logger
.
info
(
logger
.
info
(
"[GMS] Initialized in AUTO mode, granted=%s (device=%d)"
,
"[GMS] Initialized weights: requested=%s granted=%s (device=%d)"
,
actual_mode
.
upper
(),
requested_mode
.
name
,
self
.
_device_index
,
self
.
allocators
[
"weights"
].
granted_lock_type
.
name
,
device_index
,
)
)
return
weights_allocator
,
kv_cache_allocator
,
actual_mode
def
_is_weights_tag
(
self
,
tag
:
Optional
[
str
])
->
bool
:
return
tag
in
(
"weights"
,
"model_weights"
)
def
get_mode
(
self
)
->
str
:
return
self
.
_mode
def
get_allocator
(
self
)
->
Optional
[
"GMSClientMemoryManager"
]:
return
self
.
_weights_allocator
@
contextmanager
@
contextmanager
def
region
(
self
,
tag
:
str
,
enable_cpu_backup
:
bool
):
def
region
(
self
,
tag
:
str
,
enable_cpu_backup
:
bool
):
"""Mark allocation region with tag."""
"""Mark allocation region with tag."""
if
self
.
_is_weights_tag
(
tag
)
:
if
enable_cpu_backup
:
if
self
.
_mode
==
"read"
:
raise
ValueError
(
yield
"SGLang with GMS does not support CPU backup for allocations."
return
)
target_device
=
torch
.
device
(
"cuda"
,
self
.
_device_index
)
if
tag
not
in
_TAG_LOCK_TYPES
:
with
gms_use_mem_pool
(
"weights"
,
target_device
):
logger
.
warning
(
"[GMS] Ignoring unsupported torch_memory_saver region tag %r; "
"supported tags are %s"
,
tag
,
list
(
GMS_TAGS
),
)
yield
yield
return
return
if
tag
==
"kv_cache"
:
if
(
target_device
=
torch
.
device
(
"cuda"
,
self
.
_device_index
)
tag
==
"weights"
with
gms_use_mem_pool
(
"kv_cache"
,
target_device
):
and
self
.
allocators
[
"weights"
].
granted_lock_type
==
GrantedLockType
.
RO
):
# Imported weights are already mapped and immutable in RO mode, so
# there is no allocator swap to install for this region.
yield
yield
return
return
with
self
.
_torch_impl
.
region
(
tag
=
tag
,
enable_cpu_backup
=
enable_cpu_backup
):
allocator
=
self
.
allocators
[
tag
]
if
allocator
.
granted_lock_type
!=
GrantedLockType
.
RW
:
mode
=
(
allocator
.
granted_lock_type
.
name
if
allocator
.
granted_lock_type
is
not
None
else
"DISCONNECTED"
)
# The server would reject writes on a non-RW session too, but we
# fail before entering the allocation path so SGLang never starts a
# partial region with the wrong lock state.
raise
RuntimeError
(
f
"SGLang with GMS requires
{
tag
!
r
}
to be RW for allocations; got
{
mode
}
"
)
with
gms_use_mem_pool
(
tag
,
self
.
_device
):
yield
yield
@
contextmanager
def
cuda_graph
(
self
,
cuda_graph
,
pool
,
stream
,
capture_error_mode
,
tag
:
str
,
enable_cpu_backup
:
bool
,
):
# The old hybrid path could delegate this to torch_memory_saver, but
# strict GMS mode has no compatible pauseable CUDA-graph allocator hook.
raise
RuntimeError
(
"SGLang with GMS does not support pauseable CUDA graphs. "
"torch_memory_saver only supports cuda_graph in hook_mode=preload, "
"and GMS does not use the LD_PRELOAD path."
)
def
pause
(
self
,
tag
:
Optional
[
str
]
=
None
)
->
None
:
def
pause
(
self
,
tag
:
Optional
[
str
]
=
None
)
->
None
:
if
self
.
_disabled
:
for
target_tag
in
_pause_resume_tags
(
tag
)
:
return
if
self
.
allocators
[
target_tag
].
is_unmapped
:
if
tag
is
None
or
self
.
_is_weights_tag
(
tag
):
continue
self
.
_pause_weights
(
)
logger
.
info
(
"[GMS] Unmapping %s"
,
target_tag
)
if
tag
is
None
or
tag
==
"kv_cache"
:
self
.
allocators
[
target_tag
].
unmap_all_vas
()
self
.
_pause_kv_cache
()
# abort() drops the current session after unmapping while keeping
if
tag
is
None
or
(
not
self
.
_is_weights_tag
(
tag
)
and
tag
!=
"kv_cache"
):
# the VA reservation alive for the next resume().
self
.
_torch_impl
.
pause
(
tag
=
tag
)
self
.
allocators
[
target_tag
].
abort
(
)
def
resume
(
self
,
tag
:
Optional
[
str
]
=
None
)
->
None
:
def
resume
(
self
,
tag
:
Optional
[
str
]
=
None
)
->
None
:
if
self
.
_disabled
:
for
target_tag
in
_pause_resume_tags
(
tag
):
return
if
not
self
.
allocators
[
target_tag
].
is_unmapped
:
if
tag
is
None
or
self
.
_is_weights_tag
(
tag
):
continue
self
.
_resume_weights
()
if
tag
is
None
or
tag
==
"kv_cache"
:
logger
.
info
(
"[GMS] Remapping %s"
,
target_tag
)
self
.
_resume_kv_cache
()
self
.
allocators
[
target_tag
].
connect
(
_TAG_LOCK_TYPES
[
target_tag
])
if
tag
is
None
or
(
not
self
.
_is_weights_tag
(
tag
)
and
tag
!=
"kv_cache"
):
if
target_tag
==
"kv_cache"
:
self
.
_torch_impl
.
resume
(
tag
=
tag
)
# KV cache resumes into a new RW layout epoch, so the handles
# must be re-created before the VA range is mapped again.
def
_pause_weights
(
self
)
->
None
:
self
.
allocators
[
target_tag
].
reallocate_all_handles
(
tag
=
target_tag
)
if
self
.
_weights_allocator
is
None
:
self
.
allocators
[
target_tag
].
remap_all_vas
()
return
if
self
.
_weights_allocator
.
is_unmapped
:
return
logger
.
info
(
"[GMS] Unmapping weights (VA-stable)"
)
self
.
_weights_allocator
.
unmap_all_vas
()
self
.
_weights_allocator
.
abort
()
def
_resume_weights
(
self
)
->
None
:
if
self
.
_weights_allocator
is
None
:
return
if
not
self
.
_weights_allocator
.
is_unmapped
:
return
logger
.
info
(
"[GMS] Remapping weights (VA-stable)"
)
self
.
_weights_allocator
.
connect
(
RequestedLockType
.
RO
)
self
.
_weights_allocator
.
remap_all_vas
()
def
_pause_kv_cache
(
self
)
->
None
:
if
self
.
_kv_cache_allocator
.
is_unmapped
:
return
logger
.
info
(
"[GMS] Unmapping KV cache"
)
self
.
_kv_cache_allocator
.
unmap_all_vas
()
self
.
_kv_cache_allocator
.
abort
()
def
_resume_kv_cache
(
self
)
->
None
:
if
not
self
.
_kv_cache_allocator
.
is_unmapped
:
return
logger
.
info
(
"[GMS] Remapping KV cache"
)
self
.
_kv_cache_allocator
.
connect
(
RequestedLockType
.
RW
)
self
.
_kv_cache_allocator
.
reallocate_all_handles
(
tag
=
"kv_cache"
)
self
.
_kv_cache_allocator
.
remap_all_vas
()
def
finalize_write_mode
(
self
,
model
:
torch
.
nn
.
Module
)
->
None
:
def
finalize_write_mode
(
self
,
model
:
torch
.
nn
.
Module
)
->
None
:
"""Finalize write mode: register tensors, commit, and switch to read."""
"""Finalize write mode: register tensors, commit, and switch to read."""
if
self
.
_mode
!=
"write"
:
if
self
.
allocators
[
"weights"
].
granted_lock_type
!=
GrantedLockType
.
RW
:
# Read-only import mode never republishes weights.
return
return
if
self
.
_weights_allocator
is
None
:
raise
RuntimeError
(
"Allocator is None in WRITE mode"
)
from
gpu_memory_service.integrations.common.utils
import
finalize_gms_write
self
.
imported_weights_bytes
=
finalize_gms_write
(
self
.
allocators
[
"weights"
],
model
self
.
_imported_weights_bytes
=
finalize_gms_write
(
self
.
_weights_allocator
,
model
)
)
self
.
_mode
=
"read"
def
set_imported_weights_bytes
(
self
,
bytes_count
:
int
)
->
None
:
self
.
_imported_weights_bytes
=
bytes_count
def
get_imported_weights_bytes
(
self
)
->
int
:
return
self
.
_imported_weights_bytes
def
disable
(
self
)
->
None
:
self
.
_disabled
=
True
def
enable
(
self
)
->
None
:
self
.
_disabled
=
False
lib/gpu_memory_service/integrations/sglang/model_loader.py
View file @
39a6a240
...
@@ -16,11 +16,16 @@ from __future__ import annotations
...
@@ -16,11 +16,16 @@ from __future__ import annotations
import
logging
import
logging
import
torch
import
torch
from
gpu_memory_service.client.torch.module
import
materialize_module_from_gms
from
gpu_memory_service.common.locks
import
GrantedLockType
from
gpu_memory_service.integrations.common
import
patch_empty_cache
from
gpu_memory_service.integrations.common
import
patch_empty_cache
from
gpu_memory_service.integrations.common.utils
import
(
from
gpu_memory_service.integrations.common.utils
import
(
setup_meta_tensor_workaround
,
setup_meta_tensor_workaround
,
strip_gms_model_loader_config
,
strip_gms_model_loader_config
,
)
)
from
gpu_memory_service.integrations.sglang.memory_saver
import
(
get_gms_memory_saver_impl
,
)
from
gpu_memory_service.integrations.sglang.patches
import
(
from
gpu_memory_service.integrations.sglang.patches
import
(
patch_model_runner
,
patch_model_runner
,
patch_static_state_for_gms
,
patch_static_state_for_gms
,
...
@@ -66,10 +71,6 @@ class GMSModelLoader:
...
@@ -66,10 +71,6 @@ class GMSModelLoader:
device_config
,
device_config
,
)
->
torch
.
nn
.
Module
:
)
->
torch
.
nn
.
Module
:
"""Load or import model weights."""
"""Load or import model weights."""
from
gpu_memory_service.integrations.sglang.memory_saver
import
(
get_gms_memory_saver_impl
,
)
impl
=
get_gms_memory_saver_impl
()
impl
=
get_gms_memory_saver_impl
()
if
impl
is
None
:
if
impl
is
None
:
raise
RuntimeError
(
raise
RuntimeError
(
...
@@ -77,12 +78,11 @@ class GMSModelLoader:
...
@@ -77,12 +78,11 @@ class GMSModelLoader:
"Ensure torch_memory_saver patch was applied before model loading."
"Ensure torch_memory_saver patch was applied before model loading."
)
)
mode
=
impl
.
get_mode
()
mode
=
impl
.
allocators
[
"weights"
].
granted_lock_type
logger
.
info
(
"[GMS] Loading model in %s mode"
,
mode
.
upper
()
)
logger
.
info
(
"[GMS] Loading model in %s mode"
,
mode
.
name
)
if
mode
==
"read"
:
if
mode
==
GrantedLockType
.
RO
:
return
self
.
_load_import_only
(
model_config
,
device_config
,
impl
)
return
self
.
_load_import_only
(
model_config
,
device_config
,
impl
)
else
:
return
self
.
_load_write_mode
(
model_config
,
device_config
,
impl
)
return
self
.
_load_write_mode
(
model_config
,
device_config
,
impl
)
def
_load_write_mode
(
self
,
model_config
,
device_config
,
impl
)
->
torch
.
nn
.
Module
:
def
_load_write_mode
(
self
,
model_config
,
device_config
,
impl
)
->
torch
.
nn
.
Module
:
...
@@ -99,17 +99,13 @@ class GMSModelLoader:
...
@@ -99,17 +99,13 @@ class GMSModelLoader:
def
_load_import_only
(
self
,
model_config
,
device_config
,
impl
)
->
torch
.
nn
.
Module
:
def
_load_import_only
(
self
,
model_config
,
device_config
,
impl
)
->
torch
.
nn
.
Module
:
"""Import model weights from GMS metadata (READ mode)."""
"""Import model weights from GMS metadata (READ mode)."""
from
gpu_memory_service.client.torch.module
import
materialize_module_from_gms
allocator
=
impl
.
allocators
[
"weights"
]
allocator
=
impl
.
get_allocator
()
if
allocator
is
None
:
raise
RuntimeError
(
"GMS allocator is None in READ mode"
)
device_index
=
torch
.
cuda
.
current_device
()
device_index
=
torch
.
cuda
.
current_device
()
model
=
self
.
_create_meta_model
(
model_config
,
device_config
)
model
=
self
.
_create_meta_model
(
model_config
,
device_config
)
materialize_module_from_gms
(
allocator
,
model
,
device_index
=
device_index
)
materialize_module_from_gms
(
allocator
,
model
,
device_index
=
device_index
)
impl
.
set_
imported_weights_bytes
(
allocator
.
total_bytes
)
impl
.
imported_weights_bytes
=
allocator
.
total_bytes
logger
.
info
(
logger
.
info
(
"[GMS] READ mode: imported %.2f GiB from metadata"
,
"[GMS] READ mode: imported %.2f GiB from metadata"
,
...
...
lib/gpu_memory_service/integrations/sglang/patches.py
View file @
39a6a240
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
"""SGLang-specific patches for GPU Memory Service integration.
"""SGLang-specific patches for GPU Memory Service integration.
- patch_torch_memory_saver: Routes
to GMS hybrid implementation
- patch_torch_memory_saver: Routes
weights and kv_cache to GMS
- patch_model_runner: Fixes memory accounting with pre-loaded weights
- patch_model_runner: Fixes memory accounting with pre-loaded weights
- patch_static_state_for_gms: No-ops named-buffer export/import (GMS preserves them)
- patch_static_state_for_gms: No-ops named-buffer export/import (GMS preserves them)
"""
"""
...
@@ -15,7 +15,12 @@ import logging
...
@@ -15,7 +15,12 @@ import logging
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Optional
from
typing
import
Optional
import
gpu_memory_service.integrations.sglang
as
gms_sglang
import
torch
import
torch
from
gpu_memory_service.integrations.sglang.memory_saver
import
(
GMSMemorySaverImpl
,
get_gms_memory_saver_impl
,
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -57,25 +62,16 @@ def patch_torch_memory_saver() -> None:
...
@@ -57,25 +62,16 @@ def patch_torch_memory_saver() -> None:
logger
.
info
(
f
"[GMS] TorchMemorySaver initializing with hook_mode=
{
hook_mode
}
"
)
logger
.
info
(
f
"[GMS] TorchMemorySaver initializing with hook_mode=
{
hook_mode
}
"
)
if
hook_mode
is
None
or
hook_mode
==
"gms"
:
if
hook_mode
is
None
or
hook_mode
==
"gms"
:
# Use our GPU Memory Service implementation
# In GMS mode we install only the strict GMS implementation:
from
gpu_memory_service.integrations.sglang.memory_saver
import
(
# weights + kv_cache go through GMS, generic unsupported tags stay
GMSMemorySaverImpl
,
# no-ops/warnings, and cuda_graph remains unsupported.
)
from
torch_memory_saver.entrypoint
import
_TorchMemorySaverImpl
# Get device from torch.cuda.current_device() (already set by SGLang)
# Get device from torch.cuda.current_device() (already set by SGLang)
device_index
=
torch
.
cuda
.
current_device
()
device_index
=
torch
.
cuda
.
current_device
()
# Create underlying torch impl for non-GMS tags.
torch_impl
=
_TorchMemorySaverImpl
(
hook_mode
=
"torch"
)
# Read lock mode set by setup_gms() (defaults to RW_OR_RO)
# Read lock mode set by setup_gms() (defaults to RW_OR_RO)
from
gpu_memory_service.integrations.sglang
import
_gms_lock_mode
gms_impl
=
GMSMemorySaverImpl
(
gms_impl
=
GMSMemorySaverImpl
(
torch_impl
=
torch_impl
,
device_index
=
device_index
,
device_index
=
device_index
,
mode
=
_gms_lock_mode
,
mode
=
gms_sglang
.
_gms_lock_mode
,
)
)
# Set _impl directly (accessible via gms_impl property)
# Set _impl directly (accessible via gms_impl property)
...
@@ -83,7 +79,7 @@ def patch_torch_memory_saver() -> None:
...
@@ -83,7 +79,7 @@ def patch_torch_memory_saver() -> None:
logger
.
info
(
logger
.
info
(
"[GMS] Using GMS mode (device=%d, mode=%s)"
,
"[GMS] Using GMS mode (device=%d, mode=%s)"
,
device_index
,
device_index
,
gms_impl
.
get_mode
()
,
gms_impl
.
allocators
[
"weights"
].
granted_lock_type
.
name
,
)
)
del
self
.
_impl_ctor_kwargs
del
self
.
_impl_ctor_kwargs
else
:
else
:
...
@@ -111,8 +107,6 @@ def patch_torch_memory_saver() -> None:
...
@@ -111,8 +107,6 @@ def patch_torch_memory_saver() -> None:
torch_memory_saver
.
configure_subprocess
=
patched_configure_subprocess
torch_memory_saver
.
configure_subprocess
=
patched_configure_subprocess
# Add property to access GMS impl directly from the singleton
# Add property to access GMS impl directly from the singleton
from
gpu_memory_service.integrations.sglang.memory_saver
import
GMSMemorySaverImpl
@
property
@
property
def
gms_impl
(
self
)
->
Optional
[
GMSMemorySaverImpl
]:
def
gms_impl
(
self
)
->
Optional
[
GMSMemorySaverImpl
]:
"""Get the GMS impl if installed, None otherwise."""
"""Get the GMS impl if installed, None otherwise."""
...
@@ -185,12 +179,8 @@ def patch_model_runner() -> None:
...
@@ -185,12 +179,8 @@ def patch_model_runner() -> None:
weights are already resident. Newer SGLang versions changed this API, so
weights are already resident. Newer SGLang versions changed this API, so
only rewrite the old total_gpu_memory parameter shape.
only rewrite the old total_gpu_memory parameter shape.
"""
"""
from
gpu_memory_service.integrations.sglang.memory_saver
import
(
get_gms_memory_saver_impl
,
)
impl
=
get_gms_memory_saver_impl
()
impl
=
get_gms_memory_saver_impl
()
if
impl
is
not
None
and
impl
.
get_
imported_weights_bytes
()
>
0
:
if
impl
is
not
None
and
impl
.
imported_weights_bytes
>
0
:
total_memory_gib
=
torch
.
cuda
.
get_device_properties
(
total_memory_gib
=
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()
torch
.
cuda
.
current_device
()
).
total_memory
/
(
1
<<
30
)
).
total_memory
/
(
1
<<
30
)
...
...
lib/gpu_memory_service/integrations/vllm/model_loader.py
View file @
39a6a240
...
@@ -14,10 +14,12 @@ import logging
...
@@ -14,10 +14,12 @@ import logging
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
import
torch
import
torch
from
gpu_memory_service
import
get_or_create_gms_client_memory_manager
from
gpu_memory_service.client.torch.allocator
import
(
from
gpu_memory_service.client.torch.allocator
import
gms_use_mem_pool
get_or_create_gms_client_memory_manager
,
gms_use_mem_pool
,
)
from
gpu_memory_service.client.torch.module
import
materialize_module_from_gms
from
gpu_memory_service.client.torch.module
import
materialize_module_from_gms
from
gpu_memory_service.common.
type
s
import
GrantedLockType
from
gpu_memory_service.common.
lock
s
import
GrantedLockType
from
gpu_memory_service.common.utils
import
get_socket_path
from
gpu_memory_service.common.utils
import
get_socket_path
from
gpu_memory_service.integrations.common.utils
import
(
from
gpu_memory_service.integrations.common.utils
import
(
finalize_gms_write
,
finalize_gms_write
,
...
...
lib/gpu_memory_service/integrations/vllm/patches.py
View file @
39a6a240
...
@@ -14,8 +14,8 @@ from __future__ import annotations
...
@@ -14,8 +14,8 @@ from __future__ import annotations
import
logging
import
logging
from
gpu_memory_service
import
get_gms_client_memory_manager
from
gpu_memory_service
.client.torch.allocator
import
get_gms_client_memory_manager
from
gpu_memory_service.common.
type
s
import
GrantedLockType
from
gpu_memory_service.common.
lock
s
import
GrantedLockType
from
gpu_memory_service.integrations.vllm.utils
import
is_shadow_mode
from
gpu_memory_service.integrations.vllm.utils
import
is_shadow_mode
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
lib/gpu_memory_service/integrations/vllm/worker.py
View file @
39a6a240
...
@@ -18,16 +18,16 @@ from contextlib import nullcontext
...
@@ -18,16 +18,16 @@ from contextlib import nullcontext
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
import
torch
import
torch
from
gpu_memory_service
import
(
from
gpu_memory_service.client.memory_manager
import
StaleMemoryLayoutError
from
gpu_memory_service.client.torch.allocator
import
(
get_gms_client_memory_manager
,
get_gms_client_memory_manager
,
get_or_create_gms_client_memory_manager
,
get_or_create_gms_client_memory_manager
,
gms_use_mem_pool
,
)
)
from
gpu_memory_service.client.memory_manager
import
StaleMemoryLayoutError
from
gpu_memory_service.common.locks
import
RequestedLockType
from
gpu_memory_service.client.torch.allocator
import
gms_use_mem_pool
from
gpu_memory_service.common.types
import
RequestedLockType
from
gpu_memory_service.common.utils
import
get_socket_path
from
gpu_memory_service.common.utils
import
get_socket_path
from
gpu_memory_service.integrations.common
import
patch_empty_cache
from
gpu_memory_service.integrations.common
import
patch_empty_cache
from
gpu_memory_service.integrations.common.utils
import
get_gms_lock_mode
from
gpu_memory_service.integrations.common.utils
import
GMS_TAGS
,
get_gms_lock_mode
from
gpu_memory_service.integrations.vllm.model_loader
import
register_gms_loader
from
gpu_memory_service.integrations.vllm.model_loader
import
register_gms_loader
from
gpu_memory_service.integrations.vllm.patches
import
(
from
gpu_memory_service.integrations.vllm.patches
import
(
apply_shadow_mode_patches
,
apply_shadow_mode_patches
,
...
@@ -264,7 +264,7 @@ class GMSWorker(Worker):
...
@@ -264,7 +264,7 @@ class GMSWorker(Worker):
self
.
model_runner
.
exit_shadow_init
()
self
.
model_runner
.
exit_shadow_init
()
if
tags
is
None
:
if
tags
is
None
:
tags
=
[
"weights"
,
"kv_cache"
]
tags
=
list
(
GMS_TAGS
)
if
"weights"
in
tags
:
if
"weights"
in
tags
:
weights_manager
=
get_gms_client_memory_manager
(
"weights"
)
weights_manager
=
get_gms_client_memory_manager
(
"weights"
)
...
...
lib/gpu_memory_service/server/__init__.py
View file @
39a6a240
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
"""GPU Memory Service server components."""
from
gpu_memory_service.common.types
import
(
GrantedLockType
,
RequestedLockType
,
ServerState
,
StateSnapshot
,
)
from
gpu_memory_service.server.allocations
import
(
AllocationInfo
,
AllocationNotFoundError
,
GMSAllocationManager
,
)
from
gpu_memory_service.server.gms
import
GMS
,
MetadataEntry
from
gpu_memory_service.server.rpc
import
GMSRPCServer
from
gpu_memory_service.server.session
import
(
Connection
,
GMSSessionManager
,
InvalidTransition
,
OperationNotAllowed
,
)
__all__
=
[
"GMSRPCServer"
,
"GMS"
,
"GMSSessionManager"
,
"GMSAllocationManager"
,
"AllocationInfo"
,
"AllocationNotFoundError"
,
"MetadataEntry"
,
"Connection"
,
"GrantedLockType"
,
"RequestedLockType"
,
"ServerState"
,
"StateSnapshot"
,
"InvalidTransition"
,
"OperationNotAllowed"
,
]
lib/gpu_memory_service/server/fsm.py
0 → 100644
View file @
39a6a240
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
__future__
import
annotations
import
asyncio
from
dataclasses
import
dataclass
,
field
from
enum
import
Enum
,
auto
from
typing
import
Optional
,
Set
from
gpu_memory_service.common.locks
import
GrantedLockType
class
ServerState
(
str
,
Enum
):
EMPTY
=
"EMPTY"
RW
=
"RW"
COMMITTED
=
"COMMITTED"
RO
=
"RO"
class
StateEvent
(
Enum
):
RW_CONNECT
=
auto
()
RW_COMMIT
=
auto
()
RW_ABORT
=
auto
()
RO_CONNECT
=
auto
()
RO_DISCONNECT
=
auto
()
@
dataclass
(
eq
=
False
)
class
Connection
:
reader
:
asyncio
.
StreamReader
writer
:
asyncio
.
StreamWriter
mode
:
GrantedLockType
session_id
:
str
recv_buffer
:
bytearray
=
field
(
default_factory
=
bytearray
)
def
__hash__
(
self
)
->
int
:
return
hash
(
self
.
session_id
)
async
def
close
(
self
)
->
None
:
self
.
writer
.
close
()
try
:
await
self
.
writer
.
wait_closed
()
except
Exception
:
pass
class
InvalidTransition
(
Exception
):
pass
@
dataclass
(
frozen
=
True
)
class
Transition
:
from_states
:
frozenset
[
ServerState
]
event
:
StateEvent
to_state
:
Optional
[
ServerState
]
condition
:
Optional
[
str
]
=
None
TRANSITIONS
:
list
[
Transition
]
=
[
Transition
(
from_states
=
frozenset
({
ServerState
.
EMPTY
,
ServerState
.
COMMITTED
}),
event
=
StateEvent
.
RW_CONNECT
,
to_state
=
ServerState
.
RW
,
),
Transition
(
from_states
=
frozenset
({
ServerState
.
RW
}),
event
=
StateEvent
.
RW_COMMIT
,
to_state
=
ServerState
.
COMMITTED
,
),
Transition
(
from_states
=
frozenset
({
ServerState
.
RW
}),
event
=
StateEvent
.
RW_ABORT
,
to_state
=
ServerState
.
EMPTY
,
),
Transition
(
from_states
=
frozenset
({
ServerState
.
COMMITTED
,
ServerState
.
RO
}),
event
=
StateEvent
.
RO_CONNECT
,
to_state
=
ServerState
.
RO
,
),
Transition
(
from_states
=
frozenset
({
ServerState
.
RO
}),
event
=
StateEvent
.
RO_DISCONNECT
,
to_state
=
ServerState
.
RO
,
condition
=
"has_remaining_readers"
,
),
Transition
(
from_states
=
frozenset
({
ServerState
.
RO
}),
event
=
StateEvent
.
RO_DISCONNECT
,
to_state
=
ServerState
.
COMMITTED
,
condition
=
"is_last_reader"
,
),
]
class
GMSFSM
:
def
__init__
(
self
):
self
.
_rw_conn
:
Optional
[
Connection
]
=
None
self
.
_ro_conns
:
Set
[
Connection
]
=
set
()
self
.
_committed
=
False
@
property
def
state
(
self
)
->
ServerState
:
if
self
.
_rw_conn
is
not
None
:
return
ServerState
.
RW
if
self
.
_ro_conns
:
return
ServerState
.
RO
if
self
.
_committed
:
return
ServerState
.
COMMITTED
return
ServerState
.
EMPTY
@
property
def
rw_conn
(
self
)
->
Optional
[
Connection
]:
return
self
.
_rw_conn
@
property
def
ro_conns
(
self
)
->
Set
[
Connection
]:
return
self
.
_ro_conns
@
property
def
ro_count
(
self
)
->
int
:
return
len
(
self
.
_ro_conns
)
@
property
def
committed
(
self
)
->
bool
:
return
self
.
_committed
def
_check_condition
(
self
,
condition
:
Optional
[
str
],
conn
:
Connection
)
->
bool
:
if
condition
is
None
:
return
True
if
condition
==
"has_remaining_readers"
:
return
len
(
self
.
_ro_conns
)
>
1
or
conn
not
in
self
.
_ro_conns
if
condition
==
"is_last_reader"
:
return
len
(
self
.
_ro_conns
)
==
1
and
conn
in
self
.
_ro_conns
raise
ValueError
(
f
"Unknown condition:
{
condition
}
"
)
def
transition
(
self
,
event
:
StateEvent
,
conn
:
Connection
)
->
ServerState
:
from_state
=
self
.
state
for
transition
in
TRANSITIONS
:
if
from_state
not
in
transition
.
from_states
:
continue
if
transition
.
event
!=
event
:
continue
if
not
self
.
_check_condition
(
transition
.
condition
,
conn
):
continue
break
else
:
raise
InvalidTransition
(
f
"No transition for
{
event
.
name
}
from state
{
from_state
.
name
}
"
f
"(session=
{
conn
.
session_id
}
)"
)
if
event
==
StateEvent
.
RW_CONNECT
:
self
.
_rw_conn
=
conn
self
.
_committed
=
False
elif
event
==
StateEvent
.
RW_COMMIT
:
self
.
_committed
=
True
self
.
_rw_conn
=
None
elif
event
==
StateEvent
.
RW_ABORT
:
self
.
_rw_conn
=
None
elif
event
==
StateEvent
.
RO_CONNECT
:
self
.
_ro_conns
.
add
(
conn
)
elif
event
==
StateEvent
.
RO_DISCONNECT
:
self
.
_ro_conns
.
discard
(
conn
)
return
self
.
state
def
can_acquire_rw
(
self
)
->
bool
:
return
self
.
_rw_conn
is
None
and
not
self
.
_ro_conns
def
can_acquire_ro
(
self
,
waiting_writers
:
int
)
->
bool
:
return
self
.
_committed
and
self
.
_rw_conn
is
None
and
waiting_writers
==
0
lib/gpu_memory_service/server/gms.py
View file @
39a6a240
...
@@ -11,6 +11,7 @@ from collections import deque
...
@@ -11,6 +11,7 @@ from collections import deque
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Callable
,
Optional
from
typing
import
Callable
,
Optional
from
gpu_memory_service.common.locks
import
GrantedLockType
,
RequestedLockType
from
gpu_memory_service.common.protocol.messages
import
(
from
gpu_memory_service.common.protocol.messages
import
(
AllocateRequest
,
AllocateRequest
,
AllocateResponse
,
AllocateResponse
,
...
@@ -42,15 +43,10 @@ from gpu_memory_service.common.protocol.messages import (
...
@@ -42,15 +43,10 @@ from gpu_memory_service.common.protocol.messages import (
MetadataPutRequest
,
MetadataPutRequest
,
MetadataPutResponse
,
MetadataPutResponse
,
)
)
from
gpu_memory_service.common.types
import
(
GrantedLockType
,
RequestedLockType
,
ServerState
,
StateEvent
,
)
from
.allocations
import
AllocationInfo
,
GMSAllocationManager
from
.allocations
import
AllocationInfo
,
GMSAllocationManager
from
.session
import
Connection
,
GMSSessionManager
from
.fsm
import
Connection
,
ServerState
,
StateEvent
from
.session
import
GMSSessionManager
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
lib/gpu_memory_service/server/rpc.py
View file @
39a6a240
...
@@ -23,8 +23,9 @@ from gpu_memory_service.common.protocol.wire import recv_message, send_message
...
@@ -23,8 +23,9 @@ from gpu_memory_service.common.protocol.wire import recv_message, send_message
from
gpu_memory_service.common.utils
import
fail
from
gpu_memory_service.common.utils
import
fail
from
.allocations
import
AllocationNotFoundError
from
.allocations
import
AllocationNotFoundError
from
.fsm
import
Connection
,
InvalidTransition
from
.gms
import
GMS
from
.gms
import
GMS
from
.session
import
Connection
,
InvalidTransition
,
OperationNotAllowed
from
.session
import
OperationNotAllowed
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
lib/gpu_memory_service/server/session.py
View file @
39a6a240
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
"""Server-side
connection, FSM, and waiter state
."""
"""Server-side
lock acquisition and cleanup
."""
from
__future__
import
annotations
from
__future__
import
annotations
import
asyncio
import
asyncio
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Set
from
typing
import
Optional
from
gpu_memory_service.common.types
import
(
from
gpu_memory_service.common.locks
import
GrantedLockType
,
RequestedLockType
RO_ALLOWED
,
from
gpu_memory_service.common.protocol.messages
import
(
RW_ALLOWED
,
AllocateRequest
,
RW_REQUIRED
,
CommitRequest
,
GrantedLockType
,
ExportAllocationRequest
,
RequestedLockType
,
FreeAllocationRequest
,
ServerState
,
GetAllocationRequest
,
StateEvent
,
GetAllocationStateRequest
,
GetLockStateRequest
,
GetStateHashRequest
,
ListAllocationsRequest
,
MetadataDeleteRequest
,
MetadataGetRequest
,
MetadataListRequest
,
MetadataPutRequest
,
)
)
from
.fsm
import
GMSFSM
,
Connection
,
ServerState
,
StateEvent
@
dataclass
(
eq
=
False
)
class
Connection
:
reader
:
asyncio
.
StreamReader
writer
:
asyncio
.
StreamWriter
mode
:
GrantedLockType
session_id
:
str
recv_buffer
:
bytearray
=
field
(
default_factory
=
bytearray
)
def
__hash__
(
self
)
->
int
:
return
hash
(
self
.
session_id
)
async
def
close
(
self
)
->
None
:
self
.
writer
.
close
()
try
:
await
self
.
writer
.
wait_closed
()
except
Exception
:
pass
class
InvalidTransition
(
Exception
):
"""Raised when an invalid state transition is attempted."""
class
OperationNotAllowed
(
Exception
):
class
OperationNotAllowed
(
Exception
):
"""Raised when an operation is not allowed in the current state/mode."""
pass
@
dataclass
(
frozen
=
True
)
class
Transition
:
from_states
:
frozenset
[
ServerState
]
event
:
StateEvent
to_state
:
Optional
[
ServerState
]
condition
:
Optional
[
str
]
=
None
TRANSITIONS
:
list
[
Transition
]
=
[
Transition
(
from_states
=
frozenset
({
ServerState
.
EMPTY
,
ServerState
.
COMMITTED
}),
event
=
StateEvent
.
RW_CONNECT
,
to_state
=
ServerState
.
RW
,
),
Transition
(
from_states
=
frozenset
({
ServerState
.
RW
}),
event
=
StateEvent
.
RW_COMMIT
,
to_state
=
ServerState
.
COMMITTED
,
),
Transition
(
from_states
=
frozenset
({
ServerState
.
RW
}),
event
=
StateEvent
.
RW_ABORT
,
to_state
=
ServerState
.
EMPTY
,
),
Transition
(
from_states
=
frozenset
({
ServerState
.
COMMITTED
,
ServerState
.
RO
}),
event
=
StateEvent
.
RO_CONNECT
,
to_state
=
ServerState
.
RO
,
),
Transition
(
from_states
=
frozenset
({
ServerState
.
RO
}),
event
=
StateEvent
.
RO_DISCONNECT
,
to_state
=
ServerState
.
RO
,
condition
=
"has_remaining_readers"
,
),
Transition
(
from_states
=
frozenset
({
ServerState
.
RO
}),
event
=
StateEvent
.
RO_DISCONNECT
,
to_state
=
ServerState
.
COMMITTED
,
condition
=
"is_last_reader"
,
),
]
class
GMSLocalFSM
:
"""Explicit connection/lock state machine."""
def
__init__
(
self
):
self
.
_rw_conn
:
Optional
[
Connection
]
=
None
self
.
_ro_conns
:
Set
[
Connection
]
=
set
()
self
.
_committed
=
False
@
property
def
state
(
self
)
->
ServerState
:
if
self
.
_rw_conn
is
not
None
:
return
ServerState
.
RW
if
self
.
_ro_conns
:
return
ServerState
.
RO
if
self
.
_committed
:
return
ServerState
.
COMMITTED
return
ServerState
.
EMPTY
@
property
def
rw_conn
(
self
)
->
Optional
[
Connection
]:
return
self
.
_rw_conn
@
property
def
ro_conns
(
self
)
->
Set
[
Connection
]:
return
self
.
_ro_conns
@
property
def
ro_count
(
self
)
->
int
:
return
len
(
self
.
_ro_conns
)
@
property
def
committed
(
self
)
->
bool
:
return
self
.
_committed
def
_has_remaining_readers
(
self
,
conn
:
Connection
)
->
bool
:
return
len
(
self
.
_ro_conns
)
>
1
or
conn
not
in
self
.
_ro_conns
def
_is_last_reader
(
self
,
conn
:
Connection
)
->
bool
:
return
len
(
self
.
_ro_conns
)
==
1
and
conn
in
self
.
_ro_conns
def
_check_condition
(
self
,
condition
:
Optional
[
str
],
conn
:
Connection
)
->
bool
:
if
condition
is
None
:
return
True
if
condition
==
"has_remaining_readers"
:
return
self
.
_has_remaining_readers
(
conn
)
if
condition
==
"is_last_reader"
:
return
self
.
_is_last_reader
(
conn
)
raise
ValueError
(
f
"Unknown condition:
{
condition
}
"
)
def
_find_transition
(
self
,
from_state
:
ServerState
,
event
:
StateEvent
,
conn
:
Connection
,
)
->
Optional
[
Transition
]:
for
transition
in
TRANSITIONS
:
if
from_state
not
in
transition
.
from_states
:
continue
if
transition
.
event
!=
event
:
continue
if
not
self
.
_check_condition
(
transition
.
condition
,
conn
):
continue
return
transition
return
None
def
_apply_event
(
self
,
event
:
StateEvent
,
conn
:
Connection
)
->
None
:
if
event
==
StateEvent
.
RW_CONNECT
:
self
.
_rw_conn
=
conn
self
.
_committed
=
False
elif
event
==
StateEvent
.
RW_COMMIT
:
self
.
_committed
=
True
self
.
_rw_conn
=
None
elif
event
==
StateEvent
.
RW_ABORT
:
self
.
_rw_conn
=
None
elif
event
==
StateEvent
.
RO_CONNECT
:
self
.
_ro_conns
.
add
(
conn
)
elif
event
==
StateEvent
.
RO_DISCONNECT
:
self
.
_ro_conns
.
discard
(
conn
)
def
transition
(
self
,
event
:
StateEvent
,
conn
:
Connection
)
->
ServerState
:
transition
=
self
.
_find_transition
(
self
.
state
,
event
,
conn
)
if
transition
is
None
:
raise
InvalidTransition
(
f
"No transition for
{
event
.
name
}
from state
{
self
.
state
.
name
}
"
f
"(session=
{
conn
.
session_id
}
)"
)
self
.
_apply_event
(
event
,
conn
)
return
self
.
state
def
check_operation
(
self
,
msg_type
:
type
,
conn
:
Connection
)
->
None
:
RW_REQUIRED
:
frozenset
[
type
]
=
frozenset
(
if
conn
.
mode
==
GrantedLockType
.
RW
and
msg_type
not
in
RW_ALLOWED
:
{
raise
OperationNotAllowed
(
AllocateRequest
,
f
"
{
msg_type
.
__name__
}
not allowed for RW session in state
{
self
.
state
.
name
}
"
FreeAllocationRequest
,
)
MetadataPutRequest
,
if
conn
.
mode
==
GrantedLockType
.
RO
and
msg_type
not
in
RO_ALLOWED
:
MetadataDeleteRequest
,
raise
OperationNotAllowed
(
CommitRequest
,
f
"
{
msg_type
.
__name__
}
not allowed for RO session in state
{
self
.
state
.
name
}
"
}
)
)
if
msg_type
in
RW_REQUIRED
and
conn
.
mode
!=
GrantedLockType
.
RW
:
raise
OperationNotAllowed
(
f
"
{
msg_type
.
__name__
}
requires RW session, got
{
conn
.
mode
.
value
}
"
)
def
can_acquire_rw
(
self
)
->
bool
:
RO_ALLOWED
:
frozenset
[
type
]
=
frozenset
(
return
self
.
_rw_conn
is
None
and
not
self
.
_ro_conns
{
ExportAllocationRequest
,
GetAllocationRequest
,
ListAllocationsRequest
,
MetadataGetRequest
,
MetadataListRequest
,
GetLockStateRequest
,
GetAllocationStateRequest
,
GetStateHashRequest
,
}
)
def
can_acquire_ro
(
self
,
waiting_writers
:
int
)
->
bool
:
RW_ALLOWED
:
frozenset
[
type
]
=
RW_REQUIRED
|
RO_ALLOWED
return
self
.
_committed
and
self
.
_rw_conn
is
None
and
waiting_writers
==
0
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
...
@@ -215,7 +73,7 @@ class GMSSessionManager:
...
@@ -215,7 +73,7 @@ class GMSSessionManager:
"""Owns lock transitions, waiter coordination, and cleanup."""
"""Owns lock transitions, waiter coordination, and cleanup."""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
_locking
=
GMS
Local
FSM
()
self
.
_locking
=
GMSFSM
()
self
.
_waiting_writers
=
0
self
.
_waiting_writers
=
0
self
.
_reserved_rw_session_id
:
Optional
[
str
]
=
None
self
.
_reserved_rw_session_id
:
Optional
[
str
]
=
None
self
.
_condition
=
asyncio
.
Condition
()
self
.
_condition
=
asyncio
.
Condition
()
...
@@ -336,7 +194,18 @@ class GMSSessionManager:
...
@@ -336,7 +194,18 @@ class GMSSessionManager:
self
.
_locking
.
transition
(
StateEvent
.
RW_COMMIT
,
conn
)
self
.
_locking
.
transition
(
StateEvent
.
RW_COMMIT
,
conn
)
def
check_operation
(
self
,
msg_type
:
type
,
conn
:
Connection
)
->
None
:
def
check_operation
(
self
,
msg_type
:
type
,
conn
:
Connection
)
->
None
:
self
.
_locking
.
check_operation
(
msg_type
,
conn
)
if
conn
.
mode
==
GrantedLockType
.
RW
and
msg_type
not
in
RW_ALLOWED
:
raise
OperationNotAllowed
(
f
"
{
msg_type
.
__name__
}
not allowed for RW session in state
{
self
.
state
.
name
}
"
)
if
conn
.
mode
==
GrantedLockType
.
RO
and
msg_type
not
in
RO_ALLOWED
:
raise
OperationNotAllowed
(
f
"
{
msg_type
.
__name__
}
not allowed for RO session in state
{
self
.
state
.
name
}
"
)
if
msg_type
in
RW_REQUIRED
and
conn
.
mode
!=
GrantedLockType
.
RW
:
raise
OperationNotAllowed
(
f
"
{
msg_type
.
__name__
}
requires RW session, got
{
conn
.
mode
.
value
}
"
)
def
begin_cleanup
(
self
,
conn
:
Optional
[
Connection
])
->
StateEvent
|
None
:
def
begin_cleanup
(
self
,
conn
:
Optional
[
Connection
])
->
StateEvent
|
None
:
if
conn
is
None
:
if
conn
is
None
:
...
...
pyproject.toml
View file @
39a6a240
...
@@ -246,6 +246,7 @@ markers = [
...
@@ -246,6 +246,7 @@ markers = [
"stress: marks tests as stress tests"
,
"stress: marks tests as stress tests"
,
"performance: marks tests as performance tests"
,
"performance: marks tests as performance tests"
,
"benchmark: marks tests as benchmark tests"
,
"benchmark: marks tests as benchmark tests"
,
"none: marks tests that do not require a framework-specific runtime"
,
"vllm: marks tests as requiring vllm"
,
"vllm: marks tests as requiring vllm"
,
"trtllm: marks tests as requiring trtllm"
,
"trtllm: marks tests as requiring trtllm"
,
"sglang: marks tests as requiring sglang"
,
"sglang: marks tests as requiring sglang"
,
...
...
tests/gms/common/__init__.py
View file @
39a6a240
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
pytest
pytest
.
importorskip
(
"gpu_memory_service"
,
reason
=
"gpu_memory_service is required"
)
tests/gms/common/test_failover_lock.py
View file @
39a6a240
...
@@ -3,8 +3,9 @@
...
@@ -3,8 +3,9 @@
"""Tests for the flock-based failover lock.
"""Tests for the flock-based failover lock.
No GPU required — these are pure Python/OS tests exercising flock
These are pure Python/OS tests exercising flock semantics across asyncio
semantics across asyncio tasks and child processes.
tasks and child processes, so they stay on the generic cpu-style pre-merge
lane instead of the dedicated GPU job.
"""
"""
import
asyncio
import
asyncio
...
@@ -19,6 +20,7 @@ from gpu_memory_service.failover_lock.flock import FlockFailoverLock
...
@@ -19,6 +20,7 @@ from gpu_memory_service.failover_lock.flock import FlockFailoverLock
pytestmark
=
[
pytestmark
=
[
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
none
,
pytest
.
mark
.
gpu_0
,
pytest
.
mark
.
gpu_0
,
]
]
...
...
tests/gms/common/test_gms_client_memory_manager.py
View file @
39a6a240
...
@@ -9,12 +9,13 @@ from gpu_memory_service.client.memory_manager import (
...
@@ -9,12 +9,13 @@ from gpu_memory_service.client.memory_manager import (
GMSClientMemoryManager
,
GMSClientMemoryManager
,
LocalMapping
,
LocalMapping
,
)
)
from
gpu_memory_service.common.
type
s
import
GrantedLockType
,
RequestedLockType
from
gpu_memory_service.common.
lock
s
import
GrantedLockType
,
RequestedLockType
pytestmark
=
[
pytestmark
=
[
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
gpu_0
,
pytest
.
mark
.
none
,
pytest
.
mark
.
gpu_1
,
]
]
...
...
tests/gms/common/test_gms_client_session.py
View file @
39a6a240
...
@@ -6,15 +6,16 @@ from __future__ import annotations
...
@@ -6,15 +6,16 @@ from __future__ import annotations
import
pytest
import
pytest
from
gpu_memory_service.client.rpc
import
_GMSRPCTransport
from
gpu_memory_service.client.rpc
import
_GMSRPCTransport
from
gpu_memory_service.client.session
import
_GMSClientSession
from
gpu_memory_service.client.session
import
_GMSClientSession
from
gpu_memory_service.common.locks
import
GrantedLockType
,
RequestedLockType
from
gpu_memory_service.common.protocol.messages
import
(
from
gpu_memory_service.common.protocol.messages
import
(
CommitResponse
,
CommitResponse
,
HandshakeResponse
,
HandshakeResponse
,
)
)
from
gpu_memory_service.common.types
import
GrantedLockType
,
RequestedLockType
pytestmark
=
[
pytestmark
=
[
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
none
,
pytest
.
mark
.
gpu_0
,
pytest
.
mark
.
gpu_0
,
]
]
...
...
tests/gms/common/test_gms_client_transport.py
View file @
39a6a240
...
@@ -15,6 +15,7 @@ from gpu_memory_service.common.protocol.messages import (
...
@@ -15,6 +15,7 @@ from gpu_memory_service.common.protocol.messages import (
pytestmark
=
[
pytestmark
=
[
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
none
,
pytest
.
mark
.
gpu_0
,
pytest
.
mark
.
gpu_0
,
]
]
...
...
tests/gms/common/test_gms_harness.py
View file @
39a6a240
...
@@ -10,7 +10,12 @@ import pytest
...
@@ -10,7 +10,12 @@ import pytest
from
tests.gms.harness.gms
import
GMSServerProcess
from
tests.gms.harness.gms
import
GMSServerProcess
from
tests.utils.managed_process
import
ManagedProcess
from
tests.utils.managed_process
import
ManagedProcess
pytestmark
=
[
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
gpu_0
]
pytestmark
=
[
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
none
,
pytest
.
mark
.
gpu_1
,
]
@
pytest
.
fixture
@
pytest
.
fixture
...
...
tests/gms/common/test_gms_runtime_flows.py
View file @
39a6a240
...
@@ -24,19 +24,16 @@ from gpu_memory_service.client.memory_manager import (
...
@@ -24,19 +24,16 @@ from gpu_memory_service.client.memory_manager import (
from
gpu_memory_service.client.rpc
import
_GMSRPCTransport
from
gpu_memory_service.client.rpc
import
_GMSRPCTransport
from
gpu_memory_service.client.session
import
_GMSClientSession
from
gpu_memory_service.client.session
import
_GMSClientSession
from
gpu_memory_service.common
import
cuda_utils
from
gpu_memory_service.common
import
cuda_utils
from
gpu_memory_service.common.locks
import
GrantedLockType
,
RequestedLockType
from
gpu_memory_service.common.protocol.messages
import
(
from
gpu_memory_service.common.protocol.messages
import
(
GetEventHistoryRequest
,
GetEventHistoryRequest
,
GetEventHistoryResponse
,
GetEventHistoryResponse
,
GetRuntimeStateRequest
,
GetRuntimeStateRequest
,
GetRuntimeStateResponse
,
GetRuntimeStateResponse
,
)
)
from
gpu_memory_service.common.types
import
(
GrantedLockType
,
RequestedLockType
,
ServerState
,
)
from
gpu_memory_service.server
import
allocations
as
server_allocations
from
gpu_memory_service.server
import
allocations
as
server_allocations
from
gpu_memory_service.server.allocations
import
GMSAllocationManager
from
gpu_memory_service.server.allocations
import
GMSAllocationManager
from
gpu_memory_service.server.fsm
import
ServerState
from
gpu_memory_service.server.rpc
import
GMSRPCServer
from
gpu_memory_service.server.rpc
import
GMSRPCServer
from
tests.gms.harness.gms
import
ServerThread
from
tests.gms.harness.gms
import
ServerThread
...
@@ -44,7 +41,8 @@ from tests.gms.harness.gms import ServerThread
...
@@ -44,7 +41,8 @@ from tests.gms.harness.gms import ServerThread
pytestmark
=
[
pytestmark
=
[
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
gpu_0
,
pytest
.
mark
.
none
,
pytest
.
mark
.
gpu_1
,
]
]
...
...
tests/gms/common/test_gms_server_transport_failures.py
View file @
39a6a240
...
@@ -15,6 +15,7 @@ from dataclasses import dataclass
...
@@ -15,6 +15,7 @@ from dataclasses import dataclass
import
pytest
import
pytest
from
gpu_memory_service.common
import
cuda_utils
from
gpu_memory_service.common
import
cuda_utils
from
gpu_memory_service.common.locks
import
GrantedLockType
,
RequestedLockType
from
gpu_memory_service.common.protocol.messages
import
(
from
gpu_memory_service.common.protocol.messages
import
(
CommitRequest
,
CommitRequest
,
CommitResponse
,
CommitResponse
,
...
@@ -24,13 +25,8 @@ from gpu_memory_service.common.protocol.messages import (
...
@@ -24,13 +25,8 @@ from gpu_memory_service.common.protocol.messages import (
GetRuntimeStateRequest
,
GetRuntimeStateRequest
,
HandshakeRequest
,
HandshakeRequest
,
)
)
from
gpu_memory_service.common.types
import
(
GrantedLockType
,
RequestedLockType
,
ServerState
,
StateEvent
,
)
from
gpu_memory_service.server.allocations
import
GMSAllocationManager
from
gpu_memory_service.server.allocations
import
GMSAllocationManager
from
gpu_memory_service.server.fsm
import
ServerState
,
StateEvent
from
gpu_memory_service.server.gms
import
GMS
from
gpu_memory_service.server.gms
import
GMS
from
gpu_memory_service.server.rpc
import
GMSRPCServer
,
_is_connection_alive
from
gpu_memory_service.server.rpc
import
GMSRPCServer
,
_is_connection_alive
from
gpu_memory_service.server.session
import
(
from
gpu_memory_service.server.session
import
(
...
@@ -46,7 +42,8 @@ from cuda.bindings import driver as cuda # noqa: E402
...
@@ -46,7 +42,8 @@ from cuda.bindings import driver as cuda # noqa: E402
pytestmark
=
[
pytestmark
=
[
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
gpu_0
,
pytest
.
mark
.
none
,
pytest
.
mark
.
gpu_1
,
]
]
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment