dra.go 4.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
package cuda

import (
	"context"
	"fmt"

	"github.com/go-logr/logr"
	resourcev1 "k8s.io/api/resource/v1"
	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
	"k8s.io/client-go/kubernetes"
)

const (
	resourceAttributeUUID = "uuid"
)

17
18
19
20
21
22
type allocatedDRADevice struct {
	pool   string
	device string
}

func getAllocatedNVIDIADRADevices(ctx context.Context, clientset kubernetes.Interface, podName, podNamespace string, log logr.Logger) ([]allocatedDRADevice, string, bool, error) {
23
	if clientset == nil {
24
		return nil, "", false, nil
25
26
	}
	if podName == "" || podNamespace == "" {
27
		return nil, "", false, nil
28
29
30
31
	}

	pod, err := clientset.CoreV1().Pods(podNamespace).Get(ctx, podName, metav1.GetOptions{})
	if err != nil {
32
		return nil, "", false, fmt.Errorf("get pod %s/%s: %w", podNamespace, podName, err)
33
34
	}
	if len(pod.Spec.ResourceClaims) == 0 {
35
		return nil, pod.Spec.NodeName, false, nil
36
	}
37
	if pod.Spec.NodeName == "" {
38
		log.V(1).Info("pod has no node name, skipping DRA API lookup")
39
		return nil, "", false, nil
40
41
	}

42
43
44
45
46
47
48
49
50
51
52
53
54
	claimNamesByPodRef := make(map[string]string, len(pod.Spec.ResourceClaims))
	for _, ref := range pod.Spec.ResourceClaims {
		if ref.ResourceClaimName != nil && *ref.ResourceClaimName != "" {
			claimNamesByPodRef[ref.Name] = *ref.ResourceClaimName
		}
	}
	for _, status := range pod.Status.ResourceClaimStatuses {
		if status.ResourceClaimName == nil || *status.ResourceClaimName == "" {
			continue
		}
		if _, exists := claimNamesByPodRef[status.Name]; !exists {
			claimNamesByPodRef[status.Name] = *status.ResourceClaimName
		}
55
	}
56
57
58

	var allocated []allocatedDRADevice
	hasNVIDIADRAAllocation := false
59
	for _, ref := range pod.Spec.ResourceClaims {
60
61
62
		claimName := claimNamesByPodRef[ref.Name]
		if claimName == "" {
			log.V(1).Info("pod resource claim has no resolved claim name", "pod_claim", ref.Name)
63
64
65
66
			continue
		}
		claim, err := clientset.ResourceV1().ResourceClaims(podNamespace).Get(ctx, claimName, metav1.GetOptions{})
		if err != nil {
67
			return nil, pod.Spec.NodeName, hasNVIDIADRAAllocation, fmt.Errorf("get resource claim %s/%s: %w", podNamespace, claimName, err)
68
69
70
71
		}
		if claim.Status.Allocation == nil || len(claim.Status.Allocation.Devices.Results) == 0 {
			continue
		}
72
73
74
		for _, result := range claim.Status.Allocation.Devices.Results {
			if result.Driver != nvidiaGPUDRADriver {
				continue
75
			}
76
77
78
79
80
			hasNVIDIADRAAllocation = true
			allocated = append(allocated, allocatedDRADevice{
				pool:   result.Pool,
				device: result.Device,
			})
81
82
		}
	}
83
84
85
86
87
88
89
90
91
92
93
94
95
96

	return allocated, pod.Spec.NodeName, hasNVIDIADRAAllocation, nil
}

// GetGPUUUIDsViaDRAAPI resolves GPU UUIDs for a pod by querying the Kubernetes API:
// Pod (resource claim refs) -> ResourceClaim (allocation results) -> ResourceSlice (device attributes).
// It also reports whether the pod is using NVIDIA DRA GPU allocations at all.
func GetGPUUUIDsViaDRAAPI(ctx context.Context, clientset kubernetes.Interface, podName, podNamespace string, log logr.Logger) ([]string, bool, error) {
	allocated, nodeName, hasNVIDIADRAAllocation, err := getAllocatedNVIDIADRADevices(ctx, clientset, podName, podNamespace, log)
	if err != nil {
		return nil, hasNVIDIADRAAllocation, err
	}
	if !hasNVIDIADRAAllocation || len(allocated) == 0 {
		return nil, hasNVIDIADRAAllocation, nil
97
98
99
100
101
102
	}

	slices, err := clientset.ResourceV1().ResourceSlices().List(ctx, metav1.ListOptions{
		FieldSelector: fmt.Sprintf("spec.driver=%s,spec.nodeName=%s", nvidiaGPUDRADriver, nodeName),
	})
	if err != nil {
103
		return nil, true, fmt.Errorf("list resource slices for node %s: %w", nodeName, err)
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
	}

	poolDeviceToUUID := make(map[string]map[string]string)
	for i := range slices.Items {
		s := &slices.Items[i]
		poolName := s.Spec.Pool.Name
		if poolDeviceToUUID[poolName] == nil {
			poolDeviceToUUID[poolName] = make(map[string]string)
		}
		for _, dev := range s.Spec.Devices {
			uuid := deviceUUIDFromAttributes(dev.Attributes)
			if uuid != "" && gpuUUIDPattern.MatchString(uuid) {
				poolDeviceToUUID[poolName][dev.Name] = uuid
			}
		}
	}

	var uuids []string
122
123
	for _, device := range allocated {
		devMap := poolDeviceToUUID[device.pool]
124
		if devMap == nil {
125
			log.V(1).Info("no ResourceSlice found for pool", "pool", device.pool, "device", device.device)
126
127
			continue
		}
128
		uuid, ok := devMap[device.device]
129
		if !ok || uuid == "" {
130
			log.V(1).Info("device has no UUID in ResourceSlice", "pool", device.pool, "device", device.device)
131
132
133
134
135
136
137
			continue
		}
		uuids = append(uuids, uuid)
	}
	if len(uuids) > 0 {
		log.Info("resolved GPU UUIDs via DRA API", "uuids", uuids)
	}
138
	return uuids, true, nil
139
140
141
142
143
144
145
146
147
}

func deviceUUIDFromAttributes(attrs map[resourcev1.QualifiedName]resourcev1.DeviceAttribute) string {
	a, ok := attrs[resourcev1.QualifiedName(resourceAttributeUUID)]
	if !ok || a.StringValue == nil {
		return ""
	}
	return *a.StringValue
}