Commit 65c1903e authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Misc fixes

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/529

Differential Revision: D14218384

Pulled By: myleott

fbshipit-source-id: 5d2cbb1f56ea42e9929785aff4a5ae5f44d13724
parent 94fedf00
......@@ -5,8 +5,6 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from .multiprocessing_pdb import pdb
__all__ = ['pdb']
__version__ = '0.6.1'
......@@ -15,4 +13,5 @@ import fairseq.models
import fairseq.modules
import fairseq.optim
import fairseq.optim.lr_scheduler
import fairseq.pdb
import fairseq.tasks
......@@ -89,7 +89,7 @@ class Adafactor(torch.optim.Optimizer):
"""
def __init__(self, params, lr=None, eps=(1e-30, 1e-3), clip_threshold=1.0,
decay_rate=-0.8, beta1=None, weight_decay=0.0, scale_parameter=True,
decay_rate=-0.8, beta1=None, weight_decay=0.0, scale_parameter=True,
relative_step=True, warmup_init=False):
defaults = dict(lr=lr, eps=eps, clip_threshold=clip_threshold, decay_rate=decay_rate,
beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter,
......@@ -159,7 +159,7 @@ class Adafactor(torch.optim.Optimizer):
state['step'] += 1
state['RMS'] = self._rms(p.data)
lr = self._get_lr(group, state)
group['lr'] = self._get_lr(group, state)
beta2t = 1.0 - math.pow(state['step'], group['decay_rate'])
update = (grad**2) + group['eps'][0]
......@@ -180,7 +180,7 @@ class Adafactor(torch.optim.Optimizer):
torch.rsqrt(exp_avg_sq, out=update).mul_(grad)
update.div_(max(1.0, self._rms(update) / group['clip_threshold']))
update.mul_(lr)
update.mul_(group['lr'])
if use_first_moment:
exp_avg = state['exp_avg']
......@@ -188,7 +188,7 @@ class Adafactor(torch.optim.Optimizer):
update = exp_avg
if group['weight_decay'] != 0:
p.data.add_(-group['weight_decay'] * lr, p.data)
p.data.add_(-group['weight_decay'] * group['lr'], p.data)
p.data.add_(-update)
......
......@@ -12,11 +12,15 @@ from fairseq import optim, utils
class DynamicLossScaler(object):
def __init__(self, init_scale=2.**15, scale_factor=2., scale_window=2000, tolerance=0.05):
def __init__(
self, init_scale=2.**15, scale_factor=2., scale_window=2000,
tolerance=0.05, threshold=None,
):
self.loss_scale = init_scale
self.scale_factor = scale_factor
self.scale_window = scale_window
self.tolerance = tolerance
self.threshold = threshold
self._iter = 0
self._last_overflow_iter = -1
self._last_rescale_iter = -1
......@@ -29,7 +33,7 @@ class DynamicLossScaler(object):
self._overflows_since_rescale += 1
pct_overflow = self._overflows_since_rescale / float(iter_since_rescale)
if pct_overflow >= self.tolerance:
self.loss_scale /= self.scale_factor
self._decrease_loss_scale()
self._last_rescale_iter = self._iter
self._overflows_since_rescale = 0
elif (self._iter - self._last_overflow_iter) % self.scale_window == 0:
......@@ -37,6 +41,11 @@ class DynamicLossScaler(object):
self._last_rescale_iter = self._iter
self._iter += 1
def _decrease_loss_scale(self):
self.loss_scale /= self.scale_factor
if self.threshold is not None:
self.loss_scale = max(self.loss_scale, self.threshold)
@staticmethod
def has_overflow(grad_norm):
# detect inf and nan
......@@ -69,6 +78,7 @@ class FP16Optimizer(optim.FairseqOptimizer):
init_scale=args.fp16_init_scale,
scale_window=scale_window,
tolerance=args.fp16_scale_tolerance,
threshold=args.threshold_loss_scale,
)
@classmethod
......@@ -279,6 +289,7 @@ class MemoryEfficientFP16Optimizer(optim.FairseqOptimizer):
init_scale=args.fp16_init_scale,
scale_window=scale_window,
tolerance=args.fp16_scale_tolerance,
threshold=args.threshold_loss_scale,
)
@classmethod
......
......@@ -108,10 +108,6 @@ def parse_args_and_arch(parser, input_args=None, parse_known=False):
extra = None
# Post-process args.
if hasattr(args, 'lr'):
args.lr = eval_str_list(args.lr, type=float)
if hasattr(args, 'update_freq'):
args.update_freq = eval_str_list(args.update_freq, type=int)
if hasattr(args, 'max_sentences_valid') and args.max_sentences_valid is None:
args.max_sentences_valid = args.max_sentences
if getattr(args, 'memory_efficient_fp16', False):
......@@ -154,6 +150,10 @@ def get_parser(desc, default_task='translation'):
help='number of updates before increasing loss scale')
parser.add_argument('--fp16-scale-tolerance', default=0.0, type=float,
help='pct of updates that can overflow before decreasing the loss scale')
parser.add_argument('--min-loss-scale', default=1e-4, type=float, metavar='D',
help='minimum FP16 loss scale, after which training is stopped')
parser.add_argument('--threshold-loss-scale', type=float,
help='threshold FP16 loss scale from below')
parser.add_argument('--user-dir', default=None,
help='path to a python module containing custom extensions (tasks and/or architectures)')
......@@ -283,14 +283,16 @@ def add_optimization_args(parser):
group.add_argument('--sentence-avg', action='store_true',
help='normalize gradients by the number of sentences in a batch'
' (default is to normalize by number of tokens)')
group.add_argument('--update-freq', default='1', metavar='N',
group.add_argument('--update-freq', default='1', metavar='N1,N2,...,N_K',
type=lambda uf: eval_str_list(uf, type=int),
help='update parameters every N_i batches, when in epoch i')
# Optimizer definitions can be found under fairseq/optim/
group.add_argument('--optimizer', default='nag', metavar='OPT',
choices=OPTIMIZER_REGISTRY.keys(),
help='Optimizer')
group.add_argument('--lr', '--learning-rate', default='0.25', metavar='LR_1,LR_2,...,LR_N',
group.add_argument('--lr', '--learning-rate', default='0.25', type=eval_str_list,
metavar='LR_1,LR_2,...,LR_N',
help='learning rate for the first N epochs; all epochs >N using LR_N'
' (note: this may be interpreted differently depending on --lr-scheduler)')
group.add_argument('--momentum', default=0.99, type=float, metavar='M',
......@@ -306,8 +308,6 @@ def add_optimization_args(parser):
help='learning rate shrink factor for annealing, lr_new = (lr * lr_shrink)')
group.add_argument('--min-lr', default=1e-5, type=float, metavar='LR',
help='minimum learning rate')
group.add_argument('--min-loss-scale', default=1e-4, type=float, metavar='D',
help='minimum loss scale (for FP16 training)')
# fmt: on
return group
......
......@@ -11,6 +11,9 @@ import pdb
import sys
__all__ = ['set_trace']
class MultiprocessingPdb(pdb.Pdb):
"""A Pdb wrapper that works in a multiprocessing environment.
......@@ -36,4 +39,6 @@ class MultiprocessingPdb(pdb.Pdb):
sys.stdin = stdin_bak
pdb = MultiprocessingPdb()
def set_trace():
pdb = MultiprocessingPdb()
pdb.set_trace(sys._getframe().f_back)
......@@ -122,13 +122,11 @@ def train(args, trainer, task, epoch_itr):
"""Train the model for one epoch."""
# Update parameters every N batches
if epoch_itr.epoch <= len(args.update_freq):
update_freq = args.update_freq[epoch_itr.epoch - 1]
else:
update_freq = args.update_freq[-1]
# Initialize data iterator
itr = epoch_itr.next_epoch_itr(fix_batches_to_gpus=args.fix_batches_to_gpus)
update_freq = args.update_freq[epoch_itr.epoch - 1] \
if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]
itr = iterators.GroupedIterator(itr, update_freq)
progress = progress_bar.build_progress_bar(
args, itr, epoch_itr.epoch, no_progress_bar='simple',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment