Commit 02af9bb5 authored by Zongwei Zhou's avatar Zongwei Zhou Committed by A. Unique TensorFlower
Browse files

Use learning schedule op by default in Resnet50. Remove learning rate callback code.

PiperOrigin-RevId: 296988935
parent 4b8f80c3
...@@ -53,6 +53,10 @@ class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -53,6 +53,10 @@ class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark):
super(Resnet56KerasAccuracy, self).__init__( super(Resnet56KerasAccuracy, self).__init__(
output_dir=output_dir, flag_methods=flag_methods) output_dir=output_dir, flag_methods=flag_methods)
def _setup(self):
super(Resnet56KerasAccuracy, self)._setup()
FLAGS.use_tensor_lr = False
def benchmark_graph_1_gpu(self): def benchmark_graph_1_gpu(self):
"""Test keras based model with Keras fit and distribution strategies.""" """Test keras based model with Keras fit and distribution strategies."""
self._setup() self._setup()
...@@ -439,6 +443,7 @@ class Resnet56KerasBenchmarkSynth(Resnet56KerasBenchmarkBase): ...@@ -439,6 +443,7 @@ class Resnet56KerasBenchmarkSynth(Resnet56KerasBenchmarkBase):
default_flags['use_synthetic_data'] = True default_flags['use_synthetic_data'] = True
default_flags['train_steps'] = 110 default_flags['train_steps'] = 110
default_flags['log_steps'] = 10 default_flags['log_steps'] = 10
default_flags['use_tensor_lr'] = False
super(Resnet56KerasBenchmarkSynth, self).__init__( super(Resnet56KerasBenchmarkSynth, self).__init__(
output_dir=output_dir, default_flags=default_flags) output_dir=output_dir, default_flags=default_flags)
...@@ -453,6 +458,7 @@ class Resnet56KerasBenchmarkReal(Resnet56KerasBenchmarkBase): ...@@ -453,6 +458,7 @@ class Resnet56KerasBenchmarkReal(Resnet56KerasBenchmarkBase):
default_flags['data_dir'] = os.path.join(root_data_dir, CIFAR_DATA_DIR_NAME) default_flags['data_dir'] = os.path.join(root_data_dir, CIFAR_DATA_DIR_NAME)
default_flags['train_steps'] = 110 default_flags['train_steps'] = 110
default_flags['log_steps'] = 10 default_flags['log_steps'] = 10
default_flags['use_tensor_lr'] = False
super(Resnet56KerasBenchmarkReal, self).__init__( super(Resnet56KerasBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=default_flags) output_dir=output_dir, default_flags=default_flags)
......
...@@ -71,7 +71,6 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -71,7 +71,6 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.epochs_between_evals = 10 FLAGS.epochs_between_evals = 10
FLAGS.model_dir = self._get_model_dir('benchmark_graph_8_gpu') FLAGS.model_dir = self._get_model_dir('benchmark_graph_8_gpu')
FLAGS.dtype = 'fp32' FLAGS.dtype = 'fp32'
FLAGS.use_tensor_lr = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_8_gpu(self): def benchmark_8_gpu(self):
...@@ -87,7 +86,6 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -87,7 +86,6 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.enable_eager = True FLAGS.enable_eager = True
# Add some thread tunings to improve performance. # Add some thread tunings to improve performance.
FLAGS.datasets_num_private_threads = 14 FLAGS.datasets_num_private_threads = 14
FLAGS.use_tensor_lr = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_8_gpu_amp(self): def benchmark_8_gpu_amp(self):
...@@ -104,7 +102,6 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -104,7 +102,6 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.fp16_implementation = 'graph_rewrite' FLAGS.fp16_implementation = 'graph_rewrite'
# Add some thread tunings to improve performance. # Add some thread tunings to improve performance.
FLAGS.datasets_num_private_threads = 14 FLAGS.datasets_num_private_threads = 14
FLAGS.use_tensor_lr = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_8_gpu_fp16(self): def benchmark_8_gpu_fp16(self):
...@@ -120,7 +117,6 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -120,7 +117,6 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.enable_eager = True FLAGS.enable_eager = True
# Thread tuning to improve performance. # Thread tuning to improve performance.
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.use_tensor_lr = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_xla_8_gpu_fp16(self): def benchmark_xla_8_gpu_fp16(self):
...@@ -137,7 +133,6 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -137,7 +133,6 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.enable_xla = True FLAGS.enable_xla = True
# Thread tuning to improve performance. # Thread tuning to improve performance.
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.use_tensor_lr = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_8_gpu_mlperf_like(self): def benchmark_8_gpu_mlperf_like(self):
...@@ -179,7 +174,6 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -179,7 +174,6 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.loss_scale = 'dynamic' FLAGS.loss_scale = 'dynamic'
# Thread tuning to improve performance. # Thread tuning to improve performance.
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.use_tensor_lr = True
self._run_and_report_benchmark(top_1_min=0.736) self._run_and_report_benchmark(top_1_min=0.736)
@benchmark_wrappers.enable_runtime_flags @benchmark_wrappers.enable_runtime_flags
...@@ -241,7 +235,6 @@ class MobilenetV1KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -241,7 +235,6 @@ class MobilenetV1KerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.enable_eager = True FLAGS.enable_eager = True
# Add some thread tunings to improve performance. # Add some thread tunings to improve performance.
FLAGS.datasets_num_private_threads = 14 FLAGS.datasets_num_private_threads = 14
FLAGS.use_tensor_lr = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
@benchmark_wrappers.enable_runtime_flags @benchmark_wrappers.enable_runtime_flags
...@@ -472,7 +465,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -472,7 +465,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_fp16_tweaked') FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_fp16_tweaked')
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.batch_size = 256 FLAGS.batch_size = 256
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -550,7 +542,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -550,7 +542,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
'benchmark_graph_xla_1_gpu_fp16_tweaked') 'benchmark_graph_xla_1_gpu_fp16_tweaked')
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.batch_size = 256 FLAGS.batch_size = 256
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -587,7 +578,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -587,7 +578,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_tweaked') FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_tweaked')
FLAGS.batch_size = 128 * 8 # 8 GPUs FLAGS.batch_size = 128 * 8 # 8 GPUs
FLAGS.use_tensor_lr = True
FLAGS.datasets_num_private_threads = 14 FLAGS.datasets_num_private_threads = 14
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -627,7 +617,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -627,7 +617,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_tweaked') FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_tweaked')
FLAGS.batch_size = 128 * 8 FLAGS.batch_size = 128 * 8
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 24 FLAGS.datasets_num_private_threads = 24
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -654,7 +643,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -654,7 +643,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16_tweaked') FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -670,7 +658,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -670,7 +658,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
'benchmark_8_gpu_fp16_dynamic_tweaked') 'benchmark_8_gpu_fp16_dynamic_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.loss_scale = 'dynamic' FLAGS.loss_scale = 'dynamic'
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -698,7 +685,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -698,7 +685,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_fp16_tweaked') FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_fp16_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 48 FLAGS.datasets_num_private_threads = 48
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -718,7 +704,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -718,7 +704,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.model_dir = self._get_model_dir( FLAGS.model_dir = self._get_model_dir(
'benchmark_xla_8_gpu_fp16_tweaked_delay_measure') 'benchmark_xla_8_gpu_fp16_tweaked_delay_measure')
FLAGS.batch_size = 256 * 8 FLAGS.batch_size = 256 * 8
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.train_steps = 310 FLAGS.train_steps = 310
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -736,7 +721,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -736,7 +721,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
'benchmark_xla_8_gpu_fp16_dynamic_tweaked') 'benchmark_xla_8_gpu_fp16_dynamic_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.loss_scale = 'dynamic' FLAGS.loss_scale = 'dynamic'
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 48 FLAGS.datasets_num_private_threads = 48
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -799,7 +783,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -799,7 +783,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_graph_8_gpu_fp16_tweaked') FLAGS.model_dir = self._get_model_dir('benchmark_graph_8_gpu_fp16_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -815,7 +798,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -815,7 +798,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.model_dir = self._get_model_dir( FLAGS.model_dir = self._get_model_dir(
'benchmark_graph_xla_8_gpu_fp16_tweaked') 'benchmark_graph_xla_8_gpu_fp16_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -834,7 +816,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -834,7 +816,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.model_dir = self._get_model_dir( FLAGS.model_dir = self._get_model_dir(
'benchmark_graph_xla_8_gpu_fp16_tweaked_delay_measure') 'benchmark_graph_xla_8_gpu_fp16_tweaked_delay_measure')
FLAGS.batch_size = 256 * 8 FLAGS.batch_size = 256 * 8
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.train_steps = 310 FLAGS.train_steps = 310
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -851,7 +832,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -851,7 +832,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
'benchmark_graph_8_gpu_fp16_dynamic_tweaked') 'benchmark_graph_8_gpu_fp16_dynamic_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.loss_scale = 'dynamic' FLAGS.loss_scale = 'dynamic'
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -867,7 +847,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -867,7 +847,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.model_dir = self._get_model_dir( FLAGS.model_dir = self._get_model_dir(
'benchmark_graph_xla_8_gpu_fp16_dynamic_tweaked') 'benchmark_graph_xla_8_gpu_fp16_dynamic_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.use_tensor_lr = True
FLAGS.loss_scale = 'dynamic' FLAGS.loss_scale = 'dynamic'
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -963,7 +942,6 @@ class TrivialKerasBenchmarkReal(keras_benchmark.KerasBenchmark): ...@@ -963,7 +942,6 @@ class TrivialKerasBenchmarkReal(keras_benchmark.KerasBenchmark):
def_flags['use_trivial_model'] = True def_flags['use_trivial_model'] = True
def_flags['skip_eval'] = True def_flags['skip_eval'] = True
def_flags['report_accuracy_metrics'] = False def_flags['report_accuracy_metrics'] = False
def_flags['use_tensor_lr'] = True
def_flags['dtype'] = 'fp16' def_flags['dtype'] = 'fp16'
def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet') def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
def_flags['train_steps'] = 600 def_flags['train_steps'] = 600
...@@ -1097,7 +1075,6 @@ class Resnet50MultiWorkerKerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -1097,7 +1075,6 @@ class Resnet50MultiWorkerKerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.enable_eager = eager FLAGS.enable_eager = eager
FLAGS.enable_xla = False FLAGS.enable_xla = False
FLAGS.distribution_strategy = 'multi_worker_mirrored' FLAGS.distribution_strategy = 'multi_worker_mirrored'
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 32 FLAGS.datasets_num_private_threads = 32
FLAGS.model_dir = self._get_model_dir( FLAGS.model_dir = self._get_model_dir(
...@@ -1161,7 +1138,6 @@ class Resnet50MultiWorkerKerasBenchmark(Resnet50KerasBenchmarkBase): ...@@ -1161,7 +1138,6 @@ class Resnet50MultiWorkerKerasBenchmark(Resnet50KerasBenchmarkBase):
FLAGS.enable_eager = eager FLAGS.enable_eager = eager
FLAGS.enable_xla = False FLAGS.enable_xla = False
FLAGS.distribution_strategy = 'multi_worker_mirrored' FLAGS.distribution_strategy = 'multi_worker_mirrored'
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 32 FLAGS.datasets_num_private_threads = 32
FLAGS.model_dir = self._get_model_dir( FLAGS.model_dir = self._get_model_dir(
......
...@@ -18,7 +18,7 @@ from __future__ import absolute_import ...@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl import app as absl_app import numpy as np
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
from official.benchmark.models import resnet_cifar_model from official.benchmark.models import resnet_cifar_model
...@@ -64,6 +64,46 @@ def learning_rate_schedule(current_epoch, ...@@ -64,6 +64,46 @@ def learning_rate_schedule(current_epoch,
return learning_rate return learning_rate
class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
"""Callback to update learning rate on every batch (not epoch boundaries).
N.B. Only support Keras optimizers, not TF optimizers.
Attributes:
schedule: a function that takes an epoch index and a batch index as input
(both integer, indexed from 0) and returns a new learning rate as
output (float).
"""
def __init__(self, schedule, batch_size, steps_per_epoch):
super(LearningRateBatchScheduler, self).__init__()
self.schedule = schedule
self.steps_per_epoch = steps_per_epoch
self.batch_size = batch_size
self.epochs = -1
self.prev_lr = -1
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'learning_rate'):
raise ValueError('Optimizer must have a "learning_rate" attribute.')
self.epochs += 1
def on_batch_begin(self, batch, logs=None):
"""Executes before step begins."""
lr = self.schedule(self.epochs,
batch,
self.steps_per_epoch,
self.batch_size)
if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function should be float.')
if lr != self.prev_lr:
self.model.optimizer.learning_rate = lr # lr should be a float here
self.prev_lr = lr
tf.compat.v1.logging.debug(
'Epoch %05d Batch %05d: LearningRateBatchScheduler '
'change learning rate to %s.', self.epochs, batch, lr)
def run(flags_obj): def run(flags_obj):
"""Run ResNet Cifar-10 training and eval loop using native Keras APIs. """Run ResNet Cifar-10 training and eval loop using native Keras APIs.
...@@ -151,8 +191,18 @@ def run(flags_obj): ...@@ -151,8 +191,18 @@ def run(flags_obj):
num_epochs=flags_obj.train_epochs, num_epochs=flags_obj.train_epochs,
parse_record_fn=cifar_preprocessing.parse_record) parse_record_fn=cifar_preprocessing.parse_record)
steps_per_epoch = (
cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
lr_schedule = 0.1
if flags_obj.use_tensor_lr:
initial_learning_rate = common.BASE_LEARNING_RATE * flags_obj.batch_size / 128
lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
boundaries=list(p[1] * steps_per_epoch for p in LR_SCHEDULE),
values=[initial_learning_rate] +
list(p[0] * initial_learning_rate for p in LR_SCHEDULE))
with strategy_scope: with strategy_scope:
optimizer = common.get_optimizer() optimizer = common.get_optimizer(lr_schedule)
model = resnet_cifar_model.resnet56(classes=cifar_preprocessing.NUM_CLASSES) model = resnet_cifar_model.resnet56(classes=cifar_preprocessing.NUM_CLASSES)
# TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
...@@ -173,11 +223,16 @@ def run(flags_obj): ...@@ -173,11 +223,16 @@ def run(flags_obj):
if flags_obj.report_accuracy_metrics else None), if flags_obj.report_accuracy_metrics else None),
run_eagerly=flags_obj.run_eagerly) run_eagerly=flags_obj.run_eagerly)
steps_per_epoch = (
cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
train_epochs = flags_obj.train_epochs train_epochs = flags_obj.train_epochs
callbacks = common.get_callbacks(steps_per_epoch, learning_rate_schedule) callbacks = common.get_callbacks(steps_per_epoch)
if not flags_obj.use_tensor_lr:
lr_callback = LearningRateBatchScheduler(
schedule=learning_rate_schedule,
batch_size=flags_obj.batch_size,
steps_per_epoch=steps_per_epoch)
callbacks.append(lr_callback)
# if mutliple epochs, ignore the train_steps flag. # if mutliple epochs, ignore the train_steps flag.
if train_epochs <= 1 and flags_obj.train_steps: if train_epochs <= 1 and flags_obj.train_steps:
......
...@@ -20,7 +20,6 @@ from __future__ import print_function ...@@ -20,7 +20,6 @@ from __future__ import print_function
import os import os
from absl import flags from absl import flags
import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2 from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
...@@ -36,78 +35,6 @@ LR_SCHEDULE = [ # (multiplier, epoch to start) tuples ...@@ -36,78 +35,6 @@ LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
] ]
def learning_rate_schedule(current_epoch,
current_batch,
steps_per_epoch,
batch_size):
"""Handles linear scaling rule, gradual warmup, and LR decay.
Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the
provided scaling factor.
Args:
current_epoch: integer, current epoch indexed from 0.
current_batch: integer, current batch in the current epoch, indexed from 0.
steps_per_epoch: integer, number of steps in an epoch.
batch_size: integer, total batch sized.
Returns:
Adjusted learning rate.
"""
initial_lr = BASE_LEARNING_RATE * batch_size / 256
epoch = current_epoch + float(current_batch) / steps_per_epoch
warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0]
if epoch < warmup_end_epoch:
# Learning rate increases linearly per step.
return initial_lr * warmup_lr_multiplier * epoch / warmup_end_epoch
for mult, start_epoch in LR_SCHEDULE:
if epoch >= start_epoch:
learning_rate = initial_lr * mult
else:
break
return learning_rate
class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
"""Callback to update learning rate on every batch (not epoch boundaries).
N.B. Only support Keras optimizers, not TF optimizers.
Attributes:
schedule: a function that takes an epoch index and a batch index as input
(both integer, indexed from 0) and returns a new learning rate as
output (float).
"""
def __init__(self, schedule, batch_size, steps_per_epoch):
super(LearningRateBatchScheduler, self).__init__()
self.schedule = schedule
self.steps_per_epoch = steps_per_epoch
self.batch_size = batch_size
self.epochs = -1
self.prev_lr = -1
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'learning_rate'):
raise ValueError('Optimizer must have a "learning_rate" attribute.')
self.epochs += 1
def on_batch_begin(self, batch, logs=None):
"""Executes before step begins."""
lr = self.schedule(self.epochs,
batch,
self.steps_per_epoch,
self.batch_size)
if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function should be float.')
if lr != self.prev_lr:
self.model.optimizer.learning_rate = lr # lr should be a float here
self.prev_lr = lr
tf.compat.v1.logging.debug(
'Epoch %05d Batch %05d: LearningRateBatchScheduler '
'change learning rate to %s.', self.epochs, batch, lr)
class PiecewiseConstantDecayWithWarmup( class PiecewiseConstantDecayWithWarmup(
tf.keras.optimizers.schedules.LearningRateSchedule): tf.keras.optimizers.schedules.LearningRateSchedule):
"""Piecewise constant decay with warmup schedule.""" """Piecewise constant decay with warmup schedule."""
...@@ -180,10 +107,8 @@ def get_optimizer(learning_rate=0.1): ...@@ -180,10 +107,8 @@ def get_optimizer(learning_rate=0.1):
return gradient_descent_v2.SGD(learning_rate=learning_rate, momentum=0.9) return gradient_descent_v2.SGD(learning_rate=learning_rate, momentum=0.9)
# TODO(hongkuny,haoyuzhang): make cifar model use_tensor_lr to clean up code.
def get_callbacks( def get_callbacks(
steps_per_epoch, steps_per_epoch,
learning_rate_schedule_fn=None,
pruning_method=None, pruning_method=None,
enable_checkpoint_and_export=False, enable_checkpoint_and_export=False,
model_dir=None): model_dir=None):
...@@ -194,13 +119,6 @@ def get_callbacks( ...@@ -194,13 +119,6 @@ def get_callbacks(
logdir=FLAGS.model_dir if FLAGS.enable_tensorboard else None) logdir=FLAGS.model_dir if FLAGS.enable_tensorboard else None)
callbacks = [time_callback] callbacks = [time_callback]
if not FLAGS.use_tensor_lr and learning_rate_schedule_fn:
lr_callback = LearningRateBatchScheduler(
learning_rate_schedule_fn,
batch_size=FLAGS.batch_size,
steps_per_epoch=steps_per_epoch)
callbacks.append(lr_callback)
if FLAGS.enable_tensorboard: if FLAGS.enable_tensorboard:
tensorboard_callback = tf.keras.callbacks.TensorBoard( tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=FLAGS.model_dir) log_dir=FLAGS.model_dir)
...@@ -317,7 +235,7 @@ def define_keras_flags( ...@@ -317,7 +235,7 @@ def define_keras_flags(
help='Whether to use a trivial Keras model.') help='Whether to use a trivial Keras model.')
flags.DEFINE_boolean(name='report_accuracy_metrics', default=True, flags.DEFINE_boolean(name='report_accuracy_metrics', default=True,
help='Report metrics during training and evaluation.') help='Report metrics during training and evaluation.')
flags.DEFINE_boolean(name='use_tensor_lr', default=False, flags.DEFINE_boolean(name='use_tensor_lr', default=True,
help='Use learning rate tensor instead of a callback.') help='Use learning rate tensor instead of a callback.')
flags.DEFINE_boolean( flags.DEFINE_boolean(
name='enable_tensorboard', default=False, name='enable_tensorboard', default=False,
......
...@@ -155,23 +155,19 @@ def run(flags_obj): ...@@ -155,23 +155,19 @@ def run(flags_obj):
dtype=dtype, dtype=dtype,
drop_remainder=drop_remainder) drop_remainder=drop_remainder)
lr_schedule = 0.1 lr_schedule = common.PiecewiseConstantDecayWithWarmup(
if flags_obj.use_tensor_lr: batch_size=flags_obj.batch_size,
lr_schedule = common.PiecewiseConstantDecayWithWarmup( epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
batch_size=flags_obj.batch_size, warmup_epochs=common.LR_SCHEDULE[0][1],
epoch_size=imagenet_preprocessing.NUM_IMAGES['train'], boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
warmup_epochs=common.LR_SCHEDULE[0][1], multipliers=list(p[0] for p in common.LR_SCHEDULE),
boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]), compute_lr_on_cpu=True)
multipliers=list(p[0] for p in common.LR_SCHEDULE),
compute_lr_on_cpu=True)
steps_per_epoch = ( steps_per_epoch = (
imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size) imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
learning_rate_schedule_fn = None
with strategy_scope: with strategy_scope:
if flags_obj.optimizer == 'resnet50_default': if flags_obj.optimizer == 'resnet50_default':
optimizer = common.get_optimizer(lr_schedule) optimizer = common.get_optimizer(lr_schedule)
learning_rate_schedule_fn = common.learning_rate_schedule
elif flags_obj.optimizer == 'mobilenet_default': elif flags_obj.optimizer == 'mobilenet_default':
initial_learning_rate = \ initial_learning_rate = \
flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size
...@@ -248,7 +244,6 @@ def run(flags_obj): ...@@ -248,7 +244,6 @@ def run(flags_obj):
callbacks = common.get_callbacks( callbacks = common.get_callbacks(
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
learning_rate_schedule_fn=learning_rate_schedule_fn,
pruning_method=flags_obj.pruning_method, pruning_method=flags_obj.pruning_method,
enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export, enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
model_dir=flags_obj.model_dir) model_dir=flags_obj.model_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment