Unverified Commit da5563a9 authored by Olatunji Ruwase's avatar Olatunji Ruwase Committed by GitHub
Browse files

LR scheduler unit tests (#429)



* Add Linear warmup+decay lr schedule
Update lr schedule unit tests

* LR scheduler unit tests for LR Range Test and 1Cycle

* Disable yapf to preserve parameterizaton

* Disable test_pipe.py for CI debugging

* Disable test_lr_scheduler for CI debugging

* Disable test_lr_scheduler for CI debugging

* Enable all unit tests for CI debugging
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
parent c14b839d
...@@ -367,10 +367,10 @@ class LRRangeTest(object): ...@@ -367,10 +367,10 @@ class LRRangeTest(object):
self._update_optimizer(self.min_lr) self._update_optimizer(self.min_lr)
def _staircase_interval(self): def _staircase_interval(self):
return math.floor(float(self.last_batch_iteration) / self.step_size) return math.floor(float(self.last_batch_iteration + 1) / self.step_size)
def _continous_interval(self): def _continous_interval(self):
return float(self.last_batch_iteration) / self.step_size return float(self.last_batch_iteration + 1) / self.step_size
def _get_increase(self): def _get_increase(self):
return (1 + self.step_rate * self.interval_fn()) return (1 + self.step_rate * self.interval_fn())
...@@ -574,21 +574,19 @@ class OneCycle(object): ...@@ -574,21 +574,19 @@ class OneCycle(object):
for momentum, group in zip(self.min_moms, optimizer.param_groups): for momentum, group in zip(self.min_moms, optimizer.param_groups):
group['betas'] = momentum group['betas'] = momentum
def _get_cycle_lr(self): def _get_scale_factor(self):
cycle = math.floor(1 + self.last_batch_iteration / self.total_size) batch_iteration = (self.last_batch_iteration + 1)
x = 1. + self.last_batch_iteration / self.total_size - cycle cycle = math.floor(1 + batch_iteration / self.total_size)
x = 1. + batch_iteration / self.total_size - cycle
if x <= self.step_ratio: if x <= self.step_ratio:
scale_factor = x / self.step_ratio scale_factor = x / self.step_ratio
else: else:
scale_factor = (x - 1) / (self.step_ratio - 1) scale_factor = (x - 1) / (self.step_ratio - 1)
lrs = [] return scale_factor
for cycle_min_lr, cycle_max_lr in zip(self.min_lrs, self.max_lrs):
base_height = (cycle_max_lr - cycle_min_lr) * scale_factor
lr = cycle_min_lr + base_height
lrs.append(lr)
if self.cycle_momentum: def _get_cycle_mom(self):
scale_factor = self._get_scale_factor()
momentums = [] momentums = []
for base_betas, max_betas in zip(self.min_moms, self.max_moms): for base_betas, max_betas in zip(self.min_moms, self.max_moms):
cycle_min_mom = base_betas[0] cycle_min_mom = base_betas[0]
...@@ -596,44 +594,53 @@ class OneCycle(object): ...@@ -596,44 +594,53 @@ class OneCycle(object):
base_height = (cycle_max_mom - cycle_min_mom) * scale_factor base_height = (cycle_max_mom - cycle_min_mom) * scale_factor
momentum = cycle_max_mom - base_height momentum = cycle_max_mom - base_height
momentums.append((momentum, base_betas[1])) momentums.append((momentum, base_betas[1]))
for param_group, momentum in zip(self.optimizer.param_groups, momentums): return momentums
param_group['betas'] = momentum
def _get_cycle_lr(self):
scale_factor = self._get_scale_factor()
lrs = []
for cycle_min_lr, cycle_max_lr in zip(self.min_lrs, self.max_lrs):
base_height = (cycle_max_lr - cycle_min_lr) * scale_factor
lr = cycle_min_lr + base_height
lrs.append(lr)
return lrs return lrs
def _get_decay_mom(self, decay_batch_iteration):
decay_interval = decay_batch_iteration / self.decay_step_size
mom_decay_factor = (1 + self.decay_mom_rate * decay_interval)
momentums = [(beta0 * mom_decay_factor, beta1) for beta0, beta1 in self.max_moms]
return momentums
def _get_decay_lr(self, decay_batch_iteration): def _get_decay_lr(self, decay_batch_iteration):
"""Calculates the learning rate at batch index. This function is used """Calculates the learning rate at batch index. This function is used
after the cycle completes and post cycle decaying of lr/mom is enabled. after the cycle completes and post cycle decaying of lr/mom is enabled.
This function treats `self.last_batch_iteration` as the last batch index. This function treats `self.last_batch_iteration` as the last batch index.
If `self.cycle_momentum` is ``True``, this function has a side effect of
updating the optimizer's momentum.
""" """
decay_interval = decay_batch_iteration / self.decay_step_size decay_interval = decay_batch_iteration / self.decay_step_size
lr_decay_factor = (1 + self.decay_lr_rate * decay_interval) lr_decay_factor = (1 + self.decay_lr_rate * decay_interval)
lrs = [cycle_min_lr * lr_decay_factor for cycle_min_lr in self.min_lrs] lrs = [cycle_min_lr / lr_decay_factor for cycle_min_lr in self.min_lrs]
if self.cycle_momentum:
mom_decay_factor = (1 + self.decay_mom_rate * decay_interval)
momentums = [(beta0 * mom_decay_factor,
beta1) for beta0,
beta1 in self.max_moms]
for param_group, momentum in zip(self.optimizer.param_groups, momentums):
param_group['betas'] = momentum
return lrs return lrs
def get_lr(self): def get_lr(self):
"""Calculates the learning rate at batch index. This function treats """Calculates the learning rate at batch index. This function treats
`self.last_batch_iteration` as the last batch index. `self.last_batch_iteration` as the last batch index.
If `self.cycle_momentum` is ``True``, this function has a side effect of
updating the optimizer's momentum.
""" """
if self.last_batch_iteration <= self.total_size: if self.last_batch_iteration < self.total_size:
return self._get_cycle_lr() return self._get_cycle_lr()
return self._get_decay_lr(self.last_batch_iteration - self.total_size) return self._get_decay_lr(self.last_batch_iteration - self.total_size + 1)
def get_mom(self):
"""Calculates the momentum at batch index. This function treats
`self.last_batch_iteration` as the last batch index.
"""
if not self.cycle_momentum:
return None
if self.last_batch_iteration < self.total_size:
return self._get_cycle_mom()
return self._get_decay_mom(self.last_batch_iteration - self.total_size + 1)
def get_last_lr(self): def get_last_lr(self):
""" Return last computed learning rate by current scheduler. """ Return last computed learning rate by current scheduler.
...@@ -642,13 +649,24 @@ class OneCycle(object): ...@@ -642,13 +649,24 @@ class OneCycle(object):
return self._last_lr return self._last_lr
def step(self, batch_iteration=None): def step(self, batch_iteration=None):
""" Updates the optimizer with the learning rate for the last batch index.
`self.last_batch_iteration` is treated as the last batch index.
If self.cycle_momentum is true, also updates optimizer momentum.
"""
if batch_iteration is None: if batch_iteration is None:
batch_iteration = self.last_batch_iteration + 1 batch_iteration = self.last_batch_iteration + 1
self.last_batch_iteration = batch_iteration self.last_batch_iteration = batch_iteration
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr param_group['lr'] = lr
self._last_lr = [group['lr'] for group in self.optimizer.param_groups] self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
if self.cycle_momentum:
momentums = self.get_mom()
for param_group, momentum in zip(self.optimizer.param_groups, momentums):
param_group['betas'] = momentum
def state_dict(self): def state_dict(self):
return {'last_batch_iteration': self.last_batch_iteration} return {'last_batch_iteration': self.last_batch_iteration}
......
...@@ -6,9 +6,28 @@ import json ...@@ -6,9 +6,28 @@ import json
import os import os
from common import distributed_test from common import distributed_test
from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict
from deepspeed.runtime.lr_schedules import LR_RANGE_TEST, ONE_CYCLE, WARMUP_LR, WARMUP_DECAY_LR from deepspeed.runtime.lr_schedules import LR_RANGE_TEST, LR_RANGE_TEST_MIN_LR, LR_RANGE_TEST_STEP_RATE, LR_RANGE_TEST_STEP_SIZE, LR_RANGE_TEST_STAIRCASE
from deepspeed.runtime.lr_schedules import WARMUP_MIN_LR, WARMUP_MAX_LR, WARMUP_NUM_STEPS, TOTAL_NUM_STEPS from deepspeed.runtime.lr_schedules import WARMUP_LR, WARMUP_MIN_LR, WARMUP_MAX_LR, WARMUP_NUM_STEPS
from deepspeed.runtime.lr_schedules import CYCLE_MIN_LR, CYCLE_MAX_LR from deepspeed.runtime.lr_schedules import ONE_CYCLE, CYCLE_MIN_LR, CYCLE_MAX_LR, CYCLE_FIRST_STEP_SIZE, DECAY_LR_RATE, DECAY_STEP_SIZE
from deepspeed.runtime.lr_schedules import CYCLE_MIN_MOM, CYCLE_MAX_MOM, DECAY_MOM_RATE
from deepspeed.runtime.lr_schedules import WARMUP_DECAY_LR, TOTAL_NUM_STEPS
def _verify_continuous_decrease(values):
for i in range(len(values) - 1):
assert values[i] > values[i + 1]
def _verify_continuous_increase(values):
for i in range(len(values) - 1):
assert values[i] < values[i + 1]
def _verify_staircase_increase(values, step_size):
num_values = len(values)
for i in range(0, num_values, step_size):
j = min(i + step_size, num_values)
assert all([values[i] == v for v in values[i:j]])
@pytest.mark.parametrize("scheduler_type,params", @pytest.mark.parametrize("scheduler_type,params",
...@@ -22,7 +41,7 @@ from deepspeed.runtime.lr_schedules import CYCLE_MIN_LR, CYCLE_MAX_LR ...@@ -22,7 +41,7 @@ from deepspeed.runtime.lr_schedules import CYCLE_MIN_LR, CYCLE_MAX_LR
(ONE_CYCLE, (ONE_CYCLE,
{ {
CYCLE_MIN_LR: 0, CYCLE_MIN_LR: 0,
CYCLE_MAX_LR: 0 CYCLE_MAX_LR: 0.1
}), }),
(LR_RANGE_TEST, (LR_RANGE_TEST,
{})]) {})])
...@@ -205,3 +224,304 @@ def test_lr_warmup_decay_schedule(tmpdir, warmup_num_steps): ...@@ -205,3 +224,304 @@ def test_lr_warmup_decay_schedule(tmpdir, warmup_num_steps):
hidden_dim=hidden_dim, hidden_dim=hidden_dim,
schedule_params=schedule_params, schedule_params=schedule_params,
num_steps=total_num_steps) num_steps=total_num_steps)
@pytest.mark.parametrize("scheduler_type,params",
[(WARMUP_LR,
{}),
(WARMUP_DECAY_LR,
{
WARMUP_NUM_STEPS: 5,
TOTAL_NUM_STEPS: 10
}),
(ONE_CYCLE,
{
CYCLE_MIN_LR: 0,
CYCLE_MAX_LR: 0.1,
CYCLE_FIRST_STEP_SIZE: 5,
DECAY_STEP_SIZE: 5
}),
(LR_RANGE_TEST,
{
LR_RANGE_TEST_MIN_LR: 1e-4,
LR_RANGE_TEST_STEP_SIZE: 1
})])
def test_scheduler_optimizer_parity(tmpdir, scheduler_type, params):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
},
},
"scheduler": {
"type": scheduler_type,
"params": params
},
"gradient_clipping": 1.0
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[1])
def _test_scheduler_optimizer_parity(args, model, hidden_dim):
model, _, _, lr_scheduler = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device,
dtype=torch.float)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
assert lr_scheduler.get_lr() == model.get_lr()
_test_scheduler_optimizer_parity(args=args, model=model, hidden_dim=hidden_dim)
@pytest.mark.parametrize("min_lr, step_rate, step_size, staircase",
[(1e-4, 1e-5, 1, True),
(1e-5, 1e-5, 1, False),
(1e-4, 1e-3, 10, True),
(1e-3, 1e-3, 10, False),
(1e-2, 1e-2, 19, True),
(1e-2, 1e-2, 19, False)
])# yapf: disable
def test_lr_range_test(tmpdir, min_lr, step_rate, step_size, staircase):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
},
},
"scheduler": {
"type": LR_RANGE_TEST,
"params": {
LR_RANGE_TEST_MIN_LR: min_lr,
LR_RANGE_TEST_STEP_RATE: step_rate,
LR_RANGE_TEST_STEP_SIZE: step_size,
LR_RANGE_TEST_STAIRCASE: staircase
}
},
"gradient_clipping": 1.0
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[1])
def _test_lr_range_test(args, model, hidden_dim, min_lr, step_size, staircase):
model, _, _, lr_scheduler = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=max(50,
step_size * 2),
hidden_dim=hidden_dim,
device=model.device,
dtype=torch.float)
step_lrs = []
for _, batch in enumerate(data_loader):
step_lrs.append(lr_scheduler.get_lr())
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
# Verify starting lr
assert step_lrs[0] == min_lr
if staircase:
# Verify staircase increasing lr
_verify_staircase_increase(step_lrs, step_size)
else:
# Verify continuous increasing lr
_verify_continuous_increase(step_lrs)
_test_lr_range_test(args=args,
model=model,
hidden_dim=hidden_dim,
min_lr=[min_lr],
step_size=step_size,
staircase=staircase)
@pytest.mark.parametrize("min_lr, max_lr, decay_rate, step_size",
[
(1e-5, 1e-2, 1e-3, 10),
(1e-3, 1e-1, 0, 21),
(1e-5, 1e-2, 1e-3, 10),
(1e-3, 1e-1, 0, 21),
]) # yapf: disable
def test_onecycle_lr(tmpdir, min_lr, max_lr, decay_rate, step_size):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
},
},
"scheduler": {
"type": ONE_CYCLE,
"params": {
CYCLE_MIN_LR: min_lr,
CYCLE_MAX_LR: max_lr,
DECAY_LR_RATE: decay_rate,
CYCLE_FIRST_STEP_SIZE: step_size,
DECAY_STEP_SIZE: step_size
}
},
"gradient_clipping": 1.0
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[1])
def _test_onecycle_lr(args,
model,
hidden_dim,
min_lr,
max_lr,
step_size,
decay_rate):
model, _, _, lr_scheduler = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=max(50,
step_size * 3),
hidden_dim=hidden_dim,
device=model.device,
dtype=torch.float)
step_lrs = []
for _, batch in enumerate(data_loader):
step_lrs.append(lr_scheduler.get_lr())
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
# Verify starting lr
assert step_lrs[0] == min_lr
# Verify peak lr
assert step_lrs[step_size] == max_lr
# Verify increasing phase
_verify_continuous_increase(step_lrs[:step_size])
# Verify decreasing phase
_verify_continuous_decrease(step_lrs[step_size:(step_size * 2)])
# Verify decay phase
if decay_rate > 0:
_verify_continuous_decrease(step_lrs[(step_size * 2):])
_test_onecycle_lr(args=args,
model=model,
hidden_dim=hidden_dim,
min_lr=[min_lr],
max_lr=[max_lr],
step_size=step_size,
decay_rate=decay_rate)
@pytest.mark.parametrize("min_mom, max_mom, decay_rate, step_size",
[
(0.08, 0.09, 1e-3, 10),
(0.08, 0.09, 0, 21),
(0.08, 0.09, 1e-3, 10),
(0.08, 0.09, 0, 21),
]) # yapf: disable
def test_onecycle_mom(tmpdir, min_mom, max_mom, decay_rate, step_size):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
},
},
"scheduler": {
"type": ONE_CYCLE,
"params": {
CYCLE_MIN_LR: 1e-3,
CYCLE_MAX_LR: 1e-2,
CYCLE_MIN_MOM: min_mom,
CYCLE_MAX_MOM: max_mom,
DECAY_MOM_RATE: decay_rate,
CYCLE_FIRST_STEP_SIZE: step_size,
DECAY_STEP_SIZE: step_size
}
},
"gradient_clipping": 1.0
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[1])
def _test_onecycle_mom(args,
model,
hidden_dim,
min_mom,
max_mom,
step_size,
decay_rate):
model, _, _, lr_scheduler = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=max(50,
step_size * 3),
hidden_dim=hidden_dim,
device=model.device,
dtype=torch.float)
step_moms = []
for _, batch in enumerate(data_loader):
step_moms.append(lr_scheduler.get_mom())
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
# Verify starting lr
assert step_moms[0][0][0] == max_mom
# Verify peak lr
assert step_moms[step_size][0][0] == min_mom
# Verify decreasing phase
_verify_continuous_decrease(step_moms[:step_size])
# Verify increasing phase
_verify_continuous_increase(step_moms[step_size:(step_size * 2)])
# Verify decay phase
if decay_rate > 0:
_verify_continuous_increase(step_moms[(step_size * 2):])
_test_onecycle_mom(args=args,
model=model,
hidden_dim=hidden_dim,
min_mom=min_mom,
max_mom=max_mom,
step_size=step_size,
decay_rate=decay_rate)
...@@ -10,6 +10,7 @@ import pytest ...@@ -10,6 +10,7 @@ import pytest
import deepspeed import deepspeed
import deepspeed.runtime.utils as ds_utils import deepspeed.runtime.utils as ds_utils
from deepspeed.runtime.pipe.topology import PipeDataParallelTopology, PipeModelDataParallelTopology from deepspeed.runtime.pipe.topology import PipeDataParallelTopology, PipeModelDataParallelTopology
PipeTopo = PipeDataParallelTopology PipeTopo = PipeDataParallelTopology
import deepspeed.runtime.pipe.module as PipelineModule import deepspeed.runtime.pipe.module as PipelineModule
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment