Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
cc08dc87
Commit
cc08dc87
authored
Apr 12, 2018
by
Zhichao Lu
Committed by
pkulzc
Apr 13, 2018
Browse files
Update to trainer to allow for reading multiclass scores
PiperOrigin-RevId: 192624207
parent
1bddd18e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
82 additions
and
6 deletions
+82
-6
research/object_detection/trainer.py
research/object_detection/trainer.py
+34
-5
research/object_detection/trainer_test.py
research/object_detection/trainer_test.py
+48
-1
No files found.
research/object_detection/trainer.py
View file @
cc08dc87
...
...
@@ -69,10 +69,13 @@ def create_input_queue(batch_size_per_clone, create_tensor_dict_fn,
in
tensor_dict
)
include_keypoints
=
(
fields
.
InputDataFields
.
groundtruth_keypoints
in
tensor_dict
)
include_multiclass_scores
=
(
fields
.
InputDataFields
.
multiclass_scores
in
tensor_dict
)
if
data_augmentation_options
:
tensor_dict
=
preprocessor
.
preprocess
(
tensor_dict
,
data_augmentation_options
,
func_arg_map
=
preprocessor
.
get_default_func_arg_map
(
include_multiclass_scores
=
include_multiclass_scores
,
include_instance_masks
=
include_instance_masks
,
include_keypoints
=
include_keypoints
))
...
...
@@ -85,7 +88,10 @@ def create_input_queue(batch_size_per_clone, create_tensor_dict_fn,
return
input_queue
def
get_inputs
(
input_queue
,
num_classes
,
merge_multiple_label_boxes
=
False
):
def
get_inputs
(
input_queue
,
num_classes
,
merge_multiple_label_boxes
=
False
,
use_multiclass_scores
=
False
):
"""Dequeues batch and constructs inputs to object detection model.
Args:
...
...
@@ -95,6 +101,8 @@ def get_inputs(input_queue, num_classes, merge_multiple_label_boxes=False):
or not. Defaults to false. Merged boxes are represented with a single
box and a k-hot encoding of the multiple labels associated with the
boxes.
use_multiclass_scores: Whether to use multiclass scores instead of
groundtruth_classes.
Returns:
images: a list of 3-D float tensor of images.
...
...
@@ -123,9 +131,19 @@ def get_inputs(input_queue, num_classes, merge_multiple_label_boxes=False):
classes_gt
=
tf
.
cast
(
read_data
[
fields
.
InputDataFields
.
groundtruth_classes
],
tf
.
int32
)
classes_gt
-=
label_id_offset
if
merge_multiple_label_boxes
and
use_multiclass_scores
:
raise
ValueError
(
'Using both merge_multiple_label_boxes and use_multiclass_scores is'
'not supported'
)
if
merge_multiple_label_boxes
:
location_gt
,
classes_gt
,
_
=
util_ops
.
merge_boxes_with_multiple_labels
(
location_gt
,
classes_gt
,
num_classes
)
elif
use_multiclass_scores
:
classes_gt
=
tf
.
cast
(
read_data
[
fields
.
InputDataFields
.
multiclass_scores
],
tf
.
float32
)
else
:
classes_gt
=
util_ops
.
padded_one_hot_encoding
(
indices
=
classes_gt
,
depth
=
num_classes
,
left_pad
=
0
)
...
...
@@ -155,7 +173,8 @@ def _create_losses(input_queue, create_model_fn, train_config):
groundtruth_masks_list
,
groundtruth_keypoints_list
,
_
)
=
get_inputs
(
input_queue
,
detection_model
.
num_classes
,
train_config
.
merge_multiple_label_boxes
)
train_config
.
merge_multiple_label_boxes
,
train_config
.
use_multiclass_scores
)
preprocessed_images
=
[]
true_image_shapes
=
[]
...
...
@@ -183,9 +202,19 @@ def _create_losses(input_queue, create_model_fn, train_config):
tf
.
losses
.
add_loss
(
loss_tensor
)
def
train
(
create_tensor_dict_fn
,
create_model_fn
,
train_config
,
master
,
task
,
num_clones
,
worker_replicas
,
clone_on_cpu
,
ps_tasks
,
worker_job_name
,
is_chief
,
train_dir
,
graph_hook_fn
=
None
):
def
train
(
create_tensor_dict_fn
,
create_model_fn
,
train_config
,
master
,
task
,
num_clones
,
worker_replicas
,
clone_on_cpu
,
ps_tasks
,
worker_job_name
,
is_chief
,
train_dir
,
graph_hook_fn
=
None
):
"""Training function for detection models.
Args:
...
...
research/object_detection/trainer_test.py
View file @
cc08dc87
...
...
@@ -37,12 +37,15 @@ def get_input_function():
[
1
],
minval
=
0
,
maxval
=
NUMBER_OF_CLASSES
,
dtype
=
tf
.
int32
)
box_label
=
tf
.
random_uniform
(
[
1
,
4
],
minval
=
0.4
,
maxval
=
0.6
,
dtype
=
tf
.
float32
)
multiclass_scores
=
tf
.
random_uniform
(
[
1
,
NUMBER_OF_CLASSES
],
minval
=
0.4
,
maxval
=
0.6
,
dtype
=
tf
.
float32
)
return
{
fields
.
InputDataFields
.
image
:
image
,
fields
.
InputDataFields
.
key
:
key
,
fields
.
InputDataFields
.
groundtruth_classes
:
class_label
,
fields
.
InputDataFields
.
groundtruth_boxes
:
box_label
fields
.
InputDataFields
.
groundtruth_boxes
:
box_label
,
fields
.
InputDataFields
.
multiclass_scores
:
multiclass_scores
}
...
...
@@ -203,6 +206,50 @@ class TrainerTest(tf.test.TestCase):
train_dir
=
self
.
get_temp_dir
()
trainer
.
train
(
create_tensor_dict_fn
=
get_input_function
,
create_model_fn
=
FakeDetectionModel
,
train_config
=
train_config
,
master
=
''
,
task
=
0
,
num_clones
=
1
,
worker_replicas
=
1
,
clone_on_cpu
=
True
,
ps_tasks
=
0
,
worker_job_name
=
'worker'
,
is_chief
=
True
,
train_dir
=
train_dir
)
def
test_configure_trainer_with_multiclass_scores_and_train_two_steps
(
self
):
train_config_text_proto
=
"""
optimizer {
adam_optimizer {
learning_rate {
constant_learning_rate {
learning_rate: 0.01
}
}
}
}
data_augmentation_options {
random_adjust_brightness {
max_delta: 0.2
}
}
data_augmentation_options {
random_adjust_contrast {
min_delta: 0.7
max_delta: 1.1
}
}
num_steps: 2
use_multiclass_scores: true
"""
train_config
=
train_pb2
.
TrainConfig
()
text_format
.
Merge
(
train_config_text_proto
,
train_config
)
train_dir
=
self
.
get_temp_dir
()
trainer
.
train
(
create_tensor_dict_fn
=
get_input_function
,
create_model_fn
=
FakeDetectionModel
,
train_config
=
train_config
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment