Unverified Commit 9484eba4 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support logging expert balancedness metrics (#6482)

parent e9feb488
......@@ -15,10 +15,12 @@ import logging
import os
import time
from abc import ABC
from collections import deque
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, List, Literal, Optional, Tuple, Type
import einops
import torch
import torch.distributed
......@@ -472,7 +474,80 @@ class _Accumulator(ABC):
pass
class _StatAccumulator(_Accumulator):
class _UtilizationRateAccumulatorMixin(_Accumulator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._enable = self._server_args.enable_expert_distribution_metrics
if self._enable:
window_sizes = [10, 100, 1000]
self._history = _DequeCollection(maxlens=window_sizes)
self._rank = torch.distributed.get_rank()
def append(
self,
forward_pass_id: int,
gatherer_key: str,
single_pass_data: Dict,
):
super().append(forward_pass_id, gatherer_key, single_pass_data)
if self._enable:
self._append_utilization_rate(
forward_pass_id, single_pass_data["global_physical_count"]
)
def reset(self):
super().reset()
if self._enable:
self._history.clear()
def _append_utilization_rate(
self, forward_pass_id: int, single_pass_global_physical_count: torch.Tensor
):
gpu_physical_count = compute_gpu_physical_count(
single_pass_global_physical_count,
num_gpu=self._expert_location_metadata.ep_size,
)
gpu_physical_count = gpu_physical_count.to(self._server_args.device)
torch.distributed.reduce(
gpu_physical_count, dst=0, op=torch.distributed.ReduceOp.SUM
)
if self._rank == 0:
utilization_rate_tensor = compute_utilization_rate(gpu_physical_count)
utilization_rate = torch.mean(utilization_rate_tensor).item()
self._history.append(utilization_rate)
gpu_physical_count_sum = gpu_physical_count.sum().item()
logger.info(
f"[Expert Balancedness] "
f"forward_pass_id={forward_pass_id} "
f"current_pass_balancedness={utilization_rate:.03f} "
f"{''.join(f'last_{size}_average_balancedness={value:.03f} ' for size, value in self._history.mean().items())} "
f"gpu_physical_count_sum={gpu_physical_count_sum}"
# f"current_pass_per_layer={[round(x, 2) for x in utilization_rate_tensor.cpu().tolist()]}"
)
class _DequeCollection:
def __init__(self, maxlens: List[int]):
self._dequeues = [deque(maxlen=maxlen) for maxlen in maxlens]
def append(self, value):
for d in self._dequeues:
d.append(value)
def clear(self):
for d in self._dequeues:
d.clear()
def mean(self) -> Dict[int, float]:
return {d.maxlen: sum(d) / len(d) for d in self._dequeues}
class _StatAccumulator(_UtilizationRateAccumulatorMixin):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._global_physical_count_of_buffered_step = _Buffer.init_new(
......@@ -619,3 +694,34 @@ def _convert_global_physical_count_to_logical_count(
src=global_physical_count,
)
return logical_count
def compute_gpu_physical_count(
physical_count_of_whatever: torch.Tensor, # (..., num_layer, num_physical_expert)
num_gpu: int,
):
"""output: gpu_physical_count_of_batch (..., num_layer, num_gpu)"""
return einops.reduce(
physical_count_of_whatever,
"... num_layer (num_gpu num_expert_per_gpu) -> ... num_layer num_gpu",
"sum",
num_gpu=num_gpu,
)
def compute_utilization_rate(
gpu_physical_count_of_batch: torch.Tensor, # (..., num_layer, num_gpu)
):
"""output: utilization_rate (..., num_layer)"""
gpu_physical_count_of_batch = gpu_physical_count_of_batch.float()
max_gpu_physical_count = einops.reduce(
gpu_physical_count_of_batch,
"... num_layer num_gpu -> ... num_layer",
"max",
)
avg_gpu_physical_count = einops.reduce(
gpu_physical_count_of_batch,
"... num_layer num_gpu -> ... num_layer",
"mean",
)
return (avg_gpu_physical_count + 1e-5) / (max_gpu_physical_count + 1e-5)
......@@ -177,6 +177,7 @@ class ServerArgs:
Literal["stat", "per_pass", "per_token"]
] = None
expert_distribution_recorder_buffer_size: Optional[int] = None
enable_expert_distribution_metrics: bool = False
deepep_config: Optional[str] = None
enable_torch_compile: bool = False
torch_compile_max_bs: int = 32
......@@ -1304,6 +1305,11 @@ class ServerArgs:
default=ServerArgs.expert_distribution_recorder_buffer_size,
help="Circular buffer size of expert distribution recorder. Set to -1 to denote infinite buffer.",
)
parser.add_argument(
"--enable-expert-distribution-metrics",
action="store_true",
help="Enable logging metrics for expert balancedness",
)
parser.add_argument(
"--deepep-config",
type=str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment