"...data/git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "e92ed072ee3ebce116e2ec7f55ac4f09e9dd6b70"
Unverified Commit 8cb4ecca authored by Noam Wies's avatar Noam Wies Committed by GitHub
Browse files

Avoid unnecessary DDP synchronization when gradient_accumulation_steps > 1 (#7742)

* use DDP no_sync when possible

* fix is_nlp_available addition mistake

* reformat trainer.py

* reformat trainer.py

* drop support for pytorch < 1.2

* return support for pytorch < 1.2
parent 52f7d743
...@@ -101,6 +101,11 @@ else: ...@@ -101,6 +101,11 @@ else:
_use_native_amp = True _use_native_amp = True
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
if version.parse(torch.__version__) < version.parse("1.2"):
_use_ddp_no_sync = False
else:
_use_ddp_no_sync = True
if is_datasets_available(): if is_datasets_available():
import datasets import datasets
...@@ -687,7 +692,15 @@ class Trainer: ...@@ -687,7 +692,15 @@ class Trainer:
if (step + 1) % self.args.gradient_accumulation_steps == 0: if (step + 1) % self.args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control) self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control)
tr_loss += self.training_step(model, inputs) if (
((step + 1) % self.args.gradient_accumulation_steps != 0)
and self.args.local_rank != -1
and _use_ddp_no_sync
):
with model.no_sync():
tr_loss += self.training_step(model, inputs)
else:
tr_loss += self.training_step(model, inputs)
self._total_flos += self.floating_point_ops(inputs) self._total_flos += self.floating_point_ops(inputs)
if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment