Commit b6e0377b authored by Mohammad's avatar Mohammad
Browse files

refactored learning-rate

parent 3366a5b0
......@@ -12,59 +12,68 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch DataLoader for TFRecords"""
import torch
from torch.optim.lr_scheduler import _LRScheduler
"""Learning rate decay functions."""
import math
from megatron import print_rank_0
class AnnealingLR(_LRScheduler):
"""Anneals the learning rate"""
DECAY_STYLES = ['linear', 'cosine', 'exponential', 'constant', 'None']
class AnnealingLR(object):
"""Anneals the learning rate."""
def __init__(self, optimizer, start_lr, warmup_iter, num_iters,
decay_style=None, last_iter=-1, min_lr=0.0,
def __init__(self, optimizer, start_lr,
warmup_iter, total_iters,
decay_style, last_iter, min_lr=0.0,
use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False):
# Class values.
self.optimizer = optimizer
self.start_lr = start_lr
self.min_lr = min_lr
self.warmup_iter = warmup_iter
self.num_iters = last_iter + 1
self.end_iter = num_iters
self.decay_style = decay_style.lower() if isinstance(decay_style, str) \
else None
self.num_iters = last_iter
self.end_iter = total_iters
assert self.end_iter > 0
self.decay_style = decay_style
self.override_lr_scheduler = override_lr_scheduler
self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler
if self.override_lr_scheduler:
assert not self.use_checkpoint_lr_scheduler, 'both override and '\
'use-checkpoint are set.'
# Set the learning rate
self.step(self.num_iters)
if torch.distributed.get_rank() == 0:
print('learning rate decaying', decay_style)
print_rank_0('> learning rate decay style: {}'.format(self.decay_style))
def get_lr(self):
# https://openreview.net/pdf?id=BJYwwY9ll pg. 4
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
num_iters_ = min(self.num_iters, self.end_iter - self.warmup_iter)
# Warmup.
if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter:
return float(self.start_lr) * num_iters_ / self.warmup_iter
num_iters_ = num_iters_ - self.warmup_iter
if self.decay_style == 'linear':
lr = self.start_lr * (self.end_iter - num_iters_) / self.end_iter
elif self.decay_style == 'cosine':
lr = self.start_lr / 2.0 * (math.cos(
math.pi * num_iters_ / self.end_iter) + 1)
elif self.decay_style == 'exponential':
# exp(-0.693) = 1/2
lr = self.start_lr * math.exp(-0.693 * num_iters_ / self.end_iter)
else:
if self.decay_style == self.DECAY_STYLES[0]:
lr = self.start_lr * ((self.end_iter - (num_iters_ - self.warmup_iter)) / self.end_iter)
elif self.decay_style == self.DECAY_STYLES[1]:
lr = self.start_lr / 2.0 * (math.cos(math.pi * (num_iters_ - self.warmup_iter) / self.end_iter) + 1)
elif self.decay_style == self.DECAY_STYLES[2]:
# exp(-0.693) = 1/2
lr = self.start_lr * math.exp(-0.693 * (num_iters_ - self.warmup_iter) / self.end_iter)
else:
lr = self.start_lr
return max(lr, self.min_lr)
lr = self.start_lr
return max(lr, self.min_lr)
def step(self, step_num=None):
"""Set lr for all parameters groups."""
if step_num is None:
step_num = self.num_iters + 1
self.num_iters = step_num
......@@ -72,42 +81,46 @@ class AnnealingLR(_LRScheduler):
for group in self.optimizer.param_groups:
group['lr'] = new_lr
def state_dict(self):
sd = {
'start_lr': self.start_lr,
'warmup_iter': self.warmup_iter,
'num_iters': self.num_iters,
'decay_style': self.decay_style,
'end_iter': self.end_iter,
'min_lr': self.min_lr
state_dict = {
'start_lr': self.start_lr,
'warmup_iter': self.warmup_iter,
'num_iters': self.num_iters,
'decay_style': self.decay_style,
'end_iter': self.end_iter,
'min_lr': self.min_lr
}
return sd
return state_dict
def check_and_set_(self, cls_value, sd_value, name):
def _check_and_set(self, cls_value, sd_value, name):
"""Auxiliary function for checking the values in the checkpoint and
setting them."""
if self.override_lr_scheduler:
print_rank_0(' > overriding {} value to {}'.format(name, cls_value))
return cls_value
else:
if not self.use_checkpoint_lr_scheduler:
assert cls_value == sd_value, 'AnnealingLR: class input value' \
'and checkpoint values for {} do not match'.format(name)
print_rank_0(' > using checkpoint value {} for {}'.format(sd_value,
name))
return sd_value
if not self.use_checkpoint_lr_scheduler:
assert cls_value == sd_value, 'AnnealingLR: class input value' \
'and checkpoint values for {} do not match'.format(name)
print_rank_0(' > using checkpoint value {} for {}'.format(sd_value,
name))
return sd_value
def load_state_dict(self, sd):
self.start_lr = self.check_and_set_(self.start_lr, sd['start_lr'],
self.start_lr = self._check_and_set(self.start_lr, sd['start_lr'],
'learning rate')
self.min_lr = self.check_and_set_(self.min_lr, sd['min_lr'],
self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'],
'minimum learning rate')
self.warmup_iter = self.check_and_set_(self.warmup_iter,
self.warmup_iter = self._check_and_set(self.warmup_iter,
sd['warmup_iter'],
'warmup iterations')
self.end_iter = self.check_and_set_(self.end_iter, sd['end_iter'],
self.end_iter = self._check_and_set(self.end_iter, sd['end_iter'],
'total number of iterations')
self.decay_style = self.check_and_set_(self.decay_style,
self.decay_style = self._check_and_set(self.decay_style,
sd['decay_style'],
'decay style')
......
......@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron Module"""
import torch
......
......@@ -197,13 +197,13 @@ def get_learning_rate_scheduler(optimizer):
else:
num_iters = args.train_iters
num_iters = max(1, num_iters)
init_step = -1
init_step = 0
warmup_iter = args.warmup * num_iters
lr_scheduler = AnnealingLR(
optimizer,
start_lr=args.lr,
warmup_iter=warmup_iter,
num_iters=num_iters,
total_iters=num_iters,
decay_style=args.lr_decay_style,
last_iter=init_step,
min_lr=args.min_lr,
......
......@@ -89,8 +89,6 @@ def check_adlr_autoresume_termination(iteration, model,
###################################################
from megatron import mpu
def get_ltor_masks_and_position_ids(data,
eod_token,
......@@ -148,16 +146,3 @@ def get_ltor_masks_and_position_ids(data,
return attention_mask, loss_mask, position_ids
def vocab_size_with_padding(num_tokens, args):
after = num_tokens
multiple = args.make_vocab_size_divisible_by * \
mpu.get_model_parallel_world_size()
while (after % multiple) != 0:
after += 1
print_rank_0('> padded vocab (size: {}) with {} dummy '
'tokens (new size: {})'.format(
num_tokens, after - num_tokens, after))
return after
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment