Commit 7e0d222c authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Only use c10d distributed primitives

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/471

Differential Revision: D13818918

Pulled By: myleott

fbshipit-source-id: d3b8dc50e81ee1d2dcc5efc5815998be8461085f
parent 9196c0b6
...@@ -11,6 +11,7 @@ import pickle ...@@ -11,6 +11,7 @@ import pickle
import subprocess import subprocess
import torch import torch
import torch.distributed as dist
from torch import nn from torch import nn
from fairseq import utils from fairseq import utils
...@@ -20,30 +21,6 @@ def is_master(args): ...@@ -20,30 +21,6 @@ def is_master(args):
return args.distributed_rank == 0 return args.distributed_rank == 0
_use_c10d = [True]
C10dStatus = namedtuple('C10dStatus', ['has_c10d', 'is_default'])
if hasattr(nn.parallel, 'deprecated'):
c10d_status = C10dStatus(has_c10d=True, is_default=True)
elif hasattr(torch.distributed, 'c10d') and hasattr(torch.distributed.c10d, 'init_process_group'):
c10d_status = C10dStatus(has_c10d=True, is_default=False)
else:
c10d_status = C10dStatus(has_c10d=False, is_default=False)
if c10d_status.is_default:
import torch.distributed as dist_c10d
import torch.distributed.deprecated as dist_no_c10d
elif c10d_status.has_c10d:
import torch.distributed.c10d as dist_c10d
import torch.distributed as dist_no_c10d
else:
import torch.distributed as dist_no_c10d
def infer_init_method(args): def infer_init_method(args):
if args.distributed_init_method is not None: if args.distributed_init_method is not None:
return return
...@@ -80,19 +57,10 @@ def distributed_init(args): ...@@ -80,19 +57,10 @@ def distributed_init(args):
if args.distributed_world_size == 1: if args.distributed_world_size == 1:
raise ValueError('Cannot initialize distributed with distributed_world_size=1') raise ValueError('Cannot initialize distributed with distributed_world_size=1')
if args.ddp_backend == 'no_c10d' or not c10d_status.has_c10d:
args.ddp_backend = 'no_c10d'
_use_c10d[0] = False
print('| distributed init (rank {}): {}'.format( print('| distributed init (rank {}): {}'.format(
args.distributed_rank, args.distributed_init_method), flush=True) args.distributed_rank, args.distributed_init_method), flush=True)
if _use_c10d[0]: dist.init_process_group(
init_fn = dist_c10d.init_process_group
else:
init_fn = dist_no_c10d.init_process_group
init_fn(
backend=args.distributed_backend, backend=args.distributed_backend,
init_method=args.distributed_init_method, init_method=args.distributed_init_method,
world_size=args.distributed_world_size, world_size=args.distributed_world_size,
...@@ -118,33 +86,21 @@ def suppress_output(is_master): ...@@ -118,33 +86,21 @@ def suppress_output(is_master):
def get_rank(): def get_rank():
if _use_c10d[0]: return dist.get_rank()
return dist_c10d.get_rank()
else:
return dist_no_c10d.get_rank()
def get_world_size(): def get_world_size():
if _use_c10d[0]: return dist.get_world_size()
return dist_c10d.get_world_size()
else:
return dist_no_c10d.get_world_size()
def get_default_group(): def get_default_group():
if _use_c10d[0]: return dist.group.WORLD
return dist_c10d.group.WORLD
else:
return dist_no_c10d.group.WORLD
def all_reduce(tensor, group=None): def all_reduce(tensor, group=None):
if group is None: if group is None:
group = get_default_group() group = get_default_group()
if _use_c10d[0]: return dist.all_reduce(tensor, group=group)
return dist_c10d.all_reduce(tensor, group=group)
else:
return dist_no_c10d.all_reduce(tensor, group=group)
def all_gather_list(data, group=None, max_size=16384): def all_gather_list(data, group=None, max_size=16384):
......
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
import inspect import inspect
from torch.nn import parallel from torch.nn import parallel
from fairseq.distributed_utils import c10d_status
from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel
from . import BaseFairseqModel from . import BaseFairseqModel
...@@ -31,15 +30,7 @@ def DistributedFairseqModel(args, model): ...@@ -31,15 +30,7 @@ def DistributedFairseqModel(args, model):
# determine which DDP class to extend # determine which DDP class to extend
assert isinstance(model, BaseFairseqModel) assert isinstance(model, BaseFairseqModel)
if args.ddp_backend == 'c10d': if args.ddp_backend == 'c10d':
if c10d_status.is_default: ddp_class = parallel.DistributedDataParallel
ddp_class = parallel.DistributedDataParallel
elif c10d_status.has_c10d:
ddp_class = parallel._DistributedDataParallelC10d
else:
raise Exception(
'Can\'t find c10d version of DistributedDataParallel. '
'Please update PyTorch.'
)
init_kwargs = dict( init_kwargs = dict(
module=model, module=model,
device_ids=[args.device_id], device_ids=[args.device_id],
...@@ -50,7 +41,6 @@ def DistributedFairseqModel(args, model): ...@@ -50,7 +41,6 @@ def DistributedFairseqModel(args, model):
# Maintain backward compatibility for 0.4 or earlier # Maintain backward compatibility for 0.4 or earlier
if 'check_reduction' in inspect.getargspec(ddp_class)[0]: if 'check_reduction' in inspect.getargspec(ddp_class)[0]:
init_kwargs['check_reduction'] = True init_kwargs['check_reduction'] = True
elif args.ddp_backend == 'no_c10d': elif args.ddp_backend == 'no_c10d':
ddp_class = LegacyDistributedDataParallel ddp_class = LegacyDistributedDataParallel
init_kwargs = dict( init_kwargs = dict(
......
...@@ -211,7 +211,7 @@ def add_distributed_training_args(parser): ...@@ -211,7 +211,7 @@ def add_distributed_training_args(parser):
group.add_argument('--ddp-backend', default='c10d', type=str, group.add_argument('--ddp-backend', default='c10d', type=str,
choices=['c10d', 'no_c10d'], choices=['c10d', 'no_c10d'],
help='DistributedDataParallel backend') help='DistributedDataParallel backend')
group.add_argument('--bucket-cap-mb', default=150, type=int, metavar='MB', group.add_argument('--bucket-cap-mb', default=25, type=int, metavar='MB',
help='bucket size for reduction') help='bucket size for reduction')
group.add_argument('--fix-batches-to-gpus', action='store_true', group.add_argument('--fix-batches-to-gpus', action='store_true',
help='don\'t shuffle batches between GPUs; this reduces overall ' help='don\'t shuffle batches between GPUs; this reduces overall '
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment