Unverified Commit b2f7f220 authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

fix(gms): rewrite GMS checkpoint/restore operator support (#8194)


Co-authored-by: default avatarDmitry Tokarev <dtokarev@nvidia.com>
parent 2d86b81d
......@@ -98,12 +98,12 @@ func TestNewCheckpointJob(t *testing.T) {
}
}
func TestNewCheckpointJobPrefersContainerNamedMain(t *testing.T) {
func TestNewCheckpointJobWrapsFirstContainer(t *testing.T) {
job, err := NewCheckpointJob(&corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{Name: "worker", Command: []string{"python3", "-m", "dynamo.vllm"}, Args: []string{"--model", "Qwen"}},
{Name: "sidecar", Command: []string{"sleep"}, Args: []string{"infinity"}},
{Name: "main", Command: []string{"python3", "-m", "dynamo.vllm"}, Args: []string{"--model", "Qwen"}},
},
},
}, CheckpointJobOptions{
......@@ -118,17 +118,17 @@ func TestNewCheckpointJobPrefersContainerNamedMain(t *testing.T) {
t.Fatalf("expected checkpoint job, got error: %v", err)
}
main := requireCheckpointContainer(t, job.Spec.Template.Spec.Containers, "main")
if len(main.Command) != 1 || main.Command[0] != "cuda-checkpoint" {
t.Fatalf("expected main container to be wrapped, got %#v", main.Command)
worker := requireCheckpointContainer(t, job.Spec.Template.Spec.Containers, "worker")
if len(worker.Command) != 1 || worker.Command[0] != "cuda-checkpoint" {
t.Fatalf("expected first container to be wrapped, got %#v", worker.Command)
}
expectedArgs := []string{"--launch-job", "python3", "-m", "dynamo.vllm", "--model", "Qwen"}
if len(main.Args) != len(expectedArgs) {
t.Fatalf("expected launch-job args %#v, got %#v", expectedArgs, main.Args)
if len(worker.Args) != len(expectedArgs) {
t.Fatalf("expected launch-job args %#v, got %#v", expectedArgs, worker.Args)
}
for i := range expectedArgs {
if main.Args[i] != expectedArgs[i] {
t.Fatalf("expected launch-job args %#v, got %#v", expectedArgs, main.Args)
if worker.Args[i] != expectedArgs[i] {
t.Fatalf("expected launch-job args %#v, got %#v", expectedArgs, worker.Args)
}
}
......@@ -168,26 +168,7 @@ func TestNewCheckpointJobAllowsSingleNonMainContainer(t *testing.T) {
}
}
func TestNewCheckpointJobRejectsMultiContainerPodWithoutMain(t *testing.T) {
_, err := NewCheckpointJob(&corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{Name: "sidecar", Command: []string{"sleep"}, Args: []string{"infinity"}},
{Name: "worker", Command: []string{"python3", "-m", "dynamo.vllm"}},
},
},
}, CheckpointJobOptions{
Namespace: "test-ns",
CheckpointID: "hash",
ArtifactVersion: "2",
Name: "test-job",
TTLSecondsAfterFinish: ptr.To(int32(300)),
WrapLaunchJob: true,
})
if err == nil || err.Error() != "checkpoint job requires a container named \"main\" when multiple containers are present" {
t.Fatalf("expected missing main container error, got %v", err)
}
}
func TestGetCheckpointJobName(t *testing.T) {
name := GetCheckpointJobName("abc123def4567890", "2")
......
......@@ -49,19 +49,13 @@ func NewRestorePod(pod *corev1.Pod, opts PodOptions) *corev1.Pod {
return pod
}
// resolveWorkerContainer returns the workload container, which is always
// Containers[0]. GMS sidecars are appended after the workload.
func resolveWorkerContainer(podSpec *corev1.PodSpec) *corev1.Container {
if podSpec == nil {
if podSpec == nil || len(podSpec.Containers) == 0 {
return nil
}
if len(podSpec.Containers) == 1 {
return &podSpec.Containers[0]
}
for index := range podSpec.Containers {
if podSpec.Containers[index].Name == "main" {
return &podSpec.Containers[index]
}
}
return nil
return &podSpec.Containers[0]
}
func PrepareRestorePodSpec(
......@@ -73,7 +67,7 @@ func PrepareRestorePodSpec(
) {
EnsureLocalhostSeccompProfile(podSpec, seccompProfile)
if storage.PVCName != "" {
injectCheckpointVolume(podSpec, storage.PVCName)
InjectCheckpointVolume(podSpec, storage.PVCName)
}
if storage.BasePath != "" {
injectCheckpointVolumeMount(container, storage.BasePath)
......@@ -113,7 +107,7 @@ func ValidateRestorePodSpec(
}
container := resolveWorkerContainer(podSpec)
if container == nil {
return fmt.Errorf("restore target must include a worker container named main")
return fmt.Errorf("restore target must have at least one container")
}
if storage.PVCName != "" {
hasVolume := false
......@@ -209,19 +203,18 @@ func DiscoverStorageFromDaemonSets(namespace string, daemonSets []appsv1.DaemonS
)
}
func PrepareRestorePodSpecForCheckpoint(
// DiscoverAndResolveStorage lists snapshot-agent DaemonSets in the given
// namespace, discovers the shared storage configuration, and resolves the
// checkpoint-specific path for the given checkpoint ID and artifact version.
func DiscoverAndResolveStorage(
ctx context.Context,
reader ctrlclient.Reader,
namespace string,
podSpec *corev1.PodSpec,
container *corev1.Container,
checkpointID string,
artifactVersion string,
seccompProfile string,
isCheckpointReady bool,
) error {
) (Storage, error) {
if reader == nil {
return fmt.Errorf("snapshot client is required")
return Storage{}, fmt.Errorf("snapshot client is required")
}
daemonSets := &appsv1.DaemonSetList{}
......@@ -231,24 +224,41 @@ func PrepareRestorePodSpecForCheckpoint(
ctrlclient.InNamespace(namespace),
ctrlclient.MatchingLabels{SnapshotAgentLabelKey: SnapshotAgentLabelValue},
); err != nil {
return fmt.Errorf("list snapshot-agent daemonsets in %s: %w", namespace, err)
return Storage{}, fmt.Errorf("list snapshot-agent daemonsets in %s: %w", namespace, err)
}
storage, err := DiscoverStorageFromDaemonSets(namespace, daemonSets.Items)
if err != nil {
return err
return Storage{}, err
}
resolvedStorage, err := ResolveCheckpointStorage(checkpointID, artifactVersion, storage)
return ResolveCheckpointStorage(checkpointID, artifactVersion, storage)
}
func PrepareRestorePodSpecForCheckpoint(
ctx context.Context,
reader ctrlclient.Reader,
namespace string,
podSpec *corev1.PodSpec,
container *corev1.Container,
checkpointID string,
artifactVersion string,
seccompProfile string,
isCheckpointReady bool,
) error {
storage, err := DiscoverAndResolveStorage(ctx, reader, namespace, checkpointID, artifactVersion)
if err != nil {
return err
}
PrepareRestorePodSpec(podSpec, container, resolvedStorage, seccompProfile, isCheckpointReady)
PrepareRestorePodSpec(podSpec, container, storage, seccompProfile, isCheckpointReady)
return nil
}
func injectCheckpointVolume(podSpec *corev1.PodSpec, pvcName string) {
// InjectCheckpointVolume adds the checkpoint PVC volume to the pod spec if
// not already present. Used by both the snapshot protocol and the operator's
// GMS checkpoint wiring.
func InjectCheckpointVolume(podSpec *corev1.PodSpec, pvcName string) {
for _, volume := range podSpec.Volumes {
if volume.Name == CheckpointVolumeName {
return
......
......@@ -187,13 +187,13 @@ func TestPrepareRestorePodSpecSynthesizesStartupProbeFromLiveness(t *testing.T)
}
}
func TestNewRestorePodTargetsMainContainerWhenSidecarsPresent(t *testing.T) {
func TestNewRestorePodTargetsFirstContainerWhenSidecarsPresent(t *testing.T) {
restorePod := NewRestorePod(&corev1.Pod{
ObjectMeta: metav1.ObjectMeta{Name: "worker"},
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{Name: "worker", Image: "test:latest", Command: []string{"python3"}, Args: []string{"-m", "dynamo.vllm"}},
{Name: "sidecar", Image: "sidecar:latest", Command: []string{"sidecar"}, Args: []string{"run"}},
{Name: "main", Image: "test:latest", Command: []string{"python3"}, Args: []string{"-m", "dynamo.vllm"}},
},
},
}, PodOptions{
......@@ -208,14 +208,14 @@ func TestNewRestorePodTargetsMainContainerWhenSidecarsPresent(t *testing.T) {
SeccompProfile: DefaultSeccompLocalhostProfile,
})
if got := restorePod.Spec.Containers[0].Command; len(got) != 1 || got[0] != "sidecar" {
t.Fatalf("expected sidecar command to remain unchanged, got %#v", got)
if got := restorePod.Spec.Containers[0].Command; len(got) != 2 || got[0] != "sleep" || got[1] != "infinity" {
t.Fatalf("expected first container placeholder command, got %#v", got)
}
if got := restorePod.Spec.Containers[1].Command; len(got) != 2 || got[0] != "sleep" || got[1] != "infinity" {
t.Fatalf("expected main container placeholder command, got %#v", got)
if restorePod.Spec.Containers[0].Args != nil {
t.Fatalf("expected first container args to be cleared: %#v", restorePod.Spec.Containers[0].Args)
}
if restorePod.Spec.Containers[1].Args != nil {
t.Fatalf("expected main container args to be cleared: %#v", restorePod.Spec.Containers[1].Args)
if got := restorePod.Spec.Containers[1].Command; len(got) != 1 || got[0] != "sidecar" {
t.Fatalf("expected sidecar command to remain unchanged, got %#v", got)
}
}
......@@ -311,7 +311,7 @@ func TestValidateRestorePodSpec(t *testing.T) {
}
}
func TestValidateRestorePodSpecRequiresMainContainerWhenMultiContainer(t *testing.T) {
func TestValidateRestorePodSpecAcceptsFirstContainerAsWorker(t *testing.T) {
profile := DefaultSeccompLocalhostProfile
podSpec := &corev1.PodSpec{
SecurityContext: &corev1.PodSecurityContext{
......@@ -346,12 +346,13 @@ func TestValidateRestorePodSpecRequiresMainContainerWhenMultiContainer(t *testin
BasePath: "/checkpoints",
}
if err := ValidateRestorePodSpec(podSpec, storage, DefaultSeccompLocalhostProfile); err == nil || err.Error() != "restore target must include a worker container named main" {
t.Fatalf("expected multi-container restore target without main to be rejected, got %v", err)
// Containers[0] is always the worker, regardless of name
if err := ValidateRestorePodSpec(podSpec, storage, DefaultSeccompLocalhostProfile); err != nil {
t.Fatalf("expected validation to pass for first container as worker, got %v", err)
}
}
func TestValidateRestorePodSpecAllowsMainContainerWithSidecars(t *testing.T) {
func TestValidateRestorePodSpecAllowsWorkerWithSidecars(t *testing.T) {
profile := DefaultSeccompLocalhostProfile
podSpec := &corev1.PodSpec{
SecurityContext: &corev1.PodSecurityContext{
......@@ -369,14 +370,14 @@ func TestValidateRestorePodSpecAllowsMainContainerWithSidecars(t *testing.T) {
},
}},
Containers: []corev1.Container{
{Name: "sidecar"},
{
Name: "main",
Name: "worker",
VolumeMounts: []corev1.VolumeMount{{
Name: CheckpointVolumeName,
MountPath: "/checkpoints",
}},
},
{Name: "sidecar"},
},
}
......@@ -387,7 +388,7 @@ func TestValidateRestorePodSpecAllowsMainContainerWithSidecars(t *testing.T) {
}
if err := ValidateRestorePodSpec(podSpec, storage, DefaultSeccompLocalhostProfile); err != nil {
t.Fatalf("expected main container with sidecars to validate, got %v", err)
t.Fatalf("expected worker with sidecars to validate, got %v", err)
}
}
......
......@@ -4,8 +4,8 @@
"""GMS server entry point.
Launches two GMS server processes per GPU (one for weights, one for kv_cache).
Writes a ready file once all expected UDS sockets are present. Monitors an
optional checkpoint stop file and shuts down cleanly when it appears.
Writes a ready file once all expected UDS sockets are present. Runs until
SIGTERM (pod termination kills it).
"""
from __future__ import annotations
......@@ -18,6 +18,7 @@ import sys
import time
from pathlib import Path
from gpu_memory_service.common.cuda_utils import list_devices
from gpu_memory_service.common.utils import get_socket_path
logging.basicConfig(
......@@ -30,36 +31,11 @@ _TAGS = ("weights", "kv_cache")
_READY_FILE = "gms-ready"
def _ready_file_path() -> Path:
return Path(os.environ.get("GMS_SOCKET_DIR", "/tmp")) / _READY_FILE
def _list_devices() -> list[int]:
import pynvml
pynvml.nvmlInit()
try:
count = pynvml.nvmlDeviceGetCount()
finally:
pynvml.nvmlShutdown()
if count == 0:
raise SystemExit("no nvidia devices found")
return list(range(count))
def _optional_checkpoint_stop_file() -> Path | None:
control_dir = os.environ.get("GMS_CONTROL_DIR")
if not control_dir:
return None
return Path(control_dir) / "checkpoint-done"
def main() -> None:
ready_file = _ready_file_path()
ready_file = Path(os.environ.get("GMS_SOCKET_DIR", "/tmp")) / _READY_FILE
ready_file.unlink(missing_ok=True)
devices = _list_devices()
devices = list_devices()
processes = []
for device in devices:
for tag in _TAGS:
......@@ -89,14 +65,8 @@ def main() -> None:
signal.signal(signal.SIGTERM, terminate)
signal.signal(signal.SIGINT, terminate)
stop_file = _optional_checkpoint_stop_file()
ready_written = False
while True:
stop_requested = stop_file is not None and stop_file.exists()
if stop_requested:
logger.info("checkpoint stop requested; shutting down GMS servers")
shutdown()
if not ready_written:
sockets_ready = all(
os.path.exists(get_socket_path(device, tag))
......@@ -113,8 +83,6 @@ def main() -> None:
if exit_code is None:
running = True
continue
if stop_requested:
continue
shutdown()
raise SystemExit(exit_code)
......
......@@ -12,10 +12,13 @@ from __future__ import annotations
import logging
import os
import signal
import sys
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from gpu_memory_service.common.utils import get_socket_path
from gpu_memory_service.common.cuda_utils import list_devices
from gpu_memory_service.common.utils import get_socket_path, wait_for_weights_socket
from gpu_memory_service.snapshot.storage_client import GMSStorageClient
logging.basicConfig(
......@@ -24,32 +27,9 @@ logging.basicConfig(
)
logger = logging.getLogger(__name__)
_WEIGHTS_TAG = "weights"
_DEFAULT_MAX_WORKERS = 8
def _list_devices() -> list[int]:
import pynvml
pynvml.nvmlInit()
try:
count = pynvml.nvmlDeviceGetCount()
finally:
pynvml.nvmlShutdown()
if count == 0:
raise SystemExit("no nvidia devices found")
return list(range(count))
def _wait_for_weights_socket(device: int) -> None:
socket_path = get_socket_path(device, _WEIGHTS_TAG)
while not os.path.exists(socket_path):
time.sleep(1)
def _load_device(checkpoint_dir: str, device: int, max_workers: int) -> None:
_wait_for_weights_socket(device)
wait_for_weights_socket(device)
input_dir = os.path.join(checkpoint_dir, f"device-{device}")
logger.info("Loading GMS checkpoint: device=%d input_dir=%s", device, input_dir)
t0 = time.monotonic()
......@@ -68,8 +48,8 @@ def _load_device(checkpoint_dir: str, device: int, max_workers: int) -> None:
def main() -> None:
checkpoint_dir = os.environ["GMS_CHECKPOINT_DIR"]
max_workers = int(os.environ.get("GMS_LOAD_WORKERS", str(_DEFAULT_MAX_WORKERS)))
devices = _list_devices()
max_workers = int(os.environ.get("GMS_LOAD_WORKERS", "8"))
devices = list_devices()
t0 = time.monotonic()
with ThreadPoolExecutor(max_workers=len(devices)) as pool:
......@@ -83,6 +63,7 @@ def main() -> None:
logger.info("Device %d load complete", dev)
elapsed = time.monotonic() - t0
logger.info("All %d devices loaded in %.2fs", len(devices), elapsed)
signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))
while True:
time.sleep(3600)
......
......@@ -3,25 +3,22 @@
"""GMS checkpoint saver entry point.
Waits for the checkpoint pod to reach Ready=True, then saves GMS state from
each device in parallel. Writes a stop file to signal the GMS server to shut
down after save completes.
Waits for committed GMS weights on each device, then saves GPU memory state
to the checkpoint directory. Runs as an init sidecar — sleeps after saving
until the pod terminates.
"""
from __future__ import annotations
import json
import logging
import os
import ssl
import signal
import sys
import time
import urllib.error
import urllib.request
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Any
from gpu_memory_service.common.utils import get_socket_path
from gpu_memory_service.common.cuda_utils import list_devices
from gpu_memory_service.common.utils import get_socket_path, wait_for_weights_socket
from gpu_memory_service.snapshot.storage_client import GMSStorageClient
logging.basicConfig(
......@@ -30,127 +27,42 @@ logging.basicConfig(
)
logger = logging.getLogger(__name__)
_WEIGHTS_TAG = "weights"
_SERVICE_ACCOUNT_TOKEN = Path("/var/run/secrets/kubernetes.io/serviceaccount/token")
_SERVICE_ACCOUNT_CA = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt"
def _list_devices() -> list[int]:
import pynvml
pynvml.nvmlInit()
try:
count = pynvml.nvmlDeviceGetCount()
finally:
pynvml.nvmlShutdown()
if count == 0:
raise SystemExit("no nvidia devices found")
return list(range(count))
def _wait_for_weights_socket(device: int) -> None:
socket_path = get_socket_path(device, _WEIGHTS_TAG)
while not os.path.exists(socket_path):
time.sleep(1)
def _checkpoint_pod_ready(pod: dict[str, Any]) -> bool:
status = pod.get("status") or {}
if str(status.get("phase", "")).strip() != "Running":
return False
for condition in status.get("conditions") or []:
if (
condition.get("type") == "Ready"
and str(condition.get("status", "")).strip() == "True"
):
return True
return False
def _main_terminated(pod: dict[str, Any]) -> bool:
status = pod.get("status") or {}
for container in status.get("containerStatuses") or []:
if container.get("name") != "main":
continue
return bool((container.get("state") or {}).get("terminated"))
return False
def _save_device(checkpoint_dir: str, device: int, max_workers: int) -> None:
wait_for_weights_socket(device)
output_dir = os.path.join(checkpoint_dir, f"device-{device}")
logger.info("Saving GMS checkpoint: device=%d output_dir=%s", device, output_dir)
t0 = time.monotonic()
GMSStorageClient(
output_dir,
socket_path=get_socket_path(device),
device=device,
).save(max_workers=max_workers)
elapsed = time.monotonic() - t0
logger.info("GMS checkpoint saved: device=%d elapsed=%.2fs", device, elapsed)
def main() -> None:
service_token = _SERVICE_ACCOUNT_TOKEN.read_text(encoding="utf-8").strip()
ssl_context = ssl.create_default_context(cafile=_SERVICE_ACCOUNT_CA)
pod_api_url = (
"https://"
+ os.environ["KUBERNETES_SERVICE_HOST"]
+ ":"
+ os.environ.get("KUBERNETES_SERVICE_PORT_HTTPS", "443")
+ f"/api/v1/namespaces/{os.environ['POD_NAMESPACE']}/pods/{os.environ['POD_NAME']}"
)
checkpoint_dir = os.environ["GMS_CHECKPOINT_DIR"]
max_workers = int(os.environ.get("GMS_SAVE_WORKERS", "8"))
def checkpoint_pod() -> dict[str, Any]:
request = urllib.request.Request(
pod_api_url,
headers={"Authorization": f"Bearer {service_token}"},
)
with urllib.request.urlopen(
request,
context=ssl_context,
timeout=5,
) as response:
return json.load(response)
logger.info("Waiting for checkpoint pod Ready=True before GMS save")
devices = list_devices()
logger.info("Starting GMS save for %d devices", len(devices))
t0 = time.monotonic()
with ThreadPoolExecutor(max_workers=len(devices)) as pool:
futures = {
pool.submit(_save_device, checkpoint_dir, dev, max_workers): dev
for dev in devices
}
for future in as_completed(futures):
future.result()
elapsed = time.monotonic() - t0
logger.info("All %d devices saved in %.2fs", len(devices), elapsed)
logger.info("Save complete; sleeping until pod terminates")
signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))
while True:
try:
pod = checkpoint_pod()
except (urllib.error.URLError, TimeoutError, ssl.SSLError, OSError):
time.sleep(1)
continue
if _checkpoint_pod_ready(pod):
break
if _main_terminated(pod):
raise SystemExit("main container terminated before GMS save could start")
time.sleep(1)
def _save_device(device: int, max_workers: int) -> None:
_wait_for_weights_socket(device)
output_dir = os.path.join(checkpoint_dir, f"device-{device}")
logger.info(
"Saving GMS checkpoint: device=%d output_dir=%s",
device,
output_dir,
)
t0 = time.monotonic()
client = GMSStorageClient(
output_dir,
socket_path=get_socket_path(device),
device=device,
)
client.save(max_workers=max_workers)
elapsed = time.monotonic() - t0
logger.info("GMS checkpoint saved: device=%d elapsed=%.2fs", device, elapsed)
max_workers = int(os.environ.get("GMS_SAVE_WORKERS", "8"))
logger.info("Checkpoint pod is Ready; starting GMS save")
try:
devices = _list_devices()
t0 = time.monotonic()
with ThreadPoolExecutor(max_workers=len(devices)) as pool:
futures = {
pool.submit(_save_device, dev, max_workers): dev for dev in devices
}
for future in as_completed(futures):
future.result()
elapsed = time.monotonic() - t0
logger.info("All %d devices saved in %.2fs", len(devices), elapsed)
finally:
(Path(os.environ["GMS_CONTROL_DIR"]) / "checkpoint-done").write_text(
"done",
encoding="utf-8",
)
time.sleep(3600)
if __name__ == "__main__":
......
......@@ -24,6 +24,20 @@ except ImportError:
cuda = _MissingCuda()
def list_devices() -> list[int]:
"""Return list of CUDA device indices visible to this process via NVML."""
import pynvml
pynvml.nvmlInit()
try:
count = pynvml.nvmlDeviceGetCount()
finally:
pynvml.nvmlShutdown()
if count == 0:
raise SystemExit("no nvidia devices found")
return list(range(count))
def cuda_check_result(result: cuda.CUresult, name: str) -> None:
if result != cuda.CUresult.CUDA_SUCCESS:
err_result, err_str = cuda.cuGetErrorString(result)
......
......@@ -51,3 +51,12 @@ def get_socket_path(device: int, tag: str = "weights") -> str:
_uuid_cache[device] = uuid
socket_dir = os.environ.get("GMS_SOCKET_DIR") or tempfile.gettempdir()
return os.path.join(socket_dir, f"gms_{uuid}_{tag}.sock")
def wait_for_weights_socket(device: int) -> None:
"""Block until the GMS weights socket for the given device exists."""
import time
path = get_socket_path(device, "weights")
while not os.path.exists(path):
time.sleep(0.1)
......@@ -66,6 +66,7 @@ setup(
packages=[
"gpu_memory_service",
"gpu_memory_service.cli",
"gpu_memory_service.cli.snapshot",
"gpu_memory_service.common",
"gpu_memory_service.common.protocol",
"gpu_memory_service.server",
......@@ -79,10 +80,12 @@ setup(
"gpu_memory_service.integrations.sglang",
"gpu_memory_service.integrations.trtllm",
"gpu_memory_service.integrations.vllm",
"gpu_memory_service.snapshot",
],
package_dir={
"gpu_memory_service": ".",
"gpu_memory_service.cli": "cli",
"gpu_memory_service.cli.snapshot": "cli/snapshot",
"gpu_memory_service.common": "common",
"gpu_memory_service.common.protocol": "common/protocol",
"gpu_memory_service.server": "server",
......@@ -96,6 +99,7 @@ setup(
"gpu_memory_service.integrations.sglang": "integrations/sglang",
"gpu_memory_service.integrations.trtllm": "integrations/trtllm",
"gpu_memory_service.integrations.vllm": "integrations/vllm",
"gpu_memory_service.snapshot": "snapshot",
},
package_data={
"gpu_memory_service.client.torch.extensions": ["*.cpp"],
......
......@@ -245,6 +245,12 @@ def running_gms(monkeypatch, tmp_path):
server_allocations, "cumem_export_to_shareable_handle", export_fd
)
monkeypatch.setattr(
client_memory_manager,
"cuda_set_current_device",
lambda device: None,
raising=False,
)
monkeypatch.setattr(
client_memory_manager,
"cumem_get_allocation_granularity",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment