Unverified Commit 333529de authored by Artem Perevedentsev's avatar Artem Perevedentsev Committed by GitHub
Browse files

[EPLB] Fix replica selection bias in fused_moe router (#40810)


Signed-off-by: default avatarArtem Perevedentsev <aperevedents@nvidia.com>
parent 88256082
...@@ -662,6 +662,52 @@ def test_eplb_map_no_redundancy( ...@@ -662,6 +662,52 @@ def test_eplb_map_no_redundancy(
assert load.sum().item() == 0 assert load.sum().item() == 0
@pytest.mark.parametrize("top_k,R", [(2, 2), (4, 2), (8, 4), (8, 8)])
def test_eplb_map_hot_expert_replica_balance(top_k, R):
"""Hot logical expert with R replicas must be balanced across replicas
even when ``top_k`` is a multiple of ``R``. In that regime every top-k
offset for the hot expert lands on a multiple of ``top_k`` in the flat
``topk_ids`` view, so per-replica assignment must not collapse onto a
single replica.
"""
num_tokens = 8192
num_logical = 16
num_physical = R + (num_logical - 1)
l2p = torch.full((num_logical, R), -1, dtype=torch.int64, device="cuda")
l2p[0] = torch.arange(R, dtype=torch.int64, device="cuda")
for i in range(1, num_logical):
l2p[i, 0] = R + i - 1
rc = torch.tensor([R] + [1] * (num_logical - 1), dtype=torch.int64, device="cuda")
torch.manual_seed(0)
topk_ids = torch.randint(
1,
num_logical,
(num_tokens, top_k),
dtype=torch.int32,
device="cuda",
)
topk_ids[:, 0] = 0
load = torch.zeros(num_physical, dtype=torch.int32, device="cuda")
rec = torch.tensor(True, dtype=torch.bool, device="cuda")
eplb_map_to_physical_and_record(
topk_ids=topk_ids,
expert_load_view=load,
logical_to_physical_map=l2p,
logical_replica_count=rc,
record_enabled=rec,
)
hot_load = load[:R].float()
max_mean = (hot_load.max() / hot_load.mean()).item()
assert max_mean < 1.15, (
f"Hot expert replicas uneven: {hot_load.tolist()}, max/mean={max_mean:.3f}"
)
@pytest.mark.parametrize("record_enabled", [True, False]) @pytest.mark.parametrize("record_enabled", [True, False])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"l2p_map, replica_count, num_physical, topk_ids, expected_out, expected_load", "l2p_map, replica_count, num_physical, topk_ids, expected_out, expected_load",
...@@ -672,10 +718,12 @@ def test_eplb_map_no_redundancy( ...@@ -672,10 +718,12 @@ def test_eplb_map_no_redundancy(
[2, 2, 1, 1], [2, 2, 1, 1],
6, 6,
[[0, 1], [2, 3], [0, 2]], [[0, 1], [2, 3], [0, 2]],
# offs: 0→0%2=0→p0, 1→1%2=1→p5, 2→2%1=0→p2, # replica = (token_idx * KNUTH) & 0xFFFFFFFF % R.
# 3→3%1=0→p3, 4→4%2=0→p0, 5→5%1=0→p2 # token 0 hash=0x00000000: %2=0, %1=0.
[[0, 5], [2, 3], [0, 2]], # token 1 hash=0x9E3779B9: %2=1, %1=0.
[2, 0, 2, 1, 0, 1], # token 2 hash=0x3C6EF372: %2=0, %1=0.
[[0, 1], [2, 3], [0, 2]],
[2, 1, 2, 1, 0, 0],
id="partial", id="partial",
), ),
pytest.param( pytest.param(
...@@ -684,10 +732,11 @@ def test_eplb_map_no_redundancy( ...@@ -684,10 +732,11 @@ def test_eplb_map_no_redundancy(
[2, 2, 2, 2], [2, 2, 2, 2],
8, 8,
[[0, 1], [2, 3], [0, 2]], [[0, 1], [2, 3], [0, 2]],
# offs: 0→0%2=0→p0, 1→1%2=1→p5, 2→2%2=0→p2, # token 0 hash=0x00000000: %2=0.
# 3→3%2=1→p7, 4→4%2=0→p0, 5→5%2=1→p6 # token 1 hash=0x9E3779B9: %2=1.
[[0, 5], [2, 7], [0, 6]], # token 2 hash=0x3C6EF372: %2=0.
[2, 0, 1, 0, 0, 1, 1, 1], [[0, 1], [6, 7], [0, 2]],
[2, 1, 1, 0, 0, 0, 1, 1],
id="full", id="full",
), ),
pytest.param( pytest.param(
...@@ -696,10 +745,11 @@ def test_eplb_map_no_redundancy( ...@@ -696,10 +745,11 @@ def test_eplb_map_no_redundancy(
[4, 2, 2], [4, 2, 2],
8, 8,
[[0, 1], [2, 0], [1, 2]], [[0, 1], [2, 0], [1, 2]],
# offs: 0→0%4=0→p0, 1→1%2=1→p4, 2→2%2=0→p2, # token 0 hash=0x00000000: %4=0, %2=0.
# 3→3%4=3→p7, 4→4%2=0→p1, 5→5%2=1→p6 # token 1 hash=0x9E3779B9: %4=1, %2=1.
[[0, 4], [2, 7], [1, 6]], # token 2 hash=0x3C6EF372: %4=2, %2=0.
[1, 1, 1, 0, 1, 0, 1, 1], [[0, 1], [6, 3], [1, 2]],
[1, 2, 1, 1, 0, 0, 1, 0],
id="uneven", id="uneven",
), ),
], ],
......
...@@ -26,6 +26,7 @@ if current_platform.is_cuda_alike(): ...@@ -26,6 +26,7 @@ if current_platform.is_cuda_alike():
map_slots, map_slots,
out_size, out_size,
numel, numel,
num_active_experts,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
pid = tl.program_id(0) pid = tl.program_id(0)
...@@ -37,7 +38,6 @@ if current_platform.is_cuda_alike(): ...@@ -37,7 +38,6 @@ if current_platform.is_cuda_alike():
safe_expert_id = tl.where(valid_expert, expert_id, 0) safe_expert_id = tl.where(valid_expert, expert_id, 0)
# 1. Convert the logical expert ids to physical expert ids # 1. Convert the logical expert ids to physical expert ids
# Directly select a random replica for each logical expert
replica_count = tl.load( replica_count = tl.load(
logical_replica_count_ptr + safe_expert_id, logical_replica_count_ptr + safe_expert_id,
mask=mask & valid_expert, mask=mask & valid_expert,
...@@ -45,8 +45,11 @@ if current_platform.is_cuda_alike(): ...@@ -45,8 +45,11 @@ if current_platform.is_cuda_alike():
) )
# Avoid invalid modulo/div by forcing at least 1. # Avoid invalid modulo/div by forcing at least 1.
replica_count = tl.maximum(replica_count, 1) replica_count = tl.maximum(replica_count, 1)
# Match torch.compile path: use flattened token position. # floor(2^32 / phi), classic Knuth multiplicative hash multiplier.
replica_idx = offs % replica_count KNUTH_MULTIPLIER = 2654435769
token_idx = (offs // num_active_experts).to(tl.int64)
hashed = (token_idx * KNUTH_MULTIPLIER) & 0xFFFFFFFF
replica_idx = hashed % replica_count
# 2. Record expert load metrics. # 2. Record expert load metrics.
...@@ -85,6 +88,7 @@ if current_platform.is_cuda_alike(): ...@@ -85,6 +88,7 @@ if current_platform.is_cuda_alike():
numel = topk_ids_in.numel() numel = topk_ids_in.numel()
if numel == 0: if numel == 0:
return topk_ids return topk_ids
num_active_experts = topk_ids_in.shape[-1]
out_flat = torch.empty((numel,), device=topk_ids.device, dtype=topk_ids.dtype) out_flat = torch.empty((numel,), device=topk_ids.device, dtype=topk_ids.dtype)
grid = lambda meta: (triton.cdiv(numel, meta["BLOCK_SIZE"]),) grid = lambda meta: (triton.cdiv(numel, meta["BLOCK_SIZE"]),)
assert expert_load_view.is_contiguous() assert expert_load_view.is_contiguous()
...@@ -99,6 +103,7 @@ if current_platform.is_cuda_alike(): ...@@ -99,6 +103,7 @@ if current_platform.is_cuda_alike():
logical_to_physical_map.shape[1], logical_to_physical_map.shape[1],
expert_load_view.shape[0], expert_load_view.shape[0],
numel, numel,
num_active_experts,
BLOCK_SIZE=256, BLOCK_SIZE=256,
) )
return out_flat.reshape(topk_ids.shape) return out_flat.reshape(topk_ids.shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment