Commit 9b40999e authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add generic registry mechanism

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/792

Differential Revision: D15741781

Pulled By: myleott

fbshipit-source-id: c256c7900c307d485904e69b1526b9acbe08fec9
parent 9dc9a486
......@@ -8,34 +8,15 @@
import importlib
import os
from .fairseq_criterion import FairseqCriterion
from fairseq import registry
from fairseq.criterions.fairseq_criterion import FairseqCriterion
CRITERION_REGISTRY = {}
CRITERION_CLASS_NAMES = set()
def build_criterion(args, task):
return CRITERION_REGISTRY[args.criterion].build_criterion(args, task)
def register_criterion(name):
"""Decorator to register a new criterion."""
def register_criterion_cls(cls):
if name in CRITERION_REGISTRY:
raise ValueError('Cannot register duplicate criterion ({})'.format(name))
if not issubclass(cls, FairseqCriterion):
raise ValueError('Criterion ({}: {}) must extend FairseqCriterion'.format(name, cls.__name__))
if cls.__name__ in CRITERION_CLASS_NAMES:
# We use the criterion class name as a unique identifier in
# checkpoints, so all criterions must have unique class names.
raise ValueError('Cannot register criterion with duplicate class name ({})'.format(cls.__name__))
CRITERION_REGISTRY[name] = cls
CRITERION_CLASS_NAMES.add(cls.__name__)
return cls
return register_criterion_cls
build_criterion, register_criterion, CRITERION_REGISTRY = registry.setup_registry(
'--criterion',
base_class=FairseqCriterion,
default='cross_entropy',
)
# automatically import any Python files in the criterions/ directory
......
......@@ -8,8 +8,9 @@
import importlib
import os
from .fairseq_optimizer import FairseqOptimizer
from .fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer
from fairseq import registry
from fairseq.optim.fairseq_optimizer import FairseqOptimizer
from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer
__all__ = [
......@@ -19,32 +20,16 @@ __all__ = [
]
OPTIMIZER_REGISTRY = {}
OPTIMIZER_CLASS_NAMES = set()
_build_optimizer, register_optimizer, OPTIMIZER_REGISTRY = registry.setup_registry(
'--optimizer',
base_class=FairseqOptimizer,
default='nag',
)
def build_optimizer(args, params):
def build_optimizer(args, params, *extra_args, **extra_kwargs):
params = list(filter(lambda p: p.requires_grad, params))
return OPTIMIZER_REGISTRY[args.optimizer](args, params)
def register_optimizer(name):
"""Decorator to register a new optimizer."""
def register_optimizer_cls(cls):
if name in OPTIMIZER_REGISTRY:
raise ValueError('Cannot register duplicate optimizer ({})'.format(name))
if not issubclass(cls, FairseqOptimizer):
raise ValueError('Optimizer ({}: {}) must extend FairseqOptimizer'.format(name, cls.__name__))
if cls.__name__ in OPTIMIZER_CLASS_NAMES:
# We use the optimizer class name as a unique identifier in
# checkpoints, so all optimizer must have unique class names.
raise ValueError('Cannot register optimizer with duplicate class name ({})'.format(cls.__name__))
OPTIMIZER_REGISTRY[name] = cls
OPTIMIZER_CLASS_NAMES.add(cls.__name__)
return cls
return register_optimizer_cls
return _build_optimizer(args, params, *extra_args, **extra_kwargs)
# automatically import any Python files in the optim/ directory
......
......@@ -8,29 +8,15 @@
import importlib
import os
from .fairseq_lr_scheduler import FairseqLRScheduler
from fairseq import registry
from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import FairseqLRScheduler
LR_SCHEDULER_REGISTRY = {}
def build_lr_scheduler(args, optimizer):
return LR_SCHEDULER_REGISTRY[args.lr_scheduler](args, optimizer)
def register_lr_scheduler(name):
"""Decorator to register a new LR scheduler."""
def register_lr_scheduler_cls(cls):
if name in LR_SCHEDULER_REGISTRY:
raise ValueError('Cannot register duplicate LR scheduler ({})'.format(name))
if not issubclass(cls, FairseqLRScheduler):
raise ValueError('LR Scheduler ({}: {}) must extend FairseqLRScheduler'.format(name, cls.__name__))
LR_SCHEDULER_REGISTRY[name] = cls
return cls
return register_lr_scheduler_cls
build_lr_scheduler, register_lr_scheduler, LR_SCHEDULER_REGISTRY = registry.setup_registry(
'--lr-scheduler',
base_class=FairseqLRScheduler,
default='fixed',
)
# automatically import any Python files in the optim/lr_scheduler/ directory
for file in os.listdir(os.path.dirname(__file__)):
......
......@@ -10,12 +10,7 @@ import argparse
import torch
import sys
from fairseq.criterions import CRITERION_REGISTRY
from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY
from fairseq.optim import OPTIMIZER_REGISTRY
from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY
from fairseq.tasks import TASK_REGISTRY
from fairseq.utils import import_user_module
from fairseq import utils
def get_preprocessing_parser(default_task='translation'):
......@@ -75,6 +70,8 @@ def eval_bool(x, default=False):
def parse_args_and_arch(parser, input_args=None, parse_known=False):
from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY
# The parser doesn't know about model/criterion/optimizer-specific args, so
# we parse twice. First we parse the model/criterion/optimizer, then we
# parse a second time after adding the *-specific arguments.
......@@ -92,13 +89,15 @@ def parse_args_and_arch(parser, input_args=None, parse_known=False):
ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group)
# Add *-specific args to parser.
if hasattr(args, 'criterion'):
CRITERION_REGISTRY[args.criterion].add_args(parser)
if hasattr(args, 'optimizer'):
OPTIMIZER_REGISTRY[args.optimizer].add_args(parser)
if hasattr(args, 'lr_scheduler'):
LR_SCHEDULER_REGISTRY[args.lr_scheduler].add_args(parser)
from fairseq.registry import REGISTRIES
for registry_name, REGISTRY in REGISTRIES.items():
choice = getattr(args, registry_name, None)
if choice is not None:
cls = REGISTRY['registry'][choice]
if hasattr(cls, 'add_args'):
cls.add_args(parser)
if hasattr(args, 'task'):
from fairseq.tasks import TASK_REGISTRY
TASK_REGISTRY[args.task].add_args(parser)
# Parse a second time.
......@@ -130,7 +129,7 @@ def get_parser(desc, default_task='translation'):
usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
usr_parser.add_argument('--user-dir', default=None)
usr_args, _ = usr_parser.parse_known_args()
import_user_module(usr_args)
utils.import_user_module(usr_args)
parser = argparse.ArgumentParser(allow_abbrev=False)
# fmt: off
......@@ -163,7 +162,16 @@ def get_parser(desc, default_task='translation'):
parser.add_argument('--user-dir', default=None,
help='path to a python module containing custom extensions (tasks and/or architectures)')
from fairseq.registry import REGISTRIES
for registry_name, REGISTRY in REGISTRIES.items():
parser.add_argument(
'--' + registry_name.replace('_', '-'),
default=REGISTRY['default'],
choices=REGISTRY['registry'].keys(),
)
# Task definitions can be found under fairseq/tasks/
from fairseq.tasks import TASK_REGISTRY
parser.add_argument('--task', metavar='TASK', default=default_task,
choices=TASK_REGISTRY.keys(),
help='task')
......@@ -306,20 +314,10 @@ def add_optimization_args(parser):
group.add_argument('--update-freq', default='1', metavar='N1,N2,...,N_K',
type=lambda uf: eval_str_list(uf, type=int),
help='update parameters every N_i batches, when in epoch i')
# Optimizer definitions can be found under fairseq/optim/
group.add_argument('--optimizer', default='nag', metavar='OPT',
choices=OPTIMIZER_REGISTRY.keys(),
help='Optimizer')
group.add_argument('--lr', '--learning-rate', default='0.25', type=eval_str_list,
metavar='LR_1,LR_2,...,LR_N',
help='learning rate for the first N epochs; all epochs >N using LR_N'
' (note: this may be interpreted differently depending on --lr-scheduler)')
# Learning rate schedulers can be found under fairseq/optim/lr_scheduler/
group.add_argument('--lr-scheduler', default='fixed',
choices=LR_SCHEDULER_REGISTRY.keys(),
help='Learning Rate Scheduler')
group.add_argument('--min-lr', default=-1, type=float, metavar='LR',
help='stop training when the learning rate reaches this minimum')
# fmt: on
......@@ -469,13 +467,9 @@ def add_model_args(parser):
# 1) model defaults (lowest priority)
# 2) --arch argument
# 3) --encoder/decoder-* arguments (highest priority)
from fairseq.models import ARCH_MODEL_REGISTRY
group.add_argument('--arch', '-a', default='fconv', metavar='ARCH', required=True,
choices=ARCH_MODEL_REGISTRY.keys(),
help='Model Architecture')
# Criterion definitions can be found under fairseq/criterions/
group.add_argument('--criterion', default='cross_entropy', metavar='CRIT',
choices=CRITERION_REGISTRY.keys(),
help='Training Criterion')
# fmt: on
return group
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
REGISTRIES = {}
def setup_registry(
registry_name: str,
base_class=None,
default=None,
):
assert registry_name.startswith('--')
registry_name = registry_name[2:].replace('-', '_')
REGISTRY = {}
REGISTRY_CLASS_NAMES = set()
# maintain a registry of all registries
if registry_name in REGISTRIES:
raise ValueError('Canot setup duplicate registry: {}'.format(registry_name))
REGISTRIES[registry_name] = {
'registry': REGISTRY,
'default': default,
}
def build_x(args, *extra_args, **extra_kwargs):
choice = getattr(args, registry_name, None)
if choice is None:
return None
cls = REGISTRY[choice]
if hasattr(cls, 'build_' + registry_name):
builder = getattr(cls, 'build_' + registry_name)
else:
builder = cls
return builder(args, *extra_args, **extra_kwargs)
def register_x(name):
def register_x_cls(cls):
if name in REGISTRY:
raise ValueError('Cannot register duplicate {} ({})'.format(registry_name, name))
if cls.__name__ in REGISTRY_CLASS_NAMES:
raise ValueError(
'Cannot register {} with duplicate class name ({})'.format(
registry_name, cls.__name__,
)
)
if base_class is not None and not issubclass(cls, base_class):
raise ValueError('{} must extend {}'.format(cls.__name__, base_class.__name__))
REGISTRY[name] = cls
REGISTRY_CLASS_NAMES.add(cls.__name__)
return cls
return register_x_cls
return build_x, register_x, REGISTRY
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment