Unverified Commit 7d5d1d3d authored by Chayenne's avatar Chayenne Committed by GitHub
Browse files

udate weights from disk (#2265)

parent b53d6cbd
...@@ -82,7 +82,8 @@ ...@@ -82,7 +82,8 @@
"Get the information of the model.\n", "Get the information of the model.\n",
"\n", "\n",
"- `model_path`: The path/name of the model.\n", "- `model_path`: The path/name of the model.\n",
"- `is_generation`: Whether the model is used as generation model or embedding model." "- `is_generation`: Whether the model is used as generation model or embedding model.\n",
"- `tokenizer_path`: The path/name of the tokenizer."
] ]
}, },
{ {
...@@ -98,7 +99,8 @@ ...@@ -98,7 +99,8 @@
"print_highlight(response_json)\n", "print_highlight(response_json)\n",
"assert response_json[\"model_path\"] == \"meta-llama/Llama-3.2-1B-Instruct\"\n", "assert response_json[\"model_path\"] == \"meta-llama/Llama-3.2-1B-Instruct\"\n",
"assert response_json[\"is_generation\"] is True\n", "assert response_json[\"is_generation\"] is True\n",
"assert response_json.keys() == {\"model_path\", \"is_generation\"}" "assert response_json[\"tokenizer_path\"] == \"meta-llama/Llama-3.2-1B-Instruct\"\n",
"assert response_json.keys() == {\"model_path\", \"is_generation\", \"tokenizer_path\"}"
] ]
}, },
{ {
...@@ -187,9 +189,11 @@ ...@@ -187,9 +189,11 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Update Weights\n", "## Update Weights From Disk\n",
"\n", "\n",
"Update model weights without restarting the server. Use for continuous evaluation during training. Only applicable for models with the same architecture and parameter size." "Update model weights from disk without restarting the server. Only applicable for models with the same architecture and parameter size.\n",
"\n",
"SGLang support `update_weights_from_disk` API for continuous evaluation during training (save checkpoint to disk and update weights from disk).\n"
] ]
}, },
{ {
...@@ -200,7 +204,7 @@ ...@@ -200,7 +204,7 @@
"source": [ "source": [
"# successful update with same architecture and size\n", "# successful update with same architecture and size\n",
"\n", "\n",
"url = \"http://localhost:30010/update_weights\"\n", "url = \"http://localhost:30010/update_weights_from_disk\"\n",
"data = {\"model_path\": \"meta-llama/Llama-3.2-1B\"}\n", "data = {\"model_path\": \"meta-llama/Llama-3.2-1B\"}\n",
"\n", "\n",
"response = requests.post(url, json=data)\n", "response = requests.post(url, json=data)\n",
...@@ -218,7 +222,7 @@ ...@@ -218,7 +222,7 @@
"source": [ "source": [
"# failed update with different parameter size\n", "# failed update with different parameter size\n",
"\n", "\n",
"url = \"http://localhost:30010/update_weights\"\n", "url = \"http://localhost:30010/update_weights_from_disk\"\n",
"data = {\"model_path\": \"meta-llama/Llama-3.2-3B\"}\n", "data = {\"model_path\": \"meta-llama/Llama-3.2-3B\"}\n",
"\n", "\n",
"response = requests.post(url, json=data)\n", "response = requests.post(url, json=data)\n",
......
...@@ -352,7 +352,7 @@ class FlushCacheReq: ...@@ -352,7 +352,7 @@ class FlushCacheReq:
@dataclass @dataclass
class UpdateWeightReqInput: class UpdateWeightFromDiskReqInput:
# The model path with the new weights # The model path with the new weights
model_path: str model_path: str
# The format to load the weights # The format to load the weights
...@@ -360,7 +360,7 @@ class UpdateWeightReqInput: ...@@ -360,7 +360,7 @@ class UpdateWeightReqInput:
@dataclass @dataclass
class UpdateWeightReqOutput: class UpdateWeightFromDiskReqOutput:
success: bool success: bool
message: str message: str
......
...@@ -43,8 +43,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -43,8 +43,8 @@ from sglang.srt.managers.io_struct import (
ProfileReq, ProfileReq,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
UpdateWeightReqInput, UpdateWeightFromDiskReqInput,
UpdateWeightReqOutput, UpdateWeightFromDiskReqOutput,
) )
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import (
FINISH_ABORT, FINISH_ABORT,
...@@ -506,10 +506,10 @@ class Scheduler: ...@@ -506,10 +506,10 @@ class Scheduler:
self.flush_cache() self.flush_cache()
elif isinstance(recv_req, AbortReq): elif isinstance(recv_req, AbortReq):
self.abort_request(recv_req) self.abort_request(recv_req)
elif isinstance(recv_req, UpdateWeightReqInput): elif isinstance(recv_req, UpdateWeightFromDiskReqInput):
success, message = self.update_weights(recv_req) success, message = self.update_weights_from_disk(recv_req)
self.send_to_tokenizer.send_pyobj( self.send_to_tokenizer.send_pyobj(
UpdateWeightReqOutput(success, message) UpdateWeightFromDiskReqOutput(success, message)
) )
elif isinstance(recv_req, ProfileReq): elif isinstance(recv_req, ProfileReq):
if recv_req == ProfileReq.START_PROFILE: if recv_req == ProfileReq.START_PROFILE:
...@@ -1363,9 +1363,9 @@ class Scheduler: ...@@ -1363,9 +1363,9 @@ class Scheduler:
req.to_abort = True req.to_abort = True
break break
def update_weights(self, recv_req: UpdateWeightReqInput): def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
"""In-place update of the weights.""" """In-place update of the weights from disk."""
success, message = self.tp_worker.update_weights(recv_req) success, message = self.tp_worker.update_weights_from_disk(recv_req)
if success: if success:
flash_cache_success = self.flush_cache() flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights" assert flash_cache_success, "Cache flush failed after updating weights"
......
...@@ -25,6 +25,7 @@ import uuid ...@@ -25,6 +25,7 @@ import uuid
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import fastapi import fastapi
import torch
import uvloop import uvloop
import zmq import zmq
import zmq.asyncio import zmq.asyncio
...@@ -50,8 +51,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -50,8 +51,8 @@ from sglang.srt.managers.io_struct import (
ProfileReq, ProfileReq,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
UpdateWeightReqInput, UpdateWeightFromDiskReqInput,
UpdateWeightReqOutput, UpdateWeightFromDiskReqOutput,
) )
from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
...@@ -405,8 +406,10 @@ class TokenizerManager: ...@@ -405,8 +406,10 @@ class TokenizerManager:
req = ProfileReq.STOP_PROFILE req = ProfileReq.STOP_PROFILE
self.send_to_scheduler.send_pyobj(req) self.send_to_scheduler.send_pyobj(req)
async def update_weights( async def update_weights_from_disk(
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None self,
obj: UpdateWeightFromDiskReqInput,
request: Optional[fastapi.Request] = None,
): ):
if self.to_create_loop: if self.to_create_loop:
self.create_handle_loop() self.create_handle_loop()
...@@ -520,10 +523,13 @@ class TokenizerManager: ...@@ -520,10 +523,13 @@ class TokenizerManager:
while True: while True:
recv_obj: Union[ recv_obj: Union[
BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput BatchStrOut,
BatchEmbeddingOut,
BatchTokenIDOut,
UpdateWeightFromDiskReqOutput,
] = await self.recv_from_detokenizer.recv_pyobj() ] = await self.recv_from_detokenizer.recv_pyobj()
if isinstance(recv_obj, UpdateWeightReqOutput): if isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
if self.server_args.dp_size == 1: if self.server_args.dp_size == 1:
self.model_update_result.set_result(recv_obj) self.model_update_result.set_result(recv_obj)
else: # self.server_args.dp_size > 1 else: # self.server_args.dp_size > 1
......
...@@ -19,7 +19,7 @@ from typing import Optional ...@@ -19,7 +19,7 @@ from typing import Optional
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import UpdateWeightReqInput from sglang.srt.managers.io_struct import UpdateWeightFromDiskReqInput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
...@@ -155,8 +155,8 @@ class TpModelWorker: ...@@ -155,8 +155,8 @@ class TpModelWorker:
embeddings = logits_output.embeddings embeddings = logits_output.embeddings
return embeddings return embeddings
def update_weights(self, recv_req: UpdateWeightReqInput): def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
success, message = self.model_runner.update_weights( success, message = self.model_runner.update_weights_from_disk(
recv_req.model_path, recv_req.load_format recv_req.model_path, recv_req.load_format
) )
return success, message return success, message
...@@ -23,7 +23,7 @@ from typing import Optional ...@@ -23,7 +23,7 @@ from typing import Optional
import psutil import psutil
import torch import torch
from sglang.srt.managers.io_struct import UpdateWeightReqInput from sglang.srt.managers.io_struct import UpdateWeightFromDiskReqInput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -204,8 +204,8 @@ class TpModelWorkerClient: ...@@ -204,8 +204,8 @@ class TpModelWorkerClient:
) % self.future_token_ids_limit ) % self.future_token_ids_limit
return None, future_next_token_ids return None, future_next_token_ids
def update_weights(self, recv_req: UpdateWeightReqInput): def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
success, message = self.worker.update_weights(recv_req) success, message = self.worker.update_weights_from_disk(recv_req)
return success, message return success, message
def __delete__(self): def __delete__(self):
......
...@@ -20,10 +20,13 @@ import inspect ...@@ -20,10 +20,13 @@ import inspect
import json import json
import logging import logging
import pkgutil import pkgutil
import time
from functools import lru_cache from functools import lru_cache
from typing import Optional, Type from tokenize import tabsize
from typing import Any, Optional, Type, Union
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig from vllm.config import ModelConfig as VllmModelConfig
...@@ -319,8 +322,8 @@ class ModelRunner: ...@@ -319,8 +322,8 @@ class ModelRunner:
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
) )
def update_weights(self, model_path: str, load_format: str): def update_weights_from_disk(self, model_path: str, load_format: str):
"""Update weights in-place.""" """Update engine weights online from disk."""
from vllm.model_executor.model_loader.loader import ( from vllm.model_executor.model_loader.loader import (
DefaultModelLoader, DefaultModelLoader,
device_loading_context, device_loading_context,
...@@ -329,7 +332,7 @@ class ModelRunner: ...@@ -329,7 +332,7 @@ class ModelRunner:
from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.utils import set_default_torch_dtype
logger.info( logger.info(
f"Update weights begin. " f"Update engine weights online from disk begin. "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
) )
......
...@@ -53,7 +53,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -53,7 +53,7 @@ from sglang.srt.managers.io_struct import (
EmbeddingReqInput, EmbeddingReqInput,
GenerateReqInput, GenerateReqInput,
OpenSessionReqInput, OpenSessionReqInput,
UpdateWeightReqInput, UpdateWeightFromDiskReqInput,
) )
from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
...@@ -192,11 +192,11 @@ async def stop_profile_async(): ...@@ -192,11 +192,11 @@ async def stop_profile_async():
) )
@app.post("/update_weights") @app.post("/update_weights_from_disk")
@time_func_latency @time_func_latency
async def update_weights(obj: UpdateWeightReqInput, request: Request): async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
"""Update the weights inplace without re-launching the server.""" """Update the weights from disk inplace without re-launching the server."""
success, message = await tokenizer_manager.update_weights(obj, request) success, message = await tokenizer_manager.update_weights_from_disk(obj, request)
content = {"success": success, "message": message} content = {"success": success, "message": message}
if success: if success:
return ORJSONResponse( return ORJSONResponse(
......
...@@ -424,6 +424,7 @@ def popen_launch_server( ...@@ -424,6 +424,7 @@ def popen_launch_server(
port, port,
*other_args, *other_args,
] ]
if api_key: if api_key:
command += ["--api-key", api_key] command += ["--api-key", api_key]
......
...@@ -44,7 +44,7 @@ class TestDataParallelism(unittest.TestCase): ...@@ -44,7 +44,7 @@ class TestDataParallelism(unittest.TestCase):
def test_update_weight(self): def test_update_weight(self):
response = requests.post( response = requests.post(
self.base_url + "/update_weights", self.base_url + "/update_weights_from_disk",
json={"model_path": DEFAULT_MODEL_NAME_FOR_TEST}, json={"model_path": DEFAULT_MODEL_NAME_FOR_TEST},
) )
...@@ -55,7 +55,7 @@ class TestDataParallelism(unittest.TestCase): ...@@ -55,7 +55,7 @@ class TestDataParallelism(unittest.TestCase):
time.sleep(5) time.sleep(5)
response = requests.post( response = requests.post(
self.base_url + "/update_weights", self.base_url + "/update_weights_from_disk",
json={"model_path": DEFAULT_MODEL_NAME_FOR_TEST}, json={"model_path": DEFAULT_MODEL_NAME_FOR_TEST},
) )
......
...@@ -49,7 +49,7 @@ class TestUpdateWeights(unittest.TestCase): ...@@ -49,7 +49,7 @@ class TestUpdateWeights(unittest.TestCase):
def run_update_weights(self, model_path): def run_update_weights(self, model_path):
response = requests.post( response = requests.post(
self.base_url + "/update_weights", self.base_url + "/update_weights_from_disk",
json={ json={
"model_path": model_path, "model_path": model_path,
}, },
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment