"vscode:/vscode.git/clone" did not exist on "c18f957d0e078f799da5e44e4ac4251cb16b72d4"
Unverified Commit 2c3b71d6 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Improve EPLB logical to physical dispatch map (#6727)

parent 51cdd81f
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# ============================================================================== # ==============================================================================
import json import json
import logging import logging
import random
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional
...@@ -205,10 +206,10 @@ class ExpertLocationMetadata: ...@@ -205,10 +206,10 @@ class ExpertLocationMetadata:
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid, logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
logical_to_rank_dispatch_physical_map=compute_logical_to_rank_dispatch_physical_map( logical_to_rank_dispatch_physical_map=compute_logical_to_rank_dispatch_physical_map(
logical_to_all_physical_map=logical_to_all_physical_map, logical_to_all_physical_map=logical_to_all_physical_map,
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
num_gpus=ep_size, num_gpus=ep_size,
num_physical_experts=num_physical_experts, num_physical_experts=num_physical_experts,
ep_rank=torch.distributed.get_rank(), # TODO improve when we have real EP rank
ep_rank=torch.distributed.get_rank() % ep_size,
), ),
) )
...@@ -296,49 +297,82 @@ def _pad_nested_array(arr, pad_value): ...@@ -296,49 +297,82 @@ def _pad_nested_array(arr, pad_value):
return padded return padded
# TODO use more sophisticated approaches # TODO optimize performance (rewrite and/or run in separate process with overlap)
def compute_logical_to_rank_dispatch_physical_map( def compute_logical_to_rank_dispatch_physical_map(
logical_to_all_physical_map: torch.Tensor, logical_to_all_physical_map: torch.Tensor,
logical_to_all_physical_map_num_valid: torch.Tensor,
num_gpus: int, num_gpus: int,
num_physical_experts: int, num_physical_experts: int,
ep_rank: int, ep_rank: int,
base_seed: int = 42, seed: int = 42,
): ):
device = logical_to_all_physical_map.device r = random.Random(seed)
num_local_physical_experts = num_physical_experts // num_gpus num_local_physical_experts = num_physical_experts // num_gpus
num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape
dtype = logical_to_all_physical_map.dtype
g = torch.Generator(device=device) logical_to_rank_dispatch_physical_map = torch.full(
g.manual_seed(base_seed + ep_rank) size=(num_gpus, num_layers, num_logical_experts),
fill_value=-1,
output_shape = (num_layers, num_logical_experts) dtype=dtype,
chosen_index = (
torch.randint(
0, 65536, output_shape, dtype=torch.int32, device=device, generator=g
)
% logical_to_all_physical_map_num_valid
) )
logical_to_rank_dispatch_physical_map = torch.gather(
logical_to_all_physical_map, dim=2, index=chosen_index.unsqueeze(-1) for layer_id in range(num_layers):
).squeeze(-1) for logical_expert_id in range(num_logical_experts):
assert logical_to_rank_dispatch_physical_map.shape == output_shape candidate_physical_expert_ids = _logical_to_all_physical_raw(
logical_to_all_physical_map, layer_id, logical_expert_id
for index in range(logical_to_all_physical_map_num_valid.max().item()): )
partial_logical_to_all_physical_map = logical_to_all_physical_map[:, :, index] output_partial = logical_to_rank_dispatch_physical_map[
is_valid = partial_logical_to_all_physical_map != -1 :, layer_id, logical_expert_id
is_same_gpu = ( ]
partial_logical_to_all_physical_map // num_local_physical_experts
) == ep_rank for gpu_id in range(num_gpus):
logical_to_rank_dispatch_physical_map = torch.where( same_gpu_physical_expert_ids = [
is_valid & is_same_gpu, physical_expert_id
partial_logical_to_all_physical_map, for physical_expert_id in candidate_physical_expert_ids
logical_to_rank_dispatch_physical_map, if _compute_gpu_id_of_physical_expert(
) physical_expert_id, num_local_physical_experts
)
== gpu_id
]
if len(same_gpu_physical_expert_ids) > 0:
output_partial[gpu_id] = same_gpu_physical_expert_ids[0]
num_remain = torch.sum(output_partial == -1).item()
output_partial[output_partial == -1] = torch.tensor(
_fair_choices(candidate_physical_expert_ids, k=num_remain, r=r),
dtype=dtype,
)
assert torch.all(logical_to_rank_dispatch_physical_map != -1) assert torch.all(logical_to_rank_dispatch_physical_map != -1)
return logical_to_rank_dispatch_physical_map
device = logical_to_all_physical_map.device
return logical_to_rank_dispatch_physical_map[ep_rank, :, :].to(device)
def _logical_to_all_physical_raw(
logical_to_all_physical_map, layer_id: int, logical_expert_id: int
) -> List[int]:
return [
physical_expert_id
for physical_expert_id in logical_to_all_physical_map[
layer_id, logical_expert_id
].tolist()
if physical_expert_id != -1
]
def _compute_gpu_id_of_physical_expert(
physical_expert_id: int, num_local_physical_experts: int
) -> int:
return physical_expert_id // num_local_physical_experts
def _fair_choices(arr: List, k: int, r: random.Random) -> List:
quotient, remainder = divmod(k, len(arr))
ans = arr * quotient + r.sample(arr, k=remainder)
r.shuffle(ans)
return ans
@dataclass @dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment