Unverified Commit ec92b0ce authored by Yingchun Lai's avatar Yingchun Lai Committed by GitHub
Browse files

EPLB: prefer to use physical experts in the same gpu or node (#10874)

parent e03b6bee
......@@ -85,7 +85,9 @@ class ExpertLocationMetadata:
# -------------------------------- construction ------------------------------------
@staticmethod
def init_trivial(server_args: ServerArgs, model_config: ModelConfig):
def init_trivial(
server_args: ServerArgs, model_config: ModelConfig, moe_ep_rank: int
):
"""Trivial location - logical expert i corresponds to physical expert i"""
common = ExpertLocationMetadata._init_common(server_args, model_config)
......@@ -106,6 +108,7 @@ class ExpertLocationMetadata:
server_args,
model_config,
physical_to_logical_map=physical_to_logical_map,
moe_ep_rank=moe_ep_rank,
)
@staticmethod
......@@ -113,6 +116,7 @@ class ExpertLocationMetadata:
server_args: ServerArgs,
model_config: ModelConfig,
physical_to_logical_map,
moe_ep_rank: int = None,
):
if not isinstance(physical_to_logical_map, torch.Tensor):
physical_to_logical_map = torch.tensor(physical_to_logical_map)
......@@ -125,8 +129,11 @@ class ExpertLocationMetadata:
model_config_for_expert_location = common["model_config_for_expert_location"]
logical_to_all_physical_map = _compute_logical_to_all_physical_map(
physical_to_logical_map,
server_args=server_args,
physical_to_logical_map=physical_to_logical_map,
num_logical_experts=model_config_for_expert_location.num_logical_experts,
ep_size=common["ep_size"],
moe_ep_rank=moe_ep_rank,
)
return ExpertLocationMetadata._init_raw(
......@@ -233,7 +240,7 @@ class ExpertLocationMetadata:
compute_logical_to_rank_dispatch_physical_map(
server_args=server_args,
logical_to_all_physical_map=logical_to_all_physical_map,
num_gpus=ep_size,
ep_size=ep_size,
num_physical_experts=num_physical_experts,
# TODO improve when we have real EP rank
ep_rank=torch.distributed.get_rank() % ep_size,
......@@ -303,7 +310,11 @@ def set_global_expert_location_metadata(value):
def _compute_logical_to_all_physical_map(
physical_to_logical_map: torch.Tensor, num_logical_experts: int
server_args: ServerArgs,
physical_to_logical_map: torch.Tensor,
num_logical_experts: int,
ep_size: int,
moe_ep_rank: int,
):
# This is rarely called, so we use for loops for maximum clarity
......@@ -312,6 +323,8 @@ def _compute_logical_to_all_physical_map(
logical_to_all_physical_map = [
[[] for _ in range(num_logical_experts)] for _ in range(num_layers)
]
# Find out the candidate physical experts for each logical expert on each layer
for layer_id in range(num_layers):
for physical_expert_id in range(num_physical_experts):
logical_expert_id = physical_to_logical_map[
......@@ -321,6 +334,32 @@ def _compute_logical_to_all_physical_map(
physical_expert_id
)
# Replace by the physical expert on local GPU or node if possible
if moe_ep_rank is not None:
num_gpus_per_node = server_args.ep_size // server_args.nnodes
num_local_gpu_physical_experts = num_physical_experts // ep_size
num_local_node_physical_experts = (
num_local_gpu_physical_experts * num_gpus_per_node
)
for layer_id in range(num_layers):
for logical_expert_id in range(num_logical_experts):
# Try to find the nearest physical expert
nearest_expert = _find_nearest_expert(
candidate_physical_expert_ids=logical_to_all_physical_map[layer_id][
logical_expert_id
],
num_local_gpu_physical_experts=num_local_gpu_physical_experts,
moe_ep_rank=moe_ep_rank,
num_gpus_per_node=num_gpus_per_node,
num_local_node_physical_experts=num_local_node_physical_experts,
)
# Replace by the nearest physical expert
if nearest_expert != -1:
logical_to_all_physical_map[layer_id][logical_expert_id] = [
nearest_expert
]
logical_to_all_physical_map = _pad_nested_array(
logical_to_all_physical_map, pad_value=-1
)
......@@ -343,21 +382,21 @@ def _pad_nested_array(arr, pad_value):
def compute_logical_to_rank_dispatch_physical_map(
server_args: ServerArgs,
logical_to_all_physical_map: torch.Tensor,
num_gpus: int,
ep_size: int,
num_physical_experts: int,
ep_rank: int,
seed: int = 42,
):
r = random.Random(seed)
num_local_gpu_physical_experts = num_physical_experts // num_gpus
num_local_gpu_physical_experts = num_physical_experts // ep_size
num_gpus_per_node = server_args.ep_size // server_args.nnodes
num_local_node_physical_experts = num_local_gpu_physical_experts * num_gpus_per_node
num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape
dtype = logical_to_all_physical_map.dtype
logical_to_rank_dispatch_physical_map = torch.full(
size=(num_gpus, num_layers, num_logical_experts),
size=(ep_size, num_layers, num_logical_experts),
fill_value=-1,
dtype=dtype,
)
......@@ -371,33 +410,17 @@ def compute_logical_to_rank_dispatch_physical_map(
:, layer_id, logical_expert_id
]
for gpu_id in range(num_gpus):
same_gpu_physical_expert_ids = [
physical_expert_id
for physical_expert_id in candidate_physical_expert_ids
if _compute_gpu_id_of_physical_expert(
physical_expert_id, num_local_gpu_physical_experts
)
== gpu_id
]
if len(same_gpu_physical_expert_ids) > 0:
# 1. Prefer same-GPU experts
output_partial[gpu_id] = same_gpu_physical_expert_ids[0]
else:
# 2. Otherwise, prefer same-node experts
node_id = gpu_id // num_gpus_per_node
same_node_physical_expert_ids = [
physical_expert_id
for physical_expert_id in candidate_physical_expert_ids
if _compute_node_id_of_physical_expert(
physical_expert_id, num_local_node_physical_experts
for moe_ep_rank in range(ep_size):
# Fill with the nearest physical expert
output_partial[moe_ep_rank] = _find_nearest_expert(
candidate_physical_expert_ids=candidate_physical_expert_ids,
num_local_gpu_physical_experts=num_local_gpu_physical_experts,
moe_ep_rank=moe_ep_rank,
num_gpus_per_node=num_gpus_per_node,
num_local_node_physical_experts=num_local_node_physical_experts,
)
== node_id
]
if len(same_node_physical_expert_ids) > 0:
output_partial[gpu_id] = same_node_physical_expert_ids[0]
# 3. Fill remaining slots with fair random choices
# Fill remaining slots with fair random choices
num_remain = torch.sum(output_partial == -1).item()
output_partial[output_partial == -1] = torch.tensor(
_fair_choices(candidate_physical_expert_ids, k=num_remain, r=r),
......@@ -434,6 +457,46 @@ def _compute_node_id_of_physical_expert(
return physical_expert_id // num_local_host_physical_experts
def _find_nearest_expert(
candidate_physical_expert_ids: List[int],
num_local_gpu_physical_experts: int,
moe_ep_rank: int,
num_gpus_per_node: int,
num_local_node_physical_experts: int,
) -> int:
# 1. If only one candidate, return it directly
if len(candidate_physical_expert_ids) == 1:
return candidate_physical_expert_ids[0]
# 2. Prefer same-GPU experts
same_gpu_physical_expert_ids = [
physical_expert_id
for physical_expert_id in candidate_physical_expert_ids
if _compute_gpu_id_of_physical_expert(
physical_expert_id, num_local_gpu_physical_experts
)
== moe_ep_rank
]
if len(same_gpu_physical_expert_ids) > 0:
return same_gpu_physical_expert_ids[0]
# 3. Otherwise, prefer same-node experts
node_rank = moe_ep_rank // num_gpus_per_node
same_node_physical_expert_ids = [
physical_expert_id
for physical_expert_id in candidate_physical_expert_ids
if _compute_node_id_of_physical_expert(
physical_expert_id, num_local_node_physical_experts
)
== node_rank
]
if len(same_node_physical_expert_ids) > 0:
return same_node_physical_expert_ids[0]
# 4. At last, leave it as -1 to indicate not found.
return -1
def _fair_choices(arr: List, k: int, r: random.Random) -> List:
quotient, remainder = divmod(k, len(arr))
ans = arr * quotient + r.sample(arr, k=remainder)
......@@ -459,11 +522,15 @@ class ModelConfigForExpertLocation:
def compute_initial_expert_location_metadata(
server_args: ServerArgs, model_config: ModelConfig
server_args: ServerArgs,
model_config: ModelConfig,
moe_ep_rank: int,
) -> Optional[ExpertLocationMetadata]:
data = server_args.init_expert_location
if data == "trivial":
return ExpertLocationMetadata.init_trivial(server_args, model_config)
return ExpertLocationMetadata.init_trivial(
server_args, model_config, moe_ep_rank
)
# TODO unify with the utils function
if data.endswith(".pt"):
......@@ -478,7 +545,10 @@ def compute_initial_expert_location_metadata(
"init_expert_location from init_by_mapping using ServerArgs.init_expert_location"
)
return ExpertLocationMetadata.init_by_mapping(
server_args, model_config, **data_dict
server_args,
model_config,
**data_dict,
moe_ep_rank=moe_ep_rank,
)
elif "logical_count" in data_dict:
logger.info(
......
......@@ -348,7 +348,11 @@ class ModelRunner:
if not self.is_draft_worker:
set_global_expert_location_metadata(
compute_initial_expert_location_metadata(server_args, self.model_config)
compute_initial_expert_location_metadata(
server_args=server_args,
model_config=self.model_config,
moe_ep_rank=self.moe_ep_rank,
)
)
if self.tp_rank == 0 and get_bool_env_var(
"SGLANG_LOG_EXPERT_LOCATION_METADATA"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment