Commit d646a4a8 authored by Myle Ott's avatar Myle Ott
Browse files

Add support for additional optimizers

parent cab76554
......@@ -31,6 +31,8 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
(prefixed with `_async_`), which run on each process in parallel.
"""
OPTIMIZERS = ['adagrad', 'adam', 'nag', 'sgd']
def __init__(self, args, model, criterion, device_ids=None,
multiprocessing_method='spawn'):
if device_ids is None:
......@@ -69,15 +71,32 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self.criterion = criterion.cuda()
# initialize optimizer
self.optimizer = NAG(self.model.parameters(), lr=self.args.lr,
momentum=self.args.momentum,
weight_decay=self.args.weight_decay)
self.optimizer = self._build_optimizer()
self.flat_grads = None
self.loss = None
# initialize LR scheduler
self.lr_scheduler = self._build_lr_scheduler()
def _build_optimizer(self):
if self.args.optimizer == 'adagrad':
return torch.optim.Adagrad(self.model.parameters(), lr=self.args.lr,
weight_decay=self.args.weight_decay)
elif self.args.optimizer == 'adam':
return torch.optim.Adam(self.model.parameters(), lr=self.args.lr,
betas=eval(self.args.adam_betas),
weight_decay=self.args.weight_decay)
elif self.args.optimizer == 'nag':
return NAG(self.model.parameters(), lr=self.args.lr,
momentum=self.args.momentum,
weight_decay=self.args.weight_decay)
elif self.args.optimizer == 'sgd':
return torch.optim.SGD(self.model.parameters(), lr=self.args.lr,
momentum=self.args.momentum,
weight_decay=self.args.weight_decay)
else:
raise ValueError('Unknown optimizer: {}'.format(self.args.optimizer))
def _build_lr_scheduler(self):
if self.args.force_anneal > 0:
def anneal(e):
......
......@@ -9,6 +9,7 @@
import argparse
from fairseq import models
from fairseq.multiprocessing_trainer import MultiprocessingTrainer
def get_parser(desc):
......@@ -41,6 +42,9 @@ def add_dataset_args(parser):
def add_optimization_args(parser):
group = parser.add_argument_group('Optimization')
group.add_argument('--optimizer', default='nag', metavar='OPT',
choices=MultiprocessingTrainer.OPTIMIZERS,
help='optimizer ({})'.format(', '.join(MultiprocessingTrainer.OPTIMIZERS)))
group.add_argument('--lr', '--learning-rate', default=0.25, type=float, metavar='LR',
help='initial learning rate')
group.add_argument('--min-lr', metavar='LR', default=1e-5, type=float,
......@@ -53,6 +57,8 @@ def add_optimization_args(parser):
help='learning rate shrink factor for annealing, lr_new = (lr * lrshrink)')
group.add_argument('--momentum', default=0.99, type=float, metavar='M',
help='momentum factor')
group.add_argument('--adam-betas', default='(0.9, 0.999)', metavar='B',
help='betas for Adam optimizer')
group.add_argument('--clip-norm', default=25, type=float, metavar='NORM',
help='clip threshold of gradients')
group.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment