"docs/source/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "3a96fb57dd868c413befc4f60bad3d2effebdef7"
Commit d45db804 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Merge internal changes (#654)

Summary:
- Add --add-bos-token option to LM task
- Cleanup utils.py and options.py
Pull Request resolved: https://github.com/pytorch/fairseq/pull/654

Differential Revision: D15041794

Pulled By: myleott

fbshipit-source-id: 3ad00007769d5f48308052cfd40de39c5ffa1a6e
parent 89a69616
...@@ -15,9 +15,15 @@ Optimizers update the Model parameters based on the gradients. ...@@ -15,9 +15,15 @@ Optimizers update the Model parameters based on the gradients.
:members: :members:
:undoc-members: :undoc-members:
.. autoclass:: fairseq.optim.adadelta.Adadelta
:members:
:undoc-members:
.. autoclass:: fairseq.optim.adagrad.Adagrad .. autoclass:: fairseq.optim.adagrad.Adagrad
:members: :members:
:undoc-members: :undoc-members:
.. autoclass:: fairseq.optim.adafactor.FairseqAdafactor
:members:
:undoc-members:
.. autoclass:: fairseq.optim.adam.FairseqAdam .. autoclass:: fairseq.optim.adam.FairseqAdam
:members: :members:
:undoc-members: :undoc-members:
......
...@@ -28,11 +28,12 @@ fairseq implements the following high-level training flow:: ...@@ -28,11 +28,12 @@ fairseq implements the following high-level training flow::
lr_scheduler.step_update(num_updates) lr_scheduler.step_update(num_updates)
lr_scheduler.step(epoch) lr_scheduler.step(epoch)
where the default implementation for ``train.train_step`` is roughly:: where the default implementation for ``task.train_step`` is roughly::
def train_step(self, batch, model, criterion, optimizer): def train_step(self, batch, model, criterion, optimizer):
loss = criterion(model, batch) loss = criterion(model, batch)
optimizer.backward(loss) optimizer.backward(loss)
return loss
**Registering new plug-ins** **Registering new plug-ins**
......
...@@ -354,7 +354,7 @@ The model files should appear in the :file:`checkpoints/` directory. ...@@ -354,7 +354,7 @@ The model files should appear in the :file:`checkpoints/` directory.
Finally we can write a short script to evaluate our model on new inputs. Create Finally we can write a short script to evaluate our model on new inputs. Create
a new file named :file:`eval_classifier.py` with the following contents:: a new file named :file:`eval_classifier.py` with the following contents::
from fairseq import data, options, tasks, utils from fairseq import checkpoint_utils, data, options, tasks
# Parse command-line arguments for generation # Parse command-line arguments for generation
parser = options.get_generation_parser(default_task='simple_classification') parser = options.get_generation_parser(default_task='simple_classification')
...@@ -365,7 +365,7 @@ a new file named :file:`eval_classifier.py` with the following contents:: ...@@ -365,7 +365,7 @@ a new file named :file:`eval_classifier.py` with the following contents::
# Load model # Load model
print('| loading model from {}'.format(args.path)) print('| loading model from {}'.format(args.path))
models, _model_args = utils.load_ensemble_for_inference([args.path], task) models, _model_args = checkpoint_utils.load_model_ensemble([args.path], task=task)
model = models[0] model = models[0]
while True: while True:
......
...@@ -13,11 +13,10 @@ Evaluate the perplexity of a trained language model. ...@@ -13,11 +13,10 @@ Evaluate the perplexity of a trained language model.
import numpy as np import numpy as np
import torch import torch
from fairseq import options, progress_bar, tasks, utils from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
from fairseq.data import LMContextWindowDataset from fairseq.data import LMContextWindowDataset
from fairseq.meters import StopwatchMeter, TimeMeter from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer from fairseq.sequence_scorer import SequenceScorer
from fairseq.utils import import_user_module
class WordStat(object): class WordStat(object):
...@@ -49,7 +48,7 @@ class WordStat(object): ...@@ -49,7 +48,7 @@ class WordStat(object):
def main(parsed_args): def main(parsed_args):
assert parsed_args.path is not None, '--path required for evaluation!' assert parsed_args.path is not None, '--path required for evaluation!'
import_user_module(parsed_args) utils.import_user_module(parsed_args)
print(parsed_args) print(parsed_args)
...@@ -59,12 +58,17 @@ def main(parsed_args): ...@@ -59,12 +58,17 @@ def main(parsed_args):
# Load ensemble # Load ensemble
print('| loading model(s) from {}'.format(parsed_args.path)) print('| loading model(s) from {}'.format(parsed_args.path))
models, args = utils.load_ensemble_for_inference( models, args = checkpoint_utils.load_model_ensemble(
parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides), parsed_args.path.split(':'),
arg_overrides=eval(parsed_args.model_overrides),
task=task,
) )
for arg in vars(parsed_args).keys(): for arg in vars(parsed_args).keys():
if arg not in {'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary'}: if arg not in {
'self_target', 'future_target', 'past_target', 'tokens_per_sample',
'output_size_dictionary', 'add_bos_token',
}:
setattr(args, arg, getattr(parsed_args, arg)) setattr(args, arg, getattr(parsed_args, arg))
# reduce tokens per sample by the required context window size # reduce tokens per sample by the required context window size
...@@ -151,6 +155,11 @@ def main(parsed_args): ...@@ -151,6 +155,11 @@ def main(parsed_args):
tgt_len = tokens.numel() tgt_len = tokens.numel()
pos_scores = hypo['positional_scores'].float() pos_scores = hypo['positional_scores'].float()
if args.add_bos_token:
assert hypo['tokens'][0].item() == task.target_dictionary.bos()
tokens = tokens[1:]
pos_scores = pos_scores[1:]
skipped_toks = 0 skipped_toks = 0
if bpe_toks is not None: if bpe_toks is not None:
for i in range(tgt_len - 1): for i in range(tgt_len - 1):
......
...@@ -39,7 +39,7 @@ $ fairseq-train --task language_modeling data-bin/wikitext-103 \ ...@@ -39,7 +39,7 @@ $ fairseq-train --task language_modeling data-bin/wikitext-103 \
--save-dir checkpoints/transformer_wikitext-103 --arch transformer_lm_wiki103 \ --save-dir checkpoints/transformer_wikitext-103 --arch transformer_lm_wiki103 \
--max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \ --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \
--warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1 \ --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1 \
--criterion adaptive_loss --max-tokens 3072 --update-freq 4 --tokens-per-sample 3072 --seed 1 \ --criterion adaptive_loss --max-tokens 3072 --update-freq 3 --tokens-per-sample 3072 --seed 1 \
--sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d
# Evaluate: # Evaluate:
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from collections import OrderedDict
import logging
import os
import re
import traceback
import torch
from torch.serialization import default_restore_location
from fairseq import tasks
def load_checkpoint_to_cpu(path):
"""Loads a checkpoint to CPU (with upgrading for backward compatibility)."""
state = torch.load(
path, map_location=lambda s, l: default_restore_location(s, 'cpu'),
)
state = _upgrade_state_dict(state)
return state
def load_model_ensemble(filenames, arg_overrides=None, task=None):
"""Loads an ensemble of models.
Args:
filenames (List[str]): checkpoint files to load
arg_overrides (Dict[str,Any], optional): override model args that
were used during model training
task (fairseq.tasks.FairseqTask, optional): task to use for loading
"""
ensemble = []
for filename in filenames:
if not os.path.exists(filename):
raise IOError('Model file not found: {}'.format(filename))
state = load_checkpoint_to_cpu(filename)
args = state['args']
if arg_overrides is not None:
for arg_name, arg_val in arg_overrides.items():
setattr(args, arg_name, arg_val)
if task is None:
task = tasks.setup_task(args)
# build model for ensemble
model = task.build_model(args)
model.load_state_dict(state['model'], strict=True)
ensemble.append(model)
return ensemble, args
def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
"""Retrieves all checkpoints found in `path` directory.
Checkpoints are identified by matching filename to the specified pattern. If
the pattern contains groups, the result will be sorted by the first group in
descending order.
"""
pt_regexp = re.compile(pattern)
files = os.listdir(path)
entries = []
for i, f in enumerate(files):
m = pt_regexp.fullmatch(f)
if m is not None:
idx = int(m.group(1)) if len(m.groups()) > 0 else i
entries.append((idx, m.group(0)))
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
def torch_persistent_save(*args, **kwargs):
for i in range(3):
try:
return torch.save(*args, **kwargs)
except Exception:
if i == 2:
logging.error(traceback.format_exc())
def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
if isinstance(state_dict, dict):
cpu_dict = OrderedDict()
for k, v in state_dict.items():
cpu_dict[k] = convert_state_dict_type(v)
return cpu_dict
elif isinstance(state_dict, list):
return [convert_state_dict_type(v) for v in state_dict]
elif torch.is_tensor(state_dict):
return state_dict.type(ttype)
else:
return state_dict
def save_state(
filename, args, model_state_dict, criterion, optimizer, lr_scheduler,
num_updates, optim_history=None, extra_state=None,
):
if optim_history is None:
optim_history = []
if extra_state is None:
extra_state = {}
state_dict = {
'args': args,
'model': model_state_dict if model_state_dict else {},
'optimizer_history': optim_history + [
{
'criterion_name': criterion.__class__.__name__,
'optimizer_name': optimizer.__class__.__name__,
'lr_scheduler_state': lr_scheduler.state_dict(),
'num_updates': num_updates,
}
],
'last_optimizer_state': convert_state_dict_type(optimizer.state_dict()),
'extra_state': extra_state,
}
torch_persistent_save(state_dict, filename)
def _upgrade_state_dict(state):
"""Helper for upgrading old model checkpoints."""
# add optimizer_history
if 'optimizer_history' not in state:
state['optimizer_history'] = [
{
'criterion_name': 'CrossEntropyCriterion',
'best_loss': state['best_loss'],
},
]
state['last_optimizer_state'] = state['optimizer']
del state['optimizer']
del state['best_loss']
# move extra_state into sub-dictionary
if 'epoch' in state and 'extra_state' not in state:
state['extra_state'] = {
'epoch': state['epoch'],
'batch_offset': state['batch_offset'],
'val_loss': state['val_loss'],
}
del state['epoch']
del state['batch_offset']
del state['val_loss']
# reduce optimizer history's memory usage (only keep the last state)
if 'optimizer' in state['optimizer_history'][-1]:
state['last_optimizer_state'] = state['optimizer_history'][-1]['optimizer']
for optim_hist in state['optimizer_history']:
del optim_hist['optimizer']
# record the optimizer class name
if 'optimizer_name' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['optimizer_name'] = 'FairseqNAG'
# move best_loss into lr_scheduler_state
if 'lr_scheduler_state' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['lr_scheduler_state'] = {
'best': state['optimizer_history'][-1]['best_loss'],
}
del state['optimizer_history'][-1]['best_loss']
# keep track of number of updates
if 'num_updates' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['num_updates'] = 0
# old model checkpoints may not have separate source/target positions
if hasattr(state['args'], 'max_positions') and not hasattr(state['args'], 'max_source_positions'):
state['args'].max_source_positions = state['args'].max_positions
state['args'].max_target_positions = state['args'].max_positions
# use stateful training data iterator
if 'train_iterator' not in state['extra_state']:
state['extra_state']['train_iterator'] = {
'epoch': state['extra_state']['epoch'],
'iterations_in_epoch': state['extra_state'].get('batch_offset', 0),
}
return state
...@@ -18,13 +18,12 @@ from fairseq.data import data_utils ...@@ -18,13 +18,12 @@ from fairseq.data import data_utils
class Dictionary(object): class Dictionary(object):
"""A mapping from symbols to consecutive integers""" """A mapping from symbols to consecutive integers"""
def __init__(self, pad='<pad>', eos='</s>', unk='<unk>'): def __init__(self, pad='<pad>', eos='</s>', unk='<unk>', bos='<s>'):
self.unk_word, self.pad_word, self.eos_word = unk, pad, eos self.unk_word, self.pad_word, self.eos_word = unk, pad, eos
self.symbols = [] self.symbols = []
self.count = [] self.count = []
self.indices = {} self.indices = {}
# dictionary indexing starts at 1 for consistency with Lua self.bos_index = self.add_symbol(bos)
self.add_symbol('<Lua heritage>')
self.pad_index = self.add_symbol(pad) self.pad_index = self.add_symbol(pad)
self.eos_index = self.add_symbol(eos) self.eos_index = self.add_symbol(eos)
self.unk_index = self.add_symbol(unk) self.unk_index = self.add_symbol(unk)
...@@ -143,6 +142,10 @@ class Dictionary(object): ...@@ -143,6 +142,10 @@ class Dictionary(object):
self.symbols = list(new_symbols) self.symbols = list(new_symbols)
self.indices = new_indices self.indices = new_indices
def bos(self):
"""Helper to get index of beginning-of-sentence symbol"""
return self.bos_index
def pad(self): def pad(self):
"""Helper to get index of pad symbol""" """Helper to get index of pad symbol"""
return self.pad_index return self.pad_index
......
...@@ -62,13 +62,14 @@ class MonolingualDataset(FairseqDataset): ...@@ -62,13 +62,14 @@ class MonolingualDataset(FairseqDataset):
""" """
def __init__(self, dataset, sizes, src_vocab, tgt_vocab, add_eos_for_other_targets, shuffle, def __init__(self, dataset, sizes, src_vocab, tgt_vocab, add_eos_for_other_targets, shuffle,
targets=None): targets=None, add_bos_token=False):
self.dataset = dataset self.dataset = dataset
self.sizes = np.array(sizes) self.sizes = np.array(sizes)
self.vocab = src_vocab self.vocab = src_vocab
self.tgt_vocab = tgt_vocab self.tgt_vocab = tgt_vocab
self.add_eos_for_other_targets = add_eos_for_other_targets self.add_eos_for_other_targets = add_eos_for_other_targets
self.shuffle = shuffle self.shuffle = shuffle
self.add_bos_token = add_bos_token
assert targets is None or all(t in {'self', 'future', 'past'} for t in targets), \ assert targets is None or all(t in {'self', 'future', 'past'} for t in targets), \
"targets must be none or one of 'self', 'future', 'past'" "targets must be none or one of 'self', 'future', 'past'"
...@@ -91,6 +92,7 @@ class MonolingualDataset(FairseqDataset): ...@@ -91,6 +92,7 @@ class MonolingualDataset(FairseqDataset):
else: else:
source = self.dataset[index] source = self.dataset[index]
target = None target = None
source, target = self._maybe_add_bos(source, target)
return {'id': index, 'source': source, 'target': target} return {'id': index, 'source': source, 'target': target}
def __len__(self): def __len__(self):
...@@ -129,6 +131,13 @@ class MonolingualDataset(FairseqDataset): ...@@ -129,6 +131,13 @@ class MonolingualDataset(FairseqDataset):
return source, self._filter_vocab(target) return source, self._filter_vocab(target)
def _maybe_add_bos(self, source, target):
if self.add_bos_token:
source = torch.cat([source.new([self.vocab.bos()]), source])
if target is not None:
target = torch.cat([target.new([self.tgt_vocab.bos()]), target])
return source, target
def _filter_vocab(self, target): def _filter_vocab(self, target):
if len(self.tgt_vocab) != len(self.vocab): if len(self.tgt_vocab) != len(self.vocab):
def _filter(target): def _filter(target):
...@@ -173,6 +182,7 @@ class MonolingualDataset(FairseqDataset): ...@@ -173,6 +182,7 @@ class MonolingualDataset(FairseqDataset):
target = self.vocab.dummy_sentence(tgt_len + 2) target = self.vocab.dummy_sentence(tgt_len + 2)
source, past_target, future_target = target[1:-1], target[2:], target[:-2] source, past_target, future_target = target[1:-1], target[2:], target[:-2]
source, target = self._make_source_target(source, past_target, future_target) source, target = self._make_source_target(source, past_target, future_target)
source, target = self._maybe_add_bos(source, target)
return self.collater([ return self.collater([
{'id': i, 'source': source, 'target': target} {'id': i, 'source': source, 'target': target}
......
...@@ -141,7 +141,7 @@ class FConvLanguageModel(FairseqLanguageModel): ...@@ -141,7 +141,7 @@ class FConvLanguageModel(FairseqLanguageModel):
# make sure all arguments are present in older models # make sure all arguments are present in older models
base_lm_architecture(args) base_lm_architecture(args)
if hasattr(args, 'max_target_positions'): if hasattr(args, 'max_target_positions') and not hasattr(args, 'tokens_per_sample'):
args.tokens_per_sample = args.max_target_positions args.tokens_per_sample = args.max_target_positions
decoder = FConvDecoder( decoder = FConvDecoder(
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
#
import math import math
...@@ -12,11 +11,11 @@ import torch ...@@ -12,11 +11,11 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import checkpoint_utils
from fairseq.modules import ( from fairseq.modules import (
DownsampledMultiHeadAttention, GradMultiply, LayerNorm, DownsampledMultiHeadAttention, GradMultiply, LayerNorm,
LearnedPositionalEmbedding, LinearizedConvolution, LearnedPositionalEmbedding, LinearizedConvolution,
) )
from fairseq import utils
from . import ( from . import (
FairseqEncoder, CompositeEncoder, FairseqDecoder, FairseqModel, FairseqEncoder, CompositeEncoder, FairseqDecoder, FairseqModel,
...@@ -84,8 +83,7 @@ class FConvModelSelfAtt(FairseqModel): ...@@ -84,8 +83,7 @@ class FConvModelSelfAtt(FairseqModel):
pretrained = eval(args.pretrained) pretrained = eval(args.pretrained)
if pretrained: if pretrained:
print("| loading pretrained model") print("| loading pretrained model")
trained_model = utils.load_ensemble_for_inference( trained_model = checkpoint_utils.load_model_ensemble(
# not actually for inference, but loads pretrained model parameters
filenames=[args.pretrained_checkpoint], filenames=[args.pretrained_checkpoint],
task=task, task=task,
)[0][0] )[0][0]
......
...@@ -830,6 +830,7 @@ def base_lm_architecture(args): ...@@ -830,6 +830,7 @@ def base_lm_architecture(args):
args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False) args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False)
args.activation_fn = getattr(args, 'activation_fn', 'relu') args.activation_fn = getattr(args, 'activation_fn', 'relu')
args.add_bos_token = getattr(args, 'add_bos_token', False)
args.character_embeddings = getattr(args, 'character_embeddings', False) args.character_embeddings = getattr(args, 'character_embeddings', False)
args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim) args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim)
...@@ -927,7 +928,7 @@ def transformer_wmt_en_de(args): ...@@ -927,7 +928,7 @@ def transformer_wmt_en_de(args):
base_architecture(args) base_architecture(args)
# parameters used in the "Attention Is All You Need" paper (Vaswani, et al, 2017) # parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017)
@register_model_architecture('transformer', 'transformer_vaswani_wmt_en_de_big') @register_model_architecture('transformer', 'transformer_vaswani_wmt_en_de_big')
def transformer_vaswani_wmt_en_de_big(args): def transformer_vaswani_wmt_en_de_big(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024) args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
import os import os
from typing import Any, Dict from typing import Any, Dict
from fairseq import utils from fairseq import checkpoint_utils
from fairseq.data.masked_lm_dictionary import MaskedLMDictionary from fairseq.data.masked_lm_dictionary import MaskedLMDictionary
from fairseq.models.transformer import ( from fairseq.models.transformer import (
TransformerDecoder, TransformerDecoder,
...@@ -92,7 +92,7 @@ def upgrade_state_dict_with_xlm_weights( ...@@ -92,7 +92,7 @@ def upgrade_state_dict_with_xlm_weights(
if not os.path.exists(pretrained_xlm_checkpoint): if not os.path.exists(pretrained_xlm_checkpoint):
raise IOError(f"Model file not found: {pretrained_xlm_checkpoint}") raise IOError(f"Model file not found: {pretrained_xlm_checkpoint}")
state = utils.load_checkpoint_to_cpu(pretrained_xlm_checkpoint) state = checkpoint_utils.load_checkpoint_to_cpu(pretrained_xlm_checkpoint)
xlm_state_dict = state["model"] xlm_state_dict = state["model"]
for key in xlm_state_dict.keys(): for key in xlm_state_dict.keys():
......
...@@ -4,17 +4,15 @@ ...@@ -4,17 +4,15 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import math
import torch
""" """
See "Gaussian Error Linear Units (GELUs)" by Dan Hendrycks and Kevin Gimpel with See "Gaussian Error Linear Units (GELUs)" by Dan Hendrycks and Kevin Gimpel with
the corresponding GitHub repo: https://github.com/hendrycks/GELUs the corresponding GitHub repo: https://github.com/hendrycks/GELUs
""" """
import math
import torch
def gelu_fast(x): def gelu_fast(x):
if not hasattr(gelu_fast, "_a"): if not hasattr(gelu_fast, "_a"):
......
...@@ -19,11 +19,15 @@ class Adadelta(FairseqOptimizer): ...@@ -19,11 +19,15 @@ class Adadelta(FairseqOptimizer):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add optimizer-specific arguments to the parser.""" """Add optimizer-specific arguments to the parser."""
# fmt: off
parser.add_argument('--adadelta-rho', type=float, default=0.9, metavar='RHO', parser.add_argument('--adadelta-rho', type=float, default=0.9, metavar='RHO',
help='coefficient used for computing a running average of squared gradients') help='coefficient used for computing a running average of squared gradients')
parser.add_argument('--adadelta-eps', type=float, default=1e-6, metavar='EPS', parser.add_argument('--adadelta-eps', type=float, default=1e-6, metavar='EPS',
help='term added to the denominator to improve numerical stability') help='term added to the denominator to improve numerical stability')
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
parser.add_argument('--anneal-eps', action='store_true', help='flag to anneal eps') parser.add_argument('--anneal-eps', action='store_true', help='flag to anneal eps')
# fmt: on
@property @property
def optimizer_config(self): def optimizer_config(self):
......
...@@ -21,6 +21,7 @@ class FairseqAdafactor(FairseqOptimizer): ...@@ -21,6 +21,7 @@ class FairseqAdafactor(FairseqOptimizer):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add optimizer-specific arguments to the parser.""" """Add optimizer-specific arguments to the parser."""
# fmt: off
parser.add_argument('--adafactor-eps', default='(1e-30, 1e-3)', metavar="E", parser.add_argument('--adafactor-eps', default='(1e-30, 1e-3)', metavar="E",
help='epsilons for Adafactor optimizer') help='epsilons for Adafactor optimizer')
parser.add_argument('--clip-threshold', type=float, default=1.0, metavar="C", parser.add_argument('--clip-threshold', type=float, default=1.0, metavar="C",
...@@ -31,11 +32,14 @@ class FairseqAdafactor(FairseqOptimizer): ...@@ -31,11 +32,14 @@ class FairseqAdafactor(FairseqOptimizer):
help='beta for first moment estimator. Optional') help='beta for first moment estimator. Optional')
parser.add_argument('--scale-parameter', action='store_true', parser.add_argument('--scale-parameter', action='store_true',
help='scale learning rate by root mean square of parameter.') help='scale learning rate by root mean square of parameter.')
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
parser.add_argument('--warmup-init', action='store_true', parser.add_argument('--warmup-init', action='store_true',
help='use relative step for warm-up learning rate schedule') help='use relative step for warm-up learning rate schedule')
parser.add_argument('--relative-step', action='store_true', parser.add_argument('--relative-step', action='store_true',
help='set learning rate to inverse square root of timestep.' help='set learning rate to inverse square root of timestep.'
'If false, external learning rate applied') 'If false, external learning rate applied')
# fmt: on
@property @property
def optimizer_config(self): def optimizer_config(self):
......
...@@ -16,6 +16,14 @@ class Adagrad(FairseqOptimizer): ...@@ -16,6 +16,14 @@ class Adagrad(FairseqOptimizer):
super().__init__(args, params) super().__init__(args, params)
self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config) self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config)
@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
# fmt: off
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
# fmt: on
@property @property
def optimizer_config(self): def optimizer_config(self):
""" """
......
...@@ -30,6 +30,8 @@ class FairseqAdam(FairseqOptimizer): ...@@ -30,6 +30,8 @@ class FairseqAdam(FairseqOptimizer):
help='betas for Adam optimizer') help='betas for Adam optimizer')
parser.add_argument('--adam-eps', type=float, default=1e-8, metavar='D', parser.add_argument('--adam-eps', type=float, default=1e-8, metavar='D',
help='epsilon for Adam optimizer') help='epsilon for Adam optimizer')
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
# fmt: on # fmt: on
@property @property
......
...@@ -85,6 +85,8 @@ class CosineSchedule(FairseqLRScheduler): ...@@ -85,6 +85,8 @@ class CosineSchedule(FairseqLRScheduler):
help='factor to grow the length of each period') help='factor to grow the length of each period')
parser.add_argument('--lr-period-updates', default=-1, type=float, metavar='LR', parser.add_argument('--lr-period-updates', default=-1, type=float, metavar='LR',
help='initial number of updates per period') help='initial number of updates per period')
parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
help='shrink factor for annealing')
# fmt: on # fmt: on
def step(self, epoch, val_loss=None): def step(self, epoch, val_loss=None):
......
...@@ -30,6 +30,8 @@ class FixedSchedule(FairseqLRScheduler): ...@@ -30,6 +30,8 @@ class FixedSchedule(FairseqLRScheduler):
# fmt: off # fmt: off
parser.add_argument('--force-anneal', '--fa', type=int, metavar='N', parser.add_argument('--force-anneal', '--fa', type=int, metavar='N',
help='force annealing at specified epoch') help='force annealing at specified epoch')
parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
help='shrink factor for annealing, lr_new = (lr * lr_shrink)')
parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', parser.add_argument('--warmup-updates', default=0, type=int, metavar='N',
help='warmup the learning rate linearly for the first N updates') help='warmup the learning rate linearly for the first N updates')
# fmt: on # fmt: on
......
...@@ -24,6 +24,14 @@ class ReduceLROnPlateau(FairseqLRScheduler): ...@@ -24,6 +24,14 @@ class ReduceLROnPlateau(FairseqLRScheduler):
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer.optimizer, patience=0, factor=args.lr_shrink) self.optimizer.optimizer, patience=0, factor=args.lr_shrink)
@staticmethod
def add_args(parser):
"""Add arguments to the parser for this LR scheduler."""
# fmt: off
parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
help='shrink factor for annealing, lr_new = (lr * lr_shrink)')
# fmt: on
def state_dict(self): def state_dict(self):
"""Return the LR scheduler state dict.""" """Return the LR scheduler state dict."""
return { return {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment