Commit b3677ae2 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Remove piecewise decay with warmup and recompose it with stepwise + warmupdecay.

PiperOrigin-RevId: 325093611
parent 0edeca54
......@@ -40,8 +40,6 @@ model:
momentum: 0.9
decay: 0.9
epsilon: 0.001
learning_rate:
name: 'piecewise_constant_with_warmup'
loss:
label_smoothing: 0.1
train:
......
......@@ -43,8 +43,6 @@ model:
epsilon: 0.001
moving_average_decay: 0.
lookahead: False
learning_rate:
name: 'piecewise_constant_with_warmup'
loss:
label_smoothing: 0.1
train:
......
......@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Any, List, Mapping
from typing import Any, Mapping, Optional
import numpy as np
import tensorflow as tf
......@@ -32,23 +32,33 @@ class WarmupDecaySchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
def __init__(
self,
lr_schedule: tf.keras.optimizers.schedules.LearningRateSchedule,
warmup_steps: int):
warmup_steps: int,
warmup_lr: Optional[float] = None):
"""Add warmup decay to a learning rate schedule.
Args:
lr_schedule: base learning rate scheduler
warmup_steps: number of warmup steps
warmup_lr: an optional field for the final warmup learning rate. This
should be provided if the base `lr_schedule` does not contain this
field.
"""
super(WarmupDecaySchedule, self).__init__()
self._lr_schedule = lr_schedule
self._warmup_steps = warmup_steps
self._warmup_lr = warmup_lr
def __call__(self, step: int):
lr = self._lr_schedule(step)
if self._warmup_steps:
initial_learning_rate = tf.convert_to_tensor(
self._lr_schedule.initial_learning_rate, name="initial_learning_rate")
if self._warmup_lr is not None:
initial_learning_rate = tf.convert_to_tensor(
self._warmup_lr, name="initial_learning_rate")
else:
initial_learning_rate = tf.convert_to_tensor(
self._lr_schedule.initial_learning_rate,
name="initial_learning_rate")
dtype = initial_learning_rate.dtype
global_step_recomp = tf.cast(step, dtype)
warmup_steps = tf.cast(self._warmup_steps, dtype)
......@@ -62,65 +72,11 @@ class WarmupDecaySchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
config = self._lr_schedule.get_config()
config.update({
"warmup_steps": self._warmup_steps,
"warmup_lr": self._warmup_lr,
})
return config
# TODO(b/149030439) - refactor this with
# tf.keras.optimizers.schedules.PiecewiseConstantDecay + WarmupDecaySchedule.
class PiecewiseConstantDecayWithWarmup(
tf.keras.optimizers.schedules.LearningRateSchedule):
"""Piecewise constant decay with warmup schedule."""
def __init__(self,
batch_size: int,
epoch_size: int,
warmup_epochs: int,
boundaries: List[int],
multipliers: List[float]):
"""Piecewise constant decay with warmup.
Args:
batch_size: The training batch size used in the experiment.
epoch_size: The size of an epoch, or the number of examples in an epoch.
warmup_epochs: The number of warmup epochs to apply.
boundaries: The list of floats with strictly increasing entries.
multipliers: The list of multipliers/learning rates to use for the
piecewise portion. The length must be 1 less than that of boundaries.
"""
super(PiecewiseConstantDecayWithWarmup, self).__init__()
if len(boundaries) != len(multipliers) - 1:
raise ValueError("The length of boundaries must be 1 less than the "
"length of multipliers")
base_lr_batch_size = 256
steps_per_epoch = epoch_size // batch_size
self._rescaled_lr = BASE_LEARNING_RATE * batch_size / base_lr_batch_size
self._step_boundaries = [float(steps_per_epoch) * x for x in boundaries]
self._lr_values = [self._rescaled_lr * m for m in multipliers]
self._warmup_steps = warmup_epochs * steps_per_epoch
def __call__(self, step: int):
"""Compute learning rate at given step."""
def warmup_lr():
return self._rescaled_lr * (
step / tf.cast(self._warmup_steps, tf.float32))
def piecewise_lr():
return tf.compat.v1.train.piecewise_constant(
tf.cast(step, tf.float32), self._step_boundaries, self._lr_values)
return tf.cond(step < self._warmup_steps, warmup_lr, piecewise_lr)
def get_config(self) -> Mapping[str, Any]:
return {
"rescaled_lr": self._rescaled_lr,
"step_boundaries": self._step_boundaries,
"lr_values": self._lr_values,
"warmup_steps": self._warmup_steps,
}
class CosineDecayWithWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Class to generate learning rate tensor."""
......
......@@ -46,44 +46,6 @@ class LearningRateTests(tf.test.TestCase):
self.assertAllClose(self.evaluate(lr(step)),
step / warmup_steps * initial_lr)
def test_piecewise_constant_decay_with_warmup(self):
"""Basic computational test for piecewise constant decay with warmup."""
boundaries = [1, 2, 3]
warmup_epochs = boundaries[0]
learning_rate_multipliers = [1.0, 0.1, 0.001]
expected_keys = [
'rescaled_lr', 'step_boundaries', 'lr_values', 'warmup_steps',
]
expected_lrs = [0.0, 0.1, 0.1]
lr = learning_rate.PiecewiseConstantDecayWithWarmup(
batch_size=256,
epoch_size=256,
warmup_epochs=warmup_epochs,
boundaries=boundaries[1:],
multipliers=learning_rate_multipliers)
step = 0
config = lr.get_config()
self.assertAllInSet(list(config.keys()), expected_keys)
for boundary, expected_lr in zip(boundaries, expected_lrs):
for _ in range(step, boundary):
self.assertAllClose(self.evaluate(lr(step)), expected_lr)
step += 1
def test_piecewise_constant_decay_invalid_boundaries(self):
with self.assertRaisesRegex(ValueError,
'The length of boundaries must be 1 less '):
learning_rate.PiecewiseConstantDecayWithWarmup(
batch_size=256,
epoch_size=256,
warmup_epochs=1,
boundaries=[1, 2],
multipliers=[1, 2])
def test_cosine_decay_with_warmup(self):
"""Basic computational test for cosine decay with warmup."""
expected_lrs = [0.0, 0.1, 0.05, 0.0]
......
......@@ -370,29 +370,26 @@ def build_learning_rate(params: base_configs.LearningRateConfig,
decay_steps=decay_steps,
decay_rate=decay_rate,
staircase=params.staircase)
elif decay_type == 'piecewise_constant_with_warmup':
logging.info('Using Piecewise constant decay with warmup. '
'Parameters: batch_size: %d, epoch_size: %d, '
'warmup_epochs: %d, boundaries: %s, multipliers: %s',
batch_size, params.examples_per_epoch,
params.warmup_epochs, params.boundaries,
params.multipliers)
lr = learning_rate.PiecewiseConstantDecayWithWarmup(
batch_size=batch_size,
epoch_size=params.examples_per_epoch,
warmup_epochs=params.warmup_epochs,
boundaries=params.boundaries,
multipliers=params.multipliers)
elif decay_type == 'stepwise':
steps_per_epoch = params.examples_per_epoch // batch_size
boundaries = [boundary * steps_per_epoch for boundary in params.boundaries]
multipliers = [batch_size * multiplier for multiplier in params.multipliers]
logging.info('Using stepwise learning rate. Parameters: '
'boundaries: %s, values: %s',
boundaries, multipliers)
lr = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
boundaries=boundaries,
values=multipliers)
elif decay_type == 'cosine_with_warmup':
lr = learning_rate.CosineDecayWithWarmup(
batch_size=batch_size,
total_steps=train_epochs * train_steps,
warmup_steps=warmup_steps)
if warmup_steps > 0:
if decay_type not in [
'piecewise_constant_with_warmup', 'cosine_with_warmup'
]:
if decay_type not in ['cosine_with_warmup']:
logging.info('Applying %d warmup steps to the learning rate',
warmup_steps)
lr = learning_rate.WarmupDecaySchedule(lr, warmup_steps)
lr = learning_rate.WarmupDecaySchedule(lr,
warmup_steps,
warmup_lr=base_lr)
return lr
......@@ -93,7 +93,6 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
('exponential', 'exponential'),
('piecewise_constant_with_warmup', 'piecewise_constant_with_warmup'),
('cosine_with_warmup', 'cosine_with_warmup'))
def test_learning_rate_with_decay_and_warmup(self, lr_decay_type):
"""Basic smoke test for syntax."""
......
......@@ -18,22 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Any, Mapping
import dataclasses
from official.modeling.hyperparams import base_config
from official.vision.image_classification.configs import base_configs
_RESNET_LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
(1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
]
_RESNET_LR_BOUNDARIES = list(p[1] for p in _RESNET_LR_SCHEDULE[1:])
_RESNET_LR_MULTIPLIERS = list(p[0] for p in _RESNET_LR_SCHEDULE)
_RESNET_LR_WARMUP_EPOCHS = _RESNET_LR_SCHEDULE[0][1]
@dataclasses.dataclass
class ResNetModelConfig(base_configs.ModelConfig):
"""Configuration for the ResNet model."""
......@@ -56,8 +46,13 @@ class ResNetModelConfig(base_configs.ModelConfig):
moving_average_decay=None)
learning_rate: base_configs.LearningRateConfig = (
base_configs.LearningRateConfig(
name='piecewise_constant_with_warmup',
name='stepwise',
initial_lr=0.1,
examples_per_epoch=1281167,
warmup_epochs=_RESNET_LR_WARMUP_EPOCHS,
boundaries=_RESNET_LR_BOUNDARIES,
multipliers=_RESNET_LR_MULTIPLIERS))
boundaries=[30, 60, 80],
warmup_epochs=5,
scale_by_batch_size=1. / 256.,
multipliers=[0.1 / 256,
0.01 / 256,
0.001 / 256,
0.0001 / 256]))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment