/* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package workerpool import ( "context" "errors" "fmt" "strings" "sync/atomic" "testing" "time" ) func TestExecute(t *testing.T) { tests := []struct { name string maxWorkers int timeout time.Duration taskCount int taskDuration time.Duration failingTaskIndices []int expectError bool errorContains string }{ { name: "empty task list", maxWorkers: 5, timeout: time.Second, taskCount: 0, expectError: false, }, { name: "single task success", maxWorkers: 1, timeout: time.Second, taskCount: 1, taskDuration: 10 * time.Millisecond, expectError: false, }, { name: "multiple tasks success", maxWorkers: 5, timeout: time.Second, taskCount: 10, taskDuration: 10 * time.Millisecond, expectError: false, }, { name: "single task failure", maxWorkers: 5, timeout: time.Second, taskCount: 5, taskDuration: 10 * time.Millisecond, failingTaskIndices: []int{2}, expectError: true, errorContains: "1 task(s) failed", }, { name: "multiple task failures", maxWorkers: 5, timeout: time.Second, taskCount: 10, taskDuration: 10 * time.Millisecond, failingTaskIndices: []int{1, 3, 5}, expectError: true, errorContains: "3 task(s) failed", }, { name: "more tasks than workers", maxWorkers: 3, timeout: time.Second, taskCount: 10, taskDuration: 10 * time.Millisecond, expectError: false, }, { name: "more workers than tasks", maxWorkers: 10, timeout: time.Second, taskCount: 3, taskDuration: 10 * time.Millisecond, expectError: false, }, { name: "single worker multiple tasks", maxWorkers: 1, timeout: time.Second, taskCount: 5, taskDuration: 10 * time.Millisecond, expectError: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := context.Background() // Create tasks tasks := make([]Task[int], tt.taskCount) failingSet := make(map[int]bool) for _, idx := range tt.failingTaskIndices { failingSet[idx] = true } for i := range tasks { taskIndex := i tasks[i] = Task[int]{ Index: taskIndex, Work: func(ctx context.Context) (int, error) { // Simulate work if tt.taskDuration > 0 { time.Sleep(tt.taskDuration) } // Return error if this task should fail if failingSet[taskIndex] { return 0, fmt.Errorf("task %d failed", taskIndex) } return taskIndex * 2, nil }, } } // Execute tasks results, err := Execute(ctx, tt.maxWorkers, tt.timeout, tasks) // Verify error expectation if tt.expectError { if err == nil { t.Error("expected error but got none") } else if tt.errorContains != "" && !strings.Contains(err.Error(), tt.errorContains) { t.Errorf("expected error to contain %q, got %v", tt.errorContains, err) } } else { if err != nil { t.Errorf("unexpected error: %v", err) } } // Verify result count if len(results) != tt.taskCount { t.Errorf("expected %d results, got %d", tt.taskCount, len(results)) } // Verify successful task results for i, result := range results { if result.Index != i { t.Errorf("result %d has wrong index: expected %d, got %d", i, i, result.Index) } if !failingSet[i] { // Successful tasks should have correct value expectedValue := i * 2 if result.Value != expectedValue { t.Errorf("result %d has wrong value: expected %d, got %d", i, expectedValue, result.Value) } if result.Err != nil { t.Errorf("result %d has unexpected error: %v", i, result.Err) } } else { // Failed tasks should have error if result.Err == nil { t.Errorf("result %d should have error but got none", i) } } } }) } } func TestExecute_InvalidMaxWorkers(t *testing.T) { tests := []struct { name string maxWorkers int errorContains string }{ { name: "zero workers", maxWorkers: 0, errorContains: "maxWorkers must be at least 1", }, { name: "negative workers", maxWorkers: -1, errorContains: "maxWorkers must be at least 1", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := context.Background() tasks := []Task[int]{ { Index: 0, Work: func(ctx context.Context) (int, error) { return 0, nil }, }, } _, err := Execute(ctx, tt.maxWorkers, time.Second, tasks) if err == nil { t.Error("expected error but got none") } else if !strings.Contains(err.Error(), tt.errorContains) { t.Errorf("expected error to contain %q, got %v", tt.errorContains, err) } }) } } func TestExecute_Timeout(t *testing.T) { ctx := context.Background() // Create tasks that take longer than the timeout tasks := []Task[int]{ { Index: 0, Work: func(ctx context.Context) (int, error) { select { case <-time.After(2 * time.Second): return 0, nil case <-ctx.Done(): return 0, ctx.Err() } }, }, { Index: 1, Work: func(ctx context.Context) (int, error) { select { case <-time.After(2 * time.Second): return 1, nil case <-ctx.Done(): return 0, ctx.Err() } }, }, } // Execute with short timeout results, err := Execute(ctx, 2, 100*time.Millisecond, tasks) // Should get error because tasks timed out if err == nil { t.Error("expected timeout error but got none") } // Should still get results (with errors) if len(results) != 2 { t.Errorf("expected 2 results, got %d", len(results)) } // All results should have context deadline exceeded error for i, result := range results { if result.Err == nil { t.Errorf("result %d should have timeout error but got none", i) } } } func TestExecute_Concurrency(t *testing.T) { ctx := context.Background() maxWorkers := 5 taskCount := 20 // Track concurrent execution var currentConcurrent int32 var maxConcurrent int32 tasks := make([]Task[int], taskCount) for i := range tasks { taskIndex := i tasks[i] = Task[int]{ Index: taskIndex, Work: func(ctx context.Context) (int, error) { // Increment counter current := atomic.AddInt32(¤tConcurrent, 1) // Update max if needed for { max := atomic.LoadInt32(&maxConcurrent) if current <= max || atomic.CompareAndSwapInt32(&maxConcurrent, max, current) { break } } // Simulate work time.Sleep(50 * time.Millisecond) // Decrement counter atomic.AddInt32(¤tConcurrent, -1) return taskIndex, nil }, } } _, err := Execute(ctx, maxWorkers, 5*time.Second, tasks) if err != nil { t.Errorf("unexpected error: %v", err) } // Verify concurrency stayed within bounds if maxConcurrent > int32(maxWorkers) { t.Errorf("expected max concurrent workers <= %d, got %d", maxWorkers, maxConcurrent) } // Verify we actually used concurrency (should be at least 2 concurrent) if maxConcurrent < 2 { t.Errorf("expected concurrent execution, but maxConcurrent was only %d", maxConcurrent) } } func TestExecute_ContextCancellation(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) // Create tasks that check for cancellation tasks := make([]Task[int], 5) for i := range tasks { taskIndex := i tasks[i] = Task[int]{ Index: taskIndex, Work: func(ctx context.Context) (int, error) { select { case <-time.After(2 * time.Second): return taskIndex, nil case <-ctx.Done(): return 0, ctx.Err() } }, } } // Cancel context after short delay go func() { time.Sleep(100 * time.Millisecond) cancel() }() results, err := Execute(ctx, 3, 5*time.Second, tasks) // Should get error if err == nil { t.Error("expected cancellation error but got none") } // Should still get results if len(results) != 5 { t.Errorf("expected 5 results, got %d", len(results)) } // All results should have cancellation error for i, result := range results { if result.Err == nil { t.Errorf("result %d should have cancellation error but got none", i) } else if !errors.Is(result.Err, context.Canceled) { t.Errorf("result %d expected context.Canceled, got %v", i, result.Err) } } } func TestExecute_ResultOrdering(t *testing.T) { ctx := context.Background() taskCount := 10 // Create tasks that complete in reverse order tasks := make([]Task[int], taskCount) for i := range tasks { taskIndex := i tasks[i] = Task[int]{ Index: taskIndex, Work: func(ctx context.Context) (int, error) { // Later tasks sleep less (complete faster) sleepDuration := time.Duration(taskCount-taskIndex) * 10 * time.Millisecond time.Sleep(sleepDuration) return taskIndex * 10, nil }, } } results, err := Execute(ctx, 5, 5*time.Second, tasks) if err != nil { t.Errorf("unexpected error: %v", err) } // Verify results are in original order despite reverse completion for i, result := range results { if result.Index != i { t.Errorf("result %d has wrong index: expected %d, got %d", i, i, result.Index) } expectedValue := i * 10 if result.Value != expectedValue { t.Errorf("result %d has wrong value: expected %d, got %d", i, expectedValue, result.Value) } } }