Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4ca43b06
Unverified
Commit
4ca43b06
authored
Aug 02, 2025
by
Stefan He
Committed by
GitHub
Aug 02, 2025
Browse files
Add tensor.detach() back to update weight util (#8691)
parent
ea93079b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
63 additions
and
65 deletions
+63
-65
python/sglang/srt/weight_sync/utils.py
python/sglang/srt/weight_sync/utils.py
+1
-1
test/srt/test_utils_update_weights.py
test/srt/test_utils_update_weights.py
+62
-64
No files found.
python/sglang/srt/weight_sync/utils.py
View file @
4ca43b06
...
@@ -45,7 +45,7 @@ async def update_weights(
...
@@ -45,7 +45,7 @@ async def update_weights(
(
(
name
,
name
,
MultiprocessingSerializer
.
serialize
(
MultiprocessingSerializer
.
serialize
(
_preprocess_tensor_for_update_weights
(
tensor
)
_preprocess_tensor_for_update_weights
(
tensor
.
detach
()
)
),
),
)
)
for
name
,
tensor
in
params_batch
for
name
,
tensor
in
params_batch
...
...
test/srt/test_utils_update_weights.py
View file @
4ca43b06
import
asyncio
import
asyncio
import
os
import
os
import
unittest
import
pytest
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
loguru
import
logger
from
torch.distributed.device_mesh
import
init_device_mesh
from
torch.distributed.device_mesh
import
init_device_mesh
from
transformers
import
AutoModelForCausalLM
from
transformers
import
AutoModelForCausalLM
...
@@ -39,11 +38,29 @@ def setup_single_process_distributed():
...
@@ -39,11 +38,29 @@ def setup_single_process_distributed():
os
.
environ
[
"LOCAL_RANK"
]
=
"0"
os
.
environ
[
"LOCAL_RANK"
]
=
"0"
class
TestUtilsUpdateWeights
:
class
TestUtilsUpdateWeights
(
unittest
.
TestCase
)
:
"""Test class for utils.update_weights function"""
"""Test class for utils.update_weights function"""
@
pytest
.
fixture
(
scope
=
"class"
)
@
classmethod
def
setup_distributed
(
self
):
def
setUpClass
(
cls
):
"""Setup distributed environment and test fixtures for the entire test class"""
cls
.
setup_distributed
()
cls
.
setup_test_engine
()
cls
.
setup_test_model
()
cls
.
setup_device_mesh
()
@
classmethod
def
tearDownClass
(
cls
):
"""Cleanup after all tests"""
if
hasattr
(
cls
,
"engine"
)
and
cls
.
engine
:
cls
.
engine
.
shutdown
()
# Cleanup distributed
if
dist
.
is_initialized
():
dist
.
destroy_process_group
()
@
classmethod
def
setup_distributed
(
cls
):
"""Setup distributed environment for testing"""
"""Setup distributed environment for testing"""
setup_single_process_distributed
()
setup_single_process_distributed
()
...
@@ -53,13 +70,15 @@ class TestUtilsUpdateWeights:
...
@@ -53,13 +70,15 @@ class TestUtilsUpdateWeights:
backend
=
"nccl"
if
torch
.
cuda
.
is_available
()
else
"gloo"
backend
=
"nccl"
if
torch
.
cuda
.
is_available
()
else
"gloo"
)
)
except
Exception
as
e
:
except
Exception
as
e
:
pytest
.
skip
(
f
"Could not initialize distributed backend:
{
e
}
"
)
raise
unittest
.
SkipTest
(
f
"Could not initialize distributed backend:
{
e
}
"
)
rank
=
dist
.
get_rank
()
cls
.
rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
cls
.
world_size
=
dist
.
get_world_size
()
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
set_device
(
rank
%
torch
.
cuda
.
device_count
())
torch
.
cuda
.
set_device
(
cls
.
rank
%
torch
.
cuda
.
device_count
())
# Set up environment variables
# Set up environment variables
os
.
environ
[
"TF_CPP_MIN_LOG_LEVEL"
]
=
"3"
os
.
environ
[
"TF_CPP_MIN_LOG_LEVEL"
]
=
"3"
...
@@ -68,38 +87,26 @@ class TestUtilsUpdateWeights:
...
@@ -68,38 +87,26 @@ class TestUtilsUpdateWeights:
os
.
environ
[
"CUDA_DEVICE_MAX_CONNECTIONS"
]
=
"4"
os
.
environ
[
"CUDA_DEVICE_MAX_CONNECTIONS"
]
=
"4"
os
.
environ
[
"CUDA_MODULE_LOADING"
]
=
"AUTO"
os
.
environ
[
"CUDA_MODULE_LOADING"
]
=
"AUTO"
yield
rank
,
world_size
@
classmethod
def
setup_test_engine
(
cls
):
# Cleanup
if
dist
.
is_initialized
():
dist
.
destroy_process_group
()
@
pytest
.
fixture
(
scope
=
"class"
)
def
test_engine
(
self
,
setup_distributed
):
"""Setup test engine"""
"""Setup test engine"""
rank
,
world_size
=
setup_distributed
if
cls
.
rank
==
0
:
cls
.
engine
=
AsyncEngine
(
if
rank
==
0
:
os
.
environ
[
"SGLANG_BLOCK_NONZERO_RANK_CHILDREN"
]
=
"0"
engine
=
AsyncEngine
(
model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
dtype
=
"bfloat16"
,
dtype
=
"bfloat16"
,
mem_fraction_static
=
0.3
,
mem_fraction_static
=
0.3
,
enable_memory_saver
=
True
,
enable_memory_saver
=
True
,
tp_size
=
world_size
,
tp_size
=
cls
.
world_size
,
disable_cuda_graph
=
Tru
e
,
disable_cuda_graph
=
Fals
e
,
)
)
yield
engine
engine
.
shutdown
()
else
:
else
:
yield
None
cls
.
engine
=
None
@
pytest
.
fixture
(
scope
=
"class"
)
@
classmethod
def
test_model
(
self
):
def
setup_
test_model
(
cls
):
"""Load test model"""
"""Load test model"""
try
:
try
:
model
=
AutoModelForCausalLM
.
from_pretrained
(
cls
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
device_map
=
"cpu"
,
device_map
=
"cpu"
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
...
@@ -108,25 +115,20 @@ class TestUtilsUpdateWeights:
...
@@ -108,25 +115,20 @@ class TestUtilsUpdateWeights:
torch
.
float16
if
torch
.
cuda
.
is_available
()
else
torch
.
float32
torch
.
float16
if
torch
.
cuda
.
is_available
()
else
torch
.
float32
),
),
)
)
return
model
except
Exception
as
e
:
except
Exception
as
e
:
py
test
.
s
kip
(
f
"Could not load test model:
{
e
}
"
)
raise
unit
test
.
S
kip
Test
(
f
"Could not load test model:
{
e
}
"
)
@
pytest
.
fixture
(
scope
=
"class"
)
@
classmethod
def
device_mesh
(
self
,
setup_distributed
):
def
setup_
device_mesh
(
cls
):
"""Create device mesh for testing"""
"""Create device mesh for testing"""
rank
,
world_size
=
setup_distributed
if
not
torch
.
cuda
.
is_available
():
if
not
torch
.
cuda
.
is_available
():
py
test
.
s
kip
(
"CUDA not available for device mesh"
)
raise
unit
test
.
S
kip
Test
(
"CUDA not available for device mesh"
)
device_mesh_key
=
"tp"
cls
.
device_mesh_key
=
"tp"
mesh
=
init_device_mesh
(
cls
.
mesh
=
init_device_mesh
(
"cuda"
,
(
world_size
,),
mesh_dim_names
=
(
device_mesh_key
,)
"cuda"
,
(
cls
.
world_size
,),
mesh_dim_names
=
(
cls
.
device_mesh_key
,)
)
)
return
device_mesh_key
,
mesh
def
create_test_params_batch
(
self
,
model
,
num_params
=
64
):
def
create_test_params_batch
(
self
,
model
,
num_params
=
64
):
"""Create a batch of test parameters from the model"""
"""Create a batch of test parameters from the model"""
param_names
=
[]
param_names
=
[]
...
@@ -143,31 +145,27 @@ class TestUtilsUpdateWeights:
...
@@ -143,31 +145,27 @@ class TestUtilsUpdateWeights:
return
list
(
zip
(
param_names
,
test_tensors
))
return
list
(
zip
(
param_names
,
test_tensors
))
@
pytest
.
mark
.
asyncio
def
test_utils_update_weights
(
self
):
async
def
test_utils_update_weights
(
self
,
setup_distributed
,
test_engine
,
test_model
,
device_mesh
):
"""Test basic functionality of utils.update_weights"""
"""Test basic functionality of utils.update_weights"""
rank
,
world_size
=
setup_distributed
device_mesh_key
,
mesh
=
device_mesh
async
def
async_test
():
# Create test parameters batch
# Create test parameters batch
params_batch
=
self
.
create_test_params_batch
(
test_
model
,
num_params
=
2
)
params_batch
=
self
.
create_test_params_batch
(
self
.
model
,
num_params
=
2
)
print
(
f
"Rank
{
rank
}
testing utils.update_weights with
{
len
(
params_batch
)
}
parameters"
)
# Test the utils.update_weights function
# Test the utils.update_weights function
result
=
await
update_weights
(
result
=
await
update_weights
(
engine
=
test_
engine
,
engine
=
self
.
engine
,
params_batch
=
params_batch
,
params_batch
=
params_batch
,
device_mesh_key
=
device_mesh_key
,
device_mesh_key
=
self
.
device_mesh_key
,
device_mesh
=
mesh
,
device_mesh
=
self
.
mesh
,
load_format
=
None
,
load_format
=
None
,
)
)
assert
"Success"
in
result
self
.
assertIn
(
"Success"
,
result
)
# Run the async test
asyncio
.
run
(
async_test
())
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
py
test
.
main
(
[
__file__
]
)
unit
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment