Unverified Commit d00997e6 authored by Wing Lian's avatar Wing Lian Committed by GitHub
Browse files

ddp fixes for training (#22874)

ddp fixes for stable lm training
parent eddf9eec
......@@ -1565,12 +1565,13 @@ class Trainer:
kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
if is_torch_neuroncore_available():
return model
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,
output_device=self.args.local_rank if self.args._n_gpu != 0 else None,
**kwargs,
)
if any(p.requires_grad for p in model.parameters()):
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,
output_device=self.args.local_rank if self.args._n_gpu != 0 else None,
**kwargs,
)
# torch.compile() needs to be called after wrapping the model with FSDP or DDP
# to ensure that it accounts for the graph breaks required by those wrappers
......@@ -1920,6 +1921,7 @@ class Trainer:
(total_batched_samples % args.gradient_accumulation_steps != 0)
and args.parallel_mode == ParallelMode.DISTRIBUTED
and args._no_sync_in_gradient_accumulation
and hasattr(model, "no_sync")
):
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
with model.no_sync():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment