Unverified Commit 86fe943b authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix expert distribution dumping causes OOM (#6967)

parent 9ecb1856
...@@ -703,6 +703,7 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin): ...@@ -703,6 +703,7 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin):
dtype=torch.int32, dtype=torch.int32,
device=self._server_args.device, device=self._server_args.device,
) )
self._first_dump = True
def append( def append(
self, self,
...@@ -727,9 +728,15 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin): ...@@ -727,9 +728,15 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin):
num_logical_experts=self._expert_location_metadata.num_logical_experts, num_logical_experts=self._expert_location_metadata.num_logical_experts,
physical_to_logical_map=self._expert_location_metadata.physical_to_logical_map, physical_to_logical_map=self._expert_location_metadata.physical_to_logical_map,
) )
if self._first_dump:
self._first_dump = False
torch.cuda.empty_cache()
torch.distributed.all_reduce( torch.distributed.all_reduce(
logical_count_of_buffered_step, op=torch.distributed.ReduceOp.SUM logical_count_of_buffered_step, op=torch.distributed.ReduceOp.SUM
) )
output = dict( output = dict(
rank=self._rank, rank=self._rank,
logical_count=logical_count_of_buffered_step, logical_count=logical_count_of_buffered_step,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment