Unverified Commit 21fab0a3 authored by Le Yang's avatar Le Yang Committed by GitHub
Browse files

fix(moe): fix RoutedExpertsCapturer assertion failure with DP>1 and MK path (#37879)

parent 3244a2eb
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import types import types
from types import SimpleNamespace
from unittest.mock import patch
import pytest import pytest
import torch import torch
from vllm.distributed.eplb.eplb_state import EplbLayerState from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
RoutedExpertsCapturer,
)
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
pytestmark = pytest.mark.cpu_test pytestmark = pytest.mark.cpu_test
_REC_MODULE = "vllm.model_executor.layers.fused_moe.routed_experts_capturer"
def _capturer_with_buffer(
*,
max_tokens: int = 8,
num_layers: int = 4,
num_experts_per_tok: int = 2,
dp_rank: int = 0,
) -> RoutedExpertsCapturer:
c = RoutedExpertsCapturer()
c.dp_rank = dp_rank
c._device_buffer = torch.full(
(max_tokens, num_layers, num_experts_per_tok),
-1,
dtype=torch.int32,
)
return c
class DummyRouter(BaseRouter): class DummyRouter(BaseRouter):
@property @property
...@@ -159,3 +183,61 @@ def test_gpu_model_runner_binding_stage(monkeypatch): ...@@ -159,3 +183,61 @@ def test_gpu_model_runner_binding_stage(monkeypatch):
assert callable(dummy_module.router.capture_fn) assert callable(dummy_module.router.capture_fn)
dummy_module.router.capture_fn(torch.tensor([[9, 10]])) dummy_module.router.capture_fn(torch.tensor([[9, 10]]))
assert len(capturer.calls) == 1 assert len(capturer.calls) == 1
def test_routed_experts_capturer_single_dp_no_metadata():
"""dp_metadata is None: capture writes the full topk_ids rows."""
capturer = _capturer_with_buffer(dp_rank=0)
topk = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.int32)
ctx = SimpleNamespace(dp_metadata=None)
with patch(f"{_REC_MODULE}.get_forward_context", return_value=ctx):
capturer.capture(layer_id=0, topk_ids=topk)
assert torch.equal(capturer._device_buffer[:3, 0, :], topk)
assert capturer._device_buffer[3, 0, 0].item() == -1
def test_routed_experts_capturer_dp_naive_concatenated_all_ranks():
"""n == sum(num_tokens_dp): slice this rank's segment from concatenated topk."""
capturer = _capturer_with_buffer(dp_rank=1)
num_tokens_dp = torch.tensor([2, 3], dtype=torch.int32)
ctx = SimpleNamespace(
dp_metadata=SimpleNamespace(num_tokens_across_dp_cpu=num_tokens_dp)
)
# Concatenated order: rank0 rows then rank1 rows.
topk = torch.tensor(
[[0, 1], [2, 3], [10, 11], [12, 13], [14, 15]], dtype=torch.int32
)
with patch(f"{_REC_MODULE}.get_forward_context", return_value=ctx):
capturer.capture(layer_id=0, topk_ids=topk)
want = topk[2:5]
assert torch.equal(capturer._device_buffer[:3, 0, :], want)
def test_routed_experts_capturer_dp_modular_local_tokens():
"""n == token_num_per_dp: topk is already local to this DP rank."""
capturer = _capturer_with_buffer(dp_rank=1)
num_tokens_dp = torch.tensor([2, 3], dtype=torch.int32)
ctx = SimpleNamespace(
dp_metadata=SimpleNamespace(num_tokens_across_dp_cpu=num_tokens_dp)
)
topk = torch.tensor([[10, 11], [12, 13], [14, 15]], dtype=torch.int32)
with patch(f"{_REC_MODULE}.get_forward_context", return_value=ctx):
capturer.capture(layer_id=0, topk_ids=topk)
assert torch.equal(capturer._device_buffer[:3, 0, :], topk)
def test_routed_experts_capturer_dp_unexpected_batch_raises():
"""Mismatch between topk batch dim and DP layout: fail fast."""
capturer = _capturer_with_buffer(dp_rank=0)
num_tokens_dp = torch.tensor([2, 3], dtype=torch.int32)
ctx = SimpleNamespace(
dp_metadata=SimpleNamespace(num_tokens_across_dp_cpu=num_tokens_dp)
)
# total=5, local=2: n=1 matches neither naive (5) nor modular (2).
topk = torch.tensor([[1, 2]], dtype=torch.int32)
with (
patch(f"{_REC_MODULE}.get_forward_context", return_value=ctx),
pytest.raises(AssertionError, match="unexpected topk_ids batch dim"),
):
capturer.capture(layer_id=0, topk_ids=topk)
assert capturer._device_buffer[0, 0, 0].item() == -1
...@@ -176,11 +176,27 @@ class RoutedExpertsCapturer: ...@@ -176,11 +176,27 @@ class RoutedExpertsCapturer:
end_loc = topk_ids.shape[0] end_loc = topk_ids.shape[0]
token_num_per_dp = topk_ids.shape[0] token_num_per_dp = topk_ids.shape[0]
else: # multi dp else: # multi dp
token_num_per_dp = ctx.dp_metadata.num_tokens_across_dp_cpu[self.dp_rank] num_tokens_dp = ctx.dp_metadata.num_tokens_across_dp_cpu
cumsum = torch.cumsum(ctx.dp_metadata.num_tokens_across_dp_cpu, dim=0) token_num_per_dp = int(num_tokens_dp[self.dp_rank].item())
assert cumsum[-1] == topk_ids.shape[0] total = int(num_tokens_dp.sum().item())
end_loc = cumsum[self.dp_rank] n = topk_ids.shape[0]
start_loc = end_loc - token_num_per_dp
if n == total:
# Naive dispatch: all DP ranks' tokens concatenated before routing.
cumsum = torch.cumsum(num_tokens_dp, dim=0)
end_loc = int(cumsum[self.dp_rank].item())
start_loc = end_loc - token_num_per_dp
elif n == token_num_per_dp:
# Modular-kernel path: DP combine happens inside quant_method.apply;
# select_experts only sees this rank's tokens.
start_loc = 0
end_loc = token_num_per_dp
else:
raise AssertionError(
"RoutedExpertsCapturer: unexpected topk_ids batch dim "
f"{n} (expected {total} or {token_num_per_dp} "
f"for dp_rank={self.dp_rank})"
)
if layer_id >= self._device_buffer.shape[1]: if layer_id >= self._device_buffer.shape[1]:
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment