Unverified Commit 32438eba authored by Teng Ma's avatar Teng Ma Committed by GitHub
Browse files

[Ckpt Engine] feat: new sglang entrypoint support for update (#12216)

parent fed02a49
# Checkpoint Engine Integration
The SGLang checkpoint engine integration provides an efficient way to load model weights using a distributed checkpoint loading system. This feature significantly reduces model loading time, especially for large models and multi-node setups, by parallelizing the weight loading process across multiple processes and nodes.
## Overview
The checkpoint engine integration allows SGLang to:
- Load model weights in parallel using multiple processes
- Distribute weight loading across multiple nodes to increase effective disk bandwidth
- Overlap weight loading with other initialization tasks like CUDA graph capture
- Support both single-node and multi-node deployments
## Installation
First, install the checkpoint engine package:
```bash
pip install 'checkpoint-engine[p2p]'
```
## Architecture
The system consists of two main components:
1. **SGLang Server**: Runs with `--wait-for-initial-weights` flag to wait for weights before becoming ready
2. **Checkpoint Engine Workers**: Separate processes (managed by torchrun) that load and distribute model weights
The checkpoint engine uses a parameter server architecture with support for:
- **Broadcast mode**: Weights are broadcast from loading processes to inference processes
- **P2P mode**: Direct peer-to-peer weight transfer between processes
- **All mode**: Combination of both broadcast and P2P methods
## Usage Examples
### Single Node Setup
**Terminal 1 - Launch SGLang Server:**
```bash
python -m sglang.launch_server \
--model-path Qwen/Qwen3-8B \
--tp 8 \
--load-format dummy \
--wait-for-initial-weights
```
**Terminal 2 - Run Checkpoint Engine:**
Using sglang entrypoint:
```bash
python -m sglang.srt.checkpoint_engine.update \
--update-method broadcast \
--checkpoint-path /path/to/Qwen/Qwen3-8B/ \
--inference-parallel-size 8
```
Using torchrun directly:
```bash
torchrun --nproc-per-node 8 \
examples/checkpoint_engine/update.py \
--update-method broadcast \
--checkpoint-path /path/to/Qwen/Qwen3-8B/ \
--inference-parallel-size 8
```
### Multi-Node Setup (2 Nodes)
**Node 0:**
Launch SGLang server:
```bash
python -m sglang.launch_server \
--model-path Qwen/Qwen3-8B \
--tp 8 \
--load-format dummy \
--wait-for-initial-weights \
--host [IP]
```
Run checkpoint engine:
Using sglang entrypoint (recommended):
```bash
python -m sglang.srt.checkpoint_engine.update \
--update-method broadcast \
--checkpoint-path /path/to/Qwen/Qwen3-8B/ \
--inference-parallel-size 8
```
Using torchrun directly:
```bash
torchrun --nproc-per-node 8 \
--nnodes 2 \
--node-rank 0 \
--master-addr [IP] \
--master-port 29500 \
examples/checkpoint_engine/update.py \
--update-method broadcast \
--checkpoint-path /path/to/Qwen/Qwen3-8B/ \
--inference-parallel-size 8
```
**Node 1:**
Launch SGLang server:
```bash
python -m sglang.launch_server \
--model-path Qwen/Qwen3-8B \
--tp 8 \
--load-format dummy \
--wait-for-initial-weights \
--host [IP]
```
Run checkpoint engine:
Using sglang entrypoint (recommended):
```bash
python -m sglang.srt.checkpoint_engine.update \
--update-method broadcast \
--checkpoint-path /path/to/Qwen/Qwen3-8B/ \
--inference-parallel-size 8
```
Using torchrun directly:
```bash
torchrun --nproc-per-node 8 \
--nnodes 2 \
--node-rank 1 \
--master-addr [IP] \
--master-port 29500 \
examples/checkpoint_engine/update.py \
--update-method broadcast \
--checkpoint-path /path/to/Qwen/Qwen3-8B/ \
--inference-parallel-size 8
```
### Multi-Node Setup with Tensor Parallelism (TP=16)
**Node 0:**
Launch SGLang server:
```bash
python -m sglang.launch_server \
--model-path Qwen/Qwen3-8B \
--tp 8 \
--load-format dummy \
--wait-for-initial-weights \
--host [IP] \
--dist-init-addr [IP]:9120 \
--nnodes 2 \
--node-rank 0
```
Run checkpoint engine:
Using sglang entrypoint (recommended):
```bash
python -m sglang.srt.checkpoint_engine.update \
--update-method broadcast \
--checkpoint-path /path/to/Qwen/Qwen3-8B/ \
--inference-parallel-size 16
```
Using torchrun directly:
```bash
torchrun --nproc-per-node 8 \
--nnodes 2 \
--node-rank 0 \
--master-addr [IP] \
--master-port 29500 \
examples/checkpoint_engine/update.py \
--update-method broadcast \
--checkpoint-path /path/to/Qwen/Qwen3-8B/ \
--inference-parallel-size 16
```
**Node 1:**
Launch SGLang server:
```bash
python -m sglang.launch_server \
--model-path Qwen/Qwen3-8B \
--tp 8 \
--load-format dummy \
--wait-for-initial-weights \
--host [IP] \
--dist-init-addr [IP]:9120 \
--nnodes 2 \
--node-rank 1
```
Run checkpoint engine:
Using sglang entrypoint (recommended):
```bash
python -m sglang.srt.checkpoint_engine.update \
--update-method broadcast \
--checkpoint-path /path/to/Qwen/Qwen3-8B/ \
--inference-parallel-size 16
```
Using torchrun directly:
```bash
torchrun --nproc-per-node 8 \
--nnodes 2 \
--node-rank 1 \
--master-addr [IP] \
--master-port 29500 \
examples/checkpoint_engine/update.py \
--update-method broadcast \
--checkpoint-path /path/to/Qwen/Qwen3-8B/ \
--inference-parallel-size 16
```
## Configuration Options
### SGLang Server Options
- `--load-format dummy`: Use dummy format for initial loading (allows overlapping with other tasks)
- `--wait-for-initial-weights`: Wait for checkpoint engine to provide weights before becoming ready
- `--host`: Host address for multi-node setups
- `--dist-init-addr`: Distributed initialization address for tensor parallelism
### Checkpoint Engine Options
- `--update-method`: Weight update method (`broadcast`, `p2p`, or `all`)
- `--checkpoint-path`: Path to model checkpoint directory
- `--inference-parallel-size`: Number of inference parallel processes
- `--endpoint`: SGLang server endpoint (default: `http://localhost:19730`)
- `--checkpoint-name`: Name for the checkpoint (default: `my-checkpoint-iter-0`)
- `--save-metas-file`: File to save checkpoint metadata
- `--load-metas-file`: File to load checkpoint metadata from
- `--uds`: Unix domain socket path for communication
- `--weight-version`: Version identifier for weights
## Performance Benefits
The checkpoint engine provides significant time savings in two main aspects:
1. **Multi-node Loading**: Each node only loads a portion of weights from disk, effectively increasing disk bandwidth. More participating nodes provide greater acceleration. Preliminary tests show 20-second acceleration when loading DeepSeek-R1 on H20-3e with two nodes.
2. **Single Process Optimization**: Using dummy format allows overlapping disk-to-CPU transfer with CUDA graph capture and other initialization tasks, providing additional time savings.
## Troubleshooting
- Ensure checkpoint engine package is installed: `pip install 'checkpoint-engine[p2p]'`
- Verify network connectivity between nodes in multi-node setups
- Check that the checkpoint path contains valid model files
- Monitor logs for connection errors between SGLang server and checkpoint engine
- Use `--sleep-time` parameter to add delays if needed for debugging
## References
- [Checkpoint Engine Repository](https://github.com/MoonshotAI/checkpoint-engine)
...@@ -53,6 +53,7 @@ Its core features include: ...@@ -53,6 +53,7 @@ Its core features include:
advanced_features/router.md advanced_features/router.md
advanced_features/deterministic_inference.md advanced_features/deterministic_inference.md
advanced_features/observability.md advanced_features/observability.md
advanced_features/checkpoint_engine.md
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1
......
"""
Checkpoint engine module for SGLang.
This module provides functionality for updating model weights via checkpoint engine.
"""
from sglang.srt.checkpoint_engine.update import main
__all__ = ["main"]
"""
Usage:
1) Launch the server with wait-for-initial-weights option in one terminal:
python -m sglang.launch_server --model-path /workspace/Qwen/Qwen3-4B/ --tensor-parallel-size 2 --port 19730 --load-format dummy --checkpoint-engine-wait-weights-before-ready --mem-fraction-static 0.7
2) Torchrun this script in another terminal:
torchrun --nproc-per-node 2 update.py --update-method broadcast --checkpoint-path /workspace/Qwen/Qwen3-4B/ --inference-parallel-size 2
Or use the integrated entry point:
python -m sglang.srt.checkpoint_engine.update --update-method broadcast --checkpoint-path /workspace/Qwen/Qwen3-4B/ --inference-parallel-size 2
"""
import argparse
import json
import os
import pickle
import subprocess
import sys
import time
from collections import defaultdict
from collections.abc import Callable
from contextlib import contextmanager
from typing import Literal
import httpx
import torch
import torch.distributed as dist
from safetensors import safe_open
try:
from checkpoint_engine.ps import ParameterServer
from loguru import logger
except ImportError:
# Fallback for when checkpoint_engine is not available
ParameterServer = None
import logging
logger = logging.getLogger(__name__)
@contextmanager
def timer(msg: str):
start = time.perf_counter()
yield
end = time.perf_counter()
logger.info(f"{msg} duration: {end - start:.2f} seconds")
def check_sglang_ready(
endpoint: str, inference_parallel_size: int, uds: str | None = None
):
rank = int(os.getenv("RANK", 0))
if rank != rank // inference_parallel_size * inference_parallel_size:
return
retry_num = 0
transport = None
if uds is not None:
transport = httpx.HTTPTransport(uds=uds)
with httpx.Client(transport=transport) as client:
while True:
try:
response = client.get(f"{endpoint}/ping", timeout=10)
response.raise_for_status()
break
except (httpx.ConnectError, httpx.HTTPStatusError) as e:
if retry_num % 10 == 0:
logger.warning(
f"fail to check sglang ready, retry {retry_num} times, error: {e}"
)
retry_num += 1
time.sleep(0.1)
def split_checkpoint_files(
checkpoint_path: str, rank: int, world_size: int
) -> list[str]:
checkpoint_files = [
os.path.join(checkpoint_path, f)
for f in filter(
lambda x: x.endswith(".safetensors"), os.listdir(checkpoint_path)
)
]
files_per_rank = (len(checkpoint_files) + world_size - 1) // world_size
return checkpoint_files[rank * files_per_rank : (rank + 1) * files_per_rank]
def split_tensors(
checkpoint_path: str, rank: int, world_size: int
) -> dict[str, torch.Tensor]:
index_fn = os.path.join(checkpoint_path, "model.safetensors.index.json")
with open(index_fn) as f:
weight_map: dict[str, str] = json.load(f)["weight_map"]
weights_per_rank = (len(weight_map) + world_size - 1) // world_size
fn_tensors: dict[str, list[str]] = defaultdict(list)
weight_keys = list(weight_map.items())
for name, file in weight_keys[
rank * weights_per_rank : (rank + 1) * weights_per_rank
]:
fn_tensors[file].append(name)
named_tensors = {}
for file, names in fn_tensors.items():
with safe_open(os.path.join(checkpoint_path, file), framework="pt") as f:
for name in names:
named_tensors[name] = f.get_tensor(name)
return named_tensors
def req_inference(
endpoint: str,
inference_parallel_size: int,
timeout: float = 300.0,
uds: str | None = None,
weight_version: str | None = None,
) -> Callable[[list[tuple[str, str]]], None]:
rank = int(os.getenv("RANK", 0))
src = rank // inference_parallel_size * inference_parallel_size
def req_func(socket_paths: list[tuple[str, str]]):
if rank == src:
with httpx.Client(transport=httpx.HTTPTransport(uds=uds)) as client:
resp = client.post(
f"{endpoint}/update_weights_from_ipc",
json={
"zmq_handles": dict(
socket_paths[src : src + inference_parallel_size]
),
"flush_cache": True,
"weight_version": weight_version,
},
timeout=timeout,
)
resp.raise_for_status()
return req_func
def update_weights(
ps,
checkpoint_name: str,
checkpoint_files: list[str],
named_tensors: dict[str, torch.Tensor],
req_func: Callable[[list[tuple[str, str]]], None],
inference_parallel_size: int,
endpoint: str,
save_metas_file: str | None = None,
update_method: Literal["broadcast", "p2p", "all"] = "broadcast",
uds: str | None = None,
):
ps.register_checkpoint(
checkpoint_name, files=checkpoint_files, named_tensors=named_tensors
)
ps.init_process_group()
check_sglang_ready(endpoint, inference_parallel_size, uds)
dist.barrier()
with timer("Gather metas"):
ps.gather_metas(checkpoint_name)
if save_metas_file and int(os.getenv("RANK")) == 0:
with open(save_metas_file, "wb") as f:
pickle.dump(ps.get_metas(), f)
if update_method == "broadcast" or update_method == "all":
with timer("Update weights without setting ranks"):
ps.update(checkpoint_name, req_func)
if update_method == "p2p" or update_method == "all":
if update_method:
# sleep 2s to wait destroy process group
time.sleep(2)
with timer("Update weights with setting ranks"):
ps.update(
checkpoint_name, req_func, ranks=list(range(inference_parallel_size))
)
def join(
ps: ParameterServer,
checkpoint_name: str,
load_metas_file: str,
req_func: Callable[[list[tuple[str, str]]], None],
inference_parallel_size: int,
endpoint: str,
uds: str | None = None,
):
assert load_metas_file, "load_metas_file is required"
with open(load_metas_file, "rb") as f:
metas = pickle.load(f)
ps.init_process_group()
check_sglang_ready(endpoint, inference_parallel_size, uds)
dist.barrier()
with timer("Gather metas before join"):
ps.gather_metas(checkpoint_name)
ps.load_metas(metas)
with timer(
f"Update weights with setting ranks as range(0, {inference_parallel_size}) by using p2p"
):
ps.update(checkpoint_name, req_func, ranks=list(range(inference_parallel_size)))
def run_with_torchrun():
"""Run the update script with torchrun automatically."""
# Parse inference_parallel_size from command line arguments to determine nproc-per-node
inference_parallel_size = 8 # default
args = sys.argv[1:] # Skip the script name
# Look for --inference-parallel-size in arguments
for i, arg in enumerate(args):
if arg == "--inference-parallel-size" and i + 1 < len(args):
try:
inference_parallel_size = int(args[i + 1])
except ValueError:
pass
break
elif arg.startswith("--inference-parallel-size="):
try:
inference_parallel_size = int(arg.split("=", 1)[1])
except ValueError:
pass
break
# Build torchrun command
cmd = ["torchrun", f"--nproc-per-node={inference_parallel_size}", __file__] + args
print(f"Running: {' '.join(cmd)}", file=sys.stderr)
# Execute torchrun with the original script
try:
result = subprocess.run(cmd, check=False)
sys.exit(result.returncode)
except FileNotFoundError:
print(
"Error: torchrun command not found. Please ensure PyTorch is installed.",
file=sys.stderr,
)
sys.exit(1)
except KeyboardInterrupt:
print("\nInterrupted by user", file=sys.stderr)
sys.exit(130)
def main():
# Check if we're running under torchrun or need to invoke it
if os.getenv("RANK") is None:
# Not running under torchrun, so invoke it
run_with_torchrun()
return
# Running under torchrun, proceed with normal execution
parser = argparse.ArgumentParser(description="Update weights example")
parser.add_argument("--checkpoint-path", type=str, default=None)
parser.add_argument("--save-metas-file", type=str, default=None)
parser.add_argument("--load-metas-file", type=str, default=None)
parser.add_argument("--sleep-time", type=int, default=0)
parser.add_argument("--endpoint", type=str, default="http://localhost:19730")
parser.add_argument("--inference-parallel-size", type=int, default=8)
parser.add_argument("--checkpoint-name", type=str, default="my-checkpoint-iter-0")
parser.add_argument("--update-method", type=str, default="broadcast")
parser.add_argument("--uds", type=str, default=None)
parser.add_argument("--weight-version", type=str, default=None)
args = parser.parse_args()
# Get rank and world_size from environment (set by torchrun)
rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
req_func = req_inference(
args.endpoint,
args.inference_parallel_size,
uds=args.uds,
weight_version=args.weight_version,
)
if ParameterServer is None:
print("Error: checkpoint_engine package not available", file=sys.stderr)
sys.exit(1)
ps = ParameterServer(auto_pg=True)
ps._p2p_store = None
if args.load_metas_file:
join(
ps,
args.checkpoint_name,
args.load_metas_file,
req_func,
args.inference_parallel_size,
args.endpoint,
args.uds,
)
else:
if args.checkpoint_path and os.path.exists(
os.path.join(args.checkpoint_path, "model.safetensors.index.json")
):
named_tensors = split_tensors(args.checkpoint_path, rank, world_size)
checkpoint_files = []
else:
checkpoint_files = (
split_checkpoint_files(args.checkpoint_path, rank, world_size)
if args.checkpoint_path
else []
)
named_tensors = {}
update_weights(
ps,
args.checkpoint_name,
checkpoint_files,
named_tensors,
req_func,
args.inference_parallel_size,
args.endpoint,
args.save_metas_file,
args.update_method,
args.uds,
)
time.sleep(args.sleep_time)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment