"tests/vscode:/vscode.git/clone" did not exist on "2bd85c059829c0939489e4dff3418420fbc52d30"
Unverified Commit cd10654e authored by Shan Yu's avatar Shan Yu Committed by GitHub
Browse files

[Feat] Support update weights without restart server (#1157)

parent 350a8160
...@@ -28,6 +28,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -28,6 +28,7 @@ from sglang.srt.managers.io_struct import (
BatchEmbeddingOut, BatchEmbeddingOut,
BatchStrOut, BatchStrOut,
BatchTokenIDOut, BatchTokenIDOut,
UpdateWeightReqOutput,
) )
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
...@@ -84,6 +85,10 @@ class DetokenizerManager: ...@@ -84,6 +85,10 @@ class DetokenizerManager:
) )
continue continue
if isinstance(recv_obj, UpdateWeightReqOutput):
self.send_to_tokenizer.send_pyobj(recv_obj)
continue
assert isinstance(recv_obj, BatchTokenIDOut) assert isinstance(recv_obj, BatchTokenIDOut)
bs = len(recv_obj.rids) bs = len(recv_obj.rids)
......
...@@ -278,6 +278,20 @@ class FlushCacheReq: ...@@ -278,6 +278,20 @@ class FlushCacheReq:
pass pass
@dataclass
class UpdateWeightReqInput:
# The model path with the new weights
model_path: str
# The format to load the weights
load_format: Optional[str] = None
@dataclass
class UpdateWeightReqOutput:
success: bool
message: str
@dataclass @dataclass
class AbortReq: class AbortReq:
# The request id # The request id
......
...@@ -46,6 +46,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -46,6 +46,8 @@ from sglang.srt.managers.io_struct import (
GenerateReqInput, GenerateReqInput,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
UpdateWeightReqInput,
UpdateWeightReqOutput,
) )
from sglang.srt.mm_utils import expand2square, process_anyres_image from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.sampling_params import SamplingParams from sglang.srt.sampling_params import SamplingParams
...@@ -121,6 +123,10 @@ class TokenizerManager: ...@@ -121,6 +123,10 @@ class TokenizerManager:
self.to_create_loop = True self.to_create_loop = True
self.rid_to_state: Dict[str, ReqState] = {} self.rid_to_state: Dict[str, ReqState] = {}
# for update model weights
self.model_update_lock = asyncio.Lock()
self.model_update_result = None
async def get_pixel_values(self, image_data): async def get_pixel_values(self, image_data):
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None) aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
grid_pinpoints = ( grid_pinpoints = (
...@@ -146,6 +152,9 @@ class TokenizerManager: ...@@ -146,6 +152,9 @@ class TokenizerManager:
if self.to_create_loop: if self.to_create_loop:
self.create_handle_loop() self.create_handle_loop()
while self.model_update_lock.locked():
await asyncio.sleep(0)
obj.post_init() obj.post_init()
is_single = obj.is_single is_single = obj.is_single
...@@ -513,6 +522,30 @@ class TokenizerManager: ...@@ -513,6 +522,30 @@ class TokenizerManager:
req = FlushCacheReq() req = FlushCacheReq()
self.send_to_router.send_pyobj(req) self.send_to_router.send_pyobj(req)
async def update_weights(self, obj: UpdateWeightReqInput, request):
if self.to_create_loop:
self.create_handle_loop()
# default the load format to the server_args
if obj.load_format is None:
obj.load_format = self.server_args.load_format
if not self.model_update_lock.locked():
async with self.model_update_lock:
# wait for the previous generation requests to finish
while len(self.rid_to_state) > 0:
await asyncio.sleep(0)
self.send_to_router.send_pyobj(obj)
self.model_update_result = asyncio.Future()
result = await self.model_update_result
if result.success:
self.server_args.model_path = obj.model_path
self.server_args.load_format = obj.load_format
self.model_path = obj.model_path
return result.success, result.message
else:
return False, "Another update is in progress. Please try again later."
def abort_request(self, rid: str): def abort_request(self, rid: str):
if rid not in self.rid_to_state: if rid not in self.rid_to_state:
return return
...@@ -541,12 +574,18 @@ class TokenizerManager: ...@@ -541,12 +574,18 @@ class TokenizerManager:
async def handle_loop(self): async def handle_loop(self):
while True: while True:
recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut] = ( recv_obj: Union[
await self.recv_from_detokenizer.recv_pyobj() BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput
) ] = await self.recv_from_detokenizer.recv_pyobj()
if isinstance(recv_obj, UpdateWeightReqOutput):
self.model_update_result.set_result(recv_obj)
continue
assert isinstance( assert isinstance(
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut) recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
), f"Unexpected obj received: {type(recv_obj)}" ), f"Unexpected obj received: {type(recv_obj)}"
for i, rid in enumerate(recv_obj.rids): for i, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None) state = self.rid_to_state.get(rid, None)
if state is None: if state is None:
......
...@@ -39,6 +39,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -39,6 +39,8 @@ from sglang.srt.managers.io_struct import (
FlushCacheReq, FlushCacheReq,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
UpdateWeightReqInput,
UpdateWeightReqOutput,
) )
from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import (
...@@ -214,6 +216,9 @@ class ModelTpServer: ...@@ -214,6 +216,9 @@ class ModelTpServer:
self.flush_cache() self.flush_cache()
elif isinstance(recv_req, AbortReq): elif isinstance(recv_req, AbortReq):
self.abort_request(recv_req) self.abort_request(recv_req)
elif isinstance(recv_req, UpdateWeightReqInput):
success, message = self.update_weights(recv_req)
self.out_pyobjs.append(UpdateWeightReqOutput(success, message))
else: else:
raise ValueError(f"Invalid request: {recv_req}") raise ValueError(f"Invalid request: {recv_req}")
...@@ -773,12 +778,15 @@ class ModelTpServer: ...@@ -773,12 +778,15 @@ class ModelTpServer:
self.token_to_kv_pool.clear() self.token_to_kv_pool.clear()
torch.cuda.empty_cache() torch.cuda.empty_cache()
logger.info("Cache flushed successfully!") logger.info("Cache flushed successfully!")
if_success = True
else: else:
logging.warning( logging.warning(
f"Cache not flushed because there are pending requests. " f"Cache not flushed because there are pending requests. "
f"#queue-req: {len(self.waiting_queue)}, " f"#queue-req: {len(self.waiting_queue)}, "
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}" f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
) )
if_success = False
return if_success
def abort_request(self, recv_req): def abort_request(self, recv_req):
# Delete requests in the waiting queue # Delete requests in the waiting queue
...@@ -798,6 +806,15 @@ class ModelTpServer: ...@@ -798,6 +806,15 @@ class ModelTpServer:
req.finished_reason = FINISH_ABORT() req.finished_reason = FINISH_ABORT()
break break
def update_weights(self, recv_req):
success, message = self.model_runner.update_weights(
recv_req.model_path, recv_req.load_format
)
if success:
flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights"
return success, message
def run_tp_server( def run_tp_server(
gpu_id: int, gpu_id: int,
......
...@@ -15,6 +15,7 @@ limitations under the License. ...@@ -15,6 +15,7 @@ limitations under the License.
"""ModelRunner runs the forward passes of the models.""" """ModelRunner runs the forward passes of the models."""
import gc
import importlib import importlib
import importlib.resources import importlib.resources
import logging import logging
...@@ -157,9 +158,9 @@ class ModelRunner: ...@@ -157,9 +158,9 @@ class ModelRunner:
self.server_args.dtype = "float16" self.server_args.dtype = "float16"
monkey_patch_vllm_dummy_weight_loader() monkey_patch_vllm_dummy_weight_loader()
device_config = DeviceConfig() self.device_config = DeviceConfig()
load_config = LoadConfig(load_format=self.server_args.load_format) self.load_config = LoadConfig(load_format=self.server_args.load_format)
vllm_model_config = VllmModelConfig( self.vllm_model_config = VllmModelConfig(
model=self.server_args.model_path, model=self.server_args.model_path,
quantization=self.server_args.quantization, quantization=self.server_args.quantization,
tokenizer=None, tokenizer=None,
...@@ -173,17 +174,19 @@ class ModelRunner: ...@@ -173,17 +174,19 @@ class ModelRunner:
if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8: if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
self.model_config.hf_config.num_key_value_heads = 8 self.model_config.hf_config.num_key_value_heads = 8
vllm_model_config.hf_config.num_key_value_heads = 8 self.vllm_model_config.hf_config.num_key_value_heads = 8
monkey_patch_vllm_qvk_linear_loader() monkey_patch_vllm_qvk_linear_loader()
self.dtype = vllm_model_config.dtype self.dtype = self.vllm_model_config.dtype
if self.model_config.model_overide_args is not None: if self.model_config.model_overide_args is not None:
vllm_model_config.hf_config.update(self.model_config.model_overide_args) self.vllm_model_config.hf_config.update(
self.model_config.model_overide_args
)
self.model = get_model( self.model = get_model(
model_config=vllm_model_config, model_config=self.vllm_model_config,
device_config=device_config, device_config=self.device_config,
load_config=load_config, load_config=self.load_config,
lora_config=None, lora_config=None,
multimodal_config=None, multimodal_config=None,
parallel_config=None, parallel_config=None,
...@@ -206,6 +209,91 @@ class ModelRunner: ...@@ -206,6 +209,91 @@ class ModelRunner:
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
) )
def update_weights(self, model_path, load_format):
from vllm.model_executor.model_loader.loader import (
DefaultModelLoader,
device_loading_context,
get_model_loader,
)
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
logger.info(
f"[gpu={self.gpu_id}] Update weights begin. "
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
target_device = torch.device(self.device_config.device)
try:
vllm_model_config = VllmModelConfig(
model=model_path,
quantization=self.server_args.quantization,
tokenizer=None,
tokenizer_mode=None,
trust_remote_code=self.server_args.trust_remote_code,
dtype=self.server_args.dtype,
seed=42,
skip_tokenizer_init=True,
)
except Exception as e:
logger.error(f"Failed to load model config: {e}")
return False, "Failed to update model weights"
load_config = LoadConfig(load_format=load_format)
# Only support vllm DefaultModelLoader for now
loader = get_model_loader(load_config)
if not isinstance(loader, DefaultModelLoader):
logger.error("Failed to get weights iterator: Unsupported loader")
return False, "Failed to update model weights"
def get_weight_iter(config):
iter = loader._get_weights_iterator(
config.model,
config.revision,
fall_back_to_pt=getattr(
self.model, "fall_back_to_pt_during_load", True
),
)
return iter
def model_load_weights(model, iter):
model.load_weights(iter)
for _, module in self.model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
return model
with set_default_torch_dtype(vllm_model_config.dtype):
try:
iter = get_weight_iter(vllm_model_config)
except Exception as e:
message = f"Failed to get weights iterator: {e}"
logger.error(message)
return False, message
try:
model = model_load_weights(self.model, iter)
except Exception as e:
message = f"Failed to update weights: {e}. \n Rolling back to original weights"
logger.error(message)
del iter
gc.collect()
iter = get_weight_iter(self.vllm_model_config)
self.model = model_load_weights(self.model, iter)
return False, message
self.model = model
self.server_args.model_path = model_path
self.server_args.load_format = load_format
self.vllm_model_config = vllm_model_config
self.load_config = load_config
self.model_config.path = model_path
logger.info(f"[gpu={self.gpu_id}] Update weights end.")
return True, "Succeeded to update model weights"
def profile_max_num_token(self, total_gpu_memory): def profile_max_num_token(self, total_gpu_memory):
available_gpu_memory = get_available_gpu_memory( available_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1 self.gpu_id, distributed=self.tp_size > 1
......
...@@ -51,7 +51,11 @@ from sglang.srt.managers.controller_single import ( ...@@ -51,7 +51,11 @@ from sglang.srt.managers.controller_single import (
start_controller_process as start_controller_process_single, start_controller_process as start_controller_process_single,
) )
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput from sglang.srt.managers.io_struct import (
EmbeddingReqInput,
GenerateReqInput,
UpdateWeightReqInput,
)
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api.adapter import ( from sglang.srt.openai_api.adapter import (
load_chat_template_for_openai_api, load_chat_template_for_openai_api,
...@@ -136,6 +140,23 @@ async def flush_cache(): ...@@ -136,6 +140,23 @@ async def flush_cache():
) )
@app.post("/update_weights")
async def update_weights(obj: UpdateWeightReqInput, request: Request):
success, message = await tokenizer_manager.update_weights(obj, request)
content = {"message": message, "success": str(success)}
if success:
return JSONResponse(
content,
status_code=HTTPStatus.OK,
)
else:
return JSONResponse(
content,
status_code=HTTPStatus.BAD_REQUEST,
)
async def generate_request(obj: GenerateReqInput, request: Request): async def generate_request(obj: GenerateReqInput, request: Request):
"""Handle a generate request.""" """Handle a generate request."""
if obj.stream: if obj.stream:
......
import json
import unittest
import requests
from sglang.srt.utils import kill_child_process
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_UNIT_TEST,
popen_launch_server,
)
class TestReplaceWeights(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_UNIT_TEST
cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300)
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
def run_decode(self):
response = requests.post(
self.base_url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
"n": 1,
},
"stream": False,
"return_logprob": False,
"top_logprobs_num": 0,
"return_text_in_logprobs": False,
"logprob_start_len": 0,
},
)
print(json.dumps(response.json()))
print("=" * 100)
# return the "text" in response
text = response.json()["text"]
return text
def get_model_info(self):
response = requests.get(self.base_url + "/get_model_info")
model_path = response.json()["model_path"]
print(json.dumps(response.json()))
return model_path
def run_update_weights(self, model_path):
response = requests.post(
self.base_url + "/update_weights",
json={
"model_path": model_path,
},
)
print(json.dumps(response.json()))
def test_replace_weights(self):
origin_model_path = self.get_model_info()
print(f"origin_model_path: {origin_model_path}")
origin_response = self.run_decode()
# update weights
new_model_path = "meta-llama/Meta-Llama-3.1-8B"
self.run_update_weights(new_model_path)
updated_model_path = self.get_model_info()
print(f"updated_model_path: {updated_model_path}")
assert updated_model_path == new_model_path
assert updated_model_path != origin_model_path
updated_response = self.run_decode()
assert origin_response[:32] != updated_response[:32]
# update weights back
self.run_update_weights(origin_model_path)
updated_model_path = self.get_model_info()
assert updated_model_path == origin_model_path
updated_response = self.run_decode()
assert origin_response[:32] == updated_response[:32]
def test_replace_weights_unexist_model(self):
origin_model_path = self.get_model_info()
print(f"origin_model_path: {origin_model_path}")
origin_response = self.run_decode()
# update weights
new_model_path = "meta-llama/Meta-Llama-3.1-8B-1"
self.run_update_weights(new_model_path)
updated_model_path = self.get_model_info()
print(f"updated_model_path: {updated_model_path}")
assert updated_model_path == origin_model_path
updated_response = self.run_decode()
assert origin_response[:32] == updated_response[:32]
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment