Unverified Commit bb8fc8a4 authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

feat(chrek): external restore, signal-based IPC, and package refactor (#6286)


Co-authored-by: default avatarDan Feigin <dfeigin@nvidia.com>
parent c8423b57
// storage.go provides checkpoint storage I/O: write/read manifests, listing, deletion.
package checkpoint
import (
"fmt"
"os"
"path/filepath"
"strings"
"gopkg.in/yaml.v3"
)
// WriteCheckpointManifest writes a checkpoint manifest file in the checkpoint directory.
func WriteCheckpointManifest(checkpointDir string, data *CheckpointManifest) error {
content, err := yaml.Marshal(data)
if err != nil {
return fmt.Errorf("failed to marshal checkpoint manifest: %w", err)
}
manifestPath := filepath.Join(checkpointDir, CheckpointManifestFilename)
if err := os.WriteFile(manifestPath, content, 0600); err != nil {
return fmt.Errorf("failed to write checkpoint manifest: %w", err)
}
return nil
}
// ReadCheckpointManifest reads checkpoint manifest from a checkpoint directory.
func ReadCheckpointManifest(checkpointDir string) (*CheckpointManifest, error) {
manifestPath := filepath.Join(checkpointDir, CheckpointManifestFilename)
content, err := os.ReadFile(manifestPath)
if err != nil {
return nil, fmt.Errorf("failed to read checkpoint manifest: %w", err)
}
var data CheckpointManifest
if err := yaml.Unmarshal(content, &data); err != nil {
return nil, fmt.Errorf("failed to unmarshal checkpoint manifest: %w", err)
}
return &data, nil
}
// SaveDescriptors writes file descriptor information to the checkpoint directory.
func SaveDescriptors(checkpointDir string, descriptors []string) error {
content, err := yaml.Marshal(descriptors)
if err != nil {
return fmt.Errorf("failed to marshal descriptors: %w", err)
}
descriptorsPath := filepath.Join(checkpointDir, DescriptorsFilename)
if err := os.WriteFile(descriptorsPath, content, 0600); err != nil {
return fmt.Errorf("failed to write descriptors file: %w", err)
}
return nil
}
// LoadDescriptors reads file descriptor information from checkpoint directory.
func LoadDescriptors(checkpointDir string) ([]string, error) {
descriptorsPath := filepath.Join(checkpointDir, DescriptorsFilename)
content, err := os.ReadFile(descriptorsPath)
if err != nil {
return nil, fmt.Errorf("failed to read descriptors file: %w", err)
}
var descriptors []string
if err := yaml.Unmarshal(content, &descriptors); err != nil {
return nil, fmt.Errorf("failed to unmarshal descriptors: %w", err)
}
return descriptors, nil
}
// ListCheckpoints returns all checkpoint IDs in the base directory.
func ListCheckpoints(baseDir string) ([]string, error) {
entries, err := os.ReadDir(baseDir)
if err != nil {
return nil, fmt.Errorf("failed to read checkpoint directory: %w", err)
}
var checkpoints []string
for _, entry := range entries {
if !entry.IsDir() {
continue
}
// Check if manifest file exists.
manifestPath := filepath.Join(baseDir, entry.Name(), CheckpointManifestFilename)
if _, err := os.Stat(manifestPath); err == nil {
checkpoints = append(checkpoints, entry.Name())
}
}
return checkpoints, nil
}
// DeleteCheckpoint removes a checkpoint directory.
func DeleteCheckpoint(baseDir, checkpointID string) error {
checkpointDir := filepath.Join(baseDir, checkpointID)
// Ensure resolved path is within baseDir to prevent path traversal
absBase, _ := filepath.Abs(baseDir)
absDir, _ := filepath.Abs(checkpointDir)
if !strings.HasPrefix(absDir, absBase+string(filepath.Separator)) && absDir != absBase {
return fmt.Errorf("invalid checkpoint ID: resolved path %s is outside base directory %s", absDir, absBase)
}
return os.RemoveAll(checkpointDir)
}
package common
import (
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
)
const HostCgroupPath = "/sys/fs/cgroup"
// ResolveCgroupRootFromHostPID reads the unified cgroup v2 path for a PID via /host/proc.
func ResolveCgroupRootFromHostPID(pid int) (string, error) {
cgroupFile := filepath.Join(HostProcPath, strconv.Itoa(pid), "cgroup")
data, err := os.ReadFile(cgroupFile)
if err != nil {
return "", fmt.Errorf("failed reading %s: %w", cgroupFile, err)
}
for _, line := range strings.Split(strings.TrimSpace(string(data)), "\n") {
line = strings.TrimSpace(line)
if line == "" || !strings.HasPrefix(line, "0::") {
continue
}
path := strings.TrimPrefix(line, "0::")
if path == "" {
return "/", nil
}
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
return filepath.Clean(path), nil
}
return "", fmt.Errorf("unified cgroup entry not found in %s", cgroupFile)
}
// criu.go provides shared CRIU utilities used by both checkpoint and restore.
package common
import (
"bufio"
"fmt"
"os"
"strings"
"golang.org/x/sys/unix"
)
// OpenPathForCRIU opens a path (directory or file) and clears the CLOEXEC flag
// so the FD can be inherited by CRIU child processes.
// Returns the opened file and its FD. Caller must close the file when done.
func OpenPathForCRIU(path string) (*os.File, int32, error) {
dir, err := os.Open(path)
if err != nil {
return nil, 0, fmt.Errorf("failed to open %s: %w", path, err)
}
// Clear CLOEXEC so the FD is inherited by CRIU child process.
// Go's os.Open() sets O_CLOEXEC by default, but go-criu's swrk mode
// requires the FD to be inherited.
if _, err := unix.FcntlInt(dir.Fd(), unix.F_SETFD, 0); err != nil {
dir.Close()
return nil, 0, fmt.Errorf("failed to clear CLOEXEC on %s: %w", path, err)
}
return dir, int32(dir.Fd()), nil
}
// CRIUMountPoint represents a parsed mount point from /proc/pid/mountinfo.
type CRIUMountPoint struct {
MountID string // Mount ID
ParentID string // Parent mount ID
Path string // Mount point path (container-side)
Root string // Root within filesystem (host-side for bind mounts)
FSType string // Filesystem type
Source string // Mount source
Options string // Mount options
SuperOpts string // Super block options
}
// ParseMountInfoFile parses a mountinfo file and returns all mount points.
// This is used by both checkpoint (to mark mounts as external) and
// restore (to generate external mount mappings).
func ParseMountInfoFile(path string) ([]CRIUMountPoint, error) {
file, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("failed to open mountinfo: %w", err)
}
defer file.Close()
var mounts []CRIUMountPoint
scanner := bufio.NewScanner(file)
for scanner.Scan() {
mount, err := parseCRIUMountInfoLine(scanner.Text())
if err != nil {
continue // Skip malformed lines
}
mounts = append(mounts, mount)
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("error reading mountinfo: %w", err)
}
return mounts, nil
}
// GetMountPointPaths returns just the mount point paths from a mountinfo file.
// This is a convenience function when you only need the paths.
func GetMountPointPaths(path string) ([]string, error) {
mounts, err := ParseMountInfoFile(path)
if err != nil {
return nil, err
}
paths := make([]string, 0, len(mounts))
for _, m := range mounts {
paths = append(paths, m.Path)
}
return paths, nil
}
// parseCRIUMountInfoLine parses a single line from mountinfo.
// mountinfo format:
// 36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue
// (1)(2)(3) (4) (5) (6) (7) (8) (9) (10) (11)
func parseCRIUMountInfoLine(line string) (CRIUMountPoint, error) {
fields := strings.Fields(line)
if len(fields) < 10 {
return CRIUMountPoint{}, fmt.Errorf("malformed mountinfo line")
}
// Find separator (-) to get fstype and source
sepIdx := -1
for i, f := range fields {
if f == "-" {
sepIdx = i
break
}
}
if sepIdx == -1 || sepIdx+2 >= len(fields) {
return CRIUMountPoint{}, fmt.Errorf("malformed mountinfo line (no separator)")
}
superOpts := ""
if sepIdx+3 < len(fields) {
superOpts = fields[sepIdx+3]
}
return CRIUMountPoint{
MountID: fields[0],
ParentID: fields[1],
Root: fields[3],
Path: fields[4],
Options: fields[5],
FSType: fields[sepIdx+1],
Source: fields[sepIdx+2],
SuperOpts: superOpts,
}, nil
}
package common
import (
"fmt"
"os"
"path/filepath"
"strings"
"syscall"
"github.com/moby/sys/mountinfo"
specs "github.com/opencontainers/runtime-spec/specs-go"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/types"
)
// ReadMountInfo reads and parses mountinfo for a container process via /host/proc.
func ReadMountInfo(pid int) ([]types.MountInfo, error) {
mountinfoPath := fmt.Sprintf("%s/%d/mountinfo", HostProcPath, pid)
f, err := os.Open(mountinfoPath)
if err != nil {
return nil, fmt.Errorf("failed to open mountinfo: %w", err)
}
defer f.Close()
infos, err := mountinfo.GetMountsFromReader(f, nil)
if err != nil {
return nil, fmt.Errorf("failed to parse mountinfo: %w", err)
}
mounts := make([]types.MountInfo, 0, len(infos))
for _, info := range infos {
mounts = append(mounts, types.MountInfo{
MountPoint: info.Mountpoint,
FSType: info.FSType,
VFSOptions: info.VFSOptions,
})
}
return mounts, nil
}
// ClassifyMounts sets IsOCIManaged on each mount by matching against the
// container's OCI spec (mounts, masked paths, readonly paths).
// Handles /run/ ↔ /var/run/ aliasing since some images symlink one to the other.
func ClassifyMounts(mounts []types.MountInfo, ociSpec *specs.Spec, rootFS string) []types.MountInfo {
ociSet := collectOCIManagedPaths(ociSpec, rootFS)
for i := range mounts {
mp := mounts[i].MountPoint
if _, ok := ociSet[mp]; ok {
mounts[i].IsOCIManaged = true
continue
}
// /run/ ↔ /var/run/ aliasing
if strings.HasPrefix(mp, "/run/") {
if _, ok := ociSet["/var"+mp]; ok {
mounts[i].IsOCIManaged = true
continue
}
}
if strings.HasPrefix(mp, "/var/run/") {
if _, ok := ociSet[strings.TrimPrefix(mp, "/var")]; ok {
mounts[i].IsOCIManaged = true
}
}
}
return mounts
}
// BuildMountPolicy classifies mounts and masked paths for CRIU dump.
// Mounts must already have IsOCIManaged set by ClassifyMounts.
//
// Policy (evaluated top to bottom):
// 1. Skip: non-OCI /proc/*, /sys/*, /run/* submounts (virtual/runtime, not in placeholder)
// 2. Native: /dev/shm tmpfs (CRIU saves and restores content)
// 3. Masked: OCI masked non-directory paths that exist under rootFS → /dev/null
// 4. Externalize: everything else (OCI mounts the runtime recreates in placeholder)
func BuildMountPolicy(mounts []types.MountInfo, rootFS string, maskedPaths []string) (map[string]string, []string) {
extMap := make(map[string]string, len(mounts))
var skipped []string
for _, m := range mounts {
if m.MountPoint == "" {
continue
}
// Skip non-OCI virtual/runtime mounts — these won't exist in the placeholder
if !m.IsOCIManaged && (strings.HasPrefix(m.MountPoint, "/proc/") || strings.HasPrefix(m.MountPoint, "/sys/") || strings.HasPrefix(m.MountPoint, "/run/")) {
skipped = append(skipped, m.MountPoint)
continue
}
// Let CRIU handle /dev/shm content natively — don't externalize it.
if m.MountPoint == "/dev/shm" && m.FSType == "tmpfs" {
continue
}
extMap[m.MountPoint] = m.MountPoint
}
// Masked paths map to /dev/null. Only non-directory paths that exist under rootFS.
for _, p := range maskedPaths {
hostPath := filepath.Join(rootFS, p)
info, err := os.Lstat(hostPath)
if err != nil || info.IsDir() {
continue
}
extMap[p] = "/dev/null"
}
return extMap, skipped
}
// RemountProcSys remounts /proc/sys read-write or read-only.
func RemountProcSys(rw bool) error {
flags := uintptr(syscall.MS_BIND | syscall.MS_REMOUNT)
if !rw {
flags |= syscall.MS_RDONLY
}
if err := syscall.Mount("proc", "/proc/sys", "", flags, ""); err != nil {
mode := "rw"
if !rw {
mode = "ro"
}
return fmt.Errorf("failed to remount /proc/sys %s: %w", mode, err)
}
return nil
}
package common
import (
"os"
"path/filepath"
"testing"
specs "github.com/opencontainers/runtime-spec/specs-go"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/types"
)
func TestClassifyMounts(t *testing.T) {
tests := []struct {
name string
mounts []types.MountInfo
ociSpec *specs.Spec
rootFS string
want map[string]bool // mountpoint → expected IsOCIManaged
}{
{
name: "mount matching OCI destination",
mounts: []types.MountInfo{
{MountPoint: "/etc/hostname"},
},
ociSpec: &specs.Spec{
Mounts: []specs.Mount{{Destination: "/etc/hostname"}},
},
want: map[string]bool{"/etc/hostname": true},
},
{
name: "mount with no OCI match",
mounts: []types.MountInfo{
{MountPoint: "/some/random/path"},
},
ociSpec: &specs.Spec{
Mounts: []specs.Mount{{Destination: "/etc/hostname"}},
},
want: map[string]bool{"/some/random/path": false},
},
{
name: "/run/ mount aliased to /var/run/ in OCI spec",
mounts: []types.MountInfo{
{MountPoint: "/run/secrets"},
},
ociSpec: &specs.Spec{
Mounts: []specs.Mount{{Destination: "/var/run/secrets"}},
},
want: map[string]bool{"/run/secrets": true},
},
{
name: "/var/run/ mount aliased to /run/ in OCI spec",
mounts: []types.MountInfo{
{MountPoint: "/var/run/secrets"},
},
ociSpec: &specs.Spec{
Mounts: []specs.Mount{{Destination: "/run/secrets"}},
},
want: map[string]bool{"/var/run/secrets": true},
},
{
name: "/run/ prefix without alias match stays unmanaged",
mounts: []types.MountInfo{
{MountPoint: "/run/other"},
},
ociSpec: &specs.Spec{
Mounts: []specs.Mount{{Destination: "/var/run/different"}},
},
want: map[string]bool{"/run/other": false},
},
{
name: "nil OCI spec classifies nothing",
mounts: []types.MountInfo{{MountPoint: "/etc/hostname"}},
want: map[string]bool{"/etc/hostname": false},
},
{
name: "masked and readonly paths are OCI-managed",
mounts: []types.MountInfo{
{MountPoint: "/proc/acpi"},
{MountPoint: "/proc/sys"},
},
ociSpec: &specs.Spec{
Linux: &specs.Linux{
MaskedPaths: []string{"/proc/acpi"},
ReadonlyPaths: []string{"/proc/sys"},
},
},
want: map[string]bool{
"/proc/acpi": true,
"/proc/sys": true,
},
},
{
name: "empty mounts slice",
mounts: []types.MountInfo{},
ociSpec: &specs.Spec{
Mounts: []specs.Mount{{Destination: "/etc/hostname"}},
},
want: map[string]bool{},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := ClassifyMounts(tc.mounts, tc.ociSpec, tc.rootFS)
for _, m := range result {
expected, ok := tc.want[m.MountPoint]
if !ok {
continue
}
if m.IsOCIManaged != expected {
t.Errorf("mount %q: IsOCIManaged = %v, want %v", m.MountPoint, m.IsOCIManaged, expected)
}
}
})
}
}
func TestBuildMountPolicy(t *testing.T) {
tests := []struct {
name string
mounts []types.MountInfo
rootFS string
maskedPaths []string
wantExt map[string]string // expected entries in extMap
wantSkipped []string // expected entries in skipped
wantNotInExt []string // keys that must NOT be in extMap
}{
{
name: "non-OCI /proc submount is skipped",
mounts: []types.MountInfo{
{MountPoint: "/proc/kcore", IsOCIManaged: false},
},
wantSkipped: []string{"/proc/kcore"},
wantNotInExt: []string{"/proc/kcore"},
},
{
name: "non-OCI /sys submount is skipped",
mounts: []types.MountInfo{
{MountPoint: "/sys/firmware", IsOCIManaged: false},
},
wantSkipped: []string{"/sys/firmware"},
wantNotInExt: []string{"/sys/firmware"},
},
{
name: "non-OCI /run submount is skipped",
mounts: []types.MountInfo{
{MountPoint: "/run/some-daemon", IsOCIManaged: false},
},
wantSkipped: []string{"/run/some-daemon"},
wantNotInExt: []string{"/run/some-daemon"},
},
{
name: "OCI-managed /proc submount is externalized, not skipped",
mounts: []types.MountInfo{
{MountPoint: "/proc/acpi", IsOCIManaged: true},
},
wantExt: map[string]string{"/proc/acpi": "/proc/acpi"},
},
{
name: "/dev/shm tmpfs is not externalized",
mounts: []types.MountInfo{
{MountPoint: "/dev/shm", FSType: "tmpfs"},
},
wantNotInExt: []string{"/dev/shm"},
},
{
name: "/dev/shm non-tmpfs is externalized",
mounts: []types.MountInfo{
{MountPoint: "/dev/shm", FSType: "bind"},
},
wantExt: map[string]string{"/dev/shm": "/dev/shm"},
},
{
name: "normal mount is externalized",
mounts: []types.MountInfo{
{MountPoint: "/etc/hostname", IsOCIManaged: true},
},
wantExt: map[string]string{"/etc/hostname": "/etc/hostname"},
},
{
name: "empty mount point is ignored",
mounts: []types.MountInfo{
{MountPoint: ""},
},
wantExt: map[string]string{},
},
{
name: "masked path non-dir file maps to /dev/null",
mounts: []types.MountInfo{},
rootFS: func() string {
dir := t.TempDir()
os.WriteFile(filepath.Join(dir, "proc"), []byte("x"), 0644)
return dir
}(),
maskedPaths: []string{"/proc"},
wantExt: map[string]string{"/proc": "/dev/null"},
},
{
name: "masked path directory is ignored",
mounts: []types.MountInfo{},
rootFS: func() string {
dir := t.TempDir()
os.MkdirAll(filepath.Join(dir, "proc"), 0755)
return dir
}(),
maskedPaths: []string{"/proc"},
wantNotInExt: []string{"/proc"},
},
{
name: "masked path that doesn't exist is ignored",
mounts: []types.MountInfo{},
rootFS: t.TempDir(),
maskedPaths: []string{"/nonexistent"},
wantNotInExt: []string{"/nonexistent"},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
extMap, skipped := BuildMountPolicy(tc.mounts, tc.rootFS, tc.maskedPaths)
for k, v := range tc.wantExt {
got, ok := extMap[k]
if !ok {
t.Errorf("expected extMap[%q] to exist", k)
continue
}
if got != v {
t.Errorf("extMap[%q] = %q, want %q", k, got, v)
}
}
for _, k := range tc.wantNotInExt {
if _, ok := extMap[k]; ok {
t.Errorf("extMap should not contain %q", k)
}
}
if tc.wantSkipped != nil {
skippedSet := make(map[string]struct{}, len(skipped))
for _, s := range skipped {
skippedSet[s] = struct{}{}
}
for _, want := range tc.wantSkipped {
if _, ok := skippedSet[want]; !ok {
t.Errorf("expected %q in skipped list, got %v", want, skipped)
}
}
}
})
}
}
func TestNormalizeOCIPath(t *testing.T) {
tests := []struct {
name string
raw string
rootFS string
want string
}{
{name: "normal absolute path", raw: "/etc/hostname", want: "/etc/hostname"},
{name: "empty string", raw: "", want: ""},
{name: "whitespace only", raw: " ", want: ""},
{name: "dot path", raw: ".", want: ""},
{name: "path with trailing slashes cleaned", raw: "/etc/hostname///", want: "/etc/hostname"},
{
name: "with rootFS strips prefix via securejoin",
raw: "/etc/hostname",
// SecureJoin(rootFS, "/etc/hostname") → rootFS+"/etc/hostname", then strip rootFS prefix
rootFS: "/tmp/fakefs",
want: "/etc/hostname",
},
{
name: "root path with rootFS returns /",
raw: "/",
rootFS: "/tmp/fakefs",
want: "/",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := normalizeOCIPath(tc.raw, tc.rootFS)
if got != tc.want {
t.Errorf("normalizeOCIPath(%q, %q) = %q, want %q", tc.raw, tc.rootFS, got, tc.want)
}
})
}
}
package common
import (
"context"
"fmt"
"os/exec"
"strconv"
"strings"
"syscall"
"github.com/go-logr/logr"
"golang.org/x/sys/unix"
)
// GetNetNSInode returns the network namespace inode for a container process via /host/proc.
func GetNetNSInode(pid int) (uint64, error) {
nsPath := fmt.Sprintf("%s/%d/ns/net", HostProcPath, pid)
var stat unix.Stat_t
if err := unix.Stat(nsPath, &stat); err != nil {
return 0, fmt.Errorf("failed to stat %s: %w", nsPath, err)
}
return stat.Ino, nil
}
// SendSignalViaPIDNamespace sends a signal to a namespace-relative PID by entering the
// PID namespace of referenceHostPID via nsenter.
func SendSignalViaPIDNamespace(ctx context.Context, log logr.Logger, referenceHostPID, targetNamespacePID int, sig syscall.Signal, reason string) error {
if referenceHostPID <= 0 {
return fmt.Errorf("invalid reference host PID %d for signal %d", referenceHostPID, int(sig))
}
if targetNamespacePID <= 0 {
return fmt.Errorf("invalid namespace PID %d for signal %d", targetNamespacePID, int(sig))
}
cmd := exec.CommandContext(
ctx,
"nsenter",
"-t", strconv.Itoa(referenceHostPID),
"-p",
"--",
"kill",
fmt.Sprintf("-%d", int(sig)),
strconv.Itoa(targetNamespacePID),
)
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf(
"failed to signal namespace PID %d via reference host PID %d with signal %d (%s): %w (output: %s)",
targetNamespacePID, referenceHostPID, int(sig), reason, err, strings.TrimSpace(string(output)),
)
}
log.Info("Signaled runtime process in PID namespace",
"reference_host_pid", referenceHostPID,
"namespace_pid", targetNamespacePID,
"signal", int(sig),
"reason", reason,
)
return nil
}
// Package common provides low-level container, process, and device primitives
// shared across chrek packages.
package common
import (
"context"
"fmt"
"path/filepath"
"strings"
"github.com/containerd/containerd"
"github.com/containerd/containerd/namespaces"
specs "github.com/opencontainers/runtime-spec/specs-go"
securejoin "github.com/cyphar/filepath-securejoin"
)
const (
// k8sNamespace is the containerd namespace used by Kubernetes.
k8sNamespace = "k8s.io"
// ContainerdSocket is the default containerd socket path.
ContainerdSocket = "/run/containerd/containerd.sock"
)
// ResolveContainer resolves a container by ID and returns its PID and OCI spec.
func ResolveContainer(ctx context.Context, client *containerd.Client, containerID string) (int, *specs.Spec, error) {
ctx = namespaces.WithNamespace(ctx, k8sNamespace)
container, err := client.LoadContainer(ctx, containerID)
if err != nil {
return 0, nil, fmt.Errorf("failed to load container %s: %w", containerID, err)
}
task, err := container.Task(ctx, nil)
if err != nil {
return 0, nil, fmt.Errorf("failed to get task for container %s: %w", containerID, err)
}
spec, err := container.Spec(ctx)
if err != nil {
return 0, nil, fmt.Errorf("failed to get spec for container %s: %w", containerID, err)
}
return int(task.Pid()), spec, nil
}
// ResolveContainerByPod finds a container by pod name, namespace, and container name
// by listing containerd containers and matching CRI labels.
func ResolveContainerByPod(ctx context.Context, client *containerd.Client, podName, podNamespace, containerName string) (int, *specs.Spec, error) {
ctx = namespaces.WithNamespace(ctx, k8sNamespace)
filter := fmt.Sprintf("labels.\"io.kubernetes.pod.name\"==%s,labels.\"io.kubernetes.pod.namespace\"==%s,labels.\"io.kubernetes.container.name\"==%s",
podName, podNamespace, containerName)
containers, err := client.Containers(ctx, filter)
if err != nil {
return 0, nil, fmt.Errorf("failed to list containers for pod %s/%s: %w", podNamespace, podName, err)
}
if len(containers) == 0 {
return 0, nil, fmt.Errorf("no container found for pod %s/%s container %s", podNamespace, podName, containerName)
}
// During container restarts, containerd may transiently expose both the
// old and new container with the same CRI labels. Pick the one with a
// running task; fall back to the first container if none qualify.
for _, c := range containers {
task, err := c.Task(ctx, nil)
if err != nil {
continue
}
spec, err := c.Spec(ctx)
if err != nil {
continue
}
return int(task.Pid()), spec, nil
}
return 0, nil, fmt.Errorf("no running container found for pod %s/%s container %s (%d candidates)", podNamespace, podName, containerName, len(containers))
}
func collectOCIManagedPaths(ociSpec *specs.Spec, rootFS string) map[string]struct{} {
set := map[string]struct{}{}
if ociSpec == nil {
return set
}
paths := make([]string, 0, len(ociSpec.Mounts))
for _, mount := range ociSpec.Mounts {
paths = append(paths, mount.Destination)
}
if ociSpec.Linux != nil {
paths = append(paths, ociSpec.Linux.MaskedPaths...)
paths = append(paths, ociSpec.Linux.ReadonlyPaths...)
}
for _, raw := range paths {
if p := normalizeOCIPath(raw, rootFS); p != "" {
set[p] = struct{}{}
}
}
return set
}
// normalizeOCIPath resolves an OCI spec path relative to rootFS, following
// symlinks within the rootfs boundary (matching runc's addCriuDumpMount pattern).
func normalizeOCIPath(raw, rootFS string) string {
p := filepath.Clean(strings.TrimSpace(raw))
if p == "" || p == "." {
return ""
}
if rootFS == "" {
return p
}
if resolved, err := securejoin.SecureJoin(rootFS, p); err == nil {
p = strings.TrimPrefix(resolved, filepath.Clean(rootFS))
}
if !strings.HasPrefix(p, "/") {
p = "/" + p
}
return p
}
// filesystem.go provides container rootfs introspection, filesystem config/metadata types,
// and rootfs diff capture for CRIU checkpoint.
package checkpoint
package common
import (
"encoding/json"
......@@ -10,109 +8,28 @@ import (
"path/filepath"
"strings"
specs "github.com/opencontainers/runtime-spec/specs-go"
"github.com/sirupsen/logrus"
)
// FilesystemConfig is the static config for rootfs exclusions (from values.yaml).
type FilesystemConfig struct {
// SystemDirs are system directories that should be excluded from rootfs diff.
// These directories are typically injected/bind-mounted by NVIDIA GPU Operator
// at container start time, so they already exist in the restore target.
// Excluding them prevents conflicts (especially socket files which cannot be overwritten).
// Default: ["./usr", "./etc", "./opt", "./var", "./run"]
SystemDirs []string `yaml:"systemDirs"`
// CacheDirs are cache directories that can safely be excluded to reduce checkpoint size.
// Model weights and other cached data are typically re-downloaded if needed.
// Default: ["./.cache/huggingface", "./.cache/torch"]
CacheDirs []string `yaml:"cacheDirs"`
// AdditionalExclusions are custom paths to exclude from the rootfs diff.
// Use this for application-specific exclusions.
// Paths should be relative with "./" prefix (e.g., "./data/temp").
AdditionalExclusions []string `yaml:"additionalExclusions"`
}
// GetAllExclusions returns all exclusion paths combined.
// This is used when building tar arguments for rootfs diff capture.
func (c *FilesystemConfig) GetAllExclusions() []string {
if c == nil {
return nil
}
total := len(c.SystemDirs) + len(c.CacheDirs) + len(c.AdditionalExclusions)
exclusions := make([]string, 0, total)
exclusions = append(exclusions, c.SystemDirs...)
exclusions = append(exclusions, c.CacheDirs...)
exclusions = append(exclusions, c.AdditionalExclusions...)
return exclusions
}
// Validate checks that the FilesystemConfig has valid values.
func (c *FilesystemConfig) Validate() error {
if c == nil {
return nil
}
// All paths should start with "./" for tar relative path handling
for _, path := range c.GetAllExclusions() {
if !strings.HasPrefix(path, "./") {
return &ConfigError{
Field: "rootfsExclusions",
Message: "all exclusion paths must start with './' (got: " + path + ")",
}
}
}
return nil
}
// FilesystemManifest holds runtime filesystem state captured at checkpoint time.
type FilesystemManifest struct {
Exclusions FilesystemConfig `yaml:"exclusions"`
UpperDir string `yaml:"upperDir,omitempty"`
ExternalPaths []string `yaml:"externalPaths,omitempty"`
BindMountDests []string `yaml:"bindMountDests,omitempty"`
HasRootfsDiff bool `yaml:"hasRootfsDiff"`
HasDeletedFiles bool `yaml:"hasDeletedFiles"`
}
"github.com/go-logr/logr"
// NewFilesystemManifest constructs FilesystemManifest from config, overlay state, and OCI spec.
func NewFilesystemManifest(exclusions FilesystemConfig, upperDir string, ociSpec *specs.Spec) FilesystemManifest {
meta := FilesystemManifest{
Exclusions: exclusions,
UpperDir: upperDir,
}
if ociSpec == nil {
return meta
}
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/types"
)
if ociSpec.Linux != nil {
meta.ExternalPaths = make([]string, 0, len(ociSpec.Linux.MaskedPaths)+len(ociSpec.Linux.ReadonlyPaths))
meta.ExternalPaths = append(meta.ExternalPaths, ociSpec.Linux.MaskedPaths...)
meta.ExternalPaths = append(meta.ExternalPaths, ociSpec.Linux.ReadonlyPaths...)
}
for _, m := range ociSpec.Mounts {
if m.Type == "bind" {
meta.BindMountDests = append(meta.BindMountDests, m.Destination)
}
}
return meta
}
const (
rootfsDiffFilename = "rootfs-diff.tar"
deletedFilesFilename = "deleted-files.json"
)
// GetRootFS returns the container's root filesystem path.
// GetRootFS returns the container's root filesystem path via /host/proc.
func GetRootFS(pid int) (string, error) {
rootPath := fmt.Sprintf("%s/%d/root", HostProcPath, pid)
if _, err := os.Stat(rootPath); err != nil {
return "", fmt.Errorf("rootfs not accessible at %s: %w", rootPath, err)
}
return rootPath, nil
}
// GetOverlayUpperDir extracts the overlay upperdir from mountinfo.
// This is the writable layer of the container's filesystem.
func GetOverlayUpperDir(pid int) (string, error) {
mountInfo, err := ReadMountInfoFromHostProcPath(pid)
mountInfo, err := ReadMountInfo(pid)
if err != nil {
return "", fmt.Errorf("failed to parse mountinfo: %w", err)
}
......@@ -122,8 +39,7 @@ func GetOverlayUpperDir(pid int) (string, error) {
continue
}
// Parse super options to find upperdir
for _, opt := range strings.Split(mount.SuperOptions, ",") {
for _, opt := range strings.Split(mount.VFSOptions, ",") {
if strings.HasPrefix(opt, "upperdir=") {
return strings.TrimPrefix(opt, "upperdir="), nil
}
......@@ -134,29 +50,18 @@ func GetOverlayUpperDir(pid int) (string, error) {
}
// CaptureRootfsDiff captures the overlay upperdir to a tar file.
// The upperdir contains all filesystem modifications made by the container.
// Excludes bind mount destinations and configured directories to avoid conflicts during restore.
// Returns the path to the tar file or empty string if capture failed.
func CaptureRootfsDiff(upperDir, checkpointDir string, exclusions *FilesystemConfig, bindMountDests []string) (string, error) {
func CaptureRootfsDiff(upperDir, checkpointDir string, exclusions types.OverlaySettings, bindMountDests []string) (string, error) {
if upperDir == "" {
return "", fmt.Errorf("upperdir is empty")
}
rootfsDiffPath := filepath.Join(checkpointDir, RootfsDiffFilename)
rootfsDiffPath := filepath.Join(checkpointDir, rootfsDiffFilename)
// Build tar arguments with xattrs and exclusions
tarArgs := []string{"--xattrs"}
// Add configured exclusions (systemDirs, cacheDirs, additionalExclusions from values.yaml)
if exclusions != nil {
for _, excl := range exclusions.GetAllExclusions() {
tarArgs = append(tarArgs, "--exclude="+excl)
}
for _, excl := range buildExclusions(exclusions) {
tarArgs = append(tarArgs, "--exclude="+excl)
}
// Add bind mount exclusions passed from caller
for _, dest := range bindMountDests {
// Convert absolute path to relative for tar (e.g., /etc/hosts -> ./etc/hosts)
tarArgs = append(tarArgs, "--exclude=."+dest)
}
tarArgs = append(tarArgs, "-C", upperDir, "-cf", rootfsDiffPath, ".")
......@@ -170,14 +75,31 @@ func CaptureRootfsDiff(upperDir, checkpointDir string, exclusions *FilesystemCon
return rootfsDiffPath, nil
}
// buildExclusions merges exclusion lists and normalizes paths for tar --exclude patterns.
func buildExclusions(s types.OverlaySettings) []string {
total := len(s.SystemDirs) + len(s.CacheDirs) + len(s.AdditionalExclusions)
exclusions := make([]string, 0, total)
exclusions = append(exclusions, s.SystemDirs...)
exclusions = append(exclusions, s.CacheDirs...)
exclusions = append(exclusions, s.AdditionalExclusions...)
for i, p := range exclusions {
if strings.HasPrefix(p, "*") {
continue
}
p = strings.TrimPrefix(p, ".")
p = strings.TrimPrefix(p, "/")
exclusions[i] = "./" + p
}
return exclusions
}
// CaptureDeletedFiles finds whiteout files and saves them to a JSON file.
// Returns true if deleted files were found and saved.
func CaptureDeletedFiles(upperDir, checkpointDir string) (bool, error) {
if upperDir == "" {
return false, nil
}
whiteouts, err := FindWhiteoutFiles(upperDir)
whiteouts, err := findWhiteoutFiles(upperDir)
if err != nil {
return false, fmt.Errorf("failed to find whiteout files: %w", err)
}
......@@ -186,7 +108,7 @@ func CaptureDeletedFiles(upperDir, checkpointDir string) (bool, error) {
return false, nil
}
deletedFilesPath := filepath.Join(checkpointDir, DeletedFilesFilename)
deletedFilesPath := filepath.Join(checkpointDir, deletedFilesFilename)
data, err := json.Marshal(whiteouts)
if err != nil {
return false, fmt.Errorf("failed to marshal whiteouts: %w", err)
......@@ -199,10 +121,77 @@ func CaptureDeletedFiles(upperDir, checkpointDir string) (bool, error) {
return true, nil
}
// FindWhiteoutFiles finds overlay whiteout files in the upperdir.
// Overlay filesystems use .wh.<filename> to mark deleted files.
// Returns a list of paths that were deleted in the container.
func FindWhiteoutFiles(upperDir string) ([]string, error) {
// ApplyRootfsDiff extracts rootfs-diff.tar into the target root.
func ApplyRootfsDiff(checkpointPath, targetRoot string, log logr.Logger) error {
rootfsDiffPath := filepath.Join(checkpointPath, rootfsDiffFilename)
if _, err := os.Stat(rootfsDiffPath); os.IsNotExist(err) {
log.V(1).Info("No rootfs-diff.tar, skipping")
return nil
}
// --skip-old-files: silently skip files that already exist in the restore target.
// The rootfs diff only contains overlay upperdir changes (runtime-generated files
// like triton caches, tmp files) — base image files should not be overwritten.
log.Info("Applying rootfs diff", "target", targetRoot)
cmd := exec.Command("tar", "--skip-old-files", "-C", targetRoot, "-xf", rootfsDiffPath)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("tar extract failed: %w", err)
}
return nil
}
// ApplyDeletedFiles removes files marked as deleted in the checkpoint.
func ApplyDeletedFiles(checkpointPath, targetRoot string, log logr.Logger) error {
deletedFilesPath := filepath.Join(checkpointPath, deletedFilesFilename)
data, err := os.ReadFile(deletedFilesPath)
if os.IsNotExist(err) {
return nil
}
if err != nil {
return fmt.Errorf("failed to read deleted files: %w", err)
}
var deletedFiles []string
if err := json.Unmarshal(data, &deletedFiles); err != nil {
return fmt.Errorf("failed to parse deleted files: %w", err)
}
count := 0
targetRootAbs, err := filepath.Abs(targetRoot)
if err != nil {
return fmt.Errorf("failed to resolve target root %s: %w", targetRoot, err)
}
targetRootPrefix := targetRootAbs + string(os.PathSeparator)
for _, f := range deletedFiles {
if f == "" {
continue
}
target := filepath.Join(targetRoot, f)
targetAbs, err := filepath.Abs(target)
if err != nil || (targetAbs != targetRootAbs && !strings.HasPrefix(targetAbs, targetRootPrefix)) {
log.V(1).Info("Skipping out-of-root deleted file entry", "entry", f)
continue
}
if _, err := os.Stat(target); os.IsNotExist(err) {
continue
} else if err != nil {
log.V(1).Info("Could not stat deleted file target", "path", target, "error", err)
continue
}
if err := os.RemoveAll(target); err != nil {
log.V(1).Info("Could not delete file", "path", target, "error", err)
continue
}
count++
}
log.Info("Deleted files applied", "count", count)
return nil
}
// findWhiteoutFiles finds overlay whiteout files in the upperdir.
func findWhiteoutFiles(upperDir string) ([]string, error) {
var whiteouts []string
err := filepath.Walk(upperDir, func(path string, info os.FileInfo, err error) error {
......@@ -212,8 +201,10 @@ func FindWhiteoutFiles(upperDir string) ([]string, error) {
name := info.Name()
if strings.HasPrefix(name, ".wh.") {
// Convert whiteout marker to actual deleted path
relPath, _ := filepath.Rel(upperDir, path)
relPath, err := filepath.Rel(upperDir, path)
if err != nil {
return fmt.Errorf("failed to compute relative path for %s: %w", path, err)
}
dir := filepath.Dir(relPath)
deletedFile := strings.TrimPrefix(name, ".wh.")
deletedPath := deletedFile
......@@ -227,42 +218,3 @@ func FindWhiteoutFiles(upperDir string) ([]string, error) {
return whiteouts, err
}
// CaptureRootfsState captures the overlay upperdir and deleted files after CRIU dump.
// Updates the checkpoint manifest with rootfs diff information and saves it.
func CaptureRootfsState(upperDir, checkpointDir string, data *CheckpointManifest, log *logrus.Entry) {
if upperDir == "" || data == nil {
return
}
// Capture rootfs diff using exclusions from the checkpoint manifest.
configuredExclusions := data.Filesystem.Exclusions.GetAllExclusions()
log.WithFields(logrus.Fields{
"configured_exclusions": configuredExclusions,
"bind_mount_exclusions": data.Filesystem.BindMountDests,
}).Debug("Rootfs diff exclusions")
rootfsDiffPath, err := CaptureRootfsDiff(upperDir, checkpointDir, &data.Filesystem.Exclusions, data.Filesystem.BindMountDests)
if err != nil {
log.WithError(err).Warn("Failed to capture rootfs diff")
} else {
data.Filesystem.HasRootfsDiff = true
log.WithFields(logrus.Fields{
"upperdir": upperDir,
"tar_path": rootfsDiffPath,
}).Info("Captured rootfs diff")
}
// Capture deleted files (whiteouts)
hasDeletedFiles, err := CaptureDeletedFiles(upperDir, checkpointDir)
if err != nil {
log.WithError(err).Warn("Failed to capture deleted files")
} else if hasDeletedFiles {
data.Filesystem.HasDeletedFiles = true
log.Info("Recorded deleted files (whiteouts)")
}
// Update checkpoint manifest with rootfs diff info.
if err := WriteCheckpointManifest(checkpointDir, data); err != nil {
log.WithError(err).Warn("Failed to update checkpoint manifest with rootfs diff info")
}
}
package common
import (
"encoding/json"
"os"
"path/filepath"
"testing"
"github.com/go-logr/logr/testr"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/types"
)
func TestBuildExclusions(t *testing.T) {
tests := []struct {
name string
settings types.OverlaySettings
want map[string]bool // expected entries (true = must be present)
}{
{
name: "merges all lists and normalizes paths",
settings: types.OverlaySettings{
SystemDirs: []string{"/proc", "/sys"},
CacheDirs: []string{"/root/.cache"},
AdditionalExclusions: []string{"/tmp"},
},
want: map[string]bool{
"./proc": true,
"./sys": true,
"./root/.cache": true,
"./tmp": true,
},
},
{
name: "strips leading dot and slash before prepending ./",
settings: types.OverlaySettings{
SystemDirs: []string{"./proc", "/sys", "tmp"},
},
want: map[string]bool{
"./proc": true,
"./sys": true,
"./tmp": true,
},
},
{
name: "glob patterns starting with * are untouched",
settings: types.OverlaySettings{
AdditionalExclusions: []string{"*.pyc", "*/__pycache__"},
},
want: map[string]bool{
"*.pyc": true,
"*/__pycache__": true,
},
},
{
name: "empty settings produces empty slice",
settings: types.OverlaySettings{},
want: map[string]bool{},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := buildExclusions(tc.settings)
gotSet := make(map[string]bool, len(got))
for _, v := range got {
gotSet[v] = true
}
for expected := range tc.want {
if !gotSet[expected] {
t.Errorf("expected %q in exclusions, got %v", expected, got)
}
}
if len(got) != len(tc.want) {
t.Errorf("len(exclusions) = %d, want %d; got %v", len(got), len(tc.want), got)
}
})
}
}
func TestFindWhiteoutFiles(t *testing.T) {
tests := []struct {
name string
setup func(dir string) // create files in temp dir
want []string
}{
{
name: "top-level whiteout",
setup: func(dir string) {
os.WriteFile(filepath.Join(dir, ".wh.somefile"), nil, 0644)
},
want: []string{"somefile"},
},
{
name: "nested whiteout returns relative path",
setup: func(dir string) {
sub := filepath.Join(dir, "subdir")
os.MkdirAll(sub, 0755)
os.WriteFile(filepath.Join(sub, ".wh.nested"), nil, 0644)
},
want: []string{"subdir/nested"},
},
{
name: "no whiteouts returns empty",
setup: func(dir string) { os.WriteFile(filepath.Join(dir, "regular"), nil, 0644) },
want: nil,
},
{
name: "empty dir returns empty",
setup: func(dir string) {},
want: nil,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
dir := t.TempDir()
tc.setup(dir)
got, err := findWhiteoutFiles(dir)
if err != nil {
t.Fatalf("findWhiteoutFiles: %v", err)
}
if len(got) != len(tc.want) {
t.Fatalf("got %v, want %v", got, tc.want)
}
for i := range tc.want {
if got[i] != tc.want[i] {
t.Errorf("got[%d] = %q, want %q", i, got[i], tc.want[i])
}
}
})
}
}
func TestCaptureDeletedFiles(t *testing.T) {
t.Run("dir with whiteouts writes JSON and returns true", func(t *testing.T) {
upperDir := t.TempDir()
checkpointDir := t.TempDir()
os.WriteFile(filepath.Join(upperDir, ".wh.removed"), nil, 0644)
found, err := CaptureDeletedFiles(upperDir, checkpointDir)
if err != nil {
t.Fatalf("CaptureDeletedFiles: %v", err)
}
if !found {
t.Fatal("expected found=true")
}
data, err := os.ReadFile(filepath.Join(checkpointDir, deletedFilesFilename))
if err != nil {
t.Fatalf("read deleted-files.json: %v", err)
}
var files []string
if err := json.Unmarshal(data, &files); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if len(files) != 1 || files[0] != "removed" {
t.Errorf("got %v, want [removed]", files)
}
})
t.Run("dir with no whiteouts returns false and no file", func(t *testing.T) {
upperDir := t.TempDir()
checkpointDir := t.TempDir()
os.WriteFile(filepath.Join(upperDir, "normalfile"), nil, 0644)
found, err := CaptureDeletedFiles(upperDir, checkpointDir)
if err != nil {
t.Fatalf("CaptureDeletedFiles: %v", err)
}
if found {
t.Fatal("expected found=false")
}
if _, err := os.Stat(filepath.Join(checkpointDir, deletedFilesFilename)); !os.IsNotExist(err) {
t.Error("deleted-files.json should not exist")
}
})
t.Run("empty upperDir returns false", func(t *testing.T) {
found, err := CaptureDeletedFiles("", t.TempDir())
if err != nil {
t.Fatalf("CaptureDeletedFiles: %v", err)
}
if found {
t.Fatal("expected found=false for empty upperDir")
}
})
}
func TestApplyDeletedFiles(t *testing.T) {
log := testr.New(t)
t.Run("deletes listed files from target", func(t *testing.T) {
checkpointDir := t.TempDir()
targetRoot := t.TempDir()
// Create target file that should be deleted
os.WriteFile(filepath.Join(targetRoot, "old-cache"), []byte("data"), 0644)
// Write deleted-files.json
data, _ := json.Marshal([]string{"old-cache"})
os.WriteFile(filepath.Join(checkpointDir, deletedFilesFilename), data, 0644)
if err := ApplyDeletedFiles(checkpointDir, targetRoot, log); err != nil {
t.Fatalf("ApplyDeletedFiles: %v", err)
}
if _, err := os.Stat(filepath.Join(targetRoot, "old-cache")); !os.IsNotExist(err) {
t.Error("old-cache should have been deleted")
}
})
t.Run("missing deleted-files.json is a no-op", func(t *testing.T) {
if err := ApplyDeletedFiles(t.TempDir(), t.TempDir(), log); err != nil {
t.Fatalf("ApplyDeletedFiles: %v", err)
}
})
t.Run("path traversal entry is skipped", func(t *testing.T) {
checkpointDir := t.TempDir()
targetRoot := t.TempDir()
// Create a file outside targetRoot that the traversal would try to delete
outsideDir := t.TempDir()
secretFile := filepath.Join(outsideDir, "passwd")
os.WriteFile(secretFile, []byte("secret"), 0644)
// Construct a relative path that escapes targetRoot
rel, _ := filepath.Rel(targetRoot, secretFile)
data, _ := json.Marshal([]string{rel})
os.WriteFile(filepath.Join(checkpointDir, deletedFilesFilename), data, 0644)
if err := ApplyDeletedFiles(checkpointDir, targetRoot, log); err != nil {
t.Fatalf("ApplyDeletedFiles: %v", err)
}
// The file outside targetRoot must still exist
if _, err := os.Stat(secretFile); err != nil {
t.Error("path traversal should have been blocked, but file was deleted")
}
})
t.Run("already-missing file causes no error", func(t *testing.T) {
checkpointDir := t.TempDir()
targetRoot := t.TempDir()
data, _ := json.Marshal([]string{"nonexistent"})
os.WriteFile(filepath.Join(checkpointDir, deletedFilesFilename), data, 0644)
if err := ApplyDeletedFiles(checkpointDir, targetRoot, log); err != nil {
t.Fatalf("ApplyDeletedFiles: %v", err)
}
})
t.Run("empty entry is skipped", func(t *testing.T) {
checkpointDir := t.TempDir()
targetRoot := t.TempDir()
data, _ := json.Marshal([]string{""})
os.WriteFile(filepath.Join(checkpointDir, deletedFilesFilename), data, 0644)
if err := ApplyDeletedFiles(checkpointDir, targetRoot, log); err != nil {
t.Fatalf("ApplyDeletedFiles: %v", err)
}
})
}
package common
import (
"fmt"
"os"
"strconv"
"strings"
"syscall"
"github.com/go-logr/logr"
"github.com/prometheus/procfs"
)
// HostProcPath is the mount point for the host's /proc in DaemonSet pods.
const HostProcPath = "/host/proc"
// ProcessTreePIDs walks the process tree rooted at rootPID and returns all PIDs.
// Used by nsrestore for in-namespace CUDA PID discovery.
func ProcessTreePIDs(rootPID int) []int {
if rootPID <= 0 {
return nil
}
queue := []int{rootPID}
seen := map[int]struct{}{}
all := make([]int, 0, 16)
for len(queue) > 0 {
pid := queue[0]
queue = queue[1:]
if _, ok := seen[pid]; ok {
continue
}
seen[pid] = struct{}{}
if _, err := os.Stat(fmt.Sprintf("/proc/%d", pid)); err != nil {
continue
}
all = append(all, pid)
// Iterate all threads — child processes can be spawned from any thread, not just the main thread (tid==pid).
taskDir := fmt.Sprintf("/proc/%d/task", pid)
tids, err := os.ReadDir(taskDir)
if err != nil {
continue
}
for _, tid := range tids {
children, err := os.ReadFile(fmt.Sprintf("%s/%s/children", taskDir, tid.Name()))
if err != nil {
continue
}
for _, child := range strings.Fields(string(children)) {
childPID, err := strconv.Atoi(child)
if err != nil {
continue
}
queue = append(queue, childPID)
}
}
}
return all
}
// ValidateProcessState checks that a process is alive and not a zombie.
func ValidateProcessState(procRoot string, pid int) error {
if pid <= 0 {
return fmt.Errorf("invalid restored PID %d", pid)
}
fs, err := procfs.NewFS(procRoot)
if err != nil {
return fmt.Errorf("failed to open procfs at %s: %w", procRoot, err)
}
proc, err := fs.Proc(pid)
if err != nil {
return fmt.Errorf("process %d exited", pid)
}
stat, err := proc.Stat()
if err != nil {
return fmt.Errorf("failed to inspect process %d: %w", pid, err)
}
if stat.State == "Z" {
return fmt.Errorf("process %d became zombie", pid)
}
return nil
}
// ParseProcExitCode extracts and decodes the exit_code field (field 52) from a /proc/<pid>/stat line.
func ParseProcExitCode(statLine string) (syscall.WaitStatus, error) {
statLine = strings.TrimSpace(statLine)
paren := strings.LastIndex(statLine, ")")
if paren < 0 || paren+2 > len(statLine) {
return 0, fmt.Errorf("malformed stat line")
}
fields := strings.Fields(statLine[paren+2:])
if len(fields) == 0 {
return 0, fmt.Errorf("malformed stat fields")
}
raw, err := strconv.Atoi(fields[len(fields)-1])
if err != nil {
return 0, err
}
return syscall.WaitStatus(raw), nil
}
// SendSignalToPID sends a signal to a host-visible PID via syscall.Kill.
func SendSignalToPID(log logr.Logger, pid int, sig syscall.Signal, reason string) error {
signalID := int(sig)
if pid <= 0 {
return fmt.Errorf("invalid PID %d for signal %d", pid, signalID)
}
if err := syscall.Kill(pid, sig); err != nil {
return fmt.Errorf("failed to signal PID %d with signal %d (%s): %w", pid, signalID, reason, err)
}
log.Info("Signaled runtime process", "pid", pid, "signal", signalID, "reason", reason)
return nil
}
package common
import (
"testing"
)
func TestParseProcExitCode(t *testing.T) {
tests := []struct {
name string
statLine string
wantCode int
wantErr bool
}{
{
// Real /proc/<pid>/stat line (simplified). Fields after ")" start with state.
// The last field (field 52) is exit_code.
name: "normal exit code 0",
statLine: "123 (python3) S 1 123 123 0 -1 4194304 1000 0 0 0 100 50 0 0 20 0 1 0 1000 10000000 500 18446744073709551615 0 0 0 0 0 0 0 0 0 0 0 0 17 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0",
wantCode: 0,
},
{
name: "non-zero exit code",
statLine: "456 (bash) Z 1 456 456 0 -1 4194304 100 0 0 0 10 5 0 0 20 0 1 0 500 0 0 18446744073709551615 0 0 0 0 0 0 0 0 0 0 0 0 17 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 256",
wantCode: 256, // signal 1 encoded as WaitStatus
},
{
// Process names can contain spaces and parentheses.
// The parser must use LastIndex(")") to handle this correctly.
name: "process name with spaces and parens",
statLine: "789 (python3 -m vllm.entrypoints.openai.api_server (worker)) S 1 789 789 0 -1 0 0 0 0 0 0 0 0 0 20 0 1 0 100 0 0 0 0 0 0 0 0 0 0 0 0 0 17 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 42",
wantCode: 42,
},
{
name: "malformed line no closing paren",
statLine: "123 (python3 S 1 123",
wantErr: true,
},
{
name: "empty string",
statLine: "",
wantErr: true,
},
{
name: "only pid and comm, nothing after paren",
statLine: "1 (init)",
wantErr: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ws, err := ParseProcExitCode(tc.statLine)
if tc.wantErr {
if err == nil {
t.Errorf("expected error, got WaitStatus=%d", ws)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if int(ws) != tc.wantCode {
t.Errorf("exit code = %d, want %d", int(ws), tc.wantCode)
}
})
}
}
package criu
import (
"fmt"
"os"
"path/filepath"
"strings"
"time"
criulib "github.com/checkpoint-restore/go-criu/v8"
criurpc "github.com/checkpoint-restore/go-criu/v8/rpc"
"github.com/go-logr/logr"
"google.golang.org/protobuf/proto"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/common"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/types"
)
const (
dumpLogFilename = "dump.log"
criuConfFilename = "criu.conf"
)
// BuildDumpOptions creates CRIU options from the container snapshot and settings.
// It also writes the criu.conf file for options that cannot be passed via RPC.
// The ImagesDirFd is left unset — ExecuteDump opens it at dump time.
func BuildDumpOptions(
state *types.CheckpointContainerSnapshot,
settings *types.CRIUSettings,
checkpointDir string,
log logr.Logger,
) (*criurpc.CriuOpts, error) {
var maskedPaths []string
if state.OCISpec != nil && state.OCISpec.Linux != nil {
maskedPaths = state.OCISpec.Linux.MaskedPaths
}
externalized, skipped := common.BuildMountPolicy(state.Mounts, state.RootFS, maskedPaths)
log.V(1).Info("Resolved mount policy for CRIU dump",
"externalized_count", len(externalized),
"skipped_count", len(skipped),
)
criuOpts := &criurpc.CriuOpts{
Pid: proto.Int32(int32(state.PID)),
Root: proto.String(state.RootFS),
LogFile: proto.String(dumpLogFilename),
// Always externalize network namespace
External: []string{fmt.Sprintf("net[%d]:extNetNs", state.NetNSInode)},
}
criuOpts.ExtMnt = toExtMountMaps(externalized)
criuOpts.SkipMnt = skipped
if state.HostCgroupPath != "" {
criuOpts.FreezeCgroup = proto.String(state.HostCgroupPath)
}
if settings == nil {
return criuOpts, nil
}
if err := applyCommonSettings(criuOpts, settings); err != nil {
return nil, err
}
// Dump-only options
criuOpts.LeaveRunning = proto.Bool(settings.LeaveRunning)
criuOpts.OrphanPtsMaster = proto.Bool(settings.OrphanPtsMaster)
criuOpts.ExtMasters = proto.Bool(settings.ExtMasters)
criuOpts.AutoDedup = proto.Bool(settings.AutoDedup)
criuOpts.LazyPages = proto.Bool(settings.LazyPages)
if settings.GhostLimit > 0 {
criuOpts.GhostLimit = proto.Uint32(settings.GhostLimit)
}
// Write criu.conf for options that cannot be passed via RPC.
if confContent := buildCRIUConf(settings); confContent != "" {
confPath := filepath.Join(checkpointDir, criuConfFilename)
if err := os.WriteFile(confPath, []byte(confContent), 0644); err != nil {
return nil, fmt.Errorf("failed to write criu.conf: %w", err)
}
criuOpts.ConfigFile = proto.String(confPath)
}
return criuOpts, nil
}
// ExecuteDump opens the image directory FD, runs the CRIU dump, and cleans up.
func ExecuteDump(
criuOpts *criurpc.CriuOpts,
checkpointDir string,
settings *types.CRIUSettings,
log logr.Logger,
) (time.Duration, error) {
imageDir, imageDirFD, err := openPathForCRIU(checkpointDir)
if err != nil {
return 0, fmt.Errorf("failed to open image directory: %w", err)
}
defer imageDir.Close()
criuOpts.ImagesDirFd = proto.Int32(imageDirFD)
criuDumpStart := time.Now()
criuClient := criulib.MakeCriu()
if settings != nil && strings.TrimSpace(settings.BinaryPath) != "" {
if _, err := os.Stat(settings.BinaryPath); err != nil {
return 0, fmt.Errorf("criu binary not found at %s: %w", settings.BinaryPath, err)
}
criuClient.SetCriuPath(settings.BinaryPath)
}
if err := criuClient.Dump(criuOpts, nil); err != nil {
dumpDuration := time.Since(criuDumpStart)
log.Error(err, "CRIU dump failed",
"duration", dumpDuration,
"checkpoint_dir", checkpointDir,
"dump_log_path", fmt.Sprintf("%s/%s", checkpointDir, dumpLogFilename),
)
return 0, fmt.Errorf("CRIU dump failed: %w", err)
}
criuDumpDuration := time.Since(criuDumpStart)
log.Info("CRIU dump completed", "duration", criuDumpDuration)
return criuDumpDuration, nil
}
func buildCRIUConf(c *types.CRIUSettings) string {
if c == nil {
return ""
}
var content string
if c.LibDir != "" {
content += "libdir " + c.LibDir + "\n"
}
if c.AllowUprobes {
content += "allow-uprobes\n"
}
if c.SkipInFlight {
content += "skip-in-flight\n"
}
return content
}
package criu
import (
"fmt"
"os"
"path/filepath"
"strings"
criulib "github.com/checkpoint-restore/go-criu/v8"
criurpc "github.com/checkpoint-restore/go-criu/v8/rpc"
"github.com/go-logr/logr"
"google.golang.org/protobuf/proto"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/logging"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/types"
)
// RestoreLogFilename is the CRIU restore log filename (also used by orchestrate/restore.go).
const RestoreLogFilename = "restore.log"
const (
netNsPath = "/proc/1/ns/net"
placeholderFDDir = "/proc/1/fd"
)
// ExecuteRestore opens the image/work directory FDs, configures inherited
// resources, and calls go-criu Restore. Returns the namespace-relative PID.
func ExecuteRestore(
criuOpts *criurpc.CriuOpts,
m *types.CheckpointManifest,
checkpointPath string,
log logr.Logger,
) (int32, error) {
settings := m.CRIUDump.CRIU
// Open image dir FD
imageDir, imageDirFD, err := openPathForCRIU(checkpointPath)
if err != nil {
return 0, fmt.Errorf("failed to open image directory: %w", err)
}
defer imageDir.Close()
criuOpts.ImagesDirFd = proto.Int32(imageDirFD)
// Open work dir FD
if settings.WorkDir != "" {
if err := os.MkdirAll(settings.WorkDir, 0755); err != nil {
return 0, fmt.Errorf("failed to create CRIU work directory: %w", err)
}
workDirFile, workDirFD, err := openPathForCRIU(settings.WorkDir)
if err != nil {
return 0, fmt.Errorf("failed to open CRIU work directory: %w", err)
}
defer workDirFile.Close()
criuOpts.WorkDirFd = proto.Int32(workDirFD)
}
c := criulib.MakeCriu()
if _, err := os.Stat(settings.BinaryPath); err != nil {
return 0, fmt.Errorf("criu binary not found at %s: %w", settings.BinaryPath, err)
}
c.SetCriuPath(settings.BinaryPath)
netNsFile, err := os.Open(netNsPath)
if err != nil {
return 0, fmt.Errorf("failed to open net NS at %s: %w", netNsPath, err)
}
defer netNsFile.Close()
c.AddInheritFd("extNetNs", netNsFile)
inheritedFiles := registerInheritFDs(c, m.K8s.StdioFDs, log)
defer closeFiles(inheritedFiles)
notify := &restoreNotify{log: log}
log.Info("Executing go-criu Restore call")
if err := c.Restore(criuOpts, notify); err != nil {
log.Error(err, "go-criu Restore returned error")
logging.LogRestoreErrors(checkpointPath, settings.WorkDir, log)
return 0, fmt.Errorf("CRIU restore failed: %w", err)
}
return notify.restoredPID, nil
}
// BuildRestoreOpts assembles CriuOpts for a CRIU restore from the checkpoint manifest.
// ImagesDirFd and WorkDirFd are left unset — ExecuteRestore opens them at restore time.
func BuildRestoreOpts(m *types.CheckpointManifest, cgroupRoot string, log logr.Logger) (*criurpc.CriuOpts, error) {
extMounts, err := buildRestoreExtMounts(m)
if err != nil {
return nil, err
}
log.Info("Generated external mount map set", "ext_mount_count", len(extMounts))
settings := m.CRIUDump.CRIU
criuOpts := &criurpc.CriuOpts{
LogFile: proto.String(RestoreLogFilename),
Root: proto.String("/"),
ExtMnt: extMounts,
}
if err := applyCommonSettings(criuOpts, &settings); err != nil {
return nil, err
}
// Restore-only options
criuOpts.RstSibling = proto.Bool(settings.RstSibling)
criuOpts.MntnsCompatMode = proto.Bool(settings.MntnsCompatMode)
criuOpts.EvasiveDevices = proto.Bool(settings.EvasiveDevices)
criuOpts.ForceIrmap = proto.Bool(settings.ForceIrmap)
if cgroupRoot != "" && shouldSetCgroupRoot(criuOpts.GetManageCgroupsMode()) {
criuOpts.CgRoot = []*criurpc.CgroupRoot{
{Path: proto.String(cgroupRoot)},
}
}
criuConfPath := filepath.Join(settings.WorkDir, "..", criuConfFilename)
if _, err := os.Stat(criuConfPath); err == nil {
criuOpts.ConfigFile = proto.String(criuConfPath)
}
return criuOpts, nil
}
func buildRestoreExtMounts(m *types.CheckpointManifest) ([]*criurpc.ExtMountMap, error) {
if len(m.CRIUDump.ExtMnt) == 0 {
return nil, fmt.Errorf("checkpoint manifest is missing criuDump.extMnt")
}
restoreMap := map[string]string{"/": "."}
for _, val := range m.CRIUDump.ExtMnt {
if val == "" || val == "/" {
continue
}
restoreMap[val] = val
}
return toExtMountMaps(restoreMap), nil
}
func registerInheritFDs(c *criulib.Criu, stdioFDs []string, log logr.Logger) []*os.File {
if len(stdioFDs) == 0 {
log.Info("No stdio FD descriptors in manifest, skipping inherit-fd setup")
return nil
}
var openFiles []*os.File
for i, target := range stdioFDs {
if !strings.Contains(target, "pipe:") {
continue
}
// stdin (fd 0) is a read-end pipe; stdout/stderr (fd 1, 2) are write-end
openMode := os.O_WRONLY
if i == 0 {
openMode = os.O_RDONLY
}
fdPath := fmt.Sprintf("%s/%d", placeholderFDDir, i)
f, err := os.OpenFile(fdPath, openMode, 0)
if err != nil {
log.V(1).Info("Failed to open placeholder stdio FD, skipping", "fd", i, "target", target, "error", err)
continue
}
openFiles = append(openFiles, f)
c.AddInheritFd(target, f)
}
log.Info("Registered inherited stdio pipes", "count", len(openFiles))
return openFiles
}
func closeFiles(files []*os.File) {
for _, file := range files {
if file != nil {
file.Close()
}
}
}
type restoreNotify struct {
criulib.NoNotify
restoredPID int32
log logr.Logger
}
func (n *restoreNotify) PreRestore() error {
n.log.V(1).Info("CRIU pre-restore")
return nil
}
func (n *restoreNotify) PostRestore(pid int32) error {
n.restoredPID = pid
n.log.Info("CRIU post-restore: process restored", "pid", pid)
return nil
}
package criu
import (
"fmt"
"os"
"strings"
criurpc "github.com/checkpoint-restore/go-criu/v8/rpc"
"golang.org/x/sys/unix"
"google.golang.org/protobuf/proto"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/types"
)
// parseManageCgroupsMode normalizes and validates the CRIU cgroup mode setting.
func parseManageCgroupsMode(raw string) (criurpc.CriuCgMode, string, error) {
mode := strings.ToLower(strings.TrimSpace(raw))
switch mode {
case "":
// Default to SOFT when unset (matches Helm default of "soft")
return criurpc.CriuCgMode_SOFT, "soft", nil
case "ignore":
return criurpc.CriuCgMode_IGNORE, "ignore", nil
case "soft":
return criurpc.CriuCgMode_SOFT, mode, nil
case "full":
return criurpc.CriuCgMode_FULL, mode, nil
case "strict":
return criurpc.CriuCgMode_STRICT, mode, nil
default:
return criurpc.CriuCgMode_IGNORE, "", fmt.Errorf("invalid manageCgroupsMode %q", raw)
}
}
func shouldSetCgroupRoot(cgMode criurpc.CriuCgMode) bool {
switch cgMode {
case criurpc.CriuCgMode_SOFT, criurpc.CriuCgMode_FULL, criurpc.CriuCgMode_STRICT:
return true
default:
return false
}
}
// applyCommonSettings sets CRIU options shared between dump and restore.
func applyCommonSettings(opts *criurpc.CriuOpts, settings *types.CRIUSettings) error {
opts.LogLevel = proto.Int32(settings.LogLevel)
opts.ShellJob = proto.Bool(settings.ShellJob)
opts.TcpClose = proto.Bool(settings.TcpClose)
opts.FileLocks = proto.Bool(settings.FileLocks)
opts.ExtUnixSk = proto.Bool(settings.ExtUnixSk)
opts.LinkRemap = proto.Bool(settings.LinkRemap)
opts.ManageCgroups = proto.Bool(true)
cgMode, _, err := parseManageCgroupsMode(settings.ManageCgroupsMode)
if err != nil {
return fmt.Errorf("invalid cgroup mode: %w", err)
}
opts.ManageCgroupsMode = &cgMode
return nil
}
// openPathForCRIU opens a path (directory or file) and clears the CLOEXEC flag
// so the FD can be inherited by CRIU child processes.
// Returns the opened file and its FD. Caller must close the file when done.
// The caller must also retain the *os.File reference for the entire lifetime the
// raw FD is in use — if the *os.File is garbage collected, Go's finalizer will
// close the underlying FD.
func openPathForCRIU(path string) (*os.File, int32, error) {
dir, err := os.Open(path)
if err != nil {
return nil, 0, fmt.Errorf("failed to open %s: %w", path, err)
}
// Clear CLOEXEC so the FD is inherited by CRIU child process.
// Go's os.Open() sets O_CLOEXEC by default, but go-criu's swrk mode
// requires the FD to be inherited.
if _, err := unix.FcntlInt(dir.Fd(), unix.F_SETFD, 0); err != nil {
dir.Close()
return nil, 0, fmt.Errorf("failed to clear CLOEXEC on %s: %w", path, err)
}
return dir, int32(dir.Fd()), nil
}
// toExtMountMaps converts the mount policy's externalized map to CRIU protobuf entries.
func toExtMountMaps(extMap map[string]string) []*criurpc.ExtMountMap {
entries := make([]*criurpc.ExtMountMap, 0, len(extMap))
for key, val := range extMap {
entries = append(entries, &criurpc.ExtMountMap{
Key: proto.String(key),
Val: proto.String(val),
})
}
return entries
}
package criu
import (
"testing"
criurpc "github.com/checkpoint-restore/go-criu/v8/rpc"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/types"
)
func TestParseManageCgroupsMode(t *testing.T) {
tests := []struct {
raw string
wantMode criurpc.CriuCgMode
wantErr bool
}{
{raw: "ignore", wantMode: criurpc.CriuCgMode_IGNORE},
{raw: "soft", wantMode: criurpc.CriuCgMode_SOFT},
{raw: "full", wantMode: criurpc.CriuCgMode_FULL},
{raw: "strict", wantMode: criurpc.CriuCgMode_STRICT},
// Case insensitive + whitespace trimming
{raw: "IGNORE", wantMode: criurpc.CriuCgMode_IGNORE},
{raw: " Soft ", wantMode: criurpc.CriuCgMode_SOFT},
{raw: " FULL ", wantMode: criurpc.CriuCgMode_FULL},
// Empty string defaults to SOFT (matches Helm default)
{raw: "", wantMode: criurpc.CriuCgMode_SOFT},
// Invalid
{raw: "bogus", wantErr: true},
}
for _, tc := range tests {
t.Run(tc.raw, func(t *testing.T) {
mode, _, err := parseManageCgroupsMode(tc.raw)
if tc.wantErr {
if err == nil {
t.Errorf("expected error for %q, got mode=%v", tc.raw, mode)
}
return
}
if err != nil {
t.Fatalf("unexpected error for %q: %v", tc.raw, err)
}
if mode != tc.wantMode {
t.Errorf("mode = %v, want %v", mode, tc.wantMode)
}
})
}
}
func TestApplyCommonSettings(t *testing.T) {
t.Run("valid mode sets all fields", func(t *testing.T) {
opts := &criurpc.CriuOpts{}
settings := &types.CRIUSettings{
LogLevel: 4,
ShellJob: true,
TcpClose: true,
FileLocks: true,
ExtUnixSk: true,
LinkRemap: true,
ManageCgroupsMode: "soft",
}
if err := applyCommonSettings(opts, settings); err != nil {
t.Fatalf("applyCommonSettings: %v", err)
}
if opts.GetLogLevel() != 4 {
t.Errorf("LogLevel = %d", opts.GetLogLevel())
}
if !opts.GetShellJob() {
t.Error("ShellJob should be true")
}
if !opts.GetTcpClose() {
t.Error("TcpClose should be true")
}
if !opts.GetFileLocks() {
t.Error("FileLocks should be true")
}
if !opts.GetExtUnixSk() {
t.Error("ExtUnixSk should be true")
}
if !opts.GetLinkRemap() {
t.Error("LinkRemap should be true")
}
if !opts.GetManageCgroups() {
t.Error("ManageCgroups should be true")
}
if opts.GetManageCgroupsMode() != criurpc.CriuCgMode_SOFT {
t.Errorf("ManageCgroupsMode = %v, want SOFT", opts.GetManageCgroupsMode())
}
})
t.Run("invalid mode returns error", func(t *testing.T) {
opts := &criurpc.CriuOpts{}
settings := &types.CRIUSettings{ManageCgroupsMode: "invalid"}
if err := applyCommonSettings(opts, settings); err == nil {
t.Error("expected error for invalid ManageCgroupsMode")
}
})
}
func TestBuildRestoreExtMounts(t *testing.T) {
t.Run("normal manifest with ExtMnt", func(t *testing.T) {
m := &types.CheckpointManifest{
CRIUDump: types.CRIUDumpManifest{
ExtMnt: map[string]string{
"/etc/hostname": "/etc/hostname",
"/proc/acpi": "/dev/null",
},
},
}
mounts, err := buildRestoreExtMounts(m)
if err != nil {
t.Fatalf("buildRestoreExtMounts: %v", err)
}
// Should contain value→value self-mappings plus "/" → "."
mountMap := make(map[string]string, len(mounts))
for _, em := range mounts {
mountMap[em.GetKey()] = em.GetVal()
}
if mountMap["/"] != "." {
t.Errorf("root mapping: got %q, want %q", mountMap["/"], ".")
}
if mountMap["/etc/hostname"] != "/etc/hostname" {
t.Errorf("/etc/hostname mapping: got %q", mountMap["/etc/hostname"])
}
if mountMap["/dev/null"] != "/dev/null" {
t.Errorf("/dev/null mapping: got %q", mountMap["/dev/null"])
}
})
t.Run("values of / or empty are skipped", func(t *testing.T) {
m := &types.CheckpointManifest{
CRIUDump: types.CRIUDumpManifest{
ExtMnt: map[string]string{
"/root_mount": "/",
"/empty_val": "",
"/good": "/good",
},
},
}
mounts, err := buildRestoreExtMounts(m)
if err != nil {
t.Fatalf("buildRestoreExtMounts: %v", err)
}
mountMap := make(map[string]string, len(mounts))
for _, em := range mounts {
mountMap[em.GetKey()] = em.GetVal()
}
// "/" and "" values should be skipped from the value→value mapping
// but "/" → "." root mapping always exists
if mountMap["/"] != "." {
t.Errorf("root mapping missing")
}
if _, ok := mountMap[""]; ok {
t.Error("empty string should not be a key in restore map")
}
if mountMap["/good"] != "/good" {
t.Errorf("/good mapping missing")
}
})
t.Run("empty ExtMnt returns error", func(t *testing.T) {
m := &types.CheckpointManifest{
CRIUDump: types.CRIUDumpManifest{},
}
_, err := buildRestoreExtMounts(m)
if err == nil {
t.Error("expected error for empty ExtMnt")
}
})
}
// Package cuda provides CUDA checkpoint and restore operations.
package cuda
import (
"context"
"fmt"
"os/exec"
"strconv"
"strings"
"github.com/go-logr/logr"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
podresourcesv1 "k8s.io/kubelet/pkg/apis/podresources/v1"
)
const (
podResourcesSocket = "/var/lib/kubelet/pod-resources/kubelet.sock"
nvidiaGPUResource = "nvidia.com/gpu"
)
// GetPodGPUUUIDs resolves GPU UUIDs for a pod/container from the kubelet PodResources API.
func GetPodGPUUUIDs(ctx context.Context, podName, podNamespace, containerName string) ([]string, error) {
if podName == "" || podNamespace == "" {
return nil, nil
}
conn, err := grpc.DialContext(
ctx,
"unix://"+podResourcesSocket,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
)
if err != nil {
return nil, err
}
defer conn.Close()
client := podresourcesv1.NewPodResourcesListerClient(conn)
resp, err := client.List(ctx, &podresourcesv1.ListPodResourcesRequest{})
if err != nil {
return nil, err
}
for _, pod := range resp.GetPodResources() {
if pod.GetName() != podName || pod.GetNamespace() != podNamespace {
continue
}
for _, container := range pod.GetContainers() {
if containerName != "" && container.GetName() != containerName {
continue
}
for _, device := range container.GetDevices() {
if device.GetResourceName() == nvidiaGPUResource {
return device.GetDeviceIds(), nil
}
}
}
}
return nil, nil
}
// FilterProcesses returns the subset of candidate PIDs that report CUDA state.
func FilterProcesses(ctx context.Context, allPIDs []int, log logr.Logger) []int {
cudaPIDs := make([]int, 0, len(allPIDs))
for _, pid := range allPIDs {
if pid <= 0 {
continue
}
cmd := exec.CommandContext(ctx, cudaCheckpointBinary, "--get-state", "--pid", strconv.Itoa(pid))
if err := cmd.Run(); err != nil {
if ctx.Err() != nil {
break
}
log.V(1).Info("CUDA state probe failed", "pid", pid, "error", err)
continue
}
cudaPIDs = append(cudaPIDs, pid)
}
return cudaPIDs
}
// BuildDeviceMap creates a cuda-checkpoint --device-map value from source and target GPU UUID lists.
func BuildDeviceMap(sourceUUIDs, targetUUIDs []string) (string, error) {
if len(sourceUUIDs) != len(targetUUIDs) {
return "", fmt.Errorf("GPU count mismatch: source has %d, target has %d", len(sourceUUIDs), len(targetUUIDs))
}
if len(sourceUUIDs) == 0 {
return "", fmt.Errorf("GPU UUID list is empty")
}
pairs := make([]string, len(sourceUUIDs))
for i := range sourceUUIDs {
pairs[i] = sourceUUIDs[i] + "=" + targetUUIDs[i]
}
return strings.Join(pairs, ","), nil
}
// LockAndCheckpointProcessTree locks and checkpoints CUDA state for all given PIDs.
// On partial failure, already-checkpointed PIDs are restored+unlocked.
func LockAndCheckpointProcessTree(ctx context.Context, cudaPIDs []int, log logr.Logger) error {
locked := make([]int, 0, len(cudaPIDs))
for _, pid := range cudaPIDs {
if err := lock(ctx, pid, log); err != nil {
bulkUnlock(context.Background(), locked, log)
return fmt.Errorf("cuda lock failed for PID %d: %w", pid, err)
}
locked = append(locked, pid)
}
checkpointed := make([]int, 0, len(cudaPIDs))
for _, pid := range cudaPIDs {
if err := checkpoint(ctx, pid, log); err != nil {
recoverCheckpointed(context.Background(), checkpointed, locked, log)
return fmt.Errorf("cuda checkpoint failed for PID %d: %w", pid, err)
}
checkpointed = append(checkpointed, pid)
}
return nil
}
// RestoreAndUnlockProcessTree restores and unlocks CUDA state for the given PIDs.
func RestoreAndUnlockProcessTree(ctx context.Context, cudaPIDs []int, deviceMap string, log logr.Logger) error {
for _, pid := range cudaPIDs {
if err := restoreProcess(ctx, pid, deviceMap, log); err != nil {
return fmt.Errorf("cuda restore failed for PID %d: %w", pid, err)
}
}
for _, pid := range cudaPIDs {
if err := unlock(ctx, pid, log); err != nil {
state, stateErr := getState(ctx, pid)
if stateErr == nil && state == "running" {
log.Info("cuda-checkpoint unlock returned error but process is already running", "pid", pid)
continue
}
return fmt.Errorf("failed to unlock CUDA process %d: %w", pid, err)
}
}
return nil
}
// bulkUnlock unlocks a list of CUDA PIDs (best-effort).
func bulkUnlock(ctx context.Context, pids []int, log logr.Logger) {
for _, pid := range pids {
if err := unlock(ctx, pid, log); err != nil {
log.Error(err, "Failed to unlock CUDA process", "pid", pid)
}
}
}
// recoverCheckpointed is best-effort cleanup when checkpoint fails partway.
// Checkpointed PIDs need restore+unlock; locked-only PIDs just need unlock.
func recoverCheckpointed(ctx context.Context, checkpointed, locked []int, log logr.Logger) {
checkpointedSet := make(map[int]struct{}, len(checkpointed))
for _, pid := range checkpointed {
checkpointedSet[pid] = struct{}{}
}
for _, pid := range checkpointed {
if err := restoreProcess(ctx, pid, "", log); err != nil {
log.Error(err, "Failed to restore CUDA process during cleanup", "pid", pid)
continue
}
if err := unlock(ctx, pid, log); err != nil {
log.Error(err, "Failed to unlock CUDA process after restore during cleanup", "pid", pid)
}
}
for _, pid := range locked {
if _, ok := checkpointedSet[pid]; ok {
continue
}
if err := unlock(ctx, pid, log); err != nil {
log.Error(err, "Failed to unlock CUDA process during cleanup", "pid", pid)
}
}
}
package cuda
import (
"testing"
)
func TestBuildDeviceMap(t *testing.T) {
tests := []struct {
name string
source []string
target []string
want string
wantErr bool
}{
{
name: "single GPU",
source: []string{"GPU-aaa"},
target: []string{"GPU-bbb"},
want: "GPU-aaa=GPU-bbb",
},
{
name: "multiple GPUs",
source: []string{"GPU-aaa", "GPU-bbb"},
target: []string{"GPU-ccc", "GPU-ddd"},
want: "GPU-aaa=GPU-ccc,GPU-bbb=GPU-ddd",
},
{
name: "mismatched lengths",
source: []string{"GPU-aaa", "GPU-bbb"},
target: []string{"GPU-ccc"},
wantErr: true,
},
{
name: "both empty",
source: []string{},
target: []string{},
wantErr: true,
},
{
name: "source empty target non-empty",
source: []string{},
target: []string{"GPU-aaa"},
wantErr: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got, err := BuildDeviceMap(tc.source, tc.target)
if tc.wantErr {
if err == nil {
t.Errorf("expected error, got %q", got)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != tc.want {
t.Errorf("got %q, want %q", got, tc.want)
}
})
}
}
package cuda
import (
"context"
"fmt"
"os/exec"
"strconv"
"strings"
"time"
"github.com/go-logr/logr"
)
const (
cudaCheckpointBinary = "/usr/local/sbin/cuda-checkpoint"
actionLock = "lock"
actionCheckpoint = "checkpoint"
actionRestore = "restore"
actionUnlock = "unlock"
)
func lock(ctx context.Context, pid int, log logr.Logger) error {
return runAction(ctx, pid, actionLock, "", log)
}
func checkpoint(ctx context.Context, pid int, log logr.Logger) error {
return runAction(ctx, pid, actionCheckpoint, "", log)
}
func restoreProcess(ctx context.Context, pid int, deviceMap string, log logr.Logger) error {
return runAction(ctx, pid, actionRestore, deviceMap, log)
}
func unlock(ctx context.Context, pid int, log logr.Logger) error {
return runAction(ctx, pid, actionUnlock, "", log)
}
func getState(ctx context.Context, pid int) (string, error) {
cmd := exec.CommandContext(ctx, cudaCheckpointBinary, "--get-state", "--pid", strconv.Itoa(pid))
output, err := cmd.CombinedOutput()
state := strings.TrimSpace(string(output))
if err != nil {
return "", fmt.Errorf("cuda-checkpoint --get-state failed for pid %d: %w (output: %s)", pid, err, state)
}
if state == "" {
return "", fmt.Errorf("cuda-checkpoint --get-state returned empty state for pid %d", pid)
}
return state, nil
}
func runAction(ctx context.Context, pid int, action, deviceMap string, log logr.Logger) error {
args := []string{"--action", action, "--pid", strconv.Itoa(pid)}
if action == actionRestore && deviceMap != "" {
args = append(args, "--device-map", deviceMap)
}
cmd := exec.CommandContext(ctx, cudaCheckpointBinary, args...)
start := time.Now()
output, err := cmd.CombinedOutput()
duration := time.Since(start)
out := strings.TrimSpace(string(output))
if err != nil {
return fmt.Errorf("cuda-checkpoint %v failed for pid %d after %s: %w (output: %s)", args, pid, duration, err, out)
}
log.Info("cuda-checkpoint command succeeded",
"pid", pid,
"action", action,
"duration", duration,
"output", out,
)
return nil
}
This diff is collapsed.
// middleware.go provides HTTP middleware for the server.
package httpApiServer
import (
"log"
"net/http"
"time"
)
// LoggingMiddleware wraps an HTTP handler and logs request details.
func LoggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
log.Printf("Started %s %s", r.Method, r.URL.Path)
next.ServeHTTP(w, r)
log.Printf("Completed %s %s in %v", r.Method, r.URL.Path, time.Since(start))
})
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment