Commit 901c4cc4 authored by Vinh Nguyen's avatar Vinh Nguyen
Browse files

Merge remote-tracking branch 'upstream/master' into amp_resnet50

parents ef30de93 824ff2d6
......@@ -22,7 +22,7 @@ import os
from absl import logging
import tensorflow as tf
from official.resnet.keras import imagenet_preprocessing
from official.vision.image_classification import imagenet_preprocessing
HEIGHT = 32
WIDTH = 32
......
......@@ -20,17 +20,13 @@ from __future__ import print_function
import multiprocessing
import os
import numpy as np
# pylint: disable=g-bad-import-order
from absl import flags
import numpy as np
import tensorflow as tf
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
from official.utils.flags import core as flags_core
from official.utils.misc import keras_utils
# pylint: disable=ungrouped-imports
from tensorflow.python.keras.optimizer_v2 import (gradient_descent as
gradient_descent_v2)
FLAGS = flags.FLAGS
BASE_LEARNING_RATE = 0.1 # This matches Jing's version.
......@@ -262,6 +258,7 @@ def define_keras_flags(dynamic_loss_scale=True):
force_v2_in_keras_compile=True)
flags_core.define_image()
flags_core.define_benchmark()
flags_core.define_distribution()
flags.adopt_module_key_flags(flags_core)
flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?')
......
......@@ -12,21 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the keras_common module."""
"""Tests for the common module."""
from __future__ import absolute_import
from __future__ import print_function
from mock import Mock
import numpy as np
import tensorflow as tf # pylint: disable=g-bad-import-order
from tensorflow.python.platform import googletest
import tensorflow as tf
from official.resnet.keras import keras_common
from tensorflow.python.platform import googletest
from official.utils.misc import keras_utils
from official.vision.image_classification import common
class KerasCommonTests(tf.test.TestCase):
"""Tests for keras_common."""
"""Tests for common."""
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
......@@ -42,7 +42,7 @@ class KerasCommonTests(tf.test.TestCase):
keras_utils.BatchTimestamp(1, 2),
keras_utils.BatchTimestamp(2, 3)]
th.train_finish_time = 12345
stats = keras_common.build_stats(history, eval_output, [th])
stats = common.build_stats(history, eval_output, [th])
self.assertEqual(1.145, stats['loss'])
self.assertEqual(.99988, stats['training_accuracy_top_1'])
......@@ -57,7 +57,7 @@ class KerasCommonTests(tf.test.TestCase):
history = self._build_history(1.145, cat_accuracy_sparse=.99988)
eval_output = self._build_eval_output(.928, 1.9844)
stats = keras_common.build_stats(history, eval_output, None)
stats = common.build_stats(history, eval_output, None)
self.assertEqual(1.145, stats['loss'])
self.assertEqual(.99988, stats['training_accuracy_top_1'])
......
......@@ -22,13 +22,13 @@ from absl import app as absl_app
from absl import flags
import tensorflow as tf
from official.resnet.keras import cifar_preprocessing
from official.resnet.keras import keras_common
from official.resnet.keras import resnet_cifar_model
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.vision.image_classification import cifar_preprocessing
from official.vision.image_classification import common
from official.vision.image_classification import resnet_cifar_model
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
......@@ -55,7 +55,7 @@ def learning_rate_schedule(current_epoch,
Adjusted learning rate.
"""
del current_batch, batches_per_epoch # not used
initial_learning_rate = keras_common.BASE_LEARNING_RATE * batch_size / 128
initial_learning_rate = common.BASE_LEARNING_RATE * batch_size / 128
learning_rate = initial_learning_rate
for mult, start_epoch in LR_SCHEDULE:
if current_epoch >= start_epoch:
......@@ -83,8 +83,8 @@ def run(flags_obj):
# Execute flag override logic for better model performance
if flags_obj.tf_gpu_thread_mode:
keras_common.set_gpu_thread_mode_and_count(flags_obj)
keras_common.set_cudnn_batchnorm_mode()
common.set_gpu_thread_mode_and_count(flags_obj)
common.set_cudnn_batchnorm_mode()
dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'fp16':
......@@ -116,7 +116,7 @@ def run(flags_obj):
if flags_obj.use_synthetic_data:
distribution_utils.set_up_synthetic_data()
input_fn = keras_common.get_synth_input_fn(
input_fn = common.get_synth_input_fn(
height=cifar_preprocessing.HEIGHT,
width=cifar_preprocessing.WIDTH,
num_channels=cifar_preprocessing.NUM_CHANNELS,
......@@ -150,7 +150,7 @@ def run(flags_obj):
parse_record_fn=cifar_preprocessing.parse_record)
with strategy_scope:
optimizer = keras_common.get_optimizer()
optimizer = common.get_optimizer()
model = resnet_cifar_model.resnet56(classes=cifar_preprocessing.NUM_CLASSES)
# TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
......@@ -171,7 +171,7 @@ def run(flags_obj):
if flags_obj.report_accuracy_metrics else None),
run_eagerly=flags_obj.run_eagerly)
callbacks = keras_common.get_callbacks(
callbacks = common.get_callbacks(
learning_rate_schedule, cifar_preprocessing.NUM_IMAGES['train'])
train_steps = cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size
......@@ -216,12 +216,12 @@ def run(flags_obj):
if not strategy and flags_obj.explicit_gpu_placement:
no_dist_strat_device.__exit__()
stats = keras_common.build_stats(history, eval_output, callbacks)
stats = common.build_stats(history, eval_output, callbacks)
return stats
def define_cifar_flags():
keras_common.define_keras_flags(dynamic_loss_scale=False)
common.define_keras_flags(dynamic_loss_scale=False)
flags_core.set_defaults(data_dir='/tmp/cifar10_data/cifar-10-batches-bin',
model_dir='/tmp/cifar10_model',
......
......@@ -18,17 +18,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tempfile import mkdtemp
import tempfile
import tensorflow as tf
from official.resnet.keras import cifar_preprocessing
from official.resnet.keras import keras_cifar_main
from official.resnet.keras import keras_common
from official.utils.misc import keras_utils
from official.utils.testing import integration
# pylint: disable=ungrouped-imports
from tensorflow.python.eager import context
from tensorflow.python.platform import googletest
from official.utils.misc import keras_utils
from official.utils.testing import integration
from official.vision.image_classification import cifar_preprocessing
from official.vision.image_classification import resnet_cifar_main
class KerasCifarTest(googletest.TestCase):
......@@ -43,13 +42,13 @@ class KerasCifarTest(googletest.TestCase):
def get_temp_dir(self):
if not self._tempdir:
self._tempdir = mkdtemp(dir=googletest.GetTempDir())
self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir())
return self._tempdir
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(KerasCifarTest, cls).setUpClass()
keras_cifar_main.define_cifar_flags()
resnet_cifar_main.define_cifar_flags()
def setUp(self):
super(KerasCifarTest, self).setUp()
......@@ -72,7 +71,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_cifar_main.run,
main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -88,7 +87,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_cifar_main.run,
main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -112,7 +111,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_cifar_main.run,
main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -134,7 +133,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_cifar_main.run,
main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -157,7 +156,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_cifar_main.run,
main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -178,7 +177,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_cifar_main.run,
main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......
......@@ -21,17 +21,17 @@ from __future__ import print_function
from absl import app as absl_app
from absl import flags
from absl import logging
import tensorflow as tf # pylint: disable=g-bad-import-order
import tensorflow as tf
from official.resnet.keras import imagenet_preprocessing
from official.resnet.keras import keras_common
from official.resnet.keras import resnet_model
from official.resnet.keras import trivial_model
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.utils.misc import model_helpers
from official.vision.image_classification import common
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import resnet_model
from official.vision.image_classification import trivial_model
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
......@@ -57,7 +57,7 @@ def learning_rate_schedule(current_epoch,
Returns:
Adjusted learning rate.
"""
initial_lr = keras_common.BASE_LEARNING_RATE * batch_size / 256
initial_lr = common.BASE_LEARNING_RATE * batch_size / 256
epoch = current_epoch + float(current_batch) / batches_per_epoch
warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0]
if epoch < warmup_end_epoch:
......@@ -89,10 +89,10 @@ def run(flags_obj):
# Execute flag override logic for better model performance
if flags_obj.tf_gpu_thread_mode:
keras_common.set_gpu_thread_mode_and_count(flags_obj)
common.set_gpu_thread_mode_and_count(flags_obj)
if flags_obj.data_delay_prefetch:
keras_common.data_delay_prefetch()
keras_common.set_cudnn_batchnorm_mode()
common.data_delay_prefetch()
common.set_cudnn_batchnorm_mode()
dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'float16':
......@@ -105,10 +105,14 @@ def run(flags_obj):
if tf.test.is_built_with_cuda() else 'channels_last')
tf.keras.backend.set_image_data_format(data_format)
# Configures cluster spec for distribution strategy.
num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts,
flags_obj.task_index)
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus,
num_workers=distribution_utils.configure_cluster(),
num_workers=num_workers,
all_reduce_alg=flags_obj.all_reduce_alg,
num_packs=flags_obj.num_packs)
......@@ -125,7 +129,7 @@ def run(flags_obj):
# pylint: disable=protected-access
if flags_obj.use_synthetic_data:
distribution_utils.set_up_synthetic_data()
input_fn = keras_common.get_synth_input_fn(
input_fn = common.get_synth_input_fn(
height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_preprocessing.NUM_CHANNELS,
......@@ -165,7 +169,7 @@ def run(flags_obj):
lr_schedule = 0.1
if flags_obj.use_tensor_lr:
lr_schedule = keras_common.PiecewiseConstantDecayWithWarmup(
lr_schedule = common.PiecewiseConstantDecayWithWarmup(
batch_size=flags_obj.batch_size,
epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
warmup_epochs=LR_SCHEDULE[0][1],
......@@ -174,7 +178,7 @@ def run(flags_obj):
compute_lr_on_cpu=True)
with strategy_scope:
optimizer = keras_common.get_optimizer(lr_schedule)
optimizer = common.get_optimizer(lr_schedule)
if dtype == 'float16':
# TODO(reedwm): Remove manually wrapping optimizer once mixed precision
# can be enabled with a single line of code.
......@@ -212,7 +216,7 @@ def run(flags_obj):
if flags_obj.report_accuracy_metrics else None),
run_eagerly=flags_obj.run_eagerly)
callbacks = keras_common.get_callbacks(
callbacks = common.get_callbacks(
learning_rate_schedule, imagenet_preprocessing.NUM_IMAGES['train'])
train_steps = (
......@@ -262,13 +266,14 @@ def run(flags_obj):
if not strategy and flags_obj.explicit_gpu_placement:
no_dist_strat_device.__exit__()
stats = keras_common.build_stats(history, eval_output, callbacks)
stats = common.build_stats(history, eval_output, callbacks)
return stats
def define_imagenet_keras_flags():
keras_common.define_keras_flags()
common.define_keras_flags()
flags_core.set_defaults(train_epochs=90)
flags.adopt_module_key_flags(common)
def main(_):
......
......@@ -18,16 +18,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tempfile import mkdtemp
import tempfile
import tensorflow as tf
from official.resnet.keras import imagenet_preprocessing
from official.resnet.keras import keras_imagenet_main
from official.utils.misc import keras_utils
from official.utils.testing import integration
# pylint: disable=ungrouped-imports
from tensorflow.python.eager import context
from tensorflow.python.platform import googletest
from official.utils.misc import keras_utils
from official.utils.testing import integration
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import resnet_imagenet_main
class KerasImagenetTest(googletest.TestCase):
......@@ -42,13 +42,13 @@ class KerasImagenetTest(googletest.TestCase):
def get_temp_dir(self):
if not self._tempdir:
self._tempdir = mkdtemp(dir=googletest.GetTempDir())
self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir())
return self._tempdir
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(KerasImagenetTest, cls).setUpClass()
keras_imagenet_main.define_imagenet_keras_flags()
resnet_imagenet_main.define_imagenet_keras_flags()
def setUp(self):
super(KerasImagenetTest, self).setUp()
......@@ -71,7 +71,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_imagenet_main.run,
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -87,7 +87,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_imagenet_main.run,
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -111,7 +111,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_imagenet_main.run,
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -133,7 +133,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_imagenet_main.run,
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -156,7 +156,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_imagenet_main.run,
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -180,7 +180,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_imagenet_main.run,
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -204,7 +204,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_imagenet_main.run,
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -229,7 +229,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_imagenet_main.run,
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -250,7 +250,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_imagenet_main.run,
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -272,7 +272,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_imagenet_main.run,
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......
......@@ -32,3 +32,8 @@ https://scholar.googleusercontent.com/scholar.bib?q=info:rLqvkztmWYgJ:scholar.go
* yinxiao@google.com
* menglong@google.com
* yongzhe@google.com
## Table of Contents
* <a href='g3doc/exporting_models.md'>Exporting a trained model</a>
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# For training on Imagenet Video with LSTM Interleaved Mobilenet V2
[lstm_object_detection.protos.lstm_model] {
train_unroll_length: 4
eval_unroll_length: 4
lstm_state_depth: 320
depth_multipliers: 1.4
depth_multipliers: 0.35
pre_bottleneck: true
low_res: true
train_interleave_method: 'RANDOM_SKIP_SMALL'
eval_interleave_method: 'SKIP3'
}
model {
ssd {
num_classes: 30 # Num of class for imagenet vid dataset.
box_coder {
faster_rcnn_box_coder {
y_scale: 10.0
x_scale: 10.0
height_scale: 5.0
width_scale: 5.0
}
}
matcher {
argmax_matcher {
matched_threshold: 0.5
unmatched_threshold: 0.5
ignore_thresholds: false
negatives_lower_than_unmatched: true
force_match_for_each_row: true
}
}
similarity_calculator {
iou_similarity {
}
}
anchor_generator {
ssd_anchor_generator {
num_layers: 5
min_scale: 0.2
max_scale: 0.95
aspect_ratios: 1.0
aspect_ratios: 2.0
aspect_ratios: 0.5
aspect_ratios: 3.0
aspect_ratios: 0.3333
}
}
image_resizer {
fixed_shape_resizer {
height: 320
width: 320
}
}
box_predictor {
convolutional_box_predictor {
min_depth: 0
max_depth: 0
num_layers_before_predictor: 3
use_dropout: false
dropout_keep_probability: 0.8
kernel_size: 3
box_code_size: 4
apply_sigmoid_to_scores: false
use_depthwise: true
conv_hyperparams {
activation: RELU_6,
regularizer {
l2_regularizer {
weight: 0.00004
}
}
initializer {
truncated_normal_initializer {
stddev: 0.03
mean: 0.0
}
}
batch_norm {
train: true,
scale: true,
center: true,
decay: 0.9997,
epsilon: 0.001,
}
}
}
}
feature_extractor {
type: 'lstm_ssd_interleaved_mobilenet_v2'
conv_hyperparams {
activation: RELU_6,
regularizer {
l2_regularizer {
weight: 0.00004
}
}
initializer {
truncated_normal_initializer {
stddev: 0.03
mean: 0.0
}
}
batch_norm {
train: true,
scale: true,
center: true,
decay: 0.9997,
epsilon: 0.001,
}
}
}
loss {
classification_loss {
weighted_sigmoid {
}
}
localization_loss {
weighted_smooth_l1 {
}
}
hard_example_miner {
num_hard_examples: 3000
iou_threshold: 0.99
loss_type: CLASSIFICATION
max_negatives_per_positive: 3
min_negatives_per_image: 0
}
classification_weight: 1.0
localization_weight: 4.0
}
normalize_loss_by_num_matches: true
post_processing {
batch_non_max_suppression {
score_threshold: -20.0
iou_threshold: 0.5
max_detections_per_class: 100
max_total_detections: 100
}
score_converter: SIGMOID
}
}
}
train_config: {
batch_size: 8
optimizer {
use_moving_average: false
rms_prop_optimizer: {
learning_rate: {
exponential_decay_learning_rate {
initial_learning_rate: 0.002
decay_steps: 200000
decay_factor: 0.95
}
}
momentum_optimizer_value: 0.9
decay: 0.9
epsilon: 1.0
}
}
gradient_clipping_by_norm: 10.0
batch_queue_capacity: 12
prefetch_queue_capacity: 4
}
train_input_reader: {
shuffle_buffer_size: 32
queue_capacity: 12
prefetch_size: 12
min_after_dequeue: 4
label_map_path: "path/to/label_map"
external_input_reader {
[lstm_object_detection.protos.GoogleInputReader.google_input_reader] {
tf_record_video_input_reader: {
input_path: '/data/lstm_detection/tfrecords/test.tfrecord'
data_type: TF_SEQUENCE_EXAMPLE
video_length: 4
}
}
}
}
eval_config: {
metrics_set: "coco_evaluation_all_frames"
use_moving_averages: true
min_score_threshold: 0.5
max_num_boxes_to_visualize: 300
visualize_groundtruth_boxes: true
groundtruth_box_visualization_color: "red"
}
eval_input_reader {
label_map_path: "path/to/label_map"
shuffle: true
num_epochs: 1
num_parallel_batches: 1
num_readers: 1
external_input_reader {
[lstm_object_detection.protos.GoogleInputReader.google_input_reader] {
tf_record_video_input_reader: {
input_path: "path/to/sequence_example/data"
data_type: TF_SEQUENCE_EXAMPLE
video_length: 10
}
}
}
}
eval_input_reader: {
label_map_path: "path/to/label_map"
external_input_reader {
[lstm_object_detection.protos.GoogleInputReader.google_input_reader] {
tf_record_video_input_reader: {
input_path: "path/to/sequence_example/data"
data_type: TF_SEQUENCE_EXAMPLE
video_length: 4
}
}
}
shuffle: true
num_readers: 1
}
......@@ -27,8 +27,6 @@ import functools
import os
import tensorflow as tf
from google.protobuf import text_format
from google3.pyglib import app
from google3.pyglib import flags
from lstm_object_detection import evaluator
from lstm_object_detection import model_builder
from lstm_object_detection.inputs import seq_dataset_builder
......@@ -107,4 +105,4 @@ def main(unused_argv):
FLAGS.checkpoint_dir, FLAGS.eval_dir)
if __name__ == '__main__':
app.run()
tf.app.run()
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Exports an LSTM detection model to use with tf-lite.
Outputs file:
* A tflite compatible frozen graph - $output_directory/tflite_graph.pb
The exported graph has the following input and output nodes.
Inputs:
'input_video_tensor': a float32 tensor of shape
[unroll_length, height, width, 3] containing the normalized input image.
Note that the height and width must be compatible with the height and
width configured in the fixed_shape_image resizer options in the pipeline
config proto.
Outputs:
If add_postprocessing_op is true: frozen graph adds a
TFLite_Detection_PostProcess custom op node has four outputs:
detection_boxes: a float32 tensor of shape [1, num_boxes, 4] with box
locations
detection_classes: a float32 tensor of shape [1, num_boxes]
with class indices
detection_scores: a float32 tensor of shape [1, num_boxes]
with class scores
num_boxes: a float32 tensor of size 1 containing the number of detected boxes
else:
the graph has three outputs:
'raw_outputs/box_encodings': a float32 tensor of shape [1, num_anchors, 4]
containing the encoded box predictions.
'raw_outputs/class_predictions': a float32 tensor of shape
[1, num_anchors, num_classes] containing the class scores for each anchor
after applying score conversion.
'anchors': a float32 constant tensor of shape [num_anchors, 4]
containing the anchor boxes.
Example Usage:
--------------
python lstm_object_detection/export_tflite_lstd_graph.py \
--pipeline_config_path path/to/lstm_pipeline.config \
--trained_checkpoint_prefix path/to/model.ckpt \
--output_directory path/to/exported_model_directory
The expected output would be in the directory
path/to/exported_model_directory (which is created if it does not exist)
with contents:
- tflite_graph.pbtxt
- tflite_graph.pb
Config overrides (see the `config_override` flag) are text protobufs
(also of type pipeline_pb2.TrainEvalPipelineConfig) which are used to override
certain fields in the provided pipeline_config_path. These are useful for
making small changes to the inference graph that differ from the training or
eval config.
Example Usage (in which we change the NMS iou_threshold to be 0.5 and
NMS score_threshold to be 0.0):
python lstm_object_detection/export_tflite_lstd_graph.py \
--pipeline_config_path path/to/lstm_pipeline.config \
--trained_checkpoint_prefix path/to/model.ckpt \
--output_directory path/to/exported_model_directory
--config_override " \
model{ \
ssd{ \
post_processing { \
batch_non_max_suppression { \
score_threshold: 0.0 \
iou_threshold: 0.5 \
} \
} \
} \
} \
"
"""
import tensorflow as tf
from lstm_object_detection import export_tflite_lstd_graph_lib
from lstm_object_detection.utils import config_util
flags = tf.app.flags
flags.DEFINE_string('output_directory', None, 'Path to write outputs.')
flags.DEFINE_string(
'pipeline_config_path', None,
'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
'file.')
flags.DEFINE_string('trained_checkpoint_prefix', None, 'Checkpoint prefix.')
flags.DEFINE_integer('max_detections', 10,
'Maximum number of detections (boxes) to show.')
flags.DEFINE_integer('max_classes_per_detection', 1,
'Maximum number of classes to output per detection box.')
flags.DEFINE_integer(
'detections_per_class', 100,
'Number of anchors used per class in Regular Non-Max-Suppression.')
flags.DEFINE_bool('add_postprocessing_op', True,
'Add TFLite custom op for postprocessing to the graph.')
flags.DEFINE_bool(
'use_regular_nms', False,
'Flag to set postprocessing op to use Regular NMS instead of Fast NMS.')
flags.DEFINE_string(
'config_override', '', 'pipeline_pb2.TrainEvalPipelineConfig '
'text proto to override pipeline_config_path.')
FLAGS = flags.FLAGS
def main(argv):
del argv # Unused.
flags.mark_flag_as_required('output_directory')
flags.mark_flag_as_required('pipeline_config_path')
flags.mark_flag_as_required('trained_checkpoint_prefix')
pipeline_config = config_util.get_configs_from_pipeline_file(
FLAGS.pipeline_config_path)
export_tflite_lstd_graph_lib.export_tflite_graph(
pipeline_config,
FLAGS.trained_checkpoint_prefix,
FLAGS.output_directory,
FLAGS.add_postprocessing_op,
FLAGS.max_detections,
FLAGS.max_classes_per_detection,
use_regular_nms=FLAGS.use_regular_nms)
if __name__ == '__main__':
tf.app.run(main)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Exports detection models to use with tf-lite.
See export_tflite_lstd_graph.py for usage.
"""
import os
import tempfile
import numpy as np
import tensorflow as tf
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.tools.graph_transforms import TransformGraph
from lstm_object_detection import model_builder
from object_detection import exporter
from object_detection.builders import graph_rewriter_builder
from object_detection.builders import post_processing_builder
from object_detection.core import box_list
_DEFAULT_NUM_CHANNELS = 3
_DEFAULT_NUM_COORD_BOX = 4
def get_const_center_size_encoded_anchors(anchors):
"""Exports center-size encoded anchors as a constant tensor.
Args:
anchors: a float32 tensor of shape [num_anchors, 4] containing the anchor
boxes
Returns:
encoded_anchors: a float32 constant tensor of shape [num_anchors, 4]
containing the anchor boxes.
"""
anchor_boxlist = box_list.BoxList(anchors)
y, x, h, w = anchor_boxlist.get_center_coordinates_and_sizes()
num_anchors = y.get_shape().as_list()
with tf.Session() as sess:
y_out, x_out, h_out, w_out = sess.run([y, x, h, w])
encoded_anchors = tf.constant(
np.transpose(np.stack((y_out, x_out, h_out, w_out))),
dtype=tf.float32,
shape=[num_anchors[0], _DEFAULT_NUM_COORD_BOX],
name='anchors')
return encoded_anchors
def append_postprocessing_op(frozen_graph_def,
max_detections,
max_classes_per_detection,
nms_score_threshold,
nms_iou_threshold,
num_classes,
scale_values,
detections_per_class=100,
use_regular_nms=False):
"""Appends postprocessing custom op.
Args:
frozen_graph_def: Frozen GraphDef for SSD model after freezing the
checkpoint
max_detections: Maximum number of detections (boxes) to show
max_classes_per_detection: Number of classes to display per detection
nms_score_threshold: Score threshold used in Non-maximal suppression in
post-processing
nms_iou_threshold: Intersection-over-union threshold used in Non-maximal
suppression in post-processing
num_classes: number of classes in SSD detector
scale_values: scale values is a dict with following key-value pairs
{y_scale: 10, x_scale: 10, h_scale: 5, w_scale: 5} that are used in decode
centersize boxes
detections_per_class: In regular NonMaxSuppression, number of anchors used
for NonMaxSuppression per class
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead of
Fast NMS.
Returns:
transformed_graph_def: Frozen GraphDef with postprocessing custom op
appended
TFLite_Detection_PostProcess custom op node has four outputs:
detection_boxes: a float32 tensor of shape [1, num_boxes, 4] with box
locations
detection_classes: a float32 tensor of shape [1, num_boxes]
with class indices
detection_scores: a float32 tensor of shape [1, num_boxes]
with class scores
num_boxes: a float32 tensor of size 1 containing the number of detected
boxes
"""
new_output = frozen_graph_def.node.add()
new_output.op = 'TFLite_Detection_PostProcess'
new_output.name = 'TFLite_Detection_PostProcess'
new_output.attr['_output_quantized'].CopyFrom(
attr_value_pb2.AttrValue(b=True))
new_output.attr['_output_types'].list.type.extend([
types_pb2.DT_FLOAT, types_pb2.DT_FLOAT, types_pb2.DT_FLOAT,
types_pb2.DT_FLOAT
])
new_output.attr['_support_output_type_float_in_quantized_op'].CopyFrom(
attr_value_pb2.AttrValue(b=True))
new_output.attr['max_detections'].CopyFrom(
attr_value_pb2.AttrValue(i=max_detections))
new_output.attr['max_classes_per_detection'].CopyFrom(
attr_value_pb2.AttrValue(i=max_classes_per_detection))
new_output.attr['nms_score_threshold'].CopyFrom(
attr_value_pb2.AttrValue(f=nms_score_threshold.pop()))
new_output.attr['nms_iou_threshold'].CopyFrom(
attr_value_pb2.AttrValue(f=nms_iou_threshold.pop()))
new_output.attr['num_classes'].CopyFrom(
attr_value_pb2.AttrValue(i=num_classes))
new_output.attr['y_scale'].CopyFrom(
attr_value_pb2.AttrValue(f=scale_values['y_scale'].pop()))
new_output.attr['x_scale'].CopyFrom(
attr_value_pb2.AttrValue(f=scale_values['x_scale'].pop()))
new_output.attr['h_scale'].CopyFrom(
attr_value_pb2.AttrValue(f=scale_values['h_scale'].pop()))
new_output.attr['w_scale'].CopyFrom(
attr_value_pb2.AttrValue(f=scale_values['w_scale'].pop()))
new_output.attr['detections_per_class'].CopyFrom(
attr_value_pb2.AttrValue(i=detections_per_class))
new_output.attr['use_regular_nms'].CopyFrom(
attr_value_pb2.AttrValue(b=use_regular_nms))
new_output.input.extend(
['raw_outputs/box_encodings', 'raw_outputs/class_predictions', 'anchors'])
# Transform the graph to append new postprocessing op
input_names = []
output_names = ['TFLite_Detection_PostProcess']
transforms = ['strip_unused_nodes']
transformed_graph_def = TransformGraph(frozen_graph_def, input_names,
output_names, transforms)
return transformed_graph_def
def export_tflite_graph(pipeline_config,
trained_checkpoint_prefix,
output_dir,
add_postprocessing_op,
max_detections,
max_classes_per_detection,
detections_per_class=100,
use_regular_nms=False,
binary_graph_name='tflite_graph.pb',
txt_graph_name='tflite_graph.pbtxt'):
"""Exports a tflite compatible graph and anchors for ssd detection model.
Anchors are written to a tensor and tflite compatible graph
is written to output_dir/tflite_graph.pb.
Args:
pipeline_config: Dictionary of configuration objects. Keys are `model`,
`train_config`, `train_input_config`, `eval_config`, `eval_input_config`,
`lstm_model`. Value are the corresponding config objects.
trained_checkpoint_prefix: a file prefix for the checkpoint containing the
trained parameters of the SSD model.
output_dir: A directory to write the tflite graph and anchor file to.
add_postprocessing_op: If add_postprocessing_op is true: frozen graph adds a
TFLite_Detection_PostProcess custom op
max_detections: Maximum number of detections (boxes) to show
max_classes_per_detection: Number of classes to display per detection
detections_per_class: In regular NonMaxSuppression, number of anchors used
for NonMaxSuppression per class
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead of
Fast NMS.
binary_graph_name: Name of the exported graph file in binary format.
txt_graph_name: Name of the exported graph file in text format.
Raises:
ValueError: if the pipeline config contains models other than ssd or uses an
fixed_shape_resizer and provides a shape as well.
"""
model_config = pipeline_config['model']
lstm_config = pipeline_config['lstm_model']
eval_config = pipeline_config['eval_config']
tf.gfile.MakeDirs(output_dir)
if model_config.WhichOneof('model') != 'ssd':
raise ValueError('Only ssd models are supported in tflite. '
'Found {} in config'.format(
model_config.WhichOneof('model')))
num_classes = model_config.ssd.num_classes
nms_score_threshold = {
model_config.ssd.post_processing.batch_non_max_suppression.score_threshold
}
nms_iou_threshold = {
model_config.ssd.post_processing.batch_non_max_suppression.iou_threshold
}
scale_values = {}
scale_values['y_scale'] = {
model_config.ssd.box_coder.faster_rcnn_box_coder.y_scale
}
scale_values['x_scale'] = {
model_config.ssd.box_coder.faster_rcnn_box_coder.x_scale
}
scale_values['h_scale'] = {
model_config.ssd.box_coder.faster_rcnn_box_coder.height_scale
}
scale_values['w_scale'] = {
model_config.ssd.box_coder.faster_rcnn_box_coder.width_scale
}
image_resizer_config = model_config.ssd.image_resizer
image_resizer = image_resizer_config.WhichOneof('image_resizer_oneof')
num_channels = _DEFAULT_NUM_CHANNELS
if image_resizer == 'fixed_shape_resizer':
height = image_resizer_config.fixed_shape_resizer.height
width = image_resizer_config.fixed_shape_resizer.width
if image_resizer_config.fixed_shape_resizer.convert_to_grayscale:
num_channels = 1
shape = [lstm_config.eval_unroll_length, height, width, num_channels]
else:
raise ValueError(
'Only fixed_shape_resizer'
'is supported with tflite. Found {}'.format(
image_resizer_config.WhichOneof('image_resizer_oneof')))
video_tensor = tf.placeholder(
tf.float32, shape=shape, name='input_video_tensor')
detection_model = model_builder.build(
model_config, lstm_config, is_training=False)
preprocessed_video, true_image_shapes = detection_model.preprocess(
tf.to_float(video_tensor))
predicted_tensors = detection_model.predict(preprocessed_video,
true_image_shapes)
# predicted_tensors = detection_model.postprocess(predicted_tensors,
# true_image_shapes)
# The score conversion occurs before the post-processing custom op
_, score_conversion_fn = post_processing_builder.build(
model_config.ssd.post_processing)
class_predictions = score_conversion_fn(
predicted_tensors['class_predictions_with_background'])
with tf.name_scope('raw_outputs'):
# 'raw_outputs/box_encodings': a float32 tensor of shape [1, num_anchors, 4]
# containing the encoded box predictions. Note that these are raw
# predictions and no Non-Max suppression is applied on them and
# no decode center size boxes is applied to them.
tf.identity(predicted_tensors['box_encodings'], name='box_encodings')
# 'raw_outputs/class_predictions': a float32 tensor of shape
# [1, num_anchors, num_classes] containing the class scores for each anchor
# after applying score conversion.
tf.identity(class_predictions, name='class_predictions')
# 'anchors': a float32 tensor of shape
# [4, num_anchors] containing the anchors as a constant node.
tf.identity(
get_const_center_size_encoded_anchors(predicted_tensors['anchors']),
name='anchors')
# Add global step to the graph, so we know the training step number when we
# evaluate the model.
tf.train.get_or_create_global_step()
# graph rewriter
is_quantized = ('graph_rewriter' in pipeline_config)
if is_quantized:
graph_rewriter_config = pipeline_config['graph_rewriter']
graph_rewriter_fn = graph_rewriter_builder.build(
graph_rewriter_config, is_training=False, is_export=True)
graph_rewriter_fn()
if model_config.ssd.feature_extractor.HasField('fpn'):
exporter.rewrite_nn_resize_op(is_quantized)
# freeze the graph
saver_kwargs = {}
if eval_config.use_moving_averages:
saver_kwargs['write_version'] = saver_pb2.SaverDef.V1
moving_average_checkpoint = tempfile.NamedTemporaryFile()
exporter.replace_variable_values_with_moving_averages(
tf.get_default_graph(), trained_checkpoint_prefix,
moving_average_checkpoint.name)
checkpoint_to_use = moving_average_checkpoint.name
else:
checkpoint_to_use = trained_checkpoint_prefix
saver = tf.train.Saver(**saver_kwargs)
input_saver_def = saver.as_saver_def()
frozen_graph_def = exporter.freeze_graph_with_def_protos(
input_graph_def=tf.get_default_graph().as_graph_def(),
input_saver_def=input_saver_def,
input_checkpoint=checkpoint_to_use,
output_node_names=','.join([
'raw_outputs/box_encodings', 'raw_outputs/class_predictions',
'anchors'
]),
restore_op_name='save/restore_all',
filename_tensor_name='save/Const:0',
clear_devices=True,
output_graph='',
initializer_nodes='')
# Add new operation to do post processing in a custom op (TF Lite only)
if add_postprocessing_op:
transformed_graph_def = append_postprocessing_op(
frozen_graph_def, max_detections, max_classes_per_detection,
nms_score_threshold, nms_iou_threshold, num_classes, scale_values,
detections_per_class, use_regular_nms)
else:
# Return frozen without adding post-processing custom op
transformed_graph_def = frozen_graph_def
binary_graph = os.path.join(output_dir, binary_graph_name)
with tf.gfile.GFile(binary_graph, 'wb') as f:
f.write(transformed_graph_def.SerializeToString())
txt_graph = os.path.join(output_dir, txt_graph_name)
with tf.gfile.GFile(txt_graph, 'w') as f:
f.write(str(transformed_graph_def))
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Export a LSTD model in tflite format."""
import os
from absl import flags
import tensorflow as tf
from lstm_object_detection.utils import config_util
flags.DEFINE_string('export_path', None, 'Path to export model.')
flags.DEFINE_string('frozen_graph_path', None, 'Path to frozen graph.')
flags.DEFINE_string(
'pipeline_config_path', '',
'Path to a pipeline_pb2.TrainEvalPipelineConfig config file.')
FLAGS = flags.FLAGS
def main(_):
flags.mark_flag_as_required('export_path')
flags.mark_flag_as_required('frozen_graph_path')
flags.mark_flag_as_required('pipeline_config_path')
configs = config_util.get_configs_from_pipeline_file(
FLAGS.pipeline_config_path)
lstm_config = configs['lstm_model']
input_arrays = ['input_video_tensor']
output_arrays = [
'TFLite_Detection_PostProcess',
'TFLite_Detection_PostProcess:1',
'TFLite_Detection_PostProcess:2',
'TFLite_Detection_PostProcess:3',
]
input_shapes = {
'input_video_tensor': [lstm_config.eval_unroll_length, 320, 320, 3],
}
converter = tf.lite.TFLiteConverter.from_frozen_graph(
FLAGS.frozen_graph_path,
input_arrays,
output_arrays,
input_shapes=input_shapes)
converter.allow_custom_ops = True
tflite_model = converter.convert()
ofilename = os.path.join(FLAGS.export_path)
open(ofilename, 'wb').write(tflite_model)
if __name__ == '__main__':
tf.app.run()
# Exporting a tflite model from a checkpoint
Starting from a trained model checkpoint, creating a tflite model requires 2
steps:
* exporting a tflite frozen graph from a checkpoint
* exporting a tflite model from a frozen graph
## Exporting a tflite frozen graph from a checkpoint
With a candidate checkpoint to export, run the following command from
tensorflow/models/research:
```bash
# from tensorflow/models/research
PIPELINE_CONFIG_PATH={path to pipeline config}
TRAINED_CKPT_PREFIX=/{path to model.ckpt}
EXPORT_DIR={path to folder that will be used for export}
python lstm_object_detection/export_tflite_lstd_graph.py \
--pipeline_config_path ${PIPELINE_CONFIG_PATH} \
--trained_checkpoint_prefix ${TRAINED_CKPT_PREFIX} \
--output_directory ${EXPORT_DIR} \
--add_preprocessing_op
```
After export, you should see the directory ${EXPORT_DIR} containing the
following files:
* `tflite_graph.pb`
* `tflite_graph.pbtxt`
## Exporting a tflite model from a frozen graph
We then take the exported tflite-compatable tflite model, and convert it to a
TFLite FlatBuffer file by running the following:
```bash
# from tensorflow/models/research
FROZEN_GRAPH_PATH={path to exported tflite_graph.pb}
EXPORT_PATH={path to filename that will be used for export}
PIPELINE_CONFIG_PATH={path to pipeline config}
python lstm_object_detection/export_tflite_lstd_model.py \
--export_path ${EXPORT_PATH} \
--frozen_graph_path ${FROZEN_GRAPH_PATH} \
--pipeline_config_path ${PIPELINE_CONFIG_PATH}
```
After export, you should see the file ${EXPORT_PATH} containing the FlatBuffer
model to be used by an application.
......@@ -33,68 +33,6 @@ from object_detection.protos import preprocessor_pb2
class DatasetBuilderTest(tf.test.TestCase):
def _create_tf_record(self):
path = os.path.join(self.get_temp_dir(), 'tfrecord')
writer = tf.python_io.TFRecordWriter(path)
image_tensor = np.random.randint(255, size=(16, 16, 3)).astype(np.uint8)
with self.test_session():
encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).eval()
sequence_example = example_pb2.SequenceExample(
context=feature_pb2.Features(
feature={
'image/format':
feature_pb2.Feature(
bytes_list=feature_pb2.BytesList(
value=['jpeg'.encode('utf-8')])),
'image/height':
feature_pb2.Feature(
int64_list=feature_pb2.Int64List(value=[16])),
'image/width':
feature_pb2.Feature(
int64_list=feature_pb2.Int64List(value=[16])),
}),
feature_lists=feature_pb2.FeatureLists(
feature_list={
'image/encoded':
feature_pb2.FeatureList(feature=[
feature_pb2.Feature(
bytes_list=feature_pb2.BytesList(
value=[encoded_jpeg])),
]),
'image/object/bbox/xmin':
feature_pb2.FeatureList(feature=[
feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=[0.0])),
]),
'image/object/bbox/xmax':
feature_pb2.FeatureList(feature=[
feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=[1.0]))
]),
'image/object/bbox/ymin':
feature_pb2.FeatureList(feature=[
feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=[0.0])),
]),
'image/object/bbox/ymax':
feature_pb2.FeatureList(feature=[
feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=[1.0]))
]),
'image/object/class/label':
feature_pb2.FeatureList(feature=[
feature_pb2.Feature(
int64_list=feature_pb2.Int64List(value=[2]))
]),
}))
writer.write(sequence_example.SerializeToString())
writer.close()
return path
def _get_model_configs_from_proto(self):
"""Creates a model text proto for testing.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment