Commit b6e0377b authored by Mohammad's avatar Mohammad
Browse files

refactored learning-rate

parent 3366a5b0
...@@ -12,59 +12,68 @@ ...@@ -12,59 +12,68 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""PyTorch DataLoader for TFRecords"""
import torch """Learning rate decay functions."""
from torch.optim.lr_scheduler import _LRScheduler
import math import math
from megatron import print_rank_0 from megatron import print_rank_0
class AnnealingLR(_LRScheduler): class AnnealingLR(object):
"""Anneals the learning rate""" """Anneals the learning rate."""
DECAY_STYLES = ['linear', 'cosine', 'exponential', 'constant', 'None']
def __init__(self, optimizer, start_lr, warmup_iter, num_iters, def __init__(self, optimizer, start_lr,
decay_style=None, last_iter=-1, min_lr=0.0, warmup_iter, total_iters,
decay_style, last_iter, min_lr=0.0,
use_checkpoint_lr_scheduler=True, use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False): override_lr_scheduler=False):
# Class values.
self.optimizer = optimizer self.optimizer = optimizer
self.start_lr = start_lr self.start_lr = start_lr
self.min_lr = min_lr self.min_lr = min_lr
self.warmup_iter = warmup_iter self.warmup_iter = warmup_iter
self.num_iters = last_iter + 1 self.num_iters = last_iter
self.end_iter = num_iters self.end_iter = total_iters
self.decay_style = decay_style.lower() if isinstance(decay_style, str) \ assert self.end_iter > 0
else None self.decay_style = decay_style
self.override_lr_scheduler = override_lr_scheduler self.override_lr_scheduler = override_lr_scheduler
self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler
if self.override_lr_scheduler: if self.override_lr_scheduler:
assert not self.use_checkpoint_lr_scheduler, 'both override and '\ assert not self.use_checkpoint_lr_scheduler, 'both override and '\
'use-checkpoint are set.' 'use-checkpoint are set.'
# Set the learning rate
self.step(self.num_iters) self.step(self.num_iters)
if torch.distributed.get_rank() == 0:
print('learning rate decaying', decay_style) print_rank_0('> learning rate decay style: {}'.format(self.decay_style))
def get_lr(self): def get_lr(self):
# https://openreview.net/pdf?id=BJYwwY9ll pg. 4 """Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
num_iters_ = min(self.num_iters, self.end_iter - self.warmup_iter) num_iters_ = min(self.num_iters, self.end_iter - self.warmup_iter)
# Warmup.
if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter: if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter:
return float(self.start_lr) * num_iters_ / self.warmup_iter return float(self.start_lr) * num_iters_ / self.warmup_iter
num_iters_ = num_iters_ - self.warmup_iter
if self.decay_style == 'linear':
lr = self.start_lr * (self.end_iter - num_iters_) / self.end_iter
elif self.decay_style == 'cosine':
lr = self.start_lr / 2.0 * (math.cos(
math.pi * num_iters_ / self.end_iter) + 1)
elif self.decay_style == 'exponential':
# exp(-0.693) = 1/2
lr = self.start_lr * math.exp(-0.693 * num_iters_ / self.end_iter)
else: else:
if self.decay_style == self.DECAY_STYLES[0]: lr = self.start_lr
lr = self.start_lr * ((self.end_iter - (num_iters_ - self.warmup_iter)) / self.end_iter) return max(lr, self.min_lr)
elif self.decay_style == self.DECAY_STYLES[1]:
lr = self.start_lr / 2.0 * (math.cos(math.pi * (num_iters_ - self.warmup_iter) / self.end_iter) + 1)
elif self.decay_style == self.DECAY_STYLES[2]:
# exp(-0.693) = 1/2
lr = self.start_lr * math.exp(-0.693 * (num_iters_ - self.warmup_iter) / self.end_iter)
else:
lr = self.start_lr
return max(lr, self.min_lr)
def step(self, step_num=None): def step(self, step_num=None):
"""Set lr for all parameters groups."""
if step_num is None: if step_num is None:
step_num = self.num_iters + 1 step_num = self.num_iters + 1
self.num_iters = step_num self.num_iters = step_num
...@@ -72,42 +81,46 @@ class AnnealingLR(_LRScheduler): ...@@ -72,42 +81,46 @@ class AnnealingLR(_LRScheduler):
for group in self.optimizer.param_groups: for group in self.optimizer.param_groups:
group['lr'] = new_lr group['lr'] = new_lr
def state_dict(self): def state_dict(self):
sd = { state_dict = {
'start_lr': self.start_lr, 'start_lr': self.start_lr,
'warmup_iter': self.warmup_iter, 'warmup_iter': self.warmup_iter,
'num_iters': self.num_iters, 'num_iters': self.num_iters,
'decay_style': self.decay_style, 'decay_style': self.decay_style,
'end_iter': self.end_iter, 'end_iter': self.end_iter,
'min_lr': self.min_lr 'min_lr': self.min_lr
} }
return sd return state_dict
def check_and_set_(self, cls_value, sd_value, name): def _check_and_set(self, cls_value, sd_value, name):
"""Auxiliary function for checking the values in the checkpoint and
setting them."""
if self.override_lr_scheduler: if self.override_lr_scheduler:
print_rank_0(' > overriding {} value to {}'.format(name, cls_value)) print_rank_0(' > overriding {} value to {}'.format(name, cls_value))
return cls_value return cls_value
else:
if not self.use_checkpoint_lr_scheduler: if not self.use_checkpoint_lr_scheduler:
assert cls_value == sd_value, 'AnnealingLR: class input value' \ assert cls_value == sd_value, 'AnnealingLR: class input value' \
'and checkpoint values for {} do not match'.format(name) 'and checkpoint values for {} do not match'.format(name)
print_rank_0(' > using checkpoint value {} for {}'.format(sd_value, print_rank_0(' > using checkpoint value {} for {}'.format(sd_value,
name)) name))
return sd_value return sd_value
def load_state_dict(self, sd): def load_state_dict(self, sd):
self.start_lr = self.check_and_set_(self.start_lr, sd['start_lr'], self.start_lr = self._check_and_set(self.start_lr, sd['start_lr'],
'learning rate') 'learning rate')
self.min_lr = self.check_and_set_(self.min_lr, sd['min_lr'], self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'],
'minimum learning rate') 'minimum learning rate')
self.warmup_iter = self.check_and_set_(self.warmup_iter, self.warmup_iter = self._check_and_set(self.warmup_iter,
sd['warmup_iter'], sd['warmup_iter'],
'warmup iterations') 'warmup iterations')
self.end_iter = self.check_and_set_(self.end_iter, sd['end_iter'], self.end_iter = self._check_and_set(self.end_iter, sd['end_iter'],
'total number of iterations') 'total number of iterations')
self.decay_style = self.check_and_set_(self.decay_style, self.decay_style = self._check_and_set(self.decay_style,
sd['decay_style'], sd['decay_style'],
'decay style') 'decay style')
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Megatron Module""" """Megatron Module"""
import torch import torch
......
...@@ -197,13 +197,13 @@ def get_learning_rate_scheduler(optimizer): ...@@ -197,13 +197,13 @@ def get_learning_rate_scheduler(optimizer):
else: else:
num_iters = args.train_iters num_iters = args.train_iters
num_iters = max(1, num_iters) num_iters = max(1, num_iters)
init_step = -1 init_step = 0
warmup_iter = args.warmup * num_iters warmup_iter = args.warmup * num_iters
lr_scheduler = AnnealingLR( lr_scheduler = AnnealingLR(
optimizer, optimizer,
start_lr=args.lr, start_lr=args.lr,
warmup_iter=warmup_iter, warmup_iter=warmup_iter,
num_iters=num_iters, total_iters=num_iters,
decay_style=args.lr_decay_style, decay_style=args.lr_decay_style,
last_iter=init_step, last_iter=init_step,
min_lr=args.min_lr, min_lr=args.min_lr,
......
...@@ -89,8 +89,6 @@ def check_adlr_autoresume_termination(iteration, model, ...@@ -89,8 +89,6 @@ def check_adlr_autoresume_termination(iteration, model,
################################################### ###################################################
from megatron import mpu
def get_ltor_masks_and_position_ids(data, def get_ltor_masks_and_position_ids(data,
eod_token, eod_token,
...@@ -148,16 +146,3 @@ def get_ltor_masks_and_position_ids(data, ...@@ -148,16 +146,3 @@ def get_ltor_masks_and_position_ids(data,
return attention_mask, loss_mask, position_ids return attention_mask, loss_mask, position_ids
def vocab_size_with_padding(num_tokens, args):
after = num_tokens
multiple = args.make_vocab_size_divisible_by * \
mpu.get_model_parallel_world_size()
while (after % multiple) != 0:
after += 1
print_rank_0('> padded vocab (size: {}) with {} dummy '
'tokens (new size: {})'.format(
num_tokens, after - num_tokens, after))
return after
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment