benchmark_and_profiling.md 6.84 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
3
# Benchmark and Profiling

## Benchmark
4

5
6
- Benchmark the latency of running a single static batch without a server. The arguments are the same as for `launch_server.py`.
  Note that this is a simplified test script without a dynamic batching server, so it may run out of memory for a batch size that a real server can handle. A real server truncates the prefill into several batches, while this simplified script does not.
7
8

  ```bash
9
  python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --batch 32 --input-len 256 --output-len 32
Lianmin Zheng's avatar
Lianmin Zheng committed
10
  ```
11

12
- Benchmark offline processing. This script will start an offline engine and run the benchmark.
13
14

  ```bash
15
16
  python3 -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10
  ```
17

18
- Benchmark online serving. Please use `sglang.launch_server` to launch a server first and run the following command.
19
20

  ```bash
Lianmin Zheng's avatar
Lianmin Zheng committed
21
22
23
  python3 -m sglang.bench_serving --backend sglang --num-prompt 10
  ```

24
## Profile with PyTorch Profiler
25
26
27

[Pytorch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html) is a convenient basic tool to inspect kernel execution time, call stack, and kernel overlap and occupancy.

28
29
- To profile a server

30
31
32
33
34
35
36
37
  ```bash
  # set trace path
  export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log

  # start server
  python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct

  # send profiling request from client
38
  python -m sglang.bench_serving --backend sglang --model meta-llama/Llama-3.1-8B-Instruct --num-prompts 10 --sharegpt-output-len 100 --profile
39
  ```
40

41
  Please make sure that the `SGLANG_TORCH_PROFILER_DIR` should be set at both server and client side, otherwise the trace file cannot be generated correctly . A secure way will be setting `SGLANG_TORCH_PROFILER_DIR` in the `.*rc` file of shell (e.g. `~/.bashrc` for bash shells).
42
43

- To profile offline
44
45
  ```bash
  export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log
46
47
48
49
50
51

  # profile one batch with bench_one_batch.py
  # batch size can be controlled with --batch argument
  python3 -m sglang.bench_one_batch --model-path meta-llama/Llama-3.1-8B-Instruct --batch 32 --input-len 1024 --output-len 10 --profile

  # profile multiple batches with bench_offline_throughput.py
52
53
  python -m sglang.bench_offline_throughput --model-path meta-llama/Llama-3.1-8B-Instruct --dataset-name random --num-prompts 10 --profile --mem-frac=0.8
  ```
54
55
56

- View Traces

57
  Trace files can be loaded and visualized from:
58

59
60
61
62
63
64
65
66
  1. https://ui.perfetto.dev/ (any browser)
  2. chrome://tracing (Chrome browser only)

  If browser cannot open trace file due to its large size,
  client can generate a small trace file (<100MB) by controlling number of prompts and lengths of prompt outputs.
  For example, when profiling a server,

  ```bash
67
  python -m sglang.bench_serving --backend sglang --model meta-llama/Llama-3.1-8B-Instruct --num-prompts 2 --sharegpt-output-len 100 --profile
68
69
70
  ```

  This command sets the number of prompts to 2 with `--num-prompts` argument and limits the length of output sequences to 100 with `--sharegpt-output-len` argument, which can generate a small trace file for browser to open smoothly.
71

72
73
  Additionally, if you want to locate the SGLang Python source code through the cuda kernel in Trace, you need to disable CUDA Graph when starting the service. This can be done by using the `--disable-cuda-graph` parameter in the command to start the service.

Lianmin Zheng's avatar
Lianmin Zheng committed
74
## Profile with Nsight
75

76
77
78
[Nsight systems](https://docs.nvidia.com/nsight-systems/) is an advanced tool that exposes more profiling details, such as register and shared memory usage, annotated code regions and low-level CUDA APIs and events.

1. Prerequisite:
79

80
   Install using apt, or run inside a [NVIDIA Docker container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch/tags) or [SGLang Docker container](https://github.com/sgl-project/sglang/tree/main/docker).
Lianmin Zheng's avatar
Lianmin Zheng committed
81

82
83
84
85
86
87
88
89
90
91
   ```bash
   # install nsys
   # https://docs.nvidia.com/nsight-systems/InstallationGuide/index.html
   apt update
   apt install -y --no-install-recommends gnupg
   echo "deb http://developer.download.nvidia.com/devtools/repos/ubuntu$(source /etc/lsb-release; echo "$DISTRIB_RELEASE" | tr -d .)/$(dpkg --print-architecture) /" | tee /etc/apt/sources.list.d/nvidia-devtools.list
   apt-key adv --fetch-keys http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub
   apt update
   apt install nsight-systems-cli
   ```
Lianmin Zheng's avatar
Lianmin Zheng committed
92

93
2. To profile a single batch, use
Lianmin Zheng's avatar
Lianmin Zheng committed
94

95
96
97
   ```bash
   nsys profile --trace-fork-before-exec=true --cuda-graph-trace=node python3 -m sglang.bench_one_batch --model meta-llama/Meta-Llama-3-8B --batch-size 64 --input-len 512
   ```
98

99
3. To profile a server, e.g.
Lianmin Zheng's avatar
Lianmin Zheng committed
100

101
102
103
   ```bash
   # launch the server, set the delay and duration times according to needs
   # after the duration time has been used up, server will be killed by nsys
Lianmin Zheng's avatar
Lianmin Zheng committed
104

105
   nsys profile --trace-fork-before-exec=true --cuda-graph-trace=node -o sglang.out --delay 60 --duration 70 python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache
106

107
108
109
   # client
   python3 -m sglang.bench_serving --backend sglang --num-prompts 1000 --dataset-name random --random-input 1024 --random-output 512
   ```
110

111
   In practice, we recommend users to set `--duration` argument to a large value. Whenever user wants the server to stop profiling. Firstly run:
112

113
114
115
   ```bash
   nsys sessions list
   ```
116

117
   to get the session id in the form of `profile-XXXXX`, then run:
118

119
120
121
   ```bash
   nsys stop --session=profile-XXXXX
   ```
Lianmin Zheng's avatar
Lianmin Zheng committed
122

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
   to manually kill the profiler and generate `nsys-rep` files instantly.

4. Use NVTX to annotate code regions, e.g. to see their execution time.

   ```bash
   # install nvtx
   pip install nvtx
   ```

   ```python
   # code snippets
   import nvtx
   with nvtx.annotate("description", color="color"):
       # some critical code
   ```
138
139

## Other tips
140

141
1. You can benchmark a model using dummy weights by only providing the config.json file. This allows for quick testing of model variants without training. To do so, add `--load-format dummy` to the above commands and then you only need a correct `config.json` under the checkpoint folder.
142
143
144
145
146
147
2. You can benchmark a model with modified configs (e.g., less layers) by using `--json-model-override-args`. For example, you can benchmark a model with only 2 layers and 2 kv heads using:

   ```bash
   python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --batch 32 --input-len 256 --output-len 32 --load-format dummy --json-model-override-args '{"num_hidden_layers": 1, "num_key_value_heads": 1}'
   ```

148
3. You can use `--python-backtrace=cuda` to see python call stack for all CUDA kernels, as in PyTorch Profiler. (Caveat: this can cause inaccurately long kernel runtimes for CUDA event based timing)
149
4. For more arguments see [Nsight Systems User Guide](https://docs.nvidia.com/nsight-systems/UserGuide/index.html).