"examples/sampling/graphbolt/vscode:/vscode.git/clone" did not exist on "2968c9b247314dcb0ff64d00416654e99ca01de7"
Unverified Commit 8ecf6b9d authored by Stefan He's avatar Stefan He Committed by GitHub
Browse files

Support Flatten Tensor Update Weights to speed up MOE Update Weights by 20% (#8079)

parent 0418b9d4
...@@ -451,15 +451,20 @@ class Engine(EngineBase): ...@@ -451,15 +451,20 @@ class Engine(EngineBase):
): ):
"""Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be false """Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be false
to avoid duplicated cache cleaning operation.""" to avoid duplicated cache cleaning operation."""
obj = UpdateWeightsFromTensorReqInput( if load_format == "flattened_bucket":
serialized_named_tensors=[ serialized_named_tensors = named_tensors
else:
serialized_named_tensors = [
MultiprocessingSerializer.serialize(named_tensors) MultiprocessingSerializer.serialize(named_tensors)
for _ in range(self.server_args.tp_size) for _ in range(self.server_args.tp_size)
], ]
obj = UpdateWeightsFromTensorReqInput(
serialized_named_tensors=serialized_named_tensors,
load_format=load_format, load_format=load_format,
flush_cache=flush_cache, flush_cache=flush_cache,
) )
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return loop.run_until_complete( return loop.run_until_complete(
self.tokenizer_manager.update_weights_from_tensor(obj, None) self.tokenizer_manager.update_weights_from_tensor(obj, None)
) )
......
...@@ -121,6 +121,10 @@ from sglang.srt.utils import ( ...@@ -121,6 +121,10 @@ from sglang.srt.utils import (
set_cpu_offload_max_bytes, set_cpu_offload_max_bytes,
set_cuda_arch, set_cuda_arch,
) )
from sglang.srt.weight_sync.tensor_bucket import (
FlattenedTensorBucket,
FlattenedTensorMetadata,
)
_is_hip = is_hip() _is_hip = is_hip()
_is_npu = is_npu() _is_npu = is_npu()
...@@ -896,6 +900,12 @@ class ModelRunner: ...@@ -896,6 +900,12 @@ class ModelRunner:
load_format: Optional[str] = None, load_format: Optional[str] = None,
): ):
monkey_patch_torch_reductions() monkey_patch_torch_reductions()
if load_format == "flattened_bucket":
# Handle flattened bucket format
return self._update_weights_from_flattened_bucket(
flattened_tensor_bucket_dict=named_tensors
)
# We need to get device after patch otherwise the device would be wrong # We need to get device after patch otherwise the device would be wrong
infered_device = torch.cuda.current_device() infered_device = torch.cuda.current_device()
...@@ -914,6 +924,38 @@ class ModelRunner: ...@@ -914,6 +924,38 @@ class ModelRunner:
raise NotImplementedError(f"Unknown load_format={load_format}") raise NotImplementedError(f"Unknown load_format={load_format}")
return True, "Success" return True, "Success"
def _update_weights_from_flattened_bucket(
self,
flattened_tensor_bucket_dict,
):
"""Handle flattened bucket format for weight updates"""
flattened_tensor = flattened_tensor_bucket_dict["flattened_tensor"]
metadata = flattened_tensor_bucket_dict["metadata"]
# Convert metadata dict to our format
converted_metadata = []
for meta in metadata:
converted_meta = FlattenedTensorMetadata(
name=meta.name,
shape=meta.shape,
dtype=meta.dtype,
start_idx=meta.start_idx,
end_idx=meta.end_idx,
numel=meta.numel,
)
converted_metadata.append(converted_meta)
# Create bucket and reconstruct tensors
bucket = FlattenedTensorBucket(
flattened_tensor=flattened_tensor, metadata=converted_metadata
)
reconstructed_tensors = bucket.reconstruct_tensors()
# Load the reconstructed tensors using the standard method
self.model.load_weights(reconstructed_tensors)
return True, "Success"
def get_weights_by_name( def get_weights_by_name(
self, name: str, truncate_size: int = 100 self, name: str, truncate_size: int = 100
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
......
from dataclasses import dataclass
from typing import List, Tuple
import torch
@dataclass
class FlattenedTensorMetadata:
"""Metadata for a tensor in a flattened bucket"""
name: str
shape: torch.Size
dtype: torch.dtype
start_idx: int
end_idx: int
numel: int
class FlattenedTensorBucket:
"""
A bucket that flattens multiple tensors into a single tensor for efficient processing
while preserving all metadata needed for reconstruction.
"""
def __init__(
self,
named_tensors: List[Tuple[str, torch.Tensor]] = None,
flattened_tensor: torch.Tensor = None,
metadata: List[FlattenedTensorMetadata] = None,
):
"""
Initialize a tensor bucket from a list of named tensors OR from pre-flattened data.
Args:
named_tensors: List of (name, tensor) tuples (for creating new bucket)
flattened_tensor: Pre-flattened tensor (for reconstruction)
metadata: Pre-computed metadata (for reconstruction)
"""
if named_tensors is not None:
# Create bucket from named tensors
self.metadata: List[FlattenedTensorMetadata] = [None] * len(named_tensors)
self.flattened_tensor: torch.Tensor = None
if not named_tensors:
raise ValueError("Cannot create empty tensor bucket")
# Collect metadata and flatten tensors
current_idx = 0
flattened_tensors: List[torch.Tensor] = [None] * len(named_tensors)
for i, (name, tensor) in enumerate(named_tensors):
flattened = tensor.flatten()
flattened_tensors[i] = flattened
# Store metadata
numel = flattened.numel()
metadata_obj = FlattenedTensorMetadata(
name=name,
shape=tensor.shape,
dtype=tensor.dtype,
start_idx=current_idx,
end_idx=current_idx + numel,
numel=numel,
)
self.metadata[i] = metadata_obj
current_idx += numel
# Concatenate all flattened tensors
self.flattened_tensor = torch.cat(flattened_tensors, dim=0)
else:
# Initialize from pre-flattened data
if flattened_tensor is None or metadata is None:
raise ValueError(
"Must provide either named_tensors or both flattened_tensor and metadata"
)
self.flattened_tensor = flattened_tensor
self.metadata = metadata
def get_flattened_tensor(self) -> torch.Tensor:
"""Get the flattened tensor containing all bucket tensors"""
return self.flattened_tensor
def get_metadata(self) -> List[FlattenedTensorMetadata]:
"""Get metadata for all tensors in the bucket"""
return self.metadata
def reconstruct_tensors(self) -> List[Tuple[str, torch.Tensor]]:
"""
Reconstruct original tensors from flattened tensor with optimized performance.
Uses memory-efficient operations to minimize allocations and copies.
"""
# preallocate the result list
reconstructed = [None] * len(self.metadata)
for i, meta in enumerate(self.metadata):
tensor = self.flattened_tensor[meta.start_idx : meta.end_idx].reshape(
meta.shape
)
# batch dtype conversion (if needed)
if tensor.dtype != meta.dtype:
tensor = tensor.to(meta.dtype)
reconstructed[i] = (meta.name, tensor)
return reconstructed
...@@ -5,6 +5,7 @@ import unittest ...@@ -5,6 +5,7 @@ import unittest
import torch import torch
import sglang as sgl import sglang as sgl
from sglang.srt.weight_sync.tensor_bucket import FlattenedTensorBucket
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase
...@@ -112,6 +113,59 @@ class TestUpdateWeightsFromTensor(CustomTestCase): ...@@ -112,6 +113,59 @@ class TestUpdateWeightsFromTensor(CustomTestCase):
engine.shutdown() engine.shutdown()
def test_update_weights_from_tensor_load_format_flattened_bucket(self):
"""Test updating weights using flattened_bucket format"""
engine = sgl.Engine(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
# Create a small set of parameters for testing
param_names = [f"model.layers.{i}.mlp.up_proj.weight" for i in range(6, 10)]
# Check original values
_check_param(engine, param_names[0], [0.0087, -0.0214, -0.0004, 0.0039, 0.0110])
# Create new tensors with different values
new_tensors = []
for _, name in enumerate(param_names):
# Create tensors with different values for each parameter
value = 2.0 # Different value for each parameter
new_tensor = torch.full((16384, 2048), value, device="cuda")
new_tensors.append((name, new_tensor))
# Create a flattened bucket
flattened_bucket = FlattenedTensorBucket(named_tensors=new_tensors)
# Extract the flattened tensor and metadata in the format expected by model_runner
flattened_tensor = flattened_bucket.get_flattened_tensor()
metadata = flattened_bucket.get_metadata()
# Create the dict format expected by _update_weights_from_flattened_bucket
bucket_dict = {"flattened_tensor": flattened_tensor, "metadata": metadata}
# Serialize the bucket data
from sglang.srt.utils import MultiprocessingSerializer
serialized_bucket = MultiprocessingSerializer.serialize(
bucket_dict, output_str=True
)
# Create a list where each rank contains the same serialized data
# This simulates the distributed environment where each rank has the same data
serialized_bucket_list = [serialized_bucket]
# Update weights using flattened_bucket format
time_start = time.perf_counter()
engine.update_weights_from_tensor(
named_tensors=serialized_bucket_list, load_format="flattened_bucket"
)
update_time = time.perf_counter() - time_start
print(f"Flattened bucket update time: {update_time:.03f}")
# Verify the weights were updated correctly
for i, param_name in enumerate(param_names):
_check_param(engine, param_name, [2.0] * 5)
engine.shutdown()
def _check_param(engine, param_name, expect_values): def _check_param(engine, param_name, expect_values):
actual_values = torch.tensor(engine.get_weights_by_name(param_name))[0, :5] actual_values = torch.tensor(engine.get_weights_by_name(param_name))[0, :5]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment