Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
39a6a240
"lib/bindings/python/vscode:/vscode.git/clone" did not exist on "aea57caeb5186a75d374d8de0b1e4e165a4413e7"
Unverified
Commit
39a6a240
authored
Apr 09, 2026
by
Schwinn Saereesitthipitak
Committed by
GitHub
Apr 09, 2026
Browse files
refactor: simplify GPU Memory Service integrations and module boundaries (#7875)
parent
02666f04
Changes
51
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
427 additions
and
452 deletions
+427
-452
lib/gpu_memory_service/integrations/sglang/memory_saver.py
lib/gpu_memory_service/integrations/sglang/memory_saver.py
+128
-151
lib/gpu_memory_service/integrations/sglang/model_loader.py
lib/gpu_memory_service/integrations/sglang/model_loader.py
+11
-15
lib/gpu_memory_service/integrations/sglang/patches.py
lib/gpu_memory_service/integrations/sglang/patches.py
+12
-22
lib/gpu_memory_service/integrations/vllm/model_loader.py
lib/gpu_memory_service/integrations/vllm/model_loader.py
+5
-3
lib/gpu_memory_service/integrations/vllm/patches.py
lib/gpu_memory_service/integrations/vllm/patches.py
+2
-2
lib/gpu_memory_service/integrations/vllm/worker.py
lib/gpu_memory_service/integrations/vllm/worker.py
+6
-6
lib/gpu_memory_service/server/__init__.py
lib/gpu_memory_service/server/__init__.py
+0
-39
lib/gpu_memory_service/server/fsm.py
lib/gpu_memory_service/server/fsm.py
+171
-0
lib/gpu_memory_service/server/gms.py
lib/gpu_memory_service/server/gms.py
+3
-7
lib/gpu_memory_service/server/rpc.py
lib/gpu_memory_service/server/rpc.py
+2
-1
lib/gpu_memory_service/server/session.py
lib/gpu_memory_service/server/session.py
+56
-187
pyproject.toml
pyproject.toml
+1
-0
tests/gms/common/__init__.py
tests/gms/common/__init__.py
+6
-0
tests/gms/common/test_failover_lock.py
tests/gms/common/test_failover_lock.py
+4
-2
tests/gms/common/test_gms_client_memory_manager.py
tests/gms/common/test_gms_client_memory_manager.py
+3
-2
tests/gms/common/test_gms_client_session.py
tests/gms/common/test_gms_client_session.py
+2
-1
tests/gms/common/test_gms_client_transport.py
tests/gms/common/test_gms_client_transport.py
+1
-0
tests/gms/common/test_gms_harness.py
tests/gms/common/test_gms_harness.py
+6
-1
tests/gms/common/test_gms_runtime_flows.py
tests/gms/common/test_gms_runtime_flows.py
+4
-6
tests/gms/common/test_gms_server_transport_failures.py
tests/gms/common/test_gms_server_transport_failures.py
+4
-7
No files found.
lib/gpu_memory_service/integrations/sglang/memory_saver.py
View file @
39a6a240
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Hybrid
torch_memory_saver implementation for GPU Memory Service.
"""torch_memory_saver implementation for GPU Memory Service.
This module uses:
1. GPU Memory Service for "weights" (shared RO/RW publish flow)
2. GPU Memory Service for "kv_cache" (RW-only failover flow)
3. torch_memory_saver for any remaining tags
SGLang with GMS owns exactly two memory classes:
1. "weights" via the shared RO/RW publish flow
2. "kv_cache" via the RW failover flow
Unsupported release/resume tags stay no-ops with a warning so the generic
SGLang memory-control API can still pass broader tag sets without reintroducing
the old torch-memory-saver fallback. `cuda_graph` is a hard error because the
pauseable CUDA-graph path depends on the LD_PRELOAD torch allocator hooks that
GMS intentionally does not use.
"""
from
__future__
import
annotations
import
logging
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
Optional
import
torch
from
gpu_memory_service
import
get_or_create_gms_client_memory_manager
from
gpu_memory_service.client.torch.allocator
import
gms_use_mem_pool
from
gpu_memory_service.common.types
import
GrantedLockType
,
RequestedLockType
from
gpu_memory_service.client.torch.allocator
import
(
get_or_create_gms_client_memory_manager
,
gms_use_mem_pool
,
)
from
gpu_memory_service.common.locks
import
GrantedLockType
,
RequestedLockType
from
gpu_memory_service.common.utils
import
get_socket_path
if
TYPE_CHECKING
:
from
gpu_memory_service.client.memory_manager
import
GMSClientMemoryManager
from
torch_memory_saver.entrypoint
import
_TorchMemorySaverImpl
from
gpu_memory_service.integrations.common.utils
import
GMS_TAGS
,
finalize_gms_write
logger
=
logging
.
getLogger
(
__name__
)
# Published weights must come back RO, while KV cache always resumes in a fresh
# RW epoch so the restored engine can rebuild mutable cache state.
_TAG_LOCK_TYPES
=
{
"weights"
:
RequestedLockType
.
RO
,
"kv_cache"
:
RequestedLockType
.
RW
}
def
_pause_resume_tags
(
tag
:
Optional
[
str
])
->
tuple
[
str
,
...]:
if
tag
is
None
:
return
GMS_TAGS
if
tag
in
_TAG_LOCK_TYPES
:
return
(
tag
,)
logger
.
warning
(
"[GMS] Ignoring unsupported torch_memory_saver tag %r; supported tags are %s"
,
tag
,
list
(
GMS_TAGS
),
)
return
()
def
get_gms_memory_saver_impl
()
->
Optional
[
"GMSMemorySaverImpl"
]:
"""Get the GMS memory saver impl from the torch_memory_saver singleton."""
...
...
@@ -39,170 +60,126 @@ def get_gms_memory_saver_impl() -> Optional["GMSMemorySaverImpl"]:
class
GMSMemorySaverImpl
:
"""
Hybrid implementation: GMS for weights and KV cache
."""
"""
SGLang memory saver implementation backed only by GMS
."""
def
__init__
(
self
,
torch_impl
:
"_TorchMemorySaverImpl"
,
device_index
:
int
,
mode
=
None
,
):
self
.
_
torch_impl
=
torch_impl
self
.
_device_index
=
device_index
self
.
_
requested_mode
=
mode
self
.
_disabled
=
False
self
.
_imported_weights_bytes
:
int
=
0
self
.
_weights_allocator
:
Optional
[
"GMSClientMemoryManager"
]
self
.
_kv_cache_allocator
:
"GMSClientMemoryManager"
self
.
_mode
:
str
(
self
.
_weights_allocator
,
self
.
_kv_cache_allocator
,
self
.
_mode
,
)
=
self
.
_init_allocators
()
self
.
_
device
=
torch
.
device
(
"cuda"
,
device_index
)
self
.
imported_weights_bytes
=
0
requested_mode
=
mode
or
RequestedLockType
.
RW_OR_RO
self
.
allocators
=
{
tag
:
get_or_create_gms_client_memory_manager
(
get_socket_path
(
device_index
,
tag
),
device_index
,
# weights follow the configured publish/import mode; kv_cache is
# always mutable and therefore always needs an RW session.
mode
=
requested_mode
if
tag
==
"weights"
else
RequestedLockType
.
RW
,
tag
=
tag
,
)
for
tag
in
GMS_TAGS
}
logger
.
info
(
"[GMS] Initialized weights=%s mode, kv_cache=RW (device=%d)"
,
self
.
_mode
.
upper
(),
"[GMS] Initialized weights: requested=%s granted=%s (device=%d)"
,
requested_mode
.
name
,
self
.
allocators
[
"weights"
].
granted_lock_type
.
name
,
device_index
,
)
def
_init_allocators
(
self
,
)
->
tuple
[
Optional
[
"GMSClientMemoryManager"
],
"GMSClientMemoryManager"
,
str
,]:
"""Create allocator with mode from config (default: RW_OR_RO)."""
mode
=
self
.
_requested_mode
or
RequestedLockType
.
RW_OR_RO
weights_allocator
=
get_or_create_gms_client_memory_manager
(
get_socket_path
(
self
.
_device_index
,
"weights"
),
self
.
_device_index
,
mode
=
mode
,
tag
=
"weights"
,
)
kv_cache_allocator
=
get_or_create_gms_client_memory_manager
(
get_socket_path
(
self
.
_device_index
,
"kv_cache"
),
self
.
_device_index
,
mode
=
RequestedLockType
.
RW
,
tag
=
"kv_cache"
,
)
granted_mode
=
weights_allocator
.
granted_lock_type
if
granted_mode
==
GrantedLockType
.
RW
:
actual_mode
=
"write"
else
:
actual_mode
=
"read"
logger
.
info
(
"[GMS] Initialized in AUTO mode, granted=%s (device=%d)"
,
actual_mode
.
upper
(),
self
.
_device_index
,
)
return
weights_allocator
,
kv_cache_allocator
,
actual_mode
def
_is_weights_tag
(
self
,
tag
:
Optional
[
str
])
->
bool
:
return
tag
in
(
"weights"
,
"model_weights"
)
def
get_mode
(
self
)
->
str
:
return
self
.
_mode
def
get_allocator
(
self
)
->
Optional
[
"GMSClientMemoryManager"
]:
return
self
.
_weights_allocator
@
contextmanager
def
region
(
self
,
tag
:
str
,
enable_cpu_backup
:
bool
):
"""Mark allocation region with tag."""
if
self
.
_is_weights_tag
(
tag
):
if
self
.
_mode
==
"read"
:
yield
return
target_device
=
torch
.
device
(
"cuda"
,
self
.
_device_index
)
with
gms_use_mem_pool
(
"weights"
,
target_device
):
yield
if
enable_cpu_backup
:
raise
ValueError
(
"SGLang with GMS does not support CPU backup for allocations."
)
if
tag
not
in
_TAG_LOCK_TYPES
:
logger
.
warning
(
"[GMS] Ignoring unsupported torch_memory_saver region tag %r; "
"supported tags are %s"
,
tag
,
list
(
GMS_TAGS
),
)
yield
return
if
tag
==
"kv_cache"
:
target_device
=
torch
.
device
(
"cuda"
,
self
.
_device_index
)
with
gms_use_mem_pool
(
"kv_cache"
,
target_device
):
yield
if
(
tag
==
"weights"
and
self
.
allocators
[
"weights"
].
granted_lock_type
==
GrantedLockType
.
RO
):
# Imported weights are already mapped and immutable in RO mode, so
# there is no allocator swap to install for this region.
yield
return
with
self
.
_torch_impl
.
region
(
tag
=
tag
,
enable_cpu_backup
=
enable_cpu_backup
):
allocator
=
self
.
allocators
[
tag
]
if
allocator
.
granted_lock_type
!=
GrantedLockType
.
RW
:
mode
=
(
allocator
.
granted_lock_type
.
name
if
allocator
.
granted_lock_type
is
not
None
else
"DISCONNECTED"
)
# The server would reject writes on a non-RW session too, but we
# fail before entering the allocation path so SGLang never starts a
# partial region with the wrong lock state.
raise
RuntimeError
(
f
"SGLang with GMS requires
{
tag
!
r
}
to be RW for allocations; got
{
mode
}
"
)
with
gms_use_mem_pool
(
tag
,
self
.
_device
):
yield
@
contextmanager
def
cuda_graph
(
self
,
cuda_graph
,
pool
,
stream
,
capture_error_mode
,
tag
:
str
,
enable_cpu_backup
:
bool
,
):
# The old hybrid path could delegate this to torch_memory_saver, but
# strict GMS mode has no compatible pauseable CUDA-graph allocator hook.
raise
RuntimeError
(
"SGLang with GMS does not support pauseable CUDA graphs. "
"torch_memory_saver only supports cuda_graph in hook_mode=preload, "
"and GMS does not use the LD_PRELOAD path."
)
def
pause
(
self
,
tag
:
Optional
[
str
]
=
None
)
->
None
:
if
self
.
_disabled
:
return
if
tag
is
None
or
self
.
_is_weights_tag
(
tag
):
self
.
_pause_weights
(
)
if
tag
is
None
or
tag
==
"kv_cache"
:
self
.
_pause_kv_cache
()
if
tag
is
None
or
(
not
self
.
_is_weights_tag
(
tag
)
and
tag
!=
"kv_cache"
):
self
.
_torch_impl
.
pause
(
tag
=
tag
)
for
target_tag
in
_pause_resume_tags
(
tag
)
:
if
self
.
allocators
[
target_tag
].
is_unmapped
:
continue
logger
.
info
(
"[GMS] Unmapping %s"
,
target_tag
)
self
.
allocators
[
target_tag
].
unmap_all_vas
()
# abort() drops the current session after unmapping while keeping
# the VA reservation alive for the next resume().
self
.
allocators
[
target_tag
].
abort
(
)
def
resume
(
self
,
tag
:
Optional
[
str
]
=
None
)
->
None
:
if
self
.
_disabled
:
return
if
tag
is
None
or
self
.
_is_weights_tag
(
tag
):
self
.
_resume_weights
()
if
tag
is
None
or
tag
==
"kv_cache"
:
self
.
_resume_kv_cache
()
if
tag
is
None
or
(
not
self
.
_is_weights_tag
(
tag
)
and
tag
!=
"kv_cache"
):
self
.
_torch_impl
.
resume
(
tag
=
tag
)
def
_pause_weights
(
self
)
->
None
:
if
self
.
_weights_allocator
is
None
:
return
if
self
.
_weights_allocator
.
is_unmapped
:
return
logger
.
info
(
"[GMS] Unmapping weights (VA-stable)"
)
self
.
_weights_allocator
.
unmap_all_vas
()
self
.
_weights_allocator
.
abort
()
def
_resume_weights
(
self
)
->
None
:
if
self
.
_weights_allocator
is
None
:
return
if
not
self
.
_weights_allocator
.
is_unmapped
:
return
logger
.
info
(
"[GMS] Remapping weights (VA-stable)"
)
self
.
_weights_allocator
.
connect
(
RequestedLockType
.
RO
)
self
.
_weights_allocator
.
remap_all_vas
()
def
_pause_kv_cache
(
self
)
->
None
:
if
self
.
_kv_cache_allocator
.
is_unmapped
:
return
logger
.
info
(
"[GMS] Unmapping KV cache"
)
self
.
_kv_cache_allocator
.
unmap_all_vas
()
self
.
_kv_cache_allocator
.
abort
()
def
_resume_kv_cache
(
self
)
->
None
:
if
not
self
.
_kv_cache_allocator
.
is_unmapped
:
return
logger
.
info
(
"[GMS] Remapping KV cache"
)
self
.
_kv_cache_allocator
.
connect
(
RequestedLockType
.
RW
)
self
.
_kv_cache_allocator
.
reallocate_all_handles
(
tag
=
"kv_cache"
)
self
.
_kv_cache_allocator
.
remap_all_vas
()
for
target_tag
in
_pause_resume_tags
(
tag
):
if
not
self
.
allocators
[
target_tag
].
is_unmapped
:
continue
logger
.
info
(
"[GMS] Remapping %s"
,
target_tag
)
self
.
allocators
[
target_tag
].
connect
(
_TAG_LOCK_TYPES
[
target_tag
])
if
target_tag
==
"kv_cache"
:
# KV cache resumes into a new RW layout epoch, so the handles
# must be re-created before the VA range is mapped again.
self
.
allocators
[
target_tag
].
reallocate_all_handles
(
tag
=
target_tag
)
self
.
allocators
[
target_tag
].
remap_all_vas
()
def
finalize_write_mode
(
self
,
model
:
torch
.
nn
.
Module
)
->
None
:
"""Finalize write mode: register tensors, commit, and switch to read."""
if
self
.
_mode
!=
"write"
:
if
self
.
allocators
[
"weights"
].
granted_lock_type
!=
GrantedLockType
.
RW
:
# Read-only import mode never republishes weights.
return
if
self
.
_weights_allocator
is
None
:
raise
RuntimeError
(
"Allocator is None in WRITE mode"
)
from
gpu_memory_service.integrations.common.utils
import
finalize_gms_write
self
.
_
imported_weights_bytes
=
finalize_gms_write
(
self
.
_weights_
allocator
,
model
self
.
imported_weights_bytes
=
finalize_gms_write
(
self
.
allocator
s
[
"weights"
]
,
model
)
self
.
_mode
=
"read"
def
set_imported_weights_bytes
(
self
,
bytes_count
:
int
)
->
None
:
self
.
_imported_weights_bytes
=
bytes_count
def
get_imported_weights_bytes
(
self
)
->
int
:
return
self
.
_imported_weights_bytes
def
disable
(
self
)
->
None
:
self
.
_disabled
=
True
def
enable
(
self
)
->
None
:
self
.
_disabled
=
False
lib/gpu_memory_service/integrations/sglang/model_loader.py
View file @
39a6a240
...
...
@@ -16,11 +16,16 @@ from __future__ import annotations
import
logging
import
torch
from
gpu_memory_service.client.torch.module
import
materialize_module_from_gms
from
gpu_memory_service.common.locks
import
GrantedLockType
from
gpu_memory_service.integrations.common
import
patch_empty_cache
from
gpu_memory_service.integrations.common.utils
import
(
setup_meta_tensor_workaround
,
strip_gms_model_loader_config
,
)
from
gpu_memory_service.integrations.sglang.memory_saver
import
(
get_gms_memory_saver_impl
,
)
from
gpu_memory_service.integrations.sglang.patches
import
(
patch_model_runner
,
patch_static_state_for_gms
,
...
...
@@ -66,10 +71,6 @@ class GMSModelLoader:
device_config
,
)
->
torch
.
nn
.
Module
:
"""Load or import model weights."""
from
gpu_memory_service.integrations.sglang.memory_saver
import
(
get_gms_memory_saver_impl
,
)
impl
=
get_gms_memory_saver_impl
()
if
impl
is
None
:
raise
RuntimeError
(
...
...
@@ -77,13 +78,12 @@ class GMSModelLoader:
"Ensure torch_memory_saver patch was applied before model loading."
)
mode
=
impl
.
get_mode
()
logger
.
info
(
"[GMS] Loading model in %s mode"
,
mode
.
upper
()
)
mode
=
impl
.
allocators
[
"weights"
].
granted_lock_type
logger
.
info
(
"[GMS] Loading model in %s mode"
,
mode
.
name
)
if
mode
==
"read"
:
if
mode
==
GrantedLockType
.
RO
:
return
self
.
_load_import_only
(
model_config
,
device_config
,
impl
)
else
:
return
self
.
_load_write_mode
(
model_config
,
device_config
,
impl
)
return
self
.
_load_write_mode
(
model_config
,
device_config
,
impl
)
def
_load_write_mode
(
self
,
model_config
,
device_config
,
impl
)
->
torch
.
nn
.
Module
:
"""Load model from disk and register with GMS (WRITE mode)."""
...
...
@@ -99,17 +99,13 @@ class GMSModelLoader:
def
_load_import_only
(
self
,
model_config
,
device_config
,
impl
)
->
torch
.
nn
.
Module
:
"""Import model weights from GMS metadata (READ mode)."""
from
gpu_memory_service.client.torch.module
import
materialize_module_from_gms
allocator
=
impl
.
get_allocator
()
if
allocator
is
None
:
raise
RuntimeError
(
"GMS allocator is None in READ mode"
)
allocator
=
impl
.
allocators
[
"weights"
]
device_index
=
torch
.
cuda
.
current_device
()
model
=
self
.
_create_meta_model
(
model_config
,
device_config
)
materialize_module_from_gms
(
allocator
,
model
,
device_index
=
device_index
)
impl
.
set_
imported_weights_bytes
(
allocator
.
total_bytes
)
impl
.
imported_weights_bytes
=
allocator
.
total_bytes
logger
.
info
(
"[GMS] READ mode: imported %.2f GiB from metadata"
,
...
...
lib/gpu_memory_service/integrations/sglang/patches.py
View file @
39a6a240
...
...
@@ -3,7 +3,7 @@
"""SGLang-specific patches for GPU Memory Service integration.
- patch_torch_memory_saver: Routes
to GMS hybrid implementation
- patch_torch_memory_saver: Routes
weights and kv_cache to GMS
- patch_model_runner: Fixes memory accounting with pre-loaded weights
- patch_static_state_for_gms: No-ops named-buffer export/import (GMS preserves them)
"""
...
...
@@ -15,7 +15,12 @@ import logging
from
contextlib
import
contextmanager
from
typing
import
Optional
import
gpu_memory_service.integrations.sglang
as
gms_sglang
import
torch
from
gpu_memory_service.integrations.sglang.memory_saver
import
(
GMSMemorySaverImpl
,
get_gms_memory_saver_impl
,
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -57,25 +62,16 @@ def patch_torch_memory_saver() -> None:
logger
.
info
(
f
"[GMS] TorchMemorySaver initializing with hook_mode=
{
hook_mode
}
"
)
if
hook_mode
is
None
or
hook_mode
==
"gms"
:
# Use our GPU Memory Service implementation
from
gpu_memory_service.integrations.sglang.memory_saver
import
(
GMSMemorySaverImpl
,
)
from
torch_memory_saver.entrypoint
import
_TorchMemorySaverImpl
# In GMS mode we install only the strict GMS implementation:
# weights + kv_cache go through GMS, generic unsupported tags stay
# no-ops/warnings, and cuda_graph remains unsupported.
# Get device from torch.cuda.current_device() (already set by SGLang)
device_index
=
torch
.
cuda
.
current_device
()
# Create underlying torch impl for non-GMS tags.
torch_impl
=
_TorchMemorySaverImpl
(
hook_mode
=
"torch"
)
# Read lock mode set by setup_gms() (defaults to RW_OR_RO)
from
gpu_memory_service.integrations.sglang
import
_gms_lock_mode
gms_impl
=
GMSMemorySaverImpl
(
torch_impl
=
torch_impl
,
device_index
=
device_index
,
mode
=
_gms_lock_mode
,
mode
=
gms_sglang
.
_gms_lock_mode
,
)
# Set _impl directly (accessible via gms_impl property)
...
...
@@ -83,7 +79,7 @@ def patch_torch_memory_saver() -> None:
logger
.
info
(
"[GMS] Using GMS mode (device=%d, mode=%s)"
,
device_index
,
gms_impl
.
get_mode
()
,
gms_impl
.
allocators
[
"weights"
].
granted_lock_type
.
name
,
)
del
self
.
_impl_ctor_kwargs
else
:
...
...
@@ -111,8 +107,6 @@ def patch_torch_memory_saver() -> None:
torch_memory_saver
.
configure_subprocess
=
patched_configure_subprocess
# Add property to access GMS impl directly from the singleton
from
gpu_memory_service.integrations.sglang.memory_saver
import
GMSMemorySaverImpl
@
property
def
gms_impl
(
self
)
->
Optional
[
GMSMemorySaverImpl
]:
"""Get the GMS impl if installed, None otherwise."""
...
...
@@ -185,12 +179,8 @@ def patch_model_runner() -> None:
weights are already resident. Newer SGLang versions changed this API, so
only rewrite the old total_gpu_memory parameter shape.
"""
from
gpu_memory_service.integrations.sglang.memory_saver
import
(
get_gms_memory_saver_impl
,
)
impl
=
get_gms_memory_saver_impl
()
if
impl
is
not
None
and
impl
.
get_
imported_weights_bytes
()
>
0
:
if
impl
is
not
None
and
impl
.
imported_weights_bytes
>
0
:
total_memory_gib
=
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()
).
total_memory
/
(
1
<<
30
)
...
...
lib/gpu_memory_service/integrations/vllm/model_loader.py
View file @
39a6a240
...
...
@@ -14,10 +14,12 @@ import logging
from
typing
import
TYPE_CHECKING
import
torch
from
gpu_memory_service
import
get_or_create_gms_client_memory_manager
from
gpu_memory_service.client.torch.allocator
import
gms_use_mem_pool
from
gpu_memory_service.client.torch.allocator
import
(
get_or_create_gms_client_memory_manager
,
gms_use_mem_pool
,
)
from
gpu_memory_service.client.torch.module
import
materialize_module_from_gms
from
gpu_memory_service.common.
type
s
import
GrantedLockType
from
gpu_memory_service.common.
lock
s
import
GrantedLockType
from
gpu_memory_service.common.utils
import
get_socket_path
from
gpu_memory_service.integrations.common.utils
import
(
finalize_gms_write
,
...
...
lib/gpu_memory_service/integrations/vllm/patches.py
View file @
39a6a240
...
...
@@ -14,8 +14,8 @@ from __future__ import annotations
import
logging
from
gpu_memory_service
import
get_gms_client_memory_manager
from
gpu_memory_service.common.
type
s
import
GrantedLockType
from
gpu_memory_service
.client.torch.allocator
import
get_gms_client_memory_manager
from
gpu_memory_service.common.
lock
s
import
GrantedLockType
from
gpu_memory_service.integrations.vllm.utils
import
is_shadow_mode
logger
=
logging
.
getLogger
(
__name__
)
...
...
lib/gpu_memory_service/integrations/vllm/worker.py
View file @
39a6a240
...
...
@@ -18,16 +18,16 @@ from contextlib import nullcontext
from
typing
import
List
,
Optional
import
torch
from
gpu_memory_service
import
(
from
gpu_memory_service.client.memory_manager
import
StaleMemoryLayoutError
from
gpu_memory_service.client.torch.allocator
import
(
get_gms_client_memory_manager
,
get_or_create_gms_client_memory_manager
,
gms_use_mem_pool
,
)
from
gpu_memory_service.client.memory_manager
import
StaleMemoryLayoutError
from
gpu_memory_service.client.torch.allocator
import
gms_use_mem_pool
from
gpu_memory_service.common.types
import
RequestedLockType
from
gpu_memory_service.common.locks
import
RequestedLockType
from
gpu_memory_service.common.utils
import
get_socket_path
from
gpu_memory_service.integrations.common
import
patch_empty_cache
from
gpu_memory_service.integrations.common.utils
import
get_gms_lock_mode
from
gpu_memory_service.integrations.common.utils
import
GMS_TAGS
,
get_gms_lock_mode
from
gpu_memory_service.integrations.vllm.model_loader
import
register_gms_loader
from
gpu_memory_service.integrations.vllm.patches
import
(
apply_shadow_mode_patches
,
...
...
@@ -264,7 +264,7 @@ class GMSWorker(Worker):
self
.
model_runner
.
exit_shadow_init
()
if
tags
is
None
:
tags
=
[
"weights"
,
"kv_cache"
]
tags
=
list
(
GMS_TAGS
)
if
"weights"
in
tags
:
weights_manager
=
get_gms_client_memory_manager
(
"weights"
)
...
...
lib/gpu_memory_service/server/__init__.py
View file @
39a6a240
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GPU Memory Service server components."""
from
gpu_memory_service.common.types
import
(
GrantedLockType
,
RequestedLockType
,
ServerState
,
StateSnapshot
,
)
from
gpu_memory_service.server.allocations
import
(
AllocationInfo
,
AllocationNotFoundError
,
GMSAllocationManager
,
)
from
gpu_memory_service.server.gms
import
GMS
,
MetadataEntry
from
gpu_memory_service.server.rpc
import
GMSRPCServer
from
gpu_memory_service.server.session
import
(
Connection
,
GMSSessionManager
,
InvalidTransition
,
OperationNotAllowed
,
)
__all__
=
[
"GMSRPCServer"
,
"GMS"
,
"GMSSessionManager"
,
"GMSAllocationManager"
,
"AllocationInfo"
,
"AllocationNotFoundError"
,
"MetadataEntry"
,
"Connection"
,
"GrantedLockType"
,
"RequestedLockType"
,
"ServerState"
,
"StateSnapshot"
,
"InvalidTransition"
,
"OperationNotAllowed"
,
]
lib/gpu_memory_service/server/fsm.py
0 → 100644
View file @
39a6a240
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
__future__
import
annotations
import
asyncio
from
dataclasses
import
dataclass
,
field
from
enum
import
Enum
,
auto
from
typing
import
Optional
,
Set
from
gpu_memory_service.common.locks
import
GrantedLockType
class
ServerState
(
str
,
Enum
):
EMPTY
=
"EMPTY"
RW
=
"RW"
COMMITTED
=
"COMMITTED"
RO
=
"RO"
class
StateEvent
(
Enum
):
RW_CONNECT
=
auto
()
RW_COMMIT
=
auto
()
RW_ABORT
=
auto
()
RO_CONNECT
=
auto
()
RO_DISCONNECT
=
auto
()
@
dataclass
(
eq
=
False
)
class
Connection
:
reader
:
asyncio
.
StreamReader
writer
:
asyncio
.
StreamWriter
mode
:
GrantedLockType
session_id
:
str
recv_buffer
:
bytearray
=
field
(
default_factory
=
bytearray
)
def
__hash__
(
self
)
->
int
:
return
hash
(
self
.
session_id
)
async
def
close
(
self
)
->
None
:
self
.
writer
.
close
()
try
:
await
self
.
writer
.
wait_closed
()
except
Exception
:
pass
class
InvalidTransition
(
Exception
):
pass
@
dataclass
(
frozen
=
True
)
class
Transition
:
from_states
:
frozenset
[
ServerState
]
event
:
StateEvent
to_state
:
Optional
[
ServerState
]
condition
:
Optional
[
str
]
=
None
TRANSITIONS
:
list
[
Transition
]
=
[
Transition
(
from_states
=
frozenset
({
ServerState
.
EMPTY
,
ServerState
.
COMMITTED
}),
event
=
StateEvent
.
RW_CONNECT
,
to_state
=
ServerState
.
RW
,
),
Transition
(
from_states
=
frozenset
({
ServerState
.
RW
}),
event
=
StateEvent
.
RW_COMMIT
,
to_state
=
ServerState
.
COMMITTED
,
),
Transition
(
from_states
=
frozenset
({
ServerState
.
RW
}),
event
=
StateEvent
.
RW_ABORT
,
to_state
=
ServerState
.
EMPTY
,
),
Transition
(
from_states
=
frozenset
({
ServerState
.
COMMITTED
,
ServerState
.
RO
}),
event
=
StateEvent
.
RO_CONNECT
,
to_state
=
ServerState
.
RO
,
),
Transition
(
from_states
=
frozenset
({
ServerState
.
RO
}),
event
=
StateEvent
.
RO_DISCONNECT
,
to_state
=
ServerState
.
RO
,
condition
=
"has_remaining_readers"
,
),
Transition
(
from_states
=
frozenset
({
ServerState
.
RO
}),
event
=
StateEvent
.
RO_DISCONNECT
,
to_state
=
ServerState
.
COMMITTED
,
condition
=
"is_last_reader"
,
),
]
class
GMSFSM
:
def
__init__
(
self
):
self
.
_rw_conn
:
Optional
[
Connection
]
=
None
self
.
_ro_conns
:
Set
[
Connection
]
=
set
()
self
.
_committed
=
False
@
property
def
state
(
self
)
->
ServerState
:
if
self
.
_rw_conn
is
not
None
:
return
ServerState
.
RW
if
self
.
_ro_conns
:
return
ServerState
.
RO
if
self
.
_committed
:
return
ServerState
.
COMMITTED
return
ServerState
.
EMPTY
@
property
def
rw_conn
(
self
)
->
Optional
[
Connection
]:
return
self
.
_rw_conn
@
property
def
ro_conns
(
self
)
->
Set
[
Connection
]:
return
self
.
_ro_conns
@
property
def
ro_count
(
self
)
->
int
:
return
len
(
self
.
_ro_conns
)
@
property
def
committed
(
self
)
->
bool
:
return
self
.
_committed
def
_check_condition
(
self
,
condition
:
Optional
[
str
],
conn
:
Connection
)
->
bool
:
if
condition
is
None
:
return
True
if
condition
==
"has_remaining_readers"
:
return
len
(
self
.
_ro_conns
)
>
1
or
conn
not
in
self
.
_ro_conns
if
condition
==
"is_last_reader"
:
return
len
(
self
.
_ro_conns
)
==
1
and
conn
in
self
.
_ro_conns
raise
ValueError
(
f
"Unknown condition:
{
condition
}
"
)
def
transition
(
self
,
event
:
StateEvent
,
conn
:
Connection
)
->
ServerState
:
from_state
=
self
.
state
for
transition
in
TRANSITIONS
:
if
from_state
not
in
transition
.
from_states
:
continue
if
transition
.
event
!=
event
:
continue
if
not
self
.
_check_condition
(
transition
.
condition
,
conn
):
continue
break
else
:
raise
InvalidTransition
(
f
"No transition for
{
event
.
name
}
from state
{
from_state
.
name
}
"
f
"(session=
{
conn
.
session_id
}
)"
)
if
event
==
StateEvent
.
RW_CONNECT
:
self
.
_rw_conn
=
conn
self
.
_committed
=
False
elif
event
==
StateEvent
.
RW_COMMIT
:
self
.
_committed
=
True
self
.
_rw_conn
=
None
elif
event
==
StateEvent
.
RW_ABORT
:
self
.
_rw_conn
=
None
elif
event
==
StateEvent
.
RO_CONNECT
:
self
.
_ro_conns
.
add
(
conn
)
elif
event
==
StateEvent
.
RO_DISCONNECT
:
self
.
_ro_conns
.
discard
(
conn
)
return
self
.
state
def
can_acquire_rw
(
self
)
->
bool
:
return
self
.
_rw_conn
is
None
and
not
self
.
_ro_conns
def
can_acquire_ro
(
self
,
waiting_writers
:
int
)
->
bool
:
return
self
.
_committed
and
self
.
_rw_conn
is
None
and
waiting_writers
==
0
lib/gpu_memory_service/server/gms.py
View file @
39a6a240
...
...
@@ -11,6 +11,7 @@ from collections import deque
from
dataclasses
import
dataclass
from
typing
import
Callable
,
Optional
from
gpu_memory_service.common.locks
import
GrantedLockType
,
RequestedLockType
from
gpu_memory_service.common.protocol.messages
import
(
AllocateRequest
,
AllocateResponse
,
...
...
@@ -42,15 +43,10 @@ from gpu_memory_service.common.protocol.messages import (
MetadataPutRequest
,
MetadataPutResponse
,
)
from
gpu_memory_service.common.types
import
(
GrantedLockType
,
RequestedLockType
,
ServerState
,
StateEvent
,
)
from
.allocations
import
AllocationInfo
,
GMSAllocationManager
from
.session
import
Connection
,
GMSSessionManager
from
.fsm
import
Connection
,
ServerState
,
StateEvent
from
.session
import
GMSSessionManager
logger
=
logging
.
getLogger
(
__name__
)
...
...
lib/gpu_memory_service/server/rpc.py
View file @
39a6a240
...
...
@@ -23,8 +23,9 @@ from gpu_memory_service.common.protocol.wire import recv_message, send_message
from
gpu_memory_service.common.utils
import
fail
from
.allocations
import
AllocationNotFoundError
from
.fsm
import
Connection
,
InvalidTransition
from
.gms
import
GMS
from
.session
import
Connection
,
InvalidTransition
,
OperationNotAllowed
from
.session
import
OperationNotAllowed
logger
=
logging
.
getLogger
(
__name__
)
...
...
lib/gpu_memory_service/server/session.py
View file @
39a6a240
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Server-side
connection, FSM, and waiter state
."""
"""Server-side
lock acquisition and cleanup
."""
from
__future__
import
annotations
import
asyncio
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
,
Set
from
gpu_memory_service.common.types
import
(
RO_ALLOWED
,
RW_ALLOWED
,
RW_REQUIRED
,
GrantedLockType
,
RequestedLockType
,
ServerState
,
StateEvent
,
from
dataclasses
import
dataclass
from
typing
import
Optional
from
gpu_memory_service.common.locks
import
GrantedLockType
,
RequestedLockType
from
gpu_memory_service.common.protocol.messages
import
(
AllocateRequest
,
CommitRequest
,
ExportAllocationRequest
,
FreeAllocationRequest
,
GetAllocationRequest
,
GetAllocationStateRequest
,
GetLockStateRequest
,
GetStateHashRequest
,
ListAllocationsRequest
,
MetadataDeleteRequest
,
MetadataGetRequest
,
MetadataListRequest
,
MetadataPutRequest
,
)
@
dataclass
(
eq
=
False
)
class
Connection
:
reader
:
asyncio
.
StreamReader
writer
:
asyncio
.
StreamWriter
mode
:
GrantedLockType
session_id
:
str
recv_buffer
:
bytearray
=
field
(
default_factory
=
bytearray
)
def
__hash__
(
self
)
->
int
:
return
hash
(
self
.
session_id
)
async
def
close
(
self
)
->
None
:
self
.
writer
.
close
()
try
:
await
self
.
writer
.
wait_closed
()
except
Exception
:
pass
class
InvalidTransition
(
Exception
):
"""Raised when an invalid state transition is attempted."""
from
.fsm
import
GMSFSM
,
Connection
,
ServerState
,
StateEvent
class
OperationNotAllowed
(
Exception
):
"""Raised when an operation is not allowed in the current state/mode."""
@
dataclass
(
frozen
=
True
)
class
Transition
:
from_states
:
frozenset
[
ServerState
]
event
:
StateEvent
to_state
:
Optional
[
ServerState
]
condition
:
Optional
[
str
]
=
None
TRANSITIONS
:
list
[
Transition
]
=
[
Transition
(
from_states
=
frozenset
({
ServerState
.
EMPTY
,
ServerState
.
COMMITTED
}),
event
=
StateEvent
.
RW_CONNECT
,
to_state
=
ServerState
.
RW
,
),
Transition
(
from_states
=
frozenset
({
ServerState
.
RW
}),
event
=
StateEvent
.
RW_COMMIT
,
to_state
=
ServerState
.
COMMITTED
,
),
Transition
(
from_states
=
frozenset
({
ServerState
.
RW
}),
event
=
StateEvent
.
RW_ABORT
,
to_state
=
ServerState
.
EMPTY
,
),
Transition
(
from_states
=
frozenset
({
ServerState
.
COMMITTED
,
ServerState
.
RO
}),
event
=
StateEvent
.
RO_CONNECT
,
to_state
=
ServerState
.
RO
,
),
Transition
(
from_states
=
frozenset
({
ServerState
.
RO
}),
event
=
StateEvent
.
RO_DISCONNECT
,
to_state
=
ServerState
.
RO
,
condition
=
"has_remaining_readers"
,
),
Transition
(
from_states
=
frozenset
({
ServerState
.
RO
}),
event
=
StateEvent
.
RO_DISCONNECT
,
to_state
=
ServerState
.
COMMITTED
,
condition
=
"is_last_reader"
,
),
]
class
GMSLocalFSM
:
"""Explicit connection/lock state machine."""
def
__init__
(
self
):
self
.
_rw_conn
:
Optional
[
Connection
]
=
None
self
.
_ro_conns
:
Set
[
Connection
]
=
set
()
self
.
_committed
=
False
pass
@
property
def
state
(
self
)
->
ServerState
:
if
self
.
_rw_conn
is
not
None
:
return
ServerState
.
RW
if
self
.
_ro_conns
:
return
ServerState
.
RO
if
self
.
_committed
:
return
ServerState
.
COMMITTED
return
ServerState
.
EMPTY
@
property
def
rw_conn
(
self
)
->
Optional
[
Connection
]:
return
self
.
_rw_conn
@
property
def
ro_conns
(
self
)
->
Set
[
Connection
]:
return
self
.
_ro_conns
@
property
def
ro_count
(
self
)
->
int
:
return
len
(
self
.
_ro_conns
)
@
property
def
committed
(
self
)
->
bool
:
return
self
.
_committed
def
_has_remaining_readers
(
self
,
conn
:
Connection
)
->
bool
:
return
len
(
self
.
_ro_conns
)
>
1
or
conn
not
in
self
.
_ro_conns
def
_is_last_reader
(
self
,
conn
:
Connection
)
->
bool
:
return
len
(
self
.
_ro_conns
)
==
1
and
conn
in
self
.
_ro_conns
def
_check_condition
(
self
,
condition
:
Optional
[
str
],
conn
:
Connection
)
->
bool
:
if
condition
is
None
:
return
True
if
condition
==
"has_remaining_readers"
:
return
self
.
_has_remaining_readers
(
conn
)
if
condition
==
"is_last_reader"
:
return
self
.
_is_last_reader
(
conn
)
raise
ValueError
(
f
"Unknown condition:
{
condition
}
"
)
def
_find_transition
(
self
,
from_state
:
ServerState
,
event
:
StateEvent
,
conn
:
Connection
,
)
->
Optional
[
Transition
]:
for
transition
in
TRANSITIONS
:
if
from_state
not
in
transition
.
from_states
:
continue
if
transition
.
event
!=
event
:
continue
if
not
self
.
_check_condition
(
transition
.
condition
,
conn
):
continue
return
transition
return
None
def
_apply_event
(
self
,
event
:
StateEvent
,
conn
:
Connection
)
->
None
:
if
event
==
StateEvent
.
RW_CONNECT
:
self
.
_rw_conn
=
conn
self
.
_committed
=
False
elif
event
==
StateEvent
.
RW_COMMIT
:
self
.
_committed
=
True
self
.
_rw_conn
=
None
elif
event
==
StateEvent
.
RW_ABORT
:
self
.
_rw_conn
=
None
elif
event
==
StateEvent
.
RO_CONNECT
:
self
.
_ro_conns
.
add
(
conn
)
elif
event
==
StateEvent
.
RO_DISCONNECT
:
self
.
_ro_conns
.
discard
(
conn
)
def
transition
(
self
,
event
:
StateEvent
,
conn
:
Connection
)
->
ServerState
:
transition
=
self
.
_find_transition
(
self
.
state
,
event
,
conn
)
if
transition
is
None
:
raise
InvalidTransition
(
f
"No transition for
{
event
.
name
}
from state
{
self
.
state
.
name
}
"
f
"(session=
{
conn
.
session_id
}
)"
)
self
.
_apply_event
(
event
,
conn
)
return
self
.
state
def
check_operation
(
self
,
msg_type
:
type
,
conn
:
Connection
)
->
None
:
if
conn
.
mode
==
GrantedLockType
.
RW
and
msg_type
not
in
RW_ALLOWED
:
raise
OperationNotAllowed
(
f
"
{
msg_type
.
__name__
}
not allowed for RW session in state
{
self
.
state
.
name
}
"
)
if
conn
.
mode
==
GrantedLockType
.
RO
and
msg_type
not
in
RO_ALLOWED
:
raise
OperationNotAllowed
(
f
"
{
msg_type
.
__name__
}
not allowed for RO session in state
{
self
.
state
.
name
}
"
)
if
msg_type
in
RW_REQUIRED
and
conn
.
mode
!=
GrantedLockType
.
RW
:
raise
OperationNotAllowed
(
f
"
{
msg_type
.
__name__
}
requires RW session, got
{
conn
.
mode
.
value
}
"
)
RW_REQUIRED
:
frozenset
[
type
]
=
frozenset
(
{
AllocateRequest
,
FreeAllocationRequest
,
MetadataPutRequest
,
MetadataDeleteRequest
,
CommitRequest
,
}
)
def
can_acquire_rw
(
self
)
->
bool
:
return
self
.
_rw_conn
is
None
and
not
self
.
_ro_conns
RO_ALLOWED
:
frozenset
[
type
]
=
frozenset
(
{
ExportAllocationRequest
,
GetAllocationRequest
,
ListAllocationsRequest
,
MetadataGetRequest
,
MetadataListRequest
,
GetLockStateRequest
,
GetAllocationStateRequest
,
GetStateHashRequest
,
}
)
def
can_acquire_ro
(
self
,
waiting_writers
:
int
)
->
bool
:
return
self
.
_committed
and
self
.
_rw_conn
is
None
and
waiting_writers
==
0
RW_ALLOWED
:
frozenset
[
type
]
=
RW_REQUIRED
|
RO_ALLOWED
@
dataclass
(
frozen
=
True
)
...
...
@@ -215,7 +73,7 @@ class GMSSessionManager:
"""Owns lock transitions, waiter coordination, and cleanup."""
def
__init__
(
self
):
self
.
_locking
=
GMS
Local
FSM
()
self
.
_locking
=
GMSFSM
()
self
.
_waiting_writers
=
0
self
.
_reserved_rw_session_id
:
Optional
[
str
]
=
None
self
.
_condition
=
asyncio
.
Condition
()
...
...
@@ -336,7 +194,18 @@ class GMSSessionManager:
self
.
_locking
.
transition
(
StateEvent
.
RW_COMMIT
,
conn
)
def
check_operation
(
self
,
msg_type
:
type
,
conn
:
Connection
)
->
None
:
self
.
_locking
.
check_operation
(
msg_type
,
conn
)
if
conn
.
mode
==
GrantedLockType
.
RW
and
msg_type
not
in
RW_ALLOWED
:
raise
OperationNotAllowed
(
f
"
{
msg_type
.
__name__
}
not allowed for RW session in state
{
self
.
state
.
name
}
"
)
if
conn
.
mode
==
GrantedLockType
.
RO
and
msg_type
not
in
RO_ALLOWED
:
raise
OperationNotAllowed
(
f
"
{
msg_type
.
__name__
}
not allowed for RO session in state
{
self
.
state
.
name
}
"
)
if
msg_type
in
RW_REQUIRED
and
conn
.
mode
!=
GrantedLockType
.
RW
:
raise
OperationNotAllowed
(
f
"
{
msg_type
.
__name__
}
requires RW session, got
{
conn
.
mode
.
value
}
"
)
def
begin_cleanup
(
self
,
conn
:
Optional
[
Connection
])
->
StateEvent
|
None
:
if
conn
is
None
:
...
...
pyproject.toml
View file @
39a6a240
...
...
@@ -246,6 +246,7 @@ markers = [
"stress: marks tests as stress tests"
,
"performance: marks tests as performance tests"
,
"benchmark: marks tests as benchmark tests"
,
"none: marks tests that do not require a framework-specific runtime"
,
"vllm: marks tests as requiring vllm"
,
"trtllm: marks tests as requiring trtllm"
,
"sglang: marks tests as requiring sglang"
,
...
...
tests/gms/common/__init__.py
View file @
39a6a240
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
pytest
pytest
.
importorskip
(
"gpu_memory_service"
,
reason
=
"gpu_memory_service is required"
)
tests/gms/common/test_failover_lock.py
View file @
39a6a240
...
...
@@ -3,8 +3,9 @@
"""Tests for the flock-based failover lock.
No GPU required — these are pure Python/OS tests exercising flock
semantics across asyncio tasks and child processes.
These are pure Python/OS tests exercising flock semantics across asyncio
tasks and child processes, so they stay on the generic cpu-style pre-merge
lane instead of the dedicated GPU job.
"""
import
asyncio
...
...
@@ -19,6 +20,7 @@ from gpu_memory_service.failover_lock.flock import FlockFailoverLock
pytestmark
=
[
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
none
,
pytest
.
mark
.
gpu_0
,
]
...
...
tests/gms/common/test_gms_client_memory_manager.py
View file @
39a6a240
...
...
@@ -9,12 +9,13 @@ from gpu_memory_service.client.memory_manager import (
GMSClientMemoryManager
,
LocalMapping
,
)
from
gpu_memory_service.common.
type
s
import
GrantedLockType
,
RequestedLockType
from
gpu_memory_service.common.
lock
s
import
GrantedLockType
,
RequestedLockType
pytestmark
=
[
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
gpu_0
,
pytest
.
mark
.
none
,
pytest
.
mark
.
gpu_1
,
]
...
...
tests/gms/common/test_gms_client_session.py
View file @
39a6a240
...
...
@@ -6,15 +6,16 @@ from __future__ import annotations
import
pytest
from
gpu_memory_service.client.rpc
import
_GMSRPCTransport
from
gpu_memory_service.client.session
import
_GMSClientSession
from
gpu_memory_service.common.locks
import
GrantedLockType
,
RequestedLockType
from
gpu_memory_service.common.protocol.messages
import
(
CommitResponse
,
HandshakeResponse
,
)
from
gpu_memory_service.common.types
import
GrantedLockType
,
RequestedLockType
pytestmark
=
[
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
none
,
pytest
.
mark
.
gpu_0
,
]
...
...
tests/gms/common/test_gms_client_transport.py
View file @
39a6a240
...
...
@@ -15,6 +15,7 @@ from gpu_memory_service.common.protocol.messages import (
pytestmark
=
[
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
none
,
pytest
.
mark
.
gpu_0
,
]
...
...
tests/gms/common/test_gms_harness.py
View file @
39a6a240
...
...
@@ -10,7 +10,12 @@ import pytest
from
tests.gms.harness.gms
import
GMSServerProcess
from
tests.utils.managed_process
import
ManagedProcess
pytestmark
=
[
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
gpu_0
]
pytestmark
=
[
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
none
,
pytest
.
mark
.
gpu_1
,
]
@
pytest
.
fixture
...
...
tests/gms/common/test_gms_runtime_flows.py
View file @
39a6a240
...
...
@@ -24,19 +24,16 @@ from gpu_memory_service.client.memory_manager import (
from
gpu_memory_service.client.rpc
import
_GMSRPCTransport
from
gpu_memory_service.client.session
import
_GMSClientSession
from
gpu_memory_service.common
import
cuda_utils
from
gpu_memory_service.common.locks
import
GrantedLockType
,
RequestedLockType
from
gpu_memory_service.common.protocol.messages
import
(
GetEventHistoryRequest
,
GetEventHistoryResponse
,
GetRuntimeStateRequest
,
GetRuntimeStateResponse
,
)
from
gpu_memory_service.common.types
import
(
GrantedLockType
,
RequestedLockType
,
ServerState
,
)
from
gpu_memory_service.server
import
allocations
as
server_allocations
from
gpu_memory_service.server.allocations
import
GMSAllocationManager
from
gpu_memory_service.server.fsm
import
ServerState
from
gpu_memory_service.server.rpc
import
GMSRPCServer
from
tests.gms.harness.gms
import
ServerThread
...
...
@@ -44,7 +41,8 @@ from tests.gms.harness.gms import ServerThread
pytestmark
=
[
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
gpu_0
,
pytest
.
mark
.
none
,
pytest
.
mark
.
gpu_1
,
]
...
...
tests/gms/common/test_gms_server_transport_failures.py
View file @
39a6a240
...
...
@@ -15,6 +15,7 @@ from dataclasses import dataclass
import
pytest
from
gpu_memory_service.common
import
cuda_utils
from
gpu_memory_service.common.locks
import
GrantedLockType
,
RequestedLockType
from
gpu_memory_service.common.protocol.messages
import
(
CommitRequest
,
CommitResponse
,
...
...
@@ -24,13 +25,8 @@ from gpu_memory_service.common.protocol.messages import (
GetRuntimeStateRequest
,
HandshakeRequest
,
)
from
gpu_memory_service.common.types
import
(
GrantedLockType
,
RequestedLockType
,
ServerState
,
StateEvent
,
)
from
gpu_memory_service.server.allocations
import
GMSAllocationManager
from
gpu_memory_service.server.fsm
import
ServerState
,
StateEvent
from
gpu_memory_service.server.gms
import
GMS
from
gpu_memory_service.server.rpc
import
GMSRPCServer
,
_is_connection_alive
from
gpu_memory_service.server.session
import
(
...
...
@@ -46,7 +42,8 @@ from cuda.bindings import driver as cuda # noqa: E402
pytestmark
=
[
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
gpu_0
,
pytest
.
mark
.
none
,
pytest
.
mark
.
gpu_1
,
]
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment