"torchvision/vscode:/vscode.git/clone" did not exist on "a370e79eb71a5120c7b8c58101ad326764167326"
Unverified Commit 23376e62 authored by Martin Wicke's avatar Martin Wicke Committed by GitHub
Browse files

Merge pull request #5603 from dreamdragon/master

Open souring mobile video object detection framework
parents b2522f9b 77b2556e
syntax = "proto2";
package lstm_object_detection.input_readers;
import "third_party/tensorflow_models/object_detection/protos/input_reader.proto";
message GoogleInputReader {
extend object_detection.protos.ExternalInputReader {
optional GoogleInputReader google_input_reader = 444;
}
oneof input_reader {
TFRecordVideoInputReader tf_record_video_input_reader = 1;
}
}
message TFRecordVideoInputReader {
// Path(s) to tfrecords of input data.
repeated string input_path = 1;
enum DataType {
UNSPECIFIED = 0;
ANNOTATED_IMAGE = 1;
TF_EXAMPLE = 2;
TF_SEQUENCE_EXAMPLE = 3;
}
optional DataType data_type = 2 [default=TF_SEQUENCE_EXAMPLE];
// Length of the video sequence. All the input video sequence should have the
// same length in frames, e.g. 5 frames.
optional int32 video_length = 3;
}
syntax = "proto2";
package object_detection.protos;
import "third_party/tensorflow_models/object_detection/protos/pipeline.proto";
extend TrainEvalPipelineConfig {
optional LstmModel lstm_model = 205743444;
}
// Message for extra fields needed for configuring LSTM model.
message LstmModel {
// Unroll length for training LSTMs.
optional int32 train_unroll_length = 1;
// Unroll length for evaluating LSTMs.
optional int32 eval_unroll_length = 2;
// Depth of the lstm feature map.
optional int32 lstm_state_depth = 3 [default = 256];
}
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""tf.data.Dataset builder.
Creates data sources for DetectionModels from an InputReader config. See
input_reader.proto for options.
Note: If users wishes to also use their own InputReaders with the Object
Detection configuration framework, they should define their own builder function
that wraps the build function.
"""
import tensorflow as tf
import tensorflow.google as google_tf
from google3.learning.brain.contrib.slim.data import parallel_reader
from tensorflow.contrib.training.python.training import sequence_queueing_state_saver as sqss
from lstm_object_detection import tf_sequence_example_decoder
from lstm_object_detection.protos import input_reader_google_pb2
from google3.third_party.tensorflow_models.object_detection.core import preprocessor
from google3.third_party.tensorflow_models.object_detection.core import preprocessor_cache
from google3.third_party.tensorflow_models.object_detection.core import standard_fields as fields
from google3.third_party.tensorflow_models.object_detection.protos import input_reader_pb2
from google3.third_party.tensorflow_models.object_detection.utils import ops as util_ops
# TODO(yinxiao): Make the following variable into configurable proto.
# Padding size for the labeled objects in each frame. Here we assume each
# frame has a total number of objects less than _PADDING_SIZE.
_PADDING_SIZE = 30
def _build_training_batch_dict(batch_sequences_with_states, unroll_length,
batch_size):
"""Builds training batch samples.
Args:
batch_sequences_with_states: A batch_sequences_with_states object.
unroll_length: Unrolled length for LSTM training.
batch_size: Batch size for queue outputs.
Returns:
A dictionary of tensors based on items in input_reader_config.
"""
seq_tensors_dict = {
fields.InputDataFields.image: [],
fields.InputDataFields.groundtruth_boxes: [],
fields.InputDataFields.groundtruth_classes: [],
'batch': batch_sequences_with_states,
}
for i in range(unroll_length):
for j in range(batch_size):
filtered_dict = util_ops.filter_groundtruth_with_nan_box_coordinates({
fields.InputDataFields.groundtruth_boxes: (
batch_sequences_with_states.sequences['groundtruth_boxes'][j][i]),
fields.InputDataFields.groundtruth_classes: (
batch_sequences_with_states.sequences['groundtruth_classes'][j][i]
),
})
filtered_dict = util_ops.retain_groundtruth_with_positive_classes(
filtered_dict)
seq_tensors_dict[fields.InputDataFields.image].append(
batch_sequences_with_states.sequences['image'][j][i])
seq_tensors_dict[fields.InputDataFields.groundtruth_boxes].append(
filtered_dict[fields.InputDataFields.groundtruth_boxes])
seq_tensors_dict[fields.InputDataFields.groundtruth_classes].append(
filtered_dict[fields.InputDataFields.groundtruth_classes])
seq_tensors_dict[fields.InputDataFields.image] = tuple(
seq_tensors_dict[fields.InputDataFields.image])
seq_tensors_dict[fields.InputDataFields.groundtruth_boxes] = tuple(
seq_tensors_dict[fields.InputDataFields.groundtruth_boxes])
seq_tensors_dict[fields.InputDataFields.groundtruth_classes] = tuple(
seq_tensors_dict[fields.InputDataFields.groundtruth_classes])
return seq_tensors_dict
def build(input_reader_config,
model_config,
lstm_config,
unroll_length,
data_augmentation_options=None,
batch_size=1):
"""Builds a tensor dictionary based on the InputReader config.
Args:
input_reader_config: An input_reader_builder.InputReader object.
model_config: A model.proto object containing the config for the desired
DetectionModel.
lstm_config: LSTM specific configs.
unroll_length: Unrolled length for LSTM training.
data_augmentation_options: A list of tuples, where each tuple contains a
data augmentation function and a dictionary containing arguments and their
values (see preprocessor.py).
batch_size: Batch size for queue outputs.
Returns:
A dictionary of tensors based on items in the input_reader_config.
Raises:
ValueError: On invalid input reader proto.
ValueError: If no input paths are specified.
"""
if not isinstance(input_reader_config, input_reader_pb2.InputReader):
raise ValueError('input_reader_config not of type '
'input_reader_pb2.InputReader.')
external_reader_config = input_reader_config.external_input_reader
google_input_reader_config = external_reader_config.Extensions[
input_reader_google_pb2.GoogleInputReader.google_input_reader]
input_reader_type = google_input_reader_config.WhichOneof('input_reader')
if input_reader_type == 'tf_record_video_input_reader':
config = google_input_reader_config.tf_record_video_input_reader
reader_type_class = tf.TFRecordReader
else:
raise ValueError(
'Unsupported reader in input_reader_config: %s' % input_reader_type)
if not config.input_path:
raise ValueError('At least one input path must be specified in '
'`input_reader_config`.')
key, value = parallel_reader.parallel_read(
config.input_path[:], # Convert `RepeatedScalarContainer` to list.
reader_class=reader_type_class,
num_epochs=(input_reader_config.num_epochs
if input_reader_config.num_epochs else None),
num_readers=input_reader_config.num_readers,
shuffle=input_reader_config.shuffle,
dtypes=[tf.string, tf.string],
capacity=input_reader_config.queue_capacity,
min_after_dequeue=input_reader_config.min_after_dequeue)
# TODO(yinxiao): Add loading instance mask option.
decoder = tf_sequence_example_decoder.TfSequenceExampleDecoder()
keys_to_decode = [
fields.InputDataFields.image, fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_classes
]
tensor_dict = decoder.decode(value, items=keys_to_decode)
tensor_dict['image'].set_shape([None, None, None, 3])
tensor_dict['groundtruth_boxes'].set_shape([None, None, 4])
height = model_config.ssd.image_resizer.fixed_shape_resizer.height
width = model_config.ssd.image_resizer.fixed_shape_resizer.width
# If data augmentation is specified in the config file, the preprocessor
# will be called here to augment the data as specified. Most common
# augmentations include horizontal flip and cropping.
if data_augmentation_options:
images_pre = tf.split(tensor_dict['image'], config.video_length, axis=0)
bboxes_pre = tf.split(
tensor_dict['groundtruth_boxes'], config.video_length, axis=0)
labels_pre = tf.split(
tensor_dict['groundtruth_classes'], config.video_length, axis=0)
images_proc, bboxes_proc, labels_proc = [], [], []
cache = preprocessor_cache.PreprocessorCache()
for i, _ in enumerate(images_pre):
image_dict = {
fields.InputDataFields.image:
images_pre[i],
fields.InputDataFields.groundtruth_boxes:
tf.squeeze(bboxes_pre[i], axis=0),
fields.InputDataFields.groundtruth_classes:
tf.squeeze(labels_pre[i], axis=0),
}
image_dict = preprocessor.preprocess(
image_dict,
data_augmentation_options,
func_arg_map=preprocessor.get_default_func_arg_map(),
preprocess_vars_cache=cache)
# Pads detection count to _PADDING_SIZE.
image_dict[fields.InputDataFields.groundtruth_boxes] = tf.pad(
image_dict[fields.InputDataFields.groundtruth_boxes],
[[0, _PADDING_SIZE], [0, 0]])
image_dict[fields.InputDataFields.groundtruth_boxes] = tf.slice(
image_dict[fields.InputDataFields.groundtruth_boxes], [0, 0],
[_PADDING_SIZE, -1])
image_dict[fields.InputDataFields.groundtruth_classes] = tf.pad(
image_dict[fields.InputDataFields.groundtruth_classes],
[[0, _PADDING_SIZE]])
image_dict[fields.InputDataFields.groundtruth_classes] = tf.slice(
image_dict[fields.InputDataFields.groundtruth_classes], [0],
[_PADDING_SIZE])
images_proc.append(image_dict[fields.InputDataFields.image])
bboxes_proc.append(image_dict[fields.InputDataFields.groundtruth_boxes])
labels_proc.append(image_dict[fields.InputDataFields.groundtruth_classes])
tensor_dict['image'] = tf.concat(images_proc, axis=0)
tensor_dict['groundtruth_boxes'] = tf.stack(bboxes_proc, axis=0)
tensor_dict['groundtruth_classes'] = tf.stack(labels_proc, axis=0)
else:
# Pads detection count to _PADDING_SIZE per frame.
tensor_dict['groundtruth_boxes'] = tf.pad(
tensor_dict['groundtruth_boxes'], [[0, 0], [0, _PADDING_SIZE], [0, 0]])
tensor_dict['groundtruth_boxes'] = tf.slice(
tensor_dict['groundtruth_boxes'], [0, 0, 0], [-1, _PADDING_SIZE, -1])
tensor_dict['groundtruth_classes'] = tf.pad(
tensor_dict['groundtruth_classes'], [[0, 0], [0, _PADDING_SIZE]])
tensor_dict['groundtruth_classes'] = tf.slice(
tensor_dict['groundtruth_classes'], [0, 0], [-1, _PADDING_SIZE])
tensor_dict['image'], _ = preprocessor.resize_image(
tensor_dict['image'], new_height=height, new_width=width)
num_steps = config.video_length / unroll_length
init_states = {
'lstm_state_c':
tf.zeros([height / 32, width / 32, lstm_config.lstm_state_depth]),
'lstm_state_h':
tf.zeros([height / 32, width / 32, lstm_config.lstm_state_depth]),
'lstm_state_step':
tf.constant(num_steps, shape=[]),
}
batch = sqss.batch_sequences_with_states(
input_key=key,
input_sequences=tensor_dict,
input_context={},
input_length=None,
initial_states=init_states,
num_unroll=unroll_length,
batch_size=batch_size,
num_threads=batch_size,
make_keys_unique=True,
capacity=batch_size * batch_size)
return _build_training_batch_dict(batch, unroll_length, batch_size)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for dataset_builder."""
import os
import numpy as np
import tensorflow as tf
from google.protobuf import text_format
from google3.testing.pybase import parameterized
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
from lstm_object_detection import seq_dataset_builder
from lstm_object_detection.protos import pipeline_pb2 as internal_pipeline_pb2
from google3.third_party.tensorflow_models.object_detection.builders import preprocessor_builder
from google3.third_party.tensorflow_models.object_detection.core import standard_fields as fields
from google3.third_party.tensorflow_models.object_detection.protos import input_reader_pb2
from google3.third_party.tensorflow_models.object_detection.protos import pipeline_pb2
from google3.third_party.tensorflow_models.object_detection.protos import preprocessor_pb2
class DatasetBuilderTest(parameterized.TestCase):
def _create_tf_record(self):
path = os.path.join(self.get_temp_dir(), 'tfrecord')
writer = tf.python_io.TFRecordWriter(path)
image_tensor = np.random.randint(255, size=(16, 16, 3)).astype(np.uint8)
with self.test_session():
encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).eval()
sequence_example = example_pb2.SequenceExample(
context=feature_pb2.Features(
feature={
'image/format':
feature_pb2.Feature(
bytes_list=feature_pb2.BytesList(
value=['jpeg'.encode('utf-8')])),
'image/height':
feature_pb2.Feature(
int64_list=feature_pb2.Int64List(value=[16])),
'image/width':
feature_pb2.Feature(
int64_list=feature_pb2.Int64List(value=[16])),
}),
feature_lists=feature_pb2.FeatureLists(
feature_list={
'image/encoded':
feature_pb2.FeatureList(feature=[
feature_pb2.Feature(
bytes_list=feature_pb2.BytesList(
value=[encoded_jpeg])),
]),
'image/object/bbox/xmin':
feature_pb2.FeatureList(feature=[
feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=[0.0])),
]),
'image/object/bbox/xmax':
feature_pb2.FeatureList(feature=[
feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=[1.0]))
]),
'image/object/bbox/ymin':
feature_pb2.FeatureList(feature=[
feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=[0.0])),
]),
'image/object/bbox/ymax':
feature_pb2.FeatureList(feature=[
feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=[1.0]))
]),
'image/object/class/label':
feature_pb2.FeatureList(feature=[
feature_pb2.Feature(
int64_list=feature_pb2.Int64List(value=[2]))
]),
}))
writer.write(sequence_example.SerializeToString())
writer.close()
return path
def _get_model_configs_from_proto(self):
"""Creates a model text proto for testing.
Returns:
A dictionary of model configs.
"""
model_text_proto = """
[object_detection.protos.lstm_model] {
train_unroll_length: 4
eval_unroll_length: 4
}
model {
ssd {
feature_extractor {
type: 'lstm_mobilenet_v1_fpn'
conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
}
negative_class_weight: 2.0
box_coder {
faster_rcnn_box_coder {
}
}
matcher {
argmax_matcher {
}
}
similarity_calculator {
iou_similarity {
}
}
anchor_generator {
ssd_anchor_generator {
aspect_ratios: 1.0
}
}
image_resizer {
fixed_shape_resizer {
height: 32
width: 32
}
}
box_predictor {
convolutional_box_predictor {
conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
}
}
normalize_loc_loss_by_codesize: true
loss {
classification_loss {
weighted_softmax {
}
}
localization_loss {
weighted_smooth_l1 {
}
}
}
}
}"""
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
text_format.Merge(model_text_proto, pipeline_config)
configs = {}
configs['model'] = pipeline_config.model
configs['lstm_model'] = pipeline_config.Extensions[
internal_pipeline_pb2.lstm_model]
return configs
def _get_data_augmentation_preprocessor_proto(self):
preprocessor_text_proto = """
random_horizontal_flip {
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
return preprocessor_proto
def _create_training_dict(self, tensor_dict):
image_dict = {}
all_dict = {}
all_dict['batch'] = tensor_dict.pop('batch')
for i, _ in enumerate(tensor_dict[fields.InputDataFields.image]):
for key, val in tensor_dict.items():
image_dict[key] = val[i]
image_dict[fields.InputDataFields.image] = tf.to_float(
tf.expand_dims(image_dict[fields.InputDataFields.image], 0))
suffix = str(i)
for key, val in image_dict.items():
all_dict[key + suffix] = val
return all_dict
def _get_input_proto(self, input_reader):
return """
external_input_reader {
[lstm_object_detection.input_readers.GoogleInputReader.google_input_reader] {
%s: {
input_path: '{0}'
data_type: TF_SEQUENCE_EXAMPLE
video_length: 4
}
}
}
""" % input_reader
@parameterized.named_parameters(('tf_record', 'tf_record_video_input_reader'))
def test_video_input_reader(self, video_input_type):
input_reader_proto = input_reader_pb2.InputReader()
text_format.Merge(
self._get_input_proto(video_input_type), input_reader_proto)
configs = self._get_model_configs_from_proto()
tensor_dict = seq_dataset_builder.build(
input_reader_proto,
configs['model'],
configs['lstm_model'],
unroll_length=1)
all_dict = self._create_training_dict(tensor_dict)
self.assertEqual((1, 32, 32, 3), all_dict['image0'].shape)
self.assertEqual(4, all_dict['groundtruth_boxes0'].shape[1])
def test_build_with_data_augmentation(self):
input_reader_proto = input_reader_pb2.InputReader()
text_format.Merge(
self._get_input_proto('tf_record_video_input_reader'),
input_reader_proto)
configs = self._get_model_configs_from_proto()
data_augmentation_options = [
preprocessor_builder.build(
self._get_data_augmentation_preprocessor_proto())
]
tensor_dict = seq_dataset_builder.build(
input_reader_proto,
configs['model'],
configs['lstm_model'],
unroll_length=1,
data_augmentation_options=data_augmentation_options)
all_dict = self._create_training_dict(tensor_dict)
self.assertEqual((1, 32, 32, 3), all_dict['image0'].shape)
self.assertEqual(4, all_dict['groundtruth_boxes0'].shape[1])
def test_raises_error_without_input_paths(self):
input_reader_text_proto = """
shuffle: false
num_readers: 1
load_instance_masks: true
"""
input_reader_proto = input_reader_pb2.InputReader()
text_format.Merge(input_reader_text_proto, input_reader_proto)
configs = self._get_model_configs_from_proto()
with self.assertRaises(ValueError):
_ = seq_dataset_builder.build(
input_reader_proto,
configs['model'],
configs['lstm_model'],
unroll_length=1)
if __name__ == '__main__':
tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tensorflow Sequence Example proto decoder.
A decoder to decode string tensors containing serialized
tensorflow.SequenceExample protos.
TODO(yinxiao): When TensorFlow object detection API officially supports
tensorflow.SequenceExample, merge this decoder.
"""
import tensorflow as tf
from google3.learning.brain.contrib.slim.data import tfexample_decoder
from google3.third_party.tensorflow_models.object_detection.core import data_decoder
from google3.third_party.tensorflow_models.object_detection.core import standard_fields as fields
slim_example_decoder = tf.contrib.slim.tfexample_decoder
class TfSequenceExampleDecoder(data_decoder.DataDecoder):
"""Tensorflow Sequence Example proto decoder."""
def __init__(self):
"""Constructor sets keys_to_features and items_to_handlers."""
self.keys_to_context_features = {
'image/format':
tf.FixedLenFeature((), tf.string, default_value='jpeg'),
'image/filename':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/key/sha256':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/source_id':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/height':
tf.FixedLenFeature((), tf.int64, 1),
'image/width':
tf.FixedLenFeature((), tf.int64, 1),
}
self.keys_to_features = {
'image/encoded': tf.FixedLenSequenceFeature((), tf.string),
'bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
'bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
'bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
'bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
'bbox/label/index': tf.VarLenFeature(dtype=tf.int64),
'bbox/label/string': tf.VarLenFeature(tf.string),
'area': tf.VarLenFeature(tf.float32),
'is_crowd': tf.VarLenFeature(tf.int64),
'difficult': tf.VarLenFeature(tf.int64),
'group_of': tf.VarLenFeature(tf.int64),
}
self.items_to_handlers = {
fields.InputDataFields.image:
slim_example_decoder.Image(
image_key='image/encoded',
format_key='image/format',
channels=3,
repeated=True),
fields.InputDataFields.source_id: (
slim_example_decoder.Tensor('image/source_id')),
fields.InputDataFields.key: (
slim_example_decoder.Tensor('image/key/sha256')),
fields.InputDataFields.filename: (
slim_example_decoder.Tensor('image/filename')),
# Object boxes and classes.
fields.InputDataFields.groundtruth_boxes:
tfexample_decoder.BoundingBoxSequence(prefix='bbox/'),
fields.InputDataFields.groundtruth_classes: (
slim_example_decoder.Tensor('bbox/label/index')),
fields.InputDataFields.groundtruth_area:
slim_example_decoder.Tensor('area'),
fields.InputDataFields.groundtruth_is_crowd: (
slim_example_decoder.Tensor('is_crowd')),
fields.InputDataFields.groundtruth_difficult: (
slim_example_decoder.Tensor('difficult')),
fields.InputDataFields.groundtruth_group_of: (
slim_example_decoder.Tensor('group_of'))
}
def decode(self, tf_seq_example_string_tensor, items=None):
"""Decodes serialized tf.SequenceExample and returns a tensor dictionary.
Args:
tf_seq_example_string_tensor: A string tensor holding a serialized
tensorflow example proto.
items: The list of items to decode. These must be a subset of the item
keys in self._items_to_handlers. If `items` is left as None, then all
of the items in self._items_to_handlers are decoded.
Returns:
A dictionary of the following tensors.
fields.InputDataFields.image - 3D uint8 tensor of shape [None, None, seq]
containing image(s).
fields.InputDataFields.source_id - string tensor containing original
image id.
fields.InputDataFields.key - string tensor with unique sha256 hash key.
fields.InputDataFields.filename - string tensor with original dataset
filename.
fields.InputDataFields.groundtruth_boxes - 2D float32 tensor of shape
[None, 4] containing box corners.
fields.InputDataFields.groundtruth_classes - 1D int64 tensor of shape
[None] containing classes for the boxes.
fields.InputDataFields.groundtruth_area - 1D float32 tensor of shape
[None] containing object mask area in pixel squared.
fields.InputDataFields.groundtruth_is_crowd - 1D bool tensor of shape
[None] indicating if the boxes enclose a crowd.
fields.InputDataFields.groundtruth_difficult - 1D bool tensor of shape
[None] indicating if the boxes represent `difficult` instances.
"""
serialized_example = tf.reshape(tf_seq_example_string_tensor, shape=[])
decoder = TFSequenceExampleDecoderHelper(self.keys_to_context_features,
self.keys_to_features,
self.items_to_handlers)
if not items:
items = decoder.list_items()
tensors = decoder.decode(serialized_example, items=items)
tensor_dict = dict(zip(items, tensors))
return tensor_dict
class TFSequenceExampleDecoderHelper(data_decoder.DataDecoder):
"""A decoder helper class for TensorFlow SequenceExamples.
To perform this decoding operation, a SequenceExampleDecoder is given a list
of ItemHandlers. Each ItemHandler indicates the set of features.
"""
def __init__(self, keys_to_context_features, keys_to_sequence_features,
items_to_handlers):
"""Constructs the decoder.
Args:
keys_to_context_features: A dictionary from TF-SequenceExample context
keys to either tf.VarLenFeature or tf.FixedLenFeature instances.
See tensorflow's parsing_ops.py.
keys_to_sequence_features: A dictionary from TF-SequenceExample sequence
keys to either tf.VarLenFeature or tf.FixedLenSequenceFeature instances.
items_to_handlers: A dictionary from items (strings) to ItemHandler
instances. Note that the ItemHandler's are provided the keys that they
use to return the final item Tensors.
Raises:
ValueError: If the same key is present for context features and sequence
features.
"""
unique_keys = set()
unique_keys.update(keys_to_context_features)
unique_keys.update(keys_to_sequence_features)
if len(unique_keys) != (
len(keys_to_context_features) + len(keys_to_sequence_features)):
# This situation is ambiguous in the decoder's keys_to_tensors variable.
raise ValueError('Context and sequence keys are not unique. \n'
' Context keys: %s \n Sequence keys: %s' %
(list(keys_to_context_features.keys()),
list(keys_to_sequence_features.keys())))
self._keys_to_context_features = keys_to_context_features
self._keys_to_sequence_features = keys_to_sequence_features
self._items_to_handlers = items_to_handlers
def list_items(self):
"""Returns keys of items."""
return self._items_to_handlers.keys()
def decode(self, serialized_example, items=None):
"""Decodes the given serialized TF-SequenceExample.
Args:
serialized_example: A serialized TF-SequenceExample tensor.
items: The list of items to decode. These must be a subset of the item
keys in self._items_to_handlers. If `items` is left as None, then all
of the items in self._items_to_handlers are decoded.
Returns:
The decoded items, a list of tensor.
"""
context, feature_list = tf.parse_single_sequence_example(
serialized_example, self._keys_to_context_features,
self._keys_to_sequence_features)
# Reshape non-sparse elements just once:
for k in self._keys_to_context_features:
v = self._keys_to_context_features[k]
if isinstance(v, tf.FixedLenFeature):
context[k] = tf.reshape(context[k], v.shape)
if not items:
items = self._items_to_handlers.keys()
outputs = []
for item in items:
handler = self._items_to_handlers[item]
keys_to_tensors = {
key: context[key] if key in context else feature_list[key]
for key in handler.keys
}
outputs.append(handler.tensors_to_item(keys_to_tensors))
return outputs
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for video_object_detection.tf_sequence_example_decoder."""
import numpy as np
import tensorflow as tf
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
from lstm_object_detection import tf_sequence_example_decoder
from google3.third_party.tensorflow_models.object_detection.core import standard_fields as fields
class TfSequenceExampleDecoderTest(tf.test.TestCase):
"""Tests for sequence example decoder."""
def _EncodeImage(self, image_tensor, encoding_type='jpeg'):
with self.test_session():
if encoding_type == 'jpeg':
image_encoded = tf.image.encode_jpeg(tf.constant(image_tensor)).eval()
else:
raise ValueError('Invalid encoding type.')
return image_encoded
def _DecodeImage(self, image_encoded, encoding_type='jpeg'):
with self.test_session():
if encoding_type == 'jpeg':
image_decoded = tf.image.decode_jpeg(tf.constant(image_encoded)).eval()
else:
raise ValueError('Invalid encoding type.')
return image_decoded
def testDecodeJpegImageAndBoundingBox(self):
"""Test if the decoder can correctly decode the image and bounding box.
A set of random images (represented as an image tensor) is first decoded as
the groundtrue image. Meanwhile, the image tensor will be encoded and pass
through the sequence example, and then decoded as images. The groundtruth
image and the decoded image are expected to be equal. Similar tests are
also applied to labels such as bounding box.
"""
image_tensor = np.random.randint(256, size=(256, 256, 3)).astype(np.uint8)
encoded_jpeg = self._EncodeImage(image_tensor)
decoded_jpeg = self._DecodeImage(encoded_jpeg)
sequence_example = example_pb2.SequenceExample(
feature_lists=feature_pb2.FeatureLists(
feature_list={
'image/encoded':
feature_pb2.FeatureList(feature=[
feature_pb2.Feature(
bytes_list=feature_pb2.BytesList(
value=[encoded_jpeg])),
]),
'bbox/xmin':
feature_pb2.FeatureList(feature=[
feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=[0.0])),
]),
'bbox/xmax':
feature_pb2.FeatureList(feature=[
feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=[1.0]))
]),
'bbox/ymin':
feature_pb2.FeatureList(feature=[
feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=[0.0])),
]),
'bbox/ymax':
feature_pb2.FeatureList(feature=[
feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=[1.0]))
]),
})).SerializeToString()
example_decoder = tf_sequence_example_decoder.TfSequenceExampleDecoder()
tensor_dict = example_decoder.decode(tf.convert_to_tensor(sequence_example))
# Test tensor dict image dimension.
self.assertAllEqual(
(tensor_dict[fields.InputDataFields.image].get_shape().as_list()),
[None, None, None, 3])
with self.test_session() as sess:
tensor_dict[fields.InputDataFields.image] = tf.squeeze(
tensor_dict[fields.InputDataFields.image])
tensor_dict[fields.InputDataFields.groundtruth_boxes] = tf.squeeze(
tensor_dict[fields.InputDataFields.groundtruth_boxes])
tensor_dict = sess.run(tensor_dict)
# Test decoded image.
self.assertAllEqual(decoded_jpeg, tensor_dict[fields.InputDataFields.image])
# Test decoded bounding box.
self.assertAllEqual([0.0, 0.0, 1.0, 1.0],
tensor_dict[fields.InputDataFields.groundtruth_boxes])
if __name__ == '__main__':
tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Training executable for detection models.
This executable is used to train DetectionModels. There are two ways of
configuring the training job:
1) A single pipeline_pb2.TrainEvalPipelineConfig configuration file
can be specified by --pipeline_config_path.
Example usage:
./train \
--logtostderr \
--train_dir=path/to/train_dir \
--pipeline_config_path=pipeline_config.pbtxt
2) Three configuration files can be provided: a model_pb2.DetectionModel
configuration file to define what type of DetectionModel is being trained, an
input_reader_pb2.InputReader file to specify what training data will be used and
a train_pb2.TrainConfig file to configure training parameters.
Example usage:
./train \
--logtostderr \
--train_dir=path/to/train_dir \
--model_config_path=model_config.pbtxt \
--train_config_path=train_config.pbtxt \
--input_config_path=train_input_config.pbtxt
"""
import functools
import json
import os
from absl import flags
import tensorflow as tf
from lstm_object_detection import model_builder
from lstm_object_detection import seq_dataset_builder
from lstm_object_detection import trainer
from lstm_object_detection.utils import config_util
from google3.third_party.tensorflow_models.object_detection.builders import preprocessor_builder
flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
flags.DEFINE_integer('task', 0, 'task id')
flags.DEFINE_integer('num_clones', 1, 'Number of clones to deploy per worker.')
flags.DEFINE_boolean(
'clone_on_cpu', False,
'Force clones to be deployed on CPU. Note that even if '
'set to False (allowing ops to run on gpu), some ops may '
'still be run on the CPU if they have no GPU kernel.')
flags.DEFINE_integer('worker_replicas', 1, 'Number of worker+trainer '
'replicas.')
flags.DEFINE_integer(
'ps_tasks', 0, 'Number of parameter server tasks. If None, does not use '
'a parameter server.')
flags.DEFINE_string(
'train_dir', '',
'Directory to save the checkpoints and training summaries.')
flags.DEFINE_string(
'pipeline_config_path', '',
'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
'file. If provided, other configs are ignored')
flags.DEFINE_string('train_config_path', '',
'Path to a train_pb2.TrainConfig config file.')
flags.DEFINE_string('input_config_path', '',
'Path to an input_reader_pb2.InputReader config file.')
flags.DEFINE_string('model_config_path', '',
'Path to a model_pb2.DetectionModel config file.')
FLAGS = flags.FLAGS
def main(_):
assert FLAGS.train_dir, '`train_dir` is missing.'
if FLAGS.task == 0:
tf.gfile.MakeDirs(FLAGS.train_dir)
if FLAGS.pipeline_config_path:
configs = config_util.get_configs_from_pipeline_file(
FLAGS.pipeline_config_path)
if FLAGS.task == 0:
tf.gfile.Copy(
FLAGS.pipeline_config_path,
os.path.join(FLAGS.train_dir, 'pipeline.config'),
overwrite=True)
else:
configs = config_util.get_configs_from_multiple_files(
model_config_path=FLAGS.model_config_path,
train_config_path=FLAGS.train_config_path,
train_input_config_path=FLAGS.input_config_path)
if FLAGS.task == 0:
for name, config in [('model.config', FLAGS.model_config_path),
('train.config', FLAGS.train_config_path),
('input.config', FLAGS.input_config_path)]:
tf.gfile.Copy(
config, os.path.join(FLAGS.train_dir, name), overwrite=True)
model_config = configs['model']
lstm_config = configs['lstm_model']
train_config = configs['train_config']
input_config = configs['train_input_config']
model_fn = functools.partial(
model_builder.build,
model_config=model_config,
lstm_config=lstm_config,
is_training=True)
def get_next(config, model_config, lstm_config, unroll_length):
data_augmentation_options = [
preprocessor_builder.build(step)
for step in train_config.data_augmentation_options
]
return seq_dataset_builder.build(
config,
model_config,
lstm_config,
unroll_length,
data_augmentation_options,
batch_size=train_config.batch_size)
create_input_dict_fn = functools.partial(get_next, input_config, model_config,
lstm_config,
lstm_config.train_unroll_length)
env = json.loads(os.environ.get('TF_CONFIG', '{}'))
cluster_data = env.get('cluster', None)
cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None
task_data = env.get('task', None) or {'type': 'master', 'index': 0}
task_info = type('TaskSpec', (object,), task_data)
# Parameters for a single worker.
ps_tasks = 0
worker_replicas = 1
worker_job_name = 'lonely_worker'
task = 0
is_chief = True
master = ''
if cluster_data and 'worker' in cluster_data:
# Number of total worker replicas include "worker"s and the "master".
worker_replicas = len(cluster_data['worker']) + 1
if cluster_data and 'ps' in cluster_data:
ps_tasks = len(cluster_data['ps'])
if worker_replicas > 1 and ps_tasks < 1:
raise ValueError('At least 1 ps task is needed for distributed training.')
if worker_replicas >= 1 and ps_tasks > 0:
# Set up distributed training.
server = tf.train.Server(
tf.train.ClusterSpec(cluster),
protocol='grpc',
job_name=task_info.type,
task_index=task_info.index)
if task_info.type == 'ps':
server.join()
return
worker_job_name = '%s/task:%d' % (task_info.type, task_info.index)
task = task_info.index
is_chief = (task_info.type == 'master')
master = server.target
trainer.train(create_input_dict_fn, model_fn, train_config, master, task,
FLAGS.num_clones, worker_replicas, FLAGS.clone_on_cpu, ps_tasks,
worker_job_name, is_chief, FLAGS.train_dir)
if __name__ == '__main__':
tf.app.run()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Detection model trainer.
This file provides a generic training method that can be used to train a
DetectionModel.
"""
import functools
import tensorflow as tf
from google3.pyglib import logging
from google3.third_party.tensorflow_models.object_detection.builders import optimizer_builder
from google3.third_party.tensorflow_models.object_detection.core import standard_fields as fields
from google3.third_party.tensorflow_models.object_detection.utils import ops as util_ops
from google3.third_party.tensorflow_models.object_detection.utils import variables_helper
from deployment import model_deploy
slim = tf.contrib.slim
def create_input_queue(create_tensor_dict_fn):
"""Sets up reader, prefetcher and returns input queue.
Args:
create_tensor_dict_fn: function to create tensor dictionary.
Returns:
all_dict: A dictionary holds tensors for images, boxes, and targets.
"""
tensor_dict = create_tensor_dict_fn()
all_dict = {}
num_images = len(tensor_dict[fields.InputDataFields.image])
all_dict['batch'] = tensor_dict['batch']
del tensor_dict['batch']
for i in range(num_images):
suffix = str(i)
for key, val in tensor_dict.items():
all_dict[key + suffix] = val[i]
all_dict[fields.InputDataFields.image + suffix] = tf.to_float(
tf.expand_dims(all_dict[fields.InputDataFields.image + suffix], 0))
return all_dict
def get_inputs(input_queue, num_classes, merge_multiple_label_boxes=False):
"""Dequeues batch and constructs inputs to object detection model.
Args:
input_queue: BatchQueue object holding enqueued tensor_dicts.
num_classes: Number of classes.
merge_multiple_label_boxes: Whether to merge boxes with multiple labels
or not. Defaults to false. Merged boxes are represented with a single
box and a k-hot encoding of the multiple labels associated with the
boxes.
Returns:
images: a list of 3-D float tensor of images.
image_keys: a list of string keys for the images.
locations: a list of tensors of shape [num_boxes, 4] containing the corners
of the groundtruth boxes.
classes: a list of padded one-hot tensors containing target classes.
masks: a list of 3-D float tensors of shape [num_boxes, image_height,
image_width] containing instance masks for objects if present in the
input_queue. Else returns None.
keypoints: a list of 3-D float tensors of shape [num_boxes, num_keypoints,
2] containing keypoints for objects if present in the
input queue. Else returns None.
"""
read_data_list = input_queue
label_id_offset = 1
def extract_images_and_targets(read_data):
"""Extract images and targets from the input dict."""
suffix = 0
images = []
keys = []
locations = []
classes = []
masks = []
keypoints = []
while fields.InputDataFields.image + str(suffix) in read_data:
image = read_data[fields.InputDataFields.image + str(suffix)]
key = ''
if fields.InputDataFields.source_id in read_data:
key = read_data[fields.InputDataFields.source_id + str(suffix)]
location_gt = (
read_data[fields.InputDataFields.groundtruth_boxes + str(suffix)])
classes_gt = tf.cast(
read_data[fields.InputDataFields.groundtruth_classes + str(suffix)],
tf.int32)
classes_gt -= label_id_offset
masks_gt = read_data.get(
fields.InputDataFields.groundtruth_instance_masks + str(suffix))
keypoints_gt = read_data.get(
fields.InputDataFields.groundtruth_keypoints + str(suffix))
if merge_multiple_label_boxes:
location_gt, classes_gt, _ = util_ops.merge_boxes_with_multiple_labels(
location_gt, classes_gt, num_classes)
else:
classes_gt = util_ops.padded_one_hot_encoding(
indices=classes_gt, depth=num_classes, left_pad=0)
# Batch read input data and groundtruth. Images and locations, classes by
# default should have the same number of items.
images.append(image)
keys.append(key)
locations.append(location_gt)
classes.append(classes_gt)
masks.append(masks_gt)
keypoints.append(keypoints_gt)
suffix += 1
return (images, keys, locations, classes, masks, keypoints)
return extract_images_and_targets(read_data_list)
def _create_losses(input_queue, create_model_fn, train_config):
"""Creates loss function for a DetectionModel.
Args:
input_queue: BatchQueue object holding enqueued tensor_dicts.
create_model_fn: A function to create the DetectionModel.
train_config: a train_pb2.TrainConfig protobuf.
"""
detection_model = create_model_fn()
(images, _, groundtruth_boxes_list, groundtruth_classes_list,
groundtruth_masks_list, groundtruth_keypoints_list) = get_inputs(
input_queue, detection_model.num_classes,
train_config.merge_multiple_label_boxes)
preprocessed_images = []
true_image_shapes = []
for image in images:
resized_image, true_image_shape = detection_model.preprocess(image)
preprocessed_images.append(resized_image)
true_image_shapes.append(true_image_shape)
images = tf.concat(preprocessed_images, 0)
true_image_shapes = tf.concat(true_image_shapes, 0)
if any(mask is None for mask in groundtruth_masks_list):
groundtruth_masks_list = None
if any(keypoints is None for keypoints in groundtruth_keypoints_list):
groundtruth_keypoints_list = None
detection_model.provide_groundtruth(
groundtruth_boxes_list, groundtruth_classes_list, groundtruth_masks_list,
groundtruth_keypoints_list)
prediction_dict = detection_model.predict(images, true_image_shapes,
input_queue['batch'])
losses_dict = detection_model.loss(prediction_dict, true_image_shapes)
for loss_tensor in losses_dict.values():
tf.losses.add_loss(loss_tensor)
def get_restore_checkpoint_ops(restore_checkpoints, detection_model,
train_config):
"""Restore checkpoint from saved checkpoints.
Args:
restore_checkpoints: loaded checkpoints.
detection_model: Object detection model built from config file.
train_config: a train_pb2.TrainConfig protobuf.
Returns:
restorers: A list ops to init the model from checkpoints.
"""
restorers = []
vars_restored = []
for restore_checkpoint in restore_checkpoints:
var_map = detection_model.restore_map(
fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type)
available_var_map = (
variables_helper.get_variables_available_in_checkpoint(
var_map, restore_checkpoint))
for var_name, var in available_var_map.iteritems():
if var in vars_restored:
logging.info('Variable %s contained in multiple checkpoints',
var.op.name)
del available_var_map[var_name]
else:
vars_restored.append(var)
# Initialize from ExponentialMovingAverages if possible.
available_ema_var_map = {}
ckpt_reader = tf.train.NewCheckpointReader(restore_checkpoint)
ckpt_vars_to_shape_map = ckpt_reader.get_variable_to_shape_map()
for var_name, var in available_var_map.iteritems():
var_name_ema = var_name + '/ExponentialMovingAverage'
if var_name_ema in ckpt_vars_to_shape_map:
available_ema_var_map[var_name_ema] = var
else:
available_ema_var_map[var_name] = var
available_var_map = available_ema_var_map
init_saver = tf.train.Saver(available_var_map)
if available_var_map.keys():
restorers.append(init_saver)
else:
logging.info('WARNING: Checkpoint %s has no restorable variables',
restore_checkpoint)
return restorers
def train(create_tensor_dict_fn,
create_model_fn,
train_config,
master,
task,
num_clones,
worker_replicas,
clone_on_cpu,
ps_tasks,
worker_job_name,
is_chief,
train_dir,
graph_hook_fn=None):
"""Training function for detection models.
Args:
create_tensor_dict_fn: a function to create a tensor input dictionary.
create_model_fn: a function that creates a DetectionModel and generates
losses.
train_config: a train_pb2.TrainConfig protobuf.
master: BNS name of the TensorFlow master to use.
task: The task id of this training instance.
num_clones: The number of clones to run per machine.
worker_replicas: The number of work replicas to train with.
clone_on_cpu: True if clones should be forced to run on CPU.
ps_tasks: Number of parameter server tasks.
worker_job_name: Name of the worker job.
is_chief: Whether this replica is the chief replica.
train_dir: Directory to write checkpoints and training summaries to.
graph_hook_fn: Optional function that is called after the training graph is
completely built. This is helpful to perform additional changes to the
training graph such as optimizing batchnorm. The function should modify
the default graph.
"""
detection_model = create_model_fn()
with tf.Graph().as_default():
# Build a configuration specifying multi-GPU and multi-replicas.
deploy_config = model_deploy.DeploymentConfig(
num_clones=num_clones,
clone_on_cpu=clone_on_cpu,
replica_id=task,
num_replicas=worker_replicas,
num_ps_tasks=ps_tasks,
worker_job_name=worker_job_name)
# Place the global step on the device storing the variables.
with tf.device(deploy_config.variables_device()):
global_step = slim.create_global_step()
with tf.device(deploy_config.inputs_device()):
input_queue = create_input_queue(create_tensor_dict_fn)
# Gather initial summaries.
# TODO(rathodv): See if summaries can be added/extracted from global tf
# collections so that they don't have to be passed around.
summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
global_summaries = set([])
model_fn = functools.partial(
_create_losses,
create_model_fn=create_model_fn,
train_config=train_config)
clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue])
first_clone_scope = clones[0].scope
# Gather update_ops from the first clone. These contain, for example,
# the updates for the batch_norm variables created by model_fn.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)
with tf.device(deploy_config.optimizer_device()):
training_optimizer, optimizer_summary_vars = optimizer_builder.build(
train_config.optimizer)
for var in optimizer_summary_vars:
tf.summary.scalar(var.op.name, var)
sync_optimizer = None
if train_config.sync_replicas:
training_optimizer = tf.train.SyncReplicasOptimizer(
training_optimizer,
replicas_to_aggregate=train_config.replicas_to_aggregate,
total_num_replicas=train_config.worker_replicas)
sync_optimizer = training_optimizer
# Create ops required to initialize the model from a given checkpoint.
init_fn = None
if train_config.fine_tune_checkpoint:
restore_checkpoints = [
path.strip() for path in train_config.fine_tune_checkpoint.split(',')
]
restorers = get_restore_checkpoint_ops(restore_checkpoints,
detection_model, train_config)
def initializer_fn(sess):
for i, restorer in enumerate(restorers):
restorer.restore(sess, restore_checkpoints[i])
init_fn = initializer_fn
with tf.device(deploy_config.optimizer_device()):
regularization_losses = (
None if train_config.add_regularization_loss else [])
total_loss, grads_and_vars = model_deploy.optimize_clones(
clones,
training_optimizer,
regularization_losses=regularization_losses)
total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.')
# Optionally multiply bias gradients by train_config.bias_grad_multiplier.
if train_config.bias_grad_multiplier:
biases_regex_list = ['.*/biases']
grads_and_vars = variables_helper.multiply_gradients_matching_regex(
grads_and_vars,
biases_regex_list,
multiplier=train_config.bias_grad_multiplier)
# Optionally clip gradients
if train_config.gradient_clipping_by_norm > 0:
with tf.name_scope('clip_grads'):
grads_and_vars = slim.learning.clip_gradient_norms(
grads_and_vars, train_config.gradient_clipping_by_norm)
moving_average_variables = slim.get_model_variables()
variable_averages = tf.train.ExponentialMovingAverage(0.9999, global_step)
update_ops.append(variable_averages.apply(moving_average_variables))
# Create gradient updates.
grad_updates = training_optimizer.apply_gradients(
grads_and_vars, global_step=global_step)
update_ops.append(grad_updates)
update_op = tf.group(*update_ops, name='update_barrier')
with tf.control_dependencies([update_op]):
train_tensor = tf.identity(total_loss, name='train_op')
if graph_hook_fn:
with tf.device(deploy_config.variables_device()):
graph_hook_fn()
# Add summaries.
for model_var in slim.get_model_variables():
global_summaries.add(tf.summary.histogram(model_var.op.name, model_var))
for loss_tensor in tf.losses.get_losses():
global_summaries.add(tf.summary.scalar(loss_tensor.op.name, loss_tensor))
global_summaries.add(
tf.summary.scalar('TotalLoss', tf.losses.get_total_loss()))
# Add the summaries from the first clone. These contain the summaries
# created by model_fn and either optimize_clones() or _gather_clone_loss().
summaries |= set(
tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES, 'critic_loss'))
summaries |= global_summaries
# Merge all summaries together.
summary_op = tf.summary.merge(list(summaries), name='summary_op')
# Soft placement allows placing on CPU ops without GPU implementation.
session_config = tf.ConfigProto(
allow_soft_placement=True, log_device_placement=False)
# Save checkpoints regularly.
keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
saver = tf.train.Saver(
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)
slim.learning.train(
train_tensor,
logdir=train_dir,
master=master,
is_chief=is_chief,
session_config=session_config,
startup_delay_steps=train_config.startup_delay_steps,
init_fn=init_fn,
summary_op=summary_op,
number_of_steps=(train_config.num_steps
if train_config.num_steps else None),
save_summaries_secs=120,
sync_optimizer=sync_optimizer,
saver=saver)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Added functionality to load from pipeline config for lstm framework."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from google.protobuf import text_format
from lstm_object_detection.protos import input_reader_google_pb2 # pylint: disable=unused-import
from lstm_object_detection.protos import pipeline_pb2 as internal_pipeline_pb2
from google3.third_party.tensorflow_models.object_detection.protos import pipeline_pb2
from google3.third_party.tensorflow_models.object_detection.utils import config_util
def get_configs_from_pipeline_file(pipeline_config_path):
"""Reads configuration from a pipeline_pb2.TrainEvalPipelineConfig.
Args:
pipeline_config_path: Path to pipeline_pb2.TrainEvalPipelineConfig text
proto.
Returns:
Dictionary of configuration objects. Keys are `model`, `train_config`,
`train_input_config`, `eval_config`, `eval_input_config`, `lstm_confg`.
Value are the corresponding config objects.
"""
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
with tf.gfile.GFile(pipeline_config_path, "r") as f:
proto_str = f.read()
text_format.Merge(proto_str, pipeline_config)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
if pipeline_config.HasExtension(internal_pipeline_pb2.lstm_model):
configs["lstm_model"] = pipeline_config.Extensions[
internal_pipeline_pb2.lstm_model]
return configs
def create_pipeline_proto_from_configs(configs):
"""Creates a pipeline_pb2.TrainEvalPipelineConfig from configs dictionary.
This function nearly performs the inverse operation of
get_configs_from_pipeline_file(). Instead of returning a file path, it returns
a `TrainEvalPipelineConfig` object.
Args:
configs: Dictionary of configs. See get_configs_from_pipeline_file().
Returns:
A fully populated pipeline_pb2.TrainEvalPipelineConfig.
"""
pipeline_config = config_util.create_pipeline_proto_from_configs(configs)
if "lstm_model" in configs:
pipeline_config.Extensions[internal_pipeline_pb2.lstm_model].CopyFrom(
configs["lstm_model"])
return pipeline_config
def get_configs_from_multiple_files(model_config_path="",
train_config_path="",
train_input_config_path="",
eval_config_path="",
eval_input_config_path="",
lstm_config_path=""):
"""Reads training configuration from multiple config files.
Args:
model_config_path: Path to model_pb2.DetectionModel.
train_config_path: Path to train_pb2.TrainConfig.
train_input_config_path: Path to input_reader_pb2.InputReader.
eval_config_path: Path to eval_pb2.EvalConfig.
eval_input_config_path: Path to input_reader_pb2.InputReader.
lstm_config_path: Path to pipeline_pb2.LstmModel.
Returns:
Dictionary of configuration objects. Keys are `model`, `train_config`,
`train_input_config`, `eval_config`, `eval_input_config`, `lstm_model`.
Key/Values are returned only for valid (non-empty) strings.
"""
configs = config_util.get_configs_from_multiple_files(
model_config_path=model_config_path,
train_config_path=train_config_path,
train_input_config_path=train_input_config_path,
eval_config_path=eval_config_path,
eval_input_config_path=eval_input_config_path)
if lstm_config_path:
lstm_config = internal_pipeline_pb2.LstmModel()
with tf.gfile.GFile(lstm_config_path, "r") as f:
text_format.Merge(f.read(), lstm_config)
configs["lstm_model"] = lstm_config
return configs
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for object_detection.utils.config_util."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tensorflow as tf
from google.protobuf import text_format
from lstm_object_detection.protos import pipeline_pb2 as internal_pipeline_pb2
from lstm_object_detection.utils import config_util
from google3.third_party.tensorflow_models.object_detection.protos import pipeline_pb2
def _write_config(config, config_path):
"""Writes a config object to disk."""
config_text = text_format.MessageToString(config)
with tf.gfile.Open(config_path, "wb") as f:
f.write(config_text)
class ConfigUtilTest(tf.test.TestCase):
def test_get_configs_from_pipeline_file(self):
"""Test that proto configs can be read from pipeline config file."""
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.model.ssd.num_classes = 10
pipeline_config.train_config.batch_size = 32
pipeline_config.train_input_reader.label_map_path = "path/to/label_map"
pipeline_config.eval_config.num_examples = 20
pipeline_config.eval_input_reader.add().queue_capacity = 100
pipeline_config.Extensions[
internal_pipeline_pb2.lstm_model].train_unroll_length = 5
pipeline_config.Extensions[
internal_pipeline_pb2.lstm_model].eval_unroll_length = 10
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
self.assertProtoEquals(pipeline_config.model, configs["model"])
self.assertProtoEquals(pipeline_config.train_config,
configs["train_config"])
self.assertProtoEquals(pipeline_config.train_input_reader,
configs["train_input_config"])
self.assertProtoEquals(pipeline_config.eval_config, configs["eval_config"])
self.assertProtoEquals(pipeline_config.eval_input_reader,
configs["eval_input_configs"])
self.assertProtoEquals(
pipeline_config.Extensions[internal_pipeline_pb2.lstm_model],
configs["lstm_model"])
def test_create_pipeline_proto_from_configs(self):
"""Tests that proto can be reconstructed from configs dictionary."""
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.model.ssd.num_classes = 10
pipeline_config.train_config.batch_size = 32
pipeline_config.train_input_reader.label_map_path = "path/to/label_map"
pipeline_config.eval_config.num_examples = 20
pipeline_config.eval_input_reader.add().queue_capacity = 100
pipeline_config.Extensions[
internal_pipeline_pb2.lstm_model].train_unroll_length = 5
pipeline_config.Extensions[
internal_pipeline_pb2.lstm_model].eval_unroll_length = 10
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
pipeline_config_reconstructed = (
config_util.create_pipeline_proto_from_configs(configs))
self.assertEqual(pipeline_config, pipeline_config_reconstructed)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment