Unverified Commit a47baff1 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[hotfix] use the original implementation in 8785 (#8994)

parent fd7e15b7
...@@ -553,6 +553,10 @@ class CommunicateSummableTensorPairFn: ...@@ -553,6 +553,10 @@ class CommunicateSummableTensorPairFn:
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states, hidden_states,
) )
if hidden_states.data_ptr() is global_hidden_states.data_ptr():
hidden_states = torch.empty_like(hidden_states)
if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len(): if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
# When using padding, all_reduce is skipped after MLP and MOE and reduce scatter is used here instead. # When using padding, all_reduce is skipped after MLP and MOE and reduce scatter is used here instead.
dp_reduce_scatter_tensor(hidden_states, global_hidden_states) dp_reduce_scatter_tensor(hidden_states, global_hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment