"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "37a44bb2839c1af18940b6cf38f5639c9c279caf"
Commit ff12df6b authored by mohammad's avatar mohammad
Browse files

refactored learning rate scheduler so addition of variable batch size is easier

parent 16193619
...@@ -22,25 +22,25 @@ from megatron import print_rank_0 ...@@ -22,25 +22,25 @@ from megatron import print_rank_0
class AnnealingLR(object): class AnnealingLR(object):
"""Anneals the learning rate.""" """Anneals the learning rate."""
def __init__(self, optimizer, start_lr, def __init__(self, optimizer, max_lr, min_lr,
warmup_iter, total_iters, warmup_steps, decay_steps,
decay_style, last_iter, min_lr=0.0, decay_style, num_steps,
use_checkpoint_lr_scheduler=True, use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False): override_lr_scheduler=False):
# Class values. # Class values.
self.optimizer = optimizer self.optimizer = optimizer
self.start_lr = float(start_lr) self.max_lr = float(max_lr)
self.min_lr = min_lr self.min_lr = min_lr
assert self.min_lr >= 0.0 assert self.min_lr >= 0.0
assert self.start_lr >= self.min_lr assert self.max_lr >= self.min_lr
self.warmup_iter = warmup_iter self.warmup_steps = warmup_steps
self.num_iters = last_iter self.num_steps = num_steps
self.end_iter = total_iters self.decay_steps = decay_steps
assert self.end_iter > 0 assert self.decay_steps > 0
assert self.warmup_iter < self.end_iter assert self.warmup_steps < self.decay_steps
self.decay_style = decay_style self.decay_style = decay_style
...@@ -51,7 +51,7 @@ class AnnealingLR(object): ...@@ -51,7 +51,7 @@ class AnnealingLR(object):
'use-checkpoint are set.' 'use-checkpoint are set.'
# Set the learning rate # Set the learning rate
self.step(self.num_iters) self.step(step_num=self.num_steps)
print_rank_0('> learning rate decay style: {}'.format(self.decay_style)) print_rank_0('> learning rate decay style: {}'.format(self.decay_style))
...@@ -61,25 +61,25 @@ class AnnealingLR(object): ...@@ -61,25 +61,25 @@ class AnnealingLR(object):
https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
# Use linear warmup for the initial part. # Use linear warmup for the initial part.
if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter: if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps:
return self.start_lr * float(self.num_iters) / \ return self.max_lr * float(self.num_steps) / \
float(self.warmup_iter) float(self.warmup_steps)
# If the learning rate is constant, just return the initial value. # If the learning rate is constant, just return the initial value.
if self.decay_style == 'constant': if self.decay_style == 'constant':
return self.start_lr return self.max_lr
# For any iterations larger than `self.end_iter`, use `self.min_lr`. # For any steps larger than `self.decay_steps`, use `self.min_lr`.
if self.num_iters > self.end_iter: if self.num_steps > self.decay_steps:
return self.min_lr return self.min_lr
# If we are done with the warmup period, use the decay style. # If we are done with the warmup period, use the decay style.
current_iter = self.num_iters - self.warmup_iter num_steps_ = self.num_steps - self.warmup_steps
decay_iters = self.end_iter - self.warmup_iter decay_steps_ = self.decay_steps - self.warmup_steps
decay_ratio = float(current_iter) / float(decay_iters) decay_ratio = float(num_steps_) / float(decay_steps_)
assert decay_ratio >= 0.0 assert decay_ratio >= 0.0
assert decay_ratio <= 1.0 assert decay_ratio <= 1.0
delta_lr = self.start_lr - self.min_lr delta_lr = self.max_lr - self.min_lr
if self.decay_style == 'linear': if self.decay_style == 'linear':
coeff = (1.0 - decay_ratio) coeff = (1.0 - decay_ratio)
...@@ -92,11 +92,11 @@ class AnnealingLR(object): ...@@ -92,11 +92,11 @@ class AnnealingLR(object):
return self.min_lr + coeff * delta_lr return self.min_lr + coeff * delta_lr
def step(self, step_num=None): def step(self, increment=1, step_num=None):
"""Set lr for all parameters groups.""" """Set lr for all parameters groups."""
if step_num is None: if step_num is None:
step_num = self.num_iters + 1 step_num = self.num_steps + increment
self.num_iters = step_num self.num_steps = step_num
new_lr = self.get_lr() new_lr = self.get_lr()
for group in self.optimizer.param_groups: for group in self.optimizer.param_groups:
group['lr'] = new_lr group['lr'] = new_lr
...@@ -104,11 +104,11 @@ class AnnealingLR(object): ...@@ -104,11 +104,11 @@ class AnnealingLR(object):
def state_dict(self): def state_dict(self):
state_dict = { state_dict = {
'start_lr': self.start_lr, 'max_lr': self.max_lr,
'warmup_iter': self.warmup_iter, 'warmup_steps': self.warmup_steps,
'num_iters': self.num_iters, 'num_steps': self.num_steps,
'decay_style': self.decay_style, 'decay_style': self.decay_style,
'end_iter': self.end_iter, 'decay_steps': self.decay_steps,
'min_lr': self.min_lr 'min_lr': self.min_lr
} }
return state_dict return state_dict
...@@ -131,18 +131,36 @@ class AnnealingLR(object): ...@@ -131,18 +131,36 @@ class AnnealingLR(object):
def load_state_dict(self, sd): def load_state_dict(self, sd):
self.start_lr = self._check_and_set(self.start_lr, sd['start_lr'], if 'start_lr' in sd:
'learning rate') max_lr_ = sd['start_lr']
else:
max_lr_ = sd['max_lr']
self.max_lr = self._check_and_set(self.max_lr, max_lr_,
'learning rate')
self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'], self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'],
'minimum learning rate') 'minimum learning rate')
self.warmup_iter = self._check_and_set(self.warmup_iter,
sd['warmup_iter'], if 'warmup_iter' in sd:
'warmup iterations') warmup_steps_ = sd['warmup_iter']
self.end_iter = self._check_and_set(self.end_iter, sd['end_iter'], else:
'total number of iterations') warmup_steps_ = sd['warmup_steps']
self.warmup_steps = self._check_and_set(self.warmup_steps,
warmup_steps_,
'warmup iterations')
if 'end_iter' in sd:
decay_steps_ = sd['end_iter']
else:
decay_steps_ = sd['decay_steps']
self.decay_steps = self._check_and_set(self.decay_steps, decay_steps_,
'total number of iterations')
self.decay_style = self._check_and_set(self.decay_style, self.decay_style = self._check_and_set(self.decay_style,
sd['decay_style'], sd['decay_style'],
'decay style') 'decay style')
self.num_iters = sd['num_iters'] if 'num_iters' in sd:
self.step(self.num_iters) self.num_steps = sd['num_iters']
else:
self.num_steps = sd['num_steps']
self.step(step_num=self.num_steps)
...@@ -196,12 +196,12 @@ def get_learning_rate_scheduler(optimizer): ...@@ -196,12 +196,12 @@ def get_learning_rate_scheduler(optimizer):
warmup_iter = args.warmup * num_iters warmup_iter = args.warmup * num_iters
lr_scheduler = AnnealingLR( lr_scheduler = AnnealingLR(
optimizer, optimizer,
start_lr=args.lr, max_lr=args.lr,
warmup_iter=warmup_iter,
total_iters=num_iters,
decay_style=args.lr_decay_style,
last_iter=init_step,
min_lr=args.min_lr, min_lr=args.min_lr,
warmup_steps=warmup_iter,
decay_steps=num_iters,
decay_style=args.lr_decay_style,
num_steps=init_step,
use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler, use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
override_lr_scheduler=args.override_lr_scheduler) override_lr_scheduler=args.override_lr_scheduler)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment