Unverified Commit 1c0b326e authored by Olatunji Ruwase's avatar Olatunji Ruwase Committed by GitHub
Browse files

Make lr schedulers support fp16 optimizers (#124)



* add tests cases for onecycle policy with fp16/zero

* Make lr schedulers support fp16 optimizers

* Fix formatting

* More specific naming
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
parent 27d83851
......@@ -13,6 +13,7 @@ from torch.optim import Optimizer
from typing import Union, List
import math
from deepspeed.pt.deepspeed_constants import *
import logging
LR_SCHEDULE = 'lr_schedule'
LR_RANGE_TEST = 'LRRangeTest'
......@@ -277,6 +278,24 @@ def get_lr_from_config(config):
return lr_params[WARMUP_MAX_LR], ''
"""
Only optimizers that are subclass of torch.optim.Optimizer are supported. So check the passed optimizer and wrapped
optimizer to see if requirement is satisfied.
TODO: Looking under the hood to examine the wrapped optimizer is a hack that requires a better long-term fix.
"""
def get_torch_optimizer(optimizer):
if isinstance(optimizer, Optimizer):
return optimizer
if hasattr(optimizer, 'optimizer') and isinstance(optimizer.optimizer, Optimizer):
return optimizer.optimizer
raise TypeError('{} is not a subclass of torch.optim.Optimizer'.format(
type(optimizer).__name__))
class LRRangeTest(object):
"""Sets the learning rate of each parameter group according to
learning rate range test (LRRT) policy. The policy increases learning
......@@ -323,20 +342,18 @@ class LRRangeTest(object):
lr_range_test_staircase: bool = False,
last_batch_iteration: int = -1):
if not isinstance(optimizer, Optimizer):
raise TypeError('{} is not an Optimizer'.format(type(optimizer).__name__))
self.optimizer = optimizer
self.optimizer = get_torch_optimizer(optimizer)
if isinstance(lr_range_test_min_lr,
list) or isinstance(lr_range_test_min_lr,
tuple):
if len(lr_range_test_min_lr) != len(optimizer.param_groups):
if len(lr_range_test_min_lr) != len(self.optimizer.param_groups):
raise ValueError("expected {} lr_range_test_min_lr, got {}".format(
len(optimizer.param_groups),
len(self.optimizer.param_groups),
len(lr_range_test_min_lr)))
self.min_lr = list(lr_range_test_min_lr)
else:
self.min_lr = [lr_range_test_min_lr] * len(optimizer.param_groups)
self.min_lr = [lr_range_test_min_lr] * len(self.optimizer.param_groups)
self.step_size = lr_range_test_step_size
self.step_rate = lr_range_test_step_rate
......@@ -463,41 +480,86 @@ class OneCycle(object):
decay_mom_rate=0.,
last_batch_iteration=-1):
if not isinstance(optimizer, Optimizer):
raise TypeError('{} is not an Optimizer'.format(type(optimizer).__name__))
self.optimizer = optimizer
self.optimizer = get_torch_optimizer(optimizer)
self.min_lrs = [cycle_min_lr] * len(optimizer.param_groups)
if last_batch_iteration == -1:
for lr, group in zip(self.min_lrs, optimizer.param_groups):
group['lr'] = lr
# Initialize cycle shape
self._initialize_cycle(cycle_first_step_size,
cycle_second_step_size,
cycle_first_stair_count,
cycle_second_stair_count,
decay_step_size)
self.max_lrs = [cycle_max_lr] * len(optimizer.param_groups)
# Initialize cycle lr
self._initialize_lr(self.optimizer,
cycle_min_lr,
cycle_max_lr,
decay_lr_rate,
last_batch_iteration)
# Initialize cyclic momentum
self.cycle_momentum = cycle_momentum
if cycle_momentum:
self._initialize_momentum(self.optimizer,
cycle_min_mom,
cycle_max_mom,
decay_mom_rate,
last_batch_iteration)
# Initalize batch iteration tracker
self.last_batch_iteration = last_batch_iteration
# Configure cycle shape
def _initialize_cycle(self,
cycle_first_step_size,
cycle_second_step_size,
cycle_first_stair_count,
cycle_second_stair_count,
decay_step_size):
cycle_first_step_size = float(cycle_first_step_size)
cycle_second_step_size = float(
cycle_second_step_size
) if cycle_second_step_size is not None else cycle_first_step_size
self.total_size = cycle_first_step_size + cycle_second_step_size
self.step_ratio = cycle_first_step_size / self.total_size
self.first_stair_count = cycle_first_stair_count
self.second_stair_count = cycle_first_stair_count if cycle_second_stair_count is None else cycle_second_stair_count
self.decay_lr_rate = decay_lr_rate
self.decay_mom_rate = decay_mom_rate
self.decay_step_size = decay_step_size
self.min_moms = [(cycle_min_mom, 0.99)] * len(optimizer.param_groups)
self.max_moms = [(cycle_max_mom, 0.99)] * len(optimizer.param_groups)
self.cycle_momentum = cycle_momentum
# Configure lr schedule
def _initialize_lr(self,
optimizer,
cycle_min_lr,
cycle_max_lr,
decay_lr_rate,
last_batch_iteration):
self.min_lrs = [cycle_min_lr] * len(optimizer.param_groups)
if last_batch_iteration == -1:
for lr, group in zip(self.min_lrs, optimizer.param_groups):
group['lr'] = lr
self.last_batch_iteration = last_batch_iteration
self.max_lrs = [cycle_max_lr] * len(optimizer.param_groups)
self.decay_lr_rate = decay_lr_rate
if cycle_momentum:
# Configure momentum schedule
def _initialize_momentum(self,
optimizer,
cycle_min_mom,
cycle_max_mom,
decay_mom_rate,
last_batch_iteration):
if 'betas' not in optimizer.defaults:
raise ValueError(
'optimizer must support betas with `cycle_momentum` option enabled')
optimizer_name = type(optimizer).__name__
logging.warn(
f"cycle_momentum is disabled because optimizer {optimizer_name} does not support momentum, no betas attribute in defaults"
)
self.cycle_momentum = False
return
self.decay_mom_rate = decay_mom_rate
self.min_moms = [(cycle_min_mom, 0.99)] * len(optimizer.param_groups)
self.max_moms = [(cycle_max_mom, 0.99)] * len(optimizer.param_groups)
if last_batch_iteration == -1:
for momentum, group in zip(self.min_moms, optimizer.param_groups):
......@@ -606,10 +668,10 @@ class WarmupLR(object):
warmup_num_steps: int = 1000,
last_batch_iteration: int = -1):
self.optimizer = optimizer
self.optimizer = get_torch_optimizer(optimizer)
self.min_lrs = self._format_param(optimizer, warmup_min_lr, "min_lr")
self.max_lrs = self._format_param(optimizer, warmup_max_lr, "max_lr")
self.min_lrs = self._format_param(self.optimizer, warmup_min_lr, "min_lr")
self.max_lrs = self._format_param(self.optimizer, warmup_max_lr, "max_lr")
self.delta_lrs = [big - small for big, small in zip(self.max_lrs, self.min_lrs)]
self.warmup_num_steps = warmup_num_steps
self.inverse_log_warm_up = 1.0 / math.log(warmup_num_steps)
......
......@@ -142,3 +142,107 @@ def test_adamw_fp16_empty_grad(tmpdir):
model.step()
_test_adamw_fp16_empty_grad(args=args, model=model, hidden_dim=hidden_dim)
def test_adam_fp16_onecycle_compatibility(tmpdir):
config_dict = {
"train_batch_size": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"scheduler": {
"type": "OneCycle",
"params": {
"cycle_first_step_size": 16000,
"cycle_first_stair_count": 8000,
"decay_step_size": 16000,
"cycle_min_lr": 1e-06,
"cycle_max_lr": 3e-05,
"decay_lr_rate": 1e-07,
"cycle_min_mom": 0.85,
"cycle_max_mom": 0.99,
"decay_mom_rate": 0.0
}
},
"fp16": {
"enabled": True
},
"zero_optimization": False
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=True)
@distributed_test(world_size=[1])
def _test_adam_fp16_onecycle_compatibility(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
_test_adam_fp16_onecycle_compatibility(args=args, model=model, hidden_dim=hidden_dim)
def test_adam_fp16_zero_onecycle_compatibility(tmpdir):
config_dict = {
"train_batch_size": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"scheduler": {
"type": "OneCycle",
"params": {
"cycle_first_step_size": 16000,
"cycle_first_stair_count": 8000,
"decay_step_size": 16000,
"cycle_min_lr": 1e-06,
"cycle_max_lr": 3e-05,
"decay_lr_rate": 1e-07,
"cycle_min_mom": 0.85,
"cycle_max_mom": 0.99,
"decay_mom_rate": 0.0
}
},
"fp16": {
"enabled": True
},
"zero_optimization": True
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=True)
@distributed_test(world_size=[1])
def _test_adam_fp16_zero_onecycle_compatibility(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
_test_adam_fp16_zero_onecycle_compatibility(args=args,
model=model,
hidden_dim=hidden_dim)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment