Unverified Commit dabd2267 authored by julienmancuso's avatar julienmancuso Committed by GitHub
Browse files

feat: add grove multinode support (#2269)

parent d51580a4
......@@ -16,5 +16,5 @@ apiVersion: v2
name: dynamo-crds
description: A Helm chart for dynamo CRDs
type: application
version: 0.4.0
version: 0.4.1
dependencies: []
\ No newline at end of file
......@@ -404,6 +404,12 @@ spec:
minReplicas:
type: integer
type: object
backendFramework:
enum:
- sglang
- vllm
- trtllm
type: string
componentType:
type: string
dynamoComponent:
......@@ -5039,6 +5045,8 @@ spec:
type: string
memory:
type: string
nodes:
type: string
type: object
requests:
properties:
......@@ -5052,6 +5060,8 @@ spec:
type: string
memory:
type: string
nodes:
type: string
type: object
type: object
runMode:
......
......@@ -44,6 +44,12 @@ spec:
type: object
spec:
properties:
backendFramework:
enum:
- sglang
- vllm
- trtllm
type: string
dynamoGraph:
type: string
envs:
......@@ -5094,6 +5100,8 @@ spec:
type: string
memory:
type: string
nodes:
type: string
type: object
requests:
properties:
......@@ -5107,6 +5115,8 @@ spec:
type: string
memory:
type: string
nodes:
type: string
type: object
type: object
runMode:
......
......@@ -26,6 +26,7 @@ type ResourceItem struct {
CPU string `json:"cpu,omitempty"`
Memory string `json:"memory,omitempty"`
GPU string `json:"gpu,omitempty"`
Nodes string `json:"nodes,omitempty"`
Custom map[string]string `json:"custom,omitempty"`
}
......
......@@ -42,6 +42,10 @@ type DynamoComponentDeploymentSpec struct {
// contains the tag of the DynamoComponent: for example, "my_package:MyService"
DynamoTag string `json:"dynamoTag,omitempty"`
// BackendFramework specifies the backend framework (e.g., "sglang", "vllm", "trtllm")
// +kubebuilder:validation:Enum=sglang;vllm;trtllm
BackendFramework string `json:"backendFramework,omitempty"`
DynamoComponentDeploymentSharedSpec `json:",inline"`
}
......@@ -110,6 +114,13 @@ type IngressSpec struct {
IngressControllerClassName *string `json:"ingressControllerClassName,omitempty"`
}
func (i *IngressSpec) IsVirtualServiceEnabled() bool {
if i == nil {
return false
}
return i.Enabled && i.UseVirtualService && i.VirtualServiceGateway != nil
}
// DynamoComponentDeploymentStatus defines the observed state of DynamoComponentDeployment
type DynamoComponentDeploymentStatus struct {
// INSERT ADDITIONAL STATUS FIELD - define observed state of cluster
......@@ -195,11 +206,3 @@ func (s *DynamoComponentDeployment) SetDynamoDeploymentConfig(config []byte) {
Value: string(config),
})
}
// GetImage returns the docker image of the DynamoComponent
func (s *DynamoComponentDeployment) GetImage() string {
if s.Spec.ExtraPodSpec != nil && s.Spec.ExtraPodSpec.MainContainer != nil {
return s.Spec.ExtraPodSpec.MainContainer.Image
}
return ""
}
......@@ -40,6 +40,9 @@ type DynamoGraphDeploymentSpec struct {
// Environment variables to be set in the deployment
// +kubebuilder:validation:Optional
Envs []corev1.EnvVar `json:"envs,omitempty"`
// BackendFramework specifies the backend framework (e.g., "sglang", "vllm", "trtllm")
// +kubebuilder:validation:Enum=sglang;vllm;trtllm
BackendFramework string `json:"backendFramework,omitempty"`
}
// DynamoGraphDeploymentStatus defines the observed state of DynamoGraphDeployment.
......
......@@ -404,6 +404,12 @@ spec:
minReplicas:
type: integer
type: object
backendFramework:
enum:
- sglang
- vllm
- trtllm
type: string
componentType:
type: string
dynamoComponent:
......@@ -5039,6 +5045,8 @@ spec:
type: string
memory:
type: string
nodes:
type: string
type: object
requests:
properties:
......@@ -5052,6 +5060,8 @@ spec:
type: string
memory:
type: string
nodes:
type: string
type: object
type: object
runMode:
......
......@@ -44,6 +44,12 @@ spec:
type: object
spec:
properties:
backendFramework:
enum:
- sglang
- vllm
- trtllm
type: string
dynamoGraph:
type: string
envs:
......@@ -5094,6 +5100,8 @@ spec:
type: string
memory:
type: string
nodes:
type: string
type: object
requests:
properties:
......@@ -5107,6 +5115,8 @@ spec:
type: string
memory:
type: string
nodes:
type: string
type: object
type: object
runMode:
......
......@@ -6,7 +6,7 @@ toolchain go1.24.3
require (
emperror.dev/errors v0.8.1
github.com/NVIDIA/grove/operator/api v0.0.0-20250717114148-daac6e53774f
github.com/NVIDIA/grove/operator/api v0.0.0-20250801123021-8b42bac59ef2
github.com/bsm/gomega v1.27.10
github.com/google/go-cmp v0.7.0
github.com/imdario/mergo v0.3.6
......
emperror.dev/errors v0.8.1 h1:UavXZ5cSX/4u9iyvH6aDcuGkVjeexUGJ7Ij7G4VfQT0=
emperror.dev/errors v0.8.1/go.mod h1:YcRvLPh626Ubn2xqtoprejnA5nFha+TJ+2vew48kWuE=
github.com/NVIDIA/grove/operator/api v0.0.0-20250717114148-daac6e53774f h1:2ePSNDm7/Tep8F99yCQVH8/vmn86L1cUzTbVlyNopmQ=
github.com/NVIDIA/grove/operator/api v0.0.0-20250717114148-daac6e53774f/go.mod h1:nJL33lsBe+9xCcZLYkNYg1wucE4hJfa4ZfHm1zamuG0=
github.com/NVIDIA/grove/operator/api v0.0.0-20250801123021-8b42bac59ef2 h1:JLOj0GiubP3VlR0okIbuqljvl+e2Vccnu6LX6wL34G0=
github.com/NVIDIA/grove/operator/api v0.0.0-20250801123021-8b42bac59ef2/go.mod h1:QlsR2wQLj9m/zVEqv5SsCPzyjN2ykYZ0r/NEnDf4WB4=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM=
......
......@@ -15,6 +15,8 @@ const (
DynamoSystemPort = 9090
DynamoSystemPortName = "system"
MpiRunSshPort = 2222
EnvDynamoServicePort = "DYNAMO_PORT"
KubeLabelDynamoSelector = "nvidia.com/selector"
......@@ -47,4 +49,18 @@ const (
// Metrics related constants
KubeAnnotationEnableMetrics = "nvidia.com/enable-metrics" // User-provided annotation to control metrics
KubeLabelMetricsEnabled = "nvidia.com/metrics-enabled" // Controller-managed label for pod selection
KubeValueNameSharedMemory = "shared-memory"
// Grove multinode role suffixes
GroveRoleSuffixLeader = "ldr"
GroveRoleSuffixWorker = "wkr"
MpiRunSshSecretName = "mpi-run-ssh-secret"
)
type MultinodeDeploymentType string
const (
MultinodeDeploymentTypeGrove MultinodeDeploymentType = "grove"
MultinodeDeploymentTypeLWS MultinodeDeploymentType = "lws"
)
......@@ -22,11 +22,11 @@ package controller
import (
"context"
"fmt"
"maps"
"os"
"strconv"
"time"
"github.com/imdario/mergo"
appsv1 "k8s.io/api/apps/v1"
autoscalingv2 "k8s.io/api/autoscaling/v2"
corev1 "k8s.io/api/core/v1"
......@@ -34,9 +34,9 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"emperror.dev/errors"
dynamoCommon "github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/dynamo/common"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/dynamo/schemas"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/consts"
commonconsts "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/consts"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/controller_common"
commonController "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/controller_common"
......@@ -48,7 +48,6 @@ import (
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/intstr"
"k8s.io/client-go/tools/record"
"k8s.io/utils/ptr"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/builder"
"sigs.k8s.io/controller-runtime/pkg/client"
......@@ -64,7 +63,6 @@ import (
const (
DefaultClusterName = "default"
DefaultServiceAccountName = "default"
KubeValueNameSharedMemory = "shared-memory"
KubeAnnotationDeploymentStrategy = "nvidia.com/deployment-strategy"
KubeAnnotationEnableStealingTrafficDebugMode = "nvidia.com/enable-stealing-traffic-debug-mode"
KubeAnnotationEnableDebugMode = "nvidia.com/enable-debug-mode"
......@@ -78,6 +76,7 @@ const (
KubeAnnotationLWSSize = "nvidia.com/lws-size"
DeploymentTypeStandard = "standard"
DeploymentTypeLeaderWorker = "leader-worker"
DeploymentTypeMultinodeGrove = "multinode-grove"
ComponentTypePlanner = "Planner"
)
......@@ -510,11 +509,7 @@ func (r *DynamoComponentDeploymentReconciler) generateLeaderPodTemplateSpec(ctx
return nil, errors.Wrap(err, "failed to generate leader pod template")
}
if labels != nil {
leaderPodTemplateSpec.ObjectMeta.Labels = labels
} else {
leaderPodTemplateSpec.ObjectMeta.Labels = make(map[string]string)
}
maps.Copy(leaderPodTemplateSpec.ObjectMeta.Labels, labels)
leaderPodTemplateSpec.ObjectMeta.Labels["role"] = "leader"
leaderPodTemplateSpec.ObjectMeta.Labels["instance-id"] = fmt.Sprintf("%d", instanceID)
delete(leaderPodTemplateSpec.ObjectMeta.Labels, commonconsts.KubeLabelDynamoSelector)
......@@ -556,11 +551,7 @@ func (r *DynamoComponentDeploymentReconciler) generateWorkerPodTemplateSpec(ctx
return nil, errors.Wrap(err, "failed to generate worker pod template")
}
if labels != nil {
workerPodTemplateSpec.ObjectMeta.Labels = labels
} else {
workerPodTemplateSpec.ObjectMeta.Labels = make(map[string]string)
}
maps.Copy(workerPodTemplateSpec.ObjectMeta.Labels, labels)
workerPodTemplateSpec.ObjectMeta.Labels["role"] = "worker"
workerPodTemplateSpec.ObjectMeta.Labels["instance-id"] = fmt.Sprintf("%d", instanceID)
delete(workerPodTemplateSpec.ObjectMeta.Labels, commonconsts.KubeLabelDynamoSelector)
......@@ -988,8 +979,7 @@ func (r *DynamoComponentDeploymentReconciler) generateVirtualService(ctx context
},
}
vsEnabled := opt.dynamoComponentDeployment.Spec.Ingress != nil && opt.dynamoComponentDeployment.Spec.Ingress.Enabled && opt.dynamoComponentDeployment.Spec.Ingress.UseVirtualService && opt.dynamoComponentDeployment.Spec.Ingress.VirtualServiceGateway != nil
if !vsEnabled {
if !opt.dynamoComponentDeployment.Spec.Ingress.IsVirtualServiceEnabled() {
log.Info("VirtualService is not enabled")
return vs, true, nil
}
......@@ -1231,8 +1221,6 @@ func (r *DynamoComponentDeploymentReconciler) generatePodTemplateSpec(ctx contex
kubeName := r.getKubeName(opt.dynamoComponentDeployment, opt.isStealingTrafficDebugModeEnabled)
containerPort := commonconsts.DynamoServicePort
resourceAnnotations := opt.dynamoComponentDeployment.Spec.Annotations
if resourceAnnotations == nil {
......@@ -1241,192 +1229,22 @@ func (r *DynamoComponentDeploymentReconciler) generatePodTemplateSpec(ctx contex
isDebugModeEnabled := checkIfIsDebugModeEnabled(resourceAnnotations)
defaultEnvs := []corev1.EnvVar{
{
Name: commonconsts.EnvDynamoServicePort,
Value: fmt.Sprintf("%d", containerPort),
},
}
if r.Config.NatsAddress != "" {
defaultEnvs = append(defaultEnvs, corev1.EnvVar{
Name: "NATS_SERVER",
Value: r.Config.NatsAddress,
})
}
if r.Config.EtcdAddress != "" {
defaultEnvs = append(defaultEnvs, corev1.EnvVar{
Name: "ETCD_ENDPOINTS",
Value: r.Config.EtcdAddress,
})
}
envs := dynamo.MergeEnvs(opt.dynamoComponentDeployment.Spec.Envs, defaultEnvs)
var livenessProbe *corev1.Probe
if opt.dynamoComponentDeployment.Spec.LivenessProbe != nil {
livenessProbe = opt.dynamoComponentDeployment.Spec.LivenessProbe
}
var readinessProbe *corev1.Probe
if opt.dynamoComponentDeployment.Spec.ReadinessProbe != nil {
readinessProbe = opt.dynamoComponentDeployment.Spec.ReadinessProbe
}
volumes := make([]corev1.Volume, 0)
volumeMounts := make([]corev1.VolumeMount, 0)
dynamoResources := opt.dynamoComponentDeployment.Spec.Resources
resources, err := getResourcesConfig(dynamoResources)
basePodSpec, err := dynamo.GenerateBasePodSpecForController(opt.dynamoComponentDeployment, r.DockerSecretRetriever, r.Config, dynamo.RoleMain, consts.MultinodeDeploymentTypeLWS)
if err != nil {
err = errors.Wrap(err, "failed to get resources config")
err = errors.Wrap(err, "failed to generate base pod spec")
return nil, err
}
sharedMemorySizeLimit := resource.MustParse("64Mi")
memoryLimit := resources.Limits[corev1.ResourceMemory]
if !memoryLimit.IsZero() {
sharedMemorySizeLimit.SetMilli(memoryLimit.MilliValue() / 2)
}
volumes = append(volumes, corev1.Volume{
Name: KubeValueNameSharedMemory,
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{
Medium: corev1.StorageMediumMemory,
SizeLimit: &sharedMemorySizeLimit,
},
},
})
volumeMounts = append(volumeMounts, corev1.VolumeMount{
Name: KubeValueNameSharedMemory,
MountPath: "/dev/shm",
})
if opt.dynamoComponentDeployment.Spec.PVC != nil {
volumes = append(volumes, corev1.Volume{
Name: getPvcName(opt.dynamoComponentDeployment, opt.dynamoComponentDeployment.Spec.PVC.Name),
VolumeSource: corev1.VolumeSource{
PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{
ClaimName: getPvcName(opt.dynamoComponentDeployment, opt.dynamoComponentDeployment.Spec.PVC.Name),
},
},
})
volumeMounts = append(volumeMounts, corev1.VolumeMount{
Name: getPvcName(opt.dynamoComponentDeployment, opt.dynamoComponentDeployment.Spec.PVC.Name),
MountPath: *opt.dynamoComponentDeployment.Spec.PVC.MountPoint,
})
}
imageName := opt.dynamoComponentDeployment.GetImage()
if imageName == "" {
return nil, errors.Errorf("image is not set for component %s", opt.dynamoComponentDeployment.Name)
// Ensure we have at least one container (the main container should be there from GenerateBasePodSpec)
if len(basePodSpec.Containers) == 0 {
return nil, errors.New("no containers found in base pod spec")
}
var securityContext *corev1.SecurityContext
var mainContainerSecurityContext *corev1.SecurityContext
enableRestrictedSecurityContext := os.Getenv("ENABLE_RESTRICTED_SECURITY_CONTEXT") == "true"
if enableRestrictedSecurityContext {
securityContext = &corev1.SecurityContext{
AllowPrivilegeEscalation: ptr.To(false),
RunAsNonRoot: ptr.To(true),
RunAsUser: ptr.To(int64(1000)),
RunAsGroup: ptr.To(int64(1000)),
SeccompProfile: &corev1.SeccompProfile{
Type: corev1.SeccompProfileTypeRuntimeDefault,
},
Capabilities: &corev1.Capabilities{
Drop: []corev1.Capability{"ALL"},
},
}
mainContainerSecurityContext = securityContext.DeepCopy()
mainContainerSecurityContext.RunAsUser = ptr.To(int64(1034))
}
// Get the main container from the base spec
container := basePodSpec.Containers[0]
containers := make([]corev1.Container, 0, 2)
// TODO: Temporarily disabling probes
container := corev1.Container{
Name: "main",
Image: imageName,
LivenessProbe: livenessProbe,
ReadinessProbe: readinessProbe,
Resources: resources,
Env: envs,
TTY: true,
Stdin: true,
VolumeMounts: volumeMounts,
Ports: []corev1.ContainerPort{
{
Protocol: corev1.ProtocolTCP,
Name: commonconsts.DynamoContainerPortName,
ContainerPort: int32(containerPort), // nolint: gosec
},
},
SecurityContext: mainContainerSecurityContext,
}
// Add system port for worker components
if opt.dynamoComponentDeployment.Spec.ComponentType == commonconsts.ComponentTypeWorker {
container.Ports = append(container.Ports, corev1.ContainerPort{
Protocol: corev1.ProtocolTCP,
Name: commonconsts.DynamoSystemPortName,
ContainerPort: int32(commonconsts.DynamoSystemPort),
})
}
if opt.dynamoComponentDeployment.Spec.EnvFromSecret != nil {
container.EnvFrom = []corev1.EnvFromSource{
{
SecretRef: &corev1.SecretEnvSource{
LocalObjectReference: corev1.LocalObjectReference{
Name: *opt.dynamoComponentDeployment.Spec.EnvFromSecret,
},
},
},
}
}
if resourceAnnotations["nvidia.com/enable-container-privileged"] == commonconsts.KubeLabelValueTrue {
if container.SecurityContext == nil {
container.SecurityContext = &corev1.SecurityContext{}
}
container.SecurityContext.Privileged = &[]bool{true}[0]
}
if resourceAnnotations["nvidia.com/enable-container-ptrace"] == commonconsts.KubeLabelValueTrue {
if container.SecurityContext == nil {
container.SecurityContext = &corev1.SecurityContext{}
}
container.SecurityContext.Capabilities = &corev1.Capabilities{
Add: []corev1.Capability{"SYS_PTRACE"},
}
}
if resourceAnnotations["nvidia.com/run-container-as-root"] == commonconsts.KubeLabelValueTrue {
if container.SecurityContext == nil {
container.SecurityContext = &corev1.SecurityContext{}
}
container.SecurityContext.RunAsUser = &[]int64{0}[0]
}
// Merge extraPodSpecMainContainer into container, only overriding empty fields
if opt.dynamoComponentDeployment.Spec.ExtraPodSpec != nil {
extraPodSpecMainContainer := opt.dynamoComponentDeployment.Spec.ExtraPodSpec.MainContainer
if extraPodSpecMainContainer != nil {
// Merge non empty fields from extraPodSpecMainContainer into container, only overriding empty fields
err := mergo.Merge(&container, extraPodSpecMainContainer.DeepCopy())
if err != nil {
err = errors.Wrapf(err, "failed to merge extraPodSpecMainContainer into container")
return nil, err
}
// finally merge the envs from extraPodSpecMainContainer into container
container.Env = dynamo.MergeEnvs(container.Env, extraPodSpecMainContainer.Env)
}
}
containers = append(containers, container)
debuggerImage := "python:3.12-slim"
......@@ -1465,42 +1283,14 @@ func (r *DynamoComponentDeploymentReconciler) generatePodTemplateSpec(ctx contex
podLabels[commonconsts.KubeLabelDynamoSelector] = kubeName
imagePullSecrets := []corev1.LocalObjectReference{}
if r.DockerSecretRetriever == nil {
err = errors.New("DockerSecretRetriever is not initialized")
return
}
secretsName, err := r.DockerSecretRetriever.GetSecrets(opt.dynamoComponentDeployment.Namespace, imageName)
if err != nil {
err = errors.Wrapf(err, "failed to get secrets for component %s and image %s", opt.dynamoComponentDeployment.Name, imageName)
return
}
for _, secretName := range secretsName {
imagePullSecrets = append(imagePullSecrets, corev1.LocalObjectReference{
Name: secretName,
})
}
podSpec := &corev1.PodSpec{}
if opt.dynamoComponentDeployment.Spec.ExtraPodSpec != nil && opt.dynamoComponentDeployment.Spec.ExtraPodSpec.PodSpec != nil {
podSpec = opt.dynamoComponentDeployment.Spec.ExtraPodSpec.PodSpec.DeepCopy()
}
podSpec.Containers = append(podSpec.Containers, containers...)
podSpec.Volumes = append(podSpec.Volumes, volumes...)
podSpec.ImagePullSecrets = append(podSpec.ImagePullSecrets, imagePullSecrets...)
podSpec := &basePodSpec
podSpec.Containers = containers
extraPodMetadata := opt.dynamoComponentDeployment.Spec.ExtraPodMetadata
if extraPodMetadata != nil {
for k, v := range extraPodMetadata.Annotations {
podAnnotations[k] = v
}
for k, v := range extraPodMetadata.Labels {
podLabels[k] = v
}
maps.Copy(podAnnotations, extraPodMetadata.Annotations)
maps.Copy(podLabels, extraPodMetadata.Labels)
}
if podSpec.ServiceAccountName == "" {
......@@ -1519,18 +1309,6 @@ func (r *DynamoComponentDeploymentReconciler) generatePodTemplateSpec(ctx contex
}
}
if resourceAnnotations["nvidia.com/enable-host-ipc"] == commonconsts.KubeLabelValueTrue {
podSpec.HostIPC = true
}
if resourceAnnotations["nvidia.com/enable-host-network"] == commonconsts.KubeLabelValueTrue {
podSpec.HostNetwork = true
}
if resourceAnnotations["nvidia.com/enable-host-pid"] == commonconsts.KubeLabelValueTrue {
podSpec.HostPID = true
}
if opt.isStealingTrafficDebugModeEnabled || isDebugModeEnabled {
podSpec.ShareProcessNamespace = &[]bool{true}[0]
}
......@@ -1546,31 +1324,6 @@ func (r *DynamoComponentDeploymentReconciler) generatePodTemplateSpec(ctx contex
return
}
func getResourcesConfig(resources *dynamoCommon.Resources) (corev1.ResourceRequirements, error) {
defaultResources := corev1.ResourceRequirements{
Requests: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("300m"),
corev1.ResourceMemory: resource.MustParse("500Mi"),
},
Limits: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("500m"),
corev1.ResourceMemory: resource.MustParse("1Gi"),
},
}
if resources == nil {
return defaultResources, nil
}
resourcesConfig, err := controller_common.GetResourcesConfig(resources)
if err != nil {
return corev1.ResourceRequirements{}, errors.Wrapf(err, "failed to get resources config")
}
err = mergo.Merge(resourcesConfig, defaultResources.DeepCopy())
if err != nil {
return corev1.ResourceRequirements{}, errors.Wrapf(err, "failed to merge resources config")
}
return *resourcesConfig, nil
}
func (r *DynamoComponentDeploymentReconciler) generateService(opt generateResourceOption) (*corev1.Service, bool, error) {
var kubeName string
if opt.isGenericService {
......
......@@ -29,6 +29,7 @@ import (
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1"
commonconsts "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/consts"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/controller_common"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/dynamo"
"github.com/google/go-cmp/cmp"
"github.com/onsi/gomega"
"github.com/onsi/gomega/format"
......@@ -823,6 +824,7 @@ func TestDynamoComponentDeploymentReconciler_generateLeaderWorkerSet(t *testing.
Spec: v1alpha1.DynamoComponentDeploymentSpec{
DynamoComponent: "test-lws-component",
DynamoTag: "test-tag",
BackendFramework: string(dynamo.BackendFrameworkVLLM),
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
Envs: []corev1.EnvVar{
{
......@@ -837,10 +839,22 @@ func TestDynamoComponentDeploymentReconciler_generateLeaderWorkerSet(t *testing.
"nvidia.com/lws-size": "2",
},
Resources: &common.Resources{
Requests: &common.ResourceItem{
CPU: "300m",
Memory: "500Mi",
},
Limits: &common.ResourceItem{
GPU: "1",
},
},
ExtraPodMetadata: &common.ExtraPodMetadata{
Annotations: map[string]string{
"nvidia.com/annotation1": "annotation1",
},
Labels: map[string]string{
"nvidia.com/label1": "label1",
},
},
ExtraPodSpec: &dynamoCommon.ExtraPodSpec{
PodSpec: &corev1.PodSpec{
TerminationGracePeriodSeconds: ptr.To(int64(10)),
......@@ -897,48 +911,58 @@ func TestDynamoComponentDeploymentReconciler_generateLeaderWorkerSet(t *testing.
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
"instance-id": "0",
commonconsts.KubeLabelMetricsEnabled: commonconsts.KubeLabelValueTrue,
"role": "leader",
"nvidia.com/label1": "label1",
},
Annotations: map[string]string{
"scheduling.k8s.io/group-name": "test-lws-deploy-0",
"nvidia.com/annotation1": "annotation1",
},
},
Spec: corev1.PodSpec{
SchedulerName: "volcano",
TerminationGracePeriodSeconds: ptr.To(int64(10)),
Volumes: []corev1.Volume{
{
Name: "shared-memory",
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{
Medium: corev1.StorageMediumMemory,
SizeLimit: resource.NewQuantity(512*1024*1024, resource.BinarySI), // 512Mi default (calculated from memory limit)
},
},
},
},
Containers: []corev1.Container{
{
Name: "main",
Image: "test-image:latest",
Command: []string{"sh", "-c"},
Args: []string{"ray start --head --port=6379 && some dynamo command"},
Env: []corev1.EnvVar{{Name: "DYNAMO_PORT", Value: fmt.Sprintf("%d", commonconsts.DynamoServicePort)}, {Name: "TEST_ENV_FROM_DYNAMO_COMPONENT_DEPLOYMENT_SPEC", Value: "test_value_from_dynamo_component_deployment_spec"}, {Name: "TEST_ENV_FROM_EXTRA_POD_SPEC", Value: "test_value_from_extra_pod_spec"}},
VolumeMounts: []corev1.VolumeMount{
Env: []corev1.EnvVar{{Name: "TEST_ENV_FROM_DYNAMO_COMPONENT_DEPLOYMENT_SPEC", Value: "test_value_from_dynamo_component_deployment_spec"}, {Name: "TEST_ENV_FROM_EXTRA_POD_SPEC", Value: "test_value_from_extra_pod_spec"}, {Name: "DYNAMO_PORT", Value: fmt.Sprintf("%d", commonconsts.DynamoServicePort)}},
Ports: []corev1.ContainerPort{
{
Name: "shared-memory", MountPath: "/dev/shm",
Protocol: corev1.ProtocolTCP, Name: commonconsts.DynamoServicePortName, ContainerPort: commonconsts.DynamoServicePort,
},
},
Ports: []corev1.ContainerPort{
VolumeMounts: []corev1.VolumeMount{
{
Protocol: corev1.ProtocolTCP, Name: commonconsts.DynamoServicePortName, ContainerPort: commonconsts.DynamoServicePort,
Name: "shared-memory",
MountPath: "/dev/shm",
},
},
TTY: true,
Stdin: true,
Resources: corev1.ResourceRequirements{
Requests: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("300m"),
corev1.ResourceMemory: resource.MustParse("500Mi"),
},
Limits: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("500m"),
corev1.ResourceMemory: resource.MustParse("1Gi"),
"nvidia.com/gpu": resource.MustParse("1"),
},
},
},
},
Volumes: []corev1.Volume{{Name: "shared-memory", VolumeSource: corev1.VolumeSource{EmptyDir: &corev1.EmptyDirVolumeSource{Medium: corev1.StorageMediumMemory, SizeLimit: limit}}}},
ImagePullSecrets: nil, // Assuming default config gives empty secret name
ServiceAccountName: "default-test-sa", // Updated to reflect mocked SA
},
......@@ -947,37 +971,52 @@ func TestDynamoComponentDeploymentReconciler_generateLeaderWorkerSet(t *testing.
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
"instance-id": "0",
commonconsts.KubeLabelMetricsEnabled: commonconsts.KubeLabelValueTrue,
"role": "worker",
"nvidia.com/label1": "label1",
},
Annotations: map[string]string{
"scheduling.k8s.io/group-name": "test-lws-deploy-0",
"nvidia.com/annotation1": "annotation1",
},
},
Spec: corev1.PodSpec{
TerminationGracePeriodSeconds: ptr.To(int64(10)),
SchedulerName: "volcano",
Volumes: []corev1.Volume{
{
Name: "shared-memory",
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{
Medium: corev1.StorageMediumMemory,
SizeLimit: resource.NewQuantity(512*1024*1024, resource.BinarySI), // 512Mi default (calculated from memory limit)
},
},
},
},
Containers: []corev1.Container{
{
Name: "main",
Image: "test-image:latest",
Command: []string{"sh", "-c"},
Args: []string{"ray start --address=$(LWS_LEADER_ADDRESS):6379 --block"},
Env: []corev1.EnvVar{{Name: "DYNAMO_PORT", Value: fmt.Sprintf("%d", commonconsts.DynamoServicePort)}, {Name: "TEST_ENV_FROM_DYNAMO_COMPONENT_DEPLOYMENT_SPEC", Value: "test_value_from_dynamo_component_deployment_spec"}, {Name: "TEST_ENV_FROM_EXTRA_POD_SPEC", Value: "test_value_from_extra_pod_spec"}},
VolumeMounts: []corev1.VolumeMount{{Name: "shared-memory", MountPath: "/dev/shm"}},
Ports: []corev1.ContainerPort{
Env: []corev1.EnvVar{{Name: "TEST_ENV_FROM_DYNAMO_COMPONENT_DEPLOYMENT_SPEC", Value: "test_value_from_dynamo_component_deployment_spec"}, {Name: "TEST_ENV_FROM_EXTRA_POD_SPEC", Value: "test_value_from_extra_pod_spec"}, {Name: "DYNAMO_PORT", Value: fmt.Sprintf("%d", commonconsts.DynamoServicePort)}},
Ports: []corev1.ContainerPort{{Protocol: corev1.ProtocolTCP, Name: commonconsts.DynamoServicePortName, ContainerPort: commonconsts.DynamoServicePort}},
VolumeMounts: []corev1.VolumeMount{
{
Protocol: corev1.ProtocolTCP, Name: commonconsts.DynamoServicePortName, ContainerPort: commonconsts.DynamoServicePort,
Name: "shared-memory",
MountPath: "/dev/shm",
},
},
TTY: true,
Stdin: true,
Resources: corev1.ResourceRequirements{
Requests: corev1.ResourceList{corev1.ResourceCPU: resource.MustParse("300m"), corev1.ResourceMemory: resource.MustParse("500Mi")},
Limits: corev1.ResourceList{corev1.ResourceCPU: resource.MustParse("500m"), corev1.ResourceMemory: resource.MustParse("1Gi"), "nvidia.com/gpu": resource.MustParse("1")},
Limits: corev1.ResourceList{"nvidia.com/gpu": resource.MustParse("1")},
Requests: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("300m"),
corev1.ResourceMemory: resource.MustParse("500Mi"),
},
},
},
},
Volumes: []corev1.Volume{{Name: "shared-memory", VolumeSource: corev1.VolumeSource{EmptyDir: &corev1.EmptyDirVolumeSource{Medium: corev1.StorageMediumMemory, SizeLimit: limit}}}},
ImagePullSecrets: nil,
ServiceAccountName: "default-test-sa", // Updated to reflect mocked SA
},
......
......@@ -222,14 +222,14 @@ func (r *DynamoGraphDeploymentReconciler) reconcileGroveResources(ctx context.Co
return true
}))
// generate the main component virtual service
if r.Config.IngressConfig.UseVirtualService() {
mainComponentVirtualService := dynamo.GenerateComponentVirtualService(ctx, dynamo.GetDynamoComponentName(dynamoDeployment, componentName), dynamoDeployment.Namespace, ingressSpec)
if err != nil {
logger.Error(err, "failed to generate the main component virtual service")
return "", "", "", fmt.Errorf("failed to generate the main component virtual service: %w", err)
}
_, syncedMainComponentVirtualService, err := commonController.SyncResource(ctx, r, dynamoDeployment, func(ctx context.Context) (*networkingv1beta1.VirtualService, bool, error) {
vsEnabled := ingressSpec.Enabled && ingressSpec.UseVirtualService && ingressSpec.VirtualServiceGateway != nil
if !vsEnabled {
if !ingressSpec.IsVirtualServiceEnabled() {
logger.Info("VirtualService is not enabled")
return mainComponentVirtualService, true, nil
}
......@@ -244,6 +244,7 @@ func (r *DynamoGraphDeploymentReconciler) reconcileGroveResources(ctx context.Co
}))
}
}
}
return r.checkResourcesReadiness(resources)
}
......
package dynamo
import (
"fmt"
commonconsts "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/consts"
)
// generateGroveLeaderHostname generates the hostname for the leader pod in Grove multinode deployments
// The leader hostname follows the pattern: {GROVE_PCSG_NAME}-{GROVE_PCSG_INDEX}-serviceName-{GroveRoleSuffixLeader}-0.{GROVE_HEADLESS_SERVICE}
func generateGroveLeaderHostname(serviceName string) string {
return fmt.Sprintf("${GROVE_PCSG_NAME}-${GROVE_PCSG_INDEX}-%s-%s-0.${GROVE_HEADLESS_SERVICE}", serviceName, commonconsts.GroveRoleSuffixLeader)
}
package dynamo
import (
"fmt"
"regexp"
"strings"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1"
commonconsts "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/consts"
corev1 "k8s.io/api/core/v1"
)
type SGLangBackend struct{}
func (b *SGLangBackend) UpdateContainer(container *corev1.Container, numberOfNodes int32, role Role, component *v1alpha1.DynamoComponentDeploymentOverridesSpec, multinodeDeploymentType commonconsts.MultinodeDeploymentType, serviceName string) {
// For single node, nothing to do
if numberOfNodes <= 1 {
return
}
// Remove probes for multinode leader and worker
if role == RoleLeader || role == RoleWorker {
container.LivenessProbe = nil
container.ReadinessProbe = nil
container.StartupProbe = nil
}
// Generate the flags to add
flags := b.getMultinodeFlags(numberOfNodes, role, multinodeDeploymentType, serviceName)
if flags == "" {
return
}
// Flatten all args into a single command and inject flags
if len(container.Args) > 0 {
fullCommand := strings.Join(container.Args, " ")
modifiedCommand := b.injectFlagsIntoPythonCommand(fullCommand, flags)
container.Args = []string{modifiedCommand}
}
}
func (b *SGLangBackend) UpdatePodSpec(podSpec *corev1.PodSpec, numberOfNodes int32, role Role, component *v1alpha1.DynamoComponentDeploymentOverridesSpec, multinodeDeploymentType commonconsts.MultinodeDeploymentType, serviceName string) {
// do nothing
}
// getMultinodeFlags returns the multinode flags as a single string
func (b *SGLangBackend) getMultinodeFlags(numberOfNodes int32, role Role, multinodeDeploymentType commonconsts.MultinodeDeploymentType, serviceName string) string {
var distInitAddr, nodeRank string
// Determine dist-init-addr
if multinodeDeploymentType == commonconsts.MultinodeDeploymentTypeGrove {
leaderHostname := generateGroveLeaderHostname(serviceName)
distInitAddr = fmt.Sprintf("%s:29500", leaderHostname)
} else {
distInitAddr = "${LWS_LEADER_ADDRESS}:29500"
}
// Determine node-rank
if role == RoleLeader {
nodeRank = "0"
} else {
if multinodeDeploymentType == commonconsts.MultinodeDeploymentTypeGrove {
nodeRank = "$((GROVE_PCLQ_POD_INDEX + 1))"
} else {
nodeRank = "${LWS_WORKER_INDEX}"
}
}
return fmt.Sprintf("--dist-init-addr %s --nnodes %d --node-rank %s", distInitAddr, numberOfNodes, nodeRank)
}
// injectFlagsIntoPythonCommand finds python sglang commands and adds flags after them
func (b *SGLangBackend) injectFlagsIntoPythonCommand(arg, flags string) string {
// Regex to match python commands that contain sglang
// Matches: python, python3, python3.11, etc. followed by sglang-related modules
pattern := `(python[0-9.]*\s+[^|&;]*sglang[^|&;]*?)(\s|$|[|&;])`
re := regexp.MustCompile(pattern)
// Replace with the command + flags + whatever comes after
result := re.ReplaceAllStringFunc(arg, func(match string) string {
// Extract the python command part and the delimiter
submatches := re.FindStringSubmatch(match)
if len(submatches) >= 3 {
pythonCmd := submatches[1]
delimiter := submatches[2]
return pythonCmd + " " + flags + delimiter
}
return match
})
return result
}
package dynamo
import (
"reflect"
"testing"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/consts"
corev1 "k8s.io/api/core/v1"
)
func TestSGLangBackend_DirectFlagInjection(t *testing.T) {
backend := &SGLangBackend{}
tests := []struct {
name string
numberOfNodes int32
role Role
multinodeDeploymentType consts.MultinodeDeploymentType
initialArgs []string
expectedArgs []string
description string
}{
{
name: "single node does not modify args",
numberOfNodes: 1,
role: RoleMain,
multinodeDeploymentType: consts.MultinodeDeploymentTypeGrove,
initialArgs: []string{"python -m dynamo.sglang.worker"},
expectedArgs: []string{"python -m dynamo.sglang.worker"},
description: "Single node should not modify anything",
},
{
name: "multinode adds flags to simple python command",
numberOfNodes: 2,
role: RoleLeader,
multinodeDeploymentType: consts.MultinodeDeploymentTypeGrove,
initialArgs: []string{"python -m dynamo.sglang.worker"},
expectedArgs: []string{"python -m dynamo.sglang.worker --dist-init-addr ${GROVE_PCSG_NAME}-${GROVE_PCSG_INDEX}-test-service-ldr-0.${GROVE_HEADLESS_SERVICE}:29500 --nnodes 2 --node-rank 0"},
description: "Should add multinode flags directly to python command",
},
{
name: "multinode with complex command",
numberOfNodes: 2,
role: RoleLeader,
multinodeDeploymentType: consts.MultinodeDeploymentTypeGrove,
initialArgs: []string{"echo blah | wc -l && python -m dynamo.sglang.worker && ls -al"},
expectedArgs: []string{"echo blah | wc -l && python -m dynamo.sglang.worker --dist-init-addr ${GROVE_PCSG_NAME}-${GROVE_PCSG_INDEX}-test-service-ldr-0.${GROVE_HEADLESS_SERVICE}:29500 --nnodes 2 --node-rank 0 && ls -al"},
description: "Should add flags only to python command, not other commands",
},
{
name: "multinode worker with Grove deployment",
numberOfNodes: 3,
role: RoleWorker,
multinodeDeploymentType: consts.MultinodeDeploymentTypeGrove,
initialArgs: []string{"python -m dynamo.sglang.worker"},
expectedArgs: []string{"python -m dynamo.sglang.worker --dist-init-addr ${GROVE_PCSG_NAME}-${GROVE_PCSG_INDEX}-test-service-ldr-0.${GROVE_HEADLESS_SERVICE}:29500 --nnodes 3 --node-rank $((GROVE_PCLQ_POD_INDEX + 1))"},
description: "Worker should get correct node rank",
},
{
name: "LWS deployment uses correct address",
numberOfNodes: 2,
role: RoleLeader,
multinodeDeploymentType: consts.MultinodeDeploymentTypeLWS,
initialArgs: []string{"python -m dynamo.sglang.worker"},
expectedArgs: []string{"python -m dynamo.sglang.worker --dist-init-addr ${LWS_LEADER_ADDRESS}:29500 --nnodes 2 --node-rank 0"},
description: "LWS deployment should use LWS_LEADER_ADDRESS",
},
{
name: "command with pipes gets flags before pipe",
numberOfNodes: 2,
role: RoleLeader,
multinodeDeploymentType: consts.MultinodeDeploymentTypeGrove,
initialArgs: []string{"python -m dynamo.sglang.worker | tee /tmp/log"},
expectedArgs: []string{"python -m dynamo.sglang.worker --dist-init-addr ${GROVE_PCSG_NAME}-${GROVE_PCSG_INDEX}-test-service-ldr-0.${GROVE_HEADLESS_SERVICE}:29500 --nnodes 2 --node-rank 0 | tee /tmp/log"},
description: "Should insert flags before pipe operator",
},
{
name: "multiple args are flattened and processed together",
numberOfNodes: 2,
role: RoleLeader,
multinodeDeploymentType: consts.MultinodeDeploymentTypeGrove,
initialArgs: []string{"echo start", "python -m dynamo.sglang.worker", "echo done"},
expectedArgs: []string{"echo start python -m dynamo.sglang.worker --dist-init-addr ${GROVE_PCSG_NAME}-${GROVE_PCSG_INDEX}-test-service-ldr-0.${GROVE_HEADLESS_SERVICE}:29500 --nnodes 2 --node-rank 0 echo done"},
description: "Multiple args should be flattened and python command gets flags",
},
{
name: "no sglang command means flattened but no changes",
numberOfNodes: 2,
role: RoleLeader,
multinodeDeploymentType: consts.MultinodeDeploymentTypeGrove,
initialArgs: []string{"echo hello", "python -m some.other.module"},
expectedArgs: []string{"echo hello python -m some.other.module"},
description: "Non-sglang commands should be flattened but not modified",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
container := &corev1.Container{
Args: append([]string{}, tt.initialArgs...),
}
backend.UpdateContainer(container, tt.numberOfNodes, tt.role, &v1alpha1.DynamoComponentDeploymentOverridesSpec{}, tt.multinodeDeploymentType, "test-service")
if !reflect.DeepEqual(container.Args, tt.expectedArgs) {
t.Errorf("UpdateContainer() args = %v, want %v", container.Args, tt.expectedArgs)
}
// Verify no environment variables were added
if len(container.Env) > 0 {
t.Errorf("UpdateContainer() should not add environment variables, but added: %v", container.Env)
}
// Verify command was not changed
if len(container.Command) > 0 {
t.Errorf("UpdateContainer() should not modify command, but set: %v", container.Command)
}
})
}
}
func TestSGLangBackend_ProbeRemoval(t *testing.T) {
backend := &SGLangBackend{}
tests := []struct {
name string
numberOfNodes int32
role Role
multinodeDeploymentType consts.MultinodeDeploymentType
expectProbesRemoved bool
}{
{
name: "single node does not remove probes",
numberOfNodes: 1,
role: RoleMain,
multinodeDeploymentType: consts.MultinodeDeploymentTypeGrove,
expectProbesRemoved: false,
},
{
name: "multinode leader removes probes",
numberOfNodes: 2,
role: RoleLeader,
multinodeDeploymentType: consts.MultinodeDeploymentTypeGrove,
expectProbesRemoved: true,
},
{
name: "multinode worker removes probes",
numberOfNodes: 2,
role: RoleWorker,
multinodeDeploymentType: consts.MultinodeDeploymentTypeGrove,
expectProbesRemoved: true,
},
{
name: "multinode main role does not remove probes",
numberOfNodes: 2,
role: RoleMain,
multinodeDeploymentType: consts.MultinodeDeploymentTypeGrove,
expectProbesRemoved: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create initial probes
livenessProbe := &corev1.Probe{InitialDelaySeconds: 30}
readinessProbe := &corev1.Probe{InitialDelaySeconds: 10}
startupProbe := &corev1.Probe{InitialDelaySeconds: 5}
container := &corev1.Container{
Args: []string{"python -m dynamo.sglang.worker"},
LivenessProbe: livenessProbe,
ReadinessProbe: readinessProbe,
StartupProbe: startupProbe,
}
backend.UpdateContainer(container, tt.numberOfNodes, tt.role, &v1alpha1.DynamoComponentDeploymentOverridesSpec{}, tt.multinodeDeploymentType, "test-service")
if tt.expectProbesRemoved {
if container.LivenessProbe != nil {
t.Errorf("Expected LivenessProbe to be removed, but it was not")
}
if container.ReadinessProbe != nil {
t.Errorf("Expected ReadinessProbe to be removed, but it was not")
}
if container.StartupProbe != nil {
t.Errorf("Expected StartupProbe to be removed, but it was not")
}
} else {
if container.LivenessProbe == nil {
t.Errorf("Expected LivenessProbe to be preserved, but it was removed")
}
if container.ReadinessProbe == nil {
t.Errorf("Expected ReadinessProbe to be preserved, but it was removed")
}
if container.StartupProbe == nil {
t.Errorf("Expected StartupProbe to be preserved, but it was removed")
}
}
})
}
}
package dynamo
import (
"fmt"
"sort"
"strconv"
"strings"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/dynamo/common"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1"
commonconsts "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/consts"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/util/intstr"
)
type TRTLLMBackend struct{}
func (b *TRTLLMBackend) UpdateContainer(container *corev1.Container, numberOfNodes int32, role Role, component *v1alpha1.DynamoComponentDeploymentOverridesSpec, multinodeDeploymentType commonconsts.MultinodeDeploymentType, serviceName string) {
// For single node, nothing to do
if numberOfNodes <= 1 {
return
}
// Configure probes for multinode deployments
if role == RoleWorker {
// For workers: remove liveness and startup probes, set readiness to check SSH port
container.LivenessProbe = nil
container.StartupProbe = nil
container.ReadinessProbe = &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
TCPSocket: &corev1.TCPSocketAction{
Port: intstr.FromInt(commonconsts.MpiRunSshPort),
},
},
InitialDelaySeconds: 20,
PeriodSeconds: 20,
TimeoutSeconds: 5,
FailureThreshold: 10,
}
}
// For leaders: leave all probes untouched
// Add SSH keypair volume mount for multinode deployments
b.addSSHVolumeMount(container)
// Add OpenMPI environment variable to keep FQDN hostnames
envVar := corev1.EnvVar{
Name: "OMPI_MCA_orte_keep_fqdn_hostnames",
Value: "1",
}
container.Env = append(container.Env, envVar)
// Update container command based on role
switch role {
case RoleLeader:
b.setupLeaderContainer(container, numberOfNodes, multinodeDeploymentType, serviceName, component)
case RoleWorker:
b.setupWorkerContainer(container)
}
}
func (b *TRTLLMBackend) UpdatePodSpec(podSpec *corev1.PodSpec, numberOfNodes int32, role Role, component *v1alpha1.DynamoComponentDeploymentOverridesSpec, multinodeDeploymentType commonconsts.MultinodeDeploymentType, serviceName string) {
// Add SSH keypair volume for TRTLLM multinode deployments
if numberOfNodes > 1 {
sshVolume := corev1.Volume{
Name: commonconsts.MpiRunSshSecretName,
VolumeSource: corev1.VolumeSource{
Secret: &corev1.SecretVolumeSource{
SecretName: commonconsts.MpiRunSshSecretName,
DefaultMode: func() *int32 { mode := int32(0644); return &mode }(),
},
},
}
podSpec.Volumes = append(podSpec.Volumes, sshVolume)
}
}
// addSSHVolumeMount adds the SSH keypair secret volume mount to the container
func (b *TRTLLMBackend) addSSHVolumeMount(container *corev1.Container) {
sshVolumeMount := corev1.VolumeMount{
Name: commonconsts.MpiRunSshSecretName,
MountPath: "/ssh-pk",
ReadOnly: true,
}
container.VolumeMounts = append(container.VolumeMounts, sshVolumeMount)
}
// setupLeaderContainer configures the leader node with SSH setup and mpirun command
func (b *TRTLLMBackend) setupLeaderContainer(container *corev1.Container, numberOfNodes int32, multinodeDeploymentType commonconsts.MultinodeDeploymentType, serviceName string, component *v1alpha1.DynamoComponentDeploymentOverridesSpec) {
// Generate the list of worker hostnames
workerHosts := b.generateWorkerHostnames(numberOfNodes, multinodeDeploymentType, serviceName)
// Store original command/args for later use
var originalCommand string
if len(container.Args) > 0 {
originalCommand = strings.Join(container.Args, " ")
} else if len(container.Command) > 0 {
originalCommand = strings.Join(container.Command, " ")
}
// Setup SSH and run mpirun command
sshSetupCommands := []string{
"mkdir -p ~/.ssh",
"ls -la /ssh-pk/", // Debug: list files in ssh-pk directory
"cp /ssh-pk/private.key ~/.ssh/id_rsa",
"cp /ssh-pk/private.key.pub ~/.ssh/id_rsa.pub",
"cp /ssh-pk/private.key.pub ~/.ssh/authorized_keys",
"chmod 600 ~/.ssh/id_rsa ~/.ssh/authorized_keys",
"chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys",
fmt.Sprintf("printf 'Host *\\nIdentityFile ~/.ssh/id_rsa\\nStrictHostKeyChecking no\\nPort %d\\n' > ~/.ssh/config", commonconsts.MpiRunSshPort),
}
// Calculate total number of GPUs across all nodes
gpusPerNode := getGPUsPerNode(component.Resources)
totalGPUs := numberOfNodes * gpusPerNode
// Build mpirun command with explicit SSH configuration and environment variables
// Wrap the entire command (trtllm-llmapi-launch + original command) in bash -c for proper shell interpretation
wrappedCommand := fmt.Sprintf("bash -c 'source /opt/dynamo/venv/bin/activate && trtllm-llmapi-launch %s'", originalCommand)
// Generate environment variable flags for mpirun
envVarsStr := generateEnvVarFlags(container.Env)
mpirunCmd := fmt.Sprintf("mpirun --oversubscribe -n %d -H %s --mca pml ob1 --mca plm_rsh_args \"-p %d -o StrictHostKeyChecking=no -i ~/.ssh/id_rsa\" %s %s",
totalGPUs,
workerHosts,
commonconsts.MpiRunSshPort,
envVarsStr,
wrappedCommand)
// Combine SSH setup and mpirun command
fullCommand := strings.Join(append(sshSetupCommands, mpirunCmd), " && ")
// Update container to use bash with the full command
container.Command = []string{"/bin/sh", "-c"}
container.Args = []string{fullCommand}
}
// setupWorkerContainer configures worker nodes with SSH setup and daemon
func (b *TRTLLMBackend) setupWorkerContainer(container *corev1.Container) {
// Setup SSH for worker nodes
sshSetupCommands := []string{
"mkdir -p ~/.ssh ~/.ssh/host_keys ~/.ssh/run",
"ls -la /ssh-pk/", // Debug: list files in ssh-pk directory
"cp /ssh-pk/private.key ~/.ssh/id_rsa",
"cp /ssh-pk/private.key.pub ~/.ssh/id_rsa.pub",
"cp /ssh-pk/private.key.pub ~/.ssh/authorized_keys",
"chmod 600 ~/.ssh/id_rsa ~/.ssh/authorized_keys",
"chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys",
fmt.Sprintf("printf 'Host *\\nIdentityFile ~/.ssh/id_rsa\\nStrictHostKeyChecking no\\nPort %d\\n' > ~/.ssh/config", commonconsts.MpiRunSshPort),
// Generate host keys in user writable directory
"ssh-keygen -t rsa -f ~/.ssh/host_keys/ssh_host_rsa_key -N ''",
"ssh-keygen -t ecdsa -f ~/.ssh/host_keys/ssh_host_ecdsa_key -N ''",
"ssh-keygen -t ed25519 -f ~/.ssh/host_keys/ssh_host_ed25519_key -N ''",
// Create SSH daemon config to use custom host keys location and non-privileged port
fmt.Sprintf("printf 'Port %d\\nHostKey ~/.ssh/host_keys/ssh_host_rsa_key\\nHostKey ~/.ssh/host_keys/ssh_host_ecdsa_key\\nHostKey ~/.ssh/host_keys/ssh_host_ed25519_key\\nPidFile ~/.ssh/run/sshd.pid\\nPermitRootLogin yes\\nPasswordAuthentication no\\nPubkeyAuthentication yes\\nAuthorizedKeysFile ~/.ssh/authorized_keys\\n' > ~/.ssh/sshd_config", commonconsts.MpiRunSshPort),
"mkdir -p /run/sshd",
"/usr/sbin/sshd -D -f ~/.ssh/sshd_config",
}
fullCommand := strings.Join(sshSetupCommands, " && ")
// Update container to use bash with the SSH setup and daemon
container.Command = []string{"/bin/sh", "-c"}
container.Args = []string{fullCommand}
}
// generateWorkerHostnames creates a comma-separated list of worker hostnames
func (b *TRTLLMBackend) generateWorkerHostnames(numberOfNodes int32, multinodeDeploymentType commonconsts.MultinodeDeploymentType, serviceName string) string {
var hostnames []string
// Add leader hostname first
if multinodeDeploymentType == commonconsts.MultinodeDeploymentTypeGrove {
leaderHostname := generateGroveLeaderHostname(serviceName)
hostnames = append(hostnames, leaderHostname)
// Add worker hostnames
for i := int32(0); i < numberOfNodes-1; i++ {
workerHostname := fmt.Sprintf("${GROVE_PCSG_NAME}-${GROVE_PCSG_INDEX}-%s-%s-%d.${GROVE_HEADLESS_SERVICE}",
serviceName, commonconsts.GroveRoleSuffixWorker, i)
hostnames = append(hostnames, workerHostname)
}
} else {
// For LWS deployment type - using environment variables
hostnames = append(hostnames, "${LWS_LEADER_ADDRESS}")
for i := int32(1); i < numberOfNodes; i++ {
hostnames = append(hostnames, fmt.Sprintf("${LWS_WORKER_%d_ADDRESS}", i))
}
}
return strings.Join(hostnames, ",")
}
// getGPUsPerNode extracts the number of GPUs per node from resources
func getGPUsPerNode(resources *common.Resources) int32 {
if resources != nil && resources.Requests != nil && resources.Requests.GPU != "" {
if gpus, err := strconv.ParseInt(resources.Requests.GPU, 10, 32); err == nil {
return int32(gpus)
}
}
if resources != nil && resources.Limits != nil && resources.Limits.GPU != "" {
if gpus, err := strconv.ParseInt(resources.Limits.GPU, 10, 32); err == nil {
return int32(gpus)
}
}
return 0 // Default to 0 GPUs if not specified
}
// getCommonTRTLLMEnvVars returns a map of common environment variables for TRTLLM deployments
func getCommonTRTLLMEnvVars() map[string]bool {
return map[string]bool{
"CUDA_VISIBLE_DEVICES": true, "MODEL_PATH": true, "HF_TOKEN": true, "HUGGING_FACE_HUB_TOKEN": true,
"TOKENIZERS_PARALLELISM": true, "NCCL_DEBUG": true, "NCCL_IB_DISABLE": true, "NCCL_P2P_DISABLE": true,
"TENSORRT_LLM_CACHE_DIR": true, "HF_HOME": true, "TRANSFORMERS_CACHE": true, "HF_DATASETS_CACHE": true,
"PATH": true, "LD_LIBRARY_PATH": true, "PYTHONPATH": true, "HOME": true, "USER": true,
}
}
// collectAllEnvVars combines explicit container env vars with common TRTLLM env vars, removing duplicates
func collectAllEnvVars(containerEnvVars []corev1.EnvVar) []string {
// Initialize set with common environment variables
envVarSet := getCommonTRTLLMEnvVars()
// Add explicit environment variables from container
for _, env := range containerEnvVars {
envVarSet[env.Name] = true
}
// Convert set to sorted slice for consistent output
envVarNames := make([]string, 0, len(envVarSet))
for envVar := range envVarSet {
envVarNames = append(envVarNames, envVar)
}
sort.Strings(envVarNames)
return envVarNames
}
// formatEnvVarFlags converts environment variable names to mpirun -x flags
func formatEnvVarFlags(envVarNames []string) string {
envVars := make([]string, 0, len(envVarNames))
for _, envVar := range envVarNames {
envVars = append(envVars, fmt.Sprintf("-x %s", envVar))
}
return strings.Join(envVars, " ")
}
// generateEnvVarFlags generates the complete environment variable flags string for mpirun
func generateEnvVarFlags(containerEnvVars []corev1.EnvVar) string {
envVarNames := collectAllEnvVars(containerEnvVars)
return formatEnvVarFlags(envVarNames)
}
package dynamo
import (
"strings"
"testing"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/dynamo/common"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1"
commonconsts "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/consts"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/util/intstr"
)
func TestTRTLLMBackend_UpdateContainer(t *testing.T) {
tests := []struct {
name string
numberOfNodes int32
role Role
multinodeDeploymentType commonconsts.MultinodeDeploymentType
component *v1alpha1.DynamoComponentDeploymentOverridesSpec
expectedVolumeMounts []corev1.VolumeMount
expectedCommand []string
expectedArgs []string
expectedEnv []corev1.EnvVar
expectLivenessRemoved bool
expectReadinessRemoved bool
expectStartupRemoved bool
expectedReadinessProbe *corev1.Probe
}{
{
name: "Single node - no changes",
numberOfNodes: 1,
role: RoleMain,
multinodeDeploymentType: commonconsts.MultinodeDeploymentTypeGrove,
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{},
expectedVolumeMounts: []corev1.VolumeMount{},
expectedCommand: []string{},
expectedArgs: []string{"python3", "--model", "test"},
expectedEnv: []corev1.EnvVar{},
expectLivenessRemoved: false,
expectReadinessRemoved: false,
expectStartupRemoved: false,
expectedReadinessProbe: nil,
},
{
name: "Multinode leader with GPU resources",
numberOfNodes: 3,
role: RoleLeader,
multinodeDeploymentType: commonconsts.MultinodeDeploymentTypeGrove,
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
Resources: &common.Resources{
Requests: &common.ResourceItem{
GPU: "2",
},
},
},
},
expectedVolumeMounts: []corev1.VolumeMount{
{Name: commonconsts.MpiRunSshSecretName, MountPath: "/ssh-pk", ReadOnly: true},
},
expectedCommand: []string{"/bin/sh", "-c"},
expectedArgs: []string{"mkdir -p ~/.ssh && ls -la /ssh-pk/ && cp /ssh-pk/private.key ~/.ssh/id_rsa && cp /ssh-pk/private.key.pub ~/.ssh/id_rsa.pub && cp /ssh-pk/private.key.pub ~/.ssh/authorized_keys && chmod 600 ~/.ssh/id_rsa ~/.ssh/authorized_keys && chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys && printf 'Host *\\nIdentityFile ~/.ssh/id_rsa\\nStrictHostKeyChecking no\\nPort 2222\\n' > ~/.ssh/config && mpirun --oversubscribe -n 6 -H ${GROVE_PCSG_NAME}-${GROVE_PCSG_INDEX}-test-service-ldr-0.${GROVE_HEADLESS_SERVICE},${GROVE_PCSG_NAME}-${GROVE_PCSG_INDEX}-test-service-wkr-0.${GROVE_HEADLESS_SERVICE},${GROVE_PCSG_NAME}-${GROVE_PCSG_INDEX}-test-service-wkr-1.${GROVE_HEADLESS_SERVICE} --mca pml ob1 --mca plm_rsh_args \"-p 2222 -o StrictHostKeyChecking=no -i ~/.ssh/id_rsa\" -x CUDA_VISIBLE_DEVICES -x HF_DATASETS_CACHE -x HF_HOME -x HF_TOKEN -x HOME -x HUGGING_FACE_HUB_TOKEN -x LD_LIBRARY_PATH -x MODEL_PATH -x NCCL_DEBUG -x NCCL_IB_DISABLE -x NCCL_P2P_DISABLE -x OMPI_MCA_orte_keep_fqdn_hostnames -x PATH -x PYTHONPATH -x TENSORRT_LLM_CACHE_DIR -x TOKENIZERS_PARALLELISM -x TRANSFORMERS_CACHE -x USER bash -c 'source /opt/dynamo/venv/bin/activate && trtllm-llmapi-launch python3 --model test'"},
expectedEnv: []corev1.EnvVar{
{Name: "OMPI_MCA_orte_keep_fqdn_hostnames", Value: "1"},
},
expectLivenessRemoved: false,
expectReadinessRemoved: false,
expectStartupRemoved: false,
expectedReadinessProbe: nil,
},
{
name: "Multinode worker",
numberOfNodes: 3,
role: RoleWorker,
multinodeDeploymentType: commonconsts.MultinodeDeploymentTypeGrove,
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{},
expectedVolumeMounts: []corev1.VolumeMount{
{Name: commonconsts.MpiRunSshSecretName, MountPath: "/ssh-pk", ReadOnly: true},
},
expectedCommand: []string{"/bin/sh", "-c"},
expectedArgs: []string{"mkdir -p ~/.ssh ~/.ssh/host_keys ~/.ssh/run && ls -la /ssh-pk/ && cp /ssh-pk/private.key ~/.ssh/id_rsa && cp /ssh-pk/private.key.pub ~/.ssh/id_rsa.pub && cp /ssh-pk/private.key.pub ~/.ssh/authorized_keys && chmod 600 ~/.ssh/id_rsa ~/.ssh/authorized_keys && chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys && printf 'Host *\\nIdentityFile ~/.ssh/id_rsa\\nStrictHostKeyChecking no\\nPort 2222\\n' > ~/.ssh/config && ssh-keygen -t rsa -f ~/.ssh/host_keys/ssh_host_rsa_key -N '' && ssh-keygen -t ecdsa -f ~/.ssh/host_keys/ssh_host_ecdsa_key -N '' && ssh-keygen -t ed25519 -f ~/.ssh/host_keys/ssh_host_ed25519_key -N '' && printf 'Port 2222\\nHostKey ~/.ssh/host_keys/ssh_host_rsa_key\\nHostKey ~/.ssh/host_keys/ssh_host_ecdsa_key\\nHostKey ~/.ssh/host_keys/ssh_host_ed25519_key\\nPidFile ~/.ssh/run/sshd.pid\\nPermitRootLogin yes\\nPasswordAuthentication no\\nPubkeyAuthentication yes\\nAuthorizedKeysFile ~/.ssh/authorized_keys\\n' > ~/.ssh/sshd_config && mkdir -p /run/sshd && /usr/sbin/sshd -D -f ~/.ssh/sshd_config"},
expectedEnv: []corev1.EnvVar{
{Name: "OMPI_MCA_orte_keep_fqdn_hostnames", Value: "1"},
},
expectLivenessRemoved: true,
expectReadinessRemoved: false,
expectStartupRemoved: true,
expectedReadinessProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
TCPSocket: &corev1.TCPSocketAction{
Port: intstr.FromInt(commonconsts.MpiRunSshPort),
},
},
InitialDelaySeconds: 20,
PeriodSeconds: 20,
TimeoutSeconds: 5,
FailureThreshold: 10,
},
},
{
name: "Multinode leader with LWS deployment",
numberOfNodes: 2,
role: RoleLeader,
multinodeDeploymentType: commonconsts.MultinodeDeploymentTypeLWS,
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
Resources: &common.Resources{
Limits: &common.ResourceItem{
GPU: "1",
},
},
},
},
expectedVolumeMounts: []corev1.VolumeMount{
{Name: commonconsts.MpiRunSshSecretName, MountPath: "/ssh-pk", ReadOnly: true},
},
expectedCommand: []string{"/bin/sh", "-c"},
expectedArgs: []string{"mkdir -p ~/.ssh && ls -la /ssh-pk/ && cp /ssh-pk/private.key ~/.ssh/id_rsa && cp /ssh-pk/private.key.pub ~/.ssh/id_rsa.pub && cp /ssh-pk/private.key.pub ~/.ssh/authorized_keys && chmod 600 ~/.ssh/id_rsa ~/.ssh/authorized_keys && chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys && printf 'Host *\\nIdentityFile ~/.ssh/id_rsa\\nStrictHostKeyChecking no\\nPort 2222\\n' > ~/.ssh/config && mpirun --oversubscribe -n 2 -H ${LWS_LEADER_ADDRESS},${LWS_WORKER_1_ADDRESS} --mca pml ob1 --mca plm_rsh_args \"-p 2222 -o StrictHostKeyChecking=no -i ~/.ssh/id_rsa\" -x CUDA_VISIBLE_DEVICES -x HF_DATASETS_CACHE -x HF_HOME -x HF_TOKEN -x HOME -x HUGGING_FACE_HUB_TOKEN -x LD_LIBRARY_PATH -x MODEL_PATH -x NCCL_DEBUG -x NCCL_IB_DISABLE -x NCCL_P2P_DISABLE -x OMPI_MCA_orte_keep_fqdn_hostnames -x PATH -x PYTHONPATH -x TENSORRT_LLM_CACHE_DIR -x TOKENIZERS_PARALLELISM -x TRANSFORMERS_CACHE -x USER bash -c 'source /opt/dynamo/venv/bin/activate && trtllm-llmapi-launch python3 --model test'"},
expectedEnv: []corev1.EnvVar{
{Name: "OMPI_MCA_orte_keep_fqdn_hostnames", Value: "1"},
},
expectLivenessRemoved: false,
expectReadinessRemoved: false,
expectStartupRemoved: false,
expectedReadinessProbe: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
backend := &TRTLLMBackend{}
container := &corev1.Container{
Args: []string{"python3", "--model", "test"},
LivenessProbe: &corev1.Probe{},
ReadinessProbe: &corev1.Probe{},
StartupProbe: &corev1.Probe{},
}
// Call UpdateContainer
backend.UpdateContainer(container, tt.numberOfNodes, tt.role, tt.component, tt.multinodeDeploymentType, "test-service")
// Use helper functions to validate results
validateVolumeMounts(t, container, tt.expectedVolumeMounts)
validateCommand(t, container, tt.expectedCommand)
validateArgs(t, container, tt.expectedArgs)
validateEnvironmentVariables(t, container, tt.expectedEnv)
validateLivenessProbe(t, container, tt.expectLivenessRemoved, tt.role)
validateStartupProbe(t, container, tt.expectStartupRemoved, tt.role)
validateReadinessProbe(t, container, tt.expectReadinessRemoved, tt.expectedReadinessProbe, tt.role)
})
}
}
// Helper functions to reduce cyclomatic complexity of the main test
func validateVolumeMounts(t *testing.T, container *corev1.Container, expected []corev1.VolumeMount) {
if len(container.VolumeMounts) != len(expected) {
t.Errorf("UpdateContainer() volume mounts count = %d, want %d", len(container.VolumeMounts), len(expected))
return
}
for i, expectedVolumeMount := range expected {
actualVolumeMount := container.VolumeMounts[i]
if actualVolumeMount.Name != expectedVolumeMount.Name {
t.Errorf("UpdateContainer() volume mount[%d].Name = %s, want %s", i, actualVolumeMount.Name, expectedVolumeMount.Name)
}
if actualVolumeMount.MountPath != expectedVolumeMount.MountPath {
t.Errorf("UpdateContainer() volume mount[%d].MountPath = %s, want %s", i, actualVolumeMount.MountPath, expectedVolumeMount.MountPath)
}
if actualVolumeMount.ReadOnly != expectedVolumeMount.ReadOnly {
t.Errorf("UpdateContainer() volume mount[%d].ReadOnly = %t, want %t", i, actualVolumeMount.ReadOnly, expectedVolumeMount.ReadOnly)
}
}
}
func validateCommand(t *testing.T, container *corev1.Container, expected []string) {
if len(container.Command) != len(expected) {
t.Errorf("UpdateContainer() command length = %d, want %d", len(container.Command), len(expected))
return
}
for i, expectedCmd := range expected {
if container.Command[i] != expectedCmd {
t.Errorf("UpdateContainer() command[%d] = %s, want %s", i, container.Command[i], expectedCmd)
}
}
}
func validateArgs(t *testing.T, container *corev1.Container, expected []string) {
if len(container.Args) != len(expected) {
t.Errorf("UpdateContainer() args length = %d, want %d", len(container.Args), len(expected))
return
}
for i, expectedArg := range expected {
if container.Args[i] != expectedArg {
t.Errorf("UpdateContainer() args[%d] = %s, want %s", i, container.Args[i], expectedArg)
}
}
}
func validateEnvironmentVariables(t *testing.T, container *corev1.Container, expected []corev1.EnvVar) {
if len(container.Env) != len(expected) {
t.Errorf("UpdateContainer() env count = %d, want %d", len(container.Env), len(expected))
return
}
for i, expectedEnv := range expected {
actualEnv := container.Env[i]
if actualEnv.Name != expectedEnv.Name {
t.Errorf("UpdateContainer() env[%d].Name = %s, want %s", i, actualEnv.Name, expectedEnv.Name)
}
if actualEnv.Value != expectedEnv.Value {
t.Errorf("UpdateContainer() env[%d].Value = %s, want %s", i, actualEnv.Value, expectedEnv.Value)
}
}
}
func validateLivenessProbe(t *testing.T, container *corev1.Container, expectRemoved bool, role Role) {
if expectRemoved {
if container.LivenessProbe != nil {
t.Errorf("UpdateContainer() should remove LivenessProbe for %s", role)
}
} else {
if container.LivenessProbe == nil {
t.Errorf("UpdateContainer() should not remove LivenessProbe for %s", role)
}
}
}
func validateStartupProbe(t *testing.T, container *corev1.Container, expectRemoved bool, role Role) {
if expectRemoved {
if container.StartupProbe != nil {
t.Errorf("UpdateContainer() should remove StartupProbe for %s", role)
}
} else {
if container.StartupProbe == nil {
t.Errorf("UpdateContainer() should not remove StartupProbe for %s", role)
}
}
}
func validateReadinessProbe(t *testing.T, container *corev1.Container, expectRemoved bool, expected *corev1.Probe, role Role) {
if expectRemoved {
if container.ReadinessProbe != nil {
t.Errorf("UpdateContainer() should remove ReadinessProbe for %s", role)
}
} else if expected != nil {
// Check that readiness probe matches expected
if container.ReadinessProbe == nil {
t.Errorf("UpdateContainer() should set ReadinessProbe for %s", role)
} else {
validateProbeDetails(t, container.ReadinessProbe, expected)
}
} else {
// No specific readiness probe expected, should remain as originally set
if container.ReadinessProbe == nil {
t.Errorf("UpdateContainer() should not remove ReadinessProbe for %s", role)
}
}
}
func validateProbeDetails(t *testing.T, actual, expected *corev1.Probe) {
// Compare probe details
if actual.TCPSocket == nil {
t.Errorf("UpdateContainer() ReadinessProbe should have TCPSocket")
} else if actual.TCPSocket.Port.IntVal != expected.TCPSocket.Port.IntVal {
t.Errorf("UpdateContainer() ReadinessProbe port = %d, want %d", actual.TCPSocket.Port.IntVal, expected.TCPSocket.Port.IntVal)
}
if actual.InitialDelaySeconds != expected.InitialDelaySeconds {
t.Errorf("UpdateContainer() ReadinessProbe InitialDelaySeconds = %d, want %d", actual.InitialDelaySeconds, expected.InitialDelaySeconds)
}
if actual.PeriodSeconds != expected.PeriodSeconds {
t.Errorf("UpdateContainer() ReadinessProbe PeriodSeconds = %d, want %d", actual.PeriodSeconds, expected.PeriodSeconds)
}
if actual.TimeoutSeconds != expected.TimeoutSeconds {
t.Errorf("UpdateContainer() ReadinessProbe TimeoutSeconds = %d, want %d", actual.TimeoutSeconds, expected.TimeoutSeconds)
}
if actual.FailureThreshold != expected.FailureThreshold {
t.Errorf("UpdateContainer() ReadinessProbe FailureThreshold = %d, want %d", actual.FailureThreshold, expected.FailureThreshold)
}
}
func TestTRTLLMBackend_UpdatePodSpec(t *testing.T) {
tests := []struct {
name string
numberOfNodes int32
role Role
multinodeDeploymentType commonconsts.MultinodeDeploymentType
initialVolumes []corev1.Volume
expectedVolumeCount int
shouldHaveSSHVolume bool
}{
{
name: "Single node - no SSH volume added",
numberOfNodes: 1,
role: RoleMain,
multinodeDeploymentType: commonconsts.MultinodeDeploymentTypeGrove,
initialVolumes: []corev1.Volume{},
expectedVolumeCount: 0,
shouldHaveSSHVolume: false,
},
{
name: "Multinode leader - SSH volume added",
numberOfNodes: 3,
role: RoleLeader,
multinodeDeploymentType: commonconsts.MultinodeDeploymentTypeGrove,
initialVolumes: []corev1.Volume{},
expectedVolumeCount: 1,
shouldHaveSSHVolume: true,
},
{
name: "Multinode worker - SSH volume added",
numberOfNodes: 2,
role: RoleWorker,
multinodeDeploymentType: commonconsts.MultinodeDeploymentTypeLWS,
initialVolumes: []corev1.Volume{},
expectedVolumeCount: 1,
shouldHaveSSHVolume: true,
},
{
name: "Multinode with existing volumes",
numberOfNodes: 2,
role: RoleLeader,
multinodeDeploymentType: commonconsts.MultinodeDeploymentTypeGrove,
initialVolumes: []corev1.Volume{
{Name: "existing-volume"},
},
expectedVolumeCount: 2,
shouldHaveSSHVolume: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
backend := &TRTLLMBackend{}
podSpec := &corev1.PodSpec{
Volumes: tt.initialVolumes,
Containers: []corev1.Container{
{
Name: "main",
Env: []corev1.EnvVar{},
},
},
}
component := &v1alpha1.DynamoComponentDeploymentOverridesSpec{}
// Call UpdatePodSpec
backend.UpdatePodSpec(podSpec, tt.numberOfNodes, tt.role, component, tt.multinodeDeploymentType, "test-service")
// Check volume count
if len(podSpec.Volumes) != tt.expectedVolumeCount {
t.Errorf("UpdatePodSpec() volume count = %d, want %d", len(podSpec.Volumes), tt.expectedVolumeCount)
}
// Check for SSH volume
hasSSHVolume := false
for _, volume := range podSpec.Volumes {
if volume.Name == commonconsts.MpiRunSshSecretName {
hasSSHVolume = true
// Verify volume configuration
if volume.VolumeSource.Secret == nil {
t.Errorf("UpdatePodSpec() SSH volume should use Secret volume source")
} else {
if volume.VolumeSource.Secret.SecretName != commonconsts.MpiRunSshSecretName {
t.Errorf("UpdatePodSpec() SSH volume secret name = %s, want %s", volume.VolumeSource.Secret.SecretName, commonconsts.MpiRunSshSecretName)
}
if volume.VolumeSource.Secret.DefaultMode == nil || *volume.VolumeSource.Secret.DefaultMode != 0644 {
t.Errorf("UpdatePodSpec() SSH volume should have DefaultMode 0644")
}
}
break
}
}
if tt.shouldHaveSSHVolume && !hasSSHVolume {
t.Errorf("UpdatePodSpec() should add SSH volume for multinode deployment")
}
if !tt.shouldHaveSSHVolume && hasSSHVolume {
t.Errorf("UpdatePodSpec() should not add SSH volume for single node deployment")
}
})
}
}
func TestTRTLLMBackend_generateWorkerHostnames(t *testing.T) {
tests := []struct {
name string
numberOfNodes int32
multinodeDeploymentType commonconsts.MultinodeDeploymentType
serviceName string
expectedContains []string
expectedNodeCount int32
}{
{
name: "Grove deployment with 3 nodes",
numberOfNodes: 3,
multinodeDeploymentType: commonconsts.MultinodeDeploymentTypeGrove,
serviceName: "test-service",
expectedContains: []string{
"test-service-ldr-0",
"test-service-wkr-0",
"test-service-wkr-1",
"GROVE_PCSG_NAME",
"GROVE_HEADLESS_SERVICE",
},
expectedNodeCount: 3,
},
{
name: "LWS deployment with 2 nodes",
numberOfNodes: 2,
multinodeDeploymentType: commonconsts.MultinodeDeploymentTypeLWS,
serviceName: "test-service",
expectedContains: []string{
"${LWS_LEADER_ADDRESS}",
"${LWS_WORKER_1_ADDRESS}",
},
expectedNodeCount: 2,
},
{
name: "Grove deployment with 5 nodes",
numberOfNodes: 5,
multinodeDeploymentType: commonconsts.MultinodeDeploymentTypeGrove,
serviceName: "worker",
expectedContains: []string{
"worker-ldr-0",
"worker-wkr-0",
"worker-wkr-1",
"worker-wkr-2",
"worker-wkr-3",
},
expectedNodeCount: 5,
},
{
name: "LWS deployment with 4 nodes",
numberOfNodes: 4,
multinodeDeploymentType: commonconsts.MultinodeDeploymentTypeLWS,
serviceName: "worker",
expectedContains: []string{
"${LWS_LEADER_ADDRESS}",
"${LWS_WORKER_1_ADDRESS}",
"${LWS_WORKER_2_ADDRESS}",
"${LWS_WORKER_3_ADDRESS}",
},
expectedNodeCount: 4,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
backend := &TRTLLMBackend{}
result := backend.generateWorkerHostnames(tt.numberOfNodes, tt.multinodeDeploymentType, tt.serviceName)
for _, expected := range tt.expectedContains {
if !strings.Contains(result, expected) {
t.Errorf("generateWorkerHostnames() = %s, should contain %s", result, expected)
}
}
// Check that result is comma-separated with correct count
parts := strings.Split(result, ",")
if int32(len(parts)) != tt.expectedNodeCount {
t.Errorf("generateWorkerHostnames() should have %d hostnames, got %d: %v", tt.expectedNodeCount, len(parts), parts)
}
// Verify no empty parts
for i, part := range parts {
if strings.TrimSpace(part) == "" {
t.Errorf("generateWorkerHostnames() has empty hostname at index %d", i)
}
}
})
}
}
func TestTRTLLMBackend_addSSHVolumeMount(t *testing.T) {
expectedSSHVolumeMount := corev1.VolumeMount{
Name: commonconsts.MpiRunSshSecretName,
MountPath: "/ssh-pk",
ReadOnly: true,
}
tests := []struct {
name string
initialVolumeMounts []corev1.VolumeMount
expectedVolumeMounts []corev1.VolumeMount
}{
{
name: "Add SSH volume mount to empty container",
initialVolumeMounts: []corev1.VolumeMount{},
expectedVolumeMounts: []corev1.VolumeMount{expectedSSHVolumeMount},
},
{
name: "Add SSH volume mount to container with existing mounts",
initialVolumeMounts: []corev1.VolumeMount{
{Name: "existing-mount", MountPath: "/existing"},
},
expectedVolumeMounts: []corev1.VolumeMount{
{Name: "existing-mount", MountPath: "/existing"},
expectedSSHVolumeMount,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
backend := &TRTLLMBackend{}
container := &corev1.Container{
VolumeMounts: tt.initialVolumeMounts,
}
backend.addSSHVolumeMount(container)
// Check that volume mounts match expected
if len(container.VolumeMounts) != len(tt.expectedVolumeMounts) {
t.Errorf("addSSHVolumeMount() volume mount count = %d, want %d", len(container.VolumeMounts), len(tt.expectedVolumeMounts))
return
}
for i, expected := range tt.expectedVolumeMounts {
actual := container.VolumeMounts[i]
if actual.Name != expected.Name {
t.Errorf("addSSHVolumeMount() volume mount[%d].Name = %s, want %s", i, actual.Name, expected.Name)
}
if actual.MountPath != expected.MountPath {
t.Errorf("addSSHVolumeMount() volume mount[%d].MountPath = %s, want %s", i, actual.MountPath, expected.MountPath)
}
if actual.ReadOnly != expected.ReadOnly {
t.Errorf("addSSHVolumeMount() volume mount[%d].ReadOnly = %t, want %t", i, actual.ReadOnly, expected.ReadOnly)
}
}
})
}
}
func TestTRTLLMBackend_setupLeaderContainer(t *testing.T) {
tests := []struct {
name string
numberOfNodes int32
multinodeDeploymentType commonconsts.MultinodeDeploymentType
serviceName string
component *v1alpha1.DynamoComponentDeploymentOverridesSpec
initialArgs []string
initialCommand []string
expected string
}{
{
name: "Leader with args and GPU resources",
numberOfNodes: 3,
multinodeDeploymentType: commonconsts.MultinodeDeploymentTypeGrove,
serviceName: "test-service",
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
Resources: &common.Resources{
Requests: &common.ResourceItem{
GPU: "2",
},
},
},
},
initialArgs: []string{"python3", "--model", "test"},
initialCommand: []string{},
expected: "mkdir -p ~/.ssh && ls -la /ssh-pk/ && cp /ssh-pk/private.key ~/.ssh/id_rsa && cp /ssh-pk/private.key.pub ~/.ssh/id_rsa.pub && cp /ssh-pk/private.key.pub ~/.ssh/authorized_keys && chmod 600 ~/.ssh/id_rsa ~/.ssh/authorized_keys && chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys && printf 'Host *\\nIdentityFile ~/.ssh/id_rsa\\nStrictHostKeyChecking no\\nPort 2222\\n' > ~/.ssh/config && mpirun --oversubscribe -n 6 -H ${GROVE_PCSG_NAME}-${GROVE_PCSG_INDEX}-test-service-ldr-0.${GROVE_HEADLESS_SERVICE},${GROVE_PCSG_NAME}-${GROVE_PCSG_INDEX}-test-service-wkr-0.${GROVE_HEADLESS_SERVICE},${GROVE_PCSG_NAME}-${GROVE_PCSG_INDEX}-test-service-wkr-1.${GROVE_HEADLESS_SERVICE} --mca pml ob1 --mca plm_rsh_args \"-p 2222 -o StrictHostKeyChecking=no -i ~/.ssh/id_rsa\" -x CUDA_VISIBLE_DEVICES -x HF_DATASETS_CACHE -x HF_HOME -x HF_TOKEN -x HOME -x HUGGING_FACE_HUB_TOKEN -x LD_LIBRARY_PATH -x MODEL_PATH -x NCCL_DEBUG -x NCCL_IB_DISABLE -x NCCL_P2P_DISABLE -x PATH -x PYTHONPATH -x TENSORRT_LLM_CACHE_DIR -x TOKENIZERS_PARALLELISM -x TRANSFORMERS_CACHE -x USER bash -c 'source /opt/dynamo/venv/bin/activate && trtllm-llmapi-launch python3 --model test'",
},
{
name: "Leader with command and no GPU resources",
numberOfNodes: 2,
multinodeDeploymentType: commonconsts.MultinodeDeploymentTypeLWS,
serviceName: "worker",
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{},
initialArgs: []string{},
initialCommand: []string{"python", "-m", "worker"},
expected: "mkdir -p ~/.ssh && ls -la /ssh-pk/ && cp /ssh-pk/private.key ~/.ssh/id_rsa && cp /ssh-pk/private.key.pub ~/.ssh/id_rsa.pub && cp /ssh-pk/private.key.pub ~/.ssh/authorized_keys && chmod 600 ~/.ssh/id_rsa ~/.ssh/authorized_keys && chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys && printf 'Host *\\nIdentityFile ~/.ssh/id_rsa\\nStrictHostKeyChecking no\\nPort 2222\\n' > ~/.ssh/config && mpirun --oversubscribe -n 0 -H ${LWS_LEADER_ADDRESS},${LWS_WORKER_1_ADDRESS} --mca pml ob1 --mca plm_rsh_args \"-p 2222 -o StrictHostKeyChecking=no -i ~/.ssh/id_rsa\" -x CUDA_VISIBLE_DEVICES -x HF_DATASETS_CACHE -x HF_HOME -x HF_TOKEN -x HOME -x HUGGING_FACE_HUB_TOKEN -x LD_LIBRARY_PATH -x MODEL_PATH -x NCCL_DEBUG -x NCCL_IB_DISABLE -x NCCL_P2P_DISABLE -x PATH -x PYTHONPATH -x TENSORRT_LLM_CACHE_DIR -x TOKENIZERS_PARALLELISM -x TRANSFORMERS_CACHE -x USER bash -c 'source /opt/dynamo/venv/bin/activate && trtllm-llmapi-launch python -m worker'",
},
{
name: "Leader with both command and args (args take precedence)",
numberOfNodes: 2,
multinodeDeploymentType: commonconsts.MultinodeDeploymentTypeGrove,
serviceName: "test",
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
Resources: &common.Resources{
Limits: &common.ResourceItem{
GPU: "1",
},
},
},
},
initialArgs: []string{"launch", "--config", "test.yaml"},
initialCommand: []string{"ignored-command"},
expected: "mkdir -p ~/.ssh && ls -la /ssh-pk/ && cp /ssh-pk/private.key ~/.ssh/id_rsa && cp /ssh-pk/private.key.pub ~/.ssh/id_rsa.pub && cp /ssh-pk/private.key.pub ~/.ssh/authorized_keys && chmod 600 ~/.ssh/id_rsa ~/.ssh/authorized_keys && chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys && printf 'Host *\\nIdentityFile ~/.ssh/id_rsa\\nStrictHostKeyChecking no\\nPort 2222\\n' > ~/.ssh/config && mpirun --oversubscribe -n 2 -H ${GROVE_PCSG_NAME}-${GROVE_PCSG_INDEX}-test-ldr-0.${GROVE_HEADLESS_SERVICE},${GROVE_PCSG_NAME}-${GROVE_PCSG_INDEX}-test-wkr-0.${GROVE_HEADLESS_SERVICE} --mca pml ob1 --mca plm_rsh_args \"-p 2222 -o StrictHostKeyChecking=no -i ~/.ssh/id_rsa\" -x CUDA_VISIBLE_DEVICES -x HF_DATASETS_CACHE -x HF_HOME -x HF_TOKEN -x HOME -x HUGGING_FACE_HUB_TOKEN -x LD_LIBRARY_PATH -x MODEL_PATH -x NCCL_DEBUG -x NCCL_IB_DISABLE -x NCCL_P2P_DISABLE -x PATH -x PYTHONPATH -x TENSORRT_LLM_CACHE_DIR -x TOKENIZERS_PARALLELISM -x TRANSFORMERS_CACHE -x USER bash -c 'source /opt/dynamo/venv/bin/activate && trtllm-llmapi-launch launch --config test.yaml'",
},
{
name: "Leader with all environment variables forwarded",
numberOfNodes: 2,
multinodeDeploymentType: commonconsts.MultinodeDeploymentTypeGrove,
serviceName: "test",
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
Resources: &common.Resources{
Requests: &common.ResourceItem{
GPU: "1",
},
},
},
},
initialArgs: []string{"serve", "--model", "test"},
initialCommand: []string{},
expected: "mkdir -p ~/.ssh && ls -la /ssh-pk/ && cp /ssh-pk/private.key ~/.ssh/id_rsa && cp /ssh-pk/private.key.pub ~/.ssh/id_rsa.pub && cp /ssh-pk/private.key.pub ~/.ssh/authorized_keys && chmod 600 ~/.ssh/id_rsa ~/.ssh/authorized_keys && chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys && printf 'Host *\\nIdentityFile ~/.ssh/id_rsa\\nStrictHostKeyChecking no\\nPort 2222\\n' > ~/.ssh/config && mpirun --oversubscribe -n 2 -H ${GROVE_PCSG_NAME}-${GROVE_PCSG_INDEX}-test-ldr-0.${GROVE_HEADLESS_SERVICE},${GROVE_PCSG_NAME}-${GROVE_PCSG_INDEX}-test-wkr-0.${GROVE_HEADLESS_SERVICE} --mca pml ob1 --mca plm_rsh_args \"-p 2222 -o StrictHostKeyChecking=no -i ~/.ssh/id_rsa\" -x CUDA_VISIBLE_DEVICES -x HF_DATASETS_CACHE -x HF_HOME -x HF_TOKEN -x HOME -x HUGGING_FACE_HUB_TOKEN -x LD_LIBRARY_PATH -x MODEL_PATH -x NCCL_DEBUG -x NCCL_IB_DISABLE -x NCCL_P2P_DISABLE -x PATH -x PYTHONPATH -x TENSORRT_LLM_CACHE_DIR -x TOKENIZERS_PARALLELISM -x TRANSFORMERS_CACHE -x USER bash -c 'source /opt/dynamo/venv/bin/activate && trtllm-llmapi-launch serve --model test'",
},
{
name: "Leader with overlapping environment variables (deduplication test)",
numberOfNodes: 2,
multinodeDeploymentType: commonconsts.MultinodeDeploymentTypeGrove,
serviceName: "test",
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{
DynamoComponentDeploymentSharedSpec: v1alpha1.DynamoComponentDeploymentSharedSpec{
Resources: &common.Resources{
Requests: &common.ResourceItem{
GPU: "1",
},
},
},
},
initialArgs: []string{"serve", "--model", "test"},
initialCommand: []string{},
expected: "mkdir -p ~/.ssh && ls -la /ssh-pk/ && cp /ssh-pk/private.key ~/.ssh/id_rsa && cp /ssh-pk/private.key.pub ~/.ssh/id_rsa.pub && cp /ssh-pk/private.key.pub ~/.ssh/authorized_keys && chmod 600 ~/.ssh/id_rsa ~/.ssh/authorized_keys && chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys && printf 'Host *\\nIdentityFile ~/.ssh/id_rsa\\nStrictHostKeyChecking no\\nPort 2222\\n' > ~/.ssh/config && mpirun --oversubscribe -n 2 -H ${GROVE_PCSG_NAME}-${GROVE_PCSG_INDEX}-test-ldr-0.${GROVE_HEADLESS_SERVICE},${GROVE_PCSG_NAME}-${GROVE_PCSG_INDEX}-test-wkr-0.${GROVE_HEADLESS_SERVICE} --mca pml ob1 --mca plm_rsh_args \"-p 2222 -o StrictHostKeyChecking=no -i ~/.ssh/id_rsa\" -x CUDA_VISIBLE_DEVICES -x CUSTOM_VAR -x HF_DATASETS_CACHE -x HF_HOME -x HF_TOKEN -x HOME -x HUGGING_FACE_HUB_TOKEN -x LD_LIBRARY_PATH -x MODEL_PATH -x NCCL_DEBUG -x NCCL_IB_DISABLE -x NCCL_P2P_DISABLE -x PATH -x PYTHONPATH -x TENSORRT_LLM_CACHE_DIR -x TOKENIZERS_PARALLELISM -x TRANSFORMERS_CACHE -x USER bash -c 'source /opt/dynamo/venv/bin/activate && trtllm-llmapi-launch serve --model test'",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
backend := &TRTLLMBackend{}
container := &corev1.Container{
Args: tt.initialArgs,
Command: tt.initialCommand,
}
// Add test environment variables for the deduplication test
if tt.name == "Leader with overlapping environment variables (deduplication test)" {
container.Env = []corev1.EnvVar{
{Name: "CUDA_VISIBLE_DEVICES", Value: "0,1"}, // This should NOT be duplicated
{Name: "CUSTOM_VAR", Value: "test_value"}, // This should be added
{Name: "PATH", Value: "/custom/path"}, // This should NOT be duplicated
}
}
backend.setupLeaderContainer(container, tt.numberOfNodes, tt.multinodeDeploymentType, tt.serviceName, tt.component)
// Check that command is set correctly
expectedCommand := []string{"/bin/sh", "-c"}
if len(container.Command) != len(expectedCommand) {
t.Errorf("setupLeaderContainer() command = %v, want %v", container.Command, expectedCommand)
} else {
for i, cmd := range expectedCommand {
if container.Command[i] != cmd {
t.Errorf("setupLeaderContainer() command[%d] = %s, want %s", i, container.Command[i], cmd)
}
}
}
// Check args content
if len(container.Args) != 1 {
t.Errorf("setupLeaderContainer() should set exactly one arg, got %d", len(container.Args))
} else {
argsStr := container.Args[0]
if argsStr != tt.expected {
t.Errorf("setupLeaderContainer() args = %q, want %q", argsStr, tt.expected)
}
}
})
}
}
func TestTRTLLMBackend_setupWorkerContainer(t *testing.T) {
tests := []struct {
name string
initialArgs []string
initialCommand []string
expected string
}{
{
name: "Worker setup with initial args",
initialArgs: []string{"some", "args"},
initialCommand: []string{},
expected: "mkdir -p ~/.ssh ~/.ssh/host_keys ~/.ssh/run && ls -la /ssh-pk/ && cp /ssh-pk/private.key ~/.ssh/id_rsa && cp /ssh-pk/private.key.pub ~/.ssh/id_rsa.pub && cp /ssh-pk/private.key.pub ~/.ssh/authorized_keys && chmod 600 ~/.ssh/id_rsa ~/.ssh/authorized_keys && chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys && printf 'Host *\\nIdentityFile ~/.ssh/id_rsa\\nStrictHostKeyChecking no\\nPort 2222\\n' > ~/.ssh/config && ssh-keygen -t rsa -f ~/.ssh/host_keys/ssh_host_rsa_key -N '' && ssh-keygen -t ecdsa -f ~/.ssh/host_keys/ssh_host_ecdsa_key -N '' && ssh-keygen -t ed25519 -f ~/.ssh/host_keys/ssh_host_ed25519_key -N '' && printf 'Port 2222\\nHostKey ~/.ssh/host_keys/ssh_host_rsa_key\\nHostKey ~/.ssh/host_keys/ssh_host_ecdsa_key\\nHostKey ~/.ssh/host_keys/ssh_host_ed25519_key\\nPidFile ~/.ssh/run/sshd.pid\\nPermitRootLogin yes\\nPasswordAuthentication no\\nPubkeyAuthentication yes\\nAuthorizedKeysFile ~/.ssh/authorized_keys\\n' > ~/.ssh/sshd_config && mkdir -p /run/sshd && /usr/sbin/sshd -D -f ~/.ssh/sshd_config",
},
{
name: "Worker setup with initial command",
initialArgs: []string{},
initialCommand: []string{"original", "command"},
expected: "mkdir -p ~/.ssh ~/.ssh/host_keys ~/.ssh/run && ls -la /ssh-pk/ && cp /ssh-pk/private.key ~/.ssh/id_rsa && cp /ssh-pk/private.key.pub ~/.ssh/id_rsa.pub && cp /ssh-pk/private.key.pub ~/.ssh/authorized_keys && chmod 600 ~/.ssh/id_rsa ~/.ssh/authorized_keys && chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys && printf 'Host *\\nIdentityFile ~/.ssh/id_rsa\\nStrictHostKeyChecking no\\nPort 2222\\n' > ~/.ssh/config && ssh-keygen -t rsa -f ~/.ssh/host_keys/ssh_host_rsa_key -N '' && ssh-keygen -t ecdsa -f ~/.ssh/host_keys/ssh_host_ecdsa_key -N '' && ssh-keygen -t ed25519 -f ~/.ssh/host_keys/ssh_host_ed25519_key -N '' && printf 'Port 2222\\nHostKey ~/.ssh/host_keys/ssh_host_rsa_key\\nHostKey ~/.ssh/host_keys/ssh_host_ecdsa_key\\nHostKey ~/.ssh/host_keys/ssh_host_ed25519_key\\nPidFile ~/.ssh/run/sshd.pid\\nPermitRootLogin yes\\nPasswordAuthentication no\\nPubkeyAuthentication yes\\nAuthorizedKeysFile ~/.ssh/authorized_keys\\n' > ~/.ssh/sshd_config && mkdir -p /run/sshd && /usr/sbin/sshd -D -f ~/.ssh/sshd_config",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
backend := &TRTLLMBackend{}
container := &corev1.Container{
Args: tt.initialArgs,
Command: tt.initialCommand,
}
backend.setupWorkerContainer(container)
// Check that command is set correctly
expectedCommand := []string{"/bin/sh", "-c"}
if len(container.Command) != len(expectedCommand) {
t.Errorf("setupWorkerContainer() command = %v, want %v", container.Command, expectedCommand)
} else {
for i, cmd := range expectedCommand {
if container.Command[i] != cmd {
t.Errorf("setupWorkerContainer() command[%d] = %s, want %s", i, container.Command[i], cmd)
}
}
}
// Check args content
if len(container.Args) != 1 {
t.Errorf("setupWorkerContainer() should set exactly one arg, got %d", len(container.Args))
} else {
argsStr := container.Args[0]
if argsStr != tt.expected {
t.Errorf("setupWorkerContainer() args = %q, want %q", argsStr, tt.expected)
}
}
})
}
}
func TestTRTLLMBackend_getGPUsPerNode(t *testing.T) {
tests := []struct {
name string
resources *common.Resources
expected int32
}{
{
name: "No resources - default to 0",
resources: nil,
expected: 0,
},
{
name: "Empty resources - default to 0",
resources: &common.Resources{},
expected: 0,
},
{
name: "GPU in requests",
resources: &common.Resources{
Requests: &common.ResourceItem{
GPU: "2",
},
},
expected: 2,
},
{
name: "GPU in limits",
resources: &common.Resources{
Limits: &common.ResourceItem{
GPU: "4",
},
},
expected: 4,
},
{
name: "GPU in both requests and limits - requests takes precedence",
resources: &common.Resources{
Requests: &common.ResourceItem{
GPU: "3",
},
Limits: &common.ResourceItem{
GPU: "8",
},
},
expected: 3,
},
{
name: "Invalid GPU value - default to 0",
resources: &common.Resources{
Requests: &common.ResourceItem{
GPU: "invalid",
},
},
expected: 0,
},
{
name: "Empty GPU string - default to 0",
resources: &common.Resources{
Requests: &common.ResourceItem{
GPU: "",
},
},
expected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := getGPUsPerNode(tt.resources)
if result != tt.expected {
t.Errorf("getGPUsPerNode() = %d, want %d", result, tt.expected)
}
})
}
}
package dynamo
import (
"fmt"
"strings"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1"
commonconsts "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/consts"
corev1 "k8s.io/api/core/v1"
)
type VLLMBackend struct{}
func (b *VLLMBackend) UpdateContainer(container *corev1.Container, numberOfNodes int32, role Role, component *v1alpha1.DynamoComponentDeploymentOverridesSpec, multinodeDeploymentType commonconsts.MultinodeDeploymentType, serviceName string) {
isMultinode := numberOfNodes > 1
if isMultinode {
// Apply multinode-specific argument modifications
updateVLLMMultinodeArgs(container, role, multinodeDeploymentType, serviceName)
// Remove probes for multinode worker and leader
if role == RoleWorker || role == RoleLeader {
container.LivenessProbe = nil
container.ReadinessProbe = nil
container.StartupProbe = nil
}
}
}
func (b *VLLMBackend) UpdatePodSpec(podSpec *corev1.PodSpec, numberOfNodes int32, role Role, component *v1alpha1.DynamoComponentDeploymentOverridesSpec, multinodeDeploymentType commonconsts.MultinodeDeploymentType, serviceName string) {
// do nothing
}
// updateVLLMMultinodeArgs applies Ray-specific modifications for multinode deployments
func updateVLLMMultinodeArgs(container *corev1.Container, role Role, multinodeDeploymentType commonconsts.MultinodeDeploymentType, serviceName string) {
switch role {
case RoleLeader:
if len(container.Args) > 0 {
// Prepend ray start --head command to existing args
container.Args = []string{fmt.Sprintf("ray start --head --port=6379 && %s", strings.Join(container.Args, " "))}
}
case RoleWorker:
// Worker nodes only run Ray, completely replace args
if multinodeDeploymentType == commonconsts.MultinodeDeploymentTypeGrove {
leaderHostname := generateGroveLeaderHostname(serviceName)
container.Args = []string{fmt.Sprintf("ray start --address=%s:6379 --block", leaderHostname)}
} else {
container.Args = []string{"ray start --address=${LWS_LEADER_ADDRESS}:6379 --block"}
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment