Unverified Commit 43e810a4 authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

refactor(snapshot): add manifest-based snapshotctl flow and shared workload builders (#7671)


Signed-off-by: default avatarSchwinn Saereesitthipitak <schwinns@nvidia.com>
parent 23144df5
......@@ -25,7 +25,9 @@ import (
configv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/config/v1alpha1"
nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpoint"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpointjob"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
batchv1 "k8s.io/api/batch/v1"
......@@ -58,7 +60,9 @@ var testHash = func() string {
return hash
}()
var defaultCheckpointJobName = "checkpoint-job-" + testHash + "-" + consts.DefaultCheckpointArtifactVersion
var defaultCheckpointJobName = checkpointjob.DesiredCheckpointJobName(testHash, map[string]string{
snapshotprotocol.CheckpointArtifactVersionAnnotation: snapshotprotocol.DefaultCheckpointArtifactVersion,
})
func checkpointTestScheme() *runtime.Scheme {
s := runtime.NewScheme()
......@@ -74,13 +78,6 @@ func checkpointTestConfig() *configv1alpha1.OperatorConfiguration {
Checkpoint: configv1alpha1.CheckpointConfiguration{
Enabled: true,
ReadyForCheckpointFilePath: "/tmp/ready-for-checkpoint",
Storage: configv1alpha1.CheckpointStorageConfiguration{
Type: configv1alpha1.CheckpointStorageTypePVC,
PVC: configv1alpha1.CheckpointPVCConfig{
PVCName: "snapshot-pvc",
BasePath: "/checkpoints",
},
},
},
}
}
......@@ -143,14 +140,15 @@ func TestBuildCheckpointJob(t *testing.T) {
}
r := makeCheckpointReconciler(s, ckpt)
job := r.buildCheckpointJob(ckpt, defaultCheckpointJobName)
job, err := checkpointjob.BuildCheckpointJob(r.Config, ckpt, defaultCheckpointJobName)
require.NoError(t, err)
podSpec := job.Spec.Template.Spec
main := podSpec.Containers[0]
// Job and pod template labels
assert.Equal(t, testHash, job.Labels[consts.KubeLabelCheckpointHash])
assert.Equal(t, testHash, job.Labels[consts.KubeLabelCheckpointID])
assert.Equal(t, "true", job.Spec.Template.Labels[consts.KubeLabelIsCheckpointSource])
assert.Equal(t, testHash, job.Spec.Template.Labels[consts.KubeLabelCheckpointHash])
assert.Equal(t, testHash, job.Spec.Template.Labels[consts.KubeLabelCheckpointID])
// Env vars (checkpoint-specific + user-provided preserved)
envMap := make(map[string]string, len(main.Env))
......@@ -234,25 +232,26 @@ func TestBuildCheckpointJob(t *testing.T) {
assert.Equal(t, int32(0), *job.Spec.BackoffLimit)
assert.Equal(t, int32(300), *job.Spec.TTLSecondsAfterFinished)
// Custom deadlines override defaults, but checkpoint jobs never retry.
// Custom active deadlines override defaults, but checkpoint jobs never retry and keep a fixed TTL.
deadline := int64(7200)
backoff := int32(5)
ttl := int32(600)
ckpt.Spec.Job.ActiveDeadlineSeconds = &deadline
ckpt.Spec.Job.BackoffLimit = &backoff //nolint:staticcheck // Compatibility test: deprecated field must remain ignored by checkpoint Jobs.
ckpt.Spec.Job.TTLSecondsAfterFinished = &ttl
job = r.buildCheckpointJob(ckpt, defaultCheckpointJobName)
job, err = checkpointjob.BuildCheckpointJob(r.Config, ckpt, defaultCheckpointJobName)
require.NoError(t, err)
assert.Equal(t, int64(7200), *job.Spec.ActiveDeadlineSeconds)
assert.Equal(t, int32(0), *job.Spec.BackoffLimit)
assert.Equal(t, int32(600), *job.Spec.TTLSecondsAfterFinished)
assert.Equal(t, int32(300), *job.Spec.TTLSecondsAfterFinished)
ckpt.Spec.Job.PodTemplateSpec.Spec.Containers[0].Resources = corev1.ResourceRequirements{
Limits: corev1.ResourceList{
corev1.ResourceName("nvidia.com/gpu"): resource.MustParse("2"),
},
}
job = r.buildCheckpointJob(ckpt, defaultCheckpointJobName)
assert.Equal(t, []string{"cuda-checkpoint", "--launch-job", "python3", "-m", "dynamo.vllm"}, job.Spec.Template.Spec.Containers[0].Command)
job, err = checkpointjob.BuildCheckpointJob(r.Config, ckpt, defaultCheckpointJobName)
require.NoError(t, err)
assert.Equal(t, []string{"cuda-checkpoint"}, job.Spec.Template.Spec.Containers[0].Command)
assert.Equal(t, []string{"--launch-job", "python3", "-m", "dynamo.vllm"}, job.Spec.Template.Spec.Containers[0].Args)
}
func TestBuildCheckpointJobInjectsStandardEnvVars(t *testing.T) {
......@@ -274,7 +273,8 @@ func TestBuildCheckpointJobInjectsStandardEnvVars(t *testing.T) {
customShmSize := resource.MustParse("16Gi")
ckpt.Spec.Job.SharedMemory = &nvidiacomv1alpha1.SharedMemorySpec{Size: customShmSize}
job := r.buildCheckpointJob(ckpt, defaultCheckpointJobName)
job, err := checkpointjob.BuildCheckpointJob(r.Config, ckpt, defaultCheckpointJobName)
require.NoError(t, err)
foundCustomShmVolume := false
for _, v := range job.Spec.Template.Spec.Volumes {
if v.Name == consts.KubeValueNameSharedMemory {
......@@ -326,7 +326,7 @@ func TestCheckpointReconciler_Reconcile(t *testing.T) {
assert.Equal(t, nvidiacomv1alpha1.DynamoCheckpointPhasePending, updated.Status.Phase)
assert.Equal(t, testHash, updated.Status.IdentityHash)
assert.Empty(t, updated.Status.Message)
assert.Equal(t, testHash, updated.Labels[consts.KubeLabelCheckpointHash])
assert.Equal(t, testHash, updated.Labels[consts.KubeLabelCheckpointID])
})
t.Run("Ready phase is a no-op", func(t *testing.T) {
......@@ -352,7 +352,7 @@ func TestCheckpointReconciler_Reconcile(t *testing.T) {
updated := &nvidiacomv1alpha1.DynamoCheckpoint{}
require.NoError(t, r.Get(ctx, types.NamespacedName{Name: friendlyCheckpointName, Namespace: testNamespace}, updated))
assert.Equal(t, testHash, updated.Labels[consts.KubeLabelCheckpointHash])
assert.Equal(t, testHash, updated.Labels[consts.KubeLabelCheckpointID])
assert.Equal(t, testHash, updated.Status.IdentityHash)
})
......@@ -375,7 +375,6 @@ func TestCheckpointReconciler_Reconcile(t *testing.T) {
ckpt := makeTestCheckpoint(nvidiacomv1alpha1.DynamoCheckpointPhaseReady)
ckpt.Status.IdentityHash = testHash
ckpt.Status.JobName = defaultCheckpointJobName
ckpt.Status.Location = "/checkpoints/" + testHash + "/versions/" + consts.DefaultCheckpointArtifactVersion
ckpt.Annotations = map[string]string{consts.KubeAnnotationCheckpointArtifactVersion: "2"}
r := makeCheckpointReconciler(s, ckpt)
......@@ -388,7 +387,6 @@ func TestCheckpointReconciler_Reconcile(t *testing.T) {
require.NoError(t, r.Get(ctx, types.NamespacedName{Name: ckpt.Name, Namespace: testNamespace}, updated))
assert.Equal(t, nvidiacomv1alpha1.DynamoCheckpointPhaseCreating, updated.Status.Phase)
assert.Equal(t, "checkpoint-job-"+testHash+"-2", updated.Status.JobName)
assert.Equal(t, "/checkpoints/"+testHash+"/versions/2", updated.Status.Location)
})
t.Run("duplicate identity hash is rejected even with a readable name", func(t *testing.T) {
......@@ -431,13 +429,11 @@ func TestCheckpointReconciler_HandleCreating(t *testing.T) {
t.Run("succeeded job transitions to Ready", func(t *testing.T) {
ckpt := makeCreatingCkpt(testHash, defaultCheckpointJobName)
ckpt.Status.Location = "/checkpoints/" + testHash + "/versions/" + consts.DefaultCheckpointArtifactVersion
ckpt.Status.StorageType = "pvc"
job := &batchv1.Job{
ObjectMeta: metav1.ObjectMeta{
Name: defaultCheckpointJobName,
Namespace: testNamespace,
Annotations: map[string]string{checkpointStatusAnnotation: checkpointStatusCompleted},
Annotations: map[string]string{snapshotprotocol.CheckpointStatusAnnotation: snapshotprotocol.CheckpointStatusCompleted},
},
Status: batchv1.JobStatus{
Succeeded: 1,
......@@ -454,8 +450,6 @@ func TestCheckpointReconciler_HandleCreating(t *testing.T) {
updated := &nvidiacomv1alpha1.DynamoCheckpoint{}
require.NoError(t, r.Get(ctx, types.NamespacedName{Name: testHash, Namespace: testNamespace}, updated))
assert.Equal(t, nvidiacomv1alpha1.DynamoCheckpointPhaseReady, updated.Status.Phase)
assert.Equal(t, "/checkpoints/"+testHash+"/versions/"+consts.DefaultCheckpointArtifactVersion, updated.Status.Location)
assert.Equal(t, nvidiacomv1alpha1.DynamoCheckpointStorageType("pvc"), updated.Status.StorageType)
assert.NotNil(t, updated.Status.CreatedAt)
})
......@@ -531,7 +525,7 @@ func TestCheckpointReconciler_HandleCreating(t *testing.T) {
ObjectMeta: metav1.ObjectMeta{
Name: "job-agent-failed",
Namespace: testNamespace,
Annotations: map[string]string{checkpointStatusAnnotation: checkpointStatusFailed},
Annotations: map[string]string{snapshotprotocol.CheckpointStatusAnnotation: snapshotprotocol.CheckpointStatusFailed},
},
Status: batchv1.JobStatus{
Succeeded: 1,
......@@ -557,7 +551,7 @@ func TestCheckpointReconciler_HandleCreating(t *testing.T) {
ObjectMeta: metav1.ObjectMeta{
Name: "job-running-agent-failed",
Namespace: testNamespace,
Annotations: map[string]string{checkpointStatusAnnotation: checkpointStatusFailed},
Annotations: map[string]string{snapshotprotocol.CheckpointStatusAnnotation: snapshotprotocol.CheckpointStatusFailed},
},
Status: batchv1.JobStatus{Active: 1},
}
......@@ -590,14 +584,12 @@ func TestCheckpointReconciler_HandleCreating(t *testing.T) {
t.Run("in-flight version changes do not relabel the running job's artifact", func(t *testing.T) {
ckpt := makeCreatingCkpt(testHash, defaultCheckpointJobName)
ckpt.Status.Location = "/checkpoints/" + testHash + "/versions/" + consts.DefaultCheckpointArtifactVersion
ckpt.Status.StorageType = "pvc"
ckpt.Annotations = map[string]string{consts.KubeAnnotationCheckpointArtifactVersion: "2"}
job := &batchv1.Job{
ObjectMeta: metav1.ObjectMeta{
Name: defaultCheckpointJobName,
Namespace: testNamespace,
Annotations: map[string]string{checkpointStatusAnnotation: checkpointStatusCompleted},
Annotations: map[string]string{snapshotprotocol.CheckpointStatusAnnotation: snapshotprotocol.CheckpointStatusCompleted},
},
Status: batchv1.JobStatus{
Succeeded: 1,
......@@ -614,7 +606,6 @@ func TestCheckpointReconciler_HandleCreating(t *testing.T) {
updated := &nvidiacomv1alpha1.DynamoCheckpoint{}
require.NoError(t, r.Get(ctx, types.NamespacedName{Name: testHash, Namespace: testNamespace}, updated))
assert.Equal(t, nvidiacomv1alpha1.DynamoCheckpointPhaseReady, updated.Status.Phase)
assert.Equal(t, "/checkpoints/"+testHash+"/versions/"+consts.DefaultCheckpointArtifactVersion, updated.Status.Location)
})
t.Run("succeeded count without complete condition keeps Creating phase", func(t *testing.T) {
......
......@@ -80,6 +80,7 @@ type DynamoComponentDeploymentReconciler struct {
// +kubebuilder:rbac:groups=nvidia.com,resources=dynamocomponentdeployments/status,verbs=get;update;patch
// +kubebuilder:rbac:groups=nvidia.com,resources=dynamocomponentdeployments/finalizers,verbs=update
// +kubebuilder:rbac:groups=nvidia.com,resources=dynamocheckpoints,verbs=get;list
// +kubebuilder:rbac:groups=apps,resources=daemonsets,verbs=get;list;watch
//+kubebuilder:rbac:groups=apps,resources=deployments,verbs=get;list;watch;create;update;patch;delete
//+kubebuilder:rbac:groups=core,resources=pods,verbs=get;list;watch
......@@ -1041,7 +1042,9 @@ func (r *DynamoComponentDeploymentReconciler) generatePodTemplateSpec(ctx contex
// Resolve checkpoint for this component
var checkpointInfo *checkpoint.CheckpointInfo
if opt.dynamoComponentDeployment.Spec.Checkpoint != nil && opt.dynamoComponentDeployment.Spec.Checkpoint.Enabled {
if r.Config.Checkpoint.Enabled &&
opt.dynamoComponentDeployment.Spec.Checkpoint != nil &&
opt.dynamoComponentDeployment.Spec.Checkpoint.Enabled {
info, err := checkpoint.ResolveCheckpointForService(ctx, r.Client, opt.dynamoComponentDeployment.Namespace, opt.dynamoComponentDeployment.Spec.Checkpoint)
if err != nil {
return nil, errors.Wrap(err, "failed to resolve checkpoint")
......@@ -1054,6 +1057,17 @@ func (r *DynamoComponentDeploymentReconciler) generatePodTemplateSpec(ctx contex
err = errors.Wrap(err, "failed to generate base pod spec")
return nil, err
}
if r.Config.Checkpoint.Enabled {
if err := checkpoint.InjectCheckpointIntoPodSpec(
ctx,
r.Client,
opt.dynamoComponentDeployment.Namespace,
podSpec,
checkpointInfo,
); err != nil {
return nil, errors.Wrap(err, "failed to inject checkpoint config")
}
}
// Ensure we have at least one container (the main container should be there from GenerateBasePodSpec)
if len(podSpec.Containers) == 0 {
......
......@@ -29,6 +29,7 @@ import (
commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/dynamo"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
"github.com/google/go-cmp/cmp"
"github.com/onsi/gomega"
"github.com/onsi/gomega/format"
......@@ -1255,6 +1256,40 @@ func TestDynamoComponentDeploymentReconciler_generatePodTemplateSpec_RestoreLabe
if err := corev1.AddToScheme(s); err != nil {
t.Fatalf("Failed to add corev1 to scheme: %v", err)
}
if err := appsv1.AddToScheme(s); err != nil {
t.Fatalf("Failed to add appsv1 to scheme: %v", err)
}
snapshotAgentDaemonSet := &appsv1.DaemonSet{
ObjectMeta: metav1.ObjectMeta{
Name: "snapshot-agent",
Namespace: "default",
Labels: map[string]string{
snapshotprotocol.SnapshotAgentLabelKey: snapshotprotocol.SnapshotAgentLabelValue,
},
},
Spec: appsv1.DaemonSetSpec{
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{{
Name: snapshotprotocol.SnapshotAgentContainerName,
VolumeMounts: []corev1.VolumeMount{{
Name: "checkpoints",
MountPath: "/checkpoints",
}},
}},
Volumes: []corev1.Volume{{
Name: "checkpoints",
VolumeSource: corev1.VolumeSource{
PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{
ClaimName: "snapshot-pvc",
},
},
}},
},
},
},
}
makeDCD := func(checkpointRef string) *v1alpha1.DynamoComponentDeployment {
return &v1alpha1.DynamoComponentDeployment{
......@@ -1291,6 +1326,7 @@ func TestDynamoComponentDeploymentReconciler_generatePodTemplateSpec_RestoreLabe
}
makeReconciler := func(objs ...client.Object) *DynamoComponentDeploymentReconciler {
objs = append(objs, snapshotAgentDaemonSet.DeepCopy())
return &DynamoComponentDeploymentReconciler{
Client: fake.NewClientBuilder().
WithScheme(s).
......@@ -1299,13 +1335,6 @@ func TestDynamoComponentDeploymentReconciler_generatePodTemplateSpec_RestoreLabe
Config: &configv1alpha1.OperatorConfiguration{
Checkpoint: configv1alpha1.CheckpointConfiguration{
Enabled: true,
Storage: configv1alpha1.CheckpointStorageConfiguration{
Type: configv1alpha1.CheckpointStorageTypePVC,
PVC: configv1alpha1.CheckpointPVCConfig{
PVCName: "snapshot-pvc",
BasePath: "/checkpoints",
},
},
},
},
}
......@@ -1342,8 +1371,8 @@ func TestDynamoComponentDeploymentReconciler_generatePodTemplateSpec_RestoreLabe
if got := podTemplateSpec.Labels[commonconsts.KubeLabelIsRestoreTarget]; got != commonconsts.KubeLabelValueTrue {
t.Fatalf("expected %s label to be true, got %q", commonconsts.KubeLabelIsRestoreTarget, got)
}
if got := podTemplateSpec.Labels[commonconsts.KubeLabelCheckpointHash]; got != checkpointName {
t.Fatalf("expected %s to be checkpoint hash, got %q", commonconsts.KubeLabelCheckpointHash, got)
if got := podTemplateSpec.Labels[commonconsts.KubeLabelCheckpointID]; got != checkpointName {
t.Fatalf("expected %s to be checkpoint id, got %q", commonconsts.KubeLabelCheckpointID, got)
}
})
......@@ -1428,8 +1457,8 @@ func TestDynamoComponentDeploymentReconciler_generatePodTemplateSpec_RestoreLabe
if _, ok := podTemplateSpec.Labels[commonconsts.KubeLabelIsRestoreTarget]; ok {
t.Fatalf("did not expect %s label when checkpoint is not ready", commonconsts.KubeLabelIsRestoreTarget)
}
if _, ok := podTemplateSpec.Labels[commonconsts.KubeLabelCheckpointHash]; ok {
t.Fatalf("did not expect %s label when checkpoint is not ready", commonconsts.KubeLabelCheckpointHash)
if _, ok := podTemplateSpec.Labels[commonconsts.KubeLabelCheckpointID]; ok {
t.Fatalf("did not expect %s label when checkpoint is not ready", commonconsts.KubeLabelCheckpointID)
}
})
}
......@@ -1481,6 +1510,36 @@ func TestDynamoComponentDeploymentReconciler_generateDeployment_RestoreStrategy(
}
makeReconciler := func(objs ...client.Object) *DynamoComponentDeploymentReconciler {
objs = append(objs, &appsv1.DaemonSet{
ObjectMeta: metav1.ObjectMeta{
Name: "snapshot-agent",
Namespace: "default",
Labels: map[string]string{
snapshotprotocol.SnapshotAgentLabelKey: snapshotprotocol.SnapshotAgentLabelValue,
},
},
Spec: appsv1.DaemonSetSpec{
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{{
Name: snapshotprotocol.SnapshotAgentContainerName,
VolumeMounts: []corev1.VolumeMount{{
Name: "checkpoints",
MountPath: "/checkpoints",
}},
}},
Volumes: []corev1.Volume{{
Name: "checkpoints",
VolumeSource: corev1.VolumeSource{
PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{
ClaimName: "snapshot-pvc",
},
},
}},
},
},
},
})
return &DynamoComponentDeploymentReconciler{
Client: fake.NewClientBuilder().
WithScheme(s).
......@@ -1489,13 +1548,6 @@ func TestDynamoComponentDeploymentReconciler_generateDeployment_RestoreStrategy(
Config: &configv1alpha1.OperatorConfiguration{
Checkpoint: configv1alpha1.CheckpointConfiguration{
Enabled: true,
Storage: configv1alpha1.CheckpointStorageConfiguration{
Type: configv1alpha1.CheckpointStorageTypePVC,
PVC: configv1alpha1.CheckpointPVCConfig{
PVCName: "snapshot-pvc",
BasePath: "/checkpoints",
},
},
},
},
}
......
......@@ -91,6 +91,7 @@ type DynamoGraphDeploymentReconciler struct {
// +kubebuilder:rbac:groups=scheduling.run.ai,resources=queues,verbs=get;list
// +kubebuilder:rbac:groups=inference.networking.k8s.io,resources=inferencepools,verbs=get;list;watch;create;update;patch;delete
// +kubebuilder:rbac:groups=core,resources=pods,verbs=get;list;watch
// +kubebuilder:rbac:groups=apps,resources=daemonsets,verbs=get;list;watch
// Reconcile is part of the main kubernetes reconciliation loop which aims to
// move the current state of the cluster closer to the desired state.
......@@ -547,7 +548,7 @@ func (r *DynamoGraphDeploymentReconciler) reconcileGrovePodCliqueSet(ctx context
}
// generate the dynamoComponentsDeployments from the config
grovePodCliqueSet, err := dynamo.GenerateGrovePodCliqueSet(ctx, dynamoDeployment, r.Config, r.RuntimeConfig, r.DockerSecretRetriever, restartState, existingRestartAnnotations, checkpointInfos)
grovePodCliqueSet, err := dynamo.GenerateGrovePodCliqueSet(ctx, dynamoDeployment, r.Config, r.RuntimeConfig, r.Client, r.DockerSecretRetriever, restartState, existingRestartAnnotations, checkpointInfos)
if err != nil {
logger.Error(err, "failed to generate the Grove GangSet")
return nil, fmt.Errorf("failed to generate the Grove GangSet: %w", err)
......
......@@ -43,6 +43,7 @@ import (
networkingv1beta1 "istio.io/client-go/pkg/apis/networking/v1beta1"
corev1 "k8s.io/api/core/v1"
networkingv1 "k8s.io/api/networking/v1"
ctrlclient "sigs.k8s.io/controller-runtime/pkg/client"
)
// RestartState holds the restart state for DGD services.
......@@ -1165,22 +1166,6 @@ func GenerateBasePodSpec(
backend.UpdatePodSpec(&podSpec, numberOfNodes, role, component, serviceName, multinodeDeployer)
// Inject checkpoint configuration if enabled
// This handles ALL checkpoint-related modifications:
// - Command/Args transformation (moves Command to Args to respect image ENTRYPOINT)
// - Security context (hostIPC, privileged mode)
// - Restore/checkpoint pod metadata (labels/annotations)
// - Storage configuration (volumes, mounts)
// CheckpointInfo should have been resolved by ResolveCheckpointForService before calling this function
// Checkpoint config comes from the operator's controller config (Helm values)
var checkpointConfig *configv1alpha1.CheckpointConfiguration
if operatorConfig.Checkpoint.Enabled {
checkpointConfig = &operatorConfig.Checkpoint
}
if err := checkpoint.InjectCheckpointIntoPodSpec(&podSpec, checkpointInfo, checkpointConfig); err != nil {
return nil, fmt.Errorf("failed to inject checkpoint config: %w", err)
}
// Inject auto-generated frontend sidecar if configured
if component.FrontendSidecar != nil {
sidecar, err := generateFrontendSidecar(component.FrontendSidecar, componentContext, operatorConfig)
......@@ -1359,6 +1344,7 @@ func GenerateGrovePodCliqueSet(
dynamoDeployment *v1alpha1.DynamoGraphDeployment,
operatorConfig *configv1alpha1.OperatorConfiguration,
runtimeConfig *controller_common.RuntimeConfig,
kubeClient ctrlclient.Reader,
secretsRetriever SecretsRetriever,
restartState *RestartState,
existingRestartAnnotations map[string]string,
......@@ -1438,6 +1424,17 @@ func GenerateGrovePodCliqueSet(
if err != nil {
return nil, fmt.Errorf("failed to generate podSpec for role %s: %w", r.Name, err)
}
if operatorConfig.Checkpoint.Enabled {
if err := checkpoint.InjectCheckpointIntoPodSpec(
ctx,
kubeClient,
dynamoDeployment.Namespace,
podSpec,
checkpointInfo,
); err != nil {
return nil, fmt.Errorf("failed to inject checkpoint config for role %s: %w", r.Name, err)
}
}
clique := &grovev1alpha1.PodCliqueTemplateSpec{
Name: strings.ToLower(r.Name),
......
......@@ -3787,7 +3787,7 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := GenerateGrovePodCliqueSet(tt.args.ctx, tt.args.dynamoDeployment, tt.args.controllerConfig, &controller_common.RuntimeConfig{}, nil, nil, nil, nil)
got, err := GenerateGrovePodCliqueSet(tt.args.ctx, tt.args.dynamoDeployment, tt.args.controllerConfig, &controller_common.RuntimeConfig{}, nil, nil, nil, nil, nil)
if (err != nil) != tt.wantErr {
t.Errorf("GenerateGrovePodCliqueSet() error = %v, wantErr %v", err, tt.wantErr)
return
......@@ -3848,7 +3848,7 @@ func Test_GeneratePodCliqueSetGlobalDynamoNamespace(t *testing.T) {
},
}
got, err := GenerateGrovePodCliqueSet(context.Background(), dynamoDeployment, &configv1alpha1.OperatorConfiguration{}, &controller_common.RuntimeConfig{}, nil, nil, nil, nil)
got, err := GenerateGrovePodCliqueSet(context.Background(), dynamoDeployment, &configv1alpha1.OperatorConfiguration{}, &controller_common.RuntimeConfig{}, nil, nil, nil, nil, nil)
if !assert.NoError(t, err) {
return
}
......@@ -4953,7 +4953,7 @@ func TestGenerateGrovePodCliqueSet_StartsAfterDependencies(t *testing.T) {
},
}
got, err := GenerateGrovePodCliqueSet(context.Background(), dynamoDeployment, controllerConfig, &controller_common.RuntimeConfig{}, secretsRetriever, nil, nil, nil)
got, err := GenerateGrovePodCliqueSet(context.Background(), dynamoDeployment, controllerConfig, &controller_common.RuntimeConfig{}, nil, secretsRetriever, nil, nil, nil)
if err != nil {
t.Errorf("GenerateGrovePodCliqueSet() error = %v", err)
return
......@@ -6783,7 +6783,7 @@ func TestGenerateGrovePodCliqueSet_RestartAnnotations(t *testing.T) {
},
}
got, err := GenerateGrovePodCliqueSet(context.Background(), dgd, controllerConfig, &controller_common.RuntimeConfig{}, nil, tt.restartState, nil, nil)
got, err := GenerateGrovePodCliqueSet(context.Background(), dgd, controllerConfig, &controller_common.RuntimeConfig{}, nil, nil, tt.restartState, nil, nil)
if err != nil {
t.Fatalf("GenerateGrovePodCliqueSet() error = %v", err)
}
......@@ -6865,7 +6865,7 @@ func TestGenerateLabels_RemovesStaleRestoreLabelsWhenCheckpointNotReady(t *testi
ExtraPodMetadata: &v1alpha1.ExtraPodMetadata{
Labels: map[string]string{
"extra-label": "keep-too",
commonconsts.KubeLabelCheckpointHash: "stale-hash",
commonconsts.KubeLabelCheckpointID: "stale-hash",
},
},
},
......@@ -6884,7 +6884,7 @@ func TestGenerateLabels_RemovesStaleRestoreLabelsWhenCheckpointNotReady(t *testi
assert.Equal(t, "keep", labels["user-label"])
assert.Equal(t, "keep-too", labels["extra-label"])
_, hasRestoreTarget := labels[commonconsts.KubeLabelIsRestoreTarget]
_, hasCheckpointHash := labels[commonconsts.KubeLabelCheckpointHash]
_, hasCheckpointHash := labels[commonconsts.KubeLabelCheckpointID]
assert.False(t, hasRestoreTarget)
assert.False(t, hasCheckpointHash)
}
......@@ -6899,7 +6899,7 @@ func TestGenerateLabels_OverwritesStaleRestoreLabelsWhenCheckpointReady(t *testi
},
ExtraPodMetadata: &v1alpha1.ExtraPodMetadata{
Labels: map[string]string{
commonconsts.KubeLabelCheckpointHash: "stale-hash",
commonconsts.KubeLabelCheckpointID: "stale-hash",
},
},
},
......@@ -6916,7 +6916,7 @@ func TestGenerateLabels_OverwritesStaleRestoreLabelsWhenCheckpointReady(t *testi
Hash: "resolved-hash",
})
assert.Equal(t, commonconsts.KubeLabelValueTrue, labels[commonconsts.KubeLabelIsRestoreTarget])
assert.Equal(t, "resolved-hash", labels[commonconsts.KubeLabelCheckpointHash])
assert.Equal(t, "resolved-hash", labels[commonconsts.KubeLabelCheckpointID])
}
func TestGenerateLabels_ReassertsRestoreIdentityLabelsAfterMetadataMerge(t *testing.T) {
......@@ -7459,7 +7459,7 @@ func TestGenerateGrovePodCliqueSet_SpecMetadataPropagation(t *testing.T) {
},
}
pcs, err := GenerateGrovePodCliqueSet(context.Background(), dgd, &configv1alpha1.OperatorConfiguration{}, &controller_common.RuntimeConfig{}, nil, nil, nil, nil)
pcs, err := GenerateGrovePodCliqueSet(context.Background(), dgd, &configv1alpha1.OperatorConfiguration{}, &controller_common.RuntimeConfig{}, nil, nil, nil, nil, nil)
require.NoError(t, err)
// PCS object-level metadata
......@@ -7627,6 +7627,7 @@ func TestGenerateGrovePodCliqueSet_TopologyConstraints(t *testing.T) {
tt.deployment,
operatorConfig,
&controller_common.RuntimeConfig{},
nil,
secretsRetriever,
&RestartState{},
nil,
......
......@@ -8,7 +8,7 @@ import (
"gopkg.in/yaml.v3"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/types"
"github.com/ai-dynamo/dynamo/deploy/snapshot/internal/types"
)
// ConfigMapPath is the default path where the ConfigMap is mounted.
......
......@@ -12,9 +12,9 @@ import (
"github.com/containerd/containerd"
"github.com/go-logr/logr"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/common"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/controller"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/logging"
"github.com/ai-dynamo/dynamo/deploy/snapshot/internal/controller"
"github.com/ai-dynamo/dynamo/deploy/snapshot/internal/logging"
snapshotruntime "github.com/ai-dynamo/dynamo/deploy/snapshot/internal/runtime"
)
func main() {
......@@ -29,7 +29,7 @@ func main() {
fatal(agentLog, err, "Invalid configuration")
}
ctrd, err := containerd.New(common.ContainerdSocket)
ctrd, err := containerd.New(snapshotruntime.ContainerdSocket)
if err != nil {
fatal(agentLog, err, "Failed to connect to containerd")
}
......
......@@ -8,8 +8,8 @@ import (
"github.com/go-logr/logr"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/executor"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/logging"
"github.com/ai-dynamo/dynamo/deploy/snapshot/internal/executor"
"github.com/ai-dynamo/dynamo/deploy/snapshot/internal/logging"
)
func main() {
......
# `snapshotctl`
`snapshotctl` is a lower-level snapshot utility for developers and operators.
It is not the primary Dynamo user workflow. The normal user-facing path is:
```text
DynamoCheckpoint CR -> operator -> snapshot-agent
```
Use `snapshotctl` when you want to exercise checkpoint or restore behavior
directly from a worker pod manifest without going through the operator.
## Requirements
- the snapshot Helm chart must already be installed in the target namespace
- a `snapshot-agent` DaemonSet must be running in that namespace
- the namespace must already have the checkpoint PVC mounted by the agent
## Manifest requirements
`snapshotctl checkpoint --manifest ...` and `snapshotctl restore --manifest ...`
accept a Kubernetes `Pod` manifest, not a Deployment or Job manifest.
That pod manifest must:
- describe the worker pod you want to checkpoint or restore
- contain exactly one worker container
- use the placeholder image for checkpoint-aware flows
- match the runtime-relevant worker settings you care about preserving
In practice, start from the real worker pod spec you would normally run, then
keep only the pod-level fields needed to recreate that worker accurately.
## Commands
Checkpoint from a manifest:
```bash
snapshotctl checkpoint \
--manifest ./worker-pod.yaml \
--namespace ${NAMESPACE}
```
If `--checkpoint-id` is omitted, `snapshotctl` generates one.
Restore by creating a new pod from a manifest:
```bash
snapshotctl restore \
--manifest ./worker-pod.yaml \
--namespace ${NAMESPACE} \
--checkpoint-id manual-snapshot-123
```
Restore an existing snapshot-compatible pod in place:
```bash
snapshotctl restore \
--pod existing-restore-target \
--namespace ${NAMESPACE} \
--checkpoint-id manual-snapshot-123
```
## Notes
- `restore --pod` expects a pod that is already compatible with snapshot restore
- `restore --manifest` creates a new restore target pod from the manifest you provide
- `snapshotctl` is useful for debugging and lower-level validation, but it does
not replace the operator-managed checkpoint flow
package main
import (
"context"
"fmt"
"strings"
"time"
batchv1 "k8s.io/api/batch/v1"
corev1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/kubernetes"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
)
const defaultGeneratedCheckpointIDPrefix = "manual-snapshot"
type checkpointOptions struct {
ManifestPath string
Namespace string
KubeContext string
CheckpointID string
DisableCudaCheckpointJobFile bool
Timeout time.Duration
}
type result struct {
Name string
Namespace string
CheckpointID string
CheckpointLocation string
CheckpointJob string
RestorePod string
Status string
}
func runCheckpointFlow(ctx context.Context, opts checkpointOptions) (*result, error) {
if strings.TrimSpace(opts.ManifestPath) == "" {
return nil, fmt.Errorf("missing required flags: --manifest")
}
if opts.Timeout <= 0 {
return nil, fmt.Errorf("--timeout must be greater than zero")
}
pod, clientset, namespace, storage, err := loadRunContext(ctx, opts.ManifestPath, opts.Namespace, opts.KubeContext)
if err != nil {
return nil, err
}
checkpointID := strings.TrimSpace(opts.CheckpointID)
if checkpointID == "" {
checkpointID = fmt.Sprintf("%s-%d", defaultGeneratedCheckpointIDPrefix, time.Now().UTC().UnixNano())
}
resolvedStorage, err := snapshotprotocol.ResolveCheckpointStorage(checkpointID, "", snapshotprotocol.Storage{
Type: snapshotprotocol.StorageTypePVC,
PVCName: storage.PVCName,
BasePath: storage.BasePath,
})
if err != nil {
return nil, err
}
checkpointJobName := pod.Name + "-checkpoint"
job, err := snapshotprotocol.NewCheckpointJob(&corev1.PodTemplateSpec{
ObjectMeta: metav1.ObjectMeta{
Labels: pod.Labels,
Annotations: pod.Annotations,
},
Spec: *pod.Spec.DeepCopy(),
}, snapshotprotocol.CheckpointJobOptions{
Namespace: namespace,
CheckpointID: checkpointID,
ArtifactVersion: snapshotprotocol.DefaultCheckpointArtifactVersion,
SeccompProfile: snapshotprotocol.DefaultSeccompLocalhostProfile,
Name: checkpointJobName,
WrapLaunchJob: !opts.DisableCudaCheckpointJobFile,
})
if err != nil {
return nil, err
}
_, err = clientset.BatchV1().Jobs(namespace).Create(ctx, job, metav1.CreateOptions{})
if apierrors.IsAlreadyExists(err) {
return nil, fmt.Errorf("checkpoint job %s/%s already exists", namespace, checkpointJobName)
}
if err != nil {
return nil, err
}
waitCtx, cancel := context.WithTimeout(ctx, opts.Timeout)
defer cancel()
status, err := waitForCheckpoint(waitCtx, clientset, namespace, checkpointJobName)
if err != nil {
return nil, err
}
return &result{
Name: pod.Name,
Namespace: namespace,
CheckpointID: checkpointID,
CheckpointLocation: resolvedStorage.Location,
CheckpointJob: checkpointJobName,
Status: status,
}, nil
}
func waitForCheckpoint(ctx context.Context, clientset kubernetes.Interface, namespace string, jobName string) (string, error) {
var status string
if err := wait.PollUntilContextCancel(ctx, 2*time.Second, true, func(ctx context.Context) (bool, error) {
job, err := clientset.BatchV1().Jobs(namespace).Get(ctx, jobName, metav1.GetOptions{})
if err != nil {
if apierrors.IsNotFound(err) {
return false, nil
}
return false, fmt.Errorf("get checkpoint job %s/%s: %w", namespace, jobName, err)
}
status = strings.TrimSpace(job.Annotations[snapshotprotocol.CheckpointStatusAnnotation])
if status == snapshotprotocol.CheckpointStatusCompleted {
return true, nil
}
if status == snapshotprotocol.CheckpointStatusFailed {
return false, fmt.Errorf("checkpoint job %s/%s failed", namespace, jobName)
}
if job.Status.Failed > 0 {
return false, fmt.Errorf("checkpoint job %s/%s failed", namespace, jobName)
}
for _, condition := range job.Status.Conditions {
if condition.Status != corev1.ConditionTrue {
continue
}
if condition.Type == batchv1.JobFailed {
return false, fmt.Errorf("checkpoint job %s/%s failed: %s", namespace, jobName, strings.TrimSpace(condition.Message))
}
}
return false, nil
}); err != nil {
if !wait.Interrupted(err) {
return "", err
}
summaryCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
pods, err := clientset.CoreV1().Pods(namespace).List(summaryCtx, metav1.ListOptions{
LabelSelector: "batch.kubernetes.io/job-name=" + jobName,
})
summary := "no checkpoint pod created yet"
if err != nil {
summary = "unable to list checkpoint pod: " + err.Error()
} else if len(pods.Items) != 0 {
pod := pods.Items[0]
parts := []string{
fmt.Sprintf("job_status=%q", status),
fmt.Sprintf("pod=%s phase=%s", pod.Name, pod.Status.Phase),
}
for _, condition := range pod.Status.Conditions {
if condition.Status == corev1.ConditionTrue || condition.Status == corev1.ConditionFalse {
parts = append(parts, fmt.Sprintf("%s=%s", condition.Type, condition.Status))
}
}
for _, status := range pod.Status.ContainerStatuses {
if status.State.Waiting != nil {
parts = append(parts, fmt.Sprintf("container=%s waiting=%s", status.Name, status.State.Waiting.Reason))
}
if status.State.Terminated != nil {
parts = append(parts, fmt.Sprintf("container=%s terminated=%s", status.Name, status.State.Terminated.Reason))
}
}
summary = strings.Join(parts, " ")
}
return "", fmt.Errorf("checkpoint job %s/%s timed out: %s", namespace, jobName, summary)
}
return status, nil
}
package main
import (
"context"
"fmt"
"os"
"strings"
"time"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/tools/clientcmd"
"sigs.k8s.io/yaml"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
)
func loadRunContext(ctx context.Context, manifestPath string, namespaceOverride string, kubeContext string) (*corev1.Pod, kubernetes.Interface, string, snapshotprotocol.Storage, error) {
pod, err := loadPod(manifestPath)
if err != nil {
return nil, nil, "", snapshotprotocol.Storage{}, err
}
clientset, currentNamespace, err := loadClientset(kubeContext)
if err != nil {
return nil, nil, "", snapshotprotocol.Storage{}, err
}
namespace := currentNamespace
if namespace == "" {
namespace = corev1.NamespaceDefault
}
if pod.Namespace != "" {
namespace = pod.Namespace
}
if namespaceOverride != "" {
namespace = namespaceOverride
}
storage, err := discoverSnapshotStorage(ctx, clientset, namespace)
if err != nil {
return nil, nil, "", snapshotprotocol.Storage{}, err
}
return pod, clientset, namespace, storage, nil
}
func loadClientset(kubeContext string) (kubernetes.Interface, string, error) {
loadingRules := clientcmd.NewDefaultClientConfigLoadingRules()
clientConfig := clientcmd.NewNonInteractiveDeferredLoadingClientConfig(loadingRules, &clientcmd.ConfigOverrides{
CurrentContext: strings.TrimSpace(kubeContext),
})
restConfig, err := clientConfig.ClientConfig()
if err != nil {
return nil, "", fmt.Errorf("load kubeconfig: %w", err)
}
restConfig.Timeout = 30 * time.Second
namespace, _, err := clientConfig.Namespace()
if err != nil {
return nil, "", fmt.Errorf("resolve current namespace: %w", err)
}
if strings.TrimSpace(namespace) == "" {
namespace = corev1.NamespaceDefault
}
clientset, err := kubernetes.NewForConfig(restConfig)
if err != nil {
return nil, "", fmt.Errorf("create kubernetes client: %w", err)
}
return clientset, namespace, nil
}
func discoverSnapshotStorage(ctx context.Context, clientset kubernetes.Interface, namespace string) (snapshotprotocol.Storage, error) {
daemonSets, err := clientset.AppsV1().DaemonSets(namespace).List(ctx, metav1.ListOptions{
LabelSelector: snapshotprotocol.SnapshotAgentLabelSelector,
})
if err != nil {
return snapshotprotocol.Storage{}, fmt.Errorf("list snapshot-agent daemonsets in %s: %w", namespace, err)
}
return snapshotprotocol.DiscoverStorageFromDaemonSets(namespace, daemonSets.Items)
}
func loadPod(manifestPath string) (*corev1.Pod, error) {
content, err := os.ReadFile(manifestPath)
if err != nil {
return nil, fmt.Errorf("read manifest %s: %w", manifestPath, err)
}
var pod corev1.Pod
if err := yaml.Unmarshal(content, &pod); err != nil {
return nil, fmt.Errorf("parse manifest %s: %w", manifestPath, err)
}
if kind := strings.TrimSpace(pod.Kind); kind != "" && kind != "Pod" {
return nil, fmt.Errorf("manifest %s is kind %q, expected Pod", manifestPath, kind)
}
if len(pod.Spec.Containers) != 1 {
return nil, fmt.Errorf(
"manifest %s has %d containers; snapshotctl requires exactly one worker container",
manifestPath,
len(pod.Spec.Containers),
)
}
if strings.TrimSpace(pod.Spec.Containers[0].Image) == "" {
return nil, fmt.Errorf("manifest %s: worker container image is required", manifestPath)
}
if strings.TrimSpace(pod.Name) == "" {
return nil, fmt.Errorf("manifest %s: metadata.name is required", manifestPath)
}
pod.Namespace = strings.TrimSpace(pod.Namespace)
return &pod, nil
}
package main
import (
"context"
"flag"
"fmt"
"os"
"time"
"github.com/ai-dynamo/dynamo/deploy/snapshot/internal/logging"
)
var snapshotctlLog = logging.ConfigureLogger("stderr").WithName("snapshotctl")
func main() {
if err := run(os.Args[1:]); err != nil {
snapshotctlLog.Error(err, "snapshotctl failed")
os.Exit(1)
}
}
func run(args []string) error {
if len(args) == 0 {
printUsage()
return nil
}
switch args[0] {
case "checkpoint":
return runCheckpoint(args[1:])
case "restore":
return runRestore(args[1:])
case "help", "-h", "--help":
printUsage()
return nil
default:
return fmt.Errorf("unknown subcommand %q", args[0])
}
}
func runCheckpoint(args []string) error {
flags := flag.NewFlagSet("checkpoint", flag.ContinueOnError)
flags.SetOutput(os.Stderr)
manifest := flags.String("manifest", "", "Path to a worker Pod manifest")
namespace := flags.String("namespace", "", "Namespace override; defaults to the manifest namespace or current kube context namespace")
kubeContext := flags.String("kube-context", "", "Kubernetes context override")
checkpointID := flags.String("checkpoint-id", "", "Explicit checkpoint ID; defaults to a generated value")
disableCudaCheckpointJobFile := flags.Bool("disable-cuda-checkpoint-job-file", false, "Preserve the manifest command instead of wrapping it with cuda-checkpoint --launch-job")
timeout := flags.Duration("timeout", 45*time.Minute, "Maximum time to wait for checkpoint completion")
if err := flags.Parse(args); err != nil {
return err
}
if len(flags.Args()) != 0 {
return fmt.Errorf("unexpected arguments: %v", flags.Args())
}
if *manifest == "" {
return fmt.Errorf("--manifest is required")
}
snapshotctlLog.Info("Running checkpoint", "manifest", *manifest, "namespace", *namespace)
result, err := runCheckpointFlow(context.Background(), checkpointOptions{
ManifestPath: *manifest,
Namespace: *namespace,
KubeContext: *kubeContext,
CheckpointID: *checkpointID,
DisableCudaCheckpointJobFile: *disableCudaCheckpointJobFile,
Timeout: *timeout,
})
if err != nil {
return err
}
snapshotctlLog.Info("Checkpoint completed", "job", result.CheckpointJob, "checkpoint_id", result.CheckpointID)
fmt.Printf("status=%s\n", result.Status)
fmt.Printf("namespace=%s\n", result.Namespace)
fmt.Printf("name=%s\n", result.Name)
fmt.Printf("checkpoint_job=%s\n", result.CheckpointJob)
fmt.Printf("checkpoint_id=%s\n", result.CheckpointID)
fmt.Printf("checkpoint_location=%s\n", result.CheckpointLocation)
return nil
}
func runRestore(args []string) error {
flags := flag.NewFlagSet("restore", flag.ContinueOnError)
flags.SetOutput(os.Stderr)
manifest := flags.String("manifest", "", "Path to a worker Pod manifest used to create a new restore pod")
podName := flags.String("pod", "", "Existing restore target pod name")
namespace := flags.String("namespace", "", "Namespace override; defaults to the manifest namespace or current kube context namespace")
kubeContext := flags.String("kube-context", "", "Kubernetes context override")
checkpointID := flags.String("checkpoint-id", "", "Checkpoint ID to restore")
timeout := flags.Duration("timeout", 45*time.Minute, "Maximum time to wait for restore completion")
if err := flags.Parse(args); err != nil {
return err
}
if len(flags.Args()) != 0 {
return fmt.Errorf("unexpected arguments: %v", flags.Args())
}
if (*manifest == "") == (*podName == "") {
return fmt.Errorf("must specify exactly one of --manifest or --pod")
}
snapshotctlLog.Info("Running restore", "manifest", *manifest, "pod", *podName, "namespace", *namespace, "checkpoint_id", *checkpointID)
result, err := runRestoreFlow(context.Background(), restoreOptions{
ManifestPath: *manifest,
PodName: *podName,
Namespace: *namespace,
KubeContext: *kubeContext,
CheckpointID: *checkpointID,
Timeout: *timeout,
})
if err != nil {
return err
}
snapshotctlLog.Info("Restore completed", "pod", result.RestorePod, "checkpoint_id", result.CheckpointID)
fmt.Printf("status=%s\n", result.Status)
fmt.Printf("namespace=%s\n", result.Namespace)
fmt.Printf("name=%s\n", result.Name)
fmt.Printf("restore_pod=%s\n", result.RestorePod)
fmt.Printf("checkpoint_id=%s\n", result.CheckpointID)
fmt.Printf("checkpoint_location=%s\n", result.CheckpointLocation)
return nil
}
func printUsage() {
fmt.Fprintf(os.Stderr, `snapshotctl runs snapshot checkpoint and restore flows from a worker Pod manifest.
Subcommands:
checkpoint
restore
Examples:
snapshotctl checkpoint --manifest /tmp/vllm-worker-pod.yaml
snapshotctl restore --manifest /tmp/sglang-worker-pod.yaml --checkpoint-id manual-snapshot-123
snapshotctl restore --pod existing-restore-target --checkpoint-id manual-snapshot-123
`)
}
package main
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
corev1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/kubernetes"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
)
type restoreOptions struct {
ManifestPath string
PodName string
Namespace string
KubeContext string
CheckpointID string
Timeout time.Duration
}
func runRestoreFlow(ctx context.Context, opts restoreOptions) (*result, error) {
createPodFromManifest := strings.TrimSpace(opts.ManifestPath) != ""
targetExistingPod := strings.TrimSpace(opts.PodName) != ""
if createPodFromManifest == targetExistingPod {
return nil, fmt.Errorf("restore requires exactly one of --manifest or --pod")
}
if strings.TrimSpace(opts.CheckpointID) == "" {
return nil, fmt.Errorf("missing required flags: --checkpoint-id")
}
if opts.Timeout <= 0 {
return nil, fmt.Errorf("--timeout must be greater than zero")
}
checkpointID := strings.TrimSpace(opts.CheckpointID)
clientset, currentNamespace, err := loadClientset(opts.KubeContext)
if err != nil {
return nil, err
}
namespace := currentNamespace
if namespace == "" {
namespace = corev1.NamespaceDefault
}
if strings.TrimSpace(opts.Namespace) != "" {
namespace = strings.TrimSpace(opts.Namespace)
}
podName := strings.TrimSpace(opts.PodName)
pod := &corev1.Pod{}
if createPodFromManifest {
pod, err = loadPod(opts.ManifestPath)
if err != nil {
return nil, err
}
if strings.TrimSpace(pod.Namespace) != "" && strings.TrimSpace(opts.Namespace) == "" {
namespace = strings.TrimSpace(pod.Namespace)
}
podName = pod.Name
}
storage, err := discoverSnapshotStorage(ctx, clientset, namespace)
if err != nil {
return nil, err
}
resolvedStorage, err := snapshotprotocol.ResolveRestoreStorage(checkpointID, snapshotprotocol.DefaultCheckpointArtifactVersion, "", snapshotprotocol.Storage{
Type: snapshotprotocol.StorageTypePVC,
PVCName: storage.PVCName,
BasePath: storage.BasePath,
})
if err != nil {
return nil, err
}
if createPodFromManifest {
restorePod := snapshotprotocol.NewRestorePod(&corev1.Pod{
TypeMeta: metav1.TypeMeta{APIVersion: "v1", Kind: "Pod"},
ObjectMeta: metav1.ObjectMeta{
Name: pod.Name,
Labels: pod.Labels,
Annotations: pod.Annotations,
},
Spec: *pod.Spec.DeepCopy(),
}, snapshotprotocol.PodOptions{
Namespace: namespace,
CheckpointID: checkpointID,
ArtifactVersion: snapshotprotocol.DefaultCheckpointArtifactVersion,
Storage: resolvedStorage,
SeccompProfile: snapshotprotocol.DefaultSeccompLocalhostProfile,
})
_, err = clientset.CoreV1().Pods(namespace).Create(ctx, restorePod, metav1.CreateOptions{})
if apierrors.IsAlreadyExists(err) {
return nil, fmt.Errorf("restore pod %s/%s already exists", namespace, pod.Name)
}
if err != nil {
return nil, err
}
} else {
pod, err = clientset.CoreV1().Pods(namespace).Get(ctx, podName, metav1.GetOptions{})
if err != nil {
return nil, fmt.Errorf("get restore target pod %s/%s: %w", namespace, podName, err)
}
if len(pod.Spec.Containers) == 0 {
return nil, fmt.Errorf("restore target pod %s/%s has no containers", namespace, podName)
}
if err := snapshotprotocol.ValidateRestorePodSpec(&pod.Spec, resolvedStorage, snapshotprotocol.DefaultSeccompLocalhostProfile); err != nil {
return nil, fmt.Errorf("restore target pod %s/%s is not snapshot-compatible: %w", namespace, podName, err)
}
labels := map[string]string{}
for key, value := range pod.Labels {
labels[key] = value
}
annotations := map[string]string{}
for key, value := range pod.Annotations {
annotations[key] = value
}
snapshotprotocol.ApplyRestoreTargetMetadata(labels, annotations, true, checkpointID, snapshotprotocol.DefaultCheckpointArtifactVersion)
patch, err := json.Marshal(map[string]any{
"metadata": map[string]any{
"labels": labels,
"annotations": annotations,
},
})
if err != nil {
return nil, fmt.Errorf("encode restore target metadata patch: %w", err)
}
if _, err := clientset.CoreV1().Pods(namespace).Patch(ctx, podName, types.MergePatchType, patch, metav1.PatchOptions{}); err != nil {
return nil, fmt.Errorf("patch restore target pod %s/%s: %w", namespace, podName, err)
}
}
waitCtx, cancel := context.WithTimeout(ctx, opts.Timeout)
defer cancel()
status, err := waitForRestore(waitCtx, clientset, namespace, podName)
if err != nil {
return nil, err
}
return &result{
Name: podName,
Namespace: namespace,
CheckpointID: checkpointID,
CheckpointLocation: resolvedStorage.Location,
RestorePod: podName,
Status: status,
}, nil
}
func waitForRestore(ctx context.Context, clientset kubernetes.Interface, namespace string, podName string) (string, error) {
var status string
if err := wait.PollUntilContextCancel(ctx, 2*time.Second, true, func(ctx context.Context) (bool, error) {
pod, err := clientset.CoreV1().Pods(namespace).Get(ctx, podName, metav1.GetOptions{})
if err != nil {
return false, fmt.Errorf("get restore pod %s/%s: %w", namespace, podName, err)
}
status = strings.TrimSpace(pod.Annotations[snapshotprotocol.RestoreStatusAnnotation])
if status == snapshotprotocol.RestoreStatusCompleted {
return true, nil
}
if status == snapshotprotocol.RestoreStatusFailed {
return false, fmt.Errorf("restore pod %s/%s failed", namespace, podName)
}
if pod.Status.Phase == corev1.PodFailed {
return false, fmt.Errorf("restore pod %s/%s entered phase Failed (%s)", namespace, podName, pod.Status.Reason)
}
return false, nil
}); err != nil {
if !wait.Interrupted(err) {
return "", err
}
return "", fmt.Errorf("restore pod %s/%s timed out: status=%q", namespace, podName, status)
}
return status, nil
}
......@@ -8,6 +8,7 @@ require (
github.com/cyphar/filepath-securejoin v0.5.1
github.com/go-logr/logr v1.4.3
github.com/go-logr/zapr v1.3.0
github.com/google/uuid v1.6.0
github.com/moby/sys/mountinfo v0.7.1
github.com/opencontainers/runtime-spec v1.2.0
github.com/prometheus/procfs v0.16.1
......@@ -16,10 +17,13 @@ require (
google.golang.org/grpc v1.72.2
google.golang.org/protobuf v1.36.11
gopkg.in/yaml.v3 v3.0.1
k8s.io/api v0.35.0
k8s.io/apimachinery v0.35.0
k8s.io/client-go v0.35.0
k8s.io/kubelet v0.35.0
k8s.io/api v0.34.3
k8s.io/apimachinery v0.34.3
k8s.io/client-go v0.34.3
k8s.io/kubelet v0.34.3
k8s.io/utils v0.0.0-20251002143259-bc988d571ff4
sigs.k8s.io/controller-runtime v0.22.4
sigs.k8s.io/yaml v1.6.0
)
require (
......@@ -41,6 +45,7 @@ require (
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c // indirect
github.com/emicklei/go-restful/v3 v3.12.2 // indirect
github.com/evanphx/json-patch/v5 v5.9.11 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fxamacker/cbor/v2 v2.9.0 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
......@@ -51,7 +56,6 @@ require (
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/gnostic-models v0.7.0 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.18.0 // indirect
......@@ -70,6 +74,7 @@ require (
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/spf13/pflag v1.0.9 // indirect
github.com/x448/float16 v0.8.4 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
......@@ -93,10 +98,8 @@ require (
gopkg.in/evanphx/json-patch.v4 v4.13.0 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect
k8s.io/klog/v2 v2.130.1 // indirect
k8s.io/kube-openapi v0.0.0-20250910181357-589584f1c912 // indirect
k8s.io/utils v0.0.0-20251002143259-bc988d571ff4 // indirect
k8s.io/kube-openapi v0.0.0-20250814151709-d7b6acb124c3 // indirect
sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 // indirect
sigs.k8s.io/randfill v1.0.0 // indirect
sigs.k8s.io/structured-merge-diff/v6 v6.3.2-0.20260122202528-d9cc6641c482 // indirect
sigs.k8s.io/yaml v1.6.0 // indirect
sigs.k8s.io/structured-merge-diff/v6 v6.3.1 // indirect
)
......@@ -4,12 +4,12 @@ github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h
github.com/AdamKorcz/go-118-fuzz-build v0.0.0-20230306123547-8075edf89bb0 h1:59MxjQVfjXsBpLy+dbd2/ELV5ofnUkUZBvWSC85sheA=
github.com/AdamKorcz/go-118-fuzz-build v0.0.0-20230306123547-8075edf89bb0/go.mod h1:OahwfttHWG6eJ0clwcfBAHoDI6X/LV/15hx/wlMZSrU=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/Masterminds/semver/v3 v3.4.0 h1:Zog+i5UMtVoCU8oKka5P7i9q9HgrJeGzI9SA1Xbatp0=
github.com/Masterminds/semver/v3 v3.4.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/Microsoft/hcsshim v0.11.7 h1:vl/nj3Bar/CvJSYo7gIQPyRWc9f3c6IeSNavBTSZNZQ=
github.com/Microsoft/hcsshim v0.11.7/go.mod h1:MV8xMfmECjl5HdO7U/3/hFVnkmSBjAjmA09d4bExKcU=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
......@@ -53,6 +53,8 @@ github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymF
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/evanphx/json-patch/v5 v5.9.11 h1:/8HVnzMq13/3x9TPvjG08wUGqBTmZBsCWzjTM0wiaDU=
github.com/evanphx/json-patch/v5 v5.9.11/go.mod h1:3j+LviiESTElxA4p3EMKAB9HXj3/XEtnUf6OZxqIQTM=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM=
......@@ -105,8 +107,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 h1:BHT72Gu3keYf3ZEu2J0b1vyeLSOYI8bm5wbJM/8yDe8=
github.com/google/pprof v0.0.0-20250403155104-27863c87afa6/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db h1:097atOisP2aRj7vFgYQBbFN4U4JNXUNYpxael3UzMyo=
github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
......@@ -147,10 +149,10 @@ github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee h1:W5t00kpgFd
github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/onsi/ginkgo/v2 v2.27.2 h1:LzwLj0b89qtIy6SSASkzlNvX6WktqurSHwkk2ipF/Ns=
github.com/onsi/ginkgo/v2 v2.27.2/go.mod h1:ArE1D/XhNXBXCBkKOLkbsb2c81dQHCRcF5zwn/ykDRo=
github.com/onsi/gomega v1.38.2 h1:eZCjf2xjZAqe+LeWvKb5weQ+NcPwX84kqJ0cZNxok2A=
github.com/onsi/gomega v1.38.2/go.mod h1:W2MJcYxRGV63b418Ai34Ud0hEdTVXq9NW9+Sx6uXf3k=
github.com/onsi/ginkgo/v2 v2.22.0 h1:Yed107/8DjTr0lKCNt7Dn8yQ6ybuDRQoMGrNFKzMfHg=
github.com/onsi/ginkgo/v2 v2.22.0/go.mod h1:7Du3c42kxCUegi0IImZ1wUQzMBVecgIHjR1C+NkhLQo=
github.com/onsi/gomega v1.36.1 h1:bJDPBO7ibjxcbHMgSCoo4Yj18UWbKDlLwX1x9sybDcw=
github.com/onsi/gomega v1.36.1/go.mod h1:PvZbdDc8J6XJEpDK4HCuRBm8a6Fzp9/DmhC9C7yFlog=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug=
......@@ -163,7 +165,13 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q=
github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
......@@ -223,8 +231,6 @@ golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvx
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
......@@ -318,25 +324,29 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
k8s.io/api v0.35.0 h1:iBAU5LTyBI9vw3L5glmat1njFK34srdLmktWwLTprlY=
k8s.io/api v0.35.0/go.mod h1:AQ0SNTzm4ZAczM03QH42c7l3bih1TbAXYo0DkF8ktnA=
k8s.io/apimachinery v0.35.0 h1:Z2L3IHvPVv/MJ7xRxHEtk6GoJElaAqDCCU0S6ncYok8=
k8s.io/apimachinery v0.35.0/go.mod h1:jQCgFZFR1F4Ik7hvr2g84RTJSZegBc8yHgFWKn//hns=
k8s.io/client-go v0.35.0 h1:IAW0ifFbfQQwQmga0UdoH0yvdqrbwMdq9vIFEhRpxBE=
k8s.io/client-go v0.35.0/go.mod h1:q2E5AAyqcbeLGPdoRB+Nxe3KYTfPce1Dnu1myQdqz9o=
k8s.io/api v0.34.3 h1:D12sTP257/jSH2vHV2EDYrb16bS7ULlHpdNdNhEw2S4=
k8s.io/api v0.34.3/go.mod h1:PyVQBF886Q5RSQZOim7DybQjAbVs8g7gwJNhGtY5MBk=
k8s.io/apiextensions-apiserver v0.34.1 h1:NNPBva8FNAPt1iSVwIE0FsdrVriRXMsaWFMqJbII2CI=
k8s.io/apiextensions-apiserver v0.34.1/go.mod h1:hP9Rld3zF5Ay2Of3BeEpLAToP+l4s5UlxiHfqRaRcMc=
k8s.io/apimachinery v0.34.3 h1:/TB+SFEiQvN9HPldtlWOTp0hWbJ+fjU+wkxysf/aQnE=
k8s.io/apimachinery v0.34.3/go.mod h1:/GwIlEcWuTX9zKIg2mbw0LRFIsXwrfoVxn+ef0X13lw=
k8s.io/client-go v0.34.3 h1:wtYtpzy/OPNYf7WyNBTj3iUA0XaBHVqhv4Iv3tbrF5A=
k8s.io/client-go v0.34.3/go.mod h1:OxxeYagaP9Kdf78UrKLa3YZixMCfP6bgPwPwNBQBzpM=
k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk=
k8s.io/klog/v2 v2.130.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE=
k8s.io/kube-openapi v0.0.0-20250910181357-589584f1c912 h1:Y3gxNAuB0OBLImH611+UDZcmKS3g6CthxToOb37KgwE=
k8s.io/kube-openapi v0.0.0-20250910181357-589584f1c912/go.mod h1:kdmbQkyfwUagLfXIad1y2TdrjPFWp2Q89B3qkRwf/pQ=
k8s.io/kubelet v0.35.0 h1:8cgJHCBCKLYuuQ7/Pxb/qWbJfX1LXIw7790ce9xHq7c=
k8s.io/kubelet v0.35.0/go.mod h1:ciRzAXn7C4z5iB7FhG1L2CGPPXLTVCABDlbXt/Zz8YA=
k8s.io/kube-openapi v0.0.0-20250814151709-d7b6acb124c3 h1:liMHz39T5dJO1aOKHLvwaCjDbf07wVh6yaUlTpunnkE=
k8s.io/kube-openapi v0.0.0-20250814151709-d7b6acb124c3/go.mod h1:UZ2yyWbFTpuhSbFhv24aGNOdoRdJZgsIObGBUaYVsts=
k8s.io/kubelet v0.34.3 h1:8QRev2FmasZ05yCC774qn6ULche72PYM7AQv0CVt9CM=
k8s.io/kubelet v0.34.3/go.mod h1:pMgblr+nVQ02UkyaTcgqzS3AIYVQkjlMFg1Pd5rGC1Q=
k8s.io/utils v0.0.0-20251002143259-bc988d571ff4 h1:SjGebBtkBqHFOli+05xYbK8YF1Dzkbzn+gDM4X9T4Ck=
k8s.io/utils v0.0.0-20251002143259-bc988d571ff4/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0=
sigs.k8s.io/controller-runtime v0.22.4 h1:GEjV7KV3TY8e+tJ2LCTxUTanW4z/FmNB7l327UfMq9A=
sigs.k8s.io/controller-runtime v0.22.4/go.mod h1:+QX1XUpTXN4mLoblf4tqr5CQcyHPAki2HLXqQMY6vh8=
sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 h1:IpInykpT6ceI+QxKBbEflcR5EXP7sU1kvOlxwZh5txg=
sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730/go.mod h1:mdzfpAEoE6DHQEN0uh9ZbOCuHbLK5wOm7dK4ctXE9Tg=
sigs.k8s.io/randfill v1.0.0 h1:JfjMILfT8A6RbawdsK2JXGBR5AQVfd+9TbzrlneTyrU=
sigs.k8s.io/randfill v1.0.0/go.mod h1:XeLlZ/jmk4i1HRopwe7/aU3H5n1zNUcX6TM94b3QxOY=
sigs.k8s.io/structured-merge-diff/v6 v6.3.2-0.20260122202528-d9cc6641c482 h1:2WOzJpHUBVrrkDjU4KBT8n5LDcj824eX0I5UKcgeRUs=
sigs.k8s.io/structured-merge-diff/v6 v6.3.2-0.20260122202528-d9cc6641c482/go.mod h1:M3W8sfWvn2HhQDIbGWj3S099YozAsymCo/wrT5ohRUE=
sigs.k8s.io/structured-merge-diff/v6 v6.3.1 h1:JrhdFMqOd/+3ByqlP2I45kTOZmTRLBUm5pvRjeheg7E=
sigs.k8s.io/structured-merge-diff/v6 v6.3.1/go.mod h1:M3W8sfWvn2HhQDIbGWj3S099YozAsymCo/wrT5ohRUE=
sigs.k8s.io/yaml v1.6.0 h1:G8fkbMSAFqgEFgh4b1wmtzDnioxFCUgTZhlbj5P9QYs=
sigs.k8s.io/yaml v1.6.0/go.mod h1:796bPqUfzR/0jLAl6XjHl3Ck7MiyVv8dbTdyT3/pMf4=
......@@ -26,20 +26,10 @@ import (
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/cache"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/common"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/executor"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/types"
)
const (
kubeLabelIsCheckpointSource = "nvidia.com/snapshot-is-checkpoint-source"
kubeLabelCheckpointHash = "nvidia.com/snapshot-checkpoint-hash"
kubeLabelIsRestoreTarget = "nvidia.com/snapshot-is-restore-target"
kubeAnnotationCheckpointLocation = "nvidia.com/snapshot-checkpoint-location"
kubeAnnotationCheckpointStorageType = "nvidia.com/snapshot-checkpoint-storage-type"
kubeAnnotationCheckpointStatus = "nvidia.com/snapshot-checkpoint-status"
kubeAnnotationRestoreStatus = "nvidia.com/snapshot-restore-status"
kubeAnnotationRestoreContainerID = "nvidia.com/snapshot-restore-container-id"
"github.com/ai-dynamo/dynamo/deploy/snapshot/internal/executor"
snapshotruntime "github.com/ai-dynamo/dynamo/deploy/snapshot/internal/runtime"
"github.com/ai-dynamo/dynamo/deploy/snapshot/internal/types"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
)
// NodeController watches local-node pods with checkpoint metadata and reconciles
......@@ -88,8 +78,8 @@ func NewNodeController(
func (w *NodeController) Run(ctx context.Context) error {
w.log.Info("Starting snapshot node controller",
"node", w.config.NodeName,
"checkpoint", kubeLabelIsCheckpointSource,
"restore", kubeLabelIsRestoreTarget,
"checkpoint", snapshotprotocol.CheckpointSourceLabel,
"restore", snapshotprotocol.RestoreTargetLabel,
)
var nsOptions []informers.SharedInformerOption
......@@ -104,7 +94,7 @@ func (w *NodeController) Run(ctx context.Context) error {
// Checkpoint informer
checkpointSelector := labels.SelectorFromSet(labels.Set{
kubeLabelIsCheckpointSource: "true",
snapshotprotocol.CheckpointSourceLabel: "true",
}).String()
ckptFactoryOpts := append([]informers.SharedInformerOption{
......@@ -141,7 +131,7 @@ func (w *NodeController) Run(ctx context.Context) error {
// Restore informer
restoreSelector := labels.SelectorFromSet(labels.Set{
kubeLabelIsRestoreTarget: "true",
snapshotprotocol.RestoreTargetLabel: "true",
}).String()
restoreFactoryOpts := append([]informers.SharedInformerOption{
......@@ -196,9 +186,9 @@ func (w *NodeController) reconcileCheckpointPod(ctx context.Context, pod *corev1
podKey := fmt.Sprintf("%s/%s", pod.Namespace, pod.Name)
checkpointHash, ok := pod.Labels[kubeLabelCheckpointHash]
if !ok || checkpointHash == "" {
w.log.Info("Pod has checkpoint label but no checkpoint-hash label", "pod", podKey)
checkpointID, ok := pod.Labels[snapshotprotocol.CheckpointIDLabel]
if !ok || checkpointID == "" {
w.log.Info("Pod has checkpoint label but no checkpoint-id label", "pod", podKey)
return
}
......@@ -208,8 +198,8 @@ func (w *NodeController) reconcileCheckpointPod(ctx context.Context, pod *corev1
return
}
jobStatus := job.Annotations[kubeAnnotationCheckpointStatus]
if jobStatus == "completed" || jobStatus == "failed" {
jobStatus := job.Annotations[snapshotprotocol.CheckpointStatusAnnotation]
if jobStatus == snapshotprotocol.CheckpointStatusCompleted || jobStatus == snapshotprotocol.CheckpointStatusFailed {
return
}
......@@ -217,17 +207,17 @@ func (w *NodeController) reconcileCheckpointPod(ctx context.Context, pod *corev1
return
}
checkpointLocation, checkpointStorageType, err := checkpointStorageFromPod(pod)
checkpointLocation, err := w.checkpointLocationFromPod(pod, checkpointID)
if err != nil {
w.release(podKey)
w.log.Error(err, "Checkpoint pod is missing storage metadata", "pod", podKey, "checkpoint_hash", checkpointHash)
w.log.Error(err, "Checkpoint pod is missing storage metadata", "pod", podKey, "checkpoint_id", checkpointID)
return
}
acquiredLease, err := acquireCheckpointLease(ctx, w.clientset, w.log, job, w.holderID)
if err != nil {
w.release(podKey)
w.log.Error(err, "Failed to acquire checkpoint lease", "pod", podKey, "checkpoint_hash", checkpointHash)
w.log.Error(err, "Failed to acquire checkpoint lease", "pod", podKey, "checkpoint_id", checkpointID)
return
}
if !acquiredLease {
......@@ -235,12 +225,12 @@ func (w *NodeController) reconcileCheckpointPod(ctx context.Context, pod *corev1
return
}
w.log.Info("Pod ready, triggering checkpoint", "pod", podKey, "checkpoint_hash", checkpointHash)
emitPodEvent(ctx, w.clientset, w.log, pod, "snapshot", corev1.EventTypeNormal, "CheckpointRequested", fmt.Sprintf("Checkpoint requested: %s", checkpointHash))
w.log.Info("Pod ready, triggering checkpoint", "pod", podKey, "checkpoint_id", checkpointID)
emitPodEvent(ctx, w.clientset, w.log, pod, "snapshot", corev1.EventTypeNormal, "CheckpointRequested", fmt.Sprintf("Checkpoint requested: %s", checkpointID))
go func() {
if err := w.runCheckpoint(ctx, pod, job, checkpointHash, checkpointLocation, checkpointStorageType, podKey); err != nil {
opLog := w.log.WithValues("pod", podKey, "checkpoint_hash", checkpointHash)
if err := w.runCheckpoint(ctx, pod, job, checkpointID, checkpointLocation, podKey); err != nil {
opLog := w.log.WithValues("pod", podKey, "checkpoint_id", checkpointID)
opLog.Error(err, "Checkpoint controller worker failed")
emitPodEvent(ctx, w.clientset, opLog, pod, "snapshot", corev1.EventTypeWarning, "CheckpointWorkerFailed", err.Error())
}
......@@ -258,28 +248,24 @@ func (w *NodeController) reconcileRestorePod(ctx context.Context, pod *corev1.Po
return
}
if isPodReady(pod) {
return
}
checkpointHash, ok := pod.Labels[kubeLabelCheckpointHash]
if !ok || checkpointHash == "" {
w.log.Info("Restore pod has no checkpoint-hash label", "pod", podKey)
checkpointID, ok := pod.Labels[snapshotprotocol.CheckpointIDLabel]
if !ok || checkpointID == "" {
w.log.Info("Restore pod has no checkpoint-id label", "pod", podKey)
return
}
if strings.ContainsAny(checkpointHash, "/\\") || strings.Contains(checkpointHash, "..") || filepath.Clean(checkpointHash) != checkpointHash {
w.log.Error(fmt.Errorf("invalid checkpoint hash %q", checkpointHash), "Invalid checkpoint hash on restore pod", "pod", podKey)
if strings.ContainsAny(checkpointID, "/\\") || strings.Contains(checkpointID, "..") || filepath.Clean(checkpointID) != checkpointID {
w.log.Error(fmt.Errorf("invalid checkpoint id %q", checkpointID), "Invalid checkpoint id on restore pod", "pod", podKey)
return
}
checkpointLocation, checkpointStorageType, err := checkpointStorageFromPod(pod)
checkpointLocation, err := w.checkpointLocationFromPod(pod, checkpointID)
if err != nil {
w.log.Error(err, "Restore pod is missing storage metadata", "pod", podKey, "checkpoint_hash", checkpointHash)
w.log.Error(err, "Restore pod is missing storage metadata", "pod", podKey, "checkpoint_id", checkpointID)
return
}
if _, err := os.Stat(checkpointLocation); os.IsNotExist(err) {
w.log.V(1).Info("Checkpoint not ready on disk, skipping restore", "pod", podKey, "checkpoint_hash", checkpointHash, "checkpoint_location", checkpointLocation)
w.log.V(1).Info("Checkpoint not ready on disk, skipping restore", "pod", podKey, "checkpoint_id", checkpointID, "checkpoint_location", checkpointLocation)
return
}
......@@ -302,9 +288,9 @@ func (w *NodeController) reconcileRestorePod(ctx context.Context, pod *corev1.Po
return
}
annotationStatus := pod.Annotations[kubeAnnotationRestoreStatus]
annotationContainerID := pod.Annotations[kubeAnnotationRestoreContainerID]
if annotationContainerID == containerID && (annotationStatus == "completed" || annotationStatus == "in_progress") {
annotationStatus := pod.Annotations[snapshotprotocol.RestoreStatusAnnotation]
annotationContainerID := pod.Annotations[snapshotprotocol.RestoreContainerIDAnnotation]
if annotationContainerID == containerID && (annotationStatus == snapshotprotocol.RestoreStatusCompleted || annotationStatus == snapshotprotocol.RestoreStatusInProgress) {
return
}
......@@ -313,12 +299,12 @@ func (w *NodeController) reconcileRestorePod(ctx context.Context, pod *corev1.Po
return
}
w.log.Info("Restore pod running, triggering external restore", "pod", podKey, "checkpoint_hash", checkpointHash)
emitPodEvent(ctx, w.clientset, w.log, pod, "snapshot", corev1.EventTypeNormal, "RestoreRequested", fmt.Sprintf("Restore requested from checkpoint %s", checkpointHash))
w.log.Info("Restore pod running, triggering external restore", "pod", podKey, "checkpoint_id", checkpointID)
emitPodEvent(ctx, w.clientset, w.log, pod, "snapshot", corev1.EventTypeNormal, "RestoreRequested", fmt.Sprintf("Restore requested from checkpoint %s", checkpointID))
go func() {
if err := w.runRestore(ctx, pod, containerName, containerID, checkpointHash, checkpointLocation, checkpointStorageType, restoreAttemptKey); err != nil {
opLog := w.log.WithValues("pod", podKey, "checkpoint_hash", checkpointHash)
if err := w.runRestore(ctx, pod, containerName, containerID, checkpointID, checkpointLocation, restoreAttemptKey); err != nil {
opLog := w.log.WithValues("pod", podKey, "checkpoint_id", checkpointID)
opLog.Error(err, "Restore controller worker failed")
emitPodEvent(ctx, w.clientset, opLog, pod, "snapshot", corev1.EventTypeWarning, "RestoreWorkerFailed", err.Error())
}
......@@ -331,14 +317,14 @@ func (w *NodeController) reconcileRestorePod(ctx context.Context, pod *corev1.Po
// 3. Call executor.Checkpoint (inspect → configure → CUDA lock/checkpoint → CRIU dump → rootfs diff)
// 4. SIGUSR1 the process on success (notify workload), SIGKILL on failure (terminate immediately)
// 5. Mark job as completed or failed
func (w *NodeController) runCheckpoint(ctx context.Context, pod *corev1.Pod, job *batchv1.Job, checkpointHash, checkpointLocation, checkpointStorageType, podKey string) error {
func (w *NodeController) runCheckpoint(ctx context.Context, pod *corev1.Pod, job *batchv1.Job, checkpointID, checkpointLocation, podKey string) error {
releasePodOnExit := true
defer func() {
if releasePodOnExit {
w.release(podKey)
}
}()
log := w.log.WithValues("pod", podKey, "checkpoint_hash", checkpointHash)
log := w.log.WithValues("pod", podKey, "checkpoint_id", checkpointID)
leaseCtx, stopLease := context.WithCancelCause(ctx)
defer stopLease(nil)
......@@ -358,7 +344,7 @@ func (w *NodeController) runCheckpoint(ctx context.Context, pod *corev1.Pod, job
setCheckpointStatus := func(value string) error {
if err := annotateJob(ctx, w.clientset, log, job, map[string]string{
kubeAnnotationCheckpointStatus: value,
snapshotprotocol.CheckpointStatusAnnotation: value,
}); err != nil {
releasePodOnExit = false
releaseLeaseOnExit = false
......@@ -373,7 +359,7 @@ func (w *NodeController) runCheckpoint(ctx context.Context, pod *corev1.Pod, job
err := fmt.Errorf("no containers found in pod spec")
log.Error(err, "Checkpoint failed")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", err.Error())
if statusErr := setCheckpointStatus("failed"); statusErr != nil {
if statusErr := setCheckpointStatus(snapshotprotocol.CheckpointStatusFailed); statusErr != nil {
return statusErr
}
return nil
......@@ -387,18 +373,18 @@ func (w *NodeController) runCheckpoint(ctx context.Context, pod *corev1.Pod, job
}
if containerID == "" {
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", "Could not resolve target container ID")
if statusErr := setCheckpointStatus("failed"); statusErr != nil {
if statusErr := setCheckpointStatus(snapshotprotocol.CheckpointStatusFailed); statusErr != nil {
return statusErr
}
return nil
}
// Resolve the container's host PID (needed for signaling after checkpoint)
containerPID, _, err := common.ResolveContainer(ctx, w.containerd, containerID)
containerPID, _, err := snapshotruntime.ResolveContainer(ctx, w.containerd, containerID)
if err != nil {
log.Error(err, "Failed to resolve container")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", fmt.Sprintf("Container resolve failed: %v", err))
if statusErr := setCheckpointStatus("failed"); statusErr != nil {
if statusErr := setCheckpointStatus(snapshotprotocol.CheckpointStatusFailed); statusErr != nil {
return statusErr
}
return nil
......@@ -408,9 +394,8 @@ func (w *NodeController) runCheckpoint(ctx context.Context, pod *corev1.Pod, job
req := executor.CheckpointRequest{
ContainerID: containerID,
ContainerName: containerName,
CheckpointHash: checkpointHash,
CheckpointID: checkpointID,
CheckpointLocation: checkpointLocation,
CheckpointStorageType: checkpointStorageType,
NodeName: w.config.NodeName,
PodName: pod.Name,
PodNamespace: pod.Namespace,
......@@ -423,10 +408,10 @@ func (w *NodeController) runCheckpoint(ctx context.Context, pod *corev1.Pod, job
log.Error(err, "Checkpoint failed")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", err.Error())
// SIGKILL on failure: process is unrecoverable (CUDA locked), terminate immediately
if signalErr := common.SendSignalToPID(log, containerPID, syscall.SIGKILL, "checkpoint failed"); signalErr != nil {
if signalErr := snapshotruntime.SendSignalToPID(log, containerPID, syscall.SIGKILL, "checkpoint failed"); signalErr != nil {
log.Error(signalErr, "Failed to signal checkpoint failure to runtime process")
}
if statusErr := setCheckpointStatus("failed"); statusErr != nil {
if statusErr := setCheckpointStatus(snapshotprotocol.CheckpointStatusFailed); statusErr != nil {
return statusErr
}
return nil
......@@ -441,27 +426,27 @@ func (w *NodeController) runCheckpoint(ctx context.Context, pod *corev1.Pod, job
}
log.Error(err, "Checkpoint failed verification")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", err.Error())
if signalErr := common.SendSignalToPID(log, containerPID, syscall.SIGKILL, "checkpoint verification failed"); signalErr != nil {
if signalErr := snapshotruntime.SendSignalToPID(log, containerPID, syscall.SIGKILL, "checkpoint verification failed"); signalErr != nil {
log.Error(signalErr, "Failed to signal checkpoint verification failure to runtime process")
}
if statusErr := setCheckpointStatus("failed"); statusErr != nil {
if statusErr := setCheckpointStatus(snapshotprotocol.CheckpointStatusFailed); statusErr != nil {
return statusErr
}
return nil
}
// Step 2: SIGUSR1 on success: notify the workload that checkpoint completed
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeNormal, "CheckpointSucceeded", fmt.Sprintf("Checkpoint completed: %s", checkpointHash))
if err := common.SendSignalToPID(log, containerPID, syscall.SIGUSR1, "checkpoint complete"); err != nil {
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeNormal, "CheckpointSucceeded", fmt.Sprintf("Checkpoint completed: %s", checkpointID))
if err := snapshotruntime.SendSignalToPID(log, containerPID, syscall.SIGUSR1, "checkpoint complete"); err != nil {
log.Error(err, "Failed to signal checkpoint completion to runtime process")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", err.Error())
if statusErr := setCheckpointStatus("failed"); statusErr != nil {
if statusErr := setCheckpointStatus(snapshotprotocol.CheckpointStatusFailed); statusErr != nil {
return statusErr
}
return nil
}
if err := setCheckpointStatus("completed"); err != nil {
if err := setCheckpointStatus(snapshotprotocol.CheckpointStatusCompleted); err != nil {
return err
}
return nil
......@@ -473,7 +458,7 @@ func (w *NodeController) runCheckpoint(ctx context.Context, pod *corev1.Pod, job
// 3. SIGCONT the restored process to wake it up
// 4. Wait for the pod to become Ready
// 5. Mark the container instance as completed
func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, containerName, containerID, checkpointHash, checkpointLocation, checkpointStorageType, restoreAttemptKey string) error {
func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, containerName, containerID, checkpointID, checkpointLocation, restoreAttemptKey string) error {
releaseOnExit := true
defer func() {
if releaseOnExit {
......@@ -481,14 +466,14 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai
}
}()
podKey := fmt.Sprintf("%s/%s", pod.Namespace, pod.Name)
log := w.log.WithValues("pod", podKey, "checkpoint_hash", checkpointHash, "container_id", containerID)
log := w.log.WithValues("pod", podKey, "checkpoint_id", checkpointID, "container_id", containerID)
setRestoreStatus := func(value string) error {
annotations := map[string]string{
kubeAnnotationRestoreStatus: value,
kubeAnnotationRestoreContainerID: containerID,
snapshotprotocol.RestoreStatusAnnotation: value,
snapshotprotocol.RestoreContainerIDAnnotation: containerID,
}
if err := annotatePod(ctx, w.clientset, log, pod, annotations); err != nil {
if value == "completed" {
if value == snapshotprotocol.RestoreStatusCompleted {
releaseOnExit = false
return fmt.Errorf("failed to persist terminal restore status %q: %w", value, err)
}
......@@ -498,17 +483,16 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai
}
if err := annotatePod(ctx, w.clientset, log, pod, map[string]string{
kubeAnnotationRestoreStatus: "in_progress",
kubeAnnotationRestoreContainerID: containerID,
snapshotprotocol.RestoreStatusAnnotation: snapshotprotocol.RestoreStatusInProgress,
snapshotprotocol.RestoreContainerIDAnnotation: containerID,
}); err != nil {
return fmt.Errorf("failed to annotate pod with restore in_progress: %w", err)
}
// Step 1: Run the restore orchestrator (inspect + nsrestore)
req := executor.RestoreRequest{
CheckpointHash: checkpointHash,
CheckpointID: checkpointID,
CheckpointLocation: checkpointLocation,
CheckpointStorageType: checkpointStorageType,
NSRestorePath: w.config.Restore.NSRestorePath,
PodName: pod.Name,
PodNamespace: pod.Namespace,
......@@ -519,12 +503,12 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai
if err != nil {
log.Error(err, "External restore failed")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
placeholderHostPID, _, pidErr := common.ResolveContainerByPod(ctx, w.containerd, pod.Name, pod.Namespace, containerName)
placeholderHostPID, _, pidErr := snapshotruntime.ResolveContainerByPod(ctx, w.containerd, pod.Name, pod.Namespace, containerName)
if pidErr != nil {
releaseOnExit = false
return fmt.Errorf("restore failed and placeholder PID could not be resolved: %w", pidErr)
}
if killErr := common.SendSignalToPID(log, placeholderHostPID, syscall.SIGKILL, "restore failed"); killErr != nil {
if killErr := snapshotruntime.SendSignalToPID(log, placeholderHostPID, syscall.SIGKILL, "restore failed"); killErr != nil {
releaseOnExit = false
return fmt.Errorf("restore failed and placeholder could not be killed: %w", killErr)
}
......@@ -532,17 +516,17 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai
}
// Step 2: SIGCONT the restored process via PID namespace
placeholderHostPID, _, err := common.ResolveContainerByPod(ctx, w.containerd, pod.Name, pod.Namespace, containerName)
placeholderHostPID, _, err := snapshotruntime.ResolveContainerByPod(ctx, w.containerd, pod.Name, pod.Namespace, containerName)
if err != nil {
log.Error(err, "Failed to resolve placeholder host PID for signaling")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
releaseOnExit = false
return fmt.Errorf("failed to resolve placeholder host PID for signaling: %w", err)
}
if err := common.SendSignalViaPIDNamespace(ctx, log, placeholderHostPID, restoredPID, syscall.SIGCONT, "restore complete"); err != nil {
if err := snapshotruntime.SendSignalViaPIDNamespace(ctx, log, placeholderHostPID, restoredPID, syscall.SIGCONT, "restore complete"); err != nil {
log.Error(err, "Failed to signal restored runtime process")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
if killErr := common.SendSignalToPID(log, placeholderHostPID, syscall.SIGKILL, "restore signaling failed"); killErr != nil {
if killErr := snapshotruntime.SendSignalToPID(log, placeholderHostPID, syscall.SIGKILL, "restore signaling failed"); killErr != nil {
log.Error(killErr, "Failed to kill placeholder after restore signaling failure")
}
releaseOnExit = false
......@@ -559,15 +543,15 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai
if err := waitForPodReady(readyCtx, w.clientset, pod.Namespace, pod.Name, containerName); err != nil {
log.Error(err, "Restore post-signal readiness check failed")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
if killErr := common.SendSignalToPID(log, placeholderHostPID, syscall.SIGKILL, "restore readiness failed"); killErr != nil {
if killErr := snapshotruntime.SendSignalToPID(log, placeholderHostPID, syscall.SIGKILL, "restore readiness failed"); killErr != nil {
log.Error(killErr, "Failed to kill placeholder after restore readiness failure")
}
releaseOnExit = false
return fmt.Errorf("restore post-signal readiness check failed: %w", err)
}
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeNormal, "RestoreSucceeded", fmt.Sprintf("Restore completed from checkpoint %s", checkpointHash))
if err := setRestoreStatus("completed"); err != nil {
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeNormal, "RestoreSucceeded", fmt.Sprintf("Restore completed from checkpoint %s", checkpointID))
if err := setRestoreStatus(snapshotprotocol.RestoreStatusCompleted); err != nil {
return err
}
return nil
......@@ -589,19 +573,17 @@ func (w *NodeController) release(podKey string) {
delete(w.inFlight, podKey)
}
func checkpointStorageFromPod(pod *corev1.Pod) (string, string, error) {
checkpointLocation := strings.TrimSpace(pod.Annotations[kubeAnnotationCheckpointLocation])
if checkpointLocation == "" {
return "", "", fmt.Errorf("missing %s annotation", kubeAnnotationCheckpointLocation)
}
checkpointStorageType := strings.TrimSpace(pod.Annotations[kubeAnnotationCheckpointStorageType])
if checkpointStorageType == "" {
return "", "", fmt.Errorf("missing %s annotation", kubeAnnotationCheckpointStorageType)
}
if checkpointStorageType != "pvc" {
return "", "", fmt.Errorf("checkpoint storage type %q is not supported", checkpointStorageType)
func (w *NodeController) checkpointLocationFromPod(pod *corev1.Pod, checkpointID string) (string, error) {
resolvedStorage, err := snapshotprotocol.ResolveCheckpointStorage(
checkpointID,
strings.TrimSpace(pod.Annotations[snapshotprotocol.CheckpointArtifactVersionAnnotation]),
snapshotprotocol.Storage{
Type: w.config.Storage.Type,
BasePath: w.config.Storage.BasePath,
},
)
if err != nil {
return "", err
}
return checkpointLocation, checkpointStorageType, nil
return resolvedStorage.Location, nil
}
......@@ -17,7 +17,8 @@ import (
"k8s.io/client-go/kubernetes/fake"
clientgotesting "k8s.io/client-go/testing"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/types"
"github.com/ai-dynamo/dynamo/deploy/snapshot/internal/types"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
)
const testNodeName = "test-node"
......@@ -31,6 +32,10 @@ func makeTestController(t *testing.T, objs ...runtime.Object) *NodeController {
return &NodeController{
config: &types.AgentConfig{
NodeName: testNodeName,
Storage: types.StorageSpec{
Type: snapshotprotocol.StorageTypePVC,
BasePath: t.TempDir(),
},
},
clientset: fake.NewClientset(objs...),
log: testr.New(t),
......@@ -187,11 +192,11 @@ func TestReconcileCheckpointPod(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
labels := map[string]string{
kubeLabelIsCheckpointSource: "true",
snapshotprotocol.CheckpointSourceLabel: "true",
"batch.kubernetes.io/job-name": "checkpoint-job",
}
if tc.hash != "" {
labels[kubeLabelCheckpointHash] = tc.hash
labels[snapshotprotocol.CheckpointIDLabel] = tc.hash
}
job := &batchv1.Job{
......@@ -202,18 +207,11 @@ func TestReconcileCheckpointPod(t *testing.T) {
}
if tc.annotation != "" {
job.Annotations = map[string]string{
kubeAnnotationCheckpointStatus: tc.annotation,
snapshotprotocol.CheckpointStatusAnnotation: tc.annotation,
}
}
var annotations map[string]string
if tc.hash != "" {
annotations = map[string]string{
kubeAnnotationCheckpointLocation: "/checkpoints/" + tc.hash,
kubeAnnotationCheckpointStorageType: "pvc",
}
}
pod := makePod("test-pod", "default", tc.nodeName, tc.phase, tc.ready, labels, annotations)
pod := makePod("test-pod", "default", tc.nodeName, tc.phase, tc.ready, labels, nil)
objs := []runtime.Object{job}
if tc.lease != nil {
objs = append(objs, tc.lease)
......@@ -289,13 +287,13 @@ func TestReconcileRestorePod(t *testing.T) {
want: false,
},
{
name: "already ready",
name: "ready placeholder still restores",
nodeName: testNodeName,
phase: corev1.PodRunning,
ready: true,
hash: "abc123",
createDir: true,
want: false,
want: true,
},
{
name: "missing hash",
......@@ -382,29 +380,20 @@ func TestReconcileRestorePod(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
labels := map[string]string{
kubeLabelIsRestoreTarget: "true",
snapshotprotocol.RestoreTargetLabel: "true",
}
if tc.hash != "" {
labels[kubeLabelCheckpointHash] = tc.hash
labels[snapshotprotocol.CheckpointIDLabel] = tc.hash
}
w := makeTestController(t)
checkpointDir := t.TempDir()
var annotations map[string]string
if tc.annotationStatus != "" {
annotations = map[string]string{
kubeAnnotationRestoreStatus: tc.annotationStatus,
kubeAnnotationRestoreContainerID: tc.annotationContainerID,
snapshotprotocol.RestoreStatusAnnotation: tc.annotationStatus,
snapshotprotocol.RestoreContainerIDAnnotation: tc.annotationContainerID,
}
}
if tc.hash != "" {
if annotations == nil {
annotations = make(map[string]string)
}
annotations[kubeAnnotationCheckpointLocation] = filepath.Join(checkpointDir, tc.hash)
annotations[kubeAnnotationCheckpointStorageType] = "pvc"
}
pod := makePod("test-pod", "default", tc.nodeName, tc.phase, tc.ready, labels, annotations)
pod.Status.ContainerStatuses = []corev1.ContainerStatus{{
......@@ -414,7 +403,7 @@ func TestReconcileRestorePod(t *testing.T) {
}}
if tc.createDir && tc.hash != "" {
dir := filepath.Join(checkpointDir, tc.hash)
dir := filepath.Join(w.config.Storage.BasePath, tc.hash, "versions", snapshotprotocol.DefaultCheckpointArtifactVersion)
if err := os.MkdirAll(dir, 0o755); err != nil {
t.Fatalf("failed to create checkpoint dir: %v", err)
}
......@@ -473,6 +462,10 @@ func TestRunCheckpointKeepsLeaseAndInFlightOnTerminalStatusPatchFailure(t *testi
w := &NodeController{
config: &types.AgentConfig{
NodeName: testNodeName,
Storage: types.StorageSpec{
Type: snapshotprotocol.StorageTypePVC,
BasePath: t.TempDir(),
},
},
clientset: clientset,
log: testr.New(t),
......@@ -483,7 +476,7 @@ func TestRunCheckpointKeepsLeaseAndInFlightOnTerminalStatusPatchFailure(t *testi
stopCh: make(chan struct{}),
}
err := w.runCheckpoint(context.Background(), pod, job, "abc123", filepath.Join(t.TempDir(), "abc123"), "pvc", "default/test-pod")
err := w.runCheckpoint(context.Background(), pod, job, "abc123", filepath.Join(t.TempDir(), "abc123"), "default/test-pod")
if err == nil {
t.Fatal("expected terminal checkpoint status update to fail")
}
......
......@@ -22,6 +22,20 @@ const (
checkpointLeaseRenewInterval = 10 * time.Second
)
func checkpointLeaseExpired(lease *coordinationv1.Lease, now time.Time) bool {
if lease == nil || lease.Spec.LeaseDurationSeconds == nil {
return true
}
last := lease.Spec.RenewTime
if last == nil {
last = lease.Spec.AcquireTime
}
if last == nil {
return true
}
return now.After(last.Time.Add(time.Duration(*lease.Spec.LeaseDurationSeconds) * time.Second))
}
func podFromInformerObj(obj interface{}) (*corev1.Pod, bool) {
if pod, ok := obj.(*corev1.Pod); ok {
return pod, true
......@@ -94,20 +108,6 @@ func getCheckpointJob(ctx context.Context, clientset kubernetes.Interface, pod *
return job, nil
}
func isLeaseExpired(lease *coordinationv1.Lease, now time.Time) bool {
if lease == nil || lease.Spec.LeaseDurationSeconds == nil {
return true
}
last := lease.Spec.RenewTime
if last == nil {
last = lease.Spec.AcquireTime
}
if last == nil {
return true
}
return now.After(last.Time.Add(time.Duration(*lease.Spec.LeaseDurationSeconds) * time.Second))
}
func acquireCheckpointLease(ctx context.Context, clientset kubernetes.Interface, log logr.Logger, job *batchv1.Job, holderIdentity string) (bool, error) {
leaseName := job.Name
now := metav1.NewMicroTime(time.Now())
......@@ -142,7 +142,7 @@ func acquireCheckpointLease(ctx context.Context, clientset kubernetes.Interface,
return true, nil
}
if !isLeaseExpired(existingLease, now.Time) &&
if !checkpointLeaseExpired(existingLease, now.Time) &&
existingLease.Spec.HolderIdentity != nil &&
*existingLease.Spec.HolderIdentity != holderIdentity {
return false, nil
......@@ -150,7 +150,7 @@ func acquireCheckpointLease(ctx context.Context, clientset kubernetes.Interface,
existingLease.Spec.HolderIdentity = &holderIdentity
existingLease.Spec.LeaseDurationSeconds = &leaseDurationSeconds
if existingLease.Spec.AcquireTime == nil || isLeaseExpired(existingLease, now.Time) {
if existingLease.Spec.AcquireTime == nil || checkpointLeaseExpired(existingLease, now.Time) {
existingLease.Spec.AcquireTime = &now
}
existingLease.Spec.RenewTime = &now
......
......@@ -12,8 +12,8 @@ import (
"github.com/go-logr/logr"
"google.golang.org/protobuf/proto"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/common"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/types"
snapshotruntime "github.com/ai-dynamo/dynamo/deploy/snapshot/internal/runtime"
"github.com/ai-dynamo/dynamo/deploy/snapshot/internal/types"
)
const (
......@@ -35,7 +35,7 @@ func BuildDumpOptions(
maskedPaths = state.OCISpec.Linux.MaskedPaths
}
externalized, skipped := common.BuildMountPolicy(state.Mounts, state.RootFS, maskedPaths)
externalized, skipped := snapshotruntime.BuildMountPolicy(state.Mounts, state.RootFS, maskedPaths)
log.V(1).Info("Resolved mount policy for CRIU dump",
"externalized_count", len(externalized),
"skipped_count", len(skipped),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment