Commit 65c1903e authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Misc fixes

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/529

Differential Revision: D14218384

Pulled By: myleott

fbshipit-source-id: 5d2cbb1f56ea42e9929785aff4a5ae5f44d13724
parent 94fedf00
...@@ -5,8 +5,6 @@ ...@@ -5,8 +5,6 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from .multiprocessing_pdb import pdb
__all__ = ['pdb'] __all__ = ['pdb']
__version__ = '0.6.1' __version__ = '0.6.1'
...@@ -15,4 +13,5 @@ import fairseq.models ...@@ -15,4 +13,5 @@ import fairseq.models
import fairseq.modules import fairseq.modules
import fairseq.optim import fairseq.optim
import fairseq.optim.lr_scheduler import fairseq.optim.lr_scheduler
import fairseq.pdb
import fairseq.tasks import fairseq.tasks
...@@ -159,7 +159,7 @@ class Adafactor(torch.optim.Optimizer): ...@@ -159,7 +159,7 @@ class Adafactor(torch.optim.Optimizer):
state['step'] += 1 state['step'] += 1
state['RMS'] = self._rms(p.data) state['RMS'] = self._rms(p.data)
lr = self._get_lr(group, state) group['lr'] = self._get_lr(group, state)
beta2t = 1.0 - math.pow(state['step'], group['decay_rate']) beta2t = 1.0 - math.pow(state['step'], group['decay_rate'])
update = (grad**2) + group['eps'][0] update = (grad**2) + group['eps'][0]
...@@ -180,7 +180,7 @@ class Adafactor(torch.optim.Optimizer): ...@@ -180,7 +180,7 @@ class Adafactor(torch.optim.Optimizer):
torch.rsqrt(exp_avg_sq, out=update).mul_(grad) torch.rsqrt(exp_avg_sq, out=update).mul_(grad)
update.div_(max(1.0, self._rms(update) / group['clip_threshold'])) update.div_(max(1.0, self._rms(update) / group['clip_threshold']))
update.mul_(lr) update.mul_(group['lr'])
if use_first_moment: if use_first_moment:
exp_avg = state['exp_avg'] exp_avg = state['exp_avg']
...@@ -188,7 +188,7 @@ class Adafactor(torch.optim.Optimizer): ...@@ -188,7 +188,7 @@ class Adafactor(torch.optim.Optimizer):
update = exp_avg update = exp_avg
if group['weight_decay'] != 0: if group['weight_decay'] != 0:
p.data.add_(-group['weight_decay'] * lr, p.data) p.data.add_(-group['weight_decay'] * group['lr'], p.data)
p.data.add_(-update) p.data.add_(-update)
......
...@@ -12,11 +12,15 @@ from fairseq import optim, utils ...@@ -12,11 +12,15 @@ from fairseq import optim, utils
class DynamicLossScaler(object): class DynamicLossScaler(object):
def __init__(self, init_scale=2.**15, scale_factor=2., scale_window=2000, tolerance=0.05): def __init__(
self, init_scale=2.**15, scale_factor=2., scale_window=2000,
tolerance=0.05, threshold=None,
):
self.loss_scale = init_scale self.loss_scale = init_scale
self.scale_factor = scale_factor self.scale_factor = scale_factor
self.scale_window = scale_window self.scale_window = scale_window
self.tolerance = tolerance self.tolerance = tolerance
self.threshold = threshold
self._iter = 0 self._iter = 0
self._last_overflow_iter = -1 self._last_overflow_iter = -1
self._last_rescale_iter = -1 self._last_rescale_iter = -1
...@@ -29,7 +33,7 @@ class DynamicLossScaler(object): ...@@ -29,7 +33,7 @@ class DynamicLossScaler(object):
self._overflows_since_rescale += 1 self._overflows_since_rescale += 1
pct_overflow = self._overflows_since_rescale / float(iter_since_rescale) pct_overflow = self._overflows_since_rescale / float(iter_since_rescale)
if pct_overflow >= self.tolerance: if pct_overflow >= self.tolerance:
self.loss_scale /= self.scale_factor self._decrease_loss_scale()
self._last_rescale_iter = self._iter self._last_rescale_iter = self._iter
self._overflows_since_rescale = 0 self._overflows_since_rescale = 0
elif (self._iter - self._last_overflow_iter) % self.scale_window == 0: elif (self._iter - self._last_overflow_iter) % self.scale_window == 0:
...@@ -37,6 +41,11 @@ class DynamicLossScaler(object): ...@@ -37,6 +41,11 @@ class DynamicLossScaler(object):
self._last_rescale_iter = self._iter self._last_rescale_iter = self._iter
self._iter += 1 self._iter += 1
def _decrease_loss_scale(self):
self.loss_scale /= self.scale_factor
if self.threshold is not None:
self.loss_scale = max(self.loss_scale, self.threshold)
@staticmethod @staticmethod
def has_overflow(grad_norm): def has_overflow(grad_norm):
# detect inf and nan # detect inf and nan
...@@ -69,6 +78,7 @@ class FP16Optimizer(optim.FairseqOptimizer): ...@@ -69,6 +78,7 @@ class FP16Optimizer(optim.FairseqOptimizer):
init_scale=args.fp16_init_scale, init_scale=args.fp16_init_scale,
scale_window=scale_window, scale_window=scale_window,
tolerance=args.fp16_scale_tolerance, tolerance=args.fp16_scale_tolerance,
threshold=args.threshold_loss_scale,
) )
@classmethod @classmethod
...@@ -279,6 +289,7 @@ class MemoryEfficientFP16Optimizer(optim.FairseqOptimizer): ...@@ -279,6 +289,7 @@ class MemoryEfficientFP16Optimizer(optim.FairseqOptimizer):
init_scale=args.fp16_init_scale, init_scale=args.fp16_init_scale,
scale_window=scale_window, scale_window=scale_window,
tolerance=args.fp16_scale_tolerance, tolerance=args.fp16_scale_tolerance,
threshold=args.threshold_loss_scale,
) )
@classmethod @classmethod
......
...@@ -108,10 +108,6 @@ def parse_args_and_arch(parser, input_args=None, parse_known=False): ...@@ -108,10 +108,6 @@ def parse_args_and_arch(parser, input_args=None, parse_known=False):
extra = None extra = None
# Post-process args. # Post-process args.
if hasattr(args, 'lr'):
args.lr = eval_str_list(args.lr, type=float)
if hasattr(args, 'update_freq'):
args.update_freq = eval_str_list(args.update_freq, type=int)
if hasattr(args, 'max_sentences_valid') and args.max_sentences_valid is None: if hasattr(args, 'max_sentences_valid') and args.max_sentences_valid is None:
args.max_sentences_valid = args.max_sentences args.max_sentences_valid = args.max_sentences
if getattr(args, 'memory_efficient_fp16', False): if getattr(args, 'memory_efficient_fp16', False):
...@@ -154,6 +150,10 @@ def get_parser(desc, default_task='translation'): ...@@ -154,6 +150,10 @@ def get_parser(desc, default_task='translation'):
help='number of updates before increasing loss scale') help='number of updates before increasing loss scale')
parser.add_argument('--fp16-scale-tolerance', default=0.0, type=float, parser.add_argument('--fp16-scale-tolerance', default=0.0, type=float,
help='pct of updates that can overflow before decreasing the loss scale') help='pct of updates that can overflow before decreasing the loss scale')
parser.add_argument('--min-loss-scale', default=1e-4, type=float, metavar='D',
help='minimum FP16 loss scale, after which training is stopped')
parser.add_argument('--threshold-loss-scale', type=float,
help='threshold FP16 loss scale from below')
parser.add_argument('--user-dir', default=None, parser.add_argument('--user-dir', default=None,
help='path to a python module containing custom extensions (tasks and/or architectures)') help='path to a python module containing custom extensions (tasks and/or architectures)')
...@@ -283,14 +283,16 @@ def add_optimization_args(parser): ...@@ -283,14 +283,16 @@ def add_optimization_args(parser):
group.add_argument('--sentence-avg', action='store_true', group.add_argument('--sentence-avg', action='store_true',
help='normalize gradients by the number of sentences in a batch' help='normalize gradients by the number of sentences in a batch'
' (default is to normalize by number of tokens)') ' (default is to normalize by number of tokens)')
group.add_argument('--update-freq', default='1', metavar='N', group.add_argument('--update-freq', default='1', metavar='N1,N2,...,N_K',
type=lambda uf: eval_str_list(uf, type=int),
help='update parameters every N_i batches, when in epoch i') help='update parameters every N_i batches, when in epoch i')
# Optimizer definitions can be found under fairseq/optim/ # Optimizer definitions can be found under fairseq/optim/
group.add_argument('--optimizer', default='nag', metavar='OPT', group.add_argument('--optimizer', default='nag', metavar='OPT',
choices=OPTIMIZER_REGISTRY.keys(), choices=OPTIMIZER_REGISTRY.keys(),
help='Optimizer') help='Optimizer')
group.add_argument('--lr', '--learning-rate', default='0.25', metavar='LR_1,LR_2,...,LR_N', group.add_argument('--lr', '--learning-rate', default='0.25', type=eval_str_list,
metavar='LR_1,LR_2,...,LR_N',
help='learning rate for the first N epochs; all epochs >N using LR_N' help='learning rate for the first N epochs; all epochs >N using LR_N'
' (note: this may be interpreted differently depending on --lr-scheduler)') ' (note: this may be interpreted differently depending on --lr-scheduler)')
group.add_argument('--momentum', default=0.99, type=float, metavar='M', group.add_argument('--momentum', default=0.99, type=float, metavar='M',
...@@ -306,8 +308,6 @@ def add_optimization_args(parser): ...@@ -306,8 +308,6 @@ def add_optimization_args(parser):
help='learning rate shrink factor for annealing, lr_new = (lr * lr_shrink)') help='learning rate shrink factor for annealing, lr_new = (lr * lr_shrink)')
group.add_argument('--min-lr', default=1e-5, type=float, metavar='LR', group.add_argument('--min-lr', default=1e-5, type=float, metavar='LR',
help='minimum learning rate') help='minimum learning rate')
group.add_argument('--min-loss-scale', default=1e-4, type=float, metavar='D',
help='minimum loss scale (for FP16 training)')
# fmt: on # fmt: on
return group return group
......
...@@ -11,6 +11,9 @@ import pdb ...@@ -11,6 +11,9 @@ import pdb
import sys import sys
__all__ = ['set_trace']
class MultiprocessingPdb(pdb.Pdb): class MultiprocessingPdb(pdb.Pdb):
"""A Pdb wrapper that works in a multiprocessing environment. """A Pdb wrapper that works in a multiprocessing environment.
...@@ -36,4 +39,6 @@ class MultiprocessingPdb(pdb.Pdb): ...@@ -36,4 +39,6 @@ class MultiprocessingPdb(pdb.Pdb):
sys.stdin = stdin_bak sys.stdin = stdin_bak
pdb = MultiprocessingPdb() def set_trace():
pdb = MultiprocessingPdb()
pdb.set_trace(sys._getframe().f_back)
...@@ -122,13 +122,11 @@ def train(args, trainer, task, epoch_itr): ...@@ -122,13 +122,11 @@ def train(args, trainer, task, epoch_itr):
"""Train the model for one epoch.""" """Train the model for one epoch."""
# Update parameters every N batches # Update parameters every N batches
if epoch_itr.epoch <= len(args.update_freq):
update_freq = args.update_freq[epoch_itr.epoch - 1]
else:
update_freq = args.update_freq[-1]
# Initialize data iterator # Initialize data iterator
itr = epoch_itr.next_epoch_itr(fix_batches_to_gpus=args.fix_batches_to_gpus) itr = epoch_itr.next_epoch_itr(fix_batches_to_gpus=args.fix_batches_to_gpus)
update_freq = args.update_freq[epoch_itr.epoch - 1] \
if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]
itr = iterators.GroupedIterator(itr, update_freq) itr = iterators.GroupedIterator(itr, update_freq)
progress = progress_bar.build_progress_bar( progress = progress_bar.build_progress_bar(
args, itr, epoch_itr.epoch, no_progress_bar='simple', args, itr, epoch_itr.epoch, no_progress_bar='simple',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment