Unverified Commit 37bc8444 authored by Julien Mancuso's avatar Julien Mancuso Committed by GitHub
Browse files

feat: add trtllm and vllm multinode k8s examples (#3100)


Signed-off-by: default avatarJulien Mancuso <jmancuso@nvidia.com>
parent 9b893c93
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
apiVersion: v1
kind: ConfigMap
metadata:
name: nvidia-config
data:
prefill.yaml: |
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
tensor_parallel_size: 8
moe_expert_parallel_size: 1
enable_attention_dp: false
max_num_tokens: 8192
trust_remote_code: true
backend: pytorch
enable_chunked_prefill: true
# Overlap scheduler not currently supported in prefill only workers.
disable_overlap_scheduler: true
kv_cache_config:
free_gpu_memory_fraction: 0.80
cache_transceiver_config:
backend: default
decode.yaml: |
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
tensor_parallel_size: 8
moe_expert_parallel_size: 1
enable_attention_dp: false
max_num_tokens: 8192
trust_remote_code: true
backend: pytorch
enable_chunked_prefill: true
disable_overlap_scheduler: false
kv_cache_config:
free_gpu_memory_fraction: 0.80
cache_transceiver_config:
backend: default
---
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: models
spec:
accessModes:
- ReadWriteMany
resources:
requests:
storage: 100Gi
---
apiVersion: nvidia.com/v1alpha1
kind: DynamoGraphDeployment
metadata:
name: trtllm-disagg-tp8
spec:
backendFramework: trtllm
envs:
- name: OMPI_ALLOW_RUN_AS_ROOT
value: "1"
- name: OMPI_ALLOW_RUN_AS_ROOT_CONFIRM
value: "1"
- name: HF_HOME
value: "/models"
services:
Frontend:
dynamoNamespace: trtllm-disagg
componentType: frontend
replicas: 1
extraPodSpec:
mainContainer:
image: my-registry/trtllm-runtime:my-tag
workingDir: /workspace/components/backends/trtllm
command:
- /bin/sh
- -c
args:
- "python3 -m dynamo.frontend --http-port 8000"
prefill:
pvc:
name: models
mountPoint: /models
dynamoNamespace: trtllm-disagg
envFromSecret: hf-token-secret
componentType: worker
replicas: 1
multinode:
nodeCount: 2
resources:
limits:
gpu: "4"
extraPodSpec:
volumes:
- name: nvidia-config
configMap:
name: nvidia-config
mainContainer:
volumeMounts:
- name: nvidia-config
mountPath: /workspace/components/backends/trtllm/engine_configs
readOnly: true
image: my-registry/trtllm-runtime:my-tag
workingDir: /workspace/components/backends/trtllm
command:
- /bin/sh
- -c
args:
- "python3 -m dynamo.trtllm --model-path Qwen/Qwen3-0.6B --served-model-name Qwen/Qwen3-0.6B --extra-engine-args engine_configs/prefill.yaml --disaggregation-mode prefill --disaggregation-strategy decode_first"
decode:
pvc:
name: models
mountPoint: /models
dynamoNamespace: trtllm-disagg
envFromSecret: hf-token-secret
componentType: worker
replicas: 1
multinode:
nodeCount: 2
resources:
limits:
gpu: "4"
extraPodSpec:
volumes:
- name: nvidia-config
configMap:
name: nvidia-config
mainContainer:
volumeMounts:
- name: nvidia-config
mountPath: /workspace/components/backends/trtllm/engine_configs
readOnly: true
image: my-registry/trtllm-runtime:my-tag
workingDir: /workspace/components/backends/trtllm
command:
- /bin/sh
- -c
args:
- "python3 -m dynamo.trtllm --model-path Qwen/Qwen3-0.6B --served-model-name Qwen/Qwen3-0.6B --extra-engine-args engine_configs/decode.yaml --disaggregation-mode decode --disaggregation-strategy decode_first"
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
apiVersion: nvidia.com/v1alpha1
kind: DynamoGraphDeployment
metadata:
name: vllm-disagg
spec:
services:
Frontend:
dynamoNamespace: vllm-disagg
componentType: frontend
replicas: 1
extraPodSpec:
mainContainer:
image: my-registry/vllm-runtime:my-tag
workingDir: /workspace/components/backends/vllm
command:
- /bin/sh
- -c
args:
- "python3 -m dynamo.frontend --http-port 8000"
decode:
dynamoNamespace: vllm-disagg
envFromSecret: hf-token-secret
componentType: worker
replicas: 1
multinode:
nodeCount: 2
resources:
limits:
gpu: "1"
extraPodSpec:
mainContainer:
image: my-registry/vllm-runtime:my-tag
workingDir: /workspace/components/backends/vllm
command:
- /bin/sh
- -c
args:
- "python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --tensor-parallel-size 2"
prefill:
dynamoNamespace: vllm-disagg
envFromSecret: hf-token-secret
componentType: worker
replicas: 1
multinode:
nodeCount: 2
resources:
limits:
gpu: "1"
extraPodSpec:
mainContainer:
image: my-registry/vllm-runtime:my-tag
workingDir: /workspace/components/backends/vllm
command:
- /bin/sh
- -c
args:
- "python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --is-prefill-worker --tensor-parallel-size 2"
...@@ -47,8 +47,8 @@ The Dynamo Platform Helm chart deploys the complete Dynamo Cloud infrastructure ...@@ -47,8 +47,8 @@ The Dynamo Platform Helm chart deploys the complete Dynamo Cloud infrastructure
| file://components/operator | dynamo-operator | 0.5.0 | | file://components/operator | dynamo-operator | 0.5.0 |
| https://charts.bitnami.com/bitnami | etcd | 12.0.18 | | https://charts.bitnami.com/bitnami | etcd | 12.0.18 |
| https://nats-io.github.io/k8s/helm/charts/ | nats | 1.3.2 | | https://nats-io.github.io/k8s/helm/charts/ | nats | 1.3.2 |
| oci://ghcr.io/nvidia/grove | grove(grove-charts) | v0.0.0-6e30275 | | oci://ghcr.io/nvidia/grove | grove(grove-charts) | v0.1.0-alpha.1 |
| oci://ghcr.io/nvidia/kai-scheduler | kai-scheduler | v0.8.4 | | oci://ghcr.io/nvidia/kai-scheduler | kai-scheduler | v0.9.2 |
## Values ## Values
...@@ -85,6 +85,8 @@ The Dynamo Platform Helm chart deploys the complete Dynamo Cloud infrastructure ...@@ -85,6 +85,8 @@ The Dynamo Platform Helm chart deploys the complete Dynamo Cloud infrastructure
| dynamo-operator.dynamo.ingressHostSuffix | string | `""` | Host suffix for generated ingress hostnames | | dynamo-operator.dynamo.ingressHostSuffix | string | `""` | Host suffix for generated ingress hostnames |
| dynamo-operator.dynamo.virtualServiceSupportsHTTPS | bool | `false` | Whether VirtualServices should support HTTPS routing | | dynamo-operator.dynamo.virtualServiceSupportsHTTPS | bool | `false` | Whether VirtualServices should support HTTPS routing |
| dynamo-operator.dynamo.metrics.prometheusEndpoint | string | `""` | Endpoint that services can use to retrieve metrics. If set, dynamo operator will automatically inject the PROMETHEUS_ENDPOINT environment variable into services it manages. Users can override the value of the PROMETHEUS_ENDPOINT environment variable by modifying the corresponding deployment's environment variables | | dynamo-operator.dynamo.metrics.prometheusEndpoint | string | `""` | Endpoint that services can use to retrieve metrics. If set, dynamo operator will automatically inject the PROMETHEUS_ENDPOINT environment variable into services it manages. Users can override the value of the PROMETHEUS_ENDPOINT environment variable by modifying the corresponding deployment's environment variables |
| dynamo-operator.dynamo.mpiRun.secretName | string | `"mpi-run-ssh-secret"` | Name of the secret containing the SSH key for MPI Run |
| dynamo-operator.dynamo.mpiRun.sshKeygen.enabled | bool | `true` | Whether to enable SSH key generation for MPI Run |
| grove.enabled | bool | `false` | Whether to enable Grove for multi-node inference coordination, if enabled, the Grove operator will be deployed cluster-wide | | grove.enabled | bool | `false` | Whether to enable Grove for multi-node inference coordination, if enabled, the Grove operator will be deployed cluster-wide |
| kai-scheduler.enabled | bool | `false` | Whether to enable Kai Scheduler for intelligent resource allocation, if enabled, the Kai Scheduler operator will be deployed cluster-wide | | kai-scheduler.enabled | bool | `false` | Whether to enable Kai Scheduler for intelligent resource allocation, if enabled, the Kai Scheduler operator will be deployed cluster-wide |
| etcd.enabled | bool | `true` | Whether to enable etcd deployment, disable if you want to use an external etcd instance | | etcd.enabled | bool | `true` | Whether to enable etcd deployment, disable if you want to use an external etcd instance |
......
...@@ -110,6 +110,10 @@ spec: ...@@ -110,6 +110,10 @@ spec:
{{- if .Values.dynamo.metrics.prometheusEndpoint }} {{- if .Values.dynamo.metrics.prometheusEndpoint }}
- --prometheus-endpoint={{ .Values.dynamo.metrics.prometheusEndpoint }} - --prometheus-endpoint={{ .Values.dynamo.metrics.prometheusEndpoint }}
{{- end }} {{- end }}
{{- if .Values.dynamo.mpiRun.secretName }}
- --mpi-run-ssh-secret-name={{ .Values.dynamo.mpiRun.secretName }}
- --mpi-run-ssh-secret-namespace={{ .Release.Namespace }}
{{- end }}
command: command:
- /manager - /manager
env: env:
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This job is used to generate an SSH key pair and create a Kubernetes secret with the key pair.
# The secret is used when mpi is in use by dynamo workers.
{{- if .Values.dynamo.mpiRun.sshKeygen.enabled }}
apiVersion: batch/v1
kind: Job
metadata:
name: {{ include "dynamo-operator.fullname" . }}-ssh-keygen
annotations:
"helm.sh/hook": pre-install,pre-upgrade
"helm.sh/hook-weight": "-5"
"helm.sh/hook-delete-policy": before-hook-creation,hook-succeeded
spec:
backoffLimit: 1
activeDeadlineSeconds: 300
template:
spec:
restartPolicy: Never
serviceAccountName: {{ include "dynamo-operator.fullname" . }}-ssh-keygen
securityContext:
runAsNonRoot: true
runAsUser: 65534
fsGroup: 65534
initContainers:
- name: keygen
image: bitnamisecure/git:latest
volumeMounts:
- name: shared
mountPath: /shared
env:
- name: SECRET_NAME
value: "{{ .Values.dynamo.mpiRun.secretName }}"
- name: NAMESPACE
value: "{{ .Release.Namespace }}"
command:
- /bin/bash
- -e
- -c
- |
echo "Generating SSH key pair with ssh-keygen..."
ssh-keygen -t rsa -b 2048 -f /shared/private.key -N ""
echo "SSH keys generated and saved to shared volume"
containers:
- name: kubectl-create-secret
image: bitnamisecure/kubectl:latest
volumeMounts:
- name: shared
mountPath: /shared
env:
- name: SECRET_NAME
value: "{{ .Values.dynamo.mpiRun.secretName }}"
- name: NAMESPACE
value: "{{ .Release.Namespace }}"
command:
- /bin/bash
- -e
- -c
- |
# Check if secret already exists
if kubectl get secret "$SECRET_NAME" -n "$NAMESPACE" &>/dev/null; then
echo "Secret $SECRET_NAME already exists, skipping creation"
exit 0
fi
echo "Creating Kubernetes secret..."
kubectl create secret generic "$SECRET_NAME" \
--from-file=private.key=/shared/private.key \
--from-file=private.key.pub=/shared/private.key.pub \
-n "$NAMESPACE"
echo "SSH key secret created successfully"
volumes:
- name: shared
emptyDir: {}
---
apiVersion: v1
kind: ServiceAccount
metadata:
name: {{ include "dynamo-operator.fullname" . }}-ssh-keygen
labels:
{{- include "dynamo-operator.labels" . | nindent 4 }}
annotations:
"helm.sh/hook": pre-install,pre-upgrade
"helm.sh/hook-weight": "-10"
---
apiVersion: rbac.authorization.k8s.io/v1
kind: Role
metadata:
name: {{ include "dynamo-operator.fullname" . }}-ssh-keygen
labels:
{{- include "dynamo-operator.labels" . | nindent 4 }}
annotations:
"helm.sh/hook": pre-install,pre-upgrade
"helm.sh/hook-weight": "-10"
rules:
- apiGroups: [""]
resources: ["secrets"]
verbs: ["get", "create", "update"]
---
apiVersion: rbac.authorization.k8s.io/v1
kind: RoleBinding
metadata:
name: {{ include "dynamo-operator.fullname" . }}-ssh-keygen
labels:
{{- include "dynamo-operator.labels" . | nindent 4 }}
annotations:
"helm.sh/hook": pre-install,pre-upgrade
"helm.sh/hook-weight": "-10"
roleRef:
apiGroup: rbac.authorization.k8s.io
kind: Role
name: {{ include "dynamo-operator.fullname" . }}-ssh-keygen
subjects:
- kind: ServiceAccount
name: {{ include "dynamo-operator.fullname" . }}-ssh-keygen
namespace: {{ .Release.Namespace }}
---
{{- end }}
...@@ -99,6 +99,11 @@ dynamo: ...@@ -99,6 +99,11 @@ dynamo:
metrics: metrics:
prometheusEndpoint: "" prometheusEndpoint: ""
mpiRun:
secretName: "mpi-run-ssh-secret"
sshKeygen:
enabled: true
#imagePullSecrets: [] #imagePullSecrets: []
kubernetesClusterDomain: cluster.local kubernetesClusterDomain: cluster.local
......
...@@ -116,6 +116,15 @@ dynamo-operator: ...@@ -116,6 +116,15 @@ dynamo-operator:
# -- Endpoint that services can use to retrieve metrics. If set, dynamo operator will automatically inject the PROMETHEUS_ENDPOINT environment variable into services it manages. Users can override the value of the PROMETHEUS_ENDPOINT environment variable by modifying the corresponding deployment's environment variables # -- Endpoint that services can use to retrieve metrics. If set, dynamo operator will automatically inject the PROMETHEUS_ENDPOINT environment variable into services it manages. Users can override the value of the PROMETHEUS_ENDPOINT environment variable by modifying the corresponding deployment's environment variables
prometheusEndpoint: "" prometheusEndpoint: ""
# MPI Run configuration
mpiRun:
# -- Name of the secret containing the SSH key for MPI Run
secretName: "mpi-run-ssh-secret"
# SSH key generation configuration
sshKeygen:
# -- Whether to enable SSH key generation for MPI Run
enabled: true
# Grove component - distributed inference orchestration # Grove component - distributed inference orchestration
grove: grove:
......
...@@ -60,6 +60,7 @@ import ( ...@@ -60,6 +60,7 @@ import (
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/controller" "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/controller"
commonController "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/controller_common" commonController "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/controller_common"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/etcd" "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/etcd"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/secret"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/secrets" "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/secrets"
istioclientsetscheme "istio.io/client-go/pkg/clientset/versioned/scheme" istioclientsetscheme "istio.io/client-go/pkg/clientset/versioned/scheme"
//+kubebuilder:scaffold:imports //+kubebuilder:scaffold:imports
...@@ -133,6 +134,8 @@ func main() { ...@@ -133,6 +134,8 @@ func main() {
var groveTerminationDelay time.Duration var groveTerminationDelay time.Duration
var modelExpressURL string var modelExpressURL string
var prometheusEndpoint string var prometheusEndpoint string
var mpiRunSecretName string
var mpiRunSecretNamespace string
flag.StringVar(&metricsAddr, "metrics-bind-address", ":8080", "The address the metric endpoint binds to.") flag.StringVar(&metricsAddr, "metrics-bind-address", ":8080", "The address the metric endpoint binds to.")
flag.StringVar(&probeAddr, "health-probe-bind-address", ":8081", "The address the probe endpoint binds to.") flag.StringVar(&probeAddr, "health-probe-bind-address", ":8081", "The address the probe endpoint binds to.")
flag.BoolVar(&enableLeaderElection, "leader-elect", false, flag.BoolVar(&enableLeaderElection, "leader-elect", false,
...@@ -164,6 +167,10 @@ func main() { ...@@ -164,6 +167,10 @@ func main() {
"URL of the Model Express server to inject into all pods") "URL of the Model Express server to inject into all pods")
flag.StringVar(&prometheusEndpoint, "prometheus-endpoint", "", flag.StringVar(&prometheusEndpoint, "prometheus-endpoint", "",
"URL of the Prometheus endpoint to use for metrics") "URL of the Prometheus endpoint to use for metrics")
flag.StringVar(&mpiRunSecretName, "mpi-run-ssh-secret-name", "",
"Name of the secret containing the SSH key for MPI Run (required)")
flag.StringVar(&mpiRunSecretNamespace, "mpi-run-ssh-secret-namespace", "",
"Namespace where the MPI SSH secret is located (required)")
opts := zap.Options{ opts := zap.Options{
Development: true, Development: true,
} }
...@@ -179,6 +186,16 @@ func main() { ...@@ -179,6 +186,16 @@ func main() {
setupLog.Info("Model Express URL configured", "url", modelExpressURL) setupLog.Info("Model Express URL configured", "url", modelExpressURL)
} }
if mpiRunSecretName == "" {
setupLog.Error(nil, "mpi-run-ssh-secret-name is required")
os.Exit(1)
}
if mpiRunSecretNamespace == "" {
setupLog.Error(nil, "mpi-run-ssh-secret-namespace is required")
os.Exit(1)
}
ctrlConfig := commonController.Config{ ctrlConfig := commonController.Config{
RestrictedNamespace: restrictedNamespace, RestrictedNamespace: restrictedNamespace,
Grove: commonController.GroveConfig{ Grove: commonController.GroveConfig{
...@@ -201,6 +218,9 @@ func main() { ...@@ -201,6 +218,9 @@ func main() {
}, },
ModelExpressURL: modelExpressURL, ModelExpressURL: modelExpressURL,
PrometheusEndpoint: prometheusEndpoint, PrometheusEndpoint: prometheusEndpoint,
MpiRun: commonController.MpiRunConfig{
SecretName: mpiRunSecretName,
},
} }
mainCtx := ctrl.SetupSignalHandler() mainCtx := ctrl.SetupSignalHandler()
...@@ -371,6 +391,14 @@ func main() { ...@@ -371,6 +391,14 @@ func main() {
} }
} }
}() }()
// Create MPI SSH SecretReplicator for cross-namespace secret replication
mpiSecretReplicator := secret.NewSecretReplicator(
mgr.GetClient(),
mpiRunSecretNamespace,
mpiRunSecretName,
)
if err = (&controller.DynamoComponentDeploymentReconciler{ if err = (&controller.DynamoComponentDeploymentReconciler{
Client: mgr.GetClient(), Client: mgr.GetClient(),
Recorder: mgr.GetEventRecorderFor("dynamocomponentdeployment"), Recorder: mgr.GetEventRecorderFor("dynamocomponentdeployment"),
...@@ -394,6 +422,7 @@ func main() { ...@@ -394,6 +422,7 @@ func main() {
Config: ctrlConfig, Config: ctrlConfig,
DockerSecretRetriever: dockerSecretRetriever, DockerSecretRetriever: dockerSecretRetriever,
ScaleClient: scaleClient, ScaleClient: scaleClient,
MPISecretReplicator: mpiSecretReplicator,
}).SetupWithManager(mgr); err != nil { }).SetupWithManager(mgr); err != nil {
setupLog.Error(err, "unable to create controller", "controller", "DynamoGraphDeployment") setupLog.Error(err, "unable to create controller", "controller", "DynamoGraphDeployment")
os.Exit(1) os.Exit(1)
......
...@@ -73,8 +73,6 @@ const ( ...@@ -73,8 +73,6 @@ const (
// Grove multinode role suffixes // Grove multinode role suffixes
GroveRoleSuffixLeader = "ldr" GroveRoleSuffixLeader = "ldr"
GroveRoleSuffixWorker = "wkr" GroveRoleSuffixWorker = "wkr"
MpiRunSshSecretName = "mpi-run-ssh-secret"
) )
type MultinodeDeploymentType string type MultinodeDeploymentType string
......
...@@ -25,6 +25,8 @@ import ( ...@@ -25,6 +25,8 @@ import (
grovev1alpha1 "github.com/NVIDIA/grove/operator/api/core/v1alpha1" grovev1alpha1 "github.com/NVIDIA/grove/operator/api/core/v1alpha1"
"k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/errors"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/secret"
networkingv1beta1 "istio.io/client-go/pkg/apis/networking/v1beta1" networkingv1beta1 "istio.io/client-go/pkg/apis/networking/v1beta1"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
networkingv1 "k8s.io/api/networking/v1" networkingv1 "k8s.io/api/networking/v1"
...@@ -66,6 +68,7 @@ type DynamoGraphDeploymentReconciler struct { ...@@ -66,6 +68,7 @@ type DynamoGraphDeploymentReconciler struct {
Recorder record.EventRecorder Recorder record.EventRecorder
DockerSecretRetriever dockerSecretRetriever DockerSecretRetriever dockerSecretRetriever
ScaleClient scale.ScalesGetter ScaleClient scale.ScalesGetter
MPISecretReplicator *secret.SecretReplicator
} }
// +kubebuilder:rbac:groups=nvidia.com,resources=dynamographdeployments,verbs=get;list;watch;create;update;patch;delete // +kubebuilder:rbac:groups=nvidia.com,resources=dynamographdeployments,verbs=get;list;watch;create;update;patch;delete
...@@ -167,6 +170,15 @@ func (r *DynamoGraphDeploymentReconciler) reconcileResources(ctx context.Context ...@@ -167,6 +170,15 @@ func (r *DynamoGraphDeploymentReconciler) reconcileResources(ctx context.Context
// Determine if any service is multinode // Determine if any service is multinode
hasMultinode := dynamoDeployment.HasAnyMultinodeService() hasMultinode := dynamoDeployment.HasAnyMultinodeService()
// Always ensure MPI SSH secret is available in this namespace
if r.MPISecretReplicator != nil {
err := r.MPISecretReplicator.Replicate(ctx, dynamoDeployment.Namespace)
if err != nil {
logger.Error(err, "Failed to replicate MPI secret", "namespace", dynamoDeployment.Namespace)
return "", "", "", fmt.Errorf("failed to replicate MPI secret: %w", err)
}
}
if enableGrove && r.Config.Grove.Enabled { if enableGrove && r.Config.Grove.Enabled {
logger.Info("Reconciling Grove resources", "enableGrove", enableGrove, "groveEnabled", r.Config.Grove.Enabled, "hasMultinode", hasMultinode, "lwsEnabled", r.Config.LWS.Enabled) logger.Info("Reconciling Grove resources", "enableGrove", enableGrove, "groveEnabled", r.Config.Grove.Enabled, "hasMultinode", hasMultinode, "lwsEnabled", r.Config.LWS.Enabled)
return r.reconcileGroveResources(ctx, dynamoDeployment) return r.reconcileGroveResources(ctx, dynamoDeployment)
......
...@@ -47,6 +47,11 @@ type KaiSchedulerConfig struct { ...@@ -47,6 +47,11 @@ type KaiSchedulerConfig struct {
Enabled bool Enabled bool
} }
type MpiRunConfig struct {
// SecretName is the name of the secret containing the SSH key for MPI Run
SecretName string
}
type Config struct { type Config struct {
// Enable resources filtering, only the resources belonging to the given namespace will be handled. // Enable resources filtering, only the resources belonging to the given namespace will be handled.
RestrictedNamespace string RestrictedNamespace string
...@@ -60,6 +65,7 @@ type Config struct { ...@@ -60,6 +65,7 @@ type Config struct {
ModelExpressURL string ModelExpressURL string
// PrometheusEndpoint is the URL of the Prometheus endpoint to use for metrics // PrometheusEndpoint is the URL of the Prometheus endpoint to use for metrics
PrometheusEndpoint string PrometheusEndpoint string
MpiRun MpiRunConfig
} }
type IngressConfig struct { type IngressConfig struct {
......
...@@ -13,7 +13,9 @@ import ( ...@@ -13,7 +13,9 @@ import (
"k8s.io/apimachinery/pkg/util/intstr" "k8s.io/apimachinery/pkg/util/intstr"
) )
type TRTLLMBackend struct{} type TRTLLMBackend struct {
MpiRunSecretName string
}
func (b *TRTLLMBackend) UpdateContainer(container *corev1.Container, numberOfNodes int32, role Role, component *v1alpha1.DynamoComponentDeploymentOverridesSpec, serviceName string, multinodeDeployer MultinodeDeployer) { func (b *TRTLLMBackend) UpdateContainer(container *corev1.Container, numberOfNodes int32, role Role, component *v1alpha1.DynamoComponentDeploymentOverridesSpec, serviceName string, multinodeDeployer MultinodeDeployer) {
// For single node, nothing to do // For single node, nothing to do
...@@ -63,10 +65,10 @@ func (b *TRTLLMBackend) UpdatePodSpec(podSpec *corev1.PodSpec, numberOfNodes int ...@@ -63,10 +65,10 @@ func (b *TRTLLMBackend) UpdatePodSpec(podSpec *corev1.PodSpec, numberOfNodes int
// Add SSH keypair volume for TRTLLM multinode deployments // Add SSH keypair volume for TRTLLM multinode deployments
if numberOfNodes > 1 { if numberOfNodes > 1 {
sshVolume := corev1.Volume{ sshVolume := corev1.Volume{
Name: commonconsts.MpiRunSshSecretName, Name: b.MpiRunSecretName,
VolumeSource: corev1.VolumeSource{ VolumeSource: corev1.VolumeSource{
Secret: &corev1.SecretVolumeSource{ Secret: &corev1.SecretVolumeSource{
SecretName: commonconsts.MpiRunSshSecretName, SecretName: b.MpiRunSecretName,
DefaultMode: func() *int32 { mode := int32(0644); return &mode }(), DefaultMode: func() *int32 { mode := int32(0644); return &mode }(),
}, },
}, },
...@@ -78,7 +80,7 @@ func (b *TRTLLMBackend) UpdatePodSpec(podSpec *corev1.PodSpec, numberOfNodes int ...@@ -78,7 +80,7 @@ func (b *TRTLLMBackend) UpdatePodSpec(podSpec *corev1.PodSpec, numberOfNodes int
// addSSHVolumeMount adds the SSH keypair secret volume mount to the container // addSSHVolumeMount adds the SSH keypair secret volume mount to the container
func (b *TRTLLMBackend) addSSHVolumeMount(container *corev1.Container) { func (b *TRTLLMBackend) addSSHVolumeMount(container *corev1.Container) {
sshVolumeMount := corev1.VolumeMount{ sshVolumeMount := corev1.VolumeMount{
Name: commonconsts.MpiRunSshSecretName, Name: b.MpiRunSecretName,
MountPath: "/ssh-pk", MountPath: "/ssh-pk",
ReadOnly: true, ReadOnly: true,
} }
......
...@@ -11,6 +11,10 @@ import ( ...@@ -11,6 +11,10 @@ import (
"k8s.io/apimachinery/pkg/util/intstr" "k8s.io/apimachinery/pkg/util/intstr"
) )
const (
mpiRunSecretName = "mpi-run-ssh-secret"
)
func TestTRTLLMBackend_UpdateContainer(t *testing.T) { func TestTRTLLMBackend_UpdateContainer(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
...@@ -57,7 +61,7 @@ func TestTRTLLMBackend_UpdateContainer(t *testing.T) { ...@@ -57,7 +61,7 @@ func TestTRTLLMBackend_UpdateContainer(t *testing.T) {
}, },
}, },
expectedVolumeMounts: []corev1.VolumeMount{ expectedVolumeMounts: []corev1.VolumeMount{
{Name: commonconsts.MpiRunSshSecretName, MountPath: "/ssh-pk", ReadOnly: true}, {Name: mpiRunSecretName, MountPath: "/ssh-pk", ReadOnly: true},
}, },
expectedCommand: []string{"/bin/sh", "-c"}, expectedCommand: []string{"/bin/sh", "-c"},
expectedArgs: []string{"mkdir -p ~/.ssh && ls -la /ssh-pk/ && cp /ssh-pk/private.key ~/.ssh/id_rsa && cp /ssh-pk/private.key.pub ~/.ssh/id_rsa.pub && cp /ssh-pk/private.key.pub ~/.ssh/authorized_keys && chmod 600 ~/.ssh/id_rsa ~/.ssh/authorized_keys && chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys && printf 'Host *\\nIdentityFile ~/.ssh/id_rsa\\nStrictHostKeyChecking no\\nPort 2222\\n' > ~/.ssh/config && mpirun --oversubscribe -n 6 -H $(GROVE_PCSG_NAME)-$(GROVE_PCSG_INDEX)-test-service-ldr-0.$(GROVE_HEADLESS_SERVICE),$(GROVE_PCSG_NAME)-$(GROVE_PCSG_INDEX)-test-service-wkr-0.$(GROVE_HEADLESS_SERVICE),$(GROVE_PCSG_NAME)-$(GROVE_PCSG_INDEX)-test-service-wkr-1.$(GROVE_HEADLESS_SERVICE) --mca pml ob1 --mca plm_rsh_args \"-p 2222 -o StrictHostKeyChecking=no -i ~/.ssh/id_rsa\" -x CUDA_VISIBLE_DEVICES -x HF_DATASETS_CACHE -x HF_ENDPOINT -x HF_HOME -x HF_TOKEN -x HOME -x HUGGING_FACE_HUB_TOKEN -x LD_LIBRARY_PATH -x MODEL_PATH -x NCCL_DEBUG -x NCCL_IB_DISABLE -x NCCL_P2P_DISABLE -x OMPI_MCA_orte_keep_fqdn_hostnames -x PATH -x PYTHONPATH -x TENSORRT_LLM_CACHE_DIR -x TOKENIZERS_PARALLELISM -x TRANSFORMERS_CACHE -x USER bash -c 'source /opt/dynamo/venv/bin/activate && trtllm-llmapi-launch python3 --model test'"}, expectedArgs: []string{"mkdir -p ~/.ssh && ls -la /ssh-pk/ && cp /ssh-pk/private.key ~/.ssh/id_rsa && cp /ssh-pk/private.key.pub ~/.ssh/id_rsa.pub && cp /ssh-pk/private.key.pub ~/.ssh/authorized_keys && chmod 600 ~/.ssh/id_rsa ~/.ssh/authorized_keys && chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys && printf 'Host *\\nIdentityFile ~/.ssh/id_rsa\\nStrictHostKeyChecking no\\nPort 2222\\n' > ~/.ssh/config && mpirun --oversubscribe -n 6 -H $(GROVE_PCSG_NAME)-$(GROVE_PCSG_INDEX)-test-service-ldr-0.$(GROVE_HEADLESS_SERVICE),$(GROVE_PCSG_NAME)-$(GROVE_PCSG_INDEX)-test-service-wkr-0.$(GROVE_HEADLESS_SERVICE),$(GROVE_PCSG_NAME)-$(GROVE_PCSG_INDEX)-test-service-wkr-1.$(GROVE_HEADLESS_SERVICE) --mca pml ob1 --mca plm_rsh_args \"-p 2222 -o StrictHostKeyChecking=no -i ~/.ssh/id_rsa\" -x CUDA_VISIBLE_DEVICES -x HF_DATASETS_CACHE -x HF_ENDPOINT -x HF_HOME -x HF_TOKEN -x HOME -x HUGGING_FACE_HUB_TOKEN -x LD_LIBRARY_PATH -x MODEL_PATH -x NCCL_DEBUG -x NCCL_IB_DISABLE -x NCCL_P2P_DISABLE -x OMPI_MCA_orte_keep_fqdn_hostnames -x PATH -x PYTHONPATH -x TENSORRT_LLM_CACHE_DIR -x TOKENIZERS_PARALLELISM -x TRANSFORMERS_CACHE -x USER bash -c 'source /opt/dynamo/venv/bin/activate && trtllm-llmapi-launch python3 --model test'"},
...@@ -76,7 +80,7 @@ func TestTRTLLMBackend_UpdateContainer(t *testing.T) { ...@@ -76,7 +80,7 @@ func TestTRTLLMBackend_UpdateContainer(t *testing.T) {
multinodeDeployer: &GroveMultinodeDeployer{}, multinodeDeployer: &GroveMultinodeDeployer{},
component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{}, component: &v1alpha1.DynamoComponentDeploymentOverridesSpec{},
expectedVolumeMounts: []corev1.VolumeMount{ expectedVolumeMounts: []corev1.VolumeMount{
{Name: commonconsts.MpiRunSshSecretName, MountPath: "/ssh-pk", ReadOnly: true}, {Name: mpiRunSecretName, MountPath: "/ssh-pk", ReadOnly: true},
}, },
expectedCommand: []string{"/bin/sh", "-c"}, expectedCommand: []string{"/bin/sh", "-c"},
expectedArgs: []string{"mkdir -p ~/.ssh ~/.ssh/host_keys ~/.ssh/run && ls -la /ssh-pk/ && cp /ssh-pk/private.key ~/.ssh/id_rsa && cp /ssh-pk/private.key.pub ~/.ssh/id_rsa.pub && cp /ssh-pk/private.key.pub ~/.ssh/authorized_keys && chmod 600 ~/.ssh/id_rsa ~/.ssh/authorized_keys && chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys && printf 'Host *\\nIdentityFile ~/.ssh/id_rsa\\nStrictHostKeyChecking no\\nPort 2222\\n' > ~/.ssh/config && ssh-keygen -t rsa -f ~/.ssh/host_keys/ssh_host_rsa_key -N '' && ssh-keygen -t ecdsa -f ~/.ssh/host_keys/ssh_host_ecdsa_key -N '' && ssh-keygen -t ed25519 -f ~/.ssh/host_keys/ssh_host_ed25519_key -N '' && printf 'Port 2222\\nHostKey ~/.ssh/host_keys/ssh_host_rsa_key\\nHostKey ~/.ssh/host_keys/ssh_host_ecdsa_key\\nHostKey ~/.ssh/host_keys/ssh_host_ed25519_key\\nPidFile ~/.ssh/run/sshd.pid\\nPermitRootLogin yes\\nPasswordAuthentication no\\nPubkeyAuthentication yes\\nAuthorizedKeysFile ~/.ssh/authorized_keys\\n' > ~/.ssh/sshd_config && mkdir -p /run/sshd && /usr/sbin/sshd -D -f ~/.ssh/sshd_config"}, expectedArgs: []string{"mkdir -p ~/.ssh ~/.ssh/host_keys ~/.ssh/run && ls -la /ssh-pk/ && cp /ssh-pk/private.key ~/.ssh/id_rsa && cp /ssh-pk/private.key.pub ~/.ssh/id_rsa.pub && cp /ssh-pk/private.key.pub ~/.ssh/authorized_keys && chmod 600 ~/.ssh/id_rsa ~/.ssh/authorized_keys && chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys && printf 'Host *\\nIdentityFile ~/.ssh/id_rsa\\nStrictHostKeyChecking no\\nPort 2222\\n' > ~/.ssh/config && ssh-keygen -t rsa -f ~/.ssh/host_keys/ssh_host_rsa_key -N '' && ssh-keygen -t ecdsa -f ~/.ssh/host_keys/ssh_host_ecdsa_key -N '' && ssh-keygen -t ed25519 -f ~/.ssh/host_keys/ssh_host_ed25519_key -N '' && printf 'Port 2222\\nHostKey ~/.ssh/host_keys/ssh_host_rsa_key\\nHostKey ~/.ssh/host_keys/ssh_host_ecdsa_key\\nHostKey ~/.ssh/host_keys/ssh_host_ed25519_key\\nPidFile ~/.ssh/run/sshd.pid\\nPermitRootLogin yes\\nPasswordAuthentication no\\nPubkeyAuthentication yes\\nAuthorizedKeysFile ~/.ssh/authorized_keys\\n' > ~/.ssh/sshd_config && mkdir -p /run/sshd && /usr/sbin/sshd -D -f ~/.ssh/sshd_config"},
...@@ -113,7 +117,7 @@ func TestTRTLLMBackend_UpdateContainer(t *testing.T) { ...@@ -113,7 +117,7 @@ func TestTRTLLMBackend_UpdateContainer(t *testing.T) {
}, },
}, },
expectedVolumeMounts: []corev1.VolumeMount{ expectedVolumeMounts: []corev1.VolumeMount{
{Name: commonconsts.MpiRunSshSecretName, MountPath: "/ssh-pk", ReadOnly: true}, {Name: mpiRunSecretName, MountPath: "/ssh-pk", ReadOnly: true},
}, },
expectedCommand: []string{"/bin/sh", "-c"}, expectedCommand: []string{"/bin/sh", "-c"},
expectedArgs: []string{"mkdir -p ~/.ssh && ls -la /ssh-pk/ && cp /ssh-pk/private.key ~/.ssh/id_rsa && cp /ssh-pk/private.key.pub ~/.ssh/id_rsa.pub && cp /ssh-pk/private.key.pub ~/.ssh/authorized_keys && chmod 600 ~/.ssh/id_rsa ~/.ssh/authorized_keys && chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys && printf 'Host *\\nIdentityFile ~/.ssh/id_rsa\\nStrictHostKeyChecking no\\nPort 2222\\n' > ~/.ssh/config && mpirun --oversubscribe -n 2 -H $(LWS_LEADER_ADDRESS),$(LWS_WORKER_1_ADDRESS) --mca pml ob1 --mca plm_rsh_args \"-p 2222 -o StrictHostKeyChecking=no -i ~/.ssh/id_rsa\" -x CUDA_VISIBLE_DEVICES -x HF_DATASETS_CACHE -x HF_ENDPOINT -x HF_HOME -x HF_TOKEN -x HOME -x HUGGING_FACE_HUB_TOKEN -x LD_LIBRARY_PATH -x MODEL_PATH -x NCCL_DEBUG -x NCCL_IB_DISABLE -x NCCL_P2P_DISABLE -x OMPI_MCA_orte_keep_fqdn_hostnames -x PATH -x PYTHONPATH -x TENSORRT_LLM_CACHE_DIR -x TOKENIZERS_PARALLELISM -x TRANSFORMERS_CACHE -x USER bash -c 'source /opt/dynamo/venv/bin/activate && trtllm-llmapi-launch python3 --model test'"}, expectedArgs: []string{"mkdir -p ~/.ssh && ls -la /ssh-pk/ && cp /ssh-pk/private.key ~/.ssh/id_rsa && cp /ssh-pk/private.key.pub ~/.ssh/id_rsa.pub && cp /ssh-pk/private.key.pub ~/.ssh/authorized_keys && chmod 600 ~/.ssh/id_rsa ~/.ssh/authorized_keys && chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys && printf 'Host *\\nIdentityFile ~/.ssh/id_rsa\\nStrictHostKeyChecking no\\nPort 2222\\n' > ~/.ssh/config && mpirun --oversubscribe -n 2 -H $(LWS_LEADER_ADDRESS),$(LWS_WORKER_1_ADDRESS) --mca pml ob1 --mca plm_rsh_args \"-p 2222 -o StrictHostKeyChecking=no -i ~/.ssh/id_rsa\" -x CUDA_VISIBLE_DEVICES -x HF_DATASETS_CACHE -x HF_ENDPOINT -x HF_HOME -x HF_TOKEN -x HOME -x HUGGING_FACE_HUB_TOKEN -x LD_LIBRARY_PATH -x MODEL_PATH -x NCCL_DEBUG -x NCCL_IB_DISABLE -x NCCL_P2P_DISABLE -x OMPI_MCA_orte_keep_fqdn_hostnames -x PATH -x PYTHONPATH -x TENSORRT_LLM_CACHE_DIR -x TOKENIZERS_PARALLELISM -x TRANSFORMERS_CACHE -x USER bash -c 'source /opt/dynamo/venv/bin/activate && trtllm-llmapi-launch python3 --model test'"},
...@@ -129,7 +133,9 @@ func TestTRTLLMBackend_UpdateContainer(t *testing.T) { ...@@ -129,7 +133,9 @@ func TestTRTLLMBackend_UpdateContainer(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
backend := &TRTLLMBackend{} backend := &TRTLLMBackend{
MpiRunSecretName: mpiRunSecretName,
}
container := &corev1.Container{ container := &corev1.Container{
Args: []string{"python3", "--model", "test"}, Args: []string{"python3", "--model", "test"},
LivenessProbe: &corev1.Probe{}, LivenessProbe: &corev1.Probe{},
...@@ -334,7 +340,9 @@ func TestTRTLLMBackend_UpdatePodSpec(t *testing.T) { ...@@ -334,7 +340,9 @@ func TestTRTLLMBackend_UpdatePodSpec(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
backend := &TRTLLMBackend{} backend := &TRTLLMBackend{
MpiRunSecretName: mpiRunSecretName,
}
podSpec := &corev1.PodSpec{ podSpec := &corev1.PodSpec{
Volumes: tt.initialVolumes, Volumes: tt.initialVolumes,
Containers: []corev1.Container{ Containers: []corev1.Container{
...@@ -357,14 +365,14 @@ func TestTRTLLMBackend_UpdatePodSpec(t *testing.T) { ...@@ -357,14 +365,14 @@ func TestTRTLLMBackend_UpdatePodSpec(t *testing.T) {
// Check for SSH volume // Check for SSH volume
hasSSHVolume := false hasSSHVolume := false
for _, volume := range podSpec.Volumes { for _, volume := range podSpec.Volumes {
if volume.Name == commonconsts.MpiRunSshSecretName { if volume.Name == mpiRunSecretName {
hasSSHVolume = true hasSSHVolume = true
// Verify volume configuration // Verify volume configuration
if volume.VolumeSource.Secret == nil { if volume.VolumeSource.Secret == nil {
t.Errorf("UpdatePodSpec() SSH volume should use Secret volume source") t.Errorf("UpdatePodSpec() SSH volume should use Secret volume source")
} else { } else {
if volume.VolumeSource.Secret.SecretName != commonconsts.MpiRunSshSecretName { if volume.VolumeSource.Secret.SecretName != mpiRunSecretName {
t.Errorf("UpdatePodSpec() SSH volume secret name = %s, want %s", volume.VolumeSource.Secret.SecretName, commonconsts.MpiRunSshSecretName) t.Errorf("UpdatePodSpec() SSH volume secret name = %s, want %s", volume.VolumeSource.Secret.SecretName, mpiRunSecretName)
} }
if volume.VolumeSource.Secret.DefaultMode == nil || *volume.VolumeSource.Secret.DefaultMode != 0644 { if volume.VolumeSource.Secret.DefaultMode == nil || *volume.VolumeSource.Secret.DefaultMode != 0644 {
t.Errorf("UpdatePodSpec() SSH volume should have DefaultMode 0644") t.Errorf("UpdatePodSpec() SSH volume should have DefaultMode 0644")
...@@ -478,7 +486,7 @@ func TestTRTLLMBackend_generateWorkerHostnames(t *testing.T) { ...@@ -478,7 +486,7 @@ func TestTRTLLMBackend_generateWorkerHostnames(t *testing.T) {
func TestTRTLLMBackend_addSSHVolumeMount(t *testing.T) { func TestTRTLLMBackend_addSSHVolumeMount(t *testing.T) {
expectedSSHVolumeMount := corev1.VolumeMount{ expectedSSHVolumeMount := corev1.VolumeMount{
Name: commonconsts.MpiRunSshSecretName, Name: mpiRunSecretName,
MountPath: "/ssh-pk", MountPath: "/ssh-pk",
ReadOnly: true, ReadOnly: true,
} }
...@@ -507,7 +515,9 @@ func TestTRTLLMBackend_addSSHVolumeMount(t *testing.T) { ...@@ -507,7 +515,9 @@ func TestTRTLLMBackend_addSSHVolumeMount(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
backend := &TRTLLMBackend{} backend := &TRTLLMBackend{
MpiRunSecretName: mpiRunSecretName,
}
container := &corev1.Container{ container := &corev1.Container{
VolumeMounts: tt.initialVolumeMounts, VolumeMounts: tt.initialVolumeMounts,
} }
......
...@@ -553,14 +553,6 @@ func GenerateDefaultIngressSpec(dynamoDeployment *v1alpha1.DynamoGraphDeployment ...@@ -553,14 +553,6 @@ func GenerateDefaultIngressSpec(dynamoDeployment *v1alpha1.DynamoGraphDeployment
return res return res
} }
// Helper: mergeContainerCommand returns userCmd if specified, else defaultCmd
func mergeContainerCommand(defaultCmd, userCmd []string) []string {
if len(userCmd) > 0 {
return userCmd
}
return defaultCmd
}
// Define Role enum for leader/worker/main // Define Role enum for leader/worker/main
// Use this type everywhere instead of string for role // Use this type everywhere instead of string for role
...@@ -627,14 +619,16 @@ type MultinodeDeployer interface { ...@@ -627,14 +619,16 @@ type MultinodeDeployer interface {
} }
// BackendFactory creates backend instances based on the framework type // BackendFactory creates backend instances based on the framework type
func BackendFactory(backendFramework BackendFramework) Backend { func BackendFactory(backendFramework BackendFramework, controllerConfig controller_common.Config) Backend {
switch backendFramework { switch backendFramework {
case BackendFrameworkSGLang: case BackendFrameworkSGLang:
return &SGLangBackend{} return &SGLangBackend{}
case BackendFrameworkVLLM: case BackendFrameworkVLLM:
return &VLLMBackend{} return &VLLMBackend{}
case BackendFrameworkTRTLLM: case BackendFrameworkTRTLLM:
return &TRTLLMBackend{} return &TRTLLMBackend{
MpiRunSecretName: controllerConfig.MpiRun.SecretName,
}
case BackendFrameworkNoop: case BackendFrameworkNoop:
return &NoopBackend{} return &NoopBackend{}
default: default:
...@@ -811,7 +805,7 @@ func GenerateBasePodSpec( ...@@ -811,7 +805,7 @@ func GenerateBasePodSpec(
if multinodeDeployer == nil { if multinodeDeployer == nil {
return nil, fmt.Errorf("unsupported multinode deployment type: %s", multinodeDeploymentType) return nil, fmt.Errorf("unsupported multinode deployment type: %s", multinodeDeploymentType)
} }
backend := BackendFactory(backendFramework) backend := BackendFactory(backendFramework, controllerConfig)
if backend == nil { if backend == nil {
return nil, fmt.Errorf("unsupported backend framework: %s", backendFramework) return nil, fmt.Errorf("unsupported backend framework: %s", backendFramework)
} }
......
...@@ -3520,55 +3520,6 @@ func TestGeneratePodSpecForComponent_UnsupportedBackend(t *testing.T) { ...@@ -3520,55 +3520,6 @@ func TestGeneratePodSpecForComponent_UnsupportedBackend(t *testing.T) {
} }
} }
func TestMergeContainerCommand(t *testing.T) {
tests := []struct {
name string
defaultCmd []string
userCmd []string
expected []string
}{
{
name: "user command overrides default",
defaultCmd: []string{"python", "default.py"},
userCmd: []string{"python", "custom.py"},
expected: []string{"python", "custom.py"},
},
{
name: "empty user command returns default",
defaultCmd: []string{"python", "default.py"},
userCmd: []string{},
expected: []string{"python", "default.py"},
},
{
name: "nil user command returns default",
defaultCmd: []string{"python", "default.py"},
userCmd: nil,
expected: []string{"python", "default.py"},
},
{
name: "both empty returns empty",
defaultCmd: []string{},
userCmd: []string{},
expected: []string{},
},
{
name: "default empty user provided",
defaultCmd: []string{},
userCmd: []string{"python", "user.py"},
expected: []string{"python", "user.py"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := mergeContainerCommand(tt.defaultCmd, tt.userCmd)
if !reflect.DeepEqual(result, tt.expected) {
t.Errorf("mergeContainerCommand() = %v, want %v", result, tt.expected)
}
})
}
}
func TestExpandRolesForService(t *testing.T) { func TestExpandRolesForService(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
......
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package secret
import (
"context"
"fmt"
corev1 "k8s.io/api/core/v1"
k8serrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/client"
)
// SecretReplicator handles replication of secrets across namespaces
type SecretReplicator struct {
client.Client
sourceNamespace string
secretName string
}
// NewSecretReplicator creates a new SecretReplicator for replicating a specific secret
func NewSecretReplicator(client client.Client, sourceNamespace, secretName string) *SecretReplicator {
return &SecretReplicator{
Client: client,
sourceNamespace: sourceNamespace,
secretName: secretName,
}
}
// Replicate ensures the secret exists in the target namespace by copying from source namespace
func (r *SecretReplicator) Replicate(ctx context.Context, targetNamespace string) error {
// Check if secret already exists in target namespace
targetSecret := &corev1.Secret{}
err := r.Get(ctx, types.NamespacedName{
Name: r.secretName,
Namespace: targetNamespace,
}, targetSecret)
if err == nil {
// Secret already exists - do nothing
return nil
}
if !k8serrors.IsNotFound(err) {
return fmt.Errorf("failed to check target secret: %w", err)
}
// Get source secret
sourceSecret := &corev1.Secret{}
err = r.Get(ctx, types.NamespacedName{
Name: r.secretName,
Namespace: r.sourceNamespace,
}, sourceSecret)
if err != nil {
return fmt.Errorf("error getting source secret: %w", err)
}
// Create replica secret
replicaSecret := &corev1.Secret{
ObjectMeta: metav1.ObjectMeta{
Name: r.secretName,
Namespace: targetNamespace,
},
Type: sourceSecret.Type,
Data: sourceSecret.Data,
}
// Create the replica
err = r.Create(ctx, replicaSecret)
if err != nil && !k8serrors.IsAlreadyExists(err) {
return fmt.Errorf("failed to create replica: %w", err)
}
return nil
}
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package secret
import (
"context"
"strings"
"testing"
corev1 "k8s.io/api/core/v1"
k8serrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/client/fake"
)
func TestSecretReplicator_Replicate(t *testing.T) {
sourceSecret := &corev1.Secret{
ObjectMeta: metav1.ObjectMeta{
Name: "test-secret",
Namespace: "source-ns",
},
Type: corev1.SecretTypeOpaque,
Data: map[string][]byte{
"private.key": []byte("private-key-content"),
"private.key.pub": []byte("public-key-content"),
},
}
existingTargetSecret := &corev1.Secret{
ObjectMeta: metav1.ObjectMeta{
Name: "test-secret",
Namespace: "target-ns",
},
Type: corev1.SecretTypeOpaque,
Data: map[string][]byte{
"private.key": []byte("existing-private-key"),
"private.key.pub": []byte("existing-public-key"),
},
}
tests := []struct {
name string
sourceNamespace string
secretName string
targetNamespace string
existingSecrets []client.Object
mockGetError error
mockCreateError error
wantError bool
wantErrorContains string
validateResult func(t *testing.T, client client.Client)
}{
{
name: "secret already exists in target namespace - does nothing",
sourceNamespace: "source-ns",
secretName: "test-secret",
targetNamespace: "target-ns",
existingSecrets: []client.Object{sourceSecret, existingTargetSecret},
wantError: false,
validateResult: func(t *testing.T, client client.Client) {
// Should not have modified existing secret
var secret corev1.Secret
err := client.Get(context.Background(), types.NamespacedName{
Name: "test-secret",
Namespace: "target-ns",
}, &secret)
if err != nil {
t.Errorf("Expected secret to exist in target namespace")
}
if string(secret.Data["private.key"]) != "existing-private-key" {
t.Errorf("Expected existing secret to remain unchanged")
}
},
},
{
name: "source secret does not exist - returns error",
sourceNamespace: "source-ns",
secretName: "missing-secret",
targetNamespace: "target-ns",
existingSecrets: []client.Object{},
wantError: true,
wantErrorContains: "error getting source secret",
},
{
name: "successful replication",
sourceNamespace: "source-ns",
secretName: "test-secret",
targetNamespace: "target-ns",
existingSecrets: []client.Object{sourceSecret},
wantError: false,
validateResult: func(t *testing.T, client client.Client) {
var secret corev1.Secret
err := client.Get(context.Background(), types.NamespacedName{
Name: "test-secret",
Namespace: "target-ns",
}, &secret)
if err != nil {
t.Errorf("Expected secret to be created in target namespace: %v", err)
}
if secret.Type != corev1.SecretTypeOpaque {
t.Errorf("Expected secret type %v, got %v", corev1.SecretTypeOpaque, secret.Type)
}
if string(secret.Data["private.key"]) != "private-key-content" {
t.Errorf("Expected private key data to be copied")
}
if string(secret.Data["private.key.pub"]) != "public-key-content" {
t.Errorf("Expected public key data to be copied")
}
},
},
{
name: "race condition - AlreadyExists error is ignored",
sourceNamespace: "source-ns",
secretName: "test-secret",
targetNamespace: "target-ns",
existingSecrets: []client.Object{sourceSecret},
mockCreateError: k8serrors.NewAlreadyExists(schema.GroupResource{Resource: "secrets"}, "test-secret"),
wantError: false,
},
{
name: "create error other than AlreadyExists - returns error",
sourceNamespace: "source-ns",
secretName: "test-secret",
targetNamespace: "target-ns",
existingSecrets: []client.Object{sourceSecret},
mockCreateError: k8serrors.NewServiceUnavailable("mock error"),
wantError: true,
wantErrorContains: "failed to create replica",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create fake client with existing secrets
scheme := runtime.NewScheme()
_ = corev1.AddToScheme(scheme)
clientBuilder := fake.NewClientBuilder().WithScheme(scheme)
if len(tt.existingSecrets) > 0 {
clientBuilder = clientBuilder.WithObjects(tt.existingSecrets...)
}
fakeClient := clientBuilder.Build()
// Wrap client to inject errors if needed
var testClient client.Client = fakeClient
if tt.mockCreateError != nil {
testClient = &errorInjectingClient{
Client: fakeClient,
createError: tt.mockCreateError,
}
}
replicator := NewSecretReplicator(testClient, tt.sourceNamespace, tt.secretName)
err := replicator.Replicate(context.Background(), tt.targetNamespace)
if tt.wantError {
if err == nil {
t.Errorf("Replicate() expected error, got nil")
} else if tt.wantErrorContains != "" && !strings.Contains(err.Error(), tt.wantErrorContains) {
t.Errorf("Replicate() error = %v, want error containing %v", err, tt.wantErrorContains)
}
} else {
if err != nil {
t.Errorf("Replicate() unexpected error = %v", err)
}
}
if tt.validateResult != nil {
tt.validateResult(t, fakeClient)
}
})
}
}
// errorInjectingClient wraps a client to inject specific errors for testing
type errorInjectingClient struct {
client.Client
createError error
}
func (c *errorInjectingClient) Create(ctx context.Context, obj client.Object, opts ...client.CreateOption) error {
if c.createError != nil {
return c.createError
}
return c.Client.Create(ctx, obj, opts...)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment