Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
f3b181a9
Unverified
Commit
f3b181a9
authored
Apr 14, 2026
by
Schwinn Saereesitthipitak
Committed by
GitHub
Apr 14, 2026
Browse files
feat(gms): operator-managed GMS checkpoint/restore support (#8153)
parent
091cdb51
Changes
49
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1237 additions
and
11 deletions
+1237
-11
lib/gpu_memory_service/integrations/common/patches.py
lib/gpu_memory_service/integrations/common/patches.py
+7
-8
lib/gpu_memory_service/pyproject.toml
lib/gpu_memory_service/pyproject.toml
+1
-0
lib/gpu_memory_service/setup.py
lib/gpu_memory_service/setup.py
+6
-0
lib/gpu_memory_service/snapshot/__init__.py
lib/gpu_memory_service/snapshot/__init__.py
+2
-0
lib/gpu_memory_service/snapshot/disk.py
lib/gpu_memory_service/snapshot/disk.py
+445
-0
lib/gpu_memory_service/snapshot/model.py
lib/gpu_memory_service/snapshot/model.py
+68
-0
lib/gpu_memory_service/snapshot/restore.py
lib/gpu_memory_service/snapshot/restore.py
+69
-0
lib/gpu_memory_service/snapshot/storage_client.py
lib/gpu_memory_service/snapshot/storage_client.py
+639
-0
lib/gpu_memory_service/tests/test_runtime_flows.py
lib/gpu_memory_service/tests/test_runtime_flows.py
+0
-3
No files found.
lib/gpu_memory_service/integrations/common/patches.py
View file @
f3b181a9
...
...
@@ -32,16 +32,15 @@ def patch_empty_cache() -> None:
_original_empty_cache
=
torch
.
cuda
.
empty_cache
def
safe_empty_cache
()
->
None
:
active_mapping_count
=
sum
(
1
# Allow empty_cache when all managers are unmapped (sleep/checkpoint)
# or when there are no active VMM mappings with live handles.
has_live_mappings
=
any
(
any
(
m
.
handle
!=
0
for
m
in
manager
.
mappings
.
values
())
for
manager
in
get_gms_client_memory_managers
()
for
mapping
in
manager
.
mappings
.
values
()
if
mapping
.
handle
!=
0
)
if
active_mapping_count
:
logger
.
warning
(
"[GMS] Skipping torch.cuda.empty_cache() - %d active GMS mappings"
,
active_mapping_count
,
if
has_live_mappings
:
logger
.
debug
(
"[GMS] Skipping torch.cuda.empty_cache() - live VMM mappings active"
,
)
return
_original_empty_cache
()
...
...
lib/gpu_memory_service/pyproject.toml
View file @
f3b181a9
...
...
@@ -36,6 +36,7 @@ keywords = ["llm", "genai", "inference", "nvidia", "gpu", "memory", "dynamo"]
[project.scripts]
gpu-memory-service
=
"gpu_memory_service.cli.runner:main"
gms-storage-client
=
"gpu_memory_service.cli.storage_runner:main"
[project.optional-dependencies]
test
=
[
...
...
lib/gpu_memory_service/setup.py
View file @
f3b181a9
...
...
@@ -98,6 +98,12 @@ setup(
package_data
=
{
"gpu_memory_service.client.torch.extensions"
:
[
"*.cpp"
],
},
entry_points
=
{
"console_scripts"
:
[
"gpu-memory-service=gpu_memory_service.cli.runner:main"
,
"gms-storage-client=gpu_memory_service.cli.storage_runner:main"
,
]
},
ext_modules
=
_create_ext_modules
(),
cmdclass
=
{
"build_ext"
:
BuildExtension
},
)
lib/gpu_memory_service/snapshot/__init__.py
0 → 100644
View file @
f3b181a9
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
lib/gpu_memory_service/snapshot/disk.py
0 → 100644
View file @
f3b181a9
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
__future__
import
annotations
import
base64
import
errno
import
json
import
os
import
queue
import
threading
from
collections
import
defaultdict
from
concurrent.futures
import
CancelledError
,
ThreadPoolExecutor
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
if
TYPE_CHECKING
:
import
torch
from
gpu_memory_service.snapshot.model
import
AllocationEntry
,
SaveManifest
class
ShardWriter
:
"""Packs allocation bytes sequentially into large binary shard files.
This is a single-threaded utility for streaming writes. The parallel save
path in GMSStorageClient._write_shards assigns allocations to shards via
plan_shard_layout and writes each shard file concurrently, so it does not
use ShardWriter directly. ShardWriter is kept as a public utility for
callers that want a simple sequential writer.
"""
def
__init__
(
self
,
shards_dir
:
str
,
shard_size_bytes
:
int
=
4
*
1024
**
3
)
->
None
:
self
.
_shards_dir
=
shards_dir
self
.
_shard_size
=
shard_size_bytes
self
.
_shard_idx
=
-
1
self
.
_current_offset
=
0
self
.
_current_file
:
Optional
[
Any
]
=
None
self
.
_current_rel_path
:
str
=
""
os
.
makedirs
(
shards_dir
,
exist_ok
=
True
)
def
_roll_shard
(
self
)
->
None
:
if
self
.
_current_file
is
not
None
:
self
.
_current_file
.
close
()
self
.
_shard_idx
+=
1
filename
=
f
"shard_
{
self
.
_shard_idx
:
04
d
}
.bin"
abs_path
=
os
.
path
.
join
(
self
.
_shards_dir
,
filename
)
self
.
_current_file
=
open
(
abs_path
,
"wb"
)
self
.
_current_rel_path
=
os
.
path
.
join
(
"shards"
,
filename
)
self
.
_current_offset
=
0
def
write
(
self
,
tensor
:
torch
.
Tensor
)
->
Tuple
[
str
,
int
]:
cpu
=
tensor
.
cpu
()
if
hasattr
(
tensor
,
"is_cuda"
)
and
tensor
.
is_cuda
else
tensor
if
hasattr
(
cpu
,
"is_contiguous"
)
and
not
cpu
.
is_contiguous
():
cpu
=
cpu
.
contiguous
()
arr
=
cpu
.
numpy
()
size
=
arr
.
nbytes
if
self
.
_current_file
is
None
or
(
self
.
_current_offset
>
0
and
self
.
_current_offset
+
size
>
self
.
_shard_size
):
self
.
_roll_shard
()
offset
=
self
.
_current_offset
arr
.
tofile
(
self
.
_current_file
)
self
.
_current_offset
+=
size
return
self
.
_current_rel_path
,
offset
def
close
(
self
)
->
None
:
if
self
.
_current_file
is
not
None
:
self
.
_current_file
.
close
()
self
.
_current_file
=
None
def
__enter__
(
self
)
->
"ShardWriter"
:
return
self
def
__exit__
(
self
,
*
_
:
Any
)
->
None
:
self
.
close
()
def
read_shard_sequential
(
abs_path
:
str
,
sorted_entries
:
List
[
AllocationEntry
],
device
:
int
,
*
,
pin_memory
:
bool
=
False
,
os_module
=
os
,
np_module
=
None
,
torch_module
=
None
,
logger
=
None
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""Read one shard file front-to-back without seeking."""
if
np_module
is
None
or
torch_module
is
None
:
raise
RuntimeError
(
"numpy and torch modules are required to read shards"
)
result
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
device_str
=
f
"cuda:
{
device
}
"
if
device
>=
0
else
"cpu"
if
abs_path
.
endswith
(
".pt"
):
if
len
(
sorted_entries
)
!=
1
:
raise
RuntimeError
(
f
"Expected exactly 1 entry for legacy .pt file, got "
f
"
{
len
(
sorted_entries
)
}
:
{
abs_path
}
"
)
entry
=
sorted_entries
[
0
]
result
[
entry
.
allocation_id
]
=
torch_module
.
load
(
abs_path
,
weights_only
=
True
,
map_location
=
device_str
,
)
return
result
odirect_flag
=
getattr
(
os_module
,
"O_DIRECT"
,
None
)
if
odirect_flag
is
not
None
:
fd
:
Optional
[
int
]
=
None
done
=
0
try
:
total_size
=
sum
(
entry
.
aligned_size
for
entry
in
sorted_entries
)
# Avoid torch.empty(pin_memory=True): cudaHostAlloc is ~1-3 s/GiB
# and dominates wall time. Plain numpy gives good throughput since
# PCIe H2D bandwidth far exceeds network disk bandwidth.
shard_t
=
None
arr
=
np_module
.
empty
(
total_size
,
dtype
=
np_module
.
uint8
)
fd
=
os_module
.
open
(
abs_path
,
os_module
.
O_RDONLY
|
odirect_flag
)
try
:
mv
=
memoryview
(
arr
)
try
:
while
done
<
total_size
:
read
=
os_module
.
readv
(
fd
,
[
mv
[
done
:]])
if
read
==
0
:
raise
RuntimeError
(
f
"Unexpected EOF in O_DIRECT read from
{
abs_path
}
: "
f
"got
{
done
}
of
{
total_size
}
bytes"
)
done
+=
read
finally
:
mv
.
release
()
finally
:
os_module
.
close
(
fd
)
offset
=
0
for
entry
in
sorted_entries
:
size
=
entry
.
aligned_size
if
shard_t
is
not
None
:
tensor
=
shard_t
[
offset
:
offset
+
size
]
else
:
tensor
=
torch_module
.
from_numpy
(
arr
[
offset
:
offset
+
size
])
if
device
>=
0
:
tensor
=
tensor
.
to
(
device_str
)
result
[
entry
.
allocation_id
]
=
tensor
offset
+=
size
return
result
except
OSError
as
exc
:
fallback_errnos
=
{
errno
.
EINVAL
,
errno
.
EOPNOTSUPP
}
if
fd
is
not
None
and
exc
.
errno
not
in
fallback_errnos
:
raise
result
.
clear
()
if
logger
is
not
None
:
if
fd
is
None
:
logger
.
debug
(
"O_DIRECT unsupported on %s (errno %s); using buffered reads"
,
abs_path
,
exc
.
errno
,
)
else
:
logger
.
debug
(
"O_DIRECT read on %s hit EINVAL after %d/%d bytes; using buffered reads"
,
abs_path
,
done
,
total_size
,
)
if
sorted_entries
and
sorted_entries
[
0
].
tensor_offset
!=
0
:
raise
RuntimeError
(
f
"Buffered shard read requires entries starting at offset 0, "
f
"got
{
sorted_entries
[
0
].
tensor_offset
}
in
{
abs_path
}
"
)
with
open
(
abs_path
,
"rb"
)
as
handle
:
for
entry
in
sorted_entries
:
raw
=
handle
.
read
(
entry
.
aligned_size
)
if
len
(
raw
)
!=
entry
.
aligned_size
:
raise
RuntimeError
(
f
"Short read from
{
abs_path
}
at offset
{
entry
.
tensor_offset
}
: "
f
"expected
{
entry
.
aligned_size
}
bytes, got
{
len
(
raw
)
}
"
)
arr
=
np_module
.
frombuffer
(
raw
,
dtype
=
np_module
.
uint8
).
copy
()
tensor
=
torch_module
.
from_numpy
(
arr
)
if
device
>=
0
:
tensor
=
tensor
.
to
(
device_str
)
result
[
entry
.
allocation_id
]
=
tensor
return
result
def
decode_metadata
(
raw_meta
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Dict
[
str
,
Any
]]:
return
{
key
:
{
"allocation_id"
:
entry
[
"allocation_id"
],
"offset_bytes"
:
int
(
entry
[
"offset_bytes"
]),
"value"
:
base64
.
b64decode
(
entry
[
"value"
]),
}
for
key
,
entry
in
raw_meta
.
items
()
}
def
group_entries_by_shard
(
allocations
:
List
[
AllocationEntry
],
)
->
Dict
[
str
,
List
[
AllocationEntry
]]:
groups
:
Dict
[
str
,
List
[
AllocationEntry
]]
=
defaultdict
(
list
)
for
entry
in
allocations
:
groups
[
entry
.
tensor_file
].
append
(
entry
)
for
entries
in
groups
.
values
():
entries
.
sort
(
key
=
lambda
entry
:
entry
.
tensor_offset
)
return
dict
(
groups
)
def
plan_shard_layout
(
allocations_info
:
List
[
Dict
[
str
,
Any
]],
shard_size_bytes
:
int
,
)
->
List
[
Tuple
[
int
,
int
]]:
result
:
List
[
Tuple
[
int
,
int
]]
=
[]
shard_idx
=
-
1
current_offset
=
0
started
=
False
for
alloc
in
allocations_info
:
size
=
int
(
alloc
[
"aligned_size"
])
if
not
started
or
(
current_offset
>
0
and
current_offset
+
size
>
shard_size_bytes
):
shard_idx
+=
1
current_offset
=
0
started
=
True
result
.
append
((
shard_idx
,
current_offset
))
current_offset
+=
size
return
result
def
_put_entry
(
work_q
:
queue
.
Queue
[
Optional
[
Tuple
[
AllocationEntry
,
"torch.Tensor"
]]],
entry
:
AllocationEntry
,
tensor
:
"torch.Tensor"
,
cancel_event
:
Optional
[
threading
.
Event
],
abs_path
:
str
,
)
->
None
:
"""Put one entry into the work queue, respecting cancellation."""
while
True
:
if
cancel_event
is
not
None
and
cancel_event
.
is_set
():
raise
CancelledError
(
f
"shard read cancelled:
{
abs_path
}
"
)
try
:
work_q
.
put
((
entry
,
tensor
),
timeout
=
0.1
)
return
except
queue
.
Full
:
pass
# 64 MiB chunks for parallel preadv — gives high effective iodepth on NFS
# while keeping each syscall large enough to amortize overhead.
_CHUNK_SIZE
=
64
*
1024
*
1024
# How many preadv calls to keep in-flight per shard. On Vast NFS each
# outstanding preadv becomes a separate NFS READ RPC, so higher iodepth
# means more network-level parallelism from a single file descriptor.
_IO_DEPTH
=
16
def
_preadv_chunk
(
fd
:
int
,
buf
:
memoryview
,
file_offset
:
int
,
size
:
int
,
os_module
,
)
->
None
:
"""Read exactly *size* bytes from *fd* at *file_offset* into *buf*."""
done
=
0
while
done
<
size
:
n
=
os_module
.
preadv
(
fd
,
[
buf
[
done
:
size
]],
file_offset
+
done
)
if
n
==
0
:
raise
RuntimeError
(
f
"Unexpected EOF in preadv at offset
{
file_offset
+
done
}
"
)
done
+=
n
def
read_shard_streaming_to_queue
(
abs_path
:
str
,
sorted_entries
:
List
[
AllocationEntry
],
work_q
:
queue
.
Queue
[
Optional
[
Tuple
[
AllocationEntry
,
"torch.Tensor"
]]],
*
,
pin_memory
:
bool
,
cancel_event
:
Optional
[
threading
.
Event
]
=
None
,
os_module
=
os
,
np_module
=
None
,
torch_module
=
None
,
logger
=
None
,
)
->
int
:
"""Read a shard via parallel O_DIRECT preadv calls, streaming entries
to *work_q* as they become readable.
Multiple chunks are read concurrently from different file offsets to
achieve high effective I/O depth on network filesystems (e.g. Vast NFS)
where single-threaded synchronous reads severely under-utilize bandwidth.
"""
if
not
sorted_entries
:
return
0
if
np_module
is
None
or
torch_module
is
None
:
raise
RuntimeError
(
"numpy and torch modules are required"
)
total_size
=
sum
(
e
.
aligned_size
for
e
in
sorted_entries
)
# Allocate a buffer for the whole shard. We intentionally avoid
# torch.empty(pin_memory=True) because cudaHostAlloc is extremely
# slow (~1-3 s per GiB) and dominates wall time for large shards.
# A plain numpy buffer still gives good H2D throughput (the copy is
# synchronous but PCIe bandwidth ≫ disk bandwidth).
shard_t
=
None
shard_arr
=
np_module
.
empty
(
total_size
,
dtype
=
np_module
.
uint8
)
odirect_flag
=
getattr
(
os_module
,
"O_DIRECT"
,
None
)
preadv_fn
=
getattr
(
os_module
,
"preadv"
,
None
)
if
odirect_flag
is
not
None
and
preadv_fn
is
not
None
:
fd
:
Optional
[
int
]
=
None
io_pool
:
Optional
[
ThreadPoolExecutor
]
=
None
try
:
fd
=
os_module
.
open
(
abs_path
,
os_module
.
O_RDONLY
|
odirect_flag
)
mv
=
memoryview
(
shard_arr
)
# Build aligned chunk list covering the full shard.
chunk_size
=
_CHUNK_SIZE
chunks
:
List
[
Tuple
[
int
,
int
]]
=
[]
# (offset, size)
off
=
0
while
off
<
total_size
:
sz
=
min
(
chunk_size
,
total_size
-
off
)
chunks
.
append
((
off
,
sz
))
off
+=
sz
# chunks_done[i] is set when chunk i finishes (success or error).
chunks_done
=
[
threading
.
Event
()
for
_
in
chunks
]
chunk_errors
:
List
[
BaseException
]
=
[]
def
_read_chunk
(
idx
:
int
)
->
None
:
try
:
c_off
,
c_sz
=
chunks
[
idx
]
_preadv_chunk
(
fd
,
mv
[
c_off
:
c_off
+
c_sz
],
c_off
,
c_sz
,
os_module
)
except
BaseException
as
exc
:
chunk_errors
.
append
(
exc
)
finally
:
chunks_done
[
idx
].
set
()
# Submit chunk reads with bounded concurrency.
io_pool
=
ThreadPoolExecutor
(
max_workers
=
min
(
_IO_DEPTH
,
len
(
chunks
)))
for
i
in
range
(
len
(
chunks
)):
io_pool
.
submit
(
_read_chunk
,
i
)
# Stream entries to the work queue as their data arrives.
def
_chunk_for_byte
(
byte_off
:
int
)
->
int
:
return
byte_off
//
chunk_size
for
entry_idx
in
range
(
len
(
sorted_entries
)):
if
cancel_event
is
not
None
and
cancel_event
.
is_set
():
raise
CancelledError
(
f
"shard read cancelled:
{
abs_path
}
"
)
entry
=
sorted_entries
[
entry_idx
]
start_chunk
=
_chunk_for_byte
(
entry
.
tensor_offset
)
end_chunk
=
_chunk_for_byte
(
entry
.
tensor_offset
+
entry
.
aligned_size
-
1
)
for
ci
in
range
(
start_chunk
,
end_chunk
+
1
):
chunks_done
[
ci
].
wait
()
if
chunk_errors
:
raise
chunk_errors
[
0
]
eoff
=
entry
.
tensor_offset
if
shard_t
is
not
None
:
tensor
=
shard_t
[
eoff
:
eoff
+
entry
.
aligned_size
]
else
:
tensor
=
torch_module
.
from_numpy
(
shard_arr
[
eoff
:
eoff
+
entry
.
aligned_size
]
)
_put_entry
(
work_q
,
entry
,
tensor
,
cancel_event
,
abs_path
)
if
chunk_errors
:
raise
chunk_errors
[
0
]
return
len
(
sorted_entries
)
except
OSError
as
exc
:
fallback_errnos
=
{
errno
.
EINVAL
,
errno
.
EOPNOTSUPP
}
if
exc
.
errno
not
in
fallback_errnos
:
raise
if
logger
is
not
None
:
logger
.
debug
(
"O_DIRECT preadv failed on %s (errno %s); "
"falling back to buffered read"
,
abs_path
,
exc
.
errno
,
)
finally
:
if
io_pool
is
not
None
:
io_pool
.
shutdown
(
wait
=
False
)
io_pool
=
None
if
fd
is
not
None
:
os_module
.
close
(
fd
)
fd
=
None
# Fallback: buffered full-shard read, then queue all entries.
with
open
(
abs_path
,
"rb"
)
as
handle
:
raw
=
handle
.
read
()
arr
=
np_module
.
frombuffer
(
raw
,
dtype
=
np_module
.
uint8
).
copy
()
for
entry
in
sorted_entries
:
off
=
entry
.
tensor_offset
tensor
=
torch_module
.
from_numpy
(
arr
[
off
:
off
+
entry
.
aligned_size
])
_put_entry
(
work_q
,
entry
,
tensor
,
cancel_event
,
abs_path
)
return
len
(
sorted_entries
)
def
read_shard_to_queue
(
abs_path
:
str
,
sorted_entries
:
List
[
AllocationEntry
],
work_q
:
queue
.
Queue
[
Optional
[
Tuple
[
AllocationEntry
,
torch
.
Tensor
]]],
*
,
pin_memory
:
bool
,
read_shard
,
cancel_event
:
Optional
[
threading
.
Event
]
=
None
,
)
->
int
:
shard_result
=
read_shard
(
abs_path
,
sorted_entries
,
-
1
,
pin_memory
=
pin_memory
,
)
for
entry
in
sorted_entries
:
_put_entry
(
work_q
,
entry
,
shard_result
[
entry
.
allocation_id
],
cancel_event
,
abs_path
)
return
len
(
sorted_entries
)
def
load_manifest_and_metadata
(
input_dir
:
str
,
)
->
Tuple
[
SaveManifest
,
Dict
[
str
,
Dict
[
str
,
Any
]]]:
manifest_path
=
os
.
path
.
join
(
input_dir
,
"manifest.json"
)
with
open
(
manifest_path
,
encoding
=
"utf-8"
)
as
handle
:
manifest
=
SaveManifest
.
from_dict
(
json
.
load
(
handle
))
metadata_path
=
os
.
path
.
join
(
input_dir
,
"gms_metadata.json"
)
raw_meta
:
Dict
[
str
,
Any
]
=
{}
if
os
.
path
.
exists
(
metadata_path
):
with
open
(
metadata_path
,
encoding
=
"utf-8"
)
as
handle
:
raw_meta
=
json
.
load
(
handle
)
return
manifest
,
decode_metadata
(
raw_meta
)
lib/gpu_memory_service/snapshot/model.py
0 → 100644
View file @
f3b181a9
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
__future__
import
annotations
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
Any
,
Dict
,
List
CURRENT_VERSION
=
"1.0"
@
dataclass
(
frozen
=
True
)
class
AllocationEntry
:
"""Immutable record of one dumped allocation."""
allocation_id
:
str
size
:
int
aligned_size
:
int
tag
:
str
tensor_file
:
str
tensor_offset
:
int
=
0
@
dataclass
class
SaveManifest
:
"""Manifest for a GMS dump directory."""
version
:
str
timestamp
:
float
layout_hash
:
str
device
:
int
allocations
:
List
[
AllocationEntry
]
=
field
(
default_factory
=
list
)
def
to_dict
(
self
)
->
Dict
[
str
,
Any
]:
return
{
"version"
:
self
.
version
,
"timestamp"
:
self
.
timestamp
,
"layout_hash"
:
self
.
layout_hash
,
"device"
:
self
.
device
,
"allocations"
:
[
asdict
(
a
)
for
a
in
self
.
allocations
],
}
@
classmethod
def
from_dict
(
cls
,
payload
:
Dict
[
str
,
Any
])
->
"SaveManifest"
:
version
=
payload
[
"version"
]
if
version
!=
CURRENT_VERSION
:
raise
ValueError
(
f
"Unsupported manifest version
{
version
!
r
}
"
f
"(expected
{
CURRENT_VERSION
!
r
}
)"
)
allocations
=
[
AllocationEntry
(
allocation_id
=
entry
[
"allocation_id"
],
size
=
entry
[
"size"
],
aligned_size
=
entry
[
"aligned_size"
],
tag
=
entry
[
"tag"
],
tensor_file
=
entry
[
"tensor_file"
],
tensor_offset
=
entry
.
get
(
"tensor_offset"
,
0
),
)
for
entry
in
payload
.
get
(
"allocations"
,
[])
]
return
cls
(
version
=
payload
[
"version"
],
timestamp
=
payload
[
"timestamp"
],
layout_hash
=
payload
[
"layout_hash"
],
device
=
payload
[
"device"
],
allocations
=
allocations
,
)
lib/gpu_memory_service/snapshot/restore.py
0 → 100644
View file @
f3b181a9
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
__future__
import
annotations
import
queue
import
threading
from
concurrent.futures
import
Future
,
ThreadPoolExecutor
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
if
TYPE_CHECKING
:
import
torch
from
gpu_memory_service.snapshot.model
import
AllocationEntry
WORK_QUEUE_DEPTH_MULTIPLIER
=
4
@
dataclass
class
RestorePipelineContext
:
"""Mutable state shared across disk, copy, and Phase A restore stages."""
worker_count
:
int
use_streams
:
bool
device
:
int
work_q
:
queue
.
Queue
[
Optional
[
Tuple
[
AllocationEntry
,
torch
.
Tensor
]]]
va_events
:
Dict
[
str
,
threading
.
Event
]
streams
:
List
[
torch
.
cuda
.
Stream
]
cancel_event
:
threading
.
Event
=
field
(
default_factory
=
threading
.
Event
)
vas
:
Dict
[
str
,
int
]
=
field
(
default_factory
=
dict
)
staged_srcs
:
List
[
torch
.
Tensor
]
=
field
(
default_factory
=
list
)
copy_errors
:
List
[
BaseException
]
=
field
(
default_factory
=
list
)
lock
:
threading
.
Lock
=
field
(
default_factory
=
threading
.
Lock
)
@
classmethod
def
build
(
cls
,
allocations
:
List
[
AllocationEntry
],
worker_count
:
int
,
*
,
device
:
int
,
use_streams
:
bool
,
torch_module
,
)
->
"RestorePipelineContext"
:
streams
=
(
[
torch_module
.
cuda
.
Stream
(
device
=
device
)
for
_
in
range
(
worker_count
)]
if
use_streams
else
[]
)
return
cls
(
worker_count
=
worker_count
,
use_streams
=
use_streams
,
device
=
device
,
work_q
=
queue
.
Queue
(
maxsize
=
worker_count
*
WORK_QUEUE_DEPTH_MULTIPLIER
),
va_events
=
{
entry
.
allocation_id
:
threading
.
Event
()
for
entry
in
allocations
},
streams
=
streams
,
)
@
dataclass
class
RestorePipelineResources
:
"""Live restore pipeline resources that must be torn down together."""
ctx
:
RestorePipelineContext
disk_pool
:
ThreadPoolExecutor
disk_futures
:
Dict
[
Future
[
int
],
str
]
copy_threads
:
List
[
threading
.
Thread
]
active
:
bool
=
True
lib/gpu_memory_service/snapshot/storage_client.py
0 → 100644
View file @
f3b181a9
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GMS storage client: save GMS state to disk and load it back."""
from
__future__
import
annotations
import
base64
import
json
import
logging
import
os
import
queue
import
threading
import
time
from
collections
import
defaultdict
from
concurrent.futures
import
CancelledError
,
Future
,
ThreadPoolExecutor
,
as_completed
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
numpy
as
np
from
gpu_memory_service.snapshot.disk
import
(
# noqa: F401 re-exported for external callers
ShardWriter
as
_ShardWriter
,
)
from
gpu_memory_service.snapshot.disk
import
decode_metadata
as
_decode_metadata_impl
from
gpu_memory_service.snapshot.disk
import
(
group_entries_by_shard
as
_group_entries_by_shard_impl
,
)
from
gpu_memory_service.snapshot.disk
import
(
load_manifest_and_metadata
as
_load_manifest_and_metadata_impl
,
)
from
gpu_memory_service.snapshot.disk
import
(
plan_shard_layout
as
_plan_shard_layout_impl
,
)
from
gpu_memory_service.snapshot.disk
import
(
read_shard_sequential
as
_read_shard_sequential_impl
,
)
from
gpu_memory_service.snapshot.disk
import
(
read_shard_to_queue
as
_read_shard_to_queue_impl
,
)
from
gpu_memory_service.snapshot.model
import
CURRENT_VERSION
as
_CURRENT_VERSION
from
gpu_memory_service.snapshot.model
import
AllocationEntry
,
SaveManifest
from
gpu_memory_service.snapshot.restore
import
(
RestorePipelineContext
as
_RestorePipelineContext
,
)
from
gpu_memory_service.snapshot.restore
import
(
RestorePipelineResources
as
_RestorePipelineResources
,
)
logger
=
logging
.
getLogger
(
__name__
)
try
:
from
gpu_memory_service.client.memory_manager
import
GMSClientMemoryManager
from
gpu_memory_service.client.torch.tensor
import
_tensor_from_pointer
from
gpu_memory_service.common.locks
import
RequestedLockType
_GMS_IMPORTS_AVAILABLE
=
True
except
ImportError
:
_GMS_IMPORTS_AVAILABLE
=
False
GMSClientMemoryManager
=
None
# type: ignore[assignment,misc]
_tensor_from_pointer
=
None
# type: ignore[assignment]
RequestedLockType
=
None
# type: ignore[assignment]
try
:
import
torch
_TORCH_AVAILABLE
=
True
except
ImportError
:
_TORCH_AVAILABLE
=
False
torch
=
None
# type: ignore[assignment]
def
_read_shard_sequential
(
abs_path
:
str
,
sorted_entries
:
List
[
AllocationEntry
],
device
:
int
,
pin_memory
:
bool
=
False
,
)
->
Dict
[
str
,
"torch.Tensor"
]:
"""Facade wrapper kept for test patchability and backwards compatibility."""
return
_read_shard_sequential_impl
(
abs_path
,
sorted_entries
,
device
,
pin_memory
=
pin_memory
,
os_module
=
os
,
np_module
=
np
,
torch_module
=
torch
,
logger
=
logger
,
)
def
_decode_metadata
(
raw_meta
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Dict
[
str
,
Any
]]:
# Re-exported for external callers (e.g. multi_ssd_bench.py).
return
_decode_metadata_impl
(
raw_meta
)
def
_group_entries_by_shard
(
allocations
:
List
[
AllocationEntry
],
)
->
Dict
[
str
,
List
[
AllocationEntry
]]:
return
_group_entries_by_shard_impl
(
allocations
)
def
_allocation_record
(
alloc
:
Any
)
->
Dict
[
str
,
Any
]:
if
isinstance
(
alloc
,
dict
):
return
alloc
return
{
"allocation_id"
:
str
(
alloc
.
allocation_id
),
"size"
:
int
(
alloc
.
size
),
"aligned_size"
:
int
(
alloc
.
aligned_size
),
"tag"
:
str
(
alloc
.
tag
),
"layout_slot"
:
int
(
alloc
.
layout_slot
),
}
def
_plan_shard_layout
(
allocations_info
:
List
[
Dict
[
str
,
Any
]],
shard_size_bytes
:
int
,
)
->
List
[
Tuple
[
int
,
int
]]:
return
_plan_shard_layout_impl
(
allocations_info
,
shard_size_bytes
)
def
_read_shard_to_queue
(
abs_path
:
str
,
sorted_entries
:
List
[
AllocationEntry
],
work_q
:
"queue.Queue[Optional[Tuple[AllocationEntry, 'torch.Tensor']]]"
,
*
,
pin_memory
:
bool
,
cancel_event
:
Optional
[
threading
.
Event
]
=
None
,
)
->
int
:
return
_read_shard_to_queue_impl
(
abs_path
,
sorted_entries
,
work_q
,
pin_memory
=
pin_memory
,
read_shard
=
_read_shard_sequential
,
cancel_event
=
cancel_event
,
)
def
_load_manifest_and_metadata
(
input_dir
:
str
,
)
->
Tuple
[
SaveManifest
,
Dict
[
str
,
Dict
[
str
,
Any
]]]:
return
_load_manifest_and_metadata_impl
(
input_dir
)
class
GMSStorageClient
:
"""Dump and restore GMS state to/from disk."""
def
__init__
(
self
,
output_dir
:
Optional
[
str
]
=
None
,
socket_path
:
Optional
[
str
]
=
None
,
device
:
int
=
0
,
*
,
timeout_ms
:
Optional
[
int
]
=
None
,
shard_size_bytes
:
int
=
4
*
1024
**
3
,
)
->
None
:
self
.
output_dir
=
output_dir
self
.
device
=
device
self
.
_timeout_ms
=
timeout_ms
self
.
_shard_size
=
shard_size_bytes
if
socket_path
is
None
:
from
gpu_memory_service.common.utils
import
get_socket_path
socket_path
=
get_socket_path
(
device
)
self
.
_socket_path
=
socket_path
def
save
(
self
,
max_workers
:
int
=
4
)
->
SaveManifest
:
"""Connect to GMS in RO mode and save all allocations + metadata to disk."""
self
.
_validate_save_request
()
output_dir
,
shards_dir
=
self
.
_prepare_output_dir
()
mm
=
GMSClientMemoryManager
(
self
.
_socket_path
,
device
=
self
.
device
)
try
:
mm
.
connect
(
RequestedLockType
.
RO
,
timeout_ms
=
self
.
_timeout_ms
)
layout_hash
=
mm
.
get_memory_layout_hash
()
if
not
layout_hash
:
raise
RuntimeError
(
"GMS server has no committed weights; nothing to dump"
)
allocations_info
=
[
_allocation_record
(
alloc
)
for
alloc
in
mm
.
list_handles
()
]
va_list
=
self
.
_import_source_mappings
(
mm
,
allocations_info
)
entries
=
self
.
_write_shards
(
shards_dir
,
allocations_info
,
va_list
,
max_workers
=
max_workers
,
)
metadata
=
self
.
_save_metadata
(
mm
)
except
Exception
:
mm
.
close
(
best_effort
=
True
)
raise
self
.
_write_json
(
os
.
path
.
join
(
output_dir
,
"gms_metadata.json"
),
metadata
)
manifest
=
SaveManifest
(
version
=
_CURRENT_VERSION
,
timestamp
=
time
.
time
(),
layout_hash
=
layout_hash
,
device
=
self
.
device
,
allocations
=
entries
,
)
self
.
_write_json
(
os
.
path
.
join
(
output_dir
,
"manifest.json"
),
manifest
.
to_dict
())
logger
.
info
(
"Wrote manifest with %d allocations"
,
len
(
entries
))
# Best-effort cleanup; CUDA context may be invalid after
# checkpoint (cuda-checkpoint tears down device state).
mm
.
close
(
best_effort
=
True
)
return
manifest
def
_validate_save_request
(
self
)
->
None
:
if
not
_GMS_IMPORTS_AVAILABLE
:
raise
RuntimeError
(
"GMS client imports unavailable (missing cuda-python or torch)"
)
if
self
.
output_dir
is
None
:
raise
ValueError
(
"output_dir must be set to call save(); pass it to GMSStorageClient()"
)
def
_prepare_output_dir
(
self
)
->
Tuple
[
str
,
str
]:
assert
self
.
output_dir
is
not
None
os
.
makedirs
(
self
.
output_dir
,
exist_ok
=
True
)
shards_dir
=
os
.
path
.
join
(
self
.
output_dir
,
"shards"
)
os
.
makedirs
(
shards_dir
,
exist_ok
=
True
)
for
name
in
os
.
listdir
(
shards_dir
):
if
name
.
startswith
(
"shard_"
)
and
name
.
endswith
(
".bin"
):
os
.
unlink
(
os
.
path
.
join
(
shards_dir
,
name
))
return
self
.
output_dir
,
shards_dir
def
_import_source_mappings
(
self
,
mm
:
Any
,
allocations_info
:
List
[
Dict
[
str
,
Any
]],
)
->
List
[
int
]:
va_list
=
[
mm
.
create_mapping
(
allocation_id
=
alloc
[
"allocation_id"
])
for
alloc
in
allocations_info
]
logger
.
info
(
"Phase A complete: imported %d allocation VAs"
,
len
(
va_list
))
return
va_list
def
_write_shards
(
self
,
shards_dir
:
str
,
allocations_info
:
List
[
Dict
[
str
,
Any
]],
va_list
:
List
[
int
],
*
,
max_workers
:
int
,
)
->
List
[
AllocationEntry
]:
layout
=
_plan_shard_layout
(
allocations_info
,
self
.
_shard_size
)
shard_groups
:
Dict
[
int
,
List
[
Tuple
[
int
,
int
]]]
=
defaultdict
(
list
)
for
index
,
(
shard_idx
,
byte_offset
)
in
enumerate
(
layout
):
shard_groups
[
shard_idx
].
append
((
index
,
byte_offset
))
entries
:
List
[
Optional
[
AllocationEntry
]]
=
[
None
]
*
len
(
allocations_info
)
def
_write_one_shard
(
shard_idx
:
int
,
alloc_pairs
:
List
[
Tuple
[
int
,
int
]]
)
->
None
:
filename
=
f
"shard_
{
shard_idx
:
04
d
}
.bin"
abs_path
=
os
.
path
.
join
(
shards_dir
,
filename
)
rel_path
=
os
.
path
.
join
(
"shards"
,
filename
)
with
open
(
abs_path
,
"wb"
)
as
handle
:
for
index
,
byte_offset
in
alloc_pairs
:
alloc
=
allocations_info
[
index
]
aligned_size
=
int
(
alloc
[
"aligned_size"
])
tensor
=
_tensor_from_pointer
(
va_list
[
index
],
[
aligned_size
],
[
1
],
torch
.
uint8
,
self
.
device
,
)
tensor
.
cpu
().
numpy
().
tofile
(
handle
)
entries
[
index
]
=
AllocationEntry
(
allocation_id
=
alloc
[
"allocation_id"
],
size
=
int
(
alloc
[
"size"
]),
aligned_size
=
aligned_size
,
tag
=
str
(
alloc
.
get
(
"tag"
,
"default"
)),
tensor_file
=
rel_path
,
tensor_offset
=
byte_offset
,
)
with
ThreadPoolExecutor
(
max_workers
=
max_workers
)
as
pool
:
futures
=
{
pool
.
submit
(
_write_one_shard
,
shard_idx
,
alloc_pairs
):
shard_idx
for
shard_idx
,
alloc_pairs
in
shard_groups
.
items
()
}
for
future
in
as_completed
(
futures
):
future
.
result
()
missing
=
sum
(
1
for
entry
in
entries
if
entry
is
None
)
if
missing
:
raise
RuntimeError
(
f
"BUG:
{
missing
}
allocation(s) missing after shard writers completed"
)
logger
.
info
(
"Phase B complete: wrote %d shards"
,
len
(
shard_groups
))
return
[
entry
for
entry
in
entries
if
entry
is
not
None
]
def
_write_json
(
self
,
path
:
str
,
payload
:
Dict
[
str
,
Any
])
->
None
:
with
open
(
path
,
"w"
,
encoding
=
"utf-8"
)
as
handle
:
json
.
dump
(
payload
,
handle
,
indent
=
2
)
def
_run_restore_copy_worker
(
self
,
ctx
:
_RestorePipelineContext
,
stream_idx
:
int
,
)
->
None
:
while
True
:
try
:
item
=
ctx
.
work_q
.
get
(
timeout
=
0.1
)
except
queue
.
Empty
:
if
ctx
.
cancel_event
.
is_set
():
return
continue
if
item
is
None
:
return
entry
,
src
=
item
try
:
while
not
ctx
.
va_events
[
entry
.
allocation_id
].
wait
(
timeout
=
0.1
):
if
ctx
.
cancel_event
.
is_set
():
return
dst
=
_tensor_from_pointer
(
ctx
.
vas
[
entry
.
allocation_id
],
[
entry
.
aligned_size
],
[
1
],
torch
.
uint8
,
self
.
device
,
)
if
ctx
.
streams
:
with
torch
.
cuda
.
stream
(
ctx
.
streams
[
stream_idx
]):
dst
.
copy_
(
src
,
non_blocking
=
src
.
is_pinned
())
else
:
dst
.
copy_
(
src
)
if
ctx
.
use_streams
and
src
.
is_pinned
():
with
ctx
.
lock
:
ctx
.
staged_srcs
.
append
(
src
)
except
Exception
as
exc
:
# noqa: BLE001
with
ctx
.
lock
:
ctx
.
copy_errors
.
append
(
exc
)
def
_start_restore_copy_threads
(
self
,
ctx
:
_RestorePipelineContext
,
)
->
List
[
threading
.
Thread
]:
threads
=
[
threading
.
Thread
(
target
=
self
.
_run_restore_copy_worker
,
args
=
(
ctx
,
index
),
daemon
=
True
,
)
for
index
in
range
(
ctx
.
worker_count
)
]
for
thread
in
threads
:
thread
.
start
()
return
threads
def
_prepare_restore_pipeline
(
self
,
manifest
:
SaveManifest
,
groups
:
Dict
[
str
,
List
[
AllocationEntry
]],
worker_count
:
int
,
input_dir
:
str
,
)
->
_RestorePipelineResources
:
ctx
=
_RestorePipelineContext
.
build
(
manifest
.
allocations
,
worker_count
,
device
=
self
.
device
,
use_streams
=
_TORCH_AVAILABLE
and
torch
.
cuda
.
is_available
(),
torch_module
=
torch
,
)
copy_threads
=
self
.
_start_restore_copy_threads
(
ctx
)
disk_pool
=
ThreadPoolExecutor
(
max_workers
=
worker_count
)
disk_futures
=
{
disk_pool
.
submit
(
_read_shard_to_queue
,
os
.
path
.
join
(
input_dir
,
rel_path
),
sorted_entries
,
ctx
.
work_q
,
pin_memory
=
ctx
.
use_streams
,
cancel_event
=
ctx
.
cancel_event
,
):
rel_path
for
rel_path
,
sorted_entries
in
groups
.
items
()
}
return
_RestorePipelineResources
(
ctx
=
ctx
,
disk_pool
=
disk_pool
,
disk_futures
=
disk_futures
,
copy_threads
=
copy_threads
,
)
def
_allocate_restore_mappings
(
self
,
mm
:
Any
,
manifest
:
SaveManifest
,
ctx
:
_RestorePipelineContext
,
)
->
Dict
[
str
,
str
]:
id_map
:
Dict
[
str
,
str
]
=
{}
for
entry
in
manifest
.
allocations
:
old_id
=
entry
.
allocation_id
va
=
mm
.
create_mapping
(
size
=
entry
.
size
,
tag
=
entry
.
tag
)
id_map
[
old_id
]
=
mm
.
mappings
[
va
].
allocation_id
ctx
.
vas
[
old_id
]
=
va
ctx
.
va_events
[
old_id
].
set
()
logger
.
info
(
"Phase A complete: allocated %d GMS VAs; waiting for disk/copy pipeline"
,
len
(
ctx
.
vas
),
)
return
id_map
def
_await_disk_reads
(
self
,
disk_futures
:
Dict
[
Future
[
int
],
str
])
->
None
:
for
future
in
as_completed
(
disk_futures
):
rel_path
=
disk_futures
[
future
]
try
:
future
.
result
()
except
CancelledError
:
pass
except
Exception
as
exc
:
raise
RuntimeError
(
f
"Failed to load shard
{
rel_path
}
:
{
exc
}
"
)
from
exc
def
_stop_restore_copy_threads
(
self
,
ctx
:
_RestorePipelineContext
,
threads
:
List
[
threading
.
Thread
],
*
,
drain_queue
:
bool
=
False
,
)
->
None
:
if
drain_queue
:
self
.
_drain_restore_queue
(
ctx
)
for
_
in
threads
:
if
drain_queue
:
# Cancel path: workers may have exited, so drain to make room.
while
True
:
try
:
ctx
.
work_q
.
put
(
None
,
timeout
=
0.1
)
break
except
queue
.
Full
:
self
.
_drain_restore_queue
(
ctx
)
else
:
# Normal path: disk reads are done and workers are alive; block
# until a slot opens rather than spinning with a timeout.
ctx
.
work_q
.
put
(
None
)
for
thread
in
threads
:
thread
.
join
()
def
_drain_restore_queue
(
self
,
ctx
:
_RestorePipelineContext
)
->
None
:
while
True
:
try
:
ctx
.
work_q
.
get_nowait
()
except
queue
.
Empty
:
return
def
_cancel_restore_pipeline
(
self
,
ctx
:
_RestorePipelineContext
)
->
None
:
ctx
.
cancel_event
.
set
()
for
event
in
ctx
.
va_events
.
values
():
event
.
set
()
self
.
_drain_restore_queue
(
ctx
)
def
_finalize_restore_pipeline
(
self
,
ctx
:
_RestorePipelineContext
)
->
None
:
if
ctx
.
use_streams
:
torch
.
cuda
.
synchronize
(
device
=
self
.
device
)
ctx
.
staged_srcs
.
clear
()
if
ctx
.
copy_errors
:
raise
RuntimeError
(
f
"Failed to copy restored data to GMS:
{
ctx
.
copy_errors
[
0
]
}
"
)
def
_drain_restore_pipeline
(
self
,
resources
:
_RestorePipelineResources
)
->
None
:
disk_error
:
Optional
[
BaseException
]
=
None
finalize_error
:
Optional
[
BaseException
]
=
None
drain_queue
=
False
try
:
self
.
_await_disk_reads
(
resources
.
disk_futures
)
except
Exception
as
exc
:
disk_error
=
exc
self
.
_cancel_restore_pipeline
(
resources
.
ctx
)
drain_queue
=
True
resources
.
disk_pool
.
shutdown
(
wait
=
True
,
cancel_futures
=
True
)
else
:
resources
.
disk_pool
.
shutdown
(
wait
=
True
)
try
:
self
.
_stop_restore_copy_threads
(
resources
.
ctx
,
resources
.
copy_threads
,
drain_queue
=
drain_queue
,
)
finally
:
resources
.
active
=
False
try
:
self
.
_finalize_restore_pipeline
(
resources
.
ctx
)
except
Exception
as
exc
:
# noqa: BLE001
finalize_error
=
exc
if
disk_error
is
not
None
:
raise
disk_error
if
finalize_error
is
not
None
:
raise
finalize_error
def
_shutdown_restore_pipeline
(
self
,
resources
:
_RestorePipelineResources
,
)
->
None
:
if
not
resources
.
active
:
return
self
.
_cancel_restore_pipeline
(
resources
.
ctx
)
resources
.
disk_pool
.
shutdown
(
wait
=
True
,
cancel_futures
=
True
)
self
.
_stop_restore_copy_threads
(
resources
.
ctx
,
resources
.
copy_threads
,
drain_queue
=
True
,
)
resources
.
active
=
False
# Synchronize async copies to prevent use-after-free of staged pinned
# buffers, but suppress copy errors — the caller already has an error
# to propagate and we must not mask it.
try
:
self
.
_finalize_restore_pipeline
(
resources
.
ctx
)
except
Exception
:
# noqa: BLE001
self
.
_logger
.
warning
(
"cleanup failed during restore error handling"
,
exc_info
=
True
,
)
def
load_to_gms
(
self
,
input_dir
:
str
,
*
,
max_workers
:
int
=
4
,
clear_existing
:
bool
=
True
,
)
->
Dict
[
str
,
str
]:
if
not
_GMS_IMPORTS_AVAILABLE
:
raise
RuntimeError
(
"GMS client imports unavailable (missing cuda-python or torch)"
)
manifest
,
saved_metadata
=
_load_manifest_and_metadata
(
input_dir
)
groups
=
_group_entries_by_shard
(
manifest
.
allocations
)
worker_count
=
max
(
1
,
min
(
max_workers
,
len
(
groups
)
or
1
))
with
GMSClientMemoryManager
(
self
.
_socket_path
,
device
=
self
.
device
)
as
mm
:
mm
.
connect
(
RequestedLockType
.
RW
,
timeout_ms
=
self
.
_timeout_ms
)
if
clear_existing
:
logger
.
info
(
"RW connect cleared any previously committed GMS state"
)
resources
=
self
.
_prepare_restore_pipeline
(
manifest
,
groups
,
worker_count
,
input_dir
,
)
try
:
id_map
=
self
.
_allocate_restore_mappings
(
mm
,
manifest
,
resources
.
ctx
)
self
.
_drain_restore_pipeline
(
resources
)
except
Exception
:
self
.
_shutdown_restore_pipeline
(
resources
)
raise
logger
.
info
(
"Phase B complete: streamed %d allocations to GMS memory"
,
len
(
manifest
.
allocations
),
)
self
.
_restore_metadata
(
mm
,
saved_metadata
,
id_map
)
if
not
mm
.
commit
():
raise
RuntimeError
(
"GMS commit failed after restore"
)
logger
.
info
(
"load_to_gms complete: %d allocations, %d metadata keys"
,
len
(
id_map
),
len
(
saved_metadata
),
)
return
id_map
def
_restore_metadata
(
self
,
mm
:
Any
,
saved_metadata
:
Dict
[
str
,
Dict
[
str
,
Any
]],
id_map
:
Dict
[
str
,
str
],
)
->
None
:
for
key
,
meta
in
saved_metadata
.
items
():
old_alloc_id
=
meta
[
"allocation_id"
]
new_alloc_id
=
id_map
.
get
(
old_alloc_id
,
old_alloc_id
)
ok
=
mm
.
metadata_put
(
key
,
new_alloc_id
,
meta
[
"offset_bytes"
],
meta
[
"value"
])
if
not
ok
:
raise
RuntimeError
(
f
"Failed to write metadata key=
{
key
!
r
}
"
)
logger
.
debug
(
"Restored metadata key=%s -> alloc=%s"
,
key
,
new_alloc_id
)
logger
.
info
(
"Restored %d metadata keys; committing"
,
len
(
saved_metadata
))
@
staticmethod
def
load_tensors
(
input_dir
:
str
,
device
:
int
=
0
,
*
,
max_workers
:
int
=
4
,
)
->
Tuple
[
Dict
[
str
,
"torch.Tensor"
],
Dict
[
str
,
Dict
[
str
,
Any
]]]:
if
not
_TORCH_AVAILABLE
:
raise
RuntimeError
(
"PyTorch is required for load_tensors()"
)
manifest
,
metadata
=
_load_manifest_and_metadata
(
input_dir
)
groups
=
_group_entries_by_shard
(
manifest
.
allocations
)
tensors
:
Dict
[
str
,
"torch.Tensor"
]
=
{}
with
ThreadPoolExecutor
(
max_workers
=
max_workers
)
as
pool
:
futures
=
{
pool
.
submit
(
_read_shard_sequential
,
os
.
path
.
join
(
input_dir
,
rel_path
),
sorted_entries
,
device
,
):
rel_path
for
rel_path
,
sorted_entries
in
groups
.
items
()
}
for
future
in
as_completed
(
futures
):
rel_path
=
futures
[
future
]
try
:
tensors
.
update
(
future
.
result
())
except
Exception
as
exc
:
raise
RuntimeError
(
f
"Failed to load shard
{
rel_path
}
:
{
exc
}
"
)
from
exc
logger
.
info
(
"Loaded %d allocations from %s"
,
len
(
tensors
),
input_dir
)
return
tensors
,
metadata
def
_save_metadata
(
self
,
mm
:
Any
)
->
Dict
[
str
,
Any
]:
result
:
Dict
[
str
,
Any
]
=
{}
for
key
in
mm
.
metadata_list
():
got
=
mm
.
metadata_get
(
key
)
if
got
is
None
:
logger
.
warning
(
"Metadata key disappeared during dump: %s"
,
key
)
continue
allocation_id
,
offset_bytes
,
value
=
got
result
[
key
]
=
{
"allocation_id"
:
str
(
allocation_id
),
"offset_bytes"
:
int
(
offset_bytes
),
"value"
:
base64
.
b64encode
(
value
).
decode
(
"ascii"
),
}
return
result
lib/gpu_memory_service/tests/test_runtime_flows.py
View file @
f3b181a9
...
...
@@ -245,9 +245,6 @@ def running_gms(monkeypatch, tmp_path):
server_allocations
,
"cumem_export_to_shareable_handle"
,
export_fd
)
monkeypatch
.
setattr
(
client_memory_manager
,
"cuda_set_current_device"
,
lambda
device
:
None
)
monkeypatch
.
setattr
(
client_memory_manager
,
"cumem_get_allocation_granularity"
,
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment