Unverified Commit 0de5e7d4 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support layerwise rebalancing experts (#6851)

parent 72a110f6
import logging
import time
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List
import torch.cuda
......@@ -20,6 +20,10 @@ class EPLBManager:
super().__init__()
self._model_runner = model_runner
self._server_args = model_runner.server_args
self._rebalance_layers_per_chunk = (
self._server_args.eplb_rebalance_layers_per_chunk
)
self._rebalance_num_iterations = self._server_args.eplb_rebalance_num_iterations
# Otherwise, the circular buffer will contain stale data. If the case is needed, it can be implemented.
assert (
......@@ -31,17 +35,30 @@ class EPLBManager:
get_global_expert_distribution_recorder().start_record()
logger.info(
f"[EPLBManager] system started, will rebalance per {self._server_args.eplb_rebalance_num_iterations} iterations."
f"[EPLBManager] system started, will rebalance per {self._rebalance_num_iterations} iterations."
)
def on_forward_pass_end(self, forward_pass_id: int):
if forward_pass_id % self._server_args.eplb_rebalance_num_iterations == 0:
self.rebalance()
self._main_generator = self._entrypoint()
def on_forward_pass_end(self):
next(self._main_generator)
# can be more complex if needed
def _entrypoint(self):
while True:
for _ in range(self._rebalance_num_iterations):
yield
yield from self.rebalance()
def rebalance(self):
logger.info("[EPLBManager] rebalance start")
torch.cuda.synchronize()
time_start = time.time()
enable_timing = self._rebalance_layers_per_chunk is None
if enable_timing:
torch.cuda.synchronize()
time_start = time.time()
logical_count = get_global_expert_distribution_recorder().dump_record(
output_mode="object"
......@@ -49,8 +66,31 @@ class EPLBManager:
expert_location_metadata = ExpertLocationMetadata.init_by_eplb(
self._server_args, self._model_runner.model_config, logical_count
)
self._model_runner.update_expert_location(expert_location_metadata)
torch.cuda.synchronize()
time_end = time.time()
logger.info(f"[EPLBManager] rebalance end time={time_end - time_start:.3f}s")
update_layer_ids_chunks = self._compute_update_layer_ids_chunks()
for chunk_index, update_layer_ids in enumerate(update_layer_ids_chunks):
if len(update_layer_ids_chunks) > 1:
yield
self._model_runner.update_expert_location(
expert_location_metadata,
update_layer_ids=update_layer_ids,
)
msg = f"[EPLBManager] rebalance end"
if enable_timing:
torch.cuda.synchronize()
time_end = time.time()
msg += f" time={time_end - time_start:.3f}s"
logger.info(msg)
def _compute_update_layer_ids_chunks(self) -> List[List[int]]:
all_layer_ids = sorted(
list(self._model_runner.model.routed_experts_weights_of_layer.keys())
)
chunk_size = self._rebalance_layers_per_chunk or 1000000
return list(_chunk_list(all_layer_ids, chunk_size=chunk_size))
def _chunk_list(items: List, chunk_size):
for start_index in range(0, len(items), chunk_size):
yield items[start_index : start_index + chunk_size]
......@@ -33,6 +33,7 @@ logger = logging.getLogger(__name__)
@dataclass
class ExpertLocationMetadata:
physical_to_logical_map: torch.Tensor # (layers, num_physical_experts)
physical_to_logical_map_cpu: torch.Tensor
logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X)
logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts)
# (layers, num_logical_experts)
......@@ -203,6 +204,7 @@ class ExpertLocationMetadata:
return ExpertLocationMetadata(
physical_to_logical_map=physical_to_logical_map,
physical_to_logical_map_cpu=physical_to_logical_map.cpu(),
logical_to_all_physical_map=logical_to_all_physical_map_padded,
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
logical_to_rank_dispatch_physical_map=(
......@@ -223,6 +225,7 @@ class ExpertLocationMetadata:
def update(
self,
other: "ExpertLocationMetadata",
update_layer_ids: List[int],
):
for field in [
"ep_size",
......@@ -231,15 +234,21 @@ class ExpertLocationMetadata:
for field in [
"physical_to_logical_map",
"physical_to_logical_map_cpu",
"logical_to_all_physical_map",
"logical_to_all_physical_map_num_valid",
"logical_to_rank_dispatch_physical_map",
]:
src = getattr(other, field)
dst = getattr(self, field)
assert (src is not None) == (dst is not None)
if dst is not None:
dst[...] = src
other_field = getattr(other, field)
self_field = getattr(self, field)
assert (other_field is not None) == (self_field is not None)
if self_field is not None:
mask_update = torch.tensor(
[i in update_layer_ids for i in range(self.num_layers)]
)
mask_update = mask_update.view(*([-1] + [1] * (self_field.dim() - 1)))
mask_update = mask_update.to(self_field.device, non_blocking=True)
self_field[...] = torch.where(mask_update, other_field, self_field)
# -------------------------------- usage ------------------------------------
......
......@@ -24,6 +24,7 @@ from sglang.srt.managers.expert_location import (
ExpertLocationMetadata,
get_global_expert_location_metadata,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import get_bool_env_var
logger = logging.getLogger(__name__)
......@@ -37,6 +38,7 @@ class ExpertLocationUpdater:
self,
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
new_expert_location_metadata: ExpertLocationMetadata,
update_layer_ids: List[int],
nnodes: int,
rank: int,
):
......@@ -46,45 +48,47 @@ class ExpertLocationUpdater:
old_expert_location_metadata = get_global_expert_location_metadata()
_update_expert_weights(
routed_experts_weights_of_layer,
old_expert_location_metadata,
routed_experts_weights_of_layer=routed_experts_weights_of_layer,
old_expert_location_metadata=old_expert_location_metadata,
new_expert_location_metadata=new_expert_location_metadata,
update_layer_ids=update_layer_ids,
nnodes=nnodes,
rank=rank,
)
old_expert_location_metadata.update(
new_expert_location_metadata,
nnodes,
rank,
update_layer_ids=update_layer_ids,
)
old_expert_location_metadata.update(new_expert_location_metadata)
def _update_expert_weights(
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
old_expert_location_metadata: ExpertLocationMetadata,
new_expert_location_metadata: ExpertLocationMetadata,
update_layer_ids: List[int],
nnodes: int,
rank: int,
):
log_metrics = get_bool_env_var("SGLANG_EXPERT_LOCATION_UPDATER_LOG_METRICS")
temp_buffers = create_temp_buffers(
next(iter(routed_experts_weights_of_layer.values()))
routed_experts_weights_of_layer[update_layer_ids[0]]
)
world_size = torch.distributed.get_world_size()
num_local_physical_experts = old_expert_location_metadata.num_local_physical_experts
num_gpu_per_node = world_size // nnodes
old_physical_to_logical_map = (
old_expert_location_metadata.physical_to_logical_map.tolist()
)
new_physical_to_logical_map = (
new_expert_location_metadata.physical_to_logical_map.tolist()
)
for layer_id in sorted(routed_experts_weights_of_layer.keys()):
for layer_id in update_layer_ids:
update_expert_weights_single_layer(
routed_experts_weights=routed_experts_weights_of_layer[layer_id],
temp_buffers=temp_buffers,
old_physical_to_logical_map=old_physical_to_logical_map[layer_id],
new_physical_to_logical_map=new_physical_to_logical_map[layer_id],
old_physical_to_logical_map=old_expert_location_metadata.physical_to_logical_map_cpu[
layer_id
].tolist(),
new_physical_to_logical_map=new_expert_location_metadata.physical_to_logical_map_cpu[
layer_id
].tolist(),
num_local_physical_experts=num_local_physical_experts,
num_gpu_per_node=num_gpu_per_node,
rank=rank,
......
......@@ -611,11 +611,14 @@ class ModelRunner:
) from None
def update_expert_location(
self, new_expert_location_metadata: ExpertLocationMetadata
self,
new_expert_location_metadata: ExpertLocationMetadata,
update_layer_ids: List[int],
):
self.expert_location_updater.update(
self.model.routed_experts_weights_of_layer,
new_expert_location_metadata,
update_layer_ids=update_layer_ids,
nnodes=self.server_args.nnodes,
rank=self.tp_rank,
)
......@@ -1203,7 +1206,7 @@ class ModelRunner:
)
if self.eplb_manager is not None:
self.eplb_manager.on_forward_pass_end(self.forward_pass_id)
self.eplb_manager.on_forward_pass_end()
return output
......
......@@ -180,6 +180,7 @@ class ServerArgs:
enable_eplb: bool = False
eplb_algorithm: str = "auto"
eplb_rebalance_num_iterations: int = 1000
eplb_rebalance_layers_per_chunk: Optional[int] = None
expert_distribution_recorder_mode: Optional[
Literal["stat", "per_pass", "per_token"]
] = None
......@@ -1367,6 +1368,12 @@ class ServerArgs:
default=ServerArgs.eplb_rebalance_num_iterations,
help="Number of iterations to automatically trigger a EPLB re-balance.",
)
parser.add_argument(
"--eplb-rebalance-layers-per-chunk",
type=int,
default=ServerArgs.eplb_rebalance_layers_per_chunk,
help="Number of layers to rebalance per forward pass.",
)
parser.add_argument(
"--expert-distribution-recorder-mode",
type=str,
......
......@@ -5,7 +5,6 @@ from pathlib import Path
from types import SimpleNamespace
import sglang as sgl
from sglang.srt.managers.expert_distribution_storage import ExpertDistributionStorage
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
......@@ -17,7 +16,9 @@ from sglang.test.test_utils import (
)
class TestDynamicEPLB(CustomTestCase):
class _BaseTestDynamicEPLB(CustomTestCase):
extra_args = []
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
......@@ -51,8 +52,13 @@ class TestDynamicEPLB(CustomTestCase):
"stat",
"--ep-dispatch-algorithm",
"static",
*cls.extra_args,
],
env={"SGL_ENABLE_JIT_DEEPGEMM": "0", **os.environ},
env={
"SGL_ENABLE_JIT_DEEPGEMM": "0",
"SGLANG_EXPERT_LOCATION_UPDATER_CANARY": "1",
**os.environ,
},
)
@classmethod
......@@ -72,6 +78,14 @@ class TestDynamicEPLB(CustomTestCase):
self.assertGreater(metrics["score"], 0.5)
class TestDynamicEPLBSimple(_BaseTestDynamicEPLB):
pass
class TestDynamicEPLBMultiChunk(_BaseTestDynamicEPLB):
extra_args = ["--eplb-rebalance-layers-per-chunk", "1"]
class TestStaticEPLB(CustomTestCase):
def test_save_expert_distribution_and_init_expert_location(self):
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "0"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment