Commit fbe8ce65 authored by Myle Ott's avatar Myle Ott
Browse files

Better support for various c10d API changes

parent 78071e0f
...@@ -22,10 +22,11 @@ class AdaptiveLoss(FairseqCriterion): ...@@ -22,10 +22,11 @@ class AdaptiveLoss(FairseqCriterion):
def __init__(self, args, task): def __init__(self, args, task):
super().__init__(args, task) super().__init__(args, task)
if not args.no_c10d: if args.ddp_backend == 'c10d':
raise Exception( raise Exception(
'AdaptiveLoss is not compatible with the c10d version of ' 'AdaptiveLoss is not compatible with the c10d '
'DistributedDataParallel. Please add the `--no-c10d` flag.' 'version of DistributedDataParallel. Please use '
'`--ddp-backend=no_c10d` instead.'
) )
def forward(self, model, sample, reduce=True): def forward(self, model, sample, reduce=True):
......
...@@ -5,10 +5,11 @@ ...@@ -5,10 +5,11 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from collections import namedtuple
import pickle import pickle
import torch import torch
from torch import distributed from torch import distributed, nn
from torch.distributed import group from torch.distributed import group
from fairseq import utils from fairseq import utils
...@@ -18,33 +19,42 @@ def is_master(args): ...@@ -18,33 +19,42 @@ def is_master(args):
return args.distributed_rank == 0 return args.distributed_rank == 0
_use_c10d = [None] _use_c10d = [True]
C10dStatus = namedtuple('C10dStatus', ['has_c10d', 'is_default'])
if hasattr(nn.parallel, 'deprecated'):
c10d_status = C10dStatus(has_c10d=True, is_default=True)
elif hasattr(nn.parallel, '_DistributedDataParallelC10d'):
c10d_status = C10dStatus(has_c10d=True, is_default=False)
else:
c10d_status = C10dStatus(has_c10d=False, is_default=False)
def distributed_init(args): def distributed_init(args):
if args.distributed_world_size == 1: if args.distributed_world_size == 1:
raise ValueError('Cannot initialize distributed with distributed_world_size=1') raise ValueError('Cannot initialize distributed with distributed_world_size=1')
if _use_c10d[0] is None: if args.ddp_backend == 'no_c10d':
_use_c10d[0] = not args.no_c10d
if _use_c10d[0] and not hasattr(torch.nn.parallel, '_DistributedDataParallelC10d'):
_use_c10d[0] = False _use_c10d[0] = False
print('WARNING: cannot find DistributedDataParallelC10d, '
'falling back to standard DistributedDataParallel')
print('| distributed init (rank {}): {}'.format( print('| distributed init (rank {}): {}'.format(
args.distributed_rank, args.distributed_init_method), flush=True) args.distributed_rank, args.distributed_init_method), flush=True)
if _use_c10d[0]: if _use_c10d[0]:
distributed.c10d.init_process_group( if c10d_status.is_default:
backend=args.distributed_backend, init_fn = distributed.init_process_group
init_method=args.distributed_init_method,
world_size=args.distributed_world_size,
rank=args.distributed_rank,
)
else: else:
distributed.init_process_group( init_fn = distributed.c10d.init_process_group
else:
if c10d_status.is_default:
init_fn = distributed.deprecated.init_process_group
else:
init_fn = distributed.init_process_group
init_fn(
backend=args.distributed_backend, backend=args.distributed_backend,
init_method=args.distributed_init_method, init_method=args.distributed_init_method,
world_size=args.distributed_world_size, world_size=args.distributed_world_size,
......
...@@ -5,9 +5,10 @@ ...@@ -5,9 +5,10 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from torch.distributed import c10d
from torch.nn import parallel from torch.nn import parallel
from fairseq.distributed_utils import c10d_status
from . import BaseFairseqModel from . import BaseFairseqModel
...@@ -28,21 +29,36 @@ class DistributedFairseqModel(BaseFairseqModel): ...@@ -28,21 +29,36 @@ class DistributedFairseqModel(BaseFairseqModel):
def __init__(self, args, model): def __init__(self, args, model):
super().__init__() super().__init__()
assert isinstance(model, BaseFairseqModel) assert isinstance(model, BaseFairseqModel)
if args.no_c10d: if args.ddp_backend == 'c10d':
self.ddp_model = parallel.DistributedDataParallel( if c10d_status.is_default:
ddp_class = parallel.DistributedDataParallel
elif c10d_status.has_c10d:
ddp_class = parallel._DistributedDataParallelC10d
else:
raise Exception(
'Can\'t find c10d version of DistributedDataParallel. '
'Please update PyTorch.'
)
self.ddp_model = ddp_class(
module=model, module=model,
device_ids=[args.device_id], device_ids=[args.device_id],
output_device=args.device_id, output_device=args.device_id,
broadcast_buffers=False, broadcast_buffers=False,
bucket_cap_mb=args.bucket_cap_mb,
) )
elif args.ddp_backend == 'no_c10d':
if c10d_status.is_default:
ddp_class = parallel.deprecated.DistributedDataParallel
else: else:
self.ddp_model = parallel._DistributedDataParallelC10d( ddp_class = parallel.DistributedDataParallel
self.ddp_model = ddp_class(
module=model, module=model,
device_ids=[args.device_id], device_ids=[args.device_id],
output_device=args.device_id, output_device=args.device_id,
broadcast_buffers=False, broadcast_buffers=False,
bucket_cap_mb=args.c10d_bucket_cap_mb,
) )
else:
raise ValueError('Unknown --ddp-backend: ' + args.ddp_backend)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.ddp_model(*args, **kwargs) return self.ddp_model(*args, **kwargs)
......
...@@ -185,10 +185,11 @@ def add_distributed_training_args(parser): ...@@ -185,10 +185,11 @@ def add_distributed_training_args(parser):
help='port number (not required if using --distributed-init-method)') help='port number (not required if using --distributed-init-method)')
group.add_argument('--device-id', default=0, type=int, group.add_argument('--device-id', default=0, type=int,
help='which GPU to use (usually configured automatically)') help='which GPU to use (usually configured automatically)')
group.add_argument('--no-c10d', action='store_true', group.add_argument('--ddp-backend', default='c10d', type=str,
help='don\'t use c10d distributed backend') choices=['c10d', 'no_c10d'],
group.add_argument('--c10d-bucket-cap-mb', default=150, type=int, metavar='MB', help='DistributedDataParallel backend')
help='bucket size for c10d backend') group.add_argument('--bucket-cap-mb', default=150, type=int, metavar='MB',
help='bucket size for reduction')
return group return group
......
...@@ -265,7 +265,7 @@ class Trainer(object): ...@@ -265,7 +265,7 @@ class Trainer(object):
return logging_output return logging_output
def valid_step(self, sample): def valid_step(self, sample, raise_oom=False):
"""Do forward pass in evaluation mode.""" """Do forward pass in evaluation mode."""
with torch.no_grad(): with torch.no_grad():
self.model.eval() self.model.eval()
...@@ -277,9 +277,20 @@ class Trainer(object): ...@@ -277,9 +277,20 @@ class Trainer(object):
else: else:
ignore_results = False ignore_results = False
try:
_loss, sample_size, logging_output = self.task.get_loss( _loss, sample_size, logging_output = self.task.get_loss(
self.model, self.criterion, sample, self.model, self.criterion, sample,
) )
except RuntimeError as e:
if 'out of memory' in str(e) and not raise_oom:
print('| WARNING: ran out of memory, retrying batch')
for p in self.model.parameters():
if p.grad is not None:
del p.grad # free some memory
torch.cuda.empty_cache()
return self.valid_step(sample, raise_oom=True)
else:
raise e
if ignore_results: if ignore_results:
logging_output, sample_size = {}, 0 logging_output, sample_size = {}, 0
......
...@@ -292,7 +292,7 @@ def train_language_model(data_dir, arch): ...@@ -292,7 +292,7 @@ def train_language_model(data_dir, arch):
'--max-epoch', '1', '--max-epoch', '1',
'--no-progress-bar', '--no-progress-bar',
'--distributed-world-size', '1', '--distributed-world-size', '1',
'--no-c10d', '--ddp-backend', 'no_c10d',
], ],
) )
train.main(train_args) train.main(train_args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment