Commit 0cb87130 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Make multiprocessing_train.py work with multi-node setups

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/425

Differential Revision: D13558340

Pulled By: myleott

fbshipit-source-id: dff8c77027e821d8c80bfbd6a6ccce9ca1a44b78
parent 58dd1862
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import os import os
import random
import signal import signal
import torch import torch
...@@ -17,12 +16,6 @@ from train import main as single_process_main ...@@ -17,12 +16,6 @@ from train import main as single_process_main
def main(args): def main(args):
# Set distributed training parameters for a single node.
args.distributed_world_size = torch.cuda.device_count()
port = random.randint(10000, 20000)
args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
args.distributed_init_host = 'localhost'
args.distributed_port = port + 1
if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d': if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d':
print('| WARNING: when using --update-freq on a single machine, you ' print('| WARNING: when using --update-freq on a single machine, you '
'will get better performance with --ddp-backend=no_c10d') 'will get better performance with --ddp-backend=no_c10d')
...@@ -35,8 +28,9 @@ def main(args): ...@@ -35,8 +28,9 @@ def main(args):
# Train with multiprocessing. # Train with multiprocessing.
procs = [] procs = []
for i in range(args.distributed_world_size): base_rank = args.distributed_rank
args.distributed_rank = i for i in range(torch.cuda.device_count()):
args.distributed_rank = base_rank + i
args.device_id = i args.device_id = i
procs.append(mp.Process(target=run, args=(args, error_queue, ), daemon=True)) procs.append(mp.Process(target=run, args=(args, error_queue, ), daemon=True))
procs[i].start() procs[i].start()
......
...@@ -13,6 +13,8 @@ import collections ...@@ -13,6 +13,8 @@ import collections
import itertools import itertools
import os import os
import math import math
import random
import torch import torch
from fairseq import distributed_utils, options, progress_bar, tasks, utils from fairseq import distributed_utils, options, progress_bar, tasks, utils
...@@ -355,6 +357,12 @@ if __name__ == '__main__': ...@@ -355,6 +357,12 @@ if __name__ == '__main__':
elif args.distributed_world_size > 1: elif args.distributed_world_size > 1:
from multiprocessing_train import main as multiprocessing_main from multiprocessing_train import main as multiprocessing_main
# Set distributed training parameters for a single node.
args.distributed_world_size = torch.cuda.device_count()
port = random.randint(10000, 20000)
args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
args.distributed_port = port + 1
multiprocessing_main(args) multiprocessing_main(args)
else: else:
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment