Commit 27b4acd4 authored by Aman Gupta's avatar Aman Gupta
Browse files

Merge remote-tracking branch 'upstream/master'

parents 5133522f d4e1f97f
......@@ -4,17 +4,18 @@ package object_detection.protos;
// Message for configuring DetectionModel evaluation jobs (eval.py).
message EvalConfig {
optional uint32 batch_size = 25 [default=1];
// Number of visualization images to generate.
optional uint32 num_visualizations = 1 [default=10];
// Number of examples to process of evaluation.
optional uint32 num_examples = 2 [default=5000];
optional uint32 num_examples = 2 [default=5000, deprecated=true];
// How often to run evaluation.
optional uint32 eval_interval_secs = 3 [default=300];
// Maximum number of times to run evaluation. If set to 0, will run forever.
optional uint32 max_evals = 4 [default=0];
optional uint32 max_evals = 4 [default=0, deprecated=true];
// Whether the TensorFlow graph used for evaluation should be saved to disk.
optional bool save_graph = 5 [default=false];
......
......@@ -157,6 +157,13 @@ message FasterRcnn {
// Whether to use the balanced positive negative sampler implementation with
// static shape guarantees.
optional bool use_static_balanced_label_sampler = 34 [default = false];
// If True, uses implementation of ops with static shape guarantees.
optional bool use_static_shapes = 35 [default = false];
// Whether the masks present in groundtruth should be resized in the model to
// match the image size.
optional bool resize_masks = 36 [default = true];
}
......
......@@ -22,7 +22,12 @@ enum InstanceMaskType {
PNG_MASKS = 2; // Encoded PNG masks.
}
// Next id: 24
message InputReader {
// Name of input reader. Typically used to describe the dataset that is read
// by this input reader.
optional string name = 23 [default=""];
// Path to StringIntLabelMap pbtxt file specifying the mapping from string
// labels to integer ids.
optional string label_map_path = 1 [default=""];
......@@ -41,6 +46,12 @@ message InputReader {
// will be reused indefinitely.
optional uint32 num_epochs = 5 [default=0];
// Integer representing how often an example should be sampled. To feed
// only 1/3 of your data into your model, set `sample_1_of_n_examples` to 3.
// This is particularly useful for evaluation, where you might not prefer to
// evaluate all of your samples.
optional uint32 sample_1_of_n_examples = 22 [default=1];
// Number of file shards to read in parallel.
optional uint32 num_readers = 6 [default=64];
......@@ -62,7 +73,6 @@ message InputReader {
// to generate a good random shuffle.
optional uint32 min_after_dequeue = 4 [default=1000, deprecated=true];
// Number of records to read from each reader at once.
optional uint32 read_block_length = 15 [default=32];
......
......@@ -10,12 +10,13 @@ import "object_detection/protos/train.proto";
// Convenience message for configuring a training and eval pipeline. Allows all
// of the pipeline parameters to be configured from one file.
// Next id: 7
message TrainEvalPipelineConfig {
optional DetectionModel model = 1;
optional TrainConfig train_config = 2;
optional InputReader train_input_reader = 3;
optional EvalConfig eval_config = 4;
optional InputReader eval_input_reader = 5;
repeated InputReader eval_input_reader = 5;
optional GraphRewriter graph_rewriter = 6;
extensions 1000 to max;
}
......@@ -17,6 +17,9 @@ message BatchNonMaxSuppression {
// Maximum number of detections to retain across all classes.
optional int32 max_total_detections = 5 [default = 100];
// Whether to use the implementation of NMS that guarantees static shapes.
optional bool use_static_shapes = 6 [default = false];
}
// Configuration proto for post-processing predicted boxes and
......
......@@ -163,5 +163,8 @@ message FeaturePyramidNetworks {
// maximum level in feature pyramid
optional int32 max_level = 2 [default = 7];
// channel depth for additional coarse feature layers.
optional int32 additional_layer_depth = 3 [default = 256];
}
......@@ -6,7 +6,7 @@ import "object_detection/protos/optimizer.proto";
import "object_detection/protos/preprocessor.proto";
// Message for configuring DetectionModel training jobs (train.py).
// Next id: 26
// Next id: 27
message TrainConfig {
// Effective batch size to use for training.
// For TPU (or sync SGD jobs), the batch size per core (or GPU) is going to be
......@@ -112,4 +112,7 @@ message TrainConfig {
// dictionary, so that they can be displayed in Tensorboard. Note that this
// will lead to a larger memory footprint.
optional bool retain_original_images = 23 [default=false];
// Whether to use bfloat16 for training.
optional bool use_bfloat16 = 26 [default=false];
}
# Faster R-CNN with Resnet-101 (v1)
# Users should configure the fine_tune_checkpoint field in the train config as
# well as the label_map_path and input_path fields in the train_input_reader and
# eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that
# should be configured.
model {
faster_rcnn {
num_classes: 2854
image_resizer {
keep_aspect_ratio_resizer {
min_dimension: 600
max_dimension: 1024
}
}
feature_extractor {
type: 'faster_rcnn_resnet101'
first_stage_features_stride: 16
}
first_stage_anchor_generator {
grid_anchor_generator {
scales: [0.25, 0.5, 1.0, 2.0]
aspect_ratios: [0.5, 1.0, 2.0]
height_stride: 16
width_stride: 16
}
}
first_stage_box_predictor_conv_hyperparams {
op: CONV
regularizer {
l2_regularizer {
weight: 0.0
}
}
initializer {
truncated_normal_initializer {
stddev: 0.01
}
}
}
first_stage_nms_score_threshold: 0.0
first_stage_nms_iou_threshold: 0.7
first_stage_max_proposals: 32
first_stage_localization_loss_weight: 2.0
first_stage_objectness_loss_weight: 1.0
initial_crop_size: 14
maxpool_kernel_size: 2
maxpool_stride: 2
second_stage_batch_size: 32
second_stage_box_predictor {
mask_rcnn_box_predictor {
use_dropout: false
dropout_keep_probability: 1.0
fc_hyperparams {
op: FC
regularizer {
l2_regularizer {
weight: 0.0
}
}
initializer {
variance_scaling_initializer {
factor: 1.0
uniform: true
mode: FAN_AVG
}
}
}
}
}
second_stage_post_processing {
batch_non_max_suppression {
score_threshold: 0.0
iou_threshold: 0.6
max_detections_per_class: 5
max_total_detections: 5
}
score_converter: SOFTMAX
}
second_stage_localization_loss_weight: 2.0
second_stage_classification_loss_weight: 1.0
}
}
train_config: {
batch_size: 1
num_steps: 4000000
optimizer {
momentum_optimizer: {
learning_rate: {
manual_step_learning_rate {
initial_learning_rate: 0.0003
schedule {
step: 3000000
learning_rate: .00003
}
schedule {
step: 3500000
learning_rate: .000003
}
}
}
momentum_optimizer_value: 0.9
}
use_moving_average: false
}
gradient_clipping_by_norm: 10.0
fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt"
data_augmentation_options {
random_horizontal_flip {
}
}
}
train_input_reader: {
label_map_path: "PATH_TO_BE_CONFIGURED/fgvc_2854_classes_label_map.pbtxt"
tf_record_input_reader {
input_path: "PATH_TO_BE_CONFIGURED/animal_2854_train.record"
}
}
eval_config: {
metrics_set: "pascal_voc_detection_metrics"
use_moving_averages: false
num_examples: 48736
}
eval_input_reader: {
label_map_path: "PATH_TO_BE_CONFIGURED/fgvc_2854_classes_label_map.pbtxt"
shuffle: false
num_readers: 1
tf_record_input_reader {
input_path: "PATH_TO_BE_CONFIGURED/animal_2854_val.record"
}
}
# Faster R-CNN with Resnet-50 (v1)
# Users should configure the fine_tune_checkpoint field in the train config as
# well as the label_map_path and input_path fields in the train_input_reader and
# eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that
# should be configured.
model {
faster_rcnn {
num_classes: 2854
image_resizer {
keep_aspect_ratio_resizer {
min_dimension: 600
max_dimension: 1024
}
}
feature_extractor {
type: 'faster_rcnn_resnet50'
first_stage_features_stride: 16
}
first_stage_anchor_generator {
grid_anchor_generator {
scales: [0.25, 0.5, 1.0, 2.0]
aspect_ratios: [0.5, 1.0, 2.0]
height_stride: 16
width_stride: 16
}
}
first_stage_box_predictor_conv_hyperparams {
op: CONV
regularizer {
l2_regularizer {
weight: 0.0
}
}
initializer {
truncated_normal_initializer {
stddev: 0.01
}
}
}
first_stage_nms_score_threshold: 0.0
first_stage_nms_iou_threshold: 0.7
first_stage_max_proposals: 32
first_stage_localization_loss_weight: 2.0
first_stage_objectness_loss_weight: 1.0
initial_crop_size: 14
maxpool_kernel_size: 2
maxpool_stride: 2
second_stage_batch_size: 32
second_stage_box_predictor {
mask_rcnn_box_predictor {
use_dropout: false
dropout_keep_probability: 1.0
fc_hyperparams {
op: FC
regularizer {
l2_regularizer {
weight: 0.0
}
}
initializer {
variance_scaling_initializer {
factor: 1.0
uniform: true
mode: FAN_AVG
}
}
}
}
}
second_stage_post_processing {
batch_non_max_suppression {
score_threshold: 0.0
iou_threshold: 0.6
max_detections_per_class: 5
max_total_detections: 5
}
score_converter: SOFTMAX
}
second_stage_localization_loss_weight: 2.0
second_stage_classification_loss_weight: 1.0
}
}
train_config: {
batch_size: 1
num_steps: 4000000
optimizer {
momentum_optimizer: {
learning_rate: {
manual_step_learning_rate {
initial_learning_rate: 0.0003
schedule {
step: 3000000
learning_rate: .00003
}
schedule {
step: 3500000
learning_rate: .000003
}
}
}
momentum_optimizer_value: 0.9
}
use_moving_average: false
}
gradient_clipping_by_norm: 10.0
fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt"
data_augmentation_options {
random_horizontal_flip {
}
}
}
train_input_reader: {
label_map_path: "PATH_TO_BE_CONFIGURED/fgvc_2854_classes_label_map.pbtxt"
tf_record_input_reader {
input_path: "PATH_TO_BE_CONFIGURED/animal_2854_train.record"
}
}
eval_config: {
metrics_set: "pascal_voc_detection_metrics"
use_moving_averages: false
num_examples: 10
}
eval_input_reader: {
label_map_path: "PATH_TO_BE_CONFIGURED/fgvc_2854_classes_label_map.pbtxt"
shuffle: false
num_readers: 1
tf_record_input_reader {
input_path: "PATH_TO_BE_CONFIGURED/animal_2854_val.record"
}
}
......@@ -103,15 +103,20 @@ def create_configs_from_pipeline_proto(pipeline_config):
Returns:
Dictionary of configuration objects. Keys are `model`, `train_config`,
`train_input_config`, `eval_config`, `eval_input_config`. Value are the
corresponding config objects.
`train_input_config`, `eval_config`, `eval_input_configs`. Value are
the corresponding config objects or list of config objects (only for
eval_input_configs).
"""
configs = {}
configs["model"] = pipeline_config.model
configs["train_config"] = pipeline_config.train_config
configs["train_input_config"] = pipeline_config.train_input_reader
configs["eval_config"] = pipeline_config.eval_config
configs["eval_input_config"] = pipeline_config.eval_input_reader
configs["eval_input_configs"] = pipeline_config.eval_input_reader
# Keeps eval_input_config only for backwards compatibility. All clients should
# read eval_input_configs instead.
if configs["eval_input_configs"]:
configs["eval_input_config"] = configs["eval_input_configs"][0]
if pipeline_config.HasField("graph_rewriter"):
configs["graph_rewriter_config"] = pipeline_config.graph_rewriter
......@@ -150,7 +155,7 @@ def create_pipeline_proto_from_configs(configs):
pipeline_config.train_config.CopyFrom(configs["train_config"])
pipeline_config.train_input_reader.CopyFrom(configs["train_input_config"])
pipeline_config.eval_config.CopyFrom(configs["eval_config"])
pipeline_config.eval_input_reader.CopyFrom(configs["eval_input_config"])
pipeline_config.eval_input_reader.extend(configs["eval_input_configs"])
if "graph_rewriter_config" in configs:
pipeline_config.graph_rewriter.CopyFrom(configs["graph_rewriter_config"])
return pipeline_config
......@@ -224,7 +229,7 @@ def get_configs_from_multiple_files(model_config_path="",
eval_input_config = input_reader_pb2.InputReader()
with tf.gfile.GFile(eval_input_config_path, "r") as f:
text_format.Merge(f.read(), eval_input_config)
configs["eval_input_config"] = eval_input_config
configs["eval_input_configs"] = [eval_input_config]
if graph_rewriter_config_path:
configs["graph_rewriter_config"] = get_graph_rewriter_config_from_file(
......@@ -284,14 +289,133 @@ def _is_generic_key(key):
"graph_rewriter_config",
"model",
"train_input_config",
"train_input_config",
"train_config"]:
"train_config",
"eval_config"]:
if key.startswith(prefix + "."):
return True
return False
def merge_external_params_with_configs(configs, hparams=None, **kwargs):
def _check_and_convert_legacy_input_config_key(key):
"""Checks key and converts legacy input config update to specific update.
Args:
key: string indicates the target of update operation.
Returns:
is_valid_input_config_key: A boolean indicating whether the input key is to
update input config(s).
key_name: 'eval_input_configs' or 'train_input_config' string if
is_valid_input_config_key is true. None if is_valid_input_config_key is
false.
input_name: always returns None since legacy input config key never
specifies the target input config. Keeping this output only to match the
output form defined for input config update.
field_name: the field name in input config. `key` itself if
is_valid_input_config_key is false.
"""
key_name = None
input_name = None
field_name = key
is_valid_input_config_key = True
if field_name == "train_shuffle":
key_name = "train_input_config"
field_name = "shuffle"
elif field_name == "eval_shuffle":
key_name = "eval_input_configs"
field_name = "shuffle"
elif field_name == "train_input_path":
key_name = "train_input_config"
field_name = "input_path"
elif field_name == "eval_input_path":
key_name = "eval_input_configs"
field_name = "input_path"
elif field_name == "train_input_path":
key_name = "train_input_config"
field_name = "input_path"
elif field_name == "eval_input_path":
key_name = "eval_input_configs"
field_name = "input_path"
elif field_name == "append_train_input_path":
key_name = "train_input_config"
field_name = "input_path"
elif field_name == "append_eval_input_path":
key_name = "eval_input_configs"
field_name = "input_path"
else:
is_valid_input_config_key = False
return is_valid_input_config_key, key_name, input_name, field_name
def check_and_parse_input_config_key(configs, key):
"""Checks key and returns specific fields if key is valid input config update.
Args:
configs: Dictionary of configuration objects. See outputs from
get_configs_from_pipeline_file() or get_configs_from_multiple_files().
key: string indicates the target of update operation.
Returns:
is_valid_input_config_key: A boolean indicate whether the input key is to
update input config(s).
key_name: 'eval_input_configs' or 'train_input_config' string if
is_valid_input_config_key is true. None if is_valid_input_config_key is
false.
input_name: the name of the input config to be updated. None if
is_valid_input_config_key is false.
field_name: the field name in input config. `key` itself if
is_valid_input_config_key is false.
Raises:
ValueError: when the input key format doesn't match any known formats.
ValueError: if key_name doesn't match 'eval_input_configs' or
'train_input_config'.
ValueError: if input_name doesn't match any name in train or eval input
configs.
ValueError: if field_name doesn't match any supported fields.
"""
key_name = None
input_name = None
field_name = None
fields = key.split(":")
if len(fields) == 1:
field_name = key
return _check_and_convert_legacy_input_config_key(key)
elif len(fields) == 3:
key_name = fields[0]
input_name = fields[1]
field_name = fields[2]
else:
raise ValueError("Invalid key format when overriding configs.")
# Checks if key_name is valid for specific update.
if key_name not in ["eval_input_configs", "train_input_config"]:
raise ValueError("Invalid key_name when overriding input config.")
# Checks if input_name is valid for specific update. For train input config it
# should match configs[key_name].name, for eval input configs it should match
# the name field of one of the eval_input_configs.
if isinstance(configs[key_name], input_reader_pb2.InputReader):
is_valid_input_name = configs[key_name].name == input_name
else:
is_valid_input_name = input_name in [
eval_input_config.name for eval_input_config in configs[key_name]
]
if not is_valid_input_name:
raise ValueError("Invalid input_name when overriding input config.")
# Checks if field_name is valid for specific update.
if field_name not in [
"input_path", "label_map_path", "shuffle", "mask_type",
"sample_1_of_n_examples"
]:
raise ValueError("Invalid field_name when overriding input config.")
return True, key_name, input_name, field_name
def merge_external_params_with_configs(configs, hparams=None, kwargs_dict=None):
"""Updates `configs` dictionary based on supplied parameters.
This utility is for modifying specific fields in the object detection configs.
......@@ -304,6 +428,31 @@ def merge_external_params_with_configs(configs, hparams=None, **kwargs):
1. Strategy-based overrides, which update multiple relevant configuration
options. For example, updating `learning_rate` will update both the warmup and
final learning rates.
In this case key can be one of the following formats:
1. legacy update: single string that indicates the attribute to be
updated. E.g. 'lable_map_path', 'eval_input_path', 'shuffle'.
Note that when updating fields (e.g. eval_input_path, eval_shuffle) in
eval_input_configs, the override will only be applied when
eval_input_configs has exactly 1 element.
2. specific update: colon separated string that indicates which field in
which input_config to update. It should have 3 fields:
- key_name: Name of the input config we should update, either
'train_input_config' or 'eval_input_configs'
- input_name: a 'name' that can be used to identify elements, especially
when configs[key_name] is a repeated field.
- field_name: name of the field that you want to override.
For example, given configs dict as below:
configs = {
'model': {...}
'train_config': {...}
'train_input_config': {...}
'eval_config': {...}
'eval_input_configs': [{ name:"eval_coco", ...},
{ name:"eval_voc", ... }]
}
Assume we want to update the input_path of the eval_input_config
whose name is 'eval_coco'. The `key` would then be:
'eval_input_configs:eval_coco:input_path'
2. Generic key/value, which update a specific parameter based on namespaced
configuration keys. For example,
`model.ssd.loss.hard_example_miner.max_negatives_per_positive` will update the
......@@ -314,55 +463,29 @@ def merge_external_params_with_configs(configs, hparams=None, **kwargs):
configs: Dictionary of configuration objects. See outputs from
get_configs_from_pipeline_file() or get_configs_from_multiple_files().
hparams: A `HParams`.
**kwargs: Extra keyword arguments that are treated the same way as
kwargs_dict: Extra keyword arguments that are treated the same way as
attribute/value pairs in `hparams`. Note that hyperparameters with the
same names will override keyword arguments.
Returns:
`configs` dictionary.
Raises:
ValueError: when the key string doesn't match any of its allowed formats.
"""
if kwargs_dict is None:
kwargs_dict = {}
if hparams:
kwargs.update(hparams.values())
for key, value in kwargs.items():
kwargs_dict.update(hparams.values())
for key, value in kwargs_dict.items():
tf.logging.info("Maybe overwriting %s: %s", key, value)
# pylint: disable=g-explicit-bool-comparison
if value == "" or value is None:
continue
# pylint: enable=g-explicit-bool-comparison
if key == "learning_rate":
_update_initial_learning_rate(configs, value)
elif key == "batch_size":
_update_batch_size(configs, value)
elif key == "momentum_optimizer_value":
_update_momentum_optimizer_value(configs, value)
elif key == "classification_localization_weight_ratio":
# Localization weight is fixed to 1.0.
_update_classification_localization_weight_ratio(configs, value)
elif key == "focal_loss_gamma":
_update_focal_loss_gamma(configs, value)
elif key == "focal_loss_alpha":
_update_focal_loss_alpha(configs, value)
elif key == "train_steps":
_update_train_steps(configs, value)
elif key == "eval_steps":
_update_eval_steps(configs, value)
elif key == "train_input_path":
_update_input_path(configs["train_input_config"], value)
elif key == "eval_input_path":
_update_input_path(configs["eval_input_config"], value)
elif key == "label_map_path":
_update_label_map_path(configs, value)
elif key == "mask_type":
_update_mask_type(configs, value)
elif key == "eval_with_moving_averages":
_update_use_moving_averages(configs, value)
elif key == "train_shuffle":
_update_shuffle(configs["train_input_config"], value)
elif key == "eval_shuffle":
_update_shuffle(configs["eval_input_config"], value)
elif key == "retain_original_images_in_eval":
_update_retain_original_images(configs["eval_config"], value)
elif _maybe_update_config_with_key_value(configs, key, value):
continue
elif _is_generic_key(key):
_update_generic(configs, key, value)
else:
......@@ -370,6 +493,148 @@ def merge_external_params_with_configs(configs, hparams=None, **kwargs):
return configs
def _maybe_update_config_with_key_value(configs, key, value):
"""Checks key type and updates `configs` with the key value pair accordingly.
Args:
configs: Dictionary of configuration objects. See outputs from
get_configs_from_pipeline_file() or get_configs_from_multiple_files().
key: String indicates the field(s) to be updated.
value: Value used to override existing field value.
Returns:
A boolean value that indicates whether the override succeeds.
Raises:
ValueError: when the key string doesn't match any of the formats above.
"""
is_valid_input_config_key, key_name, input_name, field_name = (
check_and_parse_input_config_key(configs, key))
if is_valid_input_config_key:
update_input_reader_config(
configs,
key_name=key_name,
input_name=input_name,
field_name=field_name,
value=value)
elif field_name == "learning_rate":
_update_initial_learning_rate(configs, value)
elif field_name == "batch_size":
_update_batch_size(configs, value)
elif field_name == "momentum_optimizer_value":
_update_momentum_optimizer_value(configs, value)
elif field_name == "classification_localization_weight_ratio":
# Localization weight is fixed to 1.0.
_update_classification_localization_weight_ratio(configs, value)
elif field_name == "focal_loss_gamma":
_update_focal_loss_gamma(configs, value)
elif field_name == "focal_loss_alpha":
_update_focal_loss_alpha(configs, value)
elif field_name == "train_steps":
_update_train_steps(configs, value)
elif field_name == "label_map_path":
_update_label_map_path(configs, value)
elif field_name == "mask_type":
_update_mask_type(configs, value)
elif field_name == "sample_1_of_n_eval_examples":
_update_all_eval_input_configs(configs, "sample_1_of_n_examples", value)
elif field_name == "eval_num_epochs":
_update_all_eval_input_configs(configs, "num_epochs", value)
elif field_name == "eval_with_moving_averages":
_update_use_moving_averages(configs, value)
elif field_name == "retain_original_images_in_eval":
_update_retain_original_images(configs["eval_config"], value)
elif field_name == "use_bfloat16":
_update_use_bfloat16(configs, value)
else:
return False
return True
def _update_tf_record_input_path(input_config, input_path):
"""Updates input configuration to reflect a new input path.
The input_config object is updated in place, and hence not returned.
Args:
input_config: A input_reader_pb2.InputReader.
input_path: A path to data or list of paths.
Raises:
TypeError: if input reader type is not `tf_record_input_reader`.
"""
input_reader_type = input_config.WhichOneof("input_reader")
if input_reader_type == "tf_record_input_reader":
input_config.tf_record_input_reader.ClearField("input_path")
if isinstance(input_path, list):
input_config.tf_record_input_reader.input_path.extend(input_path)
else:
input_config.tf_record_input_reader.input_path.append(input_path)
else:
raise TypeError("Input reader type must be `tf_record_input_reader`.")
def update_input_reader_config(configs,
key_name=None,
input_name=None,
field_name=None,
value=None,
path_updater=_update_tf_record_input_path):
"""Updates specified input reader config field.
Args:
configs: Dictionary of configuration objects. See outputs from
get_configs_from_pipeline_file() or get_configs_from_multiple_files().
key_name: Name of the input config we should update, either
'train_input_config' or 'eval_input_configs'
input_name: String name used to identify input config to update with. Should
be either None or value of the 'name' field in one of the input reader
configs.
field_name: Field name in input_reader_pb2.InputReader.
value: Value used to override existing field value.
path_updater: helper function used to update the input path. Only used when
field_name is "input_path".
Raises:
ValueError: when input field_name is None.
ValueError: when input_name is None and number of eval_input_readers does
not equal to 1.
"""
if isinstance(configs[key_name], input_reader_pb2.InputReader):
# Updates singular input_config object.
target_input_config = configs[key_name]
if field_name == "input_path":
path_updater(input_config=target_input_config, input_path=value)
else:
setattr(target_input_config, field_name, value)
elif input_name is None and len(configs[key_name]) == 1:
# Updates first (and the only) object of input_config list.
target_input_config = configs[key_name][0]
if field_name == "input_path":
path_updater(input_config=target_input_config, input_path=value)
else:
setattr(target_input_config, field_name, value)
elif input_name is not None and len(configs[key_name]):
# Updates input_config whose name matches input_name.
update_count = 0
for input_config in configs[key_name]:
if input_config.name == input_name:
setattr(input_config, field_name, value)
update_count = update_count + 1
if not update_count:
raise ValueError(
"Input name {} not found when overriding.".format(input_name))
elif update_count > 1:
raise ValueError("Duplicate input name found when overriding.")
else:
key_name = "None" if key_name is None else key_name
input_name = "None" if input_name is None else input_name
field_name = "None" if field_name is None else field_name
raise ValueError("Unknown input config overriding: "
"key_name:{}, input_name:{}, field_name:{}.".format(
key_name, input_name, field_name))
def _update_initial_learning_rate(configs, learning_rate):
"""Updates `configs` to reflect the new initial learning rate.
......@@ -596,27 +861,10 @@ def _update_eval_steps(configs, eval_steps):
configs["eval_config"].num_examples = int(eval_steps)
def _update_input_path(input_config, input_path):
"""Updates input configuration to reflect a new input path.
The input_config object is updated in place, and hence not returned.
Args:
input_config: A input_reader_pb2.InputReader.
input_path: A path to data or list of paths.
Raises:
TypeError: if input reader type is not `tf_record_input_reader`.
"""
input_reader_type = input_config.WhichOneof("input_reader")
if input_reader_type == "tf_record_input_reader":
input_config.tf_record_input_reader.ClearField("input_path")
if isinstance(input_path, list):
input_config.tf_record_input_reader.input_path.extend(input_path)
else:
input_config.tf_record_input_reader.input_path.append(input_path)
else:
raise TypeError("Input reader type must be `tf_record_input_reader`.")
def _update_all_eval_input_configs(configs, field, value):
"""Updates the content of `field` with `value` for all eval input configs."""
for eval_input_config in configs["eval_input_configs"]:
setattr(eval_input_config, field, value)
def _update_label_map_path(configs, label_map_path):
......@@ -630,7 +878,7 @@ def _update_label_map_path(configs, label_map_path):
label_map_path: New path to `StringIntLabelMap` pbtxt file.
"""
configs["train_input_config"].label_map_path = label_map_path
configs["eval_input_config"].label_map_path = label_map_path
_update_all_eval_input_configs(configs, "label_map_path", label_map_path)
def _update_mask_type(configs, mask_type):
......@@ -645,7 +893,7 @@ def _update_mask_type(configs, mask_type):
input_reader_pb2.InstanceMaskType
"""
configs["train_input_config"].mask_type = mask_type
configs["eval_input_config"].mask_type = mask_type
_update_all_eval_input_configs(configs, "mask_type", mask_type)
def _update_use_moving_averages(configs, use_moving_averages):
......@@ -662,18 +910,6 @@ def _update_use_moving_averages(configs, use_moving_averages):
configs["eval_config"].use_moving_averages = use_moving_averages
def _update_shuffle(input_config, shuffle):
"""Updates input configuration to reflect a new shuffle configuration.
The input_config object is updated in place, and hence not returned.
Args:
input_config: A input_reader_pb2.InputReader.
shuffle: Whether or not to shuffle the input data before reading.
"""
input_config.shuffle = shuffle
def _update_retain_original_images(eval_config, retain_original_images):
"""Updates eval config with option to retain original images.
......@@ -685,3 +921,16 @@ def _update_retain_original_images(eval_config, retain_original_images):
in eval mode.
"""
eval_config.retain_original_images = retain_original_images
def _update_use_bfloat16(configs, use_bfloat16):
"""Updates `configs` to reflect the new setup on whether to use bfloat16.
The configs dictionary is updated in place, and hence not returned.
Args:
configs: Dictionary of configuration objects. See outputs from
get_configs_from_pipeline_file() or get_configs_from_multiple_files().
use_bfloat16: A bool, indicating whether to use bfloat16 for training.
"""
configs["train_config"].use_bfloat16 = use_bfloat16
......@@ -83,7 +83,7 @@ class ConfigUtilTest(tf.test.TestCase):
pipeline_config.train_config.batch_size = 32
pipeline_config.train_input_reader.label_map_path = "path/to/label_map"
pipeline_config.eval_config.num_examples = 20
pipeline_config.eval_input_reader.queue_capacity = 100
pipeline_config.eval_input_reader.add().queue_capacity = 100
_write_config(pipeline_config, pipeline_config_path)
......@@ -96,7 +96,7 @@ class ConfigUtilTest(tf.test.TestCase):
self.assertProtoEquals(pipeline_config.eval_config,
configs["eval_config"])
self.assertProtoEquals(pipeline_config.eval_input_reader,
configs["eval_input_config"])
configs["eval_input_configs"])
def test_create_configs_from_pipeline_proto(self):
"""Tests creating configs dictionary from pipeline proto."""
......@@ -106,7 +106,7 @@ class ConfigUtilTest(tf.test.TestCase):
pipeline_config.train_config.batch_size = 32
pipeline_config.train_input_reader.label_map_path = "path/to/label_map"
pipeline_config.eval_config.num_examples = 20
pipeline_config.eval_input_reader.queue_capacity = 100
pipeline_config.eval_input_reader.add().queue_capacity = 100
configs = config_util.create_configs_from_pipeline_proto(pipeline_config)
self.assertProtoEquals(pipeline_config.model, configs["model"])
......@@ -116,7 +116,7 @@ class ConfigUtilTest(tf.test.TestCase):
configs["train_input_config"])
self.assertProtoEquals(pipeline_config.eval_config, configs["eval_config"])
self.assertProtoEquals(pipeline_config.eval_input_reader,
configs["eval_input_config"])
configs["eval_input_configs"])
def test_create_pipeline_proto_from_configs(self):
"""Tests that proto can be reconstructed from configs dictionary."""
......@@ -127,7 +127,7 @@ class ConfigUtilTest(tf.test.TestCase):
pipeline_config.train_config.batch_size = 32
pipeline_config.train_input_reader.label_map_path = "path/to/label_map"
pipeline_config.eval_config.num_examples = 20
pipeline_config.eval_input_reader.queue_capacity = 100
pipeline_config.eval_input_reader.add().queue_capacity = 100
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
......@@ -142,7 +142,7 @@ class ConfigUtilTest(tf.test.TestCase):
pipeline_config.train_config.batch_size = 32
pipeline_config.train_input_reader.label_map_path = "path/to/label_map"
pipeline_config.eval_config.num_examples = 20
pipeline_config.eval_input_reader.queue_capacity = 100
pipeline_config.eval_input_reader.add().queue_capacity = 100
config_util.save_pipeline_config(pipeline_config, self.get_temp_dir())
configs = config_util.get_configs_from_pipeline_file(
......@@ -197,8 +197,7 @@ class ConfigUtilTest(tf.test.TestCase):
self.assertProtoEquals(train_input_config,
configs["train_input_config"])
self.assertProtoEquals(eval_config, configs["eval_config"])
self.assertProtoEquals(eval_input_config,
configs["eval_input_config"])
self.assertProtoEquals(eval_input_config, configs["eval_input_configs"][0])
def _assertOptimizerWithNewLearningRate(self, optimizer_name):
"""Asserts successful updating of all learning rate schemes."""
......@@ -282,6 +281,41 @@ class ConfigUtilTest(tf.test.TestCase):
"""Tests new learning rates for Adam Optimizer."""
self._assertOptimizerWithNewLearningRate("adam_optimizer")
def testGenericConfigOverride(self):
"""Tests generic config overrides for all top-level configs."""
# Set one parameter for each of the top-level pipeline configs:
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.model.ssd.num_classes = 1
pipeline_config.train_config.batch_size = 1
pipeline_config.eval_config.num_visualizations = 1
pipeline_config.train_input_reader.label_map_path = "/some/path"
pipeline_config.eval_input_reader.add().label_map_path = "/some/path"
pipeline_config.graph_rewriter.quantization.weight_bits = 1
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
_write_config(pipeline_config, pipeline_config_path)
# Override each of the parameters:
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
hparams = tf.contrib.training.HParams(
**{
"model.ssd.num_classes": 2,
"train_config.batch_size": 2,
"train_input_config.label_map_path": "/some/other/path",
"eval_config.num_visualizations": 2,
"graph_rewriter_config.quantization.weight_bits": 2
})
configs = config_util.merge_external_params_with_configs(configs, hparams)
# Ensure that the parameters have the overridden values:
self.assertEqual(2, configs["model"].ssd.num_classes)
self.assertEqual(2, configs["train_config"].batch_size)
self.assertEqual("/some/other/path",
configs["train_input_config"].label_map_path)
self.assertEqual(2, configs["eval_config"].num_visualizations)
self.assertEqual(2,
configs["graph_rewriter_config"].quantization.weight_bits)
def testNewBatchSize(self):
"""Tests that batch size is updated appropriately."""
original_batch_size = 2
......@@ -406,25 +440,19 @@ class ConfigUtilTest(tf.test.TestCase):
def testMergingKeywordArguments(self):
"""Tests that keyword arguments get merged as do hyperparameters."""
original_num_train_steps = 100
original_num_eval_steps = 5
desired_num_train_steps = 10
desired_num_eval_steps = 1
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.train_config.num_steps = original_num_train_steps
pipeline_config.eval_config.num_examples = original_num_eval_steps
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
override_dict = {"train_steps": desired_num_train_steps}
configs = config_util.merge_external_params_with_configs(
configs,
train_steps=desired_num_train_steps,
eval_steps=desired_num_eval_steps)
configs, kwargs_dict=override_dict)
train_steps = configs["train_config"].num_steps
eval_steps = configs["eval_config"].num_examples
self.assertEqual(desired_num_train_steps, train_steps)
self.assertEqual(desired_num_eval_steps, eval_steps)
def testGetNumberOfClasses(self):
"""Tests that number of classes can be retrieved."""
......@@ -449,8 +477,9 @@ class ConfigUtilTest(tf.test.TestCase):
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
override_dict = {"train_input_path": new_train_path}
configs = config_util.merge_external_params_with_configs(
configs, train_input_path=new_train_path)
configs, kwargs_dict=override_dict)
reader_config = configs["train_input_config"].tf_record_input_reader
final_path = reader_config.input_path
self.assertEqual([new_train_path], final_path)
......@@ -467,8 +496,9 @@ class ConfigUtilTest(tf.test.TestCase):
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
override_dict = {"train_input_path": new_train_path}
configs = config_util.merge_external_params_with_configs(
configs, train_input_path=new_train_path)
configs, kwargs_dict=override_dict)
reader_config = configs["train_input_config"].tf_record_input_reader
final_path = reader_config.input_path
self.assertEqual(new_train_path, final_path)
......@@ -482,17 +512,18 @@ class ConfigUtilTest(tf.test.TestCase):
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
train_input_reader = pipeline_config.train_input_reader
train_input_reader.label_map_path = original_label_map_path
eval_input_reader = pipeline_config.eval_input_reader
eval_input_reader = pipeline_config.eval_input_reader.add()
eval_input_reader.label_map_path = original_label_map_path
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
override_dict = {"label_map_path": new_label_map_path}
configs = config_util.merge_external_params_with_configs(
configs, label_map_path=new_label_map_path)
configs, kwargs_dict=override_dict)
self.assertEqual(new_label_map_path,
configs["train_input_config"].label_map_path)
self.assertEqual(new_label_map_path,
configs["eval_input_config"].label_map_path)
for eval_input_config in configs["eval_input_configs"]:
self.assertEqual(new_label_map_path, eval_input_config.label_map_path)
def testDontOverwriteEmptyLabelMapPath(self):
"""Tests that label map path will not by overwritten with empty string."""
......@@ -503,17 +534,18 @@ class ConfigUtilTest(tf.test.TestCase):
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
train_input_reader = pipeline_config.train_input_reader
train_input_reader.label_map_path = original_label_map_path
eval_input_reader = pipeline_config.eval_input_reader
eval_input_reader = pipeline_config.eval_input_reader.add()
eval_input_reader.label_map_path = original_label_map_path
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
override_dict = {"label_map_path": new_label_map_path}
configs = config_util.merge_external_params_with_configs(
configs, label_map_path=new_label_map_path)
configs, kwargs_dict=override_dict)
self.assertEqual(original_label_map_path,
configs["train_input_config"].label_map_path)
self.assertEqual(original_label_map_path,
configs["eval_input_config"].label_map_path)
configs["eval_input_configs"][0].label_map_path)
def testNewMaskType(self):
"""Tests that mask type can be overwritten in input readers."""
......@@ -524,15 +556,16 @@ class ConfigUtilTest(tf.test.TestCase):
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
train_input_reader = pipeline_config.train_input_reader
train_input_reader.mask_type = original_mask_type
eval_input_reader = pipeline_config.eval_input_reader
eval_input_reader = pipeline_config.eval_input_reader.add()
eval_input_reader.mask_type = original_mask_type
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
override_dict = {"mask_type": new_mask_type}
configs = config_util.merge_external_params_with_configs(
configs, mask_type=new_mask_type)
configs, kwargs_dict=override_dict)
self.assertEqual(new_mask_type, configs["train_input_config"].mask_type)
self.assertEqual(new_mask_type, configs["eval_input_config"].mask_type)
self.assertEqual(new_mask_type, configs["eval_input_configs"][0].mask_type)
def testUseMovingAverageForEval(self):
use_moving_averages_orig = False
......@@ -543,8 +576,9 @@ class ConfigUtilTest(tf.test.TestCase):
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
override_dict = {"eval_with_moving_averages": True}
configs = config_util.merge_external_params_with_configs(
configs, eval_with_moving_averages=True)
configs, kwargs_dict=override_dict)
self.assertEqual(True, configs["eval_config"].use_moving_averages)
def testGetImageResizerConfig(self):
......@@ -585,14 +619,14 @@ class ConfigUtilTest(tf.test.TestCase):
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.eval_input_reader.shuffle = original_shuffle
pipeline_config.eval_input_reader.add().shuffle = original_shuffle
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
override_dict = {"eval_shuffle": desired_shuffle}
configs = config_util.merge_external_params_with_configs(
configs, eval_shuffle=desired_shuffle)
eval_shuffle = configs["eval_input_config"].shuffle
self.assertEqual(desired_shuffle, eval_shuffle)
configs, kwargs_dict=override_dict)
self.assertEqual(desired_shuffle, configs["eval_input_configs"][0].shuffle)
def testTrainShuffle(self):
"""Tests that `train_shuffle` keyword arguments are applied correctly."""
......@@ -605,8 +639,9 @@ class ConfigUtilTest(tf.test.TestCase):
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
override_dict = {"train_shuffle": desired_shuffle}
configs = config_util.merge_external_params_with_configs(
configs, train_shuffle=desired_shuffle)
configs, kwargs_dict=override_dict)
train_shuffle = configs["train_input_config"].shuffle
self.assertEqual(desired_shuffle, train_shuffle)
......@@ -622,11 +657,210 @@ class ConfigUtilTest(tf.test.TestCase):
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
override_dict = {
"retain_original_images_in_eval": desired_retain_original_images
}
configs = config_util.merge_external_params_with_configs(
configs, retain_original_images_in_eval=desired_retain_original_images)
configs, kwargs_dict=override_dict)
retain_original_images = configs["eval_config"].retain_original_images
self.assertEqual(desired_retain_original_images, retain_original_images)
def testOverwriteAllEvalSampling(self):
original_num_eval_examples = 1
new_num_eval_examples = 10
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.eval_input_reader.add().sample_1_of_n_examples = (
original_num_eval_examples)
pipeline_config.eval_input_reader.add().sample_1_of_n_examples = (
original_num_eval_examples)
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
override_dict = {"sample_1_of_n_eval_examples": new_num_eval_examples}
configs = config_util.merge_external_params_with_configs(
configs, kwargs_dict=override_dict)
for eval_input_config in configs["eval_input_configs"]:
self.assertEqual(new_num_eval_examples,
eval_input_config.sample_1_of_n_examples)
def testOverwriteAllEvalNumEpochs(self):
original_num_epochs = 10
new_num_epochs = 1
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.eval_input_reader.add().num_epochs = original_num_epochs
pipeline_config.eval_input_reader.add().num_epochs = original_num_epochs
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
override_dict = {"eval_num_epochs": new_num_epochs}
configs = config_util.merge_external_params_with_configs(
configs, kwargs_dict=override_dict)
for eval_input_config in configs["eval_input_configs"]:
self.assertEqual(new_num_epochs, eval_input_config.num_epochs)
def testUpdateMaskTypeForAllInputConfigs(self):
original_mask_type = input_reader_pb2.NUMERICAL_MASKS
new_mask_type = input_reader_pb2.PNG_MASKS
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
train_config = pipeline_config.train_input_reader
train_config.mask_type = original_mask_type
eval_1 = pipeline_config.eval_input_reader.add()
eval_1.mask_type = original_mask_type
eval_1.name = "eval_1"
eval_2 = pipeline_config.eval_input_reader.add()
eval_2.mask_type = original_mask_type
eval_2.name = "eval_2"
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
override_dict = {"mask_type": new_mask_type}
configs = config_util.merge_external_params_with_configs(
configs, kwargs_dict=override_dict)
self.assertEqual(configs["train_input_config"].mask_type, new_mask_type)
for eval_input_config in configs["eval_input_configs"]:
self.assertEqual(eval_input_config.mask_type, new_mask_type)
def testErrorOverwritingMultipleInputConfig(self):
original_shuffle = False
new_shuffle = True
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
eval_1 = pipeline_config.eval_input_reader.add()
eval_1.shuffle = original_shuffle
eval_1.name = "eval_1"
eval_2 = pipeline_config.eval_input_reader.add()
eval_2.shuffle = original_shuffle
eval_2.name = "eval_2"
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
override_dict = {"eval_shuffle": new_shuffle}
with self.assertRaises(ValueError):
configs = config_util.merge_external_params_with_configs(
configs, kwargs_dict=override_dict)
def testCheckAndParseInputConfigKey(self):
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.eval_input_reader.add().name = "eval_1"
pipeline_config.eval_input_reader.add().name = "eval_2"
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
specific_shuffle_update_key = "eval_input_configs:eval_2:shuffle"
is_valid_input_config_key, key_name, input_name, field_name = (
config_util.check_and_parse_input_config_key(
configs, specific_shuffle_update_key))
self.assertTrue(is_valid_input_config_key)
self.assertEqual(key_name, "eval_input_configs")
self.assertEqual(input_name, "eval_2")
self.assertEqual(field_name, "shuffle")
legacy_shuffle_update_key = "eval_shuffle"
is_valid_input_config_key, key_name, input_name, field_name = (
config_util.check_and_parse_input_config_key(configs,
legacy_shuffle_update_key))
self.assertTrue(is_valid_input_config_key)
self.assertEqual(key_name, "eval_input_configs")
self.assertEqual(input_name, None)
self.assertEqual(field_name, "shuffle")
non_input_config_update_key = "label_map_path"
is_valid_input_config_key, key_name, input_name, field_name = (
config_util.check_and_parse_input_config_key(
configs, non_input_config_update_key))
self.assertFalse(is_valid_input_config_key)
self.assertEqual(key_name, None)
self.assertEqual(input_name, None)
self.assertEqual(field_name, "label_map_path")
with self.assertRaisesRegexp(ValueError,
"Invalid key format when overriding configs."):
config_util.check_and_parse_input_config_key(
configs, "train_input_config:shuffle")
with self.assertRaisesRegexp(
ValueError, "Invalid key_name when overriding input config."):
config_util.check_and_parse_input_config_key(
configs, "invalid_key_name:train_name:shuffle")
with self.assertRaisesRegexp(
ValueError, "Invalid input_name when overriding input config."):
config_util.check_and_parse_input_config_key(
configs, "eval_input_configs:unknown_eval_name:shuffle")
with self.assertRaisesRegexp(
ValueError, "Invalid field_name when overriding input config."):
config_util.check_and_parse_input_config_key(
configs, "eval_input_configs:eval_2:unknown_field_name")
def testUpdateInputReaderConfigSuccess(self):
original_shuffle = False
new_shuffle = True
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.train_input_reader.shuffle = original_shuffle
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
config_util.update_input_reader_config(
configs,
key_name="train_input_config",
input_name=None,
field_name="shuffle",
value=new_shuffle)
self.assertEqual(configs["train_input_config"].shuffle, new_shuffle)
config_util.update_input_reader_config(
configs,
key_name="train_input_config",
input_name=None,
field_name="shuffle",
value=new_shuffle)
self.assertEqual(configs["train_input_config"].shuffle, new_shuffle)
def testUpdateInputReaderConfigErrors(self):
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.eval_input_reader.add().name = "same_eval_name"
pipeline_config.eval_input_reader.add().name = "same_eval_name"
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
with self.assertRaisesRegexp(ValueError,
"Duplicate input name found when overriding."):
config_util.update_input_reader_config(
configs,
key_name="eval_input_configs",
input_name="same_eval_name",
field_name="shuffle",
value=False)
with self.assertRaisesRegexp(
ValueError, "Input name name_not_exist not found when overriding."):
config_util.update_input_reader_config(
configs,
key_name="eval_input_configs",
input_name="name_not_exist",
field_name="shuffle",
value=False)
with self.assertRaisesRegexp(ValueError,
"Unknown input config overriding."):
config_util.update_input_reader_config(
configs,
key_name="eval_input_configs",
input_name=None,
field_name="shuffle",
value=False)
if __name__ == "__main__":
tf.test.main()
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Label map utility functions."""
import logging
......@@ -73,10 +72,10 @@ def get_max_label_map_index(label_map):
def convert_label_map_to_categories(label_map,
max_num_classes,
use_display_name=True):
"""Loads label map proto and returns categories list compatible with eval.
"""Given label map proto returns categories list compatible with eval.
This function loads a label map and returns a list of dicts, each of which
has the following keys:
This function converts label map proto and returns a list of dicts, each of
which has the following keys:
'id': (required) an integer id uniquely identifying this category.
'name': (required) string representing category name
e.g., 'cat', 'dog', 'pizza'.
......@@ -89,9 +88,10 @@ def convert_label_map_to_categories(label_map,
label_map: a StringIntLabelMapProto or None. If None, a default categories
list is created with max_num_classes categories.
max_num_classes: maximum number of (consecutive) label indices to include.
use_display_name: (boolean) choose whether to load 'display_name' field
as category name. If False or if the display_name field does not exist,
uses 'name' field as category names instead.
use_display_name: (boolean) choose whether to load 'display_name' field as
category name. If False or if the display_name field does not exist, uses
'name' field as category names instead.
Returns:
categories: a list of dictionaries representing all possible categories.
"""
......@@ -107,8 +107,9 @@ def convert_label_map_to_categories(label_map,
return categories
for item in label_map.item:
if not 0 < item.id <= max_num_classes:
logging.info('Ignore item %d since it falls outside of requested '
'label range.', item.id)
logging.info(
'Ignore item %d since it falls outside of requested '
'label range.', item.id)
continue
if use_display_name and item.HasField('display_name'):
name = item.display_name
......@@ -188,20 +189,44 @@ def get_label_map_dict(label_map_path,
return label_map_dict
def create_category_index_from_labelmap(label_map_path):
def create_categories_from_labelmap(label_map_path, use_display_name=True):
"""Reads a label map and returns categories list compatible with eval.
This function converts label map proto and returns a list of dicts, each of
which has the following keys:
'id': an integer id uniquely identifying this category.
'name': string representing category name e.g., 'cat', 'dog'.
Args:
label_map_path: Path to `StringIntLabelMap` proto text file.
use_display_name: (boolean) choose whether to load 'display_name' field
as category name. If False or if the display_name field does not exist,
uses 'name' field as category names instead.
Returns:
categories: a list of dictionaries representing all possible categories.
"""
label_map = load_labelmap(label_map_path)
max_num_classes = max(item.id for item in label_map.item)
return convert_label_map_to_categories(label_map, max_num_classes,
use_display_name)
def create_category_index_from_labelmap(label_map_path, use_display_name=True):
"""Reads a label map and returns a category index.
Args:
label_map_path: Path to `StringIntLabelMap` proto text file.
use_display_name: (boolean) choose whether to load 'display_name' field
as category name. If False or if the display_name field does not exist,
uses 'name' field as category names instead.
Returns:
A category index, which is a dictionary that maps integer ids to dicts
containing categories, e.g.
{1: {'id': 1, 'name': 'dog'}, 2: {'id': 2, 'name': 'cat'}, ...}
"""
label_map = load_labelmap(label_map_path)
max_num_classes = max(item.id for item in label_map.item)
categories = convert_label_map_to_categories(label_map, max_num_classes)
categories = create_categories_from_labelmap(label_map_path, use_display_name)
return create_category_index(categories)
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for object_detection.utils.label_map_util."""
import os
......@@ -189,7 +188,7 @@ class LabelMapUtilTest(tf.test.TestCase):
}]
self.assertListEqual(expected_categories_list, categories)
def test_convert_label_map_to_coco_categories(self):
def test_convert_label_map_to_categories(self):
label_map_proto = self._generate_label_map(num_classes=4)
categories = label_map_util.convert_label_map_to_categories(
label_map_proto, max_num_classes=3)
......@@ -205,7 +204,7 @@ class LabelMapUtilTest(tf.test.TestCase):
}]
self.assertListEqual(expected_categories_list, categories)
def test_convert_label_map_to_coco_categories_with_few_classes(self):
def test_convert_label_map_to_categories_with_few_classes(self):
label_map_proto = self._generate_label_map(num_classes=4)
cat_no_offset = label_map_util.convert_label_map_to_categories(
label_map_proto, max_num_classes=2)
......@@ -238,6 +237,30 @@ class LabelMapUtilTest(tf.test.TestCase):
}
}, category_index)
def test_create_categories_from_labelmap(self):
label_map_string = """
item {
id:1
name:'dog'
}
item {
id:2
name:'cat'
}
"""
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
with tf.gfile.Open(label_map_path, 'wb') as f:
f.write(label_map_string)
categories = label_map_util.create_categories_from_labelmap(label_map_path)
self.assertListEqual([{
'name': u'dog',
'id': 1
}, {
'name': u'cat',
'id': 2
}], categories)
def test_create_category_index_from_labelmap(self):
label_map_string = """
item {
......@@ -266,6 +289,46 @@ class LabelMapUtilTest(tf.test.TestCase):
}
}, category_index)
def test_create_category_index_from_labelmap_display(self):
label_map_string = """
item {
id:2
name:'cat'
display_name:'meow'
}
item {
id:1
name:'dog'
display_name:'woof'
}
"""
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
with tf.gfile.Open(label_map_path, 'wb') as f:
f.write(label_map_string)
self.assertDictEqual({
1: {
'name': u'dog',
'id': 1
},
2: {
'name': u'cat',
'id': 2
}
}, label_map_util.create_category_index_from_labelmap(
label_map_path, False))
self.assertDictEqual({
1: {
'name': u'woof',
'id': 1
},
2: {
'name': u'meow',
'id': 2
}
}, label_map_util.create_category_index_from_labelmap(label_map_path))
if __name__ == '__main__':
tf.test.main()
......@@ -160,6 +160,9 @@ def pad_to_multiple(tensor, multiple):
Returns:
padded_tensor: the tensor zero padded to the specified multiple.
"""
if multiple == 1:
return tensor
tensor_shape = tensor.get_shape()
batch_size = static_shape.get_batch_size(tensor_shape)
tensor_height = static_shape.get_height(tensor_shape)
......@@ -697,8 +700,11 @@ def position_sensitive_crop_regions(image,
image_crops = []
for (split, box) in zip(image_splits, position_sensitive_boxes):
if split.shape.is_fully_defined() and box.shape.is_fully_defined():
crop = matmul_crop_and_resize(
tf.expand_dims(split, 0), box, bin_crop_size)
crop = tf.squeeze(
matmul_crop_and_resize(
tf.expand_dims(split, axis=0), tf.expand_dims(box, axis=0),
bin_crop_size),
axis=0)
else:
crop = tf.image.crop_and_resize(
tf.expand_dims(split, 0), box,
......@@ -785,50 +791,85 @@ def reframe_box_masks_to_image_masks(box_masks, boxes, image_height,
return tf.squeeze(image_masks, axis=3)
def merge_boxes_with_multiple_labels(boxes, classes, num_classes):
def merge_boxes_with_multiple_labels(boxes,
classes,
confidences,
num_classes,
quantization_bins=10000):
"""Merges boxes with same coordinates and returns K-hot encoded classes.
Args:
boxes: A tf.float32 tensor with shape [N, 4] holding N boxes.
boxes: A tf.float32 tensor with shape [N, 4] holding N boxes. Only
normalized coordinates are allowed.
classes: A tf.int32 tensor with shape [N] holding class indices.
The class index starts at 0.
confidences: A tf.float32 tensor with shape [N] holding class confidences.
num_classes: total number of classes to use for K-hot encoding.
quantization_bins: the number of bins used to quantize the box coordinate.
Returns:
merged_boxes: A tf.float32 tensor with shape [N', 4] holding boxes,
where N' <= N.
class_encodings: A tf.int32 tensor with shape [N', num_classes] holding
k-hot encodings for the merged boxes.
K-hot encodings for the merged boxes.
confidence_encodings: A tf.float32 tensor with shape [N', num_classes]
holding encodings of confidences for the merged boxes.
merged_box_indices: A tf.int32 tensor with shape [N'] holding original
indices of the boxes.
"""
def merge_numpy_boxes(boxes, classes, num_classes):
"""Python function to merge numpy boxes."""
if boxes.size < 1:
return (np.zeros([0, 4], dtype=np.float32),
np.zeros([0, num_classes], dtype=np.int32),
np.zeros([0], dtype=np.int32))
box_to_class_indices = {}
for box_index in range(boxes.shape[0]):
box = tuple(boxes[box_index, :].tolist())
class_index = classes[box_index]
if box not in box_to_class_indices:
box_to_class_indices[box] = [box_index, np.zeros([num_classes])]
box_to_class_indices[box][1][class_index] = 1
merged_boxes = np.vstack(box_to_class_indices.keys()).astype(np.float32)
class_encodings = [item[1] for item in box_to_class_indices.values()]
class_encodings = np.vstack(class_encodings).astype(np.int32)
merged_box_indices = [item[0] for item in box_to_class_indices.values()]
merged_box_indices = np.array(merged_box_indices).astype(np.int32)
return merged_boxes, class_encodings, merged_box_indices
merged_boxes, class_encodings, merged_box_indices = tf.py_func(
merge_numpy_boxes, [boxes, classes, num_classes],
[tf.float32, tf.int32, tf.int32])
merged_boxes = tf.reshape(merged_boxes, [-1, 4])
class_encodings = tf.reshape(class_encodings, [-1, num_classes])
merged_box_indices = tf.reshape(merged_box_indices, [-1])
return merged_boxes, class_encodings, merged_box_indices
boxes_shape = tf.shape(boxes)
classes_shape = tf.shape(classes)
confidences_shape = tf.shape(confidences)
box_class_shape_assert = shape_utils.assert_shape_equal_along_first_dimension(
boxes_shape, classes_shape)
box_confidence_shape_assert = (
shape_utils.assert_shape_equal_along_first_dimension(
boxes_shape, confidences_shape))
box_dimension_assert = tf.assert_equal(boxes_shape[1], 4)
box_normalized_assert = shape_utils.assert_box_normalized(boxes)
with tf.control_dependencies(
[box_class_shape_assert, box_confidence_shape_assert,
box_dimension_assert, box_normalized_assert]):
quantized_boxes = tf.to_int64(boxes * (quantization_bins - 1))
ymin, xmin, ymax, xmax = tf.unstack(quantized_boxes, axis=1)
hashcodes = (
ymin +
xmin * quantization_bins +
ymax * quantization_bins * quantization_bins +
xmax * quantization_bins * quantization_bins * quantization_bins)
unique_hashcodes, unique_indices = tf.unique(hashcodes)
num_boxes = tf.shape(boxes)[0]
num_unique_boxes = tf.shape(unique_hashcodes)[0]
merged_box_indices = tf.unsorted_segment_min(
tf.range(num_boxes), unique_indices, num_unique_boxes)
merged_boxes = tf.gather(boxes, merged_box_indices)
def map_box_encodings(i):
"""Produces box K-hot and score encodings for each class index."""
box_mask = tf.equal(
unique_indices, i * tf.ones(num_boxes, dtype=tf.int32))
box_mask = tf.reshape(box_mask, [-1])
box_indices = tf.boolean_mask(classes, box_mask)
box_confidences = tf.boolean_mask(confidences, box_mask)
box_class_encodings = tf.sparse_to_dense(
box_indices, [num_classes], 1, validate_indices=False)
box_confidence_encodings = tf.sparse_to_dense(
box_indices, [num_classes], box_confidences, validate_indices=False)
return box_class_encodings, box_confidence_encodings
class_encodings, confidence_encodings = tf.map_fn(
map_box_encodings,
tf.range(num_unique_boxes),
back_prop=False,
dtype=(tf.int32, tf.float32))
merged_boxes = tf.reshape(merged_boxes, [-1, 4])
class_encodings = tf.reshape(class_encodings, [-1, num_classes])
confidence_encodings = tf.reshape(confidence_encodings, [-1, num_classes])
merged_box_indices = tf.reshape(merged_box_indices, [-1])
return (merged_boxes, class_encodings, confidence_encodings,
merged_box_indices)
def nearest_neighbor_upsampling(input_tensor, scale):
......@@ -895,7 +936,8 @@ def matmul_crop_and_resize(image, boxes, crop_size, scope=None):
Returns a tensor with crops from the input image at positions defined at
the bounding box locations in boxes. The cropped boxes are all resized
(with bilinear interpolation) to a fixed size = `[crop_height, crop_width]`.
The result is a 4-D tensor `[num_boxes, crop_height, crop_width, depth]`.
The result is a 5-D tensor `[batch, num_boxes, crop_height, crop_width,
depth]`.
Running time complexity:
O((# channels) * (# boxes) * (crop_size)^2 * M), where M is the number
......@@ -914,14 +956,13 @@ def matmul_crop_and_resize(image, boxes, crop_size, scope=None):
Args:
image: A `Tensor`. Must be one of the following types: `uint8`, `int8`,
`int16`, `int32`, `int64`, `half`, `float32`, `float64`.
`int16`, `int32`, `int64`, `half`, 'bfloat16', `float32`, `float64`.
A 4-D tensor of shape `[batch, image_height, image_width, depth]`.
Both `image_height` and `image_width` need to be positive.
boxes: A `Tensor` of type `float32`.
A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor
specifies the coordinates of a box in the `box_ind[i]` image and is
specified in normalized coordinates `[y1, x1, y2, x2]`. A normalized
coordinate value of `y` is mapped to the image coordinate at
boxes: A `Tensor` of type `float32` or 'bfloat16'.
A 3-D tensor of shape `[batch, num_boxes, 4]`. The boxes are specified in
normalized coordinates and are of the form `[y1, x1, y2, x2]`. A
normalized coordinate value of `y` is mapped to the image coordinate at
`y * (image_height - 1)`, so as the `[0, 1]` interval of normalized image
height is mapped to `[0, image_height - 1] in image height coordinates.
We do allow y1 > y2, in which case the sampled crop is an up-down flipped
......@@ -935,14 +976,14 @@ def matmul_crop_and_resize(image, boxes, crop_size, scope=None):
scope: A name for the operation (optional).
Returns:
A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`
A 5-D tensor of shape `[batch, num_boxes, crop_height, crop_width, depth]`
Raises:
ValueError: if image tensor does not have shape
`[1, image_height, image_width, depth]` and all dimensions statically
`[batch, image_height, image_width, depth]` and all dimensions statically
defined.
ValueError: if boxes tensor does not have shape `[num_boxes, 4]` where
num_boxes > 0.
ValueError: if boxes tensor does not have shape `[batch, num_boxes, 4]`
where num_boxes > 0.
ValueError: if crop_size is not a list of two positive integers
"""
img_shape = image.shape.as_list()
......@@ -953,13 +994,11 @@ def matmul_crop_and_resize(image, boxes, crop_size, scope=None):
dimensions = img_shape + crop_size + boxes_shape
if not all([isinstance(dim, int) for dim in dimensions]):
raise ValueError('all input shapes must be statically defined')
if len(crop_size) != 2:
raise ValueError('`crop_size` must be a list of length 2')
if len(boxes_shape) != 2 or boxes_shape[1] != 4:
raise ValueError('`boxes` should have shape `[num_boxes, 4]`')
if len(img_shape) != 4 and img_shape[0] != 1:
if len(boxes_shape) != 3 or boxes_shape[2] != 4:
raise ValueError('`boxes` should have shape `[batch, num_boxes, 4]`')
if len(img_shape) != 4:
raise ValueError('image should have shape '
'`[1, image_height, image_width, depth]`')
'`[batch, image_height, image_width, depth]`')
num_crops = boxes_shape[0]
if not num_crops > 0:
raise ValueError('number of boxes must be > 0')
......@@ -968,43 +1007,69 @@ def matmul_crop_and_resize(image, boxes, crop_size, scope=None):
def _lin_space_weights(num, img_size):
if num > 1:
alpha = (img_size - 1) / float(num - 1)
indices = np.reshape(np.arange(num), (1, num))
start_weights = alpha * (num - 1 - indices)
stop_weights = alpha * indices
start_weights = tf.linspace(img_size - 1.0, 0.0, num)
stop_weights = img_size - 1 - start_weights
else:
start_weights = num * [.5 * (img_size - 1)]
stop_weights = num * [.5 * (img_size - 1)]
return (tf.constant(start_weights, dtype=tf.float32),
tf.constant(stop_weights, dtype=tf.float32))
start_weights = tf.constant(num * [.5 * (img_size - 1)], dtype=tf.float32)
stop_weights = tf.constant(num * [.5 * (img_size - 1)], dtype=tf.float32)
return (start_weights, stop_weights)
with tf.name_scope(scope, 'MatMulCropAndResize'):
y1_weights, y2_weights = _lin_space_weights(crop_size[0], img_height)
x1_weights, x2_weights = _lin_space_weights(crop_size[1], img_width)
[y1, x1, y2, x2] = tf.split(value=boxes, num_or_size_splits=4, axis=1)
y1_weights = tf.cast(y1_weights, boxes.dtype)
y2_weights = tf.cast(y2_weights, boxes.dtype)
x1_weights = tf.cast(x1_weights, boxes.dtype)
x2_weights = tf.cast(x2_weights, boxes.dtype)
[y1, x1, y2, x2] = tf.unstack(boxes, axis=2)
# Pixel centers of input image and grid points along height and width
image_idx_h = tf.constant(
np.reshape(np.arange(img_height), (1, 1, img_height)), dtype=tf.float32)
np.reshape(np.arange(img_height), (1, 1, 1, img_height)),
dtype=boxes.dtype)
image_idx_w = tf.constant(
np.reshape(np.arange(img_width), (1, 1, img_width)), dtype=tf.float32)
grid_pos_h = tf.expand_dims(y1 * y1_weights + y2 * y2_weights, 2)
grid_pos_w = tf.expand_dims(x1 * x1_weights + x2 * x2_weights, 2)
np.reshape(np.arange(img_width), (1, 1, 1, img_width)),
dtype=boxes.dtype)
grid_pos_h = tf.expand_dims(
tf.einsum('ab,c->abc', y1, y1_weights) + tf.einsum(
'ab,c->abc', y2, y2_weights),
axis=3)
grid_pos_w = tf.expand_dims(
tf.einsum('ab,c->abc', x1, x1_weights) + tf.einsum(
'ab,c->abc', x2, x2_weights),
axis=3)
# Create kernel matrices of pairwise kernel evaluations between pixel
# centers of image and grid points.
kernel_h = tf.nn.relu(1 - tf.abs(image_idx_h - grid_pos_h))
kernel_w = tf.nn.relu(1 - tf.abs(image_idx_w - grid_pos_w))
# TODO(jonathanhuang): investigate whether all channels can be processed
# without the explicit unstack --- possibly with a permute and map_fn call.
result_channels = []
for channel in tf.unstack(image, axis=3):
result_channels.append(
tf.matmul(
tf.matmul(kernel_h, tf.tile(channel, [num_crops, 1, 1])),
kernel_w, transpose_b=True))
return tf.stack(result_channels, axis=3)
# Compute matrix multiplication between the spatial dimensions of the image
# and height-wise kernel using einsum.
intermediate_image = tf.einsum('abci,aiop->abcop', kernel_h, image)
# Compute matrix multiplication between the spatial dimensions of the
# intermediate_image and width-wise kernel using einsum.
return tf.einsum('abno,abcop->abcnp', kernel_w, intermediate_image)
def native_crop_and_resize(image, boxes, crop_size, scope=None):
"""Same as `matmul_crop_and_resize` but uses tf.image.crop_and_resize."""
def get_box_inds(proposals):
proposals_shape = proposals.get_shape().as_list()
if any(dim is None for dim in proposals_shape):
proposals_shape = tf.shape(proposals)
ones_mat = tf.ones(proposals_shape[:2], dtype=tf.int32)
multiplier = tf.expand_dims(
tf.range(start=0, limit=proposals_shape[0]), 1)
return tf.reshape(ones_mat * multiplier, [-1])
with tf.name_scope(scope, 'CropAndResize'):
cropped_regions = tf.image.crop_and_resize(
image, tf.reshape(boxes, [-1] + boxes.shape.as_list()[2:]),
get_box_inds(boxes), crop_size)
final_shape = tf.concat([tf.shape(boxes)[:2],
tf.shape(cropped_regions)[1:]], axis=0)
return tf.reshape(cropped_regions, final_shape)
def expected_classification_loss_under_sampling(batch_cls_targets, cls_losses,
......
......@@ -1147,36 +1147,76 @@ class MergeBoxesWithMultipleLabelsTest(tf.test.TestCase):
[0.25, 0.25, 0.75, 0.75]],
dtype=tf.float32)
class_indices = tf.constant([0, 4, 2], dtype=tf.int32)
class_confidences = tf.constant([0.8, 0.2, 0.1], dtype=tf.float32)
num_classes = 5
merged_boxes, merged_classes, merged_box_indices = (
ops.merge_boxes_with_multiple_labels(boxes, class_indices, num_classes))
merged_boxes, merged_classes, merged_confidences, merged_box_indices = (
ops.merge_boxes_with_multiple_labels(
boxes, class_indices, class_confidences, num_classes))
expected_merged_boxes = np.array(
[[0.25, 0.25, 0.75, 0.75], [0.0, 0.0, 0.5, 0.75]], dtype=np.float32)
expected_merged_classes = np.array(
[[1, 0, 1, 0, 0], [0, 0, 0, 0, 1]], dtype=np.int32)
expected_merged_confidences = np.array(
[[0.8, 0, 0.1, 0, 0], [0, 0, 0, 0, 0.2]], dtype=np.float32)
expected_merged_box_indices = np.array([0, 1], dtype=np.int32)
with self.test_session() as sess:
np_merged_boxes, np_merged_classes, np_merged_box_indices = sess.run(
[merged_boxes, merged_classes, merged_box_indices])
if np_merged_classes[0, 0] != 1:
expected_merged_boxes = expected_merged_boxes[::-1, :]
expected_merged_classes = expected_merged_classes[::-1, :]
expected_merged_box_indices = expected_merged_box_indices[::-1, :]
(np_merged_boxes, np_merged_classes, np_merged_confidences,
np_merged_box_indices) = sess.run(
[merged_boxes, merged_classes, merged_confidences,
merged_box_indices])
self.assertAllClose(np_merged_boxes, expected_merged_boxes)
self.assertAllClose(np_merged_classes, expected_merged_classes)
self.assertAllClose(np_merged_confidences, expected_merged_confidences)
self.assertAllClose(np_merged_box_indices, expected_merged_box_indices)
def testMergeBoxesWithMultipleLabelsCornerCase(self):
boxes = tf.constant(
[[0, 0, 1, 1], [0, 1, 1, 1], [1, 0, 1, 1], [1, 1, 1, 1],
[1, 1, 1, 1], [1, 0, 1, 1], [0, 1, 1, 1], [0, 0, 1, 1]],
dtype=tf.float32)
class_indices = tf.constant([0, 1, 2, 3, 2, 1, 0, 3], dtype=tf.int32)
class_confidences = tf.constant([0.1, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4, 0.6],
dtype=tf.float32)
num_classes = 4
merged_boxes, merged_classes, merged_confidences, merged_box_indices = (
ops.merge_boxes_with_multiple_labels(
boxes, class_indices, class_confidences, num_classes))
expected_merged_boxes = np.array(
[[0, 0, 1, 1], [0, 1, 1, 1], [1, 0, 1, 1], [1, 1, 1, 1]],
dtype=np.float32)
expected_merged_classes = np.array(
[[1, 0, 0, 1], [1, 1, 0, 0], [0, 1, 1, 0], [0, 0, 1, 1]],
dtype=np.int32)
expected_merged_confidences = np.array(
[[0.1, 0, 0, 0.6], [0.4, 0.9, 0, 0],
[0, 0.7, 0.2, 0], [0, 0, 0.3, 0.8]], dtype=np.float32)
expected_merged_box_indices = np.array([0, 1, 2, 3], dtype=np.int32)
with self.test_session() as sess:
(np_merged_boxes, np_merged_classes, np_merged_confidences,
np_merged_box_indices) = sess.run(
[merged_boxes, merged_classes, merged_confidences,
merged_box_indices])
self.assertAllClose(np_merged_boxes, expected_merged_boxes)
self.assertAllClose(np_merged_classes, expected_merged_classes)
self.assertAllClose(np_merged_confidences, expected_merged_confidences)
self.assertAllClose(np_merged_box_indices, expected_merged_box_indices)
def testMergeBoxesWithEmptyInputs(self):
boxes = tf.constant([[]])
class_indices = tf.constant([])
boxes = tf.zeros([0, 4], dtype=tf.float32)
class_indices = tf.constant([], dtype=tf.int32)
class_confidences = tf.constant([], dtype=tf.float32)
num_classes = 5
merged_boxes, merged_classes, merged_box_indices = (
ops.merge_boxes_with_multiple_labels(boxes, class_indices, num_classes))
merged_boxes, merged_classes, merged_confidences, merged_box_indices = (
ops.merge_boxes_with_multiple_labels(
boxes, class_indices, class_confidences, num_classes))
with self.test_session() as sess:
np_merged_boxes, np_merged_classes, np_merged_box_indices = sess.run(
[merged_boxes, merged_classes, merged_box_indices])
(np_merged_boxes, np_merged_classes, np_merged_confidences,
np_merged_box_indices) = sess.run(
[merged_boxes, merged_classes, merged_confidences,
merged_box_indices])
self.assertAllEqual(np_merged_boxes.shape, [0, 4])
self.assertAllEqual(np_merged_classes.shape, [0, 5])
self.assertAllEqual(np_merged_confidences.shape, [0, 5])
self.assertAllEqual(np_merged_box_indices.shape, [0])
......@@ -1268,8 +1308,8 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
return ops.matmul_crop_and_resize(image, boxes, crop_size=[1, 1])
image = np.array([[[[1], [2]], [[3], [4]]]], dtype=np.float32)
boxes = np.array([[0, 0, 1, 1]], dtype=np.float32)
expected_output = [[[[2.5]]]]
boxes = np.array([[[0, 0, 1, 1]]], dtype=np.float32)
expected_output = [[[[[2.5]]]]]
crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output)
......@@ -1279,8 +1319,8 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
return ops.matmul_crop_and_resize(image, boxes, crop_size=[1, 1])
image = np.array([[[[1], [2]], [[3], [4]]]], dtype=np.float32)
boxes = np.array([[1, 1, 0, 0]], dtype=np.float32)
expected_output = [[[[2.5]]]]
boxes = np.array([[[1, 1, 0, 0]]], dtype=np.float32)
expected_output = [[[[[2.5]]]]]
crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output)
......@@ -1290,10 +1330,10 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
return ops.matmul_crop_and_resize(image, boxes, crop_size=[3, 3])
image = np.array([[[[1], [2]], [[3], [4]]]], dtype=np.float32)
boxes = np.array([[0, 0, 1, 1]], dtype=np.float32)
expected_output = [[[[1.0], [1.5], [2.0]],
[[2.0], [2.5], [3.0]],
[[3.0], [3.5], [4.0]]]]
boxes = np.array([[[0, 0, 1, 1]]], dtype=np.float32)
expected_output = [[[[[1.0], [1.5], [2.0]],
[[2.0], [2.5], [3.0]],
[[3.0], [3.5], [4.0]]]]]
crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output)
......@@ -1303,10 +1343,10 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
return ops.matmul_crop_and_resize(image, boxes, crop_size=[3, 3])
image = np.array([[[[1], [2]], [[3], [4]]]], dtype=np.float32)
boxes = np.array([[1, 1, 0, 0]], dtype=np.float32)
expected_output = [[[[4.0], [3.5], [3.0]],
[[3.0], [2.5], [2.0]],
[[2.0], [1.5], [1.0]]]]
boxes = np.array([[[1, 1, 0, 0]]], dtype=np.float32)
expected_output = [[[[[4.0], [3.5], [3.0]],
[[3.0], [2.5], [2.0]],
[[2.0], [1.5], [1.0]]]]]
crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output)
......@@ -1318,14 +1358,14 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
image = np.array([[[[1], [2], [3]],
[[4], [5], [6]],
[[7], [8], [9]]]], dtype=np.float32)
boxes = np.array([[0, 0, 1, 1],
[0, 0, .5, .5]], dtype=np.float32)
expected_output = [[[[1], [3]], [[7], [9]]],
[[[1], [2]], [[4], [5]]]]
boxes = np.array([[[0, 0, 1, 1],
[0, 0, .5, .5]]], dtype=np.float32)
expected_output = [[[[[1], [3]], [[7], [9]]],
[[[1], [2]], [[4], [5]]]]]
crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output)
def testMatMulCropAndResize3x3To2x2MultiChannel(self):
def testMatMulCropAndResize3x3To2x2_2Channels(self):
def graph_fn(image, boxes):
return ops.matmul_crop_and_resize(image, boxes, crop_size=[2, 2])
......@@ -1333,10 +1373,32 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
image = np.array([[[[1, 0], [2, 1], [3, 2]],
[[4, 3], [5, 4], [6, 5]],
[[7, 6], [8, 7], [9, 8]]]], dtype=np.float32)
boxes = np.array([[0, 0, 1, 1],
[0, 0, .5, .5]], dtype=np.float32)
expected_output = [[[[1, 0], [3, 2]], [[7, 6], [9, 8]]],
[[[1, 0], [2, 1]], [[4, 3], [5, 4]]]]
boxes = np.array([[[0, 0, 1, 1],
[0, 0, .5, .5]]], dtype=np.float32)
expected_output = [[[[[1, 0], [3, 2]], [[7, 6], [9, 8]]],
[[[1, 0], [2, 1]], [[4, 3], [5, 4]]]]]
crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output)
def testBatchMatMulCropAndResize3x3To2x2_2Channels(self):
def graph_fn(image, boxes):
return ops.matmul_crop_and_resize(image, boxes, crop_size=[2, 2])
image = np.array([[[[1, 0], [2, 1], [3, 2]],
[[4, 3], [5, 4], [6, 5]],
[[7, 6], [8, 7], [9, 8]]],
[[[1, 0], [2, 1], [3, 2]],
[[4, 3], [5, 4], [6, 5]],
[[7, 6], [8, 7], [9, 8]]]], dtype=np.float32)
boxes = np.array([[[0, 0, 1, 1],
[0, 0, .5, .5]],
[[1, 1, 0, 0],
[.5, .5, 0, 0]]], dtype=np.float32)
expected_output = [[[[[1, 0], [3, 2]], [[7, 6], [9, 8]]],
[[[1, 0], [2, 1]], [[4, 3], [5, 4]]]],
[[[[9, 8], [7, 6]], [[3, 2], [1, 0]]],
[[[5, 4], [4, 3]], [[2, 1], [1, 0]]]]]
crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output)
......@@ -1348,10 +1410,10 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
image = np.array([[[[1], [2], [3]],
[[4], [5], [6]],
[[7], [8], [9]]]], dtype=np.float32)
boxes = np.array([[1, 1, 0, 0],
[.5, .5, 0, 0]], dtype=np.float32)
expected_output = [[[[9], [7]], [[3], [1]]],
[[[5], [4]], [[2], [1]]]]
boxes = np.array([[[1, 1, 0, 0],
[.5, .5, 0, 0]]], dtype=np.float32)
expected_output = [[[[[9], [7]], [[3], [1]]],
[[[5], [4]], [[2], [1]]]]]
crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output)
......@@ -1363,6 +1425,31 @@ class OpsTestMatMulCropAndResize(test_case.TestCase):
_ = ops.matmul_crop_and_resize(image, boxes, crop_size)
class OpsTestCropAndResize(test_case.TestCase):
def testBatchCropAndResize3x3To2x2_2Channels(self):
def graph_fn(image, boxes):
return ops.native_crop_and_resize(image, boxes, crop_size=[2, 2])
image = np.array([[[[1, 0], [2, 1], [3, 2]],
[[4, 3], [5, 4], [6, 5]],
[[7, 6], [8, 7], [9, 8]]],
[[[1, 0], [2, 1], [3, 2]],
[[4, 3], [5, 4], [6, 5]],
[[7, 6], [8, 7], [9, 8]]]], dtype=np.float32)
boxes = np.array([[[0, 0, 1, 1],
[0, 0, .5, .5]],
[[1, 1, 0, 0],
[.5, .5, 0, 0]]], dtype=np.float32)
expected_output = [[[[[1, 0], [3, 2]], [[7, 6], [9, 8]]],
[[[1, 0], [2, 1]], [[4, 3], [5, 4]]]],
[[[[9, 8], [7, 6]], [[3, 2], [1, 0]]],
[[[5, 4], [4, 3]], [[2, 1], [1, 0]]]]]
crop_output = self.execute_cpu(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output)
class OpsTestExpectedClassificationLoss(test_case.TestCase):
def testExpectedClassificationLossUnderSamplingWithHardLabels(self):
......
......@@ -342,3 +342,26 @@ def assert_shape_equal_along_first_dimension(shape_a, shape_b):
else: return tf.no_op()
else:
return tf.assert_equal(shape_a[0], shape_b[0])
def assert_box_normalized(boxes, maximum_normalized_coordinate=1.1):
"""Asserts the input box tensor is normalized.
Args:
boxes: a tensor of shape [N, 4] where N is the number of boxes.
maximum_normalized_coordinate: Maximum coordinate value to be considered
as normalized, default to 1.1.
Returns:
a tf.Assert op which fails when the input box tensor is not normalized.
Raises:
ValueError: When the input box tensor is not normalized.
"""
box_minimum = tf.reduce_min(boxes)
box_maximum = tf.reduce_max(boxes)
return tf.Assert(
tf.logical_and(
tf.less_equal(box_maximum, maximum_normalized_coordinate),
tf.greater_equal(box_minimum, 0)),
[boxes])
......@@ -14,6 +14,7 @@
# ==============================================================================
"""A convenience wrapper around tf.test.TestCase to enable TPU tests."""
import os
import tensorflow as tf
from tensorflow.contrib import tpu
......@@ -23,6 +24,8 @@ flags.DEFINE_bool('tpu_test', False, 'Whether to configure test for TPU.')
FLAGS = flags.FLAGS
class TestCase(tf.test.TestCase):
"""Extends tf.test.TestCase to optionally allow running tests on TPU."""
......
......@@ -24,6 +24,9 @@ from object_detection.core import box_predictor
from object_detection.core import matcher
from object_detection.utils import shape_utils
# Default size (both width and height) used for testing mask predictions.
DEFAULT_MASK_SIZE = 5
class MockBoxCoder(box_coder.BoxCoder):
"""Simple `difference` BoxCoder."""
......@@ -42,8 +45,9 @@ class MockBoxCoder(box_coder.BoxCoder):
class MockBoxPredictor(box_predictor.BoxPredictor):
"""Simple box predictor that ignores inputs and outputs all zeros."""
def __init__(self, is_training, num_classes):
def __init__(self, is_training, num_classes, predict_mask=False):
super(MockBoxPredictor, self).__init__(is_training, num_classes)
self._predict_mask = predict_mask
def _predict(self, image_features, num_predictions_per_location):
image_feature = image_features[0]
......@@ -57,17 +61,29 @@ class MockBoxPredictor(box_predictor.BoxPredictor):
(batch_size, num_anchors, 1, code_size), dtype=tf.float32)
class_predictions_with_background = zero + tf.zeros(
(batch_size, num_anchors, self.num_classes + 1), dtype=tf.float32)
return {box_predictor.BOX_ENCODINGS: box_encodings,
box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND:
class_predictions_with_background}
masks = zero + tf.zeros(
(batch_size, num_anchors, self.num_classes, DEFAULT_MASK_SIZE,
DEFAULT_MASK_SIZE),
dtype=tf.float32)
predictions_dict = {
box_predictor.BOX_ENCODINGS:
box_encodings,
box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND:
class_predictions_with_background
}
if self._predict_mask:
predictions_dict[box_predictor.MASK_PREDICTIONS] = masks
return predictions_dict
class MockKerasBoxPredictor(box_predictor.KerasBoxPredictor):
"""Simple box predictor that ignores inputs and outputs all zeros."""
def __init__(self, is_training, num_classes):
def __init__(self, is_training, num_classes, predict_mask=False):
super(MockKerasBoxPredictor, self).__init__(
is_training, num_classes, False, False)
self._predict_mask = predict_mask
def _predict(self, image_features, **kwargs):
image_feature = image_features[0]
......@@ -81,9 +97,19 @@ class MockKerasBoxPredictor(box_predictor.KerasBoxPredictor):
(batch_size, num_anchors, 1, code_size), dtype=tf.float32)
class_predictions_with_background = zero + tf.zeros(
(batch_size, num_anchors, self.num_classes + 1), dtype=tf.float32)
return {box_predictor.BOX_ENCODINGS: box_encodings,
box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND:
class_predictions_with_background}
masks = zero + tf.zeros(
(batch_size, num_anchors, self.num_classes, DEFAULT_MASK_SIZE,
DEFAULT_MASK_SIZE),
dtype=tf.float32)
predictions_dict = {
box_predictor.BOX_ENCODINGS:
box_encodings,
box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND:
class_predictions_with_background
}
if self._predict_mask:
predictions_dict[box_predictor.MASK_PREDICTIONS] = masks
return predictions_dict
class MockAnchorGenerator(anchor_generator.AnchorGenerator):
......@@ -103,7 +129,7 @@ class MockAnchorGenerator(anchor_generator.AnchorGenerator):
class MockMatcher(matcher.Matcher):
"""Simple matcher that matches first anchor to first groundtruth box."""
def _match(self, similarity_matrix):
def _match(self, similarity_matrix, valid_rows):
return tf.constant([0, -1, -1, -1], dtype=tf.int32)
......
......@@ -19,6 +19,8 @@ These functions often receive an image, perform some visualization on the image.
The functions do not return a value, instead they modify the image itself.
"""
from abc import ABCMeta
from abc import abstractmethod
import collections
import functools
# Set headless-friendly backend.
......@@ -731,3 +733,158 @@ def add_hist_image_summary(values, bins, name):
return image
hist_plot = tf.py_func(hist_plot, [values, bins], tf.uint8)
tf.summary.image(name, hist_plot)
class EvalMetricOpsVisualization(object):
"""Abstract base class responsible for visualizations during evaluation.
Currently, summary images are not run during evaluation. One way to produce
evaluation images in Tensorboard is to provide tf.summary.image strings as
`value_ops` in tf.estimator.EstimatorSpec's `eval_metric_ops`. This class is
responsible for accruing images (with overlaid detections and groundtruth)
and returning a dictionary that can be passed to `eval_metric_ops`.
"""
__metaclass__ = ABCMeta
def __init__(self,
category_index,
max_examples_to_draw=5,
max_boxes_to_draw=20,
min_score_thresh=0.2,
use_normalized_coordinates=True,
summary_name_prefix='evaluation_image'):
"""Creates an EvalMetricOpsVisualization.
Args:
category_index: A category index (dictionary) produced from a labelmap.
max_examples_to_draw: The maximum number of example summaries to produce.
max_boxes_to_draw: The maximum number of boxes to draw for detections.
min_score_thresh: The minimum score threshold for showing detections.
use_normalized_coordinates: Whether to assume boxes and kepoints are in
normalized coordinates (as opposed to absolute coordiantes).
Default is True.
summary_name_prefix: A string prefix for each image summary.
"""
self._category_index = category_index
self._max_examples_to_draw = max_examples_to_draw
self._max_boxes_to_draw = max_boxes_to_draw
self._min_score_thresh = min_score_thresh
self._use_normalized_coordinates = use_normalized_coordinates
self._summary_name_prefix = summary_name_prefix
self._images = []
def clear(self):
self._images = []
def add_images(self, images):
"""Store a list of images, each with shape [1, H, W, C]."""
if len(self._images) >= self._max_examples_to_draw:
return
# Store images and clip list if necessary.
self._images.extend(images)
if len(self._images) > self._max_examples_to_draw:
self._images[self._max_examples_to_draw:] = []
def get_estimator_eval_metric_ops(self, eval_dict):
"""Returns metric ops for use in tf.estimator.EstimatorSpec.
Args:
eval_dict: A dictionary that holds an image, groundtruth, and detections
for a single example. See eval_util.result_dict_for_single_example() for
a convenient method for constructing such a dictionary. The dictionary
contains
fields.InputDataFields.original_image: [1, H, W, 3] image.
fields.InputDataFields.groundtruth_boxes - [num_boxes, 4] float32
tensor with groundtruth boxes in range [0.0, 1.0].
fields.InputDataFields.groundtruth_classes - [num_boxes] int64
tensor with 1-indexed groundtruth classes.
fields.InputDataFields.groundtruth_instance_masks - (optional)
[num_boxes, H, W] int64 tensor with instance masks.
fields.DetectionResultFields.detection_boxes - [max_num_boxes, 4]
float32 tensor with detection boxes in range [0.0, 1.0].
fields.DetectionResultFields.detection_classes - [max_num_boxes]
int64 tensor with 1-indexed detection classes.
fields.DetectionResultFields.detection_scores - [max_num_boxes]
float32 tensor with detection scores.
fields.DetectionResultFields.detection_masks - (optional)
[max_num_boxes, H, W] float32 tensor of binarized masks.
fields.DetectionResultFields.detection_keypoints - (optional)
[max_num_boxes, num_keypoints, 2] float32 tensor with keypooints.
Returns:
A dictionary of image summary names to tuple of (value_op, update_op). The
`update_op` is the same for all items in the dictionary, and is
responsible for saving a single side-by-side image with detections and
groundtruth. Each `value_op` holds the tf.summary.image string for a given
image.
"""
images = self.images_from_evaluation_dict(eval_dict)
def get_images():
"""Returns a list of images, padded to self._max_images_to_draw."""
images = self._images
while len(images) < self._max_examples_to_draw:
images.append(np.array(0, dtype=np.uint8))
self.clear()
return images
def image_summary_or_default_string(summary_name, image):
"""Returns image summaries for non-padded elements."""
return tf.cond(
tf.equal(tf.size(tf.shape(image)), 4),
lambda: tf.summary.image(summary_name, image),
lambda: tf.constant(''))
update_op = tf.py_func(self.add_images, [images], [])
image_tensors = tf.py_func(
get_images, [], [tf.uint8] * self._max_examples_to_draw)
eval_metric_ops = {}
for i, image in enumerate(image_tensors):
summary_name = self._summary_name_prefix + '/' + str(i)
value_op = image_summary_or_default_string(summary_name, image)
eval_metric_ops[summary_name] = (value_op, update_op)
return eval_metric_ops
@abstractmethod
def images_from_evaluation_dict(self, eval_dict):
"""Converts evaluation dictionary into a list of image tensors.
To be overridden by implementations.
Args:
eval_dict: A dictionary with all the necessary information for producing
visualizations.
Returns:
A list of [1, H, W, C] uint8 tensors.
"""
raise NotImplementedError
class VisualizeSingleFrameDetections(EvalMetricOpsVisualization):
"""Class responsible for single-frame object detection visualizations."""
def __init__(self,
category_index,
max_examples_to_draw=5,
max_boxes_to_draw=20,
min_score_thresh=0.2,
use_normalized_coordinates=True,
summary_name_prefix='Detections_Left_Groundtruth_Right'):
super(VisualizeSingleFrameDetections, self).__init__(
category_index=category_index,
max_examples_to_draw=max_examples_to_draw,
max_boxes_to_draw=max_boxes_to_draw,
min_score_thresh=min_score_thresh,
use_normalized_coordinates=use_normalized_coordinates,
summary_name_prefix=summary_name_prefix)
def images_from_evaluation_dict(self, eval_dict):
return [draw_side_by_side_evaluation_image(
eval_dict,
self._category_index,
self._max_boxes_to_draw,
self._min_score_thresh,
self._use_normalized_coordinates)]
......@@ -21,6 +21,7 @@ import numpy as np
import PIL.Image as Image
import tensorflow as tf
from object_detection.core import standard_fields as fields
from object_detection.utils import visualization_utils
_TESTDATA_PATH = 'object_detection/test_images'
......@@ -225,6 +226,80 @@ class VisualizationUtilsTest(tf.test.TestCase):
with self.test_session():
hist_image_summary.eval()
def test_eval_metric_ops(self):
category_index = {1: {'id': 1, 'name': 'dog'}, 2: {'id': 2, 'name': 'cat'}}
max_examples_to_draw = 4
metric_op_base = 'Detections_Left_Groundtruth_Right'
eval_metric_ops = visualization_utils.VisualizeSingleFrameDetections(
category_index,
max_examples_to_draw=max_examples_to_draw,
summary_name_prefix=metric_op_base)
original_image = tf.placeholder(tf.uint8, [1, None, None, 3])
detection_boxes = tf.random_uniform([20, 4],
minval=0.0,
maxval=1.0,
dtype=tf.float32)
detection_classes = tf.random_uniform([20],
minval=1,
maxval=3,
dtype=tf.int64)
detection_scores = tf.random_uniform([20],
minval=0.,
maxval=1.,
dtype=tf.float32)
groundtruth_boxes = tf.random_uniform([8, 4],
minval=0.0,
maxval=1.0,
dtype=tf.float32)
groundtruth_classes = tf.random_uniform([8],
minval=1,
maxval=3,
dtype=tf.int64)
eval_dict = {
fields.DetectionResultFields.detection_boxes: detection_boxes,
fields.DetectionResultFields.detection_classes: detection_classes,
fields.DetectionResultFields.detection_scores: detection_scores,
fields.InputDataFields.original_image: original_image,
fields.InputDataFields.groundtruth_boxes: groundtruth_boxes,
fields.InputDataFields.groundtruth_classes: groundtruth_classes}
metric_ops = eval_metric_ops.get_estimator_eval_metric_ops(eval_dict)
_, update_op = metric_ops[metric_ops.keys()[0]]
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
value_ops = {}
for key, (value_op, _) in metric_ops.iteritems():
value_ops[key] = value_op
# First run enough update steps to surpass `max_examples_to_draw`.
for i in range(max_examples_to_draw):
# Use a unique image shape on each eval image.
sess.run(update_op, feed_dict={
original_image: np.random.randint(low=0,
high=256,
size=(1, 6 + i, 7 + i, 3),
dtype=np.uint8)
})
value_ops_out = sess.run(value_ops)
for key, value_op in value_ops_out.iteritems():
self.assertNotEqual('', value_op)
# Now run fewer update steps than `max_examples_to_draw`. A single value
# op will be the empty string, since not enough image summaries can be
# produced.
for i in range(max_examples_to_draw - 1):
# Use a unique image shape on each eval image.
sess.run(update_op, feed_dict={
original_image: np.random.randint(low=0,
high=256,
size=(1, 6 + i, 7 + i, 3),
dtype=np.uint8)
})
value_ops_out = sess.run(value_ops)
self.assertEqual(
'',
value_ops_out[metric_op_base + '/' + str(max_examples_to_draw - 1)])
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment