Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3173441b
Unverified
Commit
3173441b
authored
Apr 20, 2026
by
Sage Moore
Committed by
GitHub
Apr 20, 2026
Browse files
[EPLB] Consolidate is_unchanged/is_received_locally into TransferMetadata (#37341)
Signed-off-by:
Sage Moore
<
sage@neuralmagic.com
>
parent
8b1f3beb
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
36 additions
and
64 deletions
+36
-64
tests/distributed/test_eplb_execute.py
tests/distributed/test_eplb_execute.py
+2
-4
vllm/distributed/eplb/async_worker.py
vllm/distributed/eplb/async_worker.py
+2
-8
vllm/distributed/eplb/eplb_state.py
vllm/distributed/eplb/eplb_state.py
+1
-3
vllm/distributed/eplb/rebalance_execute.py
vllm/distributed/eplb/rebalance_execute.py
+31
-49
No files found.
tests/distributed/test_eplb_execute.py
View file @
3173441b
...
@@ -361,7 +361,7 @@ def _test_async_transfer_layer_without_mtp_worker(
...
@@ -361,7 +361,7 @@ def _test_async_transfer_layer_without_mtp_worker(
communicator
.
set_stream
(
cuda_stream
)
communicator
.
set_stream
(
cuda_stream
)
for
layer_idx
in
range
(
num_layers
):
for
layer_idx
in
range
(
num_layers
):
is_unchanged
,
is_received_locally
,
recv
_metadata
=
asyncio
.
run
(
transfer
_metadata
=
asyncio
.
run
(
transfer_layer
(
transfer_layer
(
old_layer_indices
=
old_indices_cpu
[
layer_idx
],
old_layer_indices
=
old_indices_cpu
[
layer_idx
],
new_layer_indices
=
new_indices_cpu
[
layer_idx
],
new_layer_indices
=
new_indices_cpu
[
layer_idx
],
...
@@ -376,9 +376,7 @@ def _test_async_transfer_layer_without_mtp_worker(
...
@@ -376,9 +376,7 @@ def _test_async_transfer_layer_without_mtp_worker(
move_from_buffer
(
move_from_buffer
(
expert_weights
=
expert_weights
[
layer_idx
],
expert_weights
=
expert_weights
[
layer_idx
],
expert_weights_buffers
=
expert_buffer
,
expert_weights_buffers
=
expert_buffer
,
is_unchanged
=
is_unchanged
,
transfer_metadata
=
transfer_metadata
,
is_received_locally
=
is_received_locally
,
recv_metadata
=
recv_metadata
,
new_indices
=
new_indices_cpu
[
layer_idx
].
numpy
(),
new_indices
=
new_indices_cpu
[
layer_idx
].
numpy
(),
ep_rank
=
ep_rank
,
ep_rank
=
ep_rank
,
)
)
...
...
vllm/distributed/eplb/async_worker.py
View file @
3173441b
...
@@ -118,11 +118,7 @@ async def transfer_run_periodically(
...
@@ -118,11 +118,7 @@ async def transfer_run_periodically(
# model_state.expert_buffer, which will be consumed by the main thread in
# model_state.expert_buffer, which will be consumed by the main thread in
# move_to_workspace
# move_to_workspace
while
model_state
.
rebalanced
and
layer_idx
<
num_layers
:
while
model_state
.
rebalanced
and
layer_idx
<
num_layers
:
(
transfer_metadata
=
await
transfer_layer
(
is_unchanged
,
is_received_locally
,
recv_metadata
,
)
=
await
transfer_layer
(
old_layer_indices
=
physical_to_logical_map_cpu
[
layer_idx
],
old_layer_indices
=
physical_to_logical_map_cpu
[
layer_idx
],
new_layer_indices
=
new_physical_to_logical_map
[
layer_idx
],
new_layer_indices
=
new_physical_to_logical_map
[
layer_idx
],
expert_weights
=
model_state
.
model
.
expert_weights
[
layer_idx
],
expert_weights
=
model_state
.
model
.
expert_weights
[
layer_idx
],
...
@@ -145,9 +141,7 @@ async def transfer_run_periodically(
...
@@ -145,9 +141,7 @@ async def transfer_run_periodically(
model_state
.
pending_result
=
AsyncEplbLayerResult
(
model_state
.
pending_result
=
AsyncEplbLayerResult
(
layer_idx
=
layer_idx
,
layer_idx
=
layer_idx
,
new_physical_to_logical_map
=
new_physical_to_logical_map
[
layer_idx
],
new_physical_to_logical_map
=
new_physical_to_logical_map
[
layer_idx
],
is_unchanged
=
is_unchanged
,
transfer_metadata
=
transfer_metadata
,
is_received_locally
=
is_received_locally
,
recv_metadata
=
recv_metadata
,
consumed_event
=
consumed_event
,
consumed_event
=
consumed_event
,
)
)
...
...
vllm/distributed/eplb/eplb_state.py
View file @
3173441b
...
@@ -1147,9 +1147,7 @@ def _move_to_workspace(
...
@@ -1147,9 +1147,7 @@ def _move_to_workspace(
move_from_buffer
(
move_from_buffer
(
expert_weights
=
model_state
.
model
.
expert_weights
[
result
.
layer_idx
],
expert_weights
=
model_state
.
model
.
expert_weights
[
result
.
layer_idx
],
expert_weights_buffers
=
model_state
.
expert_buffer
,
expert_weights_buffers
=
model_state
.
expert_buffer
,
is_unchanged
=
result
.
is_unchanged
,
transfer_metadata
=
result
.
transfer_metadata
,
is_received_locally
=
result
.
is_received_locally
,
recv_metadata
=
result
.
recv_metadata
,
new_indices
=
result
.
new_physical_to_logical_map
.
numpy
(),
new_indices
=
result
.
new_physical_to_logical_map
.
numpy
(),
ep_rank
=
ep_rank
,
ep_rank
=
ep_rank
,
)
)
...
...
vllm/distributed/eplb/rebalance_execute.py
View file @
3173441b
...
@@ -21,9 +21,13 @@ logger = init_logger(__name__)
...
@@ -21,9 +21,13 @@ logger = init_logger(__name__)
@
dataclass
@
dataclass
class
Recv
Metadata
:
class
Transfer
Metadata
:
"""Metadata describing
remote receives during EPLB rebalancing
."""
"""Metadata describing
a completed EPLB buffer transfer
."""
is_unchanged
:
np
.
ndarray
"""Mask of (num_local_experts,) indicating experts unchanged after rebalance."""
is_received_locally
:
np
.
ndarray
"""Mask of (num_local_experts,) indicating experts received from local data."""
recv_primary_mask
:
np
.
ndarray
recv_primary_mask
:
np
.
ndarray
"""Mask of (num_local_experts,) indicating primary experts received."""
"""Mask of (num_local_experts,) indicating primary experts received."""
recv_count
:
int
recv_count
:
int
...
@@ -34,10 +38,6 @@ class RecvMetadata:
...
@@ -34,10 +38,6 @@ class RecvMetadata:
"""Target expert indices (num_local_experts,) in local tensors to send."""
"""Target expert indices (num_local_experts,) in local tensors to send."""
# Type alias for the result of move_to_buffer or transfer_layer
MoveToBufferResult
=
tuple
[
np
.
ndarray
,
np
.
ndarray
,
RecvMetadata
]
@
dataclass
@
dataclass
class
AsyncEplbLayerResult
:
class
AsyncEplbLayerResult
:
"""
"""
...
@@ -51,11 +51,7 @@ class AsyncEplbLayerResult:
...
@@ -51,11 +51,7 @@ class AsyncEplbLayerResult:
New physical→logical mapping for layers_idx, on CPU.
New physical→logical mapping for layers_idx, on CPU.
Shape: (num_physical_experts)
Shape: (num_physical_experts)
"""
"""
is_unchanged
:
np
.
ndarray
transfer_metadata
:
TransferMetadata
"""Per-physical-expert flag: weight was not moved during transfer."""
is_received_locally
:
np
.
ndarray
"""Per-physical-expert flag: weight was received on this rank."""
recv_metadata
:
RecvMetadata
"""Metadata describing what was received during transfer_layer."""
"""Metadata describing what was received during transfer_layer."""
consumed_event
:
CpuGpuEvent
consumed_event
:
CpuGpuEvent
"""
"""
...
@@ -182,7 +178,7 @@ def move_to_buffer(
...
@@ -182,7 +178,7 @@ def move_to_buffer(
cuda_stream
:
torch
.
cuda
.
Stream
|
None
,
cuda_stream
:
torch
.
cuda
.
Stream
|
None
,
ep_rank
:
int
,
ep_rank
:
int
,
communicator
:
EplbCommunicator
,
communicator
:
EplbCommunicator
,
)
->
MoveToBufferResult
:
)
->
TransferMetadata
:
"""
"""
Rearranges expert weights during EPLB rebalancing.
Rearranges expert weights during EPLB rebalancing.
...
@@ -199,11 +195,7 @@ def move_to_buffer(
...
@@ -199,11 +195,7 @@ def move_to_buffer(
communicator: EplbCommunicator instance for P2P communication.
communicator: EplbCommunicator instance for P2P communication.
Returns:
Returns:
is_unchanged (np.ndarray): (num_local_experts,), True where an expert row
TransferMetadata: Metadata needed for completing remote weight transfers.
is unchanged after rebalance.
is_received_locally (np.ndarray): (num_local_experts,), True where a row
can be updated from local data.
RecvMetadata: Metadata needed for completing remote weight transfers.
"""
"""
assert
old_indices
.
shape
==
new_indices
.
shape
assert
old_indices
.
shape
==
new_indices
.
shape
recv_primary_mask
=
np
.
zeros
((
num_local_experts
,),
dtype
=
np
.
bool_
)
recv_primary_mask
=
np
.
zeros
((
num_local_experts
,),
dtype
=
np
.
bool_
)
...
@@ -339,24 +331,20 @@ def move_to_buffer(
...
@@ -339,24 +331,20 @@ def move_to_buffer(
# 4. Execute the P2P operations. The real communication happens here.
# 4. Execute the P2P operations. The real communication happens here.
communicator
.
execute
()
communicator
.
execute
()
# wait for the communication to finish
# wait for the communication to finish
return
(
return
TransferMetadata
(
is_unchanged
,
is_unchanged
=
is_unchanged
,
is_received_locally
,
is_received_locally
=
is_received_locally
,
RecvMetadata
(
recv_primary_mask
=
recv_primary_mask
,
recv_primary_mask
=
recv_primary_mask
,
recv_count
=
recv_count
,
recv_count
=
recv_count
,
recv_expert_ids
=
recv_expert_ids
,
recv_expert_ids
=
recv_expert_ids
,
recv_dst_rows
=
recv_dst_rows
,
recv_dst_rows
=
recv_dst_rows
,
),
)
)
def
move_from_buffer
(
def
move_from_buffer
(
expert_weights
:
Sequence
[
torch
.
Tensor
],
expert_weights
:
Sequence
[
torch
.
Tensor
],
expert_weights_buffers
:
list
[
torch
.
Tensor
],
expert_weights_buffers
:
list
[
torch
.
Tensor
],
is_unchanged
:
np
.
ndarray
,
transfer_metadata
:
TransferMetadata
,
is_received_locally
:
np
.
ndarray
,
recv_metadata
:
RecvMetadata
,
new_indices
:
np
.
ndarray
,
new_indices
:
np
.
ndarray
,
ep_rank
:
int
,
ep_rank
:
int
,
)
->
None
:
)
->
None
:
...
@@ -368,17 +356,17 @@ def move_from_buffer(
...
@@ -368,17 +356,17 @@ def move_from_buffer(
expert_weights: List of the actual MoE layer weights used in the execution.
expert_weights: List of the actual MoE layer weights used in the execution.
expert_weights_buffers: Intermediate buffers containing the experts weights
expert_weights_buffers: Intermediate buffers containing the experts weights
after the transfer is completed.
after the transfer is completed.
is_unchanged: (num_local_experts,), True where an expert row is unchanged.
transfer_metadata: TransferMetadata containing transfer metadata.
is_received_locally: (num_local_experts,), True where a row is updated locally.
recv_metadata: RecvMetadata containing remote receive metadata.
new_indices: (num_experts_total,) mapping from local rows to desired
new_indices: (num_experts_total,) mapping from local rows to desired
(possibly global) expert id, after rebalance.
(possibly global) expert id, after rebalance.
ep_rank: Rank of the process in the expert parallel group.
ep_rank: Rank of the process in the expert parallel group.
"""
"""
recv_primary_mask
=
recv_metadata
.
recv_primary_mask
is_unchanged
=
transfer_metadata
.
is_unchanged
recv_count
=
recv_metadata
.
recv_count
is_received_locally
=
transfer_metadata
.
is_received_locally
recv_expert_ids
=
recv_metadata
.
recv_expert_ids
recv_primary_mask
=
transfer_metadata
.
recv_primary_mask
recv_dst_rows
=
recv_metadata
.
recv_dst_rows
recv_count
=
transfer_metadata
.
recv_count
recv_expert_ids
=
transfer_metadata
.
recv_expert_ids
recv_dst_rows
=
transfer_metadata
.
recv_dst_rows
num_local_experts
=
is_unchanged
.
shape
[
0
]
num_local_experts
=
is_unchanged
.
shape
[
0
]
# Mask for rows to copy back from buffers:
# Mask for rows to copy back from buffers:
...
@@ -440,7 +428,7 @@ async def transfer_layer(
...
@@ -440,7 +428,7 @@ async def transfer_layer(
is_profile
:
bool
=
False
,
is_profile
:
bool
=
False
,
cuda_stream
:
torch
.
cuda
.
Stream
|
None
=
None
,
cuda_stream
:
torch
.
cuda
.
Stream
|
None
=
None
,
rank_mapping
:
dict
[
int
,
int
]
|
None
=
None
,
rank_mapping
:
dict
[
int
,
int
]
|
None
=
None
,
)
->
MoveToBufferResult
:
)
->
TransferMetadata
:
"""
"""
Rearranges the expert weights in place according to the new expert indices.
Rearranges the expert weights in place according to the new expert indices.
...
@@ -463,11 +451,8 @@ async def transfer_layer(
...
@@ -463,11 +451,8 @@ async def transfer_layer(
rank_mapping: Optional rank mapping for elastic expert parallelism.
rank_mapping: Optional rank mapping for elastic expert parallelism.
Returns:
Returns:
is_unchanged (np.ndarray): (num_local_experts,), True where expert
TransferMetadata: Metadata needed for completing remote weight transfers,
is left unchanged.
including is_unchanged and is_received_locally masks.
is_received_locally (np.ndarray): (num_local_experts,), True where expert
can be received locally.
RecvMetadata: Metadata needed for completing remote weight transfers.
"""
"""
ep_size
=
ep_group
.
size
()
ep_size
=
ep_group
.
size
()
if
rank_mapping
is
not
None
:
if
rank_mapping
is
not
None
:
...
@@ -502,7 +487,7 @@ async def transfer_layer(
...
@@ -502,7 +487,7 @@ async def transfer_layer(
old_layer_indices_np
=
old_layer_indices
.
cpu
().
numpy
()
old_layer_indices_np
=
old_layer_indices
.
cpu
().
numpy
()
new_layer_indices_np
=
new_layer_indices
.
cpu
().
numpy
()
new_layer_indices_np
=
new_layer_indices
.
cpu
().
numpy
()
is_unchanged
,
is_received_locally
,
recv_metadata
=
move_to_buffer
(
return
move_to_buffer
(
num_local_experts
=
num_local_physical_experts
,
num_local_experts
=
num_local_physical_experts
,
old_indices
=
old_layer_indices_np
,
old_indices
=
old_layer_indices_np
,
new_indices
=
new_layer_indices_np
,
new_indices
=
new_layer_indices_np
,
...
@@ -512,7 +497,6 @@ async def transfer_layer(
...
@@ -512,7 +497,6 @@ async def transfer_layer(
ep_rank
=
ep_group
.
rank
(),
ep_rank
=
ep_group
.
rank
(),
communicator
=
communicator
,
communicator
=
communicator
,
)
)
return
is_unchanged
,
is_received_locally
,
recv_metadata
def
rearrange_expert_weights_inplace
(
def
rearrange_expert_weights_inplace
(
...
@@ -605,7 +589,7 @@ def rearrange_expert_weights_inplace(
...
@@ -605,7 +589,7 @@ def rearrange_expert_weights_inplace(
new_global_expert_indices_cpu
=
new_global_expert_indices
.
cpu
().
numpy
()
new_global_expert_indices_cpu
=
new_global_expert_indices
.
cpu
().
numpy
()
for
layer_idx
in
range
(
num_moe_layers
):
for
layer_idx
in
range
(
num_moe_layers
):
is_unchanged
,
is_received_locally
,
recv
_metadata
=
move_to_buffer
(
transfer
_metadata
=
move_to_buffer
(
num_local_experts
=
num_local_physical_experts
,
num_local_experts
=
num_local_physical_experts
,
old_indices
=
old_global_expert_indices_cpu
[
layer_idx
],
old_indices
=
old_global_expert_indices_cpu
[
layer_idx
],
new_indices
=
new_global_expert_indices_cpu
[
layer_idx
],
new_indices
=
new_global_expert_indices_cpu
[
layer_idx
],
...
@@ -619,9 +603,7 @@ def rearrange_expert_weights_inplace(
...
@@ -619,9 +603,7 @@ def rearrange_expert_weights_inplace(
move_from_buffer
(
move_from_buffer
(
expert_weights
=
expert_weights
[
layer_idx
],
expert_weights
=
expert_weights
[
layer_idx
],
expert_weights_buffers
=
weights_buffer
,
expert_weights_buffers
=
weights_buffer
,
is_unchanged
=
is_unchanged
,
transfer_metadata
=
transfer_metadata
,
is_received_locally
=
is_received_locally
,
recv_metadata
=
recv_metadata
,
new_indices
=
new_global_expert_indices_cpu
[
layer_idx
],
new_indices
=
new_global_expert_indices_cpu
[
layer_idx
],
ep_rank
=
ep_rank
,
ep_rank
=
ep_rank
,
)
)
...
@@ -715,4 +697,4 @@ def _map_new_expert_indices_with_rank_mapping(
...
@@ -715,4 +697,4 @@ def _map_new_expert_indices_with_rank_mapping(
return
mapped_expert_indices
return
mapped_expert_indices
__all__
=
[
"transfer_layer"
,
"move_from_buffer"
,
"
Recv
Metadata"
]
__all__
=
[
"transfer_layer"
,
"move_from_buffer"
,
"
Transfer
Metadata"
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment