Unverified Commit f6ebba53 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support both approximate and exact expert distribution collection (#6964)

parent 6716b417
...@@ -264,15 +264,23 @@ class _SinglePassGatherer(ABC): ...@@ -264,15 +264,23 @@ class _SinglePassGatherer(ABC):
return _DetailSinglePassGatherer( return _DetailSinglePassGatherer(
server_args, expert_location_metadata, rank server_args, expert_location_metadata, rank
) )
if server_args.expert_distribution_recorder_mode == "stat_approx":
if server_args.enable_deepep_moe and (server_args.deepep_mode == "normal"):
return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)
else:
raise NotImplementedError
if server_args.enable_deepep_moe: if server_args.enable_deepep_moe:
if server_args.deepep_mode == "normal": if server_args.deepep_mode == "normal":
return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank) return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
elif server_args.deepep_mode == "low_latency": elif server_args.deepep_mode == "low_latency":
return _DeepepLowLatencySinglePassGatherer( return _DeepepLowLatencySinglePassGatherer(
expert_location_metadata, rank expert_location_metadata, rank
) )
else: else:
raise NotImplementedError raise NotImplementedError
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank) return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int): def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int):
...@@ -347,7 +355,9 @@ class _DetailSinglePassGatherer(_SinglePassGatherer): ...@@ -347,7 +355,9 @@ class _DetailSinglePassGatherer(_SinglePassGatherer):
) )
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
self._topk_ids_of_layer[layer_idx, : topk_ids.shape[0], :] = topk_ids self._topk_ids_of_layer[layer_idx, : topk_ids.shape[0], : topk_ids.shape[1]] = (
topk_ids
)
def on_deepep_dispatch_normal( def on_deepep_dispatch_normal(
self, self,
...@@ -380,7 +390,7 @@ class _DetailSinglePassGatherer(_SinglePassGatherer): ...@@ -380,7 +390,7 @@ class _DetailSinglePassGatherer(_SinglePassGatherer):
) )
class _LayerBasedSinglePassGatherer(_SinglePassGatherer): class _LayerBasedCpuSinglePassGatherer(_SinglePassGatherer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._objects_of_layer = {} self._objects_of_layer = {}
...@@ -409,29 +419,63 @@ def _list_sum(a: List, b: List) -> List: ...@@ -409,29 +419,63 @@ def _list_sum(a: List, b: List) -> List:
return [x + y for x, y in zip(a, b, strict=True)] return [x + y for x, y in zip(a, b, strict=True)]
class _SelectExpertsSinglePassGatherer(_LayerBasedSinglePassGatherer): class _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer):
# pretty slow, but we will use the DeepEP Gatherer in production def __init__(self, *args, enable_global_physical_experts: bool, **kwargs):
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): super().__init__(*args, **kwargs)
topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist() self._enable_global_physical_experts = enable_global_physical_experts
torch.cuda.synchronize() self._data = torch.zeros(
(
global_physical_count = [ self._expert_location_metadata.num_layers,
0 (
] * self._expert_location_metadata.num_physical_experts self._expert_location_metadata.num_physical_experts
for token_record in topk_ids_list: if enable_global_physical_experts
for global_physical_expert_idx in token_record: else self._expert_location_metadata.num_local_physical_experts
global_physical_count[global_physical_expert_idx] += 1 ),
),
dtype=torch.int,
device="cuda",
)
self._on_layer_data(layer_idx, global_physical_count) def reset(self):
self._data[...] = 0
def collect(self) -> Dict: def collect(self) -> Dict:
global_physical_count = super()._collect_objects( if self._enable_global_physical_experts:
pad_len=self._expert_location_metadata.num_physical_experts global_physical_count = self._data
) else:
# Can optimize if bottleneck
global_physical_count = _convert_local_to_global_physical_count(
self._data,
rank=self._rank,
num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts,
num_physical_experts=self._expert_location_metadata.num_physical_experts,
)
return dict(global_physical_count=global_physical_count) return dict(global_physical_count=global_physical_count)
class _DeepepNormalSinglePassGatherer(_LayerBasedSinglePassGatherer): class _SelectExpertsSinglePassGatherer(_LayerBasedGpuSinglePassGatherer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, enable_global_physical_experts=True)
# can optimize (e.g. fuse / compile)
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
topk_ids = topk_ids.flatten()
mask = topk_ids != -1
self._data[layer_idx, :].scatter_add_(
dim=0, index=topk_ids.masked_fill(~mask, 0).long(), src=mask.int()
)
class _DeepepNormalSinglePassGatherer(_LayerBasedCpuSinglePassGatherer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if torch.distributed.get_rank() == 0:
logger.info(
"DeepepNormalSinglePassGatherer gathers approximate statistics. "
"If used with small batch size, consider using expert_distribution_recorder_mode=stat."
)
def on_deepep_dispatch_normal( def on_deepep_dispatch_normal(
self, self,
layer_idx: int, layer_idx: int,
...@@ -456,17 +500,9 @@ class _DeepepNormalSinglePassGatherer(_LayerBasedSinglePassGatherer): ...@@ -456,17 +500,9 @@ class _DeepepNormalSinglePassGatherer(_LayerBasedSinglePassGatherer):
return dict(global_physical_count=global_physical_count) return dict(global_physical_count=global_physical_count)
class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer): class _DeepepLowLatencySinglePassGatherer(_LayerBasedGpuSinglePassGatherer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs, enable_global_physical_experts=False)
self._data = torch.zeros(
(
self._expert_location_metadata.num_layers,
self._expert_location_metadata.num_local_physical_experts,
),
dtype=torch.int,
device="cuda",
)
def on_deepep_dispatch_low_latency( def on_deepep_dispatch_low_latency(
self, layer_idx: int, local_physical_count_of_layer: torch.Tensor self, layer_idx: int, local_physical_count_of_layer: torch.Tensor
...@@ -474,19 +510,6 @@ class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer): ...@@ -474,19 +510,6 @@ class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer):
# Most naive implementation, can optimize later # Most naive implementation, can optimize later
self._data[layer_idx, :] += local_physical_count_of_layer self._data[layer_idx, :] += local_physical_count_of_layer
def reset(self):
self._data[...] = 0
def collect(self) -> Dict:
# Can optimize if bottleneck
global_physical_count = _convert_local_to_global_physical_count(
self._data,
rank=self._rank,
num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts,
num_physical_experts=self._expert_location_metadata.num_physical_experts,
)
return dict(global_physical_count=global_physical_count)
def _convert_local_to_global_physical_count( def _convert_local_to_global_physical_count(
local_physical_count: torch.Tensor, local_physical_count: torch.Tensor,
...@@ -525,6 +548,7 @@ class _Accumulator(ABC): ...@@ -525,6 +548,7 @@ class _Accumulator(ABC):
def get_class(server_args: ServerArgs) -> Type["_Accumulator"]: def get_class(server_args: ServerArgs) -> Type["_Accumulator"]:
return { return {
"stat": _StatAccumulator, "stat": _StatAccumulator,
"stat_approx": _StatAccumulator,
"per_pass": _DetailAccumulator, "per_pass": _DetailAccumulator,
"per_token": _DetailAccumulator, "per_token": _DetailAccumulator,
}[server_args.expert_distribution_recorder_mode] }[server_args.expert_distribution_recorder_mode]
......
...@@ -460,22 +460,25 @@ class DeepseekV2MoE(nn.Module): ...@@ -460,22 +460,25 @@ class DeepseekV2MoE(nn.Module):
hidden_states = state.hidden_states_mlp_input hidden_states = state.hidden_states_mlp_input
if router_logits is not None: if router_logits is not None:
state.topk_weights_local, state.topk_idx_local = select_experts( with get_global_expert_distribution_recorder().with_current_layer(
hidden_states=hidden_states, self.layer_id
router_logits=router_logits, ):
top_k=self.top_k, state.topk_weights_local, state.topk_idx_local = select_experts(
use_grouped_topk=True, hidden_states=hidden_states,
renormalize=self.renormalize, router_logits=router_logits,
topk_group=self.topk_group, top_k=self.top_k,
num_expert_group=self.num_expert_group, use_grouped_topk=True,
num_fused_shared_experts=self.num_fused_shared_experts, renormalize=self.renormalize,
correction_bias=self.correction_bias, topk_group=self.topk_group,
routed_scaling_factor=self.routed_scaling_factor, num_expert_group=self.num_expert_group,
num_token_non_padded=state.forward_batch.num_token_non_padded, num_fused_shared_experts=self.num_fused_shared_experts,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( correction_bias=self.correction_bias,
layer_id=self.layer_id, routed_scaling_factor=self.routed_scaling_factor,
), num_token_non_padded=state.forward_batch.num_token_non_padded,
) expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
),
)
else: else:
state.topk_idx_local = torch.full( state.topk_idx_local = torch.full(
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
......
...@@ -255,17 +255,20 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -255,17 +255,20 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
router_logits = state.pop("router_logits") router_logits = state.pop("router_logits")
hidden_states = state.hidden_states_mlp_input hidden_states = state.hidden_states_mlp_input
if router_logits is not None: if router_logits is not None:
state.topk_weights_local, state.topk_idx_local = select_experts( with get_global_expert_distribution_recorder().with_current_layer(
hidden_states=hidden_states, self.layer_id
router_logits=router_logits, ):
top_k=self.top_k, state.topk_weights_local, state.topk_idx_local = select_experts(
use_grouped_topk=False, hidden_states=hidden_states,
renormalize=self.renormalize, router_logits=router_logits,
num_token_non_padded=state.forward_batch.num_token_non_padded, top_k=self.top_k,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( use_grouped_topk=False,
layer_id=self.layer_id, renormalize=self.renormalize,
), num_token_non_padded=state.forward_batch.num_token_non_padded,
) expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
),
)
else: else:
state.topk_idx_local = torch.full( state.topk_idx_local = torch.full(
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
......
...@@ -182,7 +182,7 @@ class ServerArgs: ...@@ -182,7 +182,7 @@ class ServerArgs:
eplb_rebalance_num_iterations: int = 1000 eplb_rebalance_num_iterations: int = 1000
eplb_rebalance_layers_per_chunk: Optional[int] = None eplb_rebalance_layers_per_chunk: Optional[int] = None
expert_distribution_recorder_mode: Optional[ expert_distribution_recorder_mode: Optional[
Literal["stat", "per_pass", "per_token"] Literal["stat", "stat_approx", "per_pass", "per_token"]
] = None ] = None
expert_distribution_recorder_buffer_size: Optional[int] = None expert_distribution_recorder_buffer_size: Optional[int] = None
enable_expert_distribution_metrics: bool = False enable_expert_distribution_metrics: bool = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment