Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
12701e8a
Unverified
Commit
12701e8a
authored
Mar 30, 2026
by
Ilya Markov
Committed by
GitHub
Mar 30, 2026
Browse files
[EPLB] Optmize eplb mapping and record in router for prefill (#36261)
Signed-off-by:
ilmarkov
<
markovilya197@gmail.com
>
parent
494636b2
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
336 additions
and
64 deletions
+336
-64
tests/kernels/moe/test_routing.py
tests/kernels/moe/test_routing.py
+154
-0
tests/model_executor/test_routed_experts_capture.py
tests/model_executor/test_routed_experts_capture.py
+1
-0
vllm/config/parallel.py
vllm/config/parallel.py
+3
-3
vllm/distributed/elastic_ep/elastic_execute.py
vllm/distributed/elastic_ep/elastic_execute.py
+1
-0
vllm/distributed/eplb/eplb_state.py
vllm/distributed/eplb/eplb_state.py
+85
-9
vllm/model_executor/layers/fused_moe/router/base_router.py
vllm/model_executor/layers/fused_moe/router/base_router.py
+92
-52
No files found.
tests/kernels/moe/test_routing.py
View file @
12701e8a
...
@@ -8,6 +8,9 @@ import torch
...
@@ -8,6 +8,9 @@ import torch
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.distributed.eplb.eplb_state
import
EplbLayerState
from
vllm.distributed.eplb.eplb_state
import
EplbLayerState
from
vllm.model_executor.layers.fused_moe.router.base_router
import
(
eplb_map_to_physical_and_record
,
)
from
vllm.model_executor.layers.fused_moe.router.router_factory
import
(
from
vllm.model_executor.layers.fused_moe.router.router_factory
import
(
create_fused_moe_router
,
create_fused_moe_router
,
)
)
...
@@ -55,11 +58,13 @@ def setup_eplb_state(enable_eplb: bool, global_num_experts: int) -> EplbLayerSta
...
@@ -55,11 +58,13 @@ def setup_eplb_state(enable_eplb: bool, global_num_experts: int) -> EplbLayerSta
logical_replica_count
=
torch
.
ones
(
logical_replica_count
=
torch
.
ones
(
global_num_experts
,
dtype
=
torch
.
int64
,
device
=
"cuda"
global_num_experts
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
)
should_record_tensor
=
torch
.
ones
((),
dtype
=
torch
.
bool
,
device
=
"cuda"
)
return
EplbLayerState
(
return
EplbLayerState
(
expert_load_view
=
expert_load_view
,
expert_load_view
=
expert_load_view
,
logical_to_physical_map
=
logical_to_physical_map
,
logical_to_physical_map
=
logical_to_physical_map
,
logical_replica_count
=
logical_replica_count
,
logical_replica_count
=
logical_replica_count
,
should_record_tensor
=
should_record_tensor
,
)
)
...
@@ -581,3 +586,152 @@ def test_custom(
...
@@ -581,3 +586,152 @@ def test_custom(
# hidden_states, router_logits = make_test_data(m, k, global_num_experts)
# hidden_states, router_logits = make_test_data(m, k, global_num_experts)
# topk_weights, topk_ids = router.select_experts(hidden_states, router_logits)
# topk_weights, topk_ids = router.select_experts(hidden_states, router_logits)
# ---------------------------------------------------------------------------
# Tests for eplb_map_to_physical_and_record
# ---------------------------------------------------------------------------
@
pytest
.
mark
.
parametrize
(
"record_enabled"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"l2p_map, replica_count, num_physical, topk_ids, expected_out, expected_load"
,
[
pytest
.
param
(
# logical i → physical i
[[
0
],
[
1
],
[
2
],
[
3
]],
[
1
,
1
,
1
,
1
],
4
,
[[
0
,
1
],
[
2
,
3
],
[
0
,
2
]],
[[
0
,
1
],
[
2
,
3
],
[
0
,
2
]],
[
2
,
1
,
2
,
1
],
id
=
"identity"
,
),
pytest
.
param
(
# logical 0→3, 1→0, 2→1, 3→2
[[
3
],
[
0
],
[
1
],
[
2
]],
[
1
,
1
,
1
,
1
],
4
,
[[
0
,
1
],
[
2
,
3
],
[
0
,
2
]],
[[
3
,
0
],
[
1
,
2
],
[
3
,
1
]],
[
1
,
2
,
1
,
2
],
id
=
"shuffled"
,
),
pytest
.
param
(
# logical 0→5, 1→2, 2→7, 3→0 in a larger physical space
[[
5
],
[
2
],
[
7
],
[
0
]],
[
1
,
1
,
1
,
1
],
8
,
[[
0
,
1
],
[
2
,
3
]],
[[
5
,
2
],
[
7
,
0
]],
[
1
,
0
,
1
,
0
,
0
,
1
,
0
,
1
],
id
=
"sparse"
,
),
],
)
def
test_eplb_map_no_redundancy
(
record_enabled
,
l2p_map
,
replica_count
,
num_physical
,
topk_ids
,
expected_out
,
expected_load
,
):
l2p
=
torch
.
tensor
(
l2p_map
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
rc
=
torch
.
tensor
(
replica_count
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
load
=
torch
.
zeros
(
num_physical
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
rec
=
torch
.
tensor
(
record_enabled
,
dtype
=
torch
.
bool
,
device
=
"cuda"
)
ids
=
torch
.
tensor
(
topk_ids
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
out
=
eplb_map_to_physical_and_record
(
topk_ids
=
ids
,
expert_load_view
=
load
,
logical_to_physical_map
=
l2p
,
logical_replica_count
=
rc
,
record_enabled
=
rec
,
)
exp_out
=
torch
.
tensor
(
expected_out
,
dtype
=
out
.
dtype
,
device
=
"cuda"
)
torch
.
testing
.
assert_close
(
out
,
exp_out
)
if
record_enabled
:
exp_load
=
torch
.
tensor
(
expected_load
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
torch
.
testing
.
assert_close
(
load
,
exp_load
)
else
:
assert
load
.
sum
().
item
()
==
0
@
pytest
.
mark
.
parametrize
(
"record_enabled"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"l2p_map, replica_count, num_physical, topk_ids, expected_out, expected_load"
,
[
pytest
.
param
(
# experts 0,1 have 2 replicas; 2,3 have 1
[[
0
,
4
],
[
1
,
5
],
[
2
,
-
1
],
[
3
,
-
1
]],
[
2
,
2
,
1
,
1
],
6
,
[[
0
,
1
],
[
2
,
3
],
[
0
,
2
]],
# offs: 0→0%2=0→p0, 1→1%2=1→p5, 2→2%1=0→p2,
# 3→3%1=0→p3, 4→4%2=0→p0, 5→5%1=0→p2
[[
0
,
5
],
[
2
,
3
],
[
0
,
2
]],
[
2
,
0
,
2
,
1
,
0
,
1
],
id
=
"partial"
,
),
pytest
.
param
(
# all 4 experts have 2 replicas
[[
0
,
4
],
[
1
,
5
],
[
2
,
6
],
[
3
,
7
]],
[
2
,
2
,
2
,
2
],
8
,
[[
0
,
1
],
[
2
,
3
],
[
0
,
2
]],
# offs: 0→0%2=0→p0, 1→1%2=1→p5, 2→2%2=0→p2,
# 3→3%2=1→p7, 4→4%2=0→p0, 5→5%2=1→p6
[[
0
,
5
],
[
2
,
7
],
[
0
,
6
]],
[
2
,
0
,
1
,
0
,
0
,
1
,
1
,
1
],
id
=
"full"
,
),
pytest
.
param
(
# expert 0: 4 replicas, experts 1,2: 2 replicas
[[
0
,
3
,
5
,
7
],
[
1
,
4
,
-
1
,
-
1
],
[
2
,
6
,
-
1
,
-
1
]],
[
4
,
2
,
2
],
8
,
[[
0
,
1
],
[
2
,
0
],
[
1
,
2
]],
# offs: 0→0%4=0→p0, 1→1%2=1→p4, 2→2%2=0→p2,
# 3→3%4=3→p7, 4→4%2=0→p1, 5→5%2=1→p6
[[
0
,
4
],
[
2
,
7
],
[
1
,
6
]],
[
1
,
1
,
1
,
0
,
1
,
0
,
1
,
1
],
id
=
"uneven"
,
),
],
)
def
test_eplb_map_with_redundancy
(
record_enabled
,
l2p_map
,
replica_count
,
num_physical
,
topk_ids
,
expected_out
,
expected_load
,
):
l2p
=
torch
.
tensor
(
l2p_map
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
rc
=
torch
.
tensor
(
replica_count
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
load
=
torch
.
zeros
(
num_physical
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
rec
=
torch
.
tensor
(
record_enabled
,
dtype
=
torch
.
bool
,
device
=
"cuda"
)
ids
=
torch
.
tensor
(
topk_ids
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
out
=
eplb_map_to_physical_and_record
(
topk_ids
=
ids
,
expert_load_view
=
load
,
logical_to_physical_map
=
l2p
,
logical_replica_count
=
rc
,
record_enabled
=
rec
,
)
exp_out
=
torch
.
tensor
(
expected_out
,
dtype
=
out
.
dtype
,
device
=
"cuda"
)
torch
.
testing
.
assert_close
(
out
,
exp_out
)
if
record_enabled
:
exp_load
=
torch
.
tensor
(
expected_load
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
torch
.
testing
.
assert_close
(
load
,
exp_load
)
else
:
assert
load
.
sum
().
item
()
==
0
tests/model_executor/test_routed_experts_capture.py
View file @
12701e8a
...
@@ -62,6 +62,7 @@ def test_base_router_capture_with_eplb_enabled():
...
@@ -62,6 +62,7 @@ def test_base_router_capture_with_eplb_enabled():
router
.
eplb_state
.
expert_load_view
=
torch
.
zeros
(
32
,
dtype
=
torch
.
int64
)
router
.
eplb_state
.
expert_load_view
=
torch
.
zeros
(
32
,
dtype
=
torch
.
int64
)
router
.
eplb_state
.
logical_to_physical_map
=
torch
.
arange
(
32
).
view
(
32
,
1
)
router
.
eplb_state
.
logical_to_physical_map
=
torch
.
arange
(
32
).
view
(
32
,
1
)
router
.
eplb_state
.
logical_replica_count
=
torch
.
ones
(
32
,
dtype
=
torch
.
int64
)
router
.
eplb_state
.
logical_replica_count
=
torch
.
ones
(
32
,
dtype
=
torch
.
int64
)
router
.
eplb_state
.
should_record_tensor
=
torch
.
ones
((),
dtype
=
torch
.
bool
)
captured
=
[]
captured
=
[]
...
...
vllm/config/parallel.py
View file @
12701e8a
...
@@ -53,9 +53,9 @@ All2AllBackend = Literal[
...
@@ -53,9 +53,9 @@ All2AllBackend = Literal[
class
EPLBConfig
:
class
EPLBConfig
:
"""Configuration for Expert Parallel Load Balancing (EP)."""
"""Configuration for Expert Parallel Load Balancing (EP)."""
window_size
:
int
=
1000
window_size
:
int
=
Field
(
default
=
1000
,
gt
=
0
)
"""Window size for expert load recording."""
"""Window size for expert load recording."""
step_interval
:
int
=
3000
step_interval
:
int
=
Field
(
default
=
3000
,
gt
=
0
)
"""
"""
Interval for rearranging experts in expert parallelism.
Interval for rearranging experts in expert parallelism.
...
@@ -71,7 +71,7 @@ class EPLBConfig:
...
@@ -71,7 +71,7 @@ class EPLBConfig:
Log the balancedness each step of expert parallelism.
Log the balancedness each step of expert parallelism.
This is turned off by default since it will cause communication overhead.
This is turned off by default since it will cause communication overhead.
"""
"""
log_balancedness_interval
:
int
=
1
log_balancedness_interval
:
int
=
Field
(
default
=
1
,
gt
=
0
)
"""
"""
Interval for logging the balancedness.
Interval for logging the balancedness.
"""
"""
...
...
vllm/distributed/elastic_ep/elastic_execute.py
View file @
12701e8a
...
@@ -399,6 +399,7 @@ class ElasticEPScalingExecutor:
...
@@ -399,6 +399,7 @@ class ElasticEPScalingExecutor:
eplb_model_state
.
logical_to_physical_map
,
eplb_model_state
.
logical_to_physical_map
,
eplb_model_state
.
logical_replica_count
,
eplb_model_state
.
logical_replica_count
,
)
)
eplb_state
.
_init_should_record_tensor
(
model
)
model
.
update_physical_experts_metadata
(
model
.
update_physical_experts_metadata
(
num_physical_experts
=
num_physical_experts
,
num_physical_experts
=
num_physical_experts
,
num_local_physical_experts
=
num_local_experts
,
num_local_physical_experts
=
num_local_experts
,
...
...
vllm/distributed/eplb/eplb_state.py
View file @
12701e8a
...
@@ -272,6 +272,13 @@ class EplbState:
...
@@ -272,6 +272,13 @@ class EplbState:
Interval for expert rearrangement steps.
Interval for expert rearrangement steps.
This is a constant and is taken from the config.
This is a constant and is taken from the config.
"""
"""
self
.
should_record_tensor
:
torch
.
Tensor
|
None
=
None
"""
Shared scalar bool tensor for all layers. Every
:class:`EplbLayerState` holds a reference to the **same** object so
a single ``.fill_()`` updates all layers at once. Allocated on the
first call to :meth:`_init_should_record_tensor`.
"""
self
.
is_async
:
bool
=
False
self
.
is_async
:
bool
=
False
"""
"""
The flag indicates whether the EPLB is running in async mode.
The flag indicates whether the EPLB is running in async mode.
...
@@ -462,7 +469,7 @@ class EplbState:
...
@@ -462,7 +469,7 @@ class EplbState:
logical_to_physical_map
,
logical_to_physical_map
,
logical_replica_count
,
logical_replica_count
,
)
)
self
.
_init_should_record_tensor
(
model
)
expert_buffer
=
[
torch
.
empty_like
(
w
)
for
w
in
model
.
expert_weights
[
0
]]
expert_buffer
=
[
torch
.
empty_like
(
w
)
for
w
in
model
.
expert_weights
[
0
]]
model_state
=
EplbModelState
(
model_state
=
EplbModelState
(
...
@@ -582,12 +589,15 @@ class EplbState:
...
@@ -582,12 +589,15 @@ class EplbState:
# Update the expert load sliding window
# Update the expert load sliding window
if
not
is_dummy
:
if
not
is_dummy
:
should_record
=
self
.
_should_record_current_step
(
log_stats
=
log_stats
)
for
eplb_model_state
in
self
.
model_states
.
values
():
for
eplb_model_state
in
self
.
model_states
.
values
():
eplb_model_state
.
expert_load_window
[
self
.
expert_load_window_step
]
=
(
if
should_record
:
eplb_model_state
.
expert_load_pass
.
clone
()
eplb_model_state
.
expert_load_window
[
)
self
.
expert_load_window_step
].
copy_
(
eplb_model_state
.
expert_load_pass
)
eplb_model_state
.
expert_load_pass
.
zero_
()
eplb_model_state
.
expert_load_pass
.
zero_
()
if
should_record
:
self
.
expert_load_window_step
+=
1
self
.
expert_load_window_step
+=
1
if
self
.
expert_load_window_step
>=
self
.
expert_load_window_size
:
if
self
.
expert_load_window_step
>=
self
.
expert_load_window_size
:
self
.
expert_load_window_step
=
0
self
.
expert_load_window_step
=
0
...
@@ -617,11 +627,66 @@ class EplbState:
...
@@ -617,11 +627,66 @@ class EplbState:
eplb_model_state
.
rebalanced
eplb_model_state
.
rebalanced
for
eplb_model_state
in
self
.
model_states
.
values
()
for
eplb_model_state
in
self
.
model_states
.
values
()
):
):
# Still performing asynchronous rearrangement
# Still performing asynchronous rearrangement; update
# should_record (step > step_interval, so always True) and
# bail out before the step counter is reset.
self
.
_update_layer_should_record
(
log_stats
=
log_stats
)
return
return
self
.
expert_rearrangement_step
=
0
self
.
expert_rearrangement_step
=
0
self
.
rearrange
()
self
.
rearrange
()
self
.
_update_layer_should_record
(
log_stats
=
log_stats
)
def
_should_record_current_step
(
self
,
log_stats
:
bool
=
False
)
->
bool
:
"""Return whether expert-load recording should be enabled this step.
Recording is enabled when we are close to either:
1) The next rearrangement step, so the sliding window is ready.
2) The next balancedness logging step, when log_stats is enabled.
"""
steps_remaining
=
(
self
.
expert_rearrangement_step_interval
-
self
.
expert_rearrangement_step
)
should_record_for_rearrange
=
steps_remaining
<=
self
.
expert_load_window_size
if
not
log_stats
:
return
should_record_for_rearrange
log_interval
=
self
.
parallel_config
.
eplb_config
.
log_balancedness_interval
steps_until_next_log
=
(
log_interval
-
(
self
.
expert_rearrangement_step
%
log_interval
)
)
%
log_interval
should_record_for_log
=
steps_until_next_log
<=
self
.
expert_load_window_size
return
should_record_for_rearrange
or
should_record_for_log
def
_update_layer_should_record
(
self
,
log_stats
:
bool
=
False
)
->
None
:
"""Update the shared ``should_record_tensor`` for all layers."""
if
self
.
should_record_tensor
is
not
None
:
self
.
should_record_tensor
.
fill_
(
self
.
_should_record_current_step
(
log_stats
=
log_stats
)
)
def
_init_should_record_tensor
(
self
,
model
:
"MixtureOfExperts"
)
->
None
:
# type: ignore[name-defined]
"""Allocate (once) and propagate the shared ``should_record_tensor``.
Must be called after :meth:`model.set_eplb_state` so that each
layer's ``eplb_state`` is already populated with the tensor views.
"""
layer_states
=
[
layer
.
eplb_state
for
layer
in
model
.
moe_layers
if
hasattr
(
layer
,
"eplb_state"
)
and
isinstance
(
layer
.
eplb_state
,
EplbLayerState
)
]
if
self
.
should_record_tensor
is
None
and
layer_states
:
self
.
should_record_tensor
=
torch
.
ones
(
(),
dtype
=
torch
.
bool
,
device
=
self
.
device
)
for
ls
in
layer_states
:
ls
.
should_record_tensor
=
self
.
should_record_tensor
def
rearrange
(
def
rearrange
(
self
,
self
,
is_profile
:
bool
=
False
,
is_profile
:
bool
=
False
,
...
@@ -993,6 +1058,17 @@ class EplbLayerState:
...
@@ -993,6 +1058,17 @@ class EplbLayerState:
expert_load_view
:
torch
.
Tensor
|
None
=
None
expert_load_view
:
torch
.
Tensor
|
None
=
None
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
logical_replica_count
:
torch
.
Tensor
|
None
=
None
logical_replica_count
:
torch
.
Tensor
|
None
=
None
should_record_tensor
:
torch
.
Tensor
|
None
=
None
"""
Shared scalar bool tensor controlling whether to accumulate expert load
metrics during this forward pass. All layers reference the **same**
tensor object, which is owned and updated by :class:`EplbState`.
Set to ``False`` for the first ``step_interval - window_size`` steps of
each rearrangement period: those steps would be overwritten in the
sliding window before the next rearrangement, so recording them wastes
GPU work.
"""
def
_node_count_with_rank_mapping
(
def
_node_count_with_rank_mapping
(
...
...
vllm/model_executor/layers/fused_moe/router/base_router.py
View file @
12701e8a
...
@@ -10,61 +10,49 @@ from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
...
@@ -10,61 +10,49 @@ from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter
,
FusedMoERouter
,
)
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
if
current_platform
.
is_cuda_alike
():
if
current_platform
.
is_cuda_alike
():
@
torch
.
compile
(
dynamic
=
True
,
backend
=
current_platform
.
simple_compile_backend
)
@
triton
.
jit
def
eplb_map_to_physical_and_record
(
def
_eplb_map_and_record_i32_kernel
(
topk_ids
:
torch
.
Tensor
,
topk_ids_ptr
,
expert_load_view
:
torch
.
Tensor
,
logical_replica_count_ptr
,
logical_to_physical_map
:
torch
.
Tensor
,
logical_to_physical_ptr
,
logical_replica_count
:
torch
.
Tensor
,
out_ids_ptr
,
)
->
torch
.
Tensor
:
out_ptr
,
"""
record_enabled_ptr
,
Map the logical expert ids to physical expert ids
num_logical_experts
,
and record the expert load metrics.
map_slots
,
out_size
,
This will select a pseudo-random replica for each logical expert.
numel
,
Only used for EPLB.
BLOCK_SIZE
:
tl
.
constexpr
,
):
Args:
pid
=
tl
.
program_id
(
0
)
topk_ids: The logical expert ids.
offs
=
pid
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
expert_load_view: The expert load view.
mask
=
offs
<
numel
logical_to_physical_map: The logical to physical map.
logical_replica_count: The logical replica count.
Returns:
expert_id
=
tl
.
load
(
topk_ids_ptr
+
offs
,
mask
=
mask
,
other
=
0
).
to
(
tl
.
int64
)
The phys
ical
expert
ids.
valid_expert
=
(
expert_id
>=
0
)
&
(
expert_id
<
num_log
ical
_
expert
s
)
"""
safe_expert_id
=
tl
.
where
(
valid_expert
,
expert_id
,
0
)
# 1. Convert the logical expert ids to physical expert ids
# 1. Convert the logical expert ids to physical expert ids
# Directly select a random replica for each logical expert
# Directly select a random replica for each logical expert
replica_count
=
tl
.
load
(
# In case `indices_type` is not `torch.long` or `torch.int`,
logical_replica_count_ptr
+
safe_expert_id
,
# e.g. `torch.uint32` as required by dispatch/combine kernels
mask
=
mask
&
valid_expert
,
topk_ids_long
=
topk_ids
.
long
()
other
=
1
,
# Use (token position) modulo (replica count)
# to deterministically choose a replica
replica_count
=
logical_replica_count
[
topk_ids_long
]
# Flatten-position based index, reshaped back to `topk_ids` shape
pos_indices
=
torch
.
arange
(
topk_ids
.
numel
(),
device
=
topk_ids
.
device
,
dtype
=
torch
.
long
).
reshape_as
(
topk_ids
)
# Compute pseudo-random indices by modulo
replica_indices
=
(
pos_indices
%
replica_count
).
unsqueeze
(
-
1
)
physical_ids
=
(
logical_to_physical_map
[
topk_ids_long
]
.
gather
(
-
1
,
replica_indices
)
.
squeeze
(
-
1
)
)
)
# Avoid invalid modulo/div by forcing at least 1.
topk_ids
=
physical_ids
replica_count
=
tl
.
maximum
(
replica_count
,
1
)
# Match torch.compile path: use flattened token position.
replica_idx
=
offs
%
replica_count
# 2. Record expert load metrics.
# 2. Record expert load metrics.
# TODO(bowen): When using `FusedMoEModularKernel`, this
# TODO(bowen): When using `FusedMoEModularKernel`, this
# can be done in a more unified way, since
# can be done in a more unified way, since
# `FusedMoEPrepareAndFinalize
Modular
` will return the expert
# `FusedMoEPrepareAndFinalize` will return the expert
# token count, in some cases directly from the kernel.
# token count, in some cases directly from the kernel.
# However, now there are many code paths not using
# However, now there are many code paths not using
# the modular kernel, e.g. calling `fused_experts`,
# the modular kernel, e.g. calling `fused_experts`,
...
@@ -73,17 +61,63 @@ if current_platform.is_cuda_alike():
...
@@ -73,17 +61,63 @@ if current_platform.is_cuda_alike():
# If later refactor moved all the MoE kernel calls
# If later refactor moved all the MoE kernel calls
# to the modular kernel, we can move this logic there
# to the modular kernel, we can move this logic there
# to achieve better efficiency.
# to achieve better efficiency.
map_index
=
safe_expert_id
*
map_slots
+
replica_idx
physical_id
=
tl
.
load
(
logical_to_physical_ptr
+
map_index
,
mask
=
mask
&
valid_expert
,
other
=-
1
,
)
tl
.
store
(
out_ids_ptr
+
offs
,
physical_id
,
mask
=
mask
)
# `expert_load_view`: (num_physical_experts,)
record_enabled
=
tl
.
load
(
record_enabled_ptr
)
!=
0
valid
=
mask
&
record_enabled
&
(
physical_id
>=
0
)
&
(
physical_id
<
out_size
)
safe_physical_id
=
tl
.
where
(
physical_id
>=
0
,
physical_id
,
0
)
tl
.
atomic_add
(
out_ptr
+
safe_physical_id
,
1
,
mask
=
valid
)
# `torch.bincount` is not compilable, so use `scatter_add_` instead.
def
_eplb_map_and_record_triton
(
topk_ids_flatten
=
topk_ids
.
flatten
()
topk_ids
:
torch
.
Tensor
,
expert_load_view
.
scatter_add_
(
logical_to_physical_map
:
torch
.
Tensor
,
dim
=
0
,
logical_replica_count
:
torch
.
Tensor
,
index
=
topk_ids_flatten
.
long
(),
expert_load_view
:
torch
.
Tensor
,
src
=
torch
.
ones_like
(
topk_ids_flatten
).
to
(
expert_load_view
),
record_enabled
:
torch
.
Tensor
,
)
)
->
torch
.
Tensor
:
topk_ids_in
=
topk_ids
.
contiguous
().
to
(
dtype
=
torch
.
int32
)
numel
=
topk_ids_in
.
numel
()
if
numel
==
0
:
return
topk_ids
return
topk_ids
out_flat
=
torch
.
empty
((
numel
,),
device
=
topk_ids
.
device
,
dtype
=
topk_ids
.
dtype
)
grid
=
lambda
meta
:
(
triton
.
cdiv
(
numel
,
meta
[
"BLOCK_SIZE"
]),)
assert
expert_load_view
.
is_contiguous
()
_eplb_map_and_record_i32_kernel
[
grid
](
topk_ids_in
,
logical_replica_count
.
contiguous
(),
logical_to_physical_map
.
contiguous
(),
out_flat
,
expert_load_view
,
record_enabled
,
logical_replica_count
.
shape
[
0
],
logical_to_physical_map
.
shape
[
1
],
expert_load_view
.
shape
[
0
],
numel
,
BLOCK_SIZE
=
256
,
)
return
out_flat
.
reshape
(
topk_ids
.
shape
)
def
eplb_map_to_physical_and_record
(
topk_ids
:
torch
.
Tensor
,
expert_load_view
:
torch
.
Tensor
,
logical_to_physical_map
:
torch
.
Tensor
,
logical_replica_count
:
torch
.
Tensor
,
record_enabled
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# Fused triton implementation: mapping + optional recording in one kernel.
return
_eplb_map_and_record_triton
(
topk_ids
=
topk_ids
,
logical_to_physical_map
=
logical_to_physical_map
,
logical_replica_count
=
logical_replica_count
,
expert_load_view
=
expert_load_view
,
record_enabled
=
record_enabled
,
)
else
:
else
:
def
eplb_map_to_physical_and_record
(
def
eplb_map_to_physical_and_record
(
...
@@ -91,8 +125,8 @@ else:
...
@@ -91,8 +125,8 @@ else:
expert_load_view
:
torch
.
Tensor
,
expert_load_view
:
torch
.
Tensor
,
logical_to_physical_map
:
torch
.
Tensor
,
logical_to_physical_map
:
torch
.
Tensor
,
logical_replica_count
:
torch
.
Tensor
,
logical_replica_count
:
torch
.
Tensor
,
record_enabled
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# CPU fallback: no EPLB so just return as is
return
topk_ids
return
topk_ids
...
@@ -146,6 +180,10 @@ class BaseRouter(FusedMoERouter):
...
@@ -146,6 +180,10 @@ class BaseRouter(FusedMoERouter):
raise
ValueError
(
raise
ValueError
(
"enable_eplb=True requires logical_replica_count != None"
"enable_eplb=True requires logical_replica_count != None"
)
)
if
self
.
eplb_state
.
should_record_tensor
is
None
:
raise
ValueError
(
"enable_eplb=True requires should_record_tensor != None"
)
def
_get_indices_type
(
self
)
->
torch
.
dtype
|
None
:
def
_get_indices_type
(
self
)
->
torch
.
dtype
|
None
:
"""Get the desired indices dtype from the getter function."""
"""Get the desired indices dtype from the getter function."""
...
@@ -159,11 +197,13 @@ class BaseRouter(FusedMoERouter):
...
@@ -159,11 +197,13 @@ class BaseRouter(FusedMoERouter):
assert
self
.
eplb_state
.
expert_load_view
is
not
None
assert
self
.
eplb_state
.
expert_load_view
is
not
None
assert
self
.
eplb_state
.
logical_to_physical_map
is
not
None
assert
self
.
eplb_state
.
logical_to_physical_map
is
not
None
assert
self
.
eplb_state
.
logical_replica_count
is
not
None
assert
self
.
eplb_state
.
logical_replica_count
is
not
None
assert
self
.
eplb_state
.
should_record_tensor
is
not
None
return
eplb_map_to_physical_and_record
(
return
eplb_map_to_physical_and_record
(
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
expert_load_view
=
self
.
eplb_state
.
expert_load_view
,
logical_to_physical_map
=
self
.
eplb_state
.
logical_to_physical_map
,
logical_to_physical_map
=
self
.
eplb_state
.
logical_to_physical_map
,
logical_replica_count
=
self
.
eplb_state
.
logical_replica_count
,
logical_replica_count
=
self
.
eplb_state
.
logical_replica_count
,
expert_load_view
=
self
.
eplb_state
.
expert_load_view
,
record_enabled
=
self
.
eplb_state
.
should_record_tensor
,
)
)
return
topk_ids
return
topk_ids
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment