Commit eea4d20b authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Automatically fill in default values from add_args

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/797

Differential Revision: D15761071

Pulled By: myleott

fbshipit-source-id: 257d4a2297e83da7e59baed154dbafd6bfe614bf
parent 1b937bb2
......@@ -5,6 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import argparse
from collections import OrderedDict
from typing import Union
import collections
......@@ -17,16 +18,16 @@ import shutil
import torch
from torch.serialization import default_restore_location
from fairseq import tasks, distributed_utils
from fairseq.models import FairseqEncoder, FairseqDecoder
from fairseq.meters import StopwatchMeter
def save_checkpoint(args, trainer, epoch_itr, val_loss):
from fairseq import distributed_utils, meters
if args.no_save or not distributed_utils.is_master(args):
return
write_timer = StopwatchMeter()
write_timer = meters.StopwatchMeter()
write_timer.start()
epoch = epoch_itr.epoch
......@@ -127,11 +128,15 @@ def load_checkpoint(args, trainer):
return extra_state, epoch_itr
def load_checkpoint_to_cpu(path):
def load_checkpoint_to_cpu(path, arg_overrides=None):
"""Loads a checkpoint to CPU (with upgrading for backward compatibility)."""
state = torch.load(
path, map_location=lambda s, l: default_restore_location(s, 'cpu'),
)
args = state['args']
if arg_overrides is not None:
for arg_name, arg_val in arg_overrides.items():
setattr(args, arg_name, arg_val)
state = _upgrade_state_dict(state)
return state
......@@ -145,17 +150,20 @@ def load_model_ensemble(filenames, arg_overrides=None, task=None):
were used during model training
task (fairseq.tasks.FairseqTask, optional): task to use for loading
"""
ensemble, args, _task = _load_model_ensemble(filenames, arg_overrides, task)
return ensemble, args
def _load_model_ensemble(filenames, arg_overrides=None, task=None):
from fairseq import tasks
ensemble = []
for filename in filenames:
if not os.path.exists(filename):
raise IOError('Model file not found: {}'.format(filename))
state = load_checkpoint_to_cpu(filename)
state = load_checkpoint_to_cpu(filename, arg_overrides)
args = state['args']
if arg_overrides is not None:
for arg_name, arg_val in arg_overrides.items():
setattr(args, arg_name, arg_val)
if task is None:
task = tasks.setup_task(args)
......@@ -163,8 +171,7 @@ def load_model_ensemble(filenames, arg_overrides=None, task=None):
model = task.build_model(args)
model.load_state_dict(state['model'], strict=True)
ensemble.append(model)
return ensemble, args
return ensemble, args, task
def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
......@@ -236,6 +243,8 @@ def save_state(
def _upgrade_state_dict(state):
"""Helper for upgrading old model checkpoints."""
from fairseq import models, registry, tasks
# add optimizer_history
if 'optimizer_history' not in state:
state['optimizer_history'] = [
......@@ -284,6 +293,35 @@ def _upgrade_state_dict(state):
'epoch': state['extra_state']['epoch'],
'iterations_in_epoch': state['extra_state'].get('batch_offset', 0),
}
# default to translation task
if not hasattr(state['args'], 'task'):
state['args'].task = 'translation'
def set_defaults(cls):
if not hasattr(cls, 'add_args'):
return
parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, allow_abbrev=False)
cls.add_args(parser)
# copied from argparse.py:
defaults = argparse.Namespace()
for action in parser._actions:
if action.dest is not argparse.SUPPRESS:
if not hasattr(defaults, action.dest):
if action.default is not argparse.SUPPRESS:
setattr(defaults, action.dest, action.default)
for key, default_value in vars(defaults).items():
if not hasattr(state['args'], key):
setattr(state['args'], key, default_value)
# set any missing default values in the task, model or other registries
set_defaults(tasks.TASK_REGISTRY[state['args'].task])
set_defaults(models.ARCH_MODEL_REGISTRY[state['args'].arch])
for registry_name, REGISTRY in registry.REGISTRIES.items():
choice = getattr(state['args'], registry_name, None)
if choice is not None:
cls = REGISTRY['registry'][choice]
set_defaults(cls)
return state
......
......@@ -25,6 +25,7 @@ DEFAULT_MAX_TARGET_POSITIONS = 1024
@register_model('transformer_lm')
class TransformerLanguageModel(FairseqLanguageModel):
def __init__(self, decoder):
super().__init__(decoder)
......@@ -55,7 +56,7 @@ class TransformerLanguageModel(FairseqLanguageModel):
help='num decoder attention heads')
parser.add_argument('--decoder-normalize-before', default=False, action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--no-decoder-final-norm', action='store_true',
parser.add_argument('--no-decoder-final-norm', default=False, action='store_true',
help='don\'t add an extra layernorm after the last decoder block')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
......@@ -73,9 +74,9 @@ class TransformerLanguageModel(FairseqLanguageModel):
parser.add_argument('--character-filters', type=str, metavar='LIST',
default='[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]',
help='size of character embeddings')
parser.add_argument('--character-embedding-dim', type=int, metavar='N', default=4,
parser.add_argument('--character-embedding-dim', default=4, type=int, metavar='N',
help='size of character embeddings')
parser.add_argument('--char-embedder-highway-layers', type=int, metavar='N', default=2,
parser.add_argument('--char-embedder-highway-layers', default=2, type=int, metavar='N',
help='number of highway layers for character token embeddder')
parser.add_argument('--adaptive-input', action='store_true',
help='if set, uses adaptive input')
......@@ -132,12 +133,14 @@ class TransformerLanguageModel(FairseqLanguageModel):
@register_model_architecture('transformer_lm', 'transformer_lm')
def base_lm_architecture(args):
# backward compatibility for older model checkpoints
if hasattr(args, 'no_tie_adaptive_proj') and args.no_tie_adaptive_proj is False:
if hasattr(args, 'no_tie_adaptive_proj'):
# previous models defined --no-tie-adaptive-proj, so use the existence of
# that option to determine if this is an "old" model checkpoint
args.no_decoder_final_norm = True # old models always set this to True
if args.no_tie_adaptive_proj is False:
args.tie_adaptive_proj = True
if hasattr(args, 'decoder_final_norm'):
args.no_decoder_final_norm = not args.decoder_final_norm
if not hasattr(args, 'no_decoder_final_norm'):
args.no_decoder_final_norm = True # old models always set this to True
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 2048)
......
......@@ -79,7 +79,7 @@ class CosineSchedule(FairseqLRScheduler):
help='warmup the learning rate linearly for the first N updates')
parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR',
help='initial learning rate during warmup phase; default is args.lr')
parser.add_argument('--max-lr', required=True, type=float, metavar='LR',
parser.add_argument('--max-lr', type=float, metavar='LR',
help='max learning rate, must be more than args.lr')
parser.add_argument('--t-mult', default=1, type=float, metavar='LR',
help='factor to grow the length of each period')
......
......@@ -111,7 +111,7 @@ class TranslationTask(FairseqTask):
help='target language')
parser.add_argument('--lazy-load', action='store_true',
help='load the dataset lazily')
parser.add_argument('--raw-text', default=False, action='store_true',
parser.add_argument('--raw-text', action='store_true',
help='load raw text dataset')
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
help='pad the source on the left')
......
......@@ -51,9 +51,9 @@ class TranslationMoETask(TranslationTask):
"""Add task-specific arguments to the parser."""
# fmt: off
TranslationTask.add_args(parser)
parser.add_argument('--method', required=True,
parser.add_argument('--method', default='hMoEup',
choices=['sMoElp', 'sMoEup', 'hMoElp', 'hMoEup'])
parser.add_argument('--num-experts', type=int, metavar='N', required=True,
parser.add_argument('--num-experts', default=3, type=int, metavar='N',
help='number of experts')
parser.add_argument('--mean-pool-gating-network', action='store_true',
help='use a simple mean-pooling gating network')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment