Commit bf748370 authored by Nimit Nigania's avatar Nimit Nigania
Browse files

Merge remote-tracking branch 'upstream/master'

parents 7c732da7 0d2c2e01
......@@ -21,17 +21,17 @@ from __future__ import print_function
from absl import app as absl_app
from absl import flags
from absl import logging
import tensorflow as tf # pylint: disable=g-bad-import-order
import tensorflow as tf
from official.resnet.keras import imagenet_preprocessing
from official.resnet.keras import keras_common
from official.resnet.keras import resnet_model
from official.resnet.keras import trivial_model
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.utils.misc import model_helpers
from official.vision.image_classification import common
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import resnet_model
from official.benchmark.models import trivial_model
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
......@@ -57,7 +57,7 @@ def learning_rate_schedule(current_epoch,
Returns:
Adjusted learning rate.
"""
initial_lr = keras_common.BASE_LEARNING_RATE * batch_size / 256
initial_lr = common.BASE_LEARNING_RATE * batch_size / 256
epoch = current_epoch + float(current_batch) / batches_per_epoch
warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0]
if epoch < warmup_end_epoch:
......@@ -89,10 +89,10 @@ def run(flags_obj):
# Execute flag override logic for better model performance
if flags_obj.tf_gpu_thread_mode:
keras_common.set_gpu_thread_mode_and_count(flags_obj)
common.set_gpu_thread_mode_and_count(flags_obj)
if flags_obj.data_delay_prefetch:
keras_common.data_delay_prefetch()
keras_common.set_cudnn_batchnorm_mode()
common.data_delay_prefetch()
common.set_cudnn_batchnorm_mode()
dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'float16':
......@@ -105,10 +105,14 @@ def run(flags_obj):
if tf.test.is_built_with_cuda() else 'channels_last')
tf.keras.backend.set_image_data_format(data_format)
# Configures cluster spec for distribution strategy.
num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts,
flags_obj.task_index)
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus,
num_workers=distribution_utils.configure_cluster(),
num_workers=num_workers,
all_reduce_alg=flags_obj.all_reduce_alg,
num_packs=flags_obj.num_packs)
......@@ -125,7 +129,7 @@ def run(flags_obj):
# pylint: disable=protected-access
if flags_obj.use_synthetic_data:
distribution_utils.set_up_synthetic_data()
input_fn = keras_common.get_synth_input_fn(
input_fn = common.get_synth_input_fn(
height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_preprocessing.NUM_CHANNELS,
......@@ -165,7 +169,7 @@ def run(flags_obj):
lr_schedule = 0.1
if flags_obj.use_tensor_lr:
lr_schedule = keras_common.PiecewiseConstantDecayWithWarmup(
lr_schedule = common.PiecewiseConstantDecayWithWarmup(
batch_size=flags_obj.batch_size,
epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
warmup_epochs=LR_SCHEDULE[0][1],
......@@ -174,7 +178,7 @@ def run(flags_obj):
compute_lr_on_cpu=True)
with strategy_scope:
optimizer = keras_common.get_optimizer(lr_schedule)
optimizer = common.get_optimizer(lr_schedule)
if dtype == 'float16':
# TODO(reedwm): Remove manually wrapping optimizer once mixed precision
# can be enabled with a single line of code.
......@@ -182,6 +186,7 @@ def run(flags_obj):
optimizer, loss_scale=flags_core.get_loss_scale(flags_obj,
default_for_fp16=128))
# TODO(hongkuny): Remove trivial model usage and move it to benchmark.
if flags_obj.use_trivial_model:
model = trivial_model.trivial_model(
imagenet_preprocessing.NUM_CLASSES, dtype)
......@@ -207,7 +212,7 @@ def run(flags_obj):
if flags_obj.report_accuracy_metrics else None),
run_eagerly=flags_obj.run_eagerly)
callbacks = keras_common.get_callbacks(
callbacks = common.get_callbacks(
learning_rate_schedule, imagenet_preprocessing.NUM_IMAGES['train'])
train_steps = (
......@@ -257,13 +262,14 @@ def run(flags_obj):
if not strategy and flags_obj.explicit_gpu_placement:
no_dist_strat_device.__exit__()
stats = keras_common.build_stats(history, eval_output, callbacks)
stats = common.build_stats(history, eval_output, callbacks)
return stats
def define_imagenet_keras_flags():
keras_common.define_keras_flags()
common.define_keras_flags()
flags_core.set_defaults(train_epochs=90)
flags.adopt_module_key_flags(common)
def main(_):
......
......@@ -18,16 +18,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tempfile import mkdtemp
import tempfile
import tensorflow as tf
from official.resnet.keras import imagenet_preprocessing
from official.resnet.keras import keras_imagenet_main
from official.utils.misc import keras_utils
from official.utils.testing import integration
# pylint: disable=ungrouped-imports
from tensorflow.python.eager import context
from tensorflow.python.platform import googletest
from official.utils.misc import keras_utils
from official.utils.testing import integration
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import resnet_imagenet_main
class KerasImagenetTest(googletest.TestCase):
......@@ -42,13 +42,13 @@ class KerasImagenetTest(googletest.TestCase):
def get_temp_dir(self):
if not self._tempdir:
self._tempdir = mkdtemp(dir=googletest.GetTempDir())
self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir())
return self._tempdir
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(KerasImagenetTest, cls).setUpClass()
keras_imagenet_main.define_imagenet_keras_flags()
resnet_imagenet_main.define_imagenet_keras_flags()
def setUp(self):
super(KerasImagenetTest, self).setUp()
......@@ -71,7 +71,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_imagenet_main.run,
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -87,7 +87,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_imagenet_main.run,
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -111,7 +111,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_imagenet_main.run,
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -133,7 +133,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_imagenet_main.run,
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -156,7 +156,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_imagenet_main.run,
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -180,7 +180,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_imagenet_main.run,
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -204,7 +204,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_imagenet_main.run,
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -229,7 +229,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_imagenet_main.run,
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -250,7 +250,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_imagenet_main.run,
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -272,7 +272,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_imagenet_main.run,
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......
......@@ -28,7 +28,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.keras import backend
from tensorflow.python.keras import initializers
from tensorflow.python.keras import initializers
from tensorflow.python.keras import layers
from tensorflow.python.keras import models
from tensorflow.python.keras import regularizers
......@@ -39,7 +39,16 @@ BATCH_NORM_DECAY = 0.9
BATCH_NORM_EPSILON = 1e-5
def identity_block(input_tensor, kernel_size, filters, stage, block):
def _gen_l2_regularizer(use_l2_regularizer=True):
return regularizers.l2(L2_WEIGHT_DECAY) if use_l2_regularizer else None
def identity_block(input_tensor,
kernel_size,
filters,
stage,
block,
use_l2_regularizer=True):
"""The identity block is the block that has no conv layer at shortcut.
Args:
......@@ -48,6 +57,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
use_l2_regularizer: whether to use L2 regularizer on Conv layer.
Returns:
Output tensor for the block.
......@@ -60,35 +70,51 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = layers.Conv2D(filters1, (1, 1), use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')(x)
x = layers.Conv2D(
filters1, (1, 1),
use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name=conv_name_base + '2a')(
input_tensor)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')(
x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters2, kernel_size,
padding='same', use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')(x)
x = layers.Conv2D(
filters2,
kernel_size,
padding='same',
use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name=conv_name_base + '2b')(
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')(
x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters3, (1, 1), use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2c')(x)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2c')(x)
x = layers.Conv2D(
filters3, (1, 1),
use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name=conv_name_base + '2c')(
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2c')(
x)
x = layers.add([x, input_tensor])
x = layers.Activation('relu')(x)
......@@ -100,7 +126,8 @@ def conv_block(input_tensor,
filters,
stage,
block,
strides=(2, 2)):
strides=(2, 2),
use_l2_regularizer=True):
"""A block that has a conv layer at shortcut.
Note that from stage 3,
......@@ -114,6 +141,7 @@ def conv_block(input_tensor,
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
strides: Strides for the second conv layer in the block.
use_l2_regularizer: whether to use L2 regularizer on Conv layer.
Returns:
Output tensor for the block.
......@@ -126,114 +154,231 @@ def conv_block(input_tensor,
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = layers.Conv2D(filters1, (1, 1), use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')(x)
x = layers.Conv2D(
filters1, (1, 1),
use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name=conv_name_base + '2a')(
input_tensor)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')(
x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters2, kernel_size, strides=strides, padding='same',
use_bias=False, kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')(x)
x = layers.Conv2D(
filters2,
kernel_size,
strides=strides,
padding='same',
use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name=conv_name_base + '2b')(
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')(
x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters3, (1, 1), use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2c')(x)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2c')(x)
shortcut = layers.Conv2D(filters3, (1, 1), strides=strides, use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '1')(input_tensor)
shortcut = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '1')(shortcut)
x = layers.Conv2D(
filters3, (1, 1),
use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name=conv_name_base + '2c')(
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2c')(
x)
shortcut = layers.Conv2D(
filters3, (1, 1),
strides=strides,
use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name=conv_name_base + '1')(
input_tensor)
shortcut = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '1')(
shortcut)
x = layers.add([x, shortcut])
x = layers.Activation('relu')(x)
return x
def resnet50(num_classes, dtype='float32', batch_size=None):
def resnet50(num_classes,
dtype='float32',
batch_size=None,
use_l2_regularizer=True):
"""Instantiates the ResNet50 architecture.
Args:
num_classes: `int` number of classes for image classification.
dtype: dtype to use float32 or float16 are most common.
batch_size: Size of the batches for each step.
use_l2_regularizer: whether to use L2 regularizer on Conv/Dense layer.
Returns:
A Keras model instance.
"""
input_shape = (224, 224, 3)
img_input = layers.Input(shape=input_shape, dtype=dtype,
batch_size=batch_size)
img_input = layers.Input(
shape=input_shape, dtype=dtype, batch_size=batch_size)
if backend.image_data_format() == 'channels_first':
x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
name='transpose')(img_input)
x = layers.Lambda(
lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
name='transpose')(
img_input)
bn_axis = 1
else: # channels_last
x = img_input
bn_axis = 3
x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(x)
x = layers.Conv2D(64, (7, 7),
strides=(2, 2),
padding='valid', use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name='conv1')(x)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name='bn_conv1')(x)
x = layers.Conv2D(
64, (7, 7),
strides=(2, 2),
padding='valid',
use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name='conv1')(
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name='bn_conv1')(
x)
x = layers.Activation('relu')(x)
x = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
x = conv_block(
x,
3, [64, 64, 256],
stage=2,
block='a',
strides=(1, 1),
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [64, 64, 256],
stage=2,
block='b',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [64, 64, 256],
stage=2,
block='c',
use_l2_regularizer=use_l2_regularizer)
x = conv_block(
x,
3, [128, 128, 512],
stage=3,
block='a',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [128, 128, 512],
stage=3,
block='b',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [128, 128, 512],
stage=3,
block='c',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [128, 128, 512],
stage=3,
block='d',
use_l2_regularizer=use_l2_regularizer)
x = conv_block(
x,
3, [256, 256, 1024],
stage=4,
block='a',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='b',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='c',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='d',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='e',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='f',
use_l2_regularizer=use_l2_regularizer)
x = conv_block(
x,
3, [512, 512, 2048],
stage=5,
block='a',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [512, 512, 2048],
stage=5,
block='b',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [512, 512, 2048],
stage=5,
block='c',
use_l2_regularizer=use_l2_regularizer)
rm_axes = [1, 2] if backend.image_data_format() == 'channels_last' else [2, 3]
x = layers.Lambda(lambda x: backend.mean(x, rm_axes), name='reduce_mean')(x)
x = layers.Dense(
num_classes,
kernel_initializer=initializers.RandomNormal(stddev=0.01),
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name='fc1000')(x)
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
bias_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name='fc1000')(
x)
# TODO(reedwm): Remove manual casts once mixed precision can be enabled with a
# single line of code.
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Exports an LSTM detection model to use with tf-lite.
Outputs file:
......@@ -86,8 +85,9 @@ python lstm_object_detection/export_tflite_lstd_graph.py \
"""
import tensorflow as tf
from lstm_object_detection.utils import config_util
from lstm_object_detection import export_tflite_lstd_graph_lib
from lstm_object_detection.utils import config_util
flags = tf.app.flags
flags.DEFINE_string('output_directory', None, 'Path to write outputs.')
......@@ -122,12 +122,16 @@ def main(argv):
flags.mark_flag_as_required('trained_checkpoint_prefix')
pipeline_config = config_util.get_configs_from_pipeline_file(
FLAGS.pipeline_config_path)
FLAGS.pipeline_config_path)
export_tflite_lstd_graph_lib.export_tflite_graph(
pipeline_config, FLAGS.trained_checkpoint_prefix, FLAGS.output_directory,
FLAGS.add_postprocessing_op, FLAGS.max_detections,
FLAGS.max_classes_per_detection, use_regular_nms=FLAGS.use_regular_nms)
pipeline_config,
FLAGS.trained_checkpoint_prefix,
FLAGS.output_directory,
FLAGS.add_postprocessing_op,
FLAGS.max_detections,
FLAGS.max_classes_per_detection,
use_regular_nms=FLAGS.use_regular_nms)
if __name__ == '__main__':
......
......@@ -12,26 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Exports detection models to use with tf-lite.
See export_tflite_lstd_graph.py for usage.
"""
import os
import tempfile
import numpy as np
import tensorflow as tf
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.tools.graph_transforms import TransformGraph
from lstm_object_detection import model_builder
from object_detection import exporter
from object_detection.builders import graph_rewriter_builder
from object_detection.builders import post_processing_builder
from object_detection.core import box_list
from lstm_object_detection import model_builder
_DEFAULT_NUM_CHANNELS = 3
_DEFAULT_NUM_COORD_BOX = 4
......@@ -84,11 +84,11 @@ def append_postprocessing_op(frozen_graph_def,
num_classes: number of classes in SSD detector
scale_values: scale values is a dict with following key-value pairs
{y_scale: 10, x_scale: 10, h_scale: 5, w_scale: 5} that are used in decode
centersize boxes
centersize boxes
detections_per_class: In regular NonMaxSuppression, number of anchors used
for NonMaxSuppression per class
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead
of Fast NMS.
for NonMaxSuppression per class
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead of
Fast NMS.
Returns:
transformed_graph_def: Frozen GraphDef with postprocessing custom op
......@@ -165,9 +165,9 @@ def export_tflite_graph(pipeline_config,
is written to output_dir/tflite_graph.pb.
Args:
pipeline_config: Dictionary of configuration objects. Keys are `model`, `train_config`,
`train_input_config`, `eval_config`, `eval_input_config`, `lstm_model`.
Value are the corresponding config objects.
pipeline_config: Dictionary of configuration objects. Keys are `model`,
`train_config`, `train_input_config`, `eval_config`, `eval_input_config`,
`lstm_model`. Value are the corresponding config objects.
trained_checkpoint_prefix: a file prefix for the checkpoint containing the
trained parameters of the SSD model.
output_dir: A directory to write the tflite graph and anchor file to.
......@@ -176,9 +176,9 @@ def export_tflite_graph(pipeline_config,
max_detections: Maximum number of detections (boxes) to show
max_classes_per_detection: Number of classes to display per detection
detections_per_class: In regular NonMaxSuppression, number of anchors used
for NonMaxSuppression per class
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead
of Fast NMS.
for NonMaxSuppression per class
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead of
Fast NMS.
binary_graph_name: Name of the exported graph file in binary format.
txt_graph_name: Name of the exported graph file in text format.
......@@ -197,12 +197,10 @@ def export_tflite_graph(pipeline_config,
num_classes = model_config.ssd.num_classes
nms_score_threshold = {
model_config.ssd.post_processing.batch_non_max_suppression.
score_threshold
model_config.ssd.post_processing.batch_non_max_suppression.score_threshold
}
nms_iou_threshold = {
model_config.ssd.post_processing.batch_non_max_suppression.
iou_threshold
model_config.ssd.post_processing.batch_non_max_suppression.iou_threshold
}
scale_values = {}
scale_values['y_scale'] = {
......@@ -226,7 +224,7 @@ def export_tflite_graph(pipeline_config,
width = image_resizer_config.fixed_shape_resizer.width
if image_resizer_config.fixed_shape_resizer.convert_to_grayscale:
num_channels = 1
#TODO(richardbrks) figure out how to make with a None defined batch size
shape = [lstm_config.eval_unroll_length, height, width, num_channels]
else:
raise ValueError(
......@@ -235,14 +233,14 @@ def export_tflite_graph(pipeline_config,
image_resizer_config.WhichOneof('image_resizer_oneof')))
video_tensor = tf.placeholder(
tf.float32, shape=shape, name='input_video_tensor')
tf.float32, shape=shape, name='input_video_tensor')
detection_model = model_builder.build(model_config, lstm_config,
is_training=False)
detection_model = model_builder.build(
model_config, lstm_config, is_training=False)
preprocessed_video, true_image_shapes = detection_model.preprocess(
tf.to_float(video_tensor))
tf.to_float(video_tensor))
predicted_tensors = detection_model.predict(preprocessed_video,
true_image_shapes)
true_image_shapes)
# predicted_tensors = detection_model.postprocess(predicted_tensors,
# true_image_shapes)
# The score conversion occurs before the post-processing custom op
......@@ -311,7 +309,7 @@ def export_tflite_graph(pipeline_config,
initializer_nodes='')
# Add new operation to do post processing in a custom op (TF Lite only)
#(richardbrks) Do use this or detection_model.postprocess?
if add_postprocessing_op:
transformed_graph_def = append_postprocessing_op(
frozen_graph_def, max_detections, max_classes_per_detection,
......
# Exporting a tflite model from a checkpoint
Starting from a trained model checkpoint, creating a tflite model requires 2 steps:
* exporting a tflite frozen graph from a checkpoint
* exporting a tflite model from a frozen graph
Starting from a trained model checkpoint, creating a tflite model requires 2
steps:
* exporting a tflite frozen graph from a checkpoint
* exporting a tflite model from a frozen graph
## Exporting a tflite frozen graph from a checkpoint
......@@ -20,14 +20,14 @@ python lstm_object_detection/export_tflite_lstd_graph.py \
--pipeline_config_path ${PIPELINE_CONFIG_PATH} \
--trained_checkpoint_prefix ${TRAINED_CKPT_PREFIX} \
--output_directory ${EXPORT_DIR} \
--add_preprocessing_op
--add_preprocessing_op
```
After export, you should see the directory ${EXPORT_DIR} containing the following files:
* `tflite_graph.pb`
* `tflite_graph.pbtxt`
After export, you should see the directory ${EXPORT_DIR} containing the
following files:
* `tflite_graph.pb`
* `tflite_graph.pbtxt`
## Exporting a tflite model from a frozen graph
......@@ -40,10 +40,10 @@ FROZEN_GRAPH_PATH={path to exported tflite_graph.pb}
EXPORT_PATH={path to filename that will be used for export}
PIPELINE_CONFIG_PATH={path to pipeline config}
python lstm_object_detection/export_tflite_lstd_model.py \
--export_path ${EXPORT_PATH} \
--frozen_graph_path ${FROZEN_GRAPH_PATH} \
--pipeline_config_path ${PIPELINE_CONFIG_PATH}
--export_path ${EXPORT_PATH} \
--frozen_graph_path ${FROZEN_GRAPH_PATH} \
--pipeline_config_path ${PIPELINE_CONFIG_PATH}
```
After export, you should see the file ${EXPORT_PATH} containing the FlatBuffer
model to be used by an application.
\ No newline at end of file
model to be used by an application.
This diff is collapsed.
......@@ -90,12 +90,6 @@ http_archive(
sha256 = "79d102c61e2a479a0b7e5fc167bcfaa4832a0c6aad4a75fa7da0480564931bcc",
)
#
# http_archive(
# name = "com_google_protobuf",
# strip_prefix = "protobuf-master",
# urls = ["https://github.com/protocolbuffers/protobuf/archive/master.zip"],
# )
# Needed by TensorFlow
http_archive(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment