Unverified Commit a91e6348 authored by Julien Mancuso's avatar Julien Mancuso Committed by GitHub
Browse files

feat: add dynamoModel CRD (#4166)


Signed-off-by: default avatarJulien Mancuso <jmancuso@nvidia.com>
parent b2f1defe
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package controller
import (
"context"
"fmt"
"time"
corev1 "k8s.io/api/core/v1"
discoveryv1 "k8s.io/api/discovery/v1"
k8serrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/api/meta"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/tools/record"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/builder"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/event"
"sigs.k8s.io/controller-runtime/pkg/handler"
"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/controller-runtime/pkg/predicate"
"sigs.k8s.io/controller-runtime/pkg/reconcile"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/consts"
commoncontroller "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/controller_common"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/dynamo"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/modelendpoint"
)
const (
// Condition types
ConditionTypeEndpointsReady = "EndpointsReady"
ConditionTypeServicesFound = "ServicesFound"
// Condition reasons
ReasonAllEndpointsReady = "AllEndpointsReady"
ReasonEndpointsDiscovered = "EndpointsDiscovered"
ReasonNotReady = "NotReady"
ReasonNoEndpoints = "NoEndpoints"
ReasonServicesFound = "ServicesFound"
ReasonNoServicesFound = "NoServicesFound"
// Field index names
dynamoModelBaseModelHashIndex = ".spec.baseModelNameHash"
// Requeue duration for retries when endpoints are not ready
requeueAfterDuration = 30 * time.Second
)
// DynamoModelReconciler reconciles a DynamoModel object
type DynamoModelReconciler struct {
client.Client
Recorder record.EventRecorder
EndpointClient *modelendpoint.Client
}
// +kubebuilder:rbac:groups=nvidia.com,resources=dynamomodels,verbs=get;list;watch;create;update;patch;delete
// +kubebuilder:rbac:groups=nvidia.com,resources=dynamomodels/status,verbs=get;update;patch
// +kubebuilder:rbac:groups=nvidia.com,resources=dynamomodels/finalizers,verbs=update
// +kubebuilder:rbac:groups=core,resources=services,verbs=get;list;watch
// +kubebuilder:rbac:groups=discovery.k8s.io,resources=endpointslices,verbs=get;list;watch
// Reconcile handles the reconciliation loop for DynamoModel resources
func (r *DynamoModelReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
logs := log.FromContext(ctx)
// Fetch the DynamoModel
model := &v1alpha1.DynamoModel{}
if err := r.Get(ctx, req.NamespacedName, model); err != nil {
if k8serrors.IsNotFound(err) {
logs.Info("DynamoModel resource not found. Ignoring since object must be deleted")
return ctrl.Result{}, nil
}
logs.Error(err, "Failed to get DynamoModel")
return ctrl.Result{}, err
}
logs = logs.WithValues("dynamoModel", model.Name, "namespace", model.Namespace, "baseModelName", model.Spec.BaseModelName)
logs.Info("Reconciling DynamoModel")
// Handle finalizer using common handler
finalized, err := commoncontroller.HandleFinalizer(ctx, model, r.Client, r)
if err != nil {
return ctrl.Result{}, err
}
if finalized {
// Object was being deleted and finalizer has been called
return ctrl.Result{}, nil
}
// Get endpoint candidates (common logic)
candidates, serviceNames, err := r.getEndpointCandidates(ctx, model)
if err != nil {
// Error already logged and status updated in helper
// Let controller-runtime handle retry with exponential backoff
return ctrl.Result{}, err
}
if len(candidates) == 0 {
msg := fmt.Sprintf("No endpoint slices found for base model %s", model.Spec.BaseModelName)
logs.Info(msg)
r.Recorder.Event(model, corev1.EventTypeWarning, "NoEndpointsFound", msg)
r.updateCondition(model, ConditionTypeServicesFound, metav1.ConditionFalse, ReasonNoServicesFound, msg)
r.updateCondition(model, ConditionTypeEndpointsReady, metav1.ConditionFalse, ReasonNoEndpoints, msg)
model.Status.Endpoints = nil
model.Status.TotalEndpoints = 0
model.Status.ReadyEndpoints = 0
if err := r.Status().Update(ctx, model); err != nil {
return ctrl.Result{}, err
}
// Don't requeue - we're watching EndpointSlices, so we'll be notified when they appear
return ctrl.Result{}, nil
}
// Load LoRA on all endpoints in parallel with bounded concurrency
allEndpoints, probeErr := r.EndpointClient.LoadLoRA(ctx, candidates, model)
// Determine if we need to requeue based on model type
// For LoRA models: requeue if there were probe errors OR if not all endpoints are ready
// For base models: only requeue if there were probe errors (Ready is expected to be false)
hasFailures := probeErr != nil
if model.IsLoRA() {
hasFailures = hasFailures || countReadyEndpoints(allEndpoints) < len(allEndpoints)
}
if probeErr != nil {
logs.Error(probeErr, "Some endpoints failed during probing")
r.Recorder.Event(model, corev1.EventTypeWarning, "PartialEndpointFailure",
fmt.Sprintf("Some endpoints failed to load LoRA: %v", probeErr))
}
// Update service found condition based on whether we found any services
if len(serviceNames) > 0 {
r.updateCondition(model, ConditionTypeServicesFound, metav1.ConditionTrue, ReasonServicesFound,
fmt.Sprintf("Found %d service(s)", len(serviceNames)))
} else {
r.updateCondition(model, ConditionTypeServicesFound, metav1.ConditionFalse, ReasonNoServicesFound,
"No services associated with endpoint slices")
}
// Update status
model.Status.Endpoints = allEndpoints
model.Status.TotalEndpoints = len(allEndpoints)
model.Status.ReadyEndpoints = countReadyEndpoints(allEndpoints)
// Update conditions based on model type
if model.IsLoRA() {
// For LoRA models, check readiness - condition is True only when ALL endpoints are ready
if model.Status.ReadyEndpoints == model.Status.TotalEndpoints && model.Status.TotalEndpoints > 0 {
r.updateCondition(model, ConditionTypeEndpointsReady, metav1.ConditionTrue, ReasonAllEndpointsReady,
fmt.Sprintf("All %d endpoint(s) are ready", model.Status.TotalEndpoints))
r.Recorder.Eventf(model, corev1.EventTypeNormal, "EndpointsReady",
"All %d endpoints ready for base model %s", model.Status.TotalEndpoints, model.Spec.BaseModelName)
} else if model.Status.TotalEndpoints > 0 {
r.updateCondition(model, ConditionTypeEndpointsReady, metav1.ConditionFalse, ReasonNotReady,
fmt.Sprintf("Found %d ready endpoint(s) out of %d total", model.Status.ReadyEndpoints, model.Status.TotalEndpoints))
r.Recorder.Eventf(model, corev1.EventTypeWarning, "NotReady",
"Only %d of %d endpoints ready for base model %s", model.Status.ReadyEndpoints, model.Status.TotalEndpoints, model.Spec.BaseModelName)
} else {
r.updateCondition(model, ConditionTypeEndpointsReady, metav1.ConditionFalse, ReasonNoEndpoints, "No endpoints found")
}
} else {
// For base models, just check that endpoints exist (readiness doesn't apply)
if model.Status.TotalEndpoints > 0 {
r.updateCondition(model, ConditionTypeEndpointsReady, metav1.ConditionTrue, ReasonEndpointsDiscovered,
fmt.Sprintf("Found %d endpoint(s) for base model", model.Status.TotalEndpoints))
r.Recorder.Eventf(model, corev1.EventTypeNormal, "EndpointsDiscovered",
"Discovered %d endpoints for base model %s", model.Status.TotalEndpoints, model.Spec.BaseModelName)
} else {
r.updateCondition(model, ConditionTypeEndpointsReady, metav1.ConditionFalse, ReasonNoEndpoints, "No endpoints found")
}
}
if err := r.Status().Update(ctx, model); err != nil {
logs.Error(err, "Failed to update DynamoModel status")
return ctrl.Result{}, err
}
logs.Info("Successfully reconciled DynamoModel",
"totalEndpoints", model.Status.TotalEndpoints,
"readyEndpoints", model.Status.ReadyEndpoints)
// Requeue if there were probe failures to retry loading LoRAs
if hasFailures {
logs.Info("Requeuing due to endpoint probe failures",
"ready", model.Status.ReadyEndpoints,
"total", model.Status.TotalEndpoints)
return ctrl.Result{RequeueAfter: requeueAfterDuration}, nil
}
return ctrl.Result{}, nil
}
// countReadyEndpoints counts how many endpoints are ready
func countReadyEndpoints(endpoints []v1alpha1.EndpointInfo) int {
count := 0
for _, ep := range endpoints {
if ep.Ready {
count++
}
}
return count
}
// updateCondition updates or adds a condition to the model's status
func (r *DynamoModelReconciler) updateCondition(model *v1alpha1.DynamoModel, condType string, status metav1.ConditionStatus, reason, message string) {
condition := metav1.Condition{
Type: condType,
Status: status,
ObservedGeneration: model.Generation,
LastTransitionTime: metav1.Now(),
Reason: reason,
Message: message,
}
meta.SetStatusCondition(&model.Status.Conditions, condition)
}
// SetupWithManager sets up the controller with the Manager
func (r *DynamoModelReconciler) SetupWithManager(mgr ctrl.Manager) error {
// Register field indexer for DynamoModels by hash of base model name
// This allows efficient O(1) queries: "get all DynamoModels for EndpointSlice with hash X"
// The hash matches the label on EndpointSlices: nvidia.com/dynamo-base-model-hash
if err := mgr.GetFieldIndexer().IndexField(
context.Background(),
&v1alpha1.DynamoModel{},
dynamoModelBaseModelHashIndex,
func(obj client.Object) []string {
model := obj.(*v1alpha1.DynamoModel)
// Hash the base model name using the same function used for EndpointSlice labels
hash := dynamo.HashModelName(model.Spec.BaseModelName)
return []string{hash}
},
); err != nil {
return err
}
return ctrl.NewControllerManagedBy(mgr).
For(&v1alpha1.DynamoModel{}, builder.WithPredicates(predicate.GenerationChangedPredicate{})).
// Watch EndpointSlices - reconcile when endpoints change (Service changes trigger EndpointSlice updates)
Watches(
&discoveryv1.EndpointSlice{},
handler.EnqueueRequestsFromMapFunc(r.findModelsForEndpointSlice),
builder.WithPredicates(predicate.Funcs{
GenericFunc: func(e event.GenericEvent) bool { return false },
}),
).
Complete(r)
}
// findModelsForEndpointSlice maps an EndpointSlice to DynamoModels
func (r *DynamoModelReconciler) findModelsForEndpointSlice(ctx context.Context, obj client.Object) []reconcile.Request {
slice := obj.(*discoveryv1.EndpointSlice)
logs := log.FromContext(ctx).WithValues("endpointSlice", slice.Name, "namespace", slice.Namespace)
// Get the base model hash from the EndpointSlice label
// This hash is set when the Service is created and matches our index
baseModelHash, ok := slice.Labels[consts.KubeLabelDynamoBaseModelHash]
if !ok {
return nil
}
// Find all DynamoModels with this base model hash using field indexer
// The indexer hashes each model's BaseModelName and we query by that hash
requests, err := modelendpoint.FindModelsForBaseModel(ctx, r.Client, slice.Namespace, baseModelHash, dynamoModelBaseModelHashIndex)
if err != nil {
return nil
}
if len(requests) > 0 {
logs.V(1).Info("EndpointSlice change triggered DynamoModel reconciliation",
"modelCount", len(requests),
"baseModelHash", baseModelHash)
}
return requests
}
// FinalizeResource implements the Finalizer interface
// Performs cleanup when a DynamoModel is being deleted
func (r *DynamoModelReconciler) FinalizeResource(ctx context.Context, model *v1alpha1.DynamoModel) error {
logs := log.FromContext(ctx)
logs.Info("Finalizing DynamoModel", "modelType", model.Spec.ModelType)
// Only perform cleanup for LoRA models
if model.IsLoRA() {
// Get endpoint candidates (reusing common logic)
candidates, _, err := r.getEndpointCandidates(ctx, model)
if err != nil {
logs.Info("Failed to get endpoints during deletion, continuing with resource deletion",
"error", err.Error())
r.Recorder.Event(model, corev1.EventTypeWarning, "CleanupFailed", err.Error())
// Continue with deletion even if we can't get endpoints
} else if len(candidates) > 0 {
logs.Info("Unloading LoRA from endpoints", "endpointCount", len(candidates))
// Unload LoRA from all endpoints in parallel
if err := r.EndpointClient.UnloadLoRA(ctx, candidates, model.Spec.ModelName); err != nil {
// Log as Info since we're continuing with deletion anyway (expected behavior)
// Detailed failure information is already logged by the prober
logs.Info("Some endpoints failed to unload LoRA, continuing with deletion",
"error", err.Error())
r.Recorder.Event(model, corev1.EventTypeWarning, "LoRAUnloadFailed",
fmt.Sprintf("Failed to unload LoRA from some endpoints: %v", err))
// Continue with deletion even if unload fails
} else {
logs.Info("Successfully unloaded LoRA from all endpoints")
r.Recorder.Event(model, corev1.EventTypeNormal, "LoRAUnloaded",
fmt.Sprintf("Unloaded LoRA from %d endpoint(s)", len(candidates)))
}
} else {
logs.Info("No endpoints found for cleanup")
}
} else {
logs.Info("Skipping cleanup for non-LoRA model")
}
logs.Info("Finalization completed successfully")
return nil
}
// getEndpointCandidates fetches EndpointSlices and extracts endpoint candidates
// Returns candidates, service names, and error
func (r *DynamoModelReconciler) getEndpointCandidates(
ctx context.Context,
model *v1alpha1.DynamoModel,
) ([]modelendpoint.Candidate, map[string]bool, error) {
logs := log.FromContext(ctx)
// Hash the base model name for label-based discovery
modelHash := dynamo.HashModelName(model.Spec.BaseModelName)
// Query EndpointSlices directly by base model hash label
// This label propagates from the Service to its EndpointSlices
endpointSlices := &discoveryv1.EndpointSliceList{}
if err := r.List(ctx, endpointSlices,
client.InNamespace(model.Namespace),
client.MatchingLabels{consts.KubeLabelDynamoBaseModelHash: modelHash},
); err != nil {
logs.Error(err, "Failed to list endpoint slices for model")
r.Recorder.Event(model, corev1.EventTypeWarning, "EndpointDiscoveryFailed", err.Error())
return nil, nil, err
}
if len(endpointSlices.Items) == 0 {
return nil, nil, nil
}
logs.Info("Found endpoint slices for model", "count", len(endpointSlices.Items))
// Extract pod-ready endpoint candidates from all EndpointSlices
candidates, serviceNames := modelendpoint.ExtractCandidates(endpointSlices, int32(consts.DynamoSystemPort))
return candidates, serviceNames, nil
}
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package controller
import (
"context"
"time"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/consts"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/dynamo"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/modelendpoint"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
corev1 "k8s.io/api/core/v1"
discoveryv1 "k8s.io/api/discovery/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
"k8s.io/client-go/tools/record"
"sigs.k8s.io/controller-runtime/pkg/reconcile"
)
var _ = Describe("DynamoModel Controller", func() {
const (
timeout = time.Second * 10
interval = time.Millisecond * 250
)
var (
reconciler *DynamoModelReconciler
recorder *record.FakeRecorder
)
BeforeEach(func() {
recorder = record.NewFakeRecorder(100)
reconciler = &DynamoModelReconciler{
Client: k8sClient,
Recorder: recorder,
EndpointClient: modelendpoint.NewClient(),
}
})
Context("When reconciling LoRA model", func() {
It("Should discover endpoints and set conditions", func() {
ctx := context.Background()
namespace := defaultNamespace
modelName := "test-lora-model"
baseModelName := "base-model-lora"
// Create the DynamoModel
model := &v1alpha1.DynamoModel{
ObjectMeta: metav1.ObjectMeta{
Name: modelName,
Namespace: namespace,
},
Spec: v1alpha1.DynamoModelSpec{
ModelName: modelName,
BaseModelName: baseModelName,
ModelType: "lora",
Source: &v1alpha1.ModelSource{
URI: "s3://bucket/model",
},
},
}
Expect(k8sClient.Create(ctx, model)).Should(Succeed())
defer func() { _ = k8sClient.Delete(ctx, model) }()
// Create EndpointSlice with ready Pod endpoints
trueVal := true
modelHash := dynamo.HashModelName(baseModelName)
endpointSlice := &discoveryv1.EndpointSlice{
ObjectMeta: metav1.ObjectMeta{
Name: "test-lora-endpoints",
Namespace: namespace,
Labels: map[string]string{
consts.KubeLabelDynamoBaseModelHash: modelHash,
discoveryv1.LabelServiceName: "test-service",
},
},
AddressType: discoveryv1.AddressTypeIPv4,
Endpoints: []discoveryv1.Endpoint{
{
Addresses: []string{"10.0.1.5"},
Conditions: discoveryv1.EndpointConditions{
Ready: &trueVal,
},
TargetRef: &corev1.ObjectReference{
Kind: "Pod",
Name: "worker-0",
},
},
{
Addresses: []string{"10.0.1.6"},
Conditions: discoveryv1.EndpointConditions{
Ready: &trueVal,
},
TargetRef: &corev1.ObjectReference{
Kind: "Pod",
Name: "worker-1",
},
},
},
Ports: []discoveryv1.EndpointPort{
{
Port: func() *int32 { p := int32(9090); return &p }(),
},
},
}
Expect(k8sClient.Create(ctx, endpointSlice)).Should(Succeed())
defer func() { _ = k8sClient.Delete(ctx, endpointSlice) }()
// Reconcile
_, err := reconciler.Reconcile(ctx, reconcile.Request{
NamespacedName: types.NamespacedName{
Name: modelName,
Namespace: namespace,
},
})
Expect(err).NotTo(HaveOccurred())
// Verify endpoints were discovered
Eventually(func() int {
var updated v1alpha1.DynamoModel
_ = k8sClient.Get(ctx, types.NamespacedName{Name: modelName, Namespace: namespace}, &updated)
return updated.Status.TotalEndpoints
}, timeout, interval).Should(Equal(2))
// Verify condition is set (will be False since LoRA load will fail without real service)
Eventually(func() bool {
var updated v1alpha1.DynamoModel
_ = k8sClient.Get(ctx, types.NamespacedName{Name: modelName, Namespace: namespace}, &updated)
for _, cond := range updated.Status.Conditions {
if cond.Type == ConditionTypeEndpointsReady {
return true
}
}
return false
}, timeout, interval).Should(BeTrue())
})
})
Context("When reconciling with non-Pod endpoints", func() {
It("Should skip endpoints without Pod TargetRef", func() {
ctx := context.Background()
namespace := defaultNamespace
modelName := "test-non-pod-model"
baseModelName := "base-model-non-pod"
// Create the DynamoModel
model := &v1alpha1.DynamoModel{
ObjectMeta: metav1.ObjectMeta{
Name: modelName,
Namespace: namespace,
},
Spec: v1alpha1.DynamoModelSpec{
ModelName: modelName,
BaseModelName: baseModelName,
ModelType: "base",
},
}
Expect(k8sClient.Create(ctx, model)).Should(Succeed())
defer func() { _ = k8sClient.Delete(ctx, model) }()
// Create EndpointSlice with mixed endpoints (some Pod, some not)
trueVal := true
modelHash := dynamo.HashModelName(baseModelName)
endpointSlice := &discoveryv1.EndpointSlice{
ObjectMeta: metav1.ObjectMeta{
Name: "test-mixed-endpoints",
Namespace: namespace,
Labels: map[string]string{
consts.KubeLabelDynamoBaseModelHash: modelHash,
},
},
AddressType: discoveryv1.AddressTypeIPv4,
Endpoints: []discoveryv1.Endpoint{
{
Addresses: []string{"10.0.1.7"},
Conditions: discoveryv1.EndpointConditions{
Ready: &trueVal,
},
TargetRef: &corev1.ObjectReference{
Kind: "Pod",
Name: "worker-0",
},
},
{
Addresses: []string{"10.0.1.8"},
Conditions: discoveryv1.EndpointConditions{
Ready: &trueVal,
},
TargetRef: &corev1.ObjectReference{
Kind: "Node", // Not a Pod - should be skipped
Name: "node-1",
},
},
{
Addresses: []string{"10.0.1.9"},
Conditions: discoveryv1.EndpointConditions{
Ready: &trueVal,
},
TargetRef: nil, // Nil TargetRef - should be skipped
},
},
Ports: []discoveryv1.EndpointPort{
{
Port: func() *int32 { p := int32(9090); return &p }(),
},
},
}
Expect(k8sClient.Create(ctx, endpointSlice)).Should(Succeed())
defer func() { _ = k8sClient.Delete(ctx, endpointSlice) }()
// Reconcile
_, err := reconciler.Reconcile(ctx, reconcile.Request{
NamespacedName: types.NamespacedName{
Name: modelName,
Namespace: namespace,
},
})
Expect(err).NotTo(HaveOccurred())
// Should only discover 1 endpoint (the Pod), not the Node or nil TargetRef
Eventually(func() int {
var updated v1alpha1.DynamoModel
_ = k8sClient.Get(ctx, types.NamespacedName{Name: modelName, Namespace: namespace}, &updated)
return updated.Status.TotalEndpoints
}, timeout, interval).Should(Equal(1))
// Verify only the Pod endpoint was included
Eventually(func() string {
var updated v1alpha1.DynamoModel
_ = k8sClient.Get(ctx, types.NamespacedName{Name: modelName, Namespace: namespace}, &updated)
if len(updated.Status.Endpoints) > 0 {
return updated.Status.Endpoints[0].PodName
}
return ""
}, timeout, interval).Should(Equal("worker-0"))
})
})
Context("When reconciling base model", func() {
It("Should set EndpointsReady=True when endpoints exist", func() {
ctx := context.Background()
namespace := defaultNamespace
modelName := "test-base-model"
baseModelName := "base-model-base"
// Create the DynamoModel
model := &v1alpha1.DynamoModel{
ObjectMeta: metav1.ObjectMeta{
Name: modelName,
Namespace: namespace,
},
Spec: v1alpha1.DynamoModelSpec{
ModelName: modelName,
BaseModelName: baseModelName,
ModelType: "base",
},
}
Expect(k8sClient.Create(ctx, model)).Should(Succeed())
defer func() { _ = k8sClient.Delete(ctx, model) }()
// Create EndpointSlice
trueVal := true
modelHash := dynamo.HashModelName(baseModelName)
endpointSlice := &discoveryv1.EndpointSlice{
ObjectMeta: metav1.ObjectMeta{
Name: "test-base-endpoints",
Namespace: namespace,
Labels: map[string]string{
consts.KubeLabelDynamoBaseModelHash: modelHash,
},
},
AddressType: discoveryv1.AddressTypeIPv4,
Endpoints: []discoveryv1.Endpoint{
{
Addresses: []string{"10.0.1.10"},
Conditions: discoveryv1.EndpointConditions{
Ready: &trueVal,
},
TargetRef: &corev1.ObjectReference{
Kind: "Pod",
Name: "worker-0",
},
},
},
Ports: []discoveryv1.EndpointPort{
{
Port: func() *int32 { p := int32(9090); return &p }(),
},
},
}
Expect(k8sClient.Create(ctx, endpointSlice)).Should(Succeed())
defer func() { _ = k8sClient.Delete(ctx, endpointSlice) }()
// Reconcile
_, err := reconciler.Reconcile(ctx, reconcile.Request{
NamespacedName: types.NamespacedName{
Name: modelName,
Namespace: namespace,
},
})
Expect(err).NotTo(HaveOccurred())
// For base models, EndpointsReady should be True when endpoints exist
Eventually(func() bool {
var updated v1alpha1.DynamoModel
_ = k8sClient.Get(ctx, types.NamespacedName{Name: modelName, Namespace: namespace}, &updated)
for _, cond := range updated.Status.Conditions {
if cond.Type == ConditionTypeEndpointsReady {
return cond.Status == metav1.ConditionTrue && cond.Reason == ReasonEndpointsDiscovered
}
}
return false
}, timeout, interval).Should(BeTrue())
})
It("Should set EndpointsReady=False when no endpoints exist", func() {
ctx := context.Background()
namespace := defaultNamespace
modelName := "test-base-model-no-endpoints"
baseModelName := "base-model-none"
// Create the DynamoModel
model := &v1alpha1.DynamoModel{
ObjectMeta: metav1.ObjectMeta{
Name: modelName,
Namespace: namespace,
},
Spec: v1alpha1.DynamoModelSpec{
ModelName: modelName,
BaseModelName: baseModelName,
ModelType: "base",
},
}
Expect(k8sClient.Create(ctx, model)).Should(Succeed())
defer func() { _ = k8sClient.Delete(ctx, model) }()
// Reconcile (no endpoints created)
_, err := reconciler.Reconcile(ctx, reconcile.Request{
NamespacedName: types.NamespacedName{
Name: modelName,
Namespace: namespace,
},
})
Expect(err).NotTo(HaveOccurred())
// Should have condition set to False with NoEndpoints reason
Eventually(func() bool {
var updated v1alpha1.DynamoModel
_ = k8sClient.Get(ctx, types.NamespacedName{Name: modelName, Namespace: namespace}, &updated)
for _, cond := range updated.Status.Conditions {
if cond.Type == ConditionTypeEndpointsReady {
return cond.Status == metav1.ConditionFalse && cond.Reason == ReasonNoEndpoints
}
}
return false
}, timeout, interval).Should(BeTrue())
})
})
})
...@@ -35,9 +35,7 @@ import ( ...@@ -35,9 +35,7 @@ import (
"emperror.dev/errors" "emperror.dev/errors"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/dynamo/schemas" "github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/dynamo/schemas"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1" "github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/consts"
commonconsts "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/consts" commonconsts "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/consts"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/controller_common"
commonController "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/controller_common" commonController "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/controller_common"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/dynamo" "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/dynamo"
networkingv1beta1 "istio.io/client-go/pkg/apis/networking/v1beta1" networkingv1beta1 "istio.io/client-go/pkg/apis/networking/v1beta1"
...@@ -79,7 +77,7 @@ const ( ...@@ -79,7 +77,7 @@ const (
type DynamoComponentDeploymentReconciler struct { type DynamoComponentDeploymentReconciler struct {
client.Client client.Client
Recorder record.EventRecorder Recorder record.EventRecorder
Config controller_common.Config Config commonController.Config
EtcdStorage etcdStorage EtcdStorage etcdStorage
DockerSecretRetriever dockerSecretRetriever DockerSecretRetriever dockerSecretRetriever
} }
...@@ -327,6 +325,21 @@ func (r *DynamoComponentDeploymentReconciler) Reconcile(ctx context.Context, req ...@@ -327,6 +325,21 @@ func (r *DynamoComponentDeploymentReconciler) Reconcile(ctx context.Context, req
modified = true modified = true
} }
// create or update headless service for model endpoint discovery
componentMap := map[string]*v1alpha1.DynamoComponentDeploymentSharedSpec{
dynamoComponentDeployment.Name: &dynamoComponentDeployment.Spec.DynamoComponentDeploymentSharedSpec,
}
if err := dynamo.ReconcileModelServicesForComponents(
ctx,
r,
dynamoComponentDeployment,
componentMap,
dynamoComponentDeployment.Namespace,
); err != nil {
logs.Error(err, "Failed to reconcile model service")
return ctrl.Result{}, err
}
// create or update api-server ingresses // create or update api-server ingresses
modified_, err = r.createOrUpdateOrDeleteIngress(ctx, generateResourceOption{ modified_, err = r.createOrUpdateOrDeleteIngress(ctx, generateResourceOption{
dynamoComponentDeployment: dynamoComponentDeployment, dynamoComponentDeployment: dynamoComponentDeployment,
...@@ -926,22 +939,29 @@ func (r *DynamoComponentDeploymentReconciler) getGenericServiceName(dynamoCompon ...@@ -926,22 +939,29 @@ func (r *DynamoComponentDeploymentReconciler) getGenericServiceName(dynamoCompon
} }
func (r *DynamoComponentDeploymentReconciler) getKubeLabels(dynamoComponentDeployment *v1alpha1.DynamoComponentDeployment) map[string]string { func (r *DynamoComponentDeploymentReconciler) getKubeLabels(dynamoComponentDeployment *v1alpha1.DynamoComponentDeployment) map[string]string {
if dynamoComponentDeployment != nil && dynamoComponentDeployment.Labels != nil { labels := map[string]string{}
return dynamoComponentDeployment.Labels if dynamoComponentDeployment != nil {
if dynamoComponentDeployment.Spec.Labels != nil {
maps.Copy(labels, dynamoComponentDeployment.Spec.Labels)
} }
return map[string]string{} if dynamoComponentDeployment.Labels != nil {
maps.Copy(labels, dynamoComponentDeployment.Labels)
}
dynamo.AddBaseModelLabel(labels, dynamoComponentDeployment.Spec.ModelRef)
}
return labels
} }
func (r *DynamoComponentDeploymentReconciler) getKubeAnnotations(dynamoComponentDeployment *v1alpha1.DynamoComponentDeployment) map[string]string { func (r *DynamoComponentDeploymentReconciler) getKubeAnnotations(dynamoComponentDeployment *v1alpha1.DynamoComponentDeployment) map[string]string {
annotations := map[string]string{} annotations := map[string]string{}
var extraAnnotations map[string]string if dynamoComponentDeployment != nil {
if dynamoComponentDeployment.Spec.ExtraPodMetadata != nil { if dynamoComponentDeployment.Spec.Annotations != nil {
extraAnnotations = dynamoComponentDeployment.Spec.ExtraPodMetadata.Annotations maps.Copy(annotations, dynamoComponentDeployment.Spec.Annotations)
} else { }
extraAnnotations = map[string]string{} if dynamoComponentDeployment.Spec.ExtraPodMetadata != nil && dynamoComponentDeployment.Spec.ExtraPodMetadata.Annotations != nil {
maps.Copy(annotations, dynamoComponentDeployment.Spec.ExtraPodMetadata.Annotations)
} }
for k, v := range extraAnnotations { dynamo.AddBaseModelAnnotation(annotations, dynamoComponentDeployment.Spec.ModelRef)
annotations[k] = v
} }
return annotations return annotations
} }
...@@ -1154,7 +1174,7 @@ func (r *DynamoComponentDeploymentReconciler) generatePodTemplateSpec(ctx contex ...@@ -1154,7 +1174,7 @@ func (r *DynamoComponentDeploymentReconciler) generatePodTemplateSpec(ctx contex
isDebugModeEnabled := checkIfIsDebugModeEnabled(resourceAnnotations) isDebugModeEnabled := checkIfIsDebugModeEnabled(resourceAnnotations)
podSpec, err := dynamo.GenerateBasePodSpecForController(opt.dynamoComponentDeployment, r.DockerSecretRetriever, r.Config, role, consts.MultinodeDeploymentTypeLWS) podSpec, err := dynamo.GenerateBasePodSpecForController(opt.dynamoComponentDeployment, r.DockerSecretRetriever, r.Config, role, commonconsts.MultinodeDeploymentTypeLWS)
if err != nil { if err != nil {
err = errors.Wrap(err, "failed to generate base pod spec") err = errors.Wrap(err, "failed to generate base pod spec")
return nil, err return nil, err
...@@ -1332,7 +1352,7 @@ func (r *DynamoComponentDeploymentReconciler) SetupWithManager(mgr ctrl.Manager) ...@@ -1332,7 +1352,7 @@ func (r *DynamoComponentDeploymentReconciler) SetupWithManager(mgr ctrl.Manager)
Owns(&corev1.Service{}, builder.WithPredicates(predicate.GenerationChangedPredicate{})). Owns(&corev1.Service{}, builder.WithPredicates(predicate.GenerationChangedPredicate{})).
Owns(&networkingv1.Ingress{}, builder.WithPredicates(predicate.GenerationChangedPredicate{})). Owns(&networkingv1.Ingress{}, builder.WithPredicates(predicate.GenerationChangedPredicate{})).
Owns(&corev1.PersistentVolumeClaim{}, builder.WithPredicates(predicate.GenerationChangedPredicate{})). Owns(&corev1.PersistentVolumeClaim{}, builder.WithPredicates(predicate.GenerationChangedPredicate{})).
WithEventFilter(controller_common.EphemeralDeploymentEventFilter(r.Config)) WithEventFilter(commonController.EphemeralDeploymentEventFilter(r.Config))
if r.Config.LWS.Enabled { if r.Config.LWS.Enabled {
m.Owns(&leaderworkersetv1.LeaderWorkerSet{}, builder.WithPredicates(predicate.Funcs{ m.Owns(&leaderworkersetv1.LeaderWorkerSet{}, builder.WithPredicates(predicate.Funcs{
......
...@@ -341,6 +341,18 @@ func (r *DynamoGraphDeploymentReconciler) reconcileGroveResources(ctx context.Co ...@@ -341,6 +341,18 @@ func (r *DynamoGraphDeploymentReconciler) reconcileGroveResources(ctx context.Co
return "", "", "", fmt.Errorf("failed to reconcile Grove scaling: %w", err) return "", "", "", fmt.Errorf("failed to reconcile Grove scaling: %w", err)
} }
// Reconcile headless services for model endpoint discovery
if err := dynamo.ReconcileModelServicesForComponents(
ctx,
r,
dynamoDeployment,
dynamoDeployment.Spec.Services,
dynamoDeployment.Namespace,
); err != nil {
logger.Error(err, "failed to reconcile model services")
return "", "", "", fmt.Errorf("failed to reconcile model services: %w", err)
}
resources := []Resource{groveGangSetAsResource} resources := []Resource{groveGangSetAsResource}
for componentName, component := range dynamoDeployment.Spec.Services { for componentName, component := range dynamoDeployment.Spec.Services {
if component.ComponentType == consts.ComponentTypeFrontend { if component.ComponentType == consts.ComponentTypeFrontend {
......
...@@ -33,6 +33,7 @@ import ( ...@@ -33,6 +33,7 @@ import (
appsv1 "k8s.io/api/apps/v1" appsv1 "k8s.io/api/apps/v1"
autoscalingv2 "k8s.io/api/autoscaling/v2" autoscalingv2 "k8s.io/api/autoscaling/v2"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
discoveryv1 "k8s.io/api/discovery/v1"
networkingv1 "k8s.io/api/networking/v1" networkingv1 "k8s.io/api/networking/v1"
k8sruntime "k8s.io/apimachinery/pkg/runtime" k8sruntime "k8s.io/apimachinery/pkg/runtime"
...@@ -109,6 +110,8 @@ var _ = BeforeSuite(func() { ...@@ -109,6 +110,8 @@ var _ = BeforeSuite(func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
err = appsv1.AddToScheme(scheme) err = appsv1.AddToScheme(scheme)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
err = discoveryv1.AddToScheme(scheme)
Expect(err).NotTo(HaveOccurred())
err = monitoringv1.AddToScheme(scheme) err = monitoringv1.AddToScheme(scheme)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
err = admissionregistrationv1.AddToScheme(scheme) err = admissionregistrationv1.AddToScheme(scheme)
......
...@@ -1005,6 +1005,8 @@ func generateLabels(component *v1alpha1.DynamoComponentDeploymentSharedSpec, dyn ...@@ -1005,6 +1005,8 @@ func generateLabels(component *v1alpha1.DynamoComponentDeploymentSharedSpec, dyn
if component.SubComponentType != "" { if component.SubComponentType != "" {
labels[commonconsts.KubeLabelDynamoSubComponentType] = component.SubComponentType labels[commonconsts.KubeLabelDynamoSubComponentType] = component.SubComponentType
} }
// Add base model label if modelRef is specified
AddBaseModelLabel(labels, component.ModelRef)
setMetricsLabels(labels, dynamoDeployment) setMetricsLabels(labels, dynamoDeployment)
if component.Labels != nil { if component.Labels != nil {
err := mergo.Merge(&labels, component.Labels, mergo.WithOverride) err := mergo.Merge(&labels, component.Labels, mergo.WithOverride)
......
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dynamo
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1"
commonconsts "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/consts"
commonController "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/controller_common"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/intstr"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/log"
)
// ReconcileModelServicesForComponents creates services for components with modelRef
// This is common logic used by both DynamoGraphDeployment and DynamoComponentDeployment controllers
// reconciler must implement controller_common.Reconciler interface
func ReconcileModelServicesForComponents(
ctx context.Context,
reconciler commonController.Reconciler,
owner client.Object,
components map[string]*v1alpha1.DynamoComponentDeploymentSharedSpec,
namespace string,
) error {
logger := log.FromContext(ctx)
// Track unique base models to avoid creating duplicate services
seenBaseModels := make(map[string]bool)
for componentName, component := range components {
// Skip if no modelRef
if component.ModelRef == nil || component.ModelRef.Name == "" {
continue
}
baseModelName := component.ModelRef.Name
// Skip if we've already created service for this base model
if seenBaseModels[baseModelName] {
logger.V(1).Info("Skipping duplicate headless service for base model",
"componentName", componentName,
"baseModelName", baseModelName)
continue
}
seenBaseModels[baseModelName] = true
// Generate headless service with deterministic name based on model name
headlessService := generateHeadlessServiceForModel(
namespace,
baseModelName,
)
// Sync the service (create or update)
_, syncedService, err := commonController.SyncResource(
ctx,
reconciler,
owner,
func(ctx context.Context) (*corev1.Service, bool, error) {
return headlessService, false, nil
},
)
if err != nil {
logger.Error(err, "Failed to sync headless service for model",
"baseModelName", baseModelName,
"componentName", componentName)
return fmt.Errorf("failed to sync headless service for model %s: %w", baseModelName, err)
}
logger.Info("Synced headless service for model",
"serviceName", syncedService.GetName(),
"baseModelName", baseModelName,
"namespace", namespace)
}
return nil
}
// GenerateHeadlessServiceForModel creates a headless service for model endpoint discovery
// Service name is generated deterministically from the base model name using a hash
// The base model name hash is stored as a label for efficient discovery
// The original base model name is stored in an annotation for human readability
func generateHeadlessServiceForModel(
namespace string,
baseModelName string,
) *corev1.Service {
// Generate deterministic service name from model name
serviceName := GenerateServiceName(baseModelName)
// Hash the base model name for use in labels (no length or character restrictions)
modelHash := HashModelName(baseModelName)
service := &corev1.Service{
ObjectMeta: metav1.ObjectMeta{
Name: serviceName,
Namespace: namespace,
Labels: map[string]string{
commonconsts.KubeLabelDynamoBaseModelHash: modelHash,
"nvidia.com/managed-by": "dynamo-operator",
},
Annotations: map[string]string{
commonconsts.KubeAnnotationDynamoBaseModel: baseModelName, // Original name for humans
},
},
Spec: corev1.ServiceSpec{
// Headless service - no ClusterIP, no load balancing
ClusterIP: corev1.ClusterIPNone,
// Selector to match pods with the base model hash label
Selector: map[string]string{
commonconsts.KubeLabelDynamoBaseModelHash: modelHash,
},
// Don't publish not-ready addresses - only ready pods in EndpointSlices
PublishNotReadyAddresses: false,
// System port for model HTTP APIs
Ports: []corev1.ServicePort{
{
Name: commonconsts.DynamoSystemPortName,
Port: commonconsts.DynamoSystemPort,
TargetPort: intstr.FromInt32(commonconsts.DynamoSystemPort),
Protocol: corev1.ProtocolTCP,
},
},
},
}
return service
}
// HashModelName creates a deterministic hash from a base model name for use in labels
// Returns an 8-character hex string (always valid as a Kubernetes label value)
func HashModelName(baseModelName string) string {
hash := sha256.Sum256([]byte(baseModelName))
// Use 8 characters for brevity and consistency
return hex.EncodeToString(hash[:])[:8]
}
// GenerateServiceName creates a deterministic, DNS-safe service name from a base model name
// Format: dynamo-model-{8-char-hash}
func GenerateServiceName(baseModelName string) string {
return fmt.Sprintf("dynamo-model-%s", HashModelName(baseModelName))
}
// AddBaseModelLabel adds the base model hash label to a label map if modelRef is present
// Uses a hash of the model name to avoid label length/character restrictions
func AddBaseModelLabel(labels map[string]string, modelRef *v1alpha1.ModelReference) {
if labels == nil || modelRef == nil || modelRef.Name == "" {
return
}
labels[commonconsts.KubeLabelDynamoBaseModelHash] = HashModelName(modelRef.Name)
}
// AddBaseModelAnnotation adds the base model annotation to preserve the original model name
func AddBaseModelAnnotation(annotations map[string]string, modelRef *v1alpha1.ModelReference) {
if annotations == nil || modelRef == nil || modelRef.Name == "" {
return
}
annotations[commonconsts.KubeAnnotationDynamoBaseModel] = modelRef.Name
}
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package modelendpoint
import (
"context"
"fmt"
"net/http"
"time"
"sigs.k8s.io/controller-runtime/pkg/log"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/workerpool"
)
const (
// MaxConcurrentOperations is the maximum number of concurrent endpoint operations
MaxConcurrentOperations = 10
// RequestTimeout is the timeout for individual HTTP requests
RequestTimeout = 15 * time.Second
// TotalTimeout is the timeout for all operations to complete
TotalTimeout = 30 * time.Second
)
// Client handles HTTP communication with model endpoint control APIs
type Client struct {
httpClient *http.Client
}
// NewClient creates a new model endpoint client
func NewClient() *Client {
return &Client{
httpClient: &http.Client{
Timeout: RequestTimeout,
},
}
}
// LoadLoRA loads a LoRA model on all endpoints in parallel with bounded concurrency
// Returns endpoint info with ready status and partial results even if some endpoints fail
func (c *Client) LoadLoRA(
ctx context.Context,
candidates []Candidate,
model *v1alpha1.DynamoModel,
) ([]v1alpha1.EndpointInfo, error) {
logs := log.FromContext(ctx)
// Skip loading for non-LoRA models
if !model.IsLoRA() {
logs.V(1).Info("Skipping LoRA load for non-LoRA model", "modelType", model.Spec.ModelType)
endpoints := make([]v1alpha1.EndpointInfo, len(candidates))
for i, c := range candidates {
endpoints[i] = v1alpha1.EndpointInfo{
Address: c.Address,
PodName: c.PodName,
Ready: false,
}
}
return endpoints, nil
}
// Get source URI for LoRA loading
sourceURI := ""
if model.Spec.Source != nil {
sourceURI = model.Spec.Source.URI
}
if sourceURI == "" {
logs.Error(nil, "Source URI is required for LoRA models")
return nil, fmt.Errorf("source URI is required for LoRA models")
}
// Build tasks for the worker pool
tasks := make([]workerpool.Task[v1alpha1.EndpointInfo], len(candidates))
for i, candidate := range candidates {
tasks[i] = workerpool.Task[v1alpha1.EndpointInfo]{
Index: i,
Work: func(ctx context.Context) (v1alpha1.EndpointInfo, error) {
// Load the LoRA on this endpoint (idempotent operation)
err := c.loadLoRA(ctx, candidate.Address, model.Spec.ModelName, sourceURI)
ready := err == nil
return v1alpha1.EndpointInfo{
Address: candidate.Address,
PodName: candidate.PodName,
Ready: ready,
}, err
},
}
}
// Execute all load operations in parallel with bounded concurrency
results, err := workerpool.Execute(ctx, MaxConcurrentOperations, TotalTimeout, tasks)
// Extract endpoint info from results and collect failures
endpoints := make([]v1alpha1.EndpointInfo, len(results))
readyCount := 0
var notReadyEndpoints []string
for _, result := range results {
endpoints[result.Index] = result.Value
if result.Value.Ready {
readyCount++
} else {
notReadyEndpoints = append(notReadyEndpoints, result.Value.Address)
if result.Err != nil {
logs.Info("Endpoint load operation failed",
"address", result.Value.Address,
"podName", result.Value.PodName,
"error", result.Err)
}
}
}
logs.Info("Completed parallel LoRA load operations",
"total", len(endpoints),
"ready", readyCount,
"notReady", len(notReadyEndpoints),
"notReadyEndpoints", notReadyEndpoints)
return endpoints, err
}
// UnloadLoRA unloads a LoRA model from all endpoints in parallel
func (c *Client) UnloadLoRA(ctx context.Context, candidates []Candidate, modelName string) error {
logs := log.FromContext(ctx)
if len(candidates) == 0 {
logs.Info("No candidates to unload LoRA from")
return nil
}
logs.Info("Starting parallel LoRA unload", "endpointCount", len(candidates), "modelName", modelName)
// Build tasks for the worker pool
tasks := make([]workerpool.Task[bool], len(candidates))
for i, candidate := range candidates {
tasks[i] = workerpool.Task[bool]{
Index: i,
Work: func(ctx context.Context) (bool, error) {
// Unload the LoRA from this endpoint (calls method in lora.go)
err := c.unloadLoRA(ctx, candidate.Address, modelName)
if err != nil {
return false, err
}
return true, nil
},
}
}
// Execute all unload operations in parallel with bounded concurrency
results, err := workerpool.Execute(ctx, MaxConcurrentOperations, TotalTimeout, tasks)
// Collect successes and failures with details
successCount := 0
var failedEndpoints []string
for _, result := range results {
if result.Value {
successCount++
} else {
// Log failed endpoint with error details
endpoint := candidates[result.Index].Address
failedEndpoints = append(failedEndpoints, endpoint)
logs.Info("Failed to unload LoRA from endpoint",
"address", endpoint,
"podName", candidates[result.Index].PodName,
"error", result.Err)
}
}
logs.Info("Completed parallel LoRA unload",
"total", len(candidates),
"successful", successCount,
"failed", len(failedEndpoints),
"failedEndpoints", failedEndpoints)
return err
}
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package modelendpoint
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1"
)
func TestLoadLoRA(t *testing.T) {
// Create test servers for different scenarios
successServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify HTTP method
if r.Method != http.MethodPost {
t.Errorf("expected POST method, got %s", r.Method)
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
// Verify Content-Type header
if r.Header.Get("Content-Type") != "application/json" {
t.Errorf("expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
}
w.WriteHeader(http.StatusOK)
}))
defer successServer.Close()
failingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify HTTP method even for failing requests
if r.Method != http.MethodPost {
t.Errorf("expected POST method, got %s", r.Method)
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
w.WriteHeader(http.StatusInternalServerError)
}))
defer failingServer.Close()
tests := []struct {
name string
modelType string
sourceURI string
candidates []Candidate
expectError bool
errorContains string
expectedCount int
expectedReadyCount int
}{
{
name: "non-lora model - skips loading",
modelType: "base",
candidates: []Candidate{{Address: "http://10.0.1.5:9090", PodName: "pod-1"}},
expectError: false,
expectedCount: 1,
expectedReadyCount: 0,
},
{
name: "empty candidates",
modelType: "base",
candidates: []Candidate{},
expectError: false,
expectedCount: 0,
expectedReadyCount: 0,
},
{
name: "lora with nil source",
modelType: "lora",
sourceURI: "",
candidates: []Candidate{{Address: "http://10.0.1.5:9090", PodName: "pod-1"}},
expectError: true,
errorContains: "source URI is required",
},
{
name: "lora with valid source - all success",
modelType: "lora",
sourceURI: "s3://bucket/model",
candidates: []Candidate{
{Address: successServer.URL, PodName: "pod-1"},
{Address: successServer.URL, PodName: "pod-2"},
},
expectError: false,
expectedCount: 2,
expectedReadyCount: 2,
},
{
name: "lora with valid source - partial failure",
modelType: "lora",
sourceURI: "s3://bucket/model",
candidates: []Candidate{
{Address: successServer.URL, PodName: "pod-1"},
{Address: failingServer.URL, PodName: "pod-2"},
},
expectError: true, // workerpool returns error on any failure
expectedCount: 2,
expectedReadyCount: 1,
},
{
name: "lora with huggingface source",
modelType: "lora",
sourceURI: "hf://org/model@v1.0",
candidates: []Candidate{
{Address: successServer.URL, PodName: "pod-1"},
},
expectError: false,
expectedCount: 1,
expectedReadyCount: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := NewClient()
ctx := context.Background()
var source *v1alpha1.ModelSource
if tt.sourceURI != "" {
source = &v1alpha1.ModelSource{URI: tt.sourceURI}
}
model := &v1alpha1.DynamoModel{
ObjectMeta: metav1.ObjectMeta{
Name: "test-model",
Namespace: "default",
},
Spec: v1alpha1.DynamoModelSpec{
ModelName: "test-model",
ModelType: tt.modelType,
Source: source,
},
}
endpoints, err := client.LoadLoRA(ctx, tt.candidates, model)
// Check error expectation
if tt.expectError && tt.errorContains != "" {
// For validation errors (like missing source URI), we return early
if err == nil {
t.Error("expected error but got none")
} else if !strings.Contains(err.Error(), tt.errorContains) {
t.Errorf("expected error to contain %q, got %v", tt.errorContains, err)
}
return
}
// For partial failures, we expect an error but still get endpoints
if tt.expectError && err == nil {
t.Error("expected error for partial failure but got none")
}
if !tt.expectError && err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Verify endpoint count
if len(endpoints) != tt.expectedCount {
t.Errorf("expected %d endpoints, got %d", tt.expectedCount, len(endpoints))
}
// Count ready endpoints
readyCount := 0
for _, ep := range endpoints {
if ep.Ready {
readyCount++
}
}
if readyCount != tt.expectedReadyCount {
t.Errorf("expected %d ready endpoints, got %d", tt.expectedReadyCount, readyCount)
}
})
}
}
func TestUnloadLoRA(t *testing.T) {
successServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify HTTP method
if r.Method != http.MethodDelete {
t.Errorf("expected DELETE method, got %s", r.Method)
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
// Verify URL path contains model name
if !strings.Contains(r.URL.Path, "/loras/") {
t.Errorf("expected URL path to contain /loras/, got %s", r.URL.Path)
}
w.WriteHeader(http.StatusOK)
}))
defer successServer.Close()
failingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify HTTP method even for failing requests
if r.Method != http.MethodDelete {
t.Errorf("expected DELETE method, got %s", r.Method)
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
w.WriteHeader(http.StatusInternalServerError)
}))
defer failingServer.Close()
tests := []struct {
name string
candidates []Candidate
modelName string
expectError bool
}{
{
name: "empty candidates",
candidates: []Candidate{},
modelName: "test-model",
expectError: false,
},
{
name: "single endpoint success",
candidates: []Candidate{
{Address: successServer.URL, PodName: "pod-1"},
},
modelName: "test-model",
expectError: false,
},
{
name: "multiple endpoints success",
candidates: []Candidate{
{Address: successServer.URL, PodName: "pod-1"},
{Address: successServer.URL, PodName: "pod-2"},
},
modelName: "test-model",
expectError: false,
},
{
name: "partial failure",
candidates: []Candidate{
{Address: successServer.URL, PodName: "pod-1"},
{Address: failingServer.URL, PodName: "pod-2"},
},
modelName: "test-model",
expectError: true, // workerpool returns error on any failure
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := NewClient()
ctx := context.Background()
err := client.UnloadLoRA(ctx, tt.candidates, tt.modelName)
if tt.expectError && err == nil {
t.Error("expected error but got none")
} else if !tt.expectError && err != nil {
t.Errorf("unexpected error: %v", err)
}
})
}
}
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package modelendpoint
import (
"context"
"net"
"strconv"
discoveryv1 "k8s.io/api/discovery/v1"
"k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/controller-runtime/pkg/reconcile"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1"
)
// ExtractCandidates extracts endpoint candidates from EndpointSlices
// Returns all pod-backed endpoints regardless of ready status
// The readiness will be determined by probing (for LoRA models) or set to false (for base models)
func ExtractCandidates(endpointSlices *discoveryv1.EndpointSliceList, port int32) ([]Candidate, map[string]bool) {
var candidates []Candidate
serviceNames := make(map[string]bool)
for _, slice := range endpointSlices.Items {
serviceName := slice.Labels[discoveryv1.LabelServiceName]
if serviceName != "" {
serviceNames[serviceName] = true
}
for _, ep := range slice.Endpoints {
if len(ep.Addresses) == 0 {
continue
}
// Get pod name from TargetRef - skip if not a Pod
if ep.TargetRef == nil || ep.TargetRef.Kind != "Pod" {
continue
}
podName := ep.TargetRef.Name
for _, addr := range ep.Addresses {
address := "http://" + net.JoinHostPort(addr, strconv.Itoa(int(port)))
candidates = append(candidates, Candidate{
Address: address,
PodName: podName,
})
}
}
}
return candidates, serviceNames
}
// FindModelsForBaseModel finds all DynamoModels that match a specific index value
// Uses field indexer for efficient O(1) lookup
// The indexValue can be a base model name or hash, depending on the indexField
func FindModelsForBaseModel(
ctx context.Context,
c client.Client,
namespace string,
indexValue string,
indexField string,
) ([]reconcile.Request, error) {
logs := log.FromContext(ctx)
models := &v1alpha1.DynamoModelList{}
if err := c.List(ctx, models,
client.InNamespace(namespace),
client.MatchingFields{indexField: indexValue},
); err != nil {
logs.Error(err, "Failed to list DynamoModels", "indexField", indexField, "indexValue", indexValue)
return nil, err
}
requests := make([]reconcile.Request, 0, len(models.Items))
for _, model := range models.Items {
requests = append(requests, reconcile.Request{
NamespacedName: types.NamespacedName{
Name: model.Name,
Namespace: model.Namespace,
},
})
}
if len(requests) > 0 {
logs.V(1).Info("Found DynamoModels for index value",
"indexField", indexField,
"indexValue", indexValue,
"count", len(requests))
}
return requests, nil
}
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package modelendpoint
import (
"context"
"testing"
corev1 "k8s.io/api/core/v1"
discoveryv1 "k8s.io/api/discovery/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/client/fake"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1"
)
const (
testPodWorker0 = "worker-0"
testPodWorker1 = "worker-1"
testPodWorker2 = "worker-2"
)
func TestExtractCandidates(t *testing.T) {
trueVal := true
falseVal := false
tests := []struct {
name string
endpointSlices *discoveryv1.EndpointSliceList
port int32
expectedCandidates int
expectedServiceNames map[string]bool
validateCandidates func(t *testing.T, candidates []Candidate)
}{
{
name: "empty endpoint slice list",
endpointSlices: &discoveryv1.EndpointSliceList{
Items: []discoveryv1.EndpointSlice{},
},
port: 9090,
expectedCandidates: 0,
expectedServiceNames: map[string]bool{},
},
{
name: "endpoint with pod target ref - included",
endpointSlices: &discoveryv1.EndpointSliceList{
Items: []discoveryv1.EndpointSlice{
{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
discoveryv1.LabelServiceName: "my-service",
},
},
Endpoints: []discoveryv1.Endpoint{
{
Addresses: []string{"10.0.1.5"},
Conditions: discoveryv1.EndpointConditions{
Ready: &trueVal,
},
TargetRef: &corev1.ObjectReference{
Kind: "Pod",
Name: testPodWorker0,
},
},
},
},
},
},
port: 9090,
expectedCandidates: 1,
expectedServiceNames: map[string]bool{
"my-service": true,
},
validateCandidates: func(t *testing.T, candidates []Candidate) {
if len(candidates) != 1 {
t.Fatalf("expected 1 candidate, got %d", len(candidates))
}
if candidates[0].Address != "http://10.0.1.5:9090" {
t.Errorf("expected address http://10.0.1.5:9090, got %s", candidates[0].Address)
}
if candidates[0].PodName != testPodWorker0 {
t.Errorf("expected podName %s, got %s", testPodWorker0, candidates[0].PodName)
}
},
},
{
name: "endpoint with nil target ref - skipped",
endpointSlices: &discoveryv1.EndpointSliceList{
Items: []discoveryv1.EndpointSlice{
{
Endpoints: []discoveryv1.Endpoint{
{
Addresses: []string{"10.0.1.5"},
Conditions: discoveryv1.EndpointConditions{
Ready: &trueVal,
},
TargetRef: nil,
},
},
},
},
},
port: 9090,
expectedCandidates: 0,
expectedServiceNames: map[string]bool{},
},
{
name: "endpoint with non-pod target ref - skipped",
endpointSlices: &discoveryv1.EndpointSliceList{
Items: []discoveryv1.EndpointSlice{
{
Endpoints: []discoveryv1.Endpoint{
{
Addresses: []string{"10.0.1.5"},
Conditions: discoveryv1.EndpointConditions{
Ready: &trueVal,
},
TargetRef: &corev1.ObjectReference{
Kind: "Node",
Name: "node-1",
},
},
},
},
},
},
port: 9090,
expectedCandidates: 0,
expectedServiceNames: map[string]bool{},
},
{
name: "endpoint not ready - included",
endpointSlices: &discoveryv1.EndpointSliceList{
Items: []discoveryv1.EndpointSlice{
{
Endpoints: []discoveryv1.Endpoint{
{
Addresses: []string{"10.0.1.5"},
Conditions: discoveryv1.EndpointConditions{
Ready: &falseVal,
},
TargetRef: &corev1.ObjectReference{
Kind: "Pod",
Name: testPodWorker0,
},
},
},
},
},
},
port: 9090,
expectedCandidates: 1,
expectedServiceNames: map[string]bool{},
},
{
name: "endpoint with nil ready condition - included",
endpointSlices: &discoveryv1.EndpointSliceList{
Items: []discoveryv1.EndpointSlice{
{
Endpoints: []discoveryv1.Endpoint{
{
Addresses: []string{"10.0.1.5"},
Conditions: discoveryv1.EndpointConditions{
Ready: nil,
},
TargetRef: &corev1.ObjectReference{
Kind: "Pod",
Name: testPodWorker0,
},
},
},
},
},
},
port: 9090,
expectedCandidates: 1,
expectedServiceNames: map[string]bool{},
},
{
name: "endpoint with no addresses - skipped",
endpointSlices: &discoveryv1.EndpointSliceList{
Items: []discoveryv1.EndpointSlice{
{
Endpoints: []discoveryv1.Endpoint{
{
Addresses: []string{},
Conditions: discoveryv1.EndpointConditions{
Ready: &trueVal,
},
TargetRef: &corev1.ObjectReference{
Kind: "Pod",
Name: testPodWorker0,
},
},
},
},
},
},
port: 9090,
expectedCandidates: 0,
expectedServiceNames: map[string]bool{},
},
{
name: "multiple valid endpoints",
endpointSlices: &discoveryv1.EndpointSliceList{
Items: []discoveryv1.EndpointSlice{
{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
discoveryv1.LabelServiceName: "service-1",
},
},
Endpoints: []discoveryv1.Endpoint{
{
Addresses: []string{"10.0.1.5"},
Conditions: discoveryv1.EndpointConditions{
Ready: &trueVal,
},
TargetRef: &corev1.ObjectReference{
Kind: "Pod",
Name: testPodWorker0,
},
},
{
Addresses: []string{"10.0.1.6"},
Conditions: discoveryv1.EndpointConditions{
Ready: &trueVal,
},
TargetRef: &corev1.ObjectReference{
Kind: "Pod",
Name: testPodWorker1,
},
},
},
},
},
},
port: 9090,
expectedCandidates: 2,
expectedServiceNames: map[string]bool{
"service-1": true,
},
validateCandidates: func(t *testing.T, candidates []Candidate) {
if len(candidates) != 2 {
t.Fatalf("expected 2 candidates, got %d", len(candidates))
}
},
},
{
name: "mixed valid and invalid endpoints",
endpointSlices: &discoveryv1.EndpointSliceList{
Items: []discoveryv1.EndpointSlice{
{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
discoveryv1.LabelServiceName: "my-service",
},
},
Endpoints: []discoveryv1.Endpoint{
{
Addresses: []string{"10.0.1.5"},
Conditions: discoveryv1.EndpointConditions{
Ready: &trueVal,
},
TargetRef: &corev1.ObjectReference{
Kind: "Pod",
Name: testPodWorker0,
},
},
{
Addresses: []string{"10.0.1.6"},
Conditions: discoveryv1.EndpointConditions{
Ready: &falseVal, // Not ready - now included (readiness determined by probing)
},
TargetRef: &corev1.ObjectReference{
Kind: "Pod",
Name: testPodWorker1,
},
},
{
Addresses: []string{"10.0.1.7"},
Conditions: discoveryv1.EndpointConditions{
Ready: &trueVal,
},
TargetRef: &corev1.ObjectReference{
Kind: "Node", // Not a Pod - should be skipped
Name: "node-1",
},
},
{
Addresses: []string{"10.0.1.8"},
Conditions: discoveryv1.EndpointConditions{
Ready: &trueVal,
},
TargetRef: nil, // Nil TargetRef - should be skipped
},
{
Addresses: []string{"10.0.1.9"},
Conditions: discoveryv1.EndpointConditions{
Ready: &trueVal,
},
TargetRef: &corev1.ObjectReference{
Kind: "Pod",
Name: testPodWorker2,
},
},
},
},
},
},
port: 9090,
expectedCandidates: 3, // testPodWorker0, testPodWorker1 (unready), and testPodWorker2
expectedServiceNames: map[string]bool{
"my-service": true,
},
validateCandidates: func(t *testing.T, candidates []Candidate) {
if len(candidates) != 3 {
t.Fatalf("expected 3 candidates, got %d", len(candidates))
}
// Verify only valid pods are included (all 3 pod-backed endpoints)
validPods := map[string]bool{testPodWorker0: false, testPodWorker1: false, testPodWorker2: false}
for _, c := range candidates {
if _, exists := validPods[c.PodName]; exists {
validPods[c.PodName] = true
} else {
t.Errorf("unexpected pod in candidates: %s", c.PodName)
}
}
for pod, found := range validPods {
if !found {
t.Errorf("expected pod %s not found in candidates", pod)
}
}
},
},
{
name: "endpoint with multiple addresses",
endpointSlices: &discoveryv1.EndpointSliceList{
Items: []discoveryv1.EndpointSlice{
{
Endpoints: []discoveryv1.Endpoint{
{
Addresses: []string{"10.0.1.5", "10.0.2.5"},
Conditions: discoveryv1.EndpointConditions{
Ready: &trueVal,
},
TargetRef: &corev1.ObjectReference{
Kind: "Pod",
Name: testPodWorker0,
},
},
},
},
},
},
port: 9090,
expectedCandidates: 2, // One candidate per address
expectedServiceNames: map[string]bool{},
validateCandidates: func(t *testing.T, candidates []Candidate) {
if len(candidates) != 2 {
t.Fatalf("expected 2 candidates, got %d", len(candidates))
}
// Both should have the same pod name
if candidates[0].PodName != testPodWorker0 || candidates[1].PodName != testPodWorker0 {
t.Errorf("expected both candidates to have podName %s", testPodWorker0)
}
},
},
{
name: "multiple services",
endpointSlices: &discoveryv1.EndpointSliceList{
Items: []discoveryv1.EndpointSlice{
{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
discoveryv1.LabelServiceName: "service-1",
},
},
Endpoints: []discoveryv1.Endpoint{
{
Addresses: []string{"10.0.1.5"},
Conditions: discoveryv1.EndpointConditions{
Ready: &trueVal,
},
TargetRef: &corev1.ObjectReference{
Kind: "Pod",
Name: testPodWorker0,
},
},
},
},
{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
discoveryv1.LabelServiceName: "service-2",
},
},
Endpoints: []discoveryv1.Endpoint{
{
Addresses: []string{"10.0.1.6"},
Conditions: discoveryv1.EndpointConditions{
Ready: &trueVal,
},
TargetRef: &corev1.ObjectReference{
Kind: "Pod",
Name: testPodWorker1,
},
},
},
},
},
},
port: 9090,
expectedCandidates: 2,
expectedServiceNames: map[string]bool{
"service-1": true,
"service-2": true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
candidates, serviceNames := ExtractCandidates(tt.endpointSlices, tt.port)
// Check candidate count
if len(candidates) != tt.expectedCandidates {
t.Errorf("expected %d candidates, got %d", tt.expectedCandidates, len(candidates))
}
// Check service names
if len(serviceNames) != len(tt.expectedServiceNames) {
t.Errorf("expected %d service names, got %d", len(tt.expectedServiceNames), len(serviceNames))
}
for name := range tt.expectedServiceNames {
if !serviceNames[name] {
t.Errorf("expected service name %s not found", name)
}
}
// Run additional validation if provided
if tt.validateCandidates != nil {
tt.validateCandidates(t, candidates)
}
})
}
}
func TestFindModelsForBaseModel(t *testing.T) {
tests := []struct {
name string
namespace string
baseModelName string
indexField string
existingModels []v1alpha1.DynamoModel
expectedCount int
expectedNames []string
expectError bool
}{
{
name: "finds multiple models for base model",
namespace: "default",
baseModelName: "llama-2-7b",
indexField: ".spec.baseModelName",
existingModels: []v1alpha1.DynamoModel{
{
ObjectMeta: metav1.ObjectMeta{
Name: "lora-1",
Namespace: "default",
},
Spec: v1alpha1.DynamoModelSpec{
ModelName: "lora-1",
BaseModelName: "llama-2-7b",
ModelType: "lora",
},
},
{
ObjectMeta: metav1.ObjectMeta{
Name: "lora-2",
Namespace: "default",
},
Spec: v1alpha1.DynamoModelSpec{
ModelName: "lora-2",
BaseModelName: "llama-2-7b",
ModelType: "lora",
},
},
{
ObjectMeta: metav1.ObjectMeta{
Name: "different-base",
Namespace: "default",
},
Spec: v1alpha1.DynamoModelSpec{
ModelName: "different-base",
BaseModelName: "gpt-3",
ModelType: "lora",
},
},
},
expectedCount: 2,
expectedNames: []string{"lora-1", "lora-2"},
expectError: false,
},
{
name: "finds no models for base model",
namespace: "default",
baseModelName: "non-existent-base",
indexField: ".spec.baseModelName",
existingModels: []v1alpha1.DynamoModel{
{
ObjectMeta: metav1.ObjectMeta{
Name: "lora-1",
Namespace: "default",
},
Spec: v1alpha1.DynamoModelSpec{
ModelName: "lora-1",
BaseModelName: "llama-2-7b",
ModelType: "lora",
},
},
},
expectedCount: 0,
expectedNames: []string{},
expectError: false,
},
{
name: "filters by namespace",
namespace: "ns1",
baseModelName: "llama-2-7b",
indexField: ".spec.baseModelName",
existingModels: []v1alpha1.DynamoModel{
{
ObjectMeta: metav1.ObjectMeta{
Name: "lora-ns1",
Namespace: "ns1",
},
Spec: v1alpha1.DynamoModelSpec{
ModelName: "lora-ns1",
BaseModelName: "llama-2-7b",
ModelType: "lora",
},
},
{
ObjectMeta: metav1.ObjectMeta{
Name: "lora-ns2",
Namespace: "ns2",
},
Spec: v1alpha1.DynamoModelSpec{
ModelName: "lora-ns2",
BaseModelName: "llama-2-7b",
ModelType: "lora",
},
},
},
expectedCount: 1,
expectedNames: []string{"lora-ns1"},
expectError: false,
},
{
name: "handles empty model list",
namespace: "default",
baseModelName: "any-base",
indexField: ".spec.baseModelName",
existingModels: []v1alpha1.DynamoModel{},
expectedCount: 0,
expectedNames: []string{},
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a fake client with existing models
scheme := runtime.NewScheme()
_ = v1alpha1.AddToScheme(scheme)
_ = corev1.AddToScheme(scheme)
objs := make([]client.Object, len(tt.existingModels))
for i := range tt.existingModels {
objs[i] = &tt.existingModels[i]
}
// Create fake client with indexer support
fakeClient := fake.NewClientBuilder().
WithScheme(scheme).
WithObjects(objs...).
WithIndex(&v1alpha1.DynamoModel{}, tt.indexField, func(obj client.Object) []string {
model := obj.(*v1alpha1.DynamoModel)
return []string{model.Spec.BaseModelName}
}).
Build()
ctx := context.Background()
// Call the function
requests, err := FindModelsForBaseModel(ctx, fakeClient, tt.namespace, tt.baseModelName, tt.indexField)
// Verify error
if tt.expectError {
if err == nil {
t.Error("expected error but got none")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Verify count
if len(requests) != tt.expectedCount {
t.Errorf("expected %d requests, got %d", tt.expectedCount, len(requests))
}
// Verify names
foundNames := make(map[string]bool)
for _, req := range requests {
foundNames[req.Name] = true
}
for _, expectedName := range tt.expectedNames {
if !foundNames[expectedName] {
t.Errorf("expected to find model %s, but it was not in the results", expectedName)
}
}
// Verify all returned names were expected
if len(foundNames) != len(tt.expectedNames) {
t.Errorf("found unexpected models in results")
}
})
}
}
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package modelendpoint
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"sigs.k8s.io/controller-runtime/pkg/log"
)
// loadLoRA loads a LoRA model on a single endpoint
func (c *Client) loadLoRA(ctx context.Context, address, modelName, sourceURI string) error {
logs := log.FromContext(ctx)
// Build request body with source object
loadReq := map[string]interface{}{
"lora_name": modelName,
"source": map[string]interface{}{
"uri": sourceURI,
},
}
loadBody, err := json.Marshal(loadReq)
if err != nil {
return fmt.Errorf("failed to marshal load LoRA request: %w", err)
}
// Build URL robustly using url.JoinPath to handle trailing slashes
// Pass path segments without leading slash to preserve any existing path in address (e.g., /v1)
apiURL, err := url.JoinPath(address, "v1", "loras")
if err != nil {
return fmt.Errorf("failed to construct load LoRA URL: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewBuffer(loadBody))
if err != nil {
return fmt.Errorf("failed to create load LoRA request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return fmt.Errorf("failed to call load LoRA endpoint: %w", err)
}
defer func() {
if closeErr := resp.Body.Close(); closeErr != nil {
logs.V(1).Info("Failed to close response body", "error", closeErr)
}
}()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(resp.Body)
logs.V(1).Info("Load LoRA failed", "address", address, "status", resp.StatusCode, "body", string(body))
return fmt.Errorf("load LoRA failed with status %d: %s", resp.StatusCode, string(body))
}
logs.Info("Successfully loaded LoRA", "address", address, "modelName", modelName, "sourceURI", sourceURI)
return nil
}
// unloadLoRA unloads a LoRA model from a single endpoint
func (c *Client) unloadLoRA(ctx context.Context, address, modelName string) error {
logs := log.FromContext(ctx)
// Build URL robustly using url.JoinPath to handle trailing slashes and encode modelName
// Pass path segments without leading slash to preserve any existing path in address (e.g., /v1)
apiURL, err := url.JoinPath(address, "v1", "loras", modelName)
if err != nil {
logs.V(1).Info("Failed to construct unload LoRA URL", "error", err)
return fmt.Errorf("failed to construct unload LoRA URL: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "DELETE", apiURL, nil)
if err != nil {
logs.V(1).Info("Failed to create unload LoRA request", "error", err)
return fmt.Errorf("failed to create unload LoRA request: %w", err)
}
resp, err := c.httpClient.Do(req)
if err != nil {
logs.V(1).Info("Failed to call unload LoRA endpoint", "address", address, "error", err)
return fmt.Errorf("failed to call unload LoRA endpoint: %w", err)
}
defer func() {
if closeErr := resp.Body.Close(); closeErr != nil {
logs.V(1).Info("Failed to close response body", "error", closeErr)
}
}()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(resp.Body)
logs.V(1).Info("Unload LoRA endpoint returned error status",
"address", address,
"status", resp.StatusCode,
"body", string(body))
return fmt.Errorf("unload LoRA failed with status %d: %s", resp.StatusCode, string(body))
}
logs.V(1).Info("Successfully unloaded LoRA", "address", address, "modelName", modelName)
return nil
}
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package modelendpoint
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestLoadLoRA_URLConstruction(t *testing.T) {
tests := []struct {
name string
baseAddress string
expectedURLPath string
}{
{
name: "address without trailing slash",
baseAddress: "http://10.0.1.5:9090",
expectedURLPath: "/v1/loras",
},
{
name: "address with trailing slash",
baseAddress: "http://10.0.1.5:9090/",
expectedURLPath: "/v1/loras",
},
{
name: "address with path",
baseAddress: "http://10.0.1.5:9090/api",
expectedURLPath: "/api/v1/loras",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a test server that captures the request
var capturedPath string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedPath = r.URL.Path
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := NewClient()
ctx := context.Background()
// Call loadLoRA with test server URL
_ = client.loadLoRA(ctx, server.URL+tt.baseAddress[len("http://10.0.1.5:9090"):], "test-model", "s3://bucket/model")
if capturedPath != tt.expectedURLPath {
t.Errorf("expected URL path %s, got %s", tt.expectedURLPath, capturedPath)
}
})
}
}
func TestLoadLoRA_RequestBody(t *testing.T) {
tests := []struct {
name string
modelName string
sourceURI string
expectedLoraName string
expectedSourceURI string
}{
{
name: "basic lora load",
modelName: "my-lora",
sourceURI: "s3://bucket/model",
expectedLoraName: "my-lora",
expectedSourceURI: "s3://bucket/model",
},
{
name: "huggingface lora",
modelName: "hf-lora",
sourceURI: "hf://org/model@v1.0",
expectedLoraName: "hf-lora",
expectedSourceURI: "hf://org/model@v1.0",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a test server that captures the request body
var capturedBody map[string]interface{}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
_ = json.Unmarshal(body, &capturedBody)
if r.Header.Get("Content-Type") != "application/json" {
t.Errorf("expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := NewClient()
ctx := context.Background()
// Call loadLoRA
err := client.loadLoRA(ctx, server.URL, tt.modelName, tt.sourceURI)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Verify request body
if capturedBody["lora_name"] != tt.expectedLoraName {
t.Errorf("expected lora_name %s, got %v", tt.expectedLoraName, capturedBody["lora_name"])
}
source, ok := capturedBody["source"].(map[string]interface{})
if !ok {
t.Fatal("expected source to be a map")
}
if source["uri"] != tt.expectedSourceURI {
t.Errorf("expected source URI %s, got %v", tt.expectedSourceURI, source["uri"])
}
})
}
}
func TestLoadLoRA_ResponseHandling(t *testing.T) {
tests := []struct {
name string
statusCode int
responseBody string
expectError bool
errorContains string
}{
{
name: "success - 200 OK",
statusCode: http.StatusOK,
expectError: false,
},
{
name: "success - 201 Created",
statusCode: http.StatusCreated,
expectError: false,
},
{
name: "failure - 400 Bad Request",
statusCode: http.StatusBadRequest,
responseBody: "Invalid LoRA",
expectError: true,
errorContains: "400",
},
{
name: "failure - 404 Not Found",
statusCode: http.StatusNotFound,
responseBody: "Endpoint not found",
expectError: true,
errorContains: "404",
},
{
name: "failure - 500 Internal Server Error",
statusCode: http.StatusInternalServerError,
responseBody: "Server error",
expectError: true,
errorContains: "500",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tt.statusCode)
if tt.responseBody != "" {
_, _ = w.Write([]byte(tt.responseBody))
}
}))
defer server.Close()
client := NewClient()
ctx := context.Background()
err := client.loadLoRA(ctx, server.URL, "test-model", "s3://bucket/model")
if tt.expectError {
if err == nil {
t.Error("expected error but got none")
} else if tt.errorContains != "" && !strings.Contains(err.Error(), tt.errorContains) {
t.Errorf("expected error to contain %q, got %v", tt.errorContains, err)
}
} else {
if err != nil {
t.Errorf("expected no error but got: %v", err)
}
}
})
}
}
func TestUnloadLoRA_URLConstruction(t *testing.T) {
tests := []struct {
name string
modelName string
expectedURLPath string
}{
{
name: "simple model name",
modelName: "my-lora",
expectedURLPath: "/v1/loras/my-lora",
},
{
name: "model name with special chars",
modelName: "my-lora-v1.0",
expectedURLPath: "/v1/loras/my-lora-v1.0",
},
{
name: "model name with slashes (URL encoded)",
modelName: "org/model",
expectedURLPath: "/v1/loras/org/model",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a test server that captures the request
var capturedPath string
var capturedMethod string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedPath = r.URL.Path
capturedMethod = r.Method
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := NewClient()
ctx := context.Background()
// Call unloadLoRA
err := client.unloadLoRA(ctx, server.URL, tt.modelName)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if capturedMethod != "DELETE" {
t.Errorf("expected DELETE method, got %s", capturedMethod)
}
if capturedPath != tt.expectedURLPath {
t.Errorf("expected URL path %s, got %s", tt.expectedURLPath, capturedPath)
}
})
}
}
func TestUnloadLoRA_ResponseHandling(t *testing.T) {
tests := []struct {
name string
statusCode int
responseBody string
expectError bool
errorContains string
}{
{
name: "success - 200 OK",
statusCode: http.StatusOK,
expectError: false,
},
{
name: "success - 204 No Content",
statusCode: http.StatusNoContent,
expectError: false,
},
{
name: "failure - 404 Not Found",
statusCode: http.StatusNotFound,
responseBody: "LoRA not found",
expectError: true,
errorContains: "404",
},
{
name: "failure - 500 Internal Server Error",
statusCode: http.StatusInternalServerError,
responseBody: "Failed to unload",
expectError: true,
errorContains: "500",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tt.statusCode)
if tt.responseBody != "" {
_, _ = w.Write([]byte(tt.responseBody))
}
}))
defer server.Close()
client := NewClient()
ctx := context.Background()
err := client.unloadLoRA(ctx, server.URL, "test-model")
if tt.expectError {
if err == nil {
t.Error("expected error but got none")
} else if tt.errorContains != "" && !strings.Contains(err.Error(), tt.errorContains) {
t.Errorf("expected error to contain %q, got %v", tt.errorContains, err)
}
} else {
if err != nil {
t.Errorf("expected no error but got: %v", err)
}
}
})
}
}
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package modelendpoint
// Candidate represents an endpoint candidate for operations
type Candidate struct {
Address string
PodName string
}
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package workerpool
import (
"context"
"fmt"
"sync"
"time"
)
// Task represents a unit of work to be executed
type Task[T any] struct {
Index int
Work func(ctx context.Context) (T, error)
}
// Result represents the outcome of executing a task
type Result[T any] struct {
Index int
Value T
Err error
}
// Execute runs all tasks in parallel with bounded concurrency using a worker pool
// Returns results in the same order as input tasks, even if execution order differs
// Continues executing all tasks even if some fail
// Spawns exactly maxWorkers goroutines regardless of task count
func Execute[T any](ctx context.Context, maxWorkers int, timeout time.Duration, tasks []Task[T]) ([]Result[T], error) {
// Validate maxWorkers to prevent panics or hangs
if maxWorkers < 1 {
return nil, fmt.Errorf("maxWorkers must be at least 1, got %d", maxWorkers)
}
if len(tasks) == 0 {
return nil, nil
}
// Create context with timeout
execCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
// Create channels
taskChan := make(chan Task[T])
results := make(chan Result[T], len(tasks))
// Start exactly maxWorkers worker goroutines
var wg sync.WaitGroup
for range maxWorkers {
wg.Add(1)
go func() {
defer wg.Done()
// Each worker pulls tasks from the channel until it's closed
for task := range taskChan {
// Execute the task
value, err := task.Work(execCtx)
// Send result through channel
results <- Result[T]{
Index: task.Index,
Value: value,
Err: err,
}
}
}()
}
// Feed tasks to workers in a separate goroutine to avoid blocking
go func() {
for _, task := range tasks {
taskChan <- task
}
close(taskChan) // Signal workers that no more tasks are coming
}()
// Close results channel when all workers complete
go func() {
wg.Wait()
close(results)
}()
// Collect results from channel
collectedResults := make([]Result[T], len(tasks))
var errorCount int
for result := range results {
collectedResults[result.Index] = result
if result.Err != nil {
errorCount++
}
}
// Return error if any tasks failed
if errorCount > 0 {
return collectedResults, fmt.Errorf("%d task(s) failed", errorCount)
}
return collectedResults, nil
}
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package workerpool
import (
"context"
"errors"
"fmt"
"strings"
"sync/atomic"
"testing"
"time"
)
func TestExecute(t *testing.T) {
tests := []struct {
name string
maxWorkers int
timeout time.Duration
taskCount int
taskDuration time.Duration
failingTaskIndices []int
expectError bool
errorContains string
}{
{
name: "empty task list",
maxWorkers: 5,
timeout: time.Second,
taskCount: 0,
expectError: false,
},
{
name: "single task success",
maxWorkers: 1,
timeout: time.Second,
taskCount: 1,
taskDuration: 10 * time.Millisecond,
expectError: false,
},
{
name: "multiple tasks success",
maxWorkers: 5,
timeout: time.Second,
taskCount: 10,
taskDuration: 10 * time.Millisecond,
expectError: false,
},
{
name: "single task failure",
maxWorkers: 5,
timeout: time.Second,
taskCount: 5,
taskDuration: 10 * time.Millisecond,
failingTaskIndices: []int{2},
expectError: true,
errorContains: "1 task(s) failed",
},
{
name: "multiple task failures",
maxWorkers: 5,
timeout: time.Second,
taskCount: 10,
taskDuration: 10 * time.Millisecond,
failingTaskIndices: []int{1, 3, 5},
expectError: true,
errorContains: "3 task(s) failed",
},
{
name: "more tasks than workers",
maxWorkers: 3,
timeout: time.Second,
taskCount: 10,
taskDuration: 10 * time.Millisecond,
expectError: false,
},
{
name: "more workers than tasks",
maxWorkers: 10,
timeout: time.Second,
taskCount: 3,
taskDuration: 10 * time.Millisecond,
expectError: false,
},
{
name: "single worker multiple tasks",
maxWorkers: 1,
timeout: time.Second,
taskCount: 5,
taskDuration: 10 * time.Millisecond,
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
// Create tasks
tasks := make([]Task[int], tt.taskCount)
failingSet := make(map[int]bool)
for _, idx := range tt.failingTaskIndices {
failingSet[idx] = true
}
for i := range tasks {
taskIndex := i
tasks[i] = Task[int]{
Index: taskIndex,
Work: func(ctx context.Context) (int, error) {
// Simulate work
if tt.taskDuration > 0 {
time.Sleep(tt.taskDuration)
}
// Return error if this task should fail
if failingSet[taskIndex] {
return 0, fmt.Errorf("task %d failed", taskIndex)
}
return taskIndex * 2, nil
},
}
}
// Execute tasks
results, err := Execute(ctx, tt.maxWorkers, tt.timeout, tasks)
// Verify error expectation
if tt.expectError {
if err == nil {
t.Error("expected error but got none")
} else if tt.errorContains != "" && !strings.Contains(err.Error(), tt.errorContains) {
t.Errorf("expected error to contain %q, got %v", tt.errorContains, err)
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}
// Verify result count
if len(results) != tt.taskCount {
t.Errorf("expected %d results, got %d", tt.taskCount, len(results))
}
// Verify successful task results
for i, result := range results {
if result.Index != i {
t.Errorf("result %d has wrong index: expected %d, got %d", i, i, result.Index)
}
if !failingSet[i] {
// Successful tasks should have correct value
expectedValue := i * 2
if result.Value != expectedValue {
t.Errorf("result %d has wrong value: expected %d, got %d", i, expectedValue, result.Value)
}
if result.Err != nil {
t.Errorf("result %d has unexpected error: %v", i, result.Err)
}
} else {
// Failed tasks should have error
if result.Err == nil {
t.Errorf("result %d should have error but got none", i)
}
}
}
})
}
}
func TestExecute_InvalidMaxWorkers(t *testing.T) {
tests := []struct {
name string
maxWorkers int
errorContains string
}{
{
name: "zero workers",
maxWorkers: 0,
errorContains: "maxWorkers must be at least 1",
},
{
name: "negative workers",
maxWorkers: -1,
errorContains: "maxWorkers must be at least 1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
tasks := []Task[int]{
{
Index: 0,
Work: func(ctx context.Context) (int, error) {
return 0, nil
},
},
}
_, err := Execute(ctx, tt.maxWorkers, time.Second, tasks)
if err == nil {
t.Error("expected error but got none")
} else if !strings.Contains(err.Error(), tt.errorContains) {
t.Errorf("expected error to contain %q, got %v", tt.errorContains, err)
}
})
}
}
func TestExecute_Timeout(t *testing.T) {
ctx := context.Background()
// Create tasks that take longer than the timeout
tasks := []Task[int]{
{
Index: 0,
Work: func(ctx context.Context) (int, error) {
select {
case <-time.After(2 * time.Second):
return 0, nil
case <-ctx.Done():
return 0, ctx.Err()
}
},
},
{
Index: 1,
Work: func(ctx context.Context) (int, error) {
select {
case <-time.After(2 * time.Second):
return 1, nil
case <-ctx.Done():
return 0, ctx.Err()
}
},
},
}
// Execute with short timeout
results, err := Execute(ctx, 2, 100*time.Millisecond, tasks)
// Should get error because tasks timed out
if err == nil {
t.Error("expected timeout error but got none")
}
// Should still get results (with errors)
if len(results) != 2 {
t.Errorf("expected 2 results, got %d", len(results))
}
// All results should have context deadline exceeded error
for i, result := range results {
if result.Err == nil {
t.Errorf("result %d should have timeout error but got none", i)
}
}
}
func TestExecute_Concurrency(t *testing.T) {
ctx := context.Background()
maxWorkers := 5
taskCount := 20
// Track concurrent execution
var currentConcurrent int32
var maxConcurrent int32
tasks := make([]Task[int], taskCount)
for i := range tasks {
taskIndex := i
tasks[i] = Task[int]{
Index: taskIndex,
Work: func(ctx context.Context) (int, error) {
// Increment counter
current := atomic.AddInt32(&currentConcurrent, 1)
// Update max if needed
for {
max := atomic.LoadInt32(&maxConcurrent)
if current <= max || atomic.CompareAndSwapInt32(&maxConcurrent, max, current) {
break
}
}
// Simulate work
time.Sleep(50 * time.Millisecond)
// Decrement counter
atomic.AddInt32(&currentConcurrent, -1)
return taskIndex, nil
},
}
}
_, err := Execute(ctx, maxWorkers, 5*time.Second, tasks)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
// Verify concurrency stayed within bounds
if maxConcurrent > int32(maxWorkers) {
t.Errorf("expected max concurrent workers <= %d, got %d", maxWorkers, maxConcurrent)
}
// Verify we actually used concurrency (should be at least 2 concurrent)
if maxConcurrent < 2 {
t.Errorf("expected concurrent execution, but maxConcurrent was only %d", maxConcurrent)
}
}
func TestExecute_ContextCancellation(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
// Create tasks that check for cancellation
tasks := make([]Task[int], 5)
for i := range tasks {
taskIndex := i
tasks[i] = Task[int]{
Index: taskIndex,
Work: func(ctx context.Context) (int, error) {
select {
case <-time.After(2 * time.Second):
return taskIndex, nil
case <-ctx.Done():
return 0, ctx.Err()
}
},
}
}
// Cancel context after short delay
go func() {
time.Sleep(100 * time.Millisecond)
cancel()
}()
results, err := Execute(ctx, 3, 5*time.Second, tasks)
// Should get error
if err == nil {
t.Error("expected cancellation error but got none")
}
// Should still get results
if len(results) != 5 {
t.Errorf("expected 5 results, got %d", len(results))
}
// All results should have cancellation error
for i, result := range results {
if result.Err == nil {
t.Errorf("result %d should have cancellation error but got none", i)
} else if !errors.Is(result.Err, context.Canceled) {
t.Errorf("result %d expected context.Canceled, got %v", i, result.Err)
}
}
}
func TestExecute_ResultOrdering(t *testing.T) {
ctx := context.Background()
taskCount := 10
// Create tasks that complete in reverse order
tasks := make([]Task[int], taskCount)
for i := range tasks {
taskIndex := i
tasks[i] = Task[int]{
Index: taskIndex,
Work: func(ctx context.Context) (int, error) {
// Later tasks sleep less (complete faster)
sleepDuration := time.Duration(taskCount-taskIndex) * 10 * time.Millisecond
time.Sleep(sleepDuration)
return taskIndex * 10, nil
},
}
}
results, err := Execute(ctx, 5, 5*time.Second, tasks)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
// Verify results are in original order despite reverse completion
for i, result := range results {
if result.Index != i {
t.Errorf("result %d has wrong index: expected %d, got %d", i, i, result.Index)
}
expectedValue := i * 10
if result.Value != expectedValue {
t.Errorf("result %d has wrong value: expected %d, got %d", i, expectedValue, result.Value)
}
}
}
...@@ -217,6 +217,7 @@ Key customization points include: ...@@ -217,6 +217,7 @@ Key customization points include:
- **[Examples](../examples/README.md)** - Complete working examples - **[Examples](../examples/README.md)** - Complete working examples
- **[Create Custom Deployments](./deployment/create_deployment.md)** - Build your own CRDs - **[Create Custom Deployments](./deployment/create_deployment.md)** - Build your own CRDs
- **[Managing Models with DynamoModel](./deployment/dynamomodel-guide.md)** - Deploy LoRA adapters and manage models
- **[Operator Documentation](./dynamo_operator.md)** - How the platform works - **[Operator Documentation](./dynamo_operator.md)** - How the platform works
- **[Helm Charts](../../deploy/helm/README.md)** - For advanced users - **[Helm Charts](../../deploy/helm/README.md)** - For advanced users
- **[GitOps Deployment with FluxCD](./fluxcd.md)** - For advanced users - **[GitOps Deployment with FluxCD](./fluxcd.md)** - For advanced users
......
...@@ -37,6 +37,7 @@ Package v1alpha1 contains API Schema definitions for the nvidia.com v1alpha1 API ...@@ -37,6 +37,7 @@ Package v1alpha1 contains API Schema definitions for the nvidia.com v1alpha1 API
- [DynamoComponentDeployment](#dynamocomponentdeployment) - [DynamoComponentDeployment](#dynamocomponentdeployment)
- [DynamoGraphDeployment](#dynamographdeployment) - [DynamoGraphDeployment](#dynamographdeployment)
- [DynamoGraphDeploymentRequest](#dynamographdeploymentrequest) - [DynamoGraphDeploymentRequest](#dynamographdeploymentrequest)
- [DynamoModel](#dynamomodel)
...@@ -167,6 +168,7 @@ _Appears in:_ ...@@ -167,6 +168,7 @@ _Appears in:_
| `envFromSecret` _string_ | EnvFromSecret references a Secret whose key/value pairs will be exposed as<br />environment variables in the component containers. | | | | `envFromSecret` _string_ | EnvFromSecret references a Secret whose key/value pairs will be exposed as<br />environment variables in the component containers. | | |
| `volumeMounts` _[VolumeMount](#volumemount) array_ | VolumeMounts references PVCs defined at the top level for volumes to be mounted by the component. | | | | `volumeMounts` _[VolumeMount](#volumemount) array_ | VolumeMounts references PVCs defined at the top level for volumes to be mounted by the component. | | |
| `ingress` _[IngressSpec](#ingressspec)_ | Ingress config to expose the component outside the cluster (or through a service mesh). | | | | `ingress` _[IngressSpec](#ingressspec)_ | Ingress config to expose the component outside the cluster (or through a service mesh). | | |
| `modelRef` _[ModelReference](#modelreference)_ | ModelRef references a model that this component serves<br />When specified, a headless service will be created for endpoint discovery | | |
| `sharedMemory` _[SharedMemorySpec](#sharedmemoryspec)_ | SharedMemory controls the tmpfs mounted at /dev/shm (enable/disable and size). | | | | `sharedMemory` _[SharedMemorySpec](#sharedmemoryspec)_ | SharedMemory controls the tmpfs mounted at /dev/shm (enable/disable and size). | | |
| `extraPodMetadata` _[ExtraPodMetadata](#extrapodmetadata)_ | ExtraPodMetadata adds labels/annotations to the created Pods. | | | | `extraPodMetadata` _[ExtraPodMetadata](#extrapodmetadata)_ | ExtraPodMetadata adds labels/annotations to the created Pods. | | |
| `extraPodSpec` _[ExtraPodSpec](#extrapodspec)_ | ExtraPodSpec allows to override the main pod spec configuration.<br />It is a k8s standard PodSpec. It also contains a MainContainer (standard k8s Container) field<br />that allows overriding the main container configuration. | | | | `extraPodSpec` _[ExtraPodSpec](#extrapodspec)_ | ExtraPodSpec allows to override the main pod spec configuration.<br />It is a k8s standard PodSpec. It also contains a MainContainer (standard k8s Container) field<br />that allows overriding the main container configuration. | | |
...@@ -203,6 +205,7 @@ _Appears in:_ ...@@ -203,6 +205,7 @@ _Appears in:_
| `envFromSecret` _string_ | EnvFromSecret references a Secret whose key/value pairs will be exposed as<br />environment variables in the component containers. | | | | `envFromSecret` _string_ | EnvFromSecret references a Secret whose key/value pairs will be exposed as<br />environment variables in the component containers. | | |
| `volumeMounts` _[VolumeMount](#volumemount) array_ | VolumeMounts references PVCs defined at the top level for volumes to be mounted by the component. | | | | `volumeMounts` _[VolumeMount](#volumemount) array_ | VolumeMounts references PVCs defined at the top level for volumes to be mounted by the component. | | |
| `ingress` _[IngressSpec](#ingressspec)_ | Ingress config to expose the component outside the cluster (or through a service mesh). | | | | `ingress` _[IngressSpec](#ingressspec)_ | Ingress config to expose the component outside the cluster (or through a service mesh). | | |
| `modelRef` _[ModelReference](#modelreference)_ | ModelRef references a model that this component serves<br />When specified, a headless service will be created for endpoint discovery | | |
| `sharedMemory` _[SharedMemorySpec](#sharedmemoryspec)_ | SharedMemory controls the tmpfs mounted at /dev/shm (enable/disable and size). | | | | `sharedMemory` _[SharedMemorySpec](#sharedmemoryspec)_ | SharedMemory controls the tmpfs mounted at /dev/shm (enable/disable and size). | | |
| `extraPodMetadata` _[ExtraPodMetadata](#extrapodmetadata)_ | ExtraPodMetadata adds labels/annotations to the created Pods. | | | | `extraPodMetadata` _[ExtraPodMetadata](#extrapodmetadata)_ | ExtraPodMetadata adds labels/annotations to the created Pods. | | |
| `extraPodSpec` _[ExtraPodSpec](#extrapodspec)_ | ExtraPodSpec allows to override the main pod spec configuration.<br />It is a k8s standard PodSpec. It also contains a MainContainer (standard k8s Container) field<br />that allows overriding the main container configuration. | | | | `extraPodSpec` _[ExtraPodSpec](#extrapodspec)_ | ExtraPodSpec allows to override the main pod spec configuration.<br />It is a k8s standard PodSpec. It also contains a MainContainer (standard k8s Container) field<br />that allows overriding the main container configuration. | | |
...@@ -345,6 +348,81 @@ _Appears in:_ ...@@ -345,6 +348,81 @@ _Appears in:_
| `conditions` _[Condition](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.28/#condition-v1-meta) array_ | Conditions contains the latest observed conditions of the graph deployment.<br />The slice is merged by type on patch updates. | | | | `conditions` _[Condition](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.28/#condition-v1-meta) array_ | Conditions contains the latest observed conditions of the graph deployment.<br />The slice is merged by type on patch updates. | | |
#### DynamoModel
DynamoModel is the Schema for the dynamo models API
| Field | Description | Default | Validation |
| --- | --- | --- | --- |
| `apiVersion` _string_ | `nvidia.com/v1alpha1` | | |
| `kind` _string_ | `DynamoModel` | | |
| `metadata` _[ObjectMeta](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.28/#objectmeta-v1-meta)_ | Refer to Kubernetes API documentation for fields of `metadata`. | | |
| `spec` _[DynamoModelSpec](#dynamomodelspec)_ | | | |
| `status` _[DynamoModelStatus](#dynamomodelstatus)_ | | | |
#### DynamoModelSpec
DynamoModelSpec defines the desired state of DynamoModel
_Appears in:_
- [DynamoModel](#dynamomodel)
| Field | Description | Default | Validation |
| --- | --- | --- | --- |
| `modelName` _string_ | ModelName is the full model identifier (e.g., "meta-llama/Llama-3.3-70B-Instruct-lora") | | Required: \{\} <br /> |
| `baseModelName` _string_ | BaseModelName is the base model identifier that matches the service label<br />This is used to discover endpoints via headless services | | Required: \{\} <br /> |
| `modelType` _string_ | ModelType specifies the type of model (e.g., "base", "lora", "adapter") | base | Enum: [base lora adapter] <br /> |
| `source` _[ModelSource](#modelsource)_ | Source specifies the model source location (only applicable for lora model type) | | |
#### DynamoModelStatus
DynamoModelStatus defines the observed state of DynamoModel
_Appears in:_
- [DynamoModel](#dynamomodel)
| Field | Description | Default | Validation |
| --- | --- | --- | --- |
| `endpoints` _[EndpointInfo](#endpointinfo) array_ | Endpoints is the current list of all endpoints for this model | | |
| `readyEndpoints` _integer_ | ReadyEndpoints is the count of endpoints that are ready | | |
| `totalEndpoints` _integer_ | TotalEndpoints is the total count of endpoints | | |
| `conditions` _[Condition](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.28/#condition-v1-meta) array_ | Conditions represents the latest available observations of the model's state | | |
#### EndpointInfo
EndpointInfo represents a single endpoint (pod) serving the model
_Appears in:_
- [DynamoModelStatus](#dynamomodelstatus)
| Field | Description | Default | Validation |
| --- | --- | --- | --- |
| `address` _string_ | Address is the full address of the endpoint (e.g., "http://10.0.1.5:9090") | | |
| `podName` _string_ | PodName is the name of the pod serving this endpoint | | |
| `ready` _boolean_ | Ready indicates whether the endpoint is ready to serve traffic<br />For LoRA models: true if the POST /loras request succeeded with a 2xx status code<br />For base models: always false (no probing performed) | | |
#### IngressSpec #### IngressSpec
...@@ -387,6 +465,40 @@ _Appears in:_ ...@@ -387,6 +465,40 @@ _Appears in:_
| `secretName` _string_ | SecretName is the name of a Kubernetes Secret containing the TLS certificate and key. | | | | `secretName` _string_ | SecretName is the name of a Kubernetes Secret containing the TLS certificate and key. | | |
#### ModelReference
ModelReference identifies a model served by this component
_Appears in:_
- [DynamoComponentDeploymentSharedSpec](#dynamocomponentdeploymentsharedspec)
- [DynamoComponentDeploymentSpec](#dynamocomponentdeploymentspec)
| Field | Description | Default | Validation |
| --- | --- | --- | --- |
| `name` _string_ | Name is the base model identifier (e.g., "llama-3-70b-instruct-v1") | | Required: \{\} <br /> |
| `revision` _string_ | Revision is the model revision/version (optional) | | |
#### ModelSource
ModelSource defines the source location of a model
_Appears in:_
- [DynamoModelSpec](#dynamomodelspec)
| Field | Description | Default | Validation |
| --- | --- | --- | --- |
| `uri` _string_ | URI is the model source URI<br />Supported formats:<br />- S3: s3://bucket/path/to/model<br />- HuggingFace: hf://org/model@revision_sha | | Required: \{\} <br /> |
#### MultinodeSpec #### MultinodeSpec
......
...@@ -219,3 +219,41 @@ When disabled, you can manually specify secrets as you would for a normal pod sp ...@@ -219,3 +219,41 @@ When disabled, you can manually specify secrets as you would for a normal pod sp
``` ```
This automatic discovery eliminates the need to manually configure image pull secrets for each deployment. This automatic discovery eliminates the need to manually configure image pull secrets for each deployment.
## Step 6: Deploy LoRA Adapters (Optional)
After your base model deployment is running, you can deploy LoRA adapters using the `DynamoModel` custom resource. This allows you to fine-tune and extend your models without modifying the base deployment.
To add a LoRA adapter to your deployment, link it using `modelRef` in your worker configuration:
```yaml
apiVersion: nvidia.com/v1alpha1
kind: DynamoGraphDeployment
metadata:
name: my-deployment
spec:
services:
Worker:
modelRef:
name: Qwen/Qwen3-0.6B # Base model identifier
componentType: worker
# ... rest of worker config
```
Then create a `DynamoModel` resource for your LoRA:
```yaml
apiVersion: nvidia.com/v1alpha1
kind: DynamoModel
metadata:
name: my-lora
spec:
modelName: my-custom-lora
baseModelName: Qwen/Qwen3-0.6B # Must match modelRef.name above
modelType: lora
source:
uri: s3://my-bucket/loras/my-lora
```
**For complete details on managing models and LoRA adapters, see:**
📖 **[Managing Models with DynamoModel Guide](./dynamomodel-guide.md)**
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment