Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
27b4acd4
Commit
27b4acd4
authored
Sep 25, 2018
by
Aman Gupta
Browse files
Merge remote-tracking branch 'upstream/master'
parents
5133522f
d4e1f97f
Changes
240
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1559 additions
and
254 deletions
+1559
-254
research/object_detection/protos/eval.proto
research/object_detection/protos/eval.proto
+3
-2
research/object_detection/protos/faster_rcnn.proto
research/object_detection/protos/faster_rcnn.proto
+7
-0
research/object_detection/protos/input_reader.proto
research/object_detection/protos/input_reader.proto
+11
-1
research/object_detection/protos/pipeline.proto
research/object_detection/protos/pipeline.proto
+2
-1
research/object_detection/protos/post_processing.proto
research/object_detection/protos/post_processing.proto
+3
-0
research/object_detection/protos/ssd.proto
research/object_detection/protos/ssd.proto
+3
-0
research/object_detection/protos/train.proto
research/object_detection/protos/train.proto
+4
-1
research/object_detection/samples/configs/faster_rcnn_resnet101_fgvc.config
...tection/samples/configs/faster_rcnn_resnet101_fgvc.config
+135
-0
research/object_detection/samples/configs/faster_rcnn_resnet50_fgvc.config
...etection/samples/configs/faster_rcnn_resnet50_fgvc.config
+135
-0
research/object_detection/utils/config_util.py
research/object_detection/utils/config_util.py
+328
-79
research/object_detection/utils/config_util_test.py
research/object_detection/utils/config_util_test.py
+269
-35
research/object_detection/utils/label_map_util.py
research/object_detection/utils/label_map_util.py
+38
-13
research/object_detection/utils/label_map_util_test.py
research/object_detection/utils/label_map_util_test.py
+66
-3
research/object_detection/utils/ops.py
research/object_detection/utils/ops.py
+136
-71
research/object_detection/utils/ops_test.py
research/object_detection/utils/ops_test.py
+126
-39
research/object_detection/utils/shape_utils.py
research/object_detection/utils/shape_utils.py
+23
-0
research/object_detection/utils/test_case.py
research/object_detection/utils/test_case.py
+3
-0
research/object_detection/utils/test_utils.py
research/object_detection/utils/test_utils.py
+35
-9
research/object_detection/utils/visualization_utils.py
research/object_detection/utils/visualization_utils.py
+157
-0
research/object_detection/utils/visualization_utils_test.py
research/object_detection/utils/visualization_utils_test.py
+75
-0
No files found.
research/object_detection/protos/eval.proto
View file @
27b4acd4
...
@@ -4,17 +4,18 @@ package object_detection.protos;
...
@@ -4,17 +4,18 @@ package object_detection.protos;
// Message for configuring DetectionModel evaluation jobs (eval.py).
// Message for configuring DetectionModel evaluation jobs (eval.py).
message
EvalConfig
{
message
EvalConfig
{
optional
uint32
batch_size
=
25
[
default
=
1
];
// Number of visualization images to generate.
// Number of visualization images to generate.
optional
uint32
num_visualizations
=
1
[
default
=
10
];
optional
uint32
num_visualizations
=
1
[
default
=
10
];
// Number of examples to process of evaluation.
// Number of examples to process of evaluation.
optional
uint32
num_examples
=
2
[
default
=
5000
];
optional
uint32
num_examples
=
2
[
default
=
5000
,
deprecated
=
true
];
// How often to run evaluation.
// How often to run evaluation.
optional
uint32
eval_interval_secs
=
3
[
default
=
300
];
optional
uint32
eval_interval_secs
=
3
[
default
=
300
];
// Maximum number of times to run evaluation. If set to 0, will run forever.
// Maximum number of times to run evaluation. If set to 0, will run forever.
optional
uint32
max_evals
=
4
[
default
=
0
];
optional
uint32
max_evals
=
4
[
default
=
0
,
deprecated
=
true
];
// Whether the TensorFlow graph used for evaluation should be saved to disk.
// Whether the TensorFlow graph used for evaluation should be saved to disk.
optional
bool
save_graph
=
5
[
default
=
false
];
optional
bool
save_graph
=
5
[
default
=
false
];
...
...
research/object_detection/protos/faster_rcnn.proto
View file @
27b4acd4
...
@@ -157,6 +157,13 @@ message FasterRcnn {
...
@@ -157,6 +157,13 @@ message FasterRcnn {
// Whether to use the balanced positive negative sampler implementation with
// Whether to use the balanced positive negative sampler implementation with
// static shape guarantees.
// static shape guarantees.
optional
bool
use_static_balanced_label_sampler
=
34
[
default
=
false
];
optional
bool
use_static_balanced_label_sampler
=
34
[
default
=
false
];
// If True, uses implementation of ops with static shape guarantees.
optional
bool
use_static_shapes
=
35
[
default
=
false
];
// Whether the masks present in groundtruth should be resized in the model to
// match the image size.
optional
bool
resize_masks
=
36
[
default
=
true
];
}
}
...
...
research/object_detection/protos/input_reader.proto
View file @
27b4acd4
...
@@ -22,7 +22,12 @@ enum InstanceMaskType {
...
@@ -22,7 +22,12 @@ enum InstanceMaskType {
PNG_MASKS
=
2
;
// Encoded PNG masks.
PNG_MASKS
=
2
;
// Encoded PNG masks.
}
}
// Next id: 24
message
InputReader
{
message
InputReader
{
// Name of input reader. Typically used to describe the dataset that is read
// by this input reader.
optional
string
name
=
23
[
default
=
""
];
// Path to StringIntLabelMap pbtxt file specifying the mapping from string
// Path to StringIntLabelMap pbtxt file specifying the mapping from string
// labels to integer ids.
// labels to integer ids.
optional
string
label_map_path
=
1
[
default
=
""
];
optional
string
label_map_path
=
1
[
default
=
""
];
...
@@ -41,6 +46,12 @@ message InputReader {
...
@@ -41,6 +46,12 @@ message InputReader {
// will be reused indefinitely.
// will be reused indefinitely.
optional
uint32
num_epochs
=
5
[
default
=
0
];
optional
uint32
num_epochs
=
5
[
default
=
0
];
// Integer representing how often an example should be sampled. To feed
// only 1/3 of your data into your model, set `sample_1_of_n_examples` to 3.
// This is particularly useful for evaluation, where you might not prefer to
// evaluate all of your samples.
optional
uint32
sample_1_of_n_examples
=
22
[
default
=
1
];
// Number of file shards to read in parallel.
// Number of file shards to read in parallel.
optional
uint32
num_readers
=
6
[
default
=
64
];
optional
uint32
num_readers
=
6
[
default
=
64
];
...
@@ -62,7 +73,6 @@ message InputReader {
...
@@ -62,7 +73,6 @@ message InputReader {
// to generate a good random shuffle.
// to generate a good random shuffle.
optional
uint32
min_after_dequeue
=
4
[
default
=
1000
,
deprecated
=
true
];
optional
uint32
min_after_dequeue
=
4
[
default
=
1000
,
deprecated
=
true
];
// Number of records to read from each reader at once.
// Number of records to read from each reader at once.
optional
uint32
read_block_length
=
15
[
default
=
32
];
optional
uint32
read_block_length
=
15
[
default
=
32
];
...
...
research/object_detection/protos/pipeline.proto
View file @
27b4acd4
...
@@ -10,12 +10,13 @@ import "object_detection/protos/train.proto";
...
@@ -10,12 +10,13 @@ import "object_detection/protos/train.proto";
// Convenience message for configuring a training and eval pipeline. Allows all
// Convenience message for configuring a training and eval pipeline. Allows all
// of the pipeline parameters to be configured from one file.
// of the pipeline parameters to be configured from one file.
// Next id: 7
message
TrainEvalPipelineConfig
{
message
TrainEvalPipelineConfig
{
optional
DetectionModel
model
=
1
;
optional
DetectionModel
model
=
1
;
optional
TrainConfig
train_config
=
2
;
optional
TrainConfig
train_config
=
2
;
optional
InputReader
train_input_reader
=
3
;
optional
InputReader
train_input_reader
=
3
;
optional
EvalConfig
eval_config
=
4
;
optional
EvalConfig
eval_config
=
4
;
optional
InputReader
eval_input_reader
=
5
;
repeated
InputReader
eval_input_reader
=
5
;
optional
GraphRewriter
graph_rewriter
=
6
;
optional
GraphRewriter
graph_rewriter
=
6
;
extensions
1000
to
max
;
extensions
1000
to
max
;
}
}
research/object_detection/protos/post_processing.proto
View file @
27b4acd4
...
@@ -17,6 +17,9 @@ message BatchNonMaxSuppression {
...
@@ -17,6 +17,9 @@ message BatchNonMaxSuppression {
// Maximum number of detections to retain across all classes.
// Maximum number of detections to retain across all classes.
optional
int32
max_total_detections
=
5
[
default
=
100
];
optional
int32
max_total_detections
=
5
[
default
=
100
];
// Whether to use the implementation of NMS that guarantees static shapes.
optional
bool
use_static_shapes
=
6
[
default
=
false
];
}
}
// Configuration proto for post-processing predicted boxes and
// Configuration proto for post-processing predicted boxes and
...
...
research/object_detection/protos/ssd.proto
View file @
27b4acd4
...
@@ -163,5 +163,8 @@ message FeaturePyramidNetworks {
...
@@ -163,5 +163,8 @@ message FeaturePyramidNetworks {
// maximum level in feature pyramid
// maximum level in feature pyramid
optional
int32
max_level
=
2
[
default
=
7
];
optional
int32
max_level
=
2
[
default
=
7
];
// channel depth for additional coarse feature layers.
optional
int32
additional_layer_depth
=
3
[
default
=
256
];
}
}
research/object_detection/protos/train.proto
View file @
27b4acd4
...
@@ -6,7 +6,7 @@ import "object_detection/protos/optimizer.proto";
...
@@ -6,7 +6,7 @@ import "object_detection/protos/optimizer.proto";
import
"object_detection/protos/preprocessor.proto"
;
import
"object_detection/protos/preprocessor.proto"
;
// Message for configuring DetectionModel training jobs (train.py).
// Message for configuring DetectionModel training jobs (train.py).
// Next id: 2
6
// Next id: 2
7
message
TrainConfig
{
message
TrainConfig
{
// Effective batch size to use for training.
// Effective batch size to use for training.
// For TPU (or sync SGD jobs), the batch size per core (or GPU) is going to be
// For TPU (or sync SGD jobs), the batch size per core (or GPU) is going to be
...
@@ -112,4 +112,7 @@ message TrainConfig {
...
@@ -112,4 +112,7 @@ message TrainConfig {
// dictionary, so that they can be displayed in Tensorboard. Note that this
// dictionary, so that they can be displayed in Tensorboard. Note that this
// will lead to a larger memory footprint.
// will lead to a larger memory footprint.
optional
bool
retain_original_images
=
23
[
default
=
false
];
optional
bool
retain_original_images
=
23
[
default
=
false
];
// Whether to use bfloat16 for training.
optional
bool
use_bfloat16
=
26
[
default
=
false
];
}
}
research/object_detection/samples/configs/faster_rcnn_resnet101_fgvc.config
0 → 100644
View file @
27b4acd4
# Faster R-CNN with Resnet-101 (v1)
# Users should configure the fine_tune_checkpoint field in the train config as
# well as the label_map_path and input_path fields in the train_input_reader and
# eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that
# should be configured.
model
{
faster_rcnn
{
num_classes
:
2854
image_resizer
{
keep_aspect_ratio_resizer
{
min_dimension
:
600
max_dimension
:
1024
}
}
feature_extractor
{
type
:
'faster_rcnn_resnet101'
first_stage_features_stride
:
16
}
first_stage_anchor_generator
{
grid_anchor_generator
{
scales
: [
0
.
25
,
0
.
5
,
1
.
0
,
2
.
0
]
aspect_ratios
: [
0
.
5
,
1
.
0
,
2
.
0
]
height_stride
:
16
width_stride
:
16
}
}
first_stage_box_predictor_conv_hyperparams
{
op
:
CONV
regularizer
{
l2_regularizer
{
weight
:
0
.
0
}
}
initializer
{
truncated_normal_initializer
{
stddev
:
0
.
01
}
}
}
first_stage_nms_score_threshold
:
0
.
0
first_stage_nms_iou_threshold
:
0
.
7
first_stage_max_proposals
:
32
first_stage_localization_loss_weight
:
2
.
0
first_stage_objectness_loss_weight
:
1
.
0
initial_crop_size
:
14
maxpool_kernel_size
:
2
maxpool_stride
:
2
second_stage_batch_size
:
32
second_stage_box_predictor
{
mask_rcnn_box_predictor
{
use_dropout
:
false
dropout_keep_probability
:
1
.
0
fc_hyperparams
{
op
:
FC
regularizer
{
l2_regularizer
{
weight
:
0
.
0
}
}
initializer
{
variance_scaling_initializer
{
factor
:
1
.
0
uniform
:
true
mode
:
FAN_AVG
}
}
}
}
}
second_stage_post_processing
{
batch_non_max_suppression
{
score_threshold
:
0
.
0
iou_threshold
:
0
.
6
max_detections_per_class
:
5
max_total_detections
:
5
}
score_converter
:
SOFTMAX
}
second_stage_localization_loss_weight
:
2
.
0
second_stage_classification_loss_weight
:
1
.
0
}
}
train_config
: {
batch_size
:
1
num_steps
:
4000000
optimizer
{
momentum_optimizer
: {
learning_rate
: {
manual_step_learning_rate
{
initial_learning_rate
:
0
.
0003
schedule
{
step
:
3000000
learning_rate
: .
00003
}
schedule
{
step
:
3500000
learning_rate
: .
000003
}
}
}
momentum_optimizer_value
:
0
.
9
}
use_moving_average
:
false
}
gradient_clipping_by_norm
:
10
.
0
fine_tune_checkpoint
:
"PATH_TO_BE_CONFIGURED/model.ckpt"
data_augmentation_options
{
random_horizontal_flip
{
}
}
}
train_input_reader
: {
label_map_path
:
"PATH_TO_BE_CONFIGURED/fgvc_2854_classes_label_map.pbtxt"
tf_record_input_reader
{
input_path
:
"PATH_TO_BE_CONFIGURED/animal_2854_train.record"
}
}
eval_config
: {
metrics_set
:
"pascal_voc_detection_metrics"
use_moving_averages
:
false
num_examples
:
48736
}
eval_input_reader
: {
label_map_path
:
"PATH_TO_BE_CONFIGURED/fgvc_2854_classes_label_map.pbtxt"
shuffle
:
false
num_readers
:
1
tf_record_input_reader
{
input_path
:
"PATH_TO_BE_CONFIGURED/animal_2854_val.record"
}
}
research/object_detection/samples/configs/faster_rcnn_resnet50_fgvc.config
0 → 100644
View file @
27b4acd4
# Faster R-CNN with Resnet-50 (v1)
# Users should configure the fine_tune_checkpoint field in the train config as
# well as the label_map_path and input_path fields in the train_input_reader and
# eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that
# should be configured.
model
{
faster_rcnn
{
num_classes
:
2854
image_resizer
{
keep_aspect_ratio_resizer
{
min_dimension
:
600
max_dimension
:
1024
}
}
feature_extractor
{
type
:
'faster_rcnn_resnet50'
first_stage_features_stride
:
16
}
first_stage_anchor_generator
{
grid_anchor_generator
{
scales
: [
0
.
25
,
0
.
5
,
1
.
0
,
2
.
0
]
aspect_ratios
: [
0
.
5
,
1
.
0
,
2
.
0
]
height_stride
:
16
width_stride
:
16
}
}
first_stage_box_predictor_conv_hyperparams
{
op
:
CONV
regularizer
{
l2_regularizer
{
weight
:
0
.
0
}
}
initializer
{
truncated_normal_initializer
{
stddev
:
0
.
01
}
}
}
first_stage_nms_score_threshold
:
0
.
0
first_stage_nms_iou_threshold
:
0
.
7
first_stage_max_proposals
:
32
first_stage_localization_loss_weight
:
2
.
0
first_stage_objectness_loss_weight
:
1
.
0
initial_crop_size
:
14
maxpool_kernel_size
:
2
maxpool_stride
:
2
second_stage_batch_size
:
32
second_stage_box_predictor
{
mask_rcnn_box_predictor
{
use_dropout
:
false
dropout_keep_probability
:
1
.
0
fc_hyperparams
{
op
:
FC
regularizer
{
l2_regularizer
{
weight
:
0
.
0
}
}
initializer
{
variance_scaling_initializer
{
factor
:
1
.
0
uniform
:
true
mode
:
FAN_AVG
}
}
}
}
}
second_stage_post_processing
{
batch_non_max_suppression
{
score_threshold
:
0
.
0
iou_threshold
:
0
.
6
max_detections_per_class
:
5
max_total_detections
:
5
}
score_converter
:
SOFTMAX
}
second_stage_localization_loss_weight
:
2
.
0
second_stage_classification_loss_weight
:
1
.
0
}
}
train_config
: {
batch_size
:
1
num_steps
:
4000000
optimizer
{
momentum_optimizer
: {
learning_rate
: {
manual_step_learning_rate
{
initial_learning_rate
:
0
.
0003
schedule
{
step
:
3000000
learning_rate
: .
00003
}
schedule
{
step
:
3500000
learning_rate
: .
000003
}
}
}
momentum_optimizer_value
:
0
.
9
}
use_moving_average
:
false
}
gradient_clipping_by_norm
:
10
.
0
fine_tune_checkpoint
:
"PATH_TO_BE_CONFIGURED/model.ckpt"
data_augmentation_options
{
random_horizontal_flip
{
}
}
}
train_input_reader
: {
label_map_path
:
"PATH_TO_BE_CONFIGURED/fgvc_2854_classes_label_map.pbtxt"
tf_record_input_reader
{
input_path
:
"PATH_TO_BE_CONFIGURED/animal_2854_train.record"
}
}
eval_config
: {
metrics_set
:
"pascal_voc_detection_metrics"
use_moving_averages
:
false
num_examples
:
10
}
eval_input_reader
: {
label_map_path
:
"PATH_TO_BE_CONFIGURED/fgvc_2854_classes_label_map.pbtxt"
shuffle
:
false
num_readers
:
1
tf_record_input_reader
{
input_path
:
"PATH_TO_BE_CONFIGURED/animal_2854_val.record"
}
}
research/object_detection/utils/config_util.py
View file @
27b4acd4
...
@@ -103,15 +103,20 @@ def create_configs_from_pipeline_proto(pipeline_config):
...
@@ -103,15 +103,20 @@ def create_configs_from_pipeline_proto(pipeline_config):
Returns:
Returns:
Dictionary of configuration objects. Keys are `model`, `train_config`,
Dictionary of configuration objects. Keys are `model`, `train_config`,
`train_input_config`, `eval_config`, `eval_input_config`. Value are the
`train_input_config`, `eval_config`, `eval_input_configs`. Value are
corresponding config objects.
the corresponding config objects or list of config objects (only for
eval_input_configs).
"""
"""
configs
=
{}
configs
=
{}
configs
[
"model"
]
=
pipeline_config
.
model
configs
[
"model"
]
=
pipeline_config
.
model
configs
[
"train_config"
]
=
pipeline_config
.
train_config
configs
[
"train_config"
]
=
pipeline_config
.
train_config
configs
[
"train_input_config"
]
=
pipeline_config
.
train_input_reader
configs
[
"train_input_config"
]
=
pipeline_config
.
train_input_reader
configs
[
"eval_config"
]
=
pipeline_config
.
eval_config
configs
[
"eval_config"
]
=
pipeline_config
.
eval_config
configs
[
"eval_input_config"
]
=
pipeline_config
.
eval_input_reader
configs
[
"eval_input_configs"
]
=
pipeline_config
.
eval_input_reader
# Keeps eval_input_config only for backwards compatibility. All clients should
# read eval_input_configs instead.
if
configs
[
"eval_input_configs"
]:
configs
[
"eval_input_config"
]
=
configs
[
"eval_input_configs"
][
0
]
if
pipeline_config
.
HasField
(
"graph_rewriter"
):
if
pipeline_config
.
HasField
(
"graph_rewriter"
):
configs
[
"graph_rewriter_config"
]
=
pipeline_config
.
graph_rewriter
configs
[
"graph_rewriter_config"
]
=
pipeline_config
.
graph_rewriter
...
@@ -150,7 +155,7 @@ def create_pipeline_proto_from_configs(configs):
...
@@ -150,7 +155,7 @@ def create_pipeline_proto_from_configs(configs):
pipeline_config
.
train_config
.
CopyFrom
(
configs
[
"train_config"
])
pipeline_config
.
train_config
.
CopyFrom
(
configs
[
"train_config"
])
pipeline_config
.
train_input_reader
.
CopyFrom
(
configs
[
"train_input_config"
])
pipeline_config
.
train_input_reader
.
CopyFrom
(
configs
[
"train_input_config"
])
pipeline_config
.
eval_config
.
CopyFrom
(
configs
[
"eval_config"
])
pipeline_config
.
eval_config
.
CopyFrom
(
configs
[
"eval_config"
])
pipeline_config
.
eval_input_reader
.
CopyFrom
(
configs
[
"eval_input_config"
])
pipeline_config
.
eval_input_reader
.
extend
(
configs
[
"eval_input_config
s
"
])
if
"graph_rewriter_config"
in
configs
:
if
"graph_rewriter_config"
in
configs
:
pipeline_config
.
graph_rewriter
.
CopyFrom
(
configs
[
"graph_rewriter_config"
])
pipeline_config
.
graph_rewriter
.
CopyFrom
(
configs
[
"graph_rewriter_config"
])
return
pipeline_config
return
pipeline_config
...
@@ -224,7 +229,7 @@ def get_configs_from_multiple_files(model_config_path="",
...
@@ -224,7 +229,7 @@ def get_configs_from_multiple_files(model_config_path="",
eval_input_config
=
input_reader_pb2
.
InputReader
()
eval_input_config
=
input_reader_pb2
.
InputReader
()
with
tf
.
gfile
.
GFile
(
eval_input_config_path
,
"r"
)
as
f
:
with
tf
.
gfile
.
GFile
(
eval_input_config_path
,
"r"
)
as
f
:
text_format
.
Merge
(
f
.
read
(),
eval_input_config
)
text_format
.
Merge
(
f
.
read
(),
eval_input_config
)
configs
[
"eval_input_config"
]
=
eval_input_config
configs
[
"eval_input_config
s
"
]
=
[
eval_input_config
]
if
graph_rewriter_config_path
:
if
graph_rewriter_config_path
:
configs
[
"graph_rewriter_config"
]
=
get_graph_rewriter_config_from_file
(
configs
[
"graph_rewriter_config"
]
=
get_graph_rewriter_config_from_file
(
...
@@ -284,14 +289,133 @@ def _is_generic_key(key):
...
@@ -284,14 +289,133 @@ def _is_generic_key(key):
"graph_rewriter_config"
,
"graph_rewriter_config"
,
"model"
,
"model"
,
"train_input_config"
,
"train_input_config"
,
"train_
input_
config"
,
"train_config"
,
"
train
_config"
]:
"
eval
_config"
]:
if
key
.
startswith
(
prefix
+
"."
):
if
key
.
startswith
(
prefix
+
"."
):
return
True
return
True
return
False
return
False
def
merge_external_params_with_configs
(
configs
,
hparams
=
None
,
**
kwargs
):
def
_check_and_convert_legacy_input_config_key
(
key
):
"""Checks key and converts legacy input config update to specific update.
Args:
key: string indicates the target of update operation.
Returns:
is_valid_input_config_key: A boolean indicating whether the input key is to
update input config(s).
key_name: 'eval_input_configs' or 'train_input_config' string if
is_valid_input_config_key is true. None if is_valid_input_config_key is
false.
input_name: always returns None since legacy input config key never
specifies the target input config. Keeping this output only to match the
output form defined for input config update.
field_name: the field name in input config. `key` itself if
is_valid_input_config_key is false.
"""
key_name
=
None
input_name
=
None
field_name
=
key
is_valid_input_config_key
=
True
if
field_name
==
"train_shuffle"
:
key_name
=
"train_input_config"
field_name
=
"shuffle"
elif
field_name
==
"eval_shuffle"
:
key_name
=
"eval_input_configs"
field_name
=
"shuffle"
elif
field_name
==
"train_input_path"
:
key_name
=
"train_input_config"
field_name
=
"input_path"
elif
field_name
==
"eval_input_path"
:
key_name
=
"eval_input_configs"
field_name
=
"input_path"
elif
field_name
==
"train_input_path"
:
key_name
=
"train_input_config"
field_name
=
"input_path"
elif
field_name
==
"eval_input_path"
:
key_name
=
"eval_input_configs"
field_name
=
"input_path"
elif
field_name
==
"append_train_input_path"
:
key_name
=
"train_input_config"
field_name
=
"input_path"
elif
field_name
==
"append_eval_input_path"
:
key_name
=
"eval_input_configs"
field_name
=
"input_path"
else
:
is_valid_input_config_key
=
False
return
is_valid_input_config_key
,
key_name
,
input_name
,
field_name
def
check_and_parse_input_config_key
(
configs
,
key
):
"""Checks key and returns specific fields if key is valid input config update.
Args:
configs: Dictionary of configuration objects. See outputs from
get_configs_from_pipeline_file() or get_configs_from_multiple_files().
key: string indicates the target of update operation.
Returns:
is_valid_input_config_key: A boolean indicate whether the input key is to
update input config(s).
key_name: 'eval_input_configs' or 'train_input_config' string if
is_valid_input_config_key is true. None if is_valid_input_config_key is
false.
input_name: the name of the input config to be updated. None if
is_valid_input_config_key is false.
field_name: the field name in input config. `key` itself if
is_valid_input_config_key is false.
Raises:
ValueError: when the input key format doesn't match any known formats.
ValueError: if key_name doesn't match 'eval_input_configs' or
'train_input_config'.
ValueError: if input_name doesn't match any name in train or eval input
configs.
ValueError: if field_name doesn't match any supported fields.
"""
key_name
=
None
input_name
=
None
field_name
=
None
fields
=
key
.
split
(
":"
)
if
len
(
fields
)
==
1
:
field_name
=
key
return
_check_and_convert_legacy_input_config_key
(
key
)
elif
len
(
fields
)
==
3
:
key_name
=
fields
[
0
]
input_name
=
fields
[
1
]
field_name
=
fields
[
2
]
else
:
raise
ValueError
(
"Invalid key format when overriding configs."
)
# Checks if key_name is valid for specific update.
if
key_name
not
in
[
"eval_input_configs"
,
"train_input_config"
]:
raise
ValueError
(
"Invalid key_name when overriding input config."
)
# Checks if input_name is valid for specific update. For train input config it
# should match configs[key_name].name, for eval input configs it should match
# the name field of one of the eval_input_configs.
if
isinstance
(
configs
[
key_name
],
input_reader_pb2
.
InputReader
):
is_valid_input_name
=
configs
[
key_name
].
name
==
input_name
else
:
is_valid_input_name
=
input_name
in
[
eval_input_config
.
name
for
eval_input_config
in
configs
[
key_name
]
]
if
not
is_valid_input_name
:
raise
ValueError
(
"Invalid input_name when overriding input config."
)
# Checks if field_name is valid for specific update.
if
field_name
not
in
[
"input_path"
,
"label_map_path"
,
"shuffle"
,
"mask_type"
,
"sample_1_of_n_examples"
]:
raise
ValueError
(
"Invalid field_name when overriding input config."
)
return
True
,
key_name
,
input_name
,
field_name
def
merge_external_params_with_configs
(
configs
,
hparams
=
None
,
kwargs_dict
=
None
):
"""Updates `configs` dictionary based on supplied parameters.
"""Updates `configs` dictionary based on supplied parameters.
This utility is for modifying specific fields in the object detection configs.
This utility is for modifying specific fields in the object detection configs.
...
@@ -304,6 +428,31 @@ def merge_external_params_with_configs(configs, hparams=None, **kwargs):
...
@@ -304,6 +428,31 @@ def merge_external_params_with_configs(configs, hparams=None, **kwargs):
1. Strategy-based overrides, which update multiple relevant configuration
1. Strategy-based overrides, which update multiple relevant configuration
options. For example, updating `learning_rate` will update both the warmup and
options. For example, updating `learning_rate` will update both the warmup and
final learning rates.
final learning rates.
In this case key can be one of the following formats:
1. legacy update: single string that indicates the attribute to be
updated. E.g. 'lable_map_path', 'eval_input_path', 'shuffle'.
Note that when updating fields (e.g. eval_input_path, eval_shuffle) in
eval_input_configs, the override will only be applied when
eval_input_configs has exactly 1 element.
2. specific update: colon separated string that indicates which field in
which input_config to update. It should have 3 fields:
- key_name: Name of the input config we should update, either
'train_input_config' or 'eval_input_configs'
- input_name: a 'name' that can be used to identify elements, especially
when configs[key_name] is a repeated field.
- field_name: name of the field that you want to override.
For example, given configs dict as below:
configs = {
'model': {...}
'train_config': {...}
'train_input_config': {...}
'eval_config': {...}
'eval_input_configs': [{ name:"eval_coco", ...},
{ name:"eval_voc", ... }]
}
Assume we want to update the input_path of the eval_input_config
whose name is 'eval_coco'. The `key` would then be:
'eval_input_configs:eval_coco:input_path'
2. Generic key/value, which update a specific parameter based on namespaced
2. Generic key/value, which update a specific parameter based on namespaced
configuration keys. For example,
configuration keys. For example,
`model.ssd.loss.hard_example_miner.max_negatives_per_positive` will update the
`model.ssd.loss.hard_example_miner.max_negatives_per_positive` will update the
...
@@ -314,55 +463,29 @@ def merge_external_params_with_configs(configs, hparams=None, **kwargs):
...
@@ -314,55 +463,29 @@ def merge_external_params_with_configs(configs, hparams=None, **kwargs):
configs: Dictionary of configuration objects. See outputs from
configs: Dictionary of configuration objects. See outputs from
get_configs_from_pipeline_file() or get_configs_from_multiple_files().
get_configs_from_pipeline_file() or get_configs_from_multiple_files().
hparams: A `HParams`.
hparams: A `HParams`.
**
kwargs: Extra keyword arguments that are treated the same way as
kwargs
_dict
: Extra keyword arguments that are treated the same way as
attribute/value pairs in `hparams`. Note that hyperparameters with the
attribute/value pairs in `hparams`. Note that hyperparameters with the
same names will override keyword arguments.
same names will override keyword arguments.
Returns:
Returns:
`configs` dictionary.
`configs` dictionary.
Raises:
ValueError: when the key string doesn't match any of its allowed formats.
"""
"""
if
kwargs_dict
is
None
:
kwargs_dict
=
{}
if
hparams
:
if
hparams
:
kwargs
.
update
(
hparams
.
values
())
kwargs
_dict
.
update
(
hparams
.
values
())
for
key
,
value
in
kwargs
.
items
():
for
key
,
value
in
kwargs
_dict
.
items
():
tf
.
logging
.
info
(
"Maybe overwriting %s: %s"
,
key
,
value
)
tf
.
logging
.
info
(
"Maybe overwriting %s: %s"
,
key
,
value
)
# pylint: disable=g-explicit-bool-comparison
# pylint: disable=g-explicit-bool-comparison
if
value
==
""
or
value
is
None
:
if
value
==
""
or
value
is
None
:
continue
continue
# pylint: enable=g-explicit-bool-comparison
# pylint: enable=g-explicit-bool-comparison
if
key
==
"learning_rate"
:
elif
_maybe_update_config_with_key_value
(
configs
,
key
,
value
):
_update_initial_learning_rate
(
configs
,
value
)
continue
elif
key
==
"batch_size"
:
_update_batch_size
(
configs
,
value
)
elif
key
==
"momentum_optimizer_value"
:
_update_momentum_optimizer_value
(
configs
,
value
)
elif
key
==
"classification_localization_weight_ratio"
:
# Localization weight is fixed to 1.0.
_update_classification_localization_weight_ratio
(
configs
,
value
)
elif
key
==
"focal_loss_gamma"
:
_update_focal_loss_gamma
(
configs
,
value
)
elif
key
==
"focal_loss_alpha"
:
_update_focal_loss_alpha
(
configs
,
value
)
elif
key
==
"train_steps"
:
_update_train_steps
(
configs
,
value
)
elif
key
==
"eval_steps"
:
_update_eval_steps
(
configs
,
value
)
elif
key
==
"train_input_path"
:
_update_input_path
(
configs
[
"train_input_config"
],
value
)
elif
key
==
"eval_input_path"
:
_update_input_path
(
configs
[
"eval_input_config"
],
value
)
elif
key
==
"label_map_path"
:
_update_label_map_path
(
configs
,
value
)
elif
key
==
"mask_type"
:
_update_mask_type
(
configs
,
value
)
elif
key
==
"eval_with_moving_averages"
:
_update_use_moving_averages
(
configs
,
value
)
elif
key
==
"train_shuffle"
:
_update_shuffle
(
configs
[
"train_input_config"
],
value
)
elif
key
==
"eval_shuffle"
:
_update_shuffle
(
configs
[
"eval_input_config"
],
value
)
elif
key
==
"retain_original_images_in_eval"
:
_update_retain_original_images
(
configs
[
"eval_config"
],
value
)
elif
_is_generic_key
(
key
):
elif
_is_generic_key
(
key
):
_update_generic
(
configs
,
key
,
value
)
_update_generic
(
configs
,
key
,
value
)
else
:
else
:
...
@@ -370,6 +493,148 @@ def merge_external_params_with_configs(configs, hparams=None, **kwargs):
...
@@ -370,6 +493,148 @@ def merge_external_params_with_configs(configs, hparams=None, **kwargs):
return
configs
return
configs
def
_maybe_update_config_with_key_value
(
configs
,
key
,
value
):
"""Checks key type and updates `configs` with the key value pair accordingly.
Args:
configs: Dictionary of configuration objects. See outputs from
get_configs_from_pipeline_file() or get_configs_from_multiple_files().
key: String indicates the field(s) to be updated.
value: Value used to override existing field value.
Returns:
A boolean value that indicates whether the override succeeds.
Raises:
ValueError: when the key string doesn't match any of the formats above.
"""
is_valid_input_config_key
,
key_name
,
input_name
,
field_name
=
(
check_and_parse_input_config_key
(
configs
,
key
))
if
is_valid_input_config_key
:
update_input_reader_config
(
configs
,
key_name
=
key_name
,
input_name
=
input_name
,
field_name
=
field_name
,
value
=
value
)
elif
field_name
==
"learning_rate"
:
_update_initial_learning_rate
(
configs
,
value
)
elif
field_name
==
"batch_size"
:
_update_batch_size
(
configs
,
value
)
elif
field_name
==
"momentum_optimizer_value"
:
_update_momentum_optimizer_value
(
configs
,
value
)
elif
field_name
==
"classification_localization_weight_ratio"
:
# Localization weight is fixed to 1.0.
_update_classification_localization_weight_ratio
(
configs
,
value
)
elif
field_name
==
"focal_loss_gamma"
:
_update_focal_loss_gamma
(
configs
,
value
)
elif
field_name
==
"focal_loss_alpha"
:
_update_focal_loss_alpha
(
configs
,
value
)
elif
field_name
==
"train_steps"
:
_update_train_steps
(
configs
,
value
)
elif
field_name
==
"label_map_path"
:
_update_label_map_path
(
configs
,
value
)
elif
field_name
==
"mask_type"
:
_update_mask_type
(
configs
,
value
)
elif
field_name
==
"sample_1_of_n_eval_examples"
:
_update_all_eval_input_configs
(
configs
,
"sample_1_of_n_examples"
,
value
)
elif
field_name
==
"eval_num_epochs"
:
_update_all_eval_input_configs
(
configs
,
"num_epochs"
,
value
)
elif
field_name
==
"eval_with_moving_averages"
:
_update_use_moving_averages
(
configs
,
value
)
elif
field_name
==
"retain_original_images_in_eval"
:
_update_retain_original_images
(
configs
[
"eval_config"
],
value
)
elif
field_name
==
"use_bfloat16"
:
_update_use_bfloat16
(
configs
,
value
)
else
:
return
False
return
True
def
_update_tf_record_input_path
(
input_config
,
input_path
):
"""Updates input configuration to reflect a new input path.
The input_config object is updated in place, and hence not returned.
Args:
input_config: A input_reader_pb2.InputReader.
input_path: A path to data or list of paths.
Raises:
TypeError: if input reader type is not `tf_record_input_reader`.
"""
input_reader_type
=
input_config
.
WhichOneof
(
"input_reader"
)
if
input_reader_type
==
"tf_record_input_reader"
:
input_config
.
tf_record_input_reader
.
ClearField
(
"input_path"
)
if
isinstance
(
input_path
,
list
):
input_config
.
tf_record_input_reader
.
input_path
.
extend
(
input_path
)
else
:
input_config
.
tf_record_input_reader
.
input_path
.
append
(
input_path
)
else
:
raise
TypeError
(
"Input reader type must be `tf_record_input_reader`."
)
def
update_input_reader_config
(
configs
,
key_name
=
None
,
input_name
=
None
,
field_name
=
None
,
value
=
None
,
path_updater
=
_update_tf_record_input_path
):
"""Updates specified input reader config field.
Args:
configs: Dictionary of configuration objects. See outputs from
get_configs_from_pipeline_file() or get_configs_from_multiple_files().
key_name: Name of the input config we should update, either
'train_input_config' or 'eval_input_configs'
input_name: String name used to identify input config to update with. Should
be either None or value of the 'name' field in one of the input reader
configs.
field_name: Field name in input_reader_pb2.InputReader.
value: Value used to override existing field value.
path_updater: helper function used to update the input path. Only used when
field_name is "input_path".
Raises:
ValueError: when input field_name is None.
ValueError: when input_name is None and number of eval_input_readers does
not equal to 1.
"""
if
isinstance
(
configs
[
key_name
],
input_reader_pb2
.
InputReader
):
# Updates singular input_config object.
target_input_config
=
configs
[
key_name
]
if
field_name
==
"input_path"
:
path_updater
(
input_config
=
target_input_config
,
input_path
=
value
)
else
:
setattr
(
target_input_config
,
field_name
,
value
)
elif
input_name
is
None
and
len
(
configs
[
key_name
])
==
1
:
# Updates first (and the only) object of input_config list.
target_input_config
=
configs
[
key_name
][
0
]
if
field_name
==
"input_path"
:
path_updater
(
input_config
=
target_input_config
,
input_path
=
value
)
else
:
setattr
(
target_input_config
,
field_name
,
value
)
elif
input_name
is
not
None
and
len
(
configs
[
key_name
]):
# Updates input_config whose name matches input_name.
update_count
=
0
for
input_config
in
configs
[
key_name
]:
if
input_config
.
name
==
input_name
:
setattr
(
input_config
,
field_name
,
value
)
update_count
=
update_count
+
1
if
not
update_count
:
raise
ValueError
(
"Input name {} not found when overriding."
.
format
(
input_name
))
elif
update_count
>
1
:
raise
ValueError
(
"Duplicate input name found when overriding."
)
else
:
key_name
=
"None"
if
key_name
is
None
else
key_name
input_name
=
"None"
if
input_name
is
None
else
input_name
field_name
=
"None"
if
field_name
is
None
else
field_name
raise
ValueError
(
"Unknown input config overriding: "
"key_name:{}, input_name:{}, field_name:{}."
.
format
(
key_name
,
input_name
,
field_name
))
def
_update_initial_learning_rate
(
configs
,
learning_rate
):
def
_update_initial_learning_rate
(
configs
,
learning_rate
):
"""Updates `configs` to reflect the new initial learning rate.
"""Updates `configs` to reflect the new initial learning rate.
...
@@ -596,27 +861,10 @@ def _update_eval_steps(configs, eval_steps):
...
@@ -596,27 +861,10 @@ def _update_eval_steps(configs, eval_steps):
configs
[
"eval_config"
].
num_examples
=
int
(
eval_steps
)
configs
[
"eval_config"
].
num_examples
=
int
(
eval_steps
)
def
_update_input_path
(
input_config
,
input_path
):
def
_update_all_eval_input_configs
(
configs
,
field
,
value
):
"""Updates input configuration to reflect a new input path.
"""Updates the content of `field` with `value` for all eval input configs."""
for
eval_input_config
in
configs
[
"eval_input_configs"
]:
The input_config object is updated in place, and hence not returned.
setattr
(
eval_input_config
,
field
,
value
)
Args:
input_config: A input_reader_pb2.InputReader.
input_path: A path to data or list of paths.
Raises:
TypeError: if input reader type is not `tf_record_input_reader`.
"""
input_reader_type
=
input_config
.
WhichOneof
(
"input_reader"
)
if
input_reader_type
==
"tf_record_input_reader"
:
input_config
.
tf_record_input_reader
.
ClearField
(
"input_path"
)
if
isinstance
(
input_path
,
list
):
input_config
.
tf_record_input_reader
.
input_path
.
extend
(
input_path
)
else
:
input_config
.
tf_record_input_reader
.
input_path
.
append
(
input_path
)
else
:
raise
TypeError
(
"Input reader type must be `tf_record_input_reader`."
)
def
_update_label_map_path
(
configs
,
label_map_path
):
def
_update_label_map_path
(
configs
,
label_map_path
):
...
@@ -630,7 +878,7 @@ def _update_label_map_path(configs, label_map_path):
...
@@ -630,7 +878,7 @@ def _update_label_map_path(configs, label_map_path):
label_map_path: New path to `StringIntLabelMap` pbtxt file.
label_map_path: New path to `StringIntLabelMap` pbtxt file.
"""
"""
configs
[
"train_input_config"
].
label_map_path
=
label_map_path
configs
[
"train_input_config"
].
label_map_path
=
label_map_path
configs
[
"
eval_input_config
"
].
label_map_path
=
label_map_path
_update_all_
eval_input_config
s
(
configs
,
"
label_map_path
"
,
label_map_path
)
def
_update_mask_type
(
configs
,
mask_type
):
def
_update_mask_type
(
configs
,
mask_type
):
...
@@ -645,7 +893,7 @@ def _update_mask_type(configs, mask_type):
...
@@ -645,7 +893,7 @@ def _update_mask_type(configs, mask_type):
input_reader_pb2.InstanceMaskType
input_reader_pb2.InstanceMaskType
"""
"""
configs
[
"train_input_config"
].
mask_type
=
mask_type
configs
[
"train_input_config"
].
mask_type
=
mask_type
configs
[
"
eval_input_config
"
].
mask_type
=
mask_type
_update_all_
eval_input_config
s
(
configs
,
"
mask_type
"
,
mask_type
)
def
_update_use_moving_averages
(
configs
,
use_moving_averages
):
def
_update_use_moving_averages
(
configs
,
use_moving_averages
):
...
@@ -662,18 +910,6 @@ def _update_use_moving_averages(configs, use_moving_averages):
...
@@ -662,18 +910,6 @@ def _update_use_moving_averages(configs, use_moving_averages):
configs
[
"eval_config"
].
use_moving_averages
=
use_moving_averages
configs
[
"eval_config"
].
use_moving_averages
=
use_moving_averages
def
_update_shuffle
(
input_config
,
shuffle
):
"""Updates input configuration to reflect a new shuffle configuration.
The input_config object is updated in place, and hence not returned.
Args:
input_config: A input_reader_pb2.InputReader.
shuffle: Whether or not to shuffle the input data before reading.
"""
input_config
.
shuffle
=
shuffle
def
_update_retain_original_images
(
eval_config
,
retain_original_images
):
def
_update_retain_original_images
(
eval_config
,
retain_original_images
):
"""Updates eval config with option to retain original images.
"""Updates eval config with option to retain original images.
...
@@ -685,3 +921,16 @@ def _update_retain_original_images(eval_config, retain_original_images):
...
@@ -685,3 +921,16 @@ def _update_retain_original_images(eval_config, retain_original_images):
in eval mode.
in eval mode.
"""
"""
eval_config
.
retain_original_images
=
retain_original_images
eval_config
.
retain_original_images
=
retain_original_images
def
_update_use_bfloat16
(
configs
,
use_bfloat16
):
"""Updates `configs` to reflect the new setup on whether to use bfloat16.
The configs dictionary is updated in place, and hence not returned.
Args:
configs: Dictionary of configuration objects. See outputs from
get_configs_from_pipeline_file() or get_configs_from_multiple_files().
use_bfloat16: A bool, indicating whether to use bfloat16 for training.
"""
configs
[
"train_config"
].
use_bfloat16
=
use_bfloat16
research/object_detection/utils/config_util_test.py
View file @
27b4acd4
...
@@ -83,7 +83,7 @@ class ConfigUtilTest(tf.test.TestCase):
...
@@ -83,7 +83,7 @@ class ConfigUtilTest(tf.test.TestCase):
pipeline_config
.
train_config
.
batch_size
=
32
pipeline_config
.
train_config
.
batch_size
=
32
pipeline_config
.
train_input_reader
.
label_map_path
=
"path/to/label_map"
pipeline_config
.
train_input_reader
.
label_map_path
=
"path/to/label_map"
pipeline_config
.
eval_config
.
num_examples
=
20
pipeline_config
.
eval_config
.
num_examples
=
20
pipeline_config
.
eval_input_reader
.
queue_capacity
=
100
pipeline_config
.
eval_input_reader
.
add
().
queue_capacity
=
100
_write_config
(
pipeline_config
,
pipeline_config_path
)
_write_config
(
pipeline_config
,
pipeline_config_path
)
...
@@ -96,7 +96,7 @@ class ConfigUtilTest(tf.test.TestCase):
...
@@ -96,7 +96,7 @@ class ConfigUtilTest(tf.test.TestCase):
self
.
assertProtoEquals
(
pipeline_config
.
eval_config
,
self
.
assertProtoEquals
(
pipeline_config
.
eval_config
,
configs
[
"eval_config"
])
configs
[
"eval_config"
])
self
.
assertProtoEquals
(
pipeline_config
.
eval_input_reader
,
self
.
assertProtoEquals
(
pipeline_config
.
eval_input_reader
,
configs
[
"eval_input_config"
])
configs
[
"eval_input_config
s
"
])
def
test_create_configs_from_pipeline_proto
(
self
):
def
test_create_configs_from_pipeline_proto
(
self
):
"""Tests creating configs dictionary from pipeline proto."""
"""Tests creating configs dictionary from pipeline proto."""
...
@@ -106,7 +106,7 @@ class ConfigUtilTest(tf.test.TestCase):
...
@@ -106,7 +106,7 @@ class ConfigUtilTest(tf.test.TestCase):
pipeline_config
.
train_config
.
batch_size
=
32
pipeline_config
.
train_config
.
batch_size
=
32
pipeline_config
.
train_input_reader
.
label_map_path
=
"path/to/label_map"
pipeline_config
.
train_input_reader
.
label_map_path
=
"path/to/label_map"
pipeline_config
.
eval_config
.
num_examples
=
20
pipeline_config
.
eval_config
.
num_examples
=
20
pipeline_config
.
eval_input_reader
.
queue_capacity
=
100
pipeline_config
.
eval_input_reader
.
add
().
queue_capacity
=
100
configs
=
config_util
.
create_configs_from_pipeline_proto
(
pipeline_config
)
configs
=
config_util
.
create_configs_from_pipeline_proto
(
pipeline_config
)
self
.
assertProtoEquals
(
pipeline_config
.
model
,
configs
[
"model"
])
self
.
assertProtoEquals
(
pipeline_config
.
model
,
configs
[
"model"
])
...
@@ -116,7 +116,7 @@ class ConfigUtilTest(tf.test.TestCase):
...
@@ -116,7 +116,7 @@ class ConfigUtilTest(tf.test.TestCase):
configs
[
"train_input_config"
])
configs
[
"train_input_config"
])
self
.
assertProtoEquals
(
pipeline_config
.
eval_config
,
configs
[
"eval_config"
])
self
.
assertProtoEquals
(
pipeline_config
.
eval_config
,
configs
[
"eval_config"
])
self
.
assertProtoEquals
(
pipeline_config
.
eval_input_reader
,
self
.
assertProtoEquals
(
pipeline_config
.
eval_input_reader
,
configs
[
"eval_input_config"
])
configs
[
"eval_input_config
s
"
])
def
test_create_pipeline_proto_from_configs
(
self
):
def
test_create_pipeline_proto_from_configs
(
self
):
"""Tests that proto can be reconstructed from configs dictionary."""
"""Tests that proto can be reconstructed from configs dictionary."""
...
@@ -127,7 +127,7 @@ class ConfigUtilTest(tf.test.TestCase):
...
@@ -127,7 +127,7 @@ class ConfigUtilTest(tf.test.TestCase):
pipeline_config
.
train_config
.
batch_size
=
32
pipeline_config
.
train_config
.
batch_size
=
32
pipeline_config
.
train_input_reader
.
label_map_path
=
"path/to/label_map"
pipeline_config
.
train_input_reader
.
label_map_path
=
"path/to/label_map"
pipeline_config
.
eval_config
.
num_examples
=
20
pipeline_config
.
eval_config
.
num_examples
=
20
pipeline_config
.
eval_input_reader
.
queue_capacity
=
100
pipeline_config
.
eval_input_reader
.
add
().
queue_capacity
=
100
_write_config
(
pipeline_config
,
pipeline_config_path
)
_write_config
(
pipeline_config
,
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
...
@@ -142,7 +142,7 @@ class ConfigUtilTest(tf.test.TestCase):
...
@@ -142,7 +142,7 @@ class ConfigUtilTest(tf.test.TestCase):
pipeline_config
.
train_config
.
batch_size
=
32
pipeline_config
.
train_config
.
batch_size
=
32
pipeline_config
.
train_input_reader
.
label_map_path
=
"path/to/label_map"
pipeline_config
.
train_input_reader
.
label_map_path
=
"path/to/label_map"
pipeline_config
.
eval_config
.
num_examples
=
20
pipeline_config
.
eval_config
.
num_examples
=
20
pipeline_config
.
eval_input_reader
.
queue_capacity
=
100
pipeline_config
.
eval_input_reader
.
add
().
queue_capacity
=
100
config_util
.
save_pipeline_config
(
pipeline_config
,
self
.
get_temp_dir
())
config_util
.
save_pipeline_config
(
pipeline_config
,
self
.
get_temp_dir
())
configs
=
config_util
.
get_configs_from_pipeline_file
(
configs
=
config_util
.
get_configs_from_pipeline_file
(
...
@@ -197,8 +197,7 @@ class ConfigUtilTest(tf.test.TestCase):
...
@@ -197,8 +197,7 @@ class ConfigUtilTest(tf.test.TestCase):
self
.
assertProtoEquals
(
train_input_config
,
self
.
assertProtoEquals
(
train_input_config
,
configs
[
"train_input_config"
])
configs
[
"train_input_config"
])
self
.
assertProtoEquals
(
eval_config
,
configs
[
"eval_config"
])
self
.
assertProtoEquals
(
eval_config
,
configs
[
"eval_config"
])
self
.
assertProtoEquals
(
eval_input_config
,
self
.
assertProtoEquals
(
eval_input_config
,
configs
[
"eval_input_configs"
][
0
])
configs
[
"eval_input_config"
])
def
_assertOptimizerWithNewLearningRate
(
self
,
optimizer_name
):
def
_assertOptimizerWithNewLearningRate
(
self
,
optimizer_name
):
"""Asserts successful updating of all learning rate schemes."""
"""Asserts successful updating of all learning rate schemes."""
...
@@ -282,6 +281,41 @@ class ConfigUtilTest(tf.test.TestCase):
...
@@ -282,6 +281,41 @@ class ConfigUtilTest(tf.test.TestCase):
"""Tests new learning rates for Adam Optimizer."""
"""Tests new learning rates for Adam Optimizer."""
self
.
_assertOptimizerWithNewLearningRate
(
"adam_optimizer"
)
self
.
_assertOptimizerWithNewLearningRate
(
"adam_optimizer"
)
def
testGenericConfigOverride
(
self
):
"""Tests generic config overrides for all top-level configs."""
# Set one parameter for each of the top-level pipeline configs:
pipeline_config
=
pipeline_pb2
.
TrainEvalPipelineConfig
()
pipeline_config
.
model
.
ssd
.
num_classes
=
1
pipeline_config
.
train_config
.
batch_size
=
1
pipeline_config
.
eval_config
.
num_visualizations
=
1
pipeline_config
.
train_input_reader
.
label_map_path
=
"/some/path"
pipeline_config
.
eval_input_reader
.
add
().
label_map_path
=
"/some/path"
pipeline_config
.
graph_rewriter
.
quantization
.
weight_bits
=
1
pipeline_config_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"pipeline.config"
)
_write_config
(
pipeline_config
,
pipeline_config_path
)
# Override each of the parameters:
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
hparams
=
tf
.
contrib
.
training
.
HParams
(
**
{
"model.ssd.num_classes"
:
2
,
"train_config.batch_size"
:
2
,
"train_input_config.label_map_path"
:
"/some/other/path"
,
"eval_config.num_visualizations"
:
2
,
"graph_rewriter_config.quantization.weight_bits"
:
2
})
configs
=
config_util
.
merge_external_params_with_configs
(
configs
,
hparams
)
# Ensure that the parameters have the overridden values:
self
.
assertEqual
(
2
,
configs
[
"model"
].
ssd
.
num_classes
)
self
.
assertEqual
(
2
,
configs
[
"train_config"
].
batch_size
)
self
.
assertEqual
(
"/some/other/path"
,
configs
[
"train_input_config"
].
label_map_path
)
self
.
assertEqual
(
2
,
configs
[
"eval_config"
].
num_visualizations
)
self
.
assertEqual
(
2
,
configs
[
"graph_rewriter_config"
].
quantization
.
weight_bits
)
def
testNewBatchSize
(
self
):
def
testNewBatchSize
(
self
):
"""Tests that batch size is updated appropriately."""
"""Tests that batch size is updated appropriately."""
original_batch_size
=
2
original_batch_size
=
2
...
@@ -406,25 +440,19 @@ class ConfigUtilTest(tf.test.TestCase):
...
@@ -406,25 +440,19 @@ class ConfigUtilTest(tf.test.TestCase):
def
testMergingKeywordArguments
(
self
):
def
testMergingKeywordArguments
(
self
):
"""Tests that keyword arguments get merged as do hyperparameters."""
"""Tests that keyword arguments get merged as do hyperparameters."""
original_num_train_steps
=
100
original_num_train_steps
=
100
original_num_eval_steps
=
5
desired_num_train_steps
=
10
desired_num_train_steps
=
10
desired_num_eval_steps
=
1
pipeline_config_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"pipeline.config"
)
pipeline_config_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"pipeline.config"
)
pipeline_config
=
pipeline_pb2
.
TrainEvalPipelineConfig
()
pipeline_config
=
pipeline_pb2
.
TrainEvalPipelineConfig
()
pipeline_config
.
train_config
.
num_steps
=
original_num_train_steps
pipeline_config
.
train_config
.
num_steps
=
original_num_train_steps
pipeline_config
.
eval_config
.
num_examples
=
original_num_eval_steps
_write_config
(
pipeline_config
,
pipeline_config_path
)
_write_config
(
pipeline_config
,
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
override_dict
=
{
"train_steps"
:
desired_num_train_steps
}
configs
=
config_util
.
merge_external_params_with_configs
(
configs
=
config_util
.
merge_external_params_with_configs
(
configs
,
configs
,
kwargs_dict
=
override_dict
)
train_steps
=
desired_num_train_steps
,
eval_steps
=
desired_num_eval_steps
)
train_steps
=
configs
[
"train_config"
].
num_steps
train_steps
=
configs
[
"train_config"
].
num_steps
eval_steps
=
configs
[
"eval_config"
].
num_examples
self
.
assertEqual
(
desired_num_train_steps
,
train_steps
)
self
.
assertEqual
(
desired_num_train_steps
,
train_steps
)
self
.
assertEqual
(
desired_num_eval_steps
,
eval_steps
)
def
testGetNumberOfClasses
(
self
):
def
testGetNumberOfClasses
(
self
):
"""Tests that number of classes can be retrieved."""
"""Tests that number of classes can be retrieved."""
...
@@ -449,8 +477,9 @@ class ConfigUtilTest(tf.test.TestCase):
...
@@ -449,8 +477,9 @@ class ConfigUtilTest(tf.test.TestCase):
_write_config
(
pipeline_config
,
pipeline_config_path
)
_write_config
(
pipeline_config
,
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
override_dict
=
{
"train_input_path"
:
new_train_path
}
configs
=
config_util
.
merge_external_params_with_configs
(
configs
=
config_util
.
merge_external_params_with_configs
(
configs
,
train_input_path
=
new_train_path
)
configs
,
kwargs_dict
=
override_dict
)
reader_config
=
configs
[
"train_input_config"
].
tf_record_input_reader
reader_config
=
configs
[
"train_input_config"
].
tf_record_input_reader
final_path
=
reader_config
.
input_path
final_path
=
reader_config
.
input_path
self
.
assertEqual
([
new_train_path
],
final_path
)
self
.
assertEqual
([
new_train_path
],
final_path
)
...
@@ -467,8 +496,9 @@ class ConfigUtilTest(tf.test.TestCase):
...
@@ -467,8 +496,9 @@ class ConfigUtilTest(tf.test.TestCase):
_write_config
(
pipeline_config
,
pipeline_config_path
)
_write_config
(
pipeline_config
,
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
override_dict
=
{
"train_input_path"
:
new_train_path
}
configs
=
config_util
.
merge_external_params_with_configs
(
configs
=
config_util
.
merge_external_params_with_configs
(
configs
,
train_input_path
=
new_train_path
)
configs
,
kwargs_dict
=
override_dict
)
reader_config
=
configs
[
"train_input_config"
].
tf_record_input_reader
reader_config
=
configs
[
"train_input_config"
].
tf_record_input_reader
final_path
=
reader_config
.
input_path
final_path
=
reader_config
.
input_path
self
.
assertEqual
(
new_train_path
,
final_path
)
self
.
assertEqual
(
new_train_path
,
final_path
)
...
@@ -482,17 +512,18 @@ class ConfigUtilTest(tf.test.TestCase):
...
@@ -482,17 +512,18 @@ class ConfigUtilTest(tf.test.TestCase):
pipeline_config
=
pipeline_pb2
.
TrainEvalPipelineConfig
()
pipeline_config
=
pipeline_pb2
.
TrainEvalPipelineConfig
()
train_input_reader
=
pipeline_config
.
train_input_reader
train_input_reader
=
pipeline_config
.
train_input_reader
train_input_reader
.
label_map_path
=
original_label_map_path
train_input_reader
.
label_map_path
=
original_label_map_path
eval_input_reader
=
pipeline_config
.
eval_input_reader
eval_input_reader
=
pipeline_config
.
eval_input_reader
.
add
()
eval_input_reader
.
label_map_path
=
original_label_map_path
eval_input_reader
.
label_map_path
=
original_label_map_path
_write_config
(
pipeline_config
,
pipeline_config_path
)
_write_config
(
pipeline_config
,
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
override_dict
=
{
"label_map_path"
:
new_label_map_path
}
configs
=
config_util
.
merge_external_params_with_configs
(
configs
=
config_util
.
merge_external_params_with_configs
(
configs
,
label_map_path
=
new_label_map_path
)
configs
,
kwargs_dict
=
override_dict
)
self
.
assertEqual
(
new_label_map_path
,
self
.
assertEqual
(
new_label_map_path
,
configs
[
"train_input_config"
].
label_map_path
)
configs
[
"train_input_config"
].
label_map_path
)
self
.
assertEqual
(
new_label_map_path
,
for
eval_input_config
in
configs
[
"eval_input_configs"
]:
configs
[
"
eval_input_config
"
]
.
label_map_path
)
self
.
assertEqual
(
new_label_map_path
,
eval_input_config
.
label_map_path
)
def
testDontOverwriteEmptyLabelMapPath
(
self
):
def
testDontOverwriteEmptyLabelMapPath
(
self
):
"""Tests that label map path will not by overwritten with empty string."""
"""Tests that label map path will not by overwritten with empty string."""
...
@@ -503,17 +534,18 @@ class ConfigUtilTest(tf.test.TestCase):
...
@@ -503,17 +534,18 @@ class ConfigUtilTest(tf.test.TestCase):
pipeline_config
=
pipeline_pb2
.
TrainEvalPipelineConfig
()
pipeline_config
=
pipeline_pb2
.
TrainEvalPipelineConfig
()
train_input_reader
=
pipeline_config
.
train_input_reader
train_input_reader
=
pipeline_config
.
train_input_reader
train_input_reader
.
label_map_path
=
original_label_map_path
train_input_reader
.
label_map_path
=
original_label_map_path
eval_input_reader
=
pipeline_config
.
eval_input_reader
eval_input_reader
=
pipeline_config
.
eval_input_reader
.
add
()
eval_input_reader
.
label_map_path
=
original_label_map_path
eval_input_reader
.
label_map_path
=
original_label_map_path
_write_config
(
pipeline_config
,
pipeline_config_path
)
_write_config
(
pipeline_config
,
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
override_dict
=
{
"label_map_path"
:
new_label_map_path
}
configs
=
config_util
.
merge_external_params_with_configs
(
configs
=
config_util
.
merge_external_params_with_configs
(
configs
,
label_map_path
=
new_label_map_path
)
configs
,
kwargs_dict
=
override_dict
)
self
.
assertEqual
(
original_label_map_path
,
self
.
assertEqual
(
original_label_map_path
,
configs
[
"train_input_config"
].
label_map_path
)
configs
[
"train_input_config"
].
label_map_path
)
self
.
assertEqual
(
original_label_map_path
,
self
.
assertEqual
(
original_label_map_path
,
configs
[
"eval_input_config
"
].
label_map_path
)
configs
[
"eval_input_config
s"
][
0
].
label_map_path
)
def
testNewMaskType
(
self
):
def
testNewMaskType
(
self
):
"""Tests that mask type can be overwritten in input readers."""
"""Tests that mask type can be overwritten in input readers."""
...
@@ -524,15 +556,16 @@ class ConfigUtilTest(tf.test.TestCase):
...
@@ -524,15 +556,16 @@ class ConfigUtilTest(tf.test.TestCase):
pipeline_config
=
pipeline_pb2
.
TrainEvalPipelineConfig
()
pipeline_config
=
pipeline_pb2
.
TrainEvalPipelineConfig
()
train_input_reader
=
pipeline_config
.
train_input_reader
train_input_reader
=
pipeline_config
.
train_input_reader
train_input_reader
.
mask_type
=
original_mask_type
train_input_reader
.
mask_type
=
original_mask_type
eval_input_reader
=
pipeline_config
.
eval_input_reader
eval_input_reader
=
pipeline_config
.
eval_input_reader
.
add
()
eval_input_reader
.
mask_type
=
original_mask_type
eval_input_reader
.
mask_type
=
original_mask_type
_write_config
(
pipeline_config
,
pipeline_config_path
)
_write_config
(
pipeline_config
,
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
override_dict
=
{
"mask_type"
:
new_mask_type
}
configs
=
config_util
.
merge_external_params_with_configs
(
configs
=
config_util
.
merge_external_params_with_configs
(
configs
,
mask_type
=
new_mask_type
)
configs
,
kwargs_dict
=
override_dict
)
self
.
assertEqual
(
new_mask_type
,
configs
[
"train_input_config"
].
mask_type
)
self
.
assertEqual
(
new_mask_type
,
configs
[
"train_input_config"
].
mask_type
)
self
.
assertEqual
(
new_mask_type
,
configs
[
"eval_input_config
"
].
mask_type
)
self
.
assertEqual
(
new_mask_type
,
configs
[
"eval_input_config
s"
][
0
].
mask_type
)
def
testUseMovingAverageForEval
(
self
):
def
testUseMovingAverageForEval
(
self
):
use_moving_averages_orig
=
False
use_moving_averages_orig
=
False
...
@@ -543,8 +576,9 @@ class ConfigUtilTest(tf.test.TestCase):
...
@@ -543,8 +576,9 @@ class ConfigUtilTest(tf.test.TestCase):
_write_config
(
pipeline_config
,
pipeline_config_path
)
_write_config
(
pipeline_config
,
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
override_dict
=
{
"eval_with_moving_averages"
:
True
}
configs
=
config_util
.
merge_external_params_with_configs
(
configs
=
config_util
.
merge_external_params_with_configs
(
configs
,
eval_with_moving_averages
=
True
)
configs
,
kwargs_dict
=
override_dict
)
self
.
assertEqual
(
True
,
configs
[
"eval_config"
].
use_moving_averages
)
self
.
assertEqual
(
True
,
configs
[
"eval_config"
].
use_moving_averages
)
def
testGetImageResizerConfig
(
self
):
def
testGetImageResizerConfig
(
self
):
...
@@ -585,14 +619,14 @@ class ConfigUtilTest(tf.test.TestCase):
...
@@ -585,14 +619,14 @@ class ConfigUtilTest(tf.test.TestCase):
pipeline_config_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"pipeline.config"
)
pipeline_config_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"pipeline.config"
)
pipeline_config
=
pipeline_pb2
.
TrainEvalPipelineConfig
()
pipeline_config
=
pipeline_pb2
.
TrainEvalPipelineConfig
()
pipeline_config
.
eval_input_reader
.
shuffle
=
original_shuffle
pipeline_config
.
eval_input_reader
.
add
().
shuffle
=
original_shuffle
_write_config
(
pipeline_config
,
pipeline_config_path
)
_write_config
(
pipeline_config
,
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
override_dict
=
{
"eval_shuffle"
:
desired_shuffle
}
configs
=
config_util
.
merge_external_params_with_configs
(
configs
=
config_util
.
merge_external_params_with_configs
(
configs
,
eval_shuffle
=
desired_shuffle
)
configs
,
kwargs_dict
=
override_dict
)
eval_shuffle
=
configs
[
"eval_input_config"
].
shuffle
self
.
assertEqual
(
desired_shuffle
,
configs
[
"eval_input_configs"
][
0
].
shuffle
)
self
.
assertEqual
(
desired_shuffle
,
eval_shuffle
)
def
testTrainShuffle
(
self
):
def
testTrainShuffle
(
self
):
"""Tests that `train_shuffle` keyword arguments are applied correctly."""
"""Tests that `train_shuffle` keyword arguments are applied correctly."""
...
@@ -605,8 +639,9 @@ class ConfigUtilTest(tf.test.TestCase):
...
@@ -605,8 +639,9 @@ class ConfigUtilTest(tf.test.TestCase):
_write_config
(
pipeline_config
,
pipeline_config_path
)
_write_config
(
pipeline_config
,
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
override_dict
=
{
"train_shuffle"
:
desired_shuffle
}
configs
=
config_util
.
merge_external_params_with_configs
(
configs
=
config_util
.
merge_external_params_with_configs
(
configs
,
train_shuffle
=
desired_shuffle
)
configs
,
kwargs_dict
=
override_dict
)
train_shuffle
=
configs
[
"train_input_config"
].
shuffle
train_shuffle
=
configs
[
"train_input_config"
].
shuffle
self
.
assertEqual
(
desired_shuffle
,
train_shuffle
)
self
.
assertEqual
(
desired_shuffle
,
train_shuffle
)
...
@@ -622,11 +657,210 @@ class ConfigUtilTest(tf.test.TestCase):
...
@@ -622,11 +657,210 @@ class ConfigUtilTest(tf.test.TestCase):
_write_config
(
pipeline_config
,
pipeline_config_path
)
_write_config
(
pipeline_config
,
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
override_dict
=
{
"retain_original_images_in_eval"
:
desired_retain_original_images
}
configs
=
config_util
.
merge_external_params_with_configs
(
configs
=
config_util
.
merge_external_params_with_configs
(
configs
,
retain_original_images_in_eval
=
desired_retain_original_images
)
configs
,
kwargs_dict
=
override_dict
)
retain_original_images
=
configs
[
"eval_config"
].
retain_original_images
retain_original_images
=
configs
[
"eval_config"
].
retain_original_images
self
.
assertEqual
(
desired_retain_original_images
,
retain_original_images
)
self
.
assertEqual
(
desired_retain_original_images
,
retain_original_images
)
def
testOverwriteAllEvalSampling
(
self
):
original_num_eval_examples
=
1
new_num_eval_examples
=
10
pipeline_config_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"pipeline.config"
)
pipeline_config
=
pipeline_pb2
.
TrainEvalPipelineConfig
()
pipeline_config
.
eval_input_reader
.
add
().
sample_1_of_n_examples
=
(
original_num_eval_examples
)
pipeline_config
.
eval_input_reader
.
add
().
sample_1_of_n_examples
=
(
original_num_eval_examples
)
_write_config
(
pipeline_config
,
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
override_dict
=
{
"sample_1_of_n_eval_examples"
:
new_num_eval_examples
}
configs
=
config_util
.
merge_external_params_with_configs
(
configs
,
kwargs_dict
=
override_dict
)
for
eval_input_config
in
configs
[
"eval_input_configs"
]:
self
.
assertEqual
(
new_num_eval_examples
,
eval_input_config
.
sample_1_of_n_examples
)
def
testOverwriteAllEvalNumEpochs
(
self
):
original_num_epochs
=
10
new_num_epochs
=
1
pipeline_config_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"pipeline.config"
)
pipeline_config
=
pipeline_pb2
.
TrainEvalPipelineConfig
()
pipeline_config
.
eval_input_reader
.
add
().
num_epochs
=
original_num_epochs
pipeline_config
.
eval_input_reader
.
add
().
num_epochs
=
original_num_epochs
_write_config
(
pipeline_config
,
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
override_dict
=
{
"eval_num_epochs"
:
new_num_epochs
}
configs
=
config_util
.
merge_external_params_with_configs
(
configs
,
kwargs_dict
=
override_dict
)
for
eval_input_config
in
configs
[
"eval_input_configs"
]:
self
.
assertEqual
(
new_num_epochs
,
eval_input_config
.
num_epochs
)
def
testUpdateMaskTypeForAllInputConfigs
(
self
):
original_mask_type
=
input_reader_pb2
.
NUMERICAL_MASKS
new_mask_type
=
input_reader_pb2
.
PNG_MASKS
pipeline_config_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"pipeline.config"
)
pipeline_config
=
pipeline_pb2
.
TrainEvalPipelineConfig
()
train_config
=
pipeline_config
.
train_input_reader
train_config
.
mask_type
=
original_mask_type
eval_1
=
pipeline_config
.
eval_input_reader
.
add
()
eval_1
.
mask_type
=
original_mask_type
eval_1
.
name
=
"eval_1"
eval_2
=
pipeline_config
.
eval_input_reader
.
add
()
eval_2
.
mask_type
=
original_mask_type
eval_2
.
name
=
"eval_2"
_write_config
(
pipeline_config
,
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
override_dict
=
{
"mask_type"
:
new_mask_type
}
configs
=
config_util
.
merge_external_params_with_configs
(
configs
,
kwargs_dict
=
override_dict
)
self
.
assertEqual
(
configs
[
"train_input_config"
].
mask_type
,
new_mask_type
)
for
eval_input_config
in
configs
[
"eval_input_configs"
]:
self
.
assertEqual
(
eval_input_config
.
mask_type
,
new_mask_type
)
def
testErrorOverwritingMultipleInputConfig
(
self
):
original_shuffle
=
False
new_shuffle
=
True
pipeline_config_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"pipeline.config"
)
pipeline_config
=
pipeline_pb2
.
TrainEvalPipelineConfig
()
eval_1
=
pipeline_config
.
eval_input_reader
.
add
()
eval_1
.
shuffle
=
original_shuffle
eval_1
.
name
=
"eval_1"
eval_2
=
pipeline_config
.
eval_input_reader
.
add
()
eval_2
.
shuffle
=
original_shuffle
eval_2
.
name
=
"eval_2"
_write_config
(
pipeline_config
,
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
override_dict
=
{
"eval_shuffle"
:
new_shuffle
}
with
self
.
assertRaises
(
ValueError
):
configs
=
config_util
.
merge_external_params_with_configs
(
configs
,
kwargs_dict
=
override_dict
)
def
testCheckAndParseInputConfigKey
(
self
):
pipeline_config_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"pipeline.config"
)
pipeline_config
=
pipeline_pb2
.
TrainEvalPipelineConfig
()
pipeline_config
.
eval_input_reader
.
add
().
name
=
"eval_1"
pipeline_config
.
eval_input_reader
.
add
().
name
=
"eval_2"
_write_config
(
pipeline_config
,
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
specific_shuffle_update_key
=
"eval_input_configs:eval_2:shuffle"
is_valid_input_config_key
,
key_name
,
input_name
,
field_name
=
(
config_util
.
check_and_parse_input_config_key
(
configs
,
specific_shuffle_update_key
))
self
.
assertTrue
(
is_valid_input_config_key
)
self
.
assertEqual
(
key_name
,
"eval_input_configs"
)
self
.
assertEqual
(
input_name
,
"eval_2"
)
self
.
assertEqual
(
field_name
,
"shuffle"
)
legacy_shuffle_update_key
=
"eval_shuffle"
is_valid_input_config_key
,
key_name
,
input_name
,
field_name
=
(
config_util
.
check_and_parse_input_config_key
(
configs
,
legacy_shuffle_update_key
))
self
.
assertTrue
(
is_valid_input_config_key
)
self
.
assertEqual
(
key_name
,
"eval_input_configs"
)
self
.
assertEqual
(
input_name
,
None
)
self
.
assertEqual
(
field_name
,
"shuffle"
)
non_input_config_update_key
=
"label_map_path"
is_valid_input_config_key
,
key_name
,
input_name
,
field_name
=
(
config_util
.
check_and_parse_input_config_key
(
configs
,
non_input_config_update_key
))
self
.
assertFalse
(
is_valid_input_config_key
)
self
.
assertEqual
(
key_name
,
None
)
self
.
assertEqual
(
input_name
,
None
)
self
.
assertEqual
(
field_name
,
"label_map_path"
)
with
self
.
assertRaisesRegexp
(
ValueError
,
"Invalid key format when overriding configs."
):
config_util
.
check_and_parse_input_config_key
(
configs
,
"train_input_config:shuffle"
)
with
self
.
assertRaisesRegexp
(
ValueError
,
"Invalid key_name when overriding input config."
):
config_util
.
check_and_parse_input_config_key
(
configs
,
"invalid_key_name:train_name:shuffle"
)
with
self
.
assertRaisesRegexp
(
ValueError
,
"Invalid input_name when overriding input config."
):
config_util
.
check_and_parse_input_config_key
(
configs
,
"eval_input_configs:unknown_eval_name:shuffle"
)
with
self
.
assertRaisesRegexp
(
ValueError
,
"Invalid field_name when overriding input config."
):
config_util
.
check_and_parse_input_config_key
(
configs
,
"eval_input_configs:eval_2:unknown_field_name"
)
def
testUpdateInputReaderConfigSuccess
(
self
):
original_shuffle
=
False
new_shuffle
=
True
pipeline_config_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"pipeline.config"
)
pipeline_config
=
pipeline_pb2
.
TrainEvalPipelineConfig
()
pipeline_config
.
train_input_reader
.
shuffle
=
original_shuffle
_write_config
(
pipeline_config
,
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
config_util
.
update_input_reader_config
(
configs
,
key_name
=
"train_input_config"
,
input_name
=
None
,
field_name
=
"shuffle"
,
value
=
new_shuffle
)
self
.
assertEqual
(
configs
[
"train_input_config"
].
shuffle
,
new_shuffle
)
config_util
.
update_input_reader_config
(
configs
,
key_name
=
"train_input_config"
,
input_name
=
None
,
field_name
=
"shuffle"
,
value
=
new_shuffle
)
self
.
assertEqual
(
configs
[
"train_input_config"
].
shuffle
,
new_shuffle
)
def
testUpdateInputReaderConfigErrors
(
self
):
pipeline_config_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"pipeline.config"
)
pipeline_config
=
pipeline_pb2
.
TrainEvalPipelineConfig
()
pipeline_config
.
eval_input_reader
.
add
().
name
=
"same_eval_name"
pipeline_config
.
eval_input_reader
.
add
().
name
=
"same_eval_name"
_write_config
(
pipeline_config
,
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
with
self
.
assertRaisesRegexp
(
ValueError
,
"Duplicate input name found when overriding."
):
config_util
.
update_input_reader_config
(
configs
,
key_name
=
"eval_input_configs"
,
input_name
=
"same_eval_name"
,
field_name
=
"shuffle"
,
value
=
False
)
with
self
.
assertRaisesRegexp
(
ValueError
,
"Input name name_not_exist not found when overriding."
):
config_util
.
update_input_reader_config
(
configs
,
key_name
=
"eval_input_configs"
,
input_name
=
"name_not_exist"
,
field_name
=
"shuffle"
,
value
=
False
)
with
self
.
assertRaisesRegexp
(
ValueError
,
"Unknown input config overriding."
):
config_util
.
update_input_reader_config
(
configs
,
key_name
=
"eval_input_configs"
,
input_name
=
None
,
field_name
=
"shuffle"
,
value
=
False
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
tf
.
test
.
main
()
research/object_detection/utils/label_map_util.py
View file @
27b4acd4
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Label map utility functions."""
"""Label map utility functions."""
import
logging
import
logging
...
@@ -73,10 +72,10 @@ def get_max_label_map_index(label_map):
...
@@ -73,10 +72,10 @@ def get_max_label_map_index(label_map):
def
convert_label_map_to_categories
(
label_map
,
def
convert_label_map_to_categories
(
label_map
,
max_num_classes
,
max_num_classes
,
use_display_name
=
True
):
use_display_name
=
True
):
"""
Loads
label map proto
and
returns categories list compatible with eval.
"""
Given
label map proto returns categories list compatible with eval.
This function
loads a
label map and returns a list of dicts, each of
which
This function
converts
label map
proto
and returns a list of dicts, each of
has the following keys:
which
has the following keys:
'id': (required) an integer id uniquely identifying this category.
'id': (required) an integer id uniquely identifying this category.
'name': (required) string representing category name
'name': (required) string representing category name
e.g., 'cat', 'dog', 'pizza'.
e.g., 'cat', 'dog', 'pizza'.
...
@@ -89,9 +88,10 @@ def convert_label_map_to_categories(label_map,
...
@@ -89,9 +88,10 @@ def convert_label_map_to_categories(label_map,
label_map: a StringIntLabelMapProto or None. If None, a default categories
label_map: a StringIntLabelMapProto or None. If None, a default categories
list is created with max_num_classes categories.
list is created with max_num_classes categories.
max_num_classes: maximum number of (consecutive) label indices to include.
max_num_classes: maximum number of (consecutive) label indices to include.
use_display_name: (boolean) choose whether to load 'display_name' field
use_display_name: (boolean) choose whether to load 'display_name' field as
as category name. If False or if the display_name field does not exist,
category name. If False or if the display_name field does not exist, uses
uses 'name' field as category names instead.
'name' field as category names instead.
Returns:
Returns:
categories: a list of dictionaries representing all possible categories.
categories: a list of dictionaries representing all possible categories.
"""
"""
...
@@ -107,8 +107,9 @@ def convert_label_map_to_categories(label_map,
...
@@ -107,8 +107,9 @@ def convert_label_map_to_categories(label_map,
return
categories
return
categories
for
item
in
label_map
.
item
:
for
item
in
label_map
.
item
:
if
not
0
<
item
.
id
<=
max_num_classes
:
if
not
0
<
item
.
id
<=
max_num_classes
:
logging
.
info
(
'Ignore item %d since it falls outside of requested '
logging
.
info
(
'label range.'
,
item
.
id
)
'Ignore item %d since it falls outside of requested '
'label range.'
,
item
.
id
)
continue
continue
if
use_display_name
and
item
.
HasField
(
'display_name'
):
if
use_display_name
and
item
.
HasField
(
'display_name'
):
name
=
item
.
display_name
name
=
item
.
display_name
...
@@ -188,20 +189,44 @@ def get_label_map_dict(label_map_path,
...
@@ -188,20 +189,44 @@ def get_label_map_dict(label_map_path,
return
label_map_dict
return
label_map_dict
def
create_category_index_from_labelmap
(
label_map_path
):
def
create_categories_from_labelmap
(
label_map_path
,
use_display_name
=
True
):
"""Reads a label map and returns categories list compatible with eval.
This function converts label map proto and returns a list of dicts, each of
which has the following keys:
'id': an integer id uniquely identifying this category.
'name': string representing category name e.g., 'cat', 'dog'.
Args:
label_map_path: Path to `StringIntLabelMap` proto text file.
use_display_name: (boolean) choose whether to load 'display_name' field
as category name. If False or if the display_name field does not exist,
uses 'name' field as category names instead.
Returns:
categories: a list of dictionaries representing all possible categories.
"""
label_map
=
load_labelmap
(
label_map_path
)
max_num_classes
=
max
(
item
.
id
for
item
in
label_map
.
item
)
return
convert_label_map_to_categories
(
label_map
,
max_num_classes
,
use_display_name
)
def
create_category_index_from_labelmap
(
label_map_path
,
use_display_name
=
True
):
"""Reads a label map and returns a category index.
"""Reads a label map and returns a category index.
Args:
Args:
label_map_path: Path to `StringIntLabelMap` proto text file.
label_map_path: Path to `StringIntLabelMap` proto text file.
use_display_name: (boolean) choose whether to load 'display_name' field
as category name. If False or if the display_name field does not exist,
uses 'name' field as category names instead.
Returns:
Returns:
A category index, which is a dictionary that maps integer ids to dicts
A category index, which is a dictionary that maps integer ids to dicts
containing categories, e.g.
containing categories, e.g.
{1: {'id': 1, 'name': 'dog'}, 2: {'id': 2, 'name': 'cat'}, ...}
{1: {'id': 1, 'name': 'dog'}, 2: {'id': 2, 'name': 'cat'}, ...}
"""
"""
label_map
=
load_labelmap
(
label_map_path
)
categories
=
create_categories_from_labelmap
(
label_map_path
,
use_display_name
)
max_num_classes
=
max
(
item
.
id
for
item
in
label_map
.
item
)
categories
=
convert_label_map_to_categories
(
label_map
,
max_num_classes
)
return
create_category_index
(
categories
)
return
create_category_index
(
categories
)
...
...
research/object_detection/utils/label_map_util_test.py
View file @
27b4acd4
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Tests for object_detection.utils.label_map_util."""
"""Tests for object_detection.utils.label_map_util."""
import
os
import
os
...
@@ -189,7 +188,7 @@ class LabelMapUtilTest(tf.test.TestCase):
...
@@ -189,7 +188,7 @@ class LabelMapUtilTest(tf.test.TestCase):
}]
}]
self
.
assertListEqual
(
expected_categories_list
,
categories
)
self
.
assertListEqual
(
expected_categories_list
,
categories
)
def
test_convert_label_map_to_
coco_
categories
(
self
):
def
test_convert_label_map_to_categories
(
self
):
label_map_proto
=
self
.
_generate_label_map
(
num_classes
=
4
)
label_map_proto
=
self
.
_generate_label_map
(
num_classes
=
4
)
categories
=
label_map_util
.
convert_label_map_to_categories
(
categories
=
label_map_util
.
convert_label_map_to_categories
(
label_map_proto
,
max_num_classes
=
3
)
label_map_proto
,
max_num_classes
=
3
)
...
@@ -205,7 +204,7 @@ class LabelMapUtilTest(tf.test.TestCase):
...
@@ -205,7 +204,7 @@ class LabelMapUtilTest(tf.test.TestCase):
}]
}]
self
.
assertListEqual
(
expected_categories_list
,
categories
)
self
.
assertListEqual
(
expected_categories_list
,
categories
)
def
test_convert_label_map_to_
coco_
categories_with_few_classes
(
self
):
def
test_convert_label_map_to_categories_with_few_classes
(
self
):
label_map_proto
=
self
.
_generate_label_map
(
num_classes
=
4
)
label_map_proto
=
self
.
_generate_label_map
(
num_classes
=
4
)
cat_no_offset
=
label_map_util
.
convert_label_map_to_categories
(
cat_no_offset
=
label_map_util
.
convert_label_map_to_categories
(
label_map_proto
,
max_num_classes
=
2
)
label_map_proto
,
max_num_classes
=
2
)
...
@@ -238,6 +237,30 @@ class LabelMapUtilTest(tf.test.TestCase):
...
@@ -238,6 +237,30 @@ class LabelMapUtilTest(tf.test.TestCase):
}
}
},
category_index
)
},
category_index
)
def
test_create_categories_from_labelmap
(
self
):
label_map_string
=
"""
item {
id:1
name:'dog'
}
item {
id:2
name:'cat'
}
"""
label_map_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'label_map.pbtxt'
)
with
tf
.
gfile
.
Open
(
label_map_path
,
'wb'
)
as
f
:
f
.
write
(
label_map_string
)
categories
=
label_map_util
.
create_categories_from_labelmap
(
label_map_path
)
self
.
assertListEqual
([{
'name'
:
u
'dog'
,
'id'
:
1
},
{
'name'
:
u
'cat'
,
'id'
:
2
}],
categories
)
def
test_create_category_index_from_labelmap
(
self
):
def
test_create_category_index_from_labelmap
(
self
):
label_map_string
=
"""
label_map_string
=
"""
item {
item {
...
@@ -266,6 +289,46 @@ class LabelMapUtilTest(tf.test.TestCase):
...
@@ -266,6 +289,46 @@ class LabelMapUtilTest(tf.test.TestCase):
}
}
},
category_index
)
},
category_index
)
def
test_create_category_index_from_labelmap_display
(
self
):
label_map_string
=
"""
item {
id:2
name:'cat'
display_name:'meow'
}
item {
id:1
name:'dog'
display_name:'woof'
}
"""
label_map_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'label_map.pbtxt'
)
with
tf
.
gfile
.
Open
(
label_map_path
,
'wb'
)
as
f
:
f
.
write
(
label_map_string
)
self
.
assertDictEqual
({
1
:
{
'name'
:
u
'dog'
,
'id'
:
1
},
2
:
{
'name'
:
u
'cat'
,
'id'
:
2
}
},
label_map_util
.
create_category_index_from_labelmap
(
label_map_path
,
False
))
self
.
assertDictEqual
({
1
:
{
'name'
:
u
'woof'
,
'id'
:
1
},
2
:
{
'name'
:
u
'meow'
,
'id'
:
2
}
},
label_map_util
.
create_category_index_from_labelmap
(
label_map_path
))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
research/object_detection/utils/ops.py
View file @
27b4acd4
...
@@ -160,6 +160,9 @@ def pad_to_multiple(tensor, multiple):
...
@@ -160,6 +160,9 @@ def pad_to_multiple(tensor, multiple):
Returns:
Returns:
padded_tensor: the tensor zero padded to the specified multiple.
padded_tensor: the tensor zero padded to the specified multiple.
"""
"""
if
multiple
==
1
:
return
tensor
tensor_shape
=
tensor
.
get_shape
()
tensor_shape
=
tensor
.
get_shape
()
batch_size
=
static_shape
.
get_batch_size
(
tensor_shape
)
batch_size
=
static_shape
.
get_batch_size
(
tensor_shape
)
tensor_height
=
static_shape
.
get_height
(
tensor_shape
)
tensor_height
=
static_shape
.
get_height
(
tensor_shape
)
...
@@ -697,8 +700,11 @@ def position_sensitive_crop_regions(image,
...
@@ -697,8 +700,11 @@ def position_sensitive_crop_regions(image,
image_crops
=
[]
image_crops
=
[]
for
(
split
,
box
)
in
zip
(
image_splits
,
position_sensitive_boxes
):
for
(
split
,
box
)
in
zip
(
image_splits
,
position_sensitive_boxes
):
if
split
.
shape
.
is_fully_defined
()
and
box
.
shape
.
is_fully_defined
():
if
split
.
shape
.
is_fully_defined
()
and
box
.
shape
.
is_fully_defined
():
crop
=
matmul_crop_and_resize
(
crop
=
tf
.
squeeze
(
tf
.
expand_dims
(
split
,
0
),
box
,
bin_crop_size
)
matmul_crop_and_resize
(
tf
.
expand_dims
(
split
,
axis
=
0
),
tf
.
expand_dims
(
box
,
axis
=
0
),
bin_crop_size
),
axis
=
0
)
else
:
else
:
crop
=
tf
.
image
.
crop_and_resize
(
crop
=
tf
.
image
.
crop_and_resize
(
tf
.
expand_dims
(
split
,
0
),
box
,
tf
.
expand_dims
(
split
,
0
),
box
,
...
@@ -785,50 +791,85 @@ def reframe_box_masks_to_image_masks(box_masks, boxes, image_height,
...
@@ -785,50 +791,85 @@ def reframe_box_masks_to_image_masks(box_masks, boxes, image_height,
return
tf
.
squeeze
(
image_masks
,
axis
=
3
)
return
tf
.
squeeze
(
image_masks
,
axis
=
3
)
def
merge_boxes_with_multiple_labels
(
boxes
,
classes
,
num_classes
):
def
merge_boxes_with_multiple_labels
(
boxes
,
classes
,
confidences
,
num_classes
,
quantization_bins
=
10000
):
"""Merges boxes with same coordinates and returns K-hot encoded classes.
"""Merges boxes with same coordinates and returns K-hot encoded classes.
Args:
Args:
boxes: A tf.float32 tensor with shape [N, 4] holding N boxes.
boxes: A tf.float32 tensor with shape [N, 4] holding N boxes. Only
normalized coordinates are allowed.
classes: A tf.int32 tensor with shape [N] holding class indices.
classes: A tf.int32 tensor with shape [N] holding class indices.
The class index starts at 0.
The class index starts at 0.
confidences: A tf.float32 tensor with shape [N] holding class confidences.
num_classes: total number of classes to use for K-hot encoding.
num_classes: total number of classes to use for K-hot encoding.
quantization_bins: the number of bins used to quantize the box coordinate.
Returns:
Returns:
merged_boxes: A tf.float32 tensor with shape [N', 4] holding boxes,
merged_boxes: A tf.float32 tensor with shape [N', 4] holding boxes,
where N' <= N.
where N' <= N.
class_encodings: A tf.int32 tensor with shape [N', num_classes] holding
class_encodings: A tf.int32 tensor with shape [N', num_classes] holding
k-hot encodings for the merged boxes.
K-hot encodings for the merged boxes.
confidence_encodings: A tf.float32 tensor with shape [N', num_classes]
holding encodings of confidences for the merged boxes.
merged_box_indices: A tf.int32 tensor with shape [N'] holding original
merged_box_indices: A tf.int32 tensor with shape [N'] holding original
indices of the boxes.
indices of the boxes.
"""
"""
def
merge_numpy_boxes
(
boxes
,
classes
,
num_classes
):
boxes_shape
=
tf
.
shape
(
boxes
)
"""Python function to merge numpy boxes."""
classes_shape
=
tf
.
shape
(
classes
)
if
boxes
.
size
<
1
:
confidences_shape
=
tf
.
shape
(
confidences
)
return
(
np
.
zeros
([
0
,
4
],
dtype
=
np
.
float32
),
box_class_shape_assert
=
shape_utils
.
assert_shape_equal_along_first_dimension
(
np
.
zeros
([
0
,
num_classes
],
dtype
=
np
.
int32
),
boxes_shape
,
classes_shape
)
np
.
zeros
([
0
],
dtype
=
np
.
int32
))
box_confidence_shape_assert
=
(
box_to_class_indices
=
{}
shape_utils
.
assert_shape_equal_along_first_dimension
(
for
box_index
in
range
(
boxes
.
shape
[
0
]):
boxes_shape
,
confidences_shape
))
box
=
tuple
(
boxes
[
box_index
,
:].
tolist
())
box_dimension_assert
=
tf
.
assert_equal
(
boxes_shape
[
1
],
4
)
class_index
=
classes
[
box_index
]
box_normalized_assert
=
shape_utils
.
assert_box_normalized
(
boxes
)
if
box
not
in
box_to_class_indices
:
box_to_class_indices
[
box
]
=
[
box_index
,
np
.
zeros
([
num_classes
])]
with
tf
.
control_dependencies
(
box_to_class_indices
[
box
][
1
][
class_index
]
=
1
[
box_class_shape_assert
,
box_confidence_shape_assert
,
merged_boxes
=
np
.
vstack
(
box_to_class_indices
.
keys
()).
astype
(
np
.
float32
)
box_dimension_assert
,
box_normalized_assert
]):
class_encodings
=
[
item
[
1
]
for
item
in
box_to_class_indices
.
values
()]
quantized_boxes
=
tf
.
to_int64
(
boxes
*
(
quantization_bins
-
1
))
class_encodings
=
np
.
vstack
(
class_encodings
).
astype
(
np
.
int32
)
ymin
,
xmin
,
ymax
,
xmax
=
tf
.
unstack
(
quantized_boxes
,
axis
=
1
)
merged_box_indices
=
[
item
[
0
]
for
item
in
box_to_class_indices
.
values
()]
hashcodes
=
(
merged_box_indices
=
np
.
array
(
merged_box_indices
).
astype
(
np
.
int32
)
ymin
+
return
merged_boxes
,
class_encodings
,
merged_box_indices
xmin
*
quantization_bins
+
ymax
*
quantization_bins
*
quantization_bins
+
merged_boxes
,
class_encodings
,
merged_box_indices
=
tf
.
py_func
(
xmax
*
quantization_bins
*
quantization_bins
*
quantization_bins
)
merge_numpy_boxes
,
[
boxes
,
classes
,
num_classes
],
unique_hashcodes
,
unique_indices
=
tf
.
unique
(
hashcodes
)
[
tf
.
float32
,
tf
.
int32
,
tf
.
int32
])
num_boxes
=
tf
.
shape
(
boxes
)[
0
]
merged_boxes
=
tf
.
reshape
(
merged_boxes
,
[
-
1
,
4
])
num_unique_boxes
=
tf
.
shape
(
unique_hashcodes
)[
0
]
class_encodings
=
tf
.
reshape
(
class_encodings
,
[
-
1
,
num_classes
])
merged_box_indices
=
tf
.
unsorted_segment_min
(
merged_box_indices
=
tf
.
reshape
(
merged_box_indices
,
[
-
1
])
tf
.
range
(
num_boxes
),
unique_indices
,
num_unique_boxes
)
return
merged_boxes
,
class_encodings
,
merged_box_indices
merged_boxes
=
tf
.
gather
(
boxes
,
merged_box_indices
)
def
map_box_encodings
(
i
):
"""Produces box K-hot and score encodings for each class index."""
box_mask
=
tf
.
equal
(
unique_indices
,
i
*
tf
.
ones
(
num_boxes
,
dtype
=
tf
.
int32
))
box_mask
=
tf
.
reshape
(
box_mask
,
[
-
1
])
box_indices
=
tf
.
boolean_mask
(
classes
,
box_mask
)
box_confidences
=
tf
.
boolean_mask
(
confidences
,
box_mask
)
box_class_encodings
=
tf
.
sparse_to_dense
(
box_indices
,
[
num_classes
],
1
,
validate_indices
=
False
)
box_confidence_encodings
=
tf
.
sparse_to_dense
(
box_indices
,
[
num_classes
],
box_confidences
,
validate_indices
=
False
)
return
box_class_encodings
,
box_confidence_encodings
class_encodings
,
confidence_encodings
=
tf
.
map_fn
(
map_box_encodings
,
tf
.
range
(
num_unique_boxes
),
back_prop
=
False
,
dtype
=
(
tf
.
int32
,
tf
.
float32
))
merged_boxes
=
tf
.
reshape
(
merged_boxes
,
[
-
1
,
4
])
class_encodings
=
tf
.
reshape
(
class_encodings
,
[
-
1
,
num_classes
])
confidence_encodings
=
tf
.
reshape
(
confidence_encodings
,
[
-
1
,
num_classes
])
merged_box_indices
=
tf
.
reshape
(
merged_box_indices
,
[
-
1
])
return
(
merged_boxes
,
class_encodings
,
confidence_encodings
,
merged_box_indices
)
def
nearest_neighbor_upsampling
(
input_tensor
,
scale
):
def
nearest_neighbor_upsampling
(
input_tensor
,
scale
):
...
@@ -895,7 +936,8 @@ def matmul_crop_and_resize(image, boxes, crop_size, scope=None):
...
@@ -895,7 +936,8 @@ def matmul_crop_and_resize(image, boxes, crop_size, scope=None):
Returns a tensor with crops from the input image at positions defined at
Returns a tensor with crops from the input image at positions defined at
the bounding box locations in boxes. The cropped boxes are all resized
the bounding box locations in boxes. The cropped boxes are all resized
(with bilinear interpolation) to a fixed size = `[crop_height, crop_width]`.
(with bilinear interpolation) to a fixed size = `[crop_height, crop_width]`.
The result is a 4-D tensor `[num_boxes, crop_height, crop_width, depth]`.
The result is a 5-D tensor `[batch, num_boxes, crop_height, crop_width,
depth]`.
Running time complexity:
Running time complexity:
O((# channels) * (# boxes) * (crop_size)^2 * M), where M is the number
O((# channels) * (# boxes) * (crop_size)^2 * M), where M is the number
...
@@ -914,14 +956,13 @@ def matmul_crop_and_resize(image, boxes, crop_size, scope=None):
...
@@ -914,14 +956,13 @@ def matmul_crop_and_resize(image, boxes, crop_size, scope=None):
Args:
Args:
image: A `Tensor`. Must be one of the following types: `uint8`, `int8`,
image: A `Tensor`. Must be one of the following types: `uint8`, `int8`,
`int16`, `int32`, `int64`, `half`, `float32`, `float64`.
`int16`, `int32`, `int64`, `half`,
'bfloat16',
`float32`, `float64`.
A 4-D tensor of shape `[batch, image_height, image_width, depth]`.
A 4-D tensor of shape `[batch, image_height, image_width, depth]`.
Both `image_height` and `image_width` need to be positive.
Both `image_height` and `image_width` need to be positive.
boxes: A `Tensor` of type `float32`.
boxes: A `Tensor` of type `float32` or 'bfloat16'.
A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor
A 3-D tensor of shape `[batch, num_boxes, 4]`. The boxes are specified in
specifies the coordinates of a box in the `box_ind[i]` image and is
normalized coordinates and are of the form `[y1, x1, y2, x2]`. A
specified in normalized coordinates `[y1, x1, y2, x2]`. A normalized
normalized coordinate value of `y` is mapped to the image coordinate at
coordinate value of `y` is mapped to the image coordinate at
`y * (image_height - 1)`, so as the `[0, 1]` interval of normalized image
`y * (image_height - 1)`, so as the `[0, 1]` interval of normalized image
height is mapped to `[0, image_height - 1] in image height coordinates.
height is mapped to `[0, image_height - 1] in image height coordinates.
We do allow y1 > y2, in which case the sampled crop is an up-down flipped
We do allow y1 > y2, in which case the sampled crop is an up-down flipped
...
@@ -935,14 +976,14 @@ def matmul_crop_and_resize(image, boxes, crop_size, scope=None):
...
@@ -935,14 +976,14 @@ def matmul_crop_and_resize(image, boxes, crop_size, scope=None):
scope: A name for the operation (optional).
scope: A name for the operation (optional).
Returns:
Returns:
A
4
-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`
A
5
-D tensor of shape `[
batch,
num_boxes, crop_height, crop_width, depth]`
Raises:
Raises:
ValueError: if image tensor does not have shape
ValueError: if image tensor does not have shape
`[
1
, image_height, image_width, depth]` and all dimensions statically
`[
batch
, image_height, image_width, depth]` and all dimensions statically
defined.
defined.
ValueError: if boxes tensor does not have shape `[num_boxes, 4]`
where
ValueError: if boxes tensor does not have shape `[
batch,
num_boxes, 4]`
num_boxes > 0.
where
num_boxes > 0.
ValueError: if crop_size is not a list of two positive integers
ValueError: if crop_size is not a list of two positive integers
"""
"""
img_shape
=
image
.
shape
.
as_list
()
img_shape
=
image
.
shape
.
as_list
()
...
@@ -953,13 +994,11 @@ def matmul_crop_and_resize(image, boxes, crop_size, scope=None):
...
@@ -953,13 +994,11 @@ def matmul_crop_and_resize(image, boxes, crop_size, scope=None):
dimensions
=
img_shape
+
crop_size
+
boxes_shape
dimensions
=
img_shape
+
crop_size
+
boxes_shape
if
not
all
([
isinstance
(
dim
,
int
)
for
dim
in
dimensions
]):
if
not
all
([
isinstance
(
dim
,
int
)
for
dim
in
dimensions
]):
raise
ValueError
(
'all input shapes must be statically defined'
)
raise
ValueError
(
'all input shapes must be statically defined'
)
if
len
(
crop_size
)
!=
2
:
if
len
(
boxes_shape
)
!=
3
or
boxes_shape
[
2
]
!=
4
:
raise
ValueError
(
'`crop_size` must be a list of length 2'
)
raise
ValueError
(
'`boxes` should have shape `[batch, num_boxes, 4]`'
)
if
len
(
boxes_shape
)
!=
2
or
boxes_shape
[
1
]
!=
4
:
if
len
(
img_shape
)
!=
4
:
raise
ValueError
(
'`boxes` should have shape `[num_boxes, 4]`'
)
if
len
(
img_shape
)
!=
4
and
img_shape
[
0
]
!=
1
:
raise
ValueError
(
'image should have shape '
raise
ValueError
(
'image should have shape '
'`[
1
, image_height, image_width, depth]`'
)
'`[
batch
, image_height, image_width, depth]`'
)
num_crops
=
boxes_shape
[
0
]
num_crops
=
boxes_shape
[
0
]
if
not
num_crops
>
0
:
if
not
num_crops
>
0
:
raise
ValueError
(
'number of boxes must be > 0'
)
raise
ValueError
(
'number of boxes must be > 0'
)
...
@@ -968,43 +1007,69 @@ def matmul_crop_and_resize(image, boxes, crop_size, scope=None):
...
@@ -968,43 +1007,69 @@ def matmul_crop_and_resize(image, boxes, crop_size, scope=None):
def
_lin_space_weights
(
num
,
img_size
):
def
_lin_space_weights
(
num
,
img_size
):
if
num
>
1
:
if
num
>
1
:
alpha
=
(
img_size
-
1
)
/
float
(
num
-
1
)
start_weights
=
tf
.
linspace
(
img_size
-
1.0
,
0.0
,
num
)
indices
=
np
.
reshape
(
np
.
arange
(
num
),
(
1
,
num
))
stop_weights
=
img_size
-
1
-
start_weights
start_weights
=
alpha
*
(
num
-
1
-
indices
)
stop_weights
=
alpha
*
indices
else
:
else
:
start_weights
=
num
*
[.
5
*
(
img_size
-
1
)]
start_weights
=
tf
.
constant
(
num
*
[.
5
*
(
img_size
-
1
)],
dtype
=
tf
.
float32
)
stop_weights
=
num
*
[.
5
*
(
img_size
-
1
)]
stop_weights
=
tf
.
constant
(
num
*
[.
5
*
(
img_size
-
1
)],
dtype
=
tf
.
float32
)
return
(
tf
.
constant
(
start_weights
,
dtype
=
tf
.
float32
),
return
(
start_weights
,
stop_weights
)
tf
.
constant
(
stop_weights
,
dtype
=
tf
.
float32
))
with
tf
.
name_scope
(
scope
,
'MatMulCropAndResize'
):
with
tf
.
name_scope
(
scope
,
'MatMulCropAndResize'
):
y1_weights
,
y2_weights
=
_lin_space_weights
(
crop_size
[
0
],
img_height
)
y1_weights
,
y2_weights
=
_lin_space_weights
(
crop_size
[
0
],
img_height
)
x1_weights
,
x2_weights
=
_lin_space_weights
(
crop_size
[
1
],
img_width
)
x1_weights
,
x2_weights
=
_lin_space_weights
(
crop_size
[
1
],
img_width
)
[
y1
,
x1
,
y2
,
x2
]
=
tf
.
split
(
value
=
boxes
,
num_or_size_splits
=
4
,
axis
=
1
)
y1_weights
=
tf
.
cast
(
y1_weights
,
boxes
.
dtype
)
y2_weights
=
tf
.
cast
(
y2_weights
,
boxes
.
dtype
)
x1_weights
=
tf
.
cast
(
x1_weights
,
boxes
.
dtype
)
x2_weights
=
tf
.
cast
(
x2_weights
,
boxes
.
dtype
)
[
y1
,
x1
,
y2
,
x2
]
=
tf
.
unstack
(
boxes
,
axis
=
2
)
# Pixel centers of input image and grid points along height and width
# Pixel centers of input image and grid points along height and width
image_idx_h
=
tf
.
constant
(
image_idx_h
=
tf
.
constant
(
np
.
reshape
(
np
.
arange
(
img_height
),
(
1
,
1
,
img_height
)),
dtype
=
tf
.
float32
)
np
.
reshape
(
np
.
arange
(
img_height
),
(
1
,
1
,
1
,
img_height
)),
dtype
=
boxes
.
dtype
)
image_idx_w
=
tf
.
constant
(
image_idx_w
=
tf
.
constant
(
np
.
reshape
(
np
.
arange
(
img_width
),
(
1
,
1
,
img_width
)),
dtype
=
tf
.
float32
)
np
.
reshape
(
np
.
arange
(
img_width
),
(
1
,
1
,
1
,
img_width
)),
grid_pos_h
=
tf
.
expand_dims
(
y1
*
y1_weights
+
y2
*
y2_weights
,
2
)
dtype
=
boxes
.
dtype
)
grid_pos_w
=
tf
.
expand_dims
(
x1
*
x1_weights
+
x2
*
x2_weights
,
2
)
grid_pos_h
=
tf
.
expand_dims
(
tf
.
einsum
(
'ab,c->abc'
,
y1
,
y1_weights
)
+
tf
.
einsum
(
'ab,c->abc'
,
y2
,
y2_weights
),
axis
=
3
)
grid_pos_w
=
tf
.
expand_dims
(
tf
.
einsum
(
'ab,c->abc'
,
x1
,
x1_weights
)
+
tf
.
einsum
(
'ab,c->abc'
,
x2
,
x2_weights
),
axis
=
3
)
# Create kernel matrices of pairwise kernel evaluations between pixel
# Create kernel matrices of pairwise kernel evaluations between pixel
# centers of image and grid points.
# centers of image and grid points.
kernel_h
=
tf
.
nn
.
relu
(
1
-
tf
.
abs
(
image_idx_h
-
grid_pos_h
))
kernel_h
=
tf
.
nn
.
relu
(
1
-
tf
.
abs
(
image_idx_h
-
grid_pos_h
))
kernel_w
=
tf
.
nn
.
relu
(
1
-
tf
.
abs
(
image_idx_w
-
grid_pos_w
))
kernel_w
=
tf
.
nn
.
relu
(
1
-
tf
.
abs
(
image_idx_w
-
grid_pos_w
))
# TODO(jonathanhuang): investigate whether all channels can be processed
# Compute matrix multiplication between the spatial dimensions of the image
# without the explicit unstack --- possibly with a permute and map_fn call.
# and height-wise kernel using einsum.
result_channels
=
[]
intermediate_image
=
tf
.
einsum
(
'abci,aiop->abcop'
,
kernel_h
,
image
)
for
channel
in
tf
.
unstack
(
image
,
axis
=
3
):
# Compute matrix multiplication between the spatial dimensions of the
result_channels
.
append
(
# intermediate_image and width-wise kernel using einsum.
tf
.
matmul
(
return
tf
.
einsum
(
'abno,abcop->abcnp'
,
kernel_w
,
intermediate_image
)
tf
.
matmul
(
kernel_h
,
tf
.
tile
(
channel
,
[
num_crops
,
1
,
1
])),
kernel_w
,
transpose_b
=
True
))
return
tf
.
stack
(
result_channels
,
axis
=
3
)
def
native_crop_and_resize
(
image
,
boxes
,
crop_size
,
scope
=
None
):
"""Same as `matmul_crop_and_resize` but uses tf.image.crop_and_resize."""
def
get_box_inds
(
proposals
):
proposals_shape
=
proposals
.
get_shape
().
as_list
()
if
any
(
dim
is
None
for
dim
in
proposals_shape
):
proposals_shape
=
tf
.
shape
(
proposals
)
ones_mat
=
tf
.
ones
(
proposals_shape
[:
2
],
dtype
=
tf
.
int32
)
multiplier
=
tf
.
expand_dims
(
tf
.
range
(
start
=
0
,
limit
=
proposals_shape
[
0
]),
1
)
return
tf
.
reshape
(
ones_mat
*
multiplier
,
[
-
1
])
with
tf
.
name_scope
(
scope
,
'CropAndResize'
):
cropped_regions
=
tf
.
image
.
crop_and_resize
(
image
,
tf
.
reshape
(
boxes
,
[
-
1
]
+
boxes
.
shape
.
as_list
()[
2
:]),
get_box_inds
(
boxes
),
crop_size
)
final_shape
=
tf
.
concat
([
tf
.
shape
(
boxes
)[:
2
],
tf
.
shape
(
cropped_regions
)[
1
:]],
axis
=
0
)
return
tf
.
reshape
(
cropped_regions
,
final_shape
)
def
expected_classification_loss_under_sampling
(
batch_cls_targets
,
cls_losses
,
def
expected_classification_loss_under_sampling
(
batch_cls_targets
,
cls_losses
,
...
...
research/object_detection/utils/ops_test.py
View file @
27b4acd4
...
@@ -1147,36 +1147,76 @@ class MergeBoxesWithMultipleLabelsTest(tf.test.TestCase):
...
@@ -1147,36 +1147,76 @@ class MergeBoxesWithMultipleLabelsTest(tf.test.TestCase):
[
0.25
,
0.25
,
0.75
,
0.75
]],
[
0.25
,
0.25
,
0.75
,
0.75
]],
dtype
=
tf
.
float32
)
dtype
=
tf
.
float32
)
class_indices
=
tf
.
constant
([
0
,
4
,
2
],
dtype
=
tf
.
int32
)
class_indices
=
tf
.
constant
([
0
,
4
,
2
],
dtype
=
tf
.
int32
)
class_confidences
=
tf
.
constant
([
0.8
,
0.2
,
0.1
],
dtype
=
tf
.
float32
)
num_classes
=
5
num_classes
=
5
merged_boxes
,
merged_classes
,
merged_box_indices
=
(
merged_boxes
,
merged_classes
,
merged_confidences
,
merged_box_indices
=
(
ops
.
merge_boxes_with_multiple_labels
(
boxes
,
class_indices
,
num_classes
))
ops
.
merge_boxes_with_multiple_labels
(
boxes
,
class_indices
,
class_confidences
,
num_classes
))
expected_merged_boxes
=
np
.
array
(
expected_merged_boxes
=
np
.
array
(
[[
0.25
,
0.25
,
0.75
,
0.75
],
[
0.0
,
0.0
,
0.5
,
0.75
]],
dtype
=
np
.
float32
)
[[
0.25
,
0.25
,
0.75
,
0.75
],
[
0.0
,
0.0
,
0.5
,
0.75
]],
dtype
=
np
.
float32
)
expected_merged_classes
=
np
.
array
(
expected_merged_classes
=
np
.
array
(
[[
1
,
0
,
1
,
0
,
0
],
[
0
,
0
,
0
,
0
,
1
]],
dtype
=
np
.
int32
)
[[
1
,
0
,
1
,
0
,
0
],
[
0
,
0
,
0
,
0
,
1
]],
dtype
=
np
.
int32
)
expected_merged_confidences
=
np
.
array
(
[[
0.8
,
0
,
0.1
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0.2
]],
dtype
=
np
.
float32
)
expected_merged_box_indices
=
np
.
array
([
0
,
1
],
dtype
=
np
.
int32
)
expected_merged_box_indices
=
np
.
array
([
0
,
1
],
dtype
=
np
.
int32
)
with
self
.
test_session
()
as
sess
:
with
self
.
test_session
()
as
sess
:
np_merged_boxes
,
np_merged_classes
,
np_merged_box_indices
=
sess
.
run
(
(
np_merged_boxes
,
np_merged_classes
,
np_merged_confidences
,
[
merged_boxes
,
merged_classes
,
merged_box_indices
])
np_merged_box_indices
)
=
sess
.
run
(
if
np_merged_classes
[
0
,
0
]
!=
1
:
[
merged_boxes
,
merged_classes
,
merged_confidences
,
expected_merged_boxes
=
expected_merged_boxes
[::
-
1
,
:]
merged_box_indices
])
expected_merged_classes
=
expected_merged_classes
[::
-
1
,
:]
expected_merged_box_indices
=
expected_merged_box_indices
[::
-
1
,
:]
self
.
assertAllClose
(
np_merged_boxes
,
expected_merged_boxes
)
self
.
assertAllClose
(
np_merged_boxes
,
expected_merged_boxes
)
self
.
assertAllClose
(
np_merged_classes
,
expected_merged_classes
)
self
.
assertAllClose
(
np_merged_classes
,
expected_merged_classes
)
self
.
assertAllClose
(
np_merged_confidences
,
expected_merged_confidences
)
self
.
assertAllClose
(
np_merged_box_indices
,
expected_merged_box_indices
)
def
testMergeBoxesWithMultipleLabelsCornerCase
(
self
):
boxes
=
tf
.
constant
(
[[
0
,
0
,
1
,
1
],
[
0
,
1
,
1
,
1
],
[
1
,
0
,
1
,
1
],
[
1
,
1
,
1
,
1
],
[
1
,
1
,
1
,
1
],
[
1
,
0
,
1
,
1
],
[
0
,
1
,
1
,
1
],
[
0
,
0
,
1
,
1
]],
dtype
=
tf
.
float32
)
class_indices
=
tf
.
constant
([
0
,
1
,
2
,
3
,
2
,
1
,
0
,
3
],
dtype
=
tf
.
int32
)
class_confidences
=
tf
.
constant
([
0.1
,
0.9
,
0.2
,
0.8
,
0.3
,
0.7
,
0.4
,
0.6
],
dtype
=
tf
.
float32
)
num_classes
=
4
merged_boxes
,
merged_classes
,
merged_confidences
,
merged_box_indices
=
(
ops
.
merge_boxes_with_multiple_labels
(
boxes
,
class_indices
,
class_confidences
,
num_classes
))
expected_merged_boxes
=
np
.
array
(
[[
0
,
0
,
1
,
1
],
[
0
,
1
,
1
,
1
],
[
1
,
0
,
1
,
1
],
[
1
,
1
,
1
,
1
]],
dtype
=
np
.
float32
)
expected_merged_classes
=
np
.
array
(
[[
1
,
0
,
0
,
1
],
[
1
,
1
,
0
,
0
],
[
0
,
1
,
1
,
0
],
[
0
,
0
,
1
,
1
]],
dtype
=
np
.
int32
)
expected_merged_confidences
=
np
.
array
(
[[
0.1
,
0
,
0
,
0.6
],
[
0.4
,
0.9
,
0
,
0
],
[
0
,
0.7
,
0.2
,
0
],
[
0
,
0
,
0.3
,
0.8
]],
dtype
=
np
.
float32
)
expected_merged_box_indices
=
np
.
array
([
0
,
1
,
2
,
3
],
dtype
=
np
.
int32
)
with
self
.
test_session
()
as
sess
:
(
np_merged_boxes
,
np_merged_classes
,
np_merged_confidences
,
np_merged_box_indices
)
=
sess
.
run
(
[
merged_boxes
,
merged_classes
,
merged_confidences
,
merged_box_indices
])
self
.
assertAllClose
(
np_merged_boxes
,
expected_merged_boxes
)
self
.
assertAllClose
(
np_merged_classes
,
expected_merged_classes
)
self
.
assertAllClose
(
np_merged_confidences
,
expected_merged_confidences
)
self
.
assertAllClose
(
np_merged_box_indices
,
expected_merged_box_indices
)
self
.
assertAllClose
(
np_merged_box_indices
,
expected_merged_box_indices
)
def
testMergeBoxesWithEmptyInputs
(
self
):
def
testMergeBoxesWithEmptyInputs
(
self
):
boxes
=
tf
.
constant
([[]])
boxes
=
tf
.
zeros
([
0
,
4
],
dtype
=
tf
.
float32
)
class_indices
=
tf
.
constant
([])
class_indices
=
tf
.
constant
([],
dtype
=
tf
.
int32
)
class_confidences
=
tf
.
constant
([],
dtype
=
tf
.
float32
)
num_classes
=
5
num_classes
=
5
merged_boxes
,
merged_classes
,
merged_box_indices
=
(
merged_boxes
,
merged_classes
,
merged_confidences
,
merged_box_indices
=
(
ops
.
merge_boxes_with_multiple_labels
(
boxes
,
class_indices
,
num_classes
))
ops
.
merge_boxes_with_multiple_labels
(
boxes
,
class_indices
,
class_confidences
,
num_classes
))
with
self
.
test_session
()
as
sess
:
with
self
.
test_session
()
as
sess
:
np_merged_boxes
,
np_merged_classes
,
np_merged_box_indices
=
sess
.
run
(
(
np_merged_boxes
,
np_merged_classes
,
np_merged_confidences
,
[
merged_boxes
,
merged_classes
,
merged_box_indices
])
np_merged_box_indices
)
=
sess
.
run
(
[
merged_boxes
,
merged_classes
,
merged_confidences
,
merged_box_indices
])
self
.
assertAllEqual
(
np_merged_boxes
.
shape
,
[
0
,
4
])
self
.
assertAllEqual
(
np_merged_boxes
.
shape
,
[
0
,
4
])
self
.
assertAllEqual
(
np_merged_classes
.
shape
,
[
0
,
5
])
self
.
assertAllEqual
(
np_merged_classes
.
shape
,
[
0
,
5
])
self
.
assertAllEqual
(
np_merged_confidences
.
shape
,
[
0
,
5
])
self
.
assertAllEqual
(
np_merged_box_indices
.
shape
,
[
0
])
self
.
assertAllEqual
(
np_merged_box_indices
.
shape
,
[
0
])
...
@@ -1268,8 +1308,8 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
...
@@ -1268,8 +1308,8 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
return
ops
.
matmul_crop_and_resize
(
image
,
boxes
,
crop_size
=
[
1
,
1
])
return
ops
.
matmul_crop_and_resize
(
image
,
boxes
,
crop_size
=
[
1
,
1
])
image
=
np
.
array
([[[[
1
],
[
2
]],
[[
3
],
[
4
]]]],
dtype
=
np
.
float32
)
image
=
np
.
array
([[[[
1
],
[
2
]],
[[
3
],
[
4
]]]],
dtype
=
np
.
float32
)
boxes
=
np
.
array
([[
0
,
0
,
1
,
1
]],
dtype
=
np
.
float32
)
boxes
=
np
.
array
([[
[
0
,
0
,
1
,
1
]]
]
,
dtype
=
np
.
float32
)
expected_output
=
[[[[
2.5
]]]]
expected_output
=
[[[[
[
2.5
]]]]
]
crop_output
=
self
.
execute
(
graph_fn
,
[
image
,
boxes
])
crop_output
=
self
.
execute
(
graph_fn
,
[
image
,
boxes
])
self
.
assertAllClose
(
crop_output
,
expected_output
)
self
.
assertAllClose
(
crop_output
,
expected_output
)
...
@@ -1279,8 +1319,8 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
...
@@ -1279,8 +1319,8 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
return
ops
.
matmul_crop_and_resize
(
image
,
boxes
,
crop_size
=
[
1
,
1
])
return
ops
.
matmul_crop_and_resize
(
image
,
boxes
,
crop_size
=
[
1
,
1
])
image
=
np
.
array
([[[[
1
],
[
2
]],
[[
3
],
[
4
]]]],
dtype
=
np
.
float32
)
image
=
np
.
array
([[[[
1
],
[
2
]],
[[
3
],
[
4
]]]],
dtype
=
np
.
float32
)
boxes
=
np
.
array
([[
1
,
1
,
0
,
0
]],
dtype
=
np
.
float32
)
boxes
=
np
.
array
([[
[
1
,
1
,
0
,
0
]]
]
,
dtype
=
np
.
float32
)
expected_output
=
[[[[
2.5
]]]]
expected_output
=
[[[[
[
2.5
]]]]
]
crop_output
=
self
.
execute
(
graph_fn
,
[
image
,
boxes
])
crop_output
=
self
.
execute
(
graph_fn
,
[
image
,
boxes
])
self
.
assertAllClose
(
crop_output
,
expected_output
)
self
.
assertAllClose
(
crop_output
,
expected_output
)
...
@@ -1290,10 +1330,10 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
...
@@ -1290,10 +1330,10 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
return
ops
.
matmul_crop_and_resize
(
image
,
boxes
,
crop_size
=
[
3
,
3
])
return
ops
.
matmul_crop_and_resize
(
image
,
boxes
,
crop_size
=
[
3
,
3
])
image
=
np
.
array
([[[[
1
],
[
2
]],
[[
3
],
[
4
]]]],
dtype
=
np
.
float32
)
image
=
np
.
array
([[[[
1
],
[
2
]],
[[
3
],
[
4
]]]],
dtype
=
np
.
float32
)
boxes
=
np
.
array
([[
0
,
0
,
1
,
1
]],
dtype
=
np
.
float32
)
boxes
=
np
.
array
([[
[
0
,
0
,
1
,
1
]]
]
,
dtype
=
np
.
float32
)
expected_output
=
[[[[
1.0
],
[
1.5
],
[
2.0
]],
expected_output
=
[[[[
[
1.0
],
[
1.5
],
[
2.0
]],
[[
2.0
],
[
2.5
],
[
3.0
]],
[[
2.0
],
[
2.5
],
[
3.0
]],
[[
3.0
],
[
3.5
],
[
4.0
]]]]
[[
3.0
],
[
3.5
],
[
4.0
]]]]
]
crop_output
=
self
.
execute
(
graph_fn
,
[
image
,
boxes
])
crop_output
=
self
.
execute
(
graph_fn
,
[
image
,
boxes
])
self
.
assertAllClose
(
crop_output
,
expected_output
)
self
.
assertAllClose
(
crop_output
,
expected_output
)
...
@@ -1303,10 +1343,10 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
...
@@ -1303,10 +1343,10 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
return
ops
.
matmul_crop_and_resize
(
image
,
boxes
,
crop_size
=
[
3
,
3
])
return
ops
.
matmul_crop_and_resize
(
image
,
boxes
,
crop_size
=
[
3
,
3
])
image
=
np
.
array
([[[[
1
],
[
2
]],
[[
3
],
[
4
]]]],
dtype
=
np
.
float32
)
image
=
np
.
array
([[[[
1
],
[
2
]],
[[
3
],
[
4
]]]],
dtype
=
np
.
float32
)
boxes
=
np
.
array
([[
1
,
1
,
0
,
0
]],
dtype
=
np
.
float32
)
boxes
=
np
.
array
([[
[
1
,
1
,
0
,
0
]]
]
,
dtype
=
np
.
float32
)
expected_output
=
[[[[
4.0
],
[
3.5
],
[
3.0
]],
expected_output
=
[[[[
[
4.0
],
[
3.5
],
[
3.0
]],
[[
3.0
],
[
2.5
],
[
2.0
]],
[[
3.0
],
[
2.5
],
[
2.0
]],
[[
2.0
],
[
1.5
],
[
1.0
]]]]
[[
2.0
],
[
1.5
],
[
1.0
]]]]
]
crop_output
=
self
.
execute
(
graph_fn
,
[
image
,
boxes
])
crop_output
=
self
.
execute
(
graph_fn
,
[
image
,
boxes
])
self
.
assertAllClose
(
crop_output
,
expected_output
)
self
.
assertAllClose
(
crop_output
,
expected_output
)
...
@@ -1318,14 +1358,14 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
...
@@ -1318,14 +1358,14 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
image
=
np
.
array
([[[[
1
],
[
2
],
[
3
]],
image
=
np
.
array
([[[[
1
],
[
2
],
[
3
]],
[[
4
],
[
5
],
[
6
]],
[[
4
],
[
5
],
[
6
]],
[[
7
],
[
8
],
[
9
]]]],
dtype
=
np
.
float32
)
[[
7
],
[
8
],
[
9
]]]],
dtype
=
np
.
float32
)
boxes
=
np
.
array
([[
0
,
0
,
1
,
1
],
boxes
=
np
.
array
([[
[
0
,
0
,
1
,
1
],
[
0
,
0
,
.
5
,
.
5
]],
dtype
=
np
.
float32
)
[
0
,
0
,
.
5
,
.
5
]]
]
,
dtype
=
np
.
float32
)
expected_output
=
[[[[
1
],
[
3
]],
[[
7
],
[
9
]]],
expected_output
=
[[[[
[
1
],
[
3
]],
[[
7
],
[
9
]]],
[[[
1
],
[
2
]],
[[
4
],
[
5
]]]]
[[[
1
],
[
2
]],
[[
4
],
[
5
]]]]
]
crop_output
=
self
.
execute
(
graph_fn
,
[
image
,
boxes
])
crop_output
=
self
.
execute
(
graph_fn
,
[
image
,
boxes
])
self
.
assertAllClose
(
crop_output
,
expected_output
)
self
.
assertAllClose
(
crop_output
,
expected_output
)
def
testMatMulCropAndResize3x3To2x2
Multi
Channel
(
self
):
def
testMatMulCropAndResize3x3To2x2
_2
Channel
s
(
self
):
def
graph_fn
(
image
,
boxes
):
def
graph_fn
(
image
,
boxes
):
return
ops
.
matmul_crop_and_resize
(
image
,
boxes
,
crop_size
=
[
2
,
2
])
return
ops
.
matmul_crop_and_resize
(
image
,
boxes
,
crop_size
=
[
2
,
2
])
...
@@ -1333,10 +1373,32 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
...
@@ -1333,10 +1373,32 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
image
=
np
.
array
([[[[
1
,
0
],
[
2
,
1
],
[
3
,
2
]],
image
=
np
.
array
([[[[
1
,
0
],
[
2
,
1
],
[
3
,
2
]],
[[
4
,
3
],
[
5
,
4
],
[
6
,
5
]],
[[
4
,
3
],
[
5
,
4
],
[
6
,
5
]],
[[
7
,
6
],
[
8
,
7
],
[
9
,
8
]]]],
dtype
=
np
.
float32
)
[[
7
,
6
],
[
8
,
7
],
[
9
,
8
]]]],
dtype
=
np
.
float32
)
boxes
=
np
.
array
([[
0
,
0
,
1
,
1
],
boxes
=
np
.
array
([[[
0
,
0
,
1
,
1
],
[
0
,
0
,
.
5
,
.
5
]],
dtype
=
np
.
float32
)
[
0
,
0
,
.
5
,
.
5
]]],
dtype
=
np
.
float32
)
expected_output
=
[[[[
1
,
0
],
[
3
,
2
]],
[[
7
,
6
],
[
9
,
8
]]],
expected_output
=
[[[[[
1
,
0
],
[
3
,
2
]],
[[
7
,
6
],
[
9
,
8
]]],
[[[
1
,
0
],
[
2
,
1
]],
[[
4
,
3
],
[
5
,
4
]]]]
[[[
1
,
0
],
[
2
,
1
]],
[[
4
,
3
],
[
5
,
4
]]]]]
crop_output
=
self
.
execute
(
graph_fn
,
[
image
,
boxes
])
self
.
assertAllClose
(
crop_output
,
expected_output
)
def
testBatchMatMulCropAndResize3x3To2x2_2Channels
(
self
):
def
graph_fn
(
image
,
boxes
):
return
ops
.
matmul_crop_and_resize
(
image
,
boxes
,
crop_size
=
[
2
,
2
])
image
=
np
.
array
([[[[
1
,
0
],
[
2
,
1
],
[
3
,
2
]],
[[
4
,
3
],
[
5
,
4
],
[
6
,
5
]],
[[
7
,
6
],
[
8
,
7
],
[
9
,
8
]]],
[[[
1
,
0
],
[
2
,
1
],
[
3
,
2
]],
[[
4
,
3
],
[
5
,
4
],
[
6
,
5
]],
[[
7
,
6
],
[
8
,
7
],
[
9
,
8
]]]],
dtype
=
np
.
float32
)
boxes
=
np
.
array
([[[
0
,
0
,
1
,
1
],
[
0
,
0
,
.
5
,
.
5
]],
[[
1
,
1
,
0
,
0
],
[.
5
,
.
5
,
0
,
0
]]],
dtype
=
np
.
float32
)
expected_output
=
[[[[[
1
,
0
],
[
3
,
2
]],
[[
7
,
6
],
[
9
,
8
]]],
[[[
1
,
0
],
[
2
,
1
]],
[[
4
,
3
],
[
5
,
4
]]]],
[[[[
9
,
8
],
[
7
,
6
]],
[[
3
,
2
],
[
1
,
0
]]],
[[[
5
,
4
],
[
4
,
3
]],
[[
2
,
1
],
[
1
,
0
]]]]]
crop_output
=
self
.
execute
(
graph_fn
,
[
image
,
boxes
])
crop_output
=
self
.
execute
(
graph_fn
,
[
image
,
boxes
])
self
.
assertAllClose
(
crop_output
,
expected_output
)
self
.
assertAllClose
(
crop_output
,
expected_output
)
...
@@ -1348,10 +1410,10 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
...
@@ -1348,10 +1410,10 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
image
=
np
.
array
([[[[
1
],
[
2
],
[
3
]],
image
=
np
.
array
([[[[
1
],
[
2
],
[
3
]],
[[
4
],
[
5
],
[
6
]],
[[
4
],
[
5
],
[
6
]],
[[
7
],
[
8
],
[
9
]]]],
dtype
=
np
.
float32
)
[[
7
],
[
8
],
[
9
]]]],
dtype
=
np
.
float32
)
boxes
=
np
.
array
([[
1
,
1
,
0
,
0
],
boxes
=
np
.
array
([[
[
1
,
1
,
0
,
0
],
[.
5
,
.
5
,
0
,
0
]],
dtype
=
np
.
float32
)
[.
5
,
.
5
,
0
,
0
]]
]
,
dtype
=
np
.
float32
)
expected_output
=
[[[[
9
],
[
7
]],
[[
3
],
[
1
]]],
expected_output
=
[[[[
[
9
],
[
7
]],
[[
3
],
[
1
]]],
[[[
5
],
[
4
]],
[[
2
],
[
1
]]]]
[[[
5
],
[
4
]],
[[
2
],
[
1
]]]]
]
crop_output
=
self
.
execute
(
graph_fn
,
[
image
,
boxes
])
crop_output
=
self
.
execute
(
graph_fn
,
[
image
,
boxes
])
self
.
assertAllClose
(
crop_output
,
expected_output
)
self
.
assertAllClose
(
crop_output
,
expected_output
)
...
@@ -1363,6 +1425,31 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
...
@@ -1363,6 +1425,31 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
_
=
ops
.
matmul_crop_and_resize
(
image
,
boxes
,
crop_size
)
_
=
ops
.
matmul_crop_and_resize
(
image
,
boxes
,
crop_size
)
class
OpsTestCropAndResize
(
test_case
.
TestCase
):
def
testBatchCropAndResize3x3To2x2_2Channels
(
self
):
def
graph_fn
(
image
,
boxes
):
return
ops
.
native_crop_and_resize
(
image
,
boxes
,
crop_size
=
[
2
,
2
])
image
=
np
.
array
([[[[
1
,
0
],
[
2
,
1
],
[
3
,
2
]],
[[
4
,
3
],
[
5
,
4
],
[
6
,
5
]],
[[
7
,
6
],
[
8
,
7
],
[
9
,
8
]]],
[[[
1
,
0
],
[
2
,
1
],
[
3
,
2
]],
[[
4
,
3
],
[
5
,
4
],
[
6
,
5
]],
[[
7
,
6
],
[
8
,
7
],
[
9
,
8
]]]],
dtype
=
np
.
float32
)
boxes
=
np
.
array
([[[
0
,
0
,
1
,
1
],
[
0
,
0
,
.
5
,
.
5
]],
[[
1
,
1
,
0
,
0
],
[.
5
,
.
5
,
0
,
0
]]],
dtype
=
np
.
float32
)
expected_output
=
[[[[[
1
,
0
],
[
3
,
2
]],
[[
7
,
6
],
[
9
,
8
]]],
[[[
1
,
0
],
[
2
,
1
]],
[[
4
,
3
],
[
5
,
4
]]]],
[[[[
9
,
8
],
[
7
,
6
]],
[[
3
,
2
],
[
1
,
0
]]],
[[[
5
,
4
],
[
4
,
3
]],
[[
2
,
1
],
[
1
,
0
]]]]]
crop_output
=
self
.
execute_cpu
(
graph_fn
,
[
image
,
boxes
])
self
.
assertAllClose
(
crop_output
,
expected_output
)
class
OpsTestExpectedClassificationLoss
(
test_case
.
TestCase
):
class
OpsTestExpectedClassificationLoss
(
test_case
.
TestCase
):
def
testExpectedClassificationLossUnderSamplingWithHardLabels
(
self
):
def
testExpectedClassificationLossUnderSamplingWithHardLabels
(
self
):
...
...
research/object_detection/utils/shape_utils.py
View file @
27b4acd4
...
@@ -342,3 +342,26 @@ def assert_shape_equal_along_first_dimension(shape_a, shape_b):
...
@@ -342,3 +342,26 @@ def assert_shape_equal_along_first_dimension(shape_a, shape_b):
else
:
return
tf
.
no_op
()
else
:
return
tf
.
no_op
()
else
:
else
:
return
tf
.
assert_equal
(
shape_a
[
0
],
shape_b
[
0
])
return
tf
.
assert_equal
(
shape_a
[
0
],
shape_b
[
0
])
def
assert_box_normalized
(
boxes
,
maximum_normalized_coordinate
=
1.1
):
"""Asserts the input box tensor is normalized.
Args:
boxes: a tensor of shape [N, 4] where N is the number of boxes.
maximum_normalized_coordinate: Maximum coordinate value to be considered
as normalized, default to 1.1.
Returns:
a tf.Assert op which fails when the input box tensor is not normalized.
Raises:
ValueError: When the input box tensor is not normalized.
"""
box_minimum
=
tf
.
reduce_min
(
boxes
)
box_maximum
=
tf
.
reduce_max
(
boxes
)
return
tf
.
Assert
(
tf
.
logical_and
(
tf
.
less_equal
(
box_maximum
,
maximum_normalized_coordinate
),
tf
.
greater_equal
(
box_minimum
,
0
)),
[
boxes
])
research/object_detection/utils/test_case.py
View file @
27b4acd4
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
# ==============================================================================
# ==============================================================================
"""A convenience wrapper around tf.test.TestCase to enable TPU tests."""
"""A convenience wrapper around tf.test.TestCase to enable TPU tests."""
import
os
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.contrib
import
tpu
from
tensorflow.contrib
import
tpu
...
@@ -23,6 +24,8 @@ flags.DEFINE_bool('tpu_test', False, 'Whether to configure test for TPU.')
...
@@ -23,6 +24,8 @@ flags.DEFINE_bool('tpu_test', False, 'Whether to configure test for TPU.')
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
class
TestCase
(
tf
.
test
.
TestCase
):
class
TestCase
(
tf
.
test
.
TestCase
):
"""Extends tf.test.TestCase to optionally allow running tests on TPU."""
"""Extends tf.test.TestCase to optionally allow running tests on TPU."""
...
...
research/object_detection/utils/test_utils.py
View file @
27b4acd4
...
@@ -24,6 +24,9 @@ from object_detection.core import box_predictor
...
@@ -24,6 +24,9 @@ from object_detection.core import box_predictor
from
object_detection.core
import
matcher
from
object_detection.core
import
matcher
from
object_detection.utils
import
shape_utils
from
object_detection.utils
import
shape_utils
# Default size (both width and height) used for testing mask predictions.
DEFAULT_MASK_SIZE
=
5
class
MockBoxCoder
(
box_coder
.
BoxCoder
):
class
MockBoxCoder
(
box_coder
.
BoxCoder
):
"""Simple `difference` BoxCoder."""
"""Simple `difference` BoxCoder."""
...
@@ -42,8 +45,9 @@ class MockBoxCoder(box_coder.BoxCoder):
...
@@ -42,8 +45,9 @@ class MockBoxCoder(box_coder.BoxCoder):
class
MockBoxPredictor
(
box_predictor
.
BoxPredictor
):
class
MockBoxPredictor
(
box_predictor
.
BoxPredictor
):
"""Simple box predictor that ignores inputs and outputs all zeros."""
"""Simple box predictor that ignores inputs and outputs all zeros."""
def
__init__
(
self
,
is_training
,
num_classes
):
def
__init__
(
self
,
is_training
,
num_classes
,
predict_mask
=
False
):
super
(
MockBoxPredictor
,
self
).
__init__
(
is_training
,
num_classes
)
super
(
MockBoxPredictor
,
self
).
__init__
(
is_training
,
num_classes
)
self
.
_predict_mask
=
predict_mask
def
_predict
(
self
,
image_features
,
num_predictions_per_location
):
def
_predict
(
self
,
image_features
,
num_predictions_per_location
):
image_feature
=
image_features
[
0
]
image_feature
=
image_features
[
0
]
...
@@ -57,17 +61,29 @@ class MockBoxPredictor(box_predictor.BoxPredictor):
...
@@ -57,17 +61,29 @@ class MockBoxPredictor(box_predictor.BoxPredictor):
(
batch_size
,
num_anchors
,
1
,
code_size
),
dtype
=
tf
.
float32
)
(
batch_size
,
num_anchors
,
1
,
code_size
),
dtype
=
tf
.
float32
)
class_predictions_with_background
=
zero
+
tf
.
zeros
(
class_predictions_with_background
=
zero
+
tf
.
zeros
(
(
batch_size
,
num_anchors
,
self
.
num_classes
+
1
),
dtype
=
tf
.
float32
)
(
batch_size
,
num_anchors
,
self
.
num_classes
+
1
),
dtype
=
tf
.
float32
)
return
{
box_predictor
.
BOX_ENCODINGS
:
box_encodings
,
masks
=
zero
+
tf
.
zeros
(
box_predictor
.
CLASS_PREDICTIONS_WITH_BACKGROUND
:
(
batch_size
,
num_anchors
,
self
.
num_classes
,
DEFAULT_MASK_SIZE
,
class_predictions_with_background
}
DEFAULT_MASK_SIZE
),
dtype
=
tf
.
float32
)
predictions_dict
=
{
box_predictor
.
BOX_ENCODINGS
:
box_encodings
,
box_predictor
.
CLASS_PREDICTIONS_WITH_BACKGROUND
:
class_predictions_with_background
}
if
self
.
_predict_mask
:
predictions_dict
[
box_predictor
.
MASK_PREDICTIONS
]
=
masks
return
predictions_dict
class
MockKerasBoxPredictor
(
box_predictor
.
KerasBoxPredictor
):
class
MockKerasBoxPredictor
(
box_predictor
.
KerasBoxPredictor
):
"""Simple box predictor that ignores inputs and outputs all zeros."""
"""Simple box predictor that ignores inputs and outputs all zeros."""
def
__init__
(
self
,
is_training
,
num_classes
):
def
__init__
(
self
,
is_training
,
num_classes
,
predict_mask
=
False
):
super
(
MockKerasBoxPredictor
,
self
).
__init__
(
super
(
MockKerasBoxPredictor
,
self
).
__init__
(
is_training
,
num_classes
,
False
,
False
)
is_training
,
num_classes
,
False
,
False
)
self
.
_predict_mask
=
predict_mask
def
_predict
(
self
,
image_features
,
**
kwargs
):
def
_predict
(
self
,
image_features
,
**
kwargs
):
image_feature
=
image_features
[
0
]
image_feature
=
image_features
[
0
]
...
@@ -81,9 +97,19 @@ class MockKerasBoxPredictor(box_predictor.KerasBoxPredictor):
...
@@ -81,9 +97,19 @@ class MockKerasBoxPredictor(box_predictor.KerasBoxPredictor):
(
batch_size
,
num_anchors
,
1
,
code_size
),
dtype
=
tf
.
float32
)
(
batch_size
,
num_anchors
,
1
,
code_size
),
dtype
=
tf
.
float32
)
class_predictions_with_background
=
zero
+
tf
.
zeros
(
class_predictions_with_background
=
zero
+
tf
.
zeros
(
(
batch_size
,
num_anchors
,
self
.
num_classes
+
1
),
dtype
=
tf
.
float32
)
(
batch_size
,
num_anchors
,
self
.
num_classes
+
1
),
dtype
=
tf
.
float32
)
return
{
box_predictor
.
BOX_ENCODINGS
:
box_encodings
,
masks
=
zero
+
tf
.
zeros
(
box_predictor
.
CLASS_PREDICTIONS_WITH_BACKGROUND
:
(
batch_size
,
num_anchors
,
self
.
num_classes
,
DEFAULT_MASK_SIZE
,
class_predictions_with_background
}
DEFAULT_MASK_SIZE
),
dtype
=
tf
.
float32
)
predictions_dict
=
{
box_predictor
.
BOX_ENCODINGS
:
box_encodings
,
box_predictor
.
CLASS_PREDICTIONS_WITH_BACKGROUND
:
class_predictions_with_background
}
if
self
.
_predict_mask
:
predictions_dict
[
box_predictor
.
MASK_PREDICTIONS
]
=
masks
return
predictions_dict
class
MockAnchorGenerator
(
anchor_generator
.
AnchorGenerator
):
class
MockAnchorGenerator
(
anchor_generator
.
AnchorGenerator
):
...
@@ -103,7 +129,7 @@ class MockAnchorGenerator(anchor_generator.AnchorGenerator):
...
@@ -103,7 +129,7 @@ class MockAnchorGenerator(anchor_generator.AnchorGenerator):
class
MockMatcher
(
matcher
.
Matcher
):
class
MockMatcher
(
matcher
.
Matcher
):
"""Simple matcher that matches first anchor to first groundtruth box."""
"""Simple matcher that matches first anchor to first groundtruth box."""
def
_match
(
self
,
similarity_matrix
):
def
_match
(
self
,
similarity_matrix
,
valid_rows
):
return
tf
.
constant
([
0
,
-
1
,
-
1
,
-
1
],
dtype
=
tf
.
int32
)
return
tf
.
constant
([
0
,
-
1
,
-
1
,
-
1
],
dtype
=
tf
.
int32
)
...
...
research/object_detection/utils/visualization_utils.py
View file @
27b4acd4
...
@@ -19,6 +19,8 @@ These functions often receive an image, perform some visualization on the image.
...
@@ -19,6 +19,8 @@ These functions often receive an image, perform some visualization on the image.
The functions do not return a value, instead they modify the image itself.
The functions do not return a value, instead they modify the image itself.
"""
"""
from
abc
import
ABCMeta
from
abc
import
abstractmethod
import
collections
import
collections
import
functools
import
functools
# Set headless-friendly backend.
# Set headless-friendly backend.
...
@@ -731,3 +733,158 @@ def add_hist_image_summary(values, bins, name):
...
@@ -731,3 +733,158 @@ def add_hist_image_summary(values, bins, name):
return
image
return
image
hist_plot
=
tf
.
py_func
(
hist_plot
,
[
values
,
bins
],
tf
.
uint8
)
hist_plot
=
tf
.
py_func
(
hist_plot
,
[
values
,
bins
],
tf
.
uint8
)
tf
.
summary
.
image
(
name
,
hist_plot
)
tf
.
summary
.
image
(
name
,
hist_plot
)
class
EvalMetricOpsVisualization
(
object
):
"""Abstract base class responsible for visualizations during evaluation.
Currently, summary images are not run during evaluation. One way to produce
evaluation images in Tensorboard is to provide tf.summary.image strings as
`value_ops` in tf.estimator.EstimatorSpec's `eval_metric_ops`. This class is
responsible for accruing images (with overlaid detections and groundtruth)
and returning a dictionary that can be passed to `eval_metric_ops`.
"""
__metaclass__
=
ABCMeta
def
__init__
(
self
,
category_index
,
max_examples_to_draw
=
5
,
max_boxes_to_draw
=
20
,
min_score_thresh
=
0.2
,
use_normalized_coordinates
=
True
,
summary_name_prefix
=
'evaluation_image'
):
"""Creates an EvalMetricOpsVisualization.
Args:
category_index: A category index (dictionary) produced from a labelmap.
max_examples_to_draw: The maximum number of example summaries to produce.
max_boxes_to_draw: The maximum number of boxes to draw for detections.
min_score_thresh: The minimum score threshold for showing detections.
use_normalized_coordinates: Whether to assume boxes and kepoints are in
normalized coordinates (as opposed to absolute coordiantes).
Default is True.
summary_name_prefix: A string prefix for each image summary.
"""
self
.
_category_index
=
category_index
self
.
_max_examples_to_draw
=
max_examples_to_draw
self
.
_max_boxes_to_draw
=
max_boxes_to_draw
self
.
_min_score_thresh
=
min_score_thresh
self
.
_use_normalized_coordinates
=
use_normalized_coordinates
self
.
_summary_name_prefix
=
summary_name_prefix
self
.
_images
=
[]
def
clear
(
self
):
self
.
_images
=
[]
def
add_images
(
self
,
images
):
"""Store a list of images, each with shape [1, H, W, C]."""
if
len
(
self
.
_images
)
>=
self
.
_max_examples_to_draw
:
return
# Store images and clip list if necessary.
self
.
_images
.
extend
(
images
)
if
len
(
self
.
_images
)
>
self
.
_max_examples_to_draw
:
self
.
_images
[
self
.
_max_examples_to_draw
:]
=
[]
def
get_estimator_eval_metric_ops
(
self
,
eval_dict
):
"""Returns metric ops for use in tf.estimator.EstimatorSpec.
Args:
eval_dict: A dictionary that holds an image, groundtruth, and detections
for a single example. See eval_util.result_dict_for_single_example() for
a convenient method for constructing such a dictionary. The dictionary
contains
fields.InputDataFields.original_image: [1, H, W, 3] image.
fields.InputDataFields.groundtruth_boxes - [num_boxes, 4] float32
tensor with groundtruth boxes in range [0.0, 1.0].
fields.InputDataFields.groundtruth_classes - [num_boxes] int64
tensor with 1-indexed groundtruth classes.
fields.InputDataFields.groundtruth_instance_masks - (optional)
[num_boxes, H, W] int64 tensor with instance masks.
fields.DetectionResultFields.detection_boxes - [max_num_boxes, 4]
float32 tensor with detection boxes in range [0.0, 1.0].
fields.DetectionResultFields.detection_classes - [max_num_boxes]
int64 tensor with 1-indexed detection classes.
fields.DetectionResultFields.detection_scores - [max_num_boxes]
float32 tensor with detection scores.
fields.DetectionResultFields.detection_masks - (optional)
[max_num_boxes, H, W] float32 tensor of binarized masks.
fields.DetectionResultFields.detection_keypoints - (optional)
[max_num_boxes, num_keypoints, 2] float32 tensor with keypooints.
Returns:
A dictionary of image summary names to tuple of (value_op, update_op). The
`update_op` is the same for all items in the dictionary, and is
responsible for saving a single side-by-side image with detections and
groundtruth. Each `value_op` holds the tf.summary.image string for a given
image.
"""
images
=
self
.
images_from_evaluation_dict
(
eval_dict
)
def
get_images
():
"""Returns a list of images, padded to self._max_images_to_draw."""
images
=
self
.
_images
while
len
(
images
)
<
self
.
_max_examples_to_draw
:
images
.
append
(
np
.
array
(
0
,
dtype
=
np
.
uint8
))
self
.
clear
()
return
images
def
image_summary_or_default_string
(
summary_name
,
image
):
"""Returns image summaries for non-padded elements."""
return
tf
.
cond
(
tf
.
equal
(
tf
.
size
(
tf
.
shape
(
image
)),
4
),
lambda
:
tf
.
summary
.
image
(
summary_name
,
image
),
lambda
:
tf
.
constant
(
''
))
update_op
=
tf
.
py_func
(
self
.
add_images
,
[
images
],
[])
image_tensors
=
tf
.
py_func
(
get_images
,
[],
[
tf
.
uint8
]
*
self
.
_max_examples_to_draw
)
eval_metric_ops
=
{}
for
i
,
image
in
enumerate
(
image_tensors
):
summary_name
=
self
.
_summary_name_prefix
+
'/'
+
str
(
i
)
value_op
=
image_summary_or_default_string
(
summary_name
,
image
)
eval_metric_ops
[
summary_name
]
=
(
value_op
,
update_op
)
return
eval_metric_ops
@
abstractmethod
def
images_from_evaluation_dict
(
self
,
eval_dict
):
"""Converts evaluation dictionary into a list of image tensors.
To be overridden by implementations.
Args:
eval_dict: A dictionary with all the necessary information for producing
visualizations.
Returns:
A list of [1, H, W, C] uint8 tensors.
"""
raise
NotImplementedError
class
VisualizeSingleFrameDetections
(
EvalMetricOpsVisualization
):
"""Class responsible for single-frame object detection visualizations."""
def
__init__
(
self
,
category_index
,
max_examples_to_draw
=
5
,
max_boxes_to_draw
=
20
,
min_score_thresh
=
0.2
,
use_normalized_coordinates
=
True
,
summary_name_prefix
=
'Detections_Left_Groundtruth_Right'
):
super
(
VisualizeSingleFrameDetections
,
self
).
__init__
(
category_index
=
category_index
,
max_examples_to_draw
=
max_examples_to_draw
,
max_boxes_to_draw
=
max_boxes_to_draw
,
min_score_thresh
=
min_score_thresh
,
use_normalized_coordinates
=
use_normalized_coordinates
,
summary_name_prefix
=
summary_name_prefix
)
def
images_from_evaluation_dict
(
self
,
eval_dict
):
return
[
draw_side_by_side_evaluation_image
(
eval_dict
,
self
.
_category_index
,
self
.
_max_boxes_to_draw
,
self
.
_min_score_thresh
,
self
.
_use_normalized_coordinates
)]
research/object_detection/utils/visualization_utils_test.py
View file @
27b4acd4
...
@@ -21,6 +21,7 @@ import numpy as np
...
@@ -21,6 +21,7 @@ import numpy as np
import
PIL.Image
as
Image
import
PIL.Image
as
Image
import
tensorflow
as
tf
import
tensorflow
as
tf
from
object_detection.core
import
standard_fields
as
fields
from
object_detection.utils
import
visualization_utils
from
object_detection.utils
import
visualization_utils
_TESTDATA_PATH
=
'object_detection/test_images'
_TESTDATA_PATH
=
'object_detection/test_images'
...
@@ -225,6 +226,80 @@ class VisualizationUtilsTest(tf.test.TestCase):
...
@@ -225,6 +226,80 @@ class VisualizationUtilsTest(tf.test.TestCase):
with
self
.
test_session
():
with
self
.
test_session
():
hist_image_summary
.
eval
()
hist_image_summary
.
eval
()
def
test_eval_metric_ops
(
self
):
category_index
=
{
1
:
{
'id'
:
1
,
'name'
:
'dog'
},
2
:
{
'id'
:
2
,
'name'
:
'cat'
}}
max_examples_to_draw
=
4
metric_op_base
=
'Detections_Left_Groundtruth_Right'
eval_metric_ops
=
visualization_utils
.
VisualizeSingleFrameDetections
(
category_index
,
max_examples_to_draw
=
max_examples_to_draw
,
summary_name_prefix
=
metric_op_base
)
original_image
=
tf
.
placeholder
(
tf
.
uint8
,
[
1
,
None
,
None
,
3
])
detection_boxes
=
tf
.
random_uniform
([
20
,
4
],
minval
=
0.0
,
maxval
=
1.0
,
dtype
=
tf
.
float32
)
detection_classes
=
tf
.
random_uniform
([
20
],
minval
=
1
,
maxval
=
3
,
dtype
=
tf
.
int64
)
detection_scores
=
tf
.
random_uniform
([
20
],
minval
=
0.
,
maxval
=
1.
,
dtype
=
tf
.
float32
)
groundtruth_boxes
=
tf
.
random_uniform
([
8
,
4
],
minval
=
0.0
,
maxval
=
1.0
,
dtype
=
tf
.
float32
)
groundtruth_classes
=
tf
.
random_uniform
([
8
],
minval
=
1
,
maxval
=
3
,
dtype
=
tf
.
int64
)
eval_dict
=
{
fields
.
DetectionResultFields
.
detection_boxes
:
detection_boxes
,
fields
.
DetectionResultFields
.
detection_classes
:
detection_classes
,
fields
.
DetectionResultFields
.
detection_scores
:
detection_scores
,
fields
.
InputDataFields
.
original_image
:
original_image
,
fields
.
InputDataFields
.
groundtruth_boxes
:
groundtruth_boxes
,
fields
.
InputDataFields
.
groundtruth_classes
:
groundtruth_classes
}
metric_ops
=
eval_metric_ops
.
get_estimator_eval_metric_ops
(
eval_dict
)
_
,
update_op
=
metric_ops
[
metric_ops
.
keys
()[
0
]]
with
self
.
test_session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
value_ops
=
{}
for
key
,
(
value_op
,
_
)
in
metric_ops
.
iteritems
():
value_ops
[
key
]
=
value_op
# First run enough update steps to surpass `max_examples_to_draw`.
for
i
in
range
(
max_examples_to_draw
):
# Use a unique image shape on each eval image.
sess
.
run
(
update_op
,
feed_dict
=
{
original_image
:
np
.
random
.
randint
(
low
=
0
,
high
=
256
,
size
=
(
1
,
6
+
i
,
7
+
i
,
3
),
dtype
=
np
.
uint8
)
})
value_ops_out
=
sess
.
run
(
value_ops
)
for
key
,
value_op
in
value_ops_out
.
iteritems
():
self
.
assertNotEqual
(
''
,
value_op
)
# Now run fewer update steps than `max_examples_to_draw`. A single value
# op will be the empty string, since not enough image summaries can be
# produced.
for
i
in
range
(
max_examples_to_draw
-
1
):
# Use a unique image shape on each eval image.
sess
.
run
(
update_op
,
feed_dict
=
{
original_image
:
np
.
random
.
randint
(
low
=
0
,
high
=
256
,
size
=
(
1
,
6
+
i
,
7
+
i
,
3
),
dtype
=
np
.
uint8
)
})
value_ops_out
=
sess
.
run
(
value_ops
)
self
.
assertEqual
(
''
,
value_ops_out
[
metric_op_base
+
'/'
+
str
(
max_examples_to_draw
-
1
)])
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Prev
1
…
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment