Unverified Commit 3a09a559 authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

fix: make snapshot CI image linter pass (#7124)

parent f4e20810
......@@ -47,26 +47,17 @@ test: fmt vet ## Run tests.
.PHONY: lint
lint: golangci-lint ## Run golangci-lint linter.
$(GOLANGCI_LINT) run --timeout=5m
$(GOLANGCI_LINT) run
.PHONY: lint-fix
lint-fix: golangci-lint ## Run golangci-lint linter and perform fixes.
$(GOLANGCI_LINT) run --fix --timeout=5m
$(GOLANGCI_LINT) run --fix
##@ Build
.PHONY: build
build: fmt vet ## Build snapshot-agent binary.
CGO_ENABLED=0 go build -ldflags="-w -s" -o bin/snapshot-agent ./cmd/agent
.PHONY: run
run: build ## Run snapshot-agent from your host.
./bin/snapshot-agent
.PHONY: clean
clean: ## Remove build artifacts.
rm -rf bin/
rm -f cover.out
go build -o bin/snapshot-agent ./cmd/agent
##@ Docker
......@@ -74,14 +65,6 @@ clean: ## Remove build artifacts.
docker-build-agent: ## Build snapshot-agent docker image (linux/amd64 only).
$(CONTAINER_TOOL) build --platform ${RUNTIME_IMAGE_PLATFORM} --target agent -t ${IMG} .
.PHONY: docker-build-agent-lint
docker-build-agent-lint: ## Build snapshot-agent docker image up to lint stage.
$(CONTAINER_TOOL) build --target linter -t ${IMG}-lint .
.PHONY: docker-build-agent-test
docker-build-agent-test: ## Build snapshot-agent docker image up to test stage.
$(CONTAINER_TOOL) build --target tester -t ${IMG}-test .
.PHONY: docker-build-placeholder
docker-build-placeholder: ## Build placeholder image for checkpoint restore (linux/amd64 only). Requires PLACEHOLDER_BASE_IMG.
ifndef PLACEHOLDER_BASE_IMG
......@@ -139,5 +122,5 @@ mv "$$(echo "$(1)" | sed "s/-$(3)$$//")" $(1) ;\
endef
.PHONY: coverage
coverage: test ## Show test coverage report.
coverage: test
go tool cover -func=cover.out
......@@ -6,7 +6,6 @@ import (
"path/filepath"
"strconv"
"strings"
)
const HostCgroupPath = "/sys/fs/cgroup"
......
......@@ -190,7 +190,9 @@ func TestBuildMountPolicy(t *testing.T) {
mounts: []types.MountInfo{},
rootFS: func() string {
dir := t.TempDir()
os.WriteFile(filepath.Join(dir, "proc"), []byte("x"), 0644)
if err := os.WriteFile(filepath.Join(dir, "proc"), []byte("x"), 0644); err != nil {
t.Fatalf("write masked file: %v", err)
}
return dir
}(),
maskedPaths: []string{"/proc"},
......@@ -201,7 +203,9 @@ func TestBuildMountPolicy(t *testing.T) {
mounts: []types.MountInfo{},
rootFS: func() string {
dir := t.TempDir()
os.MkdirAll(filepath.Join(dir, "proc"), 0755)
if err := os.MkdirAll(filepath.Join(dir, "proc"), 0755); err != nil {
t.Fatalf("mkdir masked dir: %v", err)
}
return dir
}(),
maskedPaths: []string{"/proc"},
......
......@@ -58,4 +58,3 @@ func SendSignalViaPIDNamespace(ctx context.Context, log logr.Logger, referenceHo
)
return nil
}
......@@ -81,33 +81,46 @@ func TestBuildExclusions(t *testing.T) {
func TestFindWhiteoutFiles(t *testing.T) {
tests := []struct {
name string
setup func(dir string) // create files in temp dir
setup func(t *testing.T, dir string) // create files in temp dir
want []string
}{
{
name: "top-level whiteout",
setup: func(dir string) {
os.WriteFile(filepath.Join(dir, ".wh.somefile"), nil, 0644)
setup: func(t *testing.T, dir string) {
t.Helper()
if err := os.WriteFile(filepath.Join(dir, ".wh.somefile"), nil, 0644); err != nil {
t.Fatalf("write whiteout: %v", err)
}
},
want: []string{"somefile"},
},
{
name: "nested whiteout returns relative path",
setup: func(dir string) {
setup: func(t *testing.T, dir string) {
t.Helper()
sub := filepath.Join(dir, "subdir")
os.MkdirAll(sub, 0755)
os.WriteFile(filepath.Join(sub, ".wh.nested"), nil, 0644)
if err := os.MkdirAll(sub, 0755); err != nil {
t.Fatalf("mkdir subdir: %v", err)
}
if err := os.WriteFile(filepath.Join(sub, ".wh.nested"), nil, 0644); err != nil {
t.Fatalf("write nested whiteout: %v", err)
}
},
want: []string{"subdir/nested"},
},
{
name: "no whiteouts returns empty",
setup: func(dir string) { os.WriteFile(filepath.Join(dir, "regular"), nil, 0644) },
setup: func(t *testing.T, dir string) {
t.Helper()
if err := os.WriteFile(filepath.Join(dir, "regular"), nil, 0644); err != nil {
t.Fatalf("write regular file: %v", err)
}
},
want: nil,
},
{
name: "empty dir returns empty",
setup: func(dir string) {},
setup: func(*testing.T, string) {},
want: nil,
},
}
......@@ -115,7 +128,7 @@ func TestFindWhiteoutFiles(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
dir := t.TempDir()
tc.setup(dir)
tc.setup(t, dir)
got, err := findWhiteoutFiles(dir)
if err != nil {
t.Fatalf("findWhiteoutFiles: %v", err)
......@@ -136,7 +149,9 @@ func TestCaptureDeletedFiles(t *testing.T) {
t.Run("dir with whiteouts writes JSON and returns true", func(t *testing.T) {
upperDir := t.TempDir()
checkpointDir := t.TempDir()
os.WriteFile(filepath.Join(upperDir, ".wh.removed"), nil, 0644)
if err := os.WriteFile(filepath.Join(upperDir, ".wh.removed"), nil, 0644); err != nil {
t.Fatalf("write whiteout: %v", err)
}
found, err := CaptureDeletedFiles(upperDir, checkpointDir)
if err != nil {
......@@ -162,7 +177,9 @@ func TestCaptureDeletedFiles(t *testing.T) {
t.Run("dir with no whiteouts returns false and no file", func(t *testing.T) {
upperDir := t.TempDir()
checkpointDir := t.TempDir()
os.WriteFile(filepath.Join(upperDir, "normalfile"), nil, 0644)
if err := os.WriteFile(filepath.Join(upperDir, "normalfile"), nil, 0644); err != nil {
t.Fatalf("write regular file: %v", err)
}
found, err := CaptureDeletedFiles(upperDir, checkpointDir)
if err != nil {
......@@ -195,11 +212,18 @@ func TestApplyDeletedFiles(t *testing.T) {
targetRoot := t.TempDir()
// Create target file that should be deleted
os.WriteFile(filepath.Join(targetRoot, "old-cache"), []byte("data"), 0644)
if err := os.WriteFile(filepath.Join(targetRoot, "old-cache"), []byte("data"), 0644); err != nil {
t.Fatalf("write target file: %v", err)
}
// Write deleted-files.json
data, _ := json.Marshal([]string{"old-cache"})
os.WriteFile(filepath.Join(checkpointDir, deletedFilesFilename), data, 0644)
data, err := json.Marshal([]string{"old-cache"})
if err != nil {
t.Fatalf("marshal deleted files: %v", err)
}
if err := os.WriteFile(filepath.Join(checkpointDir, deletedFilesFilename), data, 0644); err != nil {
t.Fatalf("write deleted-files.json: %v", err)
}
if err := ApplyDeletedFiles(checkpointDir, targetRoot, log); err != nil {
t.Fatalf("ApplyDeletedFiles: %v", err)
......@@ -223,12 +247,22 @@ func TestApplyDeletedFiles(t *testing.T) {
// Create a file outside targetRoot that the traversal would try to delete
outsideDir := t.TempDir()
secretFile := filepath.Join(outsideDir, "passwd")
os.WriteFile(secretFile, []byte("secret"), 0644)
if err := os.WriteFile(secretFile, []byte("secret"), 0644); err != nil {
t.Fatalf("write secret file: %v", err)
}
// Construct a relative path that escapes targetRoot
rel, _ := filepath.Rel(targetRoot, secretFile)
data, _ := json.Marshal([]string{rel})
os.WriteFile(filepath.Join(checkpointDir, deletedFilesFilename), data, 0644)
rel, err := filepath.Rel(targetRoot, secretFile)
if err != nil {
t.Fatalf("build relative path: %v", err)
}
data, err := json.Marshal([]string{rel})
if err != nil {
t.Fatalf("marshal deleted files: %v", err)
}
if err := os.WriteFile(filepath.Join(checkpointDir, deletedFilesFilename), data, 0644); err != nil {
t.Fatalf("write deleted-files.json: %v", err)
}
if err := ApplyDeletedFiles(checkpointDir, targetRoot, log); err != nil {
t.Fatalf("ApplyDeletedFiles: %v", err)
......@@ -244,8 +278,13 @@ func TestApplyDeletedFiles(t *testing.T) {
checkpointDir := t.TempDir()
targetRoot := t.TempDir()
data, _ := json.Marshal([]string{"nonexistent"})
os.WriteFile(filepath.Join(checkpointDir, deletedFilesFilename), data, 0644)
data, err := json.Marshal([]string{"nonexistent"})
if err != nil {
t.Fatalf("marshal deleted files: %v", err)
}
if err := os.WriteFile(filepath.Join(checkpointDir, deletedFilesFilename), data, 0644); err != nil {
t.Fatalf("write deleted-files.json: %v", err)
}
if err := ApplyDeletedFiles(checkpointDir, targetRoot, log); err != nil {
t.Fatalf("ApplyDeletedFiles: %v", err)
......@@ -256,8 +295,13 @@ func TestApplyDeletedFiles(t *testing.T) {
checkpointDir := t.TempDir()
targetRoot := t.TempDir()
data, _ := json.Marshal([]string{""})
os.WriteFile(filepath.Join(checkpointDir, deletedFilesFilename), data, 0644)
data, err := json.Marshal([]string{""})
if err != nil {
t.Fatalf("marshal deleted files: %v", err)
}
if err := os.WriteFile(filepath.Join(checkpointDir, deletedFilesFilename), data, 0644); err != nil {
t.Fatalf("write deleted-files.json: %v", err)
}
if err := ApplyDeletedFiles(checkpointDir, targetRoot, log); err != nil {
t.Fatalf("ApplyDeletedFiles: %v", err)
......
......@@ -93,4 +93,3 @@ func toExtMountMaps(extMap map[string]string) []*criurpc.ExtMountMap {
}
return entries
}
......@@ -15,10 +15,9 @@ import (
podresourcesv1 "k8s.io/kubelet/pkg/apis/podresources/v1"
)
const (
podResourcesSocket = "/var/lib/kubelet/pod-resources/kubelet.sock"
nvidiaGPUResource = "nvidia.com/gpu"
)
const nvidiaGPUResource = "nvidia.com/gpu"
var podResourcesSocketPath = "/var/lib/kubelet/pod-resources/kubelet.sock"
// GetPodGPUUUIDs resolves GPU UUIDs for a pod/container from the kubelet PodResources API.
// All nvidia.com/gpu device entries are accumulated in case the kubelet splits them
......@@ -28,11 +27,9 @@ func GetPodGPUUUIDs(ctx context.Context, podName, podNamespace, containerName st
return nil, nil
}
conn, err := grpc.DialContext(
ctx,
"unix://"+podResourcesSocket,
conn, err := grpc.NewClient(
"unix://"+podResourcesSocketPath,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
)
if err != nil {
return nil, err
......
package cuda
import (
"context"
"errors"
"net"
"path/filepath"
"strings"
"testing"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
podresourcesv1 "k8s.io/kubelet/pkg/apis/podresources/v1"
)
func TestBuildDeviceMap(t *testing.T) {
......@@ -62,3 +74,120 @@ func TestBuildDeviceMap(t *testing.T) {
})
}
}
type testPodResourcesServer struct {
podresourcesv1.UnimplementedPodResourcesListerServer
resp *podresourcesv1.ListPodResourcesResponse
}
func (s *testPodResourcesServer) List(context.Context, *podresourcesv1.ListPodResourcesRequest) (*podresourcesv1.ListPodResourcesResponse, error) {
return s.resp, nil
}
func (s *testPodResourcesServer) GetAllocatableResources(context.Context, *podresourcesv1.AllocatableResourcesRequest) (*podresourcesv1.AllocatableResourcesResponse, error) {
return nil, status.Error(codes.Unimplemented, "not implemented in test")
}
func (s *testPodResourcesServer) Get(context.Context, *podresourcesv1.GetPodResourcesRequest) (*podresourcesv1.GetPodResourcesResponse, error) {
return nil, status.Error(codes.Unimplemented, "not implemented in test")
}
func TestGetPodGPUUUIDs(t *testing.T) {
socketDir := t.TempDir()
socketPath := filepath.Join(socketDir, "kubelet.sock")
listener, err := net.Listen("unix", socketPath)
if err != nil {
t.Fatalf("listen unix socket: %v", err)
}
defer listener.Close()
server := grpc.NewServer()
podresourcesv1.RegisterPodResourcesListerServer(server, &testPodResourcesServer{
resp: &podresourcesv1.ListPodResourcesResponse{
PodResources: []*podresourcesv1.PodResources{
{
Name: "other-pod",
Namespace: "default",
Containers: []*podresourcesv1.ContainerResources{
{
Name: "main",
Devices: []*podresourcesv1.ContainerDevices{
{
ResourceName: nvidiaGPUResource,
DeviceIds: []string{"GPU-ignore"},
},
},
},
},
},
{
Name: "test-pod",
Namespace: "default",
Containers: []*podresourcesv1.ContainerResources{
{
Name: "sidecar",
Devices: []*podresourcesv1.ContainerDevices{
{
ResourceName: nvidiaGPUResource,
DeviceIds: []string{"GPU-sidecar"},
},
},
},
{
Name: "main",
Devices: []*podresourcesv1.ContainerDevices{
{
ResourceName: nvidiaGPUResource,
DeviceIds: []string{"GPU-a", "GPU-b"},
},
{
ResourceName: "example.com/fpga",
DeviceIds: []string{"FPGA-ignore"},
},
{
ResourceName: nvidiaGPUResource,
DeviceIds: []string{"GPU-c"},
},
},
},
},
},
},
},
})
go func() {
if serveErr := server.Serve(listener); serveErr != nil {
if errors.Is(serveErr, grpc.ErrServerStopped) || strings.Contains(serveErr.Error(), "use of closed network connection") {
return
}
t.Errorf("serve test pod-resources gRPC server: %v", serveErr)
}
}()
t.Cleanup(server.Stop)
previousSocketPath := podResourcesSocketPath
podResourcesSocketPath = socketPath
t.Cleanup(func() {
podResourcesSocketPath = previousSocketPath
})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
got, err := GetPodGPUUUIDs(ctx, "test-pod", "default", "main")
if err != nil {
t.Fatalf("GetPodGPUUUIDs: %v", err)
}
want := []string{"GPU-a", "GPU-b", "GPU-c"}
if len(got) != len(want) {
t.Fatalf("got %v, want %v", got, want)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("got %v, want %v", got, want)
}
}
}
......@@ -14,6 +14,11 @@ import (
"k8s.io/client-go/tools/cache"
)
const (
terminalStatusPatchRetryAttempts = 3
terminalStatusPatchRetryDelay = 10 * time.Millisecond
)
func podFromInformerObj(obj interface{}) (*corev1.Pod, bool) {
if pod, ok := obj.(*corev1.Pod); ok {
return pod, true
......@@ -73,6 +78,32 @@ func annotatePod(ctx context.Context, clientset kubernetes.Interface, log logr.L
return err
}
func annotatePodRetry(ctx context.Context, clientset kubernetes.Interface, log logr.Logger, pod *corev1.Pod, annotations map[string]string) error {
delay := terminalStatusPatchRetryDelay
var lastErr error
for attempt := 1; attempt <= terminalStatusPatchRetryAttempts; attempt++ {
if err := annotatePod(ctx, clientset, log, pod, annotations); err == nil {
return nil
} else {
lastErr = err
}
if attempt == terminalStatusPatchRetryAttempts {
break
}
select {
case <-ctx.Done():
return fmt.Errorf("pod annotation retry interrupted: %w", ctx.Err())
case <-time.After(delay):
}
delay *= 2
}
return fmt.Errorf("failed to annotate pod after %d attempts: %w", terminalStatusPatchRetryAttempts, lastErr)
}
func waitForPodReady(ctx context.Context, clientset kubernetes.Interface, namespace, podName, containerName string) error {
lastPhase := ""
......
......@@ -109,7 +109,7 @@ func (w *Watcher) Start(ctx context.Context) error {
)
ckptInformer := ckptFactory.Core().V1().Pods().Informer()
ckptInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{
if _, err := ckptInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{
AddFunc: func(obj interface{}) {
pod, ok := podFromInformerObj(obj)
if !ok {
......@@ -124,7 +124,9 @@ func (w *Watcher) Start(ctx context.Context) error {
}
w.handleCheckpointPodEvent(ctx, pod)
},
})
}); err != nil {
return fmt.Errorf("failed to add checkpoint informer handler: %w", err)
}
go ckptFactory.Start(w.stopCh)
syncFuncs = append(syncFuncs, ckptInformer.HasSynced)
......@@ -144,7 +146,7 @@ func (w *Watcher) Start(ctx context.Context) error {
)
restoreInformer := restoreFactory.Core().V1().Pods().Informer()
restoreInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{
if _, err := restoreInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{
AddFunc: func(obj interface{}) {
pod, ok := podFromInformerObj(obj)
if !ok {
......@@ -159,7 +161,9 @@ func (w *Watcher) Start(ctx context.Context) error {
}
w.handleRestorePodEvent(ctx, pod)
},
})
}); err != nil {
return fmt.Errorf("failed to add restore informer handler: %w", err)
}
go restoreFactory.Start(w.stopCh)
syncFuncs = append(syncFuncs, restoreInformer.HasSynced)
......@@ -201,10 +205,15 @@ func (w *Watcher) handleCheckpointPodEvent(ctx context.Context, pod *corev1.Pod)
w.log.Info("Pod ready, triggering checkpoint", "pod", podKey, "checkpoint_hash", checkpointHash)
emitPodEvent(ctx, w.clientset, w.log, pod, "snapshot", corev1.EventTypeNormal, "CheckpointRequested", fmt.Sprintf("Checkpoint requested: %s", checkpointHash))
go w.doCheckpoint(ctx, pod, checkpointHash, podKey)
go func() {
if err := w.doCheckpoint(ctx, pod, checkpointHash, podKey); err != nil {
opLog := w.log.WithValues("pod", podKey, "checkpoint_hash", checkpointHash)
opLog.Error(err, "Checkpoint worker failed")
emitPodEvent(ctx, w.clientset, opLog, pod, "snapshot", corev1.EventTypeWarning, "CheckpointWorkerFailed", err.Error())
}
}()
}
func (w *Watcher) handleRestorePodEvent(ctx context.Context, pod *corev1.Pod) {
if pod.Spec.NodeName != w.config.NodeName {
return
......@@ -251,7 +260,13 @@ func (w *Watcher) handleRestorePodEvent(ctx context.Context, pod *corev1.Pod) {
w.log.Info("Restore pod running, triggering external restore", "pod", podKey, "checkpoint_hash", checkpointHash)
emitPodEvent(ctx, w.clientset, w.log, pod, "snapshot", corev1.EventTypeNormal, "RestoreRequested", fmt.Sprintf("Restore requested from checkpoint %s", checkpointHash))
go w.doRestore(ctx, pod, checkpointHash, podKey)
go func() {
if err := w.doRestore(ctx, pod, checkpointHash, podKey); err != nil {
opLog := w.log.WithValues("pod", podKey, "checkpoint_hash", checkpointHash)
opLog.Error(err, "Restore worker failed")
emitPodEvent(ctx, w.clientset, opLog, pod, "snapshot", corev1.EventTypeWarning, "RestoreWorkerFailed", err.Error())
}
}()
}
// doCheckpoint runs the full checkpoint workflow for a pod:
......@@ -260,15 +275,37 @@ func (w *Watcher) handleRestorePodEvent(ctx context.Context, pod *corev1.Pod) {
// 3. Call orchestrate.Checkpoint (inspect → configure → CUDA lock/checkpoint → CRIU dump → rootfs diff)
// 4. SIGUSR1 the process on success (notify workload), SIGKILL on failure (terminate immediately)
// 5. Mark pod as completed or failed
func (w *Watcher) doCheckpoint(ctx context.Context, pod *corev1.Pod, checkpointHash, podKey string) {
defer w.release(podKey)
func (w *Watcher) doCheckpoint(ctx context.Context, pod *corev1.Pod, checkpointHash, podKey string) error {
releaseOnExit := true
defer func() {
if releaseOnExit {
w.release(podKey)
}
}()
log := w.log.WithValues("pod", podKey, "checkpoint_hash", checkpointHash)
setCheckpointStatus := func(value string) error {
annotations := map[string]string{
kubeAnnotationCheckpointStatus: value,
}
if value == "failed" || value == "completed" {
if err := annotatePodRetry(ctx, w.clientset, log, pod, annotations); err != nil {
releaseOnExit = false
return fmt.Errorf("failed to persist terminal checkpoint status %q: %w", value, err)
}
return nil
}
if err := annotatePod(ctx, w.clientset, log, pod, annotations); err != nil {
return fmt.Errorf("failed to update checkpoint status %q: %w", value, err)
}
return nil
}
if err := annotatePod(ctx, w.clientset, log, pod, map[string]string{
kubeAnnotationCheckpointStatus: "in_progress",
}); err != nil {
log.Error(err, "Failed to annotate pod with checkpoint in_progress")
return
return fmt.Errorf("failed to annotate pod with checkpoint in_progress: %w", err)
}
// Resolve the target container
......@@ -277,8 +314,10 @@ func (w *Watcher) doCheckpoint(ctx context.Context, pod *corev1.Pod, checkpointH
err := fmt.Errorf("no containers found in pod spec")
log.Error(err, "Checkpoint failed")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", err.Error())
annotatePod(ctx, w.clientset, log, pod, map[string]string{kubeAnnotationCheckpointStatus: "failed"})
return
if statusErr := setCheckpointStatus("failed"); statusErr != nil {
return statusErr
}
return nil
}
var containerID string
for _, cs := range pod.Status.ContainerStatuses {
......@@ -289,8 +328,10 @@ func (w *Watcher) doCheckpoint(ctx context.Context, pod *corev1.Pod, checkpointH
}
if containerID == "" {
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", "Could not resolve target container ID")
annotatePod(ctx, w.clientset, log, pod, map[string]string{kubeAnnotationCheckpointStatus: "failed"})
return
if statusErr := setCheckpointStatus("failed"); statusErr != nil {
return statusErr
}
return nil
}
// Resolve the container's host PID (needed for signaling after checkpoint)
......@@ -298,8 +339,10 @@ func (w *Watcher) doCheckpoint(ctx context.Context, pod *corev1.Pod, checkpointH
if err != nil {
log.Error(err, "Failed to resolve container")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", fmt.Sprintf("Container resolve failed: %v", err))
annotatePod(ctx, w.clientset, log, pod, map[string]string{kubeAnnotationCheckpointStatus: "failed"})
return
if statusErr := setCheckpointStatus("failed"); statusErr != nil {
return statusErr
}
return nil
}
// Step 1: Run the checkpoint orchestrator
......@@ -319,8 +362,10 @@ func (w *Watcher) doCheckpoint(ctx context.Context, pod *corev1.Pod, checkpointH
if signalErr := common.SendSignalToPID(log, containerPID, syscall.SIGKILL, "checkpoint failed"); signalErr != nil {
log.Error(signalErr, "Failed to signal checkpoint failure to runtime process")
}
annotatePod(ctx, w.clientset, log, pod, map[string]string{kubeAnnotationCheckpointStatus: "failed"})
return
if statusErr := setCheckpointStatus("failed"); statusErr != nil {
return statusErr
}
return nil
}
// Step 2: SIGUSR1 on success: notify the workload that checkpoint completed
......@@ -328,11 +373,16 @@ func (w *Watcher) doCheckpoint(ctx context.Context, pod *corev1.Pod, checkpointH
if err := common.SendSignalToPID(log, containerPID, syscall.SIGUSR1, "checkpoint complete"); err != nil {
log.Error(err, "Failed to signal checkpoint completion to runtime process")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", err.Error())
annotatePod(ctx, w.clientset, log, pod, map[string]string{kubeAnnotationCheckpointStatus: "failed"})
return
if statusErr := setCheckpointStatus("failed"); statusErr != nil {
return statusErr
}
return nil
}
annotatePod(ctx, w.clientset, log, pod, map[string]string{kubeAnnotationCheckpointStatus: "completed"})
if err := setCheckpointStatus("completed"); err != nil {
return err
}
return nil
}
// doRestore runs the full restore workflow for a pod:
......@@ -341,15 +391,37 @@ func (w *Watcher) doCheckpoint(ctx context.Context, pod *corev1.Pod, checkpointH
// 3. SIGCONT the restored process to wake it up
// 4. Wait for the pod to become Ready
// 5. Mark pod as completed or failed
func (w *Watcher) doRestore(ctx context.Context, pod *corev1.Pod, checkpointHash, podKey string) {
defer w.release(podKey)
func (w *Watcher) doRestore(ctx context.Context, pod *corev1.Pod, checkpointHash, podKey string) error {
releaseOnExit := true
defer func() {
if releaseOnExit {
w.release(podKey)
}
}()
log := w.log.WithValues("pod", podKey, "checkpoint_hash", checkpointHash)
setRestoreStatus := func(value string) error {
annotations := map[string]string{
kubeAnnotationRestoreStatus: value,
}
if value == "failed" || value == "completed" {
if err := annotatePodRetry(ctx, w.clientset, log, pod, annotations); err != nil {
releaseOnExit = false
return fmt.Errorf("failed to persist terminal restore status %q: %w", value, err)
}
return nil
}
if err := annotatePod(ctx, w.clientset, log, pod, annotations); err != nil {
return fmt.Errorf("failed to update restore status %q: %w", value, err)
}
return nil
}
if err := annotatePod(ctx, w.clientset, log, pod, map[string]string{
kubeAnnotationRestoreStatus: "in_progress",
}); err != nil {
log.Error(err, "Failed to annotate pod with restore in_progress")
return
return fmt.Errorf("failed to annotate pod with restore in_progress: %w", err)
}
containerName := resolveMainContainerName(pod)
......@@ -357,8 +429,10 @@ func (w *Watcher) doRestore(ctx context.Context, pod *corev1.Pod, checkpointHash
err := fmt.Errorf("no containers found in pod spec")
log.Error(err, "Restore failed")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
annotatePod(ctx, w.clientset, log, pod, map[string]string{kubeAnnotationRestoreStatus: "failed"})
return
if statusErr := setRestoreStatus("failed"); statusErr != nil {
return statusErr
}
return nil
}
// Step 1: Run the restore orchestrator (inspect + nsrestore)
......@@ -374,8 +448,10 @@ func (w *Watcher) doRestore(ctx context.Context, pod *corev1.Pod, checkpointHash
if err != nil {
log.Error(err, "External restore failed")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
annotatePod(ctx, w.clientset, log, pod, map[string]string{kubeAnnotationRestoreStatus: "failed"})
return
if statusErr := setRestoreStatus("failed"); statusErr != nil {
return statusErr
}
return nil
}
// Step 2: SIGCONT the restored process via PID namespace
......@@ -383,14 +459,18 @@ func (w *Watcher) doRestore(ctx context.Context, pod *corev1.Pod, checkpointHash
if err != nil {
log.Error(err, "Failed to resolve placeholder host PID for signaling")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
annotatePod(ctx, w.clientset, log, pod, map[string]string{kubeAnnotationRestoreStatus: "failed"})
return
if statusErr := setRestoreStatus("failed"); statusErr != nil {
return statusErr
}
return nil
}
if err := common.SendSignalViaPIDNamespace(ctx, log, placeholderHostPID, restoredPID, syscall.SIGCONT, "restore complete"); err != nil {
log.Error(err, "Failed to signal restored runtime process")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
annotatePod(ctx, w.clientset, log, pod, map[string]string{kubeAnnotationRestoreStatus: "failed"})
return
if statusErr := setRestoreStatus("failed"); statusErr != nil {
return statusErr
}
return nil
}
// Step 3: Wait for the pod to become Ready
......@@ -403,12 +483,17 @@ func (w *Watcher) doRestore(ctx context.Context, pod *corev1.Pod, checkpointHash
if err := waitForPodReady(readyCtx, w.clientset, pod.Namespace, pod.Name, containerName); err != nil {
log.Error(err, "Restore post-signal readiness check failed")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
annotatePod(ctx, w.clientset, log, pod, map[string]string{kubeAnnotationRestoreStatus: "failed"})
return
if statusErr := setRestoreStatus("failed"); statusErr != nil {
return statusErr
}
return nil
}
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeNormal, "RestoreSucceeded", fmt.Sprintf("Restore completed from checkpoint %s", checkpointHash))
annotatePod(ctx, w.clientset, log, pod, map[string]string{kubeAnnotationRestoreStatus: "completed"})
if err := setRestoreStatus("completed"); err != nil {
return err
}
return nil
}
func (w *Watcher) tryAcquire(podKey string) bool {
......@@ -426,6 +511,3 @@ func (w *Watcher) release(podKey string) {
defer w.inFlightMu.Unlock()
delete(w.inFlight, podKey)
}
......@@ -2,6 +2,7 @@ package watcher
import (
"context"
"errors"
"os"
"path/filepath"
"testing"
......@@ -10,7 +11,9 @@ import (
"github.com/go-logr/logr/testr"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/kubernetes/fake"
clientgotesting "k8s.io/client-go/testing"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/types"
)
......@@ -27,7 +30,7 @@ func makeTestWatcher(t *testing.T) *Watcher {
NodeName: testNodeName,
BasePath: t.TempDir(),
},
clientset: fake.NewSimpleClientset(),
clientset: fake.NewClientset(),
log: testr.New(t),
inFlight: make(map[string]struct{}),
stopCh: make(chan struct{}),
......@@ -354,3 +357,92 @@ func TestHandleRestorePodEvent(t *testing.T) {
})
}
}
func TestDoCheckpointKeepsInFlightOnTerminalStatusPatchFailure(t *testing.T) {
pod := &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "test-pod",
Namespace: "default",
},
}
clientset := fake.NewClientset(pod.DeepCopy())
patchCalls := 0
clientset.PrependReactor("patch", "pods", func(clientgotesting.Action) (bool, runtime.Object, error) {
patchCalls++
if patchCalls == 1 {
return false, nil, nil
}
return true, nil, errors.New("terminal patch failed")
})
w := &Watcher{
config: &types.AgentConfig{
NodeName: testNodeName,
BasePath: t.TempDir(),
},
clientset: clientset,
log: testr.New(t),
inFlight: map[string]struct{}{
"default/test-pod": {},
},
stopCh: make(chan struct{}),
}
err := w.doCheckpoint(context.Background(), pod, "abc123", "default/test-pod")
if err == nil {
t.Fatal("expected terminal checkpoint status update to fail")
}
if _, ok := w.inFlight["default/test-pod"]; !ok {
t.Fatal("checkpoint terminal status failure should keep pod in-flight")
}
if patchCalls != 1+terminalStatusPatchRetryAttempts {
t.Fatalf("patchCalls = %d, want %d", patchCalls, 1+terminalStatusPatchRetryAttempts)
}
}
func TestDoRestoreKeepsInFlightOnTerminalStatusPatchFailure(t *testing.T) {
pod := &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "test-pod",
Namespace: "default",
},
Status: corev1.PodStatus{
Phase: corev1.PodRunning,
},
}
clientset := fake.NewClientset(pod.DeepCopy())
patchCalls := 0
clientset.PrependReactor("patch", "pods", func(clientgotesting.Action) (bool, runtime.Object, error) {
patchCalls++
if patchCalls == 1 {
return false, nil, nil
}
return true, nil, errors.New("terminal patch failed")
})
w := &Watcher{
config: &types.AgentConfig{
NodeName: testNodeName,
BasePath: t.TempDir(),
},
clientset: clientset,
log: testr.New(t),
inFlight: map[string]struct{}{
"default/test-pod": {},
},
stopCh: make(chan struct{}),
}
err := w.doRestore(context.Background(), pod, "abc123", "default/test-pod")
if err == nil {
t.Fatal("expected terminal restore status update to fail")
}
if _, ok := w.inFlight["default/test-pod"]; !ok {
t.Fatal("restore terminal status failure should keep pod in-flight")
}
if patchCalls != 1+terminalStatusPatchRetryAttempts {
t.Fatalf("patchCalls = %d, want %d", patchCalls, 1+terminalStatusPatchRetryAttempts)
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment