/* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package modelendpoint import ( "context" "net/http" "net/http/httptest" "strings" "testing" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1" ) func TestLoadLoRA(t *testing.T) { // Create test servers for different scenarios successServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Verify HTTP method if r.Method != http.MethodPost { t.Errorf("expected POST method, got %s", r.Method) w.WriteHeader(http.StatusMethodNotAllowed) return } // Verify Content-Type header if r.Header.Get("Content-Type") != "application/json" { t.Errorf("expected Content-Type application/json, got %s", r.Header.Get("Content-Type")) } w.WriteHeader(http.StatusOK) })) defer successServer.Close() failingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Verify HTTP method even for failing requests if r.Method != http.MethodPost { t.Errorf("expected POST method, got %s", r.Method) w.WriteHeader(http.StatusMethodNotAllowed) return } w.WriteHeader(http.StatusInternalServerError) })) defer failingServer.Close() tests := []struct { name string modelType string sourceURI string candidates []Candidate expectError bool errorContains string expectedCount int expectedReadyCount int }{ { name: "non-lora model - skips loading", modelType: "base", candidates: []Candidate{{Address: "http://10.0.1.5:9090", PodName: "pod-1"}}, expectError: false, expectedCount: 1, expectedReadyCount: 0, }, { name: "empty candidates", modelType: "base", candidates: []Candidate{}, expectError: false, expectedCount: 0, expectedReadyCount: 0, }, { name: "lora with nil source", modelType: "lora", sourceURI: "", candidates: []Candidate{{Address: "http://10.0.1.5:9090", PodName: "pod-1"}}, expectError: true, errorContains: "source URI is required", }, { name: "lora with valid source - all success", modelType: "lora", sourceURI: "s3://bucket/model", candidates: []Candidate{ {Address: successServer.URL, PodName: "pod-1"}, {Address: successServer.URL, PodName: "pod-2"}, }, expectError: false, expectedCount: 2, expectedReadyCount: 2, }, { name: "lora with valid source - partial failure", modelType: "lora", sourceURI: "s3://bucket/model", candidates: []Candidate{ {Address: successServer.URL, PodName: "pod-1"}, {Address: failingServer.URL, PodName: "pod-2"}, }, expectError: true, // workerpool returns error on any failure expectedCount: 2, expectedReadyCount: 1, }, { name: "lora with huggingface source", modelType: "lora", sourceURI: "hf://org/model@v1.0", candidates: []Candidate{ {Address: successServer.URL, PodName: "pod-1"}, }, expectError: false, expectedCount: 1, expectedReadyCount: 1, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { client := NewClient() ctx := context.Background() var source *v1alpha1.ModelSource if tt.sourceURI != "" { source = &v1alpha1.ModelSource{URI: tt.sourceURI} } model := &v1alpha1.DynamoModel{ ObjectMeta: metav1.ObjectMeta{ Name: "test-model", Namespace: "default", }, Spec: v1alpha1.DynamoModelSpec{ ModelName: "test-model", ModelType: tt.modelType, Source: source, }, } endpoints, err := client.LoadLoRA(ctx, tt.candidates, model) // Check error expectation if tt.expectError && tt.errorContains != "" { // For validation errors (like missing source URI), we return early if err == nil { t.Error("expected error but got none") } else if !strings.Contains(err.Error(), tt.errorContains) { t.Errorf("expected error to contain %q, got %v", tt.errorContains, err) } return } // For partial failures, we expect an error but still get endpoints if tt.expectError && err == nil { t.Error("expected error for partial failure but got none") } if !tt.expectError && err != nil { t.Fatalf("unexpected error: %v", err) } // Verify endpoint count if len(endpoints) != tt.expectedCount { t.Errorf("expected %d endpoints, got %d", tt.expectedCount, len(endpoints)) } // Count ready endpoints readyCount := 0 for _, ep := range endpoints { if ep.Ready { readyCount++ } } if readyCount != tt.expectedReadyCount { t.Errorf("expected %d ready endpoints, got %d", tt.expectedReadyCount, readyCount) } }) } } func TestUnloadLoRA(t *testing.T) { successServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Verify HTTP method if r.Method != http.MethodDelete { t.Errorf("expected DELETE method, got %s", r.Method) w.WriteHeader(http.StatusMethodNotAllowed) return } // Verify URL path contains model name if !strings.Contains(r.URL.Path, "/loras/") { t.Errorf("expected URL path to contain /loras/, got %s", r.URL.Path) } w.WriteHeader(http.StatusOK) })) defer successServer.Close() failingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Verify HTTP method even for failing requests if r.Method != http.MethodDelete { t.Errorf("expected DELETE method, got %s", r.Method) w.WriteHeader(http.StatusMethodNotAllowed) return } w.WriteHeader(http.StatusInternalServerError) })) defer failingServer.Close() tests := []struct { name string candidates []Candidate modelName string expectError bool }{ { name: "empty candidates", candidates: []Candidate{}, modelName: "test-model", expectError: false, }, { name: "single endpoint success", candidates: []Candidate{ {Address: successServer.URL, PodName: "pod-1"}, }, modelName: "test-model", expectError: false, }, { name: "multiple endpoints success", candidates: []Candidate{ {Address: successServer.URL, PodName: "pod-1"}, {Address: successServer.URL, PodName: "pod-2"}, }, modelName: "test-model", expectError: false, }, { name: "partial failure", candidates: []Candidate{ {Address: successServer.URL, PodName: "pod-1"}, {Address: failingServer.URL, PodName: "pod-2"}, }, modelName: "test-model", expectError: true, // workerpool returns error on any failure }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { client := NewClient() ctx := context.Background() err := client.UnloadLoRA(ctx, tt.candidates, tt.modelName) if tt.expectError && err == nil { t.Error("expected error but got none") } else if !tt.expectError && err != nil { t.Errorf("unexpected error: %v", err) } }) } }