Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9183c23e
Unverified
Commit
9183c23e
authored
Jan 02, 2025
by
fzyzcjy
Committed by
GitHub
Jan 02, 2025
Browse files
Speed up `update_weights_from_tensor` (#2695)
parent
148254d4
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
48 additions
and
25 deletions
+48
-25
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+1
-2
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+2
-2
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+4
-4
python/sglang/srt/server.py
python/sglang/srt/server.py
+8
-3
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+15
-2
test/srt/test_update_weights_from_tensor.py
test/srt/test_update_weights_from_tensor.py
+18
-12
No files found.
python/sglang/srt/managers/io_struct.py
View file @
9183c23e
...
@@ -426,8 +426,7 @@ class UpdateWeightsFromDistributedReqOutput:
...
@@ -426,8 +426,7 @@ class UpdateWeightsFromDistributedReqOutput:
@
dataclass
@
dataclass
class
UpdateWeightsFromTensorReqInput
:
class
UpdateWeightsFromTensorReqInput
:
name
:
str
serialized_named_tensors
:
bytes
# indeed Dict[str, torch.Tensor]
tensor
:
torch
.
Tensor
@
dataclass
@
dataclass
...
...
python/sglang/srt/managers/tp_worker.py
View file @
9183c23e
...
@@ -30,7 +30,7 @@ from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_a
...
@@ -30,7 +30,7 @@ from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_a
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
broadcast_pyobj
,
set_random_seed
from
sglang.srt.utils
import
MultiprocessingSerializer
,
broadcast_pyobj
,
set_random_seed
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -197,7 +197,7 @@ class TpModelWorker:
...
@@ -197,7 +197,7 @@ class TpModelWorker:
def
update_weights_from_tensor
(
self
,
recv_req
:
UpdateWeightsFromTensorReqInput
):
def
update_weights_from_tensor
(
self
,
recv_req
:
UpdateWeightsFromTensorReqInput
):
success
,
message
=
self
.
model_runner
.
update_weights_from_tensor
(
success
,
message
=
self
.
model_runner
.
update_weights_from_tensor
(
recv_req
.
name
,
recv_req
.
tensor
MultiprocessingSerializer
.
deserialize
(
recv_req
.
serialized_named_
tensor
s
)
)
)
return
success
,
message
return
success
,
message
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
9183c23e
...
@@ -17,7 +17,7 @@ import gc
...
@@ -17,7 +17,7 @@ import gc
import
json
import
json
import
logging
import
logging
import
time
import
time
from
typing
import
Optional
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -428,9 +428,9 @@ class ModelRunner:
...
@@ -428,9 +428,9 @@ class ModelRunner:
logger
.
error
(
error_msg
)
logger
.
error
(
error_msg
)
return
False
,
error_msg
return
False
,
error_msg
def
update_weights_from_tensor
(
self
,
name
,
tensor
:
torch
.
Tensor
):
def
update_weights_from_tensor
(
self
,
name
d_
tensor
s
:
List
[
Tuple
[
str
,
torch
.
Tensor
]]
):
self
.
model
.
load_weights
(
[(
name
,
tensor
)]
)
self
.
model
.
load_weights
(
name
d_
tensor
s
)
return
True
,
"Success"
# TODO error handling
return
True
,
"Success"
def
get_weights_by_name
(
def
get_weights_by_name
(
self
,
name
:
str
,
truncate_size
:
int
=
100
self
,
name
:
str
,
truncate_size
:
int
=
100
...
...
python/sglang/srt/server.py
View file @
9183c23e
...
@@ -27,7 +27,9 @@ import signal
...
@@ -27,7 +27,9 @@ import signal
import
threading
import
threading
import
time
import
time
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
typing
import
AsyncIterator
,
Dict
,
List
,
Optional
,
Union
from
typing
import
AsyncIterator
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
# Fix a bug of Python threading
# Fix a bug of Python threading
setattr
(
threading
,
"_register_atexit"
,
lambda
*
args
,
**
kwargs
:
None
)
setattr
(
threading
,
"_register_atexit"
,
lambda
*
args
,
**
kwargs
:
None
)
...
@@ -78,6 +80,7 @@ from sglang.srt.openai_api.adapter import (
...
@@ -78,6 +80,7 @@ from sglang.srt.openai_api.adapter import (
from
sglang.srt.openai_api.protocol
import
ModelCard
,
ModelList
from
sglang.srt.openai_api.protocol
import
ModelCard
,
ModelList
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
MultiprocessingSerializer
,
add_api_key_middleware
,
add_api_key_middleware
,
add_prometheus_middleware
,
add_prometheus_middleware
,
assert_pkg_version
,
assert_pkg_version
,
...
@@ -874,9 +877,11 @@ class Engine:
...
@@ -874,9 +877,11 @@ class Engine:
tokenizer_manager
.
update_weights_from_distributed
(
obj
,
None
)
tokenizer_manager
.
update_weights_from_distributed
(
obj
,
None
)
)
)
def
update_weights_from_tensor
(
self
,
name
,
t
ensor
):
def
update_weights_from_tensor
(
self
,
name
d_tensors
:
List
[
Tuple
[
str
,
torch
.
T
ensor
]]
):
"""Update weights from distributed source."""
"""Update weights from distributed source."""
obj
=
UpdateWeightsFromTensorReqInput
(
name
=
name
,
tensor
=
tensor
)
obj
=
UpdateWeightsFromTensorReqInput
(
serialized_named_tensors
=
MultiprocessingSerializer
.
serialize
(
named_tensors
)
)
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
return
loop
.
run_until_complete
(
return
loop
.
run_until_complete
(
tokenizer_manager
.
update_weights_from_tensor
(
obj
,
None
)
tokenizer_manager
.
update_weights_from_tensor
(
obj
,
None
)
...
...
python/sglang/srt/utils.py
View file @
9183c23e
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
import
base64
import
base64
import
dataclasses
import
dataclasses
import
io
import
ipaddress
import
ipaddress
import
itertools
import
itertools
import
json
import
json
...
@@ -34,6 +35,7 @@ import warnings
...
@@ -34,6 +35,7 @@ import warnings
from
functools
import
lru_cache
from
functools
import
lru_cache
from
importlib.metadata
import
PackageNotFoundError
,
version
from
importlib.metadata
import
PackageNotFoundError
,
version
from
io
import
BytesIO
from
io
import
BytesIO
from
multiprocessing.reduction
import
ForkingPickler
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Protocol
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Protocol
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -60,7 +62,6 @@ from triton.runtime.cache import (
...
@@ -60,7 +62,6 @@ from triton.runtime.cache import (
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
show_time_cost
=
False
show_time_cost
=
False
time_infos
=
{}
time_infos
=
{}
...
@@ -1206,7 +1207,6 @@ def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) ->
...
@@ -1206,7 +1207,6 @@ def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) ->
# https://github.com/pytorch/pytorch/blob/
# https://github.com/pytorch/pytorch/blob/
# c1cd946818442aca8c7f812b16d187ce1586c3bc/
# c1cd946818442aca8c7f812b16d187ce1586c3bc/
# torch/cuda/__init__.py#L831C1-L831C17
# torch/cuda/__init__.py#L831C1-L831C17
import
torch.cuda
import
torch.version
import
torch.version
if
not
torch
.
cuda
.
_is_compiled
():
if
not
torch
.
cuda
.
_is_compiled
():
...
@@ -1335,3 +1335,16 @@ def parse_tool_response(text, tools, **kwargs):
...
@@ -1335,3 +1335,16 @@ def parse_tool_response(text, tools, **kwargs):
for
call_info
in
call_info_list
for
call_info
in
call_info_list
]
]
return
text
,
call_info_list
return
text
,
call_info_list
class
MultiprocessingSerializer
:
@
staticmethod
def
serialize
(
obj
):
buf
=
io
.
BytesIO
()
ForkingPickler
(
buf
).
dump
(
obj
)
buf
.
seek
(
0
)
return
buf
.
read
()
@
staticmethod
def
deserialize
(
data
):
return
ForkingPickler
.
loads
(
data
)
test/srt/test_update_weights_from_tensor.py
View file @
9183c23e
import
time
import
unittest
import
unittest
import
torch
import
torch
...
@@ -6,27 +7,32 @@ import sglang as sgl
...
@@ -6,27 +7,32 @@ import sglang as sgl
from
sglang.test.test_utils
import
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
from
sglang.test.test_utils
import
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
class
Test
ReleaseGPUOccupation
(
unittest
.
TestCase
):
class
Test
UpdateWeightsFromTensor
(
unittest
.
TestCase
):
def
test_
release_and_resume_occupation
(
self
):
def
test_
update_weights_from_tensor
(
self
):
engine
=
sgl
.
Engine
(
model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
)
engine
=
sgl
.
Engine
(
model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
)
param_name
=
"model.layers.
2.self_attn.k_proj.weight"
param_name
s
=
[
f
"model.layers.
{
i
}
.mlp.up_proj.weight"
for
i
in
range
(
6
,
16
)]
def
_check_param
(
expect_values
):
_check_param
(
engine
,
param_names
[
0
],
[
0.0087
,
-
0.0214
,
-
0.0004
,
0.0039
,
0.0110
])
actual_values
=
torch
.
tensor
(
engine
.
get_weights_by_name
(
param_name
))[
0
,
:
5
]
assert
torch
.
allclose
(
actual_values
,
torch
.
tensor
(
expect_values
),
atol
=
0.001
),
f
"
{
actual_values
=
}
"
_check_param
([
0.0571
,
-
0.0114
,
0.0444
,
0.0215
,
-
0.0149
]
)
new_tensor
=
torch
.
full
((
16384
,
2048
),
1.5
)
new_tensor
=
torch
.
full
((
3072
,
2048
),
1.5
)
time_start
=
time
.
time
()
engine
.
update_weights_from_tensor
(
param_name
,
new_tensor
)
engine
.
update_weights_from_tensor
([(
x
,
new_tensor
)
for
x
in
param_names
])
print
(
f
"Time delta:
{
time
.
time
()
-
time_start
:.
03
f
}
"
)
_check_param
([
1.5
]
*
5
)
for
param_name
in
param_names
[:
3
]:
_check_param
(
engine
,
param_name
,
[
1.5
]
*
5
)
engine
.
shutdown
()
engine
.
shutdown
()
def
_check_param
(
engine
,
param_name
,
expect_values
):
actual_values
=
torch
.
tensor
(
engine
.
get_weights_by_name
(
param_name
))[
0
,
:
5
]
assert
torch
.
allclose
(
actual_values
,
torch
.
tensor
(
expect_values
),
atol
=
0.002
),
f
"
{
actual_values
=
}
"
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment