Unverified Commit f6ebba53 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support both approximate and exact expert distribution collection (#6964)

parent 6716b417
......@@ -264,15 +264,23 @@ class _SinglePassGatherer(ABC):
return _DetailSinglePassGatherer(
server_args, expert_location_metadata, rank
)
if server_args.expert_distribution_recorder_mode == "stat_approx":
if server_args.enable_deepep_moe and (server_args.deepep_mode == "normal"):
return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)
else:
raise NotImplementedError
if server_args.enable_deepep_moe:
if server_args.deepep_mode == "normal":
return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
elif server_args.deepep_mode == "low_latency":
return _DeepepLowLatencySinglePassGatherer(
expert_location_metadata, rank
)
else:
raise NotImplementedError
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int):
......@@ -347,7 +355,9 @@ class _DetailSinglePassGatherer(_SinglePassGatherer):
)
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
self._topk_ids_of_layer[layer_idx, : topk_ids.shape[0], :] = topk_ids
self._topk_ids_of_layer[layer_idx, : topk_ids.shape[0], : topk_ids.shape[1]] = (
topk_ids
)
def on_deepep_dispatch_normal(
self,
......@@ -380,7 +390,7 @@ class _DetailSinglePassGatherer(_SinglePassGatherer):
)
class _LayerBasedSinglePassGatherer(_SinglePassGatherer):
class _LayerBasedCpuSinglePassGatherer(_SinglePassGatherer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._objects_of_layer = {}
......@@ -409,29 +419,63 @@ def _list_sum(a: List, b: List) -> List:
return [x + y for x, y in zip(a, b, strict=True)]
class _SelectExpertsSinglePassGatherer(_LayerBasedSinglePassGatherer):
# pretty slow, but we will use the DeepEP Gatherer in production
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist()
torch.cuda.synchronize()
global_physical_count = [
0
] * self._expert_location_metadata.num_physical_experts
for token_record in topk_ids_list:
for global_physical_expert_idx in token_record:
global_physical_count[global_physical_expert_idx] += 1
class _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer):
def __init__(self, *args, enable_global_physical_experts: bool, **kwargs):
super().__init__(*args, **kwargs)
self._enable_global_physical_experts = enable_global_physical_experts
self._data = torch.zeros(
(
self._expert_location_metadata.num_layers,
(
self._expert_location_metadata.num_physical_experts
if enable_global_physical_experts
else self._expert_location_metadata.num_local_physical_experts
),
),
dtype=torch.int,
device="cuda",
)
self._on_layer_data(layer_idx, global_physical_count)
def reset(self):
self._data[...] = 0
def collect(self) -> Dict:
global_physical_count = super()._collect_objects(
pad_len=self._expert_location_metadata.num_physical_experts
)
if self._enable_global_physical_experts:
global_physical_count = self._data
else:
# Can optimize if bottleneck
global_physical_count = _convert_local_to_global_physical_count(
self._data,
rank=self._rank,
num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts,
num_physical_experts=self._expert_location_metadata.num_physical_experts,
)
return dict(global_physical_count=global_physical_count)
class _DeepepNormalSinglePassGatherer(_LayerBasedSinglePassGatherer):
class _SelectExpertsSinglePassGatherer(_LayerBasedGpuSinglePassGatherer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, enable_global_physical_experts=True)
# can optimize (e.g. fuse / compile)
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
topk_ids = topk_ids.flatten()
mask = topk_ids != -1
self._data[layer_idx, :].scatter_add_(
dim=0, index=topk_ids.masked_fill(~mask, 0).long(), src=mask.int()
)
class _DeepepNormalSinglePassGatherer(_LayerBasedCpuSinglePassGatherer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if torch.distributed.get_rank() == 0:
logger.info(
"DeepepNormalSinglePassGatherer gathers approximate statistics. "
"If used with small batch size, consider using expert_distribution_recorder_mode=stat."
)
def on_deepep_dispatch_normal(
self,
layer_idx: int,
......@@ -456,17 +500,9 @@ class _DeepepNormalSinglePassGatherer(_LayerBasedSinglePassGatherer):
return dict(global_physical_count=global_physical_count)
class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer):
class _DeepepLowLatencySinglePassGatherer(_LayerBasedGpuSinglePassGatherer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._data = torch.zeros(
(
self._expert_location_metadata.num_layers,
self._expert_location_metadata.num_local_physical_experts,
),
dtype=torch.int,
device="cuda",
)
super().__init__(*args, **kwargs, enable_global_physical_experts=False)
def on_deepep_dispatch_low_latency(
self, layer_idx: int, local_physical_count_of_layer: torch.Tensor
......@@ -474,19 +510,6 @@ class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer):
# Most naive implementation, can optimize later
self._data[layer_idx, :] += local_physical_count_of_layer
def reset(self):
self._data[...] = 0
def collect(self) -> Dict:
# Can optimize if bottleneck
global_physical_count = _convert_local_to_global_physical_count(
self._data,
rank=self._rank,
num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts,
num_physical_experts=self._expert_location_metadata.num_physical_experts,
)
return dict(global_physical_count=global_physical_count)
def _convert_local_to_global_physical_count(
local_physical_count: torch.Tensor,
......@@ -525,6 +548,7 @@ class _Accumulator(ABC):
def get_class(server_args: ServerArgs) -> Type["_Accumulator"]:
return {
"stat": _StatAccumulator,
"stat_approx": _StatAccumulator,
"per_pass": _DetailAccumulator,
"per_token": _DetailAccumulator,
}[server_args.expert_distribution_recorder_mode]
......
......@@ -460,22 +460,25 @@ class DeepseekV2MoE(nn.Module):
hidden_states = state.hidden_states_mlp_input
if router_logits is not None:
state.topk_weights_local, state.topk_idx_local = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=state.forward_batch.num_token_non_padded,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
),
)
with get_global_expert_distribution_recorder().with_current_layer(
self.layer_id
):
state.topk_weights_local, state.topk_idx_local = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=state.forward_batch.num_token_non_padded,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
),
)
else:
state.topk_idx_local = torch.full(
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
......
......@@ -255,17 +255,20 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
router_logits = state.pop("router_logits")
hidden_states = state.hidden_states_mlp_input
if router_logits is not None:
state.topk_weights_local, state.topk_idx_local = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=self.renormalize,
num_token_non_padded=state.forward_batch.num_token_non_padded,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
),
)
with get_global_expert_distribution_recorder().with_current_layer(
self.layer_id
):
state.topk_weights_local, state.topk_idx_local = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=self.renormalize,
num_token_non_padded=state.forward_batch.num_token_non_padded,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
),
)
else:
state.topk_idx_local = torch.full(
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
......
......@@ -182,7 +182,7 @@ class ServerArgs:
eplb_rebalance_num_iterations: int = 1000
eplb_rebalance_layers_per_chunk: Optional[int] = None
expert_distribution_recorder_mode: Optional[
Literal["stat", "per_pass", "per_token"]
Literal["stat", "stat_approx", "per_pass", "per_token"]
] = None
expert_distribution_recorder_buffer_size: Optional[int] = None
enable_expert_distribution_metrics: bool = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment