Unverified Commit b2f7f220 authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

fix(gms): rewrite GMS checkpoint/restore operator support (#8194)


Co-authored-by: default avatarDmitry Tokarev <dtokarev@nvidia.com>
parent 2d86b81d
...@@ -98,12 +98,12 @@ func TestNewCheckpointJob(t *testing.T) { ...@@ -98,12 +98,12 @@ func TestNewCheckpointJob(t *testing.T) {
} }
} }
func TestNewCheckpointJobPrefersContainerNamedMain(t *testing.T) { func TestNewCheckpointJobWrapsFirstContainer(t *testing.T) {
job, err := NewCheckpointJob(&corev1.PodTemplateSpec{ job, err := NewCheckpointJob(&corev1.PodTemplateSpec{
Spec: corev1.PodSpec{ Spec: corev1.PodSpec{
Containers: []corev1.Container{ Containers: []corev1.Container{
{Name: "worker", Command: []string{"python3", "-m", "dynamo.vllm"}, Args: []string{"--model", "Qwen"}},
{Name: "sidecar", Command: []string{"sleep"}, Args: []string{"infinity"}}, {Name: "sidecar", Command: []string{"sleep"}, Args: []string{"infinity"}},
{Name: "main", Command: []string{"python3", "-m", "dynamo.vllm"}, Args: []string{"--model", "Qwen"}},
}, },
}, },
}, CheckpointJobOptions{ }, CheckpointJobOptions{
...@@ -118,17 +118,17 @@ func TestNewCheckpointJobPrefersContainerNamedMain(t *testing.T) { ...@@ -118,17 +118,17 @@ func TestNewCheckpointJobPrefersContainerNamedMain(t *testing.T) {
t.Fatalf("expected checkpoint job, got error: %v", err) t.Fatalf("expected checkpoint job, got error: %v", err)
} }
main := requireCheckpointContainer(t, job.Spec.Template.Spec.Containers, "main") worker := requireCheckpointContainer(t, job.Spec.Template.Spec.Containers, "worker")
if len(main.Command) != 1 || main.Command[0] != "cuda-checkpoint" { if len(worker.Command) != 1 || worker.Command[0] != "cuda-checkpoint" {
t.Fatalf("expected main container to be wrapped, got %#v", main.Command) t.Fatalf("expected first container to be wrapped, got %#v", worker.Command)
} }
expectedArgs := []string{"--launch-job", "python3", "-m", "dynamo.vllm", "--model", "Qwen"} expectedArgs := []string{"--launch-job", "python3", "-m", "dynamo.vllm", "--model", "Qwen"}
if len(main.Args) != len(expectedArgs) { if len(worker.Args) != len(expectedArgs) {
t.Fatalf("expected launch-job args %#v, got %#v", expectedArgs, main.Args) t.Fatalf("expected launch-job args %#v, got %#v", expectedArgs, worker.Args)
} }
for i := range expectedArgs { for i := range expectedArgs {
if main.Args[i] != expectedArgs[i] { if worker.Args[i] != expectedArgs[i] {
t.Fatalf("expected launch-job args %#v, got %#v", expectedArgs, main.Args) t.Fatalf("expected launch-job args %#v, got %#v", expectedArgs, worker.Args)
} }
} }
...@@ -168,26 +168,7 @@ func TestNewCheckpointJobAllowsSingleNonMainContainer(t *testing.T) { ...@@ -168,26 +168,7 @@ func TestNewCheckpointJobAllowsSingleNonMainContainer(t *testing.T) {
} }
} }
func TestNewCheckpointJobRejectsMultiContainerPodWithoutMain(t *testing.T) {
_, err := NewCheckpointJob(&corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{Name: "sidecar", Command: []string{"sleep"}, Args: []string{"infinity"}},
{Name: "worker", Command: []string{"python3", "-m", "dynamo.vllm"}},
},
},
}, CheckpointJobOptions{
Namespace: "test-ns",
CheckpointID: "hash",
ArtifactVersion: "2",
Name: "test-job",
TTLSecondsAfterFinish: ptr.To(int32(300)),
WrapLaunchJob: true,
})
if err == nil || err.Error() != "checkpoint job requires a container named \"main\" when multiple containers are present" {
t.Fatalf("expected missing main container error, got %v", err)
}
}
func TestGetCheckpointJobName(t *testing.T) { func TestGetCheckpointJobName(t *testing.T) {
name := GetCheckpointJobName("abc123def4567890", "2") name := GetCheckpointJobName("abc123def4567890", "2")
......
...@@ -49,19 +49,13 @@ func NewRestorePod(pod *corev1.Pod, opts PodOptions) *corev1.Pod { ...@@ -49,19 +49,13 @@ func NewRestorePod(pod *corev1.Pod, opts PodOptions) *corev1.Pod {
return pod return pod
} }
// resolveWorkerContainer returns the workload container, which is always
// Containers[0]. GMS sidecars are appended after the workload.
func resolveWorkerContainer(podSpec *corev1.PodSpec) *corev1.Container { func resolveWorkerContainer(podSpec *corev1.PodSpec) *corev1.Container {
if podSpec == nil { if podSpec == nil || len(podSpec.Containers) == 0 {
return nil return nil
} }
if len(podSpec.Containers) == 1 {
return &podSpec.Containers[0] return &podSpec.Containers[0]
}
for index := range podSpec.Containers {
if podSpec.Containers[index].Name == "main" {
return &podSpec.Containers[index]
}
}
return nil
} }
func PrepareRestorePodSpec( func PrepareRestorePodSpec(
...@@ -73,7 +67,7 @@ func PrepareRestorePodSpec( ...@@ -73,7 +67,7 @@ func PrepareRestorePodSpec(
) { ) {
EnsureLocalhostSeccompProfile(podSpec, seccompProfile) EnsureLocalhostSeccompProfile(podSpec, seccompProfile)
if storage.PVCName != "" { if storage.PVCName != "" {
injectCheckpointVolume(podSpec, storage.PVCName) InjectCheckpointVolume(podSpec, storage.PVCName)
} }
if storage.BasePath != "" { if storage.BasePath != "" {
injectCheckpointVolumeMount(container, storage.BasePath) injectCheckpointVolumeMount(container, storage.BasePath)
...@@ -113,7 +107,7 @@ func ValidateRestorePodSpec( ...@@ -113,7 +107,7 @@ func ValidateRestorePodSpec(
} }
container := resolveWorkerContainer(podSpec) container := resolveWorkerContainer(podSpec)
if container == nil { if container == nil {
return fmt.Errorf("restore target must include a worker container named main") return fmt.Errorf("restore target must have at least one container")
} }
if storage.PVCName != "" { if storage.PVCName != "" {
hasVolume := false hasVolume := false
...@@ -209,19 +203,18 @@ func DiscoverStorageFromDaemonSets(namespace string, daemonSets []appsv1.DaemonS ...@@ -209,19 +203,18 @@ func DiscoverStorageFromDaemonSets(namespace string, daemonSets []appsv1.DaemonS
) )
} }
func PrepareRestorePodSpecForCheckpoint( // DiscoverAndResolveStorage lists snapshot-agent DaemonSets in the given
// namespace, discovers the shared storage configuration, and resolves the
// checkpoint-specific path for the given checkpoint ID and artifact version.
func DiscoverAndResolveStorage(
ctx context.Context, ctx context.Context,
reader ctrlclient.Reader, reader ctrlclient.Reader,
namespace string, namespace string,
podSpec *corev1.PodSpec,
container *corev1.Container,
checkpointID string, checkpointID string,
artifactVersion string, artifactVersion string,
seccompProfile string, ) (Storage, error) {
isCheckpointReady bool,
) error {
if reader == nil { if reader == nil {
return fmt.Errorf("snapshot client is required") return Storage{}, fmt.Errorf("snapshot client is required")
} }
daemonSets := &appsv1.DaemonSetList{} daemonSets := &appsv1.DaemonSetList{}
...@@ -231,24 +224,41 @@ func PrepareRestorePodSpecForCheckpoint( ...@@ -231,24 +224,41 @@ func PrepareRestorePodSpecForCheckpoint(
ctrlclient.InNamespace(namespace), ctrlclient.InNamespace(namespace),
ctrlclient.MatchingLabels{SnapshotAgentLabelKey: SnapshotAgentLabelValue}, ctrlclient.MatchingLabels{SnapshotAgentLabelKey: SnapshotAgentLabelValue},
); err != nil { ); err != nil {
return fmt.Errorf("list snapshot-agent daemonsets in %s: %w", namespace, err) return Storage{}, fmt.Errorf("list snapshot-agent daemonsets in %s: %w", namespace, err)
} }
storage, err := DiscoverStorageFromDaemonSets(namespace, daemonSets.Items) storage, err := DiscoverStorageFromDaemonSets(namespace, daemonSets.Items)
if err != nil { if err != nil {
return err return Storage{}, err
} }
resolvedStorage, err := ResolveCheckpointStorage(checkpointID, artifactVersion, storage) return ResolveCheckpointStorage(checkpointID, artifactVersion, storage)
}
func PrepareRestorePodSpecForCheckpoint(
ctx context.Context,
reader ctrlclient.Reader,
namespace string,
podSpec *corev1.PodSpec,
container *corev1.Container,
checkpointID string,
artifactVersion string,
seccompProfile string,
isCheckpointReady bool,
) error {
storage, err := DiscoverAndResolveStorage(ctx, reader, namespace, checkpointID, artifactVersion)
if err != nil { if err != nil {
return err return err
} }
PrepareRestorePodSpec(podSpec, container, resolvedStorage, seccompProfile, isCheckpointReady) PrepareRestorePodSpec(podSpec, container, storage, seccompProfile, isCheckpointReady)
return nil return nil
} }
func injectCheckpointVolume(podSpec *corev1.PodSpec, pvcName string) { // InjectCheckpointVolume adds the checkpoint PVC volume to the pod spec if
// not already present. Used by both the snapshot protocol and the operator's
// GMS checkpoint wiring.
func InjectCheckpointVolume(podSpec *corev1.PodSpec, pvcName string) {
for _, volume := range podSpec.Volumes { for _, volume := range podSpec.Volumes {
if volume.Name == CheckpointVolumeName { if volume.Name == CheckpointVolumeName {
return return
......
...@@ -187,13 +187,13 @@ func TestPrepareRestorePodSpecSynthesizesStartupProbeFromLiveness(t *testing.T) ...@@ -187,13 +187,13 @@ func TestPrepareRestorePodSpecSynthesizesStartupProbeFromLiveness(t *testing.T)
} }
} }
func TestNewRestorePodTargetsMainContainerWhenSidecarsPresent(t *testing.T) { func TestNewRestorePodTargetsFirstContainerWhenSidecarsPresent(t *testing.T) {
restorePod := NewRestorePod(&corev1.Pod{ restorePod := NewRestorePod(&corev1.Pod{
ObjectMeta: metav1.ObjectMeta{Name: "worker"}, ObjectMeta: metav1.ObjectMeta{Name: "worker"},
Spec: corev1.PodSpec{ Spec: corev1.PodSpec{
Containers: []corev1.Container{ Containers: []corev1.Container{
{Name: "worker", Image: "test:latest", Command: []string{"python3"}, Args: []string{"-m", "dynamo.vllm"}},
{Name: "sidecar", Image: "sidecar:latest", Command: []string{"sidecar"}, Args: []string{"run"}}, {Name: "sidecar", Image: "sidecar:latest", Command: []string{"sidecar"}, Args: []string{"run"}},
{Name: "main", Image: "test:latest", Command: []string{"python3"}, Args: []string{"-m", "dynamo.vllm"}},
}, },
}, },
}, PodOptions{ }, PodOptions{
...@@ -208,14 +208,14 @@ func TestNewRestorePodTargetsMainContainerWhenSidecarsPresent(t *testing.T) { ...@@ -208,14 +208,14 @@ func TestNewRestorePodTargetsMainContainerWhenSidecarsPresent(t *testing.T) {
SeccompProfile: DefaultSeccompLocalhostProfile, SeccompProfile: DefaultSeccompLocalhostProfile,
}) })
if got := restorePod.Spec.Containers[0].Command; len(got) != 1 || got[0] != "sidecar" { if got := restorePod.Spec.Containers[0].Command; len(got) != 2 || got[0] != "sleep" || got[1] != "infinity" {
t.Fatalf("expected sidecar command to remain unchanged, got %#v", got) t.Fatalf("expected first container placeholder command, got %#v", got)
} }
if got := restorePod.Spec.Containers[1].Command; len(got) != 2 || got[0] != "sleep" || got[1] != "infinity" { if restorePod.Spec.Containers[0].Args != nil {
t.Fatalf("expected main container placeholder command, got %#v", got) t.Fatalf("expected first container args to be cleared: %#v", restorePod.Spec.Containers[0].Args)
} }
if restorePod.Spec.Containers[1].Args != nil { if got := restorePod.Spec.Containers[1].Command; len(got) != 1 || got[0] != "sidecar" {
t.Fatalf("expected main container args to be cleared: %#v", restorePod.Spec.Containers[1].Args) t.Fatalf("expected sidecar command to remain unchanged, got %#v", got)
} }
} }
...@@ -311,7 +311,7 @@ func TestValidateRestorePodSpec(t *testing.T) { ...@@ -311,7 +311,7 @@ func TestValidateRestorePodSpec(t *testing.T) {
} }
} }
func TestValidateRestorePodSpecRequiresMainContainerWhenMultiContainer(t *testing.T) { func TestValidateRestorePodSpecAcceptsFirstContainerAsWorker(t *testing.T) {
profile := DefaultSeccompLocalhostProfile profile := DefaultSeccompLocalhostProfile
podSpec := &corev1.PodSpec{ podSpec := &corev1.PodSpec{
SecurityContext: &corev1.PodSecurityContext{ SecurityContext: &corev1.PodSecurityContext{
...@@ -346,12 +346,13 @@ func TestValidateRestorePodSpecRequiresMainContainerWhenMultiContainer(t *testin ...@@ -346,12 +346,13 @@ func TestValidateRestorePodSpecRequiresMainContainerWhenMultiContainer(t *testin
BasePath: "/checkpoints", BasePath: "/checkpoints",
} }
if err := ValidateRestorePodSpec(podSpec, storage, DefaultSeccompLocalhostProfile); err == nil || err.Error() != "restore target must include a worker container named main" { // Containers[0] is always the worker, regardless of name
t.Fatalf("expected multi-container restore target without main to be rejected, got %v", err) if err := ValidateRestorePodSpec(podSpec, storage, DefaultSeccompLocalhostProfile); err != nil {
t.Fatalf("expected validation to pass for first container as worker, got %v", err)
} }
} }
func TestValidateRestorePodSpecAllowsMainContainerWithSidecars(t *testing.T) { func TestValidateRestorePodSpecAllowsWorkerWithSidecars(t *testing.T) {
profile := DefaultSeccompLocalhostProfile profile := DefaultSeccompLocalhostProfile
podSpec := &corev1.PodSpec{ podSpec := &corev1.PodSpec{
SecurityContext: &corev1.PodSecurityContext{ SecurityContext: &corev1.PodSecurityContext{
...@@ -369,14 +370,14 @@ func TestValidateRestorePodSpecAllowsMainContainerWithSidecars(t *testing.T) { ...@@ -369,14 +370,14 @@ func TestValidateRestorePodSpecAllowsMainContainerWithSidecars(t *testing.T) {
}, },
}}, }},
Containers: []corev1.Container{ Containers: []corev1.Container{
{Name: "sidecar"},
{ {
Name: "main", Name: "worker",
VolumeMounts: []corev1.VolumeMount{{ VolumeMounts: []corev1.VolumeMount{{
Name: CheckpointVolumeName, Name: CheckpointVolumeName,
MountPath: "/checkpoints", MountPath: "/checkpoints",
}}, }},
}, },
{Name: "sidecar"},
}, },
} }
...@@ -387,7 +388,7 @@ func TestValidateRestorePodSpecAllowsMainContainerWithSidecars(t *testing.T) { ...@@ -387,7 +388,7 @@ func TestValidateRestorePodSpecAllowsMainContainerWithSidecars(t *testing.T) {
} }
if err := ValidateRestorePodSpec(podSpec, storage, DefaultSeccompLocalhostProfile); err != nil { if err := ValidateRestorePodSpec(podSpec, storage, DefaultSeccompLocalhostProfile); err != nil {
t.Fatalf("expected main container with sidecars to validate, got %v", err) t.Fatalf("expected worker with sidecars to validate, got %v", err)
} }
} }
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
"""GMS server entry point. """GMS server entry point.
Launches two GMS server processes per GPU (one for weights, one for kv_cache). Launches two GMS server processes per GPU (one for weights, one for kv_cache).
Writes a ready file once all expected UDS sockets are present. Monitors an Writes a ready file once all expected UDS sockets are present. Runs until
optional checkpoint stop file and shuts down cleanly when it appears. SIGTERM (pod termination kills it).
""" """
from __future__ import annotations from __future__ import annotations
...@@ -18,6 +18,7 @@ import sys ...@@ -18,6 +18,7 @@ import sys
import time import time
from pathlib import Path from pathlib import Path
from gpu_memory_service.common.cuda_utils import list_devices
from gpu_memory_service.common.utils import get_socket_path from gpu_memory_service.common.utils import get_socket_path
logging.basicConfig( logging.basicConfig(
...@@ -30,36 +31,11 @@ _TAGS = ("weights", "kv_cache") ...@@ -30,36 +31,11 @@ _TAGS = ("weights", "kv_cache")
_READY_FILE = "gms-ready" _READY_FILE = "gms-ready"
def _ready_file_path() -> Path:
return Path(os.environ.get("GMS_SOCKET_DIR", "/tmp")) / _READY_FILE
def _list_devices() -> list[int]:
import pynvml
pynvml.nvmlInit()
try:
count = pynvml.nvmlDeviceGetCount()
finally:
pynvml.nvmlShutdown()
if count == 0:
raise SystemExit("no nvidia devices found")
return list(range(count))
def _optional_checkpoint_stop_file() -> Path | None:
control_dir = os.environ.get("GMS_CONTROL_DIR")
if not control_dir:
return None
return Path(control_dir) / "checkpoint-done"
def main() -> None: def main() -> None:
ready_file = _ready_file_path() ready_file = Path(os.environ.get("GMS_SOCKET_DIR", "/tmp")) / _READY_FILE
ready_file.unlink(missing_ok=True) ready_file.unlink(missing_ok=True)
devices = _list_devices() devices = list_devices()
processes = [] processes = []
for device in devices: for device in devices:
for tag in _TAGS: for tag in _TAGS:
...@@ -89,14 +65,8 @@ def main() -> None: ...@@ -89,14 +65,8 @@ def main() -> None:
signal.signal(signal.SIGTERM, terminate) signal.signal(signal.SIGTERM, terminate)
signal.signal(signal.SIGINT, terminate) signal.signal(signal.SIGINT, terminate)
stop_file = _optional_checkpoint_stop_file()
ready_written = False ready_written = False
while True: while True:
stop_requested = stop_file is not None and stop_file.exists()
if stop_requested:
logger.info("checkpoint stop requested; shutting down GMS servers")
shutdown()
if not ready_written: if not ready_written:
sockets_ready = all( sockets_ready = all(
os.path.exists(get_socket_path(device, tag)) os.path.exists(get_socket_path(device, tag))
...@@ -113,8 +83,6 @@ def main() -> None: ...@@ -113,8 +83,6 @@ def main() -> None:
if exit_code is None: if exit_code is None:
running = True running = True
continue continue
if stop_requested:
continue
shutdown() shutdown()
raise SystemExit(exit_code) raise SystemExit(exit_code)
......
...@@ -12,10 +12,13 @@ from __future__ import annotations ...@@ -12,10 +12,13 @@ from __future__ import annotations
import logging import logging
import os import os
import signal
import sys
import time import time
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from gpu_memory_service.common.utils import get_socket_path from gpu_memory_service.common.cuda_utils import list_devices
from gpu_memory_service.common.utils import get_socket_path, wait_for_weights_socket
from gpu_memory_service.snapshot.storage_client import GMSStorageClient from gpu_memory_service.snapshot.storage_client import GMSStorageClient
logging.basicConfig( logging.basicConfig(
...@@ -24,32 +27,9 @@ logging.basicConfig( ...@@ -24,32 +27,9 @@ logging.basicConfig(
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_WEIGHTS_TAG = "weights"
_DEFAULT_MAX_WORKERS = 8
def _list_devices() -> list[int]:
import pynvml
pynvml.nvmlInit()
try:
count = pynvml.nvmlDeviceGetCount()
finally:
pynvml.nvmlShutdown()
if count == 0:
raise SystemExit("no nvidia devices found")
return list(range(count))
def _wait_for_weights_socket(device: int) -> None:
socket_path = get_socket_path(device, _WEIGHTS_TAG)
while not os.path.exists(socket_path):
time.sleep(1)
def _load_device(checkpoint_dir: str, device: int, max_workers: int) -> None: def _load_device(checkpoint_dir: str, device: int, max_workers: int) -> None:
_wait_for_weights_socket(device) wait_for_weights_socket(device)
input_dir = os.path.join(checkpoint_dir, f"device-{device}") input_dir = os.path.join(checkpoint_dir, f"device-{device}")
logger.info("Loading GMS checkpoint: device=%d input_dir=%s", device, input_dir) logger.info("Loading GMS checkpoint: device=%d input_dir=%s", device, input_dir)
t0 = time.monotonic() t0 = time.monotonic()
...@@ -68,8 +48,8 @@ def _load_device(checkpoint_dir: str, device: int, max_workers: int) -> None: ...@@ -68,8 +48,8 @@ def _load_device(checkpoint_dir: str, device: int, max_workers: int) -> None:
def main() -> None: def main() -> None:
checkpoint_dir = os.environ["GMS_CHECKPOINT_DIR"] checkpoint_dir = os.environ["GMS_CHECKPOINT_DIR"]
max_workers = int(os.environ.get("GMS_LOAD_WORKERS", str(_DEFAULT_MAX_WORKERS))) max_workers = int(os.environ.get("GMS_LOAD_WORKERS", "8"))
devices = _list_devices() devices = list_devices()
t0 = time.monotonic() t0 = time.monotonic()
with ThreadPoolExecutor(max_workers=len(devices)) as pool: with ThreadPoolExecutor(max_workers=len(devices)) as pool:
...@@ -83,6 +63,7 @@ def main() -> None: ...@@ -83,6 +63,7 @@ def main() -> None:
logger.info("Device %d load complete", dev) logger.info("Device %d load complete", dev)
elapsed = time.monotonic() - t0 elapsed = time.monotonic() - t0
logger.info("All %d devices loaded in %.2fs", len(devices), elapsed) logger.info("All %d devices loaded in %.2fs", len(devices), elapsed)
signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))
while True: while True:
time.sleep(3600) time.sleep(3600)
......
...@@ -3,25 +3,22 @@ ...@@ -3,25 +3,22 @@
"""GMS checkpoint saver entry point. """GMS checkpoint saver entry point.
Waits for the checkpoint pod to reach Ready=True, then saves GMS state from Waits for committed GMS weights on each device, then saves GPU memory state
each device in parallel. Writes a stop file to signal the GMS server to shut to the checkpoint directory. Runs as an init sidecar — sleeps after saving
down after save completes. until the pod terminates.
""" """
from __future__ import annotations from __future__ import annotations
import json
import logging import logging
import os import os
import ssl import signal
import sys
import time import time
import urllib.error
import urllib.request
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Any
from gpu_memory_service.common.utils import get_socket_path from gpu_memory_service.common.cuda_utils import list_devices
from gpu_memory_service.common.utils import get_socket_path, wait_for_weights_socket
from gpu_memory_service.snapshot.storage_client import GMSStorageClient from gpu_memory_service.snapshot.storage_client import GMSStorageClient
logging.basicConfig( logging.basicConfig(
...@@ -30,127 +27,42 @@ logging.basicConfig( ...@@ -30,127 +27,42 @@ logging.basicConfig(
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_WEIGHTS_TAG = "weights"
_SERVICE_ACCOUNT_TOKEN = Path("/var/run/secrets/kubernetes.io/serviceaccount/token")
_SERVICE_ACCOUNT_CA = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt"
def _save_device(checkpoint_dir: str, device: int, max_workers: int) -> None:
def _list_devices() -> list[int]: wait_for_weights_socket(device)
import pynvml
pynvml.nvmlInit()
try:
count = pynvml.nvmlDeviceGetCount()
finally:
pynvml.nvmlShutdown()
if count == 0:
raise SystemExit("no nvidia devices found")
return list(range(count))
def _wait_for_weights_socket(device: int) -> None:
socket_path = get_socket_path(device, _WEIGHTS_TAG)
while not os.path.exists(socket_path):
time.sleep(1)
def _checkpoint_pod_ready(pod: dict[str, Any]) -> bool:
status = pod.get("status") or {}
if str(status.get("phase", "")).strip() != "Running":
return False
for condition in status.get("conditions") or []:
if (
condition.get("type") == "Ready"
and str(condition.get("status", "")).strip() == "True"
):
return True
return False
def _main_terminated(pod: dict[str, Any]) -> bool:
status = pod.get("status") or {}
for container in status.get("containerStatuses") or []:
if container.get("name") != "main":
continue
return bool((container.get("state") or {}).get("terminated"))
return False
def main() -> None:
service_token = _SERVICE_ACCOUNT_TOKEN.read_text(encoding="utf-8").strip()
ssl_context = ssl.create_default_context(cafile=_SERVICE_ACCOUNT_CA)
pod_api_url = (
"https://"
+ os.environ["KUBERNETES_SERVICE_HOST"]
+ ":"
+ os.environ.get("KUBERNETES_SERVICE_PORT_HTTPS", "443")
+ f"/api/v1/namespaces/{os.environ['POD_NAMESPACE']}/pods/{os.environ['POD_NAME']}"
)
checkpoint_dir = os.environ["GMS_CHECKPOINT_DIR"]
def checkpoint_pod() -> dict[str, Any]:
request = urllib.request.Request(
pod_api_url,
headers={"Authorization": f"Bearer {service_token}"},
)
with urllib.request.urlopen(
request,
context=ssl_context,
timeout=5,
) as response:
return json.load(response)
logger.info("Waiting for checkpoint pod Ready=True before GMS save")
while True:
try:
pod = checkpoint_pod()
except (urllib.error.URLError, TimeoutError, ssl.SSLError, OSError):
time.sleep(1)
continue
if _checkpoint_pod_ready(pod):
break
if _main_terminated(pod):
raise SystemExit("main container terminated before GMS save could start")
time.sleep(1)
def _save_device(device: int, max_workers: int) -> None:
_wait_for_weights_socket(device)
output_dir = os.path.join(checkpoint_dir, f"device-{device}") output_dir = os.path.join(checkpoint_dir, f"device-{device}")
logger.info( logger.info("Saving GMS checkpoint: device=%d output_dir=%s", device, output_dir)
"Saving GMS checkpoint: device=%d output_dir=%s",
device,
output_dir,
)
t0 = time.monotonic() t0 = time.monotonic()
client = GMSStorageClient( GMSStorageClient(
output_dir, output_dir,
socket_path=get_socket_path(device), socket_path=get_socket_path(device),
device=device, device=device,
) ).save(max_workers=max_workers)
client.save(max_workers=max_workers)
elapsed = time.monotonic() - t0 elapsed = time.monotonic() - t0
logger.info("GMS checkpoint saved: device=%d elapsed=%.2fs", device, elapsed) logger.info("GMS checkpoint saved: device=%d elapsed=%.2fs", device, elapsed)
def main() -> None:
checkpoint_dir = os.environ["GMS_CHECKPOINT_DIR"]
max_workers = int(os.environ.get("GMS_SAVE_WORKERS", "8")) max_workers = int(os.environ.get("GMS_SAVE_WORKERS", "8"))
logger.info("Checkpoint pod is Ready; starting GMS save")
try: devices = list_devices()
devices = _list_devices() logger.info("Starting GMS save for %d devices", len(devices))
t0 = time.monotonic() t0 = time.monotonic()
with ThreadPoolExecutor(max_workers=len(devices)) as pool: with ThreadPoolExecutor(max_workers=len(devices)) as pool:
futures = { futures = {
pool.submit(_save_device, dev, max_workers): dev for dev in devices pool.submit(_save_device, checkpoint_dir, dev, max_workers): dev
for dev in devices
} }
for future in as_completed(futures): for future in as_completed(futures):
future.result() future.result()
elapsed = time.monotonic() - t0 elapsed = time.monotonic() - t0
logger.info("All %d devices saved in %.2fs", len(devices), elapsed) logger.info("All %d devices saved in %.2fs", len(devices), elapsed)
finally:
(Path(os.environ["GMS_CONTROL_DIR"]) / "checkpoint-done").write_text( logger.info("Save complete; sleeping until pod terminates")
"done", signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))
encoding="utf-8", while True:
) time.sleep(3600)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -24,6 +24,20 @@ except ImportError: ...@@ -24,6 +24,20 @@ except ImportError:
cuda = _MissingCuda() cuda = _MissingCuda()
def list_devices() -> list[int]:
"""Return list of CUDA device indices visible to this process via NVML."""
import pynvml
pynvml.nvmlInit()
try:
count = pynvml.nvmlDeviceGetCount()
finally:
pynvml.nvmlShutdown()
if count == 0:
raise SystemExit("no nvidia devices found")
return list(range(count))
def cuda_check_result(result: cuda.CUresult, name: str) -> None: def cuda_check_result(result: cuda.CUresult, name: str) -> None:
if result != cuda.CUresult.CUDA_SUCCESS: if result != cuda.CUresult.CUDA_SUCCESS:
err_result, err_str = cuda.cuGetErrorString(result) err_result, err_str = cuda.cuGetErrorString(result)
......
...@@ -51,3 +51,12 @@ def get_socket_path(device: int, tag: str = "weights") -> str: ...@@ -51,3 +51,12 @@ def get_socket_path(device: int, tag: str = "weights") -> str:
_uuid_cache[device] = uuid _uuid_cache[device] = uuid
socket_dir = os.environ.get("GMS_SOCKET_DIR") or tempfile.gettempdir() socket_dir = os.environ.get("GMS_SOCKET_DIR") or tempfile.gettempdir()
return os.path.join(socket_dir, f"gms_{uuid}_{tag}.sock") return os.path.join(socket_dir, f"gms_{uuid}_{tag}.sock")
def wait_for_weights_socket(device: int) -> None:
"""Block until the GMS weights socket for the given device exists."""
import time
path = get_socket_path(device, "weights")
while not os.path.exists(path):
time.sleep(0.1)
...@@ -66,6 +66,7 @@ setup( ...@@ -66,6 +66,7 @@ setup(
packages=[ packages=[
"gpu_memory_service", "gpu_memory_service",
"gpu_memory_service.cli", "gpu_memory_service.cli",
"gpu_memory_service.cli.snapshot",
"gpu_memory_service.common", "gpu_memory_service.common",
"gpu_memory_service.common.protocol", "gpu_memory_service.common.protocol",
"gpu_memory_service.server", "gpu_memory_service.server",
...@@ -79,10 +80,12 @@ setup( ...@@ -79,10 +80,12 @@ setup(
"gpu_memory_service.integrations.sglang", "gpu_memory_service.integrations.sglang",
"gpu_memory_service.integrations.trtllm", "gpu_memory_service.integrations.trtllm",
"gpu_memory_service.integrations.vllm", "gpu_memory_service.integrations.vllm",
"gpu_memory_service.snapshot",
], ],
package_dir={ package_dir={
"gpu_memory_service": ".", "gpu_memory_service": ".",
"gpu_memory_service.cli": "cli", "gpu_memory_service.cli": "cli",
"gpu_memory_service.cli.snapshot": "cli/snapshot",
"gpu_memory_service.common": "common", "gpu_memory_service.common": "common",
"gpu_memory_service.common.protocol": "common/protocol", "gpu_memory_service.common.protocol": "common/protocol",
"gpu_memory_service.server": "server", "gpu_memory_service.server": "server",
...@@ -96,6 +99,7 @@ setup( ...@@ -96,6 +99,7 @@ setup(
"gpu_memory_service.integrations.sglang": "integrations/sglang", "gpu_memory_service.integrations.sglang": "integrations/sglang",
"gpu_memory_service.integrations.trtllm": "integrations/trtllm", "gpu_memory_service.integrations.trtllm": "integrations/trtllm",
"gpu_memory_service.integrations.vllm": "integrations/vllm", "gpu_memory_service.integrations.vllm": "integrations/vllm",
"gpu_memory_service.snapshot": "snapshot",
}, },
package_data={ package_data={
"gpu_memory_service.client.torch.extensions": ["*.cpp"], "gpu_memory_service.client.torch.extensions": ["*.cpp"],
......
...@@ -245,6 +245,12 @@ def running_gms(monkeypatch, tmp_path): ...@@ -245,6 +245,12 @@ def running_gms(monkeypatch, tmp_path):
server_allocations, "cumem_export_to_shareable_handle", export_fd server_allocations, "cumem_export_to_shareable_handle", export_fd
) )
monkeypatch.setattr(
client_memory_manager,
"cuda_set_current_device",
lambda device: None,
raising=False,
)
monkeypatch.setattr( monkeypatch.setattr(
client_memory_manager, client_memory_manager,
"cumem_get_allocation_granularity", "cumem_get_allocation_granularity",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment