Unverified Commit df7f61ee authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Speed up rebalancing when using non-static dispatch algorithms (#6812)

parent ef21729c
...@@ -35,7 +35,8 @@ class ExpertLocationMetadata: ...@@ -35,7 +35,8 @@ class ExpertLocationMetadata:
physical_to_logical_map: torch.Tensor # (layers, num_physical_experts) physical_to_logical_map: torch.Tensor # (layers, num_physical_experts)
logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X) logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X)
logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts) logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts)
logical_to_rank_dispatch_physical_map: torch.Tensor # (layers, num_logical_experts) # (layers, num_logical_experts)
logical_to_rank_dispatch_physical_map: Optional[torch.Tensor]
# -------------------------------- properties ------------------------------------ # -------------------------------- properties ------------------------------------
...@@ -70,11 +71,8 @@ class ExpertLocationMetadata: ...@@ -70,11 +71,8 @@ class ExpertLocationMetadata:
num_layers_2, num_logical_experts_1 = ( num_layers_2, num_logical_experts_1 = (
self.logical_to_all_physical_map_num_valid.shape self.logical_to_all_physical_map_num_valid.shape
) )
num_layers_3, num_logical_experts_2 = ( assert num_layers_0 == num_layers_1 == num_layers_2
self.logical_to_rank_dispatch_physical_map.shape assert num_logical_experts_0 == num_logical_experts_1
)
assert num_layers_0 == num_layers_1 == num_layers_2 == num_layers_3
assert num_logical_experts_0 == num_logical_experts_1 == num_logical_experts_2
assert num_physical_experts_0 == num_physical_experts_1 assert num_physical_experts_0 == num_physical_experts_1
# -------------------------------- construction ------------------------------------ # -------------------------------- construction ------------------------------------
...@@ -117,6 +115,7 @@ class ExpertLocationMetadata: ...@@ -117,6 +115,7 @@ class ExpertLocationMetadata:
) )
return ExpertLocationMetadata._init_raw( return ExpertLocationMetadata._init_raw(
server_args=server_args,
ep_size=common["ep_size"], ep_size=common["ep_size"],
physical_to_logical_map=physical_to_logical_map, physical_to_logical_map=physical_to_logical_map,
logical_to_all_physical_map=logical_to_all_physical_map, logical_to_all_physical_map=logical_to_all_physical_map,
...@@ -154,6 +153,7 @@ class ExpertLocationMetadata: ...@@ -154,6 +153,7 @@ class ExpertLocationMetadata:
) )
return ExpertLocationMetadata._init_raw( return ExpertLocationMetadata._init_raw(
server_args=server_args,
ep_size=common["ep_size"], ep_size=common["ep_size"],
physical_to_logical_map=physical_to_logical_map.to(server_args.device), physical_to_logical_map=physical_to_logical_map.to(server_args.device),
logical_to_all_physical_map=logical_to_all_physical_map.to( logical_to_all_physical_map=logical_to_all_physical_map.to(
...@@ -184,6 +184,7 @@ class ExpertLocationMetadata: ...@@ -184,6 +184,7 @@ class ExpertLocationMetadata:
@staticmethod @staticmethod
def _init_raw( def _init_raw(
server_args: ServerArgs,
ep_size: int, ep_size: int,
physical_to_logical_map: torch.Tensor, physical_to_logical_map: torch.Tensor,
logical_to_all_physical_map: torch.Tensor, logical_to_all_physical_map: torch.Tensor,
...@@ -204,12 +205,16 @@ class ExpertLocationMetadata: ...@@ -204,12 +205,16 @@ class ExpertLocationMetadata:
physical_to_logical_map=physical_to_logical_map, physical_to_logical_map=physical_to_logical_map,
logical_to_all_physical_map=logical_to_all_physical_map_padded, logical_to_all_physical_map=logical_to_all_physical_map_padded,
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid, logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
logical_to_rank_dispatch_physical_map=compute_logical_to_rank_dispatch_physical_map( logical_to_rank_dispatch_physical_map=(
logical_to_all_physical_map=logical_to_all_physical_map, compute_logical_to_rank_dispatch_physical_map(
num_gpus=ep_size, logical_to_all_physical_map=logical_to_all_physical_map,
num_physical_experts=num_physical_experts, num_gpus=ep_size,
# TODO improve when we have real EP rank num_physical_experts=num_physical_experts,
ep_rank=torch.distributed.get_rank() % ep_size, # TODO improve when we have real EP rank
ep_rank=torch.distributed.get_rank() % ep_size,
)
if server_args.ep_dispatch_algorithm == "static"
else None
), ),
) )
...@@ -230,8 +235,11 @@ class ExpertLocationMetadata: ...@@ -230,8 +235,11 @@ class ExpertLocationMetadata:
"logical_to_all_physical_map_num_valid", "logical_to_all_physical_map_num_valid",
"logical_to_rank_dispatch_physical_map", "logical_to_rank_dispatch_physical_map",
]: ]:
src = getattr(other, field)
dst = getattr(self, field) dst = getattr(self, field)
dst[...] = getattr(other, field) assert (src is not None) == (dst is not None)
if dst is not None:
dst[...] = src
# -------------------------------- usage ------------------------------------ # -------------------------------- usage ------------------------------------
......
...@@ -25,7 +25,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict ...@@ -25,7 +25,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
class ExpertLocationDispatchInfo: class ExpertLocationDispatchInfo:
ep_dispatch_algorithm: Literal["static", "random"] ep_dispatch_algorithm: Literal["static", "random"]
# (num_logical_experts,) # (num_logical_experts,)
partial_logical_to_rank_dispatch_physical_map: torch.Tensor partial_logical_to_rank_dispatch_physical_map: Optional[torch.Tensor]
# (num_logical_experts, X) # (num_logical_experts, X)
partial_logical_to_all_physical_map: torch.Tensor partial_logical_to_all_physical_map: torch.Tensor
# (num_logical_experts,) # (num_logical_experts,)
...@@ -42,9 +42,14 @@ class ExpertLocationDispatchInfo: ...@@ -42,9 +42,14 @@ class ExpertLocationDispatchInfo:
return cls( return cls(
ep_dispatch_algorithm=ep_dispatch_algorithm, ep_dispatch_algorithm=ep_dispatch_algorithm,
partial_logical_to_rank_dispatch_physical_map=expert_location_metadata.logical_to_rank_dispatch_physical_map[ partial_logical_to_rank_dispatch_physical_map=(
layer_id, : expert_location_metadata.logical_to_rank_dispatch_physical_map[
], layer_id, :
]
if expert_location_metadata.logical_to_rank_dispatch_physical_map
is not None
else None
),
partial_logical_to_all_physical_map=expert_location_metadata.logical_to_all_physical_map[ partial_logical_to_all_physical_map=expert_location_metadata.logical_to_all_physical_map[
layer_id, : layer_id, :
], ],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment