Commit 21e9e63a authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

Print progress bar during cuda graph capture (#2502)

parent 1fc84cf6
# Enabling cache for torch.compile
SGLang uses `max-autotune-no-cudagraphs` mode of torch.compile. The auto-tuning can be slow.
If you want to deploy a model on many different machines, you can ship the torch.compile cache to these machines and skip the compilation steps.
This is based on https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html
1. Generate the cache by setting TORCHINDUCTOR_CACHE_DIR and running the model once.
```
TORCHINDUCTOR_CACHE_DIR=/root/inductor_root_cache python3 -m sglang.launch_server --model meta-llama/Llama-3.1-8B-Instruct --enable-torch-compile
```
2. Copy the cache folder to other machines and launch the server with `TORCHINDUCTOR_CACHE_DIR`.
......@@ -20,6 +20,8 @@ from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable
import torch
import tqdm
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.distributed.parallel_state import graph_capture
from vllm.model_executor.custom_op import CustomOp
......@@ -255,7 +257,12 @@ class CudaGraphRunner:
def capture(self):
with graph_capture() as graph_capture_context:
self.stream = graph_capture_context.stream
for bs in self.capture_bs:
capture_bs = (
tqdm.tqdm(self.capture_bs)
if get_tensor_model_parallel_rank() == 0
else self.capture_bs
)
for bs in capture_bs:
with patch_model(
self.model_runner.model,
bs in self.compile_bs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment