Commit 7853818c authored by Davide Caroselli's avatar Davide Caroselli Committed by Facebook Github Bot
Browse files

FIX: '--user-dir' on multi-gpu (#449)

Summary:
On a multi-gpu training scenario, the `train.py` script spawns new processes with `torch.multiprocessing.spawn`. Unfortunately those child processes don't inherit the modules imported with `--user-dir`.

This pull request fixes this problem: custom module import in now explicit on every `main()` function.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/449

Differential Revision: D13676922

Pulled By: myleott

fbshipit-source-id: 520358d66155697885b878a37e7d0484bddbc1c6
parent bdec179b
...@@ -16,6 +16,7 @@ import torch ...@@ -16,6 +16,7 @@ import torch
from fairseq import options, progress_bar, tasks, utils from fairseq import options, progress_bar, tasks, utils
from fairseq.meters import StopwatchMeter, TimeMeter from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer from fairseq.sequence_scorer import SequenceScorer
from fairseq.utils import import_user_module
class WordStat(object): class WordStat(object):
...@@ -47,6 +48,8 @@ class WordStat(object): ...@@ -47,6 +48,8 @@ class WordStat(object):
def main(parsed_args): def main(parsed_args):
assert parsed_args.path is not None, '--path required for evaluation!' assert parsed_args.path is not None, '--path required for evaluation!'
import_user_module(parsed_args)
print(parsed_args) print(parsed_args)
use_cuda = torch.cuda.is_available() and not parsed_args.cpu use_cuda = torch.cuda.is_available() and not parsed_args.cpu
......
...@@ -127,9 +127,7 @@ def get_parser(desc, default_task='translation'): ...@@ -127,9 +127,7 @@ def get_parser(desc, default_task='translation'):
usr_parser = argparse.ArgumentParser(add_help=False) usr_parser = argparse.ArgumentParser(add_help=False)
usr_parser.add_argument('--user-dir', default=None) usr_parser.add_argument('--user-dir', default=None)
usr_args, _ = usr_parser.parse_known_args() usr_args, _ = usr_parser.parse_known_args()
import_user_module(usr_args)
if usr_args.user_dir is not None:
import_user_module(usr_args.user_dir)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# fmt: off # fmt: off
......
...@@ -437,10 +437,15 @@ def resolve_max_positions(*args): ...@@ -437,10 +437,15 @@ def resolve_max_positions(*args):
return max_positions return max_positions
def import_user_module(module_path): def import_user_module(args):
module_path = os.path.abspath(module_path) if hasattr(args, 'user_dir'):
module_parent, module_name = os.path.split(module_path) module_path = args.user_dir
sys.path.insert(0, module_parent) if module_path is not None:
importlib.import_module(module_name) module_path = os.path.abspath(args.user_dir)
sys.path.pop(0) module_parent, module_name = os.path.split(module_path)
if module_name not in sys.modules:
sys.path.insert(0, module_parent)
importlib.import_module(module_name)
sys.path.pop(0)
...@@ -15,6 +15,7 @@ from fairseq import bleu, options, progress_bar, tasks, tokenizer, utils ...@@ -15,6 +15,7 @@ from fairseq import bleu, options, progress_bar, tasks, tokenizer, utils
from fairseq.meters import StopwatchMeter, TimeMeter from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_generator import SequenceGenerator from fairseq.sequence_generator import SequenceGenerator
from fairseq.sequence_scorer import SequenceScorer from fairseq.sequence_scorer import SequenceScorer
from fairseq.utils import import_user_module
def main(args): def main(args):
...@@ -24,6 +25,8 @@ def main(args): ...@@ -24,6 +25,8 @@ def main(args):
assert args.replace_unk is None or args.raw_text, \ assert args.replace_unk is None or args.raw_text, \
'--replace-unk requires a raw text dataset (--raw-text)' '--replace-unk requires a raw text dataset (--raw-text)'
import_user_module(args)
if args.max_tokens is None and args.max_sentences is None: if args.max_tokens is None and args.max_sentences is None:
args.max_tokens = 12000 args.max_tokens = 12000
print(args) print(args)
......
...@@ -17,7 +17,7 @@ import torch ...@@ -17,7 +17,7 @@ import torch
from fairseq import data, options, tasks, tokenizer, utils from fairseq import data, options, tasks, tokenizer, utils
from fairseq.sequence_generator import SequenceGenerator from fairseq.sequence_generator import SequenceGenerator
from fairseq.utils import import_user_module
Batch = namedtuple('Batch', 'srcs tokens lengths') Batch = namedtuple('Batch', 'srcs tokens lengths')
Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments') Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')
...@@ -56,6 +56,8 @@ def make_batches(lines, args, task, max_positions): ...@@ -56,6 +56,8 @@ def make_batches(lines, args, task, max_positions):
def main(args): def main(args):
import_user_module(args)
if args.buffer_size < 1: if args.buffer_size < 1:
args.buffer_size = 1 args.buffer_size = 1
if args.max_tokens is None and args.max_sentences is None: if args.max_tokens is None and args.max_sentences is None:
......
...@@ -20,6 +20,8 @@ from fairseq.data import indexed_dataset, dictionary ...@@ -20,6 +20,8 @@ from fairseq.data import indexed_dataset, dictionary
from fairseq.tokenizer import Tokenizer, tokenize_line from fairseq.tokenizer import Tokenizer, tokenize_line
from multiprocessing import Pool from multiprocessing import Pool
from fairseq.utils import import_user_module
def get_parser(): def get_parser():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -66,6 +68,8 @@ def get_parser(): ...@@ -66,6 +68,8 @@ def get_parser():
def main(args): def main(args):
import_user_module(args)
print(args) print(args)
os.makedirs(args.destdir, exist_ok=True) os.makedirs(args.destdir, exist_ok=True)
target = not args.only_source target = not args.only_source
......
...@@ -21,9 +21,12 @@ from fairseq import distributed_utils, options, progress_bar, tasks, utils ...@@ -21,9 +21,12 @@ from fairseq import distributed_utils, options, progress_bar, tasks, utils
from fairseq.data import iterators from fairseq.data import iterators
from fairseq.trainer import Trainer from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter from fairseq.meters import AverageMeter, StopwatchMeter
from fairseq.utils import import_user_module
def main(args): def main(args):
import_user_module(args)
if args.max_tokens is None: if args.max_tokens is None:
args.max_tokens = 6000 args.max_tokens = 6000
print(args) print(args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment