"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c5594795929c9c0274ae4a72cbffb2e03d128efe"
Commit 19fafae6 authored by Myle Ott's avatar Myle Ott
Browse files

Allow --lr to specify a fixed learning rate schedule

parent a233fceb
...@@ -68,41 +68,43 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -68,41 +68,43 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self.model = model.cuda() self.model = model.cuda()
self.criterion = criterion.cuda() self.criterion = criterion.cuda()
# initialize optimizer # initialize optimizer and LR scheduler
self.args.lr = list(map(float, self.args.lr.split(',')))
self.optimizer = self._build_optimizer() self.optimizer = self._build_optimizer()
self.loss = None
# initialize LR scheduler
self.lr_scheduler = self._build_lr_scheduler() self.lr_scheduler = self._build_lr_scheduler()
self.loss = None
self._max_bsz_seen = 0 self._max_bsz_seen = 0
def _build_optimizer(self): def _build_optimizer(self):
if self.args.optimizer == 'adagrad': if self.args.optimizer == 'adagrad':
return torch.optim.Adagrad(self.model.parameters(), lr=self.args.lr, return torch.optim.Adagrad(self.model.parameters(), lr=self.args.lr[0],
weight_decay=self.args.weight_decay) weight_decay=self.args.weight_decay)
elif self.args.optimizer == 'adam': elif self.args.optimizer == 'adam':
return torch.optim.Adam(self.model.parameters(), lr=self.args.lr, return torch.optim.Adam(self.model.parameters(), lr=self.args.lr[0],
betas=eval(self.args.adam_betas), betas=eval(self.args.adam_betas),
weight_decay=self.args.weight_decay) weight_decay=self.args.weight_decay)
elif self.args.optimizer == 'nag': elif self.args.optimizer == 'nag':
return NAG(self.model.parameters(), lr=self.args.lr, return NAG(self.model.parameters(), lr=self.args.lr[0],
momentum=self.args.momentum, momentum=self.args.momentum,
weight_decay=self.args.weight_decay) weight_decay=self.args.weight_decay)
elif self.args.optimizer == 'sgd': elif self.args.optimizer == 'sgd':
return torch.optim.SGD(self.model.parameters(), lr=self.args.lr, return torch.optim.SGD(self.model.parameters(), lr=self.args.lr[0],
momentum=self.args.momentum, momentum=self.args.momentum,
weight_decay=self.args.weight_decay) weight_decay=self.args.weight_decay)
else: else:
raise ValueError('Unknown optimizer: {}'.format(self.args.optimizer)) raise ValueError('Unknown optimizer: {}'.format(self.args.optimizer))
def _build_lr_scheduler(self): def _build_lr_scheduler(self):
if self.args.force_anneal > 0: if len(self.args.lr) > 1 or self.args.force_anneal > 0:
lrs = self.args.lr
def anneal(e): def anneal(e):
if e < self.args.force_anneal: if e < self.args.force_anneal:
return 1 # use fixed LR schedule
next_lr = lrs[min(e, len(lrs) - 1)]
else: else:
return self.args.lrshrink ** (e + 1 - self.args.force_anneal) next_lr = lrs[-1] * self.args.lrshrink ** (e + 1 - self.args.force_anneal)
return next_lr / lrs[0] # correct for scaling from LambdaLR
lr_scheduler = LambdaLR(self.optimizer, anneal) lr_scheduler = LambdaLR(self.optimizer, anneal)
lr_scheduler.best = None lr_scheduler.best = None
else: else:
......
...@@ -49,8 +49,8 @@ def add_optimization_args(parser): ...@@ -49,8 +49,8 @@ def add_optimization_args(parser):
group.add_argument('--optimizer', default='nag', metavar='OPT', group.add_argument('--optimizer', default='nag', metavar='OPT',
choices=MultiprocessingTrainer.OPTIMIZERS, choices=MultiprocessingTrainer.OPTIMIZERS,
help='optimizer ({})'.format(', '.join(MultiprocessingTrainer.OPTIMIZERS))) help='optimizer ({})'.format(', '.join(MultiprocessingTrainer.OPTIMIZERS)))
group.add_argument('--lr', '--learning-rate', default=0.25, type=float, metavar='LR', group.add_argument('--lr', '--learning-rate', default='0.25', metavar='LR1,LR2,...,LRn',
help='initial learning rate') help='learning rate for the first n epochs with all epochs >n using LRn')
group.add_argument('--min-lr', metavar='LR', default=1e-5, type=float, group.add_argument('--min-lr', metavar='LR', default=1e-5, type=float,
help='minimum learning rate') help='minimum learning rate')
group.add_argument('--force-anneal', '--fa', default=0, type=int, metavar='N', group.add_argument('--force-anneal', '--fa', default=0, type=int, metavar='N',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment