"docs/backends/vscode:/vscode.git/clone" did not exist on "0a2a820bcacda705d927c6fdfcf37ec076e4e3fd"
Unverified Commit debda332 authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

refactor(snapshot): tighten operator protocol boundary (#8018)

parent f49df565
...@@ -132,7 +132,7 @@ func TestCreateOrGetAutoCheckpointDeduplicatesConcurrentSameHashCheckpoint(t *te ...@@ -132,7 +132,7 @@ func TestCreateOrGetAutoCheckpointDeduplicatesConcurrentSameHashCheckpoint(t *te
Name: "friendly-checkpoint", Name: "friendly-checkpoint",
Namespace: testNamespace, Namespace: testNamespace,
Labels: map[string]string{ Labels: map[string]string{
consts.KubeLabelCheckpointID: hash, snapshotprotocol.CheckpointIDLabel: hash,
}, },
}, },
Spec: nvidiacomv1alpha1.DynamoCheckpointSpec{ Spec: nvidiacomv1alpha1.DynamoCheckpointSpec{
...@@ -177,7 +177,7 @@ func TestCreateOrGetAutoCheckpointSetsDefaultArtifactVersion(t *testing.T) { ...@@ -177,7 +177,7 @@ func TestCreateOrGetAutoCheckpointSetsDefaultArtifactVersion(t *testing.T) {
ckpt, err := CreateOrGetAutoCheckpoint(ctx, c, testNamespace, testIdentity(), corev1.PodTemplateSpec{}) ckpt, err := CreateOrGetAutoCheckpoint(ctx, c, testNamespace, testIdentity(), corev1.PodTemplateSpec{})
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, ckpt.Annotations) require.NotNil(t, ckpt.Annotations)
assert.Equal(t, consts.DefaultCheckpointArtifactVersion, ckpt.Annotations[consts.KubeAnnotationCheckpointArtifactVersion]) assert.Equal(t, snapshotprotocol.DefaultCheckpointArtifactVersion, ckpt.Annotations[snapshotprotocol.CheckpointArtifactVersionAnnotation])
} }
// --- InjectCheckpointIntoPodSpec tests --- // --- InjectCheckpointIntoPodSpec tests ---
......
...@@ -86,7 +86,7 @@ func InjectCheckpointIntoPodSpec( ...@@ -86,7 +86,7 @@ func InjectCheckpointIntoPodSpec(
mainContainer, mainContainer,
info.Hash, info.Hash,
info.ArtifactVersion, info.ArtifactVersion,
commonconsts.SeccompProfilePath, snapshotprotocol.DefaultSeccompLocalhostProfile,
info.Ready, info.Ready,
); err != nil { ); err != nil {
return err return err
......
...@@ -22,7 +22,6 @@ import ( ...@@ -22,7 +22,6 @@ import (
"fmt" "fmt"
nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1" nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol" snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors" apierrors "k8s.io/apimachinery/pkg/api/errors"
...@@ -56,7 +55,7 @@ func FindCheckpointByIdentityHash( ...@@ -56,7 +55,7 @@ func FindCheckpointByIdentityHash(
ctx, ctx,
checkpoints, checkpoints,
client.InNamespace(namespace), client.InNamespace(namespace),
client.MatchingLabels{consts.KubeLabelCheckpointID: hash}, client.MatchingLabels{snapshotprotocol.CheckpointIDLabel: hash},
); err != nil { ); err != nil {
return nil, fmt.Errorf("failed to list checkpoints by hash label: %w", err) return nil, fmt.Errorf("failed to list checkpoints by hash label: %w", err)
} }
...@@ -119,7 +118,7 @@ func CreateOrGetAutoCheckpoint( ...@@ -119,7 +118,7 @@ func CreateOrGetAutoCheckpoint(
Name: fmt.Sprintf("checkpoint-%s", hash), Name: fmt.Sprintf("checkpoint-%s", hash),
Namespace: namespace, Namespace: namespace,
Labels: map[string]string{ Labels: map[string]string{
consts.KubeLabelCheckpointID: hash, snapshotprotocol.CheckpointIDLabel: hash,
}, },
Annotations: map[string]string{ Annotations: map[string]string{
snapshotprotocol.CheckpointArtifactVersionAnnotation: snapshotprotocol.DefaultCheckpointArtifactVersion, snapshotprotocol.CheckpointArtifactVersionAnnotation: snapshotprotocol.DefaultCheckpointArtifactVersion,
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package checkpointjob
import (
"testing"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
)
func TestDesiredCheckpointJobName(t *testing.T) {
name := DesiredCheckpointJobName("abc123def4567890", map[string]string{
snapshotprotocol.CheckpointArtifactVersionAnnotation: "2",
})
if name != "checkpoint-job-abc123def4567890-2" {
t.Fatalf("unexpected checkpoint job name: %s", name)
}
defaultName := DesiredCheckpointJobName("abc123def4567890", nil)
if defaultName != "checkpoint-job-abc123def4567890-"+snapshotprotocol.DefaultCheckpointArtifactVersion {
t.Fatalf("unexpected default checkpoint job name: %s", defaultName)
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package checkpointjob
import (
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
batchv1 "k8s.io/api/batch/v1"
corev1 "k8s.io/api/core/v1"
)
type ObservationPhase string
const (
ObservationPhaseRunning ObservationPhase = "running"
ObservationPhaseWaitingForConfirmation ObservationPhase = "waiting_for_confirmation"
ObservationPhaseReady ObservationPhase = "ready"
ObservationPhaseFailed ObservationPhase = "failed"
)
type Observation struct {
Phase ObservationPhase
Reason string
Message string
}
func Observe(job *batchv1.Job, checkpointWorkerActive bool) Observation {
jobComplete := false
jobFailed := false
for _, condition := range job.Status.Conditions {
if condition.Status != corev1.ConditionTrue {
continue
}
if condition.Type == batchv1.JobComplete {
jobComplete = true
continue
}
if condition.Type == batchv1.JobFailed {
jobFailed = true
}
}
status := job.Annotations[snapshotprotocol.CheckpointStatusAnnotation]
if status == snapshotprotocol.CheckpointStatusFailed {
observation := Observation{
Phase: ObservationPhaseFailed,
Reason: "JobFailed",
Message: "Checkpoint job failed",
}
if jobComplete {
observation.Reason = "CheckpointVerificationFailed"
observation.Message = "Checkpoint job completed but snapshot-agent reported checkpoint failure"
}
return observation
}
if jobComplete {
if status == snapshotprotocol.CheckpointStatusCompleted {
return Observation{
Phase: ObservationPhaseReady,
Reason: "JobSucceeded",
Message: "Checkpoint job completed successfully",
}
}
if checkpointWorkerActive {
return Observation{Phase: ObservationPhaseWaitingForConfirmation}
}
return Observation{
Phase: ObservationPhaseFailed,
Reason: "CheckpointVerificationFailed",
Message: "Checkpoint job completed without snapshot-agent completion confirmation",
}
}
if jobFailed {
return Observation{
Phase: ObservationPhaseFailed,
Reason: "JobFailed",
Message: "Checkpoint job failed",
}
}
return Observation{Phase: ObservationPhaseRunning}
}
...@@ -139,26 +139,8 @@ const ( ...@@ -139,26 +139,8 @@ const (
ResourceStateNotReady = "not_ready" ResourceStateNotReady = "not_ready"
ResourceStateUnknown = "unknown" ResourceStateUnknown = "unknown"
// Checkpoint/restore constants
// CROSS-REFERENCE: Some constants below are duplicated in deploy/snapshot/protocol.
// If you change a value here, update there too.
// Kubernetes labels
KubeLabelIsCheckpointSource = "nvidia.com/snapshot-is-checkpoint-source" // Pod label that triggers DaemonSet auto-checkpoint
KubeLabelCheckpointID = "nvidia.com/snapshot-checkpoint-id" // Checkpoint identity label; the operator stores the resolved identity hash as the value
KubeLabelIsRestoreTarget = "nvidia.com/snapshot-is-restore-target" // Pod label that triggers DaemonSet auto-restore
KubeAnnotationCheckpointArtifactVersion = "nvidia.com/snapshot-artifact-version" // Checkpoint artifact generation; changing it triggers a new immutable capture attempt
DefaultCheckpointArtifactVersion = "1"
DefaultCheckpointJobTTLSeconds = int32(300)
// Environment variables injected into pods // Environment variables injected into pods
EnvReadyForCheckpointFile = "DYN_READY_FOR_CHECKPOINT_FILE" // Ready-for-checkpoint file path — checkpoint job pods EnvReadyForCheckpointFile = "DYN_READY_FOR_CHECKPOINT_FILE" // Ready-for-checkpoint file path — checkpoint job pods
// Checkpoint pod-internal constants
CheckpointVolumeName = "checkpoint-storage" // Pod-internal volume name for checkpoint PVC
// SeccompProfilePath is the localhost seccomp profile that blocks io_uring syscalls.
// Deployed to nodes by the snapshot DaemonSet init container.
SeccompProfilePath = "profiles/block-iouring.json"
// Pod identity (Downward API) --- // Pod identity (Downward API) ---
// After CRIU restore, env vars contain stale values from the checkpoint pod. // After CRIU restore, env vars contain stale values from the checkpoint pod.
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
package checkpointjob package controller
import ( import (
"fmt" "fmt"
...@@ -17,10 +17,6 @@ import ( ...@@ -17,10 +17,6 @@ import (
"k8s.io/apimachinery/pkg/api/resource" "k8s.io/apimachinery/pkg/api/resource"
) )
func DesiredCheckpointJobName(identityHash string, annotations map[string]string) string {
return "checkpoint-job-" + identityHash + "-" + snapshotprotocol.ArtifactVersion(annotations[snapshotprotocol.CheckpointArtifactVersionAnnotation])
}
func buildCheckpointWorkerDefaultEnv( func buildCheckpointWorkerDefaultEnv(
ckpt *nvidiacomv1alpha1.DynamoCheckpoint, ckpt *nvidiacomv1alpha1.DynamoCheckpoint,
podTemplate *corev1.PodTemplateSpec, podTemplate *corev1.PodTemplateSpec,
...@@ -50,7 +46,7 @@ func buildCheckpointWorkerDefaultEnv( ...@@ -50,7 +46,7 @@ func buildCheckpointWorkerDefaultEnv(
return defaultContainer.Env return defaultContainer.Env
} }
func BuildCheckpointJob( func buildCheckpointJob(
config *configv1alpha1.OperatorConfiguration, config *configv1alpha1.OperatorConfiguration,
ckpt *nvidiacomv1alpha1.DynamoCheckpoint, ckpt *nvidiacomv1alpha1.DynamoCheckpoint,
jobName string, jobName string,
...@@ -118,7 +114,7 @@ func BuildCheckpointJob( ...@@ -118,7 +114,7 @@ func BuildCheckpointJob(
Namespace: ckpt.Namespace, Namespace: ckpt.Namespace,
CheckpointID: hash, CheckpointID: hash,
ArtifactVersion: snapshotprotocol.ArtifactVersion(ckpt.Annotations[snapshotprotocol.CheckpointArtifactVersionAnnotation]), ArtifactVersion: snapshotprotocol.ArtifactVersion(ckpt.Annotations[snapshotprotocol.CheckpointArtifactVersionAnnotation]),
SeccompProfile: consts.SeccompProfilePath, SeccompProfile: snapshotprotocol.DefaultSeccompLocalhostProfile,
Name: jobName, Name: jobName,
ActiveDeadlineSeconds: activeDeadlineSeconds, ActiveDeadlineSeconds: activeDeadlineSeconds,
TTLSecondsAfterFinish: &ttlSecondsAfterFinish, TTLSecondsAfterFinish: &ttlSecondsAfterFinish,
......
...@@ -39,9 +39,8 @@ import ( ...@@ -39,9 +39,8 @@ import (
configv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/config/v1alpha1" configv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/config/v1alpha1"
nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1" nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpoint" "github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpoint"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpointjob"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
commonController "github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common" commonController "github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
) )
// CheckpointReconciler reconciles a DynamoCheckpoint object // CheckpointReconciler reconciles a DynamoCheckpoint object
...@@ -86,8 +85,8 @@ func (r *CheckpointReconciler) Reconcile(ctx context.Context, req ctrl.Request) ...@@ -86,8 +85,8 @@ func (r *CheckpointReconciler) Reconcile(ctx context.Context, req ctrl.Request)
if ckpt.Labels == nil { if ckpt.Labels == nil {
ckpt.Labels = map[string]string{} ckpt.Labels = map[string]string{}
} }
if ckpt.Labels[consts.KubeLabelCheckpointID] != identityHash { if ckpt.Labels[snapshotprotocol.CheckpointIDLabel] != identityHash {
ckpt.Labels[consts.KubeLabelCheckpointID] = identityHash ckpt.Labels[snapshotprotocol.CheckpointIDLabel] = identityHash
if err := r.Update(ctx, ckpt); err != nil { if err := r.Update(ctx, ckpt); err != nil {
return ctrl.Result{}, err return ctrl.Result{}, err
} }
...@@ -117,7 +116,10 @@ func (r *CheckpointReconciler) Reconcile(ctx context.Context, req ctrl.Request) ...@@ -117,7 +116,10 @@ func (r *CheckpointReconciler) Reconcile(ctx context.Context, req ctrl.Request)
} }
return ctrl.Result{}, nil return ctrl.Result{}, nil
} }
desiredJobName := checkpointjob.DesiredCheckpointJobName(identityHash, ckpt.Annotations) desiredJobName := snapshotprotocol.GetCheckpointJobName(
identityHash,
ckpt.Annotations[snapshotprotocol.CheckpointArtifactVersionAnnotation],
)
switch ckpt.Status.Phase { switch ckpt.Status.Phase {
case "", nvidiacomv1alpha1.DynamoCheckpointPhasePending, nvidiacomv1alpha1.DynamoCheckpointPhaseCreating, nvidiacomv1alpha1.DynamoCheckpointPhaseReady, nvidiacomv1alpha1.DynamoCheckpointPhaseFailed: case "", nvidiacomv1alpha1.DynamoCheckpointPhasePending, nvidiacomv1alpha1.DynamoCheckpointPhaseCreating, nvidiacomv1alpha1.DynamoCheckpointPhaseReady, nvidiacomv1alpha1.DynamoCheckpointPhaseFailed:
default: default:
...@@ -181,11 +183,14 @@ func (r *CheckpointReconciler) handlePending(ctx context.Context, ckpt *nvidiaco ...@@ -181,11 +183,14 @@ func (r *CheckpointReconciler) handlePending(ctx context.Context, ckpt *nvidiaco
return ctrl.Result{}, fmt.Errorf("failed to compute checkpoint identity hash: %w", err) return ctrl.Result{}, fmt.Errorf("failed to compute checkpoint identity hash: %w", err)
} }
} }
jobName := checkpointjob.DesiredCheckpointJobName(hash, ckpt.Annotations) jobName := snapshotprotocol.GetCheckpointJobName(
hash,
ckpt.Annotations[snapshotprotocol.CheckpointArtifactVersionAnnotation],
)
// Use SyncResource to create/update the checkpoint Job // Use SyncResource to create/update the checkpoint Job
modified, _, err := commonController.SyncResource(ctx, r, ckpt, func(ctx context.Context) (*batchv1.Job, bool, error) { modified, _, err := commonController.SyncResource(ctx, r, ckpt, func(ctx context.Context) (*batchv1.Job, bool, error) {
job, err := checkpointjob.BuildCheckpointJob(r.Config, ckpt, jobName) job, err := buildCheckpointJob(r.Config, ckpt, jobName)
return job, false, err return job, false, err
}) })
if err != nil { if err != nil {
...@@ -276,12 +281,12 @@ func (r *CheckpointReconciler) handleCreating(ctx context.Context, ckpt *nvidiac ...@@ -276,12 +281,12 @@ func (r *CheckpointReconciler) handleCreating(ctx context.Context, ckpt *nvidiac
} }
} }
observation := checkpointjob.Observe(job, checkpointWorkerActive) observation := snapshotprotocol.ObserveCheckpointJob(job, checkpointWorkerActive)
switch observation.Phase { switch observation.Phase {
case checkpointjob.ObservationPhaseWaitingForConfirmation: case snapshotprotocol.CheckpointObservationPhaseWaitingForConfirmation:
logger.V(1).Info("Checkpoint job is complete but checkpoint worker is still active; waiting for terminal watcher status", "job", job.Name) logger.V(1).Info("Checkpoint job is complete but checkpoint worker is still active; waiting for terminal watcher status", "job", job.Name)
return ctrl.Result{RequeueAfter: time.Second}, nil return ctrl.Result{RequeueAfter: time.Second}, nil
case checkpointjob.ObservationPhaseReady: case snapshotprotocol.CheckpointObservationPhaseReady:
logger.Info("Checkpoint Job succeeded", "job", job.Name) logger.Info("Checkpoint Job succeeded", "job", job.Name)
r.Recorder.Event(ckpt, corev1.EventTypeNormal, "CheckpointReady", observation.Message) r.Recorder.Event(ckpt, corev1.EventTypeNormal, "CheckpointReady", observation.Message)
...@@ -300,7 +305,7 @@ func (r *CheckpointReconciler) handleCreating(ctx context.Context, ckpt *nvidiac ...@@ -300,7 +305,7 @@ func (r *CheckpointReconciler) handleCreating(ctx context.Context, ckpt *nvidiac
return ctrl.Result{}, err return ctrl.Result{}, err
} }
return ctrl.Result{}, nil return ctrl.Result{}, nil
case checkpointjob.ObservationPhaseFailed: case snapshotprotocol.CheckpointObservationPhaseFailed:
logger.Info("Checkpoint Job failed", "job", job.Name, "message", observation.Message) logger.Info("Checkpoint Job failed", "job", job.Name, "message", observation.Message)
r.Recorder.Event(ckpt, corev1.EventTypeWarning, "CheckpointFailed", observation.Message) r.Recorder.Event(ckpt, corev1.EventTypeWarning, "CheckpointFailed", observation.Message)
......
...@@ -25,7 +25,6 @@ import ( ...@@ -25,7 +25,6 @@ import (
configv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/config/v1alpha1" configv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/config/v1alpha1"
nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1" nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpoint" "github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpoint"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpointjob"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/consts" "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol" snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
...@@ -60,9 +59,7 @@ var testHash = func() string { ...@@ -60,9 +59,7 @@ var testHash = func() string {
return hash return hash
}() }()
var defaultCheckpointJobName = checkpointjob.DesiredCheckpointJobName(testHash, map[string]string{ var defaultCheckpointJobName = snapshotprotocol.GetCheckpointJobName(testHash, snapshotprotocol.DefaultCheckpointArtifactVersion)
snapshotprotocol.CheckpointArtifactVersionAnnotation: snapshotprotocol.DefaultCheckpointArtifactVersion,
})
func checkpointTestScheme() *runtime.Scheme { func checkpointTestScheme() *runtime.Scheme {
s := runtime.NewScheme() s := runtime.NewScheme()
...@@ -140,15 +137,15 @@ func TestBuildCheckpointJob(t *testing.T) { ...@@ -140,15 +137,15 @@ func TestBuildCheckpointJob(t *testing.T) {
} }
r := makeCheckpointReconciler(s, ckpt) r := makeCheckpointReconciler(s, ckpt)
job, err := checkpointjob.BuildCheckpointJob(r.Config, ckpt, defaultCheckpointJobName) job, err := buildCheckpointJob(r.Config, ckpt, defaultCheckpointJobName)
require.NoError(t, err) require.NoError(t, err)
podSpec := job.Spec.Template.Spec podSpec := job.Spec.Template.Spec
main := podSpec.Containers[0] main := podSpec.Containers[0]
// Job and pod template labels // Job and pod template labels
assert.Equal(t, testHash, job.Labels[consts.KubeLabelCheckpointID]) assert.Equal(t, testHash, job.Labels[snapshotprotocol.CheckpointIDLabel])
assert.Equal(t, "true", job.Spec.Template.Labels[consts.KubeLabelIsCheckpointSource]) assert.Equal(t, "true", job.Spec.Template.Labels[snapshotprotocol.CheckpointSourceLabel])
assert.Equal(t, testHash, job.Spec.Template.Labels[consts.KubeLabelCheckpointID]) assert.Equal(t, testHash, job.Spec.Template.Labels[snapshotprotocol.CheckpointIDLabel])
// Env vars (checkpoint-specific + user-provided preserved) // Env vars (checkpoint-specific + user-provided preserved)
envMap := make(map[string]string, len(main.Env)) envMap := make(map[string]string, len(main.Env))
...@@ -180,7 +177,7 @@ func TestBuildCheckpointJob(t *testing.T) { ...@@ -180,7 +177,7 @@ func TestBuildCheckpointJob(t *testing.T) {
require.NotNil(t, podSpec.SecurityContext) require.NotNil(t, podSpec.SecurityContext)
require.NotNil(t, podSpec.SecurityContext.SeccompProfile) require.NotNil(t, podSpec.SecurityContext.SeccompProfile)
assert.Equal(t, corev1.SeccompProfileTypeLocalhost, podSpec.SecurityContext.SeccompProfile.Type) assert.Equal(t, corev1.SeccompProfileTypeLocalhost, podSpec.SecurityContext.SeccompProfile.Type)
assert.Equal(t, consts.SeccompProfilePath, *podSpec.SecurityContext.SeccompProfile.LocalhostProfile) assert.Equal(t, snapshotprotocol.DefaultSeccompLocalhostProfile, *podSpec.SecurityContext.SeccompProfile.LocalhostProfile)
require.NotNil(t, podSpec.SecurityContext.RunAsUser) require.NotNil(t, podSpec.SecurityContext.RunAsUser)
assert.Equal(t, int64(1234), *podSpec.SecurityContext.RunAsUser) assert.Equal(t, int64(1234), *podSpec.SecurityContext.RunAsUser)
require.NotNil(t, podSpec.SecurityContext.FSGroup) require.NotNil(t, podSpec.SecurityContext.FSGroup)
...@@ -197,14 +194,14 @@ func TestBuildCheckpointJob(t *testing.T) { ...@@ -197,14 +194,14 @@ func TestBuildCheckpointJob(t *testing.T) {
for _, v := range podSpec.Volumes { for _, v := range podSpec.Volumes {
volNames[v.Name] = true volNames[v.Name] = true
} }
assert.False(t, volNames[consts.CheckpointVolumeName]) assert.False(t, volNames[snapshotprotocol.CheckpointVolumeName])
assert.True(t, volNames[consts.PodInfoVolumeName]) assert.True(t, volNames[consts.PodInfoVolumeName])
mountPaths := make(map[string]string) mountPaths := make(map[string]string)
for _, m := range main.VolumeMounts { for _, m := range main.VolumeMounts {
mountPaths[m.Name] = m.MountPath mountPaths[m.Name] = m.MountPath
} }
_, hasCheckpointMount := mountPaths[consts.CheckpointVolumeName] _, hasCheckpointMount := mountPaths[snapshotprotocol.CheckpointVolumeName]
assert.False(t, hasCheckpointMount) assert.False(t, hasCheckpointMount)
assert.Equal(t, consts.PodInfoMountPath, mountPaths[consts.PodInfoVolumeName]) assert.Equal(t, consts.PodInfoMountPath, mountPaths[consts.PodInfoVolumeName])
assert.Equal(t, consts.DefaultSharedMemoryMountPath, mountPaths[consts.KubeValueNameSharedMemory]) assert.Equal(t, consts.DefaultSharedMemoryMountPath, mountPaths[consts.KubeValueNameSharedMemory])
...@@ -237,7 +234,7 @@ func TestBuildCheckpointJob(t *testing.T) { ...@@ -237,7 +234,7 @@ func TestBuildCheckpointJob(t *testing.T) {
backoff := int32(5) backoff := int32(5)
ckpt.Spec.Job.ActiveDeadlineSeconds = &deadline ckpt.Spec.Job.ActiveDeadlineSeconds = &deadline
ckpt.Spec.Job.BackoffLimit = &backoff //nolint:staticcheck // Compatibility test: deprecated field must remain ignored by checkpoint Jobs. ckpt.Spec.Job.BackoffLimit = &backoff //nolint:staticcheck // Compatibility test: deprecated field must remain ignored by checkpoint Jobs.
job, err = checkpointjob.BuildCheckpointJob(r.Config, ckpt, defaultCheckpointJobName) job, err = buildCheckpointJob(r.Config, ckpt, defaultCheckpointJobName)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, int64(7200), *job.Spec.ActiveDeadlineSeconds) assert.Equal(t, int64(7200), *job.Spec.ActiveDeadlineSeconds)
assert.Equal(t, int32(0), *job.Spec.BackoffLimit) assert.Equal(t, int32(0), *job.Spec.BackoffLimit)
...@@ -248,7 +245,7 @@ func TestBuildCheckpointJob(t *testing.T) { ...@@ -248,7 +245,7 @@ func TestBuildCheckpointJob(t *testing.T) {
corev1.ResourceName("nvidia.com/gpu"): resource.MustParse("2"), corev1.ResourceName("nvidia.com/gpu"): resource.MustParse("2"),
}, },
} }
job, err = checkpointjob.BuildCheckpointJob(r.Config, ckpt, defaultCheckpointJobName) job, err = buildCheckpointJob(r.Config, ckpt, defaultCheckpointJobName)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []string{"cuda-checkpoint"}, job.Spec.Template.Spec.Containers[0].Command) assert.Equal(t, []string{"cuda-checkpoint"}, job.Spec.Template.Spec.Containers[0].Command)
assert.Equal(t, []string{"--launch-job", "python3", "-m", "dynamo.vllm"}, job.Spec.Template.Spec.Containers[0].Args) assert.Equal(t, []string{"--launch-job", "python3", "-m", "dynamo.vllm"}, job.Spec.Template.Spec.Containers[0].Args)
...@@ -273,7 +270,7 @@ func TestBuildCheckpointJobInjectsStandardEnvVars(t *testing.T) { ...@@ -273,7 +270,7 @@ func TestBuildCheckpointJobInjectsStandardEnvVars(t *testing.T) {
customShmSize := resource.MustParse("16Gi") customShmSize := resource.MustParse("16Gi")
ckpt.Spec.Job.SharedMemory = &nvidiacomv1alpha1.SharedMemorySpec{Size: customShmSize} ckpt.Spec.Job.SharedMemory = &nvidiacomv1alpha1.SharedMemorySpec{Size: customShmSize}
job, err := checkpointjob.BuildCheckpointJob(r.Config, ckpt, defaultCheckpointJobName) job, err := buildCheckpointJob(r.Config, ckpt, defaultCheckpointJobName)
require.NoError(t, err) require.NoError(t, err)
foundCustomShmVolume := false foundCustomShmVolume := false
for _, v := range job.Spec.Template.Spec.Volumes { for _, v := range job.Spec.Template.Spec.Volumes {
...@@ -326,7 +323,7 @@ func TestCheckpointReconciler_Reconcile(t *testing.T) { ...@@ -326,7 +323,7 @@ func TestCheckpointReconciler_Reconcile(t *testing.T) {
assert.Equal(t, nvidiacomv1alpha1.DynamoCheckpointPhasePending, updated.Status.Phase) assert.Equal(t, nvidiacomv1alpha1.DynamoCheckpointPhasePending, updated.Status.Phase)
assert.Equal(t, testHash, updated.Status.IdentityHash) assert.Equal(t, testHash, updated.Status.IdentityHash)
assert.Empty(t, updated.Status.Message) assert.Empty(t, updated.Status.Message)
assert.Equal(t, testHash, updated.Labels[consts.KubeLabelCheckpointID]) assert.Equal(t, testHash, updated.Labels[snapshotprotocol.CheckpointIDLabel])
}) })
t.Run("Ready phase is a no-op", func(t *testing.T) { t.Run("Ready phase is a no-op", func(t *testing.T) {
...@@ -352,7 +349,7 @@ func TestCheckpointReconciler_Reconcile(t *testing.T) { ...@@ -352,7 +349,7 @@ func TestCheckpointReconciler_Reconcile(t *testing.T) {
updated := &nvidiacomv1alpha1.DynamoCheckpoint{} updated := &nvidiacomv1alpha1.DynamoCheckpoint{}
require.NoError(t, r.Get(ctx, types.NamespacedName{Name: friendlyCheckpointName, Namespace: testNamespace}, updated)) require.NoError(t, r.Get(ctx, types.NamespacedName{Name: friendlyCheckpointName, Namespace: testNamespace}, updated))
assert.Equal(t, testHash, updated.Labels[consts.KubeLabelCheckpointID]) assert.Equal(t, testHash, updated.Labels[snapshotprotocol.CheckpointIDLabel])
assert.Equal(t, testHash, updated.Status.IdentityHash) assert.Equal(t, testHash, updated.Status.IdentityHash)
}) })
...@@ -375,7 +372,7 @@ func TestCheckpointReconciler_Reconcile(t *testing.T) { ...@@ -375,7 +372,7 @@ func TestCheckpointReconciler_Reconcile(t *testing.T) {
ckpt := makeTestCheckpoint(nvidiacomv1alpha1.DynamoCheckpointPhaseReady) ckpt := makeTestCheckpoint(nvidiacomv1alpha1.DynamoCheckpointPhaseReady)
ckpt.Status.IdentityHash = testHash ckpt.Status.IdentityHash = testHash
ckpt.Status.JobName = defaultCheckpointJobName ckpt.Status.JobName = defaultCheckpointJobName
ckpt.Annotations = map[string]string{consts.KubeAnnotationCheckpointArtifactVersion: "2"} ckpt.Annotations = map[string]string{snapshotprotocol.CheckpointArtifactVersionAnnotation: "2"}
r := makeCheckpointReconciler(s, ckpt) r := makeCheckpointReconciler(s, ckpt)
_, err := r.Reconcile(ctx, ctrl.Request{ _, err := r.Reconcile(ctx, ctrl.Request{
...@@ -584,7 +581,7 @@ func TestCheckpointReconciler_HandleCreating(t *testing.T) { ...@@ -584,7 +581,7 @@ func TestCheckpointReconciler_HandleCreating(t *testing.T) {
t.Run("in-flight version changes do not relabel the running job's artifact", func(t *testing.T) { t.Run("in-flight version changes do not relabel the running job's artifact", func(t *testing.T) {
ckpt := makeCreatingCkpt(testHash, defaultCheckpointJobName) ckpt := makeCreatingCkpt(testHash, defaultCheckpointJobName)
ckpt.Annotations = map[string]string{consts.KubeAnnotationCheckpointArtifactVersion: "2"} ckpt.Annotations = map[string]string{snapshotprotocol.CheckpointArtifactVersionAnnotation: "2"}
job := &batchv1.Job{ job := &batchv1.Job{
ObjectMeta: metav1.ObjectMeta{ ObjectMeta: metav1.ObjectMeta{
Name: defaultCheckpointJobName, Name: defaultCheckpointJobName,
......
...@@ -41,6 +41,7 @@ import ( ...@@ -41,6 +41,7 @@ import (
commonController "github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common" commonController "github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/dynamo" "github.com/ai-dynamo/dynamo/deploy/operator/internal/dynamo"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/observability" "github.com/ai-dynamo/dynamo/deploy/operator/internal/observability"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
networkingv1beta1 "istio.io/client-go/pkg/apis/networking/v1beta1" networkingv1beta1 "istio.io/client-go/pkg/apis/networking/v1beta1"
k8serrors "k8s.io/apimachinery/pkg/api/errors" k8serrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/api/meta" "k8s.io/apimachinery/pkg/api/meta"
...@@ -961,7 +962,7 @@ func (r *DynamoComponentDeploymentReconciler) generateDeployment(ctx context.Con ...@@ -961,7 +962,7 @@ func (r *DynamoComponentDeploymentReconciler) generateDeployment(ctx context.Con
// the old pod is terminated before the restore placeholder is started. // the old pod is terminated before the restore placeholder is started.
if podTemplateSpec != nil && if podTemplateSpec != nil &&
podTemplateSpec.Labels != nil && podTemplateSpec.Labels != nil &&
podTemplateSpec.Labels[commonconsts.KubeLabelIsRestoreTarget] == commonconsts.KubeLabelValueTrue { podTemplateSpec.Labels[snapshotprotocol.RestoreTargetLabel] == commonconsts.KubeLabelValueTrue {
strategy = appsv1.DeploymentStrategy{ strategy = appsv1.DeploymentStrategy{
Type: appsv1.RecreateDeploymentStrategyType, Type: appsv1.RecreateDeploymentStrategyType,
} }
......
...@@ -1306,7 +1306,7 @@ func TestDynamoComponentDeploymentReconciler_generatePodTemplateSpec_RestoreLabe ...@@ -1306,7 +1306,7 @@ func TestDynamoComponentDeploymentReconciler_generatePodTemplateSpec_RestoreLabe
Labels: map[string]string{ Labels: map[string]string{
commonconsts.KubeLabelDynamoGraphDeploymentName: "test-dgd", commonconsts.KubeLabelDynamoGraphDeploymentName: "test-dgd",
commonconsts.KubeLabelDynamoWorkerHash: "workerhash", commonconsts.KubeLabelDynamoWorkerHash: "workerhash",
commonconsts.KubeLabelIsRestoreTarget: commonconsts.KubeLabelValueTrue, snapshotprotocol.RestoreTargetLabel: commonconsts.KubeLabelValueTrue,
}, },
Checkpoint: &v1alpha1.ServiceCheckpointConfig{ Checkpoint: &v1alpha1.ServiceCheckpointConfig{
Enabled: true, Enabled: true,
...@@ -1368,11 +1368,11 @@ func TestDynamoComponentDeploymentReconciler_generatePodTemplateSpec_RestoreLabe ...@@ -1368,11 +1368,11 @@ func TestDynamoComponentDeploymentReconciler_generatePodTemplateSpec_RestoreLabe
t.Fatalf("generatePodTemplateSpec failed: %v", err) t.Fatalf("generatePodTemplateSpec failed: %v", err)
} }
if got := podTemplateSpec.Labels[commonconsts.KubeLabelIsRestoreTarget]; got != commonconsts.KubeLabelValueTrue { if got := podTemplateSpec.Labels[snapshotprotocol.RestoreTargetLabel]; got != commonconsts.KubeLabelValueTrue {
t.Fatalf("expected %s label to be true, got %q", commonconsts.KubeLabelIsRestoreTarget, got) t.Fatalf("expected %s label to be true, got %q", snapshotprotocol.RestoreTargetLabel, got)
} }
if got := podTemplateSpec.Labels[commonconsts.KubeLabelCheckpointID]; got != checkpointName { if got := podTemplateSpec.Labels[snapshotprotocol.CheckpointIDLabel]; got != checkpointName {
t.Fatalf("expected %s to be checkpoint id, got %q", commonconsts.KubeLabelCheckpointID, got) t.Fatalf("expected %s to be checkpoint id, got %q", snapshotprotocol.CheckpointIDLabel, got)
} }
}) })
...@@ -1454,11 +1454,11 @@ func TestDynamoComponentDeploymentReconciler_generatePodTemplateSpec_RestoreLabe ...@@ -1454,11 +1454,11 @@ func TestDynamoComponentDeploymentReconciler_generatePodTemplateSpec_RestoreLabe
t.Fatalf("generatePodTemplateSpec failed: %v", err) t.Fatalf("generatePodTemplateSpec failed: %v", err)
} }
if _, ok := podTemplateSpec.Labels[commonconsts.KubeLabelIsRestoreTarget]; ok { if _, ok := podTemplateSpec.Labels[snapshotprotocol.RestoreTargetLabel]; ok {
t.Fatalf("did not expect %s label when checkpoint is not ready", commonconsts.KubeLabelIsRestoreTarget) t.Fatalf("did not expect %s label when checkpoint is not ready", snapshotprotocol.RestoreTargetLabel)
} }
if _, ok := podTemplateSpec.Labels[commonconsts.KubeLabelCheckpointID]; ok { if _, ok := podTemplateSpec.Labels[snapshotprotocol.CheckpointIDLabel]; ok {
t.Fatalf("did not expect %s label when checkpoint is not ready", commonconsts.KubeLabelCheckpointID) t.Fatalf("did not expect %s label when checkpoint is not ready", snapshotprotocol.CheckpointIDLabel)
} }
}) })
} }
......
...@@ -31,6 +31,7 @@ import ( ...@@ -31,6 +31,7 @@ import (
"github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpoint" "github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpoint"
commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts" commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common" "github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
grovev1alpha1 "github.com/ai-dynamo/grove/operator/api/core/v1alpha1" grovev1alpha1 "github.com/ai-dynamo/grove/operator/api/core/v1alpha1"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
...@@ -6860,12 +6861,12 @@ func TestGenerateLabels_RemovesStaleRestoreLabelsWhenCheckpointNotReady(t *testi ...@@ -6860,12 +6861,12 @@ func TestGenerateLabels_RemovesStaleRestoreLabelsWhenCheckpointNotReady(t *testi
DynamoNamespace: ptr.To("default-test-dgd"), DynamoNamespace: ptr.To("default-test-dgd"),
Labels: map[string]string{ Labels: map[string]string{
"user-label": "keep", "user-label": "keep",
commonconsts.KubeLabelIsRestoreTarget: commonconsts.KubeLabelValueTrue, snapshotprotocol.RestoreTargetLabel: commonconsts.KubeLabelValueTrue,
}, },
ExtraPodMetadata: &v1alpha1.ExtraPodMetadata{ ExtraPodMetadata: &v1alpha1.ExtraPodMetadata{
Labels: map[string]string{ Labels: map[string]string{
"extra-label": "keep-too", "extra-label": "keep-too",
commonconsts.KubeLabelCheckpointID: "stale-hash", snapshotprotocol.CheckpointIDLabel: "stale-hash",
}, },
}, },
}, },
...@@ -6883,8 +6884,8 @@ func TestGenerateLabels_RemovesStaleRestoreLabelsWhenCheckpointNotReady(t *testi ...@@ -6883,8 +6884,8 @@ func TestGenerateLabels_RemovesStaleRestoreLabelsWhenCheckpointNotReady(t *testi
}) })
assert.Equal(t, "keep", labels["user-label"]) assert.Equal(t, "keep", labels["user-label"])
assert.Equal(t, "keep-too", labels["extra-label"]) assert.Equal(t, "keep-too", labels["extra-label"])
_, hasRestoreTarget := labels[commonconsts.KubeLabelIsRestoreTarget] _, hasRestoreTarget := labels[snapshotprotocol.RestoreTargetLabel]
_, hasCheckpointHash := labels[commonconsts.KubeLabelCheckpointID] _, hasCheckpointHash := labels[snapshotprotocol.CheckpointIDLabel]
assert.False(t, hasRestoreTarget) assert.False(t, hasRestoreTarget)
assert.False(t, hasCheckpointHash) assert.False(t, hasCheckpointHash)
} }
...@@ -6895,11 +6896,11 @@ func TestGenerateLabels_OverwritesStaleRestoreLabelsWhenCheckpointReady(t *testi ...@@ -6895,11 +6896,11 @@ func TestGenerateLabels_OverwritesStaleRestoreLabelsWhenCheckpointReady(t *testi
ComponentType: commonconsts.ComponentTypeWorker, ComponentType: commonconsts.ComponentTypeWorker,
DynamoNamespace: ptr.To("default-test-dgd"), DynamoNamespace: ptr.To("default-test-dgd"),
Labels: map[string]string{ Labels: map[string]string{
commonconsts.KubeLabelIsRestoreTarget: "false", snapshotprotocol.RestoreTargetLabel: "false",
}, },
ExtraPodMetadata: &v1alpha1.ExtraPodMetadata{ ExtraPodMetadata: &v1alpha1.ExtraPodMetadata{
Labels: map[string]string{ Labels: map[string]string{
commonconsts.KubeLabelCheckpointID: "stale-hash", snapshotprotocol.CheckpointIDLabel: "stale-hash",
}, },
}, },
}, },
...@@ -6915,8 +6916,8 @@ func TestGenerateLabels_OverwritesStaleRestoreLabelsWhenCheckpointReady(t *testi ...@@ -6915,8 +6916,8 @@ func TestGenerateLabels_OverwritesStaleRestoreLabelsWhenCheckpointReady(t *testi
Ready: true, Ready: true,
Hash: "resolved-hash", Hash: "resolved-hash",
}) })
assert.Equal(t, commonconsts.KubeLabelValueTrue, labels[commonconsts.KubeLabelIsRestoreTarget]) assert.Equal(t, commonconsts.KubeLabelValueTrue, labels[snapshotprotocol.RestoreTargetLabel])
assert.Equal(t, "resolved-hash", labels[commonconsts.KubeLabelCheckpointID]) assert.Equal(t, "resolved-hash", labels[snapshotprotocol.CheckpointIDLabel])
} }
func TestGenerateLabels_ReassertsRestoreIdentityLabelsAfterMetadataMerge(t *testing.T) { func TestGenerateLabels_ReassertsRestoreIdentityLabelsAfterMetadataMerge(t *testing.T) {
......
...@@ -23,6 +23,25 @@ type CheckpointJobOptions struct { ...@@ -23,6 +23,25 @@ type CheckpointJobOptions struct {
WrapLaunchJob bool WrapLaunchJob bool
} }
type CheckpointObservationPhase string
const (
CheckpointObservationPhaseRunning CheckpointObservationPhase = "running"
CheckpointObservationPhaseWaitingForConfirmation CheckpointObservationPhase = "waiting_for_confirmation"
CheckpointObservationPhaseReady CheckpointObservationPhase = "ready"
CheckpointObservationPhaseFailed CheckpointObservationPhase = "failed"
)
type CheckpointObservation struct {
Phase CheckpointObservationPhase
Reason string
Message string
}
func GetCheckpointJobName(checkpointID string, artifactVersion string) string {
return "checkpoint-job-" + checkpointID + "-" + ArtifactVersion(artifactVersion)
}
func NewCheckpointJob(podTemplate *corev1.PodTemplateSpec, opts CheckpointJobOptions) (*batchv1.Job, error) { func NewCheckpointJob(podTemplate *corev1.PodTemplateSpec, opts CheckpointJobOptions) (*batchv1.Job, error) {
podTemplate = podTemplate.DeepCopy() podTemplate = podTemplate.DeepCopy()
if podTemplate.Labels == nil { if podTemplate.Labels == nil {
...@@ -67,6 +86,65 @@ func NewCheckpointJob(podTemplate *corev1.PodTemplateSpec, opts CheckpointJobOpt ...@@ -67,6 +86,65 @@ func NewCheckpointJob(podTemplate *corev1.PodTemplateSpec, opts CheckpointJobOpt
}, nil }, nil
} }
func ObserveCheckpointJob(job *batchv1.Job, checkpointWorkerActive bool) CheckpointObservation {
jobComplete := false
jobFailed := false
for _, condition := range job.Status.Conditions {
if condition.Status != corev1.ConditionTrue {
continue
}
if condition.Type == batchv1.JobComplete {
jobComplete = true
continue
}
if condition.Type == batchv1.JobFailed {
jobFailed = true
}
}
status := job.Annotations[CheckpointStatusAnnotation]
if status == CheckpointStatusFailed {
observation := CheckpointObservation{
Phase: CheckpointObservationPhaseFailed,
Reason: "JobFailed",
Message: "Checkpoint job failed",
}
if jobComplete {
observation.Reason = "CheckpointVerificationFailed"
observation.Message = "Checkpoint job completed but snapshot-agent reported checkpoint failure"
}
return observation
}
if jobComplete {
if status == CheckpointStatusCompleted {
return CheckpointObservation{
Phase: CheckpointObservationPhaseReady,
Reason: "JobSucceeded",
Message: "Checkpoint job completed successfully",
}
}
if checkpointWorkerActive {
return CheckpointObservation{Phase: CheckpointObservationPhaseWaitingForConfirmation}
}
return CheckpointObservation{
Phase: CheckpointObservationPhaseFailed,
Reason: "CheckpointVerificationFailed",
Message: "Checkpoint job completed without snapshot-agent completion confirmation",
}
}
if jobFailed {
return CheckpointObservation{
Phase: CheckpointObservationPhaseFailed,
Reason: "JobFailed",
Message: "Checkpoint job failed",
}
}
return CheckpointObservation{Phase: CheckpointObservationPhaseRunning}
}
func EnsureLocalhostSeccompProfile(podSpec *corev1.PodSpec, profile string) { func EnsureLocalhostSeccompProfile(podSpec *corev1.PodSpec, profile string) {
if podSpec.SecurityContext == nil { if podSpec.SecurityContext == nil {
podSpec.SecurityContext = &corev1.PodSecurityContext{} podSpec.SecurityContext = &corev1.PodSecurityContext{}
......
...@@ -86,3 +86,15 @@ func TestNewCheckpointJob(t *testing.T) { ...@@ -86,3 +86,15 @@ func TestNewCheckpointJob(t *testing.T) {
t.Fatalf("unexpected ttlSecondsAfterFinished: %#v", job.Spec.TTLSecondsAfterFinished) t.Fatalf("unexpected ttlSecondsAfterFinished: %#v", job.Spec.TTLSecondsAfterFinished)
} }
} }
func TestGetCheckpointJobName(t *testing.T) {
name := GetCheckpointJobName("abc123def4567890", "2")
if name != "checkpoint-job-abc123def4567890-2" {
t.Fatalf("unexpected checkpoint job name: %s", name)
}
defaultName := GetCheckpointJobName("abc123def4567890", "")
if defaultName != "checkpoint-job-abc123def4567890-"+DefaultCheckpointArtifactVersion {
t.Fatalf("unexpected default checkpoint job name: %s", defaultName)
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
package checkpointjob package protocol
import ( import (
"testing" "testing"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
batchv1 "k8s.io/api/batch/v1" batchv1 "k8s.io/api/batch/v1"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
) )
func TestObserve(t *testing.T) { func TestObserveCheckpointJob(t *testing.T) {
makeJob := func(annotation string, conditions ...batchv1.JobCondition) *batchv1.Job { makeJob := func(annotation string, conditions ...batchv1.JobCondition) *batchv1.Job {
job := &batchv1.Job{ job := &batchv1.Job{
ObjectMeta: metav1.ObjectMeta{ ObjectMeta: metav1.ObjectMeta{
...@@ -23,7 +22,7 @@ func TestObserve(t *testing.T) { ...@@ -23,7 +22,7 @@ func TestObserve(t *testing.T) {
}, },
} }
if annotation != "" { if annotation != "" {
job.Annotations[snapshotprotocol.CheckpointStatusAnnotation] = annotation job.Annotations[CheckpointStatusAnnotation] = annotation
} }
return job return job
} }
...@@ -32,22 +31,22 @@ func TestObserve(t *testing.T) { ...@@ -32,22 +31,22 @@ func TestObserve(t *testing.T) {
name string name string
job *batchv1.Job job *batchv1.Job
checkpointWorkerActive bool checkpointWorkerActive bool
wantPhase ObservationPhase wantPhase CheckpointObservationPhase
wantReason string wantReason string
wantMessage string wantMessage string
}{ }{
{ {
name: "running job stays running", name: "running job stays running",
job: makeJob(""), job: makeJob(""),
wantPhase: ObservationPhaseRunning, wantPhase: CheckpointObservationPhaseRunning,
}, },
{ {
name: "completed job with completion annotation is ready", name: "completed job with completion annotation is ready",
job: makeJob( job: makeJob(
snapshotprotocol.CheckpointStatusCompleted, CheckpointStatusCompleted,
batchv1.JobCondition{Type: batchv1.JobComplete, Status: corev1.ConditionTrue}, batchv1.JobCondition{Type: batchv1.JobComplete, Status: corev1.ConditionTrue},
), ),
wantPhase: ObservationPhaseReady, wantPhase: CheckpointObservationPhaseReady,
wantReason: "JobSucceeded", wantReason: "JobSucceeded",
wantMessage: "Checkpoint job completed successfully", wantMessage: "Checkpoint job completed successfully",
}, },
...@@ -58,7 +57,7 @@ func TestObserve(t *testing.T) { ...@@ -58,7 +57,7 @@ func TestObserve(t *testing.T) {
batchv1.JobCondition{Type: batchv1.JobComplete, Status: corev1.ConditionTrue}, batchv1.JobCondition{Type: batchv1.JobComplete, Status: corev1.ConditionTrue},
), ),
checkpointWorkerActive: true, checkpointWorkerActive: true,
wantPhase: ObservationPhaseWaitingForConfirmation, wantPhase: CheckpointObservationPhaseWaitingForConfirmation,
}, },
{ {
name: "completed job fails without confirmation once worker is inactive", name: "completed job fails without confirmation once worker is inactive",
...@@ -66,18 +65,18 @@ func TestObserve(t *testing.T) { ...@@ -66,18 +65,18 @@ func TestObserve(t *testing.T) {
"", "",
batchv1.JobCondition{Type: batchv1.JobComplete, Status: corev1.ConditionTrue}, batchv1.JobCondition{Type: batchv1.JobComplete, Status: corev1.ConditionTrue},
), ),
wantPhase: ObservationPhaseFailed, wantPhase: CheckpointObservationPhaseFailed,
wantReason: "CheckpointVerificationFailed", wantReason: "CheckpointVerificationFailed",
wantMessage: "Checkpoint job completed without snapshot-agent completion confirmation", wantMessage: "Checkpoint job completed without snapshot-agent completion confirmation",
}, },
{ {
name: "failed checkpoint annotation wins over completed job", name: "failed checkpoint annotation wins over completed job",
job: makeJob( job: makeJob(
snapshotprotocol.CheckpointStatusFailed, CheckpointStatusFailed,
batchv1.JobCondition{Type: batchv1.JobComplete, Status: corev1.ConditionTrue}, batchv1.JobCondition{Type: batchv1.JobComplete, Status: corev1.ConditionTrue},
), ),
checkpointWorkerActive: true, checkpointWorkerActive: true,
wantPhase: ObservationPhaseFailed, wantPhase: CheckpointObservationPhaseFailed,
wantReason: "CheckpointVerificationFailed", wantReason: "CheckpointVerificationFailed",
wantMessage: "Checkpoint job completed but snapshot-agent reported checkpoint failure", wantMessage: "Checkpoint job completed but snapshot-agent reported checkpoint failure",
}, },
...@@ -85,7 +84,7 @@ func TestObserve(t *testing.T) { ...@@ -85,7 +84,7 @@ func TestObserve(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
observation := Observe(tc.job, tc.checkpointWorkerActive) observation := ObserveCheckpointJob(tc.job, tc.checkpointWorkerActive)
if observation.Phase != tc.wantPhase { if observation.Phase != tc.wantPhase {
t.Fatalf("phase = %q, want %q", observation.Phase, tc.wantPhase) t.Fatalf("phase = %q, want %q", observation.Phase, tc.wantPhase)
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment