"...ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "3f53a78e036721d367f8cbf9b3087de8b8666059"
Unverified Commit b2aefc53 authored by Janelle Cai's avatar Janelle Cai Committed by GitHub
Browse files

chore(snapshot): gms + snapshot support (#7026)


Signed-off-by: default avatarJanelle Cai <jcai18@mit.edu>
parent 9616c86f
...@@ -4,6 +4,9 @@ ...@@ -4,6 +4,9 @@
"""Shared Dynamo snapshot helpers for checkpoint lifecycle.""" """Shared Dynamo snapshot helpers for checkpoint lifecycle."""
import asyncio import asyncio
import ctypes
import ctypes.util
import gc
import logging import logging
import os import os
import signal import signal
...@@ -234,3 +237,48 @@ def reload_snapshot_restore_identity() -> tuple[str, str]: ...@@ -234,3 +237,48 @@ def reload_snapshot_restore_identity() -> tuple[str, str]:
# Snapshot restore only runs in Kubernetes-managed pods, so discovery resets here. # Snapshot restore only runs in Kubernetes-managed pods, so discovery resets here.
os.environ["DYN_DISCOVERY_BACKEND"] = "kubernetes" os.environ["DYN_DISCOVERY_BACKEND"] = "kubernetes"
return get_worker_namespace(), "kubernetes" return get_worker_namespace(), "kubernetes"
def _try_release_memory(label: str) -> None:
"""Force Python GC and glibc malloc_trim to return freed memory to the OS.
Logs RSS before/after so you can see how much memory was actually reclaimable.
"""
pid = os.getpid()
def _get_rss_kb() -> int:
try:
with open(f"/proc/{pid}/status") as f:
for line in f:
if line.startswith("VmRSS:"):
return int(line.split()[1])
except Exception:
pass
return 0
rss_before = _get_rss_kb()
collected = gc.collect()
rss_after_gc = _get_rss_kb()
try:
libc_name = ctypes.util.find_library("c")
if libc_name:
libc = ctypes.CDLL(libc_name)
libc.malloc_trim(0)
except Exception as e:
logger.debug("[MemRelease:%s] malloc_trim failed: %s", label, e)
rss_after_trim = _get_rss_kb()
logger.info(
"[MemRelease:%s] gc.collect freed %d objects, "
"RSS: %.2f MiB -> %.2f MiB (gc) -> %.2f MiB (malloc_trim), "
"reclaimed=%.2f MiB",
label,
collected,
rss_before / 1024,
rss_after_gc / 1024,
rss_after_trim / 1024,
(rss_before - rss_after_trim) / 1024,
)
...@@ -3,12 +3,17 @@ ...@@ -3,12 +3,17 @@
"""Dynamo Snapshot integration for SGLang workers.""" """Dynamo Snapshot integration for SGLang workers."""
import logging import logging
import time import time
import sglang as sgl import sglang as sgl
from dynamo.common.utils.snapshot import CheckpointConfig, EngineSnapshotController from dynamo.common.utils.snapshot import (
CheckpointConfig,
EngineSnapshotController,
_try_release_memory,
)
from .request_handlers.handler_base import SGLangEngineQuiesceController from .request_handlers.handler_base import SGLangEngineQuiesceController
...@@ -37,10 +42,18 @@ async def prepare_snapshot_engine( ...@@ -37,10 +42,18 @@ async def prepare_snapshot_engine(
logger.info("Checkpoint mode enabled (watcher-driven signals)") logger.info("Checkpoint mode enabled (watcher-driven signals)")
# Enable memory_saver + weights CPU backup so weights survive CRIU # Enable memory_saver so GPU memory can be released for CRIU.
# (mirrors vLLM's enable_sleep_mode = True) # When using GMS, weights use VA-stable unmap/remap (no CPU backup); GMS
# forbids enable_weights_cpu_backup. Otherwise use CPU backup for weights.
server_args.enable_memory_saver = True server_args.enable_memory_saver = True
server_args.enable_weights_cpu_backup = True try:
from gpu_memory_service.integrations.sglang import is_gms_active
_using_gms = is_gms_active()
except ImportError:
_using_gms = False
if not _using_gms:
server_args.enable_weights_cpu_backup = True
start_time = time.time() start_time = time.time()
engine = sgl.Engine(server_args=server_args) engine = sgl.Engine(server_args=server_args)
...@@ -48,6 +61,8 @@ async def prepare_snapshot_engine( ...@@ -48,6 +61,8 @@ async def prepare_snapshot_engine(
f"SGLang engine loaded in {time.time() - start_time:.2f}s (checkpoint mode)" f"SGLang engine loaded in {time.time() - start_time:.2f}s (checkpoint mode)"
) )
_try_release_memory("after_engine_load")
snapshot_controller = EngineSnapshotController( snapshot_controller = EngineSnapshotController(
engine=engine, engine=engine,
quiesce_controller=SGLangEngineQuiesceController(engine), quiesce_controller=SGLangEngineQuiesceController(engine),
......
...@@ -197,7 +197,15 @@ def setup_metrics_collection( ...@@ -197,7 +197,15 @@ def setup_metrics_collection(
registry=DYNAMO_COMPONENT_REGISTRY, registry=DYNAMO_COMPONENT_REGISTRY,
) )
if os.environ.get("PROMETHEUS_MULTIPROC_DIR"): multiproc_dir = os.environ.get("PROMETHEUS_MULTIPROC_DIR")
# After CRIU restore to another node, env still has the checkpoint pod's path
# but that directory exists only on the checkpoint node; create it here if missing.
if multiproc_dir and not os.path.isdir(multiproc_dir):
try:
os.makedirs(multiproc_dir, exist_ok=True)
except OSError:
pass
if multiproc_dir and os.path.isdir(multiproc_dir):
try: try:
# MultiProcessCollector reads metrics from .db files in PROMETHEUS_MULTIPROC_DIR # MultiProcessCollector reads metrics from .db files in PROMETHEUS_MULTIPROC_DIR
# Adding it to REGISTRY allows collecting both in-memory and .db file metrics # Adding it to REGISTRY allows collecting both in-memory and .db file metrics
...@@ -243,6 +251,11 @@ def setup_metrics_collection( ...@@ -243,6 +251,11 @@ def setup_metrics_collection(
model_name=config.model, model_name=config.model,
) )
else: else:
if multiproc_dir:
logger.warning(
f"PROMETHEUS_MULTIPROC_DIR={multiproc_dir} is not a valid directory, "
"falling back to single-process metrics"
)
# No multiprocess mode # No multiprocess mode
register_engine_metrics_callback( register_engine_metrics_callback(
endpoint=generate_endpoint, endpoint=generate_endpoint,
...@@ -387,6 +400,12 @@ def setup_vllm_engine( ...@@ -387,6 +400,12 @@ def setup_vllm_engine(
# instead of .name string, causing false error on exit. Set PROMETHEUS_MULTIPROC_DIR # instead of .name string, causing false error on exit. Set PROMETHEUS_MULTIPROC_DIR
# ourselves to avoid this and handle cleanup properly. # ourselves to avoid this and handle cleanup properly.
prometheus_temp_dir = None prometheus_temp_dir = None
existing_dir = os.environ.get("PROMETHEUS_MULTIPROC_DIR")
if existing_dir and not os.path.isdir(existing_dir):
logger.warning(
f"PROMETHEUS_MULTIPROC_DIR={existing_dir} does not exist, recreating"
)
os.makedirs(existing_dir, exist_ok=True)
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ: if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
prometheus_temp_dir = tempfile.TemporaryDirectory(prefix="vllm_prometheus_") prometheus_temp_dir = tempfile.TemporaryDirectory(prefix="vllm_prometheus_")
os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_temp_dir.name os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_temp_dir.name
......
...@@ -4,7 +4,11 @@ ...@@ -4,7 +4,11 @@
import logging import logging
from collections.abc import Callable from collections.abc import Callable
from dynamo.common.utils.snapshot import CheckpointConfig, EngineSnapshotController from dynamo.common.utils.snapshot import (
CheckpointConfig,
EngineSnapshotController,
_try_release_memory,
)
from .args import Config from .args import Config
from .handlers import VllmEngineQuiesceController from .handlers import VllmEngineQuiesceController
...@@ -32,6 +36,7 @@ async def prepare_snapshot_engine( ...@@ -32,6 +36,7 @@ async def prepare_snapshot_engine(
config.engine_args.enable_sleep_mode = True config.engine_args.enable_sleep_mode = True
engine = setup_vllm_engine(config) engine = setup_vllm_engine(config)
_try_release_memory("after_engine_load")
snapshot_controller = EngineSnapshotController( snapshot_controller = EngineSnapshotController(
engine=engine, engine=engine,
quiesce_controller=VllmEngineQuiesceController(engine[0]), quiesce_controller=VllmEngineQuiesceController(engine[0]),
......
...@@ -28,6 +28,10 @@ rules: ...@@ -28,6 +28,10 @@ rules:
- apiGroups: [""] - apiGroups: [""]
resources: ["events"] resources: ["events"]
verbs: ["create"] verbs: ["create"]
# Resolve DRA GPU UUIDs via ResourceClaim allocation (namespace-scoped)
- apiGroups: ["resource.k8s.io"]
resources: ["resourceclaims"]
verbs: ["get", "list"]
{{- else }} {{- else }}
apiVersion: rbac.authorization.k8s.io/v1 apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRole kind: ClusterRole
...@@ -53,5 +57,25 @@ rules: ...@@ -53,5 +57,25 @@ rules:
- apiGroups: [""] - apiGroups: [""]
resources: ["events"] resources: ["events"]
verbs: ["create"] verbs: ["create"]
# Resolve DRA GPU UUIDs via ResourceClaim and ResourceSlice
- apiGroups: ["resource.k8s.io"]
resources: ["resourceclaims", "resourceslices"]
verbs: ["get", "list"]
{{- end }} {{- end }}
{{- end }} {{- end }}
{{- if and .Values.rbac.create .Values.rbac.namespaceRestricted }}
---
# ResourceSlices are cluster-scoped; agent needs this when using a namespace-restricted Role
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRole
metadata:
name: {{ include "snapshot.fullname" . }}-agent-resourceslices
labels:
{{- include "snapshot.labels" . | nindent 4 }}
app.kubernetes.io/component: checkpoint-agent
rules:
- apiGroups: ["resource.k8s.io"]
resources: ["resourceslices"]
verbs: ["get", "list"]
{{- end }}
...@@ -19,6 +19,23 @@ subjects: ...@@ -19,6 +19,23 @@ subjects:
- kind: ServiceAccount - kind: ServiceAccount
name: {{ include "snapshot.serviceAccountName" . }} name: {{ include "snapshot.serviceAccountName" . }}
namespace: {{ .Release.Namespace }} namespace: {{ .Release.Namespace }}
---
# Bind agent to ClusterRole for cluster-scoped ResourceSlices (DRA GPU UUID lookup)
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRoleBinding
metadata:
name: {{ include "snapshot.fullname" . }}-agent-resourceslices
labels:
{{- include "snapshot.labels" . | nindent 4 }}
app.kubernetes.io/component: checkpoint-agent
roleRef:
apiGroup: rbac.authorization.k8s.io
kind: ClusterRole
name: {{ include "snapshot.fullname" . }}-agent-resourceslices
subjects:
- kind: ServiceAccount
name: {{ include "snapshot.serviceAccountName" . }}
namespace: {{ .Release.Namespace }}
{{- else }} {{- else }}
apiVersion: rbac.authorization.k8s.io/v1 apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRoleBinding kind: ClusterRoleBinding
...@@ -37,4 +54,3 @@ subjects: ...@@ -37,4 +54,3 @@ subjects:
namespace: {{ .Release.Namespace }} namespace: {{ .Release.Namespace }}
{{- end }} {{- end }}
{{- end }} {{- end }}
...@@ -414,6 +414,7 @@ func (w *NodeController) runCheckpoint(ctx context.Context, pod *corev1.Pod, job ...@@ -414,6 +414,7 @@ func (w *NodeController) runCheckpoint(ctx context.Context, pod *corev1.Pod, job
NodeName: w.config.NodeName, NodeName: w.config.NodeName,
PodName: pod.Name, PodName: pod.Name,
PodNamespace: pod.Namespace, PodNamespace: pod.Namespace,
Clientset: w.clientset,
} }
if err := executor.Checkpoint(leaseCtx, w.containerd, log, req, w.config); err != nil { if err := executor.Checkpoint(leaseCtx, w.containerd, log, req, w.config); err != nil {
if cause := context.Cause(leaseCtx); cause != nil && cause != context.Canceled { if cause := context.Cause(leaseCtx); cause != nil && cause != context.Canceled {
...@@ -512,6 +513,7 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai ...@@ -512,6 +513,7 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai
PodName: pod.Name, PodName: pod.Name,
PodNamespace: pod.Namespace, PodNamespace: pod.Namespace,
ContainerName: containerName, ContainerName: containerName,
Clientset: w.clientset,
} }
restoredPID, err := executor.Restore(ctx, w.containerd, log, req) restoredPID, err := executor.Restore(ctx, w.containerd, log, req)
if err != nil { if err != nil {
......
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"context" "context"
"fmt" "fmt"
"os/exec" "os/exec"
"regexp"
"strconv" "strconv"
"strings" "strings"
...@@ -15,13 +16,17 @@ import ( ...@@ -15,13 +16,17 @@ import (
podresourcesv1 "k8s.io/kubelet/pkg/apis/podresources/v1" podresourcesv1 "k8s.io/kubelet/pkg/apis/podresources/v1"
) )
const nvidiaGPUResource = "nvidia.com/gpu" const (
nvidiaGPUResource = "nvidia.com/gpu"
nvidiaGPUDRADriver = "gpu.nvidia.com"
)
var podResourcesSocketPath = "/var/lib/kubelet/pod-resources/kubelet.sock" var podResourcesSocketPath = "/var/lib/kubelet/pod-resources/kubelet.sock"
// GetPodGPUUUIDs resolves GPU UUIDs for a pod/container from the kubelet PodResources API. var gpuUUIDPattern = regexp.MustCompile(`^GPU-[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}$`)
// All nvidia.com/gpu device entries are accumulated in case the kubelet splits them
// across multiple entries (observed in some runtimes with multi-GPU pods). // GetPodGPUUUIDs resolves GPU UUIDs for a pod/container from kubelet
// PodResources (nvidia.com/gpu entries in GetDevices()).
func GetPodGPUUUIDs(ctx context.Context, podName, podNamespace, containerName string) ([]string, error) { func GetPodGPUUUIDs(ctx context.Context, podName, podNamespace, containerName string) ([]string, error) {
if podName == "" || podNamespace == "" { if podName == "" || podNamespace == "" {
return nil, nil return nil, nil
...@@ -56,12 +61,40 @@ func GetPodGPUUUIDs(ctx context.Context, podName, podNamespace, containerName st ...@@ -56,12 +61,40 @@ func GetPodGPUUUIDs(ctx context.Context, podName, podNamespace, containerName st
uuids = append(uuids, device.GetDeviceIds()...) uuids = append(uuids, device.GetDeviceIds()...)
} }
} }
} }
} }
return uuids, nil return uuids, nil
} }
// GetGPUUUIDsViaNvidiaSmi discovers GPU UUIDs by running nvidia-smi inside the
// container's mount namespace. This is the fallback path when the kubelet
// PodResources API does not report GPU devices (e.g. when GPUs are allocated
// via DRA instead of the NVIDIA device plugin).
func GetGPUUUIDsViaNvidiaSmi(ctx context.Context, hostProcPath string, pid int) ([]string, error) {
mountPath := fmt.Sprintf("%s/%d/ns/mnt", strings.TrimRight(hostProcPath, "/"), pid)
cmd := exec.CommandContext(
ctx,
"nsenter",
fmt.Sprintf("--mount=%s", mountPath),
"--",
"nvidia-smi", "--query-gpu=gpu_uuid", "--format=csv,noheader",
)
output, err := cmd.Output()
if err != nil {
return nil, fmt.Errorf("nvidia-smi via nsenter (pid %d) failed: %w", pid, err)
}
var uuids []string
for _, line := range strings.Split(strings.TrimSpace(string(output)), "\n") {
line = strings.TrimSpace(line)
if line != "" {
uuids = append(uuids, line)
}
}
return uuids, nil
}
// FilterProcesses returns the subset of candidate PIDs that hold actual CUDA contexts. // FilterProcesses returns the subset of candidate PIDs that hold actual CUDA contexts.
// Uses --get-restore-tid (the same technique as the CRIU CUDA plugin) instead of // Uses --get-restore-tid (the same technique as the CRIU CUDA plugin) instead of
// --get-state, because --get-state incorrectly matches coordinator processes like // --get-state, because --get-state incorrectly matches coordinator processes like
...@@ -93,13 +126,14 @@ func FilterProcesses(ctx context.Context, allPIDs []int, log logr.Logger) []int ...@@ -93,13 +126,14 @@ func FilterProcesses(ctx context.Context, allPIDs []int, log logr.Logger) []int
// When a source UUID exists in the target set, it maps to itself (identity mapping) to avoid // When a source UUID exists in the target set, it maps to itself (identity mapping) to avoid
// unnecessary cross-GPU restore on same-node restores where kubelet returns GPUs in different order. // unnecessary cross-GPU restore on same-node restores where kubelet returns GPUs in different order.
// Remaining unmatched source UUIDs are paired with remaining unmatched target UUIDs positionally. // Remaining unmatched source UUIDs are paired with remaining unmatched target UUIDs positionally.
func BuildDeviceMap(sourceUUIDs, targetUUIDs []string) (string, error) { func BuildDeviceMap(sourceUUIDs, targetUUIDs []string, log logr.Logger) (string, error) {
if len(sourceUUIDs) != len(targetUUIDs) { if len(sourceUUIDs) != len(targetUUIDs) {
return "", fmt.Errorf("GPU count mismatch: source has %d, target has %d", len(sourceUUIDs), len(targetUUIDs)) return "", fmt.Errorf("GPU count mismatch: source has %d, target has %d", len(sourceUUIDs), len(targetUUIDs))
} }
if len(sourceUUIDs) == 0 { if len(sourceUUIDs) == 0 {
return "", fmt.Errorf("GPU UUID list is empty") return "", fmt.Errorf("GPU UUID list is empty")
} }
log.V(1).Info("BuildDeviceMap inputs", "source_uuids", sourceUUIDs, "target_uuids", targetUUIDs)
targetSet := make(map[string]bool, len(targetUUIDs)) targetSet := make(map[string]bool, len(targetUUIDs))
for _, t := range targetUUIDs { for _, t := range targetUUIDs {
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/go-logr/logr"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
...@@ -58,7 +59,7 @@ func TestBuildDeviceMap(t *testing.T) { ...@@ -58,7 +59,7 @@ func TestBuildDeviceMap(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
got, err := BuildDeviceMap(tc.source, tc.target) got, err := BuildDeviceMap(tc.source, tc.target, logr.Discard())
if tc.wantErr { if tc.wantErr {
if err == nil { if err == nil {
t.Errorf("expected error, got %q", got) t.Errorf("expected error, got %q", got)
...@@ -176,7 +177,7 @@ func TestGetPodGPUUUIDs(t *testing.T) { ...@@ -176,7 +177,7 @@ func TestGetPodGPUUUIDs(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
got, err := GetPodGPUUUIDs(ctx, "test-pod", "default", "main") got, err := GetPodGPUUUIDs(ctx, nil, "test-pod", "default", "main", logr.Discard())
if err != nil { if err != nil {
t.Fatalf("GetPodGPUUUIDs: %v", err) t.Fatalf("GetPodGPUUUIDs: %v", err)
} }
......
package cuda
import (
"context"
"fmt"
"github.com/go-logr/logr"
resourcev1 "k8s.io/api/resource/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
)
const (
resourceAttributeUUID = "uuid"
)
// GetGPUUUIDsViaDRAAPI resolves GPU UUIDs for a pod by querying the Kubernetes API:
// Pod (resource claim refs) -> ResourceClaim (allocation results) -> ResourceSlice (device attributes).
// Returns nil without error if the pod has no DRA claims or the driver is not gpu.nvidia.com.
func GetGPUUUIDsViaDRAAPI(ctx context.Context, clientset kubernetes.Interface, podName, podNamespace string, log logr.Logger) ([]string, error) {
if clientset == nil {
return nil, nil
}
if podName == "" || podNamespace == "" {
return nil, nil
}
pod, err := clientset.CoreV1().Pods(podNamespace).Get(ctx, podName, metav1.GetOptions{})
if err != nil {
return nil, fmt.Errorf("get pod %s/%s: %w", podNamespace, podName, err)
}
if len(pod.Spec.ResourceClaims) == 0 {
return nil, nil
}
nodeName := pod.Spec.NodeName
if nodeName == "" {
log.V(1).Info("pod has no node name, skipping DRA API lookup")
return nil, nil
}
var allocated []struct {
driver string
pool string
device string
}
for _, ref := range pod.Spec.ResourceClaims {
if ref.ResourceClaimName == nil || *ref.ResourceClaimName == "" {
continue
}
claimName := *ref.ResourceClaimName
claim, err := clientset.ResourceV1().ResourceClaims(podNamespace).Get(ctx, claimName, metav1.GetOptions{})
if err != nil {
return nil, fmt.Errorf("get resource claim %s/%s: %w", podNamespace, claimName, err)
}
if claim.Status.Allocation == nil || len(claim.Status.Allocation.Devices.Results) == 0 {
continue
}
for _, r := range claim.Status.Allocation.Devices.Results {
if r.Driver == nvidiaGPUDRADriver {
allocated = append(allocated, struct {
driver string
pool string
device string
}{r.Driver, r.Pool, r.Device})
}
}
}
if len(allocated) == 0 {
return nil, nil
}
slices, err := clientset.ResourceV1().ResourceSlices().List(ctx, metav1.ListOptions{
FieldSelector: fmt.Sprintf("spec.driver=%s,spec.nodeName=%s", nvidiaGPUDRADriver, nodeName),
})
if err != nil {
return nil, fmt.Errorf("list resource slices for node %s: %w", nodeName, err)
}
poolDeviceToUUID := make(map[string]map[string]string)
for i := range slices.Items {
s := &slices.Items[i]
poolName := s.Spec.Pool.Name
if poolDeviceToUUID[poolName] == nil {
poolDeviceToUUID[poolName] = make(map[string]string)
}
for _, dev := range s.Spec.Devices {
uuid := deviceUUIDFromAttributes(dev.Attributes)
if uuid != "" && gpuUUIDPattern.MatchString(uuid) {
poolDeviceToUUID[poolName][dev.Name] = uuid
}
}
}
var uuids []string
for _, a := range allocated {
devMap := poolDeviceToUUID[a.pool]
if devMap == nil {
log.V(1).Info("no ResourceSlice found for pool", "pool", a.pool, "device", a.device)
continue
}
uuid, ok := devMap[a.device]
if !ok || uuid == "" {
log.V(1).Info("device has no UUID in ResourceSlice", "pool", a.pool, "device", a.device)
continue
}
uuids = append(uuids, uuid)
}
if len(uuids) > 0 {
log.Info("resolved GPU UUIDs via DRA API", "uuids", uuids)
}
return uuids, nil
}
func deviceUUIDFromAttributes(attrs map[resourcev1.QualifiedName]resourcev1.DeviceAttribute) string {
a, ok := attrs[resourcev1.QualifiedName(resourceAttributeUUID)]
if !ok || a.StringValue == nil {
return ""
}
return *a.StringValue
}
package cuda
import (
"context"
"testing"
"github.com/go-logr/logr"
corev1 "k8s.io/api/core/v1"
resourcev1 "k8s.io/api/resource/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes/fake"
)
func TestDeviceUUIDFromAttributes(t *testing.T) {
uuidVal := "GPU-f8ddcf75-4014-85da-28da-9dc4de19d997"
tests := []struct {
name string
attrs map[resourcev1.QualifiedName]resourcev1.DeviceAttribute
want string
}{
{
name: "nil map",
attrs: nil,
want: "",
},
{
name: "empty map",
attrs: map[resourcev1.QualifiedName]resourcev1.DeviceAttribute{},
want: "",
},
{
name: "uuid present",
attrs: map[resourcev1.QualifiedName]resourcev1.DeviceAttribute{
resourcev1.QualifiedName("uuid"): {StringValue: &uuidVal},
},
want: uuidVal,
},
{
name: "uuid missing, other attr present",
attrs: map[resourcev1.QualifiedName]resourcev1.DeviceAttribute{
resourcev1.QualifiedName("productName"): {StringValue: ptr("NVIDIA A100")},
},
want: "",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := deviceUUIDFromAttributes(tc.attrs)
if got != tc.want {
t.Errorf("deviceUUIDFromAttributes() = %q, want %q", got, tc.want)
}
})
}
}
func ptr(s string) *string { return &s }
func TestGetGPUUUIDsViaDRAAPI(t *testing.T) {
ctx := context.Background()
log := logr.Discard()
t.Run("nil clientset returns nil without error", func(t *testing.T) {
got, err := GetGPUUUIDsViaDRAAPI(ctx, nil, "pod", "ns", log)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != nil {
t.Errorf("got %v, want nil", got)
}
})
t.Run("empty pod name returns nil", func(t *testing.T) {
client := fake.NewSimpleClientset()
got, err := GetGPUUUIDsViaDRAAPI(ctx, client, "", "ns", log)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != nil {
t.Errorf("got %v, want nil", got)
}
})
t.Run("pod not found returns error", func(t *testing.T) {
client := fake.NewSimpleClientset()
_, err := GetGPUUUIDsViaDRAAPI(ctx, client, "missing", "default", log)
if err == nil {
t.Fatal("expected error when pod not found")
}
})
t.Run("pod with DRA claims resolves UUIDs", func(t *testing.T) {
nodeName := "node-1"
poolName := "pool-node-1"
claimName := "gpu-claim"
namespace := "default"
podName := "test-pod"
uuid1 := "GPU-aaaaaaaa-1111-2222-3333-444444444444"
uuid2 := "GPU-bbbbbbbb-5555-6666-7777-888888888888"
pod := &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{Name: podName, Namespace: namespace},
Spec: corev1.PodSpec{
NodeName: nodeName,
ResourceClaims: []corev1.PodResourceClaim{
{
Name: "gpu",
ResourceClaimName: &claimName,
},
},
},
}
claim := &resourcev1.ResourceClaim{
ObjectMeta: metav1.ObjectMeta{Name: claimName, Namespace: namespace},
Status: resourcev1.ResourceClaimStatus{
Allocation: &resourcev1.AllocationResult{
Devices: resourcev1.DeviceAllocationResult{
Results: []resourcev1.DeviceRequestAllocationResult{
{Driver: nvidiaGPUDRADriver, Pool: poolName, Device: "gpu-0", Request: "gpu"},
{Driver: nvidiaGPUDRADriver, Pool: poolName, Device: "gpu-1", Request: "gpu"},
},
},
},
},
}
slice := &resourcev1.ResourceSlice{
ObjectMeta: metav1.ObjectMeta{Name: poolName + "-gpu.nvidia.com-xxx"},
Spec: resourcev1.ResourceSliceSpec{
Driver: nvidiaGPUDRADriver,
NodeName: &nodeName,
Pool: resourcev1.ResourcePool{Name: poolName},
Devices: []resourcev1.Device{
{
Name: "gpu-0",
Attributes: map[resourcev1.QualifiedName]resourcev1.DeviceAttribute{
resourcev1.QualifiedName("uuid"): {StringValue: &uuid1},
},
},
{
Name: "gpu-1",
Attributes: map[resourcev1.QualifiedName]resourcev1.DeviceAttribute{
resourcev1.QualifiedName("uuid"): {StringValue: &uuid2},
},
},
},
},
}
client := fake.NewSimpleClientset(pod, claim, slice)
got, err := GetGPUUUIDsViaDRAAPI(ctx, client, podName, namespace, log)
if err != nil {
t.Fatalf("GetGPUUUIDsViaDRAAPI: %v", err)
}
want := []string{uuid1, uuid2}
if len(got) != len(want) {
t.Fatalf("got %v (len %d), want %v (len %d)", got, len(got), want, len(want))
}
for i := range want {
if got[i] != want[i] {
t.Errorf("got[%d] = %q, want %q", i, got[i], want[i])
}
}
})
t.Run("pod with no resource claims returns nil", func(t *testing.T) {
pod := &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{Name: "pod", Namespace: "default"},
Spec: corev1.PodSpec{NodeName: "node-1"},
}
client := fake.NewSimpleClientset(pod)
got, err := GetGPUUUIDsViaDRAAPI(ctx, client, "pod", "default", log)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != nil {
t.Errorf("got %v, want nil", got)
}
})
}
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
criurpc "github.com/checkpoint-restore/go-criu/v8/rpc" criurpc "github.com/checkpoint-restore/go-criu/v8/rpc"
"github.com/containerd/containerd" "github.com/containerd/containerd"
"github.com/go-logr/logr" "github.com/go-logr/logr"
"k8s.io/client-go/kubernetes"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/common" "github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/common"
...@@ -30,6 +31,7 @@ type CheckpointRequest struct { ...@@ -30,6 +31,7 @@ type CheckpointRequest struct {
NodeName string NodeName string
PodName string PodName string
PodNamespace string PodNamespace string
Clientset kubernetes.Interface
} }
// Checkpoint performs a CRIU dump of a container. // Checkpoint performs a CRIU dump of a container.
...@@ -162,6 +164,14 @@ func inspectContainer(ctx context.Context, ctrd *containerd.Client, log logr.Log ...@@ -162,6 +164,14 @@ func inspectContainer(ctx context.Context, ctrd *containerd.Client, log logr.Log
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to discover source GPU UUIDs: %w", err) return nil, fmt.Errorf("failed to discover source GPU UUIDs: %w", err)
} }
if len(gpuUUIDs) == 0 {
log.Info("PodResources API returned no GPU UUIDs, falling back to nvidia-smi", "pid", pid)
gpuUUIDs, err = cuda.GetGPUUUIDsViaNvidiaSmi(ctx, common.HostProcPath, pid)
if err != nil {
return nil, fmt.Errorf("nvidia-smi GPU UUID fallback failed: %w", err)
}
log.Info("nvidia-smi fallback discovered GPU UUIDs", "uuids", gpuUUIDs)
}
} }
return &types.CheckpointContainerSnapshot{ return &types.CheckpointContainerSnapshot{
......
...@@ -14,6 +14,7 @@ import ( ...@@ -14,6 +14,7 @@ import (
"github.com/containerd/containerd" "github.com/containerd/containerd"
"github.com/go-logr/logr" "github.com/go-logr/logr"
"k8s.io/client-go/kubernetes"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/common" "github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/common"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/criu" "github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/criu"
...@@ -31,6 +32,7 @@ type RestoreRequest struct { ...@@ -31,6 +32,7 @@ type RestoreRequest struct {
PodName string PodName string
PodNamespace string PodNamespace string
ContainerName string ContainerName string
Clientset kubernetes.Interface
} }
// Restore performs external restore for the given request. // Restore performs external restore for the given request.
...@@ -124,13 +126,26 @@ func inspectRestore(ctx context.Context, ctrd *containerd.Client, log logr.Logge ...@@ -124,13 +126,26 @@ func inspectRestore(ctx context.Context, ctrd *containerd.Client, log logr.Logge
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get target GPU UUIDs: %w", err) return nil, fmt.Errorf("failed to get target GPU UUIDs: %w", err)
} }
if len(targetGPUUUIDs) == 0 {
log.Info("PodResources API returned no target GPU UUIDs, falling back to nvidia-smi", "pid", placeholderPID)
targetGPUUUIDs, err = cuda.GetGPUUUIDsViaNvidiaSmi(ctx, common.HostProcPath, placeholderPID)
if err != nil {
return nil, fmt.Errorf("nvidia-smi GPU UUID fallback failed for restore target: %w", err)
}
log.Info("nvidia-smi fallback discovered target GPU UUIDs", "uuids", targetGPUUUIDs)
}
if len(targetGPUUUIDs) == 0 { if len(targetGPUUUIDs) == 0 {
return nil, fmt.Errorf("missing target GPU UUIDs for %s/%s container %s", req.PodNamespace, req.PodName, containerName) return nil, fmt.Errorf("missing target GPU UUIDs for %s/%s container %s", req.PodNamespace, req.PodName, containerName)
} }
cudaDeviceMap, err = cuda.BuildDeviceMap(m.CUDA.SourceGPUUUIDs, targetGPUUUIDs) cudaDeviceMap, err = cuda.BuildDeviceMap(m.CUDA.SourceGPUUUIDs, targetGPUUUIDs, log)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to build CUDA device map: %w", err) return nil, fmt.Errorf("failed to build CUDA device map: %w", err)
} }
log.Info("GPU UUIDs for device map",
"source_uuids", m.CUDA.SourceGPUUUIDs,
"target_uuids", targetGPUUUIDs,
"device_map", cudaDeviceMap,
)
} }
return &types.RestoreContainerSnapshot{ return &types.RestoreContainerSnapshot{
......
...@@ -205,7 +205,12 @@ def configure_sglang_logging(dyn_level: int) -> None: ...@@ -205,7 +205,12 @@ def configure_sglang_logging(dyn_level: int) -> None:
"handlers": ["dynamo"], "handlers": ["dynamo"],
"level": sglang_level, "level": sglang_level,
"propagate": False, "propagate": False,
} },
"gpu_memory_service": {
"handlers": ["dynamo"],
"level": sglang_level,
"propagate": False,
},
}, },
"version": 1, "version": 1,
"disable_existing_loggers": False, "disable_existing_loggers": False,
...@@ -260,7 +265,12 @@ def configure_vllm_logging(dyn_level: int) -> None: ...@@ -260,7 +265,12 @@ def configure_vllm_logging(dyn_level: int) -> None:
"handlers": ["vllm_stderr"], "handlers": ["vllm_stderr"],
"level": vllm_level, "level": vllm_level,
"propagate": False, "propagate": False,
} },
"gpu_memory_service": {
"handlers": ["vllm_stderr"],
"level": vllm_level,
"propagate": False,
},
}, },
"version": 1, "version": 1,
"disable_existing_loggers": False, "disable_existing_loggers": False,
......
...@@ -23,6 +23,12 @@ logger = logging.getLogger(__name__) ...@@ -23,6 +23,12 @@ logger = logging.getLogger(__name__)
# Module-level GMS lock mode, set by setup_gms() before loader is instantiated. # Module-level GMS lock mode, set by setup_gms() before loader is instantiated.
# Read by patches.py when creating GMSMemorySaverImpl. # Read by patches.py when creating GMSMemorySaverImpl.
_gms_lock_mode = None _gms_lock_mode = None
_gms_initialized = False
def is_gms_active() -> bool:
"""Return True if setup_gms() has been called successfully."""
return _gms_initialized
def setup_gms(server_args) -> Type["GMSModelLoader"]: def setup_gms(server_args) -> Type["GMSModelLoader"]:
...@@ -66,5 +72,8 @@ def setup_gms(server_args) -> Type["GMSModelLoader"]: ...@@ -66,5 +72,8 @@ def setup_gms(server_args) -> Type["GMSModelLoader"]:
# Import triggers patches at module level # Import triggers patches at module level
from gpu_memory_service.integrations.sglang.model_loader import GMSModelLoader from gpu_memory_service.integrations.sglang.model_loader import GMSModelLoader
global _gms_initialized
_gms_initialized = True
logger.info("[GMS] Using GMSModelLoader...") logger.info("[GMS] Using GMSModelLoader...")
return GMSModelLoader return GMSModelLoader
...@@ -21,6 +21,7 @@ from gpu_memory_service.integrations.common import patch_empty_cache ...@@ -21,6 +21,7 @@ from gpu_memory_service.integrations.common import patch_empty_cache
from gpu_memory_service.integrations.common.utils import setup_meta_tensor_workaround from gpu_memory_service.integrations.common.utils import setup_meta_tensor_workaround
from gpu_memory_service.integrations.sglang.patches import ( from gpu_memory_service.integrations.sglang.patches import (
patch_model_runner, patch_model_runner,
patch_static_state_for_gms,
patch_torch_memory_saver, patch_torch_memory_saver,
) )
...@@ -28,9 +29,13 @@ logger = logging.getLogger(__name__) ...@@ -28,9 +29,13 @@ logger = logging.getLogger(__name__)
# Apply patches at module import time. # Apply patches at module import time.
# This module is only imported when load_format="gms" is used. # This module is only imported when load_format="gms" is used.
# Because SGLang scheduler processes use multiprocessing spawn, these patches
# must run inside the child process. The import chain that triggers this is:
# child unpickles server_args.load_format -> imports GMSModelLoader -> here.
patch_empty_cache() patch_empty_cache()
patch_torch_memory_saver() patch_torch_memory_saver()
patch_model_runner() patch_model_runner()
patch_static_state_for_gms()
logger.info("[GMS] Applied patches") logger.info("[GMS] Applied patches")
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
- patch_torch_memory_saver: Routes to GMS hybrid implementation - patch_torch_memory_saver: Routes to GMS hybrid implementation
- patch_model_runner: Fixes memory accounting with pre-loaded weights - patch_model_runner: Fixes memory accounting with pre-loaded weights
- patch_static_state_for_gms: No-ops named-buffer export/import (GMS preserves them)
""" """
from __future__ import annotations from __future__ import annotations
...@@ -19,6 +20,7 @@ logger = logging.getLogger(__name__) ...@@ -19,6 +20,7 @@ logger = logging.getLogger(__name__)
_torch_memory_saver_patched = False _torch_memory_saver_patched = False
_model_runner_patched = False _model_runner_patched = False
_static_state_patched = False
def patch_torch_memory_saver() -> None: def patch_torch_memory_saver() -> None:
...@@ -105,6 +107,24 @@ def patch_torch_memory_saver() -> None: ...@@ -105,6 +107,24 @@ def patch_torch_memory_saver() -> None:
entrypoint_module.TorchMemorySaver.gms_impl = gms_impl entrypoint_module.TorchMemorySaver.gms_impl = gms_impl
# If the singleton was already initialized before this patch ran (e.g.,
# due to import ordering in multiprocessing spawn), reset _impl so the
# next call to _ensure_initialized goes through the patched version and
# creates GMSMemorySaverImpl instead of the default _TorchMemorySaverImpl.
import torch_memory_saver
singleton = torch_memory_saver.torch_memory_saver
if singleton._impl is not None:
logger.debug(
"[GMS] TorchMemorySaver singleton already initialized, "
"resetting to force GMS re-init on next use"
)
singleton._impl = None
# The original _ensure_initialized deletes _impl_ctor_kwargs after
# creating _impl. Restore it so the patched version can read it.
if not hasattr(singleton, "_impl_ctor_kwargs"):
singleton._impl_ctor_kwargs = {}
_torch_memory_saver_patched = True _torch_memory_saver_patched = True
logger.debug("[GMS] Patched torch_memory_saver") logger.debug("[GMS] Patched torch_memory_saver")
...@@ -158,3 +178,49 @@ def patch_model_runner() -> None: ...@@ -158,3 +178,49 @@ def patch_model_runner() -> None:
ModelRunner._gms_patched = True ModelRunner._gms_patched = True
_model_runner_patched = True _model_runner_patched = True
logger.info("[GMS] Patched ModelRunner.init_memory_pool") logger.info("[GMS] Patched ModelRunner.init_memory_pool")
def patch_static_state_for_gms() -> None:
"""No-op SGLang's _export/_import_static_state when using GMS.
SGLang's release_memory_occupation clones every named buffer via
buffer.detach().clone() through the default CUDA allocator, then restores
them during resume_memory_occupation.
This patch must run inside the scheduler child process (which uses
multiprocessing spawn). It is triggered by the GMSModelLoader import
in model_loader.py, which executes at module level in the child.
"""
import os
global _static_state_patched
logger.info(
"[GMS] patch_static_state_for_gms called (pid=%d, already_patched=%s)",
os.getpid(),
_static_state_patched,
)
if _static_state_patched:
return
try:
from sglang.srt.managers import scheduler_update_weights_mixin as _mixin
def _export_noop(model):
"""NO-OP: GMS preserves buffers via VA-stable unmap/remap."""
return dict(buffers=[])
def _import_noop(model, static_params):
"""NO-OP: GMS preserves buffers via VA-stable unmap/remap."""
pass
_mixin._export_static_state = _export_noop
_mixin._import_static_state = _import_noop
_static_state_patched = True
logger.info(
"[GMS] Patched _export/_import_static_state -> no-op (pid=%d)",
os.getpid(),
)
except Exception:
logger.warning(
"[GMS] Could not patch scheduler_update_weights_mixin: ",
exc_info=True,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment