Unverified Commit 12701e8a authored by Ilya Markov's avatar Ilya Markov Committed by GitHub
Browse files

[EPLB] Optmize eplb mapping and record in router for prefill (#36261)


Signed-off-by: default avatarilmarkov <markovilya197@gmail.com>
parent 494636b2
......@@ -8,6 +8,9 @@ import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.fused_moe.router.base_router import (
eplb_map_to_physical_and_record,
)
from vllm.model_executor.layers.fused_moe.router.router_factory import (
create_fused_moe_router,
)
......@@ -55,11 +58,13 @@ def setup_eplb_state(enable_eplb: bool, global_num_experts: int) -> EplbLayerSta
logical_replica_count = torch.ones(
global_num_experts, dtype=torch.int64, device="cuda"
)
should_record_tensor = torch.ones((), dtype=torch.bool, device="cuda")
return EplbLayerState(
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
should_record_tensor=should_record_tensor,
)
......@@ -581,3 +586,152 @@ def test_custom(
# hidden_states, router_logits = make_test_data(m, k, global_num_experts)
# topk_weights, topk_ids = router.select_experts(hidden_states, router_logits)
# ---------------------------------------------------------------------------
# Tests for eplb_map_to_physical_and_record
# ---------------------------------------------------------------------------
@pytest.mark.parametrize("record_enabled", [True, False])
@pytest.mark.parametrize(
"l2p_map, replica_count, num_physical, topk_ids, expected_out, expected_load",
[
pytest.param(
# logical i → physical i
[[0], [1], [2], [3]],
[1, 1, 1, 1],
4,
[[0, 1], [2, 3], [0, 2]],
[[0, 1], [2, 3], [0, 2]],
[2, 1, 2, 1],
id="identity",
),
pytest.param(
# logical 0→3, 1→0, 2→1, 3→2
[[3], [0], [1], [2]],
[1, 1, 1, 1],
4,
[[0, 1], [2, 3], [0, 2]],
[[3, 0], [1, 2], [3, 1]],
[1, 2, 1, 2],
id="shuffled",
),
pytest.param(
# logical 0→5, 1→2, 2→7, 3→0 in a larger physical space
[[5], [2], [7], [0]],
[1, 1, 1, 1],
8,
[[0, 1], [2, 3]],
[[5, 2], [7, 0]],
[1, 0, 1, 0, 0, 1, 0, 1],
id="sparse",
),
],
)
def test_eplb_map_no_redundancy(
record_enabled,
l2p_map,
replica_count,
num_physical,
topk_ids,
expected_out,
expected_load,
):
l2p = torch.tensor(l2p_map, dtype=torch.int64, device="cuda")
rc = torch.tensor(replica_count, dtype=torch.int64, device="cuda")
load = torch.zeros(num_physical, dtype=torch.int32, device="cuda")
rec = torch.tensor(record_enabled, dtype=torch.bool, device="cuda")
ids = torch.tensor(topk_ids, dtype=torch.int32, device="cuda")
out = eplb_map_to_physical_and_record(
topk_ids=ids,
expert_load_view=load,
logical_to_physical_map=l2p,
logical_replica_count=rc,
record_enabled=rec,
)
exp_out = torch.tensor(expected_out, dtype=out.dtype, device="cuda")
torch.testing.assert_close(out, exp_out)
if record_enabled:
exp_load = torch.tensor(expected_load, dtype=torch.int32, device="cuda")
torch.testing.assert_close(load, exp_load)
else:
assert load.sum().item() == 0
@pytest.mark.parametrize("record_enabled", [True, False])
@pytest.mark.parametrize(
"l2p_map, replica_count, num_physical, topk_ids, expected_out, expected_load",
[
pytest.param(
# experts 0,1 have 2 replicas; 2,3 have 1
[[0, 4], [1, 5], [2, -1], [3, -1]],
[2, 2, 1, 1],
6,
[[0, 1], [2, 3], [0, 2]],
# offs: 0→0%2=0→p0, 1→1%2=1→p5, 2→2%1=0→p2,
# 3→3%1=0→p3, 4→4%2=0→p0, 5→5%1=0→p2
[[0, 5], [2, 3], [0, 2]],
[2, 0, 2, 1, 0, 1],
id="partial",
),
pytest.param(
# all 4 experts have 2 replicas
[[0, 4], [1, 5], [2, 6], [3, 7]],
[2, 2, 2, 2],
8,
[[0, 1], [2, 3], [0, 2]],
# offs: 0→0%2=0→p0, 1→1%2=1→p5, 2→2%2=0→p2,
# 3→3%2=1→p7, 4→4%2=0→p0, 5→5%2=1→p6
[[0, 5], [2, 7], [0, 6]],
[2, 0, 1, 0, 0, 1, 1, 1],
id="full",
),
pytest.param(
# expert 0: 4 replicas, experts 1,2: 2 replicas
[[0, 3, 5, 7], [1, 4, -1, -1], [2, 6, -1, -1]],
[4, 2, 2],
8,
[[0, 1], [2, 0], [1, 2]],
# offs: 0→0%4=0→p0, 1→1%2=1→p4, 2→2%2=0→p2,
# 3→3%4=3→p7, 4→4%2=0→p1, 5→5%2=1→p6
[[0, 4], [2, 7], [1, 6]],
[1, 1, 1, 0, 1, 0, 1, 1],
id="uneven",
),
],
)
def test_eplb_map_with_redundancy(
record_enabled,
l2p_map,
replica_count,
num_physical,
topk_ids,
expected_out,
expected_load,
):
l2p = torch.tensor(l2p_map, dtype=torch.int64, device="cuda")
rc = torch.tensor(replica_count, dtype=torch.int64, device="cuda")
load = torch.zeros(num_physical, dtype=torch.int32, device="cuda")
rec = torch.tensor(record_enabled, dtype=torch.bool, device="cuda")
ids = torch.tensor(topk_ids, dtype=torch.int32, device="cuda")
out = eplb_map_to_physical_and_record(
topk_ids=ids,
expert_load_view=load,
logical_to_physical_map=l2p,
logical_replica_count=rc,
record_enabled=rec,
)
exp_out = torch.tensor(expected_out, dtype=out.dtype, device="cuda")
torch.testing.assert_close(out, exp_out)
if record_enabled:
exp_load = torch.tensor(expected_load, dtype=torch.int32, device="cuda")
torch.testing.assert_close(load, exp_load)
else:
assert load.sum().item() == 0
......@@ -62,6 +62,7 @@ def test_base_router_capture_with_eplb_enabled():
router.eplb_state.expert_load_view = torch.zeros(32, dtype=torch.int64)
router.eplb_state.logical_to_physical_map = torch.arange(32).view(32, 1)
router.eplb_state.logical_replica_count = torch.ones(32, dtype=torch.int64)
router.eplb_state.should_record_tensor = torch.ones((), dtype=torch.bool)
captured = []
......
......@@ -53,9 +53,9 @@ All2AllBackend = Literal[
class EPLBConfig:
"""Configuration for Expert Parallel Load Balancing (EP)."""
window_size: int = 1000
window_size: int = Field(default=1000, gt=0)
"""Window size for expert load recording."""
step_interval: int = 3000
step_interval: int = Field(default=3000, gt=0)
"""
Interval for rearranging experts in expert parallelism.
......@@ -71,7 +71,7 @@ class EPLBConfig:
Log the balancedness each step of expert parallelism.
This is turned off by default since it will cause communication overhead.
"""
log_balancedness_interval: int = 1
log_balancedness_interval: int = Field(default=1, gt=0)
"""
Interval for logging the balancedness.
"""
......
......@@ -399,6 +399,7 @@ class ElasticEPScalingExecutor:
eplb_model_state.logical_to_physical_map,
eplb_model_state.logical_replica_count,
)
eplb_state._init_should_record_tensor(model)
model.update_physical_experts_metadata(
num_physical_experts=num_physical_experts,
num_local_physical_experts=num_local_experts,
......
......@@ -272,6 +272,13 @@ class EplbState:
Interval for expert rearrangement steps.
This is a constant and is taken from the config.
"""
self.should_record_tensor: torch.Tensor | None = None
"""
Shared scalar bool tensor for all layers. Every
:class:`EplbLayerState` holds a reference to the **same** object so
a single ``.fill_()`` updates all layers at once. Allocated on the
first call to :meth:`_init_should_record_tensor`.
"""
self.is_async: bool = False
"""
The flag indicates whether the EPLB is running in async mode.
......@@ -462,7 +469,7 @@ class EplbState:
logical_to_physical_map,
logical_replica_count,
)
self._init_should_record_tensor(model)
expert_buffer = [torch.empty_like(w) for w in model.expert_weights[0]]
model_state = EplbModelState(
......@@ -582,12 +589,15 @@ class EplbState:
# Update the expert load sliding window
if not is_dummy:
should_record = self._should_record_current_step(log_stats=log_stats)
for eplb_model_state in self.model_states.values():
eplb_model_state.expert_load_window[self.expert_load_window_step] = (
eplb_model_state.expert_load_pass.clone()
)
if should_record:
eplb_model_state.expert_load_window[
self.expert_load_window_step
].copy_(eplb_model_state.expert_load_pass)
eplb_model_state.expert_load_pass.zero_()
if should_record:
self.expert_load_window_step += 1
if self.expert_load_window_step >= self.expert_load_window_size:
self.expert_load_window_step = 0
......@@ -617,11 +627,66 @@ class EplbState:
eplb_model_state.rebalanced
for eplb_model_state in self.model_states.values()
):
# Still performing asynchronous rearrangement
# Still performing asynchronous rearrangement; update
# should_record (step > step_interval, so always True) and
# bail out before the step counter is reset.
self._update_layer_should_record(log_stats=log_stats)
return
self.expert_rearrangement_step = 0
self.rearrange()
self._update_layer_should_record(log_stats=log_stats)
def _should_record_current_step(self, log_stats: bool = False) -> bool:
"""Return whether expert-load recording should be enabled this step.
Recording is enabled when we are close to either:
1) The next rearrangement step, so the sliding window is ready.
2) The next balancedness logging step, when log_stats is enabled.
"""
steps_remaining = (
self.expert_rearrangement_step_interval - self.expert_rearrangement_step
)
should_record_for_rearrange = steps_remaining <= self.expert_load_window_size
if not log_stats:
return should_record_for_rearrange
log_interval = self.parallel_config.eplb_config.log_balancedness_interval
steps_until_next_log = (
log_interval - (self.expert_rearrangement_step % log_interval)
) % log_interval
should_record_for_log = steps_until_next_log <= self.expert_load_window_size
return should_record_for_rearrange or should_record_for_log
def _update_layer_should_record(self, log_stats: bool = False) -> None:
"""Update the shared ``should_record_tensor`` for all layers."""
if self.should_record_tensor is not None:
self.should_record_tensor.fill_(
self._should_record_current_step(log_stats=log_stats)
)
def _init_should_record_tensor(self, model: "MixtureOfExperts") -> None: # type: ignore[name-defined]
"""Allocate (once) and propagate the shared ``should_record_tensor``.
Must be called after :meth:`model.set_eplb_state` so that each
layer's ``eplb_state`` is already populated with the tensor views.
"""
layer_states = [
layer.eplb_state
for layer in model.moe_layers
if hasattr(layer, "eplb_state")
and isinstance(layer.eplb_state, EplbLayerState)
]
if self.should_record_tensor is None and layer_states:
self.should_record_tensor = torch.ones(
(), dtype=torch.bool, device=self.device
)
for ls in layer_states:
ls.should_record_tensor = self.should_record_tensor
def rearrange(
self,
is_profile: bool = False,
......@@ -993,6 +1058,17 @@ class EplbLayerState:
expert_load_view: torch.Tensor | None = None
logical_to_physical_map: torch.Tensor | None = None
logical_replica_count: torch.Tensor | None = None
should_record_tensor: torch.Tensor | None = None
"""
Shared scalar bool tensor controlling whether to accumulate expert load
metrics during this forward pass. All layers reference the **same**
tensor object, which is owned and updated by :class:`EplbState`.
Set to ``False`` for the first ``step_interval - window_size`` steps of
each rearrangement period: those steps would be overwritten in the
sliding window before the next rearrangement, so recording them wastes
GPU work.
"""
def _node_count_with_rank_mapping(
......
......@@ -10,61 +10,49 @@ from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
if current_platform.is_cuda_alike():
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def eplb_map_to_physical_and_record(
topk_ids: torch.Tensor,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> torch.Tensor:
"""
Map the logical expert ids to physical expert ids
and record the expert load metrics.
This will select a pseudo-random replica for each logical expert.
Only used for EPLB.
Args:
topk_ids: The logical expert ids.
expert_load_view: The expert load view.
logical_to_physical_map: The logical to physical map.
logical_replica_count: The logical replica count.
@triton.jit
def _eplb_map_and_record_i32_kernel(
topk_ids_ptr,
logical_replica_count_ptr,
logical_to_physical_ptr,
out_ids_ptr,
out_ptr,
record_enabled_ptr,
num_logical_experts,
map_slots,
out_size,
numel,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offs < numel
Returns:
The physical expert ids.
"""
expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int64)
valid_expert = (expert_id >= 0) & (expert_id < num_logical_experts)
safe_expert_id = tl.where(valid_expert, expert_id, 0)
# 1. Convert the logical expert ids to physical expert ids
# Directly select a random replica for each logical expert
# In case `indices_type` is not `torch.long` or `torch.int`,
# e.g. `torch.uint32` as required by dispatch/combine kernels
topk_ids_long = topk_ids.long()
# Use (token position) modulo (replica count)
# to deterministically choose a replica
replica_count = logical_replica_count[topk_ids_long]
# Flatten-position based index, reshaped back to `topk_ids` shape
pos_indices = torch.arange(
topk_ids.numel(), device=topk_ids.device, dtype=torch.long
).reshape_as(topk_ids)
# Compute pseudo-random indices by modulo
replica_indices = (pos_indices % replica_count).unsqueeze(-1)
physical_ids = (
logical_to_physical_map[topk_ids_long]
.gather(-1, replica_indices)
.squeeze(-1)
replica_count = tl.load(
logical_replica_count_ptr + safe_expert_id,
mask=mask & valid_expert,
other=1,
)
topk_ids = physical_ids
# Avoid invalid modulo/div by forcing at least 1.
replica_count = tl.maximum(replica_count, 1)
# Match torch.compile path: use flattened token position.
replica_idx = offs % replica_count
# 2. Record expert load metrics.
# TODO(bowen): When using `FusedMoEModularKernel`, this
# can be done in a more unified way, since
# `FusedMoEPrepareAndFinalizeModular` will return the expert
# `FusedMoEPrepareAndFinalize` will return the expert
# token count, in some cases directly from the kernel.
# However, now there are many code paths not using
# the modular kernel, e.g. calling `fused_experts`,
......@@ -73,17 +61,63 @@ if current_platform.is_cuda_alike():
# If later refactor moved all the MoE kernel calls
# to the modular kernel, we can move this logic there
# to achieve better efficiency.
map_index = safe_expert_id * map_slots + replica_idx
physical_id = tl.load(
logical_to_physical_ptr + map_index,
mask=mask & valid_expert,
other=-1,
)
tl.store(out_ids_ptr + offs, physical_id, mask=mask)
# `expert_load_view`: (num_physical_experts,)
record_enabled = tl.load(record_enabled_ptr) != 0
valid = mask & record_enabled & (physical_id >= 0) & (physical_id < out_size)
safe_physical_id = tl.where(physical_id >= 0, physical_id, 0)
tl.atomic_add(out_ptr + safe_physical_id, 1, mask=valid)
# `torch.bincount` is not compilable, so use `scatter_add_` instead.
topk_ids_flatten = topk_ids.flatten()
expert_load_view.scatter_add_(
dim=0,
index=topk_ids_flatten.long(),
src=torch.ones_like(topk_ids_flatten).to(expert_load_view),
)
def _eplb_map_and_record_triton(
topk_ids: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
expert_load_view: torch.Tensor,
record_enabled: torch.Tensor,
) -> torch.Tensor:
topk_ids_in = topk_ids.contiguous().to(dtype=torch.int32)
numel = topk_ids_in.numel()
if numel == 0:
return topk_ids
out_flat = torch.empty((numel,), device=topk_ids.device, dtype=topk_ids.dtype)
grid = lambda meta: (triton.cdiv(numel, meta["BLOCK_SIZE"]),)
assert expert_load_view.is_contiguous()
_eplb_map_and_record_i32_kernel[grid](
topk_ids_in,
logical_replica_count.contiguous(),
logical_to_physical_map.contiguous(),
out_flat,
expert_load_view,
record_enabled,
logical_replica_count.shape[0],
logical_to_physical_map.shape[1],
expert_load_view.shape[0],
numel,
BLOCK_SIZE=256,
)
return out_flat.reshape(topk_ids.shape)
def eplb_map_to_physical_and_record(
topk_ids: torch.Tensor,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
record_enabled: torch.Tensor,
) -> torch.Tensor:
# Fused triton implementation: mapping + optional recording in one kernel.
return _eplb_map_and_record_triton(
topk_ids=topk_ids,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
expert_load_view=expert_load_view,
record_enabled=record_enabled,
)
else:
def eplb_map_to_physical_and_record(
......@@ -91,8 +125,8 @@ else:
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
record_enabled: torch.Tensor,
) -> torch.Tensor:
# CPU fallback: no EPLB so just return as is
return topk_ids
......@@ -146,6 +180,10 @@ class BaseRouter(FusedMoERouter):
raise ValueError(
"enable_eplb=True requires logical_replica_count != None"
)
if self.eplb_state.should_record_tensor is None:
raise ValueError(
"enable_eplb=True requires should_record_tensor != None"
)
def _get_indices_type(self) -> torch.dtype | None:
"""Get the desired indices dtype from the getter function."""
......@@ -159,11 +197,13 @@ class BaseRouter(FusedMoERouter):
assert self.eplb_state.expert_load_view is not None
assert self.eplb_state.logical_to_physical_map is not None
assert self.eplb_state.logical_replica_count is not None
assert self.eplb_state.should_record_tensor is not None
return eplb_map_to_physical_and_record(
topk_ids=topk_ids,
expert_load_view=self.eplb_state.expert_load_view,
logical_to_physical_map=self.eplb_state.logical_to_physical_map,
logical_replica_count=self.eplb_state.logical_replica_count,
expert_load_view=self.eplb_state.expert_load_view,
record_enabled=self.eplb_state.should_record_tensor,
)
return topk_ids
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment