Commit fbe8ce65 authored by Myle Ott's avatar Myle Ott
Browse files

Better support for various c10d API changes

parent 78071e0f
......@@ -22,10 +22,11 @@ class AdaptiveLoss(FairseqCriterion):
def __init__(self, args, task):
super().__init__(args, task)
if not args.no_c10d:
if args.ddp_backend == 'c10d':
raise Exception(
'AdaptiveLoss is not compatible with the c10d version of '
'DistributedDataParallel. Please add the `--no-c10d` flag.'
'AdaptiveLoss is not compatible with the c10d '
'version of DistributedDataParallel. Please use '
'`--ddp-backend=no_c10d` instead.'
)
def forward(self, model, sample, reduce=True):
......
......@@ -5,10 +5,11 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from collections import namedtuple
import pickle
import torch
from torch import distributed
from torch import distributed, nn
from torch.distributed import group
from fairseq import utils
......@@ -18,33 +19,42 @@ def is_master(args):
return args.distributed_rank == 0
_use_c10d = [None]
_use_c10d = [True]
C10dStatus = namedtuple('C10dStatus', ['has_c10d', 'is_default'])
if hasattr(nn.parallel, 'deprecated'):
c10d_status = C10dStatus(has_c10d=True, is_default=True)
elif hasattr(nn.parallel, '_DistributedDataParallelC10d'):
c10d_status = C10dStatus(has_c10d=True, is_default=False)
else:
c10d_status = C10dStatus(has_c10d=False, is_default=False)
def distributed_init(args):
if args.distributed_world_size == 1:
raise ValueError('Cannot initialize distributed with distributed_world_size=1')
if _use_c10d[0] is None:
_use_c10d[0] = not args.no_c10d
if _use_c10d[0] and not hasattr(torch.nn.parallel, '_DistributedDataParallelC10d'):
if args.ddp_backend == 'no_c10d':
_use_c10d[0] = False
print('WARNING: cannot find DistributedDataParallelC10d, '
'falling back to standard DistributedDataParallel')
print('| distributed init (rank {}): {}'.format(
args.distributed_rank, args.distributed_init_method), flush=True)
if _use_c10d[0]:
distributed.c10d.init_process_group(
backend=args.distributed_backend,
init_method=args.distributed_init_method,
world_size=args.distributed_world_size,
rank=args.distributed_rank,
)
if c10d_status.is_default:
init_fn = distributed.init_process_group
else:
distributed.init_process_group(
init_fn = distributed.c10d.init_process_group
else:
if c10d_status.is_default:
init_fn = distributed.deprecated.init_process_group
else:
init_fn = distributed.init_process_group
init_fn(
backend=args.distributed_backend,
init_method=args.distributed_init_method,
world_size=args.distributed_world_size,
......
......@@ -5,9 +5,10 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from torch.distributed import c10d
from torch.nn import parallel
from fairseq.distributed_utils import c10d_status
from . import BaseFairseqModel
......@@ -28,21 +29,36 @@ class DistributedFairseqModel(BaseFairseqModel):
def __init__(self, args, model):
super().__init__()
assert isinstance(model, BaseFairseqModel)
if args.no_c10d:
self.ddp_model = parallel.DistributedDataParallel(
if args.ddp_backend == 'c10d':
if c10d_status.is_default:
ddp_class = parallel.DistributedDataParallel
elif c10d_status.has_c10d:
ddp_class = parallel._DistributedDataParallelC10d
else:
raise Exception(
'Can\'t find c10d version of DistributedDataParallel. '
'Please update PyTorch.'
)
self.ddp_model = ddp_class(
module=model,
device_ids=[args.device_id],
output_device=args.device_id,
broadcast_buffers=False,
bucket_cap_mb=args.bucket_cap_mb,
)
elif args.ddp_backend == 'no_c10d':
if c10d_status.is_default:
ddp_class = parallel.deprecated.DistributedDataParallel
else:
self.ddp_model = parallel._DistributedDataParallelC10d(
ddp_class = parallel.DistributedDataParallel
self.ddp_model = ddp_class(
module=model,
device_ids=[args.device_id],
output_device=args.device_id,
broadcast_buffers=False,
bucket_cap_mb=args.c10d_bucket_cap_mb,
)
else:
raise ValueError('Unknown --ddp-backend: ' + args.ddp_backend)
def __call__(self, *args, **kwargs):
return self.ddp_model(*args, **kwargs)
......
......@@ -185,10 +185,11 @@ def add_distributed_training_args(parser):
help='port number (not required if using --distributed-init-method)')
group.add_argument('--device-id', default=0, type=int,
help='which GPU to use (usually configured automatically)')
group.add_argument('--no-c10d', action='store_true',
help='don\'t use c10d distributed backend')
group.add_argument('--c10d-bucket-cap-mb', default=150, type=int, metavar='MB',
help='bucket size for c10d backend')
group.add_argument('--ddp-backend', default='c10d', type=str,
choices=['c10d', 'no_c10d'],
help='DistributedDataParallel backend')
group.add_argument('--bucket-cap-mb', default=150, type=int, metavar='MB',
help='bucket size for reduction')
return group
......
......@@ -265,7 +265,7 @@ class Trainer(object):
return logging_output
def valid_step(self, sample):
def valid_step(self, sample, raise_oom=False):
"""Do forward pass in evaluation mode."""
with torch.no_grad():
self.model.eval()
......@@ -277,9 +277,20 @@ class Trainer(object):
else:
ignore_results = False
try:
_loss, sample_size, logging_output = self.task.get_loss(
self.model, self.criterion, sample,
)
except RuntimeError as e:
if 'out of memory' in str(e) and not raise_oom:
print('| WARNING: ran out of memory, retrying batch')
for p in self.model.parameters():
if p.grad is not None:
del p.grad # free some memory
torch.cuda.empty_cache()
return self.valid_step(sample, raise_oom=True)
else:
raise e
if ignore_results:
logging_output, sample_size = {}, 0
......
......@@ -292,7 +292,7 @@ def train_language_model(data_dir, arch):
'--max-epoch', '1',
'--no-progress-bar',
'--distributed-world-size', '1',
'--no-c10d',
'--ddp-backend', 'no_c10d',
],
)
train.main(train_args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment