resource.go 11.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
/*
 * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package controller_common

import (
	"context"
	"crypto/sha256"
	"encoding/json"
	"fmt"
25
	"reflect"
26
27
	"sort"

28
	corev1 "k8s.io/api/core/v1"
29
	"k8s.io/apimachinery/pkg/api/errors"
30
31
	"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
	"k8s.io/apimachinery/pkg/runtime"
32
	"k8s.io/apimachinery/pkg/types"
33
34
	"k8s.io/client-go/tools/record"
	ctrl "sigs.k8s.io/controller-runtime"
35
36
	"sigs.k8s.io/controller-runtime/pkg/client"
	"sigs.k8s.io/controller-runtime/pkg/client/apiutil"
37
	"sigs.k8s.io/controller-runtime/pkg/log"
38
39
40
41
42
43
44
)

const (
	// NvidiaAnnotationHashKey indicates annotation name for last applied hash by the operator
	NvidiaAnnotationHashKey = "nvidia.com/last-applied-hash"
)

45
46
47
type Reconciler interface {
	client.Client
	GetRecorder() record.EventRecorder
48
49
}

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
// ResourceGenerator is a function that generates a resource.
// it must return the resource, a boolean indicating if the resource should be deleted, and an error
// if the resource should be deleted, the returned resource must contain the necessary information to delete it (name and namespace)
type ResourceGenerator[T client.Object] func(ctx context.Context) (T, bool, error)

//nolint:nakedret
func SyncResource[T client.Object](ctx context.Context, r Reconciler, parentResource client.Object, generateResource ResourceGenerator[T]) (modified bool, res T, err error) {
	logs := log.FromContext(ctx)

	resource, toDelete, err := generateResource(ctx)
	if err != nil {
		return
	}
	resourceNamespace := resource.GetNamespace()
	resourceName := resource.GetName()
	resourceType := reflect.TypeOf(resource).Elem().Name()
	logs = logs.WithValues("namespace", resourceNamespace, "resourceName", resourceName, "resourceType", resourceType)

68
	// Retrieve the GroupVersionKind (GVK) of the desired object
69
	gvk, err := apiutil.GVKForObject(resource, r.Scheme())
70
	if err != nil {
71
72
		logs.Error(err, "Failed to get GVK for object")
		return
73
74
75
	}

	// Create a new instance of the object
76
	obj, err := r.Scheme().New(gvk)
77
	if err != nil {
78
79
		logs.Error(err, "Failed to create a new object for GVK")
		return
80
81
82
	}

	// Type assertion to ensure the object implements client.Object
83
	oldResource, ok := obj.(T)
84
	if !ok {
85
		return
86
	}
87
88
89
90
91
92
93

	err = r.Get(ctx, types.NamespacedName{Name: resourceName, Namespace: resourceNamespace}, oldResource)
	oldResourceIsNotFound := errors.IsNotFound(err)
	if err != nil && !oldResourceIsNotFound {
		r.GetRecorder().Eventf(parentResource, corev1.EventTypeWarning, fmt.Sprintf("Get%s", resourceType), "Failed to get %s %s: %s", resourceType, resourceNamespace, err)
		logs.Error(err, "Failed to get resource.")
		return
94
	}
95
	err = nil
96

97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
	if oldResourceIsNotFound {
		if toDelete {
			logs.Info("Resource not found. Nothing to do.")
			return
		}
		logs.Info("Resource not found. Creating a new one.")

		err = ctrl.SetControllerReference(parentResource, resource, r.Scheme())
		if err != nil {
			logs.Error(err, "Failed to set controller reference.")
			r.GetRecorder().Eventf(parentResource, corev1.EventTypeWarning, "SetControllerReference", "Failed to set controller reference for %s %s: %s", resourceType, resourceNamespace, err)
			return
		}

		var hash string
		hash, err = GetSpecHash(resource)
		if err != nil {
			logs.Error(err, "Failed to get spec hash.")
			r.GetRecorder().Eventf(parentResource, corev1.EventTypeWarning, "GetSpecHash", "Failed to get spec hash for %s %s: %s", resourceType, resourceNamespace, err)
			return
		}

		updateHashAnnotation(resource, hash)

		r.GetRecorder().Eventf(parentResource, corev1.EventTypeNormal, fmt.Sprintf("Create%s", resourceType), "Creating a new %s %s", resourceType, resourceNamespace)
		err = r.Create(ctx, resource)
		if err != nil {
			logs.Error(err, "Failed to create Resource.")
			r.GetRecorder().Eventf(parentResource, corev1.EventTypeWarning, fmt.Sprintf("Create%s", resourceType), "Failed to create %s %s: %s", resourceType, resourceNamespace, err)
			return
		}
		logs.Info(fmt.Sprintf("%s created.", resourceType))
		r.GetRecorder().Eventf(parentResource, corev1.EventTypeNormal, fmt.Sprintf("Create%s", resourceType), "Created %s %s", resourceType, resourceNamespace)
		modified = true
		res = resource
	} else {
		logs.Info(fmt.Sprintf("%s found.", resourceType))
		if toDelete {
			logs.Info(fmt.Sprintf("%s not found. Deleting the existing one.", resourceType))
			err = r.Delete(ctx, oldResource)
			if err != nil {
				logs.Error(err, fmt.Sprintf("Failed to delete %s.", resourceType))
				r.GetRecorder().Eventf(parentResource, corev1.EventTypeWarning, fmt.Sprintf("Delete%s", resourceType), "Failed to delete %s %s: %s", resourceType, resourceNamespace, err)
				return
141
			}
142
143
144
145
146
147
148
149
150
151
152
153
			logs.Info(fmt.Sprintf("%s deleted.", resourceType))
			r.GetRecorder().Eventf(parentResource, corev1.EventTypeNormal, fmt.Sprintf("Delete%s", resourceType), "Deleted %s %s", resourceType, resourceNamespace)
			modified = true
			return
		}

		// Check if the Spec has changed and update if necessary
		var newHash *string
		newHash, err = IsSpecChanged(oldResource, resource)
		if err != nil {
			r.GetRecorder().Eventf(parentResource, corev1.EventTypeWarning, fmt.Sprintf("CalculatePatch%s", resourceType), "Failed to calculate patch for %s %s: %s", resourceType, resourceNamespace, err)
			return false, resource, fmt.Errorf("failed to check if spec has changed: %w", err)
154
		}
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
		if newHash != nil {
			// update the spec of the current object with the desired spec
			err = CopySpec(resource, oldResource)
			if err != nil {
				logs.Error(err, fmt.Sprintf("Failed to copy spec for %s.", resourceType))
				r.GetRecorder().Eventf(parentResource, corev1.EventTypeWarning, fmt.Sprintf("CopySpec%s", resourceType), "Failed to copy spec for %s %s: %s", resourceType, resourceNamespace, err)
				return
			}

			updateHashAnnotation(oldResource, *newHash)

			err = r.Update(ctx, oldResource)
			if err != nil {
				logs.Error(err, fmt.Sprintf("Failed to update %s.", resourceType))
				r.GetRecorder().Eventf(parentResource, corev1.EventTypeWarning, fmt.Sprintf("Update%s", resourceType), "Failed to update %s %s: %s", resourceType, resourceNamespace, err)
				return
			}
			logs.Info(fmt.Sprintf("%s updated.", resourceType))
			r.GetRecorder().Eventf(parentResource, corev1.EventTypeNormal, fmt.Sprintf("Update%s", resourceType), "Updated %s %s", resourceType, resourceNamespace)
			modified = true
			res = oldResource
		} else {
			logs.Info(fmt.Sprintf("%s spec is the same. Skipping update.", resourceType))
			r.GetRecorder().Eventf(parentResource, corev1.EventTypeNormal, fmt.Sprintf("Update%s", resourceType), "Skipping update %s %s", resourceType, resourceNamespace)
			res = oldResource
		}
	}
	return
}

// CopySpec copies only the Spec field from source to destination using Unstructured
func CopySpec(source, destination client.Object) error {
	// Convert source to unstructured
	sourceMap, err := runtime.DefaultUnstructuredConverter.ToUnstructured(source)
	if err != nil {
		return err
	}
	sourceUnstructured := &unstructured.Unstructured{Object: sourceMap}

	// Convert destination to unstructured
	destMap, err := runtime.DefaultUnstructuredConverter.ToUnstructured(destination)
	if err != nil {
		return err
	}
	destUnstructured := &unstructured.Unstructured{Object: destMap}

	// Extract only the spec from source
	sourceSpec, found, err := unstructured.NestedFieldCopy(sourceUnstructured.Object, "spec")
	if err != nil {
		return err
	}
	if !found {
		return fmt.Errorf("spec not found in source object")
208
209
	}

210
211
212
213
	// Set the spec in the destination
	err = unstructured.SetNestedField(destUnstructured.Object, sourceSpec, "spec")
	if err != nil {
		return err
214
215
	}

216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
	// Convert back to the original object
	return runtime.DefaultUnstructuredConverter.FromUnstructured(destUnstructured.Object, destination)
}

func getSpec(obj client.Object) (any, error) {
	// Convert source to unstructured
	sourceMap, err := runtime.DefaultUnstructuredConverter.ToUnstructured(obj)
	if err != nil {
		return nil, err
	}
	sourceUnstructured := &unstructured.Unstructured{Object: sourceMap}
	// Extract only the spec from source
	spec, found, err := unstructured.NestedFieldCopy(sourceUnstructured.Object, "spec")
	if err != nil {
		return nil, err
	}
	if !found {
		return nil, nil
	}
	return spec, nil
}

// IsSpecChanged returns the new hash if the spec has changed between the existing one
func IsSpecChanged(current client.Object, desired client.Object) (*string, error) {
	hashStr, err := GetSpecHash(desired)
	if err != nil {
		return nil, err
	}
	if currentHash, ok := current.GetAnnotations()[NvidiaAnnotationHashKey]; ok {
		if currentHash == hashStr {
			return nil, nil
247
248
		}
	}
249
250
	return &hashStr, nil
}
251

252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
func GetSpecHash(obj client.Object) (string, error) {
	spec, err := getSpec(obj)
	if err != nil {
		return "", err
	}
	return GetResourceHash(spec)
}

func updateHashAnnotation(obj client.Object, hash string) {
	annotations := obj.GetAnnotations()
	if annotations == nil {
		annotations = map[string]string{}
	}
	annotations[NvidiaAnnotationHashKey] = hash
	obj.SetAnnotations(annotations)
267
268
269
}

// GetResourceHash returns a consistent hash for the given object spec
270
func GetResourceHash(obj any) (string, error) {
271
272
273
	// Convert obj to a map[string]interface{}
	objMap, err := json.Marshal(obj)
	if err != nil {
274
		return "", err
275
276
277
278
	}

	var objData map[string]interface{}
	if err := json.Unmarshal(objMap, &objData); err != nil {
279
		return "", err
280
281
282
283
284
285
286
287
	}

	// Sort keys to ensure consistent serialization
	sortedObjData := SortKeys(objData)

	// Serialize to JSON
	serialized, err := json.Marshal(sortedObjData)
	if err != nil {
288
		return "", err
289
290
291
292
293
	}

	// Compute the hash
	hasher := sha256.New()
	hasher.Write(serialized)
294
	return fmt.Sprintf("%x", hasher.Sum(nil)), nil
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
}

// SortKeys recursively sorts the keys of a map to ensure consistent serialization
func SortKeys(obj interface{}) interface{} {
	switch obj := obj.(type) {
	case map[string]interface{}:
		sortedMap := make(map[string]interface{})
		keys := make([]string, 0, len(obj))
		for k := range obj {
			keys = append(keys, k)
		}
		sort.Strings(keys)
		for _, k := range keys {
			sortedMap[k] = SortKeys(obj[k])
		}
		return sortedMap
	case []interface{}:
		// Check if the slice contains maps and sort them by the "name" field or the first available field
		if len(obj) > 0 {

			if _, ok := obj[0].(map[string]interface{}); ok {
				sort.SliceStable(obj, func(i, j int) bool {
					iMap, iOk := obj[i].(map[string]interface{})
					jMap, jOk := obj[j].(map[string]interface{})
					if iOk && jOk {
						// Try to sort by "name" if present
						iName, iNameOk := iMap["name"].(string)
						jName, jNameOk := jMap["name"].(string)
						if iNameOk && jNameOk {
							return iName < jName
						}

						// If "name" is not available, sort by the first key in each map
						if len(iMap) > 0 && len(jMap) > 0 {
							iFirstKey := firstKey(iMap)
							jFirstKey := firstKey(jMap)
							return iFirstKey < jFirstKey
						}
					}
					// If no valid comparison is possible, maintain the original order
					return false
				})
			}
		}
		for i, v := range obj {
			obj[i] = SortKeys(v)
		}
	}
	return obj
}

// Helper function to get the first key of a map (alphabetically sorted)
func firstKey(m map[string]interface{}) string {
	keys := make([]string, 0, len(m))
	for k := range m {
		keys = append(keys, k)
	}
	sort.Strings(keys)
	return keys[0]
}