Unverified Commit 4be02a37 authored by WeiQing Chen's avatar WeiQing Chen Committed by GitHub
Browse files

[Bugfix] EPLB load statistics problem (#22167)


Signed-off-by: default avatarycyaw66 <497410282@qq.com>
Signed-off-by: default avatarDavid Chen <530634352@qq.com>
Co-authored-by: default avatarycyaw66 <497410282@qq.com>
parent f6278b62
...@@ -32,7 +32,7 @@ from dataclasses import dataclass ...@@ -32,7 +32,7 @@ from dataclasses import dataclass
from typing import Optional, Union from typing import Optional, Union
import torch import torch
from torch.distributed import ProcessGroup, all_gather, all_reduce from torch.distributed import ProcessGroup, all_reduce
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import (get_ep_group, get_node_count, from vllm.distributed.parallel_state import (get_ep_group, get_node_count,
...@@ -112,13 +112,21 @@ class EplbState: ...@@ -112,13 +112,21 @@ class EplbState:
Expert load during this forward pass. Expert load during this forward pass.
We use the token count each expert processes as the load. We use the token count each expert processes as the load.
Shape: (num_moe_layers, num_local_physical_experts) Shape: (num_moe_layers, num_physical_experts)
""" """
expert_load_window: torch.Tensor expert_load_window: torch.Tensor
""" """
A sliding window of expert load. A sliding window of expert load.
Shape: (window_size, num_moe_layers, num_local_physical_experts) Shape: (window_size, num_moe_layers, num_physical_experts)
NOTE: The expert_load_view now records load for all physical experts
rather than just local experts. This ensures consistent load statistics
across different dispatch methods (naive all-to-all, DeepEP, pplx-kernels).
The recorded load will be multiplied by dp_size when using naive all-to-all
due to each DP rank contributing the same token set to the calculation.
See:
https://github.com/vllm-project/vllm/pull/22167#pullrequestreview-3086143856
""" """
expert_load_window_step: int = 0 expert_load_window_step: int = 0
""" """
...@@ -232,14 +240,14 @@ class EplbState: ...@@ -232,14 +240,14 @@ class EplbState:
).contiguous() ).contiguous()
expert_load_pass = torch.zeros( expert_load_pass = torch.zeros(
(model.num_moe_layers, model.num_local_physical_experts), (model.num_moe_layers, model.num_physical_experts),
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
) )
expert_load_window_size = parallel_config.eplb_window_size expert_load_window_size = parallel_config.eplb_window_size
expert_load_window = torch.zeros( expert_load_window = torch.zeros(
(expert_load_window_size, model.num_moe_layers, (expert_load_window_size, model.num_moe_layers,
model.num_local_physical_experts), model.num_physical_experts),
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
) )
...@@ -353,18 +361,18 @@ class EplbState: ...@@ -353,18 +361,18 @@ class EplbState:
self.expert_load_pass.zero_() self.expert_load_pass.zero_()
if log_stats: if log_stats:
# `num_tokens`: (num_moe_layers,) # total_expert_load_pass: (num_moe_layers, num_physical_experts)
num_tokens = self.expert_load_pass.sum(dim=-1) total_expert_load_pass = self.expert_load_pass.clone()
# Collect load metrics from all ranks # Collect load metrics from all ranks
ep_group = get_ep_group().device_group ep_group = get_ep_group().device_group
assert ep_group is not None assert ep_group is not None
num_tokens_list = [ all_reduce(total_expert_load_pass, group=ep_group)
torch.empty_like(num_tokens) for _ in range(ep_group.size())
] # num_tokens_per_rank: (num_moe_layers, num_ranks)
all_gather(num_tokens_list, num_tokens, group=ep_group) num_tokens_per_rank = total_expert_load_pass.reshape(
# Stack to get (num_ranks, num_moe_layers) total_expert_load_pass.shape[0], ep_group.size(),
num_tokens_per_rank = torch.stack(num_tokens_list).float() -1).sum(dim=-1).float()
# Compute balancedness ratio: # Compute balancedness ratio:
# for each layer: # for each layer:
...@@ -426,17 +434,7 @@ class EplbState: ...@@ -426,17 +434,7 @@ class EplbState:
"(profile)" if is_profile else "") "(profile)" if is_profile else "")
if global_expert_load is None: if global_expert_load is None:
# This mapping is only used here, so we do not store it in the state # Map the physical expert load to global logical experts
physical_expert_start = ep_rank * model.num_local_physical_experts
physical_expert_end = (physical_expert_start +
model.num_local_physical_experts)
# (num_moe_layers, num_local_physical_experts)
local_physical_to_logical_map = self.physical_to_logical_map[
:,
physical_expert_start:physical_expert_end,
]
# Map the local physical expert load to global logical experts
logical_expert_load_window = torch.zeros( logical_expert_load_window = torch.zeros(
self.expert_load_window_size, self.expert_load_window_size,
model.num_moe_layers, model.num_moe_layers,
...@@ -446,7 +444,7 @@ class EplbState: ...@@ -446,7 +444,7 @@ class EplbState:
) )
logical_expert_load_window.scatter_add_( logical_expert_load_window.scatter_add_(
dim=-1, dim=-1,
index=local_physical_to_logical_map.unsqueeze(0).expand_as( index=self.physical_to_logical_map.unsqueeze(0).expand_as(
self.expert_load_window).long(), self.expert_load_window).long(),
src=self.expert_load_window, src=self.expert_load_window,
) )
......
...@@ -1430,23 +1430,10 @@ class FusedMoE(torch.nn.Module): ...@@ -1430,23 +1430,10 @@ class FusedMoE(torch.nn.Module):
# to the modular kernel, we can move this logic there # to the modular kernel, we can move this logic there
# to achieve better efficiency. # to achieve better efficiency.
# `expert_load_view`: (num_logical_experts,) # `expert_load_view`: (num_physical_experts,)
# Mask out non-local experts
if expert_map is not None:
topk_ids_local = expert_map[topk_ids]
topk_ids_flatten = topk_ids_local.flatten()
else:
topk_ids_flatten = topk_ids.flatten() topk_ids_flatten = topk_ids.flatten()
# Should be equivalent to:
# ```
# topk_ids_masked = topk_ids_local[topk_ids_local >= 0]
# expert_load_view += topk_ids_masked.bincount(
# minlength=expert_load_view.shape[0])
# ```
# We use `scatter_add_` since `bincount` cannot be compiled
# Performance optimization: # Performance optimization:
# `masked_fill` is significantly faster than `masked_select` # `masked_fill` is significantly faster than `masked_select`
invalid_mask = topk_ids_flatten < 0 invalid_mask = topk_ids_flatten < 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment