Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8ecf6b9d
Unverified
Commit
8ecf6b9d
authored
Aug 10, 2025
by
Stefan He
Committed by
GitHub
Aug 10, 2025
Browse files
Support Flatten Tensor Update Weights to speed up MOE Update Weights by 20% (#8079)
parent
0418b9d4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
210 additions
and
3 deletions
+210
-3
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+8
-3
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+42
-0
python/sglang/srt/weight_sync/tensor_bucket.py
python/sglang/srt/weight_sync/tensor_bucket.py
+106
-0
test/srt/rl/test_update_weights_from_tensor.py
test/srt/rl/test_update_weights_from_tensor.py
+54
-0
No files found.
python/sglang/srt/entrypoints/engine.py
View file @
8ecf6b9d
...
...
@@ -451,15 +451,20 @@ class Engine(EngineBase):
):
"""Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be false
to avoid duplicated cache cleaning operation."""
obj
=
UpdateWeightsFromTensorReqInput
(
serialized_named_tensors
=
[
if
load_format
==
"flattened_bucket"
:
serialized_named_tensors
=
named_tensors
else
:
serialized_named_tensors
=
[
MultiprocessingSerializer
.
serialize
(
named_tensors
)
for
_
in
range
(
self
.
server_args
.
tp_size
)
],
]
obj
=
UpdateWeightsFromTensorReqInput
(
serialized_named_tensors
=
serialized_named_tensors
,
load_format
=
load_format
,
flush_cache
=
flush_cache
,
)
loop
=
asyncio
.
get_event_loop
()
return
loop
.
run_until_complete
(
self
.
tokenizer_manager
.
update_weights_from_tensor
(
obj
,
None
)
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
8ecf6b9d
...
...
@@ -121,6 +121,10 @@ from sglang.srt.utils import (
set_cpu_offload_max_bytes
,
set_cuda_arch
,
)
from
sglang.srt.weight_sync.tensor_bucket
import
(
FlattenedTensorBucket
,
FlattenedTensorMetadata
,
)
_is_hip
=
is_hip
()
_is_npu
=
is_npu
()
...
...
@@ -896,6 +900,12 @@ class ModelRunner:
load_format
:
Optional
[
str
]
=
None
,
):
monkey_patch_torch_reductions
()
if
load_format
==
"flattened_bucket"
:
# Handle flattened bucket format
return
self
.
_update_weights_from_flattened_bucket
(
flattened_tensor_bucket_dict
=
named_tensors
)
# We need to get device after patch otherwise the device would be wrong
infered_device
=
torch
.
cuda
.
current_device
()
...
...
@@ -914,6 +924,38 @@ class ModelRunner:
raise
NotImplementedError
(
f
"Unknown load_format=
{
load_format
}
"
)
return
True
,
"Success"
def
_update_weights_from_flattened_bucket
(
self
,
flattened_tensor_bucket_dict
,
):
"""Handle flattened bucket format for weight updates"""
flattened_tensor
=
flattened_tensor_bucket_dict
[
"flattened_tensor"
]
metadata
=
flattened_tensor_bucket_dict
[
"metadata"
]
# Convert metadata dict to our format
converted_metadata
=
[]
for
meta
in
metadata
:
converted_meta
=
FlattenedTensorMetadata
(
name
=
meta
.
name
,
shape
=
meta
.
shape
,
dtype
=
meta
.
dtype
,
start_idx
=
meta
.
start_idx
,
end_idx
=
meta
.
end_idx
,
numel
=
meta
.
numel
,
)
converted_metadata
.
append
(
converted_meta
)
# Create bucket and reconstruct tensors
bucket
=
FlattenedTensorBucket
(
flattened_tensor
=
flattened_tensor
,
metadata
=
converted_metadata
)
reconstructed_tensors
=
bucket
.
reconstruct_tensors
()
# Load the reconstructed tensors using the standard method
self
.
model
.
load_weights
(
reconstructed_tensors
)
return
True
,
"Success"
def
get_weights_by_name
(
self
,
name
:
str
,
truncate_size
:
int
=
100
)
->
Optional
[
torch
.
Tensor
]:
...
...
python/sglang/srt/weight_sync/tensor_bucket.py
0 → 100644
View file @
8ecf6b9d
from
dataclasses
import
dataclass
from
typing
import
List
,
Tuple
import
torch
@
dataclass
class
FlattenedTensorMetadata
:
"""Metadata for a tensor in a flattened bucket"""
name
:
str
shape
:
torch
.
Size
dtype
:
torch
.
dtype
start_idx
:
int
end_idx
:
int
numel
:
int
class
FlattenedTensorBucket
:
"""
A bucket that flattens multiple tensors into a single tensor for efficient processing
while preserving all metadata needed for reconstruction.
"""
def
__init__
(
self
,
named_tensors
:
List
[
Tuple
[
str
,
torch
.
Tensor
]]
=
None
,
flattened_tensor
:
torch
.
Tensor
=
None
,
metadata
:
List
[
FlattenedTensorMetadata
]
=
None
,
):
"""
Initialize a tensor bucket from a list of named tensors OR from pre-flattened data.
Args:
named_tensors: List of (name, tensor) tuples (for creating new bucket)
flattened_tensor: Pre-flattened tensor (for reconstruction)
metadata: Pre-computed metadata (for reconstruction)
"""
if
named_tensors
is
not
None
:
# Create bucket from named tensors
self
.
metadata
:
List
[
FlattenedTensorMetadata
]
=
[
None
]
*
len
(
named_tensors
)
self
.
flattened_tensor
:
torch
.
Tensor
=
None
if
not
named_tensors
:
raise
ValueError
(
"Cannot create empty tensor bucket"
)
# Collect metadata and flatten tensors
current_idx
=
0
flattened_tensors
:
List
[
torch
.
Tensor
]
=
[
None
]
*
len
(
named_tensors
)
for
i
,
(
name
,
tensor
)
in
enumerate
(
named_tensors
):
flattened
=
tensor
.
flatten
()
flattened_tensors
[
i
]
=
flattened
# Store metadata
numel
=
flattened
.
numel
()
metadata_obj
=
FlattenedTensorMetadata
(
name
=
name
,
shape
=
tensor
.
shape
,
dtype
=
tensor
.
dtype
,
start_idx
=
current_idx
,
end_idx
=
current_idx
+
numel
,
numel
=
numel
,
)
self
.
metadata
[
i
]
=
metadata_obj
current_idx
+=
numel
# Concatenate all flattened tensors
self
.
flattened_tensor
=
torch
.
cat
(
flattened_tensors
,
dim
=
0
)
else
:
# Initialize from pre-flattened data
if
flattened_tensor
is
None
or
metadata
is
None
:
raise
ValueError
(
"Must provide either named_tensors or both flattened_tensor and metadata"
)
self
.
flattened_tensor
=
flattened_tensor
self
.
metadata
=
metadata
def
get_flattened_tensor
(
self
)
->
torch
.
Tensor
:
"""Get the flattened tensor containing all bucket tensors"""
return
self
.
flattened_tensor
def
get_metadata
(
self
)
->
List
[
FlattenedTensorMetadata
]:
"""Get metadata for all tensors in the bucket"""
return
self
.
metadata
def
reconstruct_tensors
(
self
)
->
List
[
Tuple
[
str
,
torch
.
Tensor
]]:
"""
Reconstruct original tensors from flattened tensor with optimized performance.
Uses memory-efficient operations to minimize allocations and copies.
"""
# preallocate the result list
reconstructed
=
[
None
]
*
len
(
self
.
metadata
)
for
i
,
meta
in
enumerate
(
self
.
metadata
):
tensor
=
self
.
flattened_tensor
[
meta
.
start_idx
:
meta
.
end_idx
].
reshape
(
meta
.
shape
)
# batch dtype conversion (if needed)
if
tensor
.
dtype
!=
meta
.
dtype
:
tensor
=
tensor
.
to
(
meta
.
dtype
)
reconstructed
[
i
]
=
(
meta
.
name
,
tensor
)
return
reconstructed
test/srt/rl/test_update_weights_from_tensor.py
View file @
8ecf6b9d
...
...
@@ -5,6 +5,7 @@ import unittest
import
torch
import
sglang
as
sgl
from
sglang.srt.weight_sync.tensor_bucket
import
FlattenedTensorBucket
from
sglang.test.test_utils
import
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
CustomTestCase
...
...
@@ -112,6 +113,59 @@ class TestUpdateWeightsFromTensor(CustomTestCase):
engine
.
shutdown
()
def
test_update_weights_from_tensor_load_format_flattened_bucket
(
self
):
"""Test updating weights using flattened_bucket format"""
engine
=
sgl
.
Engine
(
model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
)
# Create a small set of parameters for testing
param_names
=
[
f
"model.layers.
{
i
}
.mlp.up_proj.weight"
for
i
in
range
(
6
,
10
)]
# Check original values
_check_param
(
engine
,
param_names
[
0
],
[
0.0087
,
-
0.0214
,
-
0.0004
,
0.0039
,
0.0110
])
# Create new tensors with different values
new_tensors
=
[]
for
_
,
name
in
enumerate
(
param_names
):
# Create tensors with different values for each parameter
value
=
2.0
# Different value for each parameter
new_tensor
=
torch
.
full
((
16384
,
2048
),
value
,
device
=
"cuda"
)
new_tensors
.
append
((
name
,
new_tensor
))
# Create a flattened bucket
flattened_bucket
=
FlattenedTensorBucket
(
named_tensors
=
new_tensors
)
# Extract the flattened tensor and metadata in the format expected by model_runner
flattened_tensor
=
flattened_bucket
.
get_flattened_tensor
()
metadata
=
flattened_bucket
.
get_metadata
()
# Create the dict format expected by _update_weights_from_flattened_bucket
bucket_dict
=
{
"flattened_tensor"
:
flattened_tensor
,
"metadata"
:
metadata
}
# Serialize the bucket data
from
sglang.srt.utils
import
MultiprocessingSerializer
serialized_bucket
=
MultiprocessingSerializer
.
serialize
(
bucket_dict
,
output_str
=
True
)
# Create a list where each rank contains the same serialized data
# This simulates the distributed environment where each rank has the same data
serialized_bucket_list
=
[
serialized_bucket
]
# Update weights using flattened_bucket format
time_start
=
time
.
perf_counter
()
engine
.
update_weights_from_tensor
(
named_tensors
=
serialized_bucket_list
,
load_format
=
"flattened_bucket"
)
update_time
=
time
.
perf_counter
()
-
time_start
print
(
f
"Flattened bucket update time:
{
update_time
:.
03
f
}
"
)
# Verify the weights were updated correctly
for
i
,
param_name
in
enumerate
(
param_names
):
_check_param
(
engine
,
param_name
,
[
2.0
]
*
5
)
engine
.
shutdown
()
def
_check_param
(
engine
,
param_name
,
expect_values
):
actual_values
=
torch
.
tensor
(
engine
.
get_weights_by_name
(
param_name
))[
0
,
:
5
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment