Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
70b3c6ee
Unverified
Commit
70b3c6ee
authored
Mar 05, 2025
by
Jhin
Committed by
GitHub
Mar 05, 2025
Browse files
Add update_weights_from_disk endpoint to Engine (#4102)
Co-authored-by:
zhaochenyang20
<
zhaochen20@outlook.com
>
parent
ef9d3b3c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
239 additions
and
31 deletions
+239
-31
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+22
-0
test/srt/test_update_weights_from_disk.py
test/srt/test_update_weights_from_disk.py
+214
-30
test/srt/test_update_weights_from_distributed.py
test/srt/test_update_weights_from_distributed.py
+3
-1
No files found.
python/sglang/srt/entrypoints/engine.py
View file @
70b3c6ee
...
...
@@ -44,6 +44,7 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqInput
,
ReleaseMemoryOccupationReqInput
,
ResumeMemoryOccupationReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromTensorReqInput
,
)
...
...
@@ -302,6 +303,27 @@ class Engine:
self
.
tokenizer_manager
.
update_weights_from_tensor
(
obj
,
None
)
)
def
update_weights_from_disk
(
self
,
model_path
:
str
,
load_format
:
Optional
[
str
]
=
None
,
):
"""Update the weights from disk inplace without re-launching the engine.
This method allows updating the model weights from disk without restarting
the engine. It can be used to load a different model or update weights with
new training.
"""
obj
=
UpdateWeightFromDiskReqInput
(
model_path
=
model_path
,
load_format
=
load_format
,
)
loop
=
asyncio
.
get_event_loop
()
return
loop
.
run_until_complete
(
self
.
tokenizer_manager
.
update_weights_from_disk
(
obj
,
None
)
)
def
get_weights_by_name
(
self
,
name
:
str
,
truncate_size
:
int
=
100
):
"""Get weights by parameter name."""
obj
=
GetWeightsByNameReqInput
(
name
=
name
,
truncate_size
=
truncate_size
)
...
...
test/srt/test_update_weights_from_disk.py
View file @
70b3c6ee
import
json
import
random
import
unittest
import
requests
import
sglang
as
sgl
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.test_utils
import
(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
is_in_ci
,
popen_launch_server
,
)
class
TestUpdateWeights
(
unittest
.
TestCase
):
###############################################################################
# Engine Mode Tests (Single-configuration)
###############################################################################
class
TestEngineUpdateWeightsFromDisk
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
model
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
# Initialize the engine in offline (direct) mode.
self
.
engine
=
sgl
.
Engine
(
model_path
=
self
.
model
)
def
tearDown
(
self
):
self
.
engine
.
shutdown
()
def
run_decode
(
self
):
prompts
=
[
"The capital of France is"
]
sampling_params
=
{
"temperature"
:
0
,
"max_new_tokens"
:
32
}
outputs
=
self
.
engine
.
generate
(
prompts
,
sampling_params
)
print
(
"="
*
100
)
print
(
f
"[Engine Mode] Prompt:
{
prompts
[
0
]
}
\n
Generated text:
{
outputs
[
0
][
'text'
]
}
"
)
return
outputs
[
0
][
"text"
]
def
run_update_weights
(
self
,
model_path
):
ret
=
self
.
engine
.
update_weights_from_disk
(
model_path
)
print
(
json
.
dumps
(
ret
))
return
ret
def
test_update_weights
(
self
):
origin_response
=
self
.
run_decode
()
# Update weights: use new model (remove "-Instruct")
new_model_path
=
self
.
model
.
replace
(
"-Instruct"
,
""
)
ret
=
self
.
run_update_weights
(
new_model_path
)
self
.
assertTrue
(
ret
[
0
])
# ret is a tuple; index 0 holds the success flag
updated_response
=
self
.
run_decode
()
self
.
assertNotEqual
(
origin_response
[:
32
],
updated_response
[:
32
])
# Revert back to original weights
ret
=
self
.
run_update_weights
(
self
.
model
)
self
.
assertTrue
(
ret
[
0
])
reverted_response
=
self
.
run_decode
()
self
.
assertEqual
(
origin_response
[:
32
],
reverted_response
[:
32
])
def
test_update_weights_unexist_model
(
self
):
origin_response
=
self
.
run_decode
()
new_model_path
=
self
.
model
.
replace
(
"-Instruct"
,
"wrong"
)
ret
=
self
.
run_update_weights
(
new_model_path
)
self
.
assertFalse
(
ret
[
0
])
updated_response
=
self
.
run_decode
()
self
.
assertEqual
(
origin_response
[:
32
],
updated_response
[:
32
])
###############################################################################
# HTTP Server Mode Tests (Single-configuration)
###############################################################################
class
TestServerUpdateWeightsFromDisk
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
...
...
@@ -30,16 +88,12 @@ class TestUpdateWeights(unittest.TestCase):
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
"The capital of France is"
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
32
,
},
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
32
},
},
)
print
(
json
.
dumps
(
response
.
json
()))
print
(
"="
*
100
)
text
=
response
.
json
()[
"
text
"
]
return
text
print
(
f
"[Server Mode] Generated
text
:
{
response
.
json
()[
'
text
'
]
}
"
)
return
response
.
json
()[
"
text
"
]
def
get_model_info
(
self
):
response
=
requests
.
get
(
self
.
base_url
+
"/get_model_info"
)
...
...
@@ -50,58 +104,188 @@ class TestUpdateWeights(unittest.TestCase):
def
run_update_weights
(
self
,
model_path
):
response
=
requests
.
post
(
self
.
base_url
+
"/update_weights_from_disk"
,
json
=
{
"model_path"
:
model_path
,
},
json
=
{
"model_path"
:
model_path
},
)
ret
=
response
.
json
()
print
(
json
.
dumps
(
re
sponse
.
json
()
))
print
(
json
.
dumps
(
re
t
))
return
ret
def
test_update_weights
(
self
):
origin_model_path
=
self
.
get_model_info
()
print
(
f
"origin_model_path:
{
origin_model_path
}
"
)
print
(
f
"
[Server Mode]
origin_model_path:
{
origin_model_path
}
"
)
origin_response
=
self
.
run_decode
()
# update weights
new_model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
.
replace
(
"-Instruct"
,
""
)
ret
=
self
.
run_update_weights
(
new_model_path
)
assert
ret
[
"success"
]
self
.
assert
True
(
ret
[
"success"
]
)
updated_model_path
=
self
.
get_model_info
()
print
(
f
"updated_model_path:
{
updated_model_path
}
"
)
assert
updated_model_path
==
new_model_path
assert
updated_model_path
!=
origin_model_path
print
(
f
"
[Server Mode]
updated_model_path:
{
updated_model_path
}
"
)
self
.
assert
Equal
(
updated_model_path
,
new_model_path
)
self
.
assert
NotEqual
(
updated_model_path
,
origin_model_path
)
updated_response
=
self
.
run_decode
()
assert
origin_response
[:
32
]
!=
updated_response
[:
32
]
self
.
assert
NotEqual
(
origin_response
[:
32
]
,
updated_response
[:
32
]
)
# update weights back
ret
=
self
.
run_update_weights
(
origin_model_path
)
assert
ret
[
"success"
]
self
.
assertTrue
(
ret
[
"success"
])
updated_model_path
=
self
.
get_model_info
()
assert
updated_model_path
==
origin_model_path
self
.
assert
Equal
(
updated_model_path
,
origin_model_path
)
updated_response
=
self
.
run_decode
()
assert
origin_response
[:
32
]
==
updated_response
[:
32
]
self
.
assert
Equal
(
origin_response
[:
32
]
,
updated_response
[:
32
]
)
def
test_update_weights_unexist_model
(
self
):
origin_model_path
=
self
.
get_model_info
()
print
(
f
"origin_model_path:
{
origin_model_path
}
"
)
print
(
f
"
[Server Mode]
origin_model_path:
{
origin_model_path
}
"
)
origin_response
=
self
.
run_decode
()
# update weights
new_model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
.
replace
(
"-Instruct"
,
"wrong"
)
ret
=
self
.
run_update_weights
(
new_model_path
)
assert
not
ret
[
"success"
]
self
.
assert
False
(
ret
[
"success"
]
)
updated_model_path
=
self
.
get_model_info
()
print
(
f
"updated_model_path:
{
updated_model_path
}
"
)
assert
updated_model_path
==
origin_model_path
print
(
f
"
[Server Mode]
updated_model_path:
{
updated_model_path
}
"
)
self
.
assert
Equal
(
updated_model_path
,
origin_model_path
)
updated_response
=
self
.
run_decode
()
assert
origin_response
[:
32
]
==
updated_response
[:
32
]
self
.
assertEqual
(
origin_response
[:
32
],
updated_response
[:
32
])
###############################################################################
# Parameterized Tests for update_weights_from_disk
# Test coverage is determined based on the value of is_in_ci:
# - In a CI environment: randomly select one mode (Engine or Server) and test only with tp=1, dp=1.
# - In a non-CI environment: test both Engine and Server modes, and enumerate all combinations
# with tp and dp ranging from 1 to 2.
###############################################################################
class
TestUpdateWeightsFromDiskParameterized
(
unittest
.
TestCase
):
def
run_common_test
(
self
,
mode
,
tp
,
dp
):
"""
Common test procedure for update_weights_from_disk.
For Engine mode, we instantiate the engine with tp_size=tp.
For Server mode, we launch the server with additional arguments for tp (dp is not used in server launch here).
"""
if
mode
==
"Engine"
:
# Instantiate engine with additional parameter tp_size.
print
(
f
"[Parameterized Engine] Testing with tp=
{
tp
}
, dp=
{
dp
}
"
)
engine
=
sgl
.
Engine
(
model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
random_seed
=
42
,
tp_size
=
tp
,
# dp parameter is not explicitly used in this API.
)
try
:
origin_response
=
self
.
_engine_update_weights_test
(
engine
)
finally
:
engine
.
shutdown
()
elif
mode
==
"Server"
:
print
(
f
"[Parameterized Server] Testing with tp=
{
tp
}
, dp=
{
dp
}
"
)
# Pass additional arguments to launch the server.
base_args
=
[
"--tp-size"
,
str
(
tp
)]
process
=
popen_launch_server
(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
DEFAULT_URL_FOR_TEST
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
base_args
,
)
try
:
origin_response
=
self
.
_server_update_weights_test
(
DEFAULT_URL_FOR_TEST
)
finally
:
kill_process_tree
(
process
.
pid
)
else
:
raise
ValueError
(
f
"Unknown mode:
{
mode
}
"
)
def
_engine_update_weights_test
(
self
,
engine
):
# Run the update weights test on the given engine instance.
def
run_decode
():
prompts
=
[
"The capital of France is"
]
sampling_params
=
{
"temperature"
:
0
,
"max_new_tokens"
:
32
}
outputs
=
engine
.
generate
(
prompts
,
sampling_params
)
print
(
"="
*
100
)
print
(
f
"[Parameterized Engine] Prompt:
{
prompts
[
0
]
}
\n
Generated text:
{
outputs
[
0
][
'text'
]
}
"
)
return
outputs
[
0
][
"text"
]
def
run_update_weights
(
model_path
):
ret
=
engine
.
update_weights_from_disk
(
model_path
)
print
(
json
.
dumps
(
ret
))
return
ret
origin_response
=
run_decode
()
new_model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
.
replace
(
"-Instruct"
,
""
)
ret
=
run_update_weights
(
new_model_path
)
self
.
assertTrue
(
ret
[
0
])
updated_response
=
run_decode
()
self
.
assertNotEqual
(
origin_response
[:
32
],
updated_response
[:
32
])
ret
=
run_update_weights
(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
)
self
.
assertTrue
(
ret
[
0
])
reverted_response
=
run_decode
()
self
.
assertEqual
(
origin_response
[:
32
],
reverted_response
[:
32
])
return
origin_response
def
_server_update_weights_test
(
self
,
base_url
):
def
run_decode
():
response
=
requests
.
post
(
base_url
+
"/generate"
,
json
=
{
"text"
:
"The capital of France is"
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
32
},
},
)
print
(
"="
*
100
)
print
(
f
"[Parameterized Server] Generated text:
{
response
.
json
()[
'text'
]
}
"
)
return
response
.
json
()[
"text"
]
def
get_model_info
():
response
=
requests
.
get
(
base_url
+
"/get_model_info"
)
model_path
=
response
.
json
()[
"model_path"
]
print
(
json
.
dumps
(
response
.
json
()))
return
model_path
def
run_update_weights
(
model_path
):
response
=
requests
.
post
(
base_url
+
"/update_weights_from_disk"
,
json
=
{
"model_path"
:
model_path
},
)
ret
=
response
.
json
()
print
(
json
.
dumps
(
ret
))
return
ret
origin_model_path
=
get_model_info
()
origin_response
=
run_decode
()
new_model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
.
replace
(
"-Instruct"
,
""
)
ret
=
run_update_weights
(
new_model_path
)
self
.
assertTrue
(
ret
[
"success"
])
updated_model_path
=
get_model_info
()
self
.
assertEqual
(
updated_model_path
,
new_model_path
)
self
.
assertNotEqual
(
updated_model_path
,
origin_model_path
)
updated_response
=
run_decode
()
self
.
assertNotEqual
(
origin_response
[:
32
],
updated_response
[:
32
])
ret
=
run_update_weights
(
origin_model_path
)
self
.
assertTrue
(
ret
[
"success"
])
updated_model_path
=
get_model_info
()
self
.
assertEqual
(
updated_model_path
,
origin_model_path
)
reverted_response
=
run_decode
()
self
.
assertEqual
(
origin_response
[:
32
],
reverted_response
[:
32
])
return
origin_response
def
test_parameterized_update_weights
(
self
):
if
is_in_ci
():
# In CI, choose one random mode (Engine or Server) with tp=1, dp=1.
mode
=
random
.
choice
([
"Engine"
,
"Server"
])
test_suits
=
[(
1
,
1
,
mode
)]
else
:
# Otherwise, test both modes and enumerate tp,dp combinations from 1 to 2.
test_suits
=
[]
for
mode
in
[
"Engine"
,
"Server"
]:
for
tp
in
[
1
,
2
]:
for
dp
in
[
1
,
2
]:
test_suits
.
append
((
tp
,
dp
,
mode
))
for
tp
,
dp
,
mode
in
test_suits
:
with
self
.
subTest
(
mode
=
mode
,
tp
=
tp
,
dp
=
dp
):
self
.
run_common_test
(
mode
,
tp
,
dp
)
if
__name__
==
"__main__"
:
...
...
test/srt/test_update_weights_from_distributed.py
View file @
70b3c6ee
...
...
@@ -15,6 +15,7 @@ distributed setup.
import
gc
import
os
import
random
import
time
import
unittest
...
...
@@ -529,8 +530,9 @@ class TestUpdateWeightsFromDistributed(unittest.TestCase):
assert
torch
.
cuda
.
device_count
()
>=
2
,
"At least 2 GPUs are required"
# test_suits : tp, dp, model_name, backend
if
is_in_ci
():
mode
=
random
.
choice
([
"Engine"
,
"Server"
])
test_suits
=
[
(
1
,
1
,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
"Engine"
),
(
1
,
1
,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
mode
),
]
else
:
test_suits
=
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment