"docs/vscode:/vscode.git/clone" did not exist on "1bc0d37ffe0eba638f230bdae2015f1ee4696e5f"
Unverified Commit ec92b0ce authored by Yingchun Lai's avatar Yingchun Lai Committed by GitHub
Browse files

EPLB: prefer to use physical experts in the same gpu or node (#10874)

parent e03b6bee
...@@ -85,7 +85,9 @@ class ExpertLocationMetadata: ...@@ -85,7 +85,9 @@ class ExpertLocationMetadata:
# -------------------------------- construction ------------------------------------ # -------------------------------- construction ------------------------------------
@staticmethod @staticmethod
def init_trivial(server_args: ServerArgs, model_config: ModelConfig): def init_trivial(
server_args: ServerArgs, model_config: ModelConfig, moe_ep_rank: int
):
"""Trivial location - logical expert i corresponds to physical expert i""" """Trivial location - logical expert i corresponds to physical expert i"""
common = ExpertLocationMetadata._init_common(server_args, model_config) common = ExpertLocationMetadata._init_common(server_args, model_config)
...@@ -106,6 +108,7 @@ class ExpertLocationMetadata: ...@@ -106,6 +108,7 @@ class ExpertLocationMetadata:
server_args, server_args,
model_config, model_config,
physical_to_logical_map=physical_to_logical_map, physical_to_logical_map=physical_to_logical_map,
moe_ep_rank=moe_ep_rank,
) )
@staticmethod @staticmethod
...@@ -113,6 +116,7 @@ class ExpertLocationMetadata: ...@@ -113,6 +116,7 @@ class ExpertLocationMetadata:
server_args: ServerArgs, server_args: ServerArgs,
model_config: ModelConfig, model_config: ModelConfig,
physical_to_logical_map, physical_to_logical_map,
moe_ep_rank: int = None,
): ):
if not isinstance(physical_to_logical_map, torch.Tensor): if not isinstance(physical_to_logical_map, torch.Tensor):
physical_to_logical_map = torch.tensor(physical_to_logical_map) physical_to_logical_map = torch.tensor(physical_to_logical_map)
...@@ -125,8 +129,11 @@ class ExpertLocationMetadata: ...@@ -125,8 +129,11 @@ class ExpertLocationMetadata:
model_config_for_expert_location = common["model_config_for_expert_location"] model_config_for_expert_location = common["model_config_for_expert_location"]
logical_to_all_physical_map = _compute_logical_to_all_physical_map( logical_to_all_physical_map = _compute_logical_to_all_physical_map(
physical_to_logical_map, server_args=server_args,
physical_to_logical_map=physical_to_logical_map,
num_logical_experts=model_config_for_expert_location.num_logical_experts, num_logical_experts=model_config_for_expert_location.num_logical_experts,
ep_size=common["ep_size"],
moe_ep_rank=moe_ep_rank,
) )
return ExpertLocationMetadata._init_raw( return ExpertLocationMetadata._init_raw(
...@@ -233,7 +240,7 @@ class ExpertLocationMetadata: ...@@ -233,7 +240,7 @@ class ExpertLocationMetadata:
compute_logical_to_rank_dispatch_physical_map( compute_logical_to_rank_dispatch_physical_map(
server_args=server_args, server_args=server_args,
logical_to_all_physical_map=logical_to_all_physical_map, logical_to_all_physical_map=logical_to_all_physical_map,
num_gpus=ep_size, ep_size=ep_size,
num_physical_experts=num_physical_experts, num_physical_experts=num_physical_experts,
# TODO improve when we have real EP rank # TODO improve when we have real EP rank
ep_rank=torch.distributed.get_rank() % ep_size, ep_rank=torch.distributed.get_rank() % ep_size,
...@@ -303,7 +310,11 @@ def set_global_expert_location_metadata(value): ...@@ -303,7 +310,11 @@ def set_global_expert_location_metadata(value):
def _compute_logical_to_all_physical_map( def _compute_logical_to_all_physical_map(
physical_to_logical_map: torch.Tensor, num_logical_experts: int server_args: ServerArgs,
physical_to_logical_map: torch.Tensor,
num_logical_experts: int,
ep_size: int,
moe_ep_rank: int,
): ):
# This is rarely called, so we use for loops for maximum clarity # This is rarely called, so we use for loops for maximum clarity
...@@ -312,6 +323,8 @@ def _compute_logical_to_all_physical_map( ...@@ -312,6 +323,8 @@ def _compute_logical_to_all_physical_map(
logical_to_all_physical_map = [ logical_to_all_physical_map = [
[[] for _ in range(num_logical_experts)] for _ in range(num_layers) [[] for _ in range(num_logical_experts)] for _ in range(num_layers)
] ]
# Find out the candidate physical experts for each logical expert on each layer
for layer_id in range(num_layers): for layer_id in range(num_layers):
for physical_expert_id in range(num_physical_experts): for physical_expert_id in range(num_physical_experts):
logical_expert_id = physical_to_logical_map[ logical_expert_id = physical_to_logical_map[
...@@ -321,6 +334,32 @@ def _compute_logical_to_all_physical_map( ...@@ -321,6 +334,32 @@ def _compute_logical_to_all_physical_map(
physical_expert_id physical_expert_id
) )
# Replace by the physical expert on local GPU or node if possible
if moe_ep_rank is not None:
num_gpus_per_node = server_args.ep_size // server_args.nnodes
num_local_gpu_physical_experts = num_physical_experts // ep_size
num_local_node_physical_experts = (
num_local_gpu_physical_experts * num_gpus_per_node
)
for layer_id in range(num_layers):
for logical_expert_id in range(num_logical_experts):
# Try to find the nearest physical expert
nearest_expert = _find_nearest_expert(
candidate_physical_expert_ids=logical_to_all_physical_map[layer_id][
logical_expert_id
],
num_local_gpu_physical_experts=num_local_gpu_physical_experts,
moe_ep_rank=moe_ep_rank,
num_gpus_per_node=num_gpus_per_node,
num_local_node_physical_experts=num_local_node_physical_experts,
)
# Replace by the nearest physical expert
if nearest_expert != -1:
logical_to_all_physical_map[layer_id][logical_expert_id] = [
nearest_expert
]
logical_to_all_physical_map = _pad_nested_array( logical_to_all_physical_map = _pad_nested_array(
logical_to_all_physical_map, pad_value=-1 logical_to_all_physical_map, pad_value=-1
) )
...@@ -343,21 +382,21 @@ def _pad_nested_array(arr, pad_value): ...@@ -343,21 +382,21 @@ def _pad_nested_array(arr, pad_value):
def compute_logical_to_rank_dispatch_physical_map( def compute_logical_to_rank_dispatch_physical_map(
server_args: ServerArgs, server_args: ServerArgs,
logical_to_all_physical_map: torch.Tensor, logical_to_all_physical_map: torch.Tensor,
num_gpus: int, ep_size: int,
num_physical_experts: int, num_physical_experts: int,
ep_rank: int, ep_rank: int,
seed: int = 42, seed: int = 42,
): ):
r = random.Random(seed) r = random.Random(seed)
num_local_gpu_physical_experts = num_physical_experts // num_gpus num_local_gpu_physical_experts = num_physical_experts // ep_size
num_gpus_per_node = server_args.ep_size // server_args.nnodes num_gpus_per_node = server_args.ep_size // server_args.nnodes
num_local_node_physical_experts = num_local_gpu_physical_experts * num_gpus_per_node num_local_node_physical_experts = num_local_gpu_physical_experts * num_gpus_per_node
num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape
dtype = logical_to_all_physical_map.dtype dtype = logical_to_all_physical_map.dtype
logical_to_rank_dispatch_physical_map = torch.full( logical_to_rank_dispatch_physical_map = torch.full(
size=(num_gpus, num_layers, num_logical_experts), size=(ep_size, num_layers, num_logical_experts),
fill_value=-1, fill_value=-1,
dtype=dtype, dtype=dtype,
) )
...@@ -371,33 +410,17 @@ def compute_logical_to_rank_dispatch_physical_map( ...@@ -371,33 +410,17 @@ def compute_logical_to_rank_dispatch_physical_map(
:, layer_id, logical_expert_id :, layer_id, logical_expert_id
] ]
for gpu_id in range(num_gpus): for moe_ep_rank in range(ep_size):
same_gpu_physical_expert_ids = [ # Fill with the nearest physical expert
physical_expert_id output_partial[moe_ep_rank] = _find_nearest_expert(
for physical_expert_id in candidate_physical_expert_ids candidate_physical_expert_ids=candidate_physical_expert_ids,
if _compute_gpu_id_of_physical_expert( num_local_gpu_physical_experts=num_local_gpu_physical_experts,
physical_expert_id, num_local_gpu_physical_experts moe_ep_rank=moe_ep_rank,
) num_gpus_per_node=num_gpus_per_node,
== gpu_id num_local_node_physical_experts=num_local_node_physical_experts,
] )
if len(same_gpu_physical_expert_ids) > 0:
# 1. Prefer same-GPU experts
output_partial[gpu_id] = same_gpu_physical_expert_ids[0]
else:
# 2. Otherwise, prefer same-node experts
node_id = gpu_id // num_gpus_per_node
same_node_physical_expert_ids = [
physical_expert_id
for physical_expert_id in candidate_physical_expert_ids
if _compute_node_id_of_physical_expert(
physical_expert_id, num_local_node_physical_experts
)
== node_id
]
if len(same_node_physical_expert_ids) > 0:
output_partial[gpu_id] = same_node_physical_expert_ids[0]
# 3. Fill remaining slots with fair random choices # Fill remaining slots with fair random choices
num_remain = torch.sum(output_partial == -1).item() num_remain = torch.sum(output_partial == -1).item()
output_partial[output_partial == -1] = torch.tensor( output_partial[output_partial == -1] = torch.tensor(
_fair_choices(candidate_physical_expert_ids, k=num_remain, r=r), _fair_choices(candidate_physical_expert_ids, k=num_remain, r=r),
...@@ -434,6 +457,46 @@ def _compute_node_id_of_physical_expert( ...@@ -434,6 +457,46 @@ def _compute_node_id_of_physical_expert(
return physical_expert_id // num_local_host_physical_experts return physical_expert_id // num_local_host_physical_experts
def _find_nearest_expert(
candidate_physical_expert_ids: List[int],
num_local_gpu_physical_experts: int,
moe_ep_rank: int,
num_gpus_per_node: int,
num_local_node_physical_experts: int,
) -> int:
# 1. If only one candidate, return it directly
if len(candidate_physical_expert_ids) == 1:
return candidate_physical_expert_ids[0]
# 2. Prefer same-GPU experts
same_gpu_physical_expert_ids = [
physical_expert_id
for physical_expert_id in candidate_physical_expert_ids
if _compute_gpu_id_of_physical_expert(
physical_expert_id, num_local_gpu_physical_experts
)
== moe_ep_rank
]
if len(same_gpu_physical_expert_ids) > 0:
return same_gpu_physical_expert_ids[0]
# 3. Otherwise, prefer same-node experts
node_rank = moe_ep_rank // num_gpus_per_node
same_node_physical_expert_ids = [
physical_expert_id
for physical_expert_id in candidate_physical_expert_ids
if _compute_node_id_of_physical_expert(
physical_expert_id, num_local_node_physical_experts
)
== node_rank
]
if len(same_node_physical_expert_ids) > 0:
return same_node_physical_expert_ids[0]
# 4. At last, leave it as -1 to indicate not found.
return -1
def _fair_choices(arr: List, k: int, r: random.Random) -> List: def _fair_choices(arr: List, k: int, r: random.Random) -> List:
quotient, remainder = divmod(k, len(arr)) quotient, remainder = divmod(k, len(arr))
ans = arr * quotient + r.sample(arr, k=remainder) ans = arr * quotient + r.sample(arr, k=remainder)
...@@ -459,11 +522,15 @@ class ModelConfigForExpertLocation: ...@@ -459,11 +522,15 @@ class ModelConfigForExpertLocation:
def compute_initial_expert_location_metadata( def compute_initial_expert_location_metadata(
server_args: ServerArgs, model_config: ModelConfig server_args: ServerArgs,
model_config: ModelConfig,
moe_ep_rank: int,
) -> Optional[ExpertLocationMetadata]: ) -> Optional[ExpertLocationMetadata]:
data = server_args.init_expert_location data = server_args.init_expert_location
if data == "trivial": if data == "trivial":
return ExpertLocationMetadata.init_trivial(server_args, model_config) return ExpertLocationMetadata.init_trivial(
server_args, model_config, moe_ep_rank
)
# TODO unify with the utils function # TODO unify with the utils function
if data.endswith(".pt"): if data.endswith(".pt"):
...@@ -478,7 +545,10 @@ def compute_initial_expert_location_metadata( ...@@ -478,7 +545,10 @@ def compute_initial_expert_location_metadata(
"init_expert_location from init_by_mapping using ServerArgs.init_expert_location" "init_expert_location from init_by_mapping using ServerArgs.init_expert_location"
) )
return ExpertLocationMetadata.init_by_mapping( return ExpertLocationMetadata.init_by_mapping(
server_args, model_config, **data_dict server_args,
model_config,
**data_dict,
moe_ep_rank=moe_ep_rank,
) )
elif "logical_count" in data_dict: elif "logical_count" in data_dict:
logger.info( logger.info(
......
...@@ -348,7 +348,11 @@ class ModelRunner: ...@@ -348,7 +348,11 @@ class ModelRunner:
if not self.is_draft_worker: if not self.is_draft_worker:
set_global_expert_location_metadata( set_global_expert_location_metadata(
compute_initial_expert_location_metadata(server_args, self.model_config) compute_initial_expert_location_metadata(
server_args=server_args,
model_config=self.model_config,
moe_ep_rank=self.moe_ep_rank,
)
) )
if self.tp_rank == 0 and get_bool_env_var( if self.tp_rank == 0 and get_bool_env_var(
"SGLANG_LOG_EXPERT_LOCATION_METADATA" "SGLANG_LOG_EXPERT_LOCATION_METADATA"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment