Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a3c7d7e8
Unverified
Commit
a3c7d7e8
authored
Oct 27, 2017
by
vivek rathod
Committed by
GitHub
Oct 27, 2017
Browse files
Merge pull request #2621 from tombstone/trainer
Use config parsing routines from utils/config_utils.py in train.py
parents
9cfd2d93
3616ab28
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
84 additions
and
84 deletions
+84
-84
research/object_detection/BUILD
research/object_detection/BUILD
+2
-5
research/object_detection/train.py
research/object_detection/train.py
+22
-57
research/object_detection/trainer.py
research/object_detection/trainer.py
+58
-22
research/object_detection/trainer_test.py
research/object_detection/trainer_test.py
+2
-0
No files found.
research/object_detection/BUILD
View file @
a3c7d7e8
...
...
@@ -7,7 +7,6 @@ package(
licenses
([
"notice"
])
# Apache 2.0
py_binary
(
name
=
"train"
,
srcs
=
[
...
...
@@ -18,10 +17,7 @@ py_binary(
"//tensorflow"
,
"//tensorflow_models/object_detection/builders:input_reader_builder"
,
"//tensorflow_models/object_detection/builders:model_builder"
,
"//tensorflow_models/object_detection/protos:input_reader_py_pb2"
,
"//tensorflow_models/object_detection/protos:model_py_pb2"
,
"//tensorflow_models/object_detection/protos:pipeline_py_pb2"
,
"//tensorflow_models/object_detection/protos:train_py_pb2"
,
"//tensorflow_models/object_detection/utils:config_util"
,
],
)
...
...
@@ -33,6 +29,7 @@ py_library(
"//tensorflow_models/object_detection/builders:optimizer_builder"
,
"//tensorflow_models/object_detection/builders:preprocessor_builder"
,
"//tensorflow_models/object_detection/core:batcher"
,
"//tensorflow_models/object_detection/core:preprocessor"
,
"//tensorflow_models/object_detection/core:standard_fields"
,
"//tensorflow_models/object_detection/utils:ops"
,
"//tensorflow_models/object_detection/utils:variables_helper"
,
...
...
research/object_detection/train.py
View file @
a3c7d7e8
...
...
@@ -46,15 +46,10 @@ import json
import
os
import
tensorflow
as
tf
from
google.protobuf
import
text_format
from
object_detection
import
trainer
from
object_detection.builders
import
input_reader_builder
from
object_detection.builders
import
model_builder
from
object_detection.protos
import
input_reader_pb2
from
object_detection.protos
import
model_pb2
from
object_detection.protos
import
pipeline_pb2
from
object_detection.protos
import
train_pb2
from
object_detection.utils
import
config_util
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
...
...
@@ -88,61 +83,31 @@ flags.DEFINE_string('model_config_path', '',
FLAGS
=
flags
.
FLAGS
def
get_configs_from_pipeline_file
():
"""Reads training configuration from a pipeline_pb2.TrainEvalPipelineConfig.
Reads training config from file specified by pipeline_config_path flag.
Returns:
model_config: model_pb2.DetectionModel
train_config: train_pb2.TrainConfig
input_config: input_reader_pb2.InputReader
"""
pipeline_config
=
pipeline_pb2
.
TrainEvalPipelineConfig
()
with
tf
.
gfile
.
GFile
(
FLAGS
.
pipeline_config_path
,
'r'
)
as
f
:
text_format
.
Merge
(
f
.
read
(),
pipeline_config
)
model_config
=
pipeline_config
.
model
train_config
=
pipeline_config
.
train_config
input_config
=
pipeline_config
.
train_input_reader
return
model_config
,
train_config
,
input_config
def
get_configs_from_multiple_files
():
"""Reads training configuration from multiple config files.
Reads the training config from the following files:
model_config: Read from --model_config_path
train_config: Read from --train_config_path
input_config: Read from --input_config_path
Returns:
model_config: model_pb2.DetectionModel
train_config: train_pb2.TrainConfig
input_config: input_reader_pb2.InputReader
"""
train_config
=
train_pb2
.
TrainConfig
()
with
tf
.
gfile
.
GFile
(
FLAGS
.
train_config_path
,
'r'
)
as
f
:
text_format
.
Merge
(
f
.
read
(),
train_config
)
model_config
=
model_pb2
.
DetectionModel
()
with
tf
.
gfile
.
GFile
(
FLAGS
.
model_config_path
,
'r'
)
as
f
:
text_format
.
Merge
(
f
.
read
(),
model_config
)
input_config
=
input_reader_pb2
.
InputReader
()
with
tf
.
gfile
.
GFile
(
FLAGS
.
input_config_path
,
'r'
)
as
f
:
text_format
.
Merge
(
f
.
read
(),
input_config
)
return
model_config
,
train_config
,
input_config
def
main
(
_
):
assert
FLAGS
.
train_dir
,
'`train_dir` is missing.'
if
FLAGS
.
task
==
0
:
tf
.
gfile
.
MakeDirs
(
FLAGS
.
train_dir
)
if
FLAGS
.
pipeline_config_path
:
model_config
,
train_config
,
input_config
=
get_configs_from_pipeline_file
()
configs
=
config_util
.
get_configs_from_pipeline_file
(
FLAGS
.
pipeline_config_path
)
if
FLAGS
.
task
==
0
:
tf
.
gfile
.
Copy
(
FLAGS
.
pipeline_config_path
,
os
.
path
.
join
(
FLAGS
.
train_dir
,
'pipeline.config'
),
overwrite
=
True
)
else
:
model_config
,
train_config
,
input_config
=
get_configs_from_multiple_files
()
configs
=
config_util
.
get_configs_from_multiple_files
(
model_config_path
=
FLAGS
.
model_config_path
,
train_config_path
=
FLAGS
.
train_config_path
,
train_input_config_path
=
FLAGS
.
input_config_path
)
if
FLAGS
.
task
==
0
:
for
name
,
config
in
[(
'model.config'
,
FLAGS
.
model_config_path
),
(
'train.config'
,
FLAGS
.
train_config_path
),
(
'input.config'
,
FLAGS
.
input_config_path
)]:
tf
.
gfile
.
Copy
(
config
,
os
.
path
.
join
(
FLAGS
.
train_dir
,
name
),
overwrite
=
True
)
model_config
=
configs
[
'model'
]
train_config
=
configs
[
'train_config'
]
input_config
=
configs
[
'train_input_config'
]
model_fn
=
functools
.
partial
(
model_builder
.
build
,
...
...
research/object_detection/trainer.py
View file @
a3c7d7e8
...
...
@@ -35,9 +35,9 @@ from deployment import model_deploy
slim
=
tf
.
contrib
.
slim
def
_
create_input_queue
(
batch_size_per_clone
,
create_tensor_dict_fn
,
batch_queue_capacity
,
num_batch_queue_threads
,
prefetch_queue_capacity
,
data_augmentation_options
):
def
create_input_queue
(
batch_size_per_clone
,
create_tensor_dict_fn
,
batch_queue_capacity
,
num_batch_queue_threads
,
prefetch_queue_capacity
,
data_augmentation_options
):
"""Sets up reader, prefetcher and returns input queue.
Args:
...
...
@@ -65,9 +65,16 @@ def _create_input_queue(batch_size_per_clone, create_tensor_dict_fn,
float_images
=
tf
.
to_float
(
images
)
tensor_dict
[
fields
.
InputDataFields
.
image
]
=
float_images
include_instance_masks
=
(
fields
.
InputDataFields
.
groundtruth_instance_masks
in
tensor_dict
)
include_keypoints
=
(
fields
.
InputDataFields
.
groundtruth_keypoints
in
tensor_dict
)
if
data_augmentation_options
:
tensor_dict
=
preprocessor
.
preprocess
(
tensor_dict
,
data_augmentation_options
)
tensor_dict
=
preprocessor
.
preprocess
(
tensor_dict
,
data_augmentation_options
,
func_arg_map
=
preprocessor
.
get_default_func_arg_map
(
include_instance_masks
=
include_instance_masks
,
include_keypoints
=
include_keypoints
))
input_queue
=
batcher
.
BatchQueue
(
tensor_dict
,
...
...
@@ -78,56 +85,83 @@ def _create_input_queue(batch_size_per_clone, create_tensor_dict_fn,
return
input_queue
def
_
get_inputs
(
input_queue
,
num_classes
):
"""Dequeue batch and construct inputs to object detection model.
def
get_inputs
(
input_queue
,
num_classes
,
merge_multiple_label_boxes
=
False
):
"""Dequeue
s
batch and construct
s
inputs to object detection model.
Args:
input_queue: BatchQueue object holding enqueued tensor_dicts.
num_classes: Number of classes.
merge_multiple_label_boxes: Whether to merge boxes with multiple labels
or not. Defaults to false. Merged boxes are represented with a single
box and a k-hot encoding of the multiple labels associated with the
boxes.
Returns:
images: a list of 3-D float tensor of images.
image_keys: a list of string keys for the images.
locations_list: a list of tensors of shape [num_boxes, 4]
containing the corners of the groundtruth boxes.
classes_list: a list of padded one-hot tensors containing target classes.
masks_list: a list of 3-D float tensors of shape [num_boxes, image_height,
image_width] containing instance masks for objects if present in the
input_queue. Else returns None.
keypoints_list: a list of 3-D float tensors of shape [num_boxes,
num_keypoints, 2] containing keypoints for objects if present in the
input queue. Else returns None.
"""
read_data_list
=
input_queue
.
dequeue
()
label_id_offset
=
1
def
extract_images_and_targets
(
read_data
):
"""Extract images and targets from the input dict."""
image
=
read_data
[
fields
.
InputDataFields
.
image
]
key
=
''
if
fields
.
InputDataFields
.
source_id
in
read_data
:
key
=
read_data
[
fields
.
InputDataFields
.
source_id
]
location_gt
=
read_data
[
fields
.
InputDataFields
.
groundtruth_boxes
]
classes_gt
=
tf
.
cast
(
read_data
[
fields
.
InputDataFields
.
groundtruth_classes
],
tf
.
int32
)
classes_gt
-=
label_id_offset
classes_gt
=
util_ops
.
padded_one_hot_encoding
(
indices
=
classes_gt
,
depth
=
num_classes
,
left_pad
=
0
)
if
merge_multiple_label_boxes
:
location_gt
,
classes_gt
,
_
=
util_ops
.
merge_boxes_with_multiple_labels
(
location_gt
,
classes_gt
,
num_classes
)
else
:
classes_gt
=
util_ops
.
padded_one_hot_encoding
(
indices
=
classes_gt
,
depth
=
num_classes
,
left_pad
=
0
)
masks_gt
=
read_data
.
get
(
fields
.
InputDataFields
.
groundtruth_instance_masks
)
return
image
,
location_gt
,
classes_gt
,
masks_gt
keypoints_gt
=
read_data
.
get
(
fields
.
InputDataFields
.
groundtruth_keypoints
)
if
(
merge_multiple_label_boxes
and
(
masks_gt
is
not
None
or
keypoints_gt
is
not
None
)):
raise
NotImplementedError
(
'Multi-label support is only for boxes.'
)
return
image
,
key
,
location_gt
,
classes_gt
,
masks_gt
,
keypoints_gt
return
zip
(
*
map
(
extract_images_and_targets
,
read_data_list
))
def
_create_losses
(
input_queue
,
create_model_fn
):
def
_create_losses
(
input_queue
,
create_model_fn
,
train_config
):
"""Creates loss function for a DetectionModel.
Args:
input_queue: BatchQueue object holding enqueued tensor_dicts.
create_model_fn: A function to create the DetectionModel.
train_config: a train_pb2.TrainConfig protobuf.
"""
detection_model
=
create_model_fn
()
(
images
,
groundtruth_boxes_list
,
groundtruth_classes_list
,
groundtruth_masks_list
)
=
_get_inputs
(
input_queue
,
detection_model
.
num_classes
)
(
images
,
_
,
groundtruth_boxes_list
,
groundtruth_classes_list
,
groundtruth_masks_list
,
groundtruth_keypoints_list
)
=
get_inputs
(
input_queue
,
detection_model
.
num_classes
,
train_config
.
merge_multiple_label_boxes
)
images
=
[
detection_model
.
preprocess
(
image
)
for
image
in
images
]
images
=
tf
.
concat
(
images
,
0
)
if
any
(
mask
is
None
for
mask
in
groundtruth_masks_list
):
groundtruth_masks_list
=
None
if
any
(
keypoints
is
None
for
keypoints
in
groundtruth_keypoints_list
):
groundtruth_keypoints_list
=
None
detection_model
.
provide_groundtruth
(
groundtruth_boxes_list
,
groundtruth_classes_list
,
groundtruth_masks_list
)
groundtruth_masks_list
,
groundtruth_keypoints_list
)
prediction_dict
=
detection_model
.
predict
(
images
)
losses_dict
=
detection_model
.
loss
(
prediction_dict
)
...
...
@@ -176,19 +210,21 @@ def train(create_tensor_dict_fn, create_model_fn, train_config, master, task,
global_step
=
slim
.
create_global_step
()
with
tf
.
device
(
deploy_config
.
inputs_device
()):
input_queue
=
_create_input_queue
(
train_config
.
batch_size
//
num_clones
,
create_tensor_dict_fn
,
train_config
.
batch_queue_capacity
,
train_config
.
num_batch_queue_threads
,
train_config
.
prefetch_queue_capacity
,
data_augmentation_options
)
input_queue
=
create_input_queue
(
train_config
.
batch_size
//
num_clones
,
create_tensor_dict_fn
,
train_config
.
batch_queue_capacity
,
train_config
.
num_batch_queue_threads
,
train_config
.
prefetch_queue_capacity
,
data_augmentation_options
)
# Gather initial summaries.
# TODO(rathodv): See if summaries can be added/extracted from global tf
# collections so that they don't have to be passed around.
summaries
=
set
(
tf
.
get_collection
(
tf
.
GraphKeys
.
SUMMARIES
))
global_summaries
=
set
([])
model_fn
=
functools
.
partial
(
_create_losses
,
create_model_fn
=
create_model_fn
)
create_model_fn
=
create_model_fn
,
train_config
=
train_config
)
clones
=
model_deploy
.
create_clones
(
deploy_config
,
model_fn
,
[
input_queue
])
first_clone_scope
=
clones
[
0
].
scope
...
...
research/object_detection/trainer_test.py
View file @
a3c7d7e8
...
...
@@ -32,6 +32,7 @@ NUMBER_OF_CLASSES = 2
def
get_input_function
():
"""A function to get test inputs. Returns an image with one box."""
image
=
tf
.
random_uniform
([
32
,
32
,
3
],
dtype
=
tf
.
float32
)
key
=
tf
.
constant
(
'image_000000'
)
class_label
=
tf
.
random_uniform
(
[
1
],
minval
=
0
,
maxval
=
NUMBER_OF_CLASSES
,
dtype
=
tf
.
int32
)
box_label
=
tf
.
random_uniform
(
...
...
@@ -39,6 +40,7 @@ def get_input_function():
return
{
fields
.
InputDataFields
.
image
:
image
,
fields
.
InputDataFields
.
key
:
key
,
fields
.
InputDataFields
.
groundtruth_classes
:
class_label
,
fields
.
InputDataFields
.
groundtruth_boxes
:
box_label
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment