Commit b15f5f53 authored by Davide Caroselli's avatar Davide Caroselli Committed by Facebook Github Bot
Browse files

New command line option '--user-dir' (#440)

Summary:
Following discussion on official fairseq (https://github.com/pytorch/fairseq/issues/438), I added the `--user-dir` option to the command line. The user can now specify a path in order to import a custom module with proprietary tasks, architectures and so on.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/440

Differential Revision: D13651721

Pulled By: myleott

fbshipit-source-id: 38b87454487f1ffa5eaf19c4bcefa0b3b15a8f43
parent d9284ee7
# JetBrains PyCharm IDE
.idea/
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
*.py[cod] *.py[cod]
......
...@@ -14,6 +14,7 @@ from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY ...@@ -14,6 +14,7 @@ from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY
from fairseq.optim import OPTIMIZER_REGISTRY from fairseq.optim import OPTIMIZER_REGISTRY
from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY
from fairseq.tasks import TASK_REGISTRY from fairseq.tasks import TASK_REGISTRY
from fairseq.utils import import_user_module
def get_training_parser(default_task='translation'): def get_training_parser(default_task='translation'):
...@@ -121,6 +122,15 @@ def parse_args_and_arch(parser, input_args=None, parse_known=False): ...@@ -121,6 +122,15 @@ def parse_args_and_arch(parser, input_args=None, parse_known=False):
def get_parser(desc, default_task='translation'): def get_parser(desc, default_task='translation'):
# Before creating the true parser, we need to import optional user module
# in order to eagerly import custom tasks, optimizers, architectures, etc.
usr_parser = argparse.ArgumentParser()
usr_parser.add_argument('--user-dir', default=None)
usr_args, _ = usr_parser.parse_known_args()
if usr_args.user_dir is not None:
import_user_module(usr_args.user_dir)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# fmt: off # fmt: off
parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar') parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar')
...@@ -140,6 +150,8 @@ def get_parser(desc, default_task='translation'): ...@@ -140,6 +150,8 @@ def get_parser(desc, default_task='translation'):
help='number of updates before increasing loss scale') help='number of updates before increasing loss scale')
parser.add_argument('--fp16-scale-tolerance', default=0.0, type=float, parser.add_argument('--fp16-scale-tolerance', default=0.0, type=float,
help='pct of updates that can overflow before decreasing the loss scale') help='pct of updates that can overflow before decreasing the loss scale')
parser.add_argument('--user-dir', default=None,
help='path to a python module containing custom extensions (tasks and/or architectures)')
# Task definitions can be found under fairseq/tasks/ # Task definitions can be found under fairseq/tasks/
parser.add_argument('--task', metavar='TASK', default=default_task, parser.add_argument('--task', metavar='TASK', default=default_task,
......
...@@ -4,14 +4,15 @@ ...@@ -4,14 +4,15 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import importlib.util
from collections import defaultdict, OrderedDict
import logging import logging
import os import os
import re import re
import torch import sys
import traceback import traceback
from collections import defaultdict, OrderedDict
import torch
from torch.serialization import default_restore_location from torch.serialization import default_restore_location
...@@ -434,3 +435,12 @@ def resolve_max_positions(*args): ...@@ -434,3 +435,12 @@ def resolve_max_positions(*args):
map(nullsafe_min, zip(max_positions, arg)) map(nullsafe_min, zip(max_positions, arg))
) )
return max_positions return max_positions
def import_user_module(module_path):
module_path = os.path.abspath(module_path)
module_parent, module_name = os.path.split(module_path)
sys.path.insert(0, module_parent)
importlib.import_module(module_name)
sys.path.pop(0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment