Unverified Commit 6beef12b authored by Sage Moore's avatar Sage Moore Committed by GitHub
Browse files

[EPLB][Cleanup] Remove `is_async_enabled` from `EplbModelState` (#32050)


Signed-off-by: default avatarSage Moore <sage@neuralmagic.com>
parent ab74b2a2
...@@ -30,6 +30,7 @@ def start_async_worker( ...@@ -30,6 +30,7 @@ def start_async_worker(
ep_group = get_ep_group().device_group ep_group = get_ep_group().device_group
rank = ep_group.rank() rank = ep_group.rank()
device_index = state.cuda_device_index device_index = state.cuda_device_index
assert state.is_async
def thread_target() -> None: def thread_target() -> None:
assert device_index is not None assert device_index is not None
...@@ -42,9 +43,9 @@ def start_async_worker( ...@@ -42,9 +43,9 @@ def start_async_worker(
transfer_run_periodically( transfer_run_periodically(
state=state, state=state,
ep_group=ep_group, ep_group=ep_group,
cuda_stream=cuda_stream,
is_profile=is_profile, is_profile=is_profile,
rank_mapping=rank_mapping, rank_mapping=rank_mapping,
cuda_stream=cuda_stream,
) )
) )
except Exception as exc: # pragma: no cover - diagnostic path except Exception as exc: # pragma: no cover - diagnostic path
...@@ -60,17 +61,16 @@ def start_async_worker( ...@@ -60,17 +61,16 @@ def start_async_worker(
async def transfer_run_periodically( async def transfer_run_periodically(
state: "EplbState", state: "EplbState",
ep_group: ProcessGroup, ep_group: ProcessGroup,
cuda_stream: torch.cuda.Stream,
is_profile: bool = False, is_profile: bool = False,
rank_mapping: dict[int, int] | None = None, rank_mapping: dict[int, int] | None = None,
cuda_stream: torch.cuda.Stream = None,
) -> None: ) -> None:
while True: while True:
await asyncio.to_thread(state.rearrange_event.wait) await asyncio.to_thread(state.rearrange_event.wait)
logger.info("async worker woke up for EPLB transfer") logger.info("async worker woke up for EPLB transfer")
assert state.is_async
for model_state in state.model_states.values(): for model_state in state.model_states.values():
if not model_state.is_async_enabled:
continue
current_num_layers = model_state.model.num_moe_layers current_num_layers = model_state.model.num_moe_layers
while ( while (
model_state.rebalanced model_state.rebalanced
......
...@@ -182,10 +182,6 @@ class EplbModelState: ...@@ -182,10 +182,6 @@ class EplbModelState:
""" """
intermediate variable between `move_to_buffer` and `move_to_workspace`. intermediate variable between `move_to_buffer` and `move_to_workspace`.
""" """
is_async_enabled: bool
"""
The flag indicates whether the EPLB is running in async mode.
"""
cuda_device_index: int | None cuda_device_index: int | None
""" """
CUDA device index for the async EPLB worker thread. CUDA device index for the async EPLB worker thread.
...@@ -518,7 +514,6 @@ class EplbState: ...@@ -518,7 +514,6 @@ class EplbState:
recv_expert_ids=np.array([]), recv_expert_ids=np.array([]),
recv_dst_rows=np.array([]), recv_dst_rows=np.array([]),
), ),
is_async_enabled=self.is_async,
cuda_device_index=self.cuda_device_index, cuda_device_index=self.cuda_device_index,
new_physical_to_logical_map=new_physical_to_logical_map, new_physical_to_logical_map=new_physical_to_logical_map,
new_logical_to_physical_map=new_logical_to_physical_map, new_logical_to_physical_map=new_logical_to_physical_map,
...@@ -630,19 +625,12 @@ class EplbState: ...@@ -630,19 +625,12 @@ class EplbState:
if self.is_async: if self.is_async:
for eplb_model_state in self.model_states.values(): for eplb_model_state in self.model_states.values():
if not eplb_model_state.is_async_enabled:
continue
all_ranks_buffer_ready = False all_ranks_buffer_ready = False
if eplb_model_state.pending_global_ready_check: if eplb_model_state.pending_global_ready_check:
all_ranks_buffer_ready = self._all_ranks_buffer_ready( all_ranks_buffer_ready = self._all_ranks_buffer_ready(
eplb_model_state eplb_model_state
) )
if ( if eplb_model_state.ep_buffer_ready and all_ranks_buffer_ready:
eplb_model_state.is_async_enabled
and eplb_model_state.ep_buffer_ready
and all_ranks_buffer_ready
):
self.move_to_workspace( self.move_to_workspace(
model_state=eplb_model_state, model_state=eplb_model_state,
ep_group=ep_group, ep_group=ep_group,
...@@ -664,8 +652,8 @@ class EplbState: ...@@ -664,8 +652,8 @@ class EplbState:
) )
if self.expert_rearrangement_step >= self.expert_rearrangement_step_interval: if self.expert_rearrangement_step >= self.expert_rearrangement_step_interval:
if any( if self.is_async and any(
eplb_model_state.is_async_enabled and eplb_model_state.rebalanced eplb_model_state.rebalanced
for eplb_model_state in self.model_states.values() for eplb_model_state in self.model_states.values()
): ):
# Still performing asynchronous rearrangement # Still performing asynchronous rearrangement
...@@ -822,7 +810,7 @@ class EplbState: ...@@ -822,7 +810,7 @@ class EplbState:
eplb_model_state.physical_to_logical_map, eplb_model_state.physical_to_logical_map,
) )
if not eplb_model_state.is_async_enabled or is_profile: if not self.is_async or is_profile:
# Update expert weights # Update expert weights
rearrange_expert_weights_inplace( rearrange_expert_weights_inplace(
eplb_model_state.physical_to_logical_map, eplb_model_state.physical_to_logical_map,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment