Unverified Commit dabd2267 authored by julienmancuso's avatar julienmancuso Committed by GitHub
Browse files

feat: add grove multinode support (#2269)

parent d51580a4
package dynamo
import (
"strings"
"testing"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/consts"
"github.com/onsi/gomega"
corev1 "k8s.io/api/core/v1"
)
func TestVLLMBackend_UpdateContainer(t *testing.T) {
backend := &VLLMBackend{}
tests := []struct {
name string
numberOfNodes int32
role Role
component *v1alpha1.DynamoComponentDeploymentOverridesSpec
multinodeDeploymentType consts.MultinodeDeploymentType
initialArgs []string
initialLivenessProbe *corev1.Probe
initialReadinessProbe *corev1.Probe
initialStartupProbe *corev1.Probe
expectedArgs []string
expectContains []string
expectNotModified bool // If true, container args should not change
expectProbesRemoved bool // If true, probes should be nil
}{
{
name: "single node does not modify args",
numberOfNodes: 1,
role: RoleMain,
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{},
multinodeDeploymentType: consts.MultinodeDeploymentTypeGrove,
initialArgs: []string{"python3", "-m", "dynamo.vllm"},
expectNotModified: true,
},
{
name: "multinode leader prepends ray start --head",
numberOfNodes: 3,
role: RoleLeader,
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{},
multinodeDeploymentType: consts.MultinodeDeploymentTypeGrove,
initialArgs: []string{"python3", "-m", "dynamo.vllm", "--model", "test"},
expectContains: []string{"ray start --head --port=6379 &&", "python3", "-m", "dynamo.vllm", "--model", "test"},
expectProbesRemoved: true,
},
{
name: "multinode worker replaces args with ray start --block",
numberOfNodes: 3,
role: RoleWorker,
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{},
multinodeDeploymentType: consts.MultinodeDeploymentTypeGrove,
initialArgs: []string{"python3", "-m", "dynamo.vllm", "--model", "test"},
expectedArgs: []string{"ray start --address=${GROVE_PCSG_NAME}-${GROVE_PCSG_INDEX}-test-service-ldr-0.${GROVE_HEADLESS_SERVICE}:6379 --block"},
expectProbesRemoved: true,
},
{
name: "multinode worker with LWS deployment type",
numberOfNodes: 2,
role: RoleWorker,
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{},
multinodeDeploymentType: consts.MultinodeDeploymentTypeLWS,
initialArgs: []string{"python3", "-m", "dynamo.vllm"},
expectedArgs: []string{"ray start --address=${LWS_LEADER_ADDRESS}:6379 --block"},
expectProbesRemoved: true,
},
{
name: "multinode leader with no initial args",
numberOfNodes: 2,
role: RoleLeader,
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{},
multinodeDeploymentType: consts.MultinodeDeploymentTypeGrove,
initialArgs: []string{},
expectNotModified: true, // Should not modify empty args
},
{
name: "multinode main role (non-leader/worker) does not modify args",
numberOfNodes: 3,
role: RoleMain,
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{},
multinodeDeploymentType: consts.MultinodeDeploymentTypeGrove,
initialArgs: []string{"python3", "-m", "dynamo.frontend"},
expectNotModified: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := gomega.NewGomegaWithT(t)
// Create a container with initial state
container := &corev1.Container{
Args: append([]string{}, tt.initialArgs...), // Copy slice to avoid modifying original
LivenessProbe: tt.initialLivenessProbe,
ReadinessProbe: tt.initialReadinessProbe,
StartupProbe: tt.initialStartupProbe,
}
// Call UpdateContainer
backend.UpdateContainer(container, tt.numberOfNodes, tt.role, tt.component, tt.multinodeDeploymentType, "test-service")
if tt.expectNotModified {
// Args should not have changed
g.Expect(container.Args).To(gomega.Equal(tt.initialArgs))
} else if tt.expectedArgs != nil {
// Check exact match
g.Expect(container.Args).To(gomega.Equal(tt.expectedArgs))
} else if tt.expectContains != nil {
// Check that expected strings are contained in the result
argsStr := strings.Join(container.Args, " ")
for _, expected := range tt.expectContains {
if !strings.Contains(argsStr, expected) {
t.Errorf("UpdateContainer() args = %v, should contain %s", container.Args, expected)
}
}
}
if tt.expectProbesRemoved {
g.Expect(container.LivenessProbe).To(gomega.BeNil())
g.Expect(container.ReadinessProbe).To(gomega.BeNil())
g.Expect(container.StartupProbe).To(gomega.BeNil())
}
})
}
}
func TestUpdateVLLMMultinodeArgs(t *testing.T) {
tests := []struct {
name string
role Role
multinodeDeploymentType consts.MultinodeDeploymentType
initialArgs []string
expectedArgs []string
expectContains []string
expectNotModified bool
}{
{
name: "leader prepends ray start --head",
role: RoleLeader,
multinodeDeploymentType: consts.MultinodeDeploymentTypeGrove,
initialArgs: []string{"python3", "-m", "dynamo.vllm"},
expectContains: []string{"ray start --head --port=6379 &&", "python3", "-m", "dynamo.vllm"},
},
{
name: "leader with empty args does not modify",
role: RoleLeader,
multinodeDeploymentType: consts.MultinodeDeploymentTypeGrove,
initialArgs: []string{},
expectNotModified: true,
},
{
name: "worker with Grove deployment",
role: RoleWorker,
multinodeDeploymentType: consts.MultinodeDeploymentTypeGrove,
initialArgs: []string{"python3", "-m", "dynamo.vllm"},
expectedArgs: []string{"ray start --address=${GROVE_PCSG_NAME}-${GROVE_PCSG_INDEX}-test-service-ldr-0.${GROVE_HEADLESS_SERVICE}:6379 --block"},
},
{
name: "worker with LWS deployment",
role: RoleWorker,
multinodeDeploymentType: consts.MultinodeDeploymentTypeLWS,
initialArgs: []string{"python3", "-m", "dynamo.vllm"},
expectedArgs: []string{"ray start --address=${LWS_LEADER_ADDRESS}:6379 --block"},
},
{
name: "main role does not modify args",
role: RoleMain,
multinodeDeploymentType: consts.MultinodeDeploymentTypeGrove,
initialArgs: []string{"python3", "-m", "dynamo.frontend"},
expectNotModified: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := gomega.NewGomegaWithT(t)
// Create a container with initial args
container := &corev1.Container{
Args: append([]string{}, tt.initialArgs...), // Copy slice to avoid modifying original
}
// Call updateVLLMMultinodeArgs
updateVLLMMultinodeArgs(container, tt.role, tt.multinodeDeploymentType, "test-service")
if tt.expectNotModified {
// Args should not have changed
g.Expect(container.Args).To(gomega.Equal(tt.initialArgs))
} else if tt.expectedArgs != nil {
// Check exact match
g.Expect(container.Args).To(gomega.Equal(tt.expectedArgs))
} else if tt.expectContains != nil {
// Check that expected strings are contained in the result
argsStr := strings.Join(container.Args, " ")
for _, expected := range tt.expectContains {
if !strings.Contains(argsStr, expected) {
t.Errorf("updateVLLMMultinodeArgs() args = %v, should contain %s", container.Args, expected)
}
}
}
})
}
}
......@@ -21,12 +21,14 @@ import (
"context"
"encoding/json"
"fmt"
"regexp"
"sort"
"strconv"
"strings"
istioNetworking "istio.io/api/networking/v1beta1"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/intstr"
......@@ -49,13 +51,6 @@ type DynamoConfig struct {
ComponentType string `yaml:"component_type,omitempty"`
}
type Resources struct {
CPU *string `yaml:"cpu,omitempty" json:"cpu,omitempty"`
Memory *string `yaml:"memory,omitempty" json:"memory,omitempty"`
GPU *string `yaml:"gpu,omitempty" json:"gpu,omitempty"`
Custom map[string]string `yaml:"custom,omitempty" json:"custom,omitempty"`
}
type Traffic struct {
Timeout int `yaml:"timeout"`
}
......@@ -83,6 +78,13 @@ type ServiceConfig struct {
Config Config `yaml:"config"`
}
type Resources struct {
CPU *string `yaml:"cpu,omitempty" json:"cpu,omitempty"`
Memory *string `yaml:"memory,omitempty" json:"memory,omitempty"`
GPU *string `yaml:"gpu,omitempty" json:"gpu,omitempty"`
Custom map[string]string `yaml:"custom,omitempty" json:"custom,omitempty"`
}
type DynDeploymentConfig = map[string]*DynDeploymentServiceConfig
// ServiceConfig represents the configuration for a specific service
......@@ -147,6 +149,7 @@ func GenerateDynamoComponentsDeployments(ctx context.Context, parentDynamoGraphD
deployment := &v1alpha1.DynamoComponentDeployment{}
deployment.Spec.DynamoComponentDeploymentSharedSpec = component.DynamoComponentDeploymentSharedSpec
deployment.Name = GetDynamoComponentName(parentDynamoGraphDeployment, componentName)
deployment.Spec.BackendFramework = parentDynamoGraphDeployment.Spec.BackendFramework
deployment.Namespace = parentDynamoGraphDeployment.Namespace
deployment.Spec.ServiceName = componentName
dynamoNamespace := GetDefaultDynamoNamespace(ctx, parentDynamoGraphDeployment)
......@@ -328,152 +331,108 @@ type SecretsRetriever interface {
GetSecrets(namespace, registry string) ([]string, error)
}
func GenerateGrovePodGangSet(ctx context.Context, dynamoDeployment *v1alpha1.DynamoGraphDeployment, controllerConfig controller_common.Config, secretsRetriever SecretsRetriever) (*grovev1alpha1.PodGangSet, error) {
gangSet := &grovev1alpha1.PodGangSet{}
gangSet.Name = dynamoDeployment.Name
gangSet.Namespace = dynamoDeployment.Namespace
gangSet.Spec.Replicas = 1
if controllerConfig.Grove.TerminationDelay > 0 {
gangSet.Spec.Template.TerminationDelay = &metav1.Duration{Duration: controllerConfig.Grove.TerminationDelay}
}
for componentName, component := range dynamoDeployment.Spec.Services {
container := corev1.Container{
Name: "main",
LivenessProbe: component.LivenessProbe,
ReadinessProbe: component.ReadinessProbe,
Env: component.Envs,
Ports: []corev1.ContainerPort{
{
Protocol: corev1.ProtocolTCP,
Name: commonconsts.DynamoContainerPortName,
ContainerPort: int32(commonconsts.DynamoServicePort),
},
},
// getNumberOfNodes extracts the numberOfNodes from resources.nodes
func getNumberOfNodes(resources *common.Resources) int32 {
if resources != nil && resources.Requests != nil && resources.Requests.Nodes != "" {
if nodes, err := strconv.ParseInt(resources.Requests.Nodes, 10, 32); err == nil {
return int32(nodes)
}
// Add system port for worker components
if component.ComponentType == commonconsts.ComponentTypeWorker {
container.Ports = append(container.Ports, corev1.ContainerPort{
Protocol: corev1.ProtocolTCP,
Name: commonconsts.DynamoSystemPortName,
ContainerPort: int32(commonconsts.DynamoSystemPort),
})
}
if resources != nil && resources.Limits != nil && resources.Limits.Nodes != "" {
if nodes, err := strconv.ParseInt(resources.Limits.Nodes, 10, 32); err == nil {
return int32(nodes)
}
}
return 1 // Default to single node
}
resourcesConfig, err := controller_common.GetResourcesConfig(component.Resources)
if err != nil {
return nil, fmt.Errorf("failed to get resources config: %w", err)
}
container.Resources = *resourcesConfig
if component.ExtraPodSpec != nil && component.ExtraPodSpec.MainContainer != nil {
// merge the extraPodSpec from the parent deployment with the extraPodSpec from the service
err := mergo.Merge(&container, *component.ExtraPodSpec.MainContainer.DeepCopy(), mergo.WithOverride)
if err != nil {
return nil, fmt.Errorf("failed to merge extraPodSpec: %w", err)
}
}
// retrieve the image pull secrets for the container
imagePullSecrets := []corev1.LocalObjectReference{}
if secretsRetriever != nil {
secretsName, err := secretsRetriever.GetSecrets(dynamoDeployment.Namespace, container.Image)
if err != nil {
return nil, fmt.Errorf("failed to get secrets for component %s and image %s: %w", componentName, container.Image, err)
}
for _, secretName := range secretsName {
imagePullSecrets = append(imagePullSecrets, corev1.LocalObjectReference{
Name: secretName,
})
}
}
// merge the envs from the parent deployment with the envs from the service
if len(dynamoDeployment.Spec.Envs) > 0 {
container.Env = MergeEnvs(dynamoDeployment.Spec.Envs, container.Env)
}
container.Env = append(container.Env, corev1.EnvVar{
Name: commonconsts.EnvDynamoServicePort,
Value: fmt.Sprintf("%d", commonconsts.DynamoServicePort),
})
if controllerConfig.NatsAddress != "" {
container.Env = append(container.Env, corev1.EnvVar{
Name: "NATS_SERVER",
Value: controllerConfig.NatsAddress,
})
}
if controllerConfig.EtcdAddress != "" {
container.Env = append(container.Env, corev1.EnvVar{
Name: "ETCD_ENDPOINTS",
Value: controllerConfig.EtcdAddress,
})
}
if component.EnvFromSecret != nil {
container.EnvFrom = append(container.EnvFrom, corev1.EnvFromSource{
SecretRef: &corev1.SecretEnvSource{
LocalObjectReference: corev1.LocalObjectReference{Name: *component.EnvFromSecret},
},
})
}
gangSet.Spec.Template.Cliques = append(gangSet.Spec.Template.Cliques, &grovev1alpha1.PodCliqueTemplateSpec{
Name: strings.ToLower(componentName),
Labels: map[string]string{
commonconsts.KubeLabelDynamoSelector: GetDynamoComponentName(dynamoDeployment, componentName),
},
Spec: grovev1alpha1.PodCliqueSpec{
RoleName: strings.ToLower(componentName),
Replicas: func() int32 {
if component.Replicas != nil {
return *component.Replicas
}
return 1
}(),
PodSpec: corev1.PodSpec{
Containers: []corev1.Container{container},
ImagePullSecrets: imagePullSecrets,
},
},
})
// applyCliqueStartupDependencies configures StartsAfter dependencies for cliques in a PodGangSet
// based on the backend framework and multinode deployment patterns.
//
// Rules:
// - For VLLM and SGLang: worker cliques start after leader clique
// - For TRTLLM: leader clique starts after worker cliques
// - Only applies to multinode deployments (numberOfNodes > 1)
// - Sets the PodGangSet StartupType to Explicit if any dependencies are configured
func applyCliqueStartupDependencies(
gangSet *grovev1alpha1.PodGangSet,
roles []ServiceRole,
backendFramework BackendFramework,
numberOfNodes int32,
) {
if numberOfNodes <= 1 {
return // No dependencies for single-node deployments
}
// Add metrics labels if not disabled
cliqueIndex := len(gangSet.Spec.Template.Cliques) - 1
labels := gangSet.Spec.Template.Cliques[cliqueIndex].Labels
// Build maps of leader and worker clique names
var leaderCliqueName string
var workerCliqueNames []string
// Convert user-provided metrics annotation into controller-managed label
// By default (no annotation), metrics are enabled
metricsAnnotationValue := ""
if dynamoDeployment.Annotations != nil {
metricsAnnotationValue = dynamoDeployment.Annotations[commonconsts.KubeAnnotationEnableMetrics]
for _, r := range roles {
cliqueName := strings.ToLower(r.Name)
switch r.Role {
case RoleLeader:
leaderCliqueName = cliqueName
case RoleWorker:
workerCliqueNames = append(workerCliqueNames, cliqueName)
}
switch metricsAnnotationValue {
case commonconsts.KubeLabelValueFalse:
// Explicitly disabled, don't add the label
default:
// Any other value (including empty) enables metrics
labels[commonconsts.KubeLabelMetricsEnabled] = commonconsts.KubeLabelValueTrue
}
// Apply dependencies to cliques
hasDependencies := false
for _, clique := range gangSet.Spec.Template.Cliques {
// Find the corresponding role for this clique
var cliqueRole Role
for _, r := range roles {
if strings.ToLower(r.Name) == clique.Name {
cliqueRole = r.Role
break
}
}
// Add component type label if specified
if component.ComponentType != "" {
labels[commonconsts.KubeLabelDynamoComponentType] = component.ComponentType
// Determine dependencies for this clique
startsAfter := getCliqueStartupDependencies(cliqueRole, backendFramework, leaderCliqueName, workerCliqueNames)
if len(startsAfter) > 0 {
clique.Spec.StartsAfter = startsAfter
hasDependencies = true
}
}
gangSet.Spec.Template.Cliques[cliqueIndex].Labels = labels
// Set explicit startup type if we have any dependencies
if hasDependencies {
explicitStartupType := grovev1alpha1.CliqueStartupTypeExplicit
gangSet.Spec.Template.StartupType = &explicitStartupType
}
}
if component.PVC != nil {
cliqueIndex := len(gangSet.Spec.Template.Cliques) - 1
gangSet.Spec.Template.Cliques[cliqueIndex].Spec.PodSpec.Volumes = append(gangSet.Spec.Template.Cliques[cliqueIndex].Spec.PodSpec.Volumes, corev1.Volume{
Name: *component.PVC.Name,
VolumeSource: corev1.VolumeSource{
PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{
ClaimName: *component.PVC.Name,
},
},
})
gangSet.Spec.Template.Cliques[cliqueIndex].Spec.PodSpec.Containers[0].VolumeMounts = append(gangSet.Spec.Template.Cliques[cliqueIndex].Spec.PodSpec.Containers[0].VolumeMounts, corev1.VolumeMount{
Name: *component.PVC.Name,
MountPath: *component.PVC.MountPoint,
})
// getCliqueStartupDependencies determines the StartsAfter dependencies for a clique
// based on its role, backend framework, and available leader/worker clique names.
//
// Rules:
// - For VLLM and SGLang: worker cliques start after leader clique
// - For TRTLLM: leader clique starts after worker cliques
// - For other backends or single-node deployments: no dependencies
func getCliqueStartupDependencies(
role Role,
backendFramework BackendFramework,
leaderCliqueName string,
workerCliqueNames []string,
) []string {
switch backendFramework {
case BackendFrameworkVLLM, BackendFrameworkSGLang:
// For vllm and sglang: worker cliques start after leader clique
if role == RoleWorker && leaderCliqueName != "" {
return []string{leaderCliqueName}
}
case BackendFrameworkTRTLLM:
// For trtllm: leader clique starts after worker cliques
if role == RoleLeader && len(workerCliqueNames) > 0 {
return workerCliqueNames
}
}
return gangSet, nil
// No dependencies for other cases
return nil
}
func GenerateComponentService(ctx context.Context, componentName, componentNamespace string) (*corev1.Service, error) {
......@@ -564,32 +523,34 @@ func GenerateComponentVirtualService(ctx context.Context, componentName, compone
Namespace: componentNamespace,
},
}
vs.Spec = istioNetworking.VirtualService{
Hosts: []string{
getIngressHost(ingressSpec),
},
Gateways: []string{*ingressSpec.VirtualServiceGateway},
Http: []*istioNetworking.HTTPRoute{
{
Match: []*istioNetworking.HTTPMatchRequest{
{
Uri: &istioNetworking.StringMatch{
MatchType: &istioNetworking.StringMatch_Prefix{Prefix: "/"},
if ingressSpec.IsVirtualServiceEnabled() {
vs.Spec = istioNetworking.VirtualService{
Hosts: []string{
getIngressHost(ingressSpec),
},
Gateways: []string{*ingressSpec.VirtualServiceGateway},
Http: []*istioNetworking.HTTPRoute{
{
Match: []*istioNetworking.HTTPMatchRequest{
{
Uri: &istioNetworking.StringMatch{
MatchType: &istioNetworking.StringMatch_Prefix{Prefix: "/"},
},
},
},
},
Route: []*istioNetworking.HTTPRouteDestination{
{
Destination: &istioNetworking.Destination{
Host: componentName,
Port: &istioNetworking.PortSelector{
Number: commonconsts.DynamoServicePort,
Route: []*istioNetworking.HTTPRouteDestination{
{
Destination: &istioNetworking.Destination{
Host: componentName,
Port: &istioNetworking.PortSelector{
Number: commonconsts.DynamoServicePort,
},
},
},
},
},
},
},
}
}
return vs
}
......@@ -616,3 +577,602 @@ func GenerateDefaultIngressSpec(dynamoDeployment *v1alpha1.DynamoGraphDeployment
}
return res
}
// Helper: mergeContainerCommand returns userCmd if specified, else defaultCmd
func mergeContainerCommand(defaultCmd, userCmd []string) []string {
if len(userCmd) > 0 {
return userCmd
}
return defaultCmd
}
// Define Role enum for leader/worker/main
// Use this type everywhere instead of string for role
type Role string
const (
RoleLeader Role = "leader"
RoleWorker Role = "worker"
RoleMain Role = "main"
)
// Update ServiceRole struct for expandRolesForService
type ServiceRole struct {
Name string
Role Role
Replicas int32
}
// Update expandRolesForService to use Role
func expandRolesForService(serviceName string, serviceReplicas *int32, numberOfNodes int32) []ServiceRole {
var roles []ServiceRole
if numberOfNodes > 1 {
roles = append(roles, ServiceRole{Name: serviceName + "-" + commonconsts.GroveRoleSuffixLeader, Role: RoleLeader, Replicas: 1})
roles = append(roles, ServiceRole{Name: serviceName + "-" + commonconsts.GroveRoleSuffixWorker, Role: RoleWorker, Replicas: numberOfNodes - 1})
} else {
roles = append(roles, ServiceRole{Name: serviceName, Role: RoleMain, Replicas: *serviceReplicas})
}
return roles
}
// Define BackendFramework enum for sglang, vllm, trtllm
type BackendFramework string
const (
BackendFrameworkSGLang BackendFramework = "sglang"
BackendFrameworkVLLM BackendFramework = "vllm"
BackendFrameworkTRTLLM BackendFramework = "trtllm"
)
// Backend interface for modular backend logic
// Each backend (SGLang, VLLM, etc.) implements this interface
type Backend interface {
UpdateContainer(container *corev1.Container, numberOfNodes int32, role Role, component *v1alpha1.DynamoComponentDeploymentOverridesSpec, multinodeDeploymentType commonconsts.MultinodeDeploymentType, serviceName string)
UpdatePodSpec(podSpec *corev1.PodSpec, numberOfNodes int32, role Role, component *v1alpha1.DynamoComponentDeploymentOverridesSpec, multinodeDeploymentType commonconsts.MultinodeDeploymentType, serviceName string)
}
// NoopBackend does no processing - used for non-worker components like frontend, planner, router
type NoopBackend struct{}
func (b *NoopBackend) UpdateContainer(container *corev1.Container, numberOfNodes int32, role Role, component *v1alpha1.DynamoComponentDeploymentOverridesSpec, multinodeDeploymentType commonconsts.MultinodeDeploymentType, serviceName string) {
// No-op: frontend, planner, router, etc. don't need backend-specific processing
}
func (b *NoopBackend) UpdatePodSpec(podSpec *corev1.PodSpec, numberOfNodes int32, role Role, component *v1alpha1.DynamoComponentDeploymentOverridesSpec, multinodeDeploymentType commonconsts.MultinodeDeploymentType, serviceName string) {
// No-op: frontend, planner, router, etc. don't need backend-specific processing
}
// BackendFactory creates backend instances based on the framework type
func BackendFactory(backendFramework BackendFramework) Backend {
switch backendFramework {
case BackendFrameworkSGLang:
return &SGLangBackend{}
case BackendFrameworkVLLM:
return &VLLMBackend{}
case BackendFrameworkTRTLLM:
return &TRTLLMBackend{}
case BackendFrameworkNoop:
return &NoopBackend{}
default:
return nil
}
}
// isWorkerComponent checks if a component is a worker that needs backend framework detection
func isWorkerComponent(componentType string) bool {
return componentType == commonconsts.ComponentTypeWorker
}
// addStandardEnvVars adds the standard environment variables that are common to both Grove and Controller
func addStandardEnvVars(container *corev1.Container, controllerConfig controller_common.Config) {
container.Env = append(container.Env, corev1.EnvVar{
Name: commonconsts.EnvDynamoServicePort,
Value: fmt.Sprintf("%d", commonconsts.DynamoServicePort),
})
if controllerConfig.NatsAddress != "" {
container.Env = append(container.Env, corev1.EnvVar{
Name: "NATS_SERVER",
Value: controllerConfig.NatsAddress,
})
}
if controllerConfig.EtcdAddress != "" {
container.Env = append(container.Env, corev1.EnvVar{
Name: "ETCD_ENDPOINTS",
Value: controllerConfig.EtcdAddress,
})
}
}
// GenerateBasePodSpec creates a basic PodSpec with common logic shared between controller and grove
// Includes standard environment variables (DYNAMO_PORT, NATS_SERVER, ETCD_ENDPOINTS)
// Deployment-specific environment merging should be handled by the caller
func GenerateBasePodSpec(
component *v1alpha1.DynamoComponentDeploymentOverridesSpec,
backendFramework BackendFramework,
secretsRetriever SecretsRetriever,
namespace string,
role Role,
numberOfNodes int32,
controllerConfig controller_common.Config,
multinodeDeploymentType commonconsts.MultinodeDeploymentType,
serviceName string,
) (corev1.PodSpec, error) {
container := corev1.Container{
Name: "main",
LivenessProbe: component.LivenessProbe,
ReadinessProbe: component.ReadinessProbe,
Env: component.Envs,
Ports: []corev1.ContainerPort{
{
Protocol: corev1.ProtocolTCP,
Name: commonconsts.DynamoContainerPortName,
ContainerPort: int32(commonconsts.DynamoServicePort),
},
},
}
// Add system port for worker components
if component.ComponentType == commonconsts.ComponentTypeWorker {
container.Ports = append(container.Ports, corev1.ContainerPort{
Protocol: corev1.ProtocolTCP,
Name: commonconsts.DynamoSystemPortName,
ContainerPort: int32(commonconsts.DynamoSystemPort),
})
}
// First merge the mainContainer from extraPodSpec to get the base command and args
if component.ExtraPodSpec != nil && component.ExtraPodSpec.MainContainer != nil {
main := component.ExtraPodSpec.MainContainer.DeepCopy()
if main != nil {
// merge the extraPodSpec from the parent deployment with the extraPodSpec from the service
err := mergo.Merge(&container, *main, mergo.WithOverride)
if err != nil {
return corev1.PodSpec{}, fmt.Errorf("failed to merge extraPodSpec: %w", err)
}
container.Env = MergeEnvs(component.Envs, container.Env)
}
}
resourcesConfig, err := controller_common.GetResourcesConfig(component.Resources)
if err != nil {
return corev1.PodSpec{}, fmt.Errorf("failed to get resources config: %w", err)
}
if resourcesConfig != nil {
container.Resources = *resourcesConfig
}
imagePullSecrets := []corev1.LocalObjectReference{}
if secretsRetriever != nil && component.ExtraPodSpec != nil && component.ExtraPodSpec.MainContainer != nil && component.ExtraPodSpec.MainContainer.Image != "" {
secretsName, err := secretsRetriever.GetSecrets(namespace, component.ExtraPodSpec.MainContainer.Image)
if err == nil {
for _, secretName := range secretsName {
imagePullSecrets = append(imagePullSecrets, corev1.LocalObjectReference{Name: secretName})
}
}
}
if component.EnvFromSecret != nil {
container.EnvFrom = append(container.EnvFrom, corev1.EnvFromSource{
SecretRef: &corev1.SecretEnvSource{
LocalObjectReference: corev1.LocalObjectReference{Name: *component.EnvFromSecret},
},
})
}
addStandardEnvVars(&container, controllerConfig)
var volumes []corev1.Volume
if component.PVC != nil {
volumes = append(volumes, corev1.Volume{
Name: *component.PVC.Name,
VolumeSource: corev1.VolumeSource{
PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{
ClaimName: *component.PVC.Name,
},
},
})
container.VolumeMounts = append(container.VolumeMounts, corev1.VolumeMount{
Name: *component.PVC.Name,
MountPath: *component.PVC.MountPoint,
})
}
shmVolume, shmVolumeMount := generateSharedMemoryVolumeAndMount(&container.Resources)
volumes = append(volumes, shmVolume)
container.VolumeMounts = append(container.VolumeMounts, shmVolumeMount)
// Apply backend-specific container modifications
backend := BackendFactory(backendFramework)
if backend == nil {
return corev1.PodSpec{}, fmt.Errorf("unsupported backend framework: %s", backendFramework)
}
backend.UpdateContainer(&container, numberOfNodes, role, component, multinodeDeploymentType, serviceName)
var podSpec corev1.PodSpec
if component.ExtraPodSpec != nil && component.ExtraPodSpec.PodSpec != nil {
podSpec = *component.ExtraPodSpec.PodSpec.DeepCopy()
}
podSpec.Containers = append(podSpec.Containers, container)
podSpec.Volumes = append(podSpec.Volumes, volumes...)
podSpec.ImagePullSecrets = append(podSpec.ImagePullSecrets, imagePullSecrets...)
backend.UpdatePodSpec(&podSpec, numberOfNodes, role, component, multinodeDeploymentType, serviceName)
return podSpec, nil
}
func setMetricsLabels(labels map[string]string, dynamoGraphDeployment *v1alpha1.DynamoGraphDeployment) {
// Convert user-provided metrics annotation into controller-managed label
// By default (no annotation), metrics are enabled
if metricsAnnotationValue, ok := dynamoGraphDeployment.Annotations[commonconsts.KubeAnnotationEnableMetrics]; ok && metricsAnnotationValue == commonconsts.KubeLabelValueFalse {
// Explicitly disabled, don't add the label
return
}
// Any other value (including empty) enables metrics
labels[commonconsts.KubeLabelMetricsEnabled] = commonconsts.KubeLabelValueTrue
}
// GeneratePodSpecForComponent creates a PodSpec for Grove deployments (simplified wrapper)
func GeneratePodSpecForComponent(
component *v1alpha1.DynamoComponentDeploymentOverridesSpec,
backendFramework BackendFramework,
secretsRetriever SecretsRetriever,
dynamoDeployment *v1alpha1.DynamoGraphDeployment,
role Role,
numberOfNodes int32,
controllerConfig controller_common.Config,
multinodeDeploymentType commonconsts.MultinodeDeploymentType,
serviceName string,
) (corev1.PodSpec, error) {
if len(dynamoDeployment.Spec.Envs) > 0 {
component.Envs = MergeEnvs(dynamoDeployment.Spec.Envs, component.Envs)
}
podSpec, err := GenerateBasePodSpec(component, backendFramework, secretsRetriever, dynamoDeployment.Namespace, role, numberOfNodes, controllerConfig, multinodeDeploymentType, serviceName)
if err != nil {
return corev1.PodSpec{}, err
}
return podSpec, nil
}
// GenerateGrovePodGangSet generates a Grove PodGangSet for the given deployment, supporting both single-node and multinode cases.
func GenerateGrovePodGangSet(
ctx context.Context,
dynamoDeployment *v1alpha1.DynamoGraphDeployment,
controllerConfig controller_common.Config,
secretsRetriever SecretsRetriever,
) (*grovev1alpha1.PodGangSet, error) {
gangSet := &grovev1alpha1.PodGangSet{}
gangSet.Name = dynamoDeployment.Name
gangSet.Namespace = dynamoDeployment.Namespace
gangSet.Spec.Replicas = 1
gangSet.Spec.Template.HeadlessServiceConfig = &grovev1alpha1.HeadlessServiceConfig{
PublishNotReadyAddresses: true,
}
if controllerConfig.Grove.TerminationDelay > 0 {
gangSet.Spec.Template.TerminationDelay = &metav1.Duration{Duration: controllerConfig.Grove.TerminationDelay}
}
var scalingGroups []grovev1alpha1.PodCliqueScalingGroupConfig
for serviceName, component := range dynamoDeployment.Spec.Services {
// Determine backend framework using hybrid approach
backendFramework, err := getBackendFrameworkFromComponent(component, dynamoDeployment)
if err != nil {
return nil, fmt.Errorf("failed to determine backend framework for service %s: %w", serviceName, err)
}
numberOfNodes := getNumberOfNodes(component.Resources)
isMultinode := numberOfNodes > 1
roles := expandRolesForService(serviceName, component.Replicas, numberOfNodes)
var cliqueNames []string
for _, r := range roles {
podSpec, err := GeneratePodSpecForComponent(
component,
backendFramework,
secretsRetriever,
dynamoDeployment,
r.Role,
numberOfNodes,
controllerConfig,
commonconsts.MultinodeDeploymentTypeGrove,
serviceName,
)
if err != nil {
return nil, fmt.Errorf("failed to generate podSpec for role %s: %w", r.Name, err)
}
clique := &grovev1alpha1.PodCliqueTemplateSpec{
Name: strings.ToLower(r.Name),
Spec: grovev1alpha1.PodCliqueSpec{
RoleName: strings.ToLower(r.Name),
Replicas: r.Replicas,
PodSpec: podSpec,
},
}
labels, err := generateLabels(component, dynamoDeployment, r.Name)
if err != nil {
return nil, fmt.Errorf("failed to generate labels: %w", err)
}
clique.Labels = labels
annotations, err := generateAnnotations(component)
if err != nil {
return nil, fmt.Errorf("failed to generate annotations: %w", err)
}
clique.Annotations = annotations
gangSet.Spec.Template.Cliques = append(gangSet.Spec.Template.Cliques, clique)
cliqueNames = append(cliqueNames, strings.ToLower(r.Name))
}
// Apply startup dependencies for this service
applyCliqueStartupDependencies(gangSet, roles, backendFramework, numberOfNodes)
if isMultinode {
scalingGroups = append(scalingGroups, grovev1alpha1.PodCliqueScalingGroupConfig{
Name: strings.ToLower(serviceName),
CliqueNames: cliqueNames,
Replicas: component.Replicas,
})
}
}
if len(scalingGroups) > 0 {
gangSet.Spec.Template.PodCliqueScalingGroupConfigs = scalingGroups
}
return gangSet, nil
}
func generateLabels(component *v1alpha1.DynamoComponentDeploymentOverridesSpec, dynamoDeployment *v1alpha1.DynamoGraphDeployment, componentName string) (map[string]string, error) {
labels := make(map[string]string)
labels[commonconsts.KubeLabelDynamoSelector] = GetDynamoComponentName(dynamoDeployment, componentName)
if component.ComponentType != "" {
labels[commonconsts.KubeLabelDynamoComponentType] = component.ComponentType
}
setMetricsLabels(labels, dynamoDeployment)
if component.Labels != nil {
err := mergo.Merge(&labels, component.Labels, mergo.WithOverride)
if err != nil {
return nil, fmt.Errorf("failed to merge labels: %w", err)
}
}
if component.ExtraPodMetadata != nil {
err := mergo.Merge(&labels, component.ExtraPodMetadata.Labels, mergo.WithOverride)
if err != nil {
return nil, fmt.Errorf("failed to merge extraPodMetadata labels: %w", err)
}
}
return labels, nil
}
func generateAnnotations(component *v1alpha1.DynamoComponentDeploymentOverridesSpec) (map[string]string, error) {
annotations := make(map[string]string)
if component.Annotations != nil {
err := mergo.Merge(&annotations, component.Annotations, mergo.WithOverride)
if err != nil {
return nil, fmt.Errorf("failed to merge annotations: %w", err)
}
}
if component.ExtraPodMetadata != nil {
err := mergo.Merge(&annotations, component.ExtraPodMetadata.Annotations, mergo.WithOverride)
if err != nil {
return nil, fmt.Errorf("failed to merge extraPodMetadata annotations: %w", err)
}
}
return annotations, nil
}
// detectBackendFrameworkFromArgs detects the backend framework from command/args
func detectBackendFrameworkFromArgs(command []string, args []string) (BackendFramework, error) {
// Combine command and args to search through all parts
allParts := append(command, args...)
fullCommand := strings.Join(allParts, " ")
// Pattern to match python -m dynamo.{backend}.something
patterns := map[BackendFramework]*regexp.Regexp{
BackendFrameworkVLLM: regexp.MustCompile(`python[0-9.]*\s+[^|&;]*-m\s+[^|&;]*dynamo\.vllm[^|&;]*`),
BackendFrameworkSGLang: regexp.MustCompile(`python[0-9.]*\s+[^|&;]*-m\s+[^|&;]*dynamo\.sglang[^|&;]*`),
BackendFrameworkTRTLLM: regexp.MustCompile(`python[0-9.]*\s+[^|&;]*-m\s+[^|&;]*dynamo\.trtllm[^|&;]*`),
}
var detected []BackendFramework
for framework, pattern := range patterns {
if pattern.MatchString(fullCommand) {
detected = append(detected, framework)
}
}
if len(detected) == 0 {
return "", fmt.Errorf("no backend framework detected from command: %q", fullCommand)
}
if len(detected) > 1 {
return "", fmt.Errorf("multiple backend frameworks detected from command: %v in %q", detected, fullCommand)
}
return detected[0], nil
}
// BackendFrameworkNoop represents no backend processing needed
const BackendFrameworkNoop BackendFramework = "noop"
// determineBackendFramework is the core logic for hybrid backend framework detection
// Takes extracted parameters and applies the detection logic
func determineBackendFramework(
componentType string,
command []string,
args []string,
explicitBackendFramework string,
) (BackendFramework, error) {
// Check if this is a worker component - if not, use noop backend
if !isWorkerComponent(componentType) {
return BackendFrameworkNoop, nil
}
// Worker component - apply backend framework detection
var detectedFramework BackendFramework
var detectionError error
// Try to detect from command/args
if len(command) > 0 || len(args) > 0 {
detected, err := detectBackendFrameworkFromArgs(command, args)
if err == nil {
detectedFramework = detected
} else {
detectionError = err
}
}
// Get explicit framework
var explicitFramework BackendFramework
if explicitBackendFramework != "" {
explicitFramework = BackendFramework(explicitBackendFramework)
}
// Validate consistency if both detected and explicit exist
if detectedFramework != "" && explicitFramework != "" && detectedFramework != explicitFramework {
return "", fmt.Errorf("backend framework mismatch: detected %q from command but explicitly configured as %q",
detectedFramework, explicitFramework)
}
// Return in order of preference: detected > explicit > error
if detectedFramework != "" {
return detectedFramework, nil
}
if explicitFramework != "" {
return explicitFramework, nil
}
// If we couldn't detect and no explicit config, return error
if detectionError != nil {
return "", fmt.Errorf("could not determine backend framework: %w", detectionError)
}
// No command/args to detect from and no explicit config
return "", fmt.Errorf("backend framework must be specified explicitly or detectable from command/args")
}
// getBackendFrameworkFromComponent attempts to determine backend framework using hybrid approach:
// 1. Check if component is a worker - if not, return noop
// 2. For workers: try to detect from command/args, fall back to explicit config
// 3. Return error if worker has neither detection nor explicit config
// Also validates consistency between detected and explicit if both exist
func getBackendFrameworkFromComponent(
component *v1alpha1.DynamoComponentDeploymentOverridesSpec,
dynamoDeployment *v1alpha1.DynamoGraphDeployment,
) (BackendFramework, error) {
// Extract command/args from component
var command, args []string
if component.ExtraPodSpec != nil && component.ExtraPodSpec.MainContainer != nil {
command = component.ExtraPodSpec.MainContainer.Command
args = component.ExtraPodSpec.MainContainer.Args
}
// Extract explicit backend framework from deployment
explicitBackendFramework := dynamoDeployment.Spec.BackendFramework
return determineBackendFramework(
component.ComponentType,
command,
args,
explicitBackendFramework,
)
}
// ConvertDynamoComponentDeploymentToSpec converts a DynamoComponentDeployment to our component spec interface
// This is a helper for the controller to use our backend logic
func ConvertDynamoComponentDeploymentToSpec(dynComponent *v1alpha1.DynamoComponentDeployment) *v1alpha1.DynamoComponentDeploymentOverridesSpec {
return &v1alpha1.DynamoComponentDeploymentOverridesSpec{
DynamoComponentDeploymentSharedSpec: *dynComponent.Spec.DynamoComponentDeploymentSharedSpec.DeepCopy(),
}
}
// getBackendFrameworkFromDynamoComponent determines backend framework for a DynamoComponentDeployment
func getBackendFrameworkFromDynamoComponent(dynComponent *v1alpha1.DynamoComponentDeployment) (BackendFramework, error) {
// Extract command/args from component
var command, args []string
if dynComponent.Spec.ExtraPodSpec != nil && dynComponent.Spec.ExtraPodSpec.MainContainer != nil {
command = dynComponent.Spec.ExtraPodSpec.MainContainer.Command
args = dynComponent.Spec.ExtraPodSpec.MainContainer.Args
}
// Extract explicit backend framework
explicitBackendFramework := dynComponent.Spec.BackendFramework
return determineBackendFramework(
dynComponent.Spec.ComponentType,
command,
args,
explicitBackendFramework,
)
}
// GenerateBasePodSpecForController generates a PodSpec using backend logic for controller usage
// This preserves the base pod generation while allowing controller-specific enhancements
func GenerateBasePodSpecForController(
dynComponent *v1alpha1.DynamoComponentDeployment,
secretsRetriever SecretsRetriever,
controllerConfig controller_common.Config,
role Role,
multinodeDeploymentType commonconsts.MultinodeDeploymentType,
) (corev1.PodSpec, error) {
// Convert to our interface
componentSpec := ConvertDynamoComponentDeploymentToSpec(dynComponent)
numberOfNodes := getNumberOfNodes(dynComponent.Spec.DynamoComponentDeploymentSharedSpec.Resources)
// Determine backend framework using hybrid approach
backendFramework, err := getBackendFrameworkFromDynamoComponent(dynComponent)
if err != nil {
return corev1.PodSpec{}, fmt.Errorf("failed to determine backend framework: %w", err)
}
// Generate base PodSpec with standard env vars using merged component envs
// For controller usage, we may not have serviceName, so use the component name as fallback
serviceName := dynComponent.Name
podSpec, err := GenerateBasePodSpec(
componentSpec,
backendFramework,
secretsRetriever,
dynComponent.Namespace,
role,
numberOfNodes,
controllerConfig,
multinodeDeploymentType,
serviceName,
)
if err != nil {
return corev1.PodSpec{}, err
}
return podSpec, nil
}
func generateSharedMemoryVolumeAndMount(resources *corev1.ResourceRequirements) (corev1.Volume, corev1.VolumeMount) {
sharedMemorySizeLimit := resource.MustParse("512Mi")
// Check if we have memory limits to work with
memoryLimit := resources.Limits[corev1.ResourceMemory]
if !memoryLimit.IsZero() {
// Use 1/4 of memory limit
calculatedSize := resource.NewQuantity(memoryLimit.Value()/4, resource.BinarySI)
// Apply bounds: minimum 512Mi, maximum 8Gi
minSize := resource.MustParse("512Mi")
maxSize := resource.MustParse("8Gi")
if calculatedSize.Cmp(minSize) > 0 && calculatedSize.Cmp(maxSize) < 0 {
sharedMemorySizeLimit = *calculatedSize
} else if calculatedSize.Cmp(maxSize) >= 0 {
sharedMemorySizeLimit = maxSize // Cap at maximum
}
// If calculatedSize < minSize, keep the 512Mi base
}
volume := corev1.Volume{
Name: commonconsts.KubeValueNameSharedMemory,
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{
Medium: corev1.StorageMediumMemory,
SizeLimit: &sharedMemorySizeLimit,
},
},
}
volumeMount := corev1.VolumeMount{
Name: commonconsts.KubeValueNameSharedMemory,
MountPath: "/dev/shm",
}
return volume, volumeMount
}
......@@ -22,14 +22,13 @@ import (
"fmt"
"reflect"
"sort"
"strings"
"testing"
"time"
grovev1alpha1 "github.com/NVIDIA/grove/operator/api/core/v1alpha1"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/dynamo/common"
compounaiCommon "github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/dynamo/common"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1"
nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1"
commonconsts "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/consts"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/controller_common"
"github.com/google/go-cmp/cmp"
......@@ -37,6 +36,7 @@ import (
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/intstr"
ptr "k8s.io/utils/ptr"
)
func TestGenerateDynamoComponentsDeployments(t *testing.T) {
......@@ -65,8 +65,8 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) {
DynamoNamespace: &[]string{"default"}[0],
ComponentType: "main",
Replicas: &[]int32{3}[0],
Resources: &compounaiCommon.Resources{
Requests: &compounaiCommon.ResourceItem{
Resources: &common.Resources{
Requests: &common.ResourceItem{
CPU: "1",
Memory: "1Gi",
GPU: "0",
......@@ -79,8 +79,8 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) {
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
DynamoNamespace: &[]string{"default"}[0],
Replicas: &[]int32{3}[0],
Resources: &compounaiCommon.Resources{
Requests: &compounaiCommon.ResourceItem{
Resources: &common.Resources{
Requests: &common.ResourceItem{
CPU: "1",
Memory: "1Gi",
GPU: "0",
......@@ -109,8 +109,8 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) {
DynamoNamespace: &[]string{"default"}[0],
ComponentType: "main",
Replicas: &[]int32{3}[0],
Resources: &compounaiCommon.Resources{
Requests: &compounaiCommon.ResourceItem{
Resources: &common.Resources{
Requests: &common.ResourceItem{
CPU: "1",
Memory: "1Gi",
GPU: "0",
......@@ -143,8 +143,8 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) {
commonconsts.KubeLabelDynamoComponent: "service2",
commonconsts.KubeLabelDynamoNamespace: "default",
},
Resources: &compounaiCommon.Resources{
Requests: &compounaiCommon.ResourceItem{
Resources: &common.Resources{
Requests: &common.ResourceItem{
CPU: "1",
Memory: "1Gi",
GPU: "0",
......@@ -173,8 +173,8 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) {
DynamoNamespace: nil,
ComponentType: "main",
Replicas: &[]int32{3}[0],
Resources: &compounaiCommon.Resources{
Requests: &compounaiCommon.ResourceItem{
Resources: &common.Resources{
Requests: &common.ResourceItem{
CPU: "1",
Memory: "1Gi",
GPU: "0",
......@@ -187,8 +187,8 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) {
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
DynamoNamespace: nil,
Replicas: &[]int32{3}[0],
Resources: &compounaiCommon.Resources{
Requests: &compounaiCommon.ResourceItem{
Resources: &common.Resources{
Requests: &common.ResourceItem{
CPU: "1",
Memory: "1Gi",
GPU: "0",
......@@ -217,8 +217,8 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) {
DynamoNamespace: &[]string{"dynamo-test-dynamographdeployment"}[0],
ComponentType: "main",
Replicas: &[]int32{3}[0],
Resources: &compounaiCommon.Resources{
Requests: &compounaiCommon.ResourceItem{
Resources: &common.Resources{
Requests: &common.ResourceItem{
CPU: "1",
Memory: "1Gi",
GPU: "0",
......@@ -251,8 +251,8 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) {
commonconsts.KubeLabelDynamoComponent: "service2",
commonconsts.KubeLabelDynamoNamespace: "dynamo-test-dynamographdeployment",
},
Resources: &compounaiCommon.Resources{
Requests: &compounaiCommon.ResourceItem{
Resources: &common.Resources{
Requests: &common.ResourceItem{
CPU: "1",
Memory: "1Gi",
GPU: "0",
......@@ -281,8 +281,8 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) {
DynamoNamespace: &[]string{"default"}[0],
ComponentType: "main",
Replicas: &[]int32{3}[0],
Resources: &compounaiCommon.Resources{
Requests: &compounaiCommon.ResourceItem{
Resources: &common.Resources{
Requests: &common.ResourceItem{
CPU: "1",
Memory: "1Gi",
GPU: "0",
......@@ -295,8 +295,8 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) {
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
DynamoNamespace: &[]string{"another"}[0],
Replicas: &[]int32{3}[0],
Resources: &compounaiCommon.Resources{
Requests: &compounaiCommon.ResourceItem{
Resources: &common.Resources{
Requests: &common.ResourceItem{
CPU: "1",
Memory: "1Gi",
GPU: "0",
......@@ -327,8 +327,8 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) {
DynamoNamespace: nil,
ComponentType: "main",
Replicas: &[]int32{3}[0],
Resources: &compounaiCommon.Resources{
Requests: &compounaiCommon.ResourceItem{
Resources: &common.Resources{
Requests: &common.ResourceItem{
CPU: "1",
Memory: "1Gi",
GPU: "0",
......@@ -341,8 +341,8 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) {
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
DynamoNamespace: nil,
Replicas: &[]int32{3}[0],
Resources: &compounaiCommon.Resources{
Requests: &compounaiCommon.ResourceItem{
Resources: &common.Resources{
Requests: &common.ResourceItem{
CPU: "1",
Memory: "1Gi",
GPU: "0",
......@@ -375,8 +375,8 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) {
DynamoNamespace: &[]string{"dynamo-test-dynamographdeployment"}[0],
ComponentType: "main",
Replicas: &[]int32{3}[0],
Resources: &compounaiCommon.Resources{
Requests: &compounaiCommon.ResourceItem{
Resources: &common.Resources{
Requests: &common.ResourceItem{
CPU: "1",
Memory: "1Gi",
GPU: "0",
......@@ -413,8 +413,8 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) {
commonconsts.KubeLabelDynamoComponent: "service2",
commonconsts.KubeLabelDynamoNamespace: "dynamo-test-dynamographdeployment",
},
Resources: &compounaiCommon.Resources{
Requests: &compounaiCommon.ResourceItem{
Resources: &common.Resources{
Requests: &common.ResourceItem{
CPU: "1",
Memory: "1Gi",
GPU: "0",
......@@ -449,8 +449,8 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) {
DynamoNamespace: nil,
ComponentType: "main",
Replicas: &[]int32{3}[0],
Resources: &compounaiCommon.Resources{
Requests: &compounaiCommon.ResourceItem{
Resources: &common.Resources{
Requests: &common.ResourceItem{
CPU: "1",
Memory: "1Gi",
GPU: "0",
......@@ -463,8 +463,8 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) {
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
DynamoNamespace: nil,
Replicas: &[]int32{3}[0],
Resources: &compounaiCommon.Resources{
Requests: &compounaiCommon.ResourceItem{
Resources: &common.Resources{
Requests: &common.ResourceItem{
CPU: "1",
Memory: "1Gi",
GPU: "0",
......@@ -493,14 +493,14 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) {
DynamoNamespace: &[]string{"dynamo-test-dynamographdeployment"}[0],
ComponentType: "main",
Replicas: &[]int32{3}[0],
Resources: &compounaiCommon.Resources{
Requests: &compounaiCommon.ResourceItem{
Resources: &common.Resources{
Requests: &common.ResourceItem{
CPU: "2",
Memory: "2Gi",
GPU: "2",
Custom: map[string]string{},
},
Limits: &compounaiCommon.ResourceItem{
Limits: &common.ResourceItem{
CPU: "2",
Memory: "2Gi",
GPU: "2",
......@@ -539,8 +539,8 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) {
commonconsts.KubeLabelDynamoComponent: "service2",
commonconsts.KubeLabelDynamoNamespace: "dynamo-test-dynamographdeployment",
},
Resources: &compounaiCommon.Resources{
Requests: &compounaiCommon.ResourceItem{
Resources: &common.Resources{
Requests: &common.ResourceItem{
CPU: "1",
Memory: "1Gi",
GPU: "0",
......@@ -569,21 +569,22 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) {
Namespace: "default",
},
Spec: v1alpha1.DynamoGraphDeploymentSpec{
BackendFramework: string(BackendFrameworkSGLang),
Services: map[string]*v1alpha1.DynamoComponentDeploymentOverridesSpec{
"service1": {
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
DynamoNamespace: &[]string{"default"}[0],
ComponentType: "main",
Replicas: &[]int32{3}[0],
Resources: &compounaiCommon.Resources{
Requests: &compounaiCommon.ResourceItem{
Resources: &common.Resources{
Requests: &common.ResourceItem{
CPU: "1",
Memory: "1Gi",
GPU: "0",
Custom: map[string]string{},
},
},
ExtraPodSpec: &compounaiCommon.ExtraPodSpec{
ExtraPodSpec: &common.ExtraPodSpec{
MainContainer: &corev1.Container{
Command: []string{"sh", "-c"},
Args: []string{"echo hello world", "sleep 99999"},
......@@ -595,8 +596,8 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) {
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
DynamoNamespace: &[]string{"default"}[0],
Replicas: &[]int32{3}[0],
Resources: &compounaiCommon.Resources{
Requests: &compounaiCommon.ResourceItem{
Resources: &common.Resources{
Requests: &common.ResourceItem{
CPU: "1",
Memory: "1Gi",
GPU: "0",
......@@ -620,13 +621,14 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) {
},
},
Spec: v1alpha1.DynamoComponentDeploymentSpec{
BackendFramework: string(BackendFrameworkSGLang),
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ServiceName: "service1",
DynamoNamespace: &[]string{"default"}[0],
ComponentType: "main",
Replicas: &[]int32{3}[0],
Resources: &compounaiCommon.Resources{
Requests: &compounaiCommon.ResourceItem{
Resources: &common.Resources{
Requests: &common.ResourceItem{
CPU: "1",
Memory: "1Gi",
GPU: "0",
......@@ -638,7 +640,7 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) {
commonconsts.KubeLabelDynamoNamespace: "default",
},
Autoscaling: nil,
ExtraPodSpec: &compounaiCommon.ExtraPodSpec{
ExtraPodSpec: &common.ExtraPodSpec{
MainContainer: &corev1.Container{
Command: []string{"sh", "-c"},
Args: []string{"echo hello world", "sleep 99999"},
......@@ -657,6 +659,7 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) {
},
},
Spec: v1alpha1.DynamoComponentDeploymentSpec{
BackendFramework: string(BackendFrameworkSGLang),
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ServiceName: "service2",
DynamoNamespace: &[]string{"default"}[0],
......@@ -665,8 +668,8 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) {
commonconsts.KubeLabelDynamoComponent: "service2",
commonconsts.KubeLabelDynamoNamespace: "default",
},
Resources: &compounaiCommon.Resources{
Requests: &compounaiCommon.ResourceItem{
Resources: &common.Resources{
Requests: &common.ResourceItem{
CPU: "1",
Memory: "1Gi",
GPU: "0",
......@@ -741,7 +744,7 @@ func TestSetLwsAnnotations(t *testing.T) {
func Test_updateDynDeploymentConfig(t *testing.T) {
type args struct {
dynamoDeploymentComponent *nvidiacomv1alpha1.DynamoComponentDeployment
dynamoDeploymentComponent *v1alpha1.DynamoComponentDeployment
newPort int
}
tests := []struct {
......@@ -753,10 +756,10 @@ func Test_updateDynDeploymentConfig(t *testing.T) {
{
name: "main component",
args: args{
dynamoDeploymentComponent: &nvidiacomv1alpha1.DynamoComponentDeployment{
Spec: nvidiacomv1alpha1.DynamoComponentDeploymentSpec{
dynamoDeploymentComponent: &v1alpha1.DynamoComponentDeployment{
Spec: v1alpha1.DynamoComponentDeploymentSpec{
DynamoTag: "graphs.agg:Frontend",
DynamoComponentDeploymentSharedSpec: nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ServiceName: "Frontend",
Envs: []corev1.EnvVar{
{
......@@ -779,10 +782,10 @@ func Test_updateDynDeploymentConfig(t *testing.T) {
{
name: "not main component",
args: args{
dynamoDeploymentComponent: &nvidiacomv1alpha1.DynamoComponentDeployment{
Spec: nvidiacomv1alpha1.DynamoComponentDeploymentSpec{
dynamoDeploymentComponent: &v1alpha1.DynamoComponentDeployment{
Spec: v1alpha1.DynamoComponentDeploymentSpec{
DynamoTag: "graphs.agg:Frontend",
DynamoComponentDeploymentSharedSpec: nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ServiceName: "Other",
Envs: []corev1.EnvVar{
{
......@@ -801,10 +804,10 @@ func Test_updateDynDeploymentConfig(t *testing.T) {
{
name: "no config variable",
args: args{
dynamoDeploymentComponent: &nvidiacomv1alpha1.DynamoComponentDeployment{
Spec: nvidiacomv1alpha1.DynamoComponentDeploymentSpec{
dynamoDeploymentComponent: &v1alpha1.DynamoComponentDeployment{
Spec: v1alpha1.DynamoComponentDeploymentSpec{
DynamoTag: "graphs.agg:Frontend",
DynamoComponentDeploymentSharedSpec: nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ServiceName: "Frontend",
Envs: []corev1.EnvVar{
{
......@@ -838,21 +841,21 @@ func Test_updateDynDeploymentConfig(t *testing.T) {
func Test_overrideWithDynDeploymentConfig(t *testing.T) {
type args struct {
ctx context.Context
dynamoDeploymentComponent *nvidiacomv1alpha1.DynamoComponentDeployment
dynamoDeploymentComponent *v1alpha1.DynamoComponentDeployment
}
tests := []struct {
name string
args args
wantErr bool
expected *nvidiacomv1alpha1.DynamoComponentDeployment
expected *v1alpha1.DynamoComponentDeployment
}{
{
name: "no env var",
args: args{
ctx: context.Background(),
dynamoDeploymentComponent: &nvidiacomv1alpha1.DynamoComponentDeployment{
Spec: nvidiacomv1alpha1.DynamoComponentDeploymentSpec{
DynamoComponentDeploymentSharedSpec: nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec{
dynamoDeploymentComponent: &v1alpha1.DynamoComponentDeployment{
Spec: v1alpha1.DynamoComponentDeploymentSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ServiceName: "Frontend",
Replicas: &[]int32{1}[0],
Resources: &common.Resources{
......@@ -867,9 +870,9 @@ func Test_overrideWithDynDeploymentConfig(t *testing.T) {
},
},
wantErr: false,
expected: &nvidiacomv1alpha1.DynamoComponentDeployment{
Spec: nvidiacomv1alpha1.DynamoComponentDeploymentSpec{
DynamoComponentDeploymentSharedSpec: nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec{
expected: &v1alpha1.DynamoComponentDeployment{
Spec: v1alpha1.DynamoComponentDeploymentSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ServiceName: "Frontend",
Replicas: &[]int32{1}[0],
Resources: &common.Resources{
......@@ -887,9 +890,9 @@ func Test_overrideWithDynDeploymentConfig(t *testing.T) {
name: "override workers and resources",
args: args{
ctx: context.Background(),
dynamoDeploymentComponent: &nvidiacomv1alpha1.DynamoComponentDeployment{
Spec: nvidiacomv1alpha1.DynamoComponentDeploymentSpec{
DynamoComponentDeploymentSharedSpec: nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec{
dynamoDeploymentComponent: &v1alpha1.DynamoComponentDeployment{
Spec: v1alpha1.DynamoComponentDeploymentSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ServiceName: "Frontend",
Replicas: &[]int32{1}[0],
Resources: &common.Resources{
......@@ -910,9 +913,9 @@ func Test_overrideWithDynDeploymentConfig(t *testing.T) {
},
},
wantErr: false,
expected: &nvidiacomv1alpha1.DynamoComponentDeployment{
Spec: nvidiacomv1alpha1.DynamoComponentDeploymentSpec{
DynamoComponentDeploymentSharedSpec: nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec{
expected: &v1alpha1.DynamoComponentDeployment{
Spec: v1alpha1.DynamoComponentDeploymentSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ServiceName: "Frontend",
Replicas: &[]int32{3}[0],
Resources: &common.Resources{
......@@ -941,9 +944,9 @@ func Test_overrideWithDynDeploymentConfig(t *testing.T) {
name: "override workers and resources with gpusPerNode",
args: args{
ctx: context.Background(),
dynamoDeploymentComponent: &nvidiacomv1alpha1.DynamoComponentDeployment{
Spec: nvidiacomv1alpha1.DynamoComponentDeploymentSpec{
DynamoComponentDeploymentSharedSpec: nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec{
dynamoDeploymentComponent: &v1alpha1.DynamoComponentDeployment{
Spec: v1alpha1.DynamoComponentDeploymentSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ServiceName: "Frontend",
Replicas: nil,
Resources: &common.Resources{
......@@ -964,9 +967,9 @@ func Test_overrideWithDynDeploymentConfig(t *testing.T) {
},
},
wantErr: false,
expected: &nvidiacomv1alpha1.DynamoComponentDeployment{
Spec: nvidiacomv1alpha1.DynamoComponentDeploymentSpec{
DynamoComponentDeploymentSharedSpec: nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec{
expected: &v1alpha1.DynamoComponentDeployment{
Spec: v1alpha1.DynamoComponentDeploymentSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ServiceName: "Frontend",
Replicas: &[]int32{3}[0],
Resources: &common.Resources{
......@@ -999,9 +1002,9 @@ func Test_overrideWithDynDeploymentConfig(t *testing.T) {
name: "override subset of resources",
args: args{
ctx: context.Background(),
dynamoDeploymentComponent: &nvidiacomv1alpha1.DynamoComponentDeployment{
Spec: nvidiacomv1alpha1.DynamoComponentDeploymentSpec{
DynamoComponentDeploymentSharedSpec: nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec{
dynamoDeploymentComponent: &v1alpha1.DynamoComponentDeployment{
Spec: v1alpha1.DynamoComponentDeploymentSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ServiceName: "Frontend",
Replicas: nil,
Resources: &common.Resources{
......@@ -1022,9 +1025,9 @@ func Test_overrideWithDynDeploymentConfig(t *testing.T) {
},
},
wantErr: false,
expected: &nvidiacomv1alpha1.DynamoComponentDeployment{
Spec: nvidiacomv1alpha1.DynamoComponentDeploymentSpec{
DynamoComponentDeploymentSharedSpec: nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec{
expected: &v1alpha1.DynamoComponentDeployment{
Spec: v1alpha1.DynamoComponentDeploymentSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ServiceName: "Frontend",
Replicas: &[]int32{3}[0],
Resources: &common.Resources{
......@@ -1131,7 +1134,7 @@ func TestGenerateGrovePodGangSet(t *testing.T) {
wantErr bool
}{
{
name: "test_generate_grove_pod_gang_set",
name: "test_generate_grove_pod_gang_set_single_node",
args: args{
ctx: context.Background(),
controllerConfig: controller_common.Config{
......@@ -1156,6 +1159,17 @@ func TestGenerateGrovePodGangSet(t *testing.T) {
Services: map[string]*v1alpha1.DynamoComponentDeploymentOverridesSpec{
"Frontend": {
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ComponentType: "main", // Frontend component
ExtraPodMetadata: &common.ExtraPodMetadata{
Annotations: map[string]string{
"nvidia.com/annotation1": "annotation1",
"nvidia.com/annotation2": "annotation2",
},
Labels: map[string]string{
"nvidia.com/label1": "label1",
"nvidia.com/label2": "label2",
},
},
Replicas: &[]int32{1}[0],
Resources: &common.Resources{
Requests: &common.ResourceItem{
......@@ -1192,6 +1206,14 @@ func TestGenerateGrovePodGangSet(t *testing.T) {
},
},
ExtraPodSpec: &common.ExtraPodSpec{
PodSpec: &corev1.PodSpec{
TerminationGracePeriodSeconds: ptr.To(int64(10)),
ImagePullSecrets: []corev1.LocalObjectReference{
{
Name: "frontend-secret",
},
},
},
MainContainer: &corev1.Container{
Command: []string{
"/bin/sh",
......@@ -1276,19 +1298,45 @@ func TestGenerateGrovePodGangSet(t *testing.T) {
Spec: grovev1alpha1.PodGangSetSpec{
Replicas: 1,
Template: grovev1alpha1.PodGangSetTemplateSpec{
HeadlessServiceConfig: &grovev1alpha1.HeadlessServiceConfig{
PublishNotReadyAddresses: true,
},
TerminationDelay: &metav1.Duration{Duration: 15 * time.Minute},
Cliques: []*grovev1alpha1.PodCliqueTemplateSpec{
{
Name: "frontend",
Labels: map[string]string{
commonconsts.KubeLabelDynamoSelector: "test-dynamo-graph-deployment-frontend",
commonconsts.KubeLabelMetricsEnabled: commonconsts.KubeLabelValueTrue,
commonconsts.KubeLabelDynamoSelector: "test-dynamo-graph-deployment-frontend",
commonconsts.KubeLabelMetricsEnabled: commonconsts.KubeLabelValueTrue,
commonconsts.KubeLabelDynamoComponentType: commonconsts.ComponentTypeMain,
"nvidia.com/label1": "label1",
"nvidia.com/label2": "label2",
},
Annotations: map[string]string{
"nvidia.com/annotation1": "annotation1",
"nvidia.com/annotation2": "annotation2",
},
Spec: grovev1alpha1.PodCliqueSpec{
RoleName: "frontend",
Replicas: 1,
PodSpec: corev1.PodSpec{
ImagePullSecrets: []corev1.LocalObjectReference{},
Volumes: []corev1.Volume{
{
Name: "shared-memory",
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{
Medium: corev1.StorageMediumMemory,
SizeLimit: resource.NewQuantity(536870912, resource.BinarySI),
},
},
},
},
TerminationGracePeriodSeconds: ptr.To(int64(10)),
ImagePullSecrets: []corev1.LocalObjectReference{
{
Name: "frontend-secret",
},
},
Containers: []corev1.Container{
{
Name: "main",
......@@ -1360,6 +1408,12 @@ func TestGenerateGrovePodGangSet(t *testing.T) {
corev1.ResourceName("nvidia.com/gpu"): resource.MustParse("1"),
},
},
VolumeMounts: []corev1.VolumeMount{
{
Name: "shared-memory",
MountPath: "/dev/shm",
},
},
Ports: []corev1.ContainerPort{
{
Protocol: corev1.ProtocolTCP,
......@@ -1375,14 +1429,14 @@ func TestGenerateGrovePodGangSet(t *testing.T) {
{
Name: "planner",
Labels: map[string]string{
commonconsts.KubeLabelDynamoSelector: "test-dynamo-graph-deployment-planner",
commonconsts.KubeLabelMetricsEnabled: commonconsts.KubeLabelValueTrue,
commonconsts.KubeLabelDynamoSelector: "test-dynamo-graph-deployment-planner",
},
Annotations: map[string]string{},
Spec: grovev1alpha1.PodCliqueSpec{
RoleName: "planner",
Replicas: 2,
PodSpec: corev1.PodSpec{
ImagePullSecrets: []corev1.LocalObjectReference{},
Volumes: []corev1.Volume{
{
Name: "planner-pvc",
......@@ -1392,6 +1446,15 @@ func TestGenerateGrovePodGangSet(t *testing.T) {
},
},
},
{
Name: "shared-memory",
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{
Medium: corev1.StorageMediumMemory,
SizeLimit: resource.NewQuantity(536870912, resource.BinarySI),
},
},
},
},
Containers: []corev1.Container{
{
......@@ -1469,6 +1532,10 @@ func TestGenerateGrovePodGangSet(t *testing.T) {
Name: "planner-pvc",
MountPath: "/planner",
},
{
Name: "shared-memory",
MountPath: "/dev/shm",
},
},
Ports: []corev1.ContainerPort{
{
......@@ -1488,22 +1555,2592 @@ func TestGenerateGrovePodGangSet(t *testing.T) {
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := GenerateGrovePodGangSet(tt.args.ctx, tt.args.dynamoDeployment, tt.args.controllerConfig, nil)
if (err != nil) != tt.wantErr {
t.Errorf("GenerateGrovePodGangSet() error = %v, wantErr %v", err, tt.wantErr)
return
}
sort.Slice(got.Spec.Template.Cliques, func(i, j int) bool {
return got.Spec.Template.Cliques[i].Name < got.Spec.Template.Cliques[j].Name
})
sort.Slice(tt.want.Spec.Template.Cliques, func(i, j int) bool {
return tt.want.Spec.Template.Cliques[i].Name < tt.want.Spec.Template.Cliques[j].Name
})
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Errorf("GenerateGrovePodGangSet() mismatch (-want +got):\n%s", diff)
{
name: "test_generate_grove_pod_gang_set_multinode sglang",
args: args{
ctx: context.Background(),
controllerConfig: controller_common.Config{
EtcdAddress: "etcd-address",
NatsAddress: "nats-address",
Grove: controller_common.GroveConfig{
TerminationDelay: 15 * time.Minute,
},
},
dynamoDeployment: &v1alpha1.DynamoGraphDeployment{
ObjectMeta: metav1.ObjectMeta{
Name: "test-dynamo-graph-deployment",
Namespace: "test-namespace",
},
Spec: v1alpha1.DynamoGraphDeploymentSpec{
Envs: []corev1.EnvVar{
{
Name: "DYNAMO_POD_GANG_SET_REPLICAS",
Value: "1",
},
},
BackendFramework: string(BackendFrameworkSGLang),
Services: map[string]*v1alpha1.DynamoComponentDeploymentOverridesSpec{
"Frontend": {
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
Replicas: &[]int32{1}[0],
Resources: &common.Resources{
Requests: &common.ResourceItem{
CPU: "1",
Memory: "1Gi",
},
Limits: &common.ResourceItem{
CPU: "1",
Memory: "1Gi",
GPU: "1",
},
},
Envs: []corev1.EnvVar{
{
Name: "FRONTEND_ENV_1",
Value: "1",
},
},
EnvFromSecret: &[]string{"frontend-secret"}[0],
LivenessProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/health",
Port: intstr.FromInt(8080),
},
},
},
ReadinessProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/ready",
Port: intstr.FromInt(8080),
},
},
},
ExtraPodSpec: &common.ExtraPodSpec{
PodSpec: &corev1.PodSpec{
ImagePullSecrets: []corev1.LocalObjectReference{
{
Name: "frontend-secret",
},
},
TerminationGracePeriodSeconds: ptr.To(int64(10)),
},
MainContainer: &corev1.Container{
Command: []string{
"/bin/sh",
"-c",
"echo $FRONTEND_ENV_1",
},
Args: []string{
"--frontend-env-1",
"1",
},
Image: "frontend-image",
},
},
},
},
"worker": {
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ExtraPodMetadata: &common.ExtraPodMetadata{
Annotations: map[string]string{
"nvidia.com/annotation1": "annotation1",
"nvidia.com/annotation2": "annotation2",
},
Labels: map[string]string{
"nvidia.com/label1": "label1",
"nvidia.com/label2": "label2",
},
},
Replicas: &[]int32{5}[0],
ComponentType: commonconsts.ComponentTypeWorker,
ExtraPodSpec: &common.ExtraPodSpec{
MainContainer: &corev1.Container{
Image: "worker-image",
Command: []string{
"/bin/sh",
"-c",
},
Args: []string{
"python3 -m dynamo.sglang.worker --custom-flag custom-value",
},
},
},
Resources: &common.Resources{
Requests: &common.ResourceItem{
CPU: "2",
Memory: "2Gi",
Nodes: "3",
},
Limits: &common.ResourceItem{
CPU: "2",
Memory: "2Gi",
GPU: "2",
Nodes: "3",
},
},
Envs: []corev1.EnvVar{
{
Name: "WORKER_ENV_1",
Value: "1",
},
},
},
},
"Planner": {
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
Replicas: &[]int32{2}[0],
Resources: &common.Resources{
Requests: &common.ResourceItem{
CPU: "2",
Memory: "2Gi",
},
Limits: &common.ResourceItem{
CPU: "2",
Memory: "2Gi",
GPU: "2",
},
},
Envs: []corev1.EnvVar{
{
Name: "PLANNER_ENV_1",
Value: "2",
},
},
PVC: &v1alpha1.PVC{
Name: &[]string{"planner-pvc"}[0],
MountPoint: &[]string{"/planner"}[0],
},
EnvFromSecret: &[]string{"planner-secret"}[0],
LivenessProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/health",
Port: intstr.FromInt(8080),
},
},
},
ReadinessProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/ready",
Port: intstr.FromInt(8080),
},
},
},
ExtraPodSpec: &common.ExtraPodSpec{
MainContainer: &corev1.Container{
Command: []string{
"/bin/sh",
"-c",
"echo $PLANNER_ENV_1",
},
Args: []string{
"--planner-env-1",
"1",
},
Image: "planner-image",
},
},
},
},
},
},
},
},
want: &grovev1alpha1.PodGangSet{
ObjectMeta: metav1.ObjectMeta{
Name: "test-dynamo-graph-deployment",
Namespace: "test-namespace",
},
Spec: grovev1alpha1.PodGangSetSpec{
Replicas: 1,
Template: grovev1alpha1.PodGangSetTemplateSpec{
HeadlessServiceConfig: &grovev1alpha1.HeadlessServiceConfig{
PublishNotReadyAddresses: true,
},
TerminationDelay: &metav1.Duration{Duration: 15 * time.Minute},
PodCliqueScalingGroupConfigs: []grovev1alpha1.PodCliqueScalingGroupConfig{
{
Name: "worker",
CliqueNames: []string{
"worker-ldr",
"worker-wkr",
},
Replicas: ptr.To(int32(5)),
},
},
StartupType: ptr.To(grovev1alpha1.CliqueStartupTypeExplicit),
Cliques: []*grovev1alpha1.PodCliqueTemplateSpec{
{
Name: "worker-ldr",
Labels: map[string]string{
commonconsts.KubeLabelDynamoComponentType: commonconsts.ComponentTypeWorker,
commonconsts.KubeLabelMetricsEnabled: commonconsts.KubeLabelValueTrue,
commonconsts.KubeLabelDynamoSelector: "test-dynamo-graph-deployment-worker-ldr",
"nvidia.com/label1": "label1",
"nvidia.com/label2": "label2",
},
Annotations: map[string]string{
"nvidia.com/annotation1": "annotation1",
"nvidia.com/annotation2": "annotation2",
},
Spec: grovev1alpha1.PodCliqueSpec{
RoleName: "worker-ldr",
Replicas: 1,
PodSpec: corev1.PodSpec{
Volumes: []corev1.Volume{
{
Name: "shared-memory",
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{
Medium: corev1.StorageMediumMemory,
SizeLimit: resource.NewQuantity(512*1024*1024, resource.BinarySI),
},
},
},
},
Containers: []corev1.Container{
{
Name: "main",
Image: "worker-image",
Command: []string{
"/bin/sh",
"-c",
},
Args: []string{
"python3 -m dynamo.sglang.worker --dist-init-addr ${GROVE_PCSG_NAME}-${GROVE_PCSG_INDEX}-worker-ldr-0.${GROVE_HEADLESS_SERVICE}:29500 --nnodes 3 --node-rank 0 --custom-flag custom-value",
},
Ports: []corev1.ContainerPort{
{
Protocol: corev1.ProtocolTCP,
Name: commonconsts.DynamoContainerPortName,
ContainerPort: int32(commonconsts.DynamoServicePort),
},
{
Protocol: corev1.ProtocolTCP,
Name: commonconsts.DynamoSystemPortName,
ContainerPort: int32(commonconsts.DynamoSystemPort),
},
},
Env: []corev1.EnvVar{
{
Name: "DYNAMO_POD_GANG_SET_REPLICAS",
Value: "1",
},
{
Name: "WORKER_ENV_1",
Value: "1",
},
{
Name: "DYNAMO_PORT",
Value: fmt.Sprintf("%d", commonconsts.DynamoServicePort),
},
{
Name: "NATS_SERVER",
Value: "nats-address",
},
{
Name: "ETCD_ENDPOINTS",
Value: "etcd-address",
},
},
Resources: corev1.ResourceRequirements{
Requests: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("2"),
corev1.ResourceMemory: resource.MustParse("2Gi"),
},
Limits: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("2"),
corev1.ResourceMemory: resource.MustParse("2Gi"),
corev1.ResourceName("nvidia.com/gpu"): resource.MustParse("2"),
},
},
VolumeMounts: []corev1.VolumeMount{
{
Name: commonconsts.KubeValueNameSharedMemory,
MountPath: "/dev/shm",
},
},
},
},
},
},
},
{
Name: "worker-wkr",
Labels: map[string]string{
commonconsts.KubeLabelDynamoComponentType: commonconsts.ComponentTypeWorker,
commonconsts.KubeLabelMetricsEnabled: commonconsts.KubeLabelValueTrue,
commonconsts.KubeLabelDynamoSelector: "test-dynamo-graph-deployment-worker-wkr",
"nvidia.com/label1": "label1",
"nvidia.com/label2": "label2",
},
Annotations: map[string]string{
"nvidia.com/annotation1": "annotation1",
"nvidia.com/annotation2": "annotation2",
},
Spec: grovev1alpha1.PodCliqueSpec{
RoleName: "worker-wkr",
Replicas: 2,
StartsAfter: []string{"worker-ldr"},
PodSpec: corev1.PodSpec{
Volumes: []corev1.Volume{
{
Name: "shared-memory",
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{
Medium: corev1.StorageMediumMemory,
SizeLimit: resource.NewQuantity(512*1024*1024, resource.BinarySI),
},
},
},
},
Containers: []corev1.Container{
{
Name: "main",
Image: "worker-image",
Command: []string{
"/bin/sh",
"-c",
},
Args: []string{
"python3 -m dynamo.sglang.worker --dist-init-addr ${GROVE_PCSG_NAME}-${GROVE_PCSG_INDEX}-worker-ldr-0.${GROVE_HEADLESS_SERVICE}:29500 --nnodes 3 --node-rank $((GROVE_PCLQ_POD_INDEX + 1)) --custom-flag custom-value",
},
Ports: []corev1.ContainerPort{
{
Protocol: corev1.ProtocolTCP,
Name: commonconsts.DynamoContainerPortName,
ContainerPort: int32(commonconsts.DynamoServicePort),
},
{
Protocol: corev1.ProtocolTCP,
Name: commonconsts.DynamoSystemPortName,
ContainerPort: int32(commonconsts.DynamoSystemPort),
},
},
Env: []corev1.EnvVar{
{
Name: "DYNAMO_POD_GANG_SET_REPLICAS",
Value: "1",
},
{
Name: "WORKER_ENV_1",
Value: "1",
},
{
Name: "DYNAMO_PORT",
Value: fmt.Sprintf("%d", commonconsts.DynamoServicePort),
},
{
Name: "NATS_SERVER",
Value: "nats-address",
},
{
Name: "ETCD_ENDPOINTS",
Value: "etcd-address",
},
},
Resources: corev1.ResourceRequirements{
Requests: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("2"),
corev1.ResourceMemory: resource.MustParse("2Gi"),
},
Limits: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("2"),
corev1.ResourceMemory: resource.MustParse("2Gi"),
corev1.ResourceName("nvidia.com/gpu"): resource.MustParse("2"),
},
},
VolumeMounts: []corev1.VolumeMount{
{
Name: commonconsts.KubeValueNameSharedMemory,
MountPath: "/dev/shm",
},
},
},
},
},
},
},
{
Name: "frontend",
Labels: map[string]string{
commonconsts.KubeLabelMetricsEnabled: commonconsts.KubeLabelValueTrue,
commonconsts.KubeLabelDynamoSelector: "test-dynamo-graph-deployment-frontend",
},
Annotations: map[string]string{},
Spec: grovev1alpha1.PodCliqueSpec{
RoleName: "frontend",
Replicas: 1,
PodSpec: corev1.PodSpec{
Volumes: []corev1.Volume{
{
Name: "shared-memory",
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{
Medium: corev1.StorageMediumMemory,
SizeLimit: resource.NewQuantity(512*1024*1024, resource.BinarySI),
},
},
},
},
ImagePullSecrets: []corev1.LocalObjectReference{
{
Name: "frontend-secret",
},
},
TerminationGracePeriodSeconds: ptr.To(int64(10)),
Containers: []corev1.Container{
{
Name: "main",
Image: "frontend-image",
Command: []string{
"/bin/sh",
"-c",
"echo $FRONTEND_ENV_1",
},
Args: []string{
"--frontend-env-1",
"1",
},
EnvFrom: []corev1.EnvFromSource{
{
SecretRef: &corev1.SecretEnvSource{
LocalObjectReference: corev1.LocalObjectReference{
Name: "frontend-secret",
},
},
},
},
LivenessProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/health",
Port: intstr.FromInt(8080),
},
},
},
ReadinessProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/ready",
Port: intstr.FromInt(8080),
},
},
},
Env: []corev1.EnvVar{
{
Name: "DYNAMO_POD_GANG_SET_REPLICAS",
Value: "1",
},
{
Name: "FRONTEND_ENV_1",
Value: "1",
},
{
Name: "DYNAMO_PORT",
Value: fmt.Sprintf("%d", commonconsts.DynamoServicePort),
},
{
Name: "NATS_SERVER",
Value: "nats-address",
},
{
Name: "ETCD_ENDPOINTS",
Value: "etcd-address",
},
},
Resources: corev1.ResourceRequirements{
Requests: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1"),
corev1.ResourceMemory: resource.MustParse("1Gi"),
},
Limits: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1"),
corev1.ResourceMemory: resource.MustParse("1Gi"),
corev1.ResourceName("nvidia.com/gpu"): resource.MustParse("1"),
},
},
Ports: []corev1.ContainerPort{
{
Protocol: corev1.ProtocolTCP,
Name: commonconsts.DynamoContainerPortName,
ContainerPort: int32(commonconsts.DynamoServicePort),
},
},
VolumeMounts: []corev1.VolumeMount{
{
Name: commonconsts.KubeValueNameSharedMemory,
MountPath: "/dev/shm",
},
},
},
},
},
},
},
{
Name: "planner",
Labels: map[string]string{
commonconsts.KubeLabelDynamoSelector: "test-dynamo-graph-deployment-planner",
commonconsts.KubeLabelMetricsEnabled: commonconsts.KubeLabelValueTrue,
},
Annotations: map[string]string{},
Spec: grovev1alpha1.PodCliqueSpec{
RoleName: "planner",
Replicas: 2,
PodSpec: corev1.PodSpec{
Volumes: []corev1.Volume{
{
Name: "planner-pvc",
VolumeSource: corev1.VolumeSource{
PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{
ClaimName: "planner-pvc",
},
},
},
{
Name: "shared-memory",
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{
Medium: corev1.StorageMediumMemory,
SizeLimit: resource.NewQuantity(512*1024*1024, resource.BinarySI),
},
},
},
},
Containers: []corev1.Container{
{
Name: "main",
Image: "planner-image",
Command: []string{
"/bin/sh",
"-c",
"echo $PLANNER_ENV_1",
},
Args: []string{
"--planner-env-1",
"1",
},
EnvFrom: []corev1.EnvFromSource{
{
SecretRef: &corev1.SecretEnvSource{
LocalObjectReference: corev1.LocalObjectReference{
Name: "planner-secret",
},
},
},
},
LivenessProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/health",
Port: intstr.FromInt(8080),
},
},
},
ReadinessProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/ready",
Port: intstr.FromInt(8080),
},
},
},
Env: []corev1.EnvVar{
{
Name: "DYNAMO_POD_GANG_SET_REPLICAS",
Value: "1",
},
{
Name: "PLANNER_ENV_1",
Value: "2",
},
{
Name: "DYNAMO_PORT",
Value: fmt.Sprintf("%d", commonconsts.DynamoServicePort),
},
{
Name: "NATS_SERVER",
Value: "nats-address",
},
{
Name: "ETCD_ENDPOINTS",
Value: "etcd-address",
},
},
Resources: corev1.ResourceRequirements{
Requests: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("2"),
corev1.ResourceMemory: resource.MustParse("2Gi"),
},
Limits: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("2"),
corev1.ResourceMemory: resource.MustParse("2Gi"),
corev1.ResourceName("nvidia.com/gpu"): resource.MustParse("2"),
},
},
VolumeMounts: []corev1.VolumeMount{
{
Name: "planner-pvc",
MountPath: "/planner",
},
{
Name: "shared-memory",
MountPath: "/dev/shm",
},
},
Ports: []corev1.ContainerPort{
{
Protocol: corev1.ProtocolTCP,
Name: commonconsts.DynamoContainerPortName,
ContainerPort: int32(commonconsts.DynamoServicePort),
},
},
},
},
},
},
},
},
},
},
},
wantErr: false,
},
{
name: "test_generate_grove_pod_gang_set_multinode vllm",
args: args{
ctx: context.Background(),
controllerConfig: controller_common.Config{
EtcdAddress: "etcd-address",
NatsAddress: "nats-address",
Grove: controller_common.GroveConfig{
TerminationDelay: 15 * time.Minute,
},
},
dynamoDeployment: &v1alpha1.DynamoGraphDeployment{
ObjectMeta: metav1.ObjectMeta{
Name: "test-dynamo-graph-deployment",
Namespace: "test-namespace",
},
Spec: v1alpha1.DynamoGraphDeploymentSpec{
Envs: []corev1.EnvVar{
{
Name: "DYNAMO_POD_GANG_SET_REPLICAS",
Value: "1",
},
},
BackendFramework: string(BackendFrameworkVLLM),
Services: map[string]*v1alpha1.DynamoComponentDeploymentOverridesSpec{
"Frontend": {
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
Replicas: &[]int32{1}[0],
Resources: &common.Resources{
Requests: &common.ResourceItem{
CPU: "1",
Memory: "1Gi",
},
Limits: &common.ResourceItem{
CPU: "1",
Memory: "1Gi",
GPU: "1",
},
},
Envs: []corev1.EnvVar{
{
Name: "FRONTEND_ENV_1",
Value: "1",
},
},
EnvFromSecret: &[]string{"frontend-secret"}[0],
LivenessProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/health",
Port: intstr.FromInt(8080),
},
},
},
ReadinessProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/ready",
Port: intstr.FromInt(8080),
},
},
},
ExtraPodSpec: &common.ExtraPodSpec{
PodSpec: &corev1.PodSpec{
ImagePullSecrets: []corev1.LocalObjectReference{
{
Name: "frontend-secret",
},
},
TerminationGracePeriodSeconds: ptr.To(int64(10)),
},
MainContainer: &corev1.Container{
Command: []string{
"/bin/sh",
"-c",
"echo $FRONTEND_ENV_1",
},
Args: []string{
"--frontend-env-1",
"1",
},
Image: "frontend-image",
},
},
},
},
"worker": {
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ExtraPodMetadata: &common.ExtraPodMetadata{
Annotations: map[string]string{
"nvidia.com/annotation1": "annotation1",
"nvidia.com/annotation2": "annotation2",
},
Labels: map[string]string{
"nvidia.com/label1": "label1",
"nvidia.com/label2": "label2",
},
},
Replicas: &[]int32{5}[0],
ComponentType: commonconsts.ComponentTypeWorker,
ExtraPodSpec: &common.ExtraPodSpec{
MainContainer: &corev1.Container{
Image: "worker-image",
Command: []string{
"/bin/sh",
"-c",
},
Args: []string{
"python3 -m dynamo.vllm --custom-flag custom-value",
},
StartupProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/startup",
Port: intstr.FromInt(8080),
},
},
},
},
},
ReadinessProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/ready",
Port: intstr.FromInt(8080),
},
},
},
LivenessProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/health",
Port: intstr.FromInt(8080),
},
},
},
Resources: &common.Resources{
Requests: &common.ResourceItem{
CPU: "2",
Memory: "2Gi",
Nodes: "3",
},
Limits: &common.ResourceItem{
CPU: "2",
Memory: "2Gi",
GPU: "2",
Nodes: "3",
},
},
Envs: []corev1.EnvVar{
{
Name: "WORKER_ENV_1",
Value: "1",
},
},
},
},
"Planner": {
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
Replicas: &[]int32{2}[0],
Resources: &common.Resources{
Requests: &common.ResourceItem{
CPU: "2",
Memory: "2Gi",
},
Limits: &common.ResourceItem{
CPU: "2",
Memory: "2Gi",
GPU: "2",
},
},
Envs: []corev1.EnvVar{
{
Name: "PLANNER_ENV_1",
Value: "2",
},
},
PVC: &v1alpha1.PVC{
Name: &[]string{"planner-pvc"}[0],
MountPoint: &[]string{"/planner"}[0],
},
EnvFromSecret: &[]string{"planner-secret"}[0],
LivenessProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/health",
Port: intstr.FromInt(8080),
},
},
},
ReadinessProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/ready",
Port: intstr.FromInt(8080),
},
},
},
ExtraPodSpec: &common.ExtraPodSpec{
MainContainer: &corev1.Container{
Command: []string{
"/bin/sh",
"-c",
"echo $PLANNER_ENV_1",
},
Args: []string{
"--planner-env-1",
"1",
},
Image: "planner-image",
},
},
},
},
},
},
},
},
want: &grovev1alpha1.PodGangSet{
ObjectMeta: metav1.ObjectMeta{
Name: "test-dynamo-graph-deployment",
Namespace: "test-namespace",
},
Spec: grovev1alpha1.PodGangSetSpec{
Replicas: 1,
Template: grovev1alpha1.PodGangSetTemplateSpec{
HeadlessServiceConfig: &grovev1alpha1.HeadlessServiceConfig{
PublishNotReadyAddresses: true,
},
TerminationDelay: &metav1.Duration{Duration: 15 * time.Minute},
PodCliqueScalingGroupConfigs: []grovev1alpha1.PodCliqueScalingGroupConfig{
{
Name: "worker",
CliqueNames: []string{
"worker-ldr",
"worker-wkr",
},
Replicas: ptr.To(int32(5)),
},
},
StartupType: ptr.To(grovev1alpha1.CliqueStartupTypeExplicit),
Cliques: []*grovev1alpha1.PodCliqueTemplateSpec{
{
Name: "worker-ldr",
Labels: map[string]string{
commonconsts.KubeLabelDynamoSelector: "test-dynamo-graph-deployment-worker-ldr",
commonconsts.KubeLabelMetricsEnabled: commonconsts.KubeLabelValueTrue,
commonconsts.KubeLabelDynamoComponentType: commonconsts.ComponentTypeWorker,
"nvidia.com/label1": "label1",
"nvidia.com/label2": "label2",
},
Annotations: map[string]string{
"nvidia.com/annotation1": "annotation1",
"nvidia.com/annotation2": "annotation2",
},
Spec: grovev1alpha1.PodCliqueSpec{
RoleName: "worker-ldr",
Replicas: 1,
PodSpec: corev1.PodSpec{
Volumes: []corev1.Volume{
{
Name: "shared-memory",
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{
Medium: corev1.StorageMediumMemory,
SizeLimit: resource.NewQuantity(512*1024*1024, resource.BinarySI),
},
},
},
},
Containers: []corev1.Container{
{
Name: "main",
Image: "worker-image",
Command: []string{
"/bin/sh",
"-c",
},
Args: []string{
"ray start --head --port=6379 && python3 -m dynamo.vllm --custom-flag custom-value",
},
Ports: []corev1.ContainerPort{
{
Protocol: corev1.ProtocolTCP,
Name: commonconsts.DynamoContainerPortName,
ContainerPort: int32(commonconsts.DynamoServicePort),
},
{
Protocol: corev1.ProtocolTCP,
Name: commonconsts.DynamoSystemPortName,
ContainerPort: int32(commonconsts.DynamoSystemPort),
},
},
Env: []corev1.EnvVar{
{
Name: "DYNAMO_POD_GANG_SET_REPLICAS",
Value: "1",
},
{
Name: "WORKER_ENV_1",
Value: "1",
},
{
Name: "DYNAMO_PORT",
Value: fmt.Sprintf("%d", commonconsts.DynamoServicePort),
},
{
Name: "NATS_SERVER",
Value: "nats-address",
},
{
Name: "ETCD_ENDPOINTS",
Value: "etcd-address",
},
},
Resources: corev1.ResourceRequirements{
Requests: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("2"),
corev1.ResourceMemory: resource.MustParse("2Gi"),
},
Limits: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("2"),
corev1.ResourceMemory: resource.MustParse("2Gi"),
corev1.ResourceName("nvidia.com/gpu"): resource.MustParse("2"),
},
},
VolumeMounts: []corev1.VolumeMount{
{
Name: commonconsts.KubeValueNameSharedMemory,
MountPath: "/dev/shm",
},
},
ReadinessProbe: nil,
LivenessProbe: nil,
StartupProbe: nil,
},
},
},
},
},
{
Name: "worker-wkr",
Labels: map[string]string{
commonconsts.KubeLabelDynamoComponentType: commonconsts.ComponentTypeWorker,
commonconsts.KubeLabelMetricsEnabled: commonconsts.KubeLabelValueTrue,
commonconsts.KubeLabelDynamoSelector: "test-dynamo-graph-deployment-worker-wkr",
"nvidia.com/label1": "label1",
"nvidia.com/label2": "label2",
},
Annotations: map[string]string{
"nvidia.com/annotation1": "annotation1",
"nvidia.com/annotation2": "annotation2",
},
Spec: grovev1alpha1.PodCliqueSpec{
RoleName: "worker-wkr",
Replicas: 2,
StartsAfter: []string{"worker-ldr"},
PodSpec: corev1.PodSpec{
Volumes: []corev1.Volume{
{
Name: "shared-memory",
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{
Medium: corev1.StorageMediumMemory,
SizeLimit: resource.NewQuantity(512*1024*1024, resource.BinarySI),
},
},
},
},
Containers: []corev1.Container{
{
Name: "main",
Image: "worker-image",
Command: []string{
"/bin/sh",
"-c",
},
Args: []string{
"ray start --address=${GROVE_PCSG_NAME}-${GROVE_PCSG_INDEX}-worker-ldr-0.${GROVE_HEADLESS_SERVICE}:6379 --block",
},
Ports: []corev1.ContainerPort{
{
Protocol: corev1.ProtocolTCP,
Name: commonconsts.DynamoContainerPortName,
ContainerPort: int32(commonconsts.DynamoServicePort),
},
{
Protocol: corev1.ProtocolTCP,
Name: commonconsts.DynamoSystemPortName,
ContainerPort: int32(commonconsts.DynamoSystemPort),
},
},
Env: []corev1.EnvVar{
{
Name: "DYNAMO_POD_GANG_SET_REPLICAS",
Value: "1",
},
{
Name: "WORKER_ENV_1",
Value: "1",
},
{
Name: "DYNAMO_PORT",
Value: fmt.Sprintf("%d", commonconsts.DynamoServicePort),
},
{
Name: "NATS_SERVER",
Value: "nats-address",
},
{
Name: "ETCD_ENDPOINTS",
Value: "etcd-address",
},
},
Resources: corev1.ResourceRequirements{
Requests: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("2"),
corev1.ResourceMemory: resource.MustParse("2Gi"),
},
Limits: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("2"),
corev1.ResourceMemory: resource.MustParse("2Gi"),
corev1.ResourceName("nvidia.com/gpu"): resource.MustParse("2"),
},
},
VolumeMounts: []corev1.VolumeMount{
{
Name: commonconsts.KubeValueNameSharedMemory,
MountPath: "/dev/shm",
},
},
},
},
},
},
},
{
Name: "frontend",
Labels: map[string]string{
commonconsts.KubeLabelMetricsEnabled: commonconsts.KubeLabelValueTrue,
commonconsts.KubeLabelDynamoSelector: "test-dynamo-graph-deployment-frontend",
},
Annotations: map[string]string{},
Spec: grovev1alpha1.PodCliqueSpec{
RoleName: "frontend",
Replicas: 1,
PodSpec: corev1.PodSpec{
Volumes: []corev1.Volume{
{
Name: "shared-memory",
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{
Medium: corev1.StorageMediumMemory,
SizeLimit: resource.NewQuantity(512*1024*1024, resource.BinarySI),
},
},
},
},
ImagePullSecrets: []corev1.LocalObjectReference{
{
Name: "frontend-secret",
},
},
TerminationGracePeriodSeconds: ptr.To(int64(10)),
Containers: []corev1.Container{
{
Name: "main",
Image: "frontend-image",
Command: []string{
"/bin/sh",
"-c",
"echo $FRONTEND_ENV_1",
},
Args: []string{
"--frontend-env-1",
"1",
},
EnvFrom: []corev1.EnvFromSource{
{
SecretRef: &corev1.SecretEnvSource{
LocalObjectReference: corev1.LocalObjectReference{
Name: "frontend-secret",
},
},
},
},
LivenessProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/health",
Port: intstr.FromInt(8080),
},
},
},
ReadinessProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/ready",
Port: intstr.FromInt(8080),
},
},
},
Env: []corev1.EnvVar{
{
Name: "DYNAMO_POD_GANG_SET_REPLICAS",
Value: "1",
},
{
Name: "FRONTEND_ENV_1",
Value: "1",
},
{
Name: "DYNAMO_PORT",
Value: fmt.Sprintf("%d", commonconsts.DynamoServicePort),
},
{
Name: "NATS_SERVER",
Value: "nats-address",
},
{
Name: "ETCD_ENDPOINTS",
Value: "etcd-address",
},
},
Resources: corev1.ResourceRequirements{
Requests: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1"),
corev1.ResourceMemory: resource.MustParse("1Gi"),
},
Limits: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1"),
corev1.ResourceMemory: resource.MustParse("1Gi"),
corev1.ResourceName("nvidia.com/gpu"): resource.MustParse("1"),
},
},
Ports: []corev1.ContainerPort{
{
Protocol: corev1.ProtocolTCP,
Name: commonconsts.DynamoContainerPortName,
ContainerPort: int32(commonconsts.DynamoServicePort),
},
},
VolumeMounts: []corev1.VolumeMount{
{
Name: commonconsts.KubeValueNameSharedMemory,
MountPath: "/dev/shm",
},
},
},
},
},
},
},
{
Name: "planner",
Labels: map[string]string{
commonconsts.KubeLabelMetricsEnabled: commonconsts.KubeLabelValueTrue,
commonconsts.KubeLabelDynamoSelector: "test-dynamo-graph-deployment-planner",
},
Annotations: map[string]string{},
Spec: grovev1alpha1.PodCliqueSpec{
RoleName: "planner",
Replicas: 2,
PodSpec: corev1.PodSpec{
Volumes: []corev1.Volume{
{
Name: "planner-pvc",
VolumeSource: corev1.VolumeSource{
PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{
ClaimName: "planner-pvc",
},
},
},
{
Name: "shared-memory",
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{
Medium: corev1.StorageMediumMemory,
SizeLimit: resource.NewQuantity(512*1024*1024, resource.BinarySI),
},
},
},
},
Containers: []corev1.Container{
{
Name: "main",
Image: "planner-image",
Command: []string{
"/bin/sh",
"-c",
"echo $PLANNER_ENV_1",
},
Args: []string{
"--planner-env-1",
"1",
},
EnvFrom: []corev1.EnvFromSource{
{
SecretRef: &corev1.SecretEnvSource{
LocalObjectReference: corev1.LocalObjectReference{
Name: "planner-secret",
},
},
},
},
LivenessProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/health",
Port: intstr.FromInt(8080),
},
},
},
ReadinessProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/ready",
Port: intstr.FromInt(8080),
},
},
},
Env: []corev1.EnvVar{
{
Name: "DYNAMO_POD_GANG_SET_REPLICAS",
Value: "1",
},
{
Name: "PLANNER_ENV_1",
Value: "2",
},
{
Name: "DYNAMO_PORT",
Value: fmt.Sprintf("%d", commonconsts.DynamoServicePort),
},
{
Name: "NATS_SERVER",
Value: "nats-address",
},
{
Name: "ETCD_ENDPOINTS",
Value: "etcd-address",
},
},
Resources: corev1.ResourceRequirements{
Requests: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("2"),
corev1.ResourceMemory: resource.MustParse("2Gi"),
},
Limits: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("2"),
corev1.ResourceMemory: resource.MustParse("2Gi"),
corev1.ResourceName("nvidia.com/gpu"): resource.MustParse("2"),
},
},
VolumeMounts: []corev1.VolumeMount{
{
Name: "planner-pvc",
MountPath: "/planner",
},
{
Name: "shared-memory",
MountPath: "/dev/shm",
},
},
Ports: []corev1.ContainerPort{
{
Protocol: corev1.ProtocolTCP,
Name: commonconsts.DynamoContainerPortName,
ContainerPort: int32(commonconsts.DynamoServicePort),
},
},
},
},
},
},
},
},
},
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := GenerateGrovePodGangSet(tt.args.ctx, tt.args.dynamoDeployment, tt.args.controllerConfig, nil)
if (err != nil) != tt.wantErr {
t.Errorf("GenerateGrovePodGangSet() error = %v, wantErr %v", err, tt.wantErr)
return
}
sort.Slice(got.Spec.Template.Cliques, func(i, j int) bool {
return got.Spec.Template.Cliques[i].Name < got.Spec.Template.Cliques[j].Name
})
sort.Slice(tt.want.Spec.Template.Cliques, func(i, j int) bool {
return tt.want.Spec.Template.Cliques[i].Name < tt.want.Spec.Template.Cliques[j].Name
})
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Errorf("GenerateGrovePodGangSet() mismatch (-want +got):\n%s", diff)
}
})
}
}
// Mock SecretsRetriever for testing
type mockSecretsRetriever struct{}
func (m *mockSecretsRetriever) RetrieveImagePullSecrets(ctx context.Context, deployment *v1alpha1.DynamoGraphDeployment) ([]corev1.LocalObjectReference, error) {
return []corev1.LocalObjectReference{}, nil
}
func (m *mockSecretsRetriever) GetSecrets(namespace, registry string) ([]string, error) {
return []string{}, nil
}
func TestGeneratePodSpecForComponent_SGLang(t *testing.T) {
secretsRetriever := &mockSecretsRetriever{}
dynamoDeployment := &v1alpha1.DynamoGraphDeployment{
ObjectMeta: metav1.ObjectMeta{
Name: "test-deployment",
Namespace: "default",
},
}
controllerConfig := controller_common.Config{}
tests := []struct {
name string
component *v1alpha1.DynamoComponentDeploymentOverridesSpec
backendFramework BackendFramework
role Role
numberOfNodes int32
expectError bool
expectContains []string
expectNotContains []string
}{
{
name: "SGLang single node worker",
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ComponentType: commonconsts.ComponentTypeWorker,
ExtraPodSpec: &common.ExtraPodSpec{
MainContainer: &corev1.Container{
Args: []string{"python3", "-m", "dynamo.sglang.worker"},
},
},
},
},
backendFramework: BackendFrameworkSGLang,
role: RoleMain,
numberOfNodes: 1,
expectError: false,
expectContains: []string{"python3", "-m", "dynamo.sglang.worker"},
expectNotContains: []string{"dist-init-addr", "nnodes", "tp-size"},
},
{
name: "SGLang multinode leader",
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ComponentType: commonconsts.ComponentTypeWorker,
ExtraPodSpec: &common.ExtraPodSpec{
MainContainer: &corev1.Container{
Args: []string{"python3", "-m", "dynamo.sglang.worker"},
},
},
},
},
backendFramework: BackendFrameworkSGLang,
role: RoleLeader,
numberOfNodes: 3,
expectError: false,
expectContains: []string{"python3", "-m", "dynamo.sglang.worker", "dist-init-addr", "nnodes", "node-rank"},
},
{
name: "SGLang multinode worker",
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ComponentType: commonconsts.ComponentTypeWorker,
ExtraPodSpec: &common.ExtraPodSpec{
MainContainer: &corev1.Container{
Args: []string{"python3", "-m", "dynamo.sglang.worker"},
},
},
},
},
backendFramework: BackendFrameworkSGLang,
role: RoleWorker,
numberOfNodes: 3,
expectError: false,
expectContains: []string{"python3", "-m", "dynamo.sglang.worker", "dist-init-addr", "nnodes", "node-rank"},
},
{
name: "SGLang with user command override",
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ComponentType: commonconsts.ComponentTypeWorker,
ExtraPodSpec: &common.ExtraPodSpec{
MainContainer: &corev1.Container{
Command: []string{"custom", "command"},
},
},
},
},
backendFramework: BackendFrameworkSGLang,
role: RoleMain,
numberOfNodes: 1,
expectError: false,
expectContains: []string{},
},
{
name: "SGLang with resources",
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ComponentType: commonconsts.ComponentTypeWorker,
ExtraPodSpec: &common.ExtraPodSpec{
MainContainer: &corev1.Container{
Args: []string{"python3", "-m", "dynamo.sglang.worker"},
Resources: corev1.ResourceRequirements{
Requests: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1"),
corev1.ResourceMemory: resource.MustParse("2Gi"),
},
},
},
},
},
},
backendFramework: BackendFrameworkSGLang,
role: RoleMain,
numberOfNodes: 1,
expectError: false,
expectContains: []string{"python3 -m dynamo.sglang.worker"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
podSpec, err := GeneratePodSpecForComponent(
tt.component,
tt.backendFramework,
secretsRetriever,
dynamoDeployment,
tt.role,
tt.numberOfNodes,
controllerConfig,
commonconsts.MultinodeDeploymentTypeGrove,
"worker",
)
if tt.expectError {
if err == nil {
t.Errorf("GeneratePodSpecForComponent() expected error, got nil")
}
return
}
if err != nil {
t.Errorf("GeneratePodSpecForComponent() unexpected error: %v", err)
return
}
// Check container exists
if len(podSpec.Containers) == 0 {
t.Errorf("GeneratePodSpecForComponent() no containers in podSpec")
return
}
container := podSpec.Containers[0]
// Check command and args contain expected strings
allArgs := append(container.Command, container.Args...)
allArgsStr := strings.Join(allArgs, " ")
for _, expected := range tt.expectContains {
if !strings.Contains(allArgsStr, expected) {
t.Errorf("GeneratePodSpecForComponent() args = %v, should contain %s", allArgs, expected)
}
}
for _, notExpected := range tt.expectNotContains {
if strings.Contains(allArgsStr, notExpected) {
t.Errorf("GeneratePodSpecForComponent() args = %v, should NOT contain %s", allArgs, notExpected)
}
}
// Check that container name is set
if container.Name != "main" {
t.Errorf("GeneratePodSpecForComponent() container name = %s, want main", container.Name)
}
})
}
}
func TestGeneratePodSpecForComponent_VLLM(t *testing.T) {
secretsRetriever := &mockSecretsRetriever{}
dynamoDeployment := &v1alpha1.DynamoGraphDeployment{
ObjectMeta: metav1.ObjectMeta{
Name: "test-deployment",
Namespace: "default",
},
}
controllerConfig := controller_common.Config{}
tests := []struct {
name string
component *v1alpha1.DynamoComponentDeploymentOverridesSpec
backendFramework BackendFramework
role Role
numberOfNodes int32
expectError bool
expectContains []string
expectNotContains []string
}{
{
name: "VLLM single node worker",
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ComponentType: commonconsts.ComponentTypeWorker,
ExtraPodSpec: &common.ExtraPodSpec{
MainContainer: &corev1.Container{
Args: []string{"python3", "-m", "dynamo.vllm"},
},
},
},
},
backendFramework: BackendFrameworkVLLM,
role: RoleMain,
numberOfNodes: 1,
expectError: false,
expectContains: []string{"python3", "-m", "dynamo.vllm"},
expectNotContains: []string{"ray start"},
},
{
name: "VLLM multinode leader",
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ComponentType: commonconsts.ComponentTypeWorker,
ExtraPodSpec: &common.ExtraPodSpec{
MainContainer: &corev1.Container{
Args: []string{"python3", "-m", "dynamo.vllm"},
},
},
},
},
backendFramework: BackendFrameworkVLLM,
role: RoleLeader,
numberOfNodes: 3,
expectError: false,
expectContains: []string{"ray start --head --port=6379", "python3", "-m", "dynamo.vllm"},
},
{
name: "VLLM multinode worker",
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ComponentType: commonconsts.ComponentTypeWorker,
},
},
backendFramework: BackendFrameworkVLLM,
role: RoleWorker,
numberOfNodes: 3,
expectError: false,
expectContains: []string{"ray start --address=${GROVE_PCSG_NAME}-${GROVE_PCSG_INDEX}-worker-ldr-0.${GROVE_HEADLESS_SERVICE}:6379 --block"},
expectNotContains: []string{"python3 -m dynamo.vllm"},
},
{
name: "VLLM worker single node",
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ComponentType: commonconsts.ComponentTypeWorker,
ExtraPodSpec: &common.ExtraPodSpec{
MainContainer: &corev1.Container{
Args: []string{"python3", "-m", "dynamo.vllm", "--is-prefill-worker"},
},
},
},
},
backendFramework: BackendFrameworkVLLM,
role: RoleMain,
numberOfNodes: 1,
expectError: false,
expectContains: []string{"python3", "-m", "dynamo.vllm", "--is-prefill-worker"},
expectNotContains: []string{"ray start"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
podSpec, err := GeneratePodSpecForComponent(
tt.component,
tt.backendFramework,
secretsRetriever,
dynamoDeployment,
tt.role,
tt.numberOfNodes,
controllerConfig,
commonconsts.MultinodeDeploymentTypeGrove,
"worker",
)
if tt.expectError {
if err == nil {
t.Errorf("GeneratePodSpecForComponent() expected error, got nil")
}
return
}
if err != nil {
t.Errorf("GeneratePodSpecForComponent() unexpected error: %v", err)
return
}
// Check container exists
if len(podSpec.Containers) == 0 {
t.Errorf("GeneratePodSpecForComponent() no containers in podSpec")
return
}
container := podSpec.Containers[0]
// Check command and args contain expected strings
allArgs := append(container.Command, container.Args...)
allArgsStr := strings.Join(allArgs, " ")
for _, expected := range tt.expectContains {
if !strings.Contains(allArgsStr, expected) {
t.Errorf("GeneratePodSpecForComponent() args = %v, should contain %s", allArgs, expected)
}
}
for _, notExpected := range tt.expectNotContains {
if strings.Contains(allArgsStr, notExpected) {
t.Errorf("GeneratePodSpecForComponent() args = %v, should NOT contain %s", allArgs, notExpected)
}
}
})
}
}
func TestGeneratePodSpecForComponent_UnsupportedBackend(t *testing.T) {
secretsRetriever := &mockSecretsRetriever{}
dynamoDeployment := &v1alpha1.DynamoGraphDeployment{
ObjectMeta: metav1.ObjectMeta{
Name: "test-deployment",
Namespace: "default",
},
}
controllerConfig := controller_common.Config{}
component := &v1alpha1.DynamoComponentDeploymentOverridesSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ComponentType: commonconsts.ComponentTypeWorker,
},
}
tests := []struct {
name string
backendFramework BackendFramework
expectError bool
errorContains string
}{
{
name: "TRTLLM backend implemented",
backendFramework: BackendFrameworkTRTLLM,
expectError: false,
},
{
name: "unknown backend",
backendFramework: BackendFramework("unknown"),
expectError: true,
errorContains: "unsupported backend framework",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := GeneratePodSpecForComponent(
component,
tt.backendFramework,
secretsRetriever,
dynamoDeployment,
RoleMain,
1,
controllerConfig,
commonconsts.MultinodeDeploymentTypeGrove,
"worker",
)
if tt.expectError {
if err == nil {
t.Errorf("GeneratePodSpecForComponent() expected error, got nil")
return
}
if !strings.Contains(err.Error(), tt.errorContains) {
t.Errorf("GeneratePodSpecForComponent() error = %v, should contain %s", err, tt.errorContains)
}
} else {
if err != nil {
t.Errorf("GeneratePodSpecForComponent() unexpected error: %v", err)
}
}
})
}
}
func TestMergeContainerCommand(t *testing.T) {
tests := []struct {
name string
defaultCmd []string
userCmd []string
expected []string
}{
{
name: "user command overrides default",
defaultCmd: []string{"python", "default.py"},
userCmd: []string{"python", "custom.py"},
expected: []string{"python", "custom.py"},
},
{
name: "empty user command returns default",
defaultCmd: []string{"python", "default.py"},
userCmd: []string{},
expected: []string{"python", "default.py"},
},
{
name: "nil user command returns default",
defaultCmd: []string{"python", "default.py"},
userCmd: nil,
expected: []string{"python", "default.py"},
},
{
name: "both empty returns empty",
defaultCmd: []string{},
userCmd: []string{},
expected: []string{},
},
{
name: "default empty user provided",
defaultCmd: []string{},
userCmd: []string{"python", "user.py"},
expected: []string{"python", "user.py"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := mergeContainerCommand(tt.defaultCmd, tt.userCmd)
if !reflect.DeepEqual(result, tt.expected) {
t.Errorf("mergeContainerCommand() = %v, want %v", result, tt.expected)
}
})
}
}
func TestExpandRolesForService(t *testing.T) {
tests := []struct {
name string
serviceName string
numberOfNodes int32
serviceReplicas int32
expected []ServiceRole
}{
{
name: "single node",
serviceName: "test-service",
numberOfNodes: 1,
serviceReplicas: 2,
expected: []ServiceRole{
{Name: "test-service", Role: RoleMain, Replicas: 2},
},
},
{
name: "multinode 2 nodes",
serviceName: "test-service",
numberOfNodes: 2,
expected: []ServiceRole{
{Name: "test-service-ldr", Role: RoleLeader, Replicas: 1},
{Name: "test-service-wkr", Role: RoleWorker, Replicas: 1},
},
},
{
name: "multinode 5 nodes",
serviceName: "test-service",
numberOfNodes: 5,
expected: []ServiceRole{
{Name: "test-service-ldr", Role: RoleLeader, Replicas: 1},
{Name: "test-service-wkr", Role: RoleWorker, Replicas: 4},
},
},
{
name: "zero nodes should return main",
serviceName: "test-service",
numberOfNodes: 0,
serviceReplicas: 1,
expected: []ServiceRole{
{Name: "test-service", Role: RoleMain, Replicas: 1},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := expandRolesForService(tt.serviceName, &tt.serviceReplicas, tt.numberOfNodes)
if !reflect.DeepEqual(result, tt.expected) {
t.Errorf("expandRolesForService() = %v, want %v", result, tt.expected)
}
})
}
}
func TestRoleEnum(t *testing.T) {
// Test that role constants are defined correctly
if RoleLeader != "leader" {
t.Errorf("RoleLeader = %v, want \"leader\"", RoleLeader)
}
if RoleWorker != "worker" {
t.Errorf("RoleWorker = %v, want \"worker\"", RoleWorker)
}
if RoleMain != "main" {
t.Errorf("RoleMain = %v, want \"main\"", RoleMain)
}
// Test that roles can be compared
roles := []Role{RoleLeader, RoleWorker, RoleMain}
for _, role := range roles {
switch role {
case RoleLeader, RoleWorker, RoleMain:
// Expected
default:
t.Errorf("Unexpected role value: %v", role)
}
}
}
func TestBackendFrameworkEnum(t *testing.T) {
// Test that backend framework constants are defined correctly
if BackendFrameworkSGLang != "sglang" {
t.Errorf("BackendFrameworkSGLang = %v, want \"sglang\"", BackendFrameworkSGLang)
}
if BackendFrameworkVLLM != "vllm" {
t.Errorf("BackendFrameworkVLLM = %v, want \"vllm\"", BackendFrameworkVLLM)
}
if BackendFrameworkTRTLLM != "trtllm" {
t.Errorf("BackendFrameworkTRTLLM = %v, want \"trtllm\"", BackendFrameworkTRTLLM)
}
// Test that frameworks can be compared
frameworks := []BackendFramework{BackendFrameworkSGLang, BackendFrameworkVLLM, BackendFrameworkTRTLLM}
for _, framework := range frameworks {
switch framework {
case BackendFrameworkSGLang, BackendFrameworkVLLM, BackendFrameworkTRTLLM:
// Expected
default:
t.Errorf("Unexpected framework value: %v", framework)
}
}
}
func TestServiceRoleStruct(t *testing.T) {
// Test ServiceRole struct creation and field access
sr := ServiceRole{
Name: "test-service",
Role: RoleLeader,
Replicas: 3,
}
if sr.Name != "test-service" {
t.Errorf("ServiceRole.Name = %v, want \"test-service\"", sr.Name)
}
if sr.Role != RoleLeader {
t.Errorf("ServiceRole.Role = %v, want %v", sr.Role, RoleLeader)
}
if sr.Replicas != 3 {
t.Errorf("ServiceRole.Replicas = %v, want 3", sr.Replicas)
}
}
func TestDetectBackendFrameworkFromArgs(t *testing.T) {
tests := []struct {
name string
command []string
args []string
expected BackendFramework
expectError bool
}{
{
name: "detect VLLM from args",
command: []string{"/bin/sh", "-c"},
args: []string{"python -m dynamo.vllm.worker --model test"},
expected: BackendFrameworkVLLM,
},
{
name: "detect SGLang from args",
command: []string{"/bin/sh", "-c"},
args: []string{"python -m dynamo.sglang.worker --model test"},
expected: BackendFrameworkSGLang,
},
{
name: "detect TRTLLM from args",
command: []string{"/bin/sh", "-c"},
args: []string{"python -m dynamo.trtllm.worker --model test"},
expected: BackendFrameworkTRTLLM,
},
{
name: "detect from complex command with pipes",
command: []string{},
args: []string{"echo start && python -m dynamo.vllm.worker --model test | tee /tmp/log"},
expected: BackendFrameworkVLLM,
},
{
name: "detect from python3.11",
command: []string{},
args: []string{"python3.11 -m dynamo.sglang.decode_worker"},
expected: BackendFrameworkSGLang,
},
{
name: "no backend detected",
command: []string{"/bin/sh", "-c"},
args: []string{"echo hello world"},
expectError: true,
},
{
name: "multiple backends detected",
command: []string{},
args: []string{"python -m dynamo.vllm.worker && python -m dynamo.sglang.worker"},
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := detectBackendFrameworkFromArgs(tt.command, tt.args)
if tt.expectError {
if err == nil {
t.Errorf("detectBackendFrameworkFromArgs() expected error, got none")
}
return
}
if err != nil {
t.Errorf("detectBackendFrameworkFromArgs() unexpected error: %v", err)
return
}
if result != tt.expected {
t.Errorf("detectBackendFrameworkFromArgs() = %v, want %v", result, tt.expected)
}
})
}
}
func TestDetermineBackendFramework(t *testing.T) {
tests := []struct {
name string
componentType string
command []string
args []string
explicitBackendFramework string
expected BackendFramework
expectError bool
errorContains string
}{
{
name: "non-worker component returns noop",
componentType: "main",
command: []string{"/bin/sh", "-c"},
args: []string{"echo hello world"},
expected: BackendFrameworkNoop,
},
{
name: "worker with VLLM detection",
componentType: "worker",
command: []string{},
args: []string{"python -m dynamo.vllm.worker --model test"},
expected: BackendFrameworkVLLM,
},
{
name: "worker with explicit framework only",
componentType: "worker",
explicitBackendFramework: "sglang",
expected: BackendFrameworkSGLang,
},
{
name: "worker with detected matching explicit",
componentType: "worker",
args: []string{"python -m dynamo.sglang.worker"},
explicitBackendFramework: "sglang",
expected: BackendFrameworkSGLang,
},
{
name: "worker with detected conflicting explicit",
componentType: "worker",
args: []string{"python -m dynamo.vllm.worker"},
explicitBackendFramework: "sglang",
expectError: true,
errorContains: "backend framework mismatch",
},
{
name: "worker with no detection, no explicit - returns error",
componentType: "worker",
expectError: true,
errorContains: "backend framework must be specified explicitly or detectable from command/args",
},
{
name: "worker with detection failure, no explicit - returns error",
componentType: "worker",
args: []string{"echo hello world"},
expectError: true,
errorContains: "could not determine backend framework",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := determineBackendFramework(
tt.componentType,
tt.command,
tt.args,
tt.explicitBackendFramework,
)
if tt.expectError {
if err == nil {
t.Errorf("determineBackendFramework() expected error, got none")
return
}
if tt.errorContains != "" && !strings.Contains(err.Error(), tt.errorContains) {
t.Errorf("determineBackendFramework() error = %v, should contain %q", err, tt.errorContains)
}
return
}
if err != nil {
t.Errorf("determineBackendFramework() unexpected error: %v", err)
return
}
if result != tt.expected {
t.Errorf("determineBackendFramework() = %v, want %v", result, tt.expected)
}
})
}
}
func TestGetBackendFrameworkFromComponent(t *testing.T) {
tests := []struct {
name string
component *v1alpha1.DynamoComponentDeploymentOverridesSpec
deployment *v1alpha1.DynamoGraphDeployment
expected BackendFramework
expectError bool
errorContains string
}{
{
name: "detect from args - VLLM",
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ComponentType: "worker", // Worker component
ExtraPodSpec: &common.ExtraPodSpec{
MainContainer: &corev1.Container{
Args: []string{"python -m dynamo.vllm.worker --model test"},
},
},
},
},
deployment: &v1alpha1.DynamoGraphDeployment{},
expected: BackendFrameworkVLLM,
},
{
name: "explicit framework only",
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ComponentType: "worker", // Worker component
},
},
deployment: &v1alpha1.DynamoGraphDeployment{
Spec: v1alpha1.DynamoGraphDeploymentSpec{
BackendFramework: "sglang",
},
},
expected: BackendFrameworkSGLang,
},
{
name: "detected matches explicit",
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ComponentType: "worker", // Worker component
ExtraPodSpec: &common.ExtraPodSpec{
MainContainer: &corev1.Container{
Args: []string{"python -m dynamo.sglang.worker"},
},
},
},
},
deployment: &v1alpha1.DynamoGraphDeployment{
Spec: v1alpha1.DynamoGraphDeploymentSpec{
BackendFramework: "sglang",
},
},
expected: BackendFrameworkSGLang,
},
{
name: "detected conflicts with explicit",
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ComponentType: "worker", // Worker component
ExtraPodSpec: &common.ExtraPodSpec{
MainContainer: &corev1.Container{
Args: []string{"python -m dynamo.vllm.worker"},
},
},
},
},
deployment: &v1alpha1.DynamoGraphDeployment{
Spec: v1alpha1.DynamoGraphDeploymentSpec{
BackendFramework: "sglang",
},
},
expectError: true,
errorContains: "backend framework mismatch",
},
{
name: "non-worker component returns noop",
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ComponentType: "main", // Frontend component
},
},
deployment: &v1alpha1.DynamoGraphDeployment{},
expected: BackendFrameworkNoop,
},
{
name: "worker with no detection, no explicit - returns error",
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ComponentType: "worker", // Worker component
},
},
deployment: &v1alpha1.DynamoGraphDeployment{},
expectError: true,
errorContains: "backend framework must be specified explicitly or detectable from command/args",
},
{
name: "worker with detection failure, no explicit - returns error",
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ComponentType: "worker", // Worker component
ExtraPodSpec: &common.ExtraPodSpec{
MainContainer: &corev1.Container{
Args: []string{"echo hello world"},
},
},
},
},
deployment: &v1alpha1.DynamoGraphDeployment{},
expectError: true,
errorContains: "could not determine backend framework",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := getBackendFrameworkFromComponent(tt.component, tt.deployment)
if tt.expectError {
if err == nil {
t.Errorf("getBackendFrameworkFromComponent() expected error, got none")
return
}
if tt.errorContains != "" && !strings.Contains(err.Error(), tt.errorContains) {
t.Errorf("getBackendFrameworkFromComponent() error = %v, should contain %q", err, tt.errorContains)
}
return
}
if err != nil {
t.Errorf("getBackendFrameworkFromComponent() unexpected error: %v", err)
return
}
if result != tt.expected {
t.Errorf("getBackendFrameworkFromComponent() = %v, want %v", result, tt.expected)
}
})
}
}
func TestApplyCliqueStartupDependencies(t *testing.T) {
tests := []struct {
name string
roles []ServiceRole
backendFramework BackendFramework
numberOfNodes int32
expectedDeps map[string][]string // clique name -> expected StartsAfter dependencies
expectStartupType bool
}{
{
name: "vllm_multinode_applies_dependencies",
roles: []ServiceRole{
{Name: "service-ldr", Role: RoleLeader, Replicas: 1},
{Name: "service-wkr", Role: RoleWorker, Replicas: 2},
},
backendFramework: BackendFrameworkVLLM,
numberOfNodes: 3,
expectedDeps: map[string][]string{
"service-ldr": nil,
"service-wkr": {"service-ldr"},
},
expectStartupType: true,
},
{
name: "sglang_multinode_applies_dependencies",
roles: []ServiceRole{
{Name: "service-ldr", Role: RoleLeader, Replicas: 1},
{Name: "service-wkr", Role: RoleWorker, Replicas: 2},
},
backendFramework: BackendFrameworkSGLang,
numberOfNodes: 3,
expectedDeps: map[string][]string{
"service-ldr": nil,
"service-wkr": {"service-ldr"},
},
expectStartupType: true,
},
{
name: "trtllm_multinode_applies_dependencies",
roles: []ServiceRole{
{Name: "service-ldr", Role: RoleLeader, Replicas: 1},
{Name: "service-wkr", Role: RoleWorker, Replicas: 2},
},
backendFramework: BackendFrameworkTRTLLM,
numberOfNodes: 3,
expectedDeps: map[string][]string{
"service-ldr": {"service-wkr"},
"service-wkr": nil,
},
expectStartupType: true,
},
{
name: "single_node_no_dependencies",
roles: []ServiceRole{
{Name: "service", Role: RoleMain, Replicas: 1},
},
backendFramework: BackendFrameworkVLLM,
numberOfNodes: 1,
expectedDeps: map[string][]string{
"service": nil,
},
expectStartupType: false,
},
{
name: "noop_backend_no_dependencies",
roles: []ServiceRole{
{Name: "service-ldr", Role: RoleLeader, Replicas: 1},
{Name: "service-wkr", Role: RoleWorker, Replicas: 2},
},
backendFramework: BackendFrameworkNoop,
numberOfNodes: 3,
expectedDeps: map[string][]string{
"service-ldr": nil,
"service-wkr": nil,
},
expectStartupType: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a PodGangSet with cliques matching the roles
gangSet := &grovev1alpha1.PodGangSet{
Spec: grovev1alpha1.PodGangSetSpec{
Template: grovev1alpha1.PodGangSetTemplateSpec{
Cliques: []*grovev1alpha1.PodCliqueTemplateSpec{},
},
},
}
// Add cliques for each role
for _, role := range tt.roles {
clique := &grovev1alpha1.PodCliqueTemplateSpec{
Name: strings.ToLower(role.Name),
Spec: grovev1alpha1.PodCliqueSpec{
RoleName: strings.ToLower(role.Name),
Replicas: role.Replicas,
},
}
gangSet.Spec.Template.Cliques = append(gangSet.Spec.Template.Cliques, clique)
}
// Apply dependencies
applyCliqueStartupDependencies(gangSet, tt.roles, tt.backendFramework, tt.numberOfNodes)
// Verify StartupType
if tt.expectStartupType {
if gangSet.Spec.Template.StartupType == nil || *gangSet.Spec.Template.StartupType != grovev1alpha1.CliqueStartupTypeExplicit {
t.Errorf("Expected StartupType to be CliqueStartupTypeExplicit, got %v", gangSet.Spec.Template.StartupType)
}
} else {
if gangSet.Spec.Template.StartupType != nil {
t.Errorf("Expected StartupType to be nil, got %v", *gangSet.Spec.Template.StartupType)
}
}
// Verify dependencies for each clique
for _, clique := range gangSet.Spec.Template.Cliques {
expectedDeps, exists := tt.expectedDeps[clique.Name]
if !exists {
t.Errorf("Unexpected clique %s", clique.Name)
continue
}
if !reflect.DeepEqual(clique.Spec.StartsAfter, expectedDeps) {
t.Errorf("Clique %s: expected StartsAfter %v, got %v", clique.Name, expectedDeps, clique.Spec.StartsAfter)
}
}
})
}
}
func TestGetCliqueStartupDependencies(t *testing.T) {
tests := []struct {
name string
role Role
backendFramework BackendFramework
leaderCliqueName string
workerCliqueNames []string
expected []string
}{
{
name: "vllm_worker_depends_on_leader",
role: RoleWorker,
backendFramework: BackendFrameworkVLLM,
leaderCliqueName: "service-ldr",
workerCliqueNames: []string{"service-wkr"},
expected: []string{"service-ldr"},
},
{
name: "vllm_leader_has_no_dependencies",
role: RoleLeader,
backendFramework: BackendFrameworkVLLM,
leaderCliqueName: "service-ldr",
workerCliqueNames: []string{"service-wkr"},
expected: nil,
},
{
name: "sglang_worker_depends_on_leader",
role: RoleWorker,
backendFramework: BackendFrameworkSGLang,
leaderCliqueName: "service-ldr",
workerCliqueNames: []string{"service-wkr"},
expected: []string{"service-ldr"},
},
{
name: "sglang_leader_has_no_dependencies",
role: RoleLeader,
backendFramework: BackendFrameworkSGLang,
leaderCliqueName: "service-ldr",
workerCliqueNames: []string{"service-wkr"},
expected: nil,
},
{
name: "trtllm_leader_depends_on_workers",
role: RoleLeader,
backendFramework: BackendFrameworkTRTLLM,
leaderCliqueName: "service-ldr",
workerCliqueNames: []string{"service-wkr1", "service-wkr2"},
expected: []string{"service-wkr1", "service-wkr2"},
},
{
name: "trtllm_worker_has_no_dependencies",
role: RoleWorker,
backendFramework: BackendFrameworkTRTLLM,
leaderCliqueName: "service-ldr",
workerCliqueNames: []string{"service-wkr"},
expected: nil,
},
{
name: "noop_backend_has_no_dependencies",
role: RoleWorker,
backendFramework: BackendFrameworkNoop,
leaderCliqueName: "service-ldr",
workerCliqueNames: []string{"service-wkr"},
expected: nil,
},
{
name: "main_role_has_no_dependencies",
role: RoleMain,
backendFramework: BackendFrameworkVLLM,
leaderCliqueName: "",
workerCliqueNames: nil,
expected: nil,
},
{
name: "worker_with_empty_leader_name",
role: RoleWorker,
backendFramework: BackendFrameworkVLLM,
leaderCliqueName: "",
workerCliqueNames: []string{"service-wkr"},
expected: nil,
},
{
name: "leader_with_empty_worker_names",
role: RoleLeader,
backendFramework: BackendFrameworkTRTLLM,
leaderCliqueName: "service-ldr",
workerCliqueNames: nil,
expected: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := getCliqueStartupDependencies(
tt.role,
tt.backendFramework,
tt.leaderCliqueName,
tt.workerCliqueNames,
)
if !reflect.DeepEqual(result, tt.expected) {
t.Errorf("getCliqueStartupDependencies() = %v, want %v", result, tt.expected)
}
})
}
}
func TestGenerateGrovePodGangSet_StartsAfterDependencies(t *testing.T) {
secretsRetriever := &mockSecretsRetriever{}
tests := []struct {
name string
backendFramework string
expectedDeps map[string][]string // clique name -> expected StartsAfter dependencies
}{
{
name: "vllm_worker_starts_after_leader",
backendFramework: string(BackendFrameworkVLLM),
expectedDeps: map[string][]string{
"main-wkr": {"main-ldr"}, // worker starts after leader
"main-ldr": nil, // leader has no dependencies
},
},
{
name: "sglang_worker_starts_after_leader",
backendFramework: string(BackendFrameworkSGLang),
expectedDeps: map[string][]string{
"main-wkr": {"main-ldr"}, // worker starts after leader
"main-ldr": nil, // leader has no dependencies
},
},
{
name: "trtllm_leader_starts_after_worker",
backendFramework: string(BackendFrameworkTRTLLM),
expectedDeps: map[string][]string{
"main-ldr": {"main-wkr"}, // leader starts after worker
"main-wkr": nil, // worker has no dependencies
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dynamoDeployment := &v1alpha1.DynamoGraphDeployment{
ObjectMeta: metav1.ObjectMeta{
Name: "test-deployment",
Namespace: "default",
},
Spec: v1alpha1.DynamoGraphDeploymentSpec{
BackendFramework: tt.backendFramework,
Services: map[string]*v1alpha1.DynamoComponentDeploymentOverridesSpec{
"main": {
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
ComponentType: "worker", // Must be worker to trigger backend detection
Replicas: ptr.To(int32(1)),
Resources: &common.Resources{
Requests: &common.ResourceItem{
GPU: "1", // 1 GPU per node
Nodes: "2", // Set to 2 nodes to trigger multinode
},
},
},
},
},
},
}
controllerConfig := controller_common.Config{
EtcdAddress: "etcd-address",
NatsAddress: "nats-address",
}
got, err := GenerateGrovePodGangSet(context.Background(), dynamoDeployment, controllerConfig, secretsRetriever)
if err != nil {
t.Errorf("GenerateGrovePodGangSet() error = %v", err)
return
}
// Verify that StartupType is set to Explicit
if got.Spec.Template.StartupType == nil || *got.Spec.Template.StartupType != grovev1alpha1.CliqueStartupTypeExplicit {
t.Errorf("Expected StartupType to be CliqueStartupTypeExplicit, got %v", got.Spec.Template.StartupType)
}
// Verify StartsAfter dependencies for each clique
cliqueMap := make(map[string]*grovev1alpha1.PodCliqueTemplateSpec)
for _, clique := range got.Spec.Template.Cliques {
cliqueMap[clique.Name] = clique
}
for cliqueName, expectedDeps := range tt.expectedDeps {
clique, exists := cliqueMap[cliqueName]
if !exists {
t.Errorf("Expected clique %s not found", cliqueName)
continue
}
if expectedDeps == nil {
if len(clique.Spec.StartsAfter) != 0 {
t.Errorf("Clique %s should have no StartsAfter dependencies, but has %v", cliqueName, clique.Spec.StartsAfter)
}
} else {
if len(clique.Spec.StartsAfter) != len(expectedDeps) {
t.Errorf("Clique %s expected %d StartsAfter dependencies, got %d", cliqueName, len(expectedDeps), len(clique.Spec.StartsAfter))
continue
}
for i, expectedDep := range expectedDeps {
if i >= len(clique.Spec.StartsAfter) || clique.Spec.StartsAfter[i] != expectedDep {
t.Errorf("Clique %s expected StartsAfter[%d] = %s, got %v", cliqueName, i, expectedDep, clique.Spec.StartsAfter)
}
}
}
}
})
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment