Commit 7853818c authored by Davide Caroselli's avatar Davide Caroselli Committed by Facebook Github Bot
Browse files

FIX: '--user-dir' on multi-gpu (#449)

Summary:
On a multi-gpu training scenario, the `train.py` script spawns new processes with `torch.multiprocessing.spawn`. Unfortunately those child processes don't inherit the modules imported with `--user-dir`.

This pull request fixes this problem: custom module import in now explicit on every `main()` function.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/449

Differential Revision: D13676922

Pulled By: myleott

fbshipit-source-id: 520358d66155697885b878a37e7d0484bddbc1c6
parent bdec179b
......@@ -16,6 +16,7 @@ import torch
from fairseq import options, progress_bar, tasks, utils
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer
from fairseq.utils import import_user_module
class WordStat(object):
......@@ -47,6 +48,8 @@ class WordStat(object):
def main(parsed_args):
assert parsed_args.path is not None, '--path required for evaluation!'
import_user_module(parsed_args)
print(parsed_args)
use_cuda = torch.cuda.is_available() and not parsed_args.cpu
......
......@@ -127,9 +127,7 @@ def get_parser(desc, default_task='translation'):
usr_parser = argparse.ArgumentParser(add_help=False)
usr_parser.add_argument('--user-dir', default=None)
usr_args, _ = usr_parser.parse_known_args()
if usr_args.user_dir is not None:
import_user_module(usr_args.user_dir)
import_user_module(usr_args)
parser = argparse.ArgumentParser()
# fmt: off
......
......@@ -437,10 +437,15 @@ def resolve_max_positions(*args):
return max_positions
def import_user_module(module_path):
module_path = os.path.abspath(module_path)
module_parent, module_name = os.path.split(module_path)
sys.path.insert(0, module_parent)
importlib.import_module(module_name)
sys.path.pop(0)
def import_user_module(args):
if hasattr(args, 'user_dir'):
module_path = args.user_dir
if module_path is not None:
module_path = os.path.abspath(args.user_dir)
module_parent, module_name = os.path.split(module_path)
if module_name not in sys.modules:
sys.path.insert(0, module_parent)
importlib.import_module(module_name)
sys.path.pop(0)
......@@ -15,6 +15,7 @@ from fairseq import bleu, options, progress_bar, tasks, tokenizer, utils
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_generator import SequenceGenerator
from fairseq.sequence_scorer import SequenceScorer
from fairseq.utils import import_user_module
def main(args):
......@@ -24,6 +25,8 @@ def main(args):
assert args.replace_unk is None or args.raw_text, \
'--replace-unk requires a raw text dataset (--raw-text)'
import_user_module(args)
if args.max_tokens is None and args.max_sentences is None:
args.max_tokens = 12000
print(args)
......
......@@ -17,7 +17,7 @@ import torch
from fairseq import data, options, tasks, tokenizer, utils
from fairseq.sequence_generator import SequenceGenerator
from fairseq.utils import import_user_module
Batch = namedtuple('Batch', 'srcs tokens lengths')
Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')
......@@ -56,6 +56,8 @@ def make_batches(lines, args, task, max_positions):
def main(args):
import_user_module(args)
if args.buffer_size < 1:
args.buffer_size = 1
if args.max_tokens is None and args.max_sentences is None:
......
......@@ -20,6 +20,8 @@ from fairseq.data import indexed_dataset, dictionary
from fairseq.tokenizer import Tokenizer, tokenize_line
from multiprocessing import Pool
from fairseq.utils import import_user_module
def get_parser():
parser = argparse.ArgumentParser()
......@@ -66,6 +68,8 @@ def get_parser():
def main(args):
import_user_module(args)
print(args)
os.makedirs(args.destdir, exist_ok=True)
target = not args.only_source
......
......@@ -21,9 +21,12 @@ from fairseq import distributed_utils, options, progress_bar, tasks, utils
from fairseq.data import iterators
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
from fairseq.utils import import_user_module
def main(args):
import_user_module(args)
if args.max_tokens is None:
args.max_tokens = 6000
print(args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment