Unverified Commit 0de5e7d4 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support layerwise rebalancing experts (#6851)

parent 72a110f6
import logging import logging
import time import time
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, List
import torch.cuda import torch.cuda
...@@ -20,6 +20,10 @@ class EPLBManager: ...@@ -20,6 +20,10 @@ class EPLBManager:
super().__init__() super().__init__()
self._model_runner = model_runner self._model_runner = model_runner
self._server_args = model_runner.server_args self._server_args = model_runner.server_args
self._rebalance_layers_per_chunk = (
self._server_args.eplb_rebalance_layers_per_chunk
)
self._rebalance_num_iterations = self._server_args.eplb_rebalance_num_iterations
# Otherwise, the circular buffer will contain stale data. If the case is needed, it can be implemented. # Otherwise, the circular buffer will contain stale data. If the case is needed, it can be implemented.
assert ( assert (
...@@ -31,17 +35,30 @@ class EPLBManager: ...@@ -31,17 +35,30 @@ class EPLBManager:
get_global_expert_distribution_recorder().start_record() get_global_expert_distribution_recorder().start_record()
logger.info( logger.info(
f"[EPLBManager] system started, will rebalance per {self._server_args.eplb_rebalance_num_iterations} iterations." f"[EPLBManager] system started, will rebalance per {self._rebalance_num_iterations} iterations."
) )
def on_forward_pass_end(self, forward_pass_id: int): self._main_generator = self._entrypoint()
if forward_pass_id % self._server_args.eplb_rebalance_num_iterations == 0:
self.rebalance() def on_forward_pass_end(self):
next(self._main_generator)
# can be more complex if needed
def _entrypoint(self):
while True:
for _ in range(self._rebalance_num_iterations):
yield
yield from self.rebalance()
def rebalance(self): def rebalance(self):
logger.info("[EPLBManager] rebalance start") logger.info("[EPLBManager] rebalance start")
torch.cuda.synchronize()
time_start = time.time() enable_timing = self._rebalance_layers_per_chunk is None
if enable_timing:
torch.cuda.synchronize()
time_start = time.time()
logical_count = get_global_expert_distribution_recorder().dump_record( logical_count = get_global_expert_distribution_recorder().dump_record(
output_mode="object" output_mode="object"
...@@ -49,8 +66,31 @@ class EPLBManager: ...@@ -49,8 +66,31 @@ class EPLBManager:
expert_location_metadata = ExpertLocationMetadata.init_by_eplb( expert_location_metadata = ExpertLocationMetadata.init_by_eplb(
self._server_args, self._model_runner.model_config, logical_count self._server_args, self._model_runner.model_config, logical_count
) )
self._model_runner.update_expert_location(expert_location_metadata)
torch.cuda.synchronize() update_layer_ids_chunks = self._compute_update_layer_ids_chunks()
time_end = time.time() for chunk_index, update_layer_ids in enumerate(update_layer_ids_chunks):
logger.info(f"[EPLBManager] rebalance end time={time_end - time_start:.3f}s") if len(update_layer_ids_chunks) > 1:
yield
self._model_runner.update_expert_location(
expert_location_metadata,
update_layer_ids=update_layer_ids,
)
msg = f"[EPLBManager] rebalance end"
if enable_timing:
torch.cuda.synchronize()
time_end = time.time()
msg += f" time={time_end - time_start:.3f}s"
logger.info(msg)
def _compute_update_layer_ids_chunks(self) -> List[List[int]]:
all_layer_ids = sorted(
list(self._model_runner.model.routed_experts_weights_of_layer.keys())
)
chunk_size = self._rebalance_layers_per_chunk or 1000000
return list(_chunk_list(all_layer_ids, chunk_size=chunk_size))
def _chunk_list(items: List, chunk_size):
for start_index in range(0, len(items), chunk_size):
yield items[start_index : start_index + chunk_size]
...@@ -33,6 +33,7 @@ logger = logging.getLogger(__name__) ...@@ -33,6 +33,7 @@ logger = logging.getLogger(__name__)
@dataclass @dataclass
class ExpertLocationMetadata: class ExpertLocationMetadata:
physical_to_logical_map: torch.Tensor # (layers, num_physical_experts) physical_to_logical_map: torch.Tensor # (layers, num_physical_experts)
physical_to_logical_map_cpu: torch.Tensor
logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X) logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X)
logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts) logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts)
# (layers, num_logical_experts) # (layers, num_logical_experts)
...@@ -203,6 +204,7 @@ class ExpertLocationMetadata: ...@@ -203,6 +204,7 @@ class ExpertLocationMetadata:
return ExpertLocationMetadata( return ExpertLocationMetadata(
physical_to_logical_map=physical_to_logical_map, physical_to_logical_map=physical_to_logical_map,
physical_to_logical_map_cpu=physical_to_logical_map.cpu(),
logical_to_all_physical_map=logical_to_all_physical_map_padded, logical_to_all_physical_map=logical_to_all_physical_map_padded,
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid, logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
logical_to_rank_dispatch_physical_map=( logical_to_rank_dispatch_physical_map=(
...@@ -223,6 +225,7 @@ class ExpertLocationMetadata: ...@@ -223,6 +225,7 @@ class ExpertLocationMetadata:
def update( def update(
self, self,
other: "ExpertLocationMetadata", other: "ExpertLocationMetadata",
update_layer_ids: List[int],
): ):
for field in [ for field in [
"ep_size", "ep_size",
...@@ -231,15 +234,21 @@ class ExpertLocationMetadata: ...@@ -231,15 +234,21 @@ class ExpertLocationMetadata:
for field in [ for field in [
"physical_to_logical_map", "physical_to_logical_map",
"physical_to_logical_map_cpu",
"logical_to_all_physical_map", "logical_to_all_physical_map",
"logical_to_all_physical_map_num_valid", "logical_to_all_physical_map_num_valid",
"logical_to_rank_dispatch_physical_map", "logical_to_rank_dispatch_physical_map",
]: ]:
src = getattr(other, field) other_field = getattr(other, field)
dst = getattr(self, field) self_field = getattr(self, field)
assert (src is not None) == (dst is not None) assert (other_field is not None) == (self_field is not None)
if dst is not None: if self_field is not None:
dst[...] = src mask_update = torch.tensor(
[i in update_layer_ids for i in range(self.num_layers)]
)
mask_update = mask_update.view(*([-1] + [1] * (self_field.dim() - 1)))
mask_update = mask_update.to(self_field.device, non_blocking=True)
self_field[...] = torch.where(mask_update, other_field, self_field)
# -------------------------------- usage ------------------------------------ # -------------------------------- usage ------------------------------------
......
...@@ -24,6 +24,7 @@ from sglang.srt.managers.expert_location import ( ...@@ -24,6 +24,7 @@ from sglang.srt.managers.expert_location import (
ExpertLocationMetadata, ExpertLocationMetadata,
get_global_expert_location_metadata, get_global_expert_location_metadata,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import get_bool_env_var from sglang.srt.utils import get_bool_env_var
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -37,6 +38,7 @@ class ExpertLocationUpdater: ...@@ -37,6 +38,7 @@ class ExpertLocationUpdater:
self, self,
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]], routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
new_expert_location_metadata: ExpertLocationMetadata, new_expert_location_metadata: ExpertLocationMetadata,
update_layer_ids: List[int],
nnodes: int, nnodes: int,
rank: int, rank: int,
): ):
...@@ -46,45 +48,47 @@ class ExpertLocationUpdater: ...@@ -46,45 +48,47 @@ class ExpertLocationUpdater:
old_expert_location_metadata = get_global_expert_location_metadata() old_expert_location_metadata = get_global_expert_location_metadata()
_update_expert_weights( _update_expert_weights(
routed_experts_weights_of_layer, routed_experts_weights_of_layer=routed_experts_weights_of_layer,
old_expert_location_metadata, old_expert_location_metadata=old_expert_location_metadata,
new_expert_location_metadata=new_expert_location_metadata,
update_layer_ids=update_layer_ids,
nnodes=nnodes,
rank=rank,
)
old_expert_location_metadata.update(
new_expert_location_metadata, new_expert_location_metadata,
nnodes, update_layer_ids=update_layer_ids,
rank,
) )
old_expert_location_metadata.update(new_expert_location_metadata)
def _update_expert_weights( def _update_expert_weights(
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]], routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
old_expert_location_metadata: ExpertLocationMetadata, old_expert_location_metadata: ExpertLocationMetadata,
new_expert_location_metadata: ExpertLocationMetadata, new_expert_location_metadata: ExpertLocationMetadata,
update_layer_ids: List[int],
nnodes: int, nnodes: int,
rank: int, rank: int,
): ):
log_metrics = get_bool_env_var("SGLANG_EXPERT_LOCATION_UPDATER_LOG_METRICS") log_metrics = get_bool_env_var("SGLANG_EXPERT_LOCATION_UPDATER_LOG_METRICS")
temp_buffers = create_temp_buffers( temp_buffers = create_temp_buffers(
next(iter(routed_experts_weights_of_layer.values())) routed_experts_weights_of_layer[update_layer_ids[0]]
) )
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
num_local_physical_experts = old_expert_location_metadata.num_local_physical_experts num_local_physical_experts = old_expert_location_metadata.num_local_physical_experts
num_gpu_per_node = world_size // nnodes num_gpu_per_node = world_size // nnodes
old_physical_to_logical_map = ( for layer_id in update_layer_ids:
old_expert_location_metadata.physical_to_logical_map.tolist()
)
new_physical_to_logical_map = (
new_expert_location_metadata.physical_to_logical_map.tolist()
)
for layer_id in sorted(routed_experts_weights_of_layer.keys()):
update_expert_weights_single_layer( update_expert_weights_single_layer(
routed_experts_weights=routed_experts_weights_of_layer[layer_id], routed_experts_weights=routed_experts_weights_of_layer[layer_id],
temp_buffers=temp_buffers, temp_buffers=temp_buffers,
old_physical_to_logical_map=old_physical_to_logical_map[layer_id], old_physical_to_logical_map=old_expert_location_metadata.physical_to_logical_map_cpu[
new_physical_to_logical_map=new_physical_to_logical_map[layer_id], layer_id
].tolist(),
new_physical_to_logical_map=new_expert_location_metadata.physical_to_logical_map_cpu[
layer_id
].tolist(),
num_local_physical_experts=num_local_physical_experts, num_local_physical_experts=num_local_physical_experts,
num_gpu_per_node=num_gpu_per_node, num_gpu_per_node=num_gpu_per_node,
rank=rank, rank=rank,
......
...@@ -611,11 +611,14 @@ class ModelRunner: ...@@ -611,11 +611,14 @@ class ModelRunner:
) from None ) from None
def update_expert_location( def update_expert_location(
self, new_expert_location_metadata: ExpertLocationMetadata self,
new_expert_location_metadata: ExpertLocationMetadata,
update_layer_ids: List[int],
): ):
self.expert_location_updater.update( self.expert_location_updater.update(
self.model.routed_experts_weights_of_layer, self.model.routed_experts_weights_of_layer,
new_expert_location_metadata, new_expert_location_metadata,
update_layer_ids=update_layer_ids,
nnodes=self.server_args.nnodes, nnodes=self.server_args.nnodes,
rank=self.tp_rank, rank=self.tp_rank,
) )
...@@ -1203,7 +1206,7 @@ class ModelRunner: ...@@ -1203,7 +1206,7 @@ class ModelRunner:
) )
if self.eplb_manager is not None: if self.eplb_manager is not None:
self.eplb_manager.on_forward_pass_end(self.forward_pass_id) self.eplb_manager.on_forward_pass_end()
return output return output
......
...@@ -180,6 +180,7 @@ class ServerArgs: ...@@ -180,6 +180,7 @@ class ServerArgs:
enable_eplb: bool = False enable_eplb: bool = False
eplb_algorithm: str = "auto" eplb_algorithm: str = "auto"
eplb_rebalance_num_iterations: int = 1000 eplb_rebalance_num_iterations: int = 1000
eplb_rebalance_layers_per_chunk: Optional[int] = None
expert_distribution_recorder_mode: Optional[ expert_distribution_recorder_mode: Optional[
Literal["stat", "per_pass", "per_token"] Literal["stat", "per_pass", "per_token"]
] = None ] = None
...@@ -1367,6 +1368,12 @@ class ServerArgs: ...@@ -1367,6 +1368,12 @@ class ServerArgs:
default=ServerArgs.eplb_rebalance_num_iterations, default=ServerArgs.eplb_rebalance_num_iterations,
help="Number of iterations to automatically trigger a EPLB re-balance.", help="Number of iterations to automatically trigger a EPLB re-balance.",
) )
parser.add_argument(
"--eplb-rebalance-layers-per-chunk",
type=int,
default=ServerArgs.eplb_rebalance_layers_per_chunk,
help="Number of layers to rebalance per forward pass.",
)
parser.add_argument( parser.add_argument(
"--expert-distribution-recorder-mode", "--expert-distribution-recorder-mode",
type=str, type=str,
......
...@@ -5,7 +5,6 @@ from pathlib import Path ...@@ -5,7 +5,6 @@ from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
import sglang as sgl import sglang as sgl
from sglang.srt.managers.expert_distribution_storage import ExpertDistributionStorage
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
...@@ -17,7 +16,9 @@ from sglang.test.test_utils import ( ...@@ -17,7 +16,9 @@ from sglang.test.test_utils import (
) )
class TestDynamicEPLB(CustomTestCase): class _BaseTestDynamicEPLB(CustomTestCase):
extra_args = []
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
...@@ -51,8 +52,13 @@ class TestDynamicEPLB(CustomTestCase): ...@@ -51,8 +52,13 @@ class TestDynamicEPLB(CustomTestCase):
"stat", "stat",
"--ep-dispatch-algorithm", "--ep-dispatch-algorithm",
"static", "static",
*cls.extra_args,
], ],
env={"SGL_ENABLE_JIT_DEEPGEMM": "0", **os.environ}, env={
"SGL_ENABLE_JIT_DEEPGEMM": "0",
"SGLANG_EXPERT_LOCATION_UPDATER_CANARY": "1",
**os.environ,
},
) )
@classmethod @classmethod
...@@ -72,6 +78,14 @@ class TestDynamicEPLB(CustomTestCase): ...@@ -72,6 +78,14 @@ class TestDynamicEPLB(CustomTestCase):
self.assertGreater(metrics["score"], 0.5) self.assertGreater(metrics["score"], 0.5)
class TestDynamicEPLBSimple(_BaseTestDynamicEPLB):
pass
class TestDynamicEPLBMultiChunk(_BaseTestDynamicEPLB):
extra_args = ["--eplb-rebalance-layers-per-chunk", "1"]
class TestStaticEPLB(CustomTestCase): class TestStaticEPLB(CustomTestCase):
def test_save_expert_distribution_and_init_expert_location(self): def test_save_expert_distribution_and_init_expert_location(self):
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "0" os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "0"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment