Commit 0cb87130 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Make multiprocessing_train.py work with multi-node setups

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/425

Differential Revision: D13558340

Pulled By: myleott

fbshipit-source-id: dff8c77027e821d8c80bfbd6a6ccce9ca1a44b78
parent 58dd1862
......@@ -7,7 +7,6 @@
# can be found in the PATENTS file in the same directory.
import os
import random
import signal
import torch
......@@ -17,12 +16,6 @@ from train import main as single_process_main
def main(args):
# Set distributed training parameters for a single node.
args.distributed_world_size = torch.cuda.device_count()
port = random.randint(10000, 20000)
args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
args.distributed_init_host = 'localhost'
args.distributed_port = port + 1
if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d':
print('| WARNING: when using --update-freq on a single machine, you '
'will get better performance with --ddp-backend=no_c10d')
......@@ -35,8 +28,9 @@ def main(args):
# Train with multiprocessing.
procs = []
for i in range(args.distributed_world_size):
args.distributed_rank = i
base_rank = args.distributed_rank
for i in range(torch.cuda.device_count()):
args.distributed_rank = base_rank + i
args.device_id = i
procs.append(mp.Process(target=run, args=(args, error_queue, ), daemon=True))
procs[i].start()
......
......@@ -13,6 +13,8 @@ import collections
import itertools
import os
import math
import random
import torch
from fairseq import distributed_utils, options, progress_bar, tasks, utils
......@@ -355,6 +357,12 @@ if __name__ == '__main__':
elif args.distributed_world_size > 1:
from multiprocessing_train import main as multiprocessing_main
# Set distributed training parameters for a single node.
args.distributed_world_size = torch.cuda.device_count()
port = random.randint(10000, 20000)
args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
args.distributed_port = port + 1
multiprocessing_main(args)
else:
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment