Commit 5b0ef1fc authored by Nimit Nigania's avatar Nimit Nigania
Browse files

Merge branch 'master' into ncf_f16

parents 1cba90f3 bf748370
...@@ -21,17 +21,17 @@ from __future__ import print_function ...@@ -21,17 +21,17 @@ from __future__ import print_function
from absl import app as absl_app from absl import app as absl_app
from absl import flags from absl import flags
from absl import logging from absl import logging
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf
from official.resnet.keras import imagenet_preprocessing
from official.resnet.keras import keras_common
from official.resnet.keras import resnet_model
from official.resnet.keras import trivial_model
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.logs import logger from official.utils.logs import logger
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
from official.vision.image_classification import common
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import resnet_model
from official.benchmark.models import trivial_model
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
...@@ -57,7 +57,7 @@ def learning_rate_schedule(current_epoch, ...@@ -57,7 +57,7 @@ def learning_rate_schedule(current_epoch,
Returns: Returns:
Adjusted learning rate. Adjusted learning rate.
""" """
initial_lr = keras_common.BASE_LEARNING_RATE * batch_size / 256 initial_lr = common.BASE_LEARNING_RATE * batch_size / 256
epoch = current_epoch + float(current_batch) / batches_per_epoch epoch = current_epoch + float(current_batch) / batches_per_epoch
warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0] warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0]
if epoch < warmup_end_epoch: if epoch < warmup_end_epoch:
...@@ -89,10 +89,10 @@ def run(flags_obj): ...@@ -89,10 +89,10 @@ def run(flags_obj):
# Execute flag override logic for better model performance # Execute flag override logic for better model performance
if flags_obj.tf_gpu_thread_mode: if flags_obj.tf_gpu_thread_mode:
keras_common.set_gpu_thread_mode_and_count(flags_obj) common.set_gpu_thread_mode_and_count(flags_obj)
if flags_obj.data_delay_prefetch: if flags_obj.data_delay_prefetch:
keras_common.data_delay_prefetch() common.data_delay_prefetch()
keras_common.set_cudnn_batchnorm_mode() common.set_cudnn_batchnorm_mode()
dtype = flags_core.get_tf_dtype(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'float16': if dtype == 'float16':
...@@ -105,10 +105,14 @@ def run(flags_obj): ...@@ -105,10 +105,14 @@ def run(flags_obj):
if tf.test.is_built_with_cuda() else 'channels_last') if tf.test.is_built_with_cuda() else 'channels_last')
tf.keras.backend.set_image_data_format(data_format) tf.keras.backend.set_image_data_format(data_format)
# Configures cluster spec for distribution strategy.
num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts,
flags_obj.task_index)
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy, distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus, num_gpus=flags_obj.num_gpus,
num_workers=distribution_utils.configure_cluster(), num_workers=num_workers,
all_reduce_alg=flags_obj.all_reduce_alg, all_reduce_alg=flags_obj.all_reduce_alg,
num_packs=flags_obj.num_packs) num_packs=flags_obj.num_packs)
...@@ -125,7 +129,7 @@ def run(flags_obj): ...@@ -125,7 +129,7 @@ def run(flags_obj):
# pylint: disable=protected-access # pylint: disable=protected-access
if flags_obj.use_synthetic_data: if flags_obj.use_synthetic_data:
distribution_utils.set_up_synthetic_data() distribution_utils.set_up_synthetic_data()
input_fn = keras_common.get_synth_input_fn( input_fn = common.get_synth_input_fn(
height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_preprocessing.NUM_CHANNELS, num_channels=imagenet_preprocessing.NUM_CHANNELS,
...@@ -165,7 +169,7 @@ def run(flags_obj): ...@@ -165,7 +169,7 @@ def run(flags_obj):
lr_schedule = 0.1 lr_schedule = 0.1
if flags_obj.use_tensor_lr: if flags_obj.use_tensor_lr:
lr_schedule = keras_common.PiecewiseConstantDecayWithWarmup( lr_schedule = common.PiecewiseConstantDecayWithWarmup(
batch_size=flags_obj.batch_size, batch_size=flags_obj.batch_size,
epoch_size=imagenet_preprocessing.NUM_IMAGES['train'], epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
warmup_epochs=LR_SCHEDULE[0][1], warmup_epochs=LR_SCHEDULE[0][1],
...@@ -174,7 +178,7 @@ def run(flags_obj): ...@@ -174,7 +178,7 @@ def run(flags_obj):
compute_lr_on_cpu=True) compute_lr_on_cpu=True)
with strategy_scope: with strategy_scope:
optimizer = keras_common.get_optimizer(lr_schedule) optimizer = common.get_optimizer(lr_schedule)
if dtype == 'float16': if dtype == 'float16':
# TODO(reedwm): Remove manually wrapping optimizer once mixed precision # TODO(reedwm): Remove manually wrapping optimizer once mixed precision
# can be enabled with a single line of code. # can be enabled with a single line of code.
...@@ -182,6 +186,7 @@ def run(flags_obj): ...@@ -182,6 +186,7 @@ def run(flags_obj):
optimizer, loss_scale=flags_core.get_loss_scale(flags_obj, optimizer, loss_scale=flags_core.get_loss_scale(flags_obj,
default_for_fp16=128)) default_for_fp16=128))
# TODO(hongkuny): Remove trivial model usage and move it to benchmark.
if flags_obj.use_trivial_model: if flags_obj.use_trivial_model:
model = trivial_model.trivial_model( model = trivial_model.trivial_model(
imagenet_preprocessing.NUM_CLASSES, dtype) imagenet_preprocessing.NUM_CLASSES, dtype)
...@@ -207,7 +212,7 @@ def run(flags_obj): ...@@ -207,7 +212,7 @@ def run(flags_obj):
if flags_obj.report_accuracy_metrics else None), if flags_obj.report_accuracy_metrics else None),
run_eagerly=flags_obj.run_eagerly) run_eagerly=flags_obj.run_eagerly)
callbacks = keras_common.get_callbacks( callbacks = common.get_callbacks(
learning_rate_schedule, imagenet_preprocessing.NUM_IMAGES['train']) learning_rate_schedule, imagenet_preprocessing.NUM_IMAGES['train'])
train_steps = ( train_steps = (
...@@ -257,13 +262,14 @@ def run(flags_obj): ...@@ -257,13 +262,14 @@ def run(flags_obj):
if not strategy and flags_obj.explicit_gpu_placement: if not strategy and flags_obj.explicit_gpu_placement:
no_dist_strat_device.__exit__() no_dist_strat_device.__exit__()
stats = keras_common.build_stats(history, eval_output, callbacks) stats = common.build_stats(history, eval_output, callbacks)
return stats return stats
def define_imagenet_keras_flags(): def define_imagenet_keras_flags():
keras_common.define_keras_flags() common.define_keras_flags()
flags_core.set_defaults(train_epochs=90) flags_core.set_defaults(train_epochs=90)
flags.adopt_module_key_flags(common)
def main(_): def main(_):
......
...@@ -18,16 +18,16 @@ from __future__ import absolute_import ...@@ -18,16 +18,16 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tempfile import mkdtemp import tempfile
import tensorflow as tf import tensorflow as tf
from official.resnet.keras import imagenet_preprocessing
from official.resnet.keras import keras_imagenet_main
from official.utils.misc import keras_utils
from official.utils.testing import integration
# pylint: disable=ungrouped-imports
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.platform import googletest from tensorflow.python.platform import googletest
from official.utils.misc import keras_utils
from official.utils.testing import integration
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import resnet_imagenet_main
class KerasImagenetTest(googletest.TestCase): class KerasImagenetTest(googletest.TestCase):
...@@ -42,13 +42,13 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -42,13 +42,13 @@ class KerasImagenetTest(googletest.TestCase):
def get_temp_dir(self): def get_temp_dir(self):
if not self._tempdir: if not self._tempdir:
self._tempdir = mkdtemp(dir=googletest.GetTempDir()) self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir())
return self._tempdir return self._tempdir
@classmethod @classmethod
def setUpClass(cls): # pylint: disable=invalid-name def setUpClass(cls): # pylint: disable=invalid-name
super(KerasImagenetTest, cls).setUpClass() super(KerasImagenetTest, cls).setUpClass()
keras_imagenet_main.define_imagenet_keras_flags() resnet_imagenet_main.define_imagenet_keras_flags()
def setUp(self): def setUp(self):
super(KerasImagenetTest, self).setUp() super(KerasImagenetTest, self).setUp()
...@@ -71,7 +71,7 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -71,7 +71,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_imagenet_main.run, main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
...@@ -87,7 +87,7 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -87,7 +87,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_imagenet_main.run, main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
...@@ -111,7 +111,7 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -111,7 +111,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_imagenet_main.run, main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
...@@ -133,7 +133,7 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -133,7 +133,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_imagenet_main.run, main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
...@@ -156,7 +156,7 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -156,7 +156,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_imagenet_main.run, main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
...@@ -180,7 +180,7 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -180,7 +180,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_imagenet_main.run, main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
...@@ -204,7 +204,7 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -204,7 +204,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_imagenet_main.run, main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
...@@ -229,7 +229,7 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -229,7 +229,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_imagenet_main.run, main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
...@@ -250,7 +250,7 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -250,7 +250,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_imagenet_main.run, main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
...@@ -272,7 +272,7 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -272,7 +272,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_imagenet_main.run, main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
......
...@@ -28,7 +28,7 @@ from __future__ import division ...@@ -28,7 +28,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.keras import backend from tensorflow.python.keras import backend
from tensorflow.python.keras import initializers from tensorflow.python.keras import initializers
from tensorflow.python.keras import layers from tensorflow.python.keras import layers
from tensorflow.python.keras import models from tensorflow.python.keras import models
from tensorflow.python.keras import regularizers from tensorflow.python.keras import regularizers
...@@ -39,7 +39,16 @@ BATCH_NORM_DECAY = 0.9 ...@@ -39,7 +39,16 @@ BATCH_NORM_DECAY = 0.9
BATCH_NORM_EPSILON = 1e-5 BATCH_NORM_EPSILON = 1e-5
def identity_block(input_tensor, kernel_size, filters, stage, block): def _gen_l2_regularizer(use_l2_regularizer=True):
return regularizers.l2(L2_WEIGHT_DECAY) if use_l2_regularizer else None
def identity_block(input_tensor,
kernel_size,
filters,
stage,
block,
use_l2_regularizer=True):
"""The identity block is the block that has no conv layer at shortcut. """The identity block is the block that has no conv layer at shortcut.
Args: Args:
...@@ -48,6 +57,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block): ...@@ -48,6 +57,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
filters: list of integers, the filters of 3 conv layer at main path filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names block: 'a','b'..., current block label, used for generating layer names
use_l2_regularizer: whether to use L2 regularizer on Conv layer.
Returns: Returns:
Output tensor for the block. Output tensor for the block.
...@@ -60,35 +70,51 @@ def identity_block(input_tensor, kernel_size, filters, stage, block): ...@@ -60,35 +70,51 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
conv_name_base = 'res' + str(stage) + block + '_branch' conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch' bn_name_base = 'bn' + str(stage) + block + '_branch'
x = layers.Conv2D(filters1, (1, 1), use_bias=False, x = layers.Conv2D(
kernel_initializer='he_normal', filters1, (1, 1),
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), use_bias=False,
name=conv_name_base + '2a')(input_tensor) kernel_initializer='he_normal',
x = layers.BatchNormalization(axis=bn_axis, kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
momentum=BATCH_NORM_DECAY, name=conv_name_base + '2a')(
epsilon=BATCH_NORM_EPSILON, input_tensor)
name=bn_name_base + '2a')(x) x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')(
x)
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
x = layers.Conv2D(filters2, kernel_size, x = layers.Conv2D(
padding='same', use_bias=False, filters2,
kernel_initializer='he_normal', kernel_size,
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), padding='same',
name=conv_name_base + '2b')(x) use_bias=False,
x = layers.BatchNormalization(axis=bn_axis, kernel_initializer='he_normal',
momentum=BATCH_NORM_DECAY, kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
epsilon=BATCH_NORM_EPSILON, name=conv_name_base + '2b')(
name=bn_name_base + '2b')(x) x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')(
x)
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
x = layers.Conv2D(filters3, (1, 1), use_bias=False, x = layers.Conv2D(
kernel_initializer='he_normal', filters3, (1, 1),
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), use_bias=False,
name=conv_name_base + '2c')(x) kernel_initializer='he_normal',
x = layers.BatchNormalization(axis=bn_axis, kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
momentum=BATCH_NORM_DECAY, name=conv_name_base + '2c')(
epsilon=BATCH_NORM_EPSILON, x)
name=bn_name_base + '2c')(x) x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2c')(
x)
x = layers.add([x, input_tensor]) x = layers.add([x, input_tensor])
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
...@@ -100,7 +126,8 @@ def conv_block(input_tensor, ...@@ -100,7 +126,8 @@ def conv_block(input_tensor,
filters, filters,
stage, stage,
block, block,
strides=(2, 2)): strides=(2, 2),
use_l2_regularizer=True):
"""A block that has a conv layer at shortcut. """A block that has a conv layer at shortcut.
Note that from stage 3, Note that from stage 3,
...@@ -114,6 +141,7 @@ def conv_block(input_tensor, ...@@ -114,6 +141,7 @@ def conv_block(input_tensor,
stage: integer, current stage label, used for generating layer names stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names block: 'a','b'..., current block label, used for generating layer names
strides: Strides for the second conv layer in the block. strides: Strides for the second conv layer in the block.
use_l2_regularizer: whether to use L2 regularizer on Conv layer.
Returns: Returns:
Output tensor for the block. Output tensor for the block.
...@@ -126,114 +154,231 @@ def conv_block(input_tensor, ...@@ -126,114 +154,231 @@ def conv_block(input_tensor,
conv_name_base = 'res' + str(stage) + block + '_branch' conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch' bn_name_base = 'bn' + str(stage) + block + '_branch'
x = layers.Conv2D(filters1, (1, 1), use_bias=False, x = layers.Conv2D(
kernel_initializer='he_normal', filters1, (1, 1),
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), use_bias=False,
name=conv_name_base + '2a')(input_tensor) kernel_initializer='he_normal',
x = layers.BatchNormalization(axis=bn_axis, kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
momentum=BATCH_NORM_DECAY, name=conv_name_base + '2a')(
epsilon=BATCH_NORM_EPSILON, input_tensor)
name=bn_name_base + '2a')(x) x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')(
x)
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
x = layers.Conv2D(filters2, kernel_size, strides=strides, padding='same', x = layers.Conv2D(
use_bias=False, kernel_initializer='he_normal', filters2,
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_size,
name=conv_name_base + '2b')(x) strides=strides,
x = layers.BatchNormalization(axis=bn_axis, padding='same',
momentum=BATCH_NORM_DECAY, use_bias=False,
epsilon=BATCH_NORM_EPSILON, kernel_initializer='he_normal',
name=bn_name_base + '2b')(x) kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name=conv_name_base + '2b')(
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')(
x)
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
x = layers.Conv2D(filters3, (1, 1), use_bias=False, x = layers.Conv2D(
kernel_initializer='he_normal', filters3, (1, 1),
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), use_bias=False,
name=conv_name_base + '2c')(x) kernel_initializer='he_normal',
x = layers.BatchNormalization(axis=bn_axis, kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
momentum=BATCH_NORM_DECAY, name=conv_name_base + '2c')(
epsilon=BATCH_NORM_EPSILON, x)
name=bn_name_base + '2c')(x) x = layers.BatchNormalization(
axis=bn_axis,
shortcut = layers.Conv2D(filters3, (1, 1), strides=strides, use_bias=False, momentum=BATCH_NORM_DECAY,
kernel_initializer='he_normal', epsilon=BATCH_NORM_EPSILON,
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), name=bn_name_base + '2c')(
name=conv_name_base + '1')(input_tensor) x)
shortcut = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY, shortcut = layers.Conv2D(
epsilon=BATCH_NORM_EPSILON, filters3, (1, 1),
name=bn_name_base + '1')(shortcut) strides=strides,
use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name=conv_name_base + '1')(
input_tensor)
shortcut = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '1')(
shortcut)
x = layers.add([x, shortcut]) x = layers.add([x, shortcut])
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
return x return x
def resnet50(num_classes, dtype='float32', batch_size=None): def resnet50(num_classes,
dtype='float32',
batch_size=None,
use_l2_regularizer=True):
"""Instantiates the ResNet50 architecture. """Instantiates the ResNet50 architecture.
Args: Args:
num_classes: `int` number of classes for image classification. num_classes: `int` number of classes for image classification.
dtype: dtype to use float32 or float16 are most common. dtype: dtype to use float32 or float16 are most common.
batch_size: Size of the batches for each step. batch_size: Size of the batches for each step.
use_l2_regularizer: whether to use L2 regularizer on Conv/Dense layer.
Returns: Returns:
A Keras model instance. A Keras model instance.
""" """
input_shape = (224, 224, 3) input_shape = (224, 224, 3)
img_input = layers.Input(shape=input_shape, dtype=dtype, img_input = layers.Input(
batch_size=batch_size) shape=input_shape, dtype=dtype, batch_size=batch_size)
if backend.image_data_format() == 'channels_first': if backend.image_data_format() == 'channels_first':
x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)), x = layers.Lambda(
name='transpose')(img_input) lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
name='transpose')(
img_input)
bn_axis = 1 bn_axis = 1
else: # channels_last else: # channels_last
x = img_input x = img_input
bn_axis = 3 bn_axis = 3
x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(x) x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(x)
x = layers.Conv2D(64, (7, 7), x = layers.Conv2D(
strides=(2, 2), 64, (7, 7),
padding='valid', use_bias=False, strides=(2, 2),
kernel_initializer='he_normal', padding='valid',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), use_bias=False,
name='conv1')(x) kernel_initializer='he_normal',
x = layers.BatchNormalization(axis=bn_axis, kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
momentum=BATCH_NORM_DECAY, name='conv1')(
epsilon=BATCH_NORM_EPSILON, x)
name='bn_conv1')(x) x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name='bn_conv1')(
x)
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
x = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x) x = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) x = conv_block(
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') x,
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') 3, [64, 64, 256],
stage=2,
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') block='a',
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') strides=(1, 1),
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') use_l2_regularizer=use_l2_regularizer)
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') x = identity_block(
x,
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') 3, [64, 64, 256],
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') stage=2,
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') block='b',
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') use_l2_regularizer=use_l2_regularizer)
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') x = identity_block(
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') x,
3, [64, 64, 256],
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') stage=2,
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') block='c',
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') use_l2_regularizer=use_l2_regularizer)
x = conv_block(
x,
3, [128, 128, 512],
stage=3,
block='a',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [128, 128, 512],
stage=3,
block='b',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [128, 128, 512],
stage=3,
block='c',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [128, 128, 512],
stage=3,
block='d',
use_l2_regularizer=use_l2_regularizer)
x = conv_block(
x,
3, [256, 256, 1024],
stage=4,
block='a',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='b',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='c',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='d',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='e',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='f',
use_l2_regularizer=use_l2_regularizer)
x = conv_block(
x,
3, [512, 512, 2048],
stage=5,
block='a',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [512, 512, 2048],
stage=5,
block='b',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [512, 512, 2048],
stage=5,
block='c',
use_l2_regularizer=use_l2_regularizer)
rm_axes = [1, 2] if backend.image_data_format() == 'channels_last' else [2, 3] rm_axes = [1, 2] if backend.image_data_format() == 'channels_last' else [2, 3]
x = layers.Lambda(lambda x: backend.mean(x, rm_axes), name='reduce_mean')(x) x = layers.Lambda(lambda x: backend.mean(x, rm_axes), name='reduce_mean')(x)
x = layers.Dense( x = layers.Dense(
num_classes, num_classes,
kernel_initializer=initializers.RandomNormal(stddev=0.01), kernel_initializer=initializers.RandomNormal(stddev=0.01),
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY), bias_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name='fc1000')(x) name='fc1000')(
x)
# TODO(reedwm): Remove manual casts once mixed precision can be enabled with a # TODO(reedwm): Remove manual casts once mixed precision can be enabled with a
# single line of code. # single line of code.
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
r"""Exports an LSTM detection model to use with tf-lite. r"""Exports an LSTM detection model to use with tf-lite.
Outputs file: Outputs file:
...@@ -86,8 +85,9 @@ python lstm_object_detection/export_tflite_lstd_graph.py \ ...@@ -86,8 +85,9 @@ python lstm_object_detection/export_tflite_lstd_graph.py \
""" """
import tensorflow as tf import tensorflow as tf
from lstm_object_detection.utils import config_util
from lstm_object_detection import export_tflite_lstd_graph_lib from lstm_object_detection import export_tflite_lstd_graph_lib
from lstm_object_detection.utils import config_util
flags = tf.app.flags flags = tf.app.flags
flags.DEFINE_string('output_directory', None, 'Path to write outputs.') flags.DEFINE_string('output_directory', None, 'Path to write outputs.')
...@@ -122,12 +122,16 @@ def main(argv): ...@@ -122,12 +122,16 @@ def main(argv):
flags.mark_flag_as_required('trained_checkpoint_prefix') flags.mark_flag_as_required('trained_checkpoint_prefix')
pipeline_config = config_util.get_configs_from_pipeline_file( pipeline_config = config_util.get_configs_from_pipeline_file(
FLAGS.pipeline_config_path) FLAGS.pipeline_config_path)
export_tflite_lstd_graph_lib.export_tflite_graph( export_tflite_lstd_graph_lib.export_tflite_graph(
pipeline_config, FLAGS.trained_checkpoint_prefix, FLAGS.output_directory, pipeline_config,
FLAGS.add_postprocessing_op, FLAGS.max_detections, FLAGS.trained_checkpoint_prefix,
FLAGS.max_classes_per_detection, use_regular_nms=FLAGS.use_regular_nms) FLAGS.output_directory,
FLAGS.add_postprocessing_op,
FLAGS.max_detections,
FLAGS.max_classes_per_detection,
use_regular_nms=FLAGS.use_regular_nms)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -12,26 +12,26 @@ ...@@ -12,26 +12,26 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
r"""Exports detection models to use with tf-lite. r"""Exports detection models to use with tf-lite.
See export_tflite_lstd_graph.py for usage. See export_tflite_lstd_graph.py for usage.
""" """
import os import os
import tempfile import tempfile
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import types_pb2 from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import saver_pb2 from tensorflow.core.protobuf import saver_pb2
from tensorflow.tools.graph_transforms import TransformGraph from tensorflow.tools.graph_transforms import TransformGraph
from lstm_object_detection import model_builder
from object_detection import exporter from object_detection import exporter
from object_detection.builders import graph_rewriter_builder from object_detection.builders import graph_rewriter_builder
from object_detection.builders import post_processing_builder from object_detection.builders import post_processing_builder
from object_detection.core import box_list from object_detection.core import box_list
from lstm_object_detection import model_builder
_DEFAULT_NUM_CHANNELS = 3 _DEFAULT_NUM_CHANNELS = 3
_DEFAULT_NUM_COORD_BOX = 4 _DEFAULT_NUM_COORD_BOX = 4
...@@ -84,11 +84,11 @@ def append_postprocessing_op(frozen_graph_def, ...@@ -84,11 +84,11 @@ def append_postprocessing_op(frozen_graph_def,
num_classes: number of classes in SSD detector num_classes: number of classes in SSD detector
scale_values: scale values is a dict with following key-value pairs scale_values: scale values is a dict with following key-value pairs
{y_scale: 10, x_scale: 10, h_scale: 5, w_scale: 5} that are used in decode {y_scale: 10, x_scale: 10, h_scale: 5, w_scale: 5} that are used in decode
centersize boxes centersize boxes
detections_per_class: In regular NonMaxSuppression, number of anchors used detections_per_class: In regular NonMaxSuppression, number of anchors used
for NonMaxSuppression per class for NonMaxSuppression per class
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead use_regular_nms: Flag to set postprocessing op to use Regular NMS instead of
of Fast NMS. Fast NMS.
Returns: Returns:
transformed_graph_def: Frozen GraphDef with postprocessing custom op transformed_graph_def: Frozen GraphDef with postprocessing custom op
...@@ -165,9 +165,9 @@ def export_tflite_graph(pipeline_config, ...@@ -165,9 +165,9 @@ def export_tflite_graph(pipeline_config,
is written to output_dir/tflite_graph.pb. is written to output_dir/tflite_graph.pb.
Args: Args:
pipeline_config: Dictionary of configuration objects. Keys are `model`, `train_config`, pipeline_config: Dictionary of configuration objects. Keys are `model`,
`train_input_config`, `eval_config`, `eval_input_config`, `lstm_model`. `train_config`, `train_input_config`, `eval_config`, `eval_input_config`,
Value are the corresponding config objects. `lstm_model`. Value are the corresponding config objects.
trained_checkpoint_prefix: a file prefix for the checkpoint containing the trained_checkpoint_prefix: a file prefix for the checkpoint containing the
trained parameters of the SSD model. trained parameters of the SSD model.
output_dir: A directory to write the tflite graph and anchor file to. output_dir: A directory to write the tflite graph and anchor file to.
...@@ -176,9 +176,9 @@ def export_tflite_graph(pipeline_config, ...@@ -176,9 +176,9 @@ def export_tflite_graph(pipeline_config,
max_detections: Maximum number of detections (boxes) to show max_detections: Maximum number of detections (boxes) to show
max_classes_per_detection: Number of classes to display per detection max_classes_per_detection: Number of classes to display per detection
detections_per_class: In regular NonMaxSuppression, number of anchors used detections_per_class: In regular NonMaxSuppression, number of anchors used
for NonMaxSuppression per class for NonMaxSuppression per class
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead use_regular_nms: Flag to set postprocessing op to use Regular NMS instead of
of Fast NMS. Fast NMS.
binary_graph_name: Name of the exported graph file in binary format. binary_graph_name: Name of the exported graph file in binary format.
txt_graph_name: Name of the exported graph file in text format. txt_graph_name: Name of the exported graph file in text format.
...@@ -197,12 +197,10 @@ def export_tflite_graph(pipeline_config, ...@@ -197,12 +197,10 @@ def export_tflite_graph(pipeline_config,
num_classes = model_config.ssd.num_classes num_classes = model_config.ssd.num_classes
nms_score_threshold = { nms_score_threshold = {
model_config.ssd.post_processing.batch_non_max_suppression. model_config.ssd.post_processing.batch_non_max_suppression.score_threshold
score_threshold
} }
nms_iou_threshold = { nms_iou_threshold = {
model_config.ssd.post_processing.batch_non_max_suppression. model_config.ssd.post_processing.batch_non_max_suppression.iou_threshold
iou_threshold
} }
scale_values = {} scale_values = {}
scale_values['y_scale'] = { scale_values['y_scale'] = {
...@@ -226,7 +224,7 @@ def export_tflite_graph(pipeline_config, ...@@ -226,7 +224,7 @@ def export_tflite_graph(pipeline_config,
width = image_resizer_config.fixed_shape_resizer.width width = image_resizer_config.fixed_shape_resizer.width
if image_resizer_config.fixed_shape_resizer.convert_to_grayscale: if image_resizer_config.fixed_shape_resizer.convert_to_grayscale:
num_channels = 1 num_channels = 1
#TODO(richardbrks) figure out how to make with a None defined batch size
shape = [lstm_config.eval_unroll_length, height, width, num_channels] shape = [lstm_config.eval_unroll_length, height, width, num_channels]
else: else:
raise ValueError( raise ValueError(
...@@ -235,14 +233,14 @@ def export_tflite_graph(pipeline_config, ...@@ -235,14 +233,14 @@ def export_tflite_graph(pipeline_config,
image_resizer_config.WhichOneof('image_resizer_oneof'))) image_resizer_config.WhichOneof('image_resizer_oneof')))
video_tensor = tf.placeholder( video_tensor = tf.placeholder(
tf.float32, shape=shape, name='input_video_tensor') tf.float32, shape=shape, name='input_video_tensor')
detection_model = model_builder.build(model_config, lstm_config, detection_model = model_builder.build(
is_training=False) model_config, lstm_config, is_training=False)
preprocessed_video, true_image_shapes = detection_model.preprocess( preprocessed_video, true_image_shapes = detection_model.preprocess(
tf.to_float(video_tensor)) tf.to_float(video_tensor))
predicted_tensors = detection_model.predict(preprocessed_video, predicted_tensors = detection_model.predict(preprocessed_video,
true_image_shapes) true_image_shapes)
# predicted_tensors = detection_model.postprocess(predicted_tensors, # predicted_tensors = detection_model.postprocess(predicted_tensors,
# true_image_shapes) # true_image_shapes)
# The score conversion occurs before the post-processing custom op # The score conversion occurs before the post-processing custom op
...@@ -311,7 +309,7 @@ def export_tflite_graph(pipeline_config, ...@@ -311,7 +309,7 @@ def export_tflite_graph(pipeline_config,
initializer_nodes='') initializer_nodes='')
# Add new operation to do post processing in a custom op (TF Lite only) # Add new operation to do post processing in a custom op (TF Lite only)
#(richardbrks) Do use this or detection_model.postprocess?
if add_postprocessing_op: if add_postprocessing_op:
transformed_graph_def = append_postprocessing_op( transformed_graph_def = append_postprocessing_op(
frozen_graph_def, max_detections, max_classes_per_detection, frozen_graph_def, max_detections, max_classes_per_detection,
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Export a LSTD model in tflite format."""
import os import os
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
...@@ -29,34 +31,35 @@ FLAGS = flags.FLAGS ...@@ -29,34 +31,35 @@ FLAGS = flags.FLAGS
def main(_): def main(_):
flags.mark_flag_as_required('export_path') flags.mark_flag_as_required('export_path')
flags.mark_flag_as_required('frozen_graph_path') flags.mark_flag_as_required('frozen_graph_path')
flags.mark_flag_as_required('pipeline_config_path') flags.mark_flag_as_required('pipeline_config_path')
configs = config_util.get_configs_from_pipeline_file( configs = config_util.get_configs_from_pipeline_file(
FLAGS.pipeline_config_path) FLAGS.pipeline_config_path)
lstm_config = configs['lstm_model'] lstm_config = configs['lstm_model']
input_arrays = ['input_video_tensor'] input_arrays = ['input_video_tensor']
output_arrays = [ output_arrays = [
'TFLite_Detection_PostProcess', 'TFLite_Detection_PostProcess',
'TFLite_Detection_PostProcess:1', 'TFLite_Detection_PostProcess:1',
'TFLite_Detection_PostProcess:2', 'TFLite_Detection_PostProcess:2',
'TFLite_Detection_PostProcess:3', 'TFLite_Detection_PostProcess:3',
] ]
input_shapes = { input_shapes = {
'input_video_tensor': [lstm_config.eval_unroll_length, 320, 320, 3], 'input_video_tensor': [lstm_config.eval_unroll_length, 320, 320, 3],
} }
converter = tf.lite.TFLiteConverter.from_frozen_graph( converter = tf.lite.TFLiteConverter.from_frozen_graph(
FLAGS.frozen_graph_path, input_arrays, output_arrays, FLAGS.frozen_graph_path,
input_shapes=input_shapes input_arrays,
) output_arrays,
converter.allow_custom_ops = True input_shapes=input_shapes)
tflite_model = converter.convert() converter.allow_custom_ops = True
ofilename = os.path.join(FLAGS.export_path) tflite_model = converter.convert()
open(ofilename, "wb").write(tflite_model) ofilename = os.path.join(FLAGS.export_path)
open(ofilename, 'wb').write(tflite_model)
if __name__ == '__main__': if __name__ == '__main__':
tf.app.run() tf.app.run()
...@@ -59,12 +59,19 @@ cc_library( ...@@ -59,12 +59,19 @@ cc_library(
name = "mobile_lstd_tflite_client", name = "mobile_lstd_tflite_client",
srcs = ["mobile_lstd_tflite_client.cc"], srcs = ["mobile_lstd_tflite_client.cc"],
hdrs = ["mobile_lstd_tflite_client.h"], hdrs = ["mobile_lstd_tflite_client.h"],
defines = select({
"//conditions:default": [],
"enable_edgetpu": ["ENABLE_EDGETPU"],
}),
deps = [ deps = [
":mobile_ssd_client", ":mobile_ssd_client",
":mobile_ssd_tflite_client", ":mobile_ssd_tflite_client",
"@com_google_absl//absl/base:core_headers",
"@com_google_glog//:glog", "@com_google_glog//:glog",
"@com_google_absl//absl/base:core_headers",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
], ] + select({
"//conditions:default": [],
"enable_edgetpu": ["@libedgetpu//libedgetpu:header"],
}),
alwayslink = 1, alwayslink = 1,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment