Unverified Commit 68a34911 authored by Ning Xie's avatar Ning Xie Committed by GitHub
Browse files

[Misc] enhance type hint for rearrange return value (#23519)


Signed-off-by: default avatarAndy Xie <andy.xning@gmail.com>
parent e80bca30
...@@ -409,12 +409,14 @@ class EplbState: ...@@ -409,12 +409,14 @@ class EplbState:
self.expert_rearrangement_step = 0 self.expert_rearrangement_step = 0
self.rearrange(model) self.rearrange(model)
def rearrange(self, def rearrange(
model: MixtureOfExperts, self,
is_profile: bool = False, model: MixtureOfExperts,
execute_shuffle: bool = True, is_profile: bool = False,
global_expert_load: Optional[torch.Tensor] = None, execute_shuffle: bool = True,
rank_mapping: Optional[dict[int, int]] = None) -> None: global_expert_load: Optional[torch.Tensor] = None,
rank_mapping: Optional[dict[int,
int]] = None) -> Optional[torch.Tensor]:
""" """
Rearrange the experts according to the current load. Rearrange the experts according to the current load.
""" """
...@@ -548,6 +550,7 @@ class EplbState: ...@@ -548,6 +550,7 @@ class EplbState:
" (profile) " if is_profile else " ", " (profile) " if is_profile else " ",
time_end - time_start, time_end - time_start,
) )
return None
@staticmethod @staticmethod
def recv_state() -> tuple[torch.Tensor, torch.Tensor]: def recv_state() -> tuple[torch.Tensor, torch.Tensor]:
...@@ -613,4 +616,4 @@ def _node_count_with_rank_mapping( ...@@ -613,4 +616,4 @@ def _node_count_with_rank_mapping(
if is_same_node and node_assignment[other_rank] == 0: if is_same_node and node_assignment[other_rank] == 0:
node_assignment[other_rank] = next_node_id node_assignment[other_rank] = next_node_id
return next_node_id return next_node_id
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment