Unverified Commit 87068b5c authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support gathering expert distribution details (#6665)

parent a564e001
......@@ -18,7 +18,7 @@ from abc import ABC
from collections import deque
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, List, Literal, Optional, Tuple, Type
from typing import Any, Dict, List, Literal, Optional, Tuple, Type
import einops
import torch
......@@ -293,6 +293,79 @@ class _SinglePassGatherer(ABC):
raise NotImplementedError
class _DetailSinglePassGatherer(_SinglePassGatherer):
# DeepSeek V3 has this value; should generalize later
_TOP_K_NUM = 8
def __init__(
self,
server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata",
rank: int,
):
super().__init__(expert_location_metadata, rank)
self._metadata: Optional[Dict[str, Any]] = None
self._topk_ids_of_layer = torch.zeros(
(
expert_location_metadata.num_layers,
# TODO determine the max number
server_args.chunked_prefill_size * 8,
self._TOP_K_NUM,
),
dtype=torch.int32,
device=server_args.device,
)
self._misc_objects: List[Dict[str, Any]] = []
assert (
not server_args.enable_two_batch_overlap
), "DetailSinglePassGatherer does not support TBO yet"
# TODO assert shared experts fusion is disabled, o/w data is wrong
def on_forward_pass_start(self, forward_batch: ForwardBatch):
assert self._metadata is None
self._metadata = dict(
# TODO pr-chain
# rids=forward_batch.rids,
input_ids=forward_batch.input_ids.cpu().tolist(),
positions=forward_batch.positions.cpu().tolist(),
extend_seq_lens=forward_batch.extend_seq_lens_cpu,
forward_mode=forward_batch.forward_mode.value,
)
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
self._topk_ids_of_layer[layer_idx, : topk_ids.shape[0], :] = topk_ids
def on_deepep_dispatch_normal(
self,
layer_idx: int,
local_physical_count_of_layer: List[int],
num_tokens_per_rank,
num_tokens_per_rdma_rank,
num_tokens_per_expert,
):
self._misc_objects.append(
dict(
layer_id=layer_idx,
num_tokens_per_rank=num_tokens_per_rank.cpu().tolist(),
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank.cpu().tolist(),
num_tokens_per_expert=num_tokens_per_expert.cpu().tolist(),
)
)
def reset(self):
self._topk_ids_of_layer[...] = -1
self._misc_objects.clear()
self._metadata = None
def collect(self) -> Dict:
num_tokens = len(self._metadata["input_ids"])
return dict(
**self._metadata,
topk_ids_of_layer=self._topk_ids_of_layer[:, :num_tokens, :].clone().cpu(),
misc_objects=self._misc_objects,
)
class _LayerBasedSinglePassGatherer(_SinglePassGatherer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
......@@ -438,9 +511,8 @@ class _Accumulator(ABC):
def get_class(server_args: ServerArgs) -> Type["_Accumulator"]:
return {
"stat": _StatAccumulator,
# TODO pr-chain: enable this later
# "per_pass": _DetailAccumulator,
# "per_token": _DetailAccumulator,
"per_pass": _DetailAccumulator,
"per_token": _DetailAccumulator,
}[server_args.expert_distribution_recorder_mode]
def __init__(
......@@ -547,6 +619,63 @@ class _DequeCollection:
return {d.maxlen: sum(d) / len(d) for d in self._dequeues}
class _DetailAccumulator(_UtilizationRateAccumulatorMixin):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._records = []
def get_single_pass_gatherer_keys(self):
if False: # TODO `server_args.enable_two_batch_overlap`
return [_SINGLE_PASS_GATHERER_KEY_PRIMARY, "child_a", "child_b"]
return super().get_single_pass_gatherer_keys()
def get_single_pass_gatherer_key(self, debug_name: Optional[str]):
if False: # TODO `server_args.enable_two_batch_overlap`
return debug_name or _SINGLE_PASS_GATHERER_KEY_PRIMARY
return super().get_single_pass_gatherer_key(debug_name)
def append(
self,
forward_pass_id: int,
gatherer_key: str,
single_pass_data: Dict,
):
super().append(forward_pass_id, gatherer_key, single_pass_data)
def _process_object(obj):
if isinstance(obj, torch.Tensor):
return obj.cpu().clone()
return obj
single_pass_data_processed = {
k: _process_object(v) for k, v in single_pass_data.items()
}
self._records.append(
dict(
forward_pass_id=forward_pass_id,
rank=self._rank,
gatherer_key=gatherer_key,
**single_pass_data_processed,
)
)
def reset(self):
super().reset()
self._records.clear()
def dump(self, output_mode: _OutputMode):
assert output_mode == "file"
output = dict(
records=self._records,
# NOTE: This may change during recording, so here we say it is the "last" one
last_physical_to_logical_map=self._expert_location_metadata.physical_to_logical_map,
)
_dump_to_file(
f"expert_distribution_recorder_{time.time()}_{self._rank}.pt", output
)
class _StatAccumulator(_UtilizationRateAccumulatorMixin):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
......
......@@ -23,9 +23,8 @@ class TestExpertDistribution(CustomTestCase):
dict(model_path="deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"),
dict(model_path="Qwen/Qwen1.5-MoE-A2.7B"),
dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", tp_size=2),
# TODO enable in next PR
# dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode="per_pass"),
# dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode="per_token"),
dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode="per_pass"),
dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode="per_token"),
]:
with self.subTest(info=info):
self._execute_core(**info)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment