Unverified Commit debda332 authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

refactor(snapshot): tighten operator protocol boundary (#8018)

parent f49df565
......@@ -132,7 +132,7 @@ func TestCreateOrGetAutoCheckpointDeduplicatesConcurrentSameHashCheckpoint(t *te
Name: "friendly-checkpoint",
Namespace: testNamespace,
Labels: map[string]string{
consts.KubeLabelCheckpointID: hash,
snapshotprotocol.CheckpointIDLabel: hash,
},
},
Spec: nvidiacomv1alpha1.DynamoCheckpointSpec{
......@@ -177,7 +177,7 @@ func TestCreateOrGetAutoCheckpointSetsDefaultArtifactVersion(t *testing.T) {
ckpt, err := CreateOrGetAutoCheckpoint(ctx, c, testNamespace, testIdentity(), corev1.PodTemplateSpec{})
require.NoError(t, err)
require.NotNil(t, ckpt.Annotations)
assert.Equal(t, consts.DefaultCheckpointArtifactVersion, ckpt.Annotations[consts.KubeAnnotationCheckpointArtifactVersion])
assert.Equal(t, snapshotprotocol.DefaultCheckpointArtifactVersion, ckpt.Annotations[snapshotprotocol.CheckpointArtifactVersionAnnotation])
}
// --- InjectCheckpointIntoPodSpec tests ---
......
......@@ -86,7 +86,7 @@ func InjectCheckpointIntoPodSpec(
mainContainer,
info.Hash,
info.ArtifactVersion,
commonconsts.SeccompProfilePath,
snapshotprotocol.DefaultSeccompLocalhostProfile,
info.Ready,
); err != nil {
return err
......
......@@ -22,7 +22,6 @@ import (
"fmt"
nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
corev1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
......@@ -56,7 +55,7 @@ func FindCheckpointByIdentityHash(
ctx,
checkpoints,
client.InNamespace(namespace),
client.MatchingLabels{consts.KubeLabelCheckpointID: hash},
client.MatchingLabels{snapshotprotocol.CheckpointIDLabel: hash},
); err != nil {
return nil, fmt.Errorf("failed to list checkpoints by hash label: %w", err)
}
......@@ -119,7 +118,7 @@ func CreateOrGetAutoCheckpoint(
Name: fmt.Sprintf("checkpoint-%s", hash),
Namespace: namespace,
Labels: map[string]string{
consts.KubeLabelCheckpointID: hash,
snapshotprotocol.CheckpointIDLabel: hash,
},
Annotations: map[string]string{
snapshotprotocol.CheckpointArtifactVersionAnnotation: snapshotprotocol.DefaultCheckpointArtifactVersion,
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package checkpointjob
import (
"testing"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
)
func TestDesiredCheckpointJobName(t *testing.T) {
name := DesiredCheckpointJobName("abc123def4567890", map[string]string{
snapshotprotocol.CheckpointArtifactVersionAnnotation: "2",
})
if name != "checkpoint-job-abc123def4567890-2" {
t.Fatalf("unexpected checkpoint job name: %s", name)
}
defaultName := DesiredCheckpointJobName("abc123def4567890", nil)
if defaultName != "checkpoint-job-abc123def4567890-"+snapshotprotocol.DefaultCheckpointArtifactVersion {
t.Fatalf("unexpected default checkpoint job name: %s", defaultName)
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package checkpointjob
import (
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
batchv1 "k8s.io/api/batch/v1"
corev1 "k8s.io/api/core/v1"
)
type ObservationPhase string
const (
ObservationPhaseRunning ObservationPhase = "running"
ObservationPhaseWaitingForConfirmation ObservationPhase = "waiting_for_confirmation"
ObservationPhaseReady ObservationPhase = "ready"
ObservationPhaseFailed ObservationPhase = "failed"
)
type Observation struct {
Phase ObservationPhase
Reason string
Message string
}
func Observe(job *batchv1.Job, checkpointWorkerActive bool) Observation {
jobComplete := false
jobFailed := false
for _, condition := range job.Status.Conditions {
if condition.Status != corev1.ConditionTrue {
continue
}
if condition.Type == batchv1.JobComplete {
jobComplete = true
continue
}
if condition.Type == batchv1.JobFailed {
jobFailed = true
}
}
status := job.Annotations[snapshotprotocol.CheckpointStatusAnnotation]
if status == snapshotprotocol.CheckpointStatusFailed {
observation := Observation{
Phase: ObservationPhaseFailed,
Reason: "JobFailed",
Message: "Checkpoint job failed",
}
if jobComplete {
observation.Reason = "CheckpointVerificationFailed"
observation.Message = "Checkpoint job completed but snapshot-agent reported checkpoint failure"
}
return observation
}
if jobComplete {
if status == snapshotprotocol.CheckpointStatusCompleted {
return Observation{
Phase: ObservationPhaseReady,
Reason: "JobSucceeded",
Message: "Checkpoint job completed successfully",
}
}
if checkpointWorkerActive {
return Observation{Phase: ObservationPhaseWaitingForConfirmation}
}
return Observation{
Phase: ObservationPhaseFailed,
Reason: "CheckpointVerificationFailed",
Message: "Checkpoint job completed without snapshot-agent completion confirmation",
}
}
if jobFailed {
return Observation{
Phase: ObservationPhaseFailed,
Reason: "JobFailed",
Message: "Checkpoint job failed",
}
}
return Observation{Phase: ObservationPhaseRunning}
}
......@@ -139,26 +139,8 @@ const (
ResourceStateNotReady = "not_ready"
ResourceStateUnknown = "unknown"
// Checkpoint/restore constants
// CROSS-REFERENCE: Some constants below are duplicated in deploy/snapshot/protocol.
// If you change a value here, update there too.
// Kubernetes labels
KubeLabelIsCheckpointSource = "nvidia.com/snapshot-is-checkpoint-source" // Pod label that triggers DaemonSet auto-checkpoint
KubeLabelCheckpointID = "nvidia.com/snapshot-checkpoint-id" // Checkpoint identity label; the operator stores the resolved identity hash as the value
KubeLabelIsRestoreTarget = "nvidia.com/snapshot-is-restore-target" // Pod label that triggers DaemonSet auto-restore
KubeAnnotationCheckpointArtifactVersion = "nvidia.com/snapshot-artifact-version" // Checkpoint artifact generation; changing it triggers a new immutable capture attempt
DefaultCheckpointArtifactVersion = "1"
DefaultCheckpointJobTTLSeconds = int32(300)
// Environment variables injected into pods
EnvReadyForCheckpointFile = "DYN_READY_FOR_CHECKPOINT_FILE" // Ready-for-checkpoint file path — checkpoint job pods
// Checkpoint pod-internal constants
CheckpointVolumeName = "checkpoint-storage" // Pod-internal volume name for checkpoint PVC
// SeccompProfilePath is the localhost seccomp profile that blocks io_uring syscalls.
// Deployed to nodes by the snapshot DaemonSet init container.
SeccompProfilePath = "profiles/block-iouring.json"
// Pod identity (Downward API) ---
// After CRIU restore, env vars contain stale values from the checkpoint pod.
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package checkpointjob
package controller
import (
"fmt"
......@@ -17,10 +17,6 @@ import (
"k8s.io/apimachinery/pkg/api/resource"
)
func DesiredCheckpointJobName(identityHash string, annotations map[string]string) string {
return "checkpoint-job-" + identityHash + "-" + snapshotprotocol.ArtifactVersion(annotations[snapshotprotocol.CheckpointArtifactVersionAnnotation])
}
func buildCheckpointWorkerDefaultEnv(
ckpt *nvidiacomv1alpha1.DynamoCheckpoint,
podTemplate *corev1.PodTemplateSpec,
......@@ -50,7 +46,7 @@ func buildCheckpointWorkerDefaultEnv(
return defaultContainer.Env
}
func BuildCheckpointJob(
func buildCheckpointJob(
config *configv1alpha1.OperatorConfiguration,
ckpt *nvidiacomv1alpha1.DynamoCheckpoint,
jobName string,
......@@ -118,7 +114,7 @@ func BuildCheckpointJob(
Namespace: ckpt.Namespace,
CheckpointID: hash,
ArtifactVersion: snapshotprotocol.ArtifactVersion(ckpt.Annotations[snapshotprotocol.CheckpointArtifactVersionAnnotation]),
SeccompProfile: consts.SeccompProfilePath,
SeccompProfile: snapshotprotocol.DefaultSeccompLocalhostProfile,
Name: jobName,
ActiveDeadlineSeconds: activeDeadlineSeconds,
TTLSecondsAfterFinish: &ttlSecondsAfterFinish,
......
......@@ -39,9 +39,8 @@ import (
configv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/config/v1alpha1"
nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpoint"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpointjob"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
commonController "github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
)
// CheckpointReconciler reconciles a DynamoCheckpoint object
......@@ -86,8 +85,8 @@ func (r *CheckpointReconciler) Reconcile(ctx context.Context, req ctrl.Request)
if ckpt.Labels == nil {
ckpt.Labels = map[string]string{}
}
if ckpt.Labels[consts.KubeLabelCheckpointID] != identityHash {
ckpt.Labels[consts.KubeLabelCheckpointID] = identityHash
if ckpt.Labels[snapshotprotocol.CheckpointIDLabel] != identityHash {
ckpt.Labels[snapshotprotocol.CheckpointIDLabel] = identityHash
if err := r.Update(ctx, ckpt); err != nil {
return ctrl.Result{}, err
}
......@@ -117,7 +116,10 @@ func (r *CheckpointReconciler) Reconcile(ctx context.Context, req ctrl.Request)
}
return ctrl.Result{}, nil
}
desiredJobName := checkpointjob.DesiredCheckpointJobName(identityHash, ckpt.Annotations)
desiredJobName := snapshotprotocol.GetCheckpointJobName(
identityHash,
ckpt.Annotations[snapshotprotocol.CheckpointArtifactVersionAnnotation],
)
switch ckpt.Status.Phase {
case "", nvidiacomv1alpha1.DynamoCheckpointPhasePending, nvidiacomv1alpha1.DynamoCheckpointPhaseCreating, nvidiacomv1alpha1.DynamoCheckpointPhaseReady, nvidiacomv1alpha1.DynamoCheckpointPhaseFailed:
default:
......@@ -181,11 +183,14 @@ func (r *CheckpointReconciler) handlePending(ctx context.Context, ckpt *nvidiaco
return ctrl.Result{}, fmt.Errorf("failed to compute checkpoint identity hash: %w", err)
}
}
jobName := checkpointjob.DesiredCheckpointJobName(hash, ckpt.Annotations)
jobName := snapshotprotocol.GetCheckpointJobName(
hash,
ckpt.Annotations[snapshotprotocol.CheckpointArtifactVersionAnnotation],
)
// Use SyncResource to create/update the checkpoint Job
modified, _, err := commonController.SyncResource(ctx, r, ckpt, func(ctx context.Context) (*batchv1.Job, bool, error) {
job, err := checkpointjob.BuildCheckpointJob(r.Config, ckpt, jobName)
job, err := buildCheckpointJob(r.Config, ckpt, jobName)
return job, false, err
})
if err != nil {
......@@ -276,12 +281,12 @@ func (r *CheckpointReconciler) handleCreating(ctx context.Context, ckpt *nvidiac
}
}
observation := checkpointjob.Observe(job, checkpointWorkerActive)
observation := snapshotprotocol.ObserveCheckpointJob(job, checkpointWorkerActive)
switch observation.Phase {
case checkpointjob.ObservationPhaseWaitingForConfirmation:
case snapshotprotocol.CheckpointObservationPhaseWaitingForConfirmation:
logger.V(1).Info("Checkpoint job is complete but checkpoint worker is still active; waiting for terminal watcher status", "job", job.Name)
return ctrl.Result{RequeueAfter: time.Second}, nil
case checkpointjob.ObservationPhaseReady:
case snapshotprotocol.CheckpointObservationPhaseReady:
logger.Info("Checkpoint Job succeeded", "job", job.Name)
r.Recorder.Event(ckpt, corev1.EventTypeNormal, "CheckpointReady", observation.Message)
......@@ -300,7 +305,7 @@ func (r *CheckpointReconciler) handleCreating(ctx context.Context, ckpt *nvidiac
return ctrl.Result{}, err
}
return ctrl.Result{}, nil
case checkpointjob.ObservationPhaseFailed:
case snapshotprotocol.CheckpointObservationPhaseFailed:
logger.Info("Checkpoint Job failed", "job", job.Name, "message", observation.Message)
r.Recorder.Event(ckpt, corev1.EventTypeWarning, "CheckpointFailed", observation.Message)
......
......@@ -25,7 +25,6 @@ import (
configv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/config/v1alpha1"
nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpoint"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpointjob"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
"github.com/stretchr/testify/assert"
......@@ -60,9 +59,7 @@ var testHash = func() string {
return hash
}()
var defaultCheckpointJobName = checkpointjob.DesiredCheckpointJobName(testHash, map[string]string{
snapshotprotocol.CheckpointArtifactVersionAnnotation: snapshotprotocol.DefaultCheckpointArtifactVersion,
})
var defaultCheckpointJobName = snapshotprotocol.GetCheckpointJobName(testHash, snapshotprotocol.DefaultCheckpointArtifactVersion)
func checkpointTestScheme() *runtime.Scheme {
s := runtime.NewScheme()
......@@ -140,15 +137,15 @@ func TestBuildCheckpointJob(t *testing.T) {
}
r := makeCheckpointReconciler(s, ckpt)
job, err := checkpointjob.BuildCheckpointJob(r.Config, ckpt, defaultCheckpointJobName)
job, err := buildCheckpointJob(r.Config, ckpt, defaultCheckpointJobName)
require.NoError(t, err)
podSpec := job.Spec.Template.Spec
main := podSpec.Containers[0]
// Job and pod template labels
assert.Equal(t, testHash, job.Labels[consts.KubeLabelCheckpointID])
assert.Equal(t, "true", job.Spec.Template.Labels[consts.KubeLabelIsCheckpointSource])
assert.Equal(t, testHash, job.Spec.Template.Labels[consts.KubeLabelCheckpointID])
assert.Equal(t, testHash, job.Labels[snapshotprotocol.CheckpointIDLabel])
assert.Equal(t, "true", job.Spec.Template.Labels[snapshotprotocol.CheckpointSourceLabel])
assert.Equal(t, testHash, job.Spec.Template.Labels[snapshotprotocol.CheckpointIDLabel])
// Env vars (checkpoint-specific + user-provided preserved)
envMap := make(map[string]string, len(main.Env))
......@@ -180,7 +177,7 @@ func TestBuildCheckpointJob(t *testing.T) {
require.NotNil(t, podSpec.SecurityContext)
require.NotNil(t, podSpec.SecurityContext.SeccompProfile)
assert.Equal(t, corev1.SeccompProfileTypeLocalhost, podSpec.SecurityContext.SeccompProfile.Type)
assert.Equal(t, consts.SeccompProfilePath, *podSpec.SecurityContext.SeccompProfile.LocalhostProfile)
assert.Equal(t, snapshotprotocol.DefaultSeccompLocalhostProfile, *podSpec.SecurityContext.SeccompProfile.LocalhostProfile)
require.NotNil(t, podSpec.SecurityContext.RunAsUser)
assert.Equal(t, int64(1234), *podSpec.SecurityContext.RunAsUser)
require.NotNil(t, podSpec.SecurityContext.FSGroup)
......@@ -197,14 +194,14 @@ func TestBuildCheckpointJob(t *testing.T) {
for _, v := range podSpec.Volumes {
volNames[v.Name] = true
}
assert.False(t, volNames[consts.CheckpointVolumeName])
assert.False(t, volNames[snapshotprotocol.CheckpointVolumeName])
assert.True(t, volNames[consts.PodInfoVolumeName])
mountPaths := make(map[string]string)
for _, m := range main.VolumeMounts {
mountPaths[m.Name] = m.MountPath
}
_, hasCheckpointMount := mountPaths[consts.CheckpointVolumeName]
_, hasCheckpointMount := mountPaths[snapshotprotocol.CheckpointVolumeName]
assert.False(t, hasCheckpointMount)
assert.Equal(t, consts.PodInfoMountPath, mountPaths[consts.PodInfoVolumeName])
assert.Equal(t, consts.DefaultSharedMemoryMountPath, mountPaths[consts.KubeValueNameSharedMemory])
......@@ -237,7 +234,7 @@ func TestBuildCheckpointJob(t *testing.T) {
backoff := int32(5)
ckpt.Spec.Job.ActiveDeadlineSeconds = &deadline
ckpt.Spec.Job.BackoffLimit = &backoff //nolint:staticcheck // Compatibility test: deprecated field must remain ignored by checkpoint Jobs.
job, err = checkpointjob.BuildCheckpointJob(r.Config, ckpt, defaultCheckpointJobName)
job, err = buildCheckpointJob(r.Config, ckpt, defaultCheckpointJobName)
require.NoError(t, err)
assert.Equal(t, int64(7200), *job.Spec.ActiveDeadlineSeconds)
assert.Equal(t, int32(0), *job.Spec.BackoffLimit)
......@@ -248,7 +245,7 @@ func TestBuildCheckpointJob(t *testing.T) {
corev1.ResourceName("nvidia.com/gpu"): resource.MustParse("2"),
},
}
job, err = checkpointjob.BuildCheckpointJob(r.Config, ckpt, defaultCheckpointJobName)
job, err = buildCheckpointJob(r.Config, ckpt, defaultCheckpointJobName)
require.NoError(t, err)
assert.Equal(t, []string{"cuda-checkpoint"}, job.Spec.Template.Spec.Containers[0].Command)
assert.Equal(t, []string{"--launch-job", "python3", "-m", "dynamo.vllm"}, job.Spec.Template.Spec.Containers[0].Args)
......@@ -273,7 +270,7 @@ func TestBuildCheckpointJobInjectsStandardEnvVars(t *testing.T) {
customShmSize := resource.MustParse("16Gi")
ckpt.Spec.Job.SharedMemory = &nvidiacomv1alpha1.SharedMemorySpec{Size: customShmSize}
job, err := checkpointjob.BuildCheckpointJob(r.Config, ckpt, defaultCheckpointJobName)
job, err := buildCheckpointJob(r.Config, ckpt, defaultCheckpointJobName)
require.NoError(t, err)
foundCustomShmVolume := false
for _, v := range job.Spec.Template.Spec.Volumes {
......@@ -326,7 +323,7 @@ func TestCheckpointReconciler_Reconcile(t *testing.T) {
assert.Equal(t, nvidiacomv1alpha1.DynamoCheckpointPhasePending, updated.Status.Phase)
assert.Equal(t, testHash, updated.Status.IdentityHash)
assert.Empty(t, updated.Status.Message)
assert.Equal(t, testHash, updated.Labels[consts.KubeLabelCheckpointID])
assert.Equal(t, testHash, updated.Labels[snapshotprotocol.CheckpointIDLabel])
})
t.Run("Ready phase is a no-op", func(t *testing.T) {
......@@ -352,7 +349,7 @@ func TestCheckpointReconciler_Reconcile(t *testing.T) {
updated := &nvidiacomv1alpha1.DynamoCheckpoint{}
require.NoError(t, r.Get(ctx, types.NamespacedName{Name: friendlyCheckpointName, Namespace: testNamespace}, updated))
assert.Equal(t, testHash, updated.Labels[consts.KubeLabelCheckpointID])
assert.Equal(t, testHash, updated.Labels[snapshotprotocol.CheckpointIDLabel])
assert.Equal(t, testHash, updated.Status.IdentityHash)
})
......@@ -375,7 +372,7 @@ func TestCheckpointReconciler_Reconcile(t *testing.T) {
ckpt := makeTestCheckpoint(nvidiacomv1alpha1.DynamoCheckpointPhaseReady)
ckpt.Status.IdentityHash = testHash
ckpt.Status.JobName = defaultCheckpointJobName
ckpt.Annotations = map[string]string{consts.KubeAnnotationCheckpointArtifactVersion: "2"}
ckpt.Annotations = map[string]string{snapshotprotocol.CheckpointArtifactVersionAnnotation: "2"}
r := makeCheckpointReconciler(s, ckpt)
_, err := r.Reconcile(ctx, ctrl.Request{
......@@ -584,7 +581,7 @@ func TestCheckpointReconciler_HandleCreating(t *testing.T) {
t.Run("in-flight version changes do not relabel the running job's artifact", func(t *testing.T) {
ckpt := makeCreatingCkpt(testHash, defaultCheckpointJobName)
ckpt.Annotations = map[string]string{consts.KubeAnnotationCheckpointArtifactVersion: "2"}
ckpt.Annotations = map[string]string{snapshotprotocol.CheckpointArtifactVersionAnnotation: "2"}
job := &batchv1.Job{
ObjectMeta: metav1.ObjectMeta{
Name: defaultCheckpointJobName,
......
......@@ -41,6 +41,7 @@ import (
commonController "github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/dynamo"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/observability"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
networkingv1beta1 "istio.io/client-go/pkg/apis/networking/v1beta1"
k8serrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/api/meta"
......@@ -961,7 +962,7 @@ func (r *DynamoComponentDeploymentReconciler) generateDeployment(ctx context.Con
// the old pod is terminated before the restore placeholder is started.
if podTemplateSpec != nil &&
podTemplateSpec.Labels != nil &&
podTemplateSpec.Labels[commonconsts.KubeLabelIsRestoreTarget] == commonconsts.KubeLabelValueTrue {
podTemplateSpec.Labels[snapshotprotocol.RestoreTargetLabel] == commonconsts.KubeLabelValueTrue {
strategy = appsv1.DeploymentStrategy{
Type: appsv1.RecreateDeploymentStrategyType,
}
......
......@@ -1306,7 +1306,7 @@ func TestDynamoComponentDeploymentReconciler_generatePodTemplateSpec_RestoreLabe
Labels: map[string]string{
commonconsts.KubeLabelDynamoGraphDeploymentName: "test-dgd",
commonconsts.KubeLabelDynamoWorkerHash: "workerhash",
commonconsts.KubeLabelIsRestoreTarget: commonconsts.KubeLabelValueTrue,
snapshotprotocol.RestoreTargetLabel: commonconsts.KubeLabelValueTrue,
},
Checkpoint: &v1alpha1.ServiceCheckpointConfig{
Enabled: true,
......@@ -1368,11 +1368,11 @@ func TestDynamoComponentDeploymentReconciler_generatePodTemplateSpec_RestoreLabe
t.Fatalf("generatePodTemplateSpec failed: %v", err)
}
if got := podTemplateSpec.Labels[commonconsts.KubeLabelIsRestoreTarget]; got != commonconsts.KubeLabelValueTrue {
t.Fatalf("expected %s label to be true, got %q", commonconsts.KubeLabelIsRestoreTarget, got)
if got := podTemplateSpec.Labels[snapshotprotocol.RestoreTargetLabel]; got != commonconsts.KubeLabelValueTrue {
t.Fatalf("expected %s label to be true, got %q", snapshotprotocol.RestoreTargetLabel, got)
}
if got := podTemplateSpec.Labels[commonconsts.KubeLabelCheckpointID]; got != checkpointName {
t.Fatalf("expected %s to be checkpoint id, got %q", commonconsts.KubeLabelCheckpointID, got)
if got := podTemplateSpec.Labels[snapshotprotocol.CheckpointIDLabel]; got != checkpointName {
t.Fatalf("expected %s to be checkpoint id, got %q", snapshotprotocol.CheckpointIDLabel, got)
}
})
......@@ -1454,11 +1454,11 @@ func TestDynamoComponentDeploymentReconciler_generatePodTemplateSpec_RestoreLabe
t.Fatalf("generatePodTemplateSpec failed: %v", err)
}
if _, ok := podTemplateSpec.Labels[commonconsts.KubeLabelIsRestoreTarget]; ok {
t.Fatalf("did not expect %s label when checkpoint is not ready", commonconsts.KubeLabelIsRestoreTarget)
if _, ok := podTemplateSpec.Labels[snapshotprotocol.RestoreTargetLabel]; ok {
t.Fatalf("did not expect %s label when checkpoint is not ready", snapshotprotocol.RestoreTargetLabel)
}
if _, ok := podTemplateSpec.Labels[commonconsts.KubeLabelCheckpointID]; ok {
t.Fatalf("did not expect %s label when checkpoint is not ready", commonconsts.KubeLabelCheckpointID)
if _, ok := podTemplateSpec.Labels[snapshotprotocol.CheckpointIDLabel]; ok {
t.Fatalf("did not expect %s label when checkpoint is not ready", snapshotprotocol.CheckpointIDLabel)
}
})
}
......
......@@ -31,6 +31,7 @@ import (
"github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpoint"
commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
grovev1alpha1 "github.com/ai-dynamo/grove/operator/api/core/v1alpha1"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
......@@ -6860,12 +6861,12 @@ func TestGenerateLabels_RemovesStaleRestoreLabelsWhenCheckpointNotReady(t *testi
DynamoNamespace: ptr.To("default-test-dgd"),
Labels: map[string]string{
"user-label": "keep",
commonconsts.KubeLabelIsRestoreTarget: commonconsts.KubeLabelValueTrue,
snapshotprotocol.RestoreTargetLabel: commonconsts.KubeLabelValueTrue,
},
ExtraPodMetadata: &v1alpha1.ExtraPodMetadata{
Labels: map[string]string{
"extra-label": "keep-too",
commonconsts.KubeLabelCheckpointID: "stale-hash",
snapshotprotocol.CheckpointIDLabel: "stale-hash",
},
},
},
......@@ -6883,8 +6884,8 @@ func TestGenerateLabels_RemovesStaleRestoreLabelsWhenCheckpointNotReady(t *testi
})
assert.Equal(t, "keep", labels["user-label"])
assert.Equal(t, "keep-too", labels["extra-label"])
_, hasRestoreTarget := labels[commonconsts.KubeLabelIsRestoreTarget]
_, hasCheckpointHash := labels[commonconsts.KubeLabelCheckpointID]
_, hasRestoreTarget := labels[snapshotprotocol.RestoreTargetLabel]
_, hasCheckpointHash := labels[snapshotprotocol.CheckpointIDLabel]
assert.False(t, hasRestoreTarget)
assert.False(t, hasCheckpointHash)
}
......@@ -6895,11 +6896,11 @@ func TestGenerateLabels_OverwritesStaleRestoreLabelsWhenCheckpointReady(t *testi
ComponentType: commonconsts.ComponentTypeWorker,
DynamoNamespace: ptr.To("default-test-dgd"),
Labels: map[string]string{
commonconsts.KubeLabelIsRestoreTarget: "false",
snapshotprotocol.RestoreTargetLabel: "false",
},
ExtraPodMetadata: &v1alpha1.ExtraPodMetadata{
Labels: map[string]string{
commonconsts.KubeLabelCheckpointID: "stale-hash",
snapshotprotocol.CheckpointIDLabel: "stale-hash",
},
},
},
......@@ -6915,8 +6916,8 @@ func TestGenerateLabels_OverwritesStaleRestoreLabelsWhenCheckpointReady(t *testi
Ready: true,
Hash: "resolved-hash",
})
assert.Equal(t, commonconsts.KubeLabelValueTrue, labels[commonconsts.KubeLabelIsRestoreTarget])
assert.Equal(t, "resolved-hash", labels[commonconsts.KubeLabelCheckpointID])
assert.Equal(t, commonconsts.KubeLabelValueTrue, labels[snapshotprotocol.RestoreTargetLabel])
assert.Equal(t, "resolved-hash", labels[snapshotprotocol.CheckpointIDLabel])
}
func TestGenerateLabels_ReassertsRestoreIdentityLabelsAfterMetadataMerge(t *testing.T) {
......
......@@ -23,6 +23,25 @@ type CheckpointJobOptions struct {
WrapLaunchJob bool
}
type CheckpointObservationPhase string
const (
CheckpointObservationPhaseRunning CheckpointObservationPhase = "running"
CheckpointObservationPhaseWaitingForConfirmation CheckpointObservationPhase = "waiting_for_confirmation"
CheckpointObservationPhaseReady CheckpointObservationPhase = "ready"
CheckpointObservationPhaseFailed CheckpointObservationPhase = "failed"
)
type CheckpointObservation struct {
Phase CheckpointObservationPhase
Reason string
Message string
}
func GetCheckpointJobName(checkpointID string, artifactVersion string) string {
return "checkpoint-job-" + checkpointID + "-" + ArtifactVersion(artifactVersion)
}
func NewCheckpointJob(podTemplate *corev1.PodTemplateSpec, opts CheckpointJobOptions) (*batchv1.Job, error) {
podTemplate = podTemplate.DeepCopy()
if podTemplate.Labels == nil {
......@@ -67,6 +86,65 @@ func NewCheckpointJob(podTemplate *corev1.PodTemplateSpec, opts CheckpointJobOpt
}, nil
}
func ObserveCheckpointJob(job *batchv1.Job, checkpointWorkerActive bool) CheckpointObservation {
jobComplete := false
jobFailed := false
for _, condition := range job.Status.Conditions {
if condition.Status != corev1.ConditionTrue {
continue
}
if condition.Type == batchv1.JobComplete {
jobComplete = true
continue
}
if condition.Type == batchv1.JobFailed {
jobFailed = true
}
}
status := job.Annotations[CheckpointStatusAnnotation]
if status == CheckpointStatusFailed {
observation := CheckpointObservation{
Phase: CheckpointObservationPhaseFailed,
Reason: "JobFailed",
Message: "Checkpoint job failed",
}
if jobComplete {
observation.Reason = "CheckpointVerificationFailed"
observation.Message = "Checkpoint job completed but snapshot-agent reported checkpoint failure"
}
return observation
}
if jobComplete {
if status == CheckpointStatusCompleted {
return CheckpointObservation{
Phase: CheckpointObservationPhaseReady,
Reason: "JobSucceeded",
Message: "Checkpoint job completed successfully",
}
}
if checkpointWorkerActive {
return CheckpointObservation{Phase: CheckpointObservationPhaseWaitingForConfirmation}
}
return CheckpointObservation{
Phase: CheckpointObservationPhaseFailed,
Reason: "CheckpointVerificationFailed",
Message: "Checkpoint job completed without snapshot-agent completion confirmation",
}
}
if jobFailed {
return CheckpointObservation{
Phase: CheckpointObservationPhaseFailed,
Reason: "JobFailed",
Message: "Checkpoint job failed",
}
}
return CheckpointObservation{Phase: CheckpointObservationPhaseRunning}
}
func EnsureLocalhostSeccompProfile(podSpec *corev1.PodSpec, profile string) {
if podSpec.SecurityContext == nil {
podSpec.SecurityContext = &corev1.PodSecurityContext{}
......
......@@ -86,3 +86,15 @@ func TestNewCheckpointJob(t *testing.T) {
t.Fatalf("unexpected ttlSecondsAfterFinished: %#v", job.Spec.TTLSecondsAfterFinished)
}
}
func TestGetCheckpointJobName(t *testing.T) {
name := GetCheckpointJobName("abc123def4567890", "2")
if name != "checkpoint-job-abc123def4567890-2" {
t.Fatalf("unexpected checkpoint job name: %s", name)
}
defaultName := GetCheckpointJobName("abc123def4567890", "")
if defaultName != "checkpoint-job-abc123def4567890-"+DefaultCheckpointArtifactVersion {
t.Fatalf("unexpected default checkpoint job name: %s", defaultName)
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package checkpointjob
package protocol
import (
"testing"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
batchv1 "k8s.io/api/batch/v1"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
func TestObserve(t *testing.T) {
func TestObserveCheckpointJob(t *testing.T) {
makeJob := func(annotation string, conditions ...batchv1.JobCondition) *batchv1.Job {
job := &batchv1.Job{
ObjectMeta: metav1.ObjectMeta{
......@@ -23,7 +22,7 @@ func TestObserve(t *testing.T) {
},
}
if annotation != "" {
job.Annotations[snapshotprotocol.CheckpointStatusAnnotation] = annotation
job.Annotations[CheckpointStatusAnnotation] = annotation
}
return job
}
......@@ -32,22 +31,22 @@ func TestObserve(t *testing.T) {
name string
job *batchv1.Job
checkpointWorkerActive bool
wantPhase ObservationPhase
wantPhase CheckpointObservationPhase
wantReason string
wantMessage string
}{
{
name: "running job stays running",
job: makeJob(""),
wantPhase: ObservationPhaseRunning,
wantPhase: CheckpointObservationPhaseRunning,
},
{
name: "completed job with completion annotation is ready",
job: makeJob(
snapshotprotocol.CheckpointStatusCompleted,
CheckpointStatusCompleted,
batchv1.JobCondition{Type: batchv1.JobComplete, Status: corev1.ConditionTrue},
),
wantPhase: ObservationPhaseReady,
wantPhase: CheckpointObservationPhaseReady,
wantReason: "JobSucceeded",
wantMessage: "Checkpoint job completed successfully",
},
......@@ -58,7 +57,7 @@ func TestObserve(t *testing.T) {
batchv1.JobCondition{Type: batchv1.JobComplete, Status: corev1.ConditionTrue},
),
checkpointWorkerActive: true,
wantPhase: ObservationPhaseWaitingForConfirmation,
wantPhase: CheckpointObservationPhaseWaitingForConfirmation,
},
{
name: "completed job fails without confirmation once worker is inactive",
......@@ -66,18 +65,18 @@ func TestObserve(t *testing.T) {
"",
batchv1.JobCondition{Type: batchv1.JobComplete, Status: corev1.ConditionTrue},
),
wantPhase: ObservationPhaseFailed,
wantPhase: CheckpointObservationPhaseFailed,
wantReason: "CheckpointVerificationFailed",
wantMessage: "Checkpoint job completed without snapshot-agent completion confirmation",
},
{
name: "failed checkpoint annotation wins over completed job",
job: makeJob(
snapshotprotocol.CheckpointStatusFailed,
CheckpointStatusFailed,
batchv1.JobCondition{Type: batchv1.JobComplete, Status: corev1.ConditionTrue},
),
checkpointWorkerActive: true,
wantPhase: ObservationPhaseFailed,
wantPhase: CheckpointObservationPhaseFailed,
wantReason: "CheckpointVerificationFailed",
wantMessage: "Checkpoint job completed but snapshot-agent reported checkpoint failure",
},
......@@ -85,7 +84,7 @@ func TestObserve(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
observation := Observe(tc.job, tc.checkpointWorkerActive)
observation := ObserveCheckpointJob(tc.job, tc.checkpointWorkerActive)
if observation.Phase != tc.wantPhase {
t.Fatalf("phase = %q, want %q", observation.Phase, tc.wantPhase)
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment