Unverified Commit b2aefc53 authored by Janelle Cai's avatar Janelle Cai Committed by GitHub
Browse files

chore(snapshot): gms + snapshot support (#7026)


Signed-off-by: default avatarJanelle Cai <jcai18@mit.edu>
parent 9616c86f
......@@ -4,6 +4,9 @@
"""Shared Dynamo snapshot helpers for checkpoint lifecycle."""
import asyncio
import ctypes
import ctypes.util
import gc
import logging
import os
import signal
......@@ -234,3 +237,48 @@ def reload_snapshot_restore_identity() -> tuple[str, str]:
# Snapshot restore only runs in Kubernetes-managed pods, so discovery resets here.
os.environ["DYN_DISCOVERY_BACKEND"] = "kubernetes"
return get_worker_namespace(), "kubernetes"
def _try_release_memory(label: str) -> None:
"""Force Python GC and glibc malloc_trim to return freed memory to the OS.
Logs RSS before/after so you can see how much memory was actually reclaimable.
"""
pid = os.getpid()
def _get_rss_kb() -> int:
try:
with open(f"/proc/{pid}/status") as f:
for line in f:
if line.startswith("VmRSS:"):
return int(line.split()[1])
except Exception:
pass
return 0
rss_before = _get_rss_kb()
collected = gc.collect()
rss_after_gc = _get_rss_kb()
try:
libc_name = ctypes.util.find_library("c")
if libc_name:
libc = ctypes.CDLL(libc_name)
libc.malloc_trim(0)
except Exception as e:
logger.debug("[MemRelease:%s] malloc_trim failed: %s", label, e)
rss_after_trim = _get_rss_kb()
logger.info(
"[MemRelease:%s] gc.collect freed %d objects, "
"RSS: %.2f MiB -> %.2f MiB (gc) -> %.2f MiB (malloc_trim), "
"reclaimed=%.2f MiB",
label,
collected,
rss_before / 1024,
rss_after_gc / 1024,
rss_after_trim / 1024,
(rss_before - rss_after_trim) / 1024,
)
......@@ -3,12 +3,17 @@
"""Dynamo Snapshot integration for SGLang workers."""
import logging
import time
import sglang as sgl
from dynamo.common.utils.snapshot import CheckpointConfig, EngineSnapshotController
from dynamo.common.utils.snapshot import (
CheckpointConfig,
EngineSnapshotController,
_try_release_memory,
)
from .request_handlers.handler_base import SGLangEngineQuiesceController
......@@ -37,10 +42,18 @@ async def prepare_snapshot_engine(
logger.info("Checkpoint mode enabled (watcher-driven signals)")
# Enable memory_saver + weights CPU backup so weights survive CRIU
# (mirrors vLLM's enable_sleep_mode = True)
# Enable memory_saver so GPU memory can be released for CRIU.
# When using GMS, weights use VA-stable unmap/remap (no CPU backup); GMS
# forbids enable_weights_cpu_backup. Otherwise use CPU backup for weights.
server_args.enable_memory_saver = True
server_args.enable_weights_cpu_backup = True
try:
from gpu_memory_service.integrations.sglang import is_gms_active
_using_gms = is_gms_active()
except ImportError:
_using_gms = False
if not _using_gms:
server_args.enable_weights_cpu_backup = True
start_time = time.time()
engine = sgl.Engine(server_args=server_args)
......@@ -48,6 +61,8 @@ async def prepare_snapshot_engine(
f"SGLang engine loaded in {time.time() - start_time:.2f}s (checkpoint mode)"
)
_try_release_memory("after_engine_load")
snapshot_controller = EngineSnapshotController(
engine=engine,
quiesce_controller=SGLangEngineQuiesceController(engine),
......
......@@ -197,7 +197,15 @@ def setup_metrics_collection(
registry=DYNAMO_COMPONENT_REGISTRY,
)
if os.environ.get("PROMETHEUS_MULTIPROC_DIR"):
multiproc_dir = os.environ.get("PROMETHEUS_MULTIPROC_DIR")
# After CRIU restore to another node, env still has the checkpoint pod's path
# but that directory exists only on the checkpoint node; create it here if missing.
if multiproc_dir and not os.path.isdir(multiproc_dir):
try:
os.makedirs(multiproc_dir, exist_ok=True)
except OSError:
pass
if multiproc_dir and os.path.isdir(multiproc_dir):
try:
# MultiProcessCollector reads metrics from .db files in PROMETHEUS_MULTIPROC_DIR
# Adding it to REGISTRY allows collecting both in-memory and .db file metrics
......@@ -243,6 +251,11 @@ def setup_metrics_collection(
model_name=config.model,
)
else:
if multiproc_dir:
logger.warning(
f"PROMETHEUS_MULTIPROC_DIR={multiproc_dir} is not a valid directory, "
"falling back to single-process metrics"
)
# No multiprocess mode
register_engine_metrics_callback(
endpoint=generate_endpoint,
......@@ -387,6 +400,12 @@ def setup_vllm_engine(
# instead of .name string, causing false error on exit. Set PROMETHEUS_MULTIPROC_DIR
# ourselves to avoid this and handle cleanup properly.
prometheus_temp_dir = None
existing_dir = os.environ.get("PROMETHEUS_MULTIPROC_DIR")
if existing_dir and not os.path.isdir(existing_dir):
logger.warning(
f"PROMETHEUS_MULTIPROC_DIR={existing_dir} does not exist, recreating"
)
os.makedirs(existing_dir, exist_ok=True)
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
prometheus_temp_dir = tempfile.TemporaryDirectory(prefix="vllm_prometheus_")
os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_temp_dir.name
......
......@@ -4,7 +4,11 @@
import logging
from collections.abc import Callable
from dynamo.common.utils.snapshot import CheckpointConfig, EngineSnapshotController
from dynamo.common.utils.snapshot import (
CheckpointConfig,
EngineSnapshotController,
_try_release_memory,
)
from .args import Config
from .handlers import VllmEngineQuiesceController
......@@ -32,6 +36,7 @@ async def prepare_snapshot_engine(
config.engine_args.enable_sleep_mode = True
engine = setup_vllm_engine(config)
_try_release_memory("after_engine_load")
snapshot_controller = EngineSnapshotController(
engine=engine,
quiesce_controller=VllmEngineQuiesceController(engine[0]),
......
......@@ -28,6 +28,10 @@ rules:
- apiGroups: [""]
resources: ["events"]
verbs: ["create"]
# Resolve DRA GPU UUIDs via ResourceClaim allocation (namespace-scoped)
- apiGroups: ["resource.k8s.io"]
resources: ["resourceclaims"]
verbs: ["get", "list"]
{{- else }}
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRole
......@@ -53,5 +57,25 @@ rules:
- apiGroups: [""]
resources: ["events"]
verbs: ["create"]
# Resolve DRA GPU UUIDs via ResourceClaim and ResourceSlice
- apiGroups: ["resource.k8s.io"]
resources: ["resourceclaims", "resourceslices"]
verbs: ["get", "list"]
{{- end }}
{{- end }}
{{- if and .Values.rbac.create .Values.rbac.namespaceRestricted }}
---
# ResourceSlices are cluster-scoped; agent needs this when using a namespace-restricted Role
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRole
metadata:
name: {{ include "snapshot.fullname" . }}-agent-resourceslices
labels:
{{- include "snapshot.labels" . | nindent 4 }}
app.kubernetes.io/component: checkpoint-agent
rules:
- apiGroups: ["resource.k8s.io"]
resources: ["resourceslices"]
verbs: ["get", "list"]
{{- end }}
......@@ -19,6 +19,23 @@ subjects:
- kind: ServiceAccount
name: {{ include "snapshot.serviceAccountName" . }}
namespace: {{ .Release.Namespace }}
---
# Bind agent to ClusterRole for cluster-scoped ResourceSlices (DRA GPU UUID lookup)
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRoleBinding
metadata:
name: {{ include "snapshot.fullname" . }}-agent-resourceslices
labels:
{{- include "snapshot.labels" . | nindent 4 }}
app.kubernetes.io/component: checkpoint-agent
roleRef:
apiGroup: rbac.authorization.k8s.io
kind: ClusterRole
name: {{ include "snapshot.fullname" . }}-agent-resourceslices
subjects:
- kind: ServiceAccount
name: {{ include "snapshot.serviceAccountName" . }}
namespace: {{ .Release.Namespace }}
{{- else }}
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRoleBinding
......@@ -37,4 +54,3 @@ subjects:
namespace: {{ .Release.Namespace }}
{{- end }}
{{- end }}
......@@ -414,6 +414,7 @@ func (w *NodeController) runCheckpoint(ctx context.Context, pod *corev1.Pod, job
NodeName: w.config.NodeName,
PodName: pod.Name,
PodNamespace: pod.Namespace,
Clientset: w.clientset,
}
if err := executor.Checkpoint(leaseCtx, w.containerd, log, req, w.config); err != nil {
if cause := context.Cause(leaseCtx); cause != nil && cause != context.Canceled {
......@@ -512,6 +513,7 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai
PodName: pod.Name,
PodNamespace: pod.Namespace,
ContainerName: containerName,
Clientset: w.clientset,
}
restoredPID, err := executor.Restore(ctx, w.containerd, log, req)
if err != nil {
......
......@@ -5,6 +5,7 @@ import (
"context"
"fmt"
"os/exec"
"regexp"
"strconv"
"strings"
......@@ -15,13 +16,17 @@ import (
podresourcesv1 "k8s.io/kubelet/pkg/apis/podresources/v1"
)
const nvidiaGPUResource = "nvidia.com/gpu"
const (
nvidiaGPUResource = "nvidia.com/gpu"
nvidiaGPUDRADriver = "gpu.nvidia.com"
)
var podResourcesSocketPath = "/var/lib/kubelet/pod-resources/kubelet.sock"
// GetPodGPUUUIDs resolves GPU UUIDs for a pod/container from the kubelet PodResources API.
// All nvidia.com/gpu device entries are accumulated in case the kubelet splits them
// across multiple entries (observed in some runtimes with multi-GPU pods).
var gpuUUIDPattern = regexp.MustCompile(`^GPU-[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}$`)
// GetPodGPUUUIDs resolves GPU UUIDs for a pod/container from kubelet
// PodResources (nvidia.com/gpu entries in GetDevices()).
func GetPodGPUUUIDs(ctx context.Context, podName, podNamespace, containerName string) ([]string, error) {
if podName == "" || podNamespace == "" {
return nil, nil
......@@ -56,12 +61,40 @@ func GetPodGPUUUIDs(ctx context.Context, podName, podNamespace, containerName st
uuids = append(uuids, device.GetDeviceIds()...)
}
}
}
}
return uuids, nil
}
// GetGPUUUIDsViaNvidiaSmi discovers GPU UUIDs by running nvidia-smi inside the
// container's mount namespace. This is the fallback path when the kubelet
// PodResources API does not report GPU devices (e.g. when GPUs are allocated
// via DRA instead of the NVIDIA device plugin).
func GetGPUUUIDsViaNvidiaSmi(ctx context.Context, hostProcPath string, pid int) ([]string, error) {
mountPath := fmt.Sprintf("%s/%d/ns/mnt", strings.TrimRight(hostProcPath, "/"), pid)
cmd := exec.CommandContext(
ctx,
"nsenter",
fmt.Sprintf("--mount=%s", mountPath),
"--",
"nvidia-smi", "--query-gpu=gpu_uuid", "--format=csv,noheader",
)
output, err := cmd.Output()
if err != nil {
return nil, fmt.Errorf("nvidia-smi via nsenter (pid %d) failed: %w", pid, err)
}
var uuids []string
for _, line := range strings.Split(strings.TrimSpace(string(output)), "\n") {
line = strings.TrimSpace(line)
if line != "" {
uuids = append(uuids, line)
}
}
return uuids, nil
}
// FilterProcesses returns the subset of candidate PIDs that hold actual CUDA contexts.
// Uses --get-restore-tid (the same technique as the CRIU CUDA plugin) instead of
// --get-state, because --get-state incorrectly matches coordinator processes like
......@@ -93,13 +126,14 @@ func FilterProcesses(ctx context.Context, allPIDs []int, log logr.Logger) []int
// When a source UUID exists in the target set, it maps to itself (identity mapping) to avoid
// unnecessary cross-GPU restore on same-node restores where kubelet returns GPUs in different order.
// Remaining unmatched source UUIDs are paired with remaining unmatched target UUIDs positionally.
func BuildDeviceMap(sourceUUIDs, targetUUIDs []string) (string, error) {
func BuildDeviceMap(sourceUUIDs, targetUUIDs []string, log logr.Logger) (string, error) {
if len(sourceUUIDs) != len(targetUUIDs) {
return "", fmt.Errorf("GPU count mismatch: source has %d, target has %d", len(sourceUUIDs), len(targetUUIDs))
}
if len(sourceUUIDs) == 0 {
return "", fmt.Errorf("GPU UUID list is empty")
}
log.V(1).Info("BuildDeviceMap inputs", "source_uuids", sourceUUIDs, "target_uuids", targetUUIDs)
targetSet := make(map[string]bool, len(targetUUIDs))
for _, t := range targetUUIDs {
......
......@@ -9,6 +9,7 @@ import (
"testing"
"time"
"github.com/go-logr/logr"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
......@@ -58,7 +59,7 @@ func TestBuildDeviceMap(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got, err := BuildDeviceMap(tc.source, tc.target)
got, err := BuildDeviceMap(tc.source, tc.target, logr.Discard())
if tc.wantErr {
if err == nil {
t.Errorf("expected error, got %q", got)
......@@ -176,7 +177,7 @@ func TestGetPodGPUUUIDs(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
got, err := GetPodGPUUUIDs(ctx, "test-pod", "default", "main")
got, err := GetPodGPUUUIDs(ctx, nil, "test-pod", "default", "main", logr.Discard())
if err != nil {
t.Fatalf("GetPodGPUUUIDs: %v", err)
}
......
package cuda
import (
"context"
"fmt"
"github.com/go-logr/logr"
resourcev1 "k8s.io/api/resource/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
)
const (
resourceAttributeUUID = "uuid"
)
// GetGPUUUIDsViaDRAAPI resolves GPU UUIDs for a pod by querying the Kubernetes API:
// Pod (resource claim refs) -> ResourceClaim (allocation results) -> ResourceSlice (device attributes).
// Returns nil without error if the pod has no DRA claims or the driver is not gpu.nvidia.com.
func GetGPUUUIDsViaDRAAPI(ctx context.Context, clientset kubernetes.Interface, podName, podNamespace string, log logr.Logger) ([]string, error) {
if clientset == nil {
return nil, nil
}
if podName == "" || podNamespace == "" {
return nil, nil
}
pod, err := clientset.CoreV1().Pods(podNamespace).Get(ctx, podName, metav1.GetOptions{})
if err != nil {
return nil, fmt.Errorf("get pod %s/%s: %w", podNamespace, podName, err)
}
if len(pod.Spec.ResourceClaims) == 0 {
return nil, nil
}
nodeName := pod.Spec.NodeName
if nodeName == "" {
log.V(1).Info("pod has no node name, skipping DRA API lookup")
return nil, nil
}
var allocated []struct {
driver string
pool string
device string
}
for _, ref := range pod.Spec.ResourceClaims {
if ref.ResourceClaimName == nil || *ref.ResourceClaimName == "" {
continue
}
claimName := *ref.ResourceClaimName
claim, err := clientset.ResourceV1().ResourceClaims(podNamespace).Get(ctx, claimName, metav1.GetOptions{})
if err != nil {
return nil, fmt.Errorf("get resource claim %s/%s: %w", podNamespace, claimName, err)
}
if claim.Status.Allocation == nil || len(claim.Status.Allocation.Devices.Results) == 0 {
continue
}
for _, r := range claim.Status.Allocation.Devices.Results {
if r.Driver == nvidiaGPUDRADriver {
allocated = append(allocated, struct {
driver string
pool string
device string
}{r.Driver, r.Pool, r.Device})
}
}
}
if len(allocated) == 0 {
return nil, nil
}
slices, err := clientset.ResourceV1().ResourceSlices().List(ctx, metav1.ListOptions{
FieldSelector: fmt.Sprintf("spec.driver=%s,spec.nodeName=%s", nvidiaGPUDRADriver, nodeName),
})
if err != nil {
return nil, fmt.Errorf("list resource slices for node %s: %w", nodeName, err)
}
poolDeviceToUUID := make(map[string]map[string]string)
for i := range slices.Items {
s := &slices.Items[i]
poolName := s.Spec.Pool.Name
if poolDeviceToUUID[poolName] == nil {
poolDeviceToUUID[poolName] = make(map[string]string)
}
for _, dev := range s.Spec.Devices {
uuid := deviceUUIDFromAttributes(dev.Attributes)
if uuid != "" && gpuUUIDPattern.MatchString(uuid) {
poolDeviceToUUID[poolName][dev.Name] = uuid
}
}
}
var uuids []string
for _, a := range allocated {
devMap := poolDeviceToUUID[a.pool]
if devMap == nil {
log.V(1).Info("no ResourceSlice found for pool", "pool", a.pool, "device", a.device)
continue
}
uuid, ok := devMap[a.device]
if !ok || uuid == "" {
log.V(1).Info("device has no UUID in ResourceSlice", "pool", a.pool, "device", a.device)
continue
}
uuids = append(uuids, uuid)
}
if len(uuids) > 0 {
log.Info("resolved GPU UUIDs via DRA API", "uuids", uuids)
}
return uuids, nil
}
func deviceUUIDFromAttributes(attrs map[resourcev1.QualifiedName]resourcev1.DeviceAttribute) string {
a, ok := attrs[resourcev1.QualifiedName(resourceAttributeUUID)]
if !ok || a.StringValue == nil {
return ""
}
return *a.StringValue
}
package cuda
import (
"context"
"testing"
"github.com/go-logr/logr"
corev1 "k8s.io/api/core/v1"
resourcev1 "k8s.io/api/resource/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes/fake"
)
func TestDeviceUUIDFromAttributes(t *testing.T) {
uuidVal := "GPU-f8ddcf75-4014-85da-28da-9dc4de19d997"
tests := []struct {
name string
attrs map[resourcev1.QualifiedName]resourcev1.DeviceAttribute
want string
}{
{
name: "nil map",
attrs: nil,
want: "",
},
{
name: "empty map",
attrs: map[resourcev1.QualifiedName]resourcev1.DeviceAttribute{},
want: "",
},
{
name: "uuid present",
attrs: map[resourcev1.QualifiedName]resourcev1.DeviceAttribute{
resourcev1.QualifiedName("uuid"): {StringValue: &uuidVal},
},
want: uuidVal,
},
{
name: "uuid missing, other attr present",
attrs: map[resourcev1.QualifiedName]resourcev1.DeviceAttribute{
resourcev1.QualifiedName("productName"): {StringValue: ptr("NVIDIA A100")},
},
want: "",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := deviceUUIDFromAttributes(tc.attrs)
if got != tc.want {
t.Errorf("deviceUUIDFromAttributes() = %q, want %q", got, tc.want)
}
})
}
}
func ptr(s string) *string { return &s }
func TestGetGPUUUIDsViaDRAAPI(t *testing.T) {
ctx := context.Background()
log := logr.Discard()
t.Run("nil clientset returns nil without error", func(t *testing.T) {
got, err := GetGPUUUIDsViaDRAAPI(ctx, nil, "pod", "ns", log)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != nil {
t.Errorf("got %v, want nil", got)
}
})
t.Run("empty pod name returns nil", func(t *testing.T) {
client := fake.NewSimpleClientset()
got, err := GetGPUUUIDsViaDRAAPI(ctx, client, "", "ns", log)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != nil {
t.Errorf("got %v, want nil", got)
}
})
t.Run("pod not found returns error", func(t *testing.T) {
client := fake.NewSimpleClientset()
_, err := GetGPUUUIDsViaDRAAPI(ctx, client, "missing", "default", log)
if err == nil {
t.Fatal("expected error when pod not found")
}
})
t.Run("pod with DRA claims resolves UUIDs", func(t *testing.T) {
nodeName := "node-1"
poolName := "pool-node-1"
claimName := "gpu-claim"
namespace := "default"
podName := "test-pod"
uuid1 := "GPU-aaaaaaaa-1111-2222-3333-444444444444"
uuid2 := "GPU-bbbbbbbb-5555-6666-7777-888888888888"
pod := &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{Name: podName, Namespace: namespace},
Spec: corev1.PodSpec{
NodeName: nodeName,
ResourceClaims: []corev1.PodResourceClaim{
{
Name: "gpu",
ResourceClaimName: &claimName,
},
},
},
}
claim := &resourcev1.ResourceClaim{
ObjectMeta: metav1.ObjectMeta{Name: claimName, Namespace: namespace},
Status: resourcev1.ResourceClaimStatus{
Allocation: &resourcev1.AllocationResult{
Devices: resourcev1.DeviceAllocationResult{
Results: []resourcev1.DeviceRequestAllocationResult{
{Driver: nvidiaGPUDRADriver, Pool: poolName, Device: "gpu-0", Request: "gpu"},
{Driver: nvidiaGPUDRADriver, Pool: poolName, Device: "gpu-1", Request: "gpu"},
},
},
},
},
}
slice := &resourcev1.ResourceSlice{
ObjectMeta: metav1.ObjectMeta{Name: poolName + "-gpu.nvidia.com-xxx"},
Spec: resourcev1.ResourceSliceSpec{
Driver: nvidiaGPUDRADriver,
NodeName: &nodeName,
Pool: resourcev1.ResourcePool{Name: poolName},
Devices: []resourcev1.Device{
{
Name: "gpu-0",
Attributes: map[resourcev1.QualifiedName]resourcev1.DeviceAttribute{
resourcev1.QualifiedName("uuid"): {StringValue: &uuid1},
},
},
{
Name: "gpu-1",
Attributes: map[resourcev1.QualifiedName]resourcev1.DeviceAttribute{
resourcev1.QualifiedName("uuid"): {StringValue: &uuid2},
},
},
},
},
}
client := fake.NewSimpleClientset(pod, claim, slice)
got, err := GetGPUUUIDsViaDRAAPI(ctx, client, podName, namespace, log)
if err != nil {
t.Fatalf("GetGPUUUIDsViaDRAAPI: %v", err)
}
want := []string{uuid1, uuid2}
if len(got) != len(want) {
t.Fatalf("got %v (len %d), want %v (len %d)", got, len(got), want, len(want))
}
for i := range want {
if got[i] != want[i] {
t.Errorf("got[%d] = %q, want %q", i, got[i], want[i])
}
}
})
t.Run("pod with no resource claims returns nil", func(t *testing.T) {
pod := &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{Name: "pod", Namespace: "default"},
Spec: corev1.PodSpec{NodeName: "node-1"},
}
client := fake.NewSimpleClientset(pod)
got, err := GetGPUUUIDsViaDRAAPI(ctx, client, "pod", "default", log)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != nil {
t.Errorf("got %v, want nil", got)
}
})
}
......@@ -12,6 +12,7 @@ import (
criurpc "github.com/checkpoint-restore/go-criu/v8/rpc"
"github.com/containerd/containerd"
"github.com/go-logr/logr"
"k8s.io/client-go/kubernetes"
"github.com/google/uuid"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/common"
......@@ -30,6 +31,7 @@ type CheckpointRequest struct {
NodeName string
PodName string
PodNamespace string
Clientset kubernetes.Interface
}
// Checkpoint performs a CRIU dump of a container.
......@@ -162,6 +164,14 @@ func inspectContainer(ctx context.Context, ctrd *containerd.Client, log logr.Log
if err != nil {
return nil, fmt.Errorf("failed to discover source GPU UUIDs: %w", err)
}
if len(gpuUUIDs) == 0 {
log.Info("PodResources API returned no GPU UUIDs, falling back to nvidia-smi", "pid", pid)
gpuUUIDs, err = cuda.GetGPUUUIDsViaNvidiaSmi(ctx, common.HostProcPath, pid)
if err != nil {
return nil, fmt.Errorf("nvidia-smi GPU UUID fallback failed: %w", err)
}
log.Info("nvidia-smi fallback discovered GPU UUIDs", "uuids", gpuUUIDs)
}
}
return &types.CheckpointContainerSnapshot{
......
......@@ -14,6 +14,7 @@ import (
"github.com/containerd/containerd"
"github.com/go-logr/logr"
"k8s.io/client-go/kubernetes"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/common"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/criu"
......@@ -31,6 +32,7 @@ type RestoreRequest struct {
PodName string
PodNamespace string
ContainerName string
Clientset kubernetes.Interface
}
// Restore performs external restore for the given request.
......@@ -124,13 +126,26 @@ func inspectRestore(ctx context.Context, ctrd *containerd.Client, log logr.Logge
if err != nil {
return nil, fmt.Errorf("failed to get target GPU UUIDs: %w", err)
}
if len(targetGPUUUIDs) == 0 {
log.Info("PodResources API returned no target GPU UUIDs, falling back to nvidia-smi", "pid", placeholderPID)
targetGPUUUIDs, err = cuda.GetGPUUUIDsViaNvidiaSmi(ctx, common.HostProcPath, placeholderPID)
if err != nil {
return nil, fmt.Errorf("nvidia-smi GPU UUID fallback failed for restore target: %w", err)
}
log.Info("nvidia-smi fallback discovered target GPU UUIDs", "uuids", targetGPUUUIDs)
}
if len(targetGPUUUIDs) == 0 {
return nil, fmt.Errorf("missing target GPU UUIDs for %s/%s container %s", req.PodNamespace, req.PodName, containerName)
}
cudaDeviceMap, err = cuda.BuildDeviceMap(m.CUDA.SourceGPUUUIDs, targetGPUUUIDs)
cudaDeviceMap, err = cuda.BuildDeviceMap(m.CUDA.SourceGPUUUIDs, targetGPUUUIDs, log)
if err != nil {
return nil, fmt.Errorf("failed to build CUDA device map: %w", err)
}
log.Info("GPU UUIDs for device map",
"source_uuids", m.CUDA.SourceGPUUUIDs,
"target_uuids", targetGPUUUIDs,
"device_map", cudaDeviceMap,
)
}
return &types.RestoreContainerSnapshot{
......
......@@ -205,7 +205,12 @@ def configure_sglang_logging(dyn_level: int) -> None:
"handlers": ["dynamo"],
"level": sglang_level,
"propagate": False,
}
},
"gpu_memory_service": {
"handlers": ["dynamo"],
"level": sglang_level,
"propagate": False,
},
},
"version": 1,
"disable_existing_loggers": False,
......@@ -260,7 +265,12 @@ def configure_vllm_logging(dyn_level: int) -> None:
"handlers": ["vllm_stderr"],
"level": vllm_level,
"propagate": False,
}
},
"gpu_memory_service": {
"handlers": ["vllm_stderr"],
"level": vllm_level,
"propagate": False,
},
},
"version": 1,
"disable_existing_loggers": False,
......
......@@ -23,6 +23,12 @@ logger = logging.getLogger(__name__)
# Module-level GMS lock mode, set by setup_gms() before loader is instantiated.
# Read by patches.py when creating GMSMemorySaverImpl.
_gms_lock_mode = None
_gms_initialized = False
def is_gms_active() -> bool:
"""Return True if setup_gms() has been called successfully."""
return _gms_initialized
def setup_gms(server_args) -> Type["GMSModelLoader"]:
......@@ -66,5 +72,8 @@ def setup_gms(server_args) -> Type["GMSModelLoader"]:
# Import triggers patches at module level
from gpu_memory_service.integrations.sglang.model_loader import GMSModelLoader
global _gms_initialized
_gms_initialized = True
logger.info("[GMS] Using GMSModelLoader...")
return GMSModelLoader
......@@ -21,6 +21,7 @@ from gpu_memory_service.integrations.common import patch_empty_cache
from gpu_memory_service.integrations.common.utils import setup_meta_tensor_workaround
from gpu_memory_service.integrations.sglang.patches import (
patch_model_runner,
patch_static_state_for_gms,
patch_torch_memory_saver,
)
......@@ -28,9 +29,13 @@ logger = logging.getLogger(__name__)
# Apply patches at module import time.
# This module is only imported when load_format="gms" is used.
# Because SGLang scheduler processes use multiprocessing spawn, these patches
# must run inside the child process. The import chain that triggers this is:
# child unpickles server_args.load_format -> imports GMSModelLoader -> here.
patch_empty_cache()
patch_torch_memory_saver()
patch_model_runner()
patch_static_state_for_gms()
logger.info("[GMS] Applied patches")
......
......@@ -5,6 +5,7 @@
- patch_torch_memory_saver: Routes to GMS hybrid implementation
- patch_model_runner: Fixes memory accounting with pre-loaded weights
- patch_static_state_for_gms: No-ops named-buffer export/import (GMS preserves them)
"""
from __future__ import annotations
......@@ -19,6 +20,7 @@ logger = logging.getLogger(__name__)
_torch_memory_saver_patched = False
_model_runner_patched = False
_static_state_patched = False
def patch_torch_memory_saver() -> None:
......@@ -105,6 +107,24 @@ def patch_torch_memory_saver() -> None:
entrypoint_module.TorchMemorySaver.gms_impl = gms_impl
# If the singleton was already initialized before this patch ran (e.g.,
# due to import ordering in multiprocessing spawn), reset _impl so the
# next call to _ensure_initialized goes through the patched version and
# creates GMSMemorySaverImpl instead of the default _TorchMemorySaverImpl.
import torch_memory_saver
singleton = torch_memory_saver.torch_memory_saver
if singleton._impl is not None:
logger.debug(
"[GMS] TorchMemorySaver singleton already initialized, "
"resetting to force GMS re-init on next use"
)
singleton._impl = None
# The original _ensure_initialized deletes _impl_ctor_kwargs after
# creating _impl. Restore it so the patched version can read it.
if not hasattr(singleton, "_impl_ctor_kwargs"):
singleton._impl_ctor_kwargs = {}
_torch_memory_saver_patched = True
logger.debug("[GMS] Patched torch_memory_saver")
......@@ -158,3 +178,49 @@ def patch_model_runner() -> None:
ModelRunner._gms_patched = True
_model_runner_patched = True
logger.info("[GMS] Patched ModelRunner.init_memory_pool")
def patch_static_state_for_gms() -> None:
"""No-op SGLang's _export/_import_static_state when using GMS.
SGLang's release_memory_occupation clones every named buffer via
buffer.detach().clone() through the default CUDA allocator, then restores
them during resume_memory_occupation.
This patch must run inside the scheduler child process (which uses
multiprocessing spawn). It is triggered by the GMSModelLoader import
in model_loader.py, which executes at module level in the child.
"""
import os
global _static_state_patched
logger.info(
"[GMS] patch_static_state_for_gms called (pid=%d, already_patched=%s)",
os.getpid(),
_static_state_patched,
)
if _static_state_patched:
return
try:
from sglang.srt.managers import scheduler_update_weights_mixin as _mixin
def _export_noop(model):
"""NO-OP: GMS preserves buffers via VA-stable unmap/remap."""
return dict(buffers=[])
def _import_noop(model, static_params):
"""NO-OP: GMS preserves buffers via VA-stable unmap/remap."""
pass
_mixin._export_static_state = _export_noop
_mixin._import_static_state = _import_noop
_static_state_patched = True
logger.info(
"[GMS] Patched _export/_import_static_state -> no-op (pid=%d)",
os.getpid(),
)
except Exception:
logger.warning(
"[GMS] Could not patch scheduler_update_weights_mixin: ",
exc_info=True,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment