Unverified Commit da5563a9 authored by Olatunji Ruwase's avatar Olatunji Ruwase Committed by GitHub
Browse files

LR scheduler unit tests (#429)



* Add Linear warmup+decay lr schedule
Update lr schedule unit tests

* LR scheduler unit tests for LR Range Test and 1Cycle

* Disable yapf to preserve parameterizaton

* Disable test_pipe.py for CI debugging

* Disable test_lr_scheduler for CI debugging

* Disable test_lr_scheduler for CI debugging

* Enable all unit tests for CI debugging
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
parent c14b839d
......@@ -367,10 +367,10 @@ class LRRangeTest(object):
self._update_optimizer(self.min_lr)
def _staircase_interval(self):
return math.floor(float(self.last_batch_iteration) / self.step_size)
return math.floor(float(self.last_batch_iteration + 1) / self.step_size)
def _continous_interval(self):
return float(self.last_batch_iteration) / self.step_size
return float(self.last_batch_iteration + 1) / self.step_size
def _get_increase(self):
return (1 + self.step_rate * self.interval_fn())
......@@ -574,21 +574,19 @@ class OneCycle(object):
for momentum, group in zip(self.min_moms, optimizer.param_groups):
group['betas'] = momentum
def _get_cycle_lr(self):
cycle = math.floor(1 + self.last_batch_iteration / self.total_size)
x = 1. + self.last_batch_iteration / self.total_size - cycle
def _get_scale_factor(self):
batch_iteration = (self.last_batch_iteration + 1)
cycle = math.floor(1 + batch_iteration / self.total_size)
x = 1. + batch_iteration / self.total_size - cycle
if x <= self.step_ratio:
scale_factor = x / self.step_ratio
else:
scale_factor = (x - 1) / (self.step_ratio - 1)
lrs = []
for cycle_min_lr, cycle_max_lr in zip(self.min_lrs, self.max_lrs):
base_height = (cycle_max_lr - cycle_min_lr) * scale_factor
lr = cycle_min_lr + base_height
lrs.append(lr)
return scale_factor
if self.cycle_momentum:
def _get_cycle_mom(self):
scale_factor = self._get_scale_factor()
momentums = []
for base_betas, max_betas in zip(self.min_moms, self.max_moms):
cycle_min_mom = base_betas[0]
......@@ -596,44 +594,53 @@ class OneCycle(object):
base_height = (cycle_max_mom - cycle_min_mom) * scale_factor
momentum = cycle_max_mom - base_height
momentums.append((momentum, base_betas[1]))
for param_group, momentum in zip(self.optimizer.param_groups, momentums):
param_group['betas'] = momentum
return momentums
def _get_cycle_lr(self):
scale_factor = self._get_scale_factor()
lrs = []
for cycle_min_lr, cycle_max_lr in zip(self.min_lrs, self.max_lrs):
base_height = (cycle_max_lr - cycle_min_lr) * scale_factor
lr = cycle_min_lr + base_height
lrs.append(lr)
return lrs
def _get_decay_mom(self, decay_batch_iteration):
decay_interval = decay_batch_iteration / self.decay_step_size
mom_decay_factor = (1 + self.decay_mom_rate * decay_interval)
momentums = [(beta0 * mom_decay_factor, beta1) for beta0, beta1 in self.max_moms]
return momentums
def _get_decay_lr(self, decay_batch_iteration):
"""Calculates the learning rate at batch index. This function is used
after the cycle completes and post cycle decaying of lr/mom is enabled.
This function treats `self.last_batch_iteration` as the last batch index.
If `self.cycle_momentum` is ``True``, this function has a side effect of
updating the optimizer's momentum.
"""
decay_interval = decay_batch_iteration / self.decay_step_size
lr_decay_factor = (1 + self.decay_lr_rate * decay_interval)
lrs = [cycle_min_lr * lr_decay_factor for cycle_min_lr in self.min_lrs]
if self.cycle_momentum:
mom_decay_factor = (1 + self.decay_mom_rate * decay_interval)
momentums = [(beta0 * mom_decay_factor,
beta1) for beta0,
beta1 in self.max_moms]
for param_group, momentum in zip(self.optimizer.param_groups, momentums):
param_group['betas'] = momentum
lrs = [cycle_min_lr / lr_decay_factor for cycle_min_lr in self.min_lrs]
return lrs
def get_lr(self):
"""Calculates the learning rate at batch index. This function treats
`self.last_batch_iteration` as the last batch index.
If `self.cycle_momentum` is ``True``, this function has a side effect of
updating the optimizer's momentum.
"""
if self.last_batch_iteration <= self.total_size:
if self.last_batch_iteration < self.total_size:
return self._get_cycle_lr()
return self._get_decay_lr(self.last_batch_iteration - self.total_size)
return self._get_decay_lr(self.last_batch_iteration - self.total_size + 1)
def get_mom(self):
"""Calculates the momentum at batch index. This function treats
`self.last_batch_iteration` as the last batch index.
"""
if not self.cycle_momentum:
return None
if self.last_batch_iteration < self.total_size:
return self._get_cycle_mom()
return self._get_decay_mom(self.last_batch_iteration - self.total_size + 1)
def get_last_lr(self):
""" Return last computed learning rate by current scheduler.
......@@ -642,13 +649,24 @@ class OneCycle(object):
return self._last_lr
def step(self, batch_iteration=None):
""" Updates the optimizer with the learning rate for the last batch index.
`self.last_batch_iteration` is treated as the last batch index.
If self.cycle_momentum is true, also updates optimizer momentum.
"""
if batch_iteration is None:
batch_iteration = self.last_batch_iteration + 1
self.last_batch_iteration = batch_iteration
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
if self.cycle_momentum:
momentums = self.get_mom()
for param_group, momentum in zip(self.optimizer.param_groups, momentums):
param_group['betas'] = momentum
def state_dict(self):
return {'last_batch_iteration': self.last_batch_iteration}
......
......@@ -6,9 +6,28 @@ import json
import os
from common import distributed_test
from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict
from deepspeed.runtime.lr_schedules import LR_RANGE_TEST, ONE_CYCLE, WARMUP_LR, WARMUP_DECAY_LR
from deepspeed.runtime.lr_schedules import WARMUP_MIN_LR, WARMUP_MAX_LR, WARMUP_NUM_STEPS, TOTAL_NUM_STEPS
from deepspeed.runtime.lr_schedules import CYCLE_MIN_LR, CYCLE_MAX_LR
from deepspeed.runtime.lr_schedules import LR_RANGE_TEST, LR_RANGE_TEST_MIN_LR, LR_RANGE_TEST_STEP_RATE, LR_RANGE_TEST_STEP_SIZE, LR_RANGE_TEST_STAIRCASE
from deepspeed.runtime.lr_schedules import WARMUP_LR, WARMUP_MIN_LR, WARMUP_MAX_LR, WARMUP_NUM_STEPS
from deepspeed.runtime.lr_schedules import ONE_CYCLE, CYCLE_MIN_LR, CYCLE_MAX_LR, CYCLE_FIRST_STEP_SIZE, DECAY_LR_RATE, DECAY_STEP_SIZE
from deepspeed.runtime.lr_schedules import CYCLE_MIN_MOM, CYCLE_MAX_MOM, DECAY_MOM_RATE
from deepspeed.runtime.lr_schedules import WARMUP_DECAY_LR, TOTAL_NUM_STEPS
def _verify_continuous_decrease(values):
for i in range(len(values) - 1):
assert values[i] > values[i + 1]
def _verify_continuous_increase(values):
for i in range(len(values) - 1):
assert values[i] < values[i + 1]
def _verify_staircase_increase(values, step_size):
num_values = len(values)
for i in range(0, num_values, step_size):
j = min(i + step_size, num_values)
assert all([values[i] == v for v in values[i:j]])
@pytest.mark.parametrize("scheduler_type,params",
......@@ -22,7 +41,7 @@ from deepspeed.runtime.lr_schedules import CYCLE_MIN_LR, CYCLE_MAX_LR
(ONE_CYCLE,
{
CYCLE_MIN_LR: 0,
CYCLE_MAX_LR: 0
CYCLE_MAX_LR: 0.1
}),
(LR_RANGE_TEST,
{})])
......@@ -205,3 +224,304 @@ def test_lr_warmup_decay_schedule(tmpdir, warmup_num_steps):
hidden_dim=hidden_dim,
schedule_params=schedule_params,
num_steps=total_num_steps)
@pytest.mark.parametrize("scheduler_type,params",
[(WARMUP_LR,
{}),
(WARMUP_DECAY_LR,
{
WARMUP_NUM_STEPS: 5,
TOTAL_NUM_STEPS: 10
}),
(ONE_CYCLE,
{
CYCLE_MIN_LR: 0,
CYCLE_MAX_LR: 0.1,
CYCLE_FIRST_STEP_SIZE: 5,
DECAY_STEP_SIZE: 5
}),
(LR_RANGE_TEST,
{
LR_RANGE_TEST_MIN_LR: 1e-4,
LR_RANGE_TEST_STEP_SIZE: 1
})])
def test_scheduler_optimizer_parity(tmpdir, scheduler_type, params):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
},
},
"scheduler": {
"type": scheduler_type,
"params": params
},
"gradient_clipping": 1.0
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[1])
def _test_scheduler_optimizer_parity(args, model, hidden_dim):
model, _, _, lr_scheduler = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device,
dtype=torch.float)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
assert lr_scheduler.get_lr() == model.get_lr()
_test_scheduler_optimizer_parity(args=args, model=model, hidden_dim=hidden_dim)
@pytest.mark.parametrize("min_lr, step_rate, step_size, staircase",
[(1e-4, 1e-5, 1, True),
(1e-5, 1e-5, 1, False),
(1e-4, 1e-3, 10, True),
(1e-3, 1e-3, 10, False),
(1e-2, 1e-2, 19, True),
(1e-2, 1e-2, 19, False)
])# yapf: disable
def test_lr_range_test(tmpdir, min_lr, step_rate, step_size, staircase):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
},
},
"scheduler": {
"type": LR_RANGE_TEST,
"params": {
LR_RANGE_TEST_MIN_LR: min_lr,
LR_RANGE_TEST_STEP_RATE: step_rate,
LR_RANGE_TEST_STEP_SIZE: step_size,
LR_RANGE_TEST_STAIRCASE: staircase
}
},
"gradient_clipping": 1.0
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[1])
def _test_lr_range_test(args, model, hidden_dim, min_lr, step_size, staircase):
model, _, _, lr_scheduler = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=max(50,
step_size * 2),
hidden_dim=hidden_dim,
device=model.device,
dtype=torch.float)
step_lrs = []
for _, batch in enumerate(data_loader):
step_lrs.append(lr_scheduler.get_lr())
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
# Verify starting lr
assert step_lrs[0] == min_lr
if staircase:
# Verify staircase increasing lr
_verify_staircase_increase(step_lrs, step_size)
else:
# Verify continuous increasing lr
_verify_continuous_increase(step_lrs)
_test_lr_range_test(args=args,
model=model,
hidden_dim=hidden_dim,
min_lr=[min_lr],
step_size=step_size,
staircase=staircase)
@pytest.mark.parametrize("min_lr, max_lr, decay_rate, step_size",
[
(1e-5, 1e-2, 1e-3, 10),
(1e-3, 1e-1, 0, 21),
(1e-5, 1e-2, 1e-3, 10),
(1e-3, 1e-1, 0, 21),
]) # yapf: disable
def test_onecycle_lr(tmpdir, min_lr, max_lr, decay_rate, step_size):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
},
},
"scheduler": {
"type": ONE_CYCLE,
"params": {
CYCLE_MIN_LR: min_lr,
CYCLE_MAX_LR: max_lr,
DECAY_LR_RATE: decay_rate,
CYCLE_FIRST_STEP_SIZE: step_size,
DECAY_STEP_SIZE: step_size
}
},
"gradient_clipping": 1.0
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[1])
def _test_onecycle_lr(args,
model,
hidden_dim,
min_lr,
max_lr,
step_size,
decay_rate):
model, _, _, lr_scheduler = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=max(50,
step_size * 3),
hidden_dim=hidden_dim,
device=model.device,
dtype=torch.float)
step_lrs = []
for _, batch in enumerate(data_loader):
step_lrs.append(lr_scheduler.get_lr())
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
# Verify starting lr
assert step_lrs[0] == min_lr
# Verify peak lr
assert step_lrs[step_size] == max_lr
# Verify increasing phase
_verify_continuous_increase(step_lrs[:step_size])
# Verify decreasing phase
_verify_continuous_decrease(step_lrs[step_size:(step_size * 2)])
# Verify decay phase
if decay_rate > 0:
_verify_continuous_decrease(step_lrs[(step_size * 2):])
_test_onecycle_lr(args=args,
model=model,
hidden_dim=hidden_dim,
min_lr=[min_lr],
max_lr=[max_lr],
step_size=step_size,
decay_rate=decay_rate)
@pytest.mark.parametrize("min_mom, max_mom, decay_rate, step_size",
[
(0.08, 0.09, 1e-3, 10),
(0.08, 0.09, 0, 21),
(0.08, 0.09, 1e-3, 10),
(0.08, 0.09, 0, 21),
]) # yapf: disable
def test_onecycle_mom(tmpdir, min_mom, max_mom, decay_rate, step_size):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
},
},
"scheduler": {
"type": ONE_CYCLE,
"params": {
CYCLE_MIN_LR: 1e-3,
CYCLE_MAX_LR: 1e-2,
CYCLE_MIN_MOM: min_mom,
CYCLE_MAX_MOM: max_mom,
DECAY_MOM_RATE: decay_rate,
CYCLE_FIRST_STEP_SIZE: step_size,
DECAY_STEP_SIZE: step_size
}
},
"gradient_clipping": 1.0
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[1])
def _test_onecycle_mom(args,
model,
hidden_dim,
min_mom,
max_mom,
step_size,
decay_rate):
model, _, _, lr_scheduler = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=max(50,
step_size * 3),
hidden_dim=hidden_dim,
device=model.device,
dtype=torch.float)
step_moms = []
for _, batch in enumerate(data_loader):
step_moms.append(lr_scheduler.get_mom())
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
# Verify starting lr
assert step_moms[0][0][0] == max_mom
# Verify peak lr
assert step_moms[step_size][0][0] == min_mom
# Verify decreasing phase
_verify_continuous_decrease(step_moms[:step_size])
# Verify increasing phase
_verify_continuous_increase(step_moms[step_size:(step_size * 2)])
# Verify decay phase
if decay_rate > 0:
_verify_continuous_increase(step_moms[(step_size * 2):])
_test_onecycle_mom(args=args,
model=model,
hidden_dim=hidden_dim,
min_mom=min_mom,
max_mom=max_mom,
step_size=step_size,
decay_rate=decay_rate)
......@@ -10,6 +10,7 @@ import pytest
import deepspeed
import deepspeed.runtime.utils as ds_utils
from deepspeed.runtime.pipe.topology import PipeDataParallelTopology, PipeModelDataParallelTopology
PipeTopo = PipeDataParallelTopology
import deepspeed.runtime.pipe.module as PipelineModule
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment