Commit b974c3f9 authored by Haoyu Zhang's avatar Haoyu Zhang Committed by A. Unique TensorFlower
Browse files

Moving Keras ResNet models to `official/vision/image_classification` and...

Moving Keras ResNet models to `official/vision/image_classification` and benchmarks to `official/benchmark`.

PiperOrigin-RevId: 264268533
parent b1188d03
...@@ -22,8 +22,8 @@ import time ...@@ -22,8 +22,8 @@ import time
from absl import flags from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet.keras import keras_benchmark from official.benchmark import keras_benchmark
from official.resnet.keras import keras_cifar_main from official.vision.image_classification import resnet_cifar_main
MIN_TOP_1_ACCURACY = 0.929 MIN_TOP_1_ACCURACY = 0.929
MAX_TOP_1_ACCURACY = 0.938 MAX_TOP_1_ACCURACY = 0.938
...@@ -47,7 +47,7 @@ class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -47,7 +47,7 @@ class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark):
""" """
self.data_dir = os.path.join(root_data_dir, CIFAR_DATA_DIR_NAME) self.data_dir = os.path.join(root_data_dir, CIFAR_DATA_DIR_NAME)
flag_methods = [keras_cifar_main.define_cifar_flags] flag_methods = [resnet_cifar_main.define_cifar_flags]
super(Resnet56KerasAccuracy, self).__init__( super(Resnet56KerasAccuracy, self).__init__(
output_dir=output_dir, flag_methods=flag_methods) output_dir=output_dir, flag_methods=flag_methods)
...@@ -199,7 +199,7 @@ class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -199,7 +199,7 @@ class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark):
def _run_and_report_benchmark(self): def _run_and_report_benchmark(self):
start_time_sec = time.time() start_time_sec = time.time()
stats = keras_cifar_main.run(FLAGS) stats = resnet_cifar_main.run(FLAGS)
wall_time_sec = time.time() - start_time_sec wall_time_sec = time.time() - start_time_sec
super(Resnet56KerasAccuracy, self)._report_benchmark( super(Resnet56KerasAccuracy, self)._report_benchmark(
...@@ -215,7 +215,7 @@ class Resnet56KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -215,7 +215,7 @@ class Resnet56KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
"""Short performance tests for ResNet56 via Keras and CIFAR-10.""" """Short performance tests for ResNet56 via Keras and CIFAR-10."""
def __init__(self, output_dir=None, default_flags=None): def __init__(self, output_dir=None, default_flags=None):
flag_methods = [keras_cifar_main.define_cifar_flags] flag_methods = [resnet_cifar_main.define_cifar_flags]
super(Resnet56KerasBenchmarkBase, self).__init__( super(Resnet56KerasBenchmarkBase, self).__init__(
output_dir=output_dir, output_dir=output_dir,
...@@ -224,7 +224,7 @@ class Resnet56KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -224,7 +224,7 @@ class Resnet56KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
def _run_and_report_benchmark(self): def _run_and_report_benchmark(self):
start_time_sec = time.time() start_time_sec = time.time()
stats = keras_cifar_main.run(FLAGS) stats = resnet_cifar_main.run(FLAGS)
wall_time_sec = time.time() - start_time_sec wall_time_sec = time.time() - start_time_sec
super(Resnet56KerasBenchmarkBase, self)._report_benchmark( super(Resnet56KerasBenchmarkBase, self)._report_benchmark(
......
...@@ -21,8 +21,8 @@ import time ...@@ -21,8 +21,8 @@ import time
from absl import flags from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet.keras import keras_benchmark from official.benchmark import keras_benchmark
from official.resnet.keras import keras_imagenet_main from official.vision.image_classification import resnet_imagenet_main
MIN_TOP_1_ACCURACY = 0.76 MIN_TOP_1_ACCURACY = 0.76
MAX_TOP_1_ACCURACY = 0.77 MAX_TOP_1_ACCURACY = 0.77
...@@ -44,7 +44,7 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -44,7 +44,7 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
named arguments before updating the constructor. named arguments before updating the constructor.
""" """
flag_methods = [keras_imagenet_main.define_imagenet_keras_flags] flag_methods = [resnet_imagenet_main.define_imagenet_keras_flags]
self.data_dir = os.path.join(root_data_dir, 'imagenet') self.data_dir = os.path.join(root_data_dir, 'imagenet')
super(Resnet50KerasAccuracy, self).__init__( super(Resnet50KerasAccuracy, self).__init__(
...@@ -158,7 +158,7 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -158,7 +158,7 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
top_1_min=MIN_TOP_1_ACCURACY, top_1_min=MIN_TOP_1_ACCURACY,
top_1_max=MAX_TOP_1_ACCURACY): top_1_max=MAX_TOP_1_ACCURACY):
start_time_sec = time.time() start_time_sec = time.time()
stats = keras_imagenet_main.run(flags.FLAGS) stats = resnet_imagenet_main.run(flags.FLAGS)
wall_time_sec = time.time() - start_time_sec wall_time_sec = time.time() - start_time_sec
super(Resnet50KerasAccuracy, self)._report_benchmark( super(Resnet50KerasAccuracy, self)._report_benchmark(
...@@ -177,7 +177,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -177,7 +177,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
"""Resnet50 benchmarks.""" """Resnet50 benchmarks."""
def __init__(self, output_dir=None, default_flags=None): def __init__(self, output_dir=None, default_flags=None):
flag_methods = [keras_imagenet_main.define_imagenet_keras_flags] flag_methods = [resnet_imagenet_main.define_imagenet_keras_flags]
super(Resnet50KerasBenchmarkBase, self).__init__( super(Resnet50KerasBenchmarkBase, self).__init__(
output_dir=output_dir, output_dir=output_dir,
...@@ -186,7 +186,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -186,7 +186,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
def _run_and_report_benchmark(self): def _run_and_report_benchmark(self):
start_time_sec = time.time() start_time_sec = time.time()
stats = keras_imagenet_main.run(FLAGS) stats = resnet_imagenet_main.run(FLAGS)
wall_time_sec = time.time() - start_time_sec wall_time_sec = time.time() - start_time_sec
# Number of logged step time entries that are excluded in performance # Number of logged step time entries that are excluded in performance
# report. We keep results from last 100 batches in this case. # report. We keep results from last 100 batches in this case.
...@@ -779,7 +779,7 @@ class TrivialKerasBenchmarkReal(keras_benchmark.KerasBenchmark): ...@@ -779,7 +779,7 @@ class TrivialKerasBenchmarkReal(keras_benchmark.KerasBenchmark):
"""Trivial model with real data benchmark tests.""" """Trivial model with real data benchmark tests."""
def __init__(self, output_dir=None, root_data_dir=None, **kwargs): def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
flag_methods = [keras_imagenet_main.define_imagenet_keras_flags] flag_methods = [resnet_imagenet_main.define_imagenet_keras_flags]
def_flags = {} def_flags = {}
def_flags['use_trivial_model'] = True def_flags['use_trivial_model'] = True
...@@ -799,7 +799,7 @@ class TrivialKerasBenchmarkReal(keras_benchmark.KerasBenchmark): ...@@ -799,7 +799,7 @@ class TrivialKerasBenchmarkReal(keras_benchmark.KerasBenchmark):
def _run_and_report_benchmark(self): def _run_and_report_benchmark(self):
start_time_sec = time.time() start_time_sec = time.time()
stats = keras_imagenet_main.run(FLAGS) stats = resnet_imagenet_main.run(FLAGS)
wall_time_sec = time.time() - start_time_sec wall_time_sec = time.time() - start_time_sec
super(TrivialKerasBenchmarkReal, self)._report_benchmark( super(TrivialKerasBenchmarkReal, self)._report_benchmark(
......
# ResNet in TensorFlow # ResNet in TensorFlow
* For the Keras version of the ResNet model, see * For the Keras version of the ResNet model, see
[`official/resnet/keras`](keras). [`official/vision/image_classification`](../vision/image_classification).
* For the Keras custom training loop version, see * For the Keras custom training loop version, see
[`official/resnet/ctl`](ctl). [`official/resnet/ctl`](ctl).
* For the Estimator version, see [`official/r1/resnet`](../r1/resnet). * For the Estimator version, see [`official/r1/resnet`](../r1/resnet).
# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Bring in the shared Keras ResNet modules into this module.
The TensorFlow official Keras models are moved under
official/vision/image_classification
In order to be backward compatible with models that directly import its modules,
we import the Keras ResNet modules under official.resnet.keras.
New TF models should not depend on modules directly under this path.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from official.vision.image_classification import cifar_preprocessing
from official.vision.image_classification import common as keras_common
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import resnet_cifar_main as keras_cifar_main
from official.vision.image_classification import resnet_cifar_model
from official.vision.image_classification import resnet_imagenet_main as keras_imagenet_main
from official.vision.image_classification import resnet_model
del absolute_import
del division
del print_function
This folder contains the Keras implementation of the ResNet models. For more This folder contains the Keras implementation of the ResNet models. For more
information about the models, please refer to this [README file](../README.md). information about the models, please refer to this [README file](../../README.md).
Similar to the [estimator implementation](/official/resnet), the Keras Similar to the [estimator implementation](../../r1/resnet), the Keras
implementation has code for both CIFAR-10 data and ImageNet data. The CIFAR-10 implementation has code for both CIFAR-10 data and ImageNet data. The CIFAR-10
version uses a ResNet56 model implemented in version uses a ResNet56 model implemented in
[`resnet_cifar_model.py`](./resnet_cifar_model.py), and the ImageNet version [`resnet_cifar_model.py`](./resnet_cifar_model.py), and the ImageNet version
......
...@@ -22,7 +22,7 @@ import os ...@@ -22,7 +22,7 @@ import os
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.resnet.keras import imagenet_preprocessing from official.vision.image_classification import imagenet_preprocessing
HEIGHT = 32 HEIGHT = 32
WIDTH = 32 WIDTH = 32
......
...@@ -20,17 +20,13 @@ from __future__ import print_function ...@@ -20,17 +20,13 @@ from __future__ import print_function
import multiprocessing import multiprocessing
import os import os
import numpy as np
# pylint: disable=g-bad-import-order
from absl import flags from absl import flags
import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
# pylint: disable=ungrouped-imports
from tensorflow.python.keras.optimizer_v2 import (gradient_descent as
gradient_descent_v2)
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
BASE_LEARNING_RATE = 0.1 # This matches Jing's version. BASE_LEARNING_RATE = 0.1 # This matches Jing's version.
......
...@@ -12,21 +12,21 @@ ...@@ -12,21 +12,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for the keras_common module.""" """Tests for the common module."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import print_function from __future__ import print_function
from mock import Mock from mock import Mock
import numpy as np import numpy as np
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf
from tensorflow.python.platform import googletest
from official.resnet.keras import keras_common from tensorflow.python.platform import googletest
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.vision.image_classification import common
class KerasCommonTests(tf.test.TestCase): class KerasCommonTests(tf.test.TestCase):
"""Tests for keras_common.""" """Tests for common."""
@classmethod @classmethod
def setUpClass(cls): # pylint: disable=invalid-name def setUpClass(cls): # pylint: disable=invalid-name
...@@ -42,7 +42,7 @@ class KerasCommonTests(tf.test.TestCase): ...@@ -42,7 +42,7 @@ class KerasCommonTests(tf.test.TestCase):
keras_utils.BatchTimestamp(1, 2), keras_utils.BatchTimestamp(1, 2),
keras_utils.BatchTimestamp(2, 3)] keras_utils.BatchTimestamp(2, 3)]
th.train_finish_time = 12345 th.train_finish_time = 12345
stats = keras_common.build_stats(history, eval_output, [th]) stats = common.build_stats(history, eval_output, [th])
self.assertEqual(1.145, stats['loss']) self.assertEqual(1.145, stats['loss'])
self.assertEqual(.99988, stats['training_accuracy_top_1']) self.assertEqual(.99988, stats['training_accuracy_top_1'])
...@@ -57,7 +57,7 @@ class KerasCommonTests(tf.test.TestCase): ...@@ -57,7 +57,7 @@ class KerasCommonTests(tf.test.TestCase):
history = self._build_history(1.145, cat_accuracy_sparse=.99988) history = self._build_history(1.145, cat_accuracy_sparse=.99988)
eval_output = self._build_eval_output(.928, 1.9844) eval_output = self._build_eval_output(.928, 1.9844)
stats = keras_common.build_stats(history, eval_output, None) stats = common.build_stats(history, eval_output, None)
self.assertEqual(1.145, stats['loss']) self.assertEqual(1.145, stats['loss'])
self.assertEqual(.99988, stats['training_accuracy_top_1']) self.assertEqual(.99988, stats['training_accuracy_top_1'])
......
...@@ -22,13 +22,13 @@ from absl import app as absl_app ...@@ -22,13 +22,13 @@ from absl import app as absl_app
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
from official.resnet.keras import cifar_preprocessing
from official.resnet.keras import keras_common
from official.resnet.keras import resnet_cifar_model
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.logs import logger from official.utils.logs import logger
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.vision.image_classification import cifar_preprocessing
from official.vision.image_classification import common
from official.vision.image_classification import resnet_cifar_model
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
...@@ -55,7 +55,7 @@ def learning_rate_schedule(current_epoch, ...@@ -55,7 +55,7 @@ def learning_rate_schedule(current_epoch,
Adjusted learning rate. Adjusted learning rate.
""" """
del current_batch, batches_per_epoch # not used del current_batch, batches_per_epoch # not used
initial_learning_rate = keras_common.BASE_LEARNING_RATE * batch_size / 128 initial_learning_rate = common.BASE_LEARNING_RATE * batch_size / 128
learning_rate = initial_learning_rate learning_rate = initial_learning_rate
for mult, start_epoch in LR_SCHEDULE: for mult, start_epoch in LR_SCHEDULE:
if current_epoch >= start_epoch: if current_epoch >= start_epoch:
...@@ -83,8 +83,8 @@ def run(flags_obj): ...@@ -83,8 +83,8 @@ def run(flags_obj):
# Execute flag override logic for better model performance # Execute flag override logic for better model performance
if flags_obj.tf_gpu_thread_mode: if flags_obj.tf_gpu_thread_mode:
keras_common.set_gpu_thread_mode_and_count(flags_obj) common.set_gpu_thread_mode_and_count(flags_obj)
keras_common.set_cudnn_batchnorm_mode() common.set_cudnn_batchnorm_mode()
dtype = flags_core.get_tf_dtype(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'fp16': if dtype == 'fp16':
...@@ -116,7 +116,7 @@ def run(flags_obj): ...@@ -116,7 +116,7 @@ def run(flags_obj):
if flags_obj.use_synthetic_data: if flags_obj.use_synthetic_data:
distribution_utils.set_up_synthetic_data() distribution_utils.set_up_synthetic_data()
input_fn = keras_common.get_synth_input_fn( input_fn = common.get_synth_input_fn(
height=cifar_preprocessing.HEIGHT, height=cifar_preprocessing.HEIGHT,
width=cifar_preprocessing.WIDTH, width=cifar_preprocessing.WIDTH,
num_channels=cifar_preprocessing.NUM_CHANNELS, num_channels=cifar_preprocessing.NUM_CHANNELS,
...@@ -150,7 +150,7 @@ def run(flags_obj): ...@@ -150,7 +150,7 @@ def run(flags_obj):
parse_record_fn=cifar_preprocessing.parse_record) parse_record_fn=cifar_preprocessing.parse_record)
with strategy_scope: with strategy_scope:
optimizer = keras_common.get_optimizer() optimizer = common.get_optimizer()
model = resnet_cifar_model.resnet56(classes=cifar_preprocessing.NUM_CLASSES) model = resnet_cifar_model.resnet56(classes=cifar_preprocessing.NUM_CLASSES)
# TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
...@@ -171,7 +171,7 @@ def run(flags_obj): ...@@ -171,7 +171,7 @@ def run(flags_obj):
if flags_obj.report_accuracy_metrics else None), if flags_obj.report_accuracy_metrics else None),
run_eagerly=flags_obj.run_eagerly) run_eagerly=flags_obj.run_eagerly)
callbacks = keras_common.get_callbacks( callbacks = common.get_callbacks(
learning_rate_schedule, cifar_preprocessing.NUM_IMAGES['train']) learning_rate_schedule, cifar_preprocessing.NUM_IMAGES['train'])
train_steps = cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size train_steps = cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size
...@@ -216,12 +216,12 @@ def run(flags_obj): ...@@ -216,12 +216,12 @@ def run(flags_obj):
if not strategy and flags_obj.explicit_gpu_placement: if not strategy and flags_obj.explicit_gpu_placement:
no_dist_strat_device.__exit__() no_dist_strat_device.__exit__()
stats = keras_common.build_stats(history, eval_output, callbacks) stats = common.build_stats(history, eval_output, callbacks)
return stats return stats
def define_cifar_flags(): def define_cifar_flags():
keras_common.define_keras_flags(dynamic_loss_scale=False) common.define_keras_flags(dynamic_loss_scale=False)
flags_core.set_defaults(data_dir='/tmp/cifar10_data/cifar-10-batches-bin', flags_core.set_defaults(data_dir='/tmp/cifar10_data/cifar-10-batches-bin',
model_dir='/tmp/cifar10_model', model_dir='/tmp/cifar10_model',
......
...@@ -18,17 +18,16 @@ from __future__ import absolute_import ...@@ -18,17 +18,16 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tempfile import mkdtemp import tempfile
import tensorflow as tf import tensorflow as tf
from official.resnet.keras import cifar_preprocessing
from official.resnet.keras import keras_cifar_main
from official.resnet.keras import keras_common
from official.utils.misc import keras_utils
from official.utils.testing import integration
# pylint: disable=ungrouped-imports
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.platform import googletest from tensorflow.python.platform import googletest
from official.utils.misc import keras_utils
from official.utils.testing import integration
from official.vision.image_classification import cifar_preprocessing
from official.vision.image_classification import resnet_cifar_main
class KerasCifarTest(googletest.TestCase): class KerasCifarTest(googletest.TestCase):
...@@ -43,13 +42,13 @@ class KerasCifarTest(googletest.TestCase): ...@@ -43,13 +42,13 @@ class KerasCifarTest(googletest.TestCase):
def get_temp_dir(self): def get_temp_dir(self):
if not self._tempdir: if not self._tempdir:
self._tempdir = mkdtemp(dir=googletest.GetTempDir()) self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir())
return self._tempdir return self._tempdir
@classmethod @classmethod
def setUpClass(cls): # pylint: disable=invalid-name def setUpClass(cls): # pylint: disable=invalid-name
super(KerasCifarTest, cls).setUpClass() super(KerasCifarTest, cls).setUpClass()
keras_cifar_main.define_cifar_flags() resnet_cifar_main.define_cifar_flags()
def setUp(self): def setUp(self):
super(KerasCifarTest, self).setUp() super(KerasCifarTest, self).setUp()
...@@ -72,7 +71,7 @@ class KerasCifarTest(googletest.TestCase): ...@@ -72,7 +71,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_cifar_main.run, main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
...@@ -88,7 +87,7 @@ class KerasCifarTest(googletest.TestCase): ...@@ -88,7 +87,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_cifar_main.run, main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
...@@ -112,7 +111,7 @@ class KerasCifarTest(googletest.TestCase): ...@@ -112,7 +111,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_cifar_main.run, main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
...@@ -134,7 +133,7 @@ class KerasCifarTest(googletest.TestCase): ...@@ -134,7 +133,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_cifar_main.run, main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
...@@ -157,7 +156,7 @@ class KerasCifarTest(googletest.TestCase): ...@@ -157,7 +156,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_cifar_main.run, main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
...@@ -178,7 +177,7 @@ class KerasCifarTest(googletest.TestCase): ...@@ -178,7 +177,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_cifar_main.run, main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
......
...@@ -21,17 +21,17 @@ from __future__ import print_function ...@@ -21,17 +21,17 @@ from __future__ import print_function
from absl import app as absl_app from absl import app as absl_app
from absl import flags from absl import flags
from absl import logging from absl import logging
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf
from official.resnet.keras import imagenet_preprocessing
from official.resnet.keras import keras_common
from official.resnet.keras import resnet_model
from official.resnet.keras import trivial_model
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.logs import logger from official.utils.logs import logger
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
from official.vision.image_classification import common
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import resnet_model
from official.vision.image_classification import trivial_model
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
...@@ -57,7 +57,7 @@ def learning_rate_schedule(current_epoch, ...@@ -57,7 +57,7 @@ def learning_rate_schedule(current_epoch,
Returns: Returns:
Adjusted learning rate. Adjusted learning rate.
""" """
initial_lr = keras_common.BASE_LEARNING_RATE * batch_size / 256 initial_lr = common.BASE_LEARNING_RATE * batch_size / 256
epoch = current_epoch + float(current_batch) / batches_per_epoch epoch = current_epoch + float(current_batch) / batches_per_epoch
warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0] warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0]
if epoch < warmup_end_epoch: if epoch < warmup_end_epoch:
...@@ -89,10 +89,10 @@ def run(flags_obj): ...@@ -89,10 +89,10 @@ def run(flags_obj):
# Execute flag override logic for better model performance # Execute flag override logic for better model performance
if flags_obj.tf_gpu_thread_mode: if flags_obj.tf_gpu_thread_mode:
keras_common.set_gpu_thread_mode_and_count(flags_obj) common.set_gpu_thread_mode_and_count(flags_obj)
if flags_obj.data_delay_prefetch: if flags_obj.data_delay_prefetch:
keras_common.data_delay_prefetch() common.data_delay_prefetch()
keras_common.set_cudnn_batchnorm_mode() common.set_cudnn_batchnorm_mode()
dtype = flags_core.get_tf_dtype(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'float16': if dtype == 'float16':
...@@ -129,7 +129,7 @@ def run(flags_obj): ...@@ -129,7 +129,7 @@ def run(flags_obj):
# pylint: disable=protected-access # pylint: disable=protected-access
if flags_obj.use_synthetic_data: if flags_obj.use_synthetic_data:
distribution_utils.set_up_synthetic_data() distribution_utils.set_up_synthetic_data()
input_fn = keras_common.get_synth_input_fn( input_fn = common.get_synth_input_fn(
height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_preprocessing.NUM_CHANNELS, num_channels=imagenet_preprocessing.NUM_CHANNELS,
...@@ -169,7 +169,7 @@ def run(flags_obj): ...@@ -169,7 +169,7 @@ def run(flags_obj):
lr_schedule = 0.1 lr_schedule = 0.1
if flags_obj.use_tensor_lr: if flags_obj.use_tensor_lr:
lr_schedule = keras_common.PiecewiseConstantDecayWithWarmup( lr_schedule = common.PiecewiseConstantDecayWithWarmup(
batch_size=flags_obj.batch_size, batch_size=flags_obj.batch_size,
epoch_size=imagenet_preprocessing.NUM_IMAGES['train'], epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
warmup_epochs=LR_SCHEDULE[0][1], warmup_epochs=LR_SCHEDULE[0][1],
...@@ -178,7 +178,7 @@ def run(flags_obj): ...@@ -178,7 +178,7 @@ def run(flags_obj):
compute_lr_on_cpu=True) compute_lr_on_cpu=True)
with strategy_scope: with strategy_scope:
optimizer = keras_common.get_optimizer(lr_schedule) optimizer = common.get_optimizer(lr_schedule)
if dtype == 'float16': if dtype == 'float16':
# TODO(reedwm): Remove manually wrapping optimizer once mixed precision # TODO(reedwm): Remove manually wrapping optimizer once mixed precision
# can be enabled with a single line of code. # can be enabled with a single line of code.
...@@ -211,7 +211,7 @@ def run(flags_obj): ...@@ -211,7 +211,7 @@ def run(flags_obj):
if flags_obj.report_accuracy_metrics else None), if flags_obj.report_accuracy_metrics else None),
run_eagerly=flags_obj.run_eagerly) run_eagerly=flags_obj.run_eagerly)
callbacks = keras_common.get_callbacks( callbacks = common.get_callbacks(
learning_rate_schedule, imagenet_preprocessing.NUM_IMAGES['train']) learning_rate_schedule, imagenet_preprocessing.NUM_IMAGES['train'])
train_steps = ( train_steps = (
...@@ -261,14 +261,14 @@ def run(flags_obj): ...@@ -261,14 +261,14 @@ def run(flags_obj):
if not strategy and flags_obj.explicit_gpu_placement: if not strategy and flags_obj.explicit_gpu_placement:
no_dist_strat_device.__exit__() no_dist_strat_device.__exit__()
stats = keras_common.build_stats(history, eval_output, callbacks) stats = common.build_stats(history, eval_output, callbacks)
return stats return stats
def define_imagenet_keras_flags(): def define_imagenet_keras_flags():
keras_common.define_keras_flags() common.define_keras_flags()
flags_core.set_defaults(train_epochs=90) flags_core.set_defaults(train_epochs=90)
flags.adopt_module_key_flags(keras_common) flags.adopt_module_key_flags(common)
def main(_): def main(_):
......
...@@ -18,16 +18,16 @@ from __future__ import absolute_import ...@@ -18,16 +18,16 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tempfile import mkdtemp import tempfile
import tensorflow as tf import tensorflow as tf
from official.resnet.keras import imagenet_preprocessing
from official.resnet.keras import keras_imagenet_main
from official.utils.misc import keras_utils
from official.utils.testing import integration
# pylint: disable=ungrouped-imports
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.platform import googletest from tensorflow.python.platform import googletest
from official.utils.misc import keras_utils
from official.utils.testing import integration
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import resnet_imagenet_main
class KerasImagenetTest(googletest.TestCase): class KerasImagenetTest(googletest.TestCase):
...@@ -42,13 +42,13 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -42,13 +42,13 @@ class KerasImagenetTest(googletest.TestCase):
def get_temp_dir(self): def get_temp_dir(self):
if not self._tempdir: if not self._tempdir:
self._tempdir = mkdtemp(dir=googletest.GetTempDir()) self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir())
return self._tempdir return self._tempdir
@classmethod @classmethod
def setUpClass(cls): # pylint: disable=invalid-name def setUpClass(cls): # pylint: disable=invalid-name
super(KerasImagenetTest, cls).setUpClass() super(KerasImagenetTest, cls).setUpClass()
keras_imagenet_main.define_imagenet_keras_flags() resnet_imagenet_main.define_imagenet_keras_flags()
def setUp(self): def setUp(self):
super(KerasImagenetTest, self).setUp() super(KerasImagenetTest, self).setUp()
...@@ -71,7 +71,7 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -71,7 +71,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_imagenet_main.run, main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
...@@ -87,7 +87,7 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -87,7 +87,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_imagenet_main.run, main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
...@@ -111,7 +111,7 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -111,7 +111,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_imagenet_main.run, main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
...@@ -133,7 +133,7 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -133,7 +133,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_imagenet_main.run, main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
...@@ -156,7 +156,7 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -156,7 +156,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_imagenet_main.run, main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
...@@ -180,7 +180,7 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -180,7 +180,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_imagenet_main.run, main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
...@@ -204,7 +204,7 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -204,7 +204,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_imagenet_main.run, main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
...@@ -229,7 +229,7 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -229,7 +229,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_imagenet_main.run, main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
...@@ -250,7 +250,7 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -250,7 +250,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_imagenet_main.run, main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
...@@ -272,7 +272,7 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -272,7 +272,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_imagenet_main.run, main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment