"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "a4da216125c29001dc33fefd37f959231bc4bf0a"
Commit eea4d20b authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Automatically fill in default values from add_args

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/797

Differential Revision: D15761071

Pulled By: myleott

fbshipit-source-id: 257d4a2297e83da7e59baed154dbafd6bfe614bf
parent 1b937bb2
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import argparse
from collections import OrderedDict from collections import OrderedDict
from typing import Union from typing import Union
import collections import collections
...@@ -17,16 +18,16 @@ import shutil ...@@ -17,16 +18,16 @@ import shutil
import torch import torch
from torch.serialization import default_restore_location from torch.serialization import default_restore_location
from fairseq import tasks, distributed_utils
from fairseq.models import FairseqEncoder, FairseqDecoder from fairseq.models import FairseqEncoder, FairseqDecoder
from fairseq.meters import StopwatchMeter
def save_checkpoint(args, trainer, epoch_itr, val_loss): def save_checkpoint(args, trainer, epoch_itr, val_loss):
from fairseq import distributed_utils, meters
if args.no_save or not distributed_utils.is_master(args): if args.no_save or not distributed_utils.is_master(args):
return return
write_timer = StopwatchMeter() write_timer = meters.StopwatchMeter()
write_timer.start() write_timer.start()
epoch = epoch_itr.epoch epoch = epoch_itr.epoch
...@@ -127,11 +128,15 @@ def load_checkpoint(args, trainer): ...@@ -127,11 +128,15 @@ def load_checkpoint(args, trainer):
return extra_state, epoch_itr return extra_state, epoch_itr
def load_checkpoint_to_cpu(path): def load_checkpoint_to_cpu(path, arg_overrides=None):
"""Loads a checkpoint to CPU (with upgrading for backward compatibility).""" """Loads a checkpoint to CPU (with upgrading for backward compatibility)."""
state = torch.load( state = torch.load(
path, map_location=lambda s, l: default_restore_location(s, 'cpu'), path, map_location=lambda s, l: default_restore_location(s, 'cpu'),
) )
args = state['args']
if arg_overrides is not None:
for arg_name, arg_val in arg_overrides.items():
setattr(args, arg_name, arg_val)
state = _upgrade_state_dict(state) state = _upgrade_state_dict(state)
return state return state
...@@ -145,17 +150,20 @@ def load_model_ensemble(filenames, arg_overrides=None, task=None): ...@@ -145,17 +150,20 @@ def load_model_ensemble(filenames, arg_overrides=None, task=None):
were used during model training were used during model training
task (fairseq.tasks.FairseqTask, optional): task to use for loading task (fairseq.tasks.FairseqTask, optional): task to use for loading
""" """
ensemble, args, _task = _load_model_ensemble(filenames, arg_overrides, task)
return ensemble, args
def _load_model_ensemble(filenames, arg_overrides=None, task=None):
from fairseq import tasks
ensemble = [] ensemble = []
for filename in filenames: for filename in filenames:
if not os.path.exists(filename): if not os.path.exists(filename):
raise IOError('Model file not found: {}'.format(filename)) raise IOError('Model file not found: {}'.format(filename))
state = load_checkpoint_to_cpu(filename) state = load_checkpoint_to_cpu(filename, arg_overrides)
args = state['args'] args = state['args']
if arg_overrides is not None:
for arg_name, arg_val in arg_overrides.items():
setattr(args, arg_name, arg_val)
if task is None: if task is None:
task = tasks.setup_task(args) task = tasks.setup_task(args)
...@@ -163,8 +171,7 @@ def load_model_ensemble(filenames, arg_overrides=None, task=None): ...@@ -163,8 +171,7 @@ def load_model_ensemble(filenames, arg_overrides=None, task=None):
model = task.build_model(args) model = task.build_model(args)
model.load_state_dict(state['model'], strict=True) model.load_state_dict(state['model'], strict=True)
ensemble.append(model) ensemble.append(model)
return ensemble, args, task
return ensemble, args
def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'): def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
...@@ -236,6 +243,8 @@ def save_state( ...@@ -236,6 +243,8 @@ def save_state(
def _upgrade_state_dict(state): def _upgrade_state_dict(state):
"""Helper for upgrading old model checkpoints.""" """Helper for upgrading old model checkpoints."""
from fairseq import models, registry, tasks
# add optimizer_history # add optimizer_history
if 'optimizer_history' not in state: if 'optimizer_history' not in state:
state['optimizer_history'] = [ state['optimizer_history'] = [
...@@ -284,6 +293,35 @@ def _upgrade_state_dict(state): ...@@ -284,6 +293,35 @@ def _upgrade_state_dict(state):
'epoch': state['extra_state']['epoch'], 'epoch': state['extra_state']['epoch'],
'iterations_in_epoch': state['extra_state'].get('batch_offset', 0), 'iterations_in_epoch': state['extra_state'].get('batch_offset', 0),
} }
# default to translation task
if not hasattr(state['args'], 'task'):
state['args'].task = 'translation'
def set_defaults(cls):
if not hasattr(cls, 'add_args'):
return
parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, allow_abbrev=False)
cls.add_args(parser)
# copied from argparse.py:
defaults = argparse.Namespace()
for action in parser._actions:
if action.dest is not argparse.SUPPRESS:
if not hasattr(defaults, action.dest):
if action.default is not argparse.SUPPRESS:
setattr(defaults, action.dest, action.default)
for key, default_value in vars(defaults).items():
if not hasattr(state['args'], key):
setattr(state['args'], key, default_value)
# set any missing default values in the task, model or other registries
set_defaults(tasks.TASK_REGISTRY[state['args'].task])
set_defaults(models.ARCH_MODEL_REGISTRY[state['args'].arch])
for registry_name, REGISTRY in registry.REGISTRIES.items():
choice = getattr(state['args'], registry_name, None)
if choice is not None:
cls = REGISTRY['registry'][choice]
set_defaults(cls)
return state return state
......
...@@ -25,6 +25,7 @@ DEFAULT_MAX_TARGET_POSITIONS = 1024 ...@@ -25,6 +25,7 @@ DEFAULT_MAX_TARGET_POSITIONS = 1024
@register_model('transformer_lm') @register_model('transformer_lm')
class TransformerLanguageModel(FairseqLanguageModel): class TransformerLanguageModel(FairseqLanguageModel):
def __init__(self, decoder): def __init__(self, decoder):
super().__init__(decoder) super().__init__(decoder)
...@@ -55,7 +56,7 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -55,7 +56,7 @@ class TransformerLanguageModel(FairseqLanguageModel):
help='num decoder attention heads') help='num decoder attention heads')
parser.add_argument('--decoder-normalize-before', default=False, action='store_true', parser.add_argument('--decoder-normalize-before', default=False, action='store_true',
help='apply layernorm before each decoder block') help='apply layernorm before each decoder block')
parser.add_argument('--no-decoder-final-norm', action='store_true', parser.add_argument('--no-decoder-final-norm', default=False, action='store_true',
help='don\'t add an extra layernorm after the last decoder block') help='don\'t add an extra layernorm after the last decoder block')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. ' help='comma separated list of adaptive softmax cutoff points. '
...@@ -73,9 +74,9 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -73,9 +74,9 @@ class TransformerLanguageModel(FairseqLanguageModel):
parser.add_argument('--character-filters', type=str, metavar='LIST', parser.add_argument('--character-filters', type=str, metavar='LIST',
default='[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]', default='[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]',
help='size of character embeddings') help='size of character embeddings')
parser.add_argument('--character-embedding-dim', type=int, metavar='N', default=4, parser.add_argument('--character-embedding-dim', default=4, type=int, metavar='N',
help='size of character embeddings') help='size of character embeddings')
parser.add_argument('--char-embedder-highway-layers', type=int, metavar='N', default=2, parser.add_argument('--char-embedder-highway-layers', default=2, type=int, metavar='N',
help='number of highway layers for character token embeddder') help='number of highway layers for character token embeddder')
parser.add_argument('--adaptive-input', action='store_true', parser.add_argument('--adaptive-input', action='store_true',
help='if set, uses adaptive input') help='if set, uses adaptive input')
...@@ -132,12 +133,14 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -132,12 +133,14 @@ class TransformerLanguageModel(FairseqLanguageModel):
@register_model_architecture('transformer_lm', 'transformer_lm') @register_model_architecture('transformer_lm', 'transformer_lm')
def base_lm_architecture(args): def base_lm_architecture(args):
# backward compatibility for older model checkpoints # backward compatibility for older model checkpoints
if hasattr(args, 'no_tie_adaptive_proj') and args.no_tie_adaptive_proj is False: if hasattr(args, 'no_tie_adaptive_proj'):
args.tie_adaptive_proj = True # previous models defined --no-tie-adaptive-proj, so use the existence of
# that option to determine if this is an "old" model checkpoint
args.no_decoder_final_norm = True # old models always set this to True
if args.no_tie_adaptive_proj is False:
args.tie_adaptive_proj = True
if hasattr(args, 'decoder_final_norm'): if hasattr(args, 'decoder_final_norm'):
args.no_decoder_final_norm = not args.decoder_final_norm args.no_decoder_final_norm = not args.decoder_final_norm
if not hasattr(args, 'no_decoder_final_norm'):
args.no_decoder_final_norm = True # old models always set this to True
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 2048) args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 2048)
......
...@@ -79,7 +79,7 @@ class CosineSchedule(FairseqLRScheduler): ...@@ -79,7 +79,7 @@ class CosineSchedule(FairseqLRScheduler):
help='warmup the learning rate linearly for the first N updates') help='warmup the learning rate linearly for the first N updates')
parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR', parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR',
help='initial learning rate during warmup phase; default is args.lr') help='initial learning rate during warmup phase; default is args.lr')
parser.add_argument('--max-lr', required=True, type=float, metavar='LR', parser.add_argument('--max-lr', type=float, metavar='LR',
help='max learning rate, must be more than args.lr') help='max learning rate, must be more than args.lr')
parser.add_argument('--t-mult', default=1, type=float, metavar='LR', parser.add_argument('--t-mult', default=1, type=float, metavar='LR',
help='factor to grow the length of each period') help='factor to grow the length of each period')
......
...@@ -111,7 +111,7 @@ class TranslationTask(FairseqTask): ...@@ -111,7 +111,7 @@ class TranslationTask(FairseqTask):
help='target language') help='target language')
parser.add_argument('--lazy-load', action='store_true', parser.add_argument('--lazy-load', action='store_true',
help='load the dataset lazily') help='load the dataset lazily')
parser.add_argument('--raw-text', default=False, action='store_true', parser.add_argument('--raw-text', action='store_true',
help='load raw text dataset') help='load raw text dataset')
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL', parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
help='pad the source on the left') help='pad the source on the left')
......
...@@ -51,9 +51,9 @@ class TranslationMoETask(TranslationTask): ...@@ -51,9 +51,9 @@ class TranslationMoETask(TranslationTask):
"""Add task-specific arguments to the parser.""" """Add task-specific arguments to the parser."""
# fmt: off # fmt: off
TranslationTask.add_args(parser) TranslationTask.add_args(parser)
parser.add_argument('--method', required=True, parser.add_argument('--method', default='hMoEup',
choices=['sMoElp', 'sMoEup', 'hMoElp', 'hMoEup']) choices=['sMoElp', 'sMoEup', 'hMoElp', 'hMoEup'])
parser.add_argument('--num-experts', type=int, metavar='N', required=True, parser.add_argument('--num-experts', default=3, type=int, metavar='N',
help='number of experts') help='number of experts')
parser.add_argument('--mean-pool-gating-network', action='store_true', parser.add_argument('--mean-pool-gating-network', action='store_true',
help='use a simple mean-pooling gating network') help='use a simple mean-pooling gating network')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment