disagg_same_gpu.sh 3.87 KB
Newer Older
1
#!/bin/bash
2
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
# SPDX-License-Identifier: Apache-2.0
#
5
# Disaggregated prefill/decode on a SINGLE GPU.
6
# Per-worker VRAM is controlled via build_sglang_gpu_mem_args (see gpu_utils.sh).
7
# Override individual knobs (CONTEXT_LENGTH, MAX_RUNNING_REQUESTS) via env vars.
8
#
9
10
11
12
13
14
# Measured reference (Qwen/Qwen3-0.6B, --context-length 4096, RTX 6000 Ada 48 GiB):
#   estimate (from gpu_utils.sh) : ~5.7 GiB per worker (w=1.1 + kv=0.9 + oh=3.7)
#   actual (nvidia-smi)          : ~5.3 GiB per worker (~10.9 GiB total)
#   fraction per worker (48 GiB)  : 0.12
#   KV cache                      : 25,536-29,712 tokens per worker
#   Handles full 4096-token context with --max-running-requests 2.
15

16
17
18
set -e
trap 'echo Cleaning up...; kill 0' EXIT

19
20
SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
source "$SCRIPT_DIR/../../../common/gpu_utils.sh"
21

22
23
24
25
26
MODEL="Qwen/Qwen3-0.6B"

# ---- Tunable (override via env vars) ----
CONTEXT_LENGTH="${CONTEXT_LENGTH:-4096}"
MAX_RUNNING_REQUESTS="${MAX_RUNNING_REQUESTS:-2}"
27
28
MAX_TOTAL_TOKENS="${MAX_TOTAL_TOKENS:-25000}"
CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}"
29

30
GPU_MEM_ARGS=$(build_sglang_gpu_mem_args)
31
32
33
if [[ -z "$GPU_MEM_ARGS" ]]; then
    GPU_MEM_ARGS="--max-total-tokens $MAX_TOTAL_TOKENS"
fi
34

35
source "$SCRIPT_DIR/../../../common/launch_utils.sh"
36

37
38
DISAGG_BOOTSTRAP_PORT="${DYN_DISAGG_BOOTSTRAP_PORT:-12345}"

39
HTTP_PORT="${DYN_HTTP_PORT:-8000}"
40
41
print_launch_banner "Launching Disaggregated (same GPU)" "$MODEL" "$HTTP_PORT" \
    "Workers:     2 (prefill + decode, fraction is per worker)"
42

43
# run ingress
44
# dynamo.frontend accepts either --http-port flag or DYN_HTTP_PORT env var (defaults to 8000)
45
python3 -m dynamo.frontend &
46

47
48
49
# NOTE: Each worker picks a random NCCL port (get_free_port) for torch.distributed.
# This has a TOCTOU race — the port can be grabbed before init_process_group binds it,
# causing sporadic EADDRINUSE.  Pass --nccl-port <unique_port> per worker to avoid this.
50
# run prefill worker with metrics on port 8081
51
CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES \
52
DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT1:-8081} \
53
python3 -m dynamo.sglang \
54
55
  --model-path "$MODEL" \
  --served-model-name "$MODEL" \
56
57
58
59
  --page-size 16 \
  --tp 1 \
  --trust-remote-code \
  --disaggregation-mode prefill \
60
  --disaggregation-bootstrap-port "$DISAGG_BOOTSTRAP_PORT" \
61
62
  --host 0.0.0.0 \
  --disaggregation-transfer-backend nixl \
63
  $GPU_MEM_ARGS \
64
65
66
  --context-length "$CONTEXT_LENGTH" \
  --chunked-prefill-size "$CONTEXT_LENGTH" \
  --max-prefill-tokens "$CONTEXT_LENGTH" \
67
68
  --enable-memory-saver \
  --delete-ckpt-after-loading \
69
  --max-running-requests "$MAX_RUNNING_REQUESTS" \
70
71
  --enable-metrics &

72
73
74
75
76
77
# Wait for prefill worker to initialize before starting decode worker.
# Both workers share one GPU with --delete-ckpt-after-loading; without this
# wait they compete for GPU memory during model loading and the scheduler OOMs.
# || true: don't let set -e kill the script on timeout (wait_for_ready returns 1).
PREFILL_SYSTEM_PORT="${DYN_SYSTEM_PORT1:-8081}"
wait_for_ready "http://localhost:${PREFILL_SYSTEM_PORT}/health" 45 || true
78

79
# run decode worker with metrics on port 8082
80
CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES \
81
DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT2:-8082} \
82
python3 -m dynamo.sglang \
83
84
  --model-path "$MODEL" \
  --served-model-name "$MODEL" \
85
86
87
88
  --page-size 16 \
  --tp 1 \
  --trust-remote-code \
  --disaggregation-mode decode \
89
  --disaggregation-bootstrap-port "$DISAGG_BOOTSTRAP_PORT" \
90
91
  --host 0.0.0.0 \
  --disaggregation-transfer-backend nixl \
92
  $GPU_MEM_ARGS \
93
94
95
  --context-length "$CONTEXT_LENGTH" \
  --chunked-prefill-size "$CONTEXT_LENGTH" \
  --max-prefill-tokens "$CONTEXT_LENGTH" \
96
97
  --enable-memory-saver \
  --delete-ckpt-after-loading \
98
  --max-running-requests "$MAX_RUNNING_REQUESTS" \
99
100
101
102
  --enable-metrics &

# Exit on first worker failure; kill 0 in the EXIT trap tears down the rest
wait_any_exit