Commit f2c61881 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Move CTL resnet example.

PiperOrigin-RevId: 275417626
parent 6cd426d9
......@@ -23,8 +23,7 @@ from absl import flags
import tensorflow as tf
from official.vision.image_classification import common
from official.resnet.ctl import ctl_imagenet_main
from official.resnet.ctl import ctl_common
from official.vision.image_classification import resnet_ctl_imagenet_main
from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
from official.utils.flags import core as flags_core
......@@ -121,7 +120,7 @@ class Resnet50CtlAccuracy(CtlBenchmark):
arguments before updating the constructor.
"""
flag_methods = [ctl_common.define_ctl_flags, common.define_keras_flags]
flag_methods = [common.define_keras_flags]
self.data_dir = os.path.join(root_data_dir, 'imagenet')
super(Resnet50CtlAccuracy, self).__init__(
......@@ -158,7 +157,7 @@ class Resnet50CtlAccuracy(CtlBenchmark):
def _run_and_report_benchmark(self):
start_time_sec = time.time()
stats = ctl_imagenet_main.run(flags.FLAGS)
stats = resnet_ctl_imagenet_main.run(flags.FLAGS)
wall_time_sec = time.time() - start_time_sec
super(Resnet50CtlAccuracy, self)._report_benchmark(
......@@ -177,7 +176,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
"""Resnet50 benchmarks."""
def __init__(self, output_dir=None, default_flags=None):
flag_methods = [ctl_common.define_ctl_flags, common.define_keras_flags]
flag_methods = [common.define_keras_flags]
super(Resnet50CtlBenchmarkBase, self).__init__(
output_dir=output_dir,
......@@ -186,7 +185,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
def _run_and_report_benchmark(self):
start_time_sec = time.time()
stats = ctl_imagenet_main.run(FLAGS)
stats = resnet_ctl_imagenet_main.run(FLAGS)
wall_time_sec = time.time() - start_time_sec
# Number of logged step time entries that are excluded in performance
......
......@@ -2,6 +2,6 @@
* For the Keras version of the ResNet model, see
[`official/vision/image_classification`](../vision/image_classification).
* For the Keras custom training loop version, see
[`official/resnet/ctl`](ctl).
* For the Keras custom training loop version, also see
[`official/vision/image_classification`](../vision/image_classification).
* For the Estimator version, see [`official/r1/resnet`](../r1/resnet).
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Bring in the shared ResNet modules into this module.
The TensorFlow v1 official models are moved under official/r1/resnet. In order
to be backward compatible with models that directly import v1 modules, we import
the v1 ResNet modules under official.resnet.
New TF models should not depend on modules directly under this path (which will
soon be deprecated and removed).
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from official.r1.resnet import cifar10_main
from official.r1.resnet import imagenet_main
from official.r1.resnet import imagenet_preprocessing
from official.r1.resnet import resnet_model
from official.r1.resnet import resnet_run_loop
del absolute_import
del division
del print_function
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common util functions and classes used by CTL."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
def define_ctl_flags():
"""Define flags for CTL."""
flags.DEFINE_boolean(name='use_tf_function', default=True,
help='Wrap the train and test step inside a '
'tf.function.')
flags.DEFINE_boolean(name='single_l2_loss_op', default=False,
help='Calculate L2_loss on concatenated weights, '
'instead of using Keras per-layer L2 loss.')
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test the ResNet model with ImageNet data using CTL."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tempfile import mkdtemp
import tensorflow as tf
from tensorflow.python.platform import googletest
from official.resnet.ctl import ctl_common
from official.resnet.ctl import ctl_imagenet_main
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import common
from official.utils.misc import keras_utils
from official.utils.testing import integration
class CtlImagenetTest(googletest.TestCase):
"""Unit tests for Keras ResNet with ImageNet using CTL."""
_extra_flags = [
'-batch_size', '4',
'-train_steps', '4',
'-use_synthetic_data', 'true'
]
_tempdir = None
def get_temp_dir(self):
if not self._tempdir:
self._tempdir = mkdtemp(dir=googletest.GetTempDir())
return self._tempdir
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(CtlImagenetTest, cls).setUpClass()
common.define_keras_flags()
ctl_common.define_ctl_flags()
def setUp(self):
super(CtlImagenetTest, self).setUp()
if not keras_utils.is_v2_0():
tf.compat.v1.enable_v2_behavior()
imagenet_preprocessing.NUM_IMAGES['validation'] = 4
def tearDown(self):
super(CtlImagenetTest, self).tearDown()
tf.io.gfile.rmtree(self.get_temp_dir())
def test_end_to_end_tpu(self):
"""Test Keras model with TPU distribution strategy."""
extra_flags = [
'-distribution_strategy', 'tpu',
'-model_dir', 'ctl_imagenet_tpu_dist_strat',
'-data_format', 'channels_last',
'-use_tf_function', 'true',
'-single_l2_loss_op', 'true',
]
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=ctl_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
def test_end_to_end_tpu_bf16(self):
"""Test Keras model with TPU and bfloat16 activation."""
extra_flags = [
'-distribution_strategy', 'tpu',
'-model_dir', 'ctl_imagenet_tpu_dist_strat_bf16',
'-data_format', 'channels_last',
'-use_tf_function', 'true',
'-single_l2_loss_op', 'true',
'-dtype', 'bf16',
]
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=ctl_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
if __name__ == '__main__':
googletest.main()
......@@ -23,7 +23,6 @@ from absl import flags
from absl import logging
import tensorflow as tf
from official.resnet.ctl import ctl_common
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import common
from official.vision.image_classification import resnet_model
......@@ -33,6 +32,13 @@ from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.utils.misc import model_helpers
flags.DEFINE_boolean(name='use_tf_function', default=True,
help='Wrap the train and test step inside a '
'tf.function.')
flags.DEFINE_boolean(name='single_l2_loss_op', default=False,
help='Calculate L2_loss on concatenated weights, '
'instead of using Keras per-layer L2 loss.')
def build_stats(train_result, eval_result, time_callback):
"""Normalizes and returns dictionary of stats.
......@@ -379,6 +385,4 @@ def main(_):
if __name__ == '__main__':
logging.set_verbosity(logging.INFO)
common.define_keras_flags()
ctl_common.define_ctl_flags()
flags.adopt_module_key_flags(ctl_common)
app.run(main)
......@@ -18,20 +18,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tempfile import mkdtemp
import tensorflow as tf
import tensorflow.compat.v2 as tf
from tensorflow.python.eager import context
from tensorflow.python.platform import googletest
from official.resnet.ctl import ctl_common
from official.resnet.ctl import ctl_imagenet_main
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import common
from official.utils.misc import keras_utils
from official.utils.testing import integration
from official.vision.image_classification import common
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import resnet_ctl_imagenet_main
class CtlImagenetTest(googletest.TestCase):
class CtlImagenetTest(tf.test.TestCase):
"""Unit tests for Keras ResNet with ImageNet using CTL."""
_extra_flags = [
......@@ -41,21 +37,13 @@ class CtlImagenetTest(googletest.TestCase):
]
_tempdir = None
def get_temp_dir(self):
if not self._tempdir:
self._tempdir = mkdtemp(dir=googletest.GetTempDir())
return self._tempdir
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(CtlImagenetTest, cls).setUpClass()
common.define_keras_flags()
ctl_common.define_ctl_flags()
def setUp(self):
super(CtlImagenetTest, self).setUp()
if not keras_utils.is_v2_0():
tf.compat.v1.enable_v2_behavior()
imagenet_preprocessing.NUM_IMAGES['validation'] = 4
def tearDown(self):
......@@ -73,7 +61,7 @@ class CtlImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=ctl_imagenet_main.run,
main=resnet_ctl_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -93,10 +81,11 @@ class CtlImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=ctl_imagenet_main.run,
main=resnet_ctl_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
if __name__ == '__main__':
googletest.main()
assert tf.version.VERSION.startswith('2.')
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test the keras ResNet model with ImageNet data on TPU."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from official.utils.misc import keras_utils
from official.utils.testing import integration
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import resnet_imagenet_main
class KerasImagenetTest(tf.test.TestCase):
"""Unit tests for Keras ResNet with ImageNet."""
_extra_flags = [
"-batch_size", "4",
"-train_steps", "1",
"-use_synthetic_data", "true"
]
_tempdir = None
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(KerasImagenetTest, cls).setUpClass()
resnet_imagenet_main.define_imagenet_keras_flags()
def setUp(self):
super(KerasImagenetTest, self).setUp()
imagenet_preprocessing.NUM_IMAGES["validation"] = 4
def tearDown(self):
super(KerasImagenetTest, self).tearDown()
tf.io.gfile.rmtree(self.get_temp_dir())
def test_end_to_end_tpu(self):
"""Test Keras model with TPU distribution strategy."""
config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config)
extra_flags = [
"-distribution_strategy", "tpu",
"-data_format", "channels_last",
]
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
def test_end_to_end_tpu_bf16(self):
"""Test Keras model with TPU and bfloat16 activation."""
config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config)
extra_flags = [
"-distribution_strategy", "tpu",
"-data_format", "channels_last",
"-dtype", "bf16",
]
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
if __name__ == "__main__":
tf.compat.v1.enable_v2_behavior()
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment