Unverified Commit d09135fb authored by Ilya Markov's avatar Ilya Markov Committed by GitHub
Browse files

[BugFix] Async Eplb fix potential race condition (#32881)


Signed-off-by: default avatarilmarkov <markovilya197@gmail.com>
parent 8688c3d4
...@@ -86,6 +86,12 @@ async def transfer_run_periodically( ...@@ -86,6 +86,12 @@ async def transfer_run_periodically(
if model_state.layer_to_transfer >= current_num_layers: if model_state.layer_to_transfer >= current_num_layers:
break break
# Wait for the main thread to finish consuming the buffer
# before overwriting it
if model_state.buffer_consumed_event is not None:
cuda_stream.wait_event(model_state.buffer_consumed_event)
model_state.buffer_consumed_event = None
( (
model_state.is_unchanged, model_state.is_unchanged,
model_state.is_received_locally, model_state.is_received_locally,
......
...@@ -151,6 +151,11 @@ class EplbModelState: ...@@ -151,6 +151,11 @@ class EplbModelState:
CUDA event recorded when the async worker finishes filling the buffer. CUDA event recorded when the async worker finishes filling the buffer.
The main thread waits on this before consuming the buffer. The main thread waits on this before consuming the buffer.
""" """
buffer_consumed_event: torch.cuda.Event | None
"""
CUDA event recorded after the main thread finishes consuming the buffer.
The async worker waits on this before writing to the buffer again.
"""
ep_buffer_ready: int ep_buffer_ready: int
""" """
The flag indicates whether the expert buffer is ready for transfer. The flag indicates whether the expert buffer is ready for transfer.
...@@ -502,6 +507,7 @@ class EplbState: ...@@ -502,6 +507,7 @@ class EplbState:
expert_buffer=expert_buffer, expert_buffer=expert_buffer,
buffer_lock=threading.Lock(), buffer_lock=threading.Lock(),
buffer_ready_event=None, buffer_ready_event=None,
buffer_consumed_event=None,
ep_buffer_ready=0, ep_buffer_ready=0,
layer_to_transfer=0, layer_to_transfer=0,
rebalanced=False, rebalanced=False,
...@@ -1012,6 +1018,12 @@ class EplbState: ...@@ -1012,6 +1018,12 @@ class EplbState:
new_indices=new_indices, new_indices=new_indices,
ep_rank=ep_group.rank(), ep_rank=ep_group.rank(),
) )
# Record event after consuming buffer to signal async thread
# that it's safe to overwrite the buffer
consumed_event = torch.cuda.Event()
consumed_event.record()
model_state.buffer_consumed_event = consumed_event
transferred_layer = model_state.layer_to_transfer transferred_layer = model_state.layer_to_transfer
self._update_layer_mapping_from_new(model_state, transferred_layer) self._update_layer_mapping_from_new(model_state, transferred_layer)
# After the main thread consumes, advance layer_to_transfer # After the main thread consumes, advance layer_to_transfer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment