Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
12701e8a
Unverified
Commit
12701e8a
authored
Mar 30, 2026
by
Ilya Markov
Committed by
GitHub
Mar 30, 2026
Browse files
[EPLB] Optmize eplb mapping and record in router for prefill (#36261)
Signed-off-by:
ilmarkov
<
markovilya197@gmail.com
>
parent
494636b2
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
336 additions
and
64 deletions
+336
-64
tests/kernels/moe/test_routing.py
tests/kernels/moe/test_routing.py
+154
-0
tests/model_executor/test_routed_experts_capture.py
tests/model_executor/test_routed_experts_capture.py
+1
-0
vllm/config/parallel.py
vllm/config/parallel.py
+3
-3
vllm/distributed/elastic_ep/elastic_execute.py
vllm/distributed/elastic_ep/elastic_execute.py
+1
-0
vllm/distributed/eplb/eplb_state.py
vllm/distributed/eplb/eplb_state.py
+85
-9
vllm/model_executor/layers/fused_moe/router/base_router.py
vllm/model_executor/layers/fused_moe/router/base_router.py
+92
-52
No files found.
tests/kernels/moe/test_routing.py
View file @
12701e8a
...
...
@@ -8,6 +8,9 @@ import torch
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.distributed.eplb.eplb_state
import
EplbLayerState
from
vllm.model_executor.layers.fused_moe.router.base_router
import
(
eplb_map_to_physical_and_record
,
)
from
vllm.model_executor.layers.fused_moe.router.router_factory
import
(
create_fused_moe_router
,
)
...
...
@@ -55,11 +58,13 @@ def setup_eplb_state(enable_eplb: bool, global_num_experts: int) -> EplbLayerSta
logical_replica_count
=
torch
.
ones
(
global_num_experts
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
should_record_tensor
=
torch
.
ones
((),
dtype
=
torch
.
bool
,
device
=
"cuda"
)
return
EplbLayerState
(
expert_load_view
=
expert_load_view
,
logical_to_physical_map
=
logical_to_physical_map
,
logical_replica_count
=
logical_replica_count
,
should_record_tensor
=
should_record_tensor
,
)
...
...
@@ -581,3 +586,152 @@ def test_custom(
# hidden_states, router_logits = make_test_data(m, k, global_num_experts)
# topk_weights, topk_ids = router.select_experts(hidden_states, router_logits)
# ---------------------------------------------------------------------------
# Tests for eplb_map_to_physical_and_record
# ---------------------------------------------------------------------------
@
pytest
.
mark
.
parametrize
(
"record_enabled"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"l2p_map, replica_count, num_physical, topk_ids, expected_out, expected_load"
,
[
pytest
.
param
(
# logical i → physical i
[[
0
],
[
1
],
[
2
],
[
3
]],
[
1
,
1
,
1
,
1
],
4
,
[[
0
,
1
],
[
2
,
3
],
[
0
,
2
]],
[[
0
,
1
],
[
2
,
3
],
[
0
,
2
]],
[
2
,
1
,
2
,
1
],
id
=
"identity"
,
),
pytest
.
param
(
# logical 0→3, 1→0, 2→1, 3→2
[[
3
],
[
0
],
[
1
],
[
2
]],
[
1
,
1
,
1
,
1
],
4
,
[[
0
,
1
],
[
2
,
3
],
[
0
,
2
]],
[[
3
,
0
],
[
1
,
2
],
[
3
,
1
]],
[
1
,
2
,
1
,
2
],
id
=
"shuffled"
,
),
pytest
.
param
(
# logical 0→5, 1→2, 2→7, 3→0 in a larger physical space
[[
5
],
[
2
],
[
7
],
[
0
]],
[
1
,
1
,
1
,
1
],
8
,
[[
0
,
1
],
[
2
,
3
]],
[[
5
,
2
],
[
7
,
0
]],
[
1
,
0
,
1
,
0
,
0
,
1
,
0
,
1
],
id
=
"sparse"
,
),
],
)
def
test_eplb_map_no_redundancy
(
record_enabled
,
l2p_map
,
replica_count
,
num_physical
,
topk_ids
,
expected_out
,
expected_load
,
):
l2p
=
torch
.
tensor
(
l2p_map
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
rc
=
torch
.
tensor
(
replica_count
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
load
=
torch
.
zeros
(
num_physical
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
rec
=
torch
.
tensor
(
record_enabled
,
dtype
=
torch
.
bool
,
device
=
"cuda"
)
ids
=
torch
.
tensor
(
topk_ids
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
out
=
eplb_map_to_physical_and_record
(
topk_ids
=
ids
,
expert_load_view
=
load
,
logical_to_physical_map
=
l2p
,
logical_replica_count
=
rc
,
record_enabled
=
rec
,
)
exp_out
=
torch
.
tensor
(
expected_out
,
dtype
=
out
.
dtype
,
device
=
"cuda"
)
torch
.
testing
.
assert_close
(
out
,
exp_out
)
if
record_enabled
:
exp_load
=
torch
.
tensor
(
expected_load
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
torch
.
testing
.
assert_close
(
load
,
exp_load
)
else
:
assert
load
.
sum
().
item
()
==
0
@
pytest
.
mark
.
parametrize
(
"record_enabled"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"l2p_map, replica_count, num_physical, topk_ids, expected_out, expected_load"
,
[
pytest
.
param
(
# experts 0,1 have 2 replicas; 2,3 have 1
[[
0
,
4
],
[
1
,
5
],
[
2
,
-
1
],
[
3
,
-
1
]],
[
2
,
2
,
1
,
1
],
6
,
[[
0
,
1
],
[
2
,
3
],
[
0
,
2
]],
# offs: 0→0%2=0→p0, 1→1%2=1→p5, 2→2%1=0→p2,
# 3→3%1=0→p3, 4→4%2=0→p0, 5→5%1=0→p2
[[
0
,
5
],
[
2
,
3
],
[
0
,
2
]],
[
2
,
0
,
2
,
1
,
0
,
1
],
id
=
"partial"
,
),
pytest
.
param
(
# all 4 experts have 2 replicas
[[
0
,
4
],
[
1
,
5
],
[
2
,
6
],
[
3
,
7
]],
[
2
,
2
,
2
,
2
],
8
,
[[
0
,
1
],
[
2
,
3
],
[
0
,
2
]],
# offs: 0→0%2=0→p0, 1→1%2=1→p5, 2→2%2=0→p2,
# 3→3%2=1→p7, 4→4%2=0→p0, 5→5%2=1→p6
[[
0
,
5
],
[
2
,
7
],
[
0
,
6
]],
[
2
,
0
,
1
,
0
,
0
,
1
,
1
,
1
],
id
=
"full"
,
),
pytest
.
param
(
# expert 0: 4 replicas, experts 1,2: 2 replicas
[[
0
,
3
,
5
,
7
],
[
1
,
4
,
-
1
,
-
1
],
[
2
,
6
,
-
1
,
-
1
]],
[
4
,
2
,
2
],
8
,
[[
0
,
1
],
[
2
,
0
],
[
1
,
2
]],
# offs: 0→0%4=0→p0, 1→1%2=1→p4, 2→2%2=0→p2,
# 3→3%4=3→p7, 4→4%2=0→p1, 5→5%2=1→p6
[[
0
,
4
],
[
2
,
7
],
[
1
,
6
]],
[
1
,
1
,
1
,
0
,
1
,
0
,
1
,
1
],
id
=
"uneven"
,
),
],
)
def
test_eplb_map_with_redundancy
(
record_enabled
,
l2p_map
,
replica_count
,
num_physical
,
topk_ids
,
expected_out
,
expected_load
,
):
l2p
=
torch
.
tensor
(
l2p_map
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
rc
=
torch
.
tensor
(
replica_count
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
load
=
torch
.
zeros
(
num_physical
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
rec
=
torch
.
tensor
(
record_enabled
,
dtype
=
torch
.
bool
,
device
=
"cuda"
)
ids
=
torch
.
tensor
(
topk_ids
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
out
=
eplb_map_to_physical_and_record
(
topk_ids
=
ids
,
expert_load_view
=
load
,
logical_to_physical_map
=
l2p
,
logical_replica_count
=
rc
,
record_enabled
=
rec
,
)
exp_out
=
torch
.
tensor
(
expected_out
,
dtype
=
out
.
dtype
,
device
=
"cuda"
)
torch
.
testing
.
assert_close
(
out
,
exp_out
)
if
record_enabled
:
exp_load
=
torch
.
tensor
(
expected_load
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
torch
.
testing
.
assert_close
(
load
,
exp_load
)
else
:
assert
load
.
sum
().
item
()
==
0
tests/model_executor/test_routed_experts_capture.py
View file @
12701e8a
...
...
@@ -62,6 +62,7 @@ def test_base_router_capture_with_eplb_enabled():
router
.
eplb_state
.
expert_load_view
=
torch
.
zeros
(
32
,
dtype
=
torch
.
int64
)
router
.
eplb_state
.
logical_to_physical_map
=
torch
.
arange
(
32
).
view
(
32
,
1
)
router
.
eplb_state
.
logical_replica_count
=
torch
.
ones
(
32
,
dtype
=
torch
.
int64
)
router
.
eplb_state
.
should_record_tensor
=
torch
.
ones
((),
dtype
=
torch
.
bool
)
captured
=
[]
...
...
vllm/config/parallel.py
View file @
12701e8a
...
...
@@ -53,9 +53,9 @@ All2AllBackend = Literal[
class
EPLBConfig
:
"""Configuration for Expert Parallel Load Balancing (EP)."""
window_size
:
int
=
1000
window_size
:
int
=
Field
(
default
=
1000
,
gt
=
0
)
"""Window size for expert load recording."""
step_interval
:
int
=
3000
step_interval
:
int
=
Field
(
default
=
3000
,
gt
=
0
)
"""
Interval for rearranging experts in expert parallelism.
...
...
@@ -71,7 +71,7 @@ class EPLBConfig:
Log the balancedness each step of expert parallelism.
This is turned off by default since it will cause communication overhead.
"""
log_balancedness_interval
:
int
=
1
log_balancedness_interval
:
int
=
Field
(
default
=
1
,
gt
=
0
)
"""
Interval for logging the balancedness.
"""
...
...
vllm/distributed/elastic_ep/elastic_execute.py
View file @
12701e8a
...
...
@@ -399,6 +399,7 @@ class ElasticEPScalingExecutor:
eplb_model_state
.
logical_to_physical_map
,
eplb_model_state
.
logical_replica_count
,
)
eplb_state
.
_init_should_record_tensor
(
model
)
model
.
update_physical_experts_metadata
(
num_physical_experts
=
num_physical_experts
,
num_local_physical_experts
=
num_local_experts
,
...
...
vllm/distributed/eplb/eplb_state.py
View file @
12701e8a
...
...
@@ -272,6 +272,13 @@ class EplbState:
Interval for expert rearrangement steps.
This is a constant and is taken from the config.
"""
self
.
should_record_tensor
:
torch
.
Tensor
|
None
=
None
"""
Shared scalar bool tensor for all layers. Every
:class:`EplbLayerState` holds a reference to the **same** object so
a single ``.fill_()`` updates all layers at once. Allocated on the
first call to :meth:`_init_should_record_tensor`.
"""
self
.
is_async
:
bool
=
False
"""
The flag indicates whether the EPLB is running in async mode.
...
...
@@ -462,7 +469,7 @@ class EplbState:
logical_to_physical_map
,
logical_replica_count
,
)
self
.
_init_should_record_tensor
(
model
)
expert_buffer
=
[
torch
.
empty_like
(
w
)
for
w
in
model
.
expert_weights
[
0
]]
model_state
=
EplbModelState
(
...
...
@@ -582,12 +589,15 @@ class EplbState:
# Update the expert load sliding window
if
not
is_dummy
:
should_record
=
self
.
_should_record_current_step
(
log_stats
=
log_stats
)
for
eplb_model_state
in
self
.
model_states
.
values
():
eplb_model_state
.
expert_load_window
[
self
.
expert_load_window_step
]
=
(
eplb_model_state
.
expert_load_pass
.
clone
()
)
if
should_record
:
eplb_model_state
.
expert_load_window
[
self
.
expert_load_window_step
].
copy_
(
eplb_model_state
.
expert_load_pass
)
eplb_model_state
.
expert_load_pass
.
zero_
()
if
should_record
:
self
.
expert_load_window_step
+=
1
if
self
.
expert_load_window_step
>=
self
.
expert_load_window_size
:
self
.
expert_load_window_step
=
0
...
...
@@ -617,11 +627,66 @@ class EplbState:
eplb_model_state
.
rebalanced
for
eplb_model_state
in
self
.
model_states
.
values
()
):
# Still performing asynchronous rearrangement
# Still performing asynchronous rearrangement; update
# should_record (step > step_interval, so always True) and
# bail out before the step counter is reset.
self
.
_update_layer_should_record
(
log_stats
=
log_stats
)
return
self
.
expert_rearrangement_step
=
0
self
.
rearrange
()
self
.
_update_layer_should_record
(
log_stats
=
log_stats
)
def
_should_record_current_step
(
self
,
log_stats
:
bool
=
False
)
->
bool
:
"""Return whether expert-load recording should be enabled this step.
Recording is enabled when we are close to either:
1) The next rearrangement step, so the sliding window is ready.
2) The next balancedness logging step, when log_stats is enabled.
"""
steps_remaining
=
(
self
.
expert_rearrangement_step_interval
-
self
.
expert_rearrangement_step
)
should_record_for_rearrange
=
steps_remaining
<=
self
.
expert_load_window_size
if
not
log_stats
:
return
should_record_for_rearrange
log_interval
=
self
.
parallel_config
.
eplb_config
.
log_balancedness_interval
steps_until_next_log
=
(
log_interval
-
(
self
.
expert_rearrangement_step
%
log_interval
)
)
%
log_interval
should_record_for_log
=
steps_until_next_log
<=
self
.
expert_load_window_size
return
should_record_for_rearrange
or
should_record_for_log
def
_update_layer_should_record
(
self
,
log_stats
:
bool
=
False
)
->
None
:
"""Update the shared ``should_record_tensor`` for all layers."""
if
self
.
should_record_tensor
is
not
None
:
self
.
should_record_tensor
.
fill_
(
self
.
_should_record_current_step
(
log_stats
=
log_stats
)
)
def
_init_should_record_tensor
(
self
,
model
:
"MixtureOfExperts"
)
->
None
:
# type: ignore[name-defined]
"""Allocate (once) and propagate the shared ``should_record_tensor``.
Must be called after :meth:`model.set_eplb_state` so that each
layer's ``eplb_state`` is already populated with the tensor views.
"""
layer_states
=
[
layer
.
eplb_state
for
layer
in
model
.
moe_layers
if
hasattr
(
layer
,
"eplb_state"
)
and
isinstance
(
layer
.
eplb_state
,
EplbLayerState
)
]
if
self
.
should_record_tensor
is
None
and
layer_states
:
self
.
should_record_tensor
=
torch
.
ones
(
(),
dtype
=
torch
.
bool
,
device
=
self
.
device
)
for
ls
in
layer_states
:
ls
.
should_record_tensor
=
self
.
should_record_tensor
def
rearrange
(
self
,
is_profile
:
bool
=
False
,
...
...
@@ -993,6 +1058,17 @@ class EplbLayerState:
expert_load_view
:
torch
.
Tensor
|
None
=
None
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
logical_replica_count
:
torch
.
Tensor
|
None
=
None
should_record_tensor
:
torch
.
Tensor
|
None
=
None
"""
Shared scalar bool tensor controlling whether to accumulate expert load
metrics during this forward pass. All layers reference the **same**
tensor object, which is owned and updated by :class:`EplbState`.
Set to ``False`` for the first ``step_interval - window_size`` steps of
each rearrangement period: those steps would be overwritten in the
sliding window before the next rearrangement, so recording them wastes
GPU work.
"""
def
_node_count_with_rank_mapping
(
...
...
vllm/model_executor/layers/fused_moe/router/base_router.py
View file @
12701e8a
...
...
@@ -10,61 +10,49 @@ from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter
,
)
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
if
current_platform
.
is_cuda_alike
():
@
torch
.
compile
(
dynamic
=
True
,
backend
=
current_platform
.
simple_compile_backend
)
def
eplb_map_to_physical_and_record
(
topk_ids
:
torch
.
Tensor
,
expert_load_view
:
torch
.
Tensor
,
logical_to_physical_map
:
torch
.
Tensor
,
logical_replica_count
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""
Map the logical expert ids to physical expert ids
and record the expert load metrics.
This will select a pseudo-random replica for each logical expert.
Only used for EPLB.
Args:
topk_ids: The logical expert ids.
expert_load_view: The expert load view.
logical_to_physical_map: The logical to physical map.
logical_replica_count: The logical replica count.
@
triton
.
jit
def
_eplb_map_and_record_i32_kernel
(
topk_ids_ptr
,
logical_replica_count_ptr
,
logical_to_physical_ptr
,
out_ids_ptr
,
out_ptr
,
record_enabled_ptr
,
num_logical_experts
,
map_slots
,
out_size
,
numel
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
offs
=
pid
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
offs
<
numel
Returns:
The phys
ical
expert
ids.
"""
expert_id
=
tl
.
load
(
topk_ids_ptr
+
offs
,
mask
=
mask
,
other
=
0
).
to
(
tl
.
int64
)
valid_expert
=
(
expert_id
>=
0
)
&
(
expert_id
<
num_log
ical
_
expert
s
)
safe_expert_id
=
tl
.
where
(
valid_expert
,
expert_id
,
0
)
# 1. Convert the logical expert ids to physical expert ids
# Directly select a random replica for each logical expert
# In case `indices_type` is not `torch.long` or `torch.int`,
# e.g. `torch.uint32` as required by dispatch/combine kernels
topk_ids_long
=
topk_ids
.
long
()
# Use (token position) modulo (replica count)
# to deterministically choose a replica
replica_count
=
logical_replica_count
[
topk_ids_long
]
# Flatten-position based index, reshaped back to `topk_ids` shape
pos_indices
=
torch
.
arange
(
topk_ids
.
numel
(),
device
=
topk_ids
.
device
,
dtype
=
torch
.
long
).
reshape_as
(
topk_ids
)
# Compute pseudo-random indices by modulo
replica_indices
=
(
pos_indices
%
replica_count
).
unsqueeze
(
-
1
)
physical_ids
=
(
logical_to_physical_map
[
topk_ids_long
]
.
gather
(
-
1
,
replica_indices
)
.
squeeze
(
-
1
)
replica_count
=
tl
.
load
(
logical_replica_count_ptr
+
safe_expert_id
,
mask
=
mask
&
valid_expert
,
other
=
1
,
)
topk_ids
=
physical_ids
# Avoid invalid modulo/div by forcing at least 1.
replica_count
=
tl
.
maximum
(
replica_count
,
1
)
# Match torch.compile path: use flattened token position.
replica_idx
=
offs
%
replica_count
# 2. Record expert load metrics.
# TODO(bowen): When using `FusedMoEModularKernel`, this
# can be done in a more unified way, since
# `FusedMoEPrepareAndFinalize
Modular
` will return the expert
# `FusedMoEPrepareAndFinalize` will return the expert
# token count, in some cases directly from the kernel.
# However, now there are many code paths not using
# the modular kernel, e.g. calling `fused_experts`,
...
...
@@ -73,17 +61,63 @@ if current_platform.is_cuda_alike():
# If later refactor moved all the MoE kernel calls
# to the modular kernel, we can move this logic there
# to achieve better efficiency.
map_index
=
safe_expert_id
*
map_slots
+
replica_idx
physical_id
=
tl
.
load
(
logical_to_physical_ptr
+
map_index
,
mask
=
mask
&
valid_expert
,
other
=-
1
,
)
tl
.
store
(
out_ids_ptr
+
offs
,
physical_id
,
mask
=
mask
)
# `expert_load_view`: (num_physical_experts,)
record_enabled
=
tl
.
load
(
record_enabled_ptr
)
!=
0
valid
=
mask
&
record_enabled
&
(
physical_id
>=
0
)
&
(
physical_id
<
out_size
)
safe_physical_id
=
tl
.
where
(
physical_id
>=
0
,
physical_id
,
0
)
tl
.
atomic_add
(
out_ptr
+
safe_physical_id
,
1
,
mask
=
valid
)
# `torch.bincount` is not compilable, so use `scatter_add_` instead.
topk_ids_flatten
=
topk_ids
.
flatten
()
expert_load_view
.
scatter_add_
(
dim
=
0
,
index
=
topk_ids_flatten
.
long
(),
src
=
torch
.
ones_like
(
topk_ids_flatten
).
to
(
expert_load_view
),
)
def
_eplb_map_and_record_triton
(
topk_ids
:
torch
.
Tensor
,
logical_to_physical_map
:
torch
.
Tensor
,
logical_replica_count
:
torch
.
Tensor
,
expert_load_view
:
torch
.
Tensor
,
record_enabled
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
topk_ids_in
=
topk_ids
.
contiguous
().
to
(
dtype
=
torch
.
int32
)
numel
=
topk_ids_in
.
numel
()
if
numel
==
0
:
return
topk_ids
out_flat
=
torch
.
empty
((
numel
,),
device
=
topk_ids
.
device
,
dtype
=
topk_ids
.
dtype
)
grid
=
lambda
meta
:
(
triton
.
cdiv
(
numel
,
meta
[
"BLOCK_SIZE"
]),)
assert
expert_load_view
.
is_contiguous
()
_eplb_map_and_record_i32_kernel
[
grid
](
topk_ids_in
,
logical_replica_count
.
contiguous
(),
logical_to_physical_map
.
contiguous
(),
out_flat
,
expert_load_view
,
record_enabled
,
logical_replica_count
.
shape
[
0
],
logical_to_physical_map
.
shape
[
1
],
expert_load_view
.
shape
[
0
],
numel
,
BLOCK_SIZE
=
256
,
)
return
out_flat
.
reshape
(
topk_ids
.
shape
)
def
eplb_map_to_physical_and_record
(
topk_ids
:
torch
.
Tensor
,
expert_load_view
:
torch
.
Tensor
,
logical_to_physical_map
:
torch
.
Tensor
,
logical_replica_count
:
torch
.
Tensor
,
record_enabled
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# Fused triton implementation: mapping + optional recording in one kernel.
return
_eplb_map_and_record_triton
(
topk_ids
=
topk_ids
,
logical_to_physical_map
=
logical_to_physical_map
,
logical_replica_count
=
logical_replica_count
,
expert_load_view
=
expert_load_view
,
record_enabled
=
record_enabled
,
)
else
:
def
eplb_map_to_physical_and_record
(
...
...
@@ -91,8 +125,8 @@ else:
expert_load_view
:
torch
.
Tensor
,
logical_to_physical_map
:
torch
.
Tensor
,
logical_replica_count
:
torch
.
Tensor
,
record_enabled
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# CPU fallback: no EPLB so just return as is
return
topk_ids
...
...
@@ -146,6 +180,10 @@ class BaseRouter(FusedMoERouter):
raise
ValueError
(
"enable_eplb=True requires logical_replica_count != None"
)
if
self
.
eplb_state
.
should_record_tensor
is
None
:
raise
ValueError
(
"enable_eplb=True requires should_record_tensor != None"
)
def
_get_indices_type
(
self
)
->
torch
.
dtype
|
None
:
"""Get the desired indices dtype from the getter function."""
...
...
@@ -159,11 +197,13 @@ class BaseRouter(FusedMoERouter):
assert
self
.
eplb_state
.
expert_load_view
is
not
None
assert
self
.
eplb_state
.
logical_to_physical_map
is
not
None
assert
self
.
eplb_state
.
logical_replica_count
is
not
None
assert
self
.
eplb_state
.
should_record_tensor
is
not
None
return
eplb_map_to_physical_and_record
(
topk_ids
=
topk_ids
,
expert_load_view
=
self
.
eplb_state
.
expert_load_view
,
logical_to_physical_map
=
self
.
eplb_state
.
logical_to_physical_map
,
logical_replica_count
=
self
.
eplb_state
.
logical_replica_count
,
expert_load_view
=
self
.
eplb_state
.
expert_load_view
,
record_enabled
=
self
.
eplb_state
.
should_record_tensor
,
)
return
topk_ids
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment