gpu_utils.md 5.41 KB
Newer Older
1
# GPU Memory Control
2

3
4
How vLLM, SGLang, and TensorRT-LLM allocate GPU memory, and how we override
it for deterministic parallel test execution.
5
6
7

---

8
## Why absolute caps, not fractions
9

10
11
Memory fractions (`--gpu-memory-utilization`, `--mem-fraction-static`) are
unreliable for parallel / CI workloads:
12

13
14
15
16
17
18
19
- **Non-deterministic** — same fraction produces different KV cache sizes
  depending on what else is on the GPU at init time.
- **Profiling race** — concurrent engines each see "nearly all memory free",
  allocate based on that, and OOM.
- **Not portable** — a fraction tuned for 48 GiB is wrong on 24 or 80 GiB.
- **Different semantics** — vLLM/SGLang use fraction of *total* VRAM;
  TensorRT-LLM uses fraction of *free* VRAM after model load.
20

21
Instead, we use **absolute KV cache caps**:
22

23
24
25
26
| Engine | Deterministic override | Env var |
|--------|----------------------|---------|
| vLLM | `--kv-cache-memory-bytes N` | `_PROFILE_OVERRIDE_VLLM_KV_CACHE_BYTES` |
| SGLang | `--max-total-tokens N` | `_PROFILE_OVERRIDE_SGLANG_MAX_TOTAL_TOKENS` |
27
| TensorRT-LLM | `--override-engine-args '{"kv_cache_config":{"max_tokens":N}}'` | `_PROFILE_OVERRIDE_TRTLLM_MAX_TOTAL_TOKENS` |
28
29
30

---

31
## Quick Reference
32

33
34
35
36
37
38
| | vLLM | SGLang | TensorRT-LLM |
|---|---|---|---|
| Fraction flag | `--gpu-memory-utilization` | `--mem-fraction-static` | `free_gpu_memory_fraction` |
| Fraction base | Total VRAM | Total VRAM | Free VRAM (post-load) |
| Default | 0.90 | 0.90 | 0.90 |
| Max seq len | `--max-model-len` | `--context-length` | `max_seq_len` |
39
| KV cache override | `--kv-cache-memory-bytes` | `--max-total-tokens` | `KvCacheConfig.max_tokens` via `--override-engine-args` |
40
41
42

---

43
## Per-Engine Notes
44

45
### vLLM
46

47
48
`--gpu-memory-utilization` sets a budget as fraction of total VRAM.
KV cache = budget - weights - activations - overhead. Pool is fixed at startup.
49

50
51
52
53
54
`--kv-cache-memory-bytes` overrides automatic sizing and **skips memory
profiling** ([PR #21489]). The KV cache is pinned to the exact byte value —
no profiling race, no CUDAGraph estimation errors, safe for concurrent
instances ([#10643]). When set, `--gpu-memory-utilization` only affects
headroom for activations, not KV cache size.
55

56
57
`--max-model-len` caps sequence length. Reducing it is the fastest way to
cut VRAM when the model fits but KV cache doesn't.
58

59
60
[PR #21489]: https://github.com/vllm-project/vllm/pull/21489
[#10643]: https://github.com/vllm-project/vllm/issues/10643
61

62
### SGLang
63

64
65
66
`--mem-fraction-static` sets a budget as fraction of total VRAM.
KV cache pool = budget - weights. Activations and CUDA graph buffers are
*outside* this budget (unlike vLLM).
67

68
69
`--max-total-tokens` caps the KV token pool directly, regardless of fraction.
When set, the token cap is the binding constraint.
70

71
72
`--context-length` and `--max-running-requests` affect request scheduling
only — they do **not** change KV cache allocation.
73

74
### TensorRT-LLM
75

76
77
`free_gpu_memory_fraction` is a fraction of **free** VRAM after model load.
Set via YAML or `--override-engine-args '{"kv_cache_config":{"free_gpu_memory_fraction": 0.24}}'`.
78

79
80
81
82
83
84
Deterministic KV cache control uses `build_trtllm_override_args_with_mem` in
`gpu_utils.sh`, which builds JSON for `--override-engine-args`. Token-based
(`_PROFILE_OVERRIDE_TRTLLM_MAX_TOTAL_TOKENS`) or byte-based
(`_PROFILE_OVERRIDE_TRTLLM_MAX_GPU_TOTAL_BYTES`) caps are supported. If the
launch script already passes `--override-engine-args`, the function merges
the GPU config into the existing JSON via `--merge-with-json`.
85
86
87

---

88
## Engine-Specific GPU Memory Functions
89

90
Launch scripts source `gpu_utils.sh` and call engine-specific functions to pick
91
up env-var overrides during profiling and parallel execution:
92
93

```bash
94
source "$SCRIPT_DIR/../../../common/gpu_utils.sh"
95

96
97
# vLLM
GPU_MEM_ARGS=$(build_vllm_gpu_mem_args)
98
python -m dynamo.vllm --model "$MODEL" $GPU_MEM_ARGS &
99

100
101
# SGLang
GPU_MEM_ARGS=$(build_sglang_gpu_mem_args)
102
python -m dynamo.sglang --model-path "$MODEL" $GPU_MEM_ARGS &
103
104
105
106

# TRT-LLM (JSON merging, separate function)
OVERRIDE_JSON=$(build_trtllm_override_args_with_mem)
python -m dynamo.trtllm --model-path "$MODEL" ${OVERRIDE_JSON:+--override-engine-args "$OVERRIDE_JSON"} &
107
108
```

109
When the env var is set, the function returns the corresponding flag.
110
Otherwise it returns empty and the engine uses its default allocation.
111

112
113
114
115
116
117
| Env var | Function | Output |
|---------|----------|--------|
| `_PROFILE_OVERRIDE_VLLM_KV_CACHE_BYTES` | `build_vllm_gpu_mem_args` | `--kv-cache-memory-bytes N --gpu-memory-utilization 0.01` |
| `_PROFILE_OVERRIDE_SGLANG_MAX_TOTAL_TOKENS` | `build_sglang_gpu_mem_args` | `--max-total-tokens N` |
| `_PROFILE_OVERRIDE_TRTLLM_MAX_TOTAL_TOKENS` | `build_trtllm_override_args_with_mem` | `{"kv_cache_config": {"max_tokens": N}}` (JSON) |
| `_PROFILE_OVERRIDE_TRTLLM_MAX_GPU_TOTAL_BYTES` | `build_trtllm_override_args_with_mem` | `{"kv_cache_config": {"max_gpu_total_bytes": N}}` (JSON) |
118

119
120
121
All functions return per-process args. In multi-worker-per-GPU setups
(e.g. `disagg_same_gpu.sh`), each worker gets the same override value.
The profiler finds the per-worker budget directly.
122

123
124
**Profiler** (`profile_pytest.py`): binary-searches the KV cap to find the
minimum passing value, applies a 2x safety factor, outputs pytest markers
125
126
127
(`@pytest.mark.requested_vllm_kv_cache_bytes(N)`,
`@pytest.mark.requested_sglang_kv_tokens(N)`, or
`@pytest.mark.requested_trtllm_kv_tokens(N)`).
128

129
130
**Scheduler** (`pytest_parallel_gpu.py`): reads the markers at runtime and
sets the env var per-test. See `tests/README.md` for details.