Unverified Commit 56a1b6e3 authored by William Arnold's avatar William Arnold Committed by GitHub
Browse files

feat: Add SGLang /engine weight update endpoints (#6094)

parent 98eb6b7e
......@@ -250,6 +250,73 @@ class BaseWorkerHandler(BaseGenerativeHandler):
await self.engine.tokenizer_manager.stop_profile()
return {"status": "ok", "message": "Profiling stopped"}
async def update_weights_from_disk(self, body: dict) -> dict:
"""Update model weights from disk without restarting the server."""
from sglang.srt.managers.io_struct import UpdateWeightFromDiskReqInput
req = UpdateWeightFromDiskReqInput(**body)
(
success,
message,
num_paused_requests,
) = await self.engine.tokenizer_manager.update_weights_from_disk(req, None)
return {
"success": success,
"message": message,
"num_paused_requests": num_paused_requests,
}
async def update_weights_from_tensor(self, body: dict) -> dict:
"""Update model weights from tensors without restarting the server."""
from sglang.srt.managers.io_struct import UpdateWeightsFromTensorReqInput
req = UpdateWeightsFromTensorReqInput(**body)
(
success,
message,
) = await self.engine.tokenizer_manager.update_weights_from_tensor(req, None)
return {"success": success, "message": message}
async def update_weights_from_distributed(self, body: dict) -> dict:
"""Update model weights using distributed online synchronization."""
from sglang.srt.managers.io_struct import UpdateWeightsFromDistributedReqInput
req = UpdateWeightsFromDistributedReqInput(**body)
(
success,
message,
) = await self.engine.tokenizer_manager.update_weights_from_distributed(
req, None
)
return {"success": success, "message": message}
async def update_weights_from_ipc(self, body: dict) -> dict:
"""Update model weights from IPC for checkpoint-engine integration."""
from sglang.srt.managers.io_struct import UpdateWeightsFromIPCReqInput
req = UpdateWeightsFromIPCReqInput(**body)
success, message = await self.engine.tokenizer_manager.update_weights_from_ipc(
req, None
)
if success and not self.engine.tokenizer_manager.initial_weights_loaded:
self.engine.tokenizer_manager.initial_weights_loaded = True
return {"success": success, "message": message}
async def update_weight_version(self, body: dict) -> dict:
"""Update the active weight version without changing model weights."""
from sglang.srt.managers.io_struct import UpdateWeightVersionReqInput
req = UpdateWeightVersionReqInput(**body)
if req.abort_all_requests:
self.engine.tokenizer_manager.abort_request(abort_all=True)
self.engine.tokenizer_manager.server_args.weight_version = req.new_version
return {
"success": True,
"message": f"Weight version updated to {req.new_version}",
"new_version": req.new_version,
}
def register_engine_routes(self, runtime) -> None:
"""Register all engine routes for this handler.
......@@ -264,6 +331,21 @@ class BaseWorkerHandler(BaseGenerativeHandler):
runtime.register_engine_route(
"resume_memory_occupation", self.resume_memory_occupation
)
runtime.register_engine_route(
"update_weights_from_disk", self.update_weights_from_disk
)
runtime.register_engine_route(
"update_weights_from_tensor", self.update_weights_from_tensor
)
runtime.register_engine_route(
"update_weights_from_distributed", self.update_weights_from_distributed
)
runtime.register_engine_route(
"update_weights_from_ipc", self.update_weights_from_ipc
)
runtime.register_engine_route(
"update_weight_version", self.update_weight_version
)
@abstractmethod
async def generate(self, request: Dict[str, Any], context: Context):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment