Unverified Commit 89824189 authored by vipwangerxiao's avatar vipwangerxiao Committed by GitHub
Browse files

Fix 'KeyError' for per_token expert distribution recorder (#9501)


Signed-off-by: default avatarPeng Wang <rocking@linux.alibaba.com>
Co-authored-by: default avatarPeng Wang <rocking@linux.alibaba.com>
parent 433c622e
......@@ -415,10 +415,19 @@ class _DetailSinglePassGatherer(_SinglePassGatherer):
def collect(self) -> Dict:
num_tokens = len(self._metadata["input_ids"])
global_physical_count = _convert_per_token_to_global_physical_count(
num_tokens,
num_layers=self._expert_location_metadata.num_layers,
num_physical_experts=self._expert_location_metadata.num_physical_experts,
_topk_ids_of_layer=self._topk_ids_of_layer,
)
return dict(
**self._metadata,
topk_ids_of_layer=self._topk_ids_of_layer[:, :num_tokens, :].clone().cpu(),
misc_objects=self._misc_objects,
global_physical_count=global_physical_count,
)
......@@ -547,6 +556,27 @@ class _DeepepLowLatencySinglePassGatherer(_LayerBasedGpuSinglePassGatherer):
self._data[layer_idx, :] += local_physical_count_of_layer
def _convert_per_token_to_global_physical_count(
num_tokens: int,
num_layers: int,
num_physical_experts: int,
_topk_ids_of_layer: torch.Tensor,
) -> torch.Tensor:
topk_ids_layer_major = _topk_ids_of_layer[:, :num_tokens, :].reshape(num_layers, -1)
mask = topk_ids_layer_major != -1
index = topk_ids_layer_major.masked_fill(~mask, 0).long()
src = mask.int()
ans = torch.zeros(
(num_layers, num_physical_experts),
dtype=_topk_ids_of_layer.dtype,
device=_topk_ids_of_layer.device,
)
ans.scatter_add_(dim=1, index=index, src=src)
return ans
def _convert_local_to_global_physical_count(
local_physical_count: torch.Tensor,
rank: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment