"...text-generation-inference.git" did not exist on "92c1ecd0089d329d4b5a5e2b9327da828b888d34"
Commit b625d53d authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Support DDP.no_sync context manager

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/671

Differential Revision: D15925248

fbshipit-source-id: 9eeea8a257929347e2458afdfc1def8dbb925a72
parent 6be5f07c
......@@ -7,14 +7,16 @@
"""
A modified version of the legacy DistributedDataParallel module that uses c10d
communication primitives. This is necessary for models that have conditional
computation (e.g., AdaptiveSoftmax) and which therefore do not work with the
c10d version of DDP.
communication primitives. This version is simpler than the latest PyTorch
version and is useful for debugging. Notably it does not overlap gradient
communication with the backward pass, which makes it slower but more robust
than the PyTorch version.
This version also supports the *accumulate_grads* feature, which allows faster
This version also supports the *no_sync* context manager, which allows faster
training with `--update-freq`.
"""
from contextlib import contextmanager
import copy
import torch
......@@ -74,6 +76,14 @@ class LegacyDistributedDataParallel(nn.Module):
super().__setstate__(state)
self._register_grad_hook()
@contextmanager
def no_sync(self):
"""A context manager to disable gradient synchronization."""
old_accumulate_grads = self.accumulate_grads
self.accumulate_grads = True
yield
self.accumulate_grads = old_accumulate_grads
def forward(self, *inputs, **kwargs):
return self.module(*inputs, **kwargs)
......
......@@ -10,6 +10,7 @@ Train a network across multiple GPUs.
"""
from collections import OrderedDict
import contextlib
from itertools import chain
import math
import os
......@@ -242,23 +243,28 @@ class Trainer(object):
else:
ignore_grad = False
def maybe_no_sync():
"""
Whenever *samples* contains more than one mini-batch, we
want to accumulate gradients locally and only call
all-reduce in the last backwards pass.
"""
if (
self.args.distributed_world_size > 1
and hasattr(self.model, 'no_sync')
and i < len(samples) - 1
):
return self.model.no_sync()
else:
return contextlib.ExitStack() # dummy contextmanager
try:
if self.args.distributed_world_size > 1:
# Whenever *samples* contains more than one mini-batch, we
# want to accumulate gradients locally and only call
# all-reduce in the last backwards pass. Currently the
# *accumulate_grads* flag is only supported by
# LegacyDistributedDataParallel.
if i < len(samples) - 1:
self.model.accumulate_grads = True
else:
self.model.accumulate_grads = False
# forward and backward
loss, sample_size, logging_output = self.task.train_step(
sample, self.model, self.criterion, self.optimizer,
ignore_grad
)
with maybe_no_sync():
# forward and backward
loss, sample_size, logging_output = self.task.train_step(
sample, self.model, self.criterion, self.optimizer,
ignore_grad
)
if not ignore_grad:
logging_outputs.append(logging_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment