"git@developer.sourcefind.cn:OpenDAS/torch-scatter.git" did not exist on "26a9e988ab78d075097fc4645883696d4d41ca9c"
Commit 88985493 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 315545272
parent 43605bd5
...@@ -333,6 +333,7 @@ def train_and_eval( ...@@ -333,6 +333,7 @@ def train_and_eval(
learning_rate = optimizer_factory.build_learning_rate( learning_rate = optimizer_factory.build_learning_rate(
params=params.model.learning_rate, params=params.model.learning_rate,
batch_size=train_builder.global_batch_size, batch_size=train_builder.global_batch_size,
train_epochs=train_epochs,
train_steps=train_steps) train_steps=train_steps)
optimizer = optimizer_factory.build_optimizer( optimizer = optimizer_factory.build_optimizer(
optimizer_name=params.model.optimizer.name, optimizer_name=params.model.optimizer.name,
......
...@@ -20,6 +20,7 @@ from __future__ import print_function ...@@ -20,6 +20,7 @@ from __future__ import print_function
from typing import Any, List, Mapping from typing import Any, List, Mapping
import numpy as np
import tensorflow as tf import tensorflow as tf
BASE_LEARNING_RATE = 0.1 BASE_LEARNING_RATE = 0.1
...@@ -118,3 +119,46 @@ class PiecewiseConstantDecayWithWarmup( ...@@ -118,3 +119,46 @@ class PiecewiseConstantDecayWithWarmup(
"lr_values": self._lr_values, "lr_values": self._lr_values,
"warmup_steps": self._warmup_steps, "warmup_steps": self._warmup_steps,
} }
class CosineDecayWithWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Class to generate learning rate tensor."""
def __init__(self, batch_size: int, total_steps: int, warmup_steps: int):
"""Creates the consine learning rate tensor with linear warmup.
Args:
batch_size: The training batch size used in the experiment.
total_steps: Total training steps.
warmup_steps: Steps for the warm up period.
"""
super(CosineDecayWithWarmup, self).__init__()
base_lr_batch_size = 256
self._total_steps = total_steps
self._init_learning_rate = BASE_LEARNING_RATE * batch_size / base_lr_batch_size
self._warmup_steps = warmup_steps
def __call__(self, global_step: int):
global_step = tf.cast(global_step, dtype=tf.float32)
warmup_steps = self._warmup_steps
init_lr = self._init_learning_rate
total_steps = self._total_steps
linear_warmup = global_step / warmup_steps * init_lr
cosine_learning_rate = init_lr * (tf.cos(np.pi *
(global_step - warmup_steps) /
(total_steps - warmup_steps)) +
1.0) / 2.0
learning_rate = tf.where(global_step < warmup_steps, linear_warmup,
cosine_learning_rate)
return learning_rate
def get_config(self):
return {
"total_steps": self._total_steps,
"warmup_learning_rate": self._warmup_learning_rate,
"warmup_steps": self._warmup_steps,
"init_learning_rate": self._init_learning_rate,
}
...@@ -84,6 +84,16 @@ class LearningRateTests(tf.test.TestCase): ...@@ -84,6 +84,16 @@ class LearningRateTests(tf.test.TestCase):
boundaries=[1, 2], boundaries=[1, 2],
multipliers=[1, 2]) multipliers=[1, 2])
def test_cosine_decay_with_warmup(self):
"""Basic computational test for cosine decay with warmup."""
expected_lrs = [0.0, 0.1, 0.05, 0.0]
lr = learning_rate.CosineDecayWithWarmup(
batch_size=256, total_steps=3, warmup_steps=1)
for step in [0, 1, 2, 3]:
self.assertAllClose(lr(step), expected_lrs[step])
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -331,6 +331,7 @@ def build_optimizer( ...@@ -331,6 +331,7 @@ def build_optimizer(
def build_learning_rate(params: base_configs.LearningRateConfig, def build_learning_rate(params: base_configs.LearningRateConfig,
batch_size: int = None, batch_size: int = None,
train_epochs: int = None,
train_steps: int = None): train_steps: int = None):
"""Build the learning rate given the provided configuration.""" """Build the learning rate given the provided configuration."""
decay_type = params.name decay_type = params.name
...@@ -375,8 +376,15 @@ def build_learning_rate(params: base_configs.LearningRateConfig, ...@@ -375,8 +376,15 @@ def build_learning_rate(params: base_configs.LearningRateConfig,
warmup_epochs=params.warmup_epochs, warmup_epochs=params.warmup_epochs,
boundaries=params.boundaries, boundaries=params.boundaries,
multipliers=params.multipliers) multipliers=params.multipliers)
elif decay_type == 'cosine_with_warmup':
lr = learning_rate.CosineDecayWithWarmup(
batch_size=batch_size,
total_steps=train_epochs * train_steps,
warmup_steps=warmup_steps)
if warmup_steps > 0: if warmup_steps > 0:
if decay_type != 'piecewise_constant_with_warmup': if decay_type not in [
'piecewise_constant_with_warmup', 'cosine_with_warmup'
]:
logging.info('Applying %d warmup steps to the learning rate', logging.info('Applying %d warmup steps to the learning rate',
warmup_steps) warmup_steps)
lr = learning_rate.WarmupDecaySchedule(lr, warmup_steps) lr = learning_rate.WarmupDecaySchedule(lr, warmup_steps)
......
...@@ -85,7 +85,8 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -85,7 +85,8 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters( @parameterized.named_parameters(
('exponential', 'exponential'), ('exponential', 'exponential'),
('piecewise_constant_with_warmup', 'piecewise_constant_with_warmup')) ('piecewise_constant_with_warmup', 'piecewise_constant_with_warmup'),
('cosine_with_warmup', 'cosine_with_warmup'))
def test_learning_rate_with_decay_and_warmup(self, lr_decay_type): def test_learning_rate_with_decay_and_warmup(self, lr_decay_type):
"""Basic smoke test for syntax.""" """Basic smoke test for syntax."""
params = base_configs.LearningRateConfig( params = base_configs.LearningRateConfig(
...@@ -99,11 +100,13 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -99,11 +100,13 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
boundaries=[0], boundaries=[0],
multipliers=[0, 1]) multipliers=[0, 1])
batch_size = 1 batch_size = 1
train_epochs = 1
train_steps = 1 train_steps = 1
lr = optimizer_factory.build_learning_rate( lr = optimizer_factory.build_learning_rate(
params=params, params=params,
batch_size=batch_size, batch_size=batch_size,
train_epochs=train_epochs,
train_steps=train_steps) train_steps=train_steps)
self.assertTrue( self.assertTrue(
issubclass( issubclass(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment