Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
f460f4e2
Unverified
Commit
f460f4e2
authored
Apr 20, 2026
by
Tzu-Ling Kan
Committed by
GitHub
Apr 20, 2026
Browse files
feat: vLLM multinode elastic EP scaling support test (#8183)
Signed-off-by:
Tzu-Ling
<
tzulingk@nvidia.com
>
parent
1d154dcb
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
430 additions
and
0 deletions
+430
-0
tests/fault_tolerance/deploy/templates/vllm/bare_multinode_elastic_ep.yaml
...ance/deploy/templates/vllm/bare_multinode_elastic_ep.yaml
+217
-0
tests/fault_tolerance/deploy/templates/vllm/run_bare_multinode_elastic_ep_scale_test.sh
...emplates/vllm/run_bare_multinode_elastic_ep_scale_test.sh
+213
-0
No files found.
tests/fault_tolerance/deploy/templates/vllm/bare_multinode_elastic_ep.yaml
0 → 100644
View file @
f460f4e2
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Bare vLLM multinode elastic EP test — no Dynamo operator involvement.
#
# Tests whether vLLM itself supports cross-node elastic EP via a single Ray cluster.
# Scale API is vLLM's native endpoint: POST /scale_elastic_ep
#
# Warm-standby topology:
# - Leader pod (Node 1): 2 GPUs active (dp=2). Starts Ray head, then vLLM.
# - Worker pod (Node 2): 2 GPUs idle. Waits for leader vLLM /health before joining Ray.
#
# The worker deliberately delays joining Ray until vLLM is fully serving. This ensures
# vLLM initialises its dp=2 workers only on Node 1. Once the worker joins, its GPUs
# appear idle in ray.available_resources_per_node() and are claimed on scale-up.
#
# Scale sequence (run_elastic_ep_scale_test_bare_multinode.sh):
# Baseline dp=2 → dp=3 → dp=4 → dp=3 → dp=2 → dp=4 → dp=2
#
# If scale-up succeeds and nvidia-smi on worker shows GPU memory used → multinode works.
# If scale-up fails or worker GPUs stay idle → vLLM does not support cross-node elastic EP.
---
# Headless service so leader is reachable by DNS from worker pod
apiVersion
:
v1
kind
:
Service
metadata
:
name
:
vllm-ep-leader
spec
:
clusterIP
:
None
selector
:
app
:
vllm-ep-leader
ports
:
-
name
:
ray
port
:
6379
-
name
:
serve
port
:
8000
---
apiVersion
:
v1
kind
:
Pod
metadata
:
name
:
vllm-ep-leader
labels
:
app
:
vllm-ep-leader
spec
:
tolerations
:
-
key
:
nvidia.com/gpu
operator
:
Exists
effect
:
NoSchedule
nodeSelector
:
kubernetes.io/hostname
:
aks-a100a-36888584-vmss000003
imagePullSecrets
:
-
name
:
nvcr-imagepullsecret
volumes
:
-
name
:
model-cache
persistentVolumeClaim
:
claimName
:
model-cache
-
name
:
shm
emptyDir
:
medium
:
Memory
sizeLimit
:
8Gi
containers
:
-
name
:
main
image
:
nvcr.io/nvidian/dynamo-dev/tzulingk-vllm:main
imagePullPolicy
:
Always
resources
:
limits
:
nvidia.com/gpu
:
"
2"
env
:
-
name
:
HF_HOME
value
:
/model-cache
-
name
:
HF_HUB_ENABLE_HF_TRANSFER
value
:
"
1"
-
name
:
VLLM_ALL2ALL_BACKEND
value
:
allgather_reducescatter
-
name
:
VLLM_USE_ELASTIC_EP
value
:
"
1"
-
name
:
VLLM_USE_V1
value
:
"
1"
-
name
:
VLLM_WORKER_MULTIPROC_METHOD
value
:
spawn
-
name
:
CUDA_VISIBLE_DEVICES
value
:
"
0,1"
# ── NCCL cross-node configuration ──────────────────────────────────────
# NCCL_DEBUG/NCCL_DEBUG_SUBSYS: log transport selection and communicator
# init so we can confirm TCP socket is chosen. Remove once confirmed.
-
name
:
NCCL_DEBUG
value
:
"
INFO"
-
name
:
NCCL_DEBUG_SUBSYS
value
:
"
NET,INIT"
# NCCL_IB_DISABLE=1: disable InfiniBand even though /dev/infiniband/
# uverbs0-8 are mounted in both pods. Inspecting the GID table on
# 2026-04-14 via /sys/class/infiniband/*/ports/*/gids/* showed that every
# mlx5_* device only has a valid (non-zero) GID at index 0, and it is a
# fe80:: link-local IPv6 address. Link-local addresses are not routable
# between pods on different Kubernetes nodes — they only work within a
# single L2 segment. No RoCE v2 or globally routable GIDs exist, so NCCL
# cannot establish cross-node IB connections here. Without this flag, NCCL
# tries IB by default, fails with "unhandled system error", and aborts.
-
name
:
NCCL_IB_DISABLE
value
:
"
1"
# NCCL_SOCKET_IFNAME=eth0: with IB disabled NCCL falls back to TCP socket
# transport. This tells it which interface to bind to. Confirmed via
# /proc/net/if_inet6 inside both pods: eth0 is the only non-loopback
# interface and carries the routable pod IPs (leader: 10.244.3.73,
# worker: 10.244.7.32) that the Kubernetes CNI overlay network can route.
-
name
:
NCCL_SOCKET_IFNAME
value
:
"
eth0"
envFrom
:
-
secretRef
:
name
:
hf-token-secret
volumeMounts
:
-
name
:
model-cache
mountPath
:
/model-cache
-
name
:
shm
mountPath
:
/dev/shm
command
:
[
"
/bin/sh"
,
"
-c"
]
args
:
-
|
set -e
echo "=== Installing Ray ==="
pip install -q "ray[default]"
echo "=== Starting Ray head ==="
ray start --head --port=6379 --block &
RAY_HEAD_PID=$!
# Wait until Ray head is actually accepting connections before launching vLLM.
# A fixed sleep is a race — slow startups can leave vLLM unable to find Ray.
RAY_ELAPSED=0
until nc -z localhost 6379 2>/dev/null; do
if [ "$RAY_ELAPSED" -ge 120 ]; then
echo "ERROR: Ray head did not start within 120s (pid=$RAY_HEAD_PID)" >&2
exit 1
fi
echo " Ray head not ready yet, retrying in 2s... (${RAY_ELAPSED}s elapsed)"
sleep 2
RAY_ELAPSED=$((RAY_ELAPSED + 2))
done
echo "=== Ray head started (pid=$RAY_HEAD_PID), launching vllm serve ==="
# --data-parallel-size-local is intentionally omitted. The worker pod delays
# joining Ray until vLLM is ready, so only this node is in the Ray cluster
# when create_dp_placement_groups runs. Both dp=2 workers land here.
vllm serve deepseek-ai/DeepSeek-V2-Lite \
--trust-remote-code \
--tensor-parallel-size 1 \
--data-parallel-size 2 \
--data-parallel-backend ray \
--gpu-memory-utilization 0.8 \
--max-model-len 4096 \
--enable-expert-parallel \
--enable-elastic-ep \
--enable-eplb \
--eplb-config.num_redundant_experts 0 \
--no-enable-prefix-caching \
--enforce-eager \
--port 8000
---
apiVersion
:
v1
kind
:
Pod
metadata
:
name
:
vllm-ep-worker
spec
:
tolerations
:
-
key
:
nvidia.com/gpu
operator
:
Exists
effect
:
NoSchedule
nodeSelector
:
kubernetes.io/hostname
:
aks-a100b-22138447-vmss000000
imagePullSecrets
:
-
name
:
nvcr-imagepullsecret
volumes
:
-
name
:
shm
emptyDir
:
medium
:
Memory
sizeLimit
:
8Gi
containers
:
-
name
:
main
image
:
nvcr.io/nvidian/dynamo-dev/tzulingk-vllm:main
imagePullPolicy
:
Always
resources
:
limits
:
nvidia.com/gpu
:
"
2"
env
:
-
name
:
CUDA_VISIBLE_DEVICES
value
:
"
0,1"
# ── NCCL cross-node configuration ──────────────────────────────────────
# Must match the leader pod exactly — NCCL requires all participating
# processes to use the same transport. See leader pod env block for the
# full rationale behind each variable.
-
name
:
NCCL_DEBUG
value
:
"
INFO"
-
name
:
NCCL_DEBUG_SUBSYS
value
:
"
NET,INIT"
# IB GIDs are fe80:: link-local only — not routable cross-node. Disable IB.
-
name
:
NCCL_IB_DISABLE
value
:
"
1"
# eth0 confirmed as the only non-loopback interface in this pod (10.244.7.32).
-
name
:
NCCL_SOCKET_IFNAME
value
:
"
eth0"
volumeMounts
:
-
name
:
shm
mountPath
:
/dev/shm
command
:
[
"
/bin/sh"
,
"
-c"
]
args
:
-
|
set -e
LEADER_URL="http://vllm-ep-leader:8000"
echo "=== Installing Ray ==="
pip install -q "ray[default]"
echo "=== Waiting for leader vLLM to be ready before joining Ray ==="
# Deliberately wait until vLLM is fully serving on the leader. This ensures
# vLLM initialises dp=2 with only Node 1 in the Ray cluster, so both initial
# DP workers land on Node 1. Our GPUs then appear idle for elastic EP scale-up.
until curl -sf "${LEADER_URL}/health" > /dev/null 2>&1; do
echo " leader not ready yet, retrying in 15s..."
sleep 15
done
echo "=== Leader vLLM ready — joining Ray cluster ==="
ray start --address=vllm-ep-leader:6379 --block
tests/fault_tolerance/deploy/templates/vllm/run_bare_multinode_elastic_ep_scale_test.sh
0 → 100755
View file @
f460f4e2
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Bare vLLM Multi-Node Elastic EP Scale Test
#
# Companion script for bare_multinode_elastic_ep.yaml.
# Tests cross-node elastic EP scaling using vLLM's native API — no Dynamo involved.
#
# Warm-standby topology at baseline:
# vllm-ep-leader (Node 1): 2 GPUs active (dp ranks 0, 1)
# vllm-ep-worker (Node 2): 2 GPUs idle in Ray cluster, claimed on scale-up
#
# Scale sequence:
# Baseline dp=2 → dp=3 → dp=4 → dp=3 → dp=2 → dp=4 → dp=2
#
# Key differences from dynamo scripts:
# - Single port-forward on 8000 (vLLM serves both inference AND scale API)
# - Scale API path: POST /scale_elastic_ep (no /engine/ prefix, no port 9090)
# - Pods addressed by name, not label selector
# - No frontend pod, no dynamoworkermetadata patch
# - Inference response parsed as plain OpenAI (no nvext.timing field)
#
# Usage:
# ./run_bare_multinode_elastic_ep_scale_test.sh [NAMESPACE]
#
# Defaults:
# NAMESPACE = tzulingk-multinode-elastic
#
# Prerequisites:
# - kubectl configured and pointing at the right cluster
# - Deployment already applied: kubectl apply -f bare_multinode_elastic_ep.yaml -n <NS>
# - Port 8001 free on localhost
set
-uo
pipefail
NS
=
"
${
1
:-
tzulingk
-multinode-elastic
}
"
LEADER_POD
=
"vllm-ep-leader"
WORKER_POD
=
"vllm-ep-worker"
MODEL
=
"deepseek-ai/DeepSeek-V2-Lite"
echo
"Namespace:
$NS
"
echo
"Leader pod:
$LEADER_POD
"
echo
"Worker pod:
$WORKER_POD
"
echo
"Model:
$MODEL
"
echo
""
# ── Verify pods exist ─────────────────────────────────────────────────────────
for
pod
in
"
$LEADER_POD
"
"
$WORKER_POD
"
;
do
phase
=
$(
kubectl get pod
"
$pod
"
-n
"
$NS
"
-o
jsonpath
=
'{.status.phase}'
2>/dev/null
)
if
[
-z
"
$phase
"
]
;
then
echo
"ERROR: pod
$pod
not found in namespace
$NS
"
>
&2
exit
1
fi
echo
"Pod
$pod
: phase=
$phase
"
done
echo
""
# ── Wait for leader pod ready ─────────────────────────────────────────────────
echo
"=== Waiting for leader pod to be Ready ==="
kubectl
wait
pod/
"
$LEADER_POD
"
-n
"
$NS
"
--for
=
condition
=
Ready
--timeout
=
900s
echo
"Ready at
$(
date
-u
+%Y-%m-%dT%H:%M:%SZ
)
"
# ── Port-forward (auto-restarting) ────────────────────────────────────────────
# Port 8000 on the leader serves both inference (/v1/completions) and the elastic
# EP scale API (/scale_elastic_ep). A single port-forward covers both.
# vLLM is not on port 8000 until the model finishes loading, so the port-forward
# will fail and restart repeatedly until vLLM is up — that is expected.
pkill
-f
"port-forward.*8001:8000"
2>/dev/null
||
true
sleep
1
(
while
true
;
do
kubectl port-forward pod/
"
$LEADER_POD
"
8001:8000
-n
"
$NS
"
2>&1
sleep
2
done
)
&
PF
=
$!
echo
"Port-forward: auto-restarting loop pid=
$PF
localhost:8001 →
$LEADER_POD
:8000"
sleep
3
# ── Wait for inference endpoint ───────────────────────────────────────────────
echo
"=== Waiting for inference endpoint ==="
ENDPOINT_READY
=
0
for
i
in
$(
seq
1 120
)
;
do
CODE
=
$(
curl
-s
-o
/dev/null
-w
"%{http_code}"
-m
5 http://localhost:8001/health 2>/dev/null
)
if
[
"
$CODE
"
=
"200"
]
;
then
echo
"Endpoint ready (checked after ~
$((
i
*
5
))
s)"
ENDPOINT_READY
=
1
break
fi
sleep
5
done
if
[
"
$ENDPOINT_READY
"
=
"0"
]
;
then
echo
"ERROR: inference endpoint never became ready after 600s"
>
&2
kill
$PF
2>/dev/null
exit
1
fi
# ── Helpers ───────────────────────────────────────────────────────────────────
# Blocks until the worker node appears in the Ray cluster (2 active nodes).
# The worker pod runs `ray start --address=...` only after vLLM health=200.
# On first pod start, Python bytecode compilation can silently delay this by
# up to 10 minutes — scaling before the worker is in Ray guarantees failure.
wait_worker_in_ray
()
{
local timeout
=
"
${
1
:-
900
}
"
local
interval
=
15
local
elapsed
=
0
echo
""
echo
"=== Waiting for worker node to join Ray cluster (need 2 active nodes) ==="
while
[
"
$elapsed
"
-lt
"
$timeout
"
]
;
do
STATUS
=
$(
kubectl
exec
pod/
"
$LEADER_POD
"
-n
"
$NS
"
--
ray status 2>/dev/null
)
NODES
=
$(
echo
"
$STATUS
"
|
awk
'/^Active:/{p=1;next} /^Pending:/{p=0} p && /node_/{c++} END{print c+0}'
)
if
[
"
${
NODES
:-
0
}
"
-ge
2
]
;
then
echo
"Worker joined Ray (
$NODES
active nodes) at
$(
date
-u
+%Y-%m-%dT%H:%M:%SZ
)
"
echo
"
$STATUS
"
|
awk
'/Resources/,/Pending Demands/'
|
head
-6
return
0
fi
echo
"
${
NODES
:-
0
}
/2 nodes in Ray, retrying in
${
interval
}
s... (
${
elapsed
}
s elapsed)"
sleep
"
$interval
"
elapsed
=
$((
elapsed
+
interval
))
done
echo
"ERROR: worker never joined Ray after
${
timeout
}
s"
>
&2
exit
1
}
# Captures nvidia-smi from both pods so we can see which node's GPUs are active.
snapshot
()
{
local
label
=
"
$1
"
echo
""
for
pod
in
"
$LEADER_POD
"
"
$WORKER_POD
"
;
do
node
=
$(
kubectl get pod
"
$pod
"
-n
"
$NS
"
-o
jsonpath
=
'{.spec.nodeName}'
2>/dev/null
)
echo
"--- nvidia-smi (
$label
) pod=
$pod
node=
$node
---"
kubectl
exec
"
$pod
"
-n
"
$NS
"
--
\
nvidia-smi
--query-gpu
=
index,memory.used,memory.free,utilization.gpu
--format
=
csv 2>&1
echo
"--- Ray actors (
$label
) pod=
$pod
---"
kubectl
exec
"
$pod
"
-n
"
$NS
"
--
ps aux 2>&1
\
|
awk
'/DPMoEEngineCoreActor|RayWorkerWrapper/{printf "PID=%-8s CMD=%s\n", $2, $11}'
done
}
infer
()
{
local
label
=
"
$1
"
echo
""
echo
"--- inference (
$label
) ---"
RESP
=
$(
curl
-fsS
-m
30 http://localhost:8001/v1/completions
\
-H
"Content-Type: application/json"
\
-d
"{
\"
model
\"
:
\"
$MODEL
\"
,
\"
prompt
\"
:
\"
2+2=
\"
,
\"
max_tokens
\"
:5,
\"
temperature
\"
:0}"
)
||
{
echo
"ERROR: inference request failed"
>
&2
return
1
}
# Plain OpenAI response — no nvext field in bare vLLM
echo
"
$RESP
"
| python3
-c
"
import sys, json
d = json.load(sys.stdin)
text = d['choices'][0]['text'].strip()
tokens = d.get('usage', {})
print('text:', repr(text), ' usage:', tokens)
"
2>/dev/null
||
{
echo
"ERROR: invalid inference response:
$RESP
"
>
&2
return
1
}
}
# Calls vLLM's native scale endpoint on port 8000 (same port as inference).
# NOTE: bare vLLM uses /scale_elastic_ep directly, NOT /engine/scale_elastic_ep.
scale
()
{
local
from_dp
=
"
$1
"
local
to_dp
=
"
$2
"
local timeout
=
"
${
3
:-
700
}
"
echo
""
echo
"=========================================="
echo
"SCALE dp=
$from_dp
→ dp=
$to_dp
at
$(
date
-u
+%Y-%m-%dT%H:%M:%SZ
)
"
echo
"=========================================="
echo
"--- POST /scale_elastic_ep {
\"
new_data_parallel_size
\"
:
$to_dp
} ---"
RESP
=
$(
curl
-fsS
-X
POST http://localhost:8001/scale_elastic_ep
\
-H
"Content-Type: application/json"
\
-d
"{
\"
new_data_parallel_size
\"
:
$to_dp
}"
\
--max-time
"
$timeout
"
)
||
{
echo
"ERROR: scale_elastic_ep request failed"
>
&2
return
1
}
echo
"--- response ---"
echo
"
$RESP
"
snapshot
"after dp=
$to_dp
"
infer
"dp=
$to_dp
"
}
# ── Baseline ──────────────────────────────────────────────────────────────────
# Expected: leader pod shows 2 GPUs active, worker pod shows 2 GPUs idle (low memory.used)
echo
""
echo
"=========================================="
echo
"BASELINE dp=2 at
$(
date
-u
+%Y-%m-%dT%H:%M:%SZ
)
"
echo
"=========================================="
snapshot
"baseline dp=2"
infer
"dp=2"
||
exit
1
# ── Wait for worker in Ray before any scale step ──────────────────────────────
# Do this after baseline so the baseline snapshot/inference runs while the
# worker's `ray start` is still warming up in the background.
wait_worker_in_ray 900
# ── 6 scale steps ─────────────────────────────────────────────────────────────
scale 2 3 700
||
exit
1
# step 1: dp=2 → dp=3 (Ray places 1 actor on worker node)
scale 3 4 700
||
exit
1
# step 2: dp=3 → dp=4 (Ray places 1 more actor on worker node)
scale 4 3 300
||
exit
1
# step 3: dp=4 → dp=3 (removes highest rank from worker node)
scale 3 2 300
||
exit
1
# step 4: dp=3 → dp=2 (worker node back to idle)
scale 2 4 700
||
exit
1
# step 5: dp=2 → dp=4 (both worker node GPUs claimed)
scale 4 2 300
||
exit
1
# step 6: dp=4 → dp=2 (worker node back to warm standby)
echo
""
echo
"=== ALL STEPS COMPLETE at
$(
date
-u
+%Y-%m-%dT%H:%M:%SZ
)
==="
kill
$PF
2>/dev/null
||
true
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment