Unverified Commit bfd57ff3 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] SDP syncing buffers during gradient accumulation (#1075)



- Fixes from Benjamin.

Original commit msg:
  - Fixes #1041. I just had a minute or two, hoping that it's enough :)
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent abfa7193
......@@ -218,7 +218,7 @@ class ShardedDataParallel(nn.Module):
if needs_setup:
self.refresh_trainable()
if self._enable_broadcast_buffers:
if self._enable_broadcast_buffers and not self._should_accumulate_grads:
# NCCL communications are on a different stream, needs to be blocking
# for the subsequent FW to be correct
self.sync_buffers(blocking=True)
......
......@@ -112,6 +112,14 @@ def run_one_step(
with ddp_model.no_sync() if grad_accumulation else suppress():
input_tensor = torch.rand((64, 2)).to(device)
loss = ddp_model(input_tensor).abs().sum()
# If grad_accumulation, we can check after the forward that the models are different
# (not synced)
if grad_accumulation:
check_same_models_across_ranks(
ddp_model, dist.group.WORLD, params_should_be_equal=False, check_broadcast_buffers=True
)
loss.backward()
return loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment