disagg_same_gpu.sh 4.14 KB
Newer Older
1
#!/bin/bash
2
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
# SPDX-License-Identifier: Apache-2.0
#
5
6
7
8
# Disaggregated prefill/decode on a SINGLE GPU.
# Per-worker VRAM is estimated from model parameters below. Override individual
# knobs (CONTEXT_LENGTH, MAX_RUNNING_REQUESTS) via env vars, or set
# DYN_GPU_MEMORY_FRACTION_OVERRIDE to bypass the calculation entirely.
9
#
10
11
12
13
14
15
# Measured reference (Qwen/Qwen3-0.6B, --context-length 4096, RTX 6000 Ada 48 GiB):
#   estimate (from gpu_utils.sh) : ~5.7 GiB per worker (w=1.1 + kv=0.9 + oh=3.7)
#   actual (nvidia-smi)          : ~5.3 GiB per worker (~10.9 GiB total)
#   fraction per worker (48 GiB)  : 0.12
#   KV cache                      : 25,536-29,712 tokens per worker
#   Handles full 4096-token context with --max-running-requests 2.
16

17
18
19
set -e
trap 'echo Cleaning up...; kill 0' EXIT

20
21
SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
source "$SCRIPT_DIR/../../../common/gpu_utils.sh"
22

23
24
25
26
27
MODEL="Qwen/Qwen3-0.6B"

# ---- Tunable (override via env vars) ----
CONTEXT_LENGTH="${CONTEXT_LENGTH:-4096}"
MAX_RUNNING_REQUESTS="${MAX_RUNNING_REQUESTS:-2}"
28

29
30
31
32
33
34
35
36
# ---- Estimate per-worker VRAM (see examples/common/gpu_utils.md) ----
# Sets _EW_WEIGHTS_GIB, _EW_KV_GIB, _EW_OVERHEAD_GIB, _EW_TOTAL_GIB
estimate_worker_vram "$MODEL" "$CONTEXT_LENGTH" "$MAX_RUNNING_REQUESTS" sglang

# DYN_GPU_MEMORY_FRACTION_OVERRIDE takes precedence (profiler binary search).
# In single-GPU mode, split the override evenly between the two workers.
if [[ -n "${DYN_GPU_MEMORY_FRACTION_OVERRIDE:-}" ]]; then
    GPU_MEM_FRACTION=$(awk -v f="$DYN_GPU_MEMORY_FRACTION_OVERRIDE" 'BEGIN { printf "%.2f", f / 2 }')
37
else
38
    GPU_MEM_FRACTION=$(gpu_worker_fraction sglang)
39
40
fi

41
source "$SCRIPT_DIR/../../../common/launch_utils.sh"
42

43
HTTP_PORT="${DYN_HTTP_PORT:-8000}"
44
print_launch_banner "Launching Disaggregated on Same GPU" "$MODEL" "$HTTP_PORT" \
45
46
47
    "Context len: $CONTEXT_LENGTH" \
    "GPU Mem:     ${GPU_MEM_FRACTION} per worker (~${_EW_TOTAL_GIB} GiB each)" \
    "  estimate:  weights=${_EW_WEIGHTS_GIB} + kv=${_EW_KV_GIB} + overhead=${_EW_OVERHEAD_GIB} GiB"
48

49
# run ingress with KV router mode for disaggregated setup
50
51
# dynamo.frontend accepts either --http-port flag or DYN_HTTP_PORT env var (defaults to 8000)
python3 -m dynamo.frontend --router-mode kv &
52
53

# run prefill worker with metrics on port 8081
54
DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT1:-8081} \
55
python3 -m dynamo.sglang \
56
57
  --model-path "$MODEL" \
  --served-model-name "$MODEL" \
58
59
60
61
62
63
64
  --page-size 16 \
  --tp 1 \
  --trust-remote-code \
  --disaggregation-mode prefill \
  --disaggregation-bootstrap-port 12345 \
  --host 0.0.0.0 \
  --disaggregation-transfer-backend nixl \
65
66
67
68
  --mem-fraction-static "${GPU_MEM_FRACTION}" \
  --context-length "$CONTEXT_LENGTH" \
  --chunked-prefill-size "$CONTEXT_LENGTH" \
  --max-prefill-tokens "$CONTEXT_LENGTH" \
69
70
  --enable-memory-saver \
  --delete-ckpt-after-loading \
71
  --max-running-requests "$MAX_RUNNING_REQUESTS" \
72
73
74
75
76
77
78
79
80
81
82
  --enable-metrics &

# Wait for prefill worker to initialize before starting decode worker
# This prevents both workers from competing for GPU memory simultaneously, which can cause OOM.
# The prefill worker needs time to:
# 1. Load model weights and allocate its memory fraction
# 2. Initialize KV cache with --delete-ckpt-after-loading to free checkpoint memory
# 3. Register with NATS service discovery so decode worker can find it
echo "Waiting for prefill worker to initialize..."
sleep 5

83
# run decode worker with metrics on port 8082
84
DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT2:-8082} \
85
python3 -m dynamo.sglang \
86
87
  --model-path "$MODEL" \
  --served-model-name "$MODEL" \
88
89
90
91
92
93
94
  --page-size 16 \
  --tp 1 \
  --trust-remote-code \
  --disaggregation-mode decode \
  --disaggregation-bootstrap-port 12345 \
  --host 0.0.0.0 \
  --disaggregation-transfer-backend nixl \
95
96
97
98
  --mem-fraction-static "${GPU_MEM_FRACTION}" \
  --context-length "$CONTEXT_LENGTH" \
  --chunked-prefill-size "$CONTEXT_LENGTH" \
  --max-prefill-tokens "$CONTEXT_LENGTH" \
99
100
  --enable-memory-saver \
  --delete-ckpt-after-loading \
101
  --max-running-requests "$MAX_RUNNING_REQUESTS" \
102
103
104
105
  --enable-metrics &

# Exit on first worker failure; kill 0 in the EXIT trap tears down the rest
wait_any_exit