"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "a80f6892003e102f56bc956e9f8707b52c5d4487"
Commit b625d53d authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Support DDP.no_sync context manager

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/671

Differential Revision: D15925248

fbshipit-source-id: 9eeea8a257929347e2458afdfc1def8dbb925a72
parent 6be5f07c
...@@ -7,14 +7,16 @@ ...@@ -7,14 +7,16 @@
""" """
A modified version of the legacy DistributedDataParallel module that uses c10d A modified version of the legacy DistributedDataParallel module that uses c10d
communication primitives. This is necessary for models that have conditional communication primitives. This version is simpler than the latest PyTorch
computation (e.g., AdaptiveSoftmax) and which therefore do not work with the version and is useful for debugging. Notably it does not overlap gradient
c10d version of DDP. communication with the backward pass, which makes it slower but more robust
than the PyTorch version.
This version also supports the *accumulate_grads* feature, which allows faster This version also supports the *no_sync* context manager, which allows faster
training with `--update-freq`. training with `--update-freq`.
""" """
from contextlib import contextmanager
import copy import copy
import torch import torch
...@@ -74,6 +76,14 @@ class LegacyDistributedDataParallel(nn.Module): ...@@ -74,6 +76,14 @@ class LegacyDistributedDataParallel(nn.Module):
super().__setstate__(state) super().__setstate__(state)
self._register_grad_hook() self._register_grad_hook()
@contextmanager
def no_sync(self):
"""A context manager to disable gradient synchronization."""
old_accumulate_grads = self.accumulate_grads
self.accumulate_grads = True
yield
self.accumulate_grads = old_accumulate_grads
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
return self.module(*inputs, **kwargs) return self.module(*inputs, **kwargs)
......
...@@ -10,6 +10,7 @@ Train a network across multiple GPUs. ...@@ -10,6 +10,7 @@ Train a network across multiple GPUs.
""" """
from collections import OrderedDict from collections import OrderedDict
import contextlib
from itertools import chain from itertools import chain
import math import math
import os import os
...@@ -242,23 +243,28 @@ class Trainer(object): ...@@ -242,23 +243,28 @@ class Trainer(object):
else: else:
ignore_grad = False ignore_grad = False
def maybe_no_sync():
"""
Whenever *samples* contains more than one mini-batch, we
want to accumulate gradients locally and only call
all-reduce in the last backwards pass.
"""
if (
self.args.distributed_world_size > 1
and hasattr(self.model, 'no_sync')
and i < len(samples) - 1
):
return self.model.no_sync()
else:
return contextlib.ExitStack() # dummy contextmanager
try: try:
if self.args.distributed_world_size > 1: with maybe_no_sync():
# Whenever *samples* contains more than one mini-batch, we # forward and backward
# want to accumulate gradients locally and only call loss, sample_size, logging_output = self.task.train_step(
# all-reduce in the last backwards pass. Currently the sample, self.model, self.criterion, self.optimizer,
# *accumulate_grads* flag is only supported by ignore_grad
# LegacyDistributedDataParallel. )
if i < len(samples) - 1:
self.model.accumulate_grads = True
else:
self.model.accumulate_grads = False
# forward and backward
loss, sample_size, logging_output = self.task.train_step(
sample, self.model, self.criterion, self.optimizer,
ignore_grad
)
if not ignore_grad: if not ignore_grad:
logging_outputs.append(logging_output) logging_outputs.append(logging_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment