"vscode:/vscode.git/clone" did not exist on "781fa062fb5377993b4b7cc65a07e4cea4d38283"
Unverified Commit 57648c19 authored by devivasudevan's avatar devivasudevan Committed by GitHub
Browse files

feat: GPU discovery extension using DCGM exporter for advanced metrics. (#6705)


Signed-off-by: default avatardevivasudevan <49675305+devivasudevan@users.noreply.github.com>
parent 1fc50263
...@@ -68,6 +68,7 @@ import ( ...@@ -68,6 +68,7 @@ import (
internalcert "github.com/ai-dynamo/dynamo/deploy/operator/internal/cert" internalcert "github.com/ai-dynamo/dynamo/deploy/operator/internal/cert"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/controller" "github.com/ai-dynamo/dynamo/deploy/operator/internal/controller"
commonController "github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common" commonController "github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/gpu"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/modelendpoint" "github.com/ai-dynamo/dynamo/deploy/operator/internal/modelendpoint"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/namespace_scope" "github.com/ai-dynamo/dynamo/deploy/operator/internal/namespace_scope"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/observability" "github.com/ai-dynamo/dynamo/deploy/operator/internal/observability"
...@@ -639,6 +640,8 @@ func registerControllers( ...@@ -639,6 +640,8 @@ func registerControllers(
Recorder: mgr.GetEventRecorderFor("dynamographdeploymentrequest"), Recorder: mgr.GetEventRecorderFor("dynamographdeploymentrequest"),
Config: operatorCfg, Config: operatorCfg,
RuntimeConfig: runtimeConfig, RuntimeConfig: runtimeConfig,
GPUDiscoveryCache: gpu.NewGPUDiscoveryCache(),
GPUDiscovery: gpu.NewGPUDiscovery(gpu.ScrapeMetricsEndpoint),
RBACManager: rbacManager, RBACManager: rbacManager,
}).SetupWithManager(mgr); err != nil { }).SetupWithManager(mgr); err != nil {
return fmt.Errorf("unable to create DynamoGraphDeploymentRequest controller: %w", err) return fmt.Errorf("unable to create DynamoGraphDeploymentRequest controller: %w", err)
......
...@@ -249,7 +249,8 @@ type DynamoGraphDeploymentRequestReconciler struct { ...@@ -249,7 +249,8 @@ type DynamoGraphDeploymentRequestReconciler struct {
Recorder record.EventRecorder Recorder record.EventRecorder
Config *configv1alpha1.OperatorConfiguration Config *configv1alpha1.OperatorConfiguration
RuntimeConfig *commonController.RuntimeConfig RuntimeConfig *commonController.RuntimeConfig
GPUDiscoveryCache *gpu.GPUDiscoveryCache
GPUDiscovery *gpu.GPUDiscovery
// RBACMgr handles RBAC setup for profiling jobs // RBACMgr handles RBAC setup for profiling jobs
RBACManager RBACManager RBACManager RBACManager
} }
...@@ -866,14 +867,6 @@ func (r *DynamoGraphDeploymentRequestReconciler) validateGPUHardwareInfo(ctx con ...@@ -866,14 +867,6 @@ func (r *DynamoGraphDeploymentRequestReconciler) validateGPUHardwareInfo(ctx con
return nil return nil
} }
_, err := gpu.DiscoverGPUs(ctx, r.APIReader)
if err == nil {
// GPU discovery is available, validation passes
return nil
}
logger.Info("GPU discovery not available", "reason", err.Error())
isNamespaceScoped := r.Config.Namespace.Restricted != "" isNamespaceScoped := r.Config.Namespace.Restricted != ""
if isNamespaceScoped { if isNamespaceScoped {
return fmt.Errorf( return fmt.Errorf(
...@@ -887,9 +880,63 @@ func (r *DynamoGraphDeploymentRequestReconciler) validateGPUHardwareInfo(ctx con ...@@ -887,9 +880,63 @@ func (r *DynamoGraphDeploymentRequestReconciler) validateGPUHardwareInfo(ctx con
"\n vramMb: 81920") "\n vramMb: 81920")
} }
_, err := r.GPUDiscovery.DiscoverGPUsFromDCGM(ctx, r.APIReader, r.GPUDiscoveryCache)
if err == nil {
// GPU discovery is available, validation passes
return nil
}
// Refine the logger message
reason := GetGPUDiscoveryFailureReason(err)
logger.Info("GPU discovery not available", "reason", reason, "error", err.Error())
return fmt.Errorf("GPU hardware info required but auto-discovery failed. Add spec.hardware.gpuSku, spec.hardware.vramMb, spec.hardware.numGpusPerNode") return fmt.Errorf("GPU hardware info required but auto-discovery failed. Add spec.hardware.gpuSku, spec.hardware.vramMb, spec.hardware.numGpusPerNode")
} }
// GetGPUDiscoveryFailureReason classifies a GPU discovery error and
// returns a stable, actionable reason string suitable for structured logging.
//
// The classification is based on known error message patterns produced during:
// - DCGM exporter pod discovery
// - Helm-based GPU operator and DCGM discovery
// - Metrics scraping
// - Prometheus parsing
//
// If the error does not match any known category, "unknown" is returned.
func GetGPUDiscoveryFailureReason(err error) string {
if err == nil {
return "unknown"
}
errMsg := strings.ToLower(err.Error())
switch {
case strings.Contains(errMsg, "list pods"):
return "failed to list DCGM exporter pods (RBAC/cluster connectivity issue)"
case strings.Contains(errMsg, "gpu operator is not installed"):
return "GPU Operator not installed in expected namespace"
case strings.Contains(errMsg, "helm init failed"):
return "failed to initialize Helm client (RBAC, kubeconfig, or Helm driver issue)"
case strings.Contains(errMsg, "timeout waiting for dcgm exporter pods"):
return "timeout while waiting for DCGM exporter pods to become ready"
case strings.Contains(errMsg, "http get"):
return "failed to reach DCGM metrics endpoint on pod (network/port issue)"
case strings.Contains(errMsg, "metrics endpoint") &&
strings.Contains(errMsg, "status"):
return "DCGM pod metrics endpoint returned non-200 status"
case strings.Contains(errMsg, "parse prometheus metrics"):
return "failed to parse dcgm Prometheus metrics (invalid format)"
case strings.Contains(errMsg, "no gpus detected"):
return "no GPUs detected in dcgm metrics (GPU model or metrics missing)"
case strings.Contains(errMsg, "dcgm is not enabled in the GPU Operator"):
return "DCGM is not enabled in the GPU Operator (check GPU Operator configuration and permissions)"
case strings.Contains(errMsg, "failed to scrape any dcgm exporter pod"):
return "failed to scrape any dcgm exporter pod (check DCGM exporter pod status and network connectivity)"
case strings.Contains(errMsg, "no gpu metrics could be parsed from any dcgm pod"):
return "no GPU metrics could be parsed from any DCGM pod (check DCGM exporter pod status and network connectivity)"
case strings.Contains(errMsg, "failed to create helm path"):
return "failed to initialize Helm client (RBAC, kubeconfig, or Helm driver issue)"
}
return "unknown"
}
// createProfilingJob creates a Kubernetes Job for profiling using SyncResource // createProfilingJob creates a Kubernetes Job for profiling using SyncResource
func (r *DynamoGraphDeploymentRequestReconciler) createProfilingJob(ctx context.Context, dgdr *nvidiacomv1beta1.DynamoGraphDeploymentRequest) error { func (r *DynamoGraphDeploymentRequestReconciler) createProfilingJob(ctx context.Context, dgdr *nvidiacomv1beta1.DynamoGraphDeploymentRequest) error {
logger := log.FromContext(ctx) logger := log.FromContext(ctx)
...@@ -1203,20 +1250,35 @@ func (r *DynamoGraphDeploymentRequestReconciler) enrichHardwareFromDiscovery(ctx ...@@ -1203,20 +1250,35 @@ func (r *DynamoGraphDeploymentRequestReconciler) enrichHardwareFromDiscovery(ctx
return nil // all fields already set by user; TotalGPUs is filled below when discovery runs return nil // all fields already set by user; TotalGPUs is filled below when discovery runs
} }
gpuInfo, err := gpu.DiscoverGPUs(ctx, r.APIReader) var gpuInfo *gpu.GPUInfo
logger := log.FromContext(ctx)
// Check if user provided hardware info in the typed spec
hasManualConfig := dgdr.Spec.Hardware != nil && (dgdr.Spec.Hardware.GPUSKU != "" ||
dgdr.Spec.Hardware.VRAMMB != nil ||
dgdr.Spec.Hardware.NumGPUsPerNode != nil)
if !hasManualConfig {
logger.Info("Attempting GPU discovery for profiling job")
discoveredInfo, err := r.GPUDiscovery.DiscoverGPUsFromDCGM(ctx, r.APIReader, r.GPUDiscoveryCache)
if err != nil { if err != nil {
// This path is expected for namespace-restricted operators without node read permissions
// Refine the logger message
reason := GetGPUDiscoveryFailureReason(err)
logger.Info("GPU discovery not available, using manual hardware configuration from profiling config",
"reason", reason, "error", err.Error())
return err return err
} } else {
gpuInfo = discoveredInfo
logger := log.FromContext(ctx)
logger.Info("GPU discovery completed successfully", logger.Info("GPU discovery completed successfully",
"gpusPerNode", gpuInfo.GPUsPerNode, "gpusPerNode", gpuInfo.GPUsPerNode,
"nodesWithGPUs", gpuInfo.NodesWithGPUs, "nodesWithGPUs", gpuInfo.NodesWithGPUs,
"totalGpus", gpuInfo.GPUsPerNode*gpuInfo.NodesWithGPUs, "totalGpus", gpuInfo.GPUsPerNode*gpuInfo.NodesWithGPUs,
"model", gpuInfo.Model, "model", gpuInfo.Model,
"vramMiB", gpuInfo.VRAMPerGPU,
"system", gpuInfo.System, "system", gpuInfo.System,
"vramMiB", gpuInfo.VRAMPerGPU) "cloudprovider", gpuInfo.CloudProvider)
}
}
if hw.GPUSKU == "" { if hw.GPUSKU == "" {
if gpuInfo.System != "" { if gpuInfo.System != "" {
hw.GPUSKU = gpuInfo.System hw.GPUSKU = gpuInfo.System
......
...@@ -25,6 +25,7 @@ import ( ...@@ -25,6 +25,7 @@ import (
dgdv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1" dgdv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
nvidiacomv1beta1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1beta1" nvidiacomv1beta1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1beta1"
commonController "github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common" commonController "github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/gpu"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
batchv1 "k8s.io/api/batch/v1" batchv1 "k8s.io/api/batch/v1"
...@@ -1422,6 +1423,18 @@ spec: ...@@ -1422,6 +1423,18 @@ spec:
Expect(k8sClient.Create(ctx, dgdr)).Should(Succeed()) Expect(k8sClient.Create(ctx, dgdr)).Should(Succeed())
defer func() { _ = k8sClient.Delete(ctx, dgdr) }() defer func() { _ = k8sClient.Delete(ctx, dgdr) }()
mockGPU := &gpu.GPUInfo{
GPUsPerNode: 8,
VRAMPerGPU: 81920,
System: "H100-SXM5-80GB",
NodesWithGPUs: 1,
}
cache := gpu.NewGPUDiscoveryCache()
cache.Set(mockGPU, 10*time.Minute)
reconciler.GPUDiscoveryCache = cache
reconciler.GPUDiscovery = gpu.NewGPUDiscovery(nil)
reconciler.APIReader = k8sClient
// Reconcile - should succeed with GPU discovery // Reconcile - should succeed with GPU discovery
_, err := reconciler.Reconcile(ctx, reconcile.Request{ _, err := reconciler.Reconcile(ctx, reconcile.Request{
NamespacedName: types.NamespacedName{ NamespacedName: types.NamespacedName{
...@@ -1535,6 +1548,18 @@ spec: ...@@ -1535,6 +1548,18 @@ spec:
Expect(k8sClient.Create(ctx, dgdr)).Should(Succeed()) Expect(k8sClient.Create(ctx, dgdr)).Should(Succeed())
defer func() { _ = k8sClient.Delete(ctx, dgdr) }() defer func() { _ = k8sClient.Delete(ctx, dgdr) }()
mockGPU := &gpu.GPUInfo{
GPUsPerNode: 8,
VRAMPerGPU: 81920,
System: "H100-SXM5-80GB",
NodesWithGPUs: 1,
}
cache := gpu.NewGPUDiscoveryCache()
cache.Set(mockGPU, 10*time.Minute)
reconciler.GPUDiscoveryCache = cache
reconciler.GPUDiscovery = gpu.NewGPUDiscovery(nil)
reconciler.APIReader = k8sClient
// Reconcile - should succeed with GPU discovery // Reconcile - should succeed with GPU discovery
_, err := reconciler.Reconcile(ctx, reconcile.Request{ _, err := reconciler.Reconcile(ctx, reconcile.Request{
NamespacedName: types.NamespacedName{ NamespacedName: types.NamespacedName{
...@@ -1647,6 +1672,17 @@ spec: ...@@ -1647,6 +1672,17 @@ spec:
Expect(k8sClient.Create(ctx, dgdr)).Should(Succeed()) Expect(k8sClient.Create(ctx, dgdr)).Should(Succeed())
defer func() { _ = k8sClient.Delete(ctx, dgdr) }() defer func() { _ = k8sClient.Delete(ctx, dgdr) }()
mockGPU := &gpu.GPUInfo{
GPUsPerNode: 8,
VRAMPerGPU: 81920,
System: "H100-SXM5-80GB",
NodesWithGPUs: 1,
}
cache := gpu.NewGPUDiscoveryCache()
cache.Set(mockGPU, 10*time.Minute)
reconciler.GPUDiscoveryCache = cache
reconciler.GPUDiscovery = gpu.NewGPUDiscovery(nil)
reconciler.APIReader = k8sClient
// Reconcile - should pick H100 (8 GPUs > 4 GPUs) // Reconcile - should pick H100 (8 GPUs > 4 GPUs)
_, err := reconciler.Reconcile(ctx, reconcile.Request{ _, err := reconciler.Reconcile(ctx, reconcile.Request{
NamespacedName: types.NamespacedName{ NamespacedName: types.NamespacedName{
......
...@@ -75,7 +75,7 @@ func TestEnrichHardwareFromDiscovery_UsesAICSystemIdentifier(t *testing.T) { ...@@ -75,7 +75,7 @@ func TestEnrichHardwareFromDiscovery_UsesAICSystemIdentifier(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
gfdProduct string // raw GFD label value gfdProduct string // raw GFD label value
expectedGPUSKU string // what the profiler needs expectedGPUSKU nvidiacomv1beta1.GPUSKUType // what the profiler needs
}{ }{
{ {
name: "B200 GFD label maps to AIC system identifier", name: "B200 GFD label maps to AIC system identifier",
...@@ -92,12 +92,23 @@ func TestEnrichHardwareFromDiscovery_UsesAICSystemIdentifier(t *testing.T) { ...@@ -92,12 +92,23 @@ func TestEnrichHardwareFromDiscovery_UsesAICSystemIdentifier(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
r := newFakeReconciler(gpuNode("gpu-node-1", tt.gfdProduct, 8, 141312)) r := newFakeReconciler(gpuNode("gpu-node-1", tt.gfdProduct, 8, 141312))
dgdr := &nvidiacomv1beta1.DynamoGraphDeploymentRequest{} vram := float64(141312)
gpus := int32(8)
dgdr := &nvidiacomv1beta1.DynamoGraphDeploymentRequest{
Spec: nvidiacomv1beta1.DynamoGraphDeploymentRequestSpec{
Hardware: &nvidiacomv1beta1.HardwareSpec{
GPUSKU: tt.expectedGPUSKU,
VRAMMB: &vram,
NumGPUsPerNode: &gpus,
},
},
}
err := r.enrichHardwareFromDiscovery(context.Background(), dgdr) err := r.enrichHardwareFromDiscovery(context.Background(), dgdr)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, dgdr.Spec.Hardware) require.NotNil(t, dgdr.Spec.Hardware)
assert.Equal(t, tt.expectedGPUSKU, string(dgdr.Spec.Hardware.GPUSKU), assert.Equal(t, string(tt.expectedGPUSKU), string(dgdr.Spec.Hardware.GPUSKU),
"GPUSKU should be the AIC system identifier, not the raw GFD product name %q", tt.gfdProduct) "GPUSKU should be the AIC system identifier, not the raw GFD product name %q", tt.gfdProduct)
}) })
} }
...@@ -107,7 +118,18 @@ func TestEnrichHardwareFromDiscovery_UsesAICSystemIdentifier(t *testing.T) { ...@@ -107,7 +118,18 @@ func TestEnrichHardwareFromDiscovery_UsesAICSystemIdentifier(t *testing.T) {
// not in the AIC support matrix, the raw GFD product name is used as a fallback. // not in the AIC support matrix, the raw GFD product name is used as a fallback.
func TestEnrichHardwareFromDiscovery_FallsBackToModelForUnknownGPU(t *testing.T) { func TestEnrichHardwareFromDiscovery_FallsBackToModelForUnknownGPU(t *testing.T) {
r := newFakeReconciler(gpuNode("gpu-node-1", "Tesla-V100-SXM2-16GB", 8, 16384)) r := newFakeReconciler(gpuNode("gpu-node-1", "Tesla-V100-SXM2-16GB", 8, 16384))
dgdr := &nvidiacomv1beta1.DynamoGraphDeploymentRequest{} vram := float64(16384)
gpus := int32(8)
dgdr := &nvidiacomv1beta1.DynamoGraphDeploymentRequest{
Spec: nvidiacomv1beta1.DynamoGraphDeploymentRequestSpec{
Hardware: &nvidiacomv1beta1.HardwareSpec{
GPUSKU: "Tesla-V100-SXM2-16GB",
VRAMMB: &vram,
NumGPUsPerNode: &gpus,
},
},
}
err := r.enrichHardwareFromDiscovery(context.Background(), dgdr) err := r.enrichHardwareFromDiscovery(context.Background(), dgdr)
require.NoError(t, err) require.NoError(t, err)
......
...@@ -20,8 +20,17 @@ package gpu ...@@ -20,8 +20,17 @@ package gpu
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/http"
"os"
"strconv" "strconv"
"strings" "strings"
"sync"
"time"
dto "github.com/prometheus/client_model/go"
"github.com/prometheus/common/expfmt"
"github.com/prometheus/common/model"
nvidiacomv1beta1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1beta1" nvidiacomv1beta1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1beta1"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
...@@ -30,19 +39,560 @@ import ( ...@@ -30,19 +39,560 @@ import (
) )
const ( const (
defaultDCGMEndpointTemplate = "http://{POD_IP}:9400/metrics"
// NVIDIA GPU Feature Discovery (GFD) label keys // NVIDIA GPU Feature Discovery (GFD) label keys
LabelGPUCount = "nvidia.com/gpu.count" LabelGPUCount = "nvidia.com/gpu.count"
LabelGPUProduct = "nvidia.com/gpu.product" LabelGPUProduct = "nvidia.com/gpu.product"
LabelGPUMemory = "nvidia.com/gpu.memory" LabelGPUMemory = "nvidia.com/gpu.memory"
// DCGM exporter label constants
LabelApp = "app"
LabelAppKubernetesName = "app.kubernetes.io/name"
LabelValueNvidiaDCGMExporter = "nvidia-dcgm-exporter"
LabelValueDCGMExporter = "dcgm-exporter"
LabelValueGPUOperator = "gpu-operator"
GPUOperatorNamespace = "gpu-operator"
requestTimeout = 5 * time.Second
dialTimeout = 3 * time.Second
tlsHandshakeTimeout = 3 * time.Second
CloudProviderGCP = "gcp"
CloudProviderAWS = "aws"
CloudProviderAKS = "aks"
CloudProviderOther = "other"
CloudProviderUnknown = "unknown"
) )
// awsInstanceTypePrefixes matches known GPU/accelerator instance families on EKS. See: https://aws.amazon.com/ec2/instance-types/
var awsInstanceTypePrefixes = []string{
"p3.", "p3dn.", "p4d.", "p4de.", "p5.", // GPU instances
"g3.", "g4dn.", "g4ad.", "g5.", "g6.", // GPU instances
"inf1.", "inf2.", // Inferentia
"trn1.", "trn1n.", // Trainium
}
// gcpMachineSeries matches known GCP accelerator-optimised machine series on GKE. See: https://cloud.google.com/compute/docs/machine-resource
var gcpMachineSeries = []string{
"a2-", // A100 GPU machines
"a3-", // H100 GPU machines
"g2-", // L4 GPU machines
}
// GPUInfo contains discovered GPU configuration from cluster nodes // GPUInfo contains discovered GPU configuration from cluster nodes
type GPUInfo struct { type GPUInfo struct {
NodeName string // Name of the node with this GPU configuration
GPUsPerNode int // Maximum GPUs per node found in the cluster GPUsPerNode int // Maximum GPUs per node found in the cluster
NodesWithGPUs int // Number of nodes that have GPUs NodesWithGPUs int // Number of nodes that have GPUs
Model string // GPU product name (e.g., "H100-SXM5-80GB") Model string // GPU product name (e.g., "H100-SXM5-80GB")
VRAMPerGPU int // VRAM in MiB per GPU VRAMPerGPU int // VRAM in MiB per GPU
System nvidiacomv1beta1.GPUSKUType // AIC hardware system identifier (e.g., "h100_sxm", "h200_sxm"), empty if unknown System nvidiacomv1beta1.GPUSKUType // AIC hardware system identifier (e.g., "h100_sxm", "h200_sxm"), empty if unknown
MIGEnabled bool // True if MIG is enabled (inferred from model or additional labels, not implemented in this version)
MIGProfiles map[string]int // Optional: map of MIG profile name to count (requires additional label parsing, not implemented in this version)
CloudProvider string // NEW: aws | gcp | aks | other | unknown
}
type ScrapeMetricsFunc func(ctx context.Context, endpoint string) (*GPUInfo, error)
type GPUDiscoveryCache struct {
mu sync.RWMutex
value *GPUInfo
expiresAt time.Time
}
type GPUDiscovery struct {
Scraper ScrapeMetricsFunc
}
func NewGPUDiscovery(scraper ScrapeMetricsFunc) *GPUDiscovery {
return &GPUDiscovery{
Scraper: scraper,
}
}
// NewGPUDiscoveryCache creates a new GPUDiscoveryCache instance.
//
// The cache stores a single discovered GPUInfo value with an expiration time.
// It is safe for concurrent use and is intended to reduce repeated DCGM
// scraping during reconciliation loops.
func NewGPUDiscoveryCache() *GPUDiscoveryCache {
return &GPUDiscoveryCache{}
}
// Get returns the cached GPUInfo if it exists and has not expired.
//
// The boolean return value indicates whether a valid cached value was found.
// If the cache is empty or expired, it returns (nil, false).
//
// This method is safe for concurrent use.
func (c *GPUDiscoveryCache) Get() (*GPUInfo, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
if time.Now().Before(c.expiresAt) && c.value != nil {
return c.value, true
}
return nil, false
}
// Set stores the provided GPUInfo in the cache with the given TTL (time-to-live).
//
// The cached value will be considered valid until the TTL duration elapses.
// After expiration, Get will return (nil, false) until a new value is set.
//
// This method is safe for concurrent use.
func (c *GPUDiscoveryCache) Set(info *GPUInfo, ttl time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
c.value = info
c.expiresAt = time.Now().Add(ttl)
}
// DiscoverGPUsFromDCGM discovers GPU information by scraping metrics directly
// from DCGM exporter pods running in the cluster.
//
// The function performs the following:
//
// 1. Returns cached GPU information if still valid.
// 2. Lists DCGM exporter pods across all namespaces using supported labels.
// 3. If no pods are found, attempts to find if GPU operator is installed and DCGM is enabled via Helm.
// 4. Warns user appropriately.
// 5. Scrapes each running pods metrics endpoint (http://<podIP>:9400/metrics).
// 6. Selects the "best" GPU node based on:
// - Highest GPU count
// - Highest VRAM per GPU (tie-breaker)
// 7. Caches the result for a short duration to avoid repeated scraping.
//
// Behavior Notes:
//
// - Scrapes pods directly instead of using a Service ClusterIP to avoid
// load-balancing ambiguity in multi-node clusters.
// - If at least one pod is successfully scraped, partial failures are tolerated.
// - If all pods fail to scrape, an aggregated error is returned.
// - Assumes DCGM exporter runs as a DaemonSet (one pod per GPU node).
// - Designed for homogeneous clusters; heterogeneous cluster aggregation
// is not yet implemented.
//
// Returns:
// - *GPUInfo for the selected node
// - error if no GPU data can be retrieved
//
// TODO: Current implementation selects a single "best" GPU node (highest GPU count,
// tie-broken by VRAM). This works for homogeneous clusters where all GPU
// nodes are identical.
// For Heterogeneous GPU Support (mixed GPU models or capacities), this logic
// does not represent full cluster GPU inventory. Future improvements should
// aggregate and return GPU information for all nodes instead of selecting
// only one.
func (g *GPUDiscovery) DiscoverGPUsFromDCGM(ctx context.Context, k8sClient client.Reader, cache *GPUDiscoveryCache) (*GPUInfo, error) {
if cache != nil {
// Return cached result if still valid
if cached, ok := cache.Get(); ok {
return cached, nil
}
}
// List DCGM exporter pods
dcgmPods, err := listDCGMExporterPods(ctx, k8sClient)
if err != nil && !strings.Contains(err.Error(), "no DCGM exporter pods found") {
return nil, fmt.Errorf("listing DCGM exporter pods failed: %w", err)
}
// If no pods found
if len(dcgmPods) == 0 {
gpuPods, err := listGPUOperatorRunningPods(ctx, k8sClient)
if len(gpuPods) > 0 {
return nil, fmt.Errorf("DCGM is not enabled in the GPU Operator (check GPU Operator configuration and permissions)")
}
return nil, err
}
// Scrape each running pod individually
var bestNode *GPUInfo
var scrapeErrors []error
nodesWithGPUs := 0
for _, pod := range dcgmPods {
if pod.Status.Phase != corev1.PodRunning || pod.Status.PodIP == "" {
continue
}
endpoint := buildDCGMEndpoint(pod.Status.PodIP)
info, err := g.Scraper(ctx, endpoint)
if err != nil {
scrapeErrors = append(scrapeErrors, fmt.Errorf("pod %s (%s): %w", pod.Name, pod.Status.PodIP, err))
continue
}
// Increment NodesWithGPUs for every node that successfully reports GPU metrics
nodesWithGPUs++
// Select best node: highest GPU count, tie-breaker by VRAM
if bestNode == nil ||
info.GPUsPerNode > bestNode.GPUsPerNode ||
(info.GPUsPerNode == bestNode.GPUsPerNode &&
info.VRAMPerGPU > bestNode.VRAMPerGPU) {
bestNode = info
}
}
if bestNode == nil {
if len(scrapeErrors) > 0 {
return nil, fmt.Errorf("failed to scrape any DCGM exporter pod: %v", scrapeErrors)
}
return nil, fmt.Errorf("no GPU metrics could be parsed from any DCGM pod")
}
// Infer cloud provider for the best node
cloudProvider, err := GetCloudProviderInfo(ctx, k8sClient)
if err != nil {
cloudProvider = CloudProviderUnknown
}
bestNode.CloudProvider = cloudProvider
bestNode.NodesWithGPUs = nodesWithGPUs
if cache != nil {
// Cache result for 60 seconds
cache.Set(bestNode, 60*time.Second)
}
return bestNode, nil
}
func buildDCGMEndpoint(podIP string) string {
template := os.Getenv("DCGM_METRICS_ENDPOINT_TEMPLATE")
if template == "" {
template = defaultDCGMEndpointTemplate
}
return strings.ReplaceAll(template, "{POD_IP}", podIP)
}
func listDCGMExporterPods(ctx context.Context, k8sClient client.Reader) ([]corev1.Pod, error) {
var result []corev1.Pod
seen := make(map[string]struct{})
selectors := []client.MatchingLabels{
{LabelApp: LabelValueNvidiaDCGMExporter},
{LabelApp: LabelValueDCGMExporter},
{LabelAppKubernetesName: LabelValueDCGMExporter},
}
var lastErr error
for _, selector := range selectors {
podList := &corev1.PodList{}
err := k8sClient.List(ctx, podList, selector)
if err != nil {
lastErr = fmt.Errorf("list pods: %w", err)
continue
}
for _, pod := range podList.Items {
key := pod.Namespace + "/" + pod.Name
if _, exists := seen[key]; !exists {
seen[key] = struct{}{}
result = append(result, pod)
}
}
}
if len(result) > 0 {
return result, nil
}
if lastErr != nil {
return nil, lastErr
}
return nil, fmt.Errorf("no DCGM exporter pods found")
}
// listGPUOperatorRunningPods lists GPU Operator pods in the given namespace
// and returns only those that are in Running phase.
//
// It uses common GPU Operator label selectors and deduplicates results
// across selectors. If no running pods are found, an error is returned.
func listGPUOperatorRunningPods(ctx context.Context, k8sClient client.Reader) ([]corev1.Pod, error) {
var result []corev1.Pod
seen := make(map[string]struct{})
selectors := []client.MatchingLabels{
{LabelApp: LabelValueGPUOperator},
{LabelAppKubernetesName: LabelValueGPUOperator},
}
var lastErr error
for _, selector := range selectors {
podList := &corev1.PodList{}
err := k8sClient.List(
ctx,
podList,
client.InNamespace(GPUOperatorNamespace),
selector,
)
if err != nil {
lastErr = fmt.Errorf("list gpu operator pods: %w", err)
continue
}
for _, pod := range podList.Items {
if pod.Status.Phase != corev1.PodRunning {
continue
}
key := pod.Namespace + "/" + pod.Name
if _, exists := seen[key]; !exists {
seen[key] = struct{}{}
result = append(result, pod)
}
}
}
if len(result) > 0 {
return result, nil
}
if lastErr != nil {
return nil, lastErr
}
return nil, fmt.Errorf(
"gpu operator is not installed %s",
GPUOperatorNamespace,
)
}
// scrapeMetricsEndpoint retrieves and parses Prometheus metrics from a
// DCGM exporter pod endpoint.
//
// The function performs an HTTP GET request against the provided endpoint
// (expected format: http://<podIP>:9400/metrics), validates the response,
// and parses the Prometheus text exposition format into metric families.
//
// Parsed metric families are passed to parseMetrics to extract high-level
// GPU information.
//
// Returns:
// - *GPUInfo derived from the parsed metrics
// - error if the HTTP request fails, the response is non-200,
// or metric parsing fails
//
// This function does not implement retries or fallback logic.
// Error handling and multi-pod aggregation are managed by the caller.
func ScrapeMetricsEndpoint(ctx context.Context, endpoint string) (*GPUInfo, error) {
// Set a timeout for the request
ctx, cancel := context.WithTimeout(ctx, requestTimeout)
defer cancel()
// Create a custom HTTP client with transport-level timeouts
client := &http.Client{
Transport: &http.Transport{
DialContext: (&net.Dialer{
Timeout: dialTimeout, // Dial timeout
KeepAlive: 30 * time.Second, // Keep-alive for connections
}).DialContext,
TLSHandshakeTimeout: tlsHandshakeTimeout, // TLS handshake timeout
},
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("create request for %s: %w", endpoint, err)
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("HTTP GET %s failed: %w", endpoint, err)
}
defer func() {
if cerr := resp.Body.Close(); cerr != nil {
// best-effort: can't return an error from defer; log it
log.FromContext(ctx).V(1).Info("failed to close response body", "err", cerr)
}
}()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf(
"metrics endpoint %s returned status %d",
endpoint,
resp.StatusCode,
)
}
parser := expfmt.NewTextParser(model.UTF8Validation)
metricFamilies, err := parser.TextToMetricFamilies(resp.Body)
if err != nil {
return nil, fmt.Errorf("parse prometheus metrics: %w", err)
}
return parseMetrics(ctx, metricFamilies)
}
// parseMetrics extracts GPU information for a node from DCGM Prometheus metrics.
//
// It parses the provided Prometheus metric families exported by the NVIDIA
// DCGM exporter and derives high-level GPU inventory information for the node.
//
// The function performs the following:
//
// - Detects the number of GPUs by counting unique "gpu" label values
// from DCGM_FI_DEV_GPU_TEMP (used as a reliable per-GPU metric).
//
// - Extracts the GPU model name from the "modelName" label.
//
// - Calculates total VRAM per GPU using framebuffer metrics:
// VRAM = FB_FREE + FB_USED + FB_RESERVED
// (values are in MiB).
//
// - Assumes MIG is disabled unless explicit MIG metrics are present
// (not included in the provided DCGM metric set).
//
// Parameters:
//
// ctx - Context for logging and cancellation.
// families - Map of Prometheus metric families keyed by metric name.
//
// Returns:
//
// *GPUInfo containing:
// - NodeName
// - GPUsPerNode
// - Model
// - VRAMPerGPU (MiB)
// - MIGEnabled: false because no MIG metrics were collected in the DCGM families
// - MIGProfiles: empty map; would contain MIG profile counts if MIG metrics were available
// - System (inferred from model)
//
// Returns an error if no GPUs can be detected from the metrics.
//
// Notes:
// - This function relies on DCGM exporter metrics.
// - If required metrics are missing, zero values may be returned.
// - The implementation assumes homogeneous GPUs per node.
// - For heterogeneous configurations, per-GPU parsing should be implemented.
func parseMetrics(ctx context.Context, families map[string]*dto.MetricFamily) (*GPUInfo, error) {
logger := log.FromContext(ctx)
getLabel := func(m *dto.Metric, name string) string {
for _, l := range m.GetLabel() {
if l.GetName() == name {
return l.GetValue()
}
}
return ""
}
// Track unique GPUs
gpuSet := map[string]struct{}{}
var model string
var vram int
var hostName string
fbFree := map[string]float64{}
fbUsed := map[string]float64{}
fbReserved := map[string]float64{}
// --- Detect GPUs + Model + Hostname ---
if mf, ok := families["DCGM_FI_DEV_GPU_TEMP"]; ok {
for _, m := range mf.Metric {
gpuID := getLabel(m, "gpu")
if gpuID == "" {
continue
}
gpuSet[gpuID] = struct{}{}
// Extract model from label
if model == "" {
model = getLabel(m, "modelName")
}
// Extract Hostname label
if hostName == "" {
hostName = getLabel(m, "Hostname")
}
}
}
// --- Collect framebuffer metrics ---
if mf, ok := families["DCGM_FI_DEV_FB_FREE"]; ok {
for _, m := range mf.Metric {
gpuID := getLabel(m, "gpu")
if gpuID == "" {
continue
}
fbFree[gpuID] = m.GetGauge().GetValue()
if hostName == "" {
hostName = getLabel(m, "Hostname")
}
}
}
if mf, ok := families["DCGM_FI_DEV_FB_USED"]; ok {
for _, m := range mf.Metric {
gpuID := getLabel(m, "gpu")
if gpuID == "" {
continue
}
fbUsed[gpuID] = m.GetGauge().GetValue()
if hostName == "" {
hostName = getLabel(m, "Hostname")
}
}
}
if mf, ok := families["DCGM_FI_DEV_FB_RESERVED"]; ok {
for _, m := range mf.Metric {
gpuID := getLabel(m, "gpu")
if gpuID == "" {
continue
}
fbReserved[gpuID] = m.GetGauge().GetValue()
if hostName == "" {
hostName = getLabel(m, "Hostname")
}
}
}
// --- Calculate Max VRAM
for gpuID := range gpuSet {
total := int(fbFree[gpuID] + fbUsed[gpuID] + fbReserved[gpuID])
if total > vram {
vram = total
}
}
gpuCount := len(gpuSet)
if gpuCount == 0 {
return nil, fmt.Errorf("no GPUs detected from DCGM metrics")
}
// --- Infer system from model ---
system := InferHardwareSystem(model)
logger.Info("Parsed GPU info",
"node", hostName,
"gpuCount", gpuCount,
"model", model,
"vramMiB", vram,
"system", system,
)
return &GPUInfo{
NodeName: hostName,
GPUsPerNode: gpuCount,
Model: model,
VRAMPerGPU: vram,
MIGEnabled: false,
MIGProfiles: map[string]int{},
System: system, // populated from InferHardwareSystem
}, nil
} }
// DiscoverGPUs queries Kubernetes nodes to determine GPU configuration. // DiscoverGPUs queries Kubernetes nodes to determine GPU configuration.
...@@ -203,3 +753,68 @@ func InferHardwareSystem(gpuProduct string) nvidiacomv1beta1.GPUSKUType { ...@@ -203,3 +753,68 @@ func InferHardwareSystem(gpuProduct string) nvidiacomv1beta1.GPUSKUType {
// User must specify gpuSku explicitly in spec.hardware. // User must specify gpuSku explicitly in spec.hardware.
return "" return ""
} }
func GetCloudProviderInfo(ctx context.Context, k8sClient client.Reader) (string, error) {
var nodeList corev1.NodeList
if err := k8sClient.List(ctx, &nodeList); err != nil {
return CloudProviderUnknown, fmt.Errorf("failed to list nodes: %w", err)
}
if len(nodeList.Items) == 0 {
return CloudProviderUnknown, fmt.Errorf("no nodes found in cluster")
}
// Use first node as representative (assumes homogeneous control plane)
node := nodeList.Items[0]
providerID := strings.ToLower(node.Spec.ProviderID)
labels := node.Labels
instanceType := strings.ToLower(labels["node.kubernetes.io/instance-type"])
// ---- Primary Detection: providerID ----
switch {
case strings.Contains(providerID, "azure"):
return CloudProviderAKS, nil
case strings.Contains(providerID, "aws"):
return CloudProviderAWS, nil
case strings.Contains(providerID, "gce"):
return CloudProviderGCP, nil
}
// ---- Secondary Detection: Node Labels ----
// AKS labels
if _, ok := labels["kubernetes.azure.com/cluster"]; ok {
return CloudProviderAKS, nil
}
if strings.Contains(instanceType, "standard_") {
return CloudProviderAKS, nil
}
// EKS labels
if _, ok := labels["eks.amazonaws.com/nodegroup"]; ok {
return CloudProviderAWS, nil
}
if isAWSInstanceType(instanceType) {
return CloudProviderAWS, nil
}
// GKE labels
if _, ok := labels["cloud.google.com/gke-nodepool"]; ok {
return CloudProviderGCP, nil
}
if isGCPInstanceType(instanceType) {
return CloudProviderGCP, nil
}
return "other", nil
}
func isGCPInstanceType(instanceType string) bool {
for _, prefix := range gcpMachineSeries {
if strings.HasPrefix(instanceType, prefix) {
return true
}
}
return false
}
func isAWSInstanceType(instanceType string) bool {
for _, prefix := range awsInstanceTypePrefixes {
if strings.HasPrefix(instanceType, prefix) {
return true
}
}
return false
}
...@@ -19,8 +19,14 @@ package gpu ...@@ -19,8 +19,14 @@ package gpu
import ( import (
"context" "context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing" "testing"
dto "github.com/prometheus/client_model/go"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
...@@ -31,7 +37,7 @@ import ( ...@@ -31,7 +37,7 @@ import (
) )
// newFakeClient creates a fake Kubernetes client with the given objects // newFakeClient creates a fake Kubernetes client with the given objects
func newFakeClient(objs ...client.Object) client.Client { func newFakeClient(objs ...client.Object) client.Reader {
scheme := runtime.NewScheme() scheme := runtime.NewScheme()
_ = corev1.AddToScheme(scheme) _ = corev1.AddToScheme(scheme)
return fake.NewClientBuilder(). return fake.NewClientBuilder().
...@@ -375,3 +381,495 @@ func TestInferHardwareSystem_SpacesAndDashes(t *testing.T) { ...@@ -375,3 +381,495 @@ func TestInferHardwareSystem_SpacesAndDashes(t *testing.T) {
assert.Equal(t, "h100_sxm", string(result), "Should normalize spaces/dashes: %s", variant) assert.Equal(t, "h100_sxm", string(result), "Should normalize spaces/dashes: %s", variant)
} }
} }
func TestParseMetrics(t *testing.T) {
ctx := context.Background()
// Fake DCGM metrics for a node with 2 GPUs
metricFamilies := map[string]*dto.MetricFamily{
"DCGM_FI_DEV_GPU_TEMP": {
Metric: []*dto.Metric{
{
Label: []*dto.LabelPair{
{Name: strPtr("gpu"), Value: strPtr("0")},
{Name: strPtr("modelName"), Value: strPtr("H100-SXM5-80GB")},
{Name: strPtr("Hostname"), Value: strPtr("node1")},
},
},
{
Label: []*dto.LabelPair{
{Name: strPtr("gpu"), Value: strPtr("1")},
{Name: strPtr("modelName"), Value: strPtr("H100-SXM5-80GB")},
{Name: strPtr("Hostname"), Value: strPtr("node1")},
},
},
},
},
"DCGM_FI_DEV_FB_FREE": {
Metric: []*dto.Metric{
{Label: []*dto.LabelPair{{Name: strPtr("gpu"), Value: strPtr("0")}}, Gauge: &dto.Gauge{Value: float64Ptr(10000)}},
{Label: []*dto.LabelPair{{Name: strPtr("gpu"), Value: strPtr("1")}}, Gauge: &dto.Gauge{Value: float64Ptr(12000)}},
},
},
"DCGM_FI_DEV_FB_USED": {
Metric: []*dto.Metric{
{Label: []*dto.LabelPair{{Name: strPtr("gpu"), Value: strPtr("0")}}, Gauge: &dto.Gauge{Value: float64Ptr(5000)}},
{Label: []*dto.LabelPair{{Name: strPtr("gpu"), Value: strPtr("1")}}, Gauge: &dto.Gauge{Value: float64Ptr(6000)}},
},
},
"DCGM_FI_DEV_FB_RESERVED": {
Metric: []*dto.Metric{
{Label: []*dto.LabelPair{{Name: strPtr("gpu"), Value: strPtr("0")}}, Gauge: &dto.Gauge{Value: float64Ptr(0)}},
{Label: []*dto.LabelPair{{Name: strPtr("gpu"), Value: strPtr("1")}}, Gauge: &dto.Gauge{Value: float64Ptr(0)}},
},
},
}
info, err := parseMetrics(ctx, metricFamilies)
require.NoError(t, err)
assert.Equal(t, "node1", info.NodeName)
assert.Equal(t, 2, info.GPUsPerNode)
assert.Equal(t, "H100-SXM5-80GB", info.Model)
// maxVRAM: 12000 + 6000 + 0 = 18000
assert.Equal(t, 18000, info.VRAMPerGPU)
assert.False(t, info.MIGEnabled)
assert.Empty(t, info.MIGProfiles)
}
func TestScrapeMetricsEndpoint(t *testing.T) {
ctx := context.TODO()
// Prepare a fake HTTP server to simulate Prometheus metrics
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := fmt.Fprintln(w, `# HELP DCGM_FI_DEV_GPU_TEMP GPU temperature`)
require.NoError(t, err)
_, err = fmt.Fprintln(w, `# TYPE DCGM_FI_DEV_GPU_TEMP gauge`)
require.NoError(t, err)
_, err = fmt.Fprintln(w, `DCGM_FI_DEV_GPU_TEMP{gpu="0",modelName="NVIDIA A100",Hostname="test-node"} 50`)
require.NoError(t, err)
_, err = fmt.Fprintln(w, `# HELP DCGM_FI_DEV_FB_FREE Framebuffer free`)
require.NoError(t, err)
_, err = fmt.Fprintln(w, `# TYPE DCGM_FI_DEV_FB_FREE gauge`)
require.NoError(t, err)
_, err = fmt.Fprintln(w, `DCGM_FI_DEV_FB_FREE{gpu="0",Hostname="test-node"} 10000`)
require.NoError(t, err)
_, err = fmt.Fprintln(w, `# HELP DCGM_FI_DEV_FB_USED Framebuffer used`)
require.NoError(t, err)
_, err = fmt.Fprintln(w, `# TYPE DCGM_FI_DEV_FB_USED gauge`)
require.NoError(t, err)
_, err = fmt.Fprintln(w, `DCGM_FI_DEV_FB_USED{gpu="0",Hostname="test-node"} 2000`)
require.NoError(t, err)
_, err = fmt.Fprintln(w, `# HELP DCGM_FI_DEV_FB_RESERVED Framebuffer reserved`)
require.NoError(t, err)
_, err = fmt.Fprintln(w, `# TYPE DCGM_FI_DEV_FB_RESERVED gauge`)
require.NoError(t, err)
_, err = fmt.Fprintln(w, `DCGM_FI_DEV_FB_RESERVED{gpu="0",Hostname="test-node"} 500`)
require.NoError(t, err)
}))
defer server.Close()
t.Run("successful scrape", func(t *testing.T) {
info, err := ScrapeMetricsEndpoint(ctx, server.URL)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if info == nil {
t.Fatal("expected non-nil GPUInfo")
}
})
t.Run("404 response", func(t *testing.T) {
badServer := httptest.NewServer(http.NotFoundHandler())
defer badServer.Close()
_, err := ScrapeMetricsEndpoint(ctx, badServer.URL)
expectedErr := fmt.Sprintf("metrics endpoint %s returned status 404", badServer.URL)
if err == nil || err.Error() != expectedErr {
t.Fatalf("expected %q, got %v", expectedErr, err)
}
})
t.Run("invalid metrics", func(t *testing.T) {
invalidServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := fmt.Fprintln(w, `not a prometheus format`)
require.NoError(t, err)
}))
defer invalidServer.Close()
_, err := ScrapeMetricsEndpoint(ctx, invalidServer.URL)
if err == nil {
t.Fatal("expected parse error, got nil")
}
})
}
func TestDiscoverGPUsFromDCGM_CacheHit(t *testing.T) {
ctx := context.Background()
pod := &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "dcgm-pod",
Namespace: "default",
Labels: map[string]string{
LabelApp: LabelValueNvidiaDCGMExporter,
},
},
Status: corev1.PodStatus{
Phase: corev1.PodRunning,
PodIP: "10.0.0.1",
},
}
scheme := runtime.NewScheme()
require.NoError(t, corev1.AddToScheme(scheme))
k8sClient := fake.NewClientBuilder().
WithScheme(scheme).
WithObjects(pod).
Build()
cache := NewGPUDiscoveryCache()
callCount := 0
mockScraper := func(ctx context.Context, endpoint string) (*GPUInfo, error) {
callCount++
return &GPUInfo{
NodeName: "node-a",
GPUsPerNode: 4,
Model: "A100",
VRAMPerGPU: 40960,
MIGEnabled: false,
MIGProfiles: map[string]int{},
System: "DGX",
}, nil
}
discovery := NewGPUDiscovery(mockScraper)
// First call → should scrape
info1, err := discovery.DiscoverGPUsFromDCGM(ctx, k8sClient, cache)
require.NoError(t, err)
require.NotNil(t, info1)
require.Equal(t, 1, callCount)
// Second call → should hit cache
info2, err := discovery.DiscoverGPUsFromDCGM(ctx, k8sClient, cache)
require.NoError(t, err)
require.NotNil(t, info2)
// Scrape should NOT be called again
require.Equal(t, 1, callCount)
require.Equal(t, info1, info2)
}
func TestDiscoverGPUsFromDCGM_GPUOperatorInstalled_DCgmNotEnabled(t *testing.T) {
ctx := context.Background()
gpuOperatorPod := &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "gpu-operator-abc",
Namespace: "gpu-operator",
Labels: map[string]string{
LabelApp: LabelValueGPUOperator,
},
},
Status: corev1.PodStatus{
Phase: corev1.PodRunning,
},
}
scheme := runtime.NewScheme()
require.NoError(t, corev1.AddToScheme(scheme))
k8sClient := fake.NewClientBuilder().
WithScheme(scheme).
WithObjects(gpuOperatorPod).
Build()
cache := NewGPUDiscoveryCache()
dummyScraper := func(ctx context.Context, endpoint string) (*GPUInfo, error) {
return nil, fmt.Errorf("should not be called")
}
discovery := NewGPUDiscovery(dummyScraper)
info, err := discovery.DiscoverGPUsFromDCGM(ctx, k8sClient, cache)
require.Nil(t, info)
require.Error(t, err)
require.Contains(t, err.Error(), "DCGM is not enabled in the GPU Operator")
}
func TestDiscoverGPUsFromDCGM_NoGPUOperator_NoDCGM(t *testing.T) {
ctx := context.Background()
scheme := runtime.NewScheme()
require.NoError(t, corev1.AddToScheme(scheme))
k8sClient := fake.NewClientBuilder().
WithScheme(scheme).
Build()
cache := NewGPUDiscoveryCache()
dummyScraper := func(ctx context.Context, endpoint string) (*GPUInfo, error) {
return nil, fmt.Errorf("should not be called")
}
discovery := NewGPUDiscovery(dummyScraper)
info, err := discovery.DiscoverGPUsFromDCGM(ctx, k8sClient, cache)
require.Nil(t, info)
require.Error(t, err)
require.True(
t,
strings.Contains(err.Error(), "gpu operator is not installed"),
)
}
func TestListDCGMExporterPods(t *testing.T) {
scheme := runtime.NewScheme()
_ = corev1.AddToScheme(scheme)
ctx := context.Background()
tests := []struct {
name string
objects []client.Object
expectCount int
expectErr bool
errorClient bool
}{
{
name: "pods found via different selectors",
objects: []client.Object{
&corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "pod1",
Namespace: "ns1",
Labels: map[string]string{
LabelApp: LabelValueNvidiaDCGMExporter,
},
},
},
&corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "pod2",
Namespace: "ns1",
Labels: map[string]string{
LabelAppKubernetesName: LabelValueDCGMExporter,
},
},
},
},
expectCount: 2,
expectErr: false,
},
{
name: "duplicate pods across selectors should dedupe",
objects: []client.Object{
&corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "pod1",
Namespace: "ns1",
Labels: map[string]string{
LabelApp: LabelValueDCGMExporter,
LabelAppKubernetesName: LabelValueDCGMExporter,
},
},
},
},
expectCount: 1,
expectErr: false,
},
{
name: "no pods found",
objects: []client.Object{},
expectCount: 0,
expectErr: true,
},
{
name: "client list error",
objects: []client.Object{},
expectCount: 0,
expectErr: true,
errorClient: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var k8sClient client.Reader
if tt.errorClient {
k8sClient = &errorListClient{}
} else {
k8sClient = fake.NewClientBuilder().
WithScheme(scheme).
WithObjects(tt.objects...).
Build()
}
pods, err := listDCGMExporterPods(ctx, k8sClient)
if tt.expectErr && err == nil {
t.Fatalf("expected error but got nil")
}
if !tt.expectErr && err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(pods) != tt.expectCount {
t.Fatalf("expected %d pods, got %d", tt.expectCount, len(pods))
}
})
}
}
//
// ---- Fake client that forces List error ----
//
type errorListClient struct {
client.Reader
}
func (e *errorListClient) List(ctx context.Context, list client.ObjectList, opts ...client.ListOption) error {
return errors.New("forced list error")
}
// --- Helper functions ---
func strPtr(s string) *string { return &s }
func float64Ptr(f float64) *float64 { return &f }
func TestGetCloudProviderInfo(t *testing.T) {
scheme := runtime.NewScheme()
_ = corev1.AddToScheme(scheme)
tests := []struct {
name string
node corev1.Node
want string
wantErr bool
}{
{
name: "AKS via providerID",
node: corev1.Node{
Spec: corev1.NodeSpec{
ProviderID: "azure:///subscriptions/xxx/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/vm1",
},
},
want: "aks",
wantErr: false,
},
{
name: "AWS via providerID",
node: corev1.Node{
Spec: corev1.NodeSpec{
ProviderID: "aws:///us-west-2/i-0123456789abcdef0",
},
},
want: "aws",
wantErr: false,
},
{
name: "GCP via providerID",
node: corev1.Node{
Spec: corev1.NodeSpec{
ProviderID: "gce://project/zone/instance",
},
},
want: "gcp",
wantErr: false,
},
{
name: "AKS via label",
node: corev1.Node{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
"kubernetes.azure.com/cluster": "mycluster",
},
},
},
want: "aks",
wantErr: false,
},
{
name: "AWS via label",
node: corev1.Node{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
"eks.amazonaws.com/nodegroup": "ng-1",
},
},
},
want: "aws",
wantErr: false,
},
{
name: "GCP via label",
node: corev1.Node{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
"cloud.google.com/gke-nodepool": "np-1",
},
},
},
want: "gcp",
wantErr: false,
},
{
name: "Other node",
node: corev1.Node{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
"custom-label": "foo",
},
},
},
want: "other",
wantErr: false,
},
{
name: "No nodes",
node: corev1.Node{}, // will not add to client
want: "unknown",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.TODO()
var k8sClient client.Reader
if tt.name != "No nodes" {
k8sClient = fake.NewClientBuilder().
WithScheme(scheme).
WithObjects(&tt.node).
Build()
} else {
k8sClient = fake.NewClientBuilder().
WithScheme(scheme).
Build()
}
got, err := GetCloudProviderInfo(ctx, k8sClient)
if (err != nil) != tt.wantErr {
t.Errorf("unexpected error: %v", err)
}
if got != tt.want {
t.Errorf("got %q, want %q", got, tt.want)
}
})
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment