Unverified Commit 1c0b326e authored by Olatunji Ruwase's avatar Olatunji Ruwase Committed by GitHub
Browse files

Make lr schedulers support fp16 optimizers (#124)



* add tests cases for onecycle policy with fp16/zero

* Make lr schedulers support fp16 optimizers

* Fix formatting

* More specific naming
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
parent 27d83851
...@@ -13,6 +13,7 @@ from torch.optim import Optimizer ...@@ -13,6 +13,7 @@ from torch.optim import Optimizer
from typing import Union, List from typing import Union, List
import math import math
from deepspeed.pt.deepspeed_constants import * from deepspeed.pt.deepspeed_constants import *
import logging
LR_SCHEDULE = 'lr_schedule' LR_SCHEDULE = 'lr_schedule'
LR_RANGE_TEST = 'LRRangeTest' LR_RANGE_TEST = 'LRRangeTest'
...@@ -277,6 +278,24 @@ def get_lr_from_config(config): ...@@ -277,6 +278,24 @@ def get_lr_from_config(config):
return lr_params[WARMUP_MAX_LR], '' return lr_params[WARMUP_MAX_LR], ''
"""
Only optimizers that are subclass of torch.optim.Optimizer are supported. So check the passed optimizer and wrapped
optimizer to see if requirement is satisfied.
TODO: Looking under the hood to examine the wrapped optimizer is a hack that requires a better long-term fix.
"""
def get_torch_optimizer(optimizer):
if isinstance(optimizer, Optimizer):
return optimizer
if hasattr(optimizer, 'optimizer') and isinstance(optimizer.optimizer, Optimizer):
return optimizer.optimizer
raise TypeError('{} is not a subclass of torch.optim.Optimizer'.format(
type(optimizer).__name__))
class LRRangeTest(object): class LRRangeTest(object):
"""Sets the learning rate of each parameter group according to """Sets the learning rate of each parameter group according to
learning rate range test (LRRT) policy. The policy increases learning learning rate range test (LRRT) policy. The policy increases learning
...@@ -323,20 +342,18 @@ class LRRangeTest(object): ...@@ -323,20 +342,18 @@ class LRRangeTest(object):
lr_range_test_staircase: bool = False, lr_range_test_staircase: bool = False,
last_batch_iteration: int = -1): last_batch_iteration: int = -1):
if not isinstance(optimizer, Optimizer): self.optimizer = get_torch_optimizer(optimizer)
raise TypeError('{} is not an Optimizer'.format(type(optimizer).__name__))
self.optimizer = optimizer
if isinstance(lr_range_test_min_lr, if isinstance(lr_range_test_min_lr,
list) or isinstance(lr_range_test_min_lr, list) or isinstance(lr_range_test_min_lr,
tuple): tuple):
if len(lr_range_test_min_lr) != len(optimizer.param_groups): if len(lr_range_test_min_lr) != len(self.optimizer.param_groups):
raise ValueError("expected {} lr_range_test_min_lr, got {}".format( raise ValueError("expected {} lr_range_test_min_lr, got {}".format(
len(optimizer.param_groups), len(self.optimizer.param_groups),
len(lr_range_test_min_lr))) len(lr_range_test_min_lr)))
self.min_lr = list(lr_range_test_min_lr) self.min_lr = list(lr_range_test_min_lr)
else: else:
self.min_lr = [lr_range_test_min_lr] * len(optimizer.param_groups) self.min_lr = [lr_range_test_min_lr] * len(self.optimizer.param_groups)
self.step_size = lr_range_test_step_size self.step_size = lr_range_test_step_size
self.step_rate = lr_range_test_step_rate self.step_rate = lr_range_test_step_rate
...@@ -463,45 +480,90 @@ class OneCycle(object): ...@@ -463,45 +480,90 @@ class OneCycle(object):
decay_mom_rate=0., decay_mom_rate=0.,
last_batch_iteration=-1): last_batch_iteration=-1):
if not isinstance(optimizer, Optimizer): self.optimizer = get_torch_optimizer(optimizer)
raise TypeError('{} is not an Optimizer'.format(type(optimizer).__name__))
self.optimizer = optimizer
self.min_lrs = [cycle_min_lr] * len(optimizer.param_groups) # Initialize cycle shape
if last_batch_iteration == -1: self._initialize_cycle(cycle_first_step_size,
for lr, group in zip(self.min_lrs, optimizer.param_groups): cycle_second_step_size,
group['lr'] = lr cycle_first_stair_count,
cycle_second_stair_count,
decay_step_size)
self.max_lrs = [cycle_max_lr] * len(optimizer.param_groups) # Initialize cycle lr
self._initialize_lr(self.optimizer,
cycle_min_lr,
cycle_max_lr,
decay_lr_rate,
last_batch_iteration)
# Initialize cyclic momentum
self.cycle_momentum = cycle_momentum
if cycle_momentum:
self._initialize_momentum(self.optimizer,
cycle_min_mom,
cycle_max_mom,
decay_mom_rate,
last_batch_iteration)
# Initalize batch iteration tracker
self.last_batch_iteration = last_batch_iteration
# Configure cycle shape
def _initialize_cycle(self,
cycle_first_step_size,
cycle_second_step_size,
cycle_first_stair_count,
cycle_second_stair_count,
decay_step_size):
cycle_first_step_size = float(cycle_first_step_size) cycle_first_step_size = float(cycle_first_step_size)
cycle_second_step_size = float( cycle_second_step_size = float(
cycle_second_step_size cycle_second_step_size
) if cycle_second_step_size is not None else cycle_first_step_size ) if cycle_second_step_size is not None else cycle_first_step_size
self.total_size = cycle_first_step_size + cycle_second_step_size self.total_size = cycle_first_step_size + cycle_second_step_size
self.step_ratio = cycle_first_step_size / self.total_size self.step_ratio = cycle_first_step_size / self.total_size
self.first_stair_count = cycle_first_stair_count self.first_stair_count = cycle_first_stair_count
self.second_stair_count = cycle_first_stair_count if cycle_second_stair_count is None else cycle_second_stair_count self.second_stair_count = cycle_first_stair_count if cycle_second_stair_count is None else cycle_second_stair_count
self.decay_step_size = decay_step_size
# Configure lr schedule
def _initialize_lr(self,
optimizer,
cycle_min_lr,
cycle_max_lr,
decay_lr_rate,
last_batch_iteration):
self.min_lrs = [cycle_min_lr] * len(optimizer.param_groups)
if last_batch_iteration == -1:
for lr, group in zip(self.min_lrs, optimizer.param_groups):
group['lr'] = lr
self.max_lrs = [cycle_max_lr] * len(optimizer.param_groups)
self.decay_lr_rate = decay_lr_rate self.decay_lr_rate = decay_lr_rate
self.decay_mom_rate = decay_mom_rate
self.decay_step_size = decay_step_size
# Configure momentum schedule
def _initialize_momentum(self,
optimizer,
cycle_min_mom,
cycle_max_mom,
decay_mom_rate,
last_batch_iteration):
if 'betas' not in optimizer.defaults:
optimizer_name = type(optimizer).__name__
logging.warn(
f"cycle_momentum is disabled because optimizer {optimizer_name} does not support momentum, no betas attribute in defaults"
)
self.cycle_momentum = False
return
self.decay_mom_rate = decay_mom_rate
self.min_moms = [(cycle_min_mom, 0.99)] * len(optimizer.param_groups) self.min_moms = [(cycle_min_mom, 0.99)] * len(optimizer.param_groups)
self.max_moms = [(cycle_max_mom, 0.99)] * len(optimizer.param_groups) self.max_moms = [(cycle_max_mom, 0.99)] * len(optimizer.param_groups)
self.cycle_momentum = cycle_momentum
self.last_batch_iteration = last_batch_iteration if last_batch_iteration == -1:
for momentum, group in zip(self.min_moms, optimizer.param_groups):
if cycle_momentum: group['betas'] = momentum
if 'betas' not in optimizer.defaults:
raise ValueError(
'optimizer must support betas with `cycle_momentum` option enabled')
if last_batch_iteration == -1:
for momentum, group in zip(self.min_moms, optimizer.param_groups):
group['betas'] = momentum
def _get_cycle_lr(self): def _get_cycle_lr(self):
cycle = math.floor(1 + self.last_batch_iteration / self.total_size) cycle = math.floor(1 + self.last_batch_iteration / self.total_size)
...@@ -606,10 +668,10 @@ class WarmupLR(object): ...@@ -606,10 +668,10 @@ class WarmupLR(object):
warmup_num_steps: int = 1000, warmup_num_steps: int = 1000,
last_batch_iteration: int = -1): last_batch_iteration: int = -1):
self.optimizer = optimizer self.optimizer = get_torch_optimizer(optimizer)
self.min_lrs = self._format_param(optimizer, warmup_min_lr, "min_lr") self.min_lrs = self._format_param(self.optimizer, warmup_min_lr, "min_lr")
self.max_lrs = self._format_param(optimizer, warmup_max_lr, "max_lr") self.max_lrs = self._format_param(self.optimizer, warmup_max_lr, "max_lr")
self.delta_lrs = [big - small for big, small in zip(self.max_lrs, self.min_lrs)] self.delta_lrs = [big - small for big, small in zip(self.max_lrs, self.min_lrs)]
self.warmup_num_steps = warmup_num_steps self.warmup_num_steps = warmup_num_steps
self.inverse_log_warm_up = 1.0 / math.log(warmup_num_steps) self.inverse_log_warm_up = 1.0 / math.log(warmup_num_steps)
......
...@@ -142,3 +142,107 @@ def test_adamw_fp16_empty_grad(tmpdir): ...@@ -142,3 +142,107 @@ def test_adamw_fp16_empty_grad(tmpdir):
model.step() model.step()
_test_adamw_fp16_empty_grad(args=args, model=model, hidden_dim=hidden_dim) _test_adamw_fp16_empty_grad(args=args, model=model, hidden_dim=hidden_dim)
def test_adam_fp16_onecycle_compatibility(tmpdir):
config_dict = {
"train_batch_size": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"scheduler": {
"type": "OneCycle",
"params": {
"cycle_first_step_size": 16000,
"cycle_first_stair_count": 8000,
"decay_step_size": 16000,
"cycle_min_lr": 1e-06,
"cycle_max_lr": 3e-05,
"decay_lr_rate": 1e-07,
"cycle_min_mom": 0.85,
"cycle_max_mom": 0.99,
"decay_mom_rate": 0.0
}
},
"fp16": {
"enabled": True
},
"zero_optimization": False
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=True)
@distributed_test(world_size=[1])
def _test_adam_fp16_onecycle_compatibility(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
_test_adam_fp16_onecycle_compatibility(args=args, model=model, hidden_dim=hidden_dim)
def test_adam_fp16_zero_onecycle_compatibility(tmpdir):
config_dict = {
"train_batch_size": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"scheduler": {
"type": "OneCycle",
"params": {
"cycle_first_step_size": 16000,
"cycle_first_stair_count": 8000,
"decay_step_size": 16000,
"cycle_min_lr": 1e-06,
"cycle_max_lr": 3e-05,
"decay_lr_rate": 1e-07,
"cycle_min_mom": 0.85,
"cycle_max_mom": 0.99,
"decay_mom_rate": 0.0
}
},
"fp16": {
"enabled": True
},
"zero_optimization": True
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=True)
@distributed_test(world_size=[1])
def _test_adam_fp16_zero_onecycle_compatibility(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
_test_adam_fp16_zero_onecycle_compatibility(args=args,
model=model,
hidden_dim=hidden_dim)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment