Unverified Commit b6d0ce9f authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Minor add metrics to expert location updater (#6816)

parent 0ea330ca
...@@ -12,8 +12,10 @@ ...@@ -12,8 +12,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import logging import logging
from typing import Dict, List, Tuple from collections import defaultdict
from typing import Dict, List, Optional, Tuple
import einops
import torch import torch
import torch.distributed import torch.distributed
from torch.distributed import P2POp from torch.distributed import P2POp
...@@ -22,6 +24,7 @@ from sglang.srt.managers.expert_location import ( ...@@ -22,6 +24,7 @@ from sglang.srt.managers.expert_location import (
ExpertLocationMetadata, ExpertLocationMetadata,
get_global_expert_location_metadata, get_global_expert_location_metadata,
) )
from sglang.srt.utils import get_bool_env_var
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -59,6 +62,8 @@ def _update_expert_weights( ...@@ -59,6 +62,8 @@ def _update_expert_weights(
nnodes: int, nnodes: int,
rank: int, rank: int,
): ):
log_metrics = get_bool_env_var("SGLANG_EXPERT_LOCATION_UPDATER_LOG_METRICS")
temp_buffers = create_temp_buffers( temp_buffers = create_temp_buffers(
next(iter(routed_experts_weights_of_layer.values())) next(iter(routed_experts_weights_of_layer.values()))
) )
...@@ -83,6 +88,8 @@ def _update_expert_weights( ...@@ -83,6 +88,8 @@ def _update_expert_weights(
num_local_physical_experts=num_local_physical_experts, num_local_physical_experts=num_local_physical_experts,
num_gpu_per_node=num_gpu_per_node, num_gpu_per_node=num_gpu_per_node,
rank=rank, rank=rank,
world_size=world_size,
log_metrics=log_metrics,
) )
...@@ -98,7 +105,9 @@ def update_expert_weights_single_layer( ...@@ -98,7 +105,9 @@ def update_expert_weights_single_layer(
num_local_physical_experts: int, num_local_physical_experts: int,
num_gpu_per_node: int, num_gpu_per_node: int,
rank: int, rank: int,
world_size: Optional[int] = None,
debug: bool = False, debug: bool = False,
log_metrics: bool = False,
): ):
assert all( assert all(
tensor.shape[0] == num_local_physical_experts tensor.shape[0] == num_local_physical_experts
...@@ -130,6 +139,14 @@ def update_expert_weights_single_layer( ...@@ -130,6 +139,14 @@ def update_expert_weights_single_layer(
_execute_p2p_ops(p2p_op_infos) _execute_p2p_ops(p2p_op_infos)
_execute_buffer2weight_copies(buffer2weight_copy_infos) _execute_buffer2weight_copies(buffer2weight_copy_infos)
if log_metrics:
_log_p2p_op_metrics(
p2p_op_infos,
world_size=world_size,
num_gpu_per_node=num_gpu_per_node,
self_node_id=self_node_id,
)
if debug: if debug:
output_logs.append(f"{p2p_op_infos=}") output_logs.append(f"{p2p_op_infos=}")
output_logs.append(f"{buffer2weight_copy_infos=}") output_logs.append(f"{buffer2weight_copy_infos=}")
...@@ -429,3 +446,53 @@ def _deduplicate_ordered(arr: List[int]): ...@@ -429,3 +446,53 @@ def _deduplicate_ordered(arr: List[int]):
if len(output) == 0 or item != output[-1]: if len(output) == 0 or item != output[-1]:
output.append(item) output.append(item)
return output return output
def _log_p2p_op_metrics(
p2p_op_infos: List[Tuple[int, List[P2POp]]],
num_gpu_per_node: int,
world_size: int,
self_node_id: int,
):
text = ""
all_ops = [op for _, ops in p2p_op_infos for op in ops]
for direction, ops in _group_by(all_ops, _get_direction_from_op).items():
nbytes_of_gpu = [0] * world_size
for op in ops:
nbytes_of_gpu[op.peer] += op.tensor.nbytes
nbytes_of_gpu = torch.tensor(nbytes_of_gpu, dtype=torch.int64)
nbytes_of_node = einops.reduce(
nbytes_of_gpu,
"(num_nodes num_gpu_per_node) -> num_nodes",
num_gpu_per_node=num_gpu_per_node,
reduction="sum",
)
nbytes_curr_node = nbytes_of_node[self_node_id]
nbytes_cross_node = torch.sum(nbytes_of_node) - nbytes_curr_node
text += (
f"{direction}_nbytes_of_gpu={nbytes_of_gpu.tolist()} "
f"{direction}_nbytes_of_node={nbytes_of_node.tolist()} "
f"{direction}_nbytes_curr_node={nbytes_curr_node.item()} "
f"{direction}_nbytes_cross_node={nbytes_cross_node.item()} "
)
logger.info(f"[ExpertLocationUpdater] {text}")
def _get_direction_from_op(op: P2POp):
if op.op == torch.distributed.isend:
return "isend"
if op.op == torch.distributed.irecv:
return "irecv"
raise NotImplementedError
def _group_by(items, keyfunc):
ans = defaultdict(list)
for item in items:
ans[keyfunc(item)].append(item)
return dict(ans)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment