""" Usage: 1) Launch the server with wait-for-initial-weights option in one terminal: python -m sglang.launch_server --model-path /workspace/Qwen/Qwen3-4B/ --tensor-parallel-size 2 --port 19730 --load-format dummy --checkpoint-engine-wait-weights-before-ready --mem-fraction-static 0.7 2) Torchrun this script in another terminal: torchrun --nproc-per-node 2 update.py --update-method broadcast --checkpoint-path /workspace/Qwen/Qwen3-4B/ --inference-parallel-size 2 """ import argparse import json import os import pickle import time from collections import defaultdict from collections.abc import Callable from contextlib import contextmanager from typing import Literal import httpx import torch import torch.distributed as dist from checkpoint_engine.ps import ParameterServer from loguru import logger from safetensors import safe_open @contextmanager def timer(msg: str): start = time.perf_counter() yield end = time.perf_counter() logger.info(f"{msg} duration: {end - start:.2f} seconds") def check_sglang_ready( endpoint: str, inference_parallel_size: int, uds: str | None = None ): if rank != rank // inference_parallel_size * inference_parallel_size: return retry_num = 0 transport = None if uds is not None: transport = httpx.HTTPTransport(uds=uds) with httpx.Client(transport=transport) as client: while True: try: response = client.get(f"{endpoint}/ping", timeout=10) response.raise_for_status() break except (httpx.ConnectError, httpx.HTTPStatusError) as e: if retry_num % 10 == 0: logger.warning( f"fail to check sglang ready, retry {retry_num} times, error: {e}" ) retry_num += 1 time.sleep(0.1) def split_checkpoint_files( checkpoint_path: str, rank: int, world_size: int ) -> list[str]: checkpoint_files = [ os.path.join(checkpoint_path, f) for f in filter( lambda x: x.endswith(".safetensors"), os.listdir(checkpoint_path) ) ] files_per_rank = (len(checkpoint_files) + world_size - 1) // world_size return checkpoint_files[rank * files_per_rank : (rank + 1) * files_per_rank] def split_tensors( checkpoint_path: str, rank: int, world_size: int ) -> dict[str, torch.Tensor]: index_fn = os.path.join(checkpoint_path, "model.safetensors.index.json") with open(index_fn) as f: weight_map: dict[str, str] = json.load(f)["weight_map"] weights_per_rank = (len(weight_map) + world_size - 1) // world_size fn_tensors: dict[str, list[str]] = defaultdict(list) weight_keys = list(weight_map.items()) for name, file in weight_keys[ rank * weights_per_rank : (rank + 1) * weights_per_rank ]: fn_tensors[file].append(name) named_tensors = {} for file, names in fn_tensors.items(): with safe_open(os.path.join(checkpoint_path, file), framework="pt") as f: for name in names: named_tensors[name] = f.get_tensor(name) return named_tensors def req_inference( endpoint: str, inference_parallel_size: int, timeout: float = 300.0, uds: str | None = None, weight_version: str | None = None, ) -> Callable[[list[tuple[str, str]]], None]: rank = int(os.getenv("RANK", 0)) src = rank // inference_parallel_size * inference_parallel_size def req_func(socket_paths: list[tuple[str, str]]): if rank == src: with httpx.Client(transport=httpx.HTTPTransport(uds=uds)) as client: resp = client.post( f"{endpoint}/update_weights_from_ipc", json={ "zmq_handles": dict( socket_paths[src : src + inference_parallel_size] ), "flush_cache": True, "weight_version": weight_version, }, timeout=timeout, ) resp.raise_for_status() return req_func def update_weights( ps: ParameterServer, checkpoint_name: str, checkpoint_files: list[str], named_tensors: dict[str, torch.Tensor], req_func: Callable[[list[tuple[str, str]]], None], inference_parallel_size: int, endpoint: str, save_metas_file: str | None = None, update_method: Literal["broadcast", "p2p", "all"] = "broadcast", uds: str | None = None, ): ps.register_checkpoint( checkpoint_name, files=checkpoint_files, named_tensors=named_tensors ) ps.init_process_group() check_sglang_ready(endpoint, inference_parallel_size, uds) dist.barrier() with timer("Gather metas"): ps.gather_metas(checkpoint_name) if save_metas_file and int(os.getenv("RANK")) == 0: with open(save_metas_file, "wb") as f: pickle.dump(ps.get_metas(), f) if update_method == "broadcast" or update_method == "all": with timer("Update weights without setting ranks"): ps.update(checkpoint_name, req_func) if update_method == "p2p" or update_method == "all": if update_method: # sleep 2s to wait destroy process group time.sleep(2) with timer("Update weights with setting ranks"): ps.update( checkpoint_name, req_func, ranks=list(range(inference_parallel_size)) ) def join( ps: ParameterServer, checkpoint_name: str, load_metas_file: str, req_func: Callable[[list[tuple[str, str]]], None], inference_parallel_size: int, endpoint: str, uds: str | None = None, ): assert load_metas_file, "load_metas_file is required" with open(load_metas_file, "rb") as f: metas = pickle.load(f) ps.init_process_group() check_sglang_ready(endpoint, inference_parallel_size, uds) dist.barrier() with timer("Gather metas before join"): ps.gather_metas(checkpoint_name) ps.load_metas(metas) with timer( f"Update weights with setting ranks as range(0, {inference_parallel_size}) by using p2p" ): ps.update(checkpoint_name, req_func, ranks=list(range(inference_parallel_size))) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Update weights example") parser.add_argument("--checkpoint-path", type=str, default=None) parser.add_argument("--save-metas-file", type=str, default=None) parser.add_argument("--load-metas-file", type=str, default=None) parser.add_argument("--sleep-time", type=int, default=0) parser.add_argument("--endpoint", type=str, default="http://localhost:19730") parser.add_argument("--inference-parallel-size", type=int, default=8) parser.add_argument("--checkpoint-name", type=str, default="my-checkpoint-iter-0") parser.add_argument("--update-method", type=str, default="broadcast") parser.add_argument("--uds", type=str, default=None) parser.add_argument("--weight-version", type=str, default=None) args = parser.parse_args() rank = int(os.getenv("RANK")) world_size = int(os.getenv("WORLD_SIZE")) req_func = req_inference( args.endpoint, args.inference_parallel_size, uds=args.uds, weight_version=args.weight_version, ) ps = ParameterServer(auto_pg=True) ps._p2p_store = None if args.load_metas_file: join( ps, args.checkpoint_name, args.load_metas_file, req_func, args.inference_parallel_size, args.endpoint, args.uds, ) else: if os.path.exists( os.path.join(args.checkpoint_path, "model.safetensors.index.json") ): named_tensors = split_tensors(args.checkpoint_path, rank, world_size) checkpoint_files = [] else: checkpoint_files = split_checkpoint_files( args.checkpoint_path, rank, world_size ) named_tensors = {} update_weights( ps, args.checkpoint_name, checkpoint_files, named_tensors, req_func, args.inference_parallel_size, args.endpoint, args.save_metas_file, args.update_method, args.uds, ) time.sleep(args.sleep_time)