Unverified Commit 29bd4c81 authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[CI] Add CI Testing for Prefill-Decode Disaggregation with Router (#7540)

parent 031f64aa
name: Test Disaggregation Mode
on:
push:
branches: [ main ]
paths:
- 'python/sglang/srt/disaggregation/**'
- 'scripts/ci_start_disaggregation_servers.sh'
- 'sgl-router/**'
pull_request:
branches: [ main ]
paths:
- 'python/sglang/srt/disaggregation/**'
- 'scripts/ci_start_disaggregation_servers.sh'
- 'sgl-router/**'
workflow_dispatch:
concurrency:
group: test-disaggregation-${{ github.ref }}
cancel-in-progress: true
permissions:
contents: read
pull-requests: write
issues: write
jobs:
test-disaggregation:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on: [h200]
timeout-minutes: 45
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 10
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: '3.11'
- name: Setup Rust
run: |
bash scripts/ci_install_rust.sh
- name: Cache Rust dependencies
uses: actions/cache@v4
with:
path: |
~/.cargo/bin/
~/.cargo/registry/index/
~/.cargo/registry/cache/
~/.cargo/git/db/
sgl-router/target/
key: ${{ runner.os }}-cargo-${{ hashFiles('sgl-router/Cargo.lock') }}
restore-keys: |
${{ runner.os }}-cargo-
- name: Cache pip dependencies
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('python/pyproject.toml') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Validate environment
run: |
echo "=== System Validation ==="
nvidia-smi
echo "GPU count: $(nvidia-smi -L | wc -l)"
if [ $(nvidia-smi -L | wc -l) -lt 8 ]; then
echo "Error: This test requires at least 8 GPUs"
exit 1
fi
echo "=== RDMA Validation ==="
if ! command -v ibv_devices >/dev/null 2>&1; then
echo "Error: InfiniBand tools not found"
exit 1
fi
# Check for active IB devices
found_active_device=false
for device in mlx5_{0..11}; do
if ibv_devinfo $device >/dev/null 2>&1; then
state=$(ibv_devinfo $device | grep "state:" | head -1 | awk '{print $2}')
if [[ "$state" == "PORT_ACTIVE" ]]; then
echo "✓ Found active device: $device"
found_active_device=true
break
fi
fi
done
if [ "$found_active_device" = false ]; then
echo "Error: No active IB devices found"
echo "Available devices:"
ibv_devices || true
exit 1
fi
echo "=== Model Validation ==="
if [ ! -d "/raid/models/meta-llama/Llama-3.1-8B-Instruct" ]; then
echo "Error: Model not found"
ls -la /raid/models/ || echo "No models directory"
exit 1
fi
echo "✓ Model found"
- name: Install SGLang dependencies
run: |
echo "Installing SGLang with all extras..."
python3 -m pip --no-cache-dir install -e "python[all]" --break-system-packages
python3 -m pip --no-cache-dir install mooncake-transfer-engine==0.3.4.post1
- name: Build and install sgl-router
run: |
source "$HOME/.cargo/env"
echo "Building sgl-router..."
cd sgl-router
cargo build && python3 -m build && pip install --force-reinstall dist/*.whl
- name: Start disaggregation servers
id: start_servers
run: |
echo "Starting disaggregation servers..."
bash scripts/ci_start_disaggregation_servers.sh &
SERVER_PID=$!
echo "server_pid=$SERVER_PID" >> $GITHUB_OUTPUT
echo "Waiting for router to become healthy..."
TIMEOUT=300
ELAPSED=0
while [ $ELAPSED -lt $TIMEOUT ]; do
if curl --connect-timeout 5 --silent http://127.0.0.9:8000 > /dev/null 2>&1; then
echo "✓ Router is reachable"
break
fi
if ! ps -p $SERVER_PID > /dev/null; then
echo "Error: Server processes failed to start"
exit 1
fi
echo "Waiting for router... (${ELAPSED}s/${TIMEOUT}s)"
sleep 10
ELAPSED=$((ELAPSED + 10))
done
if [ $ELAPSED -ge $TIMEOUT ]; then
echo "Error: Router health check timeout after ${TIMEOUT}s"
exit 1
fi
echo "✓ Servers started and healthy (PID: $SERVER_PID)"
- name: Test API functionality
timeout-minutes: 5
run: |
BASE_URL="http://127.0.0.9:8000"
echo "Testing API completions..."
response=$(curl -s -X POST "$BASE_URL/v1/chat/completions" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer test-token" \
-d '{
"model": "/raid/models/meta-llama/Llama-3.1-8B-Instruct",
"messages": [
{"role": "user", "content": "Write a Python function to calculate fibonacci numbers recursively"}
],
"stream": false,
"max_tokens": 100
}')
if echo "$response" | jq -e '.choices[0].message.content' > /dev/null 2>&1; then
echo "✓ API test passed"
else
echo "✗ API test failed: $response"
exit 1
fi
echo "Testing streaming API..."
stream_response=$(timeout 30 curl -s -X POST "$BASE_URL/v1/chat/completions" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer test-token" \
-d '{
"model": "/raid/models/meta-llama/Llama-3.1-8B-Instruct",
"messages": [
{"role": "user", "content": "Count from 1 to 5"}
],
"stream": true,
"max_tokens": 50
}')
if echo "$stream_response" | grep -q "data:"; then
echo "✓ Streaming API test passed"
else
echo "✗ Streaming API test failed"
exit 1
fi
- name: Run benchmark test
timeout-minutes: 5
run: |
echo "Running benchmark test..."
benchmark_output=$(python3 -m sglang.bench_one_batch_server \
--model-path "/raid/models/meta-llama/Llama-3.1-8B-Instruct" \
--base-url "http://127.0.0.9:8000" \
--batch-size 8 \
--input-len 4096 \
--output-len 5 \
--skip-warmup)
echo "$benchmark_output"
# Extract metrics from output
latency=$(echo "$benchmark_output" | grep "latency:" | awk '{print $2}' | sed 's/s//')
input_throughput=$(echo "$benchmark_output" | grep "input throughput:" | awk '{print $3}')
output_throughput=$(echo "$benchmark_output" | grep "output throughput:" | awk '{print $3}')
# Validate performance (latency<1.5s, input>20k, output>1k)
command -v bc >/dev/null || (apt-get update && apt-get install -y bc)
echo "Performance: ${latency}s | ${input_throughput} | ${output_throughput} tok/s"
fail=""
(( $(echo "$latency > 1.5" | bc -l) )) && fail="Latency too high (${latency}s>1.5s) "
(( $(echo "$input_throughput < 20000" | bc -l) )) && fail="${fail}Input too low (${input_throughput}<20k) "
(( $(echo "$output_throughput < 1000" | bc -l) )) && fail="${fail}Output too low (${output_throughput}<1k) "
if [ -n "$fail" ]; then
echo "✗ Benchmark failed: $fail"
exit 1
else
echo "✓ Performance validation passed"
fi
- name: Cleanup servers
if: always()
run: |
if [ -n "${{ steps.start_servers.outputs.server_pid }}" ]; then
pkill -P ${{ steps.start_servers.outputs.server_pid }} || true
kill ${{ steps.start_servers.outputs.server_pid }} || true
fi
pkill -f "sglang.launch_server" || true
sleep 5
remaining=$(ps aux | grep -c "sglang.launch_server" || echo "0")
echo "Cleanup completed. Remaining processes: $remaining"
...@@ -18,6 +18,7 @@ dependencies = ["aiohttp", "requests", "tqdm", "numpy", "IPython", "setproctitle ...@@ -18,6 +18,7 @@ dependencies = ["aiohttp", "requests", "tqdm", "numpy", "IPython", "setproctitle
[project.optional-dependencies] [project.optional-dependencies]
runtime_common = [ runtime_common = [
"blobfile==3.0.0", "blobfile==3.0.0",
"build",
"compressed-tensors", "compressed-tensors",
"datasets", "datasets",
"fastapi", "fastapi",
......
#!/bin/bash
MODEL_PATH="/raid/models/meta-llama/Llama-3.1-8B-Instruct"
# Function to find the first available active IB device
find_active_ib_device() {
for device in mlx5_{0..11}; do
if ibv_devinfo $device >/dev/null 2>&1; then
state=$(ibv_devinfo $device | grep "state:" | head -1 | awk '{print $2}')
if [[ "$state" == "PORT_ACTIVE" ]]; then
echo "$device"
return 0
fi
fi
done
echo "No active IB device found" >&2
return 1
}
# Get the first available active IB device
DEVICE=$(find_active_ib_device)
echo "Using IB device: $DEVICE"
# Launch prefill servers on GPU 0–3
for i in {0..3}; do
PORT=$((30001 + i))
BOOTSTRAP_PORT=$((9001 + i))
HOST="127.0.0.$((i + 1))"
echo "Launching PREFILL server on GPU $i at $HOST:$PORT (bootstrap: $BOOTSTRAP_PORT)"
CUDA_VISIBLE_DEVICES=$i \
python3 -m sglang.launch_server \
--model-path "$MODEL_PATH" \
--disaggregation-mode prefill \
--host "$HOST" \
--port "$PORT" \
--disaggregation-ib-device "$DEVICE" \
--disaggregation-bootstrap-port "$BOOTSTRAP_PORT" &
done
# Launch decode servers on GPU 4–7
for i in {4..7}; do
PORT=$((30001 + i))
HOST="127.0.0.$((i + 1))"
echo "Launching DECODE server on GPU $i at $HOST:$PORT"
CUDA_VISIBLE_DEVICES=$i \
python3 -m sglang.launch_server \
--model-path "$MODEL_PATH" \
--disaggregation-mode decode \
--host "$HOST" \
--port "$PORT" \
--disaggregation-ib-device "$DEVICE" \
--base-gpu-id 0 &
done
# Wait for disaggregation servers to initialize
echo "Waiting for disaggregation servers to initialize..."
# Health check with 5-minute timeout
TIMEOUT=300
START_TIME=$(date +%s)
echo "Checking health of all 8 servers..."
while true; do
CURRENT_TIME=$(date +%s)
ELAPSED=$((CURRENT_TIME - START_TIME))
if [ $ELAPSED -ge $TIMEOUT ]; then
echo "❌ Timeout: Servers did not become healthy within 5 minutes"
exit 1
fi
HEALTHY_COUNT=0
# Check all 8 servers (127.0.0.1-8:30001-30008)
for i in {1..8}; do
if curl -s -f "http://127.0.0.$i:$((30000 + i))/health" >/dev/null 2>&1; then
HEALTHY_COUNT=$((HEALTHY_COUNT + 1))
fi
done
echo "Healthy servers: $HEALTHY_COUNT/8 (elapsed: ${ELAPSED}s)"
if [ $HEALTHY_COUNT -eq 8 ]; then
echo "✅ All 8 servers are healthy!"
break
else
sleep 10 # Wait 10 seconds before next check
fi
done
# Launch the router
echo "Launching router at 127.0.0.9:8000..."
python3 -m sglang_router.launch_router \
--pd-disaggregation \
--policy power_of_two \
--prefill http://127.0.0.1:30001 9001 \
--prefill http://127.0.0.2:30002 9002 \
--prefill http://127.0.0.3:30003 9003 \
--prefill http://127.0.0.4:30004 9004 \
--decode http://127.0.0.5:30005 \
--decode http://127.0.0.6:30006 \
--decode http://127.0.0.7:30007 \
--decode http://127.0.0.8:30008 \
--host 127.0.0.9 \
--port 8000 &
wait # Wait for all background jobs to finish
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment