Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
91ba98fe
"vscode:/vscode.git/clone" did not exist on "73de9abe23189e5ee61e9093398db97f99f9afcb"
Unverified
Commit
91ba98fe
authored
Mar 17, 2025
by
Wei Wu
Committed by
GitHub
Mar 17, 2025
Browse files
[Fix] Resolve GPU Memory Leak in update_weights_from_tensor (#4446)
parent
c614dbdf
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
38 additions
and
14 deletions
+38
-14
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+4
-1
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+1
-1
test/srt/test_update_weights_from_tensor.py
test/srt/test_update_weights_from_tensor.py
+33
-12
No files found.
python/sglang/srt/entrypoints/engine.py
View file @
91ba98fe
...
...
@@ -320,7 +320,10 @@ class Engine:
"""Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be true
to avoid duplicated operations such as clearing cache."""
obj
=
UpdateWeightsFromTensorReqInput
(
serialized_named_tensors
=
MultiprocessingSerializer
.
serialize
(
named_tensors
),
serialized_named_tensors
=
[
MultiprocessingSerializer
.
serialize
(
named_tensors
)
for
_
in
range
(
self
.
server_args
.
tp_size
)
],
load_format
=
load_format
,
flush_cache
=
flush_cache
,
)
...
...
python/sglang/srt/managers/tp_worker.py
View file @
91ba98fe
...
...
@@ -214,7 +214,7 @@ class TpModelWorker:
def
update_weights_from_tensor
(
self
,
recv_req
:
UpdateWeightsFromTensorReqInput
):
success
,
message
=
self
.
model_runner
.
update_weights_from_tensor
(
named_tensors
=
MultiprocessingSerializer
.
deserialize
(
recv_req
.
serialized_named_tensors
recv_req
.
serialized_named_tensors
[
self
.
tp_rank
]
),
load_format
=
recv_req
.
load_format
,
)
...
...
test/srt/test_update_weights_from_tensor.py
View file @
91ba98fe
import
gc
import
time
import
unittest
...
...
@@ -7,24 +8,44 @@ import sglang as sgl
from
sglang.test.test_utils
import
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
class
T
est
U
pdate
W
eights
F
rom
T
ensor
(
unittest
.
TestCas
e
):
def
test_update_weights_from_tensor
(
self
):
engine
=
sgl
.
Engine
(
model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
)
def
t
est
_u
pdate
_w
eights
_f
rom
_t
ensor
(
tp_siz
e
):
assert
torch
.
cuda
.
device_count
()
>=
tp_size
,
f
"At least
{
tp_size
}
GPUs are required"
torch
.
cuda
.
empty_cache
(
)
param_names
=
[
f
"model.layers.
{
i
}
.mlp.up_proj.weight"
for
i
in
range
(
6
,
16
)]
engine
=
sgl
.
Engine
(
model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
tp_size
=
tp_size
)
_check_param
(
engine
,
param_names
[
0
],
[
0.0087
,
-
0.0214
,
-
0.0004
,
0.0039
,
0.0110
])
param_names
=
[
f
"model.layers.
{
i
}
.mlp.up_proj.weight"
for
i
in
range
(
6
,
16
)]
new_tensor
=
torch
.
full
((
16384
,
2048
),
1.5
)
_check_param
(
engine
,
param_names
[
0
],
[
0.0087
,
-
0.0214
,
-
0.0004
,
0.0039
,
0.0110
]
)
time_start
=
time
.
time
()
engine
.
update_weights_from_tensor
([(
x
,
new_tensor
)
for
x
in
param_names
])
print
(
f
"Time delta:
{
time
.
time
()
-
time_start
:.
03
f
}
"
)
memory_before
=
torch
.
cuda
.
memory_allocated
()
new_tensor
=
torch
.
full
((
16384
,
2048
),
1.5
,
device
=
"cuda"
)
for
param_name
in
param_names
[:
3
]:
_check_param
(
engine
,
param_name
,
[
1.5
]
*
5
)
time_start
=
time
.
time
()
engine
.
update_weights_from_tensor
([(
x
,
new_tensor
)
for
x
in
param_names
])
print
(
f
"Time delta:
{
time
.
time
()
-
time_start
:.
03
f
}
"
)
engine
.
shutdown
()
for
param_name
in
param_names
[:
3
]:
_check_param
(
engine
,
param_name
,
[
1.5
]
*
5
)
engine
.
shutdown
()
del
new_tensor
gc
.
collect
()
torch
.
cuda
.
ipc_collect
()
torch
.
cuda
.
empty_cache
()
memory_after
=
torch
.
cuda
.
memory_allocated
()
assert
(
memory_after
<=
memory_before
+
1024
),
f
"Memory leak detected:
{
memory_after
-
memory_before
}
bytes"
class
TestUpdateWeightsFromTensor
(
unittest
.
TestCase
):
def
test_update_weights_from_tensor
(
self
):
tp_sizes
=
[
1
,
2
]
for
tp_size
in
tp_sizes
:
with
self
.
subTest
(
tp_size
=
tp_size
):
test_update_weights_from_tensor
(
tp_size
)
def
test_update_weights_from_tensor_load_format_direct
(
self
):
engine
=
sgl
.
Engine
(
model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment