Unverified Commit 6b8bb99a authored by Xavier Chang's avatar Xavier Chang Committed by GitHub
Browse files

refactor: make hostnames more descriptive and simplify dns check command (#4551)


Signed-off-by: default avatarXavier Chang <xuzhengc@nvidia.com>
parent ff7aea54
......@@ -102,8 +102,9 @@ func (b *TRTLLMBackend) addSSHVolumeMount(container *corev1.Container) {
// setupLeaderContainer configures the leader node with SSH setup and mpirun command
func (b *TRTLLMBackend) setupLeaderContainer(container *corev1.Container, numberOfNodes int32, serviceName string, component *v1alpha1.DynamoComponentDeploymentSharedSpec, multinodeDeployer MultinodeDeployer) {
// Generate the list of worker hostnames
workerHosts := b.generateWorkerHostnames(numberOfNodes, serviceName, multinodeDeployer)
// Generate the list of all hostnames
hostNamesList := b.hostNamesList(numberOfNodes, serviceName, multinodeDeployer)
allHostnames := strings.Join(hostNamesList, ",")
// Store original command/args for later use
var originalCommand string
......@@ -149,7 +150,7 @@ func (b *TRTLLMBackend) setupLeaderContainer(container *corev1.Container, number
mpirunCmd := fmt.Sprintf("mpirun --allow-run-as-root --oversubscribe -n %d -H %s --mca pml ob1 --mca plm_rsh_args \"-p %d -o StrictHostKeyChecking=no -i ~/.ssh/id_rsa\" %s %s",
totalGPUs,
workerHosts,
allHostnames,
commonconsts.MpiRunSshPort,
envVarsStr,
wrappedCommand)
......@@ -158,7 +159,9 @@ func (b *TRTLLMBackend) setupLeaderContainer(container *corev1.Container, number
var allCommands []string
if multinodeDeployer.NeedsDNSWait() {
// Wait for DNS resolution of all worker nodes (needed for LWS)
dnsWaitCmd := fmt.Sprintf(`TIMEOUT=300; START_TIME=$(date +%%s); for worker in $(echo "%s" | tr ',' ' '); do echo "Waiting for DNS: $worker"; until getent hosts $worker >/dev/null 2>&1; do CURRENT_TIME=$(date +%%s); if [ $((CURRENT_TIME - START_TIME)) -gt $TIMEOUT ]; then echo "ERROR: Timeout waiting for DNS: $worker"; exit 1; fi; echo "DNS not ready for $worker, retrying..."; sleep 2; done; echo "✓ DNS resolved: $worker"; done; echo "All workers DNS ready"`, workerHosts)
workerHosts := strings.Join(hostNamesList[1:], " ")
dnsWaitCmd := fmt.Sprintf(`TIMEOUT=300; START_TIME=$(date +%%s); for worker in %s; do echo "Waiting for DNS: $worker"; until getent hosts $worker >/dev/null 2>&1; do CURRENT_TIME=$(date +%%s); if [ $((CURRENT_TIME - START_TIME)) -gt $TIMEOUT ]; then echo "ERROR: Timeout waiting for DNS: $worker"; exit 1; fi; echo "DNS not ready for $worker, retrying..."; sleep 2; done; echo "✓ DNS resolved: $worker"; done; echo "All workers DNS ready"`, workerHosts)
allCommands = append(sshSetupCommands, dnsWaitCmd, mpirunCmd)
} else {
allCommands = append(sshSetupCommands, mpirunCmd)
......@@ -198,9 +201,9 @@ func (b *TRTLLMBackend) setupWorkerContainer(container *corev1.Container) {
container.Args = []string{fullCommand}
}
// generateWorkerHostnames creates a comma-separated list of worker hostnames
func (b *TRTLLMBackend) generateWorkerHostnames(numberOfNodes int32, serviceName string, multinodeDeployer MultinodeDeployer) string {
return strings.Join(multinodeDeployer.GetHostNames(serviceName, numberOfNodes), ",")
// hostNamesList generates the list of hostnames for all nodes in the multinode deployment
func (b *TRTLLMBackend) hostNamesList(numberOfNodes int32, serviceName string, multinodeDeployer MultinodeDeployer) []string {
return multinodeDeployer.GetHostNames(serviceName, numberOfNodes)
}
// getGPUsPerNode extracts the number of GPUs per node from resources
......
package dynamo
import (
"strings"
"reflect"
"testing"
"github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1"
......@@ -115,7 +115,7 @@ func TestTRTLLMBackend_UpdateContainer(t *testing.T) {
{Name: mpiRunSecretName, MountPath: "/ssh-pk", ReadOnly: true},
},
expectedCommand: []string{"/bin/sh", "-c"},
expectedArgs: []string{"mkdir -p ~/.ssh && ls -la /ssh-pk/ && cp /ssh-pk/private.key ~/.ssh/id_rsa && cp /ssh-pk/private.key.pub ~/.ssh/id_rsa.pub && cp /ssh-pk/private.key.pub ~/.ssh/authorized_keys && chmod 600 ~/.ssh/id_rsa ~/.ssh/authorized_keys && chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys && printf 'Host *\\nIdentityFile ~/.ssh/id_rsa\\nStrictHostKeyChecking no\\nPort 2222\\n' > ~/.ssh/config && TIMEOUT=300; START_TIME=$(date +%s); for worker in $(echo \"$LWS_LEADER_ADDRESS,$(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-1\\./')\" | tr ',' ' '); do echo \"Waiting for DNS: $worker\"; until getent hosts $worker >/dev/null 2>&1; do CURRENT_TIME=$(date +%s); if [ $((CURRENT_TIME - START_TIME)) -gt $TIMEOUT ]; then echo \"ERROR: Timeout waiting for DNS: $worker\"; exit 1; fi; echo \"DNS not ready for $worker, retrying...\"; sleep 2; done; echo \"✓ DNS resolved: $worker\"; done; echo \"All workers DNS ready\" && mpirun --allow-run-as-root --oversubscribe -n 2 -H $LWS_LEADER_ADDRESS,$(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-1\\./') --mca pml ob1 --mca plm_rsh_args \"-p 2222 -o StrictHostKeyChecking=no -i ~/.ssh/id_rsa\" -x CUDA_VISIBLE_DEVICES -x HF_DATASETS_CACHE -x HF_ENDPOINT -x HF_HOME -x HF_TOKEN -x HOME -x HUGGING_FACE_HUB_TOKEN -x LD_LIBRARY_PATH -x MODEL_PATH -x NCCL_DEBUG -x NCCL_IB_DISABLE -x NCCL_P2P_DISABLE -x OMPI_MCA_orte_keep_fqdn_hostnames -x PATH -x PYTHONPATH -x TENSORRT_LLM_CACHE_DIR -x TOKENIZERS_PARALLELISM -x TRANSFORMERS_CACHE -x TRTLLM_USE_UCX_KVCACHE -x USER bash -c 'trtllm-llmapi-launch python3 --model test'"},
expectedArgs: []string{"mkdir -p ~/.ssh && ls -la /ssh-pk/ && cp /ssh-pk/private.key ~/.ssh/id_rsa && cp /ssh-pk/private.key.pub ~/.ssh/id_rsa.pub && cp /ssh-pk/private.key.pub ~/.ssh/authorized_keys && chmod 600 ~/.ssh/id_rsa ~/.ssh/authorized_keys && chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys && printf 'Host *\\nIdentityFile ~/.ssh/id_rsa\\nStrictHostKeyChecking no\\nPort 2222\\n' > ~/.ssh/config && TIMEOUT=300; START_TIME=$(date +%s); for worker in $(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-1\\./'); do echo \"Waiting for DNS: $worker\"; until getent hosts $worker >/dev/null 2>&1; do CURRENT_TIME=$(date +%s); if [ $((CURRENT_TIME - START_TIME)) -gt $TIMEOUT ]; then echo \"ERROR: Timeout waiting for DNS: $worker\"; exit 1; fi; echo \"DNS not ready for $worker, retrying...\"; sleep 2; done; echo \"✓ DNS resolved: $worker\"; done; echo \"All workers DNS ready\" && mpirun --allow-run-as-root --oversubscribe -n 2 -H $LWS_LEADER_ADDRESS,$(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-1\\./') --mca pml ob1 --mca plm_rsh_args \"-p 2222 -o StrictHostKeyChecking=no -i ~/.ssh/id_rsa\" -x CUDA_VISIBLE_DEVICES -x HF_DATASETS_CACHE -x HF_ENDPOINT -x HF_HOME -x HF_TOKEN -x HOME -x HUGGING_FACE_HUB_TOKEN -x LD_LIBRARY_PATH -x MODEL_PATH -x NCCL_DEBUG -x NCCL_IB_DISABLE -x NCCL_P2P_DISABLE -x OMPI_MCA_orte_keep_fqdn_hostnames -x PATH -x PYTHONPATH -x TENSORRT_LLM_CACHE_DIR -x TOKENIZERS_PARALLELISM -x TRANSFORMERS_CACHE -x TRTLLM_USE_UCX_KVCACHE -x USER bash -c 'trtllm-llmapi-launch python3 --model test'"},
expectedEnv: []corev1.EnvVar{
{Name: "OMPI_MCA_orte_keep_fqdn_hostnames", Value: "1"},
},
......@@ -387,14 +387,13 @@ func TestTRTLLMBackend_UpdatePodSpec(t *testing.T) {
}
}
func TestTRTLLMBackend_generateWorkerHostnames(t *testing.T) {
func TestTRTLLMBackend_hostNamesList(t *testing.T) {
tests := []struct {
name string
numberOfNodes int32
multinodeDeployer MultinodeDeployer
serviceName string
expectedContains []string
expectedNodeCount int32
}{
{
name: "Grove deployment with 3 nodes",
......@@ -402,13 +401,10 @@ func TestTRTLLMBackend_generateWorkerHostnames(t *testing.T) {
multinodeDeployer: &GroveMultinodeDeployer{},
serviceName: "test-service",
expectedContains: []string{
"test-service-ldr-0",
"test-service-wkr-0",
"test-service-wkr-1",
"GROVE_PCSG_NAME",
"GROVE_HEADLESS_SERVICE",
"$(GROVE_PCSG_NAME)-$(GROVE_PCSG_INDEX)-test-service-ldr-0.$(GROVE_HEADLESS_SERVICE)",
"$(GROVE_PCSG_NAME)-$(GROVE_PCSG_INDEX)-test-service-wkr-0.$(GROVE_HEADLESS_SERVICE)",
"$(GROVE_PCSG_NAME)-$(GROVE_PCSG_INDEX)-test-service-wkr-1.$(GROVE_HEADLESS_SERVICE)",
},
expectedNodeCount: 3,
},
{
name: "LWS deployment with 2 nodes",
......@@ -419,7 +415,6 @@ func TestTRTLLMBackend_generateWorkerHostnames(t *testing.T) {
"$LWS_LEADER_ADDRESS",
"$(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-1\\./')",
},
expectedNodeCount: 2,
},
{
name: "Grove deployment with 5 nodes",
......@@ -427,13 +422,12 @@ func TestTRTLLMBackend_generateWorkerHostnames(t *testing.T) {
multinodeDeployer: &GroveMultinodeDeployer{},
serviceName: "worker",
expectedContains: []string{
"worker-ldr-0",
"worker-wkr-0",
"worker-wkr-1",
"worker-wkr-2",
"worker-wkr-3",
"$(GROVE_PCSG_NAME)-$(GROVE_PCSG_INDEX)-worker-ldr-0.$(GROVE_HEADLESS_SERVICE)",
"$(GROVE_PCSG_NAME)-$(GROVE_PCSG_INDEX)-worker-wkr-0.$(GROVE_HEADLESS_SERVICE)",
"$(GROVE_PCSG_NAME)-$(GROVE_PCSG_INDEX)-worker-wkr-1.$(GROVE_HEADLESS_SERVICE)",
"$(GROVE_PCSG_NAME)-$(GROVE_PCSG_INDEX)-worker-wkr-2.$(GROVE_HEADLESS_SERVICE)",
"$(GROVE_PCSG_NAME)-$(GROVE_PCSG_INDEX)-worker-wkr-3.$(GROVE_HEADLESS_SERVICE)",
},
expectedNodeCount: 5,
},
{
name: "LWS deployment with 4 nodes",
......@@ -446,32 +440,16 @@ func TestTRTLLMBackend_generateWorkerHostnames(t *testing.T) {
"$(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-2\\./')",
"$(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-3\\./')",
},
expectedNodeCount: 4,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
backend := &TRTLLMBackend{}
result := backend.generateWorkerHostnames(tt.numberOfNodes, tt.serviceName, tt.multinodeDeployer)
result := backend.hostNamesList(tt.numberOfNodes, tt.serviceName, tt.multinodeDeployer)
for _, expected := range tt.expectedContains {
if !strings.Contains(result, expected) {
t.Errorf("generateWorkerHostnames() = %s, should contain %s", result, expected)
}
}
// Check that result is comma-separated with correct count
parts := strings.Split(result, ",")
if int32(len(parts)) != tt.expectedNodeCount {
t.Errorf("generateWorkerHostnames() should have %d hostnames, got %d: %v", tt.expectedNodeCount, len(parts), parts)
}
// Verify no empty parts
for i, part := range parts {
if strings.TrimSpace(part) == "" {
t.Errorf("generateWorkerHostnames() has empty hostname at index %d", i)
}
if !reflect.DeepEqual(result, tt.expectedContains) {
t.Errorf("hostNamesList() = %s, should be %s", result, tt.expectedContains)
}
})
}
......@@ -574,7 +552,7 @@ func TestTRTLLMBackend_setupLeaderContainer(t *testing.T) {
component: &v1alpha1.DynamoComponentDeploymentSharedSpec{},
initialArgs: []string{},
initialCommand: []string{"python", "-m", "worker"},
expected: "mkdir -p ~/.ssh && ls -la /ssh-pk/ && cp /ssh-pk/private.key ~/.ssh/id_rsa && cp /ssh-pk/private.key.pub ~/.ssh/id_rsa.pub && cp /ssh-pk/private.key.pub ~/.ssh/authorized_keys && chmod 600 ~/.ssh/id_rsa ~/.ssh/authorized_keys && chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys && printf 'Host *\\nIdentityFile ~/.ssh/id_rsa\\nStrictHostKeyChecking no\\nPort 2222\\n' > ~/.ssh/config && TIMEOUT=300; START_TIME=$(date +%s); for worker in $(echo \"$LWS_LEADER_ADDRESS,$(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-1\\./')\" | tr ',' ' '); do echo \"Waiting for DNS: $worker\"; until getent hosts $worker >/dev/null 2>&1; do CURRENT_TIME=$(date +%s); if [ $((CURRENT_TIME - START_TIME)) -gt $TIMEOUT ]; then echo \"ERROR: Timeout waiting for DNS: $worker\"; exit 1; fi; echo \"DNS not ready for $worker, retrying...\"; sleep 2; done; echo \"✓ DNS resolved: $worker\"; done; echo \"All workers DNS ready\" && mpirun --allow-run-as-root --oversubscribe -n 0 -H $LWS_LEADER_ADDRESS,$(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-1\\./') --mca pml ob1 --mca plm_rsh_args \"-p 2222 -o StrictHostKeyChecking=no -i ~/.ssh/id_rsa\" -x CUDA_VISIBLE_DEVICES -x HF_DATASETS_CACHE -x HF_ENDPOINT -x HF_HOME -x HF_TOKEN -x HOME -x HUGGING_FACE_HUB_TOKEN -x LD_LIBRARY_PATH -x MODEL_PATH -x NCCL_DEBUG -x NCCL_IB_DISABLE -x NCCL_P2P_DISABLE -x PATH -x PYTHONPATH -x TENSORRT_LLM_CACHE_DIR -x TOKENIZERS_PARALLELISM -x TRANSFORMERS_CACHE -x TRTLLM_USE_UCX_KVCACHE -x USER bash -c 'trtllm-llmapi-launch python -m worker'",
expected: "mkdir -p ~/.ssh && ls -la /ssh-pk/ && cp /ssh-pk/private.key ~/.ssh/id_rsa && cp /ssh-pk/private.key.pub ~/.ssh/id_rsa.pub && cp /ssh-pk/private.key.pub ~/.ssh/authorized_keys && chmod 600 ~/.ssh/id_rsa ~/.ssh/authorized_keys && chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys && printf 'Host *\\nIdentityFile ~/.ssh/id_rsa\\nStrictHostKeyChecking no\\nPort 2222\\n' > ~/.ssh/config && TIMEOUT=300; START_TIME=$(date +%s); for worker in $(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-1\\./'); do echo \"Waiting for DNS: $worker\"; until getent hosts $worker >/dev/null 2>&1; do CURRENT_TIME=$(date +%s); if [ $((CURRENT_TIME - START_TIME)) -gt $TIMEOUT ]; then echo \"ERROR: Timeout waiting for DNS: $worker\"; exit 1; fi; echo \"DNS not ready for $worker, retrying...\"; sleep 2; done; echo \"✓ DNS resolved: $worker\"; done; echo \"All workers DNS ready\" && mpirun --allow-run-as-root --oversubscribe -n 0 -H $LWS_LEADER_ADDRESS,$(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-1\\./') --mca pml ob1 --mca plm_rsh_args \"-p 2222 -o StrictHostKeyChecking=no -i ~/.ssh/id_rsa\" -x CUDA_VISIBLE_DEVICES -x HF_DATASETS_CACHE -x HF_ENDPOINT -x HF_HOME -x HF_TOKEN -x HOME -x HUGGING_FACE_HUB_TOKEN -x LD_LIBRARY_PATH -x MODEL_PATH -x NCCL_DEBUG -x NCCL_IB_DISABLE -x NCCL_P2P_DISABLE -x PATH -x PYTHONPATH -x TENSORRT_LLM_CACHE_DIR -x TOKENIZERS_PARALLELISM -x TRANSFORMERS_CACHE -x TRTLLM_USE_UCX_KVCACHE -x USER bash -c 'trtllm-llmapi-launch python -m worker'",
},
{
name: "Leader with both command and args (shell command - args take precedence)",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment