Unverified Commit df7f61ee authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Speed up rebalancing when using non-static dispatch algorithms (#6812)

parent ef21729c
......@@ -35,7 +35,8 @@ class ExpertLocationMetadata:
physical_to_logical_map: torch.Tensor # (layers, num_physical_experts)
logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X)
logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts)
logical_to_rank_dispatch_physical_map: torch.Tensor # (layers, num_logical_experts)
# (layers, num_logical_experts)
logical_to_rank_dispatch_physical_map: Optional[torch.Tensor]
# -------------------------------- properties ------------------------------------
......@@ -70,11 +71,8 @@ class ExpertLocationMetadata:
num_layers_2, num_logical_experts_1 = (
self.logical_to_all_physical_map_num_valid.shape
)
num_layers_3, num_logical_experts_2 = (
self.logical_to_rank_dispatch_physical_map.shape
)
assert num_layers_0 == num_layers_1 == num_layers_2 == num_layers_3
assert num_logical_experts_0 == num_logical_experts_1 == num_logical_experts_2
assert num_layers_0 == num_layers_1 == num_layers_2
assert num_logical_experts_0 == num_logical_experts_1
assert num_physical_experts_0 == num_physical_experts_1
# -------------------------------- construction ------------------------------------
......@@ -117,6 +115,7 @@ class ExpertLocationMetadata:
)
return ExpertLocationMetadata._init_raw(
server_args=server_args,
ep_size=common["ep_size"],
physical_to_logical_map=physical_to_logical_map,
logical_to_all_physical_map=logical_to_all_physical_map,
......@@ -154,6 +153,7 @@ class ExpertLocationMetadata:
)
return ExpertLocationMetadata._init_raw(
server_args=server_args,
ep_size=common["ep_size"],
physical_to_logical_map=physical_to_logical_map.to(server_args.device),
logical_to_all_physical_map=logical_to_all_physical_map.to(
......@@ -184,6 +184,7 @@ class ExpertLocationMetadata:
@staticmethod
def _init_raw(
server_args: ServerArgs,
ep_size: int,
physical_to_logical_map: torch.Tensor,
logical_to_all_physical_map: torch.Tensor,
......@@ -204,12 +205,16 @@ class ExpertLocationMetadata:
physical_to_logical_map=physical_to_logical_map,
logical_to_all_physical_map=logical_to_all_physical_map_padded,
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
logical_to_rank_dispatch_physical_map=compute_logical_to_rank_dispatch_physical_map(
logical_to_all_physical_map=logical_to_all_physical_map,
num_gpus=ep_size,
num_physical_experts=num_physical_experts,
# TODO improve when we have real EP rank
ep_rank=torch.distributed.get_rank() % ep_size,
logical_to_rank_dispatch_physical_map=(
compute_logical_to_rank_dispatch_physical_map(
logical_to_all_physical_map=logical_to_all_physical_map,
num_gpus=ep_size,
num_physical_experts=num_physical_experts,
# TODO improve when we have real EP rank
ep_rank=torch.distributed.get_rank() % ep_size,
)
if server_args.ep_dispatch_algorithm == "static"
else None
),
)
......@@ -230,8 +235,11 @@ class ExpertLocationMetadata:
"logical_to_all_physical_map_num_valid",
"logical_to_rank_dispatch_physical_map",
]:
src = getattr(other, field)
dst = getattr(self, field)
dst[...] = getattr(other, field)
assert (src is not None) == (dst is not None)
if dst is not None:
dst[...] = src
# -------------------------------- usage ------------------------------------
......
......@@ -25,7 +25,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
class ExpertLocationDispatchInfo:
ep_dispatch_algorithm: Literal["static", "random"]
# (num_logical_experts,)
partial_logical_to_rank_dispatch_physical_map: torch.Tensor
partial_logical_to_rank_dispatch_physical_map: Optional[torch.Tensor]
# (num_logical_experts, X)
partial_logical_to_all_physical_map: torch.Tensor
# (num_logical_experts,)
......@@ -42,9 +42,14 @@ class ExpertLocationDispatchInfo:
return cls(
ep_dispatch_algorithm=ep_dispatch_algorithm,
partial_logical_to_rank_dispatch_physical_map=expert_location_metadata.logical_to_rank_dispatch_physical_map[
layer_id, :
],
partial_logical_to_rank_dispatch_physical_map=(
expert_location_metadata.logical_to_rank_dispatch_physical_map[
layer_id, :
]
if expert_location_metadata.logical_to_rank_dispatch_physical_map
is not None
else None
),
partial_logical_to_all_physical_map=expert_location_metadata.logical_to_all_physical_map[
layer_id, :
],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment