Commit 9b40999e authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add generic registry mechanism

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/792

Differential Revision: D15741781

Pulled By: myleott

fbshipit-source-id: c256c7900c307d485904e69b1526b9acbe08fec9
parent 9dc9a486
...@@ -8,34 +8,15 @@ ...@@ -8,34 +8,15 @@
import importlib import importlib
import os import os
from .fairseq_criterion import FairseqCriterion from fairseq import registry
from fairseq.criterions.fairseq_criterion import FairseqCriterion
CRITERION_REGISTRY = {} build_criterion, register_criterion, CRITERION_REGISTRY = registry.setup_registry(
CRITERION_CLASS_NAMES = set() '--criterion',
base_class=FairseqCriterion,
default='cross_entropy',
def build_criterion(args, task): )
return CRITERION_REGISTRY[args.criterion].build_criterion(args, task)
def register_criterion(name):
"""Decorator to register a new criterion."""
def register_criterion_cls(cls):
if name in CRITERION_REGISTRY:
raise ValueError('Cannot register duplicate criterion ({})'.format(name))
if not issubclass(cls, FairseqCriterion):
raise ValueError('Criterion ({}: {}) must extend FairseqCriterion'.format(name, cls.__name__))
if cls.__name__ in CRITERION_CLASS_NAMES:
# We use the criterion class name as a unique identifier in
# checkpoints, so all criterions must have unique class names.
raise ValueError('Cannot register criterion with duplicate class name ({})'.format(cls.__name__))
CRITERION_REGISTRY[name] = cls
CRITERION_CLASS_NAMES.add(cls.__name__)
return cls
return register_criterion_cls
# automatically import any Python files in the criterions/ directory # automatically import any Python files in the criterions/ directory
......
...@@ -8,8 +8,9 @@ ...@@ -8,8 +8,9 @@
import importlib import importlib
import os import os
from .fairseq_optimizer import FairseqOptimizer from fairseq import registry
from .fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer from fairseq.optim.fairseq_optimizer import FairseqOptimizer
from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer
__all__ = [ __all__ = [
...@@ -19,32 +20,16 @@ __all__ = [ ...@@ -19,32 +20,16 @@ __all__ = [
] ]
OPTIMIZER_REGISTRY = {} _build_optimizer, register_optimizer, OPTIMIZER_REGISTRY = registry.setup_registry(
OPTIMIZER_CLASS_NAMES = set() '--optimizer',
base_class=FairseqOptimizer,
default='nag',
)
def build_optimizer(args, params): def build_optimizer(args, params, *extra_args, **extra_kwargs):
params = list(filter(lambda p: p.requires_grad, params)) params = list(filter(lambda p: p.requires_grad, params))
return OPTIMIZER_REGISTRY[args.optimizer](args, params) return _build_optimizer(args, params, *extra_args, **extra_kwargs)
def register_optimizer(name):
"""Decorator to register a new optimizer."""
def register_optimizer_cls(cls):
if name in OPTIMIZER_REGISTRY:
raise ValueError('Cannot register duplicate optimizer ({})'.format(name))
if not issubclass(cls, FairseqOptimizer):
raise ValueError('Optimizer ({}: {}) must extend FairseqOptimizer'.format(name, cls.__name__))
if cls.__name__ in OPTIMIZER_CLASS_NAMES:
# We use the optimizer class name as a unique identifier in
# checkpoints, so all optimizer must have unique class names.
raise ValueError('Cannot register optimizer with duplicate class name ({})'.format(cls.__name__))
OPTIMIZER_REGISTRY[name] = cls
OPTIMIZER_CLASS_NAMES.add(cls.__name__)
return cls
return register_optimizer_cls
# automatically import any Python files in the optim/ directory # automatically import any Python files in the optim/ directory
......
...@@ -8,29 +8,15 @@ ...@@ -8,29 +8,15 @@
import importlib import importlib
import os import os
from .fairseq_lr_scheduler import FairseqLRScheduler from fairseq import registry
from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import FairseqLRScheduler
LR_SCHEDULER_REGISTRY = {} build_lr_scheduler, register_lr_scheduler, LR_SCHEDULER_REGISTRY = registry.setup_registry(
'--lr-scheduler',
base_class=FairseqLRScheduler,
def build_lr_scheduler(args, optimizer): default='fixed',
return LR_SCHEDULER_REGISTRY[args.lr_scheduler](args, optimizer) )
def register_lr_scheduler(name):
"""Decorator to register a new LR scheduler."""
def register_lr_scheduler_cls(cls):
if name in LR_SCHEDULER_REGISTRY:
raise ValueError('Cannot register duplicate LR scheduler ({})'.format(name))
if not issubclass(cls, FairseqLRScheduler):
raise ValueError('LR Scheduler ({}: {}) must extend FairseqLRScheduler'.format(name, cls.__name__))
LR_SCHEDULER_REGISTRY[name] = cls
return cls
return register_lr_scheduler_cls
# automatically import any Python files in the optim/lr_scheduler/ directory # automatically import any Python files in the optim/lr_scheduler/ directory
for file in os.listdir(os.path.dirname(__file__)): for file in os.listdir(os.path.dirname(__file__)):
......
...@@ -10,12 +10,7 @@ import argparse ...@@ -10,12 +10,7 @@ import argparse
import torch import torch
import sys import sys
from fairseq.criterions import CRITERION_REGISTRY from fairseq import utils
from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY
from fairseq.optim import OPTIMIZER_REGISTRY
from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY
from fairseq.tasks import TASK_REGISTRY
from fairseq.utils import import_user_module
def get_preprocessing_parser(default_task='translation'): def get_preprocessing_parser(default_task='translation'):
...@@ -75,6 +70,8 @@ def eval_bool(x, default=False): ...@@ -75,6 +70,8 @@ def eval_bool(x, default=False):
def parse_args_and_arch(parser, input_args=None, parse_known=False): def parse_args_and_arch(parser, input_args=None, parse_known=False):
from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY
# The parser doesn't know about model/criterion/optimizer-specific args, so # The parser doesn't know about model/criterion/optimizer-specific args, so
# we parse twice. First we parse the model/criterion/optimizer, then we # we parse twice. First we parse the model/criterion/optimizer, then we
# parse a second time after adding the *-specific arguments. # parse a second time after adding the *-specific arguments.
...@@ -92,13 +89,15 @@ def parse_args_and_arch(parser, input_args=None, parse_known=False): ...@@ -92,13 +89,15 @@ def parse_args_and_arch(parser, input_args=None, parse_known=False):
ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group) ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group)
# Add *-specific args to parser. # Add *-specific args to parser.
if hasattr(args, 'criterion'): from fairseq.registry import REGISTRIES
CRITERION_REGISTRY[args.criterion].add_args(parser) for registry_name, REGISTRY in REGISTRIES.items():
if hasattr(args, 'optimizer'): choice = getattr(args, registry_name, None)
OPTIMIZER_REGISTRY[args.optimizer].add_args(parser) if choice is not None:
if hasattr(args, 'lr_scheduler'): cls = REGISTRY['registry'][choice]
LR_SCHEDULER_REGISTRY[args.lr_scheduler].add_args(parser) if hasattr(cls, 'add_args'):
cls.add_args(parser)
if hasattr(args, 'task'): if hasattr(args, 'task'):
from fairseq.tasks import TASK_REGISTRY
TASK_REGISTRY[args.task].add_args(parser) TASK_REGISTRY[args.task].add_args(parser)
# Parse a second time. # Parse a second time.
...@@ -130,7 +129,7 @@ def get_parser(desc, default_task='translation'): ...@@ -130,7 +129,7 @@ def get_parser(desc, default_task='translation'):
usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
usr_parser.add_argument('--user-dir', default=None) usr_parser.add_argument('--user-dir', default=None)
usr_args, _ = usr_parser.parse_known_args() usr_args, _ = usr_parser.parse_known_args()
import_user_module(usr_args) utils.import_user_module(usr_args)
parser = argparse.ArgumentParser(allow_abbrev=False) parser = argparse.ArgumentParser(allow_abbrev=False)
# fmt: off # fmt: off
...@@ -163,7 +162,16 @@ def get_parser(desc, default_task='translation'): ...@@ -163,7 +162,16 @@ def get_parser(desc, default_task='translation'):
parser.add_argument('--user-dir', default=None, parser.add_argument('--user-dir', default=None,
help='path to a python module containing custom extensions (tasks and/or architectures)') help='path to a python module containing custom extensions (tasks and/or architectures)')
from fairseq.registry import REGISTRIES
for registry_name, REGISTRY in REGISTRIES.items():
parser.add_argument(
'--' + registry_name.replace('_', '-'),
default=REGISTRY['default'],
choices=REGISTRY['registry'].keys(),
)
# Task definitions can be found under fairseq/tasks/ # Task definitions can be found under fairseq/tasks/
from fairseq.tasks import TASK_REGISTRY
parser.add_argument('--task', metavar='TASK', default=default_task, parser.add_argument('--task', metavar='TASK', default=default_task,
choices=TASK_REGISTRY.keys(), choices=TASK_REGISTRY.keys(),
help='task') help='task')
...@@ -306,20 +314,10 @@ def add_optimization_args(parser): ...@@ -306,20 +314,10 @@ def add_optimization_args(parser):
group.add_argument('--update-freq', default='1', metavar='N1,N2,...,N_K', group.add_argument('--update-freq', default='1', metavar='N1,N2,...,N_K',
type=lambda uf: eval_str_list(uf, type=int), type=lambda uf: eval_str_list(uf, type=int),
help='update parameters every N_i batches, when in epoch i') help='update parameters every N_i batches, when in epoch i')
# Optimizer definitions can be found under fairseq/optim/
group.add_argument('--optimizer', default='nag', metavar='OPT',
choices=OPTIMIZER_REGISTRY.keys(),
help='Optimizer')
group.add_argument('--lr', '--learning-rate', default='0.25', type=eval_str_list, group.add_argument('--lr', '--learning-rate', default='0.25', type=eval_str_list,
metavar='LR_1,LR_2,...,LR_N', metavar='LR_1,LR_2,...,LR_N',
help='learning rate for the first N epochs; all epochs >N using LR_N' help='learning rate for the first N epochs; all epochs >N using LR_N'
' (note: this may be interpreted differently depending on --lr-scheduler)') ' (note: this may be interpreted differently depending on --lr-scheduler)')
# Learning rate schedulers can be found under fairseq/optim/lr_scheduler/
group.add_argument('--lr-scheduler', default='fixed',
choices=LR_SCHEDULER_REGISTRY.keys(),
help='Learning Rate Scheduler')
group.add_argument('--min-lr', default=-1, type=float, metavar='LR', group.add_argument('--min-lr', default=-1, type=float, metavar='LR',
help='stop training when the learning rate reaches this minimum') help='stop training when the learning rate reaches this minimum')
# fmt: on # fmt: on
...@@ -469,13 +467,9 @@ def add_model_args(parser): ...@@ -469,13 +467,9 @@ def add_model_args(parser):
# 1) model defaults (lowest priority) # 1) model defaults (lowest priority)
# 2) --arch argument # 2) --arch argument
# 3) --encoder/decoder-* arguments (highest priority) # 3) --encoder/decoder-* arguments (highest priority)
from fairseq.models import ARCH_MODEL_REGISTRY
group.add_argument('--arch', '-a', default='fconv', metavar='ARCH', required=True, group.add_argument('--arch', '-a', default='fconv', metavar='ARCH', required=True,
choices=ARCH_MODEL_REGISTRY.keys(), choices=ARCH_MODEL_REGISTRY.keys(),
help='Model Architecture') help='Model Architecture')
# Criterion definitions can be found under fairseq/criterions/
group.add_argument('--criterion', default='cross_entropy', metavar='CRIT',
choices=CRITERION_REGISTRY.keys(),
help='Training Criterion')
# fmt: on # fmt: on
return group return group
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
REGISTRIES = {}
def setup_registry(
registry_name: str,
base_class=None,
default=None,
):
assert registry_name.startswith('--')
registry_name = registry_name[2:].replace('-', '_')
REGISTRY = {}
REGISTRY_CLASS_NAMES = set()
# maintain a registry of all registries
if registry_name in REGISTRIES:
raise ValueError('Canot setup duplicate registry: {}'.format(registry_name))
REGISTRIES[registry_name] = {
'registry': REGISTRY,
'default': default,
}
def build_x(args, *extra_args, **extra_kwargs):
choice = getattr(args, registry_name, None)
if choice is None:
return None
cls = REGISTRY[choice]
if hasattr(cls, 'build_' + registry_name):
builder = getattr(cls, 'build_' + registry_name)
else:
builder = cls
return builder(args, *extra_args, **extra_kwargs)
def register_x(name):
def register_x_cls(cls):
if name in REGISTRY:
raise ValueError('Cannot register duplicate {} ({})'.format(registry_name, name))
if cls.__name__ in REGISTRY_CLASS_NAMES:
raise ValueError(
'Cannot register {} with duplicate class name ({})'.format(
registry_name, cls.__name__,
)
)
if base_class is not None and not issubclass(cls, base_class):
raise ValueError('{} must extend {}'.format(cls.__name__, base_class.__name__))
REGISTRY[name] = cls
REGISTRY_CLASS_NAMES.add(cls.__name__)
return cls
return register_x_cls
return build_x, register_x, REGISTRY
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment