Commit 34726d56 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Move distributed_init into DistributedFairseqModel (#687)

Summary:
This should make rendezvous happen as lazily as possible.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/687

Differential Revision: D15151145

Pulled By: myleott

fbshipit-source-id: d70816a85414c5d509a6b12e2b339b4736db2c88
parent fb18be00
...@@ -9,6 +9,7 @@ from collections import namedtuple ...@@ -9,6 +9,7 @@ from collections import namedtuple
import os import os
import pickle import pickle
import subprocess import subprocess
import warnings
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -54,18 +55,22 @@ def distributed_init(args): ...@@ -54,18 +55,22 @@ def distributed_init(args):
if args.distributed_world_size == 1: if args.distributed_world_size == 1:
raise ValueError('Cannot initialize distributed with distributed_world_size=1') raise ValueError('Cannot initialize distributed with distributed_world_size=1')
print('| distributed init (rank {}): {}'.format( if torch.distributed.is_initialized():
args.distributed_rank, args.distributed_init_method), flush=True) warnings.warn('Distributed is already initialized, cannot initialize twice!')
else:
dist.init_process_group( print('| distributed init (rank {}): {}'.format(
backend=args.distributed_backend, args.distributed_rank, args.distributed_init_method), flush=True)
init_method=args.distributed_init_method,
world_size=args.distributed_world_size, dist.init_process_group(
rank=args.distributed_rank, backend=args.distributed_backend,
) init_method=args.distributed_init_method,
world_size=args.distributed_world_size,
rank=args.distributed_rank,
)
suppress_output(is_master(args)) suppress_output(is_master(args))
args.distributed_rank = torch.distributed.get_rank()
return args.distributed_rank return args.distributed_rank
......
...@@ -6,8 +6,11 @@ ...@@ -6,8 +6,11 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import inspect import inspect
import socket
from torch.nn import parallel from torch.nn import parallel
from fairseq import distributed_utils
from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel
from . import BaseFairseqModel from . import BaseFairseqModel
...@@ -26,6 +29,9 @@ def DistributedFairseqModel(args, model): ...@@ -26,6 +29,9 @@ def DistributedFairseqModel(args, model):
args (argparse.Namespace): fairseq args args (argparse.Namespace): fairseq args
model (BaseFairseqModel): model to wrap model (BaseFairseqModel): model to wrap
""" """
# rendezvous with other workers
args.distributed_rank = distributed_utils.distributed_init(args)
print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank))
# determine which DDP class to extend # determine which DDP class to extend
assert isinstance(model, BaseFairseqModel) assert isinstance(model, BaseFairseqModel)
......
...@@ -23,7 +23,7 @@ from fairseq.trainer import Trainer ...@@ -23,7 +23,7 @@ from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter from fairseq.meters import AverageMeter, StopwatchMeter
def main(args, init_distributed=False): def main(args):
utils.import_user_module(args) utils.import_user_module(args)
if args.max_tokens is None: if args.max_tokens is None:
...@@ -82,12 +82,6 @@ def main(args, init_distributed=False): ...@@ -82,12 +82,6 @@ def main(args, init_distributed=False):
num_workers=args.num_workers, num_workers=args.num_workers,
) )
# Initialize distributed training (after data loading)
if init_distributed:
import socket
args.distributed_rank = distributed_utils.distributed_init(args)
print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank))
# Load the latest checkpoint if one is available # Load the latest checkpoint if one is available
if not load_checkpoint(args, trainer, epoch_itr): if not load_checkpoint(args, trainer, epoch_itr):
trainer.dummy_train_step([dummy_batch]) trainer.dummy_train_step([dummy_batch])
...@@ -390,7 +384,7 @@ def distributed_main(i, args): ...@@ -390,7 +384,7 @@ def distributed_main(i, args):
args.device_id = i args.device_id = i
if args.distributed_rank is None: # torch.multiprocessing.spawn if args.distributed_rank is None: # torch.multiprocessing.spawn
args.distributed_rank = i args.distributed_rank = i
main(args, init_distributed=True) main(args)
def cli_main(): def cli_main():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment