Commit 27b4acd4 authored by Aman Gupta's avatar Aman Gupta
Browse files

Merge remote-tracking branch 'upstream/master'

parents 5133522f d4e1f97f
...@@ -4,17 +4,18 @@ package object_detection.protos; ...@@ -4,17 +4,18 @@ package object_detection.protos;
// Message for configuring DetectionModel evaluation jobs (eval.py). // Message for configuring DetectionModel evaluation jobs (eval.py).
message EvalConfig { message EvalConfig {
optional uint32 batch_size = 25 [default=1];
// Number of visualization images to generate. // Number of visualization images to generate.
optional uint32 num_visualizations = 1 [default=10]; optional uint32 num_visualizations = 1 [default=10];
// Number of examples to process of evaluation. // Number of examples to process of evaluation.
optional uint32 num_examples = 2 [default=5000]; optional uint32 num_examples = 2 [default=5000, deprecated=true];
// How often to run evaluation. // How often to run evaluation.
optional uint32 eval_interval_secs = 3 [default=300]; optional uint32 eval_interval_secs = 3 [default=300];
// Maximum number of times to run evaluation. If set to 0, will run forever. // Maximum number of times to run evaluation. If set to 0, will run forever.
optional uint32 max_evals = 4 [default=0]; optional uint32 max_evals = 4 [default=0, deprecated=true];
// Whether the TensorFlow graph used for evaluation should be saved to disk. // Whether the TensorFlow graph used for evaluation should be saved to disk.
optional bool save_graph = 5 [default=false]; optional bool save_graph = 5 [default=false];
......
...@@ -157,6 +157,13 @@ message FasterRcnn { ...@@ -157,6 +157,13 @@ message FasterRcnn {
// Whether to use the balanced positive negative sampler implementation with // Whether to use the balanced positive negative sampler implementation with
// static shape guarantees. // static shape guarantees.
optional bool use_static_balanced_label_sampler = 34 [default = false]; optional bool use_static_balanced_label_sampler = 34 [default = false];
// If True, uses implementation of ops with static shape guarantees.
optional bool use_static_shapes = 35 [default = false];
// Whether the masks present in groundtruth should be resized in the model to
// match the image size.
optional bool resize_masks = 36 [default = true];
} }
......
...@@ -22,7 +22,12 @@ enum InstanceMaskType { ...@@ -22,7 +22,12 @@ enum InstanceMaskType {
PNG_MASKS = 2; // Encoded PNG masks. PNG_MASKS = 2; // Encoded PNG masks.
} }
// Next id: 24
message InputReader { message InputReader {
// Name of input reader. Typically used to describe the dataset that is read
// by this input reader.
optional string name = 23 [default=""];
// Path to StringIntLabelMap pbtxt file specifying the mapping from string // Path to StringIntLabelMap pbtxt file specifying the mapping from string
// labels to integer ids. // labels to integer ids.
optional string label_map_path = 1 [default=""]; optional string label_map_path = 1 [default=""];
...@@ -41,6 +46,12 @@ message InputReader { ...@@ -41,6 +46,12 @@ message InputReader {
// will be reused indefinitely. // will be reused indefinitely.
optional uint32 num_epochs = 5 [default=0]; optional uint32 num_epochs = 5 [default=0];
// Integer representing how often an example should be sampled. To feed
// only 1/3 of your data into your model, set `sample_1_of_n_examples` to 3.
// This is particularly useful for evaluation, where you might not prefer to
// evaluate all of your samples.
optional uint32 sample_1_of_n_examples = 22 [default=1];
// Number of file shards to read in parallel. // Number of file shards to read in parallel.
optional uint32 num_readers = 6 [default=64]; optional uint32 num_readers = 6 [default=64];
...@@ -62,7 +73,6 @@ message InputReader { ...@@ -62,7 +73,6 @@ message InputReader {
// to generate a good random shuffle. // to generate a good random shuffle.
optional uint32 min_after_dequeue = 4 [default=1000, deprecated=true]; optional uint32 min_after_dequeue = 4 [default=1000, deprecated=true];
// Number of records to read from each reader at once. // Number of records to read from each reader at once.
optional uint32 read_block_length = 15 [default=32]; optional uint32 read_block_length = 15 [default=32];
......
...@@ -10,12 +10,13 @@ import "object_detection/protos/train.proto"; ...@@ -10,12 +10,13 @@ import "object_detection/protos/train.proto";
// Convenience message for configuring a training and eval pipeline. Allows all // Convenience message for configuring a training and eval pipeline. Allows all
// of the pipeline parameters to be configured from one file. // of the pipeline parameters to be configured from one file.
// Next id: 7
message TrainEvalPipelineConfig { message TrainEvalPipelineConfig {
optional DetectionModel model = 1; optional DetectionModel model = 1;
optional TrainConfig train_config = 2; optional TrainConfig train_config = 2;
optional InputReader train_input_reader = 3; optional InputReader train_input_reader = 3;
optional EvalConfig eval_config = 4; optional EvalConfig eval_config = 4;
optional InputReader eval_input_reader = 5; repeated InputReader eval_input_reader = 5;
optional GraphRewriter graph_rewriter = 6; optional GraphRewriter graph_rewriter = 6;
extensions 1000 to max; extensions 1000 to max;
} }
...@@ -17,6 +17,9 @@ message BatchNonMaxSuppression { ...@@ -17,6 +17,9 @@ message BatchNonMaxSuppression {
// Maximum number of detections to retain across all classes. // Maximum number of detections to retain across all classes.
optional int32 max_total_detections = 5 [default = 100]; optional int32 max_total_detections = 5 [default = 100];
// Whether to use the implementation of NMS that guarantees static shapes.
optional bool use_static_shapes = 6 [default = false];
} }
// Configuration proto for post-processing predicted boxes and // Configuration proto for post-processing predicted boxes and
......
...@@ -163,5 +163,8 @@ message FeaturePyramidNetworks { ...@@ -163,5 +163,8 @@ message FeaturePyramidNetworks {
// maximum level in feature pyramid // maximum level in feature pyramid
optional int32 max_level = 2 [default = 7]; optional int32 max_level = 2 [default = 7];
// channel depth for additional coarse feature layers.
optional int32 additional_layer_depth = 3 [default = 256];
} }
...@@ -6,7 +6,7 @@ import "object_detection/protos/optimizer.proto"; ...@@ -6,7 +6,7 @@ import "object_detection/protos/optimizer.proto";
import "object_detection/protos/preprocessor.proto"; import "object_detection/protos/preprocessor.proto";
// Message for configuring DetectionModel training jobs (train.py). // Message for configuring DetectionModel training jobs (train.py).
// Next id: 26 // Next id: 27
message TrainConfig { message TrainConfig {
// Effective batch size to use for training. // Effective batch size to use for training.
// For TPU (or sync SGD jobs), the batch size per core (or GPU) is going to be // For TPU (or sync SGD jobs), the batch size per core (or GPU) is going to be
...@@ -112,4 +112,7 @@ message TrainConfig { ...@@ -112,4 +112,7 @@ message TrainConfig {
// dictionary, so that they can be displayed in Tensorboard. Note that this // dictionary, so that they can be displayed in Tensorboard. Note that this
// will lead to a larger memory footprint. // will lead to a larger memory footprint.
optional bool retain_original_images = 23 [default=false]; optional bool retain_original_images = 23 [default=false];
// Whether to use bfloat16 for training.
optional bool use_bfloat16 = 26 [default=false];
} }
# Faster R-CNN with Resnet-101 (v1)
# Users should configure the fine_tune_checkpoint field in the train config as
# well as the label_map_path and input_path fields in the train_input_reader and
# eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that
# should be configured.
model {
faster_rcnn {
num_classes: 2854
image_resizer {
keep_aspect_ratio_resizer {
min_dimension: 600
max_dimension: 1024
}
}
feature_extractor {
type: 'faster_rcnn_resnet101'
first_stage_features_stride: 16
}
first_stage_anchor_generator {
grid_anchor_generator {
scales: [0.25, 0.5, 1.0, 2.0]
aspect_ratios: [0.5, 1.0, 2.0]
height_stride: 16
width_stride: 16
}
}
first_stage_box_predictor_conv_hyperparams {
op: CONV
regularizer {
l2_regularizer {
weight: 0.0
}
}
initializer {
truncated_normal_initializer {
stddev: 0.01
}
}
}
first_stage_nms_score_threshold: 0.0
first_stage_nms_iou_threshold: 0.7
first_stage_max_proposals: 32
first_stage_localization_loss_weight: 2.0
first_stage_objectness_loss_weight: 1.0
initial_crop_size: 14
maxpool_kernel_size: 2
maxpool_stride: 2
second_stage_batch_size: 32
second_stage_box_predictor {
mask_rcnn_box_predictor {
use_dropout: false
dropout_keep_probability: 1.0
fc_hyperparams {
op: FC
regularizer {
l2_regularizer {
weight: 0.0
}
}
initializer {
variance_scaling_initializer {
factor: 1.0
uniform: true
mode: FAN_AVG
}
}
}
}
}
second_stage_post_processing {
batch_non_max_suppression {
score_threshold: 0.0
iou_threshold: 0.6
max_detections_per_class: 5
max_total_detections: 5
}
score_converter: SOFTMAX
}
second_stage_localization_loss_weight: 2.0
second_stage_classification_loss_weight: 1.0
}
}
train_config: {
batch_size: 1
num_steps: 4000000
optimizer {
momentum_optimizer: {
learning_rate: {
manual_step_learning_rate {
initial_learning_rate: 0.0003
schedule {
step: 3000000
learning_rate: .00003
}
schedule {
step: 3500000
learning_rate: .000003
}
}
}
momentum_optimizer_value: 0.9
}
use_moving_average: false
}
gradient_clipping_by_norm: 10.0
fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt"
data_augmentation_options {
random_horizontal_flip {
}
}
}
train_input_reader: {
label_map_path: "PATH_TO_BE_CONFIGURED/fgvc_2854_classes_label_map.pbtxt"
tf_record_input_reader {
input_path: "PATH_TO_BE_CONFIGURED/animal_2854_train.record"
}
}
eval_config: {
metrics_set: "pascal_voc_detection_metrics"
use_moving_averages: false
num_examples: 48736
}
eval_input_reader: {
label_map_path: "PATH_TO_BE_CONFIGURED/fgvc_2854_classes_label_map.pbtxt"
shuffle: false
num_readers: 1
tf_record_input_reader {
input_path: "PATH_TO_BE_CONFIGURED/animal_2854_val.record"
}
}
# Faster R-CNN with Resnet-50 (v1)
# Users should configure the fine_tune_checkpoint field in the train config as
# well as the label_map_path and input_path fields in the train_input_reader and
# eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that
# should be configured.
model {
faster_rcnn {
num_classes: 2854
image_resizer {
keep_aspect_ratio_resizer {
min_dimension: 600
max_dimension: 1024
}
}
feature_extractor {
type: 'faster_rcnn_resnet50'
first_stage_features_stride: 16
}
first_stage_anchor_generator {
grid_anchor_generator {
scales: [0.25, 0.5, 1.0, 2.0]
aspect_ratios: [0.5, 1.0, 2.0]
height_stride: 16
width_stride: 16
}
}
first_stage_box_predictor_conv_hyperparams {
op: CONV
regularizer {
l2_regularizer {
weight: 0.0
}
}
initializer {
truncated_normal_initializer {
stddev: 0.01
}
}
}
first_stage_nms_score_threshold: 0.0
first_stage_nms_iou_threshold: 0.7
first_stage_max_proposals: 32
first_stage_localization_loss_weight: 2.0
first_stage_objectness_loss_weight: 1.0
initial_crop_size: 14
maxpool_kernel_size: 2
maxpool_stride: 2
second_stage_batch_size: 32
second_stage_box_predictor {
mask_rcnn_box_predictor {
use_dropout: false
dropout_keep_probability: 1.0
fc_hyperparams {
op: FC
regularizer {
l2_regularizer {
weight: 0.0
}
}
initializer {
variance_scaling_initializer {
factor: 1.0
uniform: true
mode: FAN_AVG
}
}
}
}
}
second_stage_post_processing {
batch_non_max_suppression {
score_threshold: 0.0
iou_threshold: 0.6
max_detections_per_class: 5
max_total_detections: 5
}
score_converter: SOFTMAX
}
second_stage_localization_loss_weight: 2.0
second_stage_classification_loss_weight: 1.0
}
}
train_config: {
batch_size: 1
num_steps: 4000000
optimizer {
momentum_optimizer: {
learning_rate: {
manual_step_learning_rate {
initial_learning_rate: 0.0003
schedule {
step: 3000000
learning_rate: .00003
}
schedule {
step: 3500000
learning_rate: .000003
}
}
}
momentum_optimizer_value: 0.9
}
use_moving_average: false
}
gradient_clipping_by_norm: 10.0
fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt"
data_augmentation_options {
random_horizontal_flip {
}
}
}
train_input_reader: {
label_map_path: "PATH_TO_BE_CONFIGURED/fgvc_2854_classes_label_map.pbtxt"
tf_record_input_reader {
input_path: "PATH_TO_BE_CONFIGURED/animal_2854_train.record"
}
}
eval_config: {
metrics_set: "pascal_voc_detection_metrics"
use_moving_averages: false
num_examples: 10
}
eval_input_reader: {
label_map_path: "PATH_TO_BE_CONFIGURED/fgvc_2854_classes_label_map.pbtxt"
shuffle: false
num_readers: 1
tf_record_input_reader {
input_path: "PATH_TO_BE_CONFIGURED/animal_2854_val.record"
}
}
...@@ -103,15 +103,20 @@ def create_configs_from_pipeline_proto(pipeline_config): ...@@ -103,15 +103,20 @@ def create_configs_from_pipeline_proto(pipeline_config):
Returns: Returns:
Dictionary of configuration objects. Keys are `model`, `train_config`, Dictionary of configuration objects. Keys are `model`, `train_config`,
`train_input_config`, `eval_config`, `eval_input_config`. Value are the `train_input_config`, `eval_config`, `eval_input_configs`. Value are
corresponding config objects. the corresponding config objects or list of config objects (only for
eval_input_configs).
""" """
configs = {} configs = {}
configs["model"] = pipeline_config.model configs["model"] = pipeline_config.model
configs["train_config"] = pipeline_config.train_config configs["train_config"] = pipeline_config.train_config
configs["train_input_config"] = pipeline_config.train_input_reader configs["train_input_config"] = pipeline_config.train_input_reader
configs["eval_config"] = pipeline_config.eval_config configs["eval_config"] = pipeline_config.eval_config
configs["eval_input_config"] = pipeline_config.eval_input_reader configs["eval_input_configs"] = pipeline_config.eval_input_reader
# Keeps eval_input_config only for backwards compatibility. All clients should
# read eval_input_configs instead.
if configs["eval_input_configs"]:
configs["eval_input_config"] = configs["eval_input_configs"][0]
if pipeline_config.HasField("graph_rewriter"): if pipeline_config.HasField("graph_rewriter"):
configs["graph_rewriter_config"] = pipeline_config.graph_rewriter configs["graph_rewriter_config"] = pipeline_config.graph_rewriter
...@@ -150,7 +155,7 @@ def create_pipeline_proto_from_configs(configs): ...@@ -150,7 +155,7 @@ def create_pipeline_proto_from_configs(configs):
pipeline_config.train_config.CopyFrom(configs["train_config"]) pipeline_config.train_config.CopyFrom(configs["train_config"])
pipeline_config.train_input_reader.CopyFrom(configs["train_input_config"]) pipeline_config.train_input_reader.CopyFrom(configs["train_input_config"])
pipeline_config.eval_config.CopyFrom(configs["eval_config"]) pipeline_config.eval_config.CopyFrom(configs["eval_config"])
pipeline_config.eval_input_reader.CopyFrom(configs["eval_input_config"]) pipeline_config.eval_input_reader.extend(configs["eval_input_configs"])
if "graph_rewriter_config" in configs: if "graph_rewriter_config" in configs:
pipeline_config.graph_rewriter.CopyFrom(configs["graph_rewriter_config"]) pipeline_config.graph_rewriter.CopyFrom(configs["graph_rewriter_config"])
return pipeline_config return pipeline_config
...@@ -224,7 +229,7 @@ def get_configs_from_multiple_files(model_config_path="", ...@@ -224,7 +229,7 @@ def get_configs_from_multiple_files(model_config_path="",
eval_input_config = input_reader_pb2.InputReader() eval_input_config = input_reader_pb2.InputReader()
with tf.gfile.GFile(eval_input_config_path, "r") as f: with tf.gfile.GFile(eval_input_config_path, "r") as f:
text_format.Merge(f.read(), eval_input_config) text_format.Merge(f.read(), eval_input_config)
configs["eval_input_config"] = eval_input_config configs["eval_input_configs"] = [eval_input_config]
if graph_rewriter_config_path: if graph_rewriter_config_path:
configs["graph_rewriter_config"] = get_graph_rewriter_config_from_file( configs["graph_rewriter_config"] = get_graph_rewriter_config_from_file(
...@@ -284,14 +289,133 @@ def _is_generic_key(key): ...@@ -284,14 +289,133 @@ def _is_generic_key(key):
"graph_rewriter_config", "graph_rewriter_config",
"model", "model",
"train_input_config", "train_input_config",
"train_input_config", "train_config",
"train_config"]: "eval_config"]:
if key.startswith(prefix + "."): if key.startswith(prefix + "."):
return True return True
return False return False
def merge_external_params_with_configs(configs, hparams=None, **kwargs): def _check_and_convert_legacy_input_config_key(key):
"""Checks key and converts legacy input config update to specific update.
Args:
key: string indicates the target of update operation.
Returns:
is_valid_input_config_key: A boolean indicating whether the input key is to
update input config(s).
key_name: 'eval_input_configs' or 'train_input_config' string if
is_valid_input_config_key is true. None if is_valid_input_config_key is
false.
input_name: always returns None since legacy input config key never
specifies the target input config. Keeping this output only to match the
output form defined for input config update.
field_name: the field name in input config. `key` itself if
is_valid_input_config_key is false.
"""
key_name = None
input_name = None
field_name = key
is_valid_input_config_key = True
if field_name == "train_shuffle":
key_name = "train_input_config"
field_name = "shuffle"
elif field_name == "eval_shuffle":
key_name = "eval_input_configs"
field_name = "shuffle"
elif field_name == "train_input_path":
key_name = "train_input_config"
field_name = "input_path"
elif field_name == "eval_input_path":
key_name = "eval_input_configs"
field_name = "input_path"
elif field_name == "train_input_path":
key_name = "train_input_config"
field_name = "input_path"
elif field_name == "eval_input_path":
key_name = "eval_input_configs"
field_name = "input_path"
elif field_name == "append_train_input_path":
key_name = "train_input_config"
field_name = "input_path"
elif field_name == "append_eval_input_path":
key_name = "eval_input_configs"
field_name = "input_path"
else:
is_valid_input_config_key = False
return is_valid_input_config_key, key_name, input_name, field_name
def check_and_parse_input_config_key(configs, key):
"""Checks key and returns specific fields if key is valid input config update.
Args:
configs: Dictionary of configuration objects. See outputs from
get_configs_from_pipeline_file() or get_configs_from_multiple_files().
key: string indicates the target of update operation.
Returns:
is_valid_input_config_key: A boolean indicate whether the input key is to
update input config(s).
key_name: 'eval_input_configs' or 'train_input_config' string if
is_valid_input_config_key is true. None if is_valid_input_config_key is
false.
input_name: the name of the input config to be updated. None if
is_valid_input_config_key is false.
field_name: the field name in input config. `key` itself if
is_valid_input_config_key is false.
Raises:
ValueError: when the input key format doesn't match any known formats.
ValueError: if key_name doesn't match 'eval_input_configs' or
'train_input_config'.
ValueError: if input_name doesn't match any name in train or eval input
configs.
ValueError: if field_name doesn't match any supported fields.
"""
key_name = None
input_name = None
field_name = None
fields = key.split(":")
if len(fields) == 1:
field_name = key
return _check_and_convert_legacy_input_config_key(key)
elif len(fields) == 3:
key_name = fields[0]
input_name = fields[1]
field_name = fields[2]
else:
raise ValueError("Invalid key format when overriding configs.")
# Checks if key_name is valid for specific update.
if key_name not in ["eval_input_configs", "train_input_config"]:
raise ValueError("Invalid key_name when overriding input config.")
# Checks if input_name is valid for specific update. For train input config it
# should match configs[key_name].name, for eval input configs it should match
# the name field of one of the eval_input_configs.
if isinstance(configs[key_name], input_reader_pb2.InputReader):
is_valid_input_name = configs[key_name].name == input_name
else:
is_valid_input_name = input_name in [
eval_input_config.name for eval_input_config in configs[key_name]
]
if not is_valid_input_name:
raise ValueError("Invalid input_name when overriding input config.")
# Checks if field_name is valid for specific update.
if field_name not in [
"input_path", "label_map_path", "shuffle", "mask_type",
"sample_1_of_n_examples"
]:
raise ValueError("Invalid field_name when overriding input config.")
return True, key_name, input_name, field_name
def merge_external_params_with_configs(configs, hparams=None, kwargs_dict=None):
"""Updates `configs` dictionary based on supplied parameters. """Updates `configs` dictionary based on supplied parameters.
This utility is for modifying specific fields in the object detection configs. This utility is for modifying specific fields in the object detection configs.
...@@ -304,6 +428,31 @@ def merge_external_params_with_configs(configs, hparams=None, **kwargs): ...@@ -304,6 +428,31 @@ def merge_external_params_with_configs(configs, hparams=None, **kwargs):
1. Strategy-based overrides, which update multiple relevant configuration 1. Strategy-based overrides, which update multiple relevant configuration
options. For example, updating `learning_rate` will update both the warmup and options. For example, updating `learning_rate` will update both the warmup and
final learning rates. final learning rates.
In this case key can be one of the following formats:
1. legacy update: single string that indicates the attribute to be
updated. E.g. 'lable_map_path', 'eval_input_path', 'shuffle'.
Note that when updating fields (e.g. eval_input_path, eval_shuffle) in
eval_input_configs, the override will only be applied when
eval_input_configs has exactly 1 element.
2. specific update: colon separated string that indicates which field in
which input_config to update. It should have 3 fields:
- key_name: Name of the input config we should update, either
'train_input_config' or 'eval_input_configs'
- input_name: a 'name' that can be used to identify elements, especially
when configs[key_name] is a repeated field.
- field_name: name of the field that you want to override.
For example, given configs dict as below:
configs = {
'model': {...}
'train_config': {...}
'train_input_config': {...}
'eval_config': {...}
'eval_input_configs': [{ name:"eval_coco", ...},
{ name:"eval_voc", ... }]
}
Assume we want to update the input_path of the eval_input_config
whose name is 'eval_coco'. The `key` would then be:
'eval_input_configs:eval_coco:input_path'
2. Generic key/value, which update a specific parameter based on namespaced 2. Generic key/value, which update a specific parameter based on namespaced
configuration keys. For example, configuration keys. For example,
`model.ssd.loss.hard_example_miner.max_negatives_per_positive` will update the `model.ssd.loss.hard_example_miner.max_negatives_per_positive` will update the
...@@ -314,55 +463,29 @@ def merge_external_params_with_configs(configs, hparams=None, **kwargs): ...@@ -314,55 +463,29 @@ def merge_external_params_with_configs(configs, hparams=None, **kwargs):
configs: Dictionary of configuration objects. See outputs from configs: Dictionary of configuration objects. See outputs from
get_configs_from_pipeline_file() or get_configs_from_multiple_files(). get_configs_from_pipeline_file() or get_configs_from_multiple_files().
hparams: A `HParams`. hparams: A `HParams`.
**kwargs: Extra keyword arguments that are treated the same way as kwargs_dict: Extra keyword arguments that are treated the same way as
attribute/value pairs in `hparams`. Note that hyperparameters with the attribute/value pairs in `hparams`. Note that hyperparameters with the
same names will override keyword arguments. same names will override keyword arguments.
Returns: Returns:
`configs` dictionary. `configs` dictionary.
Raises:
ValueError: when the key string doesn't match any of its allowed formats.
""" """
if kwargs_dict is None:
kwargs_dict = {}
if hparams: if hparams:
kwargs.update(hparams.values()) kwargs_dict.update(hparams.values())
for key, value in kwargs.items(): for key, value in kwargs_dict.items():
tf.logging.info("Maybe overwriting %s: %s", key, value) tf.logging.info("Maybe overwriting %s: %s", key, value)
# pylint: disable=g-explicit-bool-comparison # pylint: disable=g-explicit-bool-comparison
if value == "" or value is None: if value == "" or value is None:
continue continue
# pylint: enable=g-explicit-bool-comparison # pylint: enable=g-explicit-bool-comparison
if key == "learning_rate": elif _maybe_update_config_with_key_value(configs, key, value):
_update_initial_learning_rate(configs, value) continue
elif key == "batch_size":
_update_batch_size(configs, value)
elif key == "momentum_optimizer_value":
_update_momentum_optimizer_value(configs, value)
elif key == "classification_localization_weight_ratio":
# Localization weight is fixed to 1.0.
_update_classification_localization_weight_ratio(configs, value)
elif key == "focal_loss_gamma":
_update_focal_loss_gamma(configs, value)
elif key == "focal_loss_alpha":
_update_focal_loss_alpha(configs, value)
elif key == "train_steps":
_update_train_steps(configs, value)
elif key == "eval_steps":
_update_eval_steps(configs, value)
elif key == "train_input_path":
_update_input_path(configs["train_input_config"], value)
elif key == "eval_input_path":
_update_input_path(configs["eval_input_config"], value)
elif key == "label_map_path":
_update_label_map_path(configs, value)
elif key == "mask_type":
_update_mask_type(configs, value)
elif key == "eval_with_moving_averages":
_update_use_moving_averages(configs, value)
elif key == "train_shuffle":
_update_shuffle(configs["train_input_config"], value)
elif key == "eval_shuffle":
_update_shuffle(configs["eval_input_config"], value)
elif key == "retain_original_images_in_eval":
_update_retain_original_images(configs["eval_config"], value)
elif _is_generic_key(key): elif _is_generic_key(key):
_update_generic(configs, key, value) _update_generic(configs, key, value)
else: else:
...@@ -370,6 +493,148 @@ def merge_external_params_with_configs(configs, hparams=None, **kwargs): ...@@ -370,6 +493,148 @@ def merge_external_params_with_configs(configs, hparams=None, **kwargs):
return configs return configs
def _maybe_update_config_with_key_value(configs, key, value):
"""Checks key type and updates `configs` with the key value pair accordingly.
Args:
configs: Dictionary of configuration objects. See outputs from
get_configs_from_pipeline_file() or get_configs_from_multiple_files().
key: String indicates the field(s) to be updated.
value: Value used to override existing field value.
Returns:
A boolean value that indicates whether the override succeeds.
Raises:
ValueError: when the key string doesn't match any of the formats above.
"""
is_valid_input_config_key, key_name, input_name, field_name = (
check_and_parse_input_config_key(configs, key))
if is_valid_input_config_key:
update_input_reader_config(
configs,
key_name=key_name,
input_name=input_name,
field_name=field_name,
value=value)
elif field_name == "learning_rate":
_update_initial_learning_rate(configs, value)
elif field_name == "batch_size":
_update_batch_size(configs, value)
elif field_name == "momentum_optimizer_value":
_update_momentum_optimizer_value(configs, value)
elif field_name == "classification_localization_weight_ratio":
# Localization weight is fixed to 1.0.
_update_classification_localization_weight_ratio(configs, value)
elif field_name == "focal_loss_gamma":
_update_focal_loss_gamma(configs, value)
elif field_name == "focal_loss_alpha":
_update_focal_loss_alpha(configs, value)
elif field_name == "train_steps":
_update_train_steps(configs, value)
elif field_name == "label_map_path":
_update_label_map_path(configs, value)
elif field_name == "mask_type":
_update_mask_type(configs, value)
elif field_name == "sample_1_of_n_eval_examples":
_update_all_eval_input_configs(configs, "sample_1_of_n_examples", value)
elif field_name == "eval_num_epochs":
_update_all_eval_input_configs(configs, "num_epochs", value)
elif field_name == "eval_with_moving_averages":
_update_use_moving_averages(configs, value)
elif field_name == "retain_original_images_in_eval":
_update_retain_original_images(configs["eval_config"], value)
elif field_name == "use_bfloat16":
_update_use_bfloat16(configs, value)
else:
return False
return True
def _update_tf_record_input_path(input_config, input_path):
"""Updates input configuration to reflect a new input path.
The input_config object is updated in place, and hence not returned.
Args:
input_config: A input_reader_pb2.InputReader.
input_path: A path to data or list of paths.
Raises:
TypeError: if input reader type is not `tf_record_input_reader`.
"""
input_reader_type = input_config.WhichOneof("input_reader")
if input_reader_type == "tf_record_input_reader":
input_config.tf_record_input_reader.ClearField("input_path")
if isinstance(input_path, list):
input_config.tf_record_input_reader.input_path.extend(input_path)
else:
input_config.tf_record_input_reader.input_path.append(input_path)
else:
raise TypeError("Input reader type must be `tf_record_input_reader`.")
def update_input_reader_config(configs,
key_name=None,
input_name=None,
field_name=None,
value=None,
path_updater=_update_tf_record_input_path):
"""Updates specified input reader config field.
Args:
configs: Dictionary of configuration objects. See outputs from
get_configs_from_pipeline_file() or get_configs_from_multiple_files().
key_name: Name of the input config we should update, either
'train_input_config' or 'eval_input_configs'
input_name: String name used to identify input config to update with. Should
be either None or value of the 'name' field in one of the input reader
configs.
field_name: Field name in input_reader_pb2.InputReader.
value: Value used to override existing field value.
path_updater: helper function used to update the input path. Only used when
field_name is "input_path".
Raises:
ValueError: when input field_name is None.
ValueError: when input_name is None and number of eval_input_readers does
not equal to 1.
"""
if isinstance(configs[key_name], input_reader_pb2.InputReader):
# Updates singular input_config object.
target_input_config = configs[key_name]
if field_name == "input_path":
path_updater(input_config=target_input_config, input_path=value)
else:
setattr(target_input_config, field_name, value)
elif input_name is None and len(configs[key_name]) == 1:
# Updates first (and the only) object of input_config list.
target_input_config = configs[key_name][0]
if field_name == "input_path":
path_updater(input_config=target_input_config, input_path=value)
else:
setattr(target_input_config, field_name, value)
elif input_name is not None and len(configs[key_name]):
# Updates input_config whose name matches input_name.
update_count = 0
for input_config in configs[key_name]:
if input_config.name == input_name:
setattr(input_config, field_name, value)
update_count = update_count + 1
if not update_count:
raise ValueError(
"Input name {} not found when overriding.".format(input_name))
elif update_count > 1:
raise ValueError("Duplicate input name found when overriding.")
else:
key_name = "None" if key_name is None else key_name
input_name = "None" if input_name is None else input_name
field_name = "None" if field_name is None else field_name
raise ValueError("Unknown input config overriding: "
"key_name:{}, input_name:{}, field_name:{}.".format(
key_name, input_name, field_name))
def _update_initial_learning_rate(configs, learning_rate): def _update_initial_learning_rate(configs, learning_rate):
"""Updates `configs` to reflect the new initial learning rate. """Updates `configs` to reflect the new initial learning rate.
...@@ -596,27 +861,10 @@ def _update_eval_steps(configs, eval_steps): ...@@ -596,27 +861,10 @@ def _update_eval_steps(configs, eval_steps):
configs["eval_config"].num_examples = int(eval_steps) configs["eval_config"].num_examples = int(eval_steps)
def _update_input_path(input_config, input_path): def _update_all_eval_input_configs(configs, field, value):
"""Updates input configuration to reflect a new input path. """Updates the content of `field` with `value` for all eval input configs."""
for eval_input_config in configs["eval_input_configs"]:
The input_config object is updated in place, and hence not returned. setattr(eval_input_config, field, value)
Args:
input_config: A input_reader_pb2.InputReader.
input_path: A path to data or list of paths.
Raises:
TypeError: if input reader type is not `tf_record_input_reader`.
"""
input_reader_type = input_config.WhichOneof("input_reader")
if input_reader_type == "tf_record_input_reader":
input_config.tf_record_input_reader.ClearField("input_path")
if isinstance(input_path, list):
input_config.tf_record_input_reader.input_path.extend(input_path)
else:
input_config.tf_record_input_reader.input_path.append(input_path)
else:
raise TypeError("Input reader type must be `tf_record_input_reader`.")
def _update_label_map_path(configs, label_map_path): def _update_label_map_path(configs, label_map_path):
...@@ -630,7 +878,7 @@ def _update_label_map_path(configs, label_map_path): ...@@ -630,7 +878,7 @@ def _update_label_map_path(configs, label_map_path):
label_map_path: New path to `StringIntLabelMap` pbtxt file. label_map_path: New path to `StringIntLabelMap` pbtxt file.
""" """
configs["train_input_config"].label_map_path = label_map_path configs["train_input_config"].label_map_path = label_map_path
configs["eval_input_config"].label_map_path = label_map_path _update_all_eval_input_configs(configs, "label_map_path", label_map_path)
def _update_mask_type(configs, mask_type): def _update_mask_type(configs, mask_type):
...@@ -645,7 +893,7 @@ def _update_mask_type(configs, mask_type): ...@@ -645,7 +893,7 @@ def _update_mask_type(configs, mask_type):
input_reader_pb2.InstanceMaskType input_reader_pb2.InstanceMaskType
""" """
configs["train_input_config"].mask_type = mask_type configs["train_input_config"].mask_type = mask_type
configs["eval_input_config"].mask_type = mask_type _update_all_eval_input_configs(configs, "mask_type", mask_type)
def _update_use_moving_averages(configs, use_moving_averages): def _update_use_moving_averages(configs, use_moving_averages):
...@@ -662,18 +910,6 @@ def _update_use_moving_averages(configs, use_moving_averages): ...@@ -662,18 +910,6 @@ def _update_use_moving_averages(configs, use_moving_averages):
configs["eval_config"].use_moving_averages = use_moving_averages configs["eval_config"].use_moving_averages = use_moving_averages
def _update_shuffle(input_config, shuffle):
"""Updates input configuration to reflect a new shuffle configuration.
The input_config object is updated in place, and hence not returned.
Args:
input_config: A input_reader_pb2.InputReader.
shuffle: Whether or not to shuffle the input data before reading.
"""
input_config.shuffle = shuffle
def _update_retain_original_images(eval_config, retain_original_images): def _update_retain_original_images(eval_config, retain_original_images):
"""Updates eval config with option to retain original images. """Updates eval config with option to retain original images.
...@@ -685,3 +921,16 @@ def _update_retain_original_images(eval_config, retain_original_images): ...@@ -685,3 +921,16 @@ def _update_retain_original_images(eval_config, retain_original_images):
in eval mode. in eval mode.
""" """
eval_config.retain_original_images = retain_original_images eval_config.retain_original_images = retain_original_images
def _update_use_bfloat16(configs, use_bfloat16):
"""Updates `configs` to reflect the new setup on whether to use bfloat16.
The configs dictionary is updated in place, and hence not returned.
Args:
configs: Dictionary of configuration objects. See outputs from
get_configs_from_pipeline_file() or get_configs_from_multiple_files().
use_bfloat16: A bool, indicating whether to use bfloat16 for training.
"""
configs["train_config"].use_bfloat16 = use_bfloat16
...@@ -83,7 +83,7 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -83,7 +83,7 @@ class ConfigUtilTest(tf.test.TestCase):
pipeline_config.train_config.batch_size = 32 pipeline_config.train_config.batch_size = 32
pipeline_config.train_input_reader.label_map_path = "path/to/label_map" pipeline_config.train_input_reader.label_map_path = "path/to/label_map"
pipeline_config.eval_config.num_examples = 20 pipeline_config.eval_config.num_examples = 20
pipeline_config.eval_input_reader.queue_capacity = 100 pipeline_config.eval_input_reader.add().queue_capacity = 100
_write_config(pipeline_config, pipeline_config_path) _write_config(pipeline_config, pipeline_config_path)
...@@ -96,7 +96,7 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -96,7 +96,7 @@ class ConfigUtilTest(tf.test.TestCase):
self.assertProtoEquals(pipeline_config.eval_config, self.assertProtoEquals(pipeline_config.eval_config,
configs["eval_config"]) configs["eval_config"])
self.assertProtoEquals(pipeline_config.eval_input_reader, self.assertProtoEquals(pipeline_config.eval_input_reader,
configs["eval_input_config"]) configs["eval_input_configs"])
def test_create_configs_from_pipeline_proto(self): def test_create_configs_from_pipeline_proto(self):
"""Tests creating configs dictionary from pipeline proto.""" """Tests creating configs dictionary from pipeline proto."""
...@@ -106,7 +106,7 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -106,7 +106,7 @@ class ConfigUtilTest(tf.test.TestCase):
pipeline_config.train_config.batch_size = 32 pipeline_config.train_config.batch_size = 32
pipeline_config.train_input_reader.label_map_path = "path/to/label_map" pipeline_config.train_input_reader.label_map_path = "path/to/label_map"
pipeline_config.eval_config.num_examples = 20 pipeline_config.eval_config.num_examples = 20
pipeline_config.eval_input_reader.queue_capacity = 100 pipeline_config.eval_input_reader.add().queue_capacity = 100
configs = config_util.create_configs_from_pipeline_proto(pipeline_config) configs = config_util.create_configs_from_pipeline_proto(pipeline_config)
self.assertProtoEquals(pipeline_config.model, configs["model"]) self.assertProtoEquals(pipeline_config.model, configs["model"])
...@@ -116,7 +116,7 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -116,7 +116,7 @@ class ConfigUtilTest(tf.test.TestCase):
configs["train_input_config"]) configs["train_input_config"])
self.assertProtoEquals(pipeline_config.eval_config, configs["eval_config"]) self.assertProtoEquals(pipeline_config.eval_config, configs["eval_config"])
self.assertProtoEquals(pipeline_config.eval_input_reader, self.assertProtoEquals(pipeline_config.eval_input_reader,
configs["eval_input_config"]) configs["eval_input_configs"])
def test_create_pipeline_proto_from_configs(self): def test_create_pipeline_proto_from_configs(self):
"""Tests that proto can be reconstructed from configs dictionary.""" """Tests that proto can be reconstructed from configs dictionary."""
...@@ -127,7 +127,7 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -127,7 +127,7 @@ class ConfigUtilTest(tf.test.TestCase):
pipeline_config.train_config.batch_size = 32 pipeline_config.train_config.batch_size = 32
pipeline_config.train_input_reader.label_map_path = "path/to/label_map" pipeline_config.train_input_reader.label_map_path = "path/to/label_map"
pipeline_config.eval_config.num_examples = 20 pipeline_config.eval_config.num_examples = 20
pipeline_config.eval_input_reader.queue_capacity = 100 pipeline_config.eval_input_reader.add().queue_capacity = 100
_write_config(pipeline_config, pipeline_config_path) _write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path) configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
...@@ -142,7 +142,7 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -142,7 +142,7 @@ class ConfigUtilTest(tf.test.TestCase):
pipeline_config.train_config.batch_size = 32 pipeline_config.train_config.batch_size = 32
pipeline_config.train_input_reader.label_map_path = "path/to/label_map" pipeline_config.train_input_reader.label_map_path = "path/to/label_map"
pipeline_config.eval_config.num_examples = 20 pipeline_config.eval_config.num_examples = 20
pipeline_config.eval_input_reader.queue_capacity = 100 pipeline_config.eval_input_reader.add().queue_capacity = 100
config_util.save_pipeline_config(pipeline_config, self.get_temp_dir()) config_util.save_pipeline_config(pipeline_config, self.get_temp_dir())
configs = config_util.get_configs_from_pipeline_file( configs = config_util.get_configs_from_pipeline_file(
...@@ -197,8 +197,7 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -197,8 +197,7 @@ class ConfigUtilTest(tf.test.TestCase):
self.assertProtoEquals(train_input_config, self.assertProtoEquals(train_input_config,
configs["train_input_config"]) configs["train_input_config"])
self.assertProtoEquals(eval_config, configs["eval_config"]) self.assertProtoEquals(eval_config, configs["eval_config"])
self.assertProtoEquals(eval_input_config, self.assertProtoEquals(eval_input_config, configs["eval_input_configs"][0])
configs["eval_input_config"])
def _assertOptimizerWithNewLearningRate(self, optimizer_name): def _assertOptimizerWithNewLearningRate(self, optimizer_name):
"""Asserts successful updating of all learning rate schemes.""" """Asserts successful updating of all learning rate schemes."""
...@@ -282,6 +281,41 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -282,6 +281,41 @@ class ConfigUtilTest(tf.test.TestCase):
"""Tests new learning rates for Adam Optimizer.""" """Tests new learning rates for Adam Optimizer."""
self._assertOptimizerWithNewLearningRate("adam_optimizer") self._assertOptimizerWithNewLearningRate("adam_optimizer")
def testGenericConfigOverride(self):
"""Tests generic config overrides for all top-level configs."""
# Set one parameter for each of the top-level pipeline configs:
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.model.ssd.num_classes = 1
pipeline_config.train_config.batch_size = 1
pipeline_config.eval_config.num_visualizations = 1
pipeline_config.train_input_reader.label_map_path = "/some/path"
pipeline_config.eval_input_reader.add().label_map_path = "/some/path"
pipeline_config.graph_rewriter.quantization.weight_bits = 1
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
_write_config(pipeline_config, pipeline_config_path)
# Override each of the parameters:
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
hparams = tf.contrib.training.HParams(
**{
"model.ssd.num_classes": 2,
"train_config.batch_size": 2,
"train_input_config.label_map_path": "/some/other/path",
"eval_config.num_visualizations": 2,
"graph_rewriter_config.quantization.weight_bits": 2
})
configs = config_util.merge_external_params_with_configs(configs, hparams)
# Ensure that the parameters have the overridden values:
self.assertEqual(2, configs["model"].ssd.num_classes)
self.assertEqual(2, configs["train_config"].batch_size)
self.assertEqual("/some/other/path",
configs["train_input_config"].label_map_path)
self.assertEqual(2, configs["eval_config"].num_visualizations)
self.assertEqual(2,
configs["graph_rewriter_config"].quantization.weight_bits)
def testNewBatchSize(self): def testNewBatchSize(self):
"""Tests that batch size is updated appropriately.""" """Tests that batch size is updated appropriately."""
original_batch_size = 2 original_batch_size = 2
...@@ -406,25 +440,19 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -406,25 +440,19 @@ class ConfigUtilTest(tf.test.TestCase):
def testMergingKeywordArguments(self): def testMergingKeywordArguments(self):
"""Tests that keyword arguments get merged as do hyperparameters.""" """Tests that keyword arguments get merged as do hyperparameters."""
original_num_train_steps = 100 original_num_train_steps = 100
original_num_eval_steps = 5
desired_num_train_steps = 10 desired_num_train_steps = 10
desired_num_eval_steps = 1
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config") pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.train_config.num_steps = original_num_train_steps pipeline_config.train_config.num_steps = original_num_train_steps
pipeline_config.eval_config.num_examples = original_num_eval_steps
_write_config(pipeline_config, pipeline_config_path) _write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path) configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
override_dict = {"train_steps": desired_num_train_steps}
configs = config_util.merge_external_params_with_configs( configs = config_util.merge_external_params_with_configs(
configs, configs, kwargs_dict=override_dict)
train_steps=desired_num_train_steps,
eval_steps=desired_num_eval_steps)
train_steps = configs["train_config"].num_steps train_steps = configs["train_config"].num_steps
eval_steps = configs["eval_config"].num_examples
self.assertEqual(desired_num_train_steps, train_steps) self.assertEqual(desired_num_train_steps, train_steps)
self.assertEqual(desired_num_eval_steps, eval_steps)
def testGetNumberOfClasses(self): def testGetNumberOfClasses(self):
"""Tests that number of classes can be retrieved.""" """Tests that number of classes can be retrieved."""
...@@ -449,8 +477,9 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -449,8 +477,9 @@ class ConfigUtilTest(tf.test.TestCase):
_write_config(pipeline_config, pipeline_config_path) _write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path) configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
override_dict = {"train_input_path": new_train_path}
configs = config_util.merge_external_params_with_configs( configs = config_util.merge_external_params_with_configs(
configs, train_input_path=new_train_path) configs, kwargs_dict=override_dict)
reader_config = configs["train_input_config"].tf_record_input_reader reader_config = configs["train_input_config"].tf_record_input_reader
final_path = reader_config.input_path final_path = reader_config.input_path
self.assertEqual([new_train_path], final_path) self.assertEqual([new_train_path], final_path)
...@@ -467,8 +496,9 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -467,8 +496,9 @@ class ConfigUtilTest(tf.test.TestCase):
_write_config(pipeline_config, pipeline_config_path) _write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path) configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
override_dict = {"train_input_path": new_train_path}
configs = config_util.merge_external_params_with_configs( configs = config_util.merge_external_params_with_configs(
configs, train_input_path=new_train_path) configs, kwargs_dict=override_dict)
reader_config = configs["train_input_config"].tf_record_input_reader reader_config = configs["train_input_config"].tf_record_input_reader
final_path = reader_config.input_path final_path = reader_config.input_path
self.assertEqual(new_train_path, final_path) self.assertEqual(new_train_path, final_path)
...@@ -482,17 +512,18 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -482,17 +512,18 @@ class ConfigUtilTest(tf.test.TestCase):
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
train_input_reader = pipeline_config.train_input_reader train_input_reader = pipeline_config.train_input_reader
train_input_reader.label_map_path = original_label_map_path train_input_reader.label_map_path = original_label_map_path
eval_input_reader = pipeline_config.eval_input_reader eval_input_reader = pipeline_config.eval_input_reader.add()
eval_input_reader.label_map_path = original_label_map_path eval_input_reader.label_map_path = original_label_map_path
_write_config(pipeline_config, pipeline_config_path) _write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path) configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
override_dict = {"label_map_path": new_label_map_path}
configs = config_util.merge_external_params_with_configs( configs = config_util.merge_external_params_with_configs(
configs, label_map_path=new_label_map_path) configs, kwargs_dict=override_dict)
self.assertEqual(new_label_map_path, self.assertEqual(new_label_map_path,
configs["train_input_config"].label_map_path) configs["train_input_config"].label_map_path)
self.assertEqual(new_label_map_path, for eval_input_config in configs["eval_input_configs"]:
configs["eval_input_config"].label_map_path) self.assertEqual(new_label_map_path, eval_input_config.label_map_path)
def testDontOverwriteEmptyLabelMapPath(self): def testDontOverwriteEmptyLabelMapPath(self):
"""Tests that label map path will not by overwritten with empty string.""" """Tests that label map path will not by overwritten with empty string."""
...@@ -503,17 +534,18 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -503,17 +534,18 @@ class ConfigUtilTest(tf.test.TestCase):
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
train_input_reader = pipeline_config.train_input_reader train_input_reader = pipeline_config.train_input_reader
train_input_reader.label_map_path = original_label_map_path train_input_reader.label_map_path = original_label_map_path
eval_input_reader = pipeline_config.eval_input_reader eval_input_reader = pipeline_config.eval_input_reader.add()
eval_input_reader.label_map_path = original_label_map_path eval_input_reader.label_map_path = original_label_map_path
_write_config(pipeline_config, pipeline_config_path) _write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path) configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
override_dict = {"label_map_path": new_label_map_path}
configs = config_util.merge_external_params_with_configs( configs = config_util.merge_external_params_with_configs(
configs, label_map_path=new_label_map_path) configs, kwargs_dict=override_dict)
self.assertEqual(original_label_map_path, self.assertEqual(original_label_map_path,
configs["train_input_config"].label_map_path) configs["train_input_config"].label_map_path)
self.assertEqual(original_label_map_path, self.assertEqual(original_label_map_path,
configs["eval_input_config"].label_map_path) configs["eval_input_configs"][0].label_map_path)
def testNewMaskType(self): def testNewMaskType(self):
"""Tests that mask type can be overwritten in input readers.""" """Tests that mask type can be overwritten in input readers."""
...@@ -524,15 +556,16 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -524,15 +556,16 @@ class ConfigUtilTest(tf.test.TestCase):
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
train_input_reader = pipeline_config.train_input_reader train_input_reader = pipeline_config.train_input_reader
train_input_reader.mask_type = original_mask_type train_input_reader.mask_type = original_mask_type
eval_input_reader = pipeline_config.eval_input_reader eval_input_reader = pipeline_config.eval_input_reader.add()
eval_input_reader.mask_type = original_mask_type eval_input_reader.mask_type = original_mask_type
_write_config(pipeline_config, pipeline_config_path) _write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path) configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
override_dict = {"mask_type": new_mask_type}
configs = config_util.merge_external_params_with_configs( configs = config_util.merge_external_params_with_configs(
configs, mask_type=new_mask_type) configs, kwargs_dict=override_dict)
self.assertEqual(new_mask_type, configs["train_input_config"].mask_type) self.assertEqual(new_mask_type, configs["train_input_config"].mask_type)
self.assertEqual(new_mask_type, configs["eval_input_config"].mask_type) self.assertEqual(new_mask_type, configs["eval_input_configs"][0].mask_type)
def testUseMovingAverageForEval(self): def testUseMovingAverageForEval(self):
use_moving_averages_orig = False use_moving_averages_orig = False
...@@ -543,8 +576,9 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -543,8 +576,9 @@ class ConfigUtilTest(tf.test.TestCase):
_write_config(pipeline_config, pipeline_config_path) _write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path) configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
override_dict = {"eval_with_moving_averages": True}
configs = config_util.merge_external_params_with_configs( configs = config_util.merge_external_params_with_configs(
configs, eval_with_moving_averages=True) configs, kwargs_dict=override_dict)
self.assertEqual(True, configs["eval_config"].use_moving_averages) self.assertEqual(True, configs["eval_config"].use_moving_averages)
def testGetImageResizerConfig(self): def testGetImageResizerConfig(self):
...@@ -585,14 +619,14 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -585,14 +619,14 @@ class ConfigUtilTest(tf.test.TestCase):
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config") pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.eval_input_reader.shuffle = original_shuffle pipeline_config.eval_input_reader.add().shuffle = original_shuffle
_write_config(pipeline_config, pipeline_config_path) _write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path) configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
override_dict = {"eval_shuffle": desired_shuffle}
configs = config_util.merge_external_params_with_configs( configs = config_util.merge_external_params_with_configs(
configs, eval_shuffle=desired_shuffle) configs, kwargs_dict=override_dict)
eval_shuffle = configs["eval_input_config"].shuffle self.assertEqual(desired_shuffle, configs["eval_input_configs"][0].shuffle)
self.assertEqual(desired_shuffle, eval_shuffle)
def testTrainShuffle(self): def testTrainShuffle(self):
"""Tests that `train_shuffle` keyword arguments are applied correctly.""" """Tests that `train_shuffle` keyword arguments are applied correctly."""
...@@ -605,8 +639,9 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -605,8 +639,9 @@ class ConfigUtilTest(tf.test.TestCase):
_write_config(pipeline_config, pipeline_config_path) _write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path) configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
override_dict = {"train_shuffle": desired_shuffle}
configs = config_util.merge_external_params_with_configs( configs = config_util.merge_external_params_with_configs(
configs, train_shuffle=desired_shuffle) configs, kwargs_dict=override_dict)
train_shuffle = configs["train_input_config"].shuffle train_shuffle = configs["train_input_config"].shuffle
self.assertEqual(desired_shuffle, train_shuffle) self.assertEqual(desired_shuffle, train_shuffle)
...@@ -622,11 +657,210 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -622,11 +657,210 @@ class ConfigUtilTest(tf.test.TestCase):
_write_config(pipeline_config, pipeline_config_path) _write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path) configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
override_dict = {
"retain_original_images_in_eval": desired_retain_original_images
}
configs = config_util.merge_external_params_with_configs( configs = config_util.merge_external_params_with_configs(
configs, retain_original_images_in_eval=desired_retain_original_images) configs, kwargs_dict=override_dict)
retain_original_images = configs["eval_config"].retain_original_images retain_original_images = configs["eval_config"].retain_original_images
self.assertEqual(desired_retain_original_images, retain_original_images) self.assertEqual(desired_retain_original_images, retain_original_images)
def testOverwriteAllEvalSampling(self):
original_num_eval_examples = 1
new_num_eval_examples = 10
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.eval_input_reader.add().sample_1_of_n_examples = (
original_num_eval_examples)
pipeline_config.eval_input_reader.add().sample_1_of_n_examples = (
original_num_eval_examples)
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
override_dict = {"sample_1_of_n_eval_examples": new_num_eval_examples}
configs = config_util.merge_external_params_with_configs(
configs, kwargs_dict=override_dict)
for eval_input_config in configs["eval_input_configs"]:
self.assertEqual(new_num_eval_examples,
eval_input_config.sample_1_of_n_examples)
def testOverwriteAllEvalNumEpochs(self):
original_num_epochs = 10
new_num_epochs = 1
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.eval_input_reader.add().num_epochs = original_num_epochs
pipeline_config.eval_input_reader.add().num_epochs = original_num_epochs
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
override_dict = {"eval_num_epochs": new_num_epochs}
configs = config_util.merge_external_params_with_configs(
configs, kwargs_dict=override_dict)
for eval_input_config in configs["eval_input_configs"]:
self.assertEqual(new_num_epochs, eval_input_config.num_epochs)
def testUpdateMaskTypeForAllInputConfigs(self):
original_mask_type = input_reader_pb2.NUMERICAL_MASKS
new_mask_type = input_reader_pb2.PNG_MASKS
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
train_config = pipeline_config.train_input_reader
train_config.mask_type = original_mask_type
eval_1 = pipeline_config.eval_input_reader.add()
eval_1.mask_type = original_mask_type
eval_1.name = "eval_1"
eval_2 = pipeline_config.eval_input_reader.add()
eval_2.mask_type = original_mask_type
eval_2.name = "eval_2"
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
override_dict = {"mask_type": new_mask_type}
configs = config_util.merge_external_params_with_configs(
configs, kwargs_dict=override_dict)
self.assertEqual(configs["train_input_config"].mask_type, new_mask_type)
for eval_input_config in configs["eval_input_configs"]:
self.assertEqual(eval_input_config.mask_type, new_mask_type)
def testErrorOverwritingMultipleInputConfig(self):
original_shuffle = False
new_shuffle = True
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
eval_1 = pipeline_config.eval_input_reader.add()
eval_1.shuffle = original_shuffle
eval_1.name = "eval_1"
eval_2 = pipeline_config.eval_input_reader.add()
eval_2.shuffle = original_shuffle
eval_2.name = "eval_2"
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
override_dict = {"eval_shuffle": new_shuffle}
with self.assertRaises(ValueError):
configs = config_util.merge_external_params_with_configs(
configs, kwargs_dict=override_dict)
def testCheckAndParseInputConfigKey(self):
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.eval_input_reader.add().name = "eval_1"
pipeline_config.eval_input_reader.add().name = "eval_2"
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
specific_shuffle_update_key = "eval_input_configs:eval_2:shuffle"
is_valid_input_config_key, key_name, input_name, field_name = (
config_util.check_and_parse_input_config_key(
configs, specific_shuffle_update_key))
self.assertTrue(is_valid_input_config_key)
self.assertEqual(key_name, "eval_input_configs")
self.assertEqual(input_name, "eval_2")
self.assertEqual(field_name, "shuffle")
legacy_shuffle_update_key = "eval_shuffle"
is_valid_input_config_key, key_name, input_name, field_name = (
config_util.check_and_parse_input_config_key(configs,
legacy_shuffle_update_key))
self.assertTrue(is_valid_input_config_key)
self.assertEqual(key_name, "eval_input_configs")
self.assertEqual(input_name, None)
self.assertEqual(field_name, "shuffle")
non_input_config_update_key = "label_map_path"
is_valid_input_config_key, key_name, input_name, field_name = (
config_util.check_and_parse_input_config_key(
configs, non_input_config_update_key))
self.assertFalse(is_valid_input_config_key)
self.assertEqual(key_name, None)
self.assertEqual(input_name, None)
self.assertEqual(field_name, "label_map_path")
with self.assertRaisesRegexp(ValueError,
"Invalid key format when overriding configs."):
config_util.check_and_parse_input_config_key(
configs, "train_input_config:shuffle")
with self.assertRaisesRegexp(
ValueError, "Invalid key_name when overriding input config."):
config_util.check_and_parse_input_config_key(
configs, "invalid_key_name:train_name:shuffle")
with self.assertRaisesRegexp(
ValueError, "Invalid input_name when overriding input config."):
config_util.check_and_parse_input_config_key(
configs, "eval_input_configs:unknown_eval_name:shuffle")
with self.assertRaisesRegexp(
ValueError, "Invalid field_name when overriding input config."):
config_util.check_and_parse_input_config_key(
configs, "eval_input_configs:eval_2:unknown_field_name")
def testUpdateInputReaderConfigSuccess(self):
original_shuffle = False
new_shuffle = True
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.train_input_reader.shuffle = original_shuffle
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
config_util.update_input_reader_config(
configs,
key_name="train_input_config",
input_name=None,
field_name="shuffle",
value=new_shuffle)
self.assertEqual(configs["train_input_config"].shuffle, new_shuffle)
config_util.update_input_reader_config(
configs,
key_name="train_input_config",
input_name=None,
field_name="shuffle",
value=new_shuffle)
self.assertEqual(configs["train_input_config"].shuffle, new_shuffle)
def testUpdateInputReaderConfigErrors(self):
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.eval_input_reader.add().name = "same_eval_name"
pipeline_config.eval_input_reader.add().name = "same_eval_name"
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
with self.assertRaisesRegexp(ValueError,
"Duplicate input name found when overriding."):
config_util.update_input_reader_config(
configs,
key_name="eval_input_configs",
input_name="same_eval_name",
field_name="shuffle",
value=False)
with self.assertRaisesRegexp(
ValueError, "Input name name_not_exist not found when overriding."):
config_util.update_input_reader_config(
configs,
key_name="eval_input_configs",
input_name="name_not_exist",
field_name="shuffle",
value=False)
with self.assertRaisesRegexp(ValueError,
"Unknown input config overriding."):
config_util.update_input_reader_config(
configs,
key_name="eval_input_configs",
input_name=None,
field_name="shuffle",
value=False)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Label map utility functions.""" """Label map utility functions."""
import logging import logging
...@@ -73,10 +72,10 @@ def get_max_label_map_index(label_map): ...@@ -73,10 +72,10 @@ def get_max_label_map_index(label_map):
def convert_label_map_to_categories(label_map, def convert_label_map_to_categories(label_map,
max_num_classes, max_num_classes,
use_display_name=True): use_display_name=True):
"""Loads label map proto and returns categories list compatible with eval. """Given label map proto returns categories list compatible with eval.
This function loads a label map and returns a list of dicts, each of which This function converts label map proto and returns a list of dicts, each of
has the following keys: which has the following keys:
'id': (required) an integer id uniquely identifying this category. 'id': (required) an integer id uniquely identifying this category.
'name': (required) string representing category name 'name': (required) string representing category name
e.g., 'cat', 'dog', 'pizza'. e.g., 'cat', 'dog', 'pizza'.
...@@ -89,9 +88,10 @@ def convert_label_map_to_categories(label_map, ...@@ -89,9 +88,10 @@ def convert_label_map_to_categories(label_map,
label_map: a StringIntLabelMapProto or None. If None, a default categories label_map: a StringIntLabelMapProto or None. If None, a default categories
list is created with max_num_classes categories. list is created with max_num_classes categories.
max_num_classes: maximum number of (consecutive) label indices to include. max_num_classes: maximum number of (consecutive) label indices to include.
use_display_name: (boolean) choose whether to load 'display_name' field use_display_name: (boolean) choose whether to load 'display_name' field as
as category name. If False or if the display_name field does not exist, category name. If False or if the display_name field does not exist, uses
uses 'name' field as category names instead. 'name' field as category names instead.
Returns: Returns:
categories: a list of dictionaries representing all possible categories. categories: a list of dictionaries representing all possible categories.
""" """
...@@ -107,8 +107,9 @@ def convert_label_map_to_categories(label_map, ...@@ -107,8 +107,9 @@ def convert_label_map_to_categories(label_map,
return categories return categories
for item in label_map.item: for item in label_map.item:
if not 0 < item.id <= max_num_classes: if not 0 < item.id <= max_num_classes:
logging.info('Ignore item %d since it falls outside of requested ' logging.info(
'label range.', item.id) 'Ignore item %d since it falls outside of requested '
'label range.', item.id)
continue continue
if use_display_name and item.HasField('display_name'): if use_display_name and item.HasField('display_name'):
name = item.display_name name = item.display_name
...@@ -188,20 +189,44 @@ def get_label_map_dict(label_map_path, ...@@ -188,20 +189,44 @@ def get_label_map_dict(label_map_path,
return label_map_dict return label_map_dict
def create_category_index_from_labelmap(label_map_path): def create_categories_from_labelmap(label_map_path, use_display_name=True):
"""Reads a label map and returns categories list compatible with eval.
This function converts label map proto and returns a list of dicts, each of
which has the following keys:
'id': an integer id uniquely identifying this category.
'name': string representing category name e.g., 'cat', 'dog'.
Args:
label_map_path: Path to `StringIntLabelMap` proto text file.
use_display_name: (boolean) choose whether to load 'display_name' field
as category name. If False or if the display_name field does not exist,
uses 'name' field as category names instead.
Returns:
categories: a list of dictionaries representing all possible categories.
"""
label_map = load_labelmap(label_map_path)
max_num_classes = max(item.id for item in label_map.item)
return convert_label_map_to_categories(label_map, max_num_classes,
use_display_name)
def create_category_index_from_labelmap(label_map_path, use_display_name=True):
"""Reads a label map and returns a category index. """Reads a label map and returns a category index.
Args: Args:
label_map_path: Path to `StringIntLabelMap` proto text file. label_map_path: Path to `StringIntLabelMap` proto text file.
use_display_name: (boolean) choose whether to load 'display_name' field
as category name. If False or if the display_name field does not exist,
uses 'name' field as category names instead.
Returns: Returns:
A category index, which is a dictionary that maps integer ids to dicts A category index, which is a dictionary that maps integer ids to dicts
containing categories, e.g. containing categories, e.g.
{1: {'id': 1, 'name': 'dog'}, 2: {'id': 2, 'name': 'cat'}, ...} {1: {'id': 1, 'name': 'dog'}, 2: {'id': 2, 'name': 'cat'}, ...}
""" """
label_map = load_labelmap(label_map_path) categories = create_categories_from_labelmap(label_map_path, use_display_name)
max_num_classes = max(item.id for item in label_map.item)
categories = convert_label_map_to_categories(label_map, max_num_classes)
return create_category_index(categories) return create_category_index(categories)
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for object_detection.utils.label_map_util.""" """Tests for object_detection.utils.label_map_util."""
import os import os
...@@ -189,7 +188,7 @@ class LabelMapUtilTest(tf.test.TestCase): ...@@ -189,7 +188,7 @@ class LabelMapUtilTest(tf.test.TestCase):
}] }]
self.assertListEqual(expected_categories_list, categories) self.assertListEqual(expected_categories_list, categories)
def test_convert_label_map_to_coco_categories(self): def test_convert_label_map_to_categories(self):
label_map_proto = self._generate_label_map(num_classes=4) label_map_proto = self._generate_label_map(num_classes=4)
categories = label_map_util.convert_label_map_to_categories( categories = label_map_util.convert_label_map_to_categories(
label_map_proto, max_num_classes=3) label_map_proto, max_num_classes=3)
...@@ -205,7 +204,7 @@ class LabelMapUtilTest(tf.test.TestCase): ...@@ -205,7 +204,7 @@ class LabelMapUtilTest(tf.test.TestCase):
}] }]
self.assertListEqual(expected_categories_list, categories) self.assertListEqual(expected_categories_list, categories)
def test_convert_label_map_to_coco_categories_with_few_classes(self): def test_convert_label_map_to_categories_with_few_classes(self):
label_map_proto = self._generate_label_map(num_classes=4) label_map_proto = self._generate_label_map(num_classes=4)
cat_no_offset = label_map_util.convert_label_map_to_categories( cat_no_offset = label_map_util.convert_label_map_to_categories(
label_map_proto, max_num_classes=2) label_map_proto, max_num_classes=2)
...@@ -238,6 +237,30 @@ class LabelMapUtilTest(tf.test.TestCase): ...@@ -238,6 +237,30 @@ class LabelMapUtilTest(tf.test.TestCase):
} }
}, category_index) }, category_index)
def test_create_categories_from_labelmap(self):
label_map_string = """
item {
id:1
name:'dog'
}
item {
id:2
name:'cat'
}
"""
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
with tf.gfile.Open(label_map_path, 'wb') as f:
f.write(label_map_string)
categories = label_map_util.create_categories_from_labelmap(label_map_path)
self.assertListEqual([{
'name': u'dog',
'id': 1
}, {
'name': u'cat',
'id': 2
}], categories)
def test_create_category_index_from_labelmap(self): def test_create_category_index_from_labelmap(self):
label_map_string = """ label_map_string = """
item { item {
...@@ -266,6 +289,46 @@ class LabelMapUtilTest(tf.test.TestCase): ...@@ -266,6 +289,46 @@ class LabelMapUtilTest(tf.test.TestCase):
} }
}, category_index) }, category_index)
def test_create_category_index_from_labelmap_display(self):
label_map_string = """
item {
id:2
name:'cat'
display_name:'meow'
}
item {
id:1
name:'dog'
display_name:'woof'
}
"""
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
with tf.gfile.Open(label_map_path, 'wb') as f:
f.write(label_map_string)
self.assertDictEqual({
1: {
'name': u'dog',
'id': 1
},
2: {
'name': u'cat',
'id': 2
}
}, label_map_util.create_category_index_from_labelmap(
label_map_path, False))
self.assertDictEqual({
1: {
'name': u'woof',
'id': 1
},
2: {
'name': u'meow',
'id': 2
}
}, label_map_util.create_category_index_from_labelmap(label_map_path))
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -160,6 +160,9 @@ def pad_to_multiple(tensor, multiple): ...@@ -160,6 +160,9 @@ def pad_to_multiple(tensor, multiple):
Returns: Returns:
padded_tensor: the tensor zero padded to the specified multiple. padded_tensor: the tensor zero padded to the specified multiple.
""" """
if multiple == 1:
return tensor
tensor_shape = tensor.get_shape() tensor_shape = tensor.get_shape()
batch_size = static_shape.get_batch_size(tensor_shape) batch_size = static_shape.get_batch_size(tensor_shape)
tensor_height = static_shape.get_height(tensor_shape) tensor_height = static_shape.get_height(tensor_shape)
...@@ -697,8 +700,11 @@ def position_sensitive_crop_regions(image, ...@@ -697,8 +700,11 @@ def position_sensitive_crop_regions(image,
image_crops = [] image_crops = []
for (split, box) in zip(image_splits, position_sensitive_boxes): for (split, box) in zip(image_splits, position_sensitive_boxes):
if split.shape.is_fully_defined() and box.shape.is_fully_defined(): if split.shape.is_fully_defined() and box.shape.is_fully_defined():
crop = matmul_crop_and_resize( crop = tf.squeeze(
tf.expand_dims(split, 0), box, bin_crop_size) matmul_crop_and_resize(
tf.expand_dims(split, axis=0), tf.expand_dims(box, axis=0),
bin_crop_size),
axis=0)
else: else:
crop = tf.image.crop_and_resize( crop = tf.image.crop_and_resize(
tf.expand_dims(split, 0), box, tf.expand_dims(split, 0), box,
...@@ -785,50 +791,85 @@ def reframe_box_masks_to_image_masks(box_masks, boxes, image_height, ...@@ -785,50 +791,85 @@ def reframe_box_masks_to_image_masks(box_masks, boxes, image_height,
return tf.squeeze(image_masks, axis=3) return tf.squeeze(image_masks, axis=3)
def merge_boxes_with_multiple_labels(boxes, classes, num_classes): def merge_boxes_with_multiple_labels(boxes,
classes,
confidences,
num_classes,
quantization_bins=10000):
"""Merges boxes with same coordinates and returns K-hot encoded classes. """Merges boxes with same coordinates and returns K-hot encoded classes.
Args: Args:
boxes: A tf.float32 tensor with shape [N, 4] holding N boxes. boxes: A tf.float32 tensor with shape [N, 4] holding N boxes. Only
normalized coordinates are allowed.
classes: A tf.int32 tensor with shape [N] holding class indices. classes: A tf.int32 tensor with shape [N] holding class indices.
The class index starts at 0. The class index starts at 0.
confidences: A tf.float32 tensor with shape [N] holding class confidences.
num_classes: total number of classes to use for K-hot encoding. num_classes: total number of classes to use for K-hot encoding.
quantization_bins: the number of bins used to quantize the box coordinate.
Returns: Returns:
merged_boxes: A tf.float32 tensor with shape [N', 4] holding boxes, merged_boxes: A tf.float32 tensor with shape [N', 4] holding boxes,
where N' <= N. where N' <= N.
class_encodings: A tf.int32 tensor with shape [N', num_classes] holding class_encodings: A tf.int32 tensor with shape [N', num_classes] holding
k-hot encodings for the merged boxes. K-hot encodings for the merged boxes.
confidence_encodings: A tf.float32 tensor with shape [N', num_classes]
holding encodings of confidences for the merged boxes.
merged_box_indices: A tf.int32 tensor with shape [N'] holding original merged_box_indices: A tf.int32 tensor with shape [N'] holding original
indices of the boxes. indices of the boxes.
""" """
def merge_numpy_boxes(boxes, classes, num_classes): boxes_shape = tf.shape(boxes)
"""Python function to merge numpy boxes.""" classes_shape = tf.shape(classes)
if boxes.size < 1: confidences_shape = tf.shape(confidences)
return (np.zeros([0, 4], dtype=np.float32), box_class_shape_assert = shape_utils.assert_shape_equal_along_first_dimension(
np.zeros([0, num_classes], dtype=np.int32), boxes_shape, classes_shape)
np.zeros([0], dtype=np.int32)) box_confidence_shape_assert = (
box_to_class_indices = {} shape_utils.assert_shape_equal_along_first_dimension(
for box_index in range(boxes.shape[0]): boxes_shape, confidences_shape))
box = tuple(boxes[box_index, :].tolist()) box_dimension_assert = tf.assert_equal(boxes_shape[1], 4)
class_index = classes[box_index] box_normalized_assert = shape_utils.assert_box_normalized(boxes)
if box not in box_to_class_indices:
box_to_class_indices[box] = [box_index, np.zeros([num_classes])] with tf.control_dependencies(
box_to_class_indices[box][1][class_index] = 1 [box_class_shape_assert, box_confidence_shape_assert,
merged_boxes = np.vstack(box_to_class_indices.keys()).astype(np.float32) box_dimension_assert, box_normalized_assert]):
class_encodings = [item[1] for item in box_to_class_indices.values()] quantized_boxes = tf.to_int64(boxes * (quantization_bins - 1))
class_encodings = np.vstack(class_encodings).astype(np.int32) ymin, xmin, ymax, xmax = tf.unstack(quantized_boxes, axis=1)
merged_box_indices = [item[0] for item in box_to_class_indices.values()] hashcodes = (
merged_box_indices = np.array(merged_box_indices).astype(np.int32) ymin +
return merged_boxes, class_encodings, merged_box_indices xmin * quantization_bins +
ymax * quantization_bins * quantization_bins +
merged_boxes, class_encodings, merged_box_indices = tf.py_func( xmax * quantization_bins * quantization_bins * quantization_bins)
merge_numpy_boxes, [boxes, classes, num_classes], unique_hashcodes, unique_indices = tf.unique(hashcodes)
[tf.float32, tf.int32, tf.int32]) num_boxes = tf.shape(boxes)[0]
merged_boxes = tf.reshape(merged_boxes, [-1, 4]) num_unique_boxes = tf.shape(unique_hashcodes)[0]
class_encodings = tf.reshape(class_encodings, [-1, num_classes]) merged_box_indices = tf.unsorted_segment_min(
merged_box_indices = tf.reshape(merged_box_indices, [-1]) tf.range(num_boxes), unique_indices, num_unique_boxes)
return merged_boxes, class_encodings, merged_box_indices merged_boxes = tf.gather(boxes, merged_box_indices)
def map_box_encodings(i):
"""Produces box K-hot and score encodings for each class index."""
box_mask = tf.equal(
unique_indices, i * tf.ones(num_boxes, dtype=tf.int32))
box_mask = tf.reshape(box_mask, [-1])
box_indices = tf.boolean_mask(classes, box_mask)
box_confidences = tf.boolean_mask(confidences, box_mask)
box_class_encodings = tf.sparse_to_dense(
box_indices, [num_classes], 1, validate_indices=False)
box_confidence_encodings = tf.sparse_to_dense(
box_indices, [num_classes], box_confidences, validate_indices=False)
return box_class_encodings, box_confidence_encodings
class_encodings, confidence_encodings = tf.map_fn(
map_box_encodings,
tf.range(num_unique_boxes),
back_prop=False,
dtype=(tf.int32, tf.float32))
merged_boxes = tf.reshape(merged_boxes, [-1, 4])
class_encodings = tf.reshape(class_encodings, [-1, num_classes])
confidence_encodings = tf.reshape(confidence_encodings, [-1, num_classes])
merged_box_indices = tf.reshape(merged_box_indices, [-1])
return (merged_boxes, class_encodings, confidence_encodings,
merged_box_indices)
def nearest_neighbor_upsampling(input_tensor, scale): def nearest_neighbor_upsampling(input_tensor, scale):
...@@ -895,7 +936,8 @@ def matmul_crop_and_resize(image, boxes, crop_size, scope=None): ...@@ -895,7 +936,8 @@ def matmul_crop_and_resize(image, boxes, crop_size, scope=None):
Returns a tensor with crops from the input image at positions defined at Returns a tensor with crops from the input image at positions defined at
the bounding box locations in boxes. The cropped boxes are all resized the bounding box locations in boxes. The cropped boxes are all resized
(with bilinear interpolation) to a fixed size = `[crop_height, crop_width]`. (with bilinear interpolation) to a fixed size = `[crop_height, crop_width]`.
The result is a 4-D tensor `[num_boxes, crop_height, crop_width, depth]`. The result is a 5-D tensor `[batch, num_boxes, crop_height, crop_width,
depth]`.
Running time complexity: Running time complexity:
O((# channels) * (# boxes) * (crop_size)^2 * M), where M is the number O((# channels) * (# boxes) * (crop_size)^2 * M), where M is the number
...@@ -914,14 +956,13 @@ def matmul_crop_and_resize(image, boxes, crop_size, scope=None): ...@@ -914,14 +956,13 @@ def matmul_crop_and_resize(image, boxes, crop_size, scope=None):
Args: Args:
image: A `Tensor`. Must be one of the following types: `uint8`, `int8`, image: A `Tensor`. Must be one of the following types: `uint8`, `int8`,
`int16`, `int32`, `int64`, `half`, `float32`, `float64`. `int16`, `int32`, `int64`, `half`, 'bfloat16', `float32`, `float64`.
A 4-D tensor of shape `[batch, image_height, image_width, depth]`. A 4-D tensor of shape `[batch, image_height, image_width, depth]`.
Both `image_height` and `image_width` need to be positive. Both `image_height` and `image_width` need to be positive.
boxes: A `Tensor` of type `float32`. boxes: A `Tensor` of type `float32` or 'bfloat16'.
A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor A 3-D tensor of shape `[batch, num_boxes, 4]`. The boxes are specified in
specifies the coordinates of a box in the `box_ind[i]` image and is normalized coordinates and are of the form `[y1, x1, y2, x2]`. A
specified in normalized coordinates `[y1, x1, y2, x2]`. A normalized normalized coordinate value of `y` is mapped to the image coordinate at
coordinate value of `y` is mapped to the image coordinate at
`y * (image_height - 1)`, so as the `[0, 1]` interval of normalized image `y * (image_height - 1)`, so as the `[0, 1]` interval of normalized image
height is mapped to `[0, image_height - 1] in image height coordinates. height is mapped to `[0, image_height - 1] in image height coordinates.
We do allow y1 > y2, in which case the sampled crop is an up-down flipped We do allow y1 > y2, in which case the sampled crop is an up-down flipped
...@@ -935,14 +976,14 @@ def matmul_crop_and_resize(image, boxes, crop_size, scope=None): ...@@ -935,14 +976,14 @@ def matmul_crop_and_resize(image, boxes, crop_size, scope=None):
scope: A name for the operation (optional). scope: A name for the operation (optional).
Returns: Returns:
A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]` A 5-D tensor of shape `[batch, num_boxes, crop_height, crop_width, depth]`
Raises: Raises:
ValueError: if image tensor does not have shape ValueError: if image tensor does not have shape
`[1, image_height, image_width, depth]` and all dimensions statically `[batch, image_height, image_width, depth]` and all dimensions statically
defined. defined.
ValueError: if boxes tensor does not have shape `[num_boxes, 4]` where ValueError: if boxes tensor does not have shape `[batch, num_boxes, 4]`
num_boxes > 0. where num_boxes > 0.
ValueError: if crop_size is not a list of two positive integers ValueError: if crop_size is not a list of two positive integers
""" """
img_shape = image.shape.as_list() img_shape = image.shape.as_list()
...@@ -953,13 +994,11 @@ def matmul_crop_and_resize(image, boxes, crop_size, scope=None): ...@@ -953,13 +994,11 @@ def matmul_crop_and_resize(image, boxes, crop_size, scope=None):
dimensions = img_shape + crop_size + boxes_shape dimensions = img_shape + crop_size + boxes_shape
if not all([isinstance(dim, int) for dim in dimensions]): if not all([isinstance(dim, int) for dim in dimensions]):
raise ValueError('all input shapes must be statically defined') raise ValueError('all input shapes must be statically defined')
if len(crop_size) != 2: if len(boxes_shape) != 3 or boxes_shape[2] != 4:
raise ValueError('`crop_size` must be a list of length 2') raise ValueError('`boxes` should have shape `[batch, num_boxes, 4]`')
if len(boxes_shape) != 2 or boxes_shape[1] != 4: if len(img_shape) != 4:
raise ValueError('`boxes` should have shape `[num_boxes, 4]`')
if len(img_shape) != 4 and img_shape[0] != 1:
raise ValueError('image should have shape ' raise ValueError('image should have shape '
'`[1, image_height, image_width, depth]`') '`[batch, image_height, image_width, depth]`')
num_crops = boxes_shape[0] num_crops = boxes_shape[0]
if not num_crops > 0: if not num_crops > 0:
raise ValueError('number of boxes must be > 0') raise ValueError('number of boxes must be > 0')
...@@ -968,43 +1007,69 @@ def matmul_crop_and_resize(image, boxes, crop_size, scope=None): ...@@ -968,43 +1007,69 @@ def matmul_crop_and_resize(image, boxes, crop_size, scope=None):
def _lin_space_weights(num, img_size): def _lin_space_weights(num, img_size):
if num > 1: if num > 1:
alpha = (img_size - 1) / float(num - 1) start_weights = tf.linspace(img_size - 1.0, 0.0, num)
indices = np.reshape(np.arange(num), (1, num)) stop_weights = img_size - 1 - start_weights
start_weights = alpha * (num - 1 - indices)
stop_weights = alpha * indices
else: else:
start_weights = num * [.5 * (img_size - 1)] start_weights = tf.constant(num * [.5 * (img_size - 1)], dtype=tf.float32)
stop_weights = num * [.5 * (img_size - 1)] stop_weights = tf.constant(num * [.5 * (img_size - 1)], dtype=tf.float32)
return (tf.constant(start_weights, dtype=tf.float32), return (start_weights, stop_weights)
tf.constant(stop_weights, dtype=tf.float32))
with tf.name_scope(scope, 'MatMulCropAndResize'): with tf.name_scope(scope, 'MatMulCropAndResize'):
y1_weights, y2_weights = _lin_space_weights(crop_size[0], img_height) y1_weights, y2_weights = _lin_space_weights(crop_size[0], img_height)
x1_weights, x2_weights = _lin_space_weights(crop_size[1], img_width) x1_weights, x2_weights = _lin_space_weights(crop_size[1], img_width)
[y1, x1, y2, x2] = tf.split(value=boxes, num_or_size_splits=4, axis=1) y1_weights = tf.cast(y1_weights, boxes.dtype)
y2_weights = tf.cast(y2_weights, boxes.dtype)
x1_weights = tf.cast(x1_weights, boxes.dtype)
x2_weights = tf.cast(x2_weights, boxes.dtype)
[y1, x1, y2, x2] = tf.unstack(boxes, axis=2)
# Pixel centers of input image and grid points along height and width # Pixel centers of input image and grid points along height and width
image_idx_h = tf.constant( image_idx_h = tf.constant(
np.reshape(np.arange(img_height), (1, 1, img_height)), dtype=tf.float32) np.reshape(np.arange(img_height), (1, 1, 1, img_height)),
dtype=boxes.dtype)
image_idx_w = tf.constant( image_idx_w = tf.constant(
np.reshape(np.arange(img_width), (1, 1, img_width)), dtype=tf.float32) np.reshape(np.arange(img_width), (1, 1, 1, img_width)),
grid_pos_h = tf.expand_dims(y1 * y1_weights + y2 * y2_weights, 2) dtype=boxes.dtype)
grid_pos_w = tf.expand_dims(x1 * x1_weights + x2 * x2_weights, 2) grid_pos_h = tf.expand_dims(
tf.einsum('ab,c->abc', y1, y1_weights) + tf.einsum(
'ab,c->abc', y2, y2_weights),
axis=3)
grid_pos_w = tf.expand_dims(
tf.einsum('ab,c->abc', x1, x1_weights) + tf.einsum(
'ab,c->abc', x2, x2_weights),
axis=3)
# Create kernel matrices of pairwise kernel evaluations between pixel # Create kernel matrices of pairwise kernel evaluations between pixel
# centers of image and grid points. # centers of image and grid points.
kernel_h = tf.nn.relu(1 - tf.abs(image_idx_h - grid_pos_h)) kernel_h = tf.nn.relu(1 - tf.abs(image_idx_h - grid_pos_h))
kernel_w = tf.nn.relu(1 - tf.abs(image_idx_w - grid_pos_w)) kernel_w = tf.nn.relu(1 - tf.abs(image_idx_w - grid_pos_w))
# TODO(jonathanhuang): investigate whether all channels can be processed # Compute matrix multiplication between the spatial dimensions of the image
# without the explicit unstack --- possibly with a permute and map_fn call. # and height-wise kernel using einsum.
result_channels = [] intermediate_image = tf.einsum('abci,aiop->abcop', kernel_h, image)
for channel in tf.unstack(image, axis=3): # Compute matrix multiplication between the spatial dimensions of the
result_channels.append( # intermediate_image and width-wise kernel using einsum.
tf.matmul( return tf.einsum('abno,abcop->abcnp', kernel_w, intermediate_image)
tf.matmul(kernel_h, tf.tile(channel, [num_crops, 1, 1])),
kernel_w, transpose_b=True))
return tf.stack(result_channels, axis=3) def native_crop_and_resize(image, boxes, crop_size, scope=None):
"""Same as `matmul_crop_and_resize` but uses tf.image.crop_and_resize."""
def get_box_inds(proposals):
proposals_shape = proposals.get_shape().as_list()
if any(dim is None for dim in proposals_shape):
proposals_shape = tf.shape(proposals)
ones_mat = tf.ones(proposals_shape[:2], dtype=tf.int32)
multiplier = tf.expand_dims(
tf.range(start=0, limit=proposals_shape[0]), 1)
return tf.reshape(ones_mat * multiplier, [-1])
with tf.name_scope(scope, 'CropAndResize'):
cropped_regions = tf.image.crop_and_resize(
image, tf.reshape(boxes, [-1] + boxes.shape.as_list()[2:]),
get_box_inds(boxes), crop_size)
final_shape = tf.concat([tf.shape(boxes)[:2],
tf.shape(cropped_regions)[1:]], axis=0)
return tf.reshape(cropped_regions, final_shape)
def expected_classification_loss_under_sampling(batch_cls_targets, cls_losses, def expected_classification_loss_under_sampling(batch_cls_targets, cls_losses,
......
...@@ -1147,36 +1147,76 @@ class MergeBoxesWithMultipleLabelsTest(tf.test.TestCase): ...@@ -1147,36 +1147,76 @@ class MergeBoxesWithMultipleLabelsTest(tf.test.TestCase):
[0.25, 0.25, 0.75, 0.75]], [0.25, 0.25, 0.75, 0.75]],
dtype=tf.float32) dtype=tf.float32)
class_indices = tf.constant([0, 4, 2], dtype=tf.int32) class_indices = tf.constant([0, 4, 2], dtype=tf.int32)
class_confidences = tf.constant([0.8, 0.2, 0.1], dtype=tf.float32)
num_classes = 5 num_classes = 5
merged_boxes, merged_classes, merged_box_indices = ( merged_boxes, merged_classes, merged_confidences, merged_box_indices = (
ops.merge_boxes_with_multiple_labels(boxes, class_indices, num_classes)) ops.merge_boxes_with_multiple_labels(
boxes, class_indices, class_confidences, num_classes))
expected_merged_boxes = np.array( expected_merged_boxes = np.array(
[[0.25, 0.25, 0.75, 0.75], [0.0, 0.0, 0.5, 0.75]], dtype=np.float32) [[0.25, 0.25, 0.75, 0.75], [0.0, 0.0, 0.5, 0.75]], dtype=np.float32)
expected_merged_classes = np.array( expected_merged_classes = np.array(
[[1, 0, 1, 0, 0], [0, 0, 0, 0, 1]], dtype=np.int32) [[1, 0, 1, 0, 0], [0, 0, 0, 0, 1]], dtype=np.int32)
expected_merged_confidences = np.array(
[[0.8, 0, 0.1, 0, 0], [0, 0, 0, 0, 0.2]], dtype=np.float32)
expected_merged_box_indices = np.array([0, 1], dtype=np.int32) expected_merged_box_indices = np.array([0, 1], dtype=np.int32)
with self.test_session() as sess: with self.test_session() as sess:
np_merged_boxes, np_merged_classes, np_merged_box_indices = sess.run( (np_merged_boxes, np_merged_classes, np_merged_confidences,
[merged_boxes, merged_classes, merged_box_indices]) np_merged_box_indices) = sess.run(
if np_merged_classes[0, 0] != 1: [merged_boxes, merged_classes, merged_confidences,
expected_merged_boxes = expected_merged_boxes[::-1, :] merged_box_indices])
expected_merged_classes = expected_merged_classes[::-1, :]
expected_merged_box_indices = expected_merged_box_indices[::-1, :]
self.assertAllClose(np_merged_boxes, expected_merged_boxes) self.assertAllClose(np_merged_boxes, expected_merged_boxes)
self.assertAllClose(np_merged_classes, expected_merged_classes) self.assertAllClose(np_merged_classes, expected_merged_classes)
self.assertAllClose(np_merged_confidences, expected_merged_confidences)
self.assertAllClose(np_merged_box_indices, expected_merged_box_indices)
def testMergeBoxesWithMultipleLabelsCornerCase(self):
boxes = tf.constant(
[[0, 0, 1, 1], [0, 1, 1, 1], [1, 0, 1, 1], [1, 1, 1, 1],
[1, 1, 1, 1], [1, 0, 1, 1], [0, 1, 1, 1], [0, 0, 1, 1]],
dtype=tf.float32)
class_indices = tf.constant([0, 1, 2, 3, 2, 1, 0, 3], dtype=tf.int32)
class_confidences = tf.constant([0.1, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4, 0.6],
dtype=tf.float32)
num_classes = 4
merged_boxes, merged_classes, merged_confidences, merged_box_indices = (
ops.merge_boxes_with_multiple_labels(
boxes, class_indices, class_confidences, num_classes))
expected_merged_boxes = np.array(
[[0, 0, 1, 1], [0, 1, 1, 1], [1, 0, 1, 1], [1, 1, 1, 1]],
dtype=np.float32)
expected_merged_classes = np.array(
[[1, 0, 0, 1], [1, 1, 0, 0], [0, 1, 1, 0], [0, 0, 1, 1]],
dtype=np.int32)
expected_merged_confidences = np.array(
[[0.1, 0, 0, 0.6], [0.4, 0.9, 0, 0],
[0, 0.7, 0.2, 0], [0, 0, 0.3, 0.8]], dtype=np.float32)
expected_merged_box_indices = np.array([0, 1, 2, 3], dtype=np.int32)
with self.test_session() as sess:
(np_merged_boxes, np_merged_classes, np_merged_confidences,
np_merged_box_indices) = sess.run(
[merged_boxes, merged_classes, merged_confidences,
merged_box_indices])
self.assertAllClose(np_merged_boxes, expected_merged_boxes)
self.assertAllClose(np_merged_classes, expected_merged_classes)
self.assertAllClose(np_merged_confidences, expected_merged_confidences)
self.assertAllClose(np_merged_box_indices, expected_merged_box_indices) self.assertAllClose(np_merged_box_indices, expected_merged_box_indices)
def testMergeBoxesWithEmptyInputs(self): def testMergeBoxesWithEmptyInputs(self):
boxes = tf.constant([[]]) boxes = tf.zeros([0, 4], dtype=tf.float32)
class_indices = tf.constant([]) class_indices = tf.constant([], dtype=tf.int32)
class_confidences = tf.constant([], dtype=tf.float32)
num_classes = 5 num_classes = 5
merged_boxes, merged_classes, merged_box_indices = ( merged_boxes, merged_classes, merged_confidences, merged_box_indices = (
ops.merge_boxes_with_multiple_labels(boxes, class_indices, num_classes)) ops.merge_boxes_with_multiple_labels(
boxes, class_indices, class_confidences, num_classes))
with self.test_session() as sess: with self.test_session() as sess:
np_merged_boxes, np_merged_classes, np_merged_box_indices = sess.run( (np_merged_boxes, np_merged_classes, np_merged_confidences,
[merged_boxes, merged_classes, merged_box_indices]) np_merged_box_indices) = sess.run(
[merged_boxes, merged_classes, merged_confidences,
merged_box_indices])
self.assertAllEqual(np_merged_boxes.shape, [0, 4]) self.assertAllEqual(np_merged_boxes.shape, [0, 4])
self.assertAllEqual(np_merged_classes.shape, [0, 5]) self.assertAllEqual(np_merged_classes.shape, [0, 5])
self.assertAllEqual(np_merged_confidences.shape, [0, 5])
self.assertAllEqual(np_merged_box_indices.shape, [0]) self.assertAllEqual(np_merged_box_indices.shape, [0])
...@@ -1268,8 +1308,8 @@ class OpsTestMatMulCropAndResize(test_case.TestCase): ...@@ -1268,8 +1308,8 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
return ops.matmul_crop_and_resize(image, boxes, crop_size=[1, 1]) return ops.matmul_crop_and_resize(image, boxes, crop_size=[1, 1])
image = np.array([[[[1], [2]], [[3], [4]]]], dtype=np.float32) image = np.array([[[[1], [2]], [[3], [4]]]], dtype=np.float32)
boxes = np.array([[0, 0, 1, 1]], dtype=np.float32) boxes = np.array([[[0, 0, 1, 1]]], dtype=np.float32)
expected_output = [[[[2.5]]]] expected_output = [[[[[2.5]]]]]
crop_output = self.execute(graph_fn, [image, boxes]) crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output) self.assertAllClose(crop_output, expected_output)
...@@ -1279,8 +1319,8 @@ class OpsTestMatMulCropAndResize(test_case.TestCase): ...@@ -1279,8 +1319,8 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
return ops.matmul_crop_and_resize(image, boxes, crop_size=[1, 1]) return ops.matmul_crop_and_resize(image, boxes, crop_size=[1, 1])
image = np.array([[[[1], [2]], [[3], [4]]]], dtype=np.float32) image = np.array([[[[1], [2]], [[3], [4]]]], dtype=np.float32)
boxes = np.array([[1, 1, 0, 0]], dtype=np.float32) boxes = np.array([[[1, 1, 0, 0]]], dtype=np.float32)
expected_output = [[[[2.5]]]] expected_output = [[[[[2.5]]]]]
crop_output = self.execute(graph_fn, [image, boxes]) crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output) self.assertAllClose(crop_output, expected_output)
...@@ -1290,10 +1330,10 @@ class OpsTestMatMulCropAndResize(test_case.TestCase): ...@@ -1290,10 +1330,10 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
return ops.matmul_crop_and_resize(image, boxes, crop_size=[3, 3]) return ops.matmul_crop_and_resize(image, boxes, crop_size=[3, 3])
image = np.array([[[[1], [2]], [[3], [4]]]], dtype=np.float32) image = np.array([[[[1], [2]], [[3], [4]]]], dtype=np.float32)
boxes = np.array([[0, 0, 1, 1]], dtype=np.float32) boxes = np.array([[[0, 0, 1, 1]]], dtype=np.float32)
expected_output = [[[[1.0], [1.5], [2.0]], expected_output = [[[[[1.0], [1.5], [2.0]],
[[2.0], [2.5], [3.0]], [[2.0], [2.5], [3.0]],
[[3.0], [3.5], [4.0]]]] [[3.0], [3.5], [4.0]]]]]
crop_output = self.execute(graph_fn, [image, boxes]) crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output) self.assertAllClose(crop_output, expected_output)
...@@ -1303,10 +1343,10 @@ class OpsTestMatMulCropAndResize(test_case.TestCase): ...@@ -1303,10 +1343,10 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
return ops.matmul_crop_and_resize(image, boxes, crop_size=[3, 3]) return ops.matmul_crop_and_resize(image, boxes, crop_size=[3, 3])
image = np.array([[[[1], [2]], [[3], [4]]]], dtype=np.float32) image = np.array([[[[1], [2]], [[3], [4]]]], dtype=np.float32)
boxes = np.array([[1, 1, 0, 0]], dtype=np.float32) boxes = np.array([[[1, 1, 0, 0]]], dtype=np.float32)
expected_output = [[[[4.0], [3.5], [3.0]], expected_output = [[[[[4.0], [3.5], [3.0]],
[[3.0], [2.5], [2.0]], [[3.0], [2.5], [2.0]],
[[2.0], [1.5], [1.0]]]] [[2.0], [1.5], [1.0]]]]]
crop_output = self.execute(graph_fn, [image, boxes]) crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output) self.assertAllClose(crop_output, expected_output)
...@@ -1318,14 +1358,14 @@ class OpsTestMatMulCropAndResize(test_case.TestCase): ...@@ -1318,14 +1358,14 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
image = np.array([[[[1], [2], [3]], image = np.array([[[[1], [2], [3]],
[[4], [5], [6]], [[4], [5], [6]],
[[7], [8], [9]]]], dtype=np.float32) [[7], [8], [9]]]], dtype=np.float32)
boxes = np.array([[0, 0, 1, 1], boxes = np.array([[[0, 0, 1, 1],
[0, 0, .5, .5]], dtype=np.float32) [0, 0, .5, .5]]], dtype=np.float32)
expected_output = [[[[1], [3]], [[7], [9]]], expected_output = [[[[[1], [3]], [[7], [9]]],
[[[1], [2]], [[4], [5]]]] [[[1], [2]], [[4], [5]]]]]
crop_output = self.execute(graph_fn, [image, boxes]) crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output) self.assertAllClose(crop_output, expected_output)
def testMatMulCropAndResize3x3To2x2MultiChannel(self): def testMatMulCropAndResize3x3To2x2_2Channels(self):
def graph_fn(image, boxes): def graph_fn(image, boxes):
return ops.matmul_crop_and_resize(image, boxes, crop_size=[2, 2]) return ops.matmul_crop_and_resize(image, boxes, crop_size=[2, 2])
...@@ -1333,10 +1373,32 @@ class OpsTestMatMulCropAndResize(test_case.TestCase): ...@@ -1333,10 +1373,32 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
image = np.array([[[[1, 0], [2, 1], [3, 2]], image = np.array([[[[1, 0], [2, 1], [3, 2]],
[[4, 3], [5, 4], [6, 5]], [[4, 3], [5, 4], [6, 5]],
[[7, 6], [8, 7], [9, 8]]]], dtype=np.float32) [[7, 6], [8, 7], [9, 8]]]], dtype=np.float32)
boxes = np.array([[0, 0, 1, 1], boxes = np.array([[[0, 0, 1, 1],
[0, 0, .5, .5]], dtype=np.float32) [0, 0, .5, .5]]], dtype=np.float32)
expected_output = [[[[1, 0], [3, 2]], [[7, 6], [9, 8]]], expected_output = [[[[[1, 0], [3, 2]], [[7, 6], [9, 8]]],
[[[1, 0], [2, 1]], [[4, 3], [5, 4]]]] [[[1, 0], [2, 1]], [[4, 3], [5, 4]]]]]
crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output)
def testBatchMatMulCropAndResize3x3To2x2_2Channels(self):
def graph_fn(image, boxes):
return ops.matmul_crop_and_resize(image, boxes, crop_size=[2, 2])
image = np.array([[[[1, 0], [2, 1], [3, 2]],
[[4, 3], [5, 4], [6, 5]],
[[7, 6], [8, 7], [9, 8]]],
[[[1, 0], [2, 1], [3, 2]],
[[4, 3], [5, 4], [6, 5]],
[[7, 6], [8, 7], [9, 8]]]], dtype=np.float32)
boxes = np.array([[[0, 0, 1, 1],
[0, 0, .5, .5]],
[[1, 1, 0, 0],
[.5, .5, 0, 0]]], dtype=np.float32)
expected_output = [[[[[1, 0], [3, 2]], [[7, 6], [9, 8]]],
[[[1, 0], [2, 1]], [[4, 3], [5, 4]]]],
[[[[9, 8], [7, 6]], [[3, 2], [1, 0]]],
[[[5, 4], [4, 3]], [[2, 1], [1, 0]]]]]
crop_output = self.execute(graph_fn, [image, boxes]) crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output) self.assertAllClose(crop_output, expected_output)
...@@ -1348,10 +1410,10 @@ class OpsTestMatMulCropAndResize(test_case.TestCase): ...@@ -1348,10 +1410,10 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
image = np.array([[[[1], [2], [3]], image = np.array([[[[1], [2], [3]],
[[4], [5], [6]], [[4], [5], [6]],
[[7], [8], [9]]]], dtype=np.float32) [[7], [8], [9]]]], dtype=np.float32)
boxes = np.array([[1, 1, 0, 0], boxes = np.array([[[1, 1, 0, 0],
[.5, .5, 0, 0]], dtype=np.float32) [.5, .5, 0, 0]]], dtype=np.float32)
expected_output = [[[[9], [7]], [[3], [1]]], expected_output = [[[[[9], [7]], [[3], [1]]],
[[[5], [4]], [[2], [1]]]] [[[5], [4]], [[2], [1]]]]]
crop_output = self.execute(graph_fn, [image, boxes]) crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output) self.assertAllClose(crop_output, expected_output)
...@@ -1363,6 +1425,31 @@ class OpsTestMatMulCropAndResize(test_case.TestCase): ...@@ -1363,6 +1425,31 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
_ = ops.matmul_crop_and_resize(image, boxes, crop_size) _ = ops.matmul_crop_and_resize(image, boxes, crop_size)
class OpsTestCropAndResize(test_case.TestCase):
def testBatchCropAndResize3x3To2x2_2Channels(self):
def graph_fn(image, boxes):
return ops.native_crop_and_resize(image, boxes, crop_size=[2, 2])
image = np.array([[[[1, 0], [2, 1], [3, 2]],
[[4, 3], [5, 4], [6, 5]],
[[7, 6], [8, 7], [9, 8]]],
[[[1, 0], [2, 1], [3, 2]],
[[4, 3], [5, 4], [6, 5]],
[[7, 6], [8, 7], [9, 8]]]], dtype=np.float32)
boxes = np.array([[[0, 0, 1, 1],
[0, 0, .5, .5]],
[[1, 1, 0, 0],
[.5, .5, 0, 0]]], dtype=np.float32)
expected_output = [[[[[1, 0], [3, 2]], [[7, 6], [9, 8]]],
[[[1, 0], [2, 1]], [[4, 3], [5, 4]]]],
[[[[9, 8], [7, 6]], [[3, 2], [1, 0]]],
[[[5, 4], [4, 3]], [[2, 1], [1, 0]]]]]
crop_output = self.execute_cpu(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output)
class OpsTestExpectedClassificationLoss(test_case.TestCase): class OpsTestExpectedClassificationLoss(test_case.TestCase):
def testExpectedClassificationLossUnderSamplingWithHardLabels(self): def testExpectedClassificationLossUnderSamplingWithHardLabels(self):
......
...@@ -342,3 +342,26 @@ def assert_shape_equal_along_first_dimension(shape_a, shape_b): ...@@ -342,3 +342,26 @@ def assert_shape_equal_along_first_dimension(shape_a, shape_b):
else: return tf.no_op() else: return tf.no_op()
else: else:
return tf.assert_equal(shape_a[0], shape_b[0]) return tf.assert_equal(shape_a[0], shape_b[0])
def assert_box_normalized(boxes, maximum_normalized_coordinate=1.1):
"""Asserts the input box tensor is normalized.
Args:
boxes: a tensor of shape [N, 4] where N is the number of boxes.
maximum_normalized_coordinate: Maximum coordinate value to be considered
as normalized, default to 1.1.
Returns:
a tf.Assert op which fails when the input box tensor is not normalized.
Raises:
ValueError: When the input box tensor is not normalized.
"""
box_minimum = tf.reduce_min(boxes)
box_maximum = tf.reduce_max(boxes)
return tf.Assert(
tf.logical_and(
tf.less_equal(box_maximum, maximum_normalized_coordinate),
tf.greater_equal(box_minimum, 0)),
[boxes])
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# ============================================================================== # ==============================================================================
"""A convenience wrapper around tf.test.TestCase to enable TPU tests.""" """A convenience wrapper around tf.test.TestCase to enable TPU tests."""
import os
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib import tpu from tensorflow.contrib import tpu
...@@ -23,6 +24,8 @@ flags.DEFINE_bool('tpu_test', False, 'Whether to configure test for TPU.') ...@@ -23,6 +24,8 @@ flags.DEFINE_bool('tpu_test', False, 'Whether to configure test for TPU.')
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
class TestCase(tf.test.TestCase): class TestCase(tf.test.TestCase):
"""Extends tf.test.TestCase to optionally allow running tests on TPU.""" """Extends tf.test.TestCase to optionally allow running tests on TPU."""
......
...@@ -24,6 +24,9 @@ from object_detection.core import box_predictor ...@@ -24,6 +24,9 @@ from object_detection.core import box_predictor
from object_detection.core import matcher from object_detection.core import matcher
from object_detection.utils import shape_utils from object_detection.utils import shape_utils
# Default size (both width and height) used for testing mask predictions.
DEFAULT_MASK_SIZE = 5
class MockBoxCoder(box_coder.BoxCoder): class MockBoxCoder(box_coder.BoxCoder):
"""Simple `difference` BoxCoder.""" """Simple `difference` BoxCoder."""
...@@ -42,8 +45,9 @@ class MockBoxCoder(box_coder.BoxCoder): ...@@ -42,8 +45,9 @@ class MockBoxCoder(box_coder.BoxCoder):
class MockBoxPredictor(box_predictor.BoxPredictor): class MockBoxPredictor(box_predictor.BoxPredictor):
"""Simple box predictor that ignores inputs and outputs all zeros.""" """Simple box predictor that ignores inputs and outputs all zeros."""
def __init__(self, is_training, num_classes): def __init__(self, is_training, num_classes, predict_mask=False):
super(MockBoxPredictor, self).__init__(is_training, num_classes) super(MockBoxPredictor, self).__init__(is_training, num_classes)
self._predict_mask = predict_mask
def _predict(self, image_features, num_predictions_per_location): def _predict(self, image_features, num_predictions_per_location):
image_feature = image_features[0] image_feature = image_features[0]
...@@ -57,17 +61,29 @@ class MockBoxPredictor(box_predictor.BoxPredictor): ...@@ -57,17 +61,29 @@ class MockBoxPredictor(box_predictor.BoxPredictor):
(batch_size, num_anchors, 1, code_size), dtype=tf.float32) (batch_size, num_anchors, 1, code_size), dtype=tf.float32)
class_predictions_with_background = zero + tf.zeros( class_predictions_with_background = zero + tf.zeros(
(batch_size, num_anchors, self.num_classes + 1), dtype=tf.float32) (batch_size, num_anchors, self.num_classes + 1), dtype=tf.float32)
return {box_predictor.BOX_ENCODINGS: box_encodings, masks = zero + tf.zeros(
box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND: (batch_size, num_anchors, self.num_classes, DEFAULT_MASK_SIZE,
class_predictions_with_background} DEFAULT_MASK_SIZE),
dtype=tf.float32)
predictions_dict = {
box_predictor.BOX_ENCODINGS:
box_encodings,
box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND:
class_predictions_with_background
}
if self._predict_mask:
predictions_dict[box_predictor.MASK_PREDICTIONS] = masks
return predictions_dict
class MockKerasBoxPredictor(box_predictor.KerasBoxPredictor): class MockKerasBoxPredictor(box_predictor.KerasBoxPredictor):
"""Simple box predictor that ignores inputs and outputs all zeros.""" """Simple box predictor that ignores inputs and outputs all zeros."""
def __init__(self, is_training, num_classes): def __init__(self, is_training, num_classes, predict_mask=False):
super(MockKerasBoxPredictor, self).__init__( super(MockKerasBoxPredictor, self).__init__(
is_training, num_classes, False, False) is_training, num_classes, False, False)
self._predict_mask = predict_mask
def _predict(self, image_features, **kwargs): def _predict(self, image_features, **kwargs):
image_feature = image_features[0] image_feature = image_features[0]
...@@ -81,9 +97,19 @@ class MockKerasBoxPredictor(box_predictor.KerasBoxPredictor): ...@@ -81,9 +97,19 @@ class MockKerasBoxPredictor(box_predictor.KerasBoxPredictor):
(batch_size, num_anchors, 1, code_size), dtype=tf.float32) (batch_size, num_anchors, 1, code_size), dtype=tf.float32)
class_predictions_with_background = zero + tf.zeros( class_predictions_with_background = zero + tf.zeros(
(batch_size, num_anchors, self.num_classes + 1), dtype=tf.float32) (batch_size, num_anchors, self.num_classes + 1), dtype=tf.float32)
return {box_predictor.BOX_ENCODINGS: box_encodings, masks = zero + tf.zeros(
box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND: (batch_size, num_anchors, self.num_classes, DEFAULT_MASK_SIZE,
class_predictions_with_background} DEFAULT_MASK_SIZE),
dtype=tf.float32)
predictions_dict = {
box_predictor.BOX_ENCODINGS:
box_encodings,
box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND:
class_predictions_with_background
}
if self._predict_mask:
predictions_dict[box_predictor.MASK_PREDICTIONS] = masks
return predictions_dict
class MockAnchorGenerator(anchor_generator.AnchorGenerator): class MockAnchorGenerator(anchor_generator.AnchorGenerator):
...@@ -103,7 +129,7 @@ class MockAnchorGenerator(anchor_generator.AnchorGenerator): ...@@ -103,7 +129,7 @@ class MockAnchorGenerator(anchor_generator.AnchorGenerator):
class MockMatcher(matcher.Matcher): class MockMatcher(matcher.Matcher):
"""Simple matcher that matches first anchor to first groundtruth box.""" """Simple matcher that matches first anchor to first groundtruth box."""
def _match(self, similarity_matrix): def _match(self, similarity_matrix, valid_rows):
return tf.constant([0, -1, -1, -1], dtype=tf.int32) return tf.constant([0, -1, -1, -1], dtype=tf.int32)
......
...@@ -19,6 +19,8 @@ These functions often receive an image, perform some visualization on the image. ...@@ -19,6 +19,8 @@ These functions often receive an image, perform some visualization on the image.
The functions do not return a value, instead they modify the image itself. The functions do not return a value, instead they modify the image itself.
""" """
from abc import ABCMeta
from abc import abstractmethod
import collections import collections
import functools import functools
# Set headless-friendly backend. # Set headless-friendly backend.
...@@ -731,3 +733,158 @@ def add_hist_image_summary(values, bins, name): ...@@ -731,3 +733,158 @@ def add_hist_image_summary(values, bins, name):
return image return image
hist_plot = tf.py_func(hist_plot, [values, bins], tf.uint8) hist_plot = tf.py_func(hist_plot, [values, bins], tf.uint8)
tf.summary.image(name, hist_plot) tf.summary.image(name, hist_plot)
class EvalMetricOpsVisualization(object):
"""Abstract base class responsible for visualizations during evaluation.
Currently, summary images are not run during evaluation. One way to produce
evaluation images in Tensorboard is to provide tf.summary.image strings as
`value_ops` in tf.estimator.EstimatorSpec's `eval_metric_ops`. This class is
responsible for accruing images (with overlaid detections and groundtruth)
and returning a dictionary that can be passed to `eval_metric_ops`.
"""
__metaclass__ = ABCMeta
def __init__(self,
category_index,
max_examples_to_draw=5,
max_boxes_to_draw=20,
min_score_thresh=0.2,
use_normalized_coordinates=True,
summary_name_prefix='evaluation_image'):
"""Creates an EvalMetricOpsVisualization.
Args:
category_index: A category index (dictionary) produced from a labelmap.
max_examples_to_draw: The maximum number of example summaries to produce.
max_boxes_to_draw: The maximum number of boxes to draw for detections.
min_score_thresh: The minimum score threshold for showing detections.
use_normalized_coordinates: Whether to assume boxes and kepoints are in
normalized coordinates (as opposed to absolute coordiantes).
Default is True.
summary_name_prefix: A string prefix for each image summary.
"""
self._category_index = category_index
self._max_examples_to_draw = max_examples_to_draw
self._max_boxes_to_draw = max_boxes_to_draw
self._min_score_thresh = min_score_thresh
self._use_normalized_coordinates = use_normalized_coordinates
self._summary_name_prefix = summary_name_prefix
self._images = []
def clear(self):
self._images = []
def add_images(self, images):
"""Store a list of images, each with shape [1, H, W, C]."""
if len(self._images) >= self._max_examples_to_draw:
return
# Store images and clip list if necessary.
self._images.extend(images)
if len(self._images) > self._max_examples_to_draw:
self._images[self._max_examples_to_draw:] = []
def get_estimator_eval_metric_ops(self, eval_dict):
"""Returns metric ops for use in tf.estimator.EstimatorSpec.
Args:
eval_dict: A dictionary that holds an image, groundtruth, and detections
for a single example. See eval_util.result_dict_for_single_example() for
a convenient method for constructing such a dictionary. The dictionary
contains
fields.InputDataFields.original_image: [1, H, W, 3] image.
fields.InputDataFields.groundtruth_boxes - [num_boxes, 4] float32
tensor with groundtruth boxes in range [0.0, 1.0].
fields.InputDataFields.groundtruth_classes - [num_boxes] int64
tensor with 1-indexed groundtruth classes.
fields.InputDataFields.groundtruth_instance_masks - (optional)
[num_boxes, H, W] int64 tensor with instance masks.
fields.DetectionResultFields.detection_boxes - [max_num_boxes, 4]
float32 tensor with detection boxes in range [0.0, 1.0].
fields.DetectionResultFields.detection_classes - [max_num_boxes]
int64 tensor with 1-indexed detection classes.
fields.DetectionResultFields.detection_scores - [max_num_boxes]
float32 tensor with detection scores.
fields.DetectionResultFields.detection_masks - (optional)
[max_num_boxes, H, W] float32 tensor of binarized masks.
fields.DetectionResultFields.detection_keypoints - (optional)
[max_num_boxes, num_keypoints, 2] float32 tensor with keypooints.
Returns:
A dictionary of image summary names to tuple of (value_op, update_op). The
`update_op` is the same for all items in the dictionary, and is
responsible for saving a single side-by-side image with detections and
groundtruth. Each `value_op` holds the tf.summary.image string for a given
image.
"""
images = self.images_from_evaluation_dict(eval_dict)
def get_images():
"""Returns a list of images, padded to self._max_images_to_draw."""
images = self._images
while len(images) < self._max_examples_to_draw:
images.append(np.array(0, dtype=np.uint8))
self.clear()
return images
def image_summary_or_default_string(summary_name, image):
"""Returns image summaries for non-padded elements."""
return tf.cond(
tf.equal(tf.size(tf.shape(image)), 4),
lambda: tf.summary.image(summary_name, image),
lambda: tf.constant(''))
update_op = tf.py_func(self.add_images, [images], [])
image_tensors = tf.py_func(
get_images, [], [tf.uint8] * self._max_examples_to_draw)
eval_metric_ops = {}
for i, image in enumerate(image_tensors):
summary_name = self._summary_name_prefix + '/' + str(i)
value_op = image_summary_or_default_string(summary_name, image)
eval_metric_ops[summary_name] = (value_op, update_op)
return eval_metric_ops
@abstractmethod
def images_from_evaluation_dict(self, eval_dict):
"""Converts evaluation dictionary into a list of image tensors.
To be overridden by implementations.
Args:
eval_dict: A dictionary with all the necessary information for producing
visualizations.
Returns:
A list of [1, H, W, C] uint8 tensors.
"""
raise NotImplementedError
class VisualizeSingleFrameDetections(EvalMetricOpsVisualization):
"""Class responsible for single-frame object detection visualizations."""
def __init__(self,
category_index,
max_examples_to_draw=5,
max_boxes_to_draw=20,
min_score_thresh=0.2,
use_normalized_coordinates=True,
summary_name_prefix='Detections_Left_Groundtruth_Right'):
super(VisualizeSingleFrameDetections, self).__init__(
category_index=category_index,
max_examples_to_draw=max_examples_to_draw,
max_boxes_to_draw=max_boxes_to_draw,
min_score_thresh=min_score_thresh,
use_normalized_coordinates=use_normalized_coordinates,
summary_name_prefix=summary_name_prefix)
def images_from_evaluation_dict(self, eval_dict):
return [draw_side_by_side_evaluation_image(
eval_dict,
self._category_index,
self._max_boxes_to_draw,
self._min_score_thresh,
self._use_normalized_coordinates)]
...@@ -21,6 +21,7 @@ import numpy as np ...@@ -21,6 +21,7 @@ import numpy as np
import PIL.Image as Image import PIL.Image as Image
import tensorflow as tf import tensorflow as tf
from object_detection.core import standard_fields as fields
from object_detection.utils import visualization_utils from object_detection.utils import visualization_utils
_TESTDATA_PATH = 'object_detection/test_images' _TESTDATA_PATH = 'object_detection/test_images'
...@@ -225,6 +226,80 @@ class VisualizationUtilsTest(tf.test.TestCase): ...@@ -225,6 +226,80 @@ class VisualizationUtilsTest(tf.test.TestCase):
with self.test_session(): with self.test_session():
hist_image_summary.eval() hist_image_summary.eval()
def test_eval_metric_ops(self):
category_index = {1: {'id': 1, 'name': 'dog'}, 2: {'id': 2, 'name': 'cat'}}
max_examples_to_draw = 4
metric_op_base = 'Detections_Left_Groundtruth_Right'
eval_metric_ops = visualization_utils.VisualizeSingleFrameDetections(
category_index,
max_examples_to_draw=max_examples_to_draw,
summary_name_prefix=metric_op_base)
original_image = tf.placeholder(tf.uint8, [1, None, None, 3])
detection_boxes = tf.random_uniform([20, 4],
minval=0.0,
maxval=1.0,
dtype=tf.float32)
detection_classes = tf.random_uniform([20],
minval=1,
maxval=3,
dtype=tf.int64)
detection_scores = tf.random_uniform([20],
minval=0.,
maxval=1.,
dtype=tf.float32)
groundtruth_boxes = tf.random_uniform([8, 4],
minval=0.0,
maxval=1.0,
dtype=tf.float32)
groundtruth_classes = tf.random_uniform([8],
minval=1,
maxval=3,
dtype=tf.int64)
eval_dict = {
fields.DetectionResultFields.detection_boxes: detection_boxes,
fields.DetectionResultFields.detection_classes: detection_classes,
fields.DetectionResultFields.detection_scores: detection_scores,
fields.InputDataFields.original_image: original_image,
fields.InputDataFields.groundtruth_boxes: groundtruth_boxes,
fields.InputDataFields.groundtruth_classes: groundtruth_classes}
metric_ops = eval_metric_ops.get_estimator_eval_metric_ops(eval_dict)
_, update_op = metric_ops[metric_ops.keys()[0]]
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
value_ops = {}
for key, (value_op, _) in metric_ops.iteritems():
value_ops[key] = value_op
# First run enough update steps to surpass `max_examples_to_draw`.
for i in range(max_examples_to_draw):
# Use a unique image shape on each eval image.
sess.run(update_op, feed_dict={
original_image: np.random.randint(low=0,
high=256,
size=(1, 6 + i, 7 + i, 3),
dtype=np.uint8)
})
value_ops_out = sess.run(value_ops)
for key, value_op in value_ops_out.iteritems():
self.assertNotEqual('', value_op)
# Now run fewer update steps than `max_examples_to_draw`. A single value
# op will be the empty string, since not enough image summaries can be
# produced.
for i in range(max_examples_to_draw - 1):
# Use a unique image shape on each eval image.
sess.run(update_op, feed_dict={
original_image: np.random.randint(low=0,
high=256,
size=(1, 6 + i, 7 + i, 3),
dtype=np.uint8)
})
value_ops_out = sess.run(value_ops)
self.assertEqual(
'',
value_ops_out[metric_op_base + '/' + str(max_examples_to_draw - 1)])
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment