Unverified Commit 98457c04 authored by Huaixin Chang's avatar Huaixin Chang Committed by GitHub
Browse files

[Bugfix] Avoid unnecessary reduce-scatter call in prepare_mlp (#9169)

parent 0fc8bf2c
...@@ -292,6 +292,10 @@ def _dp_gather_via_all_gather( ...@@ -292,6 +292,10 @@ def _dp_gather_via_all_gather(
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
is_partial: bool, is_partial: bool,
): ):
if get_attention_tp_size() == 1:
get_tp_group().all_gather_into_tensor(global_tokens, local_tokens)
return
if not is_partial: if not is_partial:
if get_attention_tp_rank() != 0: if get_attention_tp_rank() != 0:
local_tokens.fill_(0) local_tokens.fill_(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment