Unverified Commit 0488e1b2 authored by Thomas Montfort's avatar Thomas Montfort Committed by GitHub
Browse files

fix(operator): prevent orphaned old worker DCDs after rolling update (#7939)


Signed-off-by: default avatartmontfort <tmontfort@nvidia.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent 7f8e88fb
...@@ -210,10 +210,8 @@ func (r *DynamoGraphDeploymentReconciler) reconcileRollingUpdate( ...@@ -210,10 +210,8 @@ func (r *DynamoGraphDeploymentReconciler) reconcileRollingUpdate(
) error { ) error {
logger := log.FromContext(ctx) logger := log.FromContext(ctx)
// Get or create rollingUpdate status
rollingUpdateStatus := r.getOrCreateRollingUpdateStatus(dgd) rollingUpdateStatus := r.getOrCreateRollingUpdateStatus(dgd)
// Compute hash information
newWorkerHash := dynamo.ComputeDGDWorkersSpecHash(dgd) newWorkerHash := dynamo.ComputeDGDWorkersSpecHash(dgd)
prevWorkerHash := r.getCurrentWorkerHash(dgd) prevWorkerHash := r.getCurrentWorkerHash(dgd)
...@@ -247,25 +245,21 @@ func (r *DynamoGraphDeploymentReconciler) reconcileRollingUpdate( ...@@ -247,25 +245,21 @@ func (r *DynamoGraphDeploymentReconciler) reconcileRollingUpdate(
logger.Info("Detected stuck rolling update: hashes match but phase is InProgress", logger.Info("Detected stuck rolling update: hashes match but phase is InProgress",
"hash", newWorkerHash, "hash", newWorkerHash,
"phase", rollingUpdateStatus.Phase) "phase", rollingUpdateStatus.Phase)
return r.completeRollingUpdate(ctx, dgd, rollingUpdateStatus, newWorkerHash) return r.completeRollingUpdate(ctx, dgd, newWorkerHash)
} }
switch rollingUpdateStatus.Phase { switch rollingUpdateStatus.Phase {
case nvidiacomv1alpha1.RollingUpdatePhaseNone: case nvidiacomv1alpha1.RollingUpdatePhaseNone:
return r.startRollingUpdate(ctx, dgd, rollingUpdateStatus, newWorkerHash) return r.startRollingUpdate(ctx, dgd, newWorkerHash)
case nvidiacomv1alpha1.RollingUpdatePhasePending: case nvidiacomv1alpha1.RollingUpdatePhasePending:
rollingUpdateStatus.Phase = nvidiacomv1alpha1.RollingUpdatePhaseInProgress rollingUpdateStatus.Phase = nvidiacomv1alpha1.RollingUpdatePhaseInProgress
if err := r.Status().Update(ctx, dgd); err != nil { return nil // deferred function in Reconcile() persists status
return fmt.Errorf("failed to update rolling update status to InProgress: %w", err)
}
return nil
case nvidiacomv1alpha1.RollingUpdatePhaseInProgress: case nvidiacomv1alpha1.RollingUpdatePhaseInProgress:
return r.continueRollingUpdate(ctx, dgd, rollingUpdateStatus, newWorkerHash) return r.continueRollingUpdate(ctx, dgd, newWorkerHash)
case nvidiacomv1alpha1.RollingUpdatePhaseCompleted: case nvidiacomv1alpha1.RollingUpdatePhaseCompleted:
// Cleanup is now done atomically in completeRollingUpdate, nothing to do here
logger.Info("Rolling update already completed") logger.Info("Rolling update already completed")
return nil return nil
} }
...@@ -277,7 +271,6 @@ func (r *DynamoGraphDeploymentReconciler) reconcileRollingUpdate( ...@@ -277,7 +271,6 @@ func (r *DynamoGraphDeploymentReconciler) reconcileRollingUpdate(
func (r *DynamoGraphDeploymentReconciler) startRollingUpdate( func (r *DynamoGraphDeploymentReconciler) startRollingUpdate(
ctx context.Context, ctx context.Context,
dgd *nvidiacomv1alpha1.DynamoGraphDeployment, dgd *nvidiacomv1alpha1.DynamoGraphDeployment,
rollingUpdateStatus *nvidiacomv1alpha1.RollingUpdateStatus,
newWorkerHash string, newWorkerHash string,
) error { ) error {
logger := log.FromContext(ctx) logger := log.FromContext(ctx)
...@@ -289,6 +282,7 @@ func (r *DynamoGraphDeploymentReconciler) startRollingUpdate( ...@@ -289,6 +282,7 @@ func (r *DynamoGraphDeploymentReconciler) startRollingUpdate(
"newHash", newWorkerHash) "newHash", newWorkerHash)
now := metav1.Now() now := metav1.Now()
rollingUpdateStatus := r.getOrCreateRollingUpdateStatus(dgd)
rollingUpdateStatus.Phase = nvidiacomv1alpha1.RollingUpdatePhasePending rollingUpdateStatus.Phase = nvidiacomv1alpha1.RollingUpdatePhasePending
rollingUpdateStatus.StartTime = &now rollingUpdateStatus.StartTime = &now
rollingUpdateStatus.UpdatedServices = nil rollingUpdateStatus.UpdatedServices = nil
...@@ -296,18 +290,13 @@ func (r *DynamoGraphDeploymentReconciler) startRollingUpdate( ...@@ -296,18 +290,13 @@ func (r *DynamoGraphDeploymentReconciler) startRollingUpdate(
r.Recorder.Eventf(dgd, corev1.EventTypeNormal, "RollingUpdateStarted", r.Recorder.Eventf(dgd, corev1.EventTypeNormal, "RollingUpdateStarted",
"Starting rolling update from worker hash %s to %s", prevWorkerHash, newWorkerHash) "Starting rolling update from worker hash %s to %s", prevWorkerHash, newWorkerHash)
if err := r.Status().Update(ctx, dgd); err != nil { return nil // deferred function in Reconcile() persists status
return fmt.Errorf("failed to initialize rolling update status: %w", err)
}
return nil
} }
// continueRollingUpdate handles the in-progress phase of a rolling update. // continueRollingUpdate handles the in-progress phase of a rolling update.
func (r *DynamoGraphDeploymentReconciler) continueRollingUpdate( func (r *DynamoGraphDeploymentReconciler) continueRollingUpdate(
ctx context.Context, ctx context.Context,
dgd *nvidiacomv1alpha1.DynamoGraphDeployment, dgd *nvidiacomv1alpha1.DynamoGraphDeployment,
rollingUpdateStatus *nvidiacomv1alpha1.RollingUpdateStatus,
newWorkerHash string, newWorkerHash string,
) error { ) error {
logger := log.FromContext(ctx) logger := log.FromContext(ctx)
...@@ -355,6 +344,7 @@ func (r *DynamoGraphDeploymentReconciler) continueRollingUpdate( ...@@ -355,6 +344,7 @@ func (r *DynamoGraphDeploymentReconciler) continueRollingUpdate(
} }
} }
sort.Strings(updatedServices) sort.Strings(updatedServices)
rollingUpdateStatus := r.getOrCreateRollingUpdateStatus(dgd)
rollingUpdateStatus.UpdatedServices = updatedServices rollingUpdateStatus.UpdatedServices = updatedServices
// Count total worker services // Count total worker services
...@@ -367,15 +357,10 @@ func (r *DynamoGraphDeploymentReconciler) continueRollingUpdate( ...@@ -367,15 +357,10 @@ func (r *DynamoGraphDeploymentReconciler) continueRollingUpdate(
// Rolling update is complete when every worker service is individually updated // Rolling update is complete when every worker service is individually updated
if len(updatedServices) == totalWorkerServices && totalWorkerServices > 0 { if len(updatedServices) == totalWorkerServices && totalWorkerServices > 0 {
return r.completeRollingUpdate(ctx, dgd, rollingUpdateStatus, newWorkerHash) return r.completeRollingUpdate(ctx, dgd, newWorkerHash)
}
// Persist updated services list mid-rolling update
if err := r.Status().Update(ctx, dgd); err != nil {
return fmt.Errorf("failed to update rolling update status with updated services: %w", err)
} }
return nil return nil // deferred function in Reconcile() persists UpdatedServices
} }
// completeRollingUpdate marks the rolling update as completed, cleans up old resources, and updates status. // completeRollingUpdate marks the rolling update as completed, cleans up old resources, and updates status.
...@@ -383,22 +368,21 @@ func (r *DynamoGraphDeploymentReconciler) continueRollingUpdate( ...@@ -383,22 +368,21 @@ func (r *DynamoGraphDeploymentReconciler) continueRollingUpdate(
func (r *DynamoGraphDeploymentReconciler) completeRollingUpdate( func (r *DynamoGraphDeploymentReconciler) completeRollingUpdate(
ctx context.Context, ctx context.Context,
dgd *nvidiacomv1alpha1.DynamoGraphDeployment, dgd *nvidiacomv1alpha1.DynamoGraphDeployment,
rollingUpdateStatus *nvidiacomv1alpha1.RollingUpdateStatus,
newWorkerHash string, newWorkerHash string,
) error { ) error {
logger := log.FromContext(ctx) logger := log.FromContext(ctx)
// Delete all non-current worker DCDs (any number of old generations) // Delete all non-current worker DCDs (any number of old generations)
if err := r.deleteOldWorkerDCDs(ctx, dgd, newWorkerHash); err != nil { if err := r.deleteOldWorkerDCDs(ctx, dgd, newWorkerHash); err != nil {
logger.Error(err, "Failed to delete non-current worker DCDs", "newWorkerHash", newWorkerHash) return fmt.Errorf("failed to delete old worker DCDs: %w", err)
r.Recorder.Eventf(dgd, corev1.EventTypeWarning, "CleanupPartialFailure", }
"Failed to delete some old worker DCDs: %v", err)
// Continue anyway - we don't want cleanup failures to block the rolling update completion r.setCurrentWorkerHash(dgd, newWorkerHash)
} else { if err := r.Update(ctx, dgd); err != nil {
logger.Info("Old resources cleaned up", "newWorkerHash", newWorkerHash) return fmt.Errorf("failed to update current worker hash: %w", err)
} }
// Update rolling update status to Completed rollingUpdateStatus := r.getOrCreateRollingUpdateStatus(dgd)
rollingUpdateStatus.Phase = nvidiacomv1alpha1.RollingUpdatePhaseCompleted rollingUpdateStatus.Phase = nvidiacomv1alpha1.RollingUpdatePhaseCompleted
now := metav1.Now() now := metav1.Now()
rollingUpdateStatus.EndTime = &now rollingUpdateStatus.EndTime = &now
...@@ -416,16 +400,6 @@ func (r *DynamoGraphDeploymentReconciler) completeRollingUpdate( ...@@ -416,16 +400,6 @@ func (r *DynamoGraphDeploymentReconciler) completeRollingUpdate(
r.Recorder.Eventf(dgd, corev1.EventTypeNormal, "RollingUpdateCompleted", r.Recorder.Eventf(dgd, corev1.EventTypeNormal, "RollingUpdateCompleted",
"Rolling update completed, worker hash %s", newWorkerHash) "Rolling update completed, worker hash %s", newWorkerHash)
if err := r.Status().Update(ctx, dgd); err != nil {
return fmt.Errorf("failed to update rolling update status: %w", err)
}
// Update the current worker hash to the new hash
r.setCurrentWorkerHash(dgd, newWorkerHash)
if err := r.Update(ctx, dgd); err != nil {
return fmt.Errorf("failed to update current worker hash: %w", err)
}
logger.Info("Rolling update finalized", "newWorkerHash", newWorkerHash) logger.Info("Rolling update finalized", "newWorkerHash", newWorkerHash)
return nil return nil
......
...@@ -653,7 +653,7 @@ func TestContinueRollingUpdate_UpdatedServicesPartialCompletion(t *testing.T) { ...@@ -653,7 +653,7 @@ func TestContinueRollingUpdate_UpdatedServicesPartialCompletion(t *testing.T) {
ctx := context.Background() ctx := context.Background()
rollingUpdateStatus := dgd.Status.RollingUpdate rollingUpdateStatus := dgd.Status.RollingUpdate
err := r.continueRollingUpdate(ctx, dgd, rollingUpdateStatus, newWorkerHash) err := r.continueRollingUpdate(ctx, dgd, newWorkerHash)
require.NoError(t, err) require.NoError(t, err)
// Prefill is updated (new ready >= desired, old gone), decode is not // Prefill is updated (new ready >= desired, old gone), decode is not
...@@ -736,8 +736,8 @@ func TestContinueRollingUpdate_AggregateReadyButPerServiceNot(t *testing.T) { ...@@ -736,8 +736,8 @@ func TestContinueRollingUpdate_AggregateReadyButPerServiceNot(t *testing.T) {
r := createTestReconcilerWithStatus(dgd, newPrefillDCD, newDecodeDCD) r := createTestReconcilerWithStatus(dgd, newPrefillDCD, newDecodeDCD)
ctx := context.Background() ctx := context.Background()
rollingUpdateStatus := dgd.Status.RollingUpdate rollingUpdateStatus := r.getOrCreateRollingUpdateStatus(dgd)
err := r.continueRollingUpdate(ctx, dgd, rollingUpdateStatus, newWorkerHash) err := r.continueRollingUpdate(ctx, dgd, newWorkerHash)
require.NoError(t, err) require.NoError(t, err)
// Only prefill is updated; decode has 0 ready replicas // Only prefill is updated; decode has 0 ready replicas
...@@ -765,10 +765,10 @@ func TestStartRollingUpdate_UpdatedServicesInitializedToNil(t *testing.T) { ...@@ -765,10 +765,10 @@ func TestStartRollingUpdate_UpdatedServicesInitializedToNil(t *testing.T) {
r := createTestReconcilerWithStatus(dgd) r := createTestReconcilerWithStatus(dgd)
ctx := context.Background() ctx := context.Background()
rollingUpdateStatus := dgd.Status.RollingUpdate err := r.startRollingUpdate(ctx, dgd, testNewWorkerHash)
err := r.startRollingUpdate(ctx, dgd, rollingUpdateStatus, testNewWorkerHash)
require.NoError(t, err) require.NoError(t, err)
rollingUpdateStatus := r.getOrCreateRollingUpdateStatus(dgd)
assert.Nil(t, rollingUpdateStatus.UpdatedServices) assert.Nil(t, rollingUpdateStatus.UpdatedServices)
assert.Equal(t, nvidiacomv1alpha1.RollingUpdatePhasePending, rollingUpdateStatus.Phase) assert.Equal(t, nvidiacomv1alpha1.RollingUpdatePhasePending, rollingUpdateStatus.Phase)
} }
...@@ -801,14 +801,14 @@ func TestCompleteRollingUpdate_UpdatedServicesContainsAllWorkers(t *testing.T) { ...@@ -801,14 +801,14 @@ func TestCompleteRollingUpdate_UpdatedServicesContainsAllWorkers(t *testing.T) {
r := createTestReconcilerWithStatus(dgd) r := createTestReconcilerWithStatus(dgd)
ctx := context.Background() ctx := context.Background()
rollingUpdateStatus := dgd.Status.RollingUpdate err := r.completeRollingUpdate(ctx, dgd, newWorkerHash)
err := r.completeRollingUpdate(ctx, dgd, rollingUpdateStatus, newWorkerHash)
require.NoError(t, err) require.NoError(t, err)
// Should contain all worker services (sorted), but not frontend // Check dgd.Status.RollingUpdate directly because r.Update() inside completeRollingUpdate
assert.Equal(t, []string{"decode", "prefill"}, rollingUpdateStatus.UpdatedServices) // decodes the API server response back into dgd, and status is re-set after the update.
assert.Equal(t, nvidiacomv1alpha1.RollingUpdatePhaseCompleted, rollingUpdateStatus.Phase) assert.Equal(t, []string{"decode", "prefill"}, dgd.Status.RollingUpdate.UpdatedServices)
assert.NotNil(t, rollingUpdateStatus.EndTime) assert.Equal(t, nvidiacomv1alpha1.RollingUpdatePhaseCompleted, dgd.Status.RollingUpdate.Phase)
assert.NotNil(t, dgd.Status.RollingUpdate.EndTime)
} }
func TestContinueRollingUpdate_AllServicesUpdated(t *testing.T) { func TestContinueRollingUpdate_AllServicesUpdated(t *testing.T) {
...@@ -884,13 +884,14 @@ func TestContinueRollingUpdate_AllServicesUpdated(t *testing.T) { ...@@ -884,13 +884,14 @@ func TestContinueRollingUpdate_AllServicesUpdated(t *testing.T) {
r := createTestReconcilerWithStatus(dgd, newPrefillDCD, newDecodeDCD) r := createTestReconcilerWithStatus(dgd, newPrefillDCD, newDecodeDCD)
ctx := context.Background() ctx := context.Background()
rollingUpdateStatus := dgd.Status.RollingUpdate err := r.continueRollingUpdate(ctx, dgd, newWorkerHash)
err := r.continueRollingUpdate(ctx, dgd, rollingUpdateStatus, newWorkerHash)
require.NoError(t, err) require.NoError(t, err)
// Rolling update should complete, and all services should be listed // Rolling update should complete, and all services should be listed.
assert.Equal(t, nvidiacomv1alpha1.RollingUpdatePhaseCompleted, rollingUpdateStatus.Phase) // Check dgd.Status.RollingUpdate directly because r.Update() inside completeRollingUpdate
assert.Equal(t, []string{"decode", "prefill"}, rollingUpdateStatus.UpdatedServices) // decodes the API server response back into dgd, and status is re-set after the update.
assert.Equal(t, nvidiacomv1alpha1.RollingUpdatePhaseCompleted, dgd.Status.RollingUpdate.Phase)
assert.Equal(t, []string{"decode", "prefill"}, dgd.Status.RollingUpdate.UpdatedServices)
} }
func TestGetWorkerInfoForWorkerHash(t *testing.T) { func TestGetWorkerInfoForWorkerHash(t *testing.T) {
...@@ -2440,11 +2441,11 @@ func TestContinueRollingUpdate_CascadingSpecChange(t *testing.T) { ...@@ -2440,11 +2441,11 @@ func TestContinueRollingUpdate_CascadingSpecChange(t *testing.T) {
r := createTestReconcilerWithStatus(dgd, genADCD, genBDCD, genCDCD) r := createTestReconcilerWithStatus(dgd, genADCD, genBDCD, genCDCD)
ctx := context.Background() ctx := context.Background()
rollingUpdateStatus := dgd.Status.RollingUpdate err := r.continueRollingUpdate(ctx, dgd, newWorkerHash)
err := r.continueRollingUpdate(ctx, dgd, rollingUpdateStatus, newWorkerHash)
require.NoError(t, err) require.NoError(t, err)
// Both A and B have ready replicas, C has 0 — rolling update not complete // Both A and B have ready replicas, C has 0 — rolling update not complete
rollingUpdateStatus := r.getOrCreateRollingUpdateStatus(dgd)
assert.Equal(t, nvidiacomv1alpha1.RollingUpdatePhaseInProgress, rollingUpdateStatus.Phase) assert.Equal(t, nvidiacomv1alpha1.RollingUpdatePhaseInProgress, rollingUpdateStatus.Phase)
assert.Empty(t, rollingUpdateStatus.UpdatedServices, "No services should be fully updated yet") assert.Empty(t, rollingUpdateStatus.UpdatedServices, "No services should be fully updated yet")
} }
...@@ -2686,3 +2687,32 @@ func TestReconcileRollingUpdate_NonePhaseStartsRollout(t *testing.T) { ...@@ -2686,3 +2687,32 @@ func TestReconcileRollingUpdate_NonePhaseStartsRollout(t *testing.T) {
assert.NotNil(t, dgd.Status.RollingUpdate.StartTime) assert.NotNil(t, dgd.Status.RollingUpdate.StartTime)
assert.Nil(t, dgd.Status.RollingUpdate.UpdatedServices) assert.Nil(t, dgd.Status.RollingUpdate.UpdatedServices)
} }
func TestReconcileRollingUpdate_StuckDetection_CompletesViaCompleteRollingUpdate(t *testing.T) {
// Stuck case: hashes match but phase is InProgress (e.g., operator restarted between
// annotation write and status persistence). Should call completeRollingUpdate which
// cleans up old DCDs, updates annotation, and sets Completed.
dgd := createTestDGD("test-dgd", map[string]*nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec{
"prefill": {ComponentType: consts.ComponentTypePrefill},
"decode": {ComponentType: consts.ComponentTypeDecode},
})
hash := dynamo.ComputeDGDWorkersSpecHash(dgd)
dgd.Annotations = map[string]string{consts.AnnotationCurrentWorkerHash: hash}
dgd.Status.RollingUpdate = &nvidiacomv1alpha1.RollingUpdateStatus{
Phase: nvidiacomv1alpha1.RollingUpdatePhaseInProgress,
}
r := createTestReconcilerWithStatus(dgd)
err := r.reconcileRollingUpdate(context.Background(), dgd)
require.NoError(t, err)
// Phase should be Completed
assert.Equal(t, nvidiacomv1alpha1.RollingUpdatePhaseCompleted, dgd.Status.RollingUpdate.Phase)
// EndTime should be set
assert.NotNil(t, dgd.Status.RollingUpdate.EndTime)
// UpdatedServices should contain all worker services
assert.Contains(t, dgd.Status.RollingUpdate.UpdatedServices, "prefill")
assert.Contains(t, dgd.Status.RollingUpdate.UpdatedServices, "decode")
// Annotation should still have the correct hash
assert.Equal(t, hash, dgd.Annotations[consts.AnnotationCurrentWorkerHash])
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment