Unverified Commit 17e22476 authored by Julien Mancuso's avatar Julien Mancuso Committed by GitHub
Browse files

fix: fix trtllm multinode deployment with LWS (#4477)


Signed-off-by: default avatarJulien Mancuso <jmancuso@nvidia.com>
parent 781331c6
...@@ -952,7 +952,7 @@ func TestDynamoComponentDeploymentReconciler_generateLeaderWorkerSet(t *testing. ...@@ -952,7 +952,7 @@ func TestDynamoComponentDeploymentReconciler_generateLeaderWorkerSet(t *testing.
Name: commonconsts.MainContainerName, Name: commonconsts.MainContainerName,
Image: "test-image:latest", Image: "test-image:latest",
Command: []string{"/bin/sh", "-c"}, Command: []string{"/bin/sh", "-c"},
Args: []string{"ray start --address=$(LWS_LEADER_ADDRESS):6379 --block"}, Args: []string{"ray start --address=$LWS_LEADER_ADDRESS:6379 --block"},
Env: []corev1.EnvVar{ Env: []corev1.EnvVar{
{Name: commonconsts.DynamoComponentEnvVar, Value: commonconsts.ComponentTypeWorker}, {Name: commonconsts.DynamoComponentEnvVar, Value: commonconsts.ComponentTypeWorker},
{Name: commonconsts.DynamoNamespaceEnvVar, Value: "default"}, {Name: commonconsts.DynamoNamespaceEnvVar, Value: "default"},
......
...@@ -28,6 +28,10 @@ func (m *MockSimpleDeployer) GetNodeRank() (string, bool) { ...@@ -28,6 +28,10 @@ func (m *MockSimpleDeployer) GetNodeRank() (string, bool) {
return "1", false // simple rank, no shell interpretation needed return "1", false // simple rank, no shell interpretation needed
} }
func (m *MockSimpleDeployer) NeedsDNSWait() bool {
return false
}
// Mock MultinodeDeployer for testing with shell interpretation needed // Mock MultinodeDeployer for testing with shell interpretation needed
type MockShellDeployer struct{} type MockShellDeployer struct{}
...@@ -48,6 +52,10 @@ func (m *MockShellDeployer) GetNodeRank() (string, bool) { ...@@ -48,6 +52,10 @@ func (m *MockShellDeployer) GetNodeRank() (string, bool) {
return "$(WORKER_INDEX)", true // needs shell interpretation return "$(WORKER_INDEX)", true // needs shell interpretation
} }
func (m *MockShellDeployer) NeedsDNSWait() bool {
return true
}
func TestSGLangBackend_PythonCommandInjection(t *testing.T) { func TestSGLangBackend_PythonCommandInjection(t *testing.T) {
backend := &SGLangBackend{} backend := &SGLangBackend{}
...@@ -265,7 +273,7 @@ func TestSGLangBackend_ShellCommandInjection(t *testing.T) { ...@@ -265,7 +273,7 @@ func TestSGLangBackend_ShellCommandInjection(t *testing.T) {
multinodeDeployer: &LWSMultinodeDeployer{}, multinodeDeployer: &LWSMultinodeDeployer{},
initialCommand: []string{"sh", "-c"}, initialCommand: []string{"sh", "-c"},
initialArgs: []string{"python -m dynamo.sglang"}, initialArgs: []string{"python -m dynamo.sglang"},
expectedArgs: []string{"python -m dynamo.sglang --dist-init-addr $(LWS_LEADER_ADDRESS):29500 --nnodes 2 --node-rank 0"}, expectedArgs: []string{"python -m dynamo.sglang --dist-init-addr $LWS_LEADER_ADDRESS:29500 --nnodes 2 --node-rank 0"},
description: "LWS shell commands should use LWS variables", description: "LWS shell commands should use LWS variables",
}, },
{ {
......
...@@ -154,8 +154,16 @@ func (b *TRTLLMBackend) setupLeaderContainer(container *corev1.Container, number ...@@ -154,8 +154,16 @@ func (b *TRTLLMBackend) setupLeaderContainer(container *corev1.Container, number
envVarsStr, envVarsStr,
wrappedCommand) wrappedCommand)
// Combine SSH setup and mpirun command // Combine SSH setup and mpirun command, optionally adding DNS wait for deployers that need it
fullCommand := strings.Join(append(sshSetupCommands, mpirunCmd), " && ") var allCommands []string
if multinodeDeployer.NeedsDNSWait() {
// Wait for DNS resolution of all worker nodes (needed for LWS)
dnsWaitCmd := fmt.Sprintf(`TIMEOUT=300; START_TIME=$(date +%%s); for worker in $(echo "%s" | tr ',' ' '); do echo "Waiting for DNS: $worker"; until getent hosts $worker >/dev/null 2>&1; do CURRENT_TIME=$(date +%%s); if [ $((CURRENT_TIME - START_TIME)) -gt $TIMEOUT ]; then echo "ERROR: Timeout waiting for DNS: $worker"; exit 1; fi; echo "DNS not ready for $worker, retrying..."; sleep 2; done; echo "✓ DNS resolved: $worker"; done; echo "All workers DNS ready"`, workerHosts)
allCommands = append(sshSetupCommands, dnsWaitCmd, mpirunCmd)
} else {
allCommands = append(sshSetupCommands, mpirunCmd)
}
fullCommand := strings.Join(allCommands, " && ")
// Update container to use bash with the full command // Update container to use bash with the full command
container.Command = []string{"/bin/sh", "-c"} container.Command = []string{"/bin/sh", "-c"}
......
...@@ -115,7 +115,7 @@ func TestTRTLLMBackend_UpdateContainer(t *testing.T) { ...@@ -115,7 +115,7 @@ func TestTRTLLMBackend_UpdateContainer(t *testing.T) {
{Name: mpiRunSecretName, MountPath: "/ssh-pk", ReadOnly: true}, {Name: mpiRunSecretName, MountPath: "/ssh-pk", ReadOnly: true},
}, },
expectedCommand: []string{"/bin/sh", "-c"}, expectedCommand: []string{"/bin/sh", "-c"},
expectedArgs: []string{"mkdir -p ~/.ssh && ls -la /ssh-pk/ && cp /ssh-pk/private.key ~/.ssh/id_rsa && cp /ssh-pk/private.key.pub ~/.ssh/id_rsa.pub && cp /ssh-pk/private.key.pub ~/.ssh/authorized_keys && chmod 600 ~/.ssh/id_rsa ~/.ssh/authorized_keys && chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys && printf 'Host *\\nIdentityFile ~/.ssh/id_rsa\\nStrictHostKeyChecking no\\nPort 2222\\n' > ~/.ssh/config && mpirun --allow-run-as-root --oversubscribe -n 2 -H $(LWS_LEADER_ADDRESS),$(LWS_WORKER_1_ADDRESS) --mca pml ob1 --mca plm_rsh_args \"-p 2222 -o StrictHostKeyChecking=no -i ~/.ssh/id_rsa\" -x CUDA_VISIBLE_DEVICES -x HF_DATASETS_CACHE -x HF_ENDPOINT -x HF_HOME -x HF_TOKEN -x HOME -x HUGGING_FACE_HUB_TOKEN -x LD_LIBRARY_PATH -x MODEL_PATH -x NCCL_DEBUG -x NCCL_IB_DISABLE -x NCCL_P2P_DISABLE -x OMPI_MCA_orte_keep_fqdn_hostnames -x PATH -x PYTHONPATH -x TENSORRT_LLM_CACHE_DIR -x TOKENIZERS_PARALLELISM -x TRANSFORMERS_CACHE -x USER bash -c 'trtllm-llmapi-launch python3 --model test'"}, expectedArgs: []string{"mkdir -p ~/.ssh && ls -la /ssh-pk/ && cp /ssh-pk/private.key ~/.ssh/id_rsa && cp /ssh-pk/private.key.pub ~/.ssh/id_rsa.pub && cp /ssh-pk/private.key.pub ~/.ssh/authorized_keys && chmod 600 ~/.ssh/id_rsa ~/.ssh/authorized_keys && chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys && printf 'Host *\\nIdentityFile ~/.ssh/id_rsa\\nStrictHostKeyChecking no\\nPort 2222\\n' > ~/.ssh/config && TIMEOUT=300; START_TIME=$(date +%s); for worker in $(echo \"$LWS_LEADER_ADDRESS,$(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-1\\./')\" | tr ',' ' '); do echo \"Waiting for DNS: $worker\"; until getent hosts $worker >/dev/null 2>&1; do CURRENT_TIME=$(date +%s); if [ $((CURRENT_TIME - START_TIME)) -gt $TIMEOUT ]; then echo \"ERROR: Timeout waiting for DNS: $worker\"; exit 1; fi; echo \"DNS not ready for $worker, retrying...\"; sleep 2; done; echo \"✓ DNS resolved: $worker\"; done; echo \"All workers DNS ready\" && mpirun --allow-run-as-root --oversubscribe -n 2 -H $LWS_LEADER_ADDRESS,$(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-1\\./') --mca pml ob1 --mca plm_rsh_args \"-p 2222 -o StrictHostKeyChecking=no -i ~/.ssh/id_rsa\" -x CUDA_VISIBLE_DEVICES -x HF_DATASETS_CACHE -x HF_ENDPOINT -x HF_HOME -x HF_TOKEN -x HOME -x HUGGING_FACE_HUB_TOKEN -x LD_LIBRARY_PATH -x MODEL_PATH -x NCCL_DEBUG -x NCCL_IB_DISABLE -x NCCL_P2P_DISABLE -x OMPI_MCA_orte_keep_fqdn_hostnames -x PATH -x PYTHONPATH -x TENSORRT_LLM_CACHE_DIR -x TOKENIZERS_PARALLELISM -x TRANSFORMERS_CACHE -x USER bash -c 'trtllm-llmapi-launch python3 --model test'"},
expectedEnv: []corev1.EnvVar{ expectedEnv: []corev1.EnvVar{
{Name: "OMPI_MCA_orte_keep_fqdn_hostnames", Value: "1"}, {Name: "OMPI_MCA_orte_keep_fqdn_hostnames", Value: "1"},
}, },
...@@ -416,8 +416,8 @@ func TestTRTLLMBackend_generateWorkerHostnames(t *testing.T) { ...@@ -416,8 +416,8 @@ func TestTRTLLMBackend_generateWorkerHostnames(t *testing.T) {
multinodeDeployer: &LWSMultinodeDeployer{}, multinodeDeployer: &LWSMultinodeDeployer{},
serviceName: "test-service", serviceName: "test-service",
expectedContains: []string{ expectedContains: []string{
"$(LWS_LEADER_ADDRESS)", "$LWS_LEADER_ADDRESS",
"$(LWS_WORKER_1_ADDRESS)", "$(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-1\\./')",
}, },
expectedNodeCount: 2, expectedNodeCount: 2,
}, },
...@@ -441,10 +441,10 @@ func TestTRTLLMBackend_generateWorkerHostnames(t *testing.T) { ...@@ -441,10 +441,10 @@ func TestTRTLLMBackend_generateWorkerHostnames(t *testing.T) {
multinodeDeployer: &LWSMultinodeDeployer{}, multinodeDeployer: &LWSMultinodeDeployer{},
serviceName: "worker", serviceName: "worker",
expectedContains: []string{ expectedContains: []string{
"$(LWS_LEADER_ADDRESS)", "$LWS_LEADER_ADDRESS",
"$(LWS_WORKER_1_ADDRESS)", "$(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-1\\./')",
"$(LWS_WORKER_2_ADDRESS)", "$(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-2\\./')",
"$(LWS_WORKER_3_ADDRESS)", "$(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-3\\./')",
}, },
expectedNodeCount: 4, expectedNodeCount: 4,
}, },
...@@ -574,7 +574,7 @@ func TestTRTLLMBackend_setupLeaderContainer(t *testing.T) { ...@@ -574,7 +574,7 @@ func TestTRTLLMBackend_setupLeaderContainer(t *testing.T) {
component: &v1alpha1.DynamoComponentDeploymentSharedSpec{}, component: &v1alpha1.DynamoComponentDeploymentSharedSpec{},
initialArgs: []string{}, initialArgs: []string{},
initialCommand: []string{"python", "-m", "worker"}, initialCommand: []string{"python", "-m", "worker"},
expected: "mkdir -p ~/.ssh && ls -la /ssh-pk/ && cp /ssh-pk/private.key ~/.ssh/id_rsa && cp /ssh-pk/private.key.pub ~/.ssh/id_rsa.pub && cp /ssh-pk/private.key.pub ~/.ssh/authorized_keys && chmod 600 ~/.ssh/id_rsa ~/.ssh/authorized_keys && chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys && printf 'Host *\\nIdentityFile ~/.ssh/id_rsa\\nStrictHostKeyChecking no\\nPort 2222\\n' > ~/.ssh/config && mpirun --allow-run-as-root --oversubscribe -n 0 -H $(LWS_LEADER_ADDRESS),$(LWS_WORKER_1_ADDRESS) --mca pml ob1 --mca plm_rsh_args \"-p 2222 -o StrictHostKeyChecking=no -i ~/.ssh/id_rsa\" -x CUDA_VISIBLE_DEVICES -x HF_DATASETS_CACHE -x HF_ENDPOINT -x HF_HOME -x HF_TOKEN -x HOME -x HUGGING_FACE_HUB_TOKEN -x LD_LIBRARY_PATH -x MODEL_PATH -x NCCL_DEBUG -x NCCL_IB_DISABLE -x NCCL_P2P_DISABLE -x PATH -x PYTHONPATH -x TENSORRT_LLM_CACHE_DIR -x TOKENIZERS_PARALLELISM -x TRANSFORMERS_CACHE -x USER bash -c 'trtllm-llmapi-launch python -m worker'", expected: "mkdir -p ~/.ssh && ls -la /ssh-pk/ && cp /ssh-pk/private.key ~/.ssh/id_rsa && cp /ssh-pk/private.key.pub ~/.ssh/id_rsa.pub && cp /ssh-pk/private.key.pub ~/.ssh/authorized_keys && chmod 600 ~/.ssh/id_rsa ~/.ssh/authorized_keys && chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys && printf 'Host *\\nIdentityFile ~/.ssh/id_rsa\\nStrictHostKeyChecking no\\nPort 2222\\n' > ~/.ssh/config && TIMEOUT=300; START_TIME=$(date +%s); for worker in $(echo \"$LWS_LEADER_ADDRESS,$(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-1\\./')\" | tr ',' ' '); do echo \"Waiting for DNS: $worker\"; until getent hosts $worker >/dev/null 2>&1; do CURRENT_TIME=$(date +%s); if [ $((CURRENT_TIME - START_TIME)) -gt $TIMEOUT ]; then echo \"ERROR: Timeout waiting for DNS: $worker\"; exit 1; fi; echo \"DNS not ready for $worker, retrying...\"; sleep 2; done; echo \"✓ DNS resolved: $worker\"; done; echo \"All workers DNS ready\" && mpirun --allow-run-as-root --oversubscribe -n 0 -H $LWS_LEADER_ADDRESS,$(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-1\\./') --mca pml ob1 --mca plm_rsh_args \"-p 2222 -o StrictHostKeyChecking=no -i ~/.ssh/id_rsa\" -x CUDA_VISIBLE_DEVICES -x HF_DATASETS_CACHE -x HF_ENDPOINT -x HF_HOME -x HF_TOKEN -x HOME -x HUGGING_FACE_HUB_TOKEN -x LD_LIBRARY_PATH -x MODEL_PATH -x NCCL_DEBUG -x NCCL_IB_DISABLE -x NCCL_P2P_DISABLE -x PATH -x PYTHONPATH -x TENSORRT_LLM_CACHE_DIR -x TOKENIZERS_PARALLELISM -x TRANSFORMERS_CACHE -x USER bash -c 'trtllm-llmapi-launch python -m worker'",
}, },
{ {
name: "Leader with both command and args (shell command - args take precedence)", name: "Leader with both command and args (shell command - args take precedence)",
......
...@@ -66,7 +66,7 @@ func TestVLLMBackend_UpdateContainer(t *testing.T) { ...@@ -66,7 +66,7 @@ func TestVLLMBackend_UpdateContainer(t *testing.T) {
multinodeDeployer: &LWSMultinodeDeployer{}, multinodeDeployer: &LWSMultinodeDeployer{},
initialContainer: &corev1.Container{Args: []string{"python3", "-m", "dynamo.vllm", tensorParallelSizeFlag, "8"}}, initialContainer: &corev1.Container{Args: []string{"python3", "-m", "dynamo.vllm", tensorParallelSizeFlag, "8"}},
gpuCount: 4, gpuCount: 4,
expectedArgs: []string{"ray start --address=$(LWS_LEADER_ADDRESS):6379 --block"}, expectedArgs: []string{"ray start --address=$LWS_LEADER_ADDRESS:6379 --block"},
expectProbesRemoved: true, expectProbesRemoved: true,
}, },
{ {
...@@ -176,7 +176,7 @@ func TestVLLMBackend_ShellCommandInjection(t *testing.T) { ...@@ -176,7 +176,7 @@ func TestVLLMBackend_ShellCommandInjection(t *testing.T) {
multinodeDeployer: &LWSMultinodeDeployer{}, multinodeDeployer: &LWSMultinodeDeployer{},
initialContainer: &corev1.Container{Command: []string{"sh", "-c"}, Args: []string{fmt.Sprintf("python3 -m dynamo.vllm %s 8", dataParallelSizeFlag)}}, initialContainer: &corev1.Container{Command: []string{"sh", "-c"}, Args: []string{fmt.Sprintf("python3 -m dynamo.vllm %s 8", dataParallelSizeFlag)}},
gpuCount: 4, gpuCount: 4,
expectedArgs: []string{"python3 -m dynamo.vllm --data-parallel-address $(LWS_LEADER_ADDRESS) --data-parallel-size-local 4 --data-parallel-rpc-port 13445 --data-parallel-start-rank 0 --data-parallel-size 8"}, expectedArgs: []string{"python3 -m dynamo.vllm --data-parallel-address $LWS_LEADER_ADDRESS --data-parallel-size-local 4 --data-parallel-rpc-port 13445 --data-parallel-start-rank 0 --data-parallel-size 8"},
description: "LWS shell commands should use LWS variables", description: "LWS shell commands should use LWS variables",
}, },
{ {
...@@ -386,7 +386,7 @@ func TestUpdateVLLMMultinodeArgs(t *testing.T) { ...@@ -386,7 +386,7 @@ func TestUpdateVLLMMultinodeArgs(t *testing.T) {
multinodeDeployer: &LWSMultinodeDeployer{}, multinodeDeployer: &LWSMultinodeDeployer{},
initialContainer: &corev1.Container{Args: []string{"python3", "-m", "dynamo.vllm", tensorParallelSizeFlag, "16"}}, initialContainer: &corev1.Container{Args: []string{"python3", "-m", "dynamo.vllm", tensorParallelSizeFlag, "16"}},
gpuCount: 8, gpuCount: 8,
expectedArgs: []string{"ray start --address=$(LWS_LEADER_ADDRESS):6379 --block"}, expectedArgs: []string{"ray start --address=$LWS_LEADER_ADDRESS:6379 --block"},
}, },
{ {
name: "main role does not modify args", name: "main role does not modify args",
......
...@@ -625,6 +625,7 @@ type MultinodeDeployer interface { ...@@ -625,6 +625,7 @@ type MultinodeDeployer interface {
GetLeaderHostname(serviceName string) string GetLeaderHostname(serviceName string) string
GetHostNames(serviceName string, numberOfNodes int32) []string GetHostNames(serviceName string, numberOfNodes int32) []string
GetNodeRank() (string, bool) // returns (rank, needsShellInterpretation) GetNodeRank() (string, bool) // returns (rank, needsShellInterpretation)
NeedsDNSWait() bool // returns true if DNS wait is needed to launch multinode components
} }
// BackendFactory creates backend instances based on the framework type // BackendFactory creates backend instances based on the framework type
......
...@@ -33,6 +33,11 @@ func (d *GroveMultinodeDeployer) GetNodeRank() (string, bool) { ...@@ -33,6 +33,11 @@ func (d *GroveMultinodeDeployer) GetNodeRank() (string, bool) {
return "$((GROVE_PCLQ_POD_INDEX + 1))", true return "$((GROVE_PCLQ_POD_INDEX + 1))", true
} }
func (d *GroveMultinodeDeployer) NeedsDNSWait() bool {
// Grove doesn't need DNS wait - it handles startup coordination differently
return false
}
func (d *GroveMultinodeDeployer) GetHostNames(serviceName string, numberOfNodes int32) []string { func (d *GroveMultinodeDeployer) GetHostNames(serviceName string, numberOfNodes int32) []string {
hostnames := make([]string, 0, numberOfNodes) hostnames := make([]string, 0, numberOfNodes)
leaderHostname := d.GetLeaderHostname(serviceName) leaderHostname := d.GetLeaderHostname(serviceName)
......
...@@ -7,7 +7,7 @@ type LWSMultinodeDeployer struct { ...@@ -7,7 +7,7 @@ type LWSMultinodeDeployer struct {
} }
func (d *LWSMultinodeDeployer) GetLeaderHostname(serviceName string) string { func (d *LWSMultinodeDeployer) GetLeaderHostname(serviceName string) string {
return "$(LWS_LEADER_ADDRESS)" return "$LWS_LEADER_ADDRESS"
} }
func (d *LWSMultinodeDeployer) GetNodeRank() (string, bool) { func (d *LWSMultinodeDeployer) GetNodeRank() (string, bool) {
...@@ -15,11 +15,23 @@ func (d *LWSMultinodeDeployer) GetNodeRank() (string, bool) { ...@@ -15,11 +15,23 @@ func (d *LWSMultinodeDeployer) GetNodeRank() (string, bool) {
return "$(LWS_WORKER_INDEX)", true return "$(LWS_WORKER_INDEX)", true
} }
func (d *LWSMultinodeDeployer) NeedsDNSWait() bool {
// LWS needs DNS wait because pods start simultaneously and DNS may not be ready
return true
}
func (d *LWSMultinodeDeployer) GetHostNames(serviceName string, numberOfNodes int32) []string { func (d *LWSMultinodeDeployer) GetHostNames(serviceName string, numberOfNodes int32) []string {
hostnames := make([]string, numberOfNodes) hostnames := make([]string, numberOfNodes)
hostnames[0] = d.GetLeaderHostname(serviceName) hostnames[0] = d.GetLeaderHostname(serviceName)
// LWS only provides LWS_LEADER_ADDRESS, LWS_GROUP_SIZE, and LWS_WORKER_INDEX
// LWS_LEADER_ADDRESS format: <lws-name>-<group-index>-<leader-pod-index>.<service-name>.<namespace>
// Example: trtllm-disagg-tp8-decode-0-0.trtllm-disagg-tp8-decode-0.jsm
// Worker pods append their index: trtllm-disagg-tp8-decode-0-0-1, trtllm-disagg-tp8-decode-0-0-2, etc.
// We derive worker addresses by inserting -{i} before the first dot
for i := int32(1); i < numberOfNodes; i++ { for i := int32(1); i < numberOfNodes; i++ {
hostnames[i] = fmt.Sprintf("$(LWS_WORKER_%d_ADDRESS)", i) // Use sed to replace first "." with "-{i}." to append worker index
hostnames[i] = fmt.Sprintf("$(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-%d\\./')", i)
} }
return hostnames return hostnames
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment