Unverified Commit cbbb7383 authored by Stefan He's avatar Stefan He Committed by GitHub
Browse files

[2/3] Optimize Slime Update Weights: Avoid GPU-to-CPU Device Sync when update...

[2/3]  Optimize Slime Update Weights: Avoid GPU-to-CPU Device Sync when update expert weights (#8753)
parent 89588179
...@@ -35,6 +35,7 @@ class ExpertLocationMetadata: ...@@ -35,6 +35,7 @@ class ExpertLocationMetadata:
physical_to_logical_map: torch.Tensor # (layers, num_physical_experts) physical_to_logical_map: torch.Tensor # (layers, num_physical_experts)
physical_to_logical_map_cpu: torch.Tensor physical_to_logical_map_cpu: torch.Tensor
logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X) logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X)
logical_to_all_physical_map_cpu: torch.Tensor # CPU copy for performance
logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts) logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts)
# (layers, num_logical_experts) # (layers, num_logical_experts)
logical_to_rank_dispatch_physical_map: Optional[torch.Tensor] logical_to_rank_dispatch_physical_map: Optional[torch.Tensor]
...@@ -221,6 +222,7 @@ class ExpertLocationMetadata: ...@@ -221,6 +222,7 @@ class ExpertLocationMetadata:
physical_to_logical_map=physical_to_logical_map, physical_to_logical_map=physical_to_logical_map,
physical_to_logical_map_cpu=physical_to_logical_map.cpu(), physical_to_logical_map_cpu=physical_to_logical_map.cpu(),
logical_to_all_physical_map=logical_to_all_physical_map_padded, logical_to_all_physical_map=logical_to_all_physical_map_padded,
logical_to_all_physical_map_cpu=logical_to_all_physical_map_padded.cpu(),
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid, logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
logical_to_rank_dispatch_physical_map=( logical_to_rank_dispatch_physical_map=(
compute_logical_to_rank_dispatch_physical_map( compute_logical_to_rank_dispatch_physical_map(
...@@ -251,6 +253,7 @@ class ExpertLocationMetadata: ...@@ -251,6 +253,7 @@ class ExpertLocationMetadata:
"physical_to_logical_map", "physical_to_logical_map",
"physical_to_logical_map_cpu", "physical_to_logical_map_cpu",
"logical_to_all_physical_map", "logical_to_all_physical_map",
"logical_to_all_physical_map_cpu",
"logical_to_all_physical_map_num_valid", "logical_to_all_physical_map_num_valid",
"logical_to_rank_dispatch_physical_map", "logical_to_rank_dispatch_physical_map",
]: ]:
...@@ -270,9 +273,10 @@ class ExpertLocationMetadata: ...@@ -270,9 +273,10 @@ class ExpertLocationMetadata:
def logical_to_all_physical( def logical_to_all_physical(
self, layer_id: int, logical_expert_id: int self, layer_id: int, logical_expert_id: int
) -> List[int]: ) -> List[int]:
# Use CPU copy to avoid GPU→CPU sync on every call, which is expensive in update weights scenario
return [ return [
physical_expert_id physical_expert_id
for physical_expert_id in self.logical_to_all_physical_map[ for physical_expert_id in self.logical_to_all_physical_map_cpu[
layer_id, logical_expert_id layer_id, logical_expert_id
].tolist() ].tolist()
if physical_expert_id != -1 if physical_expert_id != -1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment