Unverified Commit 96a5e4dd authored by Teng Ma's avatar Teng Ma Committed by GitHub
Browse files

[Feature] Support loading weights from ckpt engine worker (#11755)


Signed-off-by: default avatarYang Kaiyong <yangkaiyong.yky@antgroup.com>
Signed-off-by: default avatarCruz Zhao <CruzZhao@linux.alibaba.com>
Signed-off-by: default avatarXuchun Shang <xuchun.shang@gmail.com>
Co-authored-by: default avatarYang Kaiyong <yangkaiyong.yky@antgroup.com>
Co-authored-by: default avatarCruz Zhao <CruzZhao@linux.alibaba.com>
Co-authored-by: default avatarXuchun Shang <xuchun.shang@gmail.com>
Co-authored-by: default avatarShangming Cai <csmthu@gmail.com>
parent b0b4f716
......@@ -100,4 +100,5 @@ SGLang supports various environment variables that can be used to configure its
| Environment Variable | Description | Default Value |
| --- | --- | --- |
| `SGLANG_WAIT_WEIGHTS_READY_TIMEOUT` | Timeout period for waiting on weights | `120` |
| `SGLANG_DISABLE_OUTLINES_DISK_CACHE` | Disable Outlines disk cache | `true` |
"""
Usage:
1) Launch the server with wait-for-initial-weights option in one terminal:
python -m sglang.launch_server --model-path /workspace/Qwen/Qwen3-4B/ --tensor-parallel-size 2 --port 19730 --load-format dummy --checkpoint-engine-wait-weights-before-ready --mem-fraction-static 0.7
2) Torchrun this script in another terminal:
torchrun --nproc-per-node 2 update.py --update-method broadcast --checkpoint-path /workspace/Qwen/Qwen3-4B/ --inference-parallel-size 2
"""
import argparse
import json
import os
import pickle
import time
from collections import defaultdict
from collections.abc import Callable
from contextlib import contextmanager
from typing import Literal
import httpx
import torch
import torch.distributed as dist
from checkpoint_engine.ps import ParameterServer
from loguru import logger
from safetensors import safe_open
@contextmanager
def timer(msg: str):
start = time.perf_counter()
yield
end = time.perf_counter()
logger.info(f"{msg} duration: {end - start:.2f} seconds")
def check_sglang_ready(
endpoint: str, inference_parallel_size: int, uds: str | None = None
):
if rank != rank // inference_parallel_size * inference_parallel_size:
return
retry_num = 0
transport = None
if uds is not None:
transport = httpx.HTTPTransport(uds=uds)
with httpx.Client(transport=transport) as client:
while True:
try:
response = client.get(f"{endpoint}/ping", timeout=10)
response.raise_for_status()
break
except (httpx.ConnectError, httpx.HTTPStatusError) as e:
if retry_num % 10 == 0:
logger.warning(
f"fail to check sglang ready, retry {retry_num} times, error: {e}"
)
retry_num += 1
time.sleep(0.1)
def split_checkpoint_files(
checkpoint_path: str, rank: int, world_size: int
) -> list[str]:
checkpoint_files = [
os.path.join(checkpoint_path, f)
for f in filter(
lambda x: x.endswith(".safetensors"), os.listdir(checkpoint_path)
)
]
files_per_rank = (len(checkpoint_files) + world_size - 1) // world_size
return checkpoint_files[rank * files_per_rank : (rank + 1) * files_per_rank]
def split_tensors(
checkpoint_path: str, rank: int, world_size: int
) -> dict[str, torch.Tensor]:
index_fn = os.path.join(checkpoint_path, "model.safetensors.index.json")
with open(index_fn) as f:
weight_map: dict[str, str] = json.load(f)["weight_map"]
weights_per_rank = (len(weight_map) + world_size - 1) // world_size
fn_tensors: dict[str, list[str]] = defaultdict(list)
weight_keys = list(weight_map.items())
for name, file in weight_keys[
rank * weights_per_rank : (rank + 1) * weights_per_rank
]:
fn_tensors[file].append(name)
named_tensors = {}
for file, names in fn_tensors.items():
with safe_open(os.path.join(checkpoint_path, file), framework="pt") as f:
for name in names:
named_tensors[name] = f.get_tensor(name)
return named_tensors
def req_inference(
endpoint: str,
inference_parallel_size: int,
timeout: float = 300.0,
uds: str | None = None,
weight_version: str | None = None,
) -> Callable[[list[tuple[str, str]]], None]:
rank = int(os.getenv("RANK", 0))
src = rank // inference_parallel_size * inference_parallel_size
def req_func(socket_paths: list[tuple[str, str]]):
if rank == src:
with httpx.Client(transport=httpx.HTTPTransport(uds=uds)) as client:
resp = client.post(
f"{endpoint}/update_weights_from_ipc",
json={
"zmq_handles": dict(
socket_paths[src : src + inference_parallel_size]
),
"flush_cache": True,
"weight_version": weight_version,
},
timeout=timeout,
)
resp.raise_for_status()
return req_func
def update_weights(
ps: ParameterServer,
checkpoint_name: str,
checkpoint_files: list[str],
named_tensors: dict[str, torch.Tensor],
req_func: Callable[[list[tuple[str, str]]], None],
inference_parallel_size: int,
endpoint: str,
save_metas_file: str | None = None,
update_method: Literal["broadcast", "p2p", "all"] = "broadcast",
uds: str | None = None,
):
ps.register_checkpoint(
checkpoint_name, files=checkpoint_files, named_tensors=named_tensors
)
ps.init_process_group()
check_sglang_ready(endpoint, inference_parallel_size, uds)
dist.barrier()
with timer("Gather metas"):
ps.gather_metas(checkpoint_name)
if save_metas_file and int(os.getenv("RANK")) == 0:
with open(save_metas_file, "wb") as f:
pickle.dump(ps.get_metas(), f)
if update_method == "broadcast" or update_method == "all":
with timer("Update weights without setting ranks"):
ps.update(checkpoint_name, req_func)
if update_method == "p2p" or update_method == "all":
if update_method:
# sleep 2s to wait destroy process group
time.sleep(2)
with timer("Update weights with setting ranks"):
ps.update(
checkpoint_name, req_func, ranks=list(range(inference_parallel_size))
)
def join(
ps: ParameterServer,
checkpoint_name: str,
load_metas_file: str,
req_func: Callable[[list[tuple[str, str]]], None],
inference_parallel_size: int,
endpoint: str,
uds: str | None = None,
):
assert load_metas_file, "load_metas_file is required"
with open(load_metas_file, "rb") as f:
metas = pickle.load(f)
ps.init_process_group()
check_sglang_ready(endpoint, inference_parallel_size, uds)
dist.barrier()
with timer("Gather metas before join"):
ps.gather_metas(checkpoint_name)
ps.load_metas(metas)
with timer(
f"Update weights with setting ranks as range(0, {inference_parallel_size}) by using p2p"
):
ps.update(checkpoint_name, req_func, ranks=list(range(inference_parallel_size)))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Update weights example")
parser.add_argument("--checkpoint-path", type=str, default=None)
parser.add_argument("--save-metas-file", type=str, default=None)
parser.add_argument("--load-metas-file", type=str, default=None)
parser.add_argument("--sleep-time", type=int, default=0)
parser.add_argument("--endpoint", type=str, default="http://localhost:19730")
parser.add_argument("--inference-parallel-size", type=int, default=8)
parser.add_argument("--checkpoint-name", type=str, default="my-checkpoint-iter-0")
parser.add_argument("--update-method", type=str, default="broadcast")
parser.add_argument("--uds", type=str, default=None)
parser.add_argument("--weight-version", type=str, default=None)
args = parser.parse_args()
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
req_func = req_inference(
args.endpoint,
args.inference_parallel_size,
uds=args.uds,
weight_version=args.weight_version,
)
ps = ParameterServer(auto_pg=True)
ps._p2p_store = None
if args.load_metas_file:
join(
ps,
args.checkpoint_name,
args.load_metas_file,
req_func,
args.inference_parallel_size,
args.endpoint,
args.uds,
)
else:
if os.path.exists(
os.path.join(args.checkpoint_path, "model.safetensors.index.json")
):
named_tensors = split_tensors(args.checkpoint_path, rank, world_size)
checkpoint_files = []
else:
checkpoint_files = split_checkpoint_files(
args.checkpoint_path, rank, world_size
)
named_tensors = {}
update_weights(
ps,
args.checkpoint_name,
checkpoint_files,
named_tensors,
req_func,
args.inference_parallel_size,
args.endpoint,
args.save_metas_file,
args.update_method,
args.uds,
)
time.sleep(args.sleep_time)
......@@ -89,6 +89,7 @@ test = [
"sentence_transformers",
"tabulate",
]
checkpoint-engine = ["checkpoint-engine==0.1.2"]
all = []
dev = ["sglang[test]"]
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Checkpoint-engine integration for SGLang.
This module provides weight update functionality via IPC for checkpoint-engine compatibility.
"""
import logging
from typing import Callable, Dict, Optional
import torch
import zmq
try:
from checkpoint_engine.worker import update_weights_from_ipc
except ImportError:
raise ImportError(
"checkpoint-engine is not installed. "
"Please install it with: pip install sglang[checkpoint-engine]"
)
logger = logging.getLogger(__name__)
class SGLangCheckpointEngineWorkerExtension:
"""
Worker extension for SGLang to support checkpoint-engine IPC weight updates.
This class provides the interface needed for checkpoint-engine integration.
"""
def __init__(self):
self._zmq_ctx: Optional[zmq.Context] = None
def get_device_uuid(self) -> str:
"""Get the UUID of current device."""
# We need to implement this to get the device UUID
# This will be overridden when integrated into SGLang's worker
raise NotImplementedError(
"This method should be overridden by SGLang integration"
)
def get_device_id(self) -> int:
"""Get the device ID."""
raise NotImplementedError(
"This method should be overridden by SGLang integration"
)
def get_model_loader(self) -> Callable:
"""Get the model weight loader function."""
raise NotImplementedError(
"This method should be overridden by SGLang integration"
)
def get_post_hook(self) -> Optional[Callable]:
"""Get the post-processing hook after weight loading."""
return None
def update_weights_from_ipc(self, zmq_handles: Dict[str, str]):
"""
Update weights from IPC communication.
Args:
zmq_handles: Dict mapping device UUID to ZMQ socket path
"""
if self._zmq_ctx is None:
self._zmq_ctx = zmq.Context()
device_uuid = self.get_device_uuid()
device_id = self.get_device_id()
if device_uuid not in zmq_handles:
raise ValueError(
f"Device UUID {device_uuid} not found in zmq_handles: {list(zmq_handles.keys())}"
)
update_weights_from_ipc(
self._zmq_ctx,
zmq_handles[device_uuid],
device_id=device_id,
run=self.get_model_loader(),
post_hook=self.get_post_hook(),
)
class SGLangCheckpointEngineWorkerExtensionImpl(SGLangCheckpointEngineWorkerExtension):
"""
Implementation of SGLangCheckpointEngineWorkerExtension that integrates with SGLang's model runner.
This class provides the concrete implementation for checkpoint-engine IPC weight updates.
"""
def __init__(self, model_runner):
super().__init__()
self.model_runner = model_runner
def get_device_uuid(self) -> str:
"""Get the UUID of current device."""
# Get device UUID for current device
device_id = torch.cuda.current_device()
try:
return f"GPU-{torch.cuda.get_device_properties(device_id).uuid!s}"
except AssertionError as e:
raise ValueError(f"Failed to get GPU UUID for device {device_id}") from e
def get_device_id(self) -> int:
"""Get the device ID."""
return torch.cuda.current_device()
def get_model_loader(self) -> Callable:
"""Get the model weight loader function."""
return self.model_runner.model.load_weights
def get_post_hook(self) -> Optional[Callable]:
"""Get the post-processing hook after weight loading."""
def post_hook():
# Perform post-processing after weight loading similar to DefaultModelLoader
try:
from sglang.srt.model_loader.loader import device_loading_context
# Process quantization methods after loading weights
for _, module in self.model_runner.model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
# Move parameters to device if needed for quantization processing
target_device = torch.device(
"cuda", torch.cuda.current_device()
)
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
# Call model-specific post-loading hook if available
if hasattr(self.model_runner.model, "post_load_weights"):
self.model_runner.model.post_load_weights()
except Exception as e:
logger.warning(f"Post-hook processing failed: {e}")
return post_hook
......@@ -59,6 +59,7 @@ from sglang.srt.managers.io_struct import (
UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromIPCReqInput,
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerRouter
......@@ -649,6 +650,21 @@ class Engine(EngineBase):
request=None,
)
def update_weights_from_ipc(
self,
zmq_handles: Dict[str, str],
flush_cache: bool = True,
):
"""Update weights from IPC for checkpoint-engine integration."""
obj = UpdateWeightsFromIPCReqInput(
zmq_handles=zmq_handles,
flush_cache=flush_cache,
)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.tokenizer_manager.update_weights_from_ipc(obj, None)
)
def _set_envs_and_config(server_args: ServerArgs):
# Set global environments
......
......@@ -96,6 +96,7 @@ from sglang.srt.managers.io_struct import (
UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromIPCReqInput,
UpdateWeightsFromTensorReqInput,
UpdateWeightVersionReqInput,
VertexGenerateReqInput,
......@@ -129,6 +130,7 @@ logger = logging.getLogger(__name__)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
WAIT_WEIGHTS_READY_TIMEOUT = int(os.getenv("SGLANG_WAIT_WEIGHTS_READY_TIMEOUT", 120))
# Store global states
......@@ -838,6 +840,27 @@ async def update_weights_from_distributed(
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
@app.post("/update_weights_from_ipc")
async def update_weights_from_ipc(obj: UpdateWeightsFromIPCReqInput, request: Request):
"""Update the weights from IPC (Inter-Process Communication) for checkpoint-engine integration."""
success, message = await _global_state.tokenizer_manager.update_weights_from_ipc(
obj, request
)
# Update weight version if provided and weights update was successful
if success and obj.weight_version is not None:
_update_weight_version_if_provided(obj.weight_version)
message += f" Weight version updated to {obj.weight_version}."
content = {"success": success, "message": message}
if success:
if _global_state.tokenizer_manager.initial_weights_loaded is False:
_global_state.tokenizer_manager.initial_weights_loaded = True
return ORJSONResponse(content)
else:
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
@app.post("/update_weight_version")
async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request):
"""Update the weight version. This operation requires no active requests."""
......@@ -1530,6 +1553,8 @@ def _wait_and_warmup(
pipe_finish_writer: Optional[multiprocessing.connection.Connection],
launch_callback: Optional[Callable[[], None]] = None,
):
if server_args.checkpoint_engine_wait_weights_before_ready:
_wait_weights_ready()
if not server_args.skip_server_warmup:
if not _execute_server_warmup(
server_args,
......@@ -1552,3 +1577,24 @@ def _wait_and_warmup(
if launch_callback is not None:
launch_callback()
def _wait_weights_ready():
"""Wait for weights to be ready within the specified timeout."""
timeout = WAIT_WEIGHTS_READY_TIMEOUT
start_time = time.time()
for _ in range(timeout):
if _global_state.tokenizer_manager.initial_weights_loaded:
logger.info(
f"Weights are ready after {time.time() - start_time:.2f} seconds"
)
return
time.sleep(1)
# Timeout reached without weights being ready
logger.error(
f"Weights are not ready after waiting {timeout} seconds. "
f"Consider increasing SGLANG_WAIT_WEIGHTS_READY_TIMEOUT environment variable. "
f"Current status: initial_weights_loaded={_global_state.tokenizer_manager.initial_weights_loaded}"
)
......@@ -1080,6 +1080,24 @@ class InitWeightsSendGroupForRemoteInstanceReqInput(BaseReq):
backend: str = "nccl"
# Now UpdateWeightsFromIPCReqInput and UpdateWeightsFromIPCReqOutput
# are only used by Checkpoint Engine (https://github.com/MoonshotAI/checkpoint-engine)
@dataclass
class UpdateWeightsFromIPCReqInput(BaseReq):
# ZMQ socket paths for each device UUID
zmq_handles: Dict[str, str]
# Whether to flush cache after weight update
flush_cache: bool = True
# Optional: Update weight version along with weights
weight_version: Optional[str] = None
@dataclass
class UpdateWeightsFromIPCReqOutput(BaseReq):
success: bool
message: str
@dataclass
class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq):
success: bool
......
......@@ -109,6 +109,7 @@ from sglang.srt.managers.io_struct import (
UnloadLoRAAdapterReqOutput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromIPCReqInput,
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.mm_utils import init_embedding_cache
......@@ -530,6 +531,7 @@ class Scheduler(
self.update_weights_from_distributed,
),
(UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
(UpdateWeightsFromIPCReqInput, self.update_weights_from_ipc),
(GetWeightsByNameReqInput, self.get_weights_by_name),
(ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
......
......@@ -21,6 +21,8 @@ from sglang.srt.managers.io_struct import (
UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromDistributedReqOutput,
UpdateWeightsFromIPCReqInput,
UpdateWeightsFromIPCReqOutput,
UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput,
)
......@@ -80,6 +82,18 @@ class SchedulerUpdateWeightsMixin:
torch.distributed.barrier(group=self.tp_cpu_group)
return UpdateWeightsFromTensorReqOutput(success, message)
def update_weights_from_ipc(self, recv_req: UpdateWeightsFromIPCReqInput):
"""Update the online model parameter from IPC for checkpoint-engine integration."""
success, message = self.tp_worker.update_weights_from_ipc(recv_req)
if success:
if recv_req.flush_cache:
flush_cache_success = self.flush_cache()
assert flush_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
torch.distributed.barrier(group=self.tp_cpu_group)
return UpdateWeightsFromIPCReqOutput(success, message)
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.tp_worker.get_weights_by_name(recv_req)
return GetWeightsByNameReqOutput(parameter)
......
......@@ -63,6 +63,8 @@ from sglang.srt.managers.io_struct import (
UnloadLoRAAdapterReqOutput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromDistributedReqOutput,
UpdateWeightsFromIPCReqInput,
UpdateWeightsFromIPCReqOutput,
UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput,
)
......@@ -169,6 +171,9 @@ class TokenizerCommunicatorMixin:
self.update_weights_from_tensor_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.update_weights_from_ipc_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.get_weights_by_name_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
......@@ -235,6 +240,10 @@ class TokenizerCommunicatorMixin:
UpdateWeightsFromTensorReqOutput,
self.update_weights_from_tensor_communicator.handle_recv,
),
(
UpdateWeightsFromIPCReqOutput,
self.update_weights_from_ipc_communicator.handle_recv,
),
(
GetWeightsByNameReqOutput,
self.get_weights_by_name_communicator.handle_recv,
......@@ -442,6 +451,28 @@ class TokenizerCommunicatorMixin:
result = (await self.update_weights_from_tensor_communicator(obj))[0]
return result.success, result.message
async def update_weights_from_ipc(
self,
obj: UpdateWeightsFromIPCReqInput,
request: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]:
"""Update weights via IPC for checkpoint-engine integration."""
self.auto_create_handle_loop()
try:
# For now, we only support single data parallel instance
assert (
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
), "dp_size must be 1 or dp attention must be enabled for update weights from IPC"
logger.info("Starting IPC weight update")
# This means that weight sync cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
result = (await self.update_weights_from_ipc_communicator(obj))[0]
return result.success, result.message
except Exception as e:
error_msg = f"IPC weight update failed: {str(e)}"
logger.error(error_msg)
return False, error_msg
async def load_lora_adapter(
self: TokenizerManager,
obj: LoadLoRAAdapterReqInput,
......
......@@ -284,6 +284,11 @@ class TokenizerManager(TokenizerCommunicatorMixin):
self.gracefully_exit = False
self.last_receive_tstamp = 0
# Initial weights status
self.initial_weights_loaded = True
if server_args.checkpoint_engine_wait_weights_before_ready:
self.initial_weights_loaded = False
# Dumping
self.dump_requests_folder = "" # By default do not dump
self.dump_requests_threshold = 1000
......
......@@ -32,6 +32,7 @@ from sglang.srt.managers.io_struct import (
UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromIPCReqInput,
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
......@@ -164,6 +165,11 @@ class BaseTpWorker(ABC):
)
return success, message
def update_weights_from_ipc(self, recv_req: UpdateWeightsFromIPCReqInput):
"""Update weights from IPC for checkpoint-engine integration."""
success, message = self.model_runner.update_weights_from_ipc(recv_req)
return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.model_runner.get_weights_by_name(
recv_req.name, recv_req.truncate_size
......
......@@ -2387,6 +2387,23 @@ class ModelRunner:
)
ShardedStateLoader.save_model(self.model, path, pattern, max_size)
def update_weights_from_ipc(self, recv_req):
"""Update weights from IPC for checkpoint-engine integration."""
try:
from sglang.srt.checkpoint_engine.checkpoint_engine_worker import (
SGLangCheckpointEngineWorkerExtensionImpl,
)
# Create a worker extension that integrates with SGLang's model
worker = SGLangCheckpointEngineWorkerExtensionImpl(self)
worker.update_weights_from_ipc(recv_req.zmq_handles)
return True, "IPC weight update completed successfully"
except ImportError as e:
return False, f"IPC weight update failed: ImportError {e}"
except Exception as e:
logger.error(f"IPC weight update failed: {e}")
return False, str(e)
def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
params_dict = dict(model.named_parameters())
......
......@@ -208,6 +208,7 @@ class ServerArgs:
skip_server_warmup: bool = False
warmups: Optional[str] = None
nccl_port: Optional[int] = None
checkpoint_engine_wait_weights_before_ready: bool = False
# Quantization and data type
dtype: str = "auto"
......@@ -1704,6 +1705,12 @@ class ServerArgs:
default=ServerArgs.nccl_port,
help="The port for NCCL distributed environment setup. Defaults to a random port.",
)
parser.add_argument(
"--checkpoint-engine-wait-weights-before-ready",
action="store_true",
help="If set, the server will wait for initial weights to be loaded via checkpoint-engine or other update methods "
"before serving inference requests.",
)
# Quantization and data type
parser.add_argument(
......
......@@ -2275,6 +2275,11 @@ def launch_dummy_health_check_server(host, port, enable_metrics):
app = FastAPI()
@app.get("/ping")
async def ping():
"""Could be used by the checkpoint-engine update script to confirm the server is up."""
return Response(status_code=200)
@app.get("/health")
async def health():
"""Check the health of the http server."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment