Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
0d8e49ec
Commit
0d8e49ec
authored
Jul 30, 2018
by
Yinxiao Li
Committed by
dreamdragon
Oct 24, 2018
Browse files
PiperOrigin-RevId: 206648257
parent
d7676c1c
Changes
27
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1393 additions
and
0 deletions
+1393
-0
research/lstm_object_detection/seq_dataset_builder_test.py
research/lstm_object_detection/seq_dataset_builder_test.py
+283
-0
research/lstm_object_detection/tf_sequence_example_decoder.py
...arch/lstm_object_detection/tf_sequence_example_decoder.py
+204
-0
research/lstm_object_detection/tf_sequence_example_decoder_test.py
...lstm_object_detection/tf_sequence_example_decoder_test.py
+111
-0
research/lstm_object_detection/train.py
research/lstm_object_detection/train.py
+185
-0
research/lstm_object_detection/trainer.py
research/lstm_object_detection/trainer.py
+410
-0
research/lstm_object_detection/utils/config_util.py
research/lstm_object_detection/utils/config_util.py
+106
-0
research/lstm_object_detection/utils/config_util_test.py
research/lstm_object_detection/utils/config_util_test.py
+94
-0
No files found.
research/lstm_object_detection/seq_dataset_builder_test.py
0 → 100644
View file @
0d8e49ec
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for dataset_builder."""
import
os
import
numpy
as
np
import
tensorflow
as
tf
from
google.protobuf
import
text_format
from
google3.testing.pybase
import
parameterized
from
tensorflow.core.example
import
example_pb2
from
tensorflow.core.example
import
feature_pb2
from
lstm_object_detection
import
seq_dataset_builder
from
lstm_object_detection.protos
import
pipeline_pb2
as
internal_pipeline_pb2
from
google3.third_party.tensorflow_models.object_detection.builders
import
preprocessor_builder
from
google3.third_party.tensorflow_models.object_detection.core
import
standard_fields
as
fields
from
google3.third_party.tensorflow_models.object_detection.protos
import
input_reader_pb2
from
google3.third_party.tensorflow_models.object_detection.protos
import
pipeline_pb2
from
google3.third_party.tensorflow_models.object_detection.protos
import
preprocessor_pb2
class
DatasetBuilderTest
(
parameterized
.
TestCase
):
def
_create_tf_record
(
self
):
path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'tfrecord'
)
writer
=
tf
.
python_io
.
TFRecordWriter
(
path
)
image_tensor
=
np
.
random
.
randint
(
255
,
size
=
(
16
,
16
,
3
)).
astype
(
np
.
uint8
)
with
self
.
test_session
():
encoded_jpeg
=
tf
.
image
.
encode_jpeg
(
tf
.
constant
(
image_tensor
)).
eval
()
sequence_example
=
example_pb2
.
SequenceExample
(
context
=
feature_pb2
.
Features
(
feature
=
{
'image/format'
:
feature_pb2
.
Feature
(
bytes_list
=
feature_pb2
.
BytesList
(
value
=
[
'jpeg'
.
encode
(
'utf-8'
)])),
'image/height'
:
feature_pb2
.
Feature
(
int64_list
=
feature_pb2
.
Int64List
(
value
=
[
16
])),
'image/width'
:
feature_pb2
.
Feature
(
int64_list
=
feature_pb2
.
Int64List
(
value
=
[
16
])),
}),
feature_lists
=
feature_pb2
.
FeatureLists
(
feature_list
=
{
'image/encoded'
:
feature_pb2
.
FeatureList
(
feature
=
[
feature_pb2
.
Feature
(
bytes_list
=
feature_pb2
.
BytesList
(
value
=
[
encoded_jpeg
])),
]),
'image/object/bbox/xmin'
:
feature_pb2
.
FeatureList
(
feature
=
[
feature_pb2
.
Feature
(
float_list
=
feature_pb2
.
FloatList
(
value
=
[
0.0
])),
]),
'image/object/bbox/xmax'
:
feature_pb2
.
FeatureList
(
feature
=
[
feature_pb2
.
Feature
(
float_list
=
feature_pb2
.
FloatList
(
value
=
[
1.0
]))
]),
'image/object/bbox/ymin'
:
feature_pb2
.
FeatureList
(
feature
=
[
feature_pb2
.
Feature
(
float_list
=
feature_pb2
.
FloatList
(
value
=
[
0.0
])),
]),
'image/object/bbox/ymax'
:
feature_pb2
.
FeatureList
(
feature
=
[
feature_pb2
.
Feature
(
float_list
=
feature_pb2
.
FloatList
(
value
=
[
1.0
]))
]),
'image/object/class/label'
:
feature_pb2
.
FeatureList
(
feature
=
[
feature_pb2
.
Feature
(
int64_list
=
feature_pb2
.
Int64List
(
value
=
[
2
]))
]),
}))
writer
.
write
(
sequence_example
.
SerializeToString
())
writer
.
close
()
return
path
def
_get_model_configs_from_proto
(
self
):
"""Creates a model text proto for testing.
Returns:
A dictionary of model configs.
"""
model_text_proto
=
"""
[object_detection.protos.lstm_model] {
train_unroll_length: 4
eval_unroll_length: 4
}
model {
ssd {
feature_extractor {
type: 'lstm_mobilenet_v1_fpn'
conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
}
negative_class_weight: 2.0
box_coder {
faster_rcnn_box_coder {
}
}
matcher {
argmax_matcher {
}
}
similarity_calculator {
iou_similarity {
}
}
anchor_generator {
ssd_anchor_generator {
aspect_ratios: 1.0
}
}
image_resizer {
fixed_shape_resizer {
height: 32
width: 32
}
}
box_predictor {
convolutional_box_predictor {
conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
}
}
normalize_loc_loss_by_codesize: true
loss {
classification_loss {
weighted_softmax {
}
}
localization_loss {
weighted_smooth_l1 {
}
}
}
}
}"""
pipeline_config
=
pipeline_pb2
.
TrainEvalPipelineConfig
()
text_format
.
Merge
(
model_text_proto
,
pipeline_config
)
configs
=
{}
configs
[
'model'
]
=
pipeline_config
.
model
configs
[
'lstm_model'
]
=
pipeline_config
.
Extensions
[
internal_pipeline_pb2
.
lstm_model
]
return
configs
def
_get_data_augmentation_preprocessor_proto
(
self
):
preprocessor_text_proto
=
"""
random_horizontal_flip {
}
"""
preprocessor_proto
=
preprocessor_pb2
.
PreprocessingStep
()
text_format
.
Merge
(
preprocessor_text_proto
,
preprocessor_proto
)
return
preprocessor_proto
def
_create_training_dict
(
self
,
tensor_dict
):
image_dict
=
{}
all_dict
=
{}
all_dict
[
'batch'
]
=
tensor_dict
.
pop
(
'batch'
)
for
i
,
_
in
enumerate
(
tensor_dict
[
fields
.
InputDataFields
.
image
]):
for
key
,
val
in
tensor_dict
.
items
():
image_dict
[
key
]
=
val
[
i
]
image_dict
[
fields
.
InputDataFields
.
image
]
=
tf
.
to_float
(
tf
.
expand_dims
(
image_dict
[
fields
.
InputDataFields
.
image
],
0
))
suffix
=
str
(
i
)
for
key
,
val
in
image_dict
.
items
():
all_dict
[
key
+
suffix
]
=
val
return
all_dict
def
_get_input_proto
(
self
,
input_reader
):
return
"""
external_input_reader {
[lstm_object_detection.input_readers.GoogleInputReader.google_input_reader] {
%s: {
input_path: '{0}'
data_type: TF_SEQUENCE_EXAMPLE
video_length: 4
}
}
}
"""
%
input_reader
@
parameterized
.
named_parameters
((
'tf_record'
,
'tf_record_video_input_reader'
))
def
test_video_input_reader
(
self
,
video_input_type
):
input_reader_proto
=
input_reader_pb2
.
InputReader
()
text_format
.
Merge
(
self
.
_get_input_proto
(
video_input_type
),
input_reader_proto
)
configs
=
self
.
_get_model_configs_from_proto
()
tensor_dict
=
seq_dataset_builder
.
build
(
input_reader_proto
,
configs
[
'model'
],
configs
[
'lstm_model'
],
unroll_length
=
1
)
all_dict
=
self
.
_create_training_dict
(
tensor_dict
)
self
.
assertEqual
((
1
,
32
,
32
,
3
),
all_dict
[
'image0'
].
shape
)
self
.
assertEqual
(
4
,
all_dict
[
'groundtruth_boxes0'
].
shape
[
1
])
def
test_build_with_data_augmentation
(
self
):
input_reader_proto
=
input_reader_pb2
.
InputReader
()
text_format
.
Merge
(
self
.
_get_input_proto
(
'tf_record_video_input_reader'
),
input_reader_proto
)
configs
=
self
.
_get_model_configs_from_proto
()
data_augmentation_options
=
[
preprocessor_builder
.
build
(
self
.
_get_data_augmentation_preprocessor_proto
())
]
tensor_dict
=
seq_dataset_builder
.
build
(
input_reader_proto
,
configs
[
'model'
],
configs
[
'lstm_model'
],
unroll_length
=
1
,
data_augmentation_options
=
data_augmentation_options
)
all_dict
=
self
.
_create_training_dict
(
tensor_dict
)
self
.
assertEqual
((
1
,
32
,
32
,
3
),
all_dict
[
'image0'
].
shape
)
self
.
assertEqual
(
4
,
all_dict
[
'groundtruth_boxes0'
].
shape
[
1
])
def
test_raises_error_without_input_paths
(
self
):
input_reader_text_proto
=
"""
shuffle: false
num_readers: 1
load_instance_masks: true
"""
input_reader_proto
=
input_reader_pb2
.
InputReader
()
text_format
.
Merge
(
input_reader_text_proto
,
input_reader_proto
)
configs
=
self
.
_get_model_configs_from_proto
()
with
self
.
assertRaises
(
ValueError
):
_
=
seq_dataset_builder
.
build
(
input_reader_proto
,
configs
[
'model'
],
configs
[
'lstm_model'
],
unroll_length
=
1
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/lstm_object_detection/tf_sequence_example_decoder.py
0 → 100644
View file @
0d8e49ec
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tensorflow Sequence Example proto decoder.
A decoder to decode string tensors containing serialized
tensorflow.SequenceExample protos.
TODO(yinxiao): When TensorFlow object detection API officially supports
tensorflow.SequenceExample, merge this decoder.
"""
import
tensorflow
as
tf
from
google3.learning.brain.contrib.slim.data
import
tfexample_decoder
from
google3.third_party.tensorflow_models.object_detection.core
import
data_decoder
from
google3.third_party.tensorflow_models.object_detection.core
import
standard_fields
as
fields
slim_example_decoder
=
tf
.
contrib
.
slim
.
tfexample_decoder
class
TfSequenceExampleDecoder
(
data_decoder
.
DataDecoder
):
"""Tensorflow Sequence Example proto decoder."""
def
__init__
(
self
):
"""Constructor sets keys_to_features and items_to_handlers."""
self
.
keys_to_context_features
=
{
'image/format'
:
tf
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
'jpeg'
),
'image/filename'
:
tf
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
''
),
'image/key/sha256'
:
tf
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
''
),
'image/source_id'
:
tf
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
''
),
'image/height'
:
tf
.
FixedLenFeature
((),
tf
.
int64
,
1
),
'image/width'
:
tf
.
FixedLenFeature
((),
tf
.
int64
,
1
),
}
self
.
keys_to_features
=
{
'image/encoded'
:
tf
.
FixedLenSequenceFeature
((),
tf
.
string
),
'bbox/xmin'
:
tf
.
VarLenFeature
(
dtype
=
tf
.
float32
),
'bbox/xmax'
:
tf
.
VarLenFeature
(
dtype
=
tf
.
float32
),
'bbox/ymin'
:
tf
.
VarLenFeature
(
dtype
=
tf
.
float32
),
'bbox/ymax'
:
tf
.
VarLenFeature
(
dtype
=
tf
.
float32
),
'bbox/label/index'
:
tf
.
VarLenFeature
(
dtype
=
tf
.
int64
),
'bbox/label/string'
:
tf
.
VarLenFeature
(
tf
.
string
),
'area'
:
tf
.
VarLenFeature
(
tf
.
float32
),
'is_crowd'
:
tf
.
VarLenFeature
(
tf
.
int64
),
'difficult'
:
tf
.
VarLenFeature
(
tf
.
int64
),
'group_of'
:
tf
.
VarLenFeature
(
tf
.
int64
),
}
self
.
items_to_handlers
=
{
fields
.
InputDataFields
.
image
:
slim_example_decoder
.
Image
(
image_key
=
'image/encoded'
,
format_key
=
'image/format'
,
channels
=
3
,
repeated
=
True
),
fields
.
InputDataFields
.
source_id
:
(
slim_example_decoder
.
Tensor
(
'image/source_id'
)),
fields
.
InputDataFields
.
key
:
(
slim_example_decoder
.
Tensor
(
'image/key/sha256'
)),
fields
.
InputDataFields
.
filename
:
(
slim_example_decoder
.
Tensor
(
'image/filename'
)),
# Object boxes and classes.
fields
.
InputDataFields
.
groundtruth_boxes
:
tfexample_decoder
.
BoundingBoxSequence
(
prefix
=
'bbox/'
),
fields
.
InputDataFields
.
groundtruth_classes
:
(
slim_example_decoder
.
Tensor
(
'bbox/label/index'
)),
fields
.
InputDataFields
.
groundtruth_area
:
slim_example_decoder
.
Tensor
(
'area'
),
fields
.
InputDataFields
.
groundtruth_is_crowd
:
(
slim_example_decoder
.
Tensor
(
'is_crowd'
)),
fields
.
InputDataFields
.
groundtruth_difficult
:
(
slim_example_decoder
.
Tensor
(
'difficult'
)),
fields
.
InputDataFields
.
groundtruth_group_of
:
(
slim_example_decoder
.
Tensor
(
'group_of'
))
}
def
decode
(
self
,
tf_seq_example_string_tensor
,
items
=
None
):
"""Decodes serialized tf.SequenceExample and returns a tensor dictionary.
Args:
tf_seq_example_string_tensor: A string tensor holding a serialized
tensorflow example proto.
items: The list of items to decode. These must be a subset of the item
keys in self._items_to_handlers. If `items` is left as None, then all
of the items in self._items_to_handlers are decoded.
Returns:
A dictionary of the following tensors.
fields.InputDataFields.image - 3D uint8 tensor of shape [None, None, seq]
containing image(s).
fields.InputDataFields.source_id - string tensor containing original
image id.
fields.InputDataFields.key - string tensor with unique sha256 hash key.
fields.InputDataFields.filename - string tensor with original dataset
filename.
fields.InputDataFields.groundtruth_boxes - 2D float32 tensor of shape
[None, 4] containing box corners.
fields.InputDataFields.groundtruth_classes - 1D int64 tensor of shape
[None] containing classes for the boxes.
fields.InputDataFields.groundtruth_area - 1D float32 tensor of shape
[None] containing object mask area in pixel squared.
fields.InputDataFields.groundtruth_is_crowd - 1D bool tensor of shape
[None] indicating if the boxes enclose a crowd.
fields.InputDataFields.groundtruth_difficult - 1D bool tensor of shape
[None] indicating if the boxes represent `difficult` instances.
"""
serialized_example
=
tf
.
reshape
(
tf_seq_example_string_tensor
,
shape
=
[])
decoder
=
TFSequenceExampleDecoderHelper
(
self
.
keys_to_context_features
,
self
.
keys_to_features
,
self
.
items_to_handlers
)
if
not
items
:
items
=
decoder
.
list_items
()
tensors
=
decoder
.
decode
(
serialized_example
,
items
=
items
)
tensor_dict
=
dict
(
zip
(
items
,
tensors
))
return
tensor_dict
class
TFSequenceExampleDecoderHelper
(
data_decoder
.
DataDecoder
):
"""A decoder helper class for TensorFlow SequenceExamples.
To perform this decoding operation, a SequenceExampleDecoder is given a list
of ItemHandlers. Each ItemHandler indicates the set of features.
"""
def
__init__
(
self
,
keys_to_context_features
,
keys_to_sequence_features
,
items_to_handlers
):
"""Constructs the decoder.
Args:
keys_to_context_features: A dictionary from TF-SequenceExample context
keys to either tf.VarLenFeature or tf.FixedLenFeature instances.
See tensorflow's parsing_ops.py.
keys_to_sequence_features: A dictionary from TF-SequenceExample sequence
keys to either tf.VarLenFeature or tf.FixedLenSequenceFeature instances.
items_to_handlers: A dictionary from items (strings) to ItemHandler
instances. Note that the ItemHandler's are provided the keys that they
use to return the final item Tensors.
Raises:
ValueError: If the same key is present for context features and sequence
features.
"""
unique_keys
=
set
()
unique_keys
.
update
(
keys_to_context_features
)
unique_keys
.
update
(
keys_to_sequence_features
)
if
len
(
unique_keys
)
!=
(
len
(
keys_to_context_features
)
+
len
(
keys_to_sequence_features
)):
# This situation is ambiguous in the decoder's keys_to_tensors variable.
raise
ValueError
(
'Context and sequence keys are not unique.
\n
'
' Context keys: %s
\n
Sequence keys: %s'
%
(
list
(
keys_to_context_features
.
keys
()),
list
(
keys_to_sequence_features
.
keys
())))
self
.
_keys_to_context_features
=
keys_to_context_features
self
.
_keys_to_sequence_features
=
keys_to_sequence_features
self
.
_items_to_handlers
=
items_to_handlers
def
list_items
(
self
):
"""Returns keys of items."""
return
self
.
_items_to_handlers
.
keys
()
def
decode
(
self
,
serialized_example
,
items
=
None
):
"""Decodes the given serialized TF-SequenceExample.
Args:
serialized_example: A serialized TF-SequenceExample tensor.
items: The list of items to decode. These must be a subset of the item
keys in self._items_to_handlers. If `items` is left as None, then all
of the items in self._items_to_handlers are decoded.
Returns:
The decoded items, a list of tensor.
"""
context
,
feature_list
=
tf
.
parse_single_sequence_example
(
serialized_example
,
self
.
_keys_to_context_features
,
self
.
_keys_to_sequence_features
)
# Reshape non-sparse elements just once:
for
k
in
self
.
_keys_to_context_features
:
v
=
self
.
_keys_to_context_features
[
k
]
if
isinstance
(
v
,
tf
.
FixedLenFeature
):
context
[
k
]
=
tf
.
reshape
(
context
[
k
],
v
.
shape
)
if
not
items
:
items
=
self
.
_items_to_handlers
.
keys
()
outputs
=
[]
for
item
in
items
:
handler
=
self
.
_items_to_handlers
[
item
]
keys_to_tensors
=
{
key
:
context
[
key
]
if
key
in
context
else
feature_list
[
key
]
for
key
in
handler
.
keys
}
outputs
.
append
(
handler
.
tensors_to_item
(
keys_to_tensors
))
return
outputs
research/lstm_object_detection/tf_sequence_example_decoder_test.py
0 → 100644
View file @
0d8e49ec
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for video_object_detection.tf_sequence_example_decoder."""
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.core.example
import
example_pb2
from
tensorflow.core.example
import
feature_pb2
from
lstm_object_detection
import
tf_sequence_example_decoder
from
google3.third_party.tensorflow_models.object_detection.core
import
standard_fields
as
fields
class
TfSequenceExampleDecoderTest
(
tf
.
test
.
TestCase
):
"""Tests for sequence example decoder."""
def
_EncodeImage
(
self
,
image_tensor
,
encoding_type
=
'jpeg'
):
with
self
.
test_session
():
if
encoding_type
==
'jpeg'
:
image_encoded
=
tf
.
image
.
encode_jpeg
(
tf
.
constant
(
image_tensor
)).
eval
()
else
:
raise
ValueError
(
'Invalid encoding type.'
)
return
image_encoded
def
_DecodeImage
(
self
,
image_encoded
,
encoding_type
=
'jpeg'
):
with
self
.
test_session
():
if
encoding_type
==
'jpeg'
:
image_decoded
=
tf
.
image
.
decode_jpeg
(
tf
.
constant
(
image_encoded
)).
eval
()
else
:
raise
ValueError
(
'Invalid encoding type.'
)
return
image_decoded
def
testDecodeJpegImageAndBoundingBox
(
self
):
"""Test if the decoder can correctly decode the image and bounding box.
A set of random images (represented as an image tensor) is first decoded as
the groundtrue image. Meanwhile, the image tensor will be encoded and pass
through the sequence example, and then decoded as images. The groundtruth
image and the decoded image are expected to be equal. Similar tests are
also applied to labels such as bounding box.
"""
image_tensor
=
np
.
random
.
randint
(
256
,
size
=
(
256
,
256
,
3
)).
astype
(
np
.
uint8
)
encoded_jpeg
=
self
.
_EncodeImage
(
image_tensor
)
decoded_jpeg
=
self
.
_DecodeImage
(
encoded_jpeg
)
sequence_example
=
example_pb2
.
SequenceExample
(
feature_lists
=
feature_pb2
.
FeatureLists
(
feature_list
=
{
'image/encoded'
:
feature_pb2
.
FeatureList
(
feature
=
[
feature_pb2
.
Feature
(
bytes_list
=
feature_pb2
.
BytesList
(
value
=
[
encoded_jpeg
])),
]),
'bbox/xmin'
:
feature_pb2
.
FeatureList
(
feature
=
[
feature_pb2
.
Feature
(
float_list
=
feature_pb2
.
FloatList
(
value
=
[
0.0
])),
]),
'bbox/xmax'
:
feature_pb2
.
FeatureList
(
feature
=
[
feature_pb2
.
Feature
(
float_list
=
feature_pb2
.
FloatList
(
value
=
[
1.0
]))
]),
'bbox/ymin'
:
feature_pb2
.
FeatureList
(
feature
=
[
feature_pb2
.
Feature
(
float_list
=
feature_pb2
.
FloatList
(
value
=
[
0.0
])),
]),
'bbox/ymax'
:
feature_pb2
.
FeatureList
(
feature
=
[
feature_pb2
.
Feature
(
float_list
=
feature_pb2
.
FloatList
(
value
=
[
1.0
]))
]),
})).
SerializeToString
()
example_decoder
=
tf_sequence_example_decoder
.
TfSequenceExampleDecoder
()
tensor_dict
=
example_decoder
.
decode
(
tf
.
convert_to_tensor
(
sequence_example
))
# Test tensor dict image dimension.
self
.
assertAllEqual
(
(
tensor_dict
[
fields
.
InputDataFields
.
image
].
get_shape
().
as_list
()),
[
None
,
None
,
None
,
3
])
with
self
.
test_session
()
as
sess
:
tensor_dict
[
fields
.
InputDataFields
.
image
]
=
tf
.
squeeze
(
tensor_dict
[
fields
.
InputDataFields
.
image
])
tensor_dict
[
fields
.
InputDataFields
.
groundtruth_boxes
]
=
tf
.
squeeze
(
tensor_dict
[
fields
.
InputDataFields
.
groundtruth_boxes
])
tensor_dict
=
sess
.
run
(
tensor_dict
)
# Test decoded image.
self
.
assertAllEqual
(
decoded_jpeg
,
tensor_dict
[
fields
.
InputDataFields
.
image
])
# Test decoded bounding box.
self
.
assertAllEqual
([
0.0
,
0.0
,
1.0
,
1.0
],
tensor_dict
[
fields
.
InputDataFields
.
groundtruth_boxes
])
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/lstm_object_detection/train.py
0 → 100644
View file @
0d8e49ec
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r
"""Training executable for detection models.
This executable is used to train DetectionModels. There are two ways of
configuring the training job:
1) A single pipeline_pb2.TrainEvalPipelineConfig configuration file
can be specified by --pipeline_config_path.
Example usage:
./train \
--logtostderr \
--train_dir=path/to/train_dir \
--pipeline_config_path=pipeline_config.pbtxt
2) Three configuration files can be provided: a model_pb2.DetectionModel
configuration file to define what type of DetectionModel is being trained, an
input_reader_pb2.InputReader file to specify what training data will be used and
a train_pb2.TrainConfig file to configure training parameters.
Example usage:
./train \
--logtostderr \
--train_dir=path/to/train_dir \
--model_config_path=model_config.pbtxt \
--train_config_path=train_config.pbtxt \
--input_config_path=train_input_config.pbtxt
"""
import
functools
import
json
import
os
from
absl
import
flags
import
tensorflow
as
tf
from
lstm_object_detection
import
model_builder
from
lstm_object_detection
import
seq_dataset_builder
from
lstm_object_detection
import
trainer
from
lstm_object_detection.utils
import
config_util
from
google3.third_party.tensorflow_models.object_detection.builders
import
preprocessor_builder
flags
.
DEFINE_string
(
'master'
,
''
,
'Name of the TensorFlow master to use.'
)
flags
.
DEFINE_integer
(
'task'
,
0
,
'task id'
)
flags
.
DEFINE_integer
(
'num_clones'
,
1
,
'Number of clones to deploy per worker.'
)
flags
.
DEFINE_boolean
(
'clone_on_cpu'
,
False
,
'Force clones to be deployed on CPU. Note that even if '
'set to False (allowing ops to run on gpu), some ops may '
'still be run on the CPU if they have no GPU kernel.'
)
flags
.
DEFINE_integer
(
'worker_replicas'
,
1
,
'Number of worker+trainer '
'replicas.'
)
flags
.
DEFINE_integer
(
'ps_tasks'
,
0
,
'Number of parameter server tasks. If None, does not use '
'a parameter server.'
)
flags
.
DEFINE_string
(
'train_dir'
,
''
,
'Directory to save the checkpoints and training summaries.'
)
flags
.
DEFINE_string
(
'pipeline_config_path'
,
''
,
'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
'file. If provided, other configs are ignored'
)
flags
.
DEFINE_string
(
'train_config_path'
,
''
,
'Path to a train_pb2.TrainConfig config file.'
)
flags
.
DEFINE_string
(
'input_config_path'
,
''
,
'Path to an input_reader_pb2.InputReader config file.'
)
flags
.
DEFINE_string
(
'model_config_path'
,
''
,
'Path to a model_pb2.DetectionModel config file.'
)
FLAGS
=
flags
.
FLAGS
def
main
(
_
):
assert
FLAGS
.
train_dir
,
'`train_dir` is missing.'
if
FLAGS
.
task
==
0
:
tf
.
gfile
.
MakeDirs
(
FLAGS
.
train_dir
)
if
FLAGS
.
pipeline_config_path
:
configs
=
config_util
.
get_configs_from_pipeline_file
(
FLAGS
.
pipeline_config_path
)
if
FLAGS
.
task
==
0
:
tf
.
gfile
.
Copy
(
FLAGS
.
pipeline_config_path
,
os
.
path
.
join
(
FLAGS
.
train_dir
,
'pipeline.config'
),
overwrite
=
True
)
else
:
configs
=
config_util
.
get_configs_from_multiple_files
(
model_config_path
=
FLAGS
.
model_config_path
,
train_config_path
=
FLAGS
.
train_config_path
,
train_input_config_path
=
FLAGS
.
input_config_path
)
if
FLAGS
.
task
==
0
:
for
name
,
config
in
[(
'model.config'
,
FLAGS
.
model_config_path
),
(
'train.config'
,
FLAGS
.
train_config_path
),
(
'input.config'
,
FLAGS
.
input_config_path
)]:
tf
.
gfile
.
Copy
(
config
,
os
.
path
.
join
(
FLAGS
.
train_dir
,
name
),
overwrite
=
True
)
model_config
=
configs
[
'model'
]
lstm_config
=
configs
[
'lstm_model'
]
train_config
=
configs
[
'train_config'
]
input_config
=
configs
[
'train_input_config'
]
model_fn
=
functools
.
partial
(
model_builder
.
build
,
model_config
=
model_config
,
lstm_config
=
lstm_config
,
is_training
=
True
)
def
get_next
(
config
,
model_config
,
lstm_config
,
unroll_length
):
data_augmentation_options
=
[
preprocessor_builder
.
build
(
step
)
for
step
in
train_config
.
data_augmentation_options
]
return
seq_dataset_builder
.
build
(
config
,
model_config
,
lstm_config
,
unroll_length
,
data_augmentation_options
,
batch_size
=
train_config
.
batch_size
)
create_input_dict_fn
=
functools
.
partial
(
get_next
,
input_config
,
model_config
,
lstm_config
,
lstm_config
.
train_unroll_length
)
env
=
json
.
loads
(
os
.
environ
.
get
(
'TF_CONFIG'
,
'{}'
))
cluster_data
=
env
.
get
(
'cluster'
,
None
)
cluster
=
tf
.
train
.
ClusterSpec
(
cluster_data
)
if
cluster_data
else
None
task_data
=
env
.
get
(
'task'
,
None
)
or
{
'type'
:
'master'
,
'index'
:
0
}
task_info
=
type
(
'TaskSpec'
,
(
object
,),
task_data
)
# Parameters for a single worker.
ps_tasks
=
0
worker_replicas
=
1
worker_job_name
=
'lonely_worker'
task
=
0
is_chief
=
True
master
=
''
if
cluster_data
and
'worker'
in
cluster_data
:
# Number of total worker replicas include "worker"s and the "master".
worker_replicas
=
len
(
cluster_data
[
'worker'
])
+
1
if
cluster_data
and
'ps'
in
cluster_data
:
ps_tasks
=
len
(
cluster_data
[
'ps'
])
if
worker_replicas
>
1
and
ps_tasks
<
1
:
raise
ValueError
(
'At least 1 ps task is needed for distributed training.'
)
if
worker_replicas
>=
1
and
ps_tasks
>
0
:
# Set up distributed training.
server
=
tf
.
train
.
Server
(
tf
.
train
.
ClusterSpec
(
cluster
),
protocol
=
'grpc'
,
job_name
=
task_info
.
type
,
task_index
=
task_info
.
index
)
if
task_info
.
type
==
'ps'
:
server
.
join
()
return
worker_job_name
=
'%s/task:%d'
%
(
task_info
.
type
,
task_info
.
index
)
task
=
task_info
.
index
is_chief
=
(
task_info
.
type
==
'master'
)
master
=
server
.
target
trainer
.
train
(
create_input_dict_fn
,
model_fn
,
train_config
,
master
,
task
,
FLAGS
.
num_clones
,
worker_replicas
,
FLAGS
.
clone_on_cpu
,
ps_tasks
,
worker_job_name
,
is_chief
,
FLAGS
.
train_dir
)
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
research/lstm_object_detection/trainer.py
0 → 100644
View file @
0d8e49ec
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Detection model trainer.
This file provides a generic training method that can be used to train a
DetectionModel.
"""
import
functools
import
tensorflow
as
tf
from
google3.pyglib
import
logging
from
google3.third_party.tensorflow_models.object_detection.builders
import
optimizer_builder
from
google3.third_party.tensorflow_models.object_detection.core
import
standard_fields
as
fields
from
google3.third_party.tensorflow_models.object_detection.utils
import
ops
as
util_ops
from
google3.third_party.tensorflow_models.object_detection.utils
import
variables_helper
from
deployment
import
model_deploy
slim
=
tf
.
contrib
.
slim
def
create_input_queue
(
create_tensor_dict_fn
):
"""Sets up reader, prefetcher and returns input queue.
Args:
create_tensor_dict_fn: function to create tensor dictionary.
Returns:
all_dict: A dictionary holds tensors for images, boxes, and targets.
"""
tensor_dict
=
create_tensor_dict_fn
()
all_dict
=
{}
num_images
=
len
(
tensor_dict
[
fields
.
InputDataFields
.
image
])
all_dict
[
'batch'
]
=
tensor_dict
[
'batch'
]
del
tensor_dict
[
'batch'
]
for
i
in
range
(
num_images
):
suffix
=
str
(
i
)
for
key
,
val
in
tensor_dict
.
items
():
all_dict
[
key
+
suffix
]
=
val
[
i
]
all_dict
[
fields
.
InputDataFields
.
image
+
suffix
]
=
tf
.
to_float
(
tf
.
expand_dims
(
all_dict
[
fields
.
InputDataFields
.
image
+
suffix
],
0
))
return
all_dict
def
get_inputs
(
input_queue
,
num_classes
,
merge_multiple_label_boxes
=
False
):
"""Dequeues batch and constructs inputs to object detection model.
Args:
input_queue: BatchQueue object holding enqueued tensor_dicts.
num_classes: Number of classes.
merge_multiple_label_boxes: Whether to merge boxes with multiple labels
or not. Defaults to false. Merged boxes are represented with a single
box and a k-hot encoding of the multiple labels associated with the
boxes.
Returns:
images: a list of 3-D float tensor of images.
image_keys: a list of string keys for the images.
locations: a list of tensors of shape [num_boxes, 4] containing the corners
of the groundtruth boxes.
classes: a list of padded one-hot tensors containing target classes.
masks: a list of 3-D float tensors of shape [num_boxes, image_height,
image_width] containing instance masks for objects if present in the
input_queue. Else returns None.
keypoints: a list of 3-D float tensors of shape [num_boxes, num_keypoints,
2] containing keypoints for objects if present in the
input queue. Else returns None.
"""
read_data_list
=
input_queue
label_id_offset
=
1
def
extract_images_and_targets
(
read_data
):
"""Extract images and targets from the input dict."""
suffix
=
0
images
=
[]
keys
=
[]
locations
=
[]
classes
=
[]
masks
=
[]
keypoints
=
[]
while
fields
.
InputDataFields
.
image
+
str
(
suffix
)
in
read_data
:
image
=
read_data
[
fields
.
InputDataFields
.
image
+
str
(
suffix
)]
key
=
''
if
fields
.
InputDataFields
.
source_id
in
read_data
:
key
=
read_data
[
fields
.
InputDataFields
.
source_id
+
str
(
suffix
)]
location_gt
=
(
read_data
[
fields
.
InputDataFields
.
groundtruth_boxes
+
str
(
suffix
)])
classes_gt
=
tf
.
cast
(
read_data
[
fields
.
InputDataFields
.
groundtruth_classes
+
str
(
suffix
)],
tf
.
int32
)
classes_gt
-=
label_id_offset
masks_gt
=
read_data
.
get
(
fields
.
InputDataFields
.
groundtruth_instance_masks
+
str
(
suffix
))
keypoints_gt
=
read_data
.
get
(
fields
.
InputDataFields
.
groundtruth_keypoints
+
str
(
suffix
))
if
merge_multiple_label_boxes
:
location_gt
,
classes_gt
,
_
=
util_ops
.
merge_boxes_with_multiple_labels
(
location_gt
,
classes_gt
,
num_classes
)
else
:
classes_gt
=
util_ops
.
padded_one_hot_encoding
(
indices
=
classes_gt
,
depth
=
num_classes
,
left_pad
=
0
)
# Batch read input data and groundtruth. Images and locations, classes by
# default should have the same number of items.
images
.
append
(
image
)
keys
.
append
(
key
)
locations
.
append
(
location_gt
)
classes
.
append
(
classes_gt
)
masks
.
append
(
masks_gt
)
keypoints
.
append
(
keypoints_gt
)
suffix
+=
1
return
(
images
,
keys
,
locations
,
classes
,
masks
,
keypoints
)
return
extract_images_and_targets
(
read_data_list
)
def
_create_losses
(
input_queue
,
create_model_fn
,
train_config
):
"""Creates loss function for a DetectionModel.
Args:
input_queue: BatchQueue object holding enqueued tensor_dicts.
create_model_fn: A function to create the DetectionModel.
train_config: a train_pb2.TrainConfig protobuf.
"""
detection_model
=
create_model_fn
()
(
images
,
_
,
groundtruth_boxes_list
,
groundtruth_classes_list
,
groundtruth_masks_list
,
groundtruth_keypoints_list
)
=
get_inputs
(
input_queue
,
detection_model
.
num_classes
,
train_config
.
merge_multiple_label_boxes
)
preprocessed_images
=
[]
true_image_shapes
=
[]
for
image
in
images
:
resized_image
,
true_image_shape
=
detection_model
.
preprocess
(
image
)
preprocessed_images
.
append
(
resized_image
)
true_image_shapes
.
append
(
true_image_shape
)
images
=
tf
.
concat
(
preprocessed_images
,
0
)
true_image_shapes
=
tf
.
concat
(
true_image_shapes
,
0
)
if
any
(
mask
is
None
for
mask
in
groundtruth_masks_list
):
groundtruth_masks_list
=
None
if
any
(
keypoints
is
None
for
keypoints
in
groundtruth_keypoints_list
):
groundtruth_keypoints_list
=
None
detection_model
.
provide_groundtruth
(
groundtruth_boxes_list
,
groundtruth_classes_list
,
groundtruth_masks_list
,
groundtruth_keypoints_list
)
prediction_dict
=
detection_model
.
predict
(
images
,
true_image_shapes
,
input_queue
[
'batch'
])
losses_dict
=
detection_model
.
loss
(
prediction_dict
,
true_image_shapes
)
for
loss_tensor
in
losses_dict
.
values
():
tf
.
losses
.
add_loss
(
loss_tensor
)
def
get_restore_checkpoint_ops
(
restore_checkpoints
,
detection_model
,
train_config
):
"""Restore checkpoint from saved checkpoints.
Args:
restore_checkpoints: loaded checkpoints.
detection_model: Object detection model built from config file.
train_config: a train_pb2.TrainConfig protobuf.
Returns:
restorers: A list ops to init the model from checkpoints.
"""
restorers
=
[]
vars_restored
=
[]
for
restore_checkpoint
in
restore_checkpoints
:
var_map
=
detection_model
.
restore_map
(
fine_tune_checkpoint_type
=
train_config
.
fine_tune_checkpoint_type
)
available_var_map
=
(
variables_helper
.
get_variables_available_in_checkpoint
(
var_map
,
restore_checkpoint
))
for
var_name
,
var
in
available_var_map
.
iteritems
():
if
var
in
vars_restored
:
logging
.
info
(
'Variable %s contained in multiple checkpoints'
,
var
.
op
.
name
)
del
available_var_map
[
var_name
]
else
:
vars_restored
.
append
(
var
)
# Initialize from ExponentialMovingAverages if possible.
available_ema_var_map
=
{}
ckpt_reader
=
tf
.
train
.
NewCheckpointReader
(
restore_checkpoint
)
ckpt_vars_to_shape_map
=
ckpt_reader
.
get_variable_to_shape_map
()
for
var_name
,
var
in
available_var_map
.
iteritems
():
var_name_ema
=
var_name
+
'/ExponentialMovingAverage'
if
var_name_ema
in
ckpt_vars_to_shape_map
:
available_ema_var_map
[
var_name_ema
]
=
var
else
:
available_ema_var_map
[
var_name
]
=
var
available_var_map
=
available_ema_var_map
init_saver
=
tf
.
train
.
Saver
(
available_var_map
)
if
available_var_map
.
keys
():
restorers
.
append
(
init_saver
)
else
:
logging
.
info
(
'WARNING: Checkpoint %s has no restorable variables'
,
restore_checkpoint
)
return
restorers
def
train
(
create_tensor_dict_fn
,
create_model_fn
,
train_config
,
master
,
task
,
num_clones
,
worker_replicas
,
clone_on_cpu
,
ps_tasks
,
worker_job_name
,
is_chief
,
train_dir
,
graph_hook_fn
=
None
):
"""Training function for detection models.
Args:
create_tensor_dict_fn: a function to create a tensor input dictionary.
create_model_fn: a function that creates a DetectionModel and generates
losses.
train_config: a train_pb2.TrainConfig protobuf.
master: BNS name of the TensorFlow master to use.
task: The task id of this training instance.
num_clones: The number of clones to run per machine.
worker_replicas: The number of work replicas to train with.
clone_on_cpu: True if clones should be forced to run on CPU.
ps_tasks: Number of parameter server tasks.
worker_job_name: Name of the worker job.
is_chief: Whether this replica is the chief replica.
train_dir: Directory to write checkpoints and training summaries to.
graph_hook_fn: Optional function that is called after the training graph is
completely built. This is helpful to perform additional changes to the
training graph such as optimizing batchnorm. The function should modify
the default graph.
"""
detection_model
=
create_model_fn
()
with
tf
.
Graph
().
as_default
():
# Build a configuration specifying multi-GPU and multi-replicas.
deploy_config
=
model_deploy
.
DeploymentConfig
(
num_clones
=
num_clones
,
clone_on_cpu
=
clone_on_cpu
,
replica_id
=
task
,
num_replicas
=
worker_replicas
,
num_ps_tasks
=
ps_tasks
,
worker_job_name
=
worker_job_name
)
# Place the global step on the device storing the variables.
with
tf
.
device
(
deploy_config
.
variables_device
()):
global_step
=
slim
.
create_global_step
()
with
tf
.
device
(
deploy_config
.
inputs_device
()):
input_queue
=
create_input_queue
(
create_tensor_dict_fn
)
# Gather initial summaries.
# TODO(rathodv): See if summaries can be added/extracted from global tf
# collections so that they don't have to be passed around.
summaries
=
set
(
tf
.
get_collection
(
tf
.
GraphKeys
.
SUMMARIES
))
global_summaries
=
set
([])
model_fn
=
functools
.
partial
(
_create_losses
,
create_model_fn
=
create_model_fn
,
train_config
=
train_config
)
clones
=
model_deploy
.
create_clones
(
deploy_config
,
model_fn
,
[
input_queue
])
first_clone_scope
=
clones
[
0
].
scope
# Gather update_ops from the first clone. These contain, for example,
# the updates for the batch_norm variables created by model_fn.
update_ops
=
tf
.
get_collection
(
tf
.
GraphKeys
.
UPDATE_OPS
,
first_clone_scope
)
with
tf
.
device
(
deploy_config
.
optimizer_device
()):
training_optimizer
,
optimizer_summary_vars
=
optimizer_builder
.
build
(
train_config
.
optimizer
)
for
var
in
optimizer_summary_vars
:
tf
.
summary
.
scalar
(
var
.
op
.
name
,
var
)
sync_optimizer
=
None
if
train_config
.
sync_replicas
:
training_optimizer
=
tf
.
train
.
SyncReplicasOptimizer
(
training_optimizer
,
replicas_to_aggregate
=
train_config
.
replicas_to_aggregate
,
total_num_replicas
=
train_config
.
worker_replicas
)
sync_optimizer
=
training_optimizer
# Create ops required to initialize the model from a given checkpoint.
init_fn
=
None
if
train_config
.
fine_tune_checkpoint
:
restore_checkpoints
=
[
path
.
strip
()
for
path
in
train_config
.
fine_tune_checkpoint
.
split
(
','
)
]
restorers
=
get_restore_checkpoint_ops
(
restore_checkpoints
,
detection_model
,
train_config
)
def
initializer_fn
(
sess
):
for
i
,
restorer
in
enumerate
(
restorers
):
restorer
.
restore
(
sess
,
restore_checkpoints
[
i
])
init_fn
=
initializer_fn
with
tf
.
device
(
deploy_config
.
optimizer_device
()):
regularization_losses
=
(
None
if
train_config
.
add_regularization_loss
else
[])
total_loss
,
grads_and_vars
=
model_deploy
.
optimize_clones
(
clones
,
training_optimizer
,
regularization_losses
=
regularization_losses
)
total_loss
=
tf
.
check_numerics
(
total_loss
,
'LossTensor is inf or nan.'
)
# Optionally multiply bias gradients by train_config.bias_grad_multiplier.
if
train_config
.
bias_grad_multiplier
:
biases_regex_list
=
[
'.*/biases'
]
grads_and_vars
=
variables_helper
.
multiply_gradients_matching_regex
(
grads_and_vars
,
biases_regex_list
,
multiplier
=
train_config
.
bias_grad_multiplier
)
# Optionally clip gradients
if
train_config
.
gradient_clipping_by_norm
>
0
:
with
tf
.
name_scope
(
'clip_grads'
):
grads_and_vars
=
slim
.
learning
.
clip_gradient_norms
(
grads_and_vars
,
train_config
.
gradient_clipping_by_norm
)
moving_average_variables
=
slim
.
get_model_variables
()
variable_averages
=
tf
.
train
.
ExponentialMovingAverage
(
0.9999
,
global_step
)
update_ops
.
append
(
variable_averages
.
apply
(
moving_average_variables
))
# Create gradient updates.
grad_updates
=
training_optimizer
.
apply_gradients
(
grads_and_vars
,
global_step
=
global_step
)
update_ops
.
append
(
grad_updates
)
update_op
=
tf
.
group
(
*
update_ops
,
name
=
'update_barrier'
)
with
tf
.
control_dependencies
([
update_op
]):
train_tensor
=
tf
.
identity
(
total_loss
,
name
=
'train_op'
)
if
graph_hook_fn
:
with
tf
.
device
(
deploy_config
.
variables_device
()):
graph_hook_fn
()
# Add summaries.
for
model_var
in
slim
.
get_model_variables
():
global_summaries
.
add
(
tf
.
summary
.
histogram
(
model_var
.
op
.
name
,
model_var
))
for
loss_tensor
in
tf
.
losses
.
get_losses
():
global_summaries
.
add
(
tf
.
summary
.
scalar
(
loss_tensor
.
op
.
name
,
loss_tensor
))
global_summaries
.
add
(
tf
.
summary
.
scalar
(
'TotalLoss'
,
tf
.
losses
.
get_total_loss
()))
# Add the summaries from the first clone. These contain the summaries
# created by model_fn and either optimize_clones() or _gather_clone_loss().
summaries
|=
set
(
tf
.
get_collection
(
tf
.
GraphKeys
.
SUMMARIES
,
first_clone_scope
))
summaries
|=
set
(
tf
.
get_collection
(
tf
.
GraphKeys
.
SUMMARIES
,
'critic_loss'
))
summaries
|=
global_summaries
# Merge all summaries together.
summary_op
=
tf
.
summary
.
merge
(
list
(
summaries
),
name
=
'summary_op'
)
# Soft placement allows placing on CPU ops without GPU implementation.
session_config
=
tf
.
ConfigProto
(
allow_soft_placement
=
True
,
log_device_placement
=
False
)
# Save checkpoints regularly.
keep_checkpoint_every_n_hours
=
train_config
.
keep_checkpoint_every_n_hours
saver
=
tf
.
train
.
Saver
(
keep_checkpoint_every_n_hours
=
keep_checkpoint_every_n_hours
)
slim
.
learning
.
train
(
train_tensor
,
logdir
=
train_dir
,
master
=
master
,
is_chief
=
is_chief
,
session_config
=
session_config
,
startup_delay_steps
=
train_config
.
startup_delay_steps
,
init_fn
=
init_fn
,
summary_op
=
summary_op
,
number_of_steps
=
(
train_config
.
num_steps
if
train_config
.
num_steps
else
None
),
save_summaries_secs
=
120
,
sync_optimizer
=
sync_optimizer
,
saver
=
saver
)
research/lstm_object_detection/utils/config_util.py
0 → 100644
View file @
0d8e49ec
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Added functionality to load from pipeline config for lstm framework."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
google.protobuf
import
text_format
from
lstm_object_detection.protos
import
input_reader_google_pb2
# pylint: disable=unused-import
from
lstm_object_detection.protos
import
pipeline_pb2
as
internal_pipeline_pb2
from
google3.third_party.tensorflow_models.object_detection.protos
import
pipeline_pb2
from
google3.third_party.tensorflow_models.object_detection.utils
import
config_util
def
get_configs_from_pipeline_file
(
pipeline_config_path
):
"""Reads configuration from a pipeline_pb2.TrainEvalPipelineConfig.
Args:
pipeline_config_path: Path to pipeline_pb2.TrainEvalPipelineConfig text
proto.
Returns:
Dictionary of configuration objects. Keys are `model`, `train_config`,
`train_input_config`, `eval_config`, `eval_input_config`, `lstm_confg`.
Value are the corresponding config objects.
"""
pipeline_config
=
pipeline_pb2
.
TrainEvalPipelineConfig
()
with
tf
.
gfile
.
GFile
(
pipeline_config_path
,
"r"
)
as
f
:
proto_str
=
f
.
read
()
text_format
.
Merge
(
proto_str
,
pipeline_config
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
if
pipeline_config
.
HasExtension
(
internal_pipeline_pb2
.
lstm_model
):
configs
[
"lstm_model"
]
=
pipeline_config
.
Extensions
[
internal_pipeline_pb2
.
lstm_model
]
return
configs
def
create_pipeline_proto_from_configs
(
configs
):
"""Creates a pipeline_pb2.TrainEvalPipelineConfig from configs dictionary.
This function nearly performs the inverse operation of
get_configs_from_pipeline_file(). Instead of returning a file path, it returns
a `TrainEvalPipelineConfig` object.
Args:
configs: Dictionary of configs. See get_configs_from_pipeline_file().
Returns:
A fully populated pipeline_pb2.TrainEvalPipelineConfig.
"""
pipeline_config
=
config_util
.
create_pipeline_proto_from_configs
(
configs
)
if
"lstm_model"
in
configs
:
pipeline_config
.
Extensions
[
internal_pipeline_pb2
.
lstm_model
].
CopyFrom
(
configs
[
"lstm_model"
])
return
pipeline_config
def
get_configs_from_multiple_files
(
model_config_path
=
""
,
train_config_path
=
""
,
train_input_config_path
=
""
,
eval_config_path
=
""
,
eval_input_config_path
=
""
,
lstm_config_path
=
""
):
"""Reads training configuration from multiple config files.
Args:
model_config_path: Path to model_pb2.DetectionModel.
train_config_path: Path to train_pb2.TrainConfig.
train_input_config_path: Path to input_reader_pb2.InputReader.
eval_config_path: Path to eval_pb2.EvalConfig.
eval_input_config_path: Path to input_reader_pb2.InputReader.
lstm_config_path: Path to pipeline_pb2.LstmModel.
Returns:
Dictionary of configuration objects. Keys are `model`, `train_config`,
`train_input_config`, `eval_config`, `eval_input_config`, `lstm_model`.
Key/Values are returned only for valid (non-empty) strings.
"""
configs
=
config_util
.
get_configs_from_multiple_files
(
model_config_path
=
model_config_path
,
train_config_path
=
train_config_path
,
train_input_config_path
=
train_input_config_path
,
eval_config_path
=
eval_config_path
,
eval_input_config_path
=
eval_input_config_path
)
if
lstm_config_path
:
lstm_config
=
internal_pipeline_pb2
.
LstmModel
()
with
tf
.
gfile
.
GFile
(
lstm_config_path
,
"r"
)
as
f
:
text_format
.
Merge
(
f
.
read
(),
lstm_config
)
configs
[
"lstm_model"
]
=
lstm_config
return
configs
research/lstm_object_detection/utils/config_util_test.py
0 → 100644
View file @
0d8e49ec
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for object_detection.utils.config_util."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
tensorflow
as
tf
from
google.protobuf
import
text_format
from
lstm_object_detection.protos
import
pipeline_pb2
as
internal_pipeline_pb2
from
lstm_object_detection.utils
import
config_util
from
google3.third_party.tensorflow_models.object_detection.protos
import
pipeline_pb2
def
_write_config
(
config
,
config_path
):
"""Writes a config object to disk."""
config_text
=
text_format
.
MessageToString
(
config
)
with
tf
.
gfile
.
Open
(
config_path
,
"wb"
)
as
f
:
f
.
write
(
config_text
)
class
ConfigUtilTest
(
tf
.
test
.
TestCase
):
def
test_get_configs_from_pipeline_file
(
self
):
"""Test that proto configs can be read from pipeline config file."""
pipeline_config_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"pipeline.config"
)
pipeline_config
=
pipeline_pb2
.
TrainEvalPipelineConfig
()
pipeline_config
.
model
.
ssd
.
num_classes
=
10
pipeline_config
.
train_config
.
batch_size
=
32
pipeline_config
.
train_input_reader
.
label_map_path
=
"path/to/label_map"
pipeline_config
.
eval_config
.
num_examples
=
20
pipeline_config
.
eval_input_reader
.
queue_capacity
=
100
pipeline_config
.
Extensions
[
internal_pipeline_pb2
.
lstm_model
].
train_unroll_length
=
5
pipeline_config
.
Extensions
[
internal_pipeline_pb2
.
lstm_model
].
eval_unroll_length
=
10
_write_config
(
pipeline_config
,
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
self
.
assertProtoEquals
(
pipeline_config
.
model
,
configs
[
"model"
])
self
.
assertProtoEquals
(
pipeline_config
.
train_config
,
configs
[
"train_config"
])
self
.
assertProtoEquals
(
pipeline_config
.
train_input_reader
,
configs
[
"train_input_config"
])
self
.
assertProtoEquals
(
pipeline_config
.
eval_config
,
configs
[
"eval_config"
])
self
.
assertProtoEquals
(
pipeline_config
.
eval_input_reader
,
configs
[
"eval_input_config"
])
self
.
assertProtoEquals
(
pipeline_config
.
Extensions
[
internal_pipeline_pb2
.
lstm_model
],
configs
[
"lstm_model"
])
def
test_create_pipeline_proto_from_configs
(
self
):
"""Tests that proto can be reconstructed from configs dictionary."""
pipeline_config_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"pipeline.config"
)
pipeline_config
=
pipeline_pb2
.
TrainEvalPipelineConfig
()
pipeline_config
.
model
.
ssd
.
num_classes
=
10
pipeline_config
.
train_config
.
batch_size
=
32
pipeline_config
.
train_input_reader
.
label_map_path
=
"path/to/label_map"
pipeline_config
.
eval_config
.
num_examples
=
20
pipeline_config
.
eval_input_reader
.
queue_capacity
=
100
pipeline_config
.
Extensions
[
internal_pipeline_pb2
.
lstm_model
].
train_unroll_length
=
5
pipeline_config
.
Extensions
[
internal_pipeline_pb2
.
lstm_model
].
eval_unroll_length
=
10
_write_config
(
pipeline_config
,
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
pipeline_config_reconstructed
=
(
config_util
.
create_pipeline_proto_from_configs
(
configs
))
self
.
assertEqual
(
pipeline_config
,
pipeline_config_reconstructed
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment