Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
cd10654e
Unverified
Commit
cd10654e
authored
Aug 20, 2024
by
Shan Yu
Committed by
GitHub
Aug 20, 2024
Browse files
[Feat] Support update weights without restart server (#1157)
parent
350a8160
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
303 additions
and
13 deletions
+303
-13
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+5
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+14
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+42
-3
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+17
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+97
-9
python/sglang/srt/server.py
python/sglang/srt/server.py
+22
-1
test/srt/test_update_weights.py
test/srt/test_update_weights.py
+106
-0
No files found.
python/sglang/srt/managers/detokenizer_manager.py
View file @
cd10654e
...
...
@@ -28,6 +28,7 @@ from sglang.srt.managers.io_struct import (
BatchEmbeddingOut
,
BatchStrOut
,
BatchTokenIDOut
,
UpdateWeightReqOutput
,
)
from
sglang.srt.managers.schedule_batch
import
FINISH_MATCHED_STR
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
...
...
@@ -84,6 +85,10 @@ class DetokenizerManager:
)
continue
if
isinstance
(
recv_obj
,
UpdateWeightReqOutput
):
self
.
send_to_tokenizer
.
send_pyobj
(
recv_obj
)
continue
assert
isinstance
(
recv_obj
,
BatchTokenIDOut
)
bs
=
len
(
recv_obj
.
rids
)
...
...
python/sglang/srt/managers/io_struct.py
View file @
cd10654e
...
...
@@ -278,6 +278,20 @@ class FlushCacheReq:
pass
@
dataclass
class
UpdateWeightReqInput
:
# The model path with the new weights
model_path
:
str
# The format to load the weights
load_format
:
Optional
[
str
]
=
None
@
dataclass
class
UpdateWeightReqOutput
:
success
:
bool
message
:
str
@
dataclass
class
AbortReq
:
# The request id
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
cd10654e
...
...
@@ -46,6 +46,8 @@ from sglang.srt.managers.io_struct import (
GenerateReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
UpdateWeightReqInput
,
UpdateWeightReqOutput
,
)
from
sglang.srt.mm_utils
import
expand2square
,
process_anyres_image
from
sglang.srt.sampling_params
import
SamplingParams
...
...
@@ -121,6 +123,10 @@ class TokenizerManager:
self
.
to_create_loop
=
True
self
.
rid_to_state
:
Dict
[
str
,
ReqState
]
=
{}
# for update model weights
self
.
model_update_lock
=
asyncio
.
Lock
()
self
.
model_update_result
=
None
async
def
get_pixel_values
(
self
,
image_data
):
aspect_ratio
=
getattr
(
self
.
hf_config
,
"image_aspect_ratio"
,
None
)
grid_pinpoints
=
(
...
...
@@ -146,6 +152,9 @@ class TokenizerManager:
if
self
.
to_create_loop
:
self
.
create_handle_loop
()
while
self
.
model_update_lock
.
locked
():
await
asyncio
.
sleep
(
0
)
obj
.
post_init
()
is_single
=
obj
.
is_single
...
...
@@ -513,6 +522,30 @@ class TokenizerManager:
req
=
FlushCacheReq
()
self
.
send_to_router
.
send_pyobj
(
req
)
async
def
update_weights
(
self
,
obj
:
UpdateWeightReqInput
,
request
):
if
self
.
to_create_loop
:
self
.
create_handle_loop
()
# default the load format to the server_args
if
obj
.
load_format
is
None
:
obj
.
load_format
=
self
.
server_args
.
load_format
if
not
self
.
model_update_lock
.
locked
():
async
with
self
.
model_update_lock
:
# wait for the previous generation requests to finish
while
len
(
self
.
rid_to_state
)
>
0
:
await
asyncio
.
sleep
(
0
)
self
.
send_to_router
.
send_pyobj
(
obj
)
self
.
model_update_result
=
asyncio
.
Future
()
result
=
await
self
.
model_update_result
if
result
.
success
:
self
.
server_args
.
model_path
=
obj
.
model_path
self
.
server_args
.
load_format
=
obj
.
load_format
self
.
model_path
=
obj
.
model_path
return
result
.
success
,
result
.
message
else
:
return
False
,
"Another update is in progress. Please try again later."
def
abort_request
(
self
,
rid
:
str
):
if
rid
not
in
self
.
rid_to_state
:
return
...
...
@@ -541,12 +574,18 @@ class TokenizerManager:
async
def
handle_loop
(
self
):
while
True
:
recv_obj
:
Union
[
BatchStrOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
]
=
(
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
)
recv_obj
:
Union
[
BatchStrOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
,
UpdateWeightReqOutput
]
=
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
if
isinstance
(
recv_obj
,
UpdateWeightReqOutput
):
self
.
model_update_result
.
set_result
(
recv_obj
)
continue
assert
isinstance
(
recv_obj
,
(
BatchStrOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
)
),
f
"Unexpected obj received:
{
type
(
recv_obj
)
}
"
for
i
,
rid
in
enumerate
(
recv_obj
.
rids
):
state
=
self
.
rid_to_state
.
get
(
rid
,
None
)
if
state
is
None
:
...
...
python/sglang/srt/managers/tp_worker.py
View file @
cd10654e
...
...
@@ -39,6 +39,8 @@ from sglang.srt.managers.io_struct import (
FlushCacheReq
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
UpdateWeightReqInput
,
UpdateWeightReqOutput
,
)
from
sglang.srt.managers.policy_scheduler
import
PolicyScheduler
,
PrefillAdder
from
sglang.srt.managers.schedule_batch
import
(
...
...
@@ -214,6 +216,9 @@ class ModelTpServer:
self
.
flush_cache
()
elif
isinstance
(
recv_req
,
AbortReq
):
self
.
abort_request
(
recv_req
)
elif
isinstance
(
recv_req
,
UpdateWeightReqInput
):
success
,
message
=
self
.
update_weights
(
recv_req
)
self
.
out_pyobjs
.
append
(
UpdateWeightReqOutput
(
success
,
message
))
else
:
raise
ValueError
(
f
"Invalid request:
{
recv_req
}
"
)
...
...
@@ -773,12 +778,15 @@ class ModelTpServer:
self
.
token_to_kv_pool
.
clear
()
torch
.
cuda
.
empty_cache
()
logger
.
info
(
"Cache flushed successfully!"
)
if_success
=
True
else
:
logging
.
warning
(
f
"Cache not flushed because there are pending requests. "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
f
"#running-req:
{
0
if
self
.
running_batch
is
None
else
len
(
self
.
running_batch
.
reqs
)
}
"
)
if_success
=
False
return
if_success
def
abort_request
(
self
,
recv_req
):
# Delete requests in the waiting queue
...
...
@@ -798,6 +806,15 @@ class ModelTpServer:
req
.
finished_reason
=
FINISH_ABORT
()
break
def
update_weights
(
self
,
recv_req
):
success
,
message
=
self
.
model_runner
.
update_weights
(
recv_req
.
model_path
,
recv_req
.
load_format
)
if
success
:
flash_cache_success
=
self
.
flush_cache
()
assert
flash_cache_success
,
"Cache flush failed after updating weights"
return
success
,
message
def
run_tp_server
(
gpu_id
:
int
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
cd10654e
...
...
@@ -15,6 +15,7 @@ limitations under the License.
"""ModelRunner runs the forward passes of the models."""
import
gc
import
importlib
import
importlib.resources
import
logging
...
...
@@ -157,9 +158,9 @@ class ModelRunner:
self
.
server_args
.
dtype
=
"float16"
monkey_patch_vllm_dummy_weight_loader
()
device_config
=
DeviceConfig
()
load_config
=
LoadConfig
(
load_format
=
self
.
server_args
.
load_format
)
vllm_model_config
=
VllmModelConfig
(
self
.
device_config
=
DeviceConfig
()
self
.
load_config
=
LoadConfig
(
load_format
=
self
.
server_args
.
load_format
)
self
.
vllm_model_config
=
VllmModelConfig
(
model
=
self
.
server_args
.
model_path
,
quantization
=
self
.
server_args
.
quantization
,
tokenizer
=
None
,
...
...
@@ -173,17 +174,19 @@ class ModelRunner:
if
is_llama3_405b_fp8_head_16
(
self
.
model_config
)
and
self
.
tp_size
<=
8
:
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
self
.
model_config
.
hf_config
.
num_key_value_heads
=
8
vllm_model_config
.
hf_config
.
num_key_value_heads
=
8
self
.
vllm_model_config
.
hf_config
.
num_key_value_heads
=
8
monkey_patch_vllm_qvk_linear_loader
()
self
.
dtype
=
vllm_model_config
.
dtype
self
.
dtype
=
self
.
vllm_model_config
.
dtype
if
self
.
model_config
.
model_overide_args
is
not
None
:
vllm_model_config
.
hf_config
.
update
(
self
.
model_config
.
model_overide_args
)
self
.
vllm_model_config
.
hf_config
.
update
(
self
.
model_config
.
model_overide_args
)
self
.
model
=
get_model
(
model_config
=
vllm_model_config
,
device_config
=
device_config
,
load_config
=
load_config
,
model_config
=
self
.
vllm_model_config
,
device_config
=
self
.
device_config
,
load_config
=
self
.
load_config
,
lora_config
=
None
,
multimodal_config
=
None
,
parallel_config
=
None
,
...
...
@@ -206,6 +209,91 @@ class ModelRunner:
f
"avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
)
def
update_weights
(
self
,
model_path
,
load_format
):
from
vllm.model_executor.model_loader.loader
import
(
DefaultModelLoader
,
device_loading_context
,
get_model_loader
,
)
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
logger
.
info
(
f
"[gpu=
{
self
.
gpu_id
}
] Update weights begin. "
f
"avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
)
target_device
=
torch
.
device
(
self
.
device_config
.
device
)
try
:
vllm_model_config
=
VllmModelConfig
(
model
=
model_path
,
quantization
=
self
.
server_args
.
quantization
,
tokenizer
=
None
,
tokenizer_mode
=
None
,
trust_remote_code
=
self
.
server_args
.
trust_remote_code
,
dtype
=
self
.
server_args
.
dtype
,
seed
=
42
,
skip_tokenizer_init
=
True
,
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to load model config:
{
e
}
"
)
return
False
,
"Failed to update model weights"
load_config
=
LoadConfig
(
load_format
=
load_format
)
# Only support vllm DefaultModelLoader for now
loader
=
get_model_loader
(
load_config
)
if
not
isinstance
(
loader
,
DefaultModelLoader
):
logger
.
error
(
"Failed to get weights iterator: Unsupported loader"
)
return
False
,
"Failed to update model weights"
def
get_weight_iter
(
config
):
iter
=
loader
.
_get_weights_iterator
(
config
.
model
,
config
.
revision
,
fall_back_to_pt
=
getattr
(
self
.
model
,
"fall_back_to_pt_during_load"
,
True
),
)
return
iter
def
model_load_weights
(
model
,
iter
):
model
.
load_weights
(
iter
)
for
_
,
module
in
self
.
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
with
device_loading_context
(
module
,
target_device
):
quant_method
.
process_weights_after_loading
(
module
)
return
model
with
set_default_torch_dtype
(
vllm_model_config
.
dtype
):
try
:
iter
=
get_weight_iter
(
vllm_model_config
)
except
Exception
as
e
:
message
=
f
"Failed to get weights iterator:
{
e
}
"
logger
.
error
(
message
)
return
False
,
message
try
:
model
=
model_load_weights
(
self
.
model
,
iter
)
except
Exception
as
e
:
message
=
f
"Failed to update weights:
{
e
}
.
\n
Rolling back to original weights"
logger
.
error
(
message
)
del
iter
gc
.
collect
()
iter
=
get_weight_iter
(
self
.
vllm_model_config
)
self
.
model
=
model_load_weights
(
self
.
model
,
iter
)
return
False
,
message
self
.
model
=
model
self
.
server_args
.
model_path
=
model_path
self
.
server_args
.
load_format
=
load_format
self
.
vllm_model_config
=
vllm_model_config
self
.
load_config
=
load_config
self
.
model_config
.
path
=
model_path
logger
.
info
(
f
"[gpu=
{
self
.
gpu_id
}
] Update weights end."
)
return
True
,
"Succeeded to update model weights"
def
profile_max_num_token
(
self
,
total_gpu_memory
):
available_gpu_memory
=
get_available_gpu_memory
(
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
...
...
python/sglang/srt/server.py
View file @
cd10654e
...
...
@@ -51,7 +51,11 @@ from sglang.srt.managers.controller_single import (
start_controller_process
as
start_controller_process_single
,
)
from
sglang.srt.managers.detokenizer_manager
import
start_detokenizer_process
from
sglang.srt.managers.io_struct
import
EmbeddingReqInput
,
GenerateReqInput
from
sglang.srt.managers.io_struct
import
(
EmbeddingReqInput
,
GenerateReqInput
,
UpdateWeightReqInput
,
)
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.openai_api.adapter
import
(
load_chat_template_for_openai_api
,
...
...
@@ -136,6 +140,23 @@ async def flush_cache():
)
@
app
.
post
(
"/update_weights"
)
async
def
update_weights
(
obj
:
UpdateWeightReqInput
,
request
:
Request
):
success
,
message
=
await
tokenizer_manager
.
update_weights
(
obj
,
request
)
content
=
{
"message"
:
message
,
"success"
:
str
(
success
)}
if
success
:
return
JSONResponse
(
content
,
status_code
=
HTTPStatus
.
OK
,
)
else
:
return
JSONResponse
(
content
,
status_code
=
HTTPStatus
.
BAD_REQUEST
,
)
async
def
generate_request
(
obj
:
GenerateReqInput
,
request
:
Request
):
"""Handle a generate request."""
if
obj
.
stream
:
...
...
test/srt/test_update_weights.py
0 → 100644
View file @
cd10654e
import
json
import
unittest
import
requests
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_URL_FOR_UNIT_TEST
,
popen_launch_server
,
)
class
TestReplaceWeights
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_UNIT_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
300
)
@
classmethod
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
def
run_decode
(
self
):
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
"The capital of France is"
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
32
,
"n"
:
1
,
},
"stream"
:
False
,
"return_logprob"
:
False
,
"top_logprobs_num"
:
0
,
"return_text_in_logprobs"
:
False
,
"logprob_start_len"
:
0
,
},
)
print
(
json
.
dumps
(
response
.
json
()))
print
(
"="
*
100
)
# return the "text" in response
text
=
response
.
json
()[
"text"
]
return
text
def
get_model_info
(
self
):
response
=
requests
.
get
(
self
.
base_url
+
"/get_model_info"
)
model_path
=
response
.
json
()[
"model_path"
]
print
(
json
.
dumps
(
response
.
json
()))
return
model_path
def
run_update_weights
(
self
,
model_path
):
response
=
requests
.
post
(
self
.
base_url
+
"/update_weights"
,
json
=
{
"model_path"
:
model_path
,
},
)
print
(
json
.
dumps
(
response
.
json
()))
def
test_replace_weights
(
self
):
origin_model_path
=
self
.
get_model_info
()
print
(
f
"origin_model_path:
{
origin_model_path
}
"
)
origin_response
=
self
.
run_decode
()
# update weights
new_model_path
=
"meta-llama/Meta-Llama-3.1-8B"
self
.
run_update_weights
(
new_model_path
)
updated_model_path
=
self
.
get_model_info
()
print
(
f
"updated_model_path:
{
updated_model_path
}
"
)
assert
updated_model_path
==
new_model_path
assert
updated_model_path
!=
origin_model_path
updated_response
=
self
.
run_decode
()
assert
origin_response
[:
32
]
!=
updated_response
[:
32
]
# update weights back
self
.
run_update_weights
(
origin_model_path
)
updated_model_path
=
self
.
get_model_info
()
assert
updated_model_path
==
origin_model_path
updated_response
=
self
.
run_decode
()
assert
origin_response
[:
32
]
==
updated_response
[:
32
]
def
test_replace_weights_unexist_model
(
self
):
origin_model_path
=
self
.
get_model_info
()
print
(
f
"origin_model_path:
{
origin_model_path
}
"
)
origin_response
=
self
.
run_decode
()
# update weights
new_model_path
=
"meta-llama/Meta-Llama-3.1-8B-1"
self
.
run_update_weights
(
new_model_path
)
updated_model_path
=
self
.
get_model_info
()
print
(
f
"updated_model_path:
{
updated_model_path
}
"
)
assert
updated_model_path
==
origin_model_path
updated_response
=
self
.
run_decode
()
assert
origin_response
[:
32
]
==
updated_response
[:
32
]
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment