Unverified Commit 7d5d1d3d authored by Chayenne's avatar Chayenne Committed by GitHub
Browse files

udate weights from disk (#2265)

parent b53d6cbd
......@@ -82,7 +82,8 @@
"Get the information of the model.\n",
"\n",
"- `model_path`: The path/name of the model.\n",
"- `is_generation`: Whether the model is used as generation model or embedding model."
"- `is_generation`: Whether the model is used as generation model or embedding model.\n",
"- `tokenizer_path`: The path/name of the tokenizer."
]
},
{
......@@ -98,7 +99,8 @@
"print_highlight(response_json)\n",
"assert response_json[\"model_path\"] == \"meta-llama/Llama-3.2-1B-Instruct\"\n",
"assert response_json[\"is_generation\"] is True\n",
"assert response_json.keys() == {\"model_path\", \"is_generation\"}"
"assert response_json[\"tokenizer_path\"] == \"meta-llama/Llama-3.2-1B-Instruct\"\n",
"assert response_json.keys() == {\"model_path\", \"is_generation\", \"tokenizer_path\"}"
]
},
{
......@@ -187,9 +189,11 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Update Weights\n",
"## Update Weights From Disk\n",
"\n",
"Update model weights without restarting the server. Use for continuous evaluation during training. Only applicable for models with the same architecture and parameter size."
"Update model weights from disk without restarting the server. Only applicable for models with the same architecture and parameter size.\n",
"\n",
"SGLang support `update_weights_from_disk` API for continuous evaluation during training (save checkpoint to disk and update weights from disk).\n"
]
},
{
......@@ -200,7 +204,7 @@
"source": [
"# successful update with same architecture and size\n",
"\n",
"url = \"http://localhost:30010/update_weights\"\n",
"url = \"http://localhost:30010/update_weights_from_disk\"\n",
"data = {\"model_path\": \"meta-llama/Llama-3.2-1B\"}\n",
"\n",
"response = requests.post(url, json=data)\n",
......@@ -218,7 +222,7 @@
"source": [
"# failed update with different parameter size\n",
"\n",
"url = \"http://localhost:30010/update_weights\"\n",
"url = \"http://localhost:30010/update_weights_from_disk\"\n",
"data = {\"model_path\": \"meta-llama/Llama-3.2-3B\"}\n",
"\n",
"response = requests.post(url, json=data)\n",
......
......@@ -352,7 +352,7 @@ class FlushCacheReq:
@dataclass
class UpdateWeightReqInput:
class UpdateWeightFromDiskReqInput:
# The model path with the new weights
model_path: str
# The format to load the weights
......@@ -360,7 +360,7 @@ class UpdateWeightReqInput:
@dataclass
class UpdateWeightReqOutput:
class UpdateWeightFromDiskReqOutput:
success: bool
message: str
......
......@@ -43,8 +43,8 @@ from sglang.srt.managers.io_struct import (
ProfileReq,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
UpdateWeightReqInput,
UpdateWeightReqOutput,
UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput,
)
from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
......@@ -506,10 +506,10 @@ class Scheduler:
self.flush_cache()
elif isinstance(recv_req, AbortReq):
self.abort_request(recv_req)
elif isinstance(recv_req, UpdateWeightReqInput):
success, message = self.update_weights(recv_req)
elif isinstance(recv_req, UpdateWeightFromDiskReqInput):
success, message = self.update_weights_from_disk(recv_req)
self.send_to_tokenizer.send_pyobj(
UpdateWeightReqOutput(success, message)
UpdateWeightFromDiskReqOutput(success, message)
)
elif isinstance(recv_req, ProfileReq):
if recv_req == ProfileReq.START_PROFILE:
......@@ -1363,9 +1363,9 @@ class Scheduler:
req.to_abort = True
break
def update_weights(self, recv_req: UpdateWeightReqInput):
"""In-place update of the weights."""
success, message = self.tp_worker.update_weights(recv_req)
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
"""In-place update of the weights from disk."""
success, message = self.tp_worker.update_weights_from_disk(recv_req)
if success:
flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights"
......
......@@ -25,6 +25,7 @@ import uuid
from typing import Dict, List, Optional, Tuple, Union
import fastapi
import torch
import uvloop
import zmq
import zmq.asyncio
......@@ -50,8 +51,8 @@ from sglang.srt.managers.io_struct import (
ProfileReq,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
UpdateWeightReqInput,
UpdateWeightReqOutput,
UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput,
)
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams
......@@ -405,8 +406,10 @@ class TokenizerManager:
req = ProfileReq.STOP_PROFILE
self.send_to_scheduler.send_pyobj(req)
async def update_weights(
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
async def update_weights_from_disk(
self,
obj: UpdateWeightFromDiskReqInput,
request: Optional[fastapi.Request] = None,
):
if self.to_create_loop:
self.create_handle_loop()
......@@ -520,10 +523,13 @@ class TokenizerManager:
while True:
recv_obj: Union[
BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput
BatchStrOut,
BatchEmbeddingOut,
BatchTokenIDOut,
UpdateWeightFromDiskReqOutput,
] = await self.recv_from_detokenizer.recv_pyobj()
if isinstance(recv_obj, UpdateWeightReqOutput):
if isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
if self.server_args.dp_size == 1:
self.model_update_result.set_result(recv_obj)
else: # self.server_args.dp_size > 1
......
......@@ -19,7 +19,7 @@ from typing import Optional
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import UpdateWeightReqInput
from sglang.srt.managers.io_struct import UpdateWeightFromDiskReqInput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner
......@@ -155,8 +155,8 @@ class TpModelWorker:
embeddings = logits_output.embeddings
return embeddings
def update_weights(self, recv_req: UpdateWeightReqInput):
success, message = self.model_runner.update_weights(
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
success, message = self.model_runner.update_weights_from_disk(
recv_req.model_path, recv_req.load_format
)
return success, message
......@@ -23,7 +23,7 @@ from typing import Optional
import psutil
import torch
from sglang.srt.managers.io_struct import UpdateWeightReqInput
from sglang.srt.managers.io_struct import UpdateWeightFromDiskReqInput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.server_args import ServerArgs
......@@ -204,8 +204,8 @@ class TpModelWorkerClient:
) % self.future_token_ids_limit
return None, future_next_token_ids
def update_weights(self, recv_req: UpdateWeightReqInput):
success, message = self.worker.update_weights(recv_req)
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
success, message = self.worker.update_weights_from_disk(recv_req)
return success, message
def __delete__(self):
......
......@@ -20,10 +20,13 @@ import inspect
import json
import logging
import pkgutil
import time
from functools import lru_cache
from typing import Optional, Type
from tokenize import tabsize
from typing import Any, Optional, Type, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
......@@ -319,8 +322,8 @@ class ModelRunner:
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
def update_weights(self, model_path: str, load_format: str):
"""Update weights in-place."""
def update_weights_from_disk(self, model_path: str, load_format: str):
"""Update engine weights online from disk."""
from vllm.model_executor.model_loader.loader import (
DefaultModelLoader,
device_loading_context,
......@@ -329,7 +332,7 @@ class ModelRunner:
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
logger.info(
f"Update weights begin. "
f"Update engine weights online from disk begin. "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
......
......@@ -53,7 +53,7 @@ from sglang.srt.managers.io_struct import (
EmbeddingReqInput,
GenerateReqInput,
OpenSessionReqInput,
UpdateWeightReqInput,
UpdateWeightFromDiskReqInput,
)
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager
......@@ -192,11 +192,11 @@ async def stop_profile_async():
)
@app.post("/update_weights")
@app.post("/update_weights_from_disk")
@time_func_latency
async def update_weights(obj: UpdateWeightReqInput, request: Request):
"""Update the weights inplace without re-launching the server."""
success, message = await tokenizer_manager.update_weights(obj, request)
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
"""Update the weights from disk inplace without re-launching the server."""
success, message = await tokenizer_manager.update_weights_from_disk(obj, request)
content = {"success": success, "message": message}
if success:
return ORJSONResponse(
......
......@@ -424,6 +424,7 @@ def popen_launch_server(
port,
*other_args,
]
if api_key:
command += ["--api-key", api_key]
......
......@@ -44,7 +44,7 @@ class TestDataParallelism(unittest.TestCase):
def test_update_weight(self):
response = requests.post(
self.base_url + "/update_weights",
self.base_url + "/update_weights_from_disk",
json={"model_path": DEFAULT_MODEL_NAME_FOR_TEST},
)
......@@ -55,7 +55,7 @@ class TestDataParallelism(unittest.TestCase):
time.sleep(5)
response = requests.post(
self.base_url + "/update_weights",
self.base_url + "/update_weights_from_disk",
json={"model_path": DEFAULT_MODEL_NAME_FOR_TEST},
)
......
......@@ -49,7 +49,7 @@ class TestUpdateWeights(unittest.TestCase):
def run_update_weights(self, model_path):
response = requests.post(
self.base_url + "/update_weights",
self.base_url + "/update_weights_from_disk",
json={
"model_path": model_path,
},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment