Commit 7e0d222c authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Only use c10d distributed primitives

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/471

Differential Revision: D13818918

Pulled By: myleott

fbshipit-source-id: d3b8dc50e81ee1d2dcc5efc5815998be8461085f
parent 9196c0b6
......@@ -11,6 +11,7 @@ import pickle
import subprocess
import torch
import torch.distributed as dist
from torch import nn
from fairseq import utils
......@@ -20,30 +21,6 @@ def is_master(args):
return args.distributed_rank == 0
_use_c10d = [True]
C10dStatus = namedtuple('C10dStatus', ['has_c10d', 'is_default'])
if hasattr(nn.parallel, 'deprecated'):
c10d_status = C10dStatus(has_c10d=True, is_default=True)
elif hasattr(torch.distributed, 'c10d') and hasattr(torch.distributed.c10d, 'init_process_group'):
c10d_status = C10dStatus(has_c10d=True, is_default=False)
else:
c10d_status = C10dStatus(has_c10d=False, is_default=False)
if c10d_status.is_default:
import torch.distributed as dist_c10d
import torch.distributed.deprecated as dist_no_c10d
elif c10d_status.has_c10d:
import torch.distributed.c10d as dist_c10d
import torch.distributed as dist_no_c10d
else:
import torch.distributed as dist_no_c10d
def infer_init_method(args):
if args.distributed_init_method is not None:
return
......@@ -80,19 +57,10 @@ def distributed_init(args):
if args.distributed_world_size == 1:
raise ValueError('Cannot initialize distributed with distributed_world_size=1')
if args.ddp_backend == 'no_c10d' or not c10d_status.has_c10d:
args.ddp_backend = 'no_c10d'
_use_c10d[0] = False
print('| distributed init (rank {}): {}'.format(
args.distributed_rank, args.distributed_init_method), flush=True)
if _use_c10d[0]:
init_fn = dist_c10d.init_process_group
else:
init_fn = dist_no_c10d.init_process_group
init_fn(
dist.init_process_group(
backend=args.distributed_backend,
init_method=args.distributed_init_method,
world_size=args.distributed_world_size,
......@@ -118,33 +86,21 @@ def suppress_output(is_master):
def get_rank():
if _use_c10d[0]:
return dist_c10d.get_rank()
else:
return dist_no_c10d.get_rank()
return dist.get_rank()
def get_world_size():
if _use_c10d[0]:
return dist_c10d.get_world_size()
else:
return dist_no_c10d.get_world_size()
return dist.get_world_size()
def get_default_group():
if _use_c10d[0]:
return dist_c10d.group.WORLD
else:
return dist_no_c10d.group.WORLD
return dist.group.WORLD
def all_reduce(tensor, group=None):
if group is None:
group = get_default_group()
if _use_c10d[0]:
return dist_c10d.all_reduce(tensor, group=group)
else:
return dist_no_c10d.all_reduce(tensor, group=group)
return dist.all_reduce(tensor, group=group)
def all_gather_list(data, group=None, max_size=16384):
......
......@@ -8,7 +8,6 @@
import inspect
from torch.nn import parallel
from fairseq.distributed_utils import c10d_status
from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel
from . import BaseFairseqModel
......@@ -31,15 +30,7 @@ def DistributedFairseqModel(args, model):
# determine which DDP class to extend
assert isinstance(model, BaseFairseqModel)
if args.ddp_backend == 'c10d':
if c10d_status.is_default:
ddp_class = parallel.DistributedDataParallel
elif c10d_status.has_c10d:
ddp_class = parallel._DistributedDataParallelC10d
else:
raise Exception(
'Can\'t find c10d version of DistributedDataParallel. '
'Please update PyTorch.'
)
init_kwargs = dict(
module=model,
device_ids=[args.device_id],
......@@ -50,7 +41,6 @@ def DistributedFairseqModel(args, model):
# Maintain backward compatibility for 0.4 or earlier
if 'check_reduction' in inspect.getargspec(ddp_class)[0]:
init_kwargs['check_reduction'] = True
elif args.ddp_backend == 'no_c10d':
ddp_class = LegacyDistributedDataParallel
init_kwargs = dict(
......
......@@ -211,7 +211,7 @@ def add_distributed_training_args(parser):
group.add_argument('--ddp-backend', default='c10d', type=str,
choices=['c10d', 'no_c10d'],
help='DistributedDataParallel backend')
group.add_argument('--bucket-cap-mb', default=150, type=int, metavar='MB',
group.add_argument('--bucket-cap-mb', default=25, type=int, metavar='MB',
help='bucket size for reduction')
group.add_argument('--fix-batches-to-gpus', action='store_true',
help='don\'t shuffle batches between GPUs; this reduces overall '
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment