Commit b974c3f9 authored by Haoyu Zhang's avatar Haoyu Zhang Committed by A. Unique TensorFlower
Browse files

Moving Keras ResNet models to `official/vision/image_classification` and...

Moving Keras ResNet models to `official/vision/image_classification` and benchmarks to `official/benchmark`.

PiperOrigin-RevId: 264268533
parent b1188d03
......@@ -22,8 +22,8 @@ import time
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet.keras import keras_benchmark
from official.resnet.keras import keras_cifar_main
from official.benchmark import keras_benchmark
from official.vision.image_classification import resnet_cifar_main
MIN_TOP_1_ACCURACY = 0.929
MAX_TOP_1_ACCURACY = 0.938
......@@ -47,7 +47,7 @@ class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark):
"""
self.data_dir = os.path.join(root_data_dir, CIFAR_DATA_DIR_NAME)
flag_methods = [keras_cifar_main.define_cifar_flags]
flag_methods = [resnet_cifar_main.define_cifar_flags]
super(Resnet56KerasAccuracy, self).__init__(
output_dir=output_dir, flag_methods=flag_methods)
......@@ -199,7 +199,7 @@ class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark):
def _run_and_report_benchmark(self):
start_time_sec = time.time()
stats = keras_cifar_main.run(FLAGS)
stats = resnet_cifar_main.run(FLAGS)
wall_time_sec = time.time() - start_time_sec
super(Resnet56KerasAccuracy, self)._report_benchmark(
......@@ -215,7 +215,7 @@ class Resnet56KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
"""Short performance tests for ResNet56 via Keras and CIFAR-10."""
def __init__(self, output_dir=None, default_flags=None):
flag_methods = [keras_cifar_main.define_cifar_flags]
flag_methods = [resnet_cifar_main.define_cifar_flags]
super(Resnet56KerasBenchmarkBase, self).__init__(
output_dir=output_dir,
......@@ -224,7 +224,7 @@ class Resnet56KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
def _run_and_report_benchmark(self):
start_time_sec = time.time()
stats = keras_cifar_main.run(FLAGS)
stats = resnet_cifar_main.run(FLAGS)
wall_time_sec = time.time() - start_time_sec
super(Resnet56KerasBenchmarkBase, self)._report_benchmark(
......
......@@ -21,8 +21,8 @@ import time
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet.keras import keras_benchmark
from official.resnet.keras import keras_imagenet_main
from official.benchmark import keras_benchmark
from official.vision.image_classification import resnet_imagenet_main
MIN_TOP_1_ACCURACY = 0.76
MAX_TOP_1_ACCURACY = 0.77
......@@ -44,7 +44,7 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
named arguments before updating the constructor.
"""
flag_methods = [keras_imagenet_main.define_imagenet_keras_flags]
flag_methods = [resnet_imagenet_main.define_imagenet_keras_flags]
self.data_dir = os.path.join(root_data_dir, 'imagenet')
super(Resnet50KerasAccuracy, self).__init__(
......@@ -158,7 +158,7 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
top_1_min=MIN_TOP_1_ACCURACY,
top_1_max=MAX_TOP_1_ACCURACY):
start_time_sec = time.time()
stats = keras_imagenet_main.run(flags.FLAGS)
stats = resnet_imagenet_main.run(flags.FLAGS)
wall_time_sec = time.time() - start_time_sec
super(Resnet50KerasAccuracy, self)._report_benchmark(
......@@ -177,7 +177,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
"""Resnet50 benchmarks."""
def __init__(self, output_dir=None, default_flags=None):
flag_methods = [keras_imagenet_main.define_imagenet_keras_flags]
flag_methods = [resnet_imagenet_main.define_imagenet_keras_flags]
super(Resnet50KerasBenchmarkBase, self).__init__(
output_dir=output_dir,
......@@ -186,7 +186,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
def _run_and_report_benchmark(self):
start_time_sec = time.time()
stats = keras_imagenet_main.run(FLAGS)
stats = resnet_imagenet_main.run(FLAGS)
wall_time_sec = time.time() - start_time_sec
# Number of logged step time entries that are excluded in performance
# report. We keep results from last 100 batches in this case.
......@@ -779,7 +779,7 @@ class TrivialKerasBenchmarkReal(keras_benchmark.KerasBenchmark):
"""Trivial model with real data benchmark tests."""
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
flag_methods = [keras_imagenet_main.define_imagenet_keras_flags]
flag_methods = [resnet_imagenet_main.define_imagenet_keras_flags]
def_flags = {}
def_flags['use_trivial_model'] = True
......@@ -799,7 +799,7 @@ class TrivialKerasBenchmarkReal(keras_benchmark.KerasBenchmark):
def _run_and_report_benchmark(self):
start_time_sec = time.time()
stats = keras_imagenet_main.run(FLAGS)
stats = resnet_imagenet_main.run(FLAGS)
wall_time_sec = time.time() - start_time_sec
super(TrivialKerasBenchmarkReal, self)._report_benchmark(
......
# ResNet in TensorFlow
* For the Keras version of the ResNet model, see
[`official/resnet/keras`](keras).
[`official/vision/image_classification`](../vision/image_classification).
* For the Keras custom training loop version, see
[`official/resnet/ctl`](ctl).
* For the Estimator version, see [`official/r1/resnet`](../r1/resnet).
\ No newline at end of file
* For the Estimator version, see [`official/r1/resnet`](../r1/resnet).
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Bring in the shared Keras ResNet modules into this module.
The TensorFlow official Keras models are moved under
official/vision/image_classification
In order to be backward compatible with models that directly import its modules,
we import the Keras ResNet modules under official.resnet.keras.
New TF models should not depend on modules directly under this path.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from official.vision.image_classification import cifar_preprocessing
from official.vision.image_classification import common as keras_common
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import resnet_cifar_main as keras_cifar_main
from official.vision.image_classification import resnet_cifar_model
from official.vision.image_classification import resnet_imagenet_main as keras_imagenet_main
from official.vision.image_classification import resnet_model
del absolute_import
del division
del print_function
This folder contains the Keras implementation of the ResNet models. For more
information about the models, please refer to this [README file](../README.md).
This folder contains the Keras implementation of the ResNet models. For more
information about the models, please refer to this [README file](../../README.md).
Similar to the [estimator implementation](/official/resnet), the Keras
Similar to the [estimator implementation](../../r1/resnet), the Keras
implementation has code for both CIFAR-10 data and ImageNet data. The CIFAR-10
version uses a ResNet56 model implemented in
[`resnet_cifar_model.py`](./resnet_cifar_model.py), and the ImageNet version
version uses a ResNet56 model implemented in
[`resnet_cifar_model.py`](./resnet_cifar_model.py), and the ImageNet version
uses a ResNet50 model implemented in [`resnet_model.py`](./resnet_model.py).
To use
either dataset, make sure that you have the latest version of TensorFlow
installed and
To use
either dataset, make sure that you have the latest version of TensorFlow
installed and
[add the models folder to your Python path](/official/#running-the-models),
otherwise you may encounter an error like `ImportError: No module named
otherwise you may encounter an error like `ImportError: No module named
official.resnet`.
## CIFAR-10
......@@ -36,7 +36,7 @@ python keras_cifar_main.py --data_dir=/path/to/cifar
## ImageNet
Download the ImageNet dataset and convert it to TFRecord format.
Download the ImageNet dataset and convert it to TFRecord format.
The following [script](https://github.com/tensorflow/tpu/blob/master/tools/datasets/imagenet_to_gcs.py)
and [README](https://github.com/tensorflow/tpu/tree/master/tools/datasets#imagenet_to_gcspy)
provide a few options.
......@@ -56,17 +56,17 @@ python keras_imagenet_main.py --data_dir=/path/to/imagenet
There are more flag options you can specify. Here are some examples:
- `--use_synthetic_data`: when set to true, synthetic data, rather than real
- `--use_synthetic_data`: when set to true, synthetic data, rather than real
data, are used;
- `--batch_size`: the batch size used for the model;
- `--model_dir`: the directory to save the model checkpoint;
- `--train_epochs`: number of epoches to run for training the model;
- `--train_steps`: number of steps to run for training the model. We now only
support a number that is smaller than the number of batches in an epoch.
- `--skip_eval`: when set to true, evaluation as well as validation during
- `--skip_eval`: when set to true, evaluation as well as validation during
training is skipped
For example, this is a typical command line to run with ImageNet data with
For example, this is a typical command line to run with ImageNet data with
batch size 128 per GPU:
```bash
......@@ -82,19 +82,19 @@ python -m keras_imagenet_main \
See [`keras_common.py`](keras_common.py) for full list of options.
## Using multiple GPUs
You can train these models on multiple GPUs using `tf.distribute.Strategy` API.
You can read more about them in this
You can train these models on multiple GPUs using `tf.distribute.Strategy` API.
You can read more about them in this
[guide](https://www.tensorflow.org/guide/distribute_strategy).
In this example, we have made it easier to use is with just a command line flag
`--num_gpus`. By default this flag is 1 if TensorFlow is compiled with CUDA,
In this example, we have made it easier to use is with just a command line flag
`--num_gpus`. By default this flag is 1 if TensorFlow is compiled with CUDA,
and 0 otherwise.
- --num_gpus=0: Uses tf.distribute.OneDeviceStrategy with CPU as the device.
- --num_gpus=1: Uses tf.distribute.OneDeviceStrategy with GPU as the device.
- --num_gpus=2+: Uses tf.distribute.MirroredStrategy to run synchronous
- --num_gpus=2+: Uses tf.distribute.MirroredStrategy to run synchronous
distributed training across the GPUs.
If you wish to run without `tf.distribute.Strategy`, you can do so by setting
If you wish to run without `tf.distribute.Strategy`, you can do so by setting
`--distribution_strategy=off`.
......@@ -22,7 +22,7 @@ import os
from absl import logging
import tensorflow as tf
from official.resnet.keras import imagenet_preprocessing
from official.vision.image_classification import imagenet_preprocessing
HEIGHT = 32
WIDTH = 32
......
......@@ -20,17 +20,13 @@ from __future__ import print_function
import multiprocessing
import os
import numpy as np
# pylint: disable=g-bad-import-order
from absl import flags
import numpy as np
import tensorflow as tf
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
from official.utils.flags import core as flags_core
from official.utils.misc import keras_utils
# pylint: disable=ungrouped-imports
from tensorflow.python.keras.optimizer_v2 import (gradient_descent as
gradient_descent_v2)
FLAGS = flags.FLAGS
BASE_LEARNING_RATE = 0.1 # This matches Jing's version.
......
......@@ -12,21 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the keras_common module."""
"""Tests for the common module."""
from __future__ import absolute_import
from __future__ import print_function
from mock import Mock
import numpy as np
import tensorflow as tf # pylint: disable=g-bad-import-order
from tensorflow.python.platform import googletest
import tensorflow as tf
from official.resnet.keras import keras_common
from tensorflow.python.platform import googletest
from official.utils.misc import keras_utils
from official.vision.image_classification import common
class KerasCommonTests(tf.test.TestCase):
"""Tests for keras_common."""
"""Tests for common."""
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
......@@ -42,7 +42,7 @@ class KerasCommonTests(tf.test.TestCase):
keras_utils.BatchTimestamp(1, 2),
keras_utils.BatchTimestamp(2, 3)]
th.train_finish_time = 12345
stats = keras_common.build_stats(history, eval_output, [th])
stats = common.build_stats(history, eval_output, [th])
self.assertEqual(1.145, stats['loss'])
self.assertEqual(.99988, stats['training_accuracy_top_1'])
......@@ -57,7 +57,7 @@ class KerasCommonTests(tf.test.TestCase):
history = self._build_history(1.145, cat_accuracy_sparse=.99988)
eval_output = self._build_eval_output(.928, 1.9844)
stats = keras_common.build_stats(history, eval_output, None)
stats = common.build_stats(history, eval_output, None)
self.assertEqual(1.145, stats['loss'])
self.assertEqual(.99988, stats['training_accuracy_top_1'])
......
......@@ -22,13 +22,13 @@ from absl import app as absl_app
from absl import flags
import tensorflow as tf
from official.resnet.keras import cifar_preprocessing
from official.resnet.keras import keras_common
from official.resnet.keras import resnet_cifar_model
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.vision.image_classification import cifar_preprocessing
from official.vision.image_classification import common
from official.vision.image_classification import resnet_cifar_model
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
......@@ -55,7 +55,7 @@ def learning_rate_schedule(current_epoch,
Adjusted learning rate.
"""
del current_batch, batches_per_epoch # not used
initial_learning_rate = keras_common.BASE_LEARNING_RATE * batch_size / 128
initial_learning_rate = common.BASE_LEARNING_RATE * batch_size / 128
learning_rate = initial_learning_rate
for mult, start_epoch in LR_SCHEDULE:
if current_epoch >= start_epoch:
......@@ -83,8 +83,8 @@ def run(flags_obj):
# Execute flag override logic for better model performance
if flags_obj.tf_gpu_thread_mode:
keras_common.set_gpu_thread_mode_and_count(flags_obj)
keras_common.set_cudnn_batchnorm_mode()
common.set_gpu_thread_mode_and_count(flags_obj)
common.set_cudnn_batchnorm_mode()
dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'fp16':
......@@ -116,7 +116,7 @@ def run(flags_obj):
if flags_obj.use_synthetic_data:
distribution_utils.set_up_synthetic_data()
input_fn = keras_common.get_synth_input_fn(
input_fn = common.get_synth_input_fn(
height=cifar_preprocessing.HEIGHT,
width=cifar_preprocessing.WIDTH,
num_channels=cifar_preprocessing.NUM_CHANNELS,
......@@ -150,7 +150,7 @@ def run(flags_obj):
parse_record_fn=cifar_preprocessing.parse_record)
with strategy_scope:
optimizer = keras_common.get_optimizer()
optimizer = common.get_optimizer()
model = resnet_cifar_model.resnet56(classes=cifar_preprocessing.NUM_CLASSES)
# TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
......@@ -171,7 +171,7 @@ def run(flags_obj):
if flags_obj.report_accuracy_metrics else None),
run_eagerly=flags_obj.run_eagerly)
callbacks = keras_common.get_callbacks(
callbacks = common.get_callbacks(
learning_rate_schedule, cifar_preprocessing.NUM_IMAGES['train'])
train_steps = cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size
......@@ -216,12 +216,12 @@ def run(flags_obj):
if not strategy and flags_obj.explicit_gpu_placement:
no_dist_strat_device.__exit__()
stats = keras_common.build_stats(history, eval_output, callbacks)
stats = common.build_stats(history, eval_output, callbacks)
return stats
def define_cifar_flags():
keras_common.define_keras_flags(dynamic_loss_scale=False)
common.define_keras_flags(dynamic_loss_scale=False)
flags_core.set_defaults(data_dir='/tmp/cifar10_data/cifar-10-batches-bin',
model_dir='/tmp/cifar10_model',
......
......@@ -18,17 +18,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tempfile import mkdtemp
import tempfile
import tensorflow as tf
from official.resnet.keras import cifar_preprocessing
from official.resnet.keras import keras_cifar_main
from official.resnet.keras import keras_common
from official.utils.misc import keras_utils
from official.utils.testing import integration
# pylint: disable=ungrouped-imports
from tensorflow.python.eager import context
from tensorflow.python.platform import googletest
from official.utils.misc import keras_utils
from official.utils.testing import integration
from official.vision.image_classification import cifar_preprocessing
from official.vision.image_classification import resnet_cifar_main
class KerasCifarTest(googletest.TestCase):
......@@ -43,13 +42,13 @@ class KerasCifarTest(googletest.TestCase):
def get_temp_dir(self):
if not self._tempdir:
self._tempdir = mkdtemp(dir=googletest.GetTempDir())
self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir())
return self._tempdir
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(KerasCifarTest, cls).setUpClass()
keras_cifar_main.define_cifar_flags()
resnet_cifar_main.define_cifar_flags()
def setUp(self):
super(KerasCifarTest, self).setUp()
......@@ -72,7 +71,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_cifar_main.run,
main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -88,7 +87,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_cifar_main.run,
main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -112,7 +111,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_cifar_main.run,
main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -134,7 +133,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_cifar_main.run,
main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -157,7 +156,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_cifar_main.run,
main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -178,7 +177,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_cifar_main.run,
main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......
......@@ -21,17 +21,17 @@ from __future__ import print_function
from absl import app as absl_app
from absl import flags
from absl import logging
import tensorflow as tf # pylint: disable=g-bad-import-order
import tensorflow as tf
from official.resnet.keras import imagenet_preprocessing
from official.resnet.keras import keras_common
from official.resnet.keras import resnet_model
from official.resnet.keras import trivial_model
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.utils.misc import model_helpers
from official.vision.image_classification import common
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import resnet_model
from official.vision.image_classification import trivial_model
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
......@@ -57,7 +57,7 @@ def learning_rate_schedule(current_epoch,
Returns:
Adjusted learning rate.
"""
initial_lr = keras_common.BASE_LEARNING_RATE * batch_size / 256
initial_lr = common.BASE_LEARNING_RATE * batch_size / 256
epoch = current_epoch + float(current_batch) / batches_per_epoch
warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0]
if epoch < warmup_end_epoch:
......@@ -89,10 +89,10 @@ def run(flags_obj):
# Execute flag override logic for better model performance
if flags_obj.tf_gpu_thread_mode:
keras_common.set_gpu_thread_mode_and_count(flags_obj)
common.set_gpu_thread_mode_and_count(flags_obj)
if flags_obj.data_delay_prefetch:
keras_common.data_delay_prefetch()
keras_common.set_cudnn_batchnorm_mode()
common.data_delay_prefetch()
common.set_cudnn_batchnorm_mode()
dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'float16':
......@@ -129,7 +129,7 @@ def run(flags_obj):
# pylint: disable=protected-access
if flags_obj.use_synthetic_data:
distribution_utils.set_up_synthetic_data()
input_fn = keras_common.get_synth_input_fn(
input_fn = common.get_synth_input_fn(
height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_preprocessing.NUM_CHANNELS,
......@@ -169,7 +169,7 @@ def run(flags_obj):
lr_schedule = 0.1
if flags_obj.use_tensor_lr:
lr_schedule = keras_common.PiecewiseConstantDecayWithWarmup(
lr_schedule = common.PiecewiseConstantDecayWithWarmup(
batch_size=flags_obj.batch_size,
epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
warmup_epochs=LR_SCHEDULE[0][1],
......@@ -178,7 +178,7 @@ def run(flags_obj):
compute_lr_on_cpu=True)
with strategy_scope:
optimizer = keras_common.get_optimizer(lr_schedule)
optimizer = common.get_optimizer(lr_schedule)
if dtype == 'float16':
# TODO(reedwm): Remove manually wrapping optimizer once mixed precision
# can be enabled with a single line of code.
......@@ -211,7 +211,7 @@ def run(flags_obj):
if flags_obj.report_accuracy_metrics else None),
run_eagerly=flags_obj.run_eagerly)
callbacks = keras_common.get_callbacks(
callbacks = common.get_callbacks(
learning_rate_schedule, imagenet_preprocessing.NUM_IMAGES['train'])
train_steps = (
......@@ -261,14 +261,14 @@ def run(flags_obj):
if not strategy and flags_obj.explicit_gpu_placement:
no_dist_strat_device.__exit__()
stats = keras_common.build_stats(history, eval_output, callbacks)
stats = common.build_stats(history, eval_output, callbacks)
return stats
def define_imagenet_keras_flags():
keras_common.define_keras_flags()
common.define_keras_flags()
flags_core.set_defaults(train_epochs=90)
flags.adopt_module_key_flags(keras_common)
flags.adopt_module_key_flags(common)
def main(_):
......
......@@ -18,16 +18,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tempfile import mkdtemp
import tempfile
import tensorflow as tf
from official.resnet.keras import imagenet_preprocessing
from official.resnet.keras import keras_imagenet_main
from official.utils.misc import keras_utils
from official.utils.testing import integration
# pylint: disable=ungrouped-imports
from tensorflow.python.eager import context
from tensorflow.python.platform import googletest
from official.utils.misc import keras_utils
from official.utils.testing import integration
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import resnet_imagenet_main
class KerasImagenetTest(googletest.TestCase):
......@@ -42,13 +42,13 @@ class KerasImagenetTest(googletest.TestCase):
def get_temp_dir(self):
if not self._tempdir:
self._tempdir = mkdtemp(dir=googletest.GetTempDir())
self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir())
return self._tempdir
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(KerasImagenetTest, cls).setUpClass()
keras_imagenet_main.define_imagenet_keras_flags()
resnet_imagenet_main.define_imagenet_keras_flags()
def setUp(self):
super(KerasImagenetTest, self).setUp()
......@@ -71,7 +71,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_imagenet_main.run,
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -87,7 +87,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_imagenet_main.run,
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -111,7 +111,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_imagenet_main.run,
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -133,7 +133,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_imagenet_main.run,
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -156,7 +156,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_imagenet_main.run,
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -180,7 +180,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_imagenet_main.run,
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -204,7 +204,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_imagenet_main.run,
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -229,7 +229,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_imagenet_main.run,
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -250,7 +250,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_imagenet_main.run,
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -272,7 +272,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_imagenet_main.run,
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment