dynamocheckpoint_controller.go 12.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
/*
 * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package controller

import (
	"context"
	"fmt"
23
	"time"
24
25

	batchv1 "k8s.io/api/batch/v1"
26
	coordinationv1 "k8s.io/api/coordination/v1"
27
28
29
30
31
32
33
34
35
36
37
38
	corev1 "k8s.io/api/core/v1"
	apierrors "k8s.io/apimachinery/pkg/api/errors"
	"k8s.io/apimachinery/pkg/api/meta"
	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
	"k8s.io/client-go/tools/record"
	ctrl "sigs.k8s.io/controller-runtime"
	"sigs.k8s.io/controller-runtime/pkg/builder"
	"sigs.k8s.io/controller-runtime/pkg/client"
	"sigs.k8s.io/controller-runtime/pkg/event"
	"sigs.k8s.io/controller-runtime/pkg/log"
	"sigs.k8s.io/controller-runtime/pkg/predicate"

39
	configv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/config/v1alpha1"
40
41
42
	nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
	"github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpoint"
	commonController "github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common"
43
	snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
44
45
46
47
48
)

// CheckpointReconciler reconciles a DynamoCheckpoint object
type CheckpointReconciler struct {
	client.Client
49
50
51
	Config        *configv1alpha1.OperatorConfiguration
	RuntimeConfig *commonController.RuntimeConfig
	Recorder      record.EventRecorder
52
53
54
55
56
57
58
59
60
61
62
}

// GetRecorder returns the event recorder (implements controller_common.Reconciler interface)
func (r *CheckpointReconciler) GetRecorder() record.EventRecorder {
	return r.Recorder
}

// +kubebuilder:rbac:groups=nvidia.com,resources=dynamocheckpoints,verbs=get;list;watch;create;update;patch;delete
// +kubebuilder:rbac:groups=nvidia.com,resources=dynamocheckpoints/status,verbs=get;update;patch
// +kubebuilder:rbac:groups=nvidia.com,resources=dynamocheckpoints/finalizers,verbs=update
// +kubebuilder:rbac:groups=batch,resources=jobs,verbs=get;list;watch;create;update;patch;delete
63
// +kubebuilder:rbac:groups=coordination.k8s.io,resources=leases,verbs=get;list;watch
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

func (r *CheckpointReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
	logger := log.FromContext(ctx)

	// Fetch the DynamoCheckpoint instance
	ckpt := &nvidiacomv1alpha1.DynamoCheckpoint{}
	if err := r.Get(ctx, req.NamespacedName, ckpt); err != nil {
		if apierrors.IsNotFound(err) {
			return ctrl.Result{}, nil
		}
		return ctrl.Result{}, err
	}

	logger.Info("Reconciling DynamoCheckpoint", "name", ckpt.Name, "phase", ckpt.Status.Phase)

79
80
81
82
83
84
85
86
87
	identityHash, err := checkpoint.ComputeIdentityHash(ckpt.Spec.Identity)
	if err != nil {
		logger.Error(err, "Failed to compute checkpoint identity hash")
		return ctrl.Result{}, fmt.Errorf("failed to compute checkpoint identity hash: %w", err)
	}

	if ckpt.Labels == nil {
		ckpt.Labels = map[string]string{}
	}
88
89
	if ckpt.Labels[snapshotprotocol.CheckpointIDLabel] != identityHash {
		ckpt.Labels[snapshotprotocol.CheckpointIDLabel] = identityHash
90
91
92
93
94
		if err := r.Update(ctx, ckpt); err != nil {
			return ctrl.Result{}, err
		}
		if err := r.Get(ctx, req.NamespacedName, ckpt); err != nil {
			return ctrl.Result{}, err
95
		}
96
	}
97

98
99
100
101
102
103
	needsStatusUpdate := false
	phaseWasEmpty := ckpt.Status.Phase == ""
	if ckpt.Status.IdentityHash != identityHash {
		ckpt.Status.IdentityHash = identityHash
		needsStatusUpdate = true
	}
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
	existing, err := checkpoint.FindCheckpointByIdentityHash(ctx, r.Client, ckpt.Namespace, identityHash, ckpt.Name)
	if err != nil {
		return ctrl.Result{}, err
	}
	if existing != nil {
		ckpt.Status.Phase = nvidiacomv1alpha1.DynamoCheckpointPhaseFailed
		ckpt.Status.JobName = ""
		ckpt.Status.CreatedAt = nil
		ckpt.Status.Message = fmt.Sprintf("checkpoint identity hash %s is already owned by %s", identityHash, existing.Name)
		if err := r.Status().Update(ctx, ckpt); err != nil {
			logger.Error(err, "Failed to mark duplicate DynamoCheckpoint as failed")
			return ctrl.Result{}, err
		}
		return ctrl.Result{}, nil
	}
119
120
121
122
	desiredJobName := snapshotprotocol.GetCheckpointJobName(
		identityHash,
		ckpt.Annotations[snapshotprotocol.CheckpointArtifactVersionAnnotation],
	)
123
124
125
	switch ckpt.Status.Phase {
	case "", nvidiacomv1alpha1.DynamoCheckpointPhasePending, nvidiacomv1alpha1.DynamoCheckpointPhaseCreating, nvidiacomv1alpha1.DynamoCheckpointPhaseReady, nvidiacomv1alpha1.DynamoCheckpointPhaseFailed:
	default:
126
		ckpt.Status.Phase = nvidiacomv1alpha1.DynamoCheckpointPhasePending
127
128
129
130
131
132
133
134
		ckpt.Status.Message = ""
		needsStatusUpdate = true
	}
	if ckpt.Status.Phase == "" {
		ckpt.Status.Phase = nvidiacomv1alpha1.DynamoCheckpointPhasePending
		ckpt.Status.Message = ""
		needsStatusUpdate = true
	}
135
136
137
138
139
140
141
142
143
	if ckpt.Status.Phase != nvidiacomv1alpha1.DynamoCheckpointPhaseCreating &&
		ckpt.Status.JobName != "" &&
		ckpt.Status.JobName != desiredJobName {
		ckpt.Status.Phase = nvidiacomv1alpha1.DynamoCheckpointPhasePending
		ckpt.Status.JobName = ""
		ckpt.Status.CreatedAt = nil
		ckpt.Status.Message = ""
		needsStatusUpdate = true
	}
144
	if needsStatusUpdate {
145
		if err := r.Status().Update(ctx, ckpt); err != nil {
146
			logger.Error(err, "Failed to initialize DynamoCheckpoint status")
147
148
			return ctrl.Result{}, err
		}
149
150
151
		if phaseWasEmpty {
			return ctrl.Result{}, nil
		}
152
153
154
155
156
157
158
159
160
161
162
163
	}

	// Handle based on current phase
	switch ckpt.Status.Phase {
	case nvidiacomv1alpha1.DynamoCheckpointPhasePending:
		return r.handlePending(ctx, ckpt)
	case nvidiacomv1alpha1.DynamoCheckpointPhaseCreating:
		return r.handleCreating(ctx, ckpt)
	case nvidiacomv1alpha1.DynamoCheckpointPhaseReady:
		// Nothing to do, checkpoint is ready
		return ctrl.Result{}, nil
	case nvidiacomv1alpha1.DynamoCheckpointPhaseFailed:
164
		return ctrl.Result{}, nil
165
166
167
168
169
170
171
172
173
174
175
176
177
	default:
		// Unknown phase, reset to Pending
		ckpt.Status.Phase = nvidiacomv1alpha1.DynamoCheckpointPhasePending
		if err := r.Status().Update(ctx, ckpt); err != nil {
			return ctrl.Result{}, err
		}
		return ctrl.Result{}, nil
	}
}

func (r *CheckpointReconciler) handlePending(ctx context.Context, ckpt *nvidiacomv1alpha1.DynamoCheckpoint) (ctrl.Result, error) {
	logger := log.FromContext(ctx)

178
179
180
181
182
183
184
185
	hash := ckpt.Status.IdentityHash
	if hash == "" {
		var err error
		hash, err = checkpoint.ComputeIdentityHash(ckpt.Spec.Identity)
		if err != nil {
			return ctrl.Result{}, fmt.Errorf("failed to compute checkpoint identity hash: %w", err)
		}
	}
186
187
188
189
	jobName := snapshotprotocol.GetCheckpointJobName(
		hash,
		ckpt.Annotations[snapshotprotocol.CheckpointArtifactVersionAnnotation],
	)
190
191
192

	// Use SyncResource to create/update the checkpoint Job
	modified, _, err := commonController.SyncResource(ctx, r, ckpt, func(ctx context.Context) (*batchv1.Job, bool, error) {
193
		job, err := buildCheckpointJob(r.Config, ckpt, jobName)
194
		return job, false, err
195
196
197
198
199
200
201
202
203
204
205
206
207
	})
	if err != nil {
		logger.Error(err, "Failed to sync checkpoint Job")
		return ctrl.Result{}, err
	}

	if modified {
		logger.Info("Created/updated checkpoint Job", "job", jobName)
	}

	// Update status to Creating phase
	ckpt.Status.Phase = nvidiacomv1alpha1.DynamoCheckpointPhaseCreating
	ckpt.Status.JobName = jobName
208
	ckpt.Status.CreatedAt = nil
209
	ckpt.Status.Message = ""
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
	meta.SetStatusCondition(&ckpt.Status.Conditions, metav1.Condition{
		Type:               string(nvidiacomv1alpha1.DynamoCheckpointConditionJobCreated),
		Status:             metav1.ConditionTrue,
		Reason:             "JobCreated",
		Message:            fmt.Sprintf("Checkpoint job %s created", jobName),
		LastTransitionTime: metav1.Now(),
	})

	if err := r.Status().Update(ctx, ckpt); err != nil {
		return ctrl.Result{}, err
	}

	// Status update will trigger next reconcile via watch
	return ctrl.Result{}, nil
}

func (r *CheckpointReconciler) handleCreating(ctx context.Context, ckpt *nvidiacomv1alpha1.DynamoCheckpoint) (ctrl.Result, error) {
	logger := log.FromContext(ctx)

229
230
231
232
233
234
235
236
237
	if ckpt.Status.JobName == "" {
		ckpt.Status.Phase = nvidiacomv1alpha1.DynamoCheckpointPhasePending
		ckpt.Status.Message = "checkpoint job is missing from status"
		if err := r.Status().Update(ctx, ckpt); err != nil {
			return ctrl.Result{}, err
		}
		return ctrl.Result{}, nil
	}

238
239
240
241
	// Check Job status
	job := &batchv1.Job{}
	if err := r.Get(ctx, client.ObjectKey{Namespace: ckpt.Namespace, Name: ckpt.Status.JobName}, job); err != nil {
		if apierrors.IsNotFound(err) {
242
			ckpt.Status.Phase = nvidiacomv1alpha1.DynamoCheckpointPhaseFailed
243
			ckpt.Status.Message = "checkpoint job was deleted"
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
			meta.SetStatusCondition(&ckpt.Status.Conditions, metav1.Condition{
				Type:               string(nvidiacomv1alpha1.DynamoCheckpointConditionJobCreated),
				Status:             metav1.ConditionFalse,
				Reason:             "JobDeleted",
				Message:            "Checkpoint job was deleted",
				LastTransitionTime: metav1.Now(),
			})
			if err := r.Status().Update(ctx, ckpt); err != nil {
				return ctrl.Result{}, err
			}
			return ctrl.Result{}, nil
		}
		return ctrl.Result{}, err
	}

259
260
261
262
263
264
	var lease *coordinationv1.Lease
	leaseKey := client.ObjectKey{Namespace: job.Namespace, Name: job.Name}
	lease = &coordinationv1.Lease{}
	if err := r.Get(ctx, leaseKey, lease); err != nil {
		if !apierrors.IsNotFound(err) {
			return ctrl.Result{}, err
265
		}
266
267
268
269
270
271
272
273
274
275
276
277
		lease = nil
	}

	now := time.Now()
	checkpointWorkerActive := false
	if lease != nil && lease.Spec.LeaseDurationSeconds != nil {
		// The snapshot-agent owns and renews this lease while it is still finalizing
		// checkpoint state. A Job can complete before the agent writes the terminal
		// checkpoint annotation, so we keep requeuing until the lease is no longer active.
		lastRenewal := lease.Spec.RenewTime
		if lastRenewal == nil {
			lastRenewal = lease.Spec.AcquireTime
278
		}
279
280
		if lastRenewal != nil {
			checkpointWorkerActive = !now.After(lastRenewal.Time.Add(time.Duration(*lease.Spec.LeaseDurationSeconds) * time.Second))
281
282
283
		}
	}

284
	observation := snapshotprotocol.ObserveCheckpointJob(job, checkpointWorkerActive)
285
	switch observation.Phase {
286
	case snapshotprotocol.CheckpointObservationPhaseWaitingForConfirmation:
287
288
		logger.V(1).Info("Checkpoint job is complete but checkpoint worker is still active; waiting for terminal watcher status", "job", job.Name)
		return ctrl.Result{RequeueAfter: time.Second}, nil
289
	case snapshotprotocol.CheckpointObservationPhaseReady:
290
		logger.Info("Checkpoint Job succeeded", "job", job.Name)
291
		r.Recorder.Event(ckpt, corev1.EventTypeNormal, "CheckpointReady", observation.Message)
292
293

		now := metav1.Now()
294
295
		ckpt.Status.Phase = nvidiacomv1alpha1.DynamoCheckpointPhaseReady
		ckpt.Status.CreatedAt = &now
296
		ckpt.Status.Message = ""
297
298
299
		meta.SetStatusCondition(&ckpt.Status.Conditions, metav1.Condition{
			Type:               string(nvidiacomv1alpha1.DynamoCheckpointConditionJobCompleted),
			Status:             metav1.ConditionTrue,
300
301
			Reason:             observation.Reason,
			Message:            observation.Message,
302
303
304
305
306
307
			LastTransitionTime: metav1.Now(),
		})
		if err := r.Status().Update(ctx, ckpt); err != nil {
			return ctrl.Result{}, err
		}
		return ctrl.Result{}, nil
308
	case snapshotprotocol.CheckpointObservationPhaseFailed:
309
310
		logger.Info("Checkpoint Job failed", "job", job.Name, "message", observation.Message)
		r.Recorder.Event(ckpt, corev1.EventTypeWarning, "CheckpointFailed", observation.Message)
311
312

		ckpt.Status.Phase = nvidiacomv1alpha1.DynamoCheckpointPhaseFailed
313
		ckpt.Status.Message = observation.Message
314
315
316
		meta.SetStatusCondition(&ckpt.Status.Conditions, metav1.Condition{
			Type:               string(nvidiacomv1alpha1.DynamoCheckpointConditionJobCompleted),
			Status:             metav1.ConditionFalse,
317
318
			Reason:             observation.Reason,
			Message:            observation.Message,
319
320
321
322
323
324
			LastTransitionTime: metav1.Now(),
		})
		if err := r.Status().Update(ctx, ckpt); err != nil {
			return ctrl.Result{}, err
		}
		return ctrl.Result{}, nil
325
326
	default:
		return ctrl.Result{}, nil
327
328
329
330
331
332
333
334
335
336
337
338
339
340
	}
}

// SetupWithManager sets up the controller with the Manager.
func (r *CheckpointReconciler) SetupWithManager(mgr ctrl.Manager) error {
	return ctrl.NewControllerManagedBy(mgr).
		For(&nvidiacomv1alpha1.DynamoCheckpoint{}).
		Owns(&batchv1.Job{}, builder.WithPredicates(predicate.Funcs{
			// Ignore creation - we don't need to reconcile when we just created the Job
			CreateFunc:  func(ce event.CreateEvent) bool { return false },
			DeleteFunc:  func(de event.DeleteEvent) bool { return true },
			UpdateFunc:  func(ue event.UpdateEvent) bool { return true },
			GenericFunc: func(ge event.GenericEvent) bool { return true },
		})).
341
		WithEventFilter(commonController.EphemeralDeploymentEventFilter(r.Config, r.RuntimeConfig)).
342
343
		Complete(r)
}