Unverified Commit fd7b6887 authored by Jonathan Huang's avatar Jonathan Huang Committed by GitHub
Browse files

Merge pull request #3293 from pkulzc/master

Internal changes of object_detection 
parents f98ec55e 1efe98bb
...@@ -108,12 +108,10 @@ model { ...@@ -108,12 +108,10 @@ model {
loss { loss {
classification_loss { classification_loss {
weighted_sigmoid { weighted_sigmoid {
anchorwise_output: true
} }
} }
localization_loss { localization_loss {
weighted_smooth_l1 { weighted_smooth_l1 {
anchorwise_output: true
} }
} }
hard_example_miner { hard_example_miner {
......
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"])
exports_files([
"pets_examples.record",
])
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"])
exports_files([
"image1.jpg",
"image2.jpg",
])
...@@ -47,9 +47,10 @@ import os ...@@ -47,9 +47,10 @@ import os
import tensorflow as tf import tensorflow as tf
from object_detection import trainer from object_detection import trainer
from object_detection.builders import input_reader_builder from object_detection.builders import dataset_builder
from object_detection.builders import model_builder from object_detection.builders import model_builder
from object_detection.utils import config_util from object_detection.utils import config_util
from object_detection.utils import dataset_util
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
...@@ -114,8 +115,13 @@ def main(_): ...@@ -114,8 +115,13 @@ def main(_):
model_config=model_config, model_config=model_config,
is_training=True) is_training=True)
create_input_dict_fn = functools.partial( def get_next(config):
input_reader_builder.build, input_config) return dataset_util.make_initializable_iterator(
dataset_builder.build(
config, num_workers=FLAGS.worker_replicas,
worker_index=FLAGS.task)).get_next()
create_input_dict_fn = functools.partial(get_next, input_config)
env = json.loads(os.environ.get('TF_CONFIG', '{}')) env = json.loads(os.environ.get('TF_CONFIG', '{}'))
cluster_data = env.get('cluster', None) cluster_data = env.get('cluster', None)
......
...@@ -108,6 +108,8 @@ def get_inputs(input_queue, num_classes, merge_multiple_label_boxes=False): ...@@ -108,6 +108,8 @@ def get_inputs(input_queue, num_classes, merge_multiple_label_boxes=False):
keypoints_list: a list of 3-D float tensors of shape [num_boxes, keypoints_list: a list of 3-D float tensors of shape [num_boxes,
num_keypoints, 2] containing keypoints for objects if present in the num_keypoints, 2] containing keypoints for objects if present in the
input queue. Else returns None. input queue. Else returns None.
weights_lists: a list of 1-D float32 tensors of shape [num_boxes]
containing groundtruth weight for each box.
""" """
read_data_list = input_queue.dequeue() read_data_list = input_queue.dequeue()
label_id_offset = 1 label_id_offset = 1
...@@ -132,7 +134,10 @@ def get_inputs(input_queue, num_classes, merge_multiple_label_boxes=False): ...@@ -132,7 +134,10 @@ def get_inputs(input_queue, num_classes, merge_multiple_label_boxes=False):
if (merge_multiple_label_boxes and ( if (merge_multiple_label_boxes and (
masks_gt is not None or keypoints_gt is not None)): masks_gt is not None or keypoints_gt is not None)):
raise NotImplementedError('Multi-label support is only for boxes.') raise NotImplementedError('Multi-label support is only for boxes.')
return image, key, location_gt, classes_gt, masks_gt, keypoints_gt weights_gt = read_data.get(
fields.InputDataFields.groundtruth_weights)
return (image, key, location_gt, classes_gt, masks_gt, keypoints_gt,
weights_gt)
return zip(*map(extract_images_and_targets, read_data_list)) return zip(*map(extract_images_and_targets, read_data_list))
...@@ -147,12 +152,21 @@ def _create_losses(input_queue, create_model_fn, train_config): ...@@ -147,12 +152,21 @@ def _create_losses(input_queue, create_model_fn, train_config):
""" """
detection_model = create_model_fn() detection_model = create_model_fn()
(images, _, groundtruth_boxes_list, groundtruth_classes_list, (images, _, groundtruth_boxes_list, groundtruth_classes_list,
groundtruth_masks_list, groundtruth_keypoints_list) = get_inputs( groundtruth_masks_list, groundtruth_keypoints_list, _) = get_inputs(
input_queue, input_queue,
detection_model.num_classes, detection_model.num_classes,
train_config.merge_multiple_label_boxes) train_config.merge_multiple_label_boxes)
images = [detection_model.preprocess(image) for image in images]
images = tf.concat(images, 0) preprocessed_images = []
true_image_shapes = []
for image in images:
resized_image, true_image_shape = detection_model.preprocess(image)
preprocessed_images.append(resized_image)
true_image_shapes.append(true_image_shape)
images = tf.concat(preprocessed_images, 0)
true_image_shapes = tf.concat(true_image_shapes, 0)
if any(mask is None for mask in groundtruth_masks_list): if any(mask is None for mask in groundtruth_masks_list):
groundtruth_masks_list = None groundtruth_masks_list = None
if any(keypoints is None for keypoints in groundtruth_keypoints_list): if any(keypoints is None for keypoints in groundtruth_keypoints_list):
...@@ -162,16 +176,16 @@ def _create_losses(input_queue, create_model_fn, train_config): ...@@ -162,16 +176,16 @@ def _create_losses(input_queue, create_model_fn, train_config):
groundtruth_classes_list, groundtruth_classes_list,
groundtruth_masks_list, groundtruth_masks_list,
groundtruth_keypoints_list) groundtruth_keypoints_list)
prediction_dict = detection_model.predict(images) prediction_dict = detection_model.predict(images, true_image_shapes)
losses_dict = detection_model.loss(prediction_dict) losses_dict = detection_model.loss(prediction_dict, true_image_shapes)
for loss_tensor in losses_dict.values(): for loss_tensor in losses_dict.values():
tf.losses.add_loss(loss_tensor) tf.losses.add_loss(loss_tensor)
def train(create_tensor_dict_fn, create_model_fn, train_config, master, task, def train(create_tensor_dict_fn, create_model_fn, train_config, master, task,
num_clones, worker_replicas, clone_on_cpu, ps_tasks, worker_job_name, num_clones, worker_replicas, clone_on_cpu, ps_tasks, worker_job_name,
is_chief, train_dir): is_chief, train_dir, graph_hook_fn=None):
"""Training function for detection models. """Training function for detection models.
Args: Args:
...@@ -188,6 +202,10 @@ def train(create_tensor_dict_fn, create_model_fn, train_config, master, task, ...@@ -188,6 +202,10 @@ def train(create_tensor_dict_fn, create_model_fn, train_config, master, task,
worker_job_name: Name of the worker job. worker_job_name: Name of the worker job.
is_chief: Whether this replica is the chief replica. is_chief: Whether this replica is the chief replica.
train_dir: Directory to write checkpoints and training summaries to. train_dir: Directory to write checkpoints and training summaries to.
graph_hook_fn: Optional function that is called after the training graph is
completely built. This is helpful to perform additional changes to the
training graph such as optimizing batchnorm. The function should modify
the default graph.
""" """
detection_model = create_model_fn() detection_model = create_model_fn()
...@@ -217,7 +235,7 @@ def train(create_tensor_dict_fn, create_model_fn, train_config, master, task, ...@@ -217,7 +235,7 @@ def train(create_tensor_dict_fn, create_model_fn, train_config, master, task,
train_config.prefetch_queue_capacity, data_augmentation_options) train_config.prefetch_queue_capacity, data_augmentation_options)
# Gather initial summaries. # Gather initial summaries.
# TODO(rathodv): See if summaries can be added/extracted from global tf # TODO: See if summaries can be added/extracted from global tf
# collections so that they don't have to be passed around. # collections so that they don't have to be passed around.
summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
global_summaries = set([]) global_summaries = set([])
...@@ -233,8 +251,10 @@ def train(create_tensor_dict_fn, create_model_fn, train_config, master, task, ...@@ -233,8 +251,10 @@ def train(create_tensor_dict_fn, create_model_fn, train_config, master, task,
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)
with tf.device(deploy_config.optimizer_device()): with tf.device(deploy_config.optimizer_device()):
training_optimizer = optimizer_builder.build(train_config.optimizer, training_optimizer, optimizer_summary_vars = optimizer_builder.build(
global_summaries) train_config.optimizer)
for var in optimizer_summary_vars:
tf.summary.scalar(var.op.name, var)
sync_optimizer = None sync_optimizer = None
if train_config.sync_replicas: if train_config.sync_replicas:
...@@ -258,8 +278,11 @@ def train(create_tensor_dict_fn, create_model_fn, train_config, master, task, ...@@ -258,8 +278,11 @@ def train(create_tensor_dict_fn, create_model_fn, train_config, master, task,
init_fn = initializer_fn init_fn = initializer_fn
with tf.device(deploy_config.optimizer_device()): with tf.device(deploy_config.optimizer_device()):
regularization_losses = (None if train_config.add_regularization_loss
else [])
total_loss, grads_and_vars = model_deploy.optimize_clones( total_loss, grads_and_vars = model_deploy.optimize_clones(
clones, training_optimizer, regularization_losses=None) clones, training_optimizer,
regularization_losses=regularization_losses)
total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.') total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.')
# Optionally multiply bias gradients by train_config.bias_grad_multiplier. # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
...@@ -285,11 +308,14 @@ def train(create_tensor_dict_fn, create_model_fn, train_config, master, task, ...@@ -285,11 +308,14 @@ def train(create_tensor_dict_fn, create_model_fn, train_config, master, task,
grad_updates = training_optimizer.apply_gradients(grads_and_vars, grad_updates = training_optimizer.apply_gradients(grads_and_vars,
global_step=global_step) global_step=global_step)
update_ops.append(grad_updates) update_ops.append(grad_updates)
update_op = tf.group(*update_ops, name='update_barrier')
update_op = tf.group(*update_ops)
with tf.control_dependencies([update_op]): with tf.control_dependencies([update_op]):
train_tensor = tf.identity(total_loss, name='train_op') train_tensor = tf.identity(total_loss, name='train_op')
if graph_hook_fn:
with tf.device(deploy_config.variables_device()):
graph_hook_fn()
# Add summaries. # Add summaries.
for model_var in slim.get_model_variables(): for model_var in slim.get_model_variables():
global_summaries.add(tf.summary.histogram(model_var.op.name, model_var)) global_summaries.add(tf.summary.histogram(model_var.op.name, model_var))
......
...@@ -51,10 +51,8 @@ class FakeDetectionModel(model.DetectionModel): ...@@ -51,10 +51,8 @@ class FakeDetectionModel(model.DetectionModel):
def __init__(self): def __init__(self):
super(FakeDetectionModel, self).__init__(num_classes=NUMBER_OF_CLASSES) super(FakeDetectionModel, self).__init__(num_classes=NUMBER_OF_CLASSES)
self._classification_loss = losses.WeightedSigmoidClassificationLoss( self._classification_loss = losses.WeightedSigmoidClassificationLoss()
anchorwise_output=True) self._localization_loss = losses.WeightedSmoothL1LocalizationLoss()
self._localization_loss = losses.WeightedSmoothL1LocalizationLoss(
anchorwise_output=True)
def preprocess(self, inputs): def preprocess(self, inputs):
"""Input preprocessing, resizes images to 28x28. """Input preprocessing, resizes images to 28x28.
...@@ -65,14 +63,24 @@ class FakeDetectionModel(model.DetectionModel): ...@@ -65,14 +63,24 @@ class FakeDetectionModel(model.DetectionModel):
Returns: Returns:
preprocessed_inputs: a [batch, 28, 28, channels] float32 tensor. preprocessed_inputs: a [batch, 28, 28, channels] float32 tensor.
true_image_shapes: int32 tensor of shape [batch, 3] where each row is
of the form [height, width, channels] indicating the shapes
of true images in the resized images, as resized images can be padded
with zeros.
""" """
return tf.image.resize_images(inputs, [28, 28]) true_image_shapes = [inputs.shape[:-1].as_list()
for _ in range(inputs.shape[-1])]
return tf.image.resize_images(inputs, [28, 28]), true_image_shapes
def predict(self, preprocessed_inputs): def predict(self, preprocessed_inputs, true_image_shapes):
"""Prediction tensors from inputs tensor. """Prediction tensors from inputs tensor.
Args: Args:
preprocessed_inputs: a [batch, 28, 28, channels] float32 tensor. preprocessed_inputs: a [batch, 28, 28, channels] float32 tensor.
true_image_shapes: int32 tensor of shape [batch, 3] where each row is
of the form [height, width, channels] indicating the shapes
of true images in the resized images, as resized images can be padded
with zeros.
Returns: Returns:
prediction_dict: a dictionary holding prediction tensors to be prediction_dict: a dictionary holding prediction tensors to be
...@@ -89,11 +97,15 @@ class FakeDetectionModel(model.DetectionModel): ...@@ -89,11 +97,15 @@ class FakeDetectionModel(model.DetectionModel):
'box_encodings': tf.reshape(box_prediction, [-1, 1, 4]) 'box_encodings': tf.reshape(box_prediction, [-1, 1, 4])
} }
def postprocess(self, prediction_dict, **params): def postprocess(self, prediction_dict, true_image_shapes, **params):
"""Convert predicted output tensors to final detections. Unused. """Convert predicted output tensors to final detections. Unused.
Args: Args:
prediction_dict: a dictionary holding prediction tensors. prediction_dict: a dictionary holding prediction tensors.
true_image_shapes: int32 tensor of shape [batch, 3] where each row is
of the form [height, width, channels] indicating the shapes
of true images in the resized images, as resized images can be padded
with zeros.
**params: Additional keyword arguments for specific implementations of **params: Additional keyword arguments for specific implementations of
DetectionModel. DetectionModel.
...@@ -107,7 +119,7 @@ class FakeDetectionModel(model.DetectionModel): ...@@ -107,7 +119,7 @@ class FakeDetectionModel(model.DetectionModel):
'num_detections': None 'num_detections': None
} }
def loss(self, prediction_dict): def loss(self, prediction_dict, true_image_shapes):
"""Compute scalar loss tensors with respect to provided groundtruth. """Compute scalar loss tensors with respect to provided groundtruth.
Calling this function requires that groundtruth tensors have been Calling this function requires that groundtruth tensors have been
...@@ -115,6 +127,10 @@ class FakeDetectionModel(model.DetectionModel): ...@@ -115,6 +127,10 @@ class FakeDetectionModel(model.DetectionModel):
Args: Args:
prediction_dict: a dictionary holding predicted tensors prediction_dict: a dictionary holding predicted tensors
true_image_shapes: int32 tensor of shape [batch, 3] where each row is
of the form [height, width, channels] indicating the shapes
of true images in the resized images, as resized images can be padded
with zeros.
Returns: Returns:
a dictionary mapping strings (loss names) to scalar tensors representing a dictionary mapping strings (loss names) to scalar tensors representing
......
...@@ -8,6 +8,12 @@ licenses(["notice"]) ...@@ -8,6 +8,12 @@ licenses(["notice"])
# Apache 2.0 # Apache 2.0
py_library(
name = "test_case",
srcs = ["test_case.py"],
deps = ["//tensorflow"],
)
py_library( py_library(
name = "category_util", name = "category_util",
srcs = ["category_util.py"], srcs = ["category_util.py"],
...@@ -18,12 +24,14 @@ py_library( ...@@ -18,12 +24,14 @@ py_library(
name = "config_util", name = "config_util",
srcs = ["config_util.py"], srcs = ["config_util.py"],
deps = [ deps = [
"//pyglib/logging",
"//tensorflow", "//tensorflow",
"//tensorflow_models/object_detection/protos:eval_py_pb2", "//tensorflow/models/research/object_detection/protos:eval_py_pb2",
"//tensorflow_models/object_detection/protos:input_reader_py_pb2", "//tensorflow/models/research/object_detection/protos:image_resizer_py_pb2",
"//tensorflow_models/object_detection/protos:model_py_pb2", "//tensorflow/models/research/object_detection/protos:input_reader_py_pb2",
"//tensorflow_models/object_detection/protos:pipeline_py_pb2", "//tensorflow/models/research/object_detection/protos:model_py_pb2",
"//tensorflow_models/object_detection/protos:train_py_pb2", "//tensorflow/models/research/object_detection/protos:pipeline_py_pb2",
"//tensorflow/models/research/object_detection/protos:train_py_pb2",
], ],
) )
...@@ -35,13 +43,28 @@ py_library( ...@@ -35,13 +43,28 @@ py_library(
], ],
) )
py_library(
name = "json_utils",
srcs = ["json_utils.py"],
deps = [],
)
py_test(
name = "json_utils_test",
srcs = ["json_utils_test.py"],
deps = [
":json_utils",
"//tensorflow",
],
)
py_library( py_library(
name = "label_map_util", name = "label_map_util",
srcs = ["label_map_util.py"], srcs = ["label_map_util.py"],
deps = [ deps = [
"//third_party/py/google/protobuf", "//google/protobuf",
"//tensorflow", "//tensorflow",
"//tensorflow_models/object_detection/protos:string_int_label_map_py_pb2", "//tensorflow/models/research/object_detection/protos:string_int_label_map_py_pb2",
], ],
) )
...@@ -56,13 +79,22 @@ py_library( ...@@ -56,13 +79,22 @@ py_library(
py_library( py_library(
name = "metrics", name = "metrics",
srcs = ["metrics.py"], srcs = ["metrics.py"],
deps = ["//third_party/py/numpy"], deps = ["//numpy"],
) )
py_library( py_library(
name = "np_box_list", name = "np_box_list",
srcs = ["np_box_list.py"], srcs = ["np_box_list.py"],
deps = ["//tensorflow"], deps = ["//numpy"],
)
py_library(
name = "np_box_mask_list",
srcs = ["np_box_mask_list.py"],
deps = [
":np_box_list",
"//numpy",
],
) )
py_library( py_library(
...@@ -71,7 +103,18 @@ py_library( ...@@ -71,7 +103,18 @@ py_library(
deps = [ deps = [
":np_box_list", ":np_box_list",
":np_box_ops", ":np_box_ops",
"//tensorflow", "//numpy",
],
)
py_library(
name = "np_box_mask_list_ops",
srcs = ["np_box_mask_list_ops.py"],
deps = [
":np_box_list_ops",
":np_box_mask_list",
":np_mask_ops",
"//numpy",
], ],
) )
...@@ -81,6 +124,12 @@ py_library( ...@@ -81,6 +124,12 @@ py_library(
deps = ["//tensorflow"], deps = ["//tensorflow"],
) )
py_library(
name = "np_mask_ops",
srcs = ["np_mask_ops.py"],
deps = ["//numpy"],
)
py_library( py_library(
name = "object_detection_evaluation", name = "object_detection_evaluation",
srcs = ["object_detection_evaluation.py"], srcs = ["object_detection_evaluation.py"],
...@@ -89,7 +138,7 @@ py_library( ...@@ -89,7 +138,7 @@ py_library(
":metrics", ":metrics",
":per_image_evaluation", ":per_image_evaluation",
"//tensorflow", "//tensorflow",
"//tensorflow_models/object_detection/core:standard_fields", "//tensorflow/models/research/object_detection/core:standard_fields",
], ],
) )
...@@ -97,11 +146,12 @@ py_library( ...@@ -97,11 +146,12 @@ py_library(
name = "ops", name = "ops",
srcs = ["ops.py"], srcs = ["ops.py"],
deps = [ deps = [
":shape_utils",
":static_shape", ":static_shape",
"//tensorflow", "//tensorflow",
"//tensorflow_models/object_detection/core:box_list", "//tensorflow/models/research/object_detection/core:box_list",
"//tensorflow_models/object_detection/core:box_list_ops", "//tensorflow/models/research/object_detection/core:box_list_ops",
"//tensorflow_models/object_detection/core:standard_fields", "//tensorflow/models/research/object_detection/core:standard_fields",
], ],
) )
...@@ -111,6 +161,8 @@ py_library( ...@@ -111,6 +161,8 @@ py_library(
deps = [ deps = [
":np_box_list", ":np_box_list",
":np_box_list_ops", ":np_box_list_ops",
":np_box_mask_list",
":np_box_mask_list_ops",
"//tensorflow", "//tensorflow",
], ],
) )
...@@ -118,7 +170,10 @@ py_library( ...@@ -118,7 +170,10 @@ py_library(
py_library( py_library(
name = "shape_utils", name = "shape_utils",
srcs = ["shape_utils.py"], srcs = ["shape_utils.py"],
deps = ["//tensorflow"], deps = [
":static_shape",
"//tensorflow",
],
) )
py_library( py_library(
...@@ -132,12 +187,12 @@ py_library( ...@@ -132,12 +187,12 @@ py_library(
srcs = ["test_utils.py"], srcs = ["test_utils.py"],
deps = [ deps = [
"//tensorflow", "//tensorflow",
"//tensorflow_models/object_detection/core:anchor_generator", "//tensorflow/models/research/object_detection/core:anchor_generator",
"//tensorflow_models/object_detection/core:box_coder", "//tensorflow/models/research/object_detection/core:box_coder",
"//tensorflow_models/object_detection/core:box_list", "//tensorflow/models/research/object_detection/core:box_list",
"//tensorflow_models/object_detection/core:box_predictor", "//tensorflow/models/research/object_detection/core:box_predictor",
"//tensorflow_models/object_detection/core:matcher", "//tensorflow/models/research/object_detection/core:matcher",
"//tensorflow_models/object_detection/utils:shape_utils", "//tensorflow/models/research/object_detection/utils:shape_utils",
], ],
) )
...@@ -153,10 +208,12 @@ py_library( ...@@ -153,10 +208,12 @@ py_library(
name = "visualization_utils", name = "visualization_utils",
srcs = ["visualization_utils.py"], srcs = ["visualization_utils.py"],
deps = [ deps = [
"//third_party/py/PIL:pil", "//PIL:pil",
"//third_party/py/matplotlib", "//Tkinter", # buildcleaner: keep
"//third_party/py/six", "//matplotlib",
"//six",
"//tensorflow", "//tensorflow",
"//tensorflow/models/research/object_detection/core:standard_fields",
], ],
) )
...@@ -174,11 +231,12 @@ py_test( ...@@ -174,11 +231,12 @@ py_test(
srcs = ["config_util_test.py"], srcs = ["config_util_test.py"],
deps = [ deps = [
":config_util", ":config_util",
"//tensorflow:tensorflow_google", "//tensorflow",
"//tensorflow_models/object_detection/protos:input_reader_py_pb2", "//tensorflow/models/research/object_detection/protos:image_resizer_py_pb2",
"//tensorflow_models/object_detection/protos:model_py_pb2", "//tensorflow/models/research/object_detection/protos:input_reader_py_pb2",
"//tensorflow_models/object_detection/protos:pipeline_py_pb2", "//tensorflow/models/research/object_detection/protos:model_py_pb2",
"//tensorflow_models/object_detection/protos:train_py_pb2", "//tensorflow/models/research/object_detection/protos:pipeline_py_pb2",
"//tensorflow/models/research/object_detection/protos:train_py_pb2",
], ],
) )
...@@ -188,6 +246,7 @@ py_test( ...@@ -188,6 +246,7 @@ py_test(
deps = [ deps = [
":dataset_util", ":dataset_util",
"//tensorflow", "//tensorflow",
"//tensorflow/models/research/object_detection/protos:input_reader_py_pb2",
], ],
) )
...@@ -223,6 +282,17 @@ py_test( ...@@ -223,6 +282,17 @@ py_test(
srcs = ["np_box_list_test.py"], srcs = ["np_box_list_test.py"],
deps = [ deps = [
":np_box_list", ":np_box_list",
"//numpy",
"//tensorflow",
],
)
py_test(
name = "np_box_mask_list_test",
srcs = ["np_box_mask_list_test.py"],
deps = [
":np_box_mask_list",
"//numpy",
"//tensorflow", "//tensorflow",
], ],
) )
...@@ -233,6 +303,18 @@ py_test( ...@@ -233,6 +303,18 @@ py_test(
deps = [ deps = [
":np_box_list", ":np_box_list",
":np_box_list_ops", ":np_box_list_ops",
"//numpy",
"//tensorflow",
],
)
py_test(
name = "np_box_mask_list_ops_test",
srcs = ["np_box_mask_list_ops_test.py"],
deps = [
":np_box_mask_list",
":np_box_mask_list_ops",
"//numpy",
"//tensorflow", "//tensorflow",
], ],
) )
...@@ -246,13 +328,22 @@ py_test( ...@@ -246,13 +328,22 @@ py_test(
], ],
) )
py_test(
name = "np_mask_ops_test",
srcs = ["np_mask_ops_test.py"],
deps = [
":np_mask_ops",
"//tensorflow",
],
)
py_test( py_test(
name = "object_detection_evaluation_test", name = "object_detection_evaluation_test",
srcs = ["object_detection_evaluation_test.py"], srcs = ["object_detection_evaluation_test.py"],
deps = [ deps = [
":object_detection_evaluation", ":object_detection_evaluation",
"//tensorflow", "//tensorflow",
"//tensorflow_models/object_detection/core:standard_fields", "//tensorflow/models/research/object_detection/core:standard_fields",
], ],
) )
...@@ -261,8 +352,9 @@ py_test( ...@@ -261,8 +352,9 @@ py_test(
srcs = ["ops_test.py"], srcs = ["ops_test.py"],
deps = [ deps = [
":ops", ":ops",
":test_case",
"//tensorflow", "//tensorflow",
"//tensorflow_models/object_detection/core:standard_fields", "//tensorflow/models/research/object_detection/core:standard_fields",
], ],
) )
...@@ -280,6 +372,7 @@ py_test( ...@@ -280,6 +372,7 @@ py_test(
srcs = ["shape_utils_test.py"], srcs = ["shape_utils_test.py"],
deps = [ deps = [
":shape_utils", ":shape_utils",
"//numpy",
"//tensorflow", "//tensorflow",
], ],
) )
...@@ -315,10 +408,11 @@ py_test( ...@@ -315,10 +408,11 @@ py_test(
name = "visualization_utils_test", name = "visualization_utils_test",
srcs = ["visualization_utils_test.py"], srcs = ["visualization_utils_test.py"],
data = [ data = [
"//tensorflow_models/object_detection/test_images:image1.jpg", "//tensorflow/models/research/object_detection/test_images:image1.jpg",
], ],
deps = [ deps = [
":visualization_utils", ":visualization_utils",
"//third_party/py/PIL:pil", "//pyglib/flags",
"//PIL:pil",
], ],
) )
...@@ -25,6 +25,51 @@ from object_detection.protos import pipeline_pb2 ...@@ -25,6 +25,51 @@ from object_detection.protos import pipeline_pb2
from object_detection.protos import train_pb2 from object_detection.protos import train_pb2
def get_image_resizer_config(model_config):
"""Returns the image resizer config from a model config.
Args:
model_config: A model_pb2.DetectionModel.
Returns:
An image_resizer_pb2.ImageResizer.
Raises:
ValueError: If the model type is not recognized.
"""
meta_architecture = model_config.WhichOneof("model")
if meta_architecture == "faster_rcnn":
return model_config.faster_rcnn.image_resizer
if meta_architecture == "ssd":
return model_config.ssd.image_resizer
raise ValueError("Unknown model type: {}".format(meta_architecture))
def get_spatial_image_size(image_resizer_config):
"""Returns expected spatial size of the output image from a given config.
Args:
image_resizer_config: An image_resizer_pb2.ImageResizer.
Returns:
A list of two integers of the form [height, width]. `height` and `width` are
set -1 if they cannot be determined during graph construction.
Raises:
ValueError: If the model type is not recognized.
"""
if image_resizer_config.HasField("fixed_shape_resizer"):
return [image_resizer_config.fixed_shape_resizer.height,
image_resizer_config.fixed_shape_resizer.width]
if image_resizer_config.HasField("keep_aspect_ratio_resizer"):
if image_resizer_config.keep_aspect_ratio_resizer.pad_to_max_dimension:
return [image_resizer_config.keep_aspect_ratio_resizer.max_dimension] * 2
else:
return [-1, -1]
raise ValueError("Unknown image resizer type.")
def get_configs_from_pipeline_file(pipeline_config_path): def get_configs_from_pipeline_file(pipeline_config_path):
"""Reads configuration from a pipeline_pb2.TrainEvalPipelineConfig. """Reads configuration from a pipeline_pb2.TrainEvalPipelineConfig.
...@@ -228,6 +273,9 @@ def merge_external_params_with_configs(configs, hparams=None, **kwargs): ...@@ -228,6 +273,9 @@ def merge_external_params_with_configs(configs, hparams=None, **kwargs):
if value: if value:
_update_label_map_path(configs, value) _update_label_map_path(configs, value)
tf.logging.info("Overwriting label map path: %s", value) tf.logging.info("Overwriting label map path: %s", value)
if key == "mask_type":
_update_mask_type(configs, value)
tf.logging.info("Overwritten mask type: %s", value)
return configs return configs
...@@ -450,3 +498,18 @@ def _update_label_map_path(configs, label_map_path): ...@@ -450,3 +498,18 @@ def _update_label_map_path(configs, label_map_path):
""" """
configs["train_input_config"].label_map_path = label_map_path configs["train_input_config"].label_map_path = label_map_path
configs["eval_input_config"].label_map_path = label_map_path configs["eval_input_config"].label_map_path = label_map_path
def _update_mask_type(configs, mask_type):
"""Updates the mask type for both train and eval input readers.
The configs dictionary is updated in place, and hence not returned.
Args:
configs: Dictionary of configuration objects. See outputs from
get_configs_from_pipeline_file() or get_configs_from_multiple_files().
mask_type: A string name representing a value of
input_reader_pb2.InstanceMaskType
"""
configs["train_input_config"].mask_type = mask_type
configs["eval_input_config"].mask_type = mask_type
...@@ -16,12 +16,12 @@ ...@@ -16,12 +16,12 @@
import os import os
import google3 import tensorflow as tf
import tensorflow.google as tf
from google.protobuf import text_format from google.protobuf import text_format
from object_detection.protos import eval_pb2 from object_detection.protos import eval_pb2
from object_detection.protos import image_resizer_pb2
from object_detection.protos import input_reader_pb2 from object_detection.protos import input_reader_pb2
from object_detection.protos import model_pb2 from object_detection.protos import model_pb2
from object_detection.protos import pipeline_pb2 from object_detection.protos import pipeline_pb2
...@@ -154,7 +154,7 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -154,7 +154,7 @@ class ConfigUtilTest(tf.test.TestCase):
"""Asserts successful updating of all learning rate schemes.""" """Asserts successful updating of all learning rate schemes."""
original_learning_rate = 0.7 original_learning_rate = 0.7
learning_rate_scaling = 0.1 learning_rate_scaling = 0.1
hparams = tf.HParams(learning_rate=0.15) hparams = tf.contrib.training.HParams(learning_rate=0.15)
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config") pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
# Constant learning rate. # Constant learning rate.
...@@ -216,7 +216,7 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -216,7 +216,7 @@ class ConfigUtilTest(tf.test.TestCase):
def testNewBatchSize(self): def testNewBatchSize(self):
"""Tests that batch size is updated appropriately.""" """Tests that batch size is updated appropriately."""
original_batch_size = 2 original_batch_size = 2
hparams = tf.HParams(batch_size=16) hparams = tf.contrib.training.HParams(batch_size=16)
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config") pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
...@@ -231,7 +231,7 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -231,7 +231,7 @@ class ConfigUtilTest(tf.test.TestCase):
def testNewBatchSizeWithClipping(self): def testNewBatchSizeWithClipping(self):
"""Tests that batch size is clipped to 1 from below.""" """Tests that batch size is clipped to 1 from below."""
original_batch_size = 2 original_batch_size = 2
hparams = tf.HParams(batch_size=0.5) hparams = tf.contrib.training.HParams(batch_size=0.5)
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config") pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
...@@ -246,7 +246,7 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -246,7 +246,7 @@ class ConfigUtilTest(tf.test.TestCase):
def testNewMomentumOptimizerValue(self): def testNewMomentumOptimizerValue(self):
"""Tests that new momentum value is updated appropriately.""" """Tests that new momentum value is updated appropriately."""
original_momentum_value = 0.4 original_momentum_value = 0.4
hparams = tf.HParams(momentum_optimizer_value=1.1) hparams = tf.contrib.training.HParams(momentum_optimizer_value=1.1)
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config") pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
...@@ -265,7 +265,7 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -265,7 +265,7 @@ class ConfigUtilTest(tf.test.TestCase):
original_localization_weight = 0.1 original_localization_weight = 0.1
original_classification_weight = 0.2 original_classification_weight = 0.2
new_weight_ratio = 5.0 new_weight_ratio = 5.0
hparams = tf.HParams( hparams = tf.contrib.training.HParams(
classification_localization_weight_ratio=new_weight_ratio) classification_localization_weight_ratio=new_weight_ratio)
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config") pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
...@@ -288,7 +288,8 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -288,7 +288,8 @@ class ConfigUtilTest(tf.test.TestCase):
original_gamma = 1.0 original_gamma = 1.0
new_alpha = 0.3 new_alpha = 0.3
new_gamma = 2.0 new_gamma = 2.0
hparams = tf.HParams(focal_loss_alpha=new_alpha, focal_loss_gamma=new_gamma) hparams = tf.contrib.training.HParams(
focal_loss_alpha=new_alpha, focal_loss_gamma=new_gamma)
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config") pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
...@@ -396,6 +397,56 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -396,6 +397,56 @@ class ConfigUtilTest(tf.test.TestCase):
self.assertEqual(new_label_map_path, self.assertEqual(new_label_map_path,
configs["eval_input_config"].label_map_path) configs["eval_input_config"].label_map_path)
def testNewMaskType(self):
"""Tests that mask type can be overwritten in input readers."""
original_mask_type = input_reader_pb2.NUMERICAL_MASKS
new_mask_type = input_reader_pb2.PNG_MASKS
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
train_input_reader = pipeline_config.train_input_reader
train_input_reader.mask_type = original_mask_type
eval_input_reader = pipeline_config.eval_input_reader
eval_input_reader.mask_type = original_mask_type
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
configs = config_util.merge_external_params_with_configs(
configs, mask_type=new_mask_type)
self.assertEqual(new_mask_type, configs["train_input_config"].mask_type)
self.assertEqual(new_mask_type, configs["eval_input_config"].mask_type)
def test_get_image_resizer_config(self):
"""Tests that number of classes can be retrieved."""
model_config = model_pb2.DetectionModel()
model_config.faster_rcnn.image_resizer.fixed_shape_resizer.height = 100
model_config.faster_rcnn.image_resizer.fixed_shape_resizer.width = 300
image_resizer_config = config_util.get_image_resizer_config(model_config)
self.assertEqual(image_resizer_config.fixed_shape_resizer.height, 100)
self.assertEqual(image_resizer_config.fixed_shape_resizer.width, 300)
def test_get_spatial_image_size_from_fixed_shape_resizer_config(self):
image_resizer_config = image_resizer_pb2.ImageResizer()
image_resizer_config.fixed_shape_resizer.height = 100
image_resizer_config.fixed_shape_resizer.width = 200
image_shape = config_util.get_spatial_image_size(image_resizer_config)
self.assertAllEqual(image_shape, [100, 200])
def test_get_spatial_image_size_from_aspect_preserving_resizer_config(self):
image_resizer_config = image_resizer_pb2.ImageResizer()
image_resizer_config.keep_aspect_ratio_resizer.min_dimension = 100
image_resizer_config.keep_aspect_ratio_resizer.max_dimension = 600
image_resizer_config.keep_aspect_ratio_resizer.pad_to_max_dimension = True
image_shape = config_util.get_spatial_image_size(image_resizer_config)
self.assertAllEqual(image_shape, [600, 600])
def test_get_spatial_image_size_from_aspect_preserving_resizer_dynamic(self):
image_resizer_config = image_resizer_pb2.ImageResizer()
image_resizer_config.keep_aspect_ratio_resizer.min_dimension = 100
image_resizer_config.keep_aspect_ratio_resizer.max_dimension = 600
image_shape = config_util.get_spatial_image_size(image_resizer_config)
self.assertAllEqual(image_shape, [-1, -1])
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
...@@ -84,3 +84,64 @@ def recursive_parse_xml_to_dict(xml): ...@@ -84,3 +84,64 @@ def recursive_parse_xml_to_dict(xml):
result[child.tag] = [] result[child.tag] = []
result[child.tag].append(child_result[child.tag]) result[child.tag].append(child_result[child.tag])
return {xml.tag: result} return {xml.tag: result}
def make_initializable_iterator(dataset):
"""Creates an iterator, and initializes tables.
This is useful in cases where make_one_shot_iterator wouldn't work because
the graph contains a hash table that needs to be initialized.
Args:
dataset: A `tf.data.Dataset` object.
Returns:
A `tf.data.Iterator`.
"""
iterator = dataset.make_initializable_iterator()
tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
return iterator
def read_dataset(
file_read_func, decode_func, input_files, config, num_workers=1,
worker_index=0):
"""Reads a dataset, and handles repetition and shuffling.
Args:
file_read_func: Function to use in tf.data.Dataset.interleave, to read
every individual file into a tf.data.Dataset.
decode_func: Function to apply to all records.
input_files: A list of file paths to read.
config: A input_reader_builder.InputReader object.
num_workers: Number of workers / shards.
worker_index: Id for the current worker.
Returns:
A tf.data.Dataset based on config.
"""
# Shard, shuffle, and read files.
filenames = tf.concat([tf.matching_files(pattern) for pattern in input_files],
0)
dataset = tf.data.Dataset.from_tensor_slices(filenames)
dataset = dataset.shard(num_workers, worker_index)
dataset = dataset.repeat(config.num_epochs or None)
if config.shuffle:
dataset = dataset.shuffle(config.filenames_shuffle_buffer_size,
reshuffle_each_iteration=True)
# Read file records and shuffle them.
# If cycle_length is larger than the number of files, more than one reader
# will be assigned to the same file, leading to repetition.
cycle_length = tf.cast(
tf.minimum(config.num_readers, tf.size(filenames)), tf.int64)
# TODO: find the optimal block_length.
dataset = dataset.interleave(
file_read_func, cycle_length=cycle_length, block_length=1)
if config.shuffle:
dataset = dataset.shuffle(config.shuffle_buffer_size,
reshuffle_each_iteration=True)
dataset = dataset.map(decode_func, num_parallel_calls=config.num_readers)
return dataset.prefetch(config.prefetch_buffer_size)
...@@ -18,11 +18,29 @@ ...@@ -18,11 +18,29 @@
import os import os
import tensorflow as tf import tensorflow as tf
from object_detection.protos import input_reader_pb2
from object_detection.utils import dataset_util from object_detection.utils import dataset_util
class DatasetUtilTest(tf.test.TestCase): class DatasetUtilTest(tf.test.TestCase):
def setUp(self):
self._path_template = os.path.join(self.get_temp_dir(), 'examples_%s.txt')
for i in range(5):
path = self._path_template % i
with tf.gfile.Open(path, 'wb') as f:
f.write('\n'.join([str(i + 1), str((i + 1) * 10)]))
def _get_dataset_next(self, files, config, batch_size):
def decode_func(value):
return [tf.string_to_number(value, out_type=tf.int32)]
dataset = dataset_util.read_dataset(
tf.data.TextLineDataset, decode_func, files, config)
dataset = dataset.batch(batch_size)
return dataset.make_one_shot_iterator().get_next()
def test_read_examples_list(self): def test_read_examples_list(self):
example_list_data = """example1 1\nexample2 2""" example_list_data = """example1 1\nexample2 2"""
example_list_path = os.path.join(self.get_temp_dir(), 'examples.txt') example_list_path = os.path.join(self.get_temp_dir(), 'examples.txt')
...@@ -32,6 +50,47 @@ class DatasetUtilTest(tf.test.TestCase): ...@@ -32,6 +50,47 @@ class DatasetUtilTest(tf.test.TestCase):
examples = dataset_util.read_examples_list(example_list_path) examples = dataset_util.read_examples_list(example_list_path)
self.assertListEqual(['example1', 'example2'], examples) self.assertListEqual(['example1', 'example2'], examples)
def test_make_initializable_iterator_with_hashTable(self):
keys = [1, 0, -1]
dataset = tf.data.Dataset.from_tensor_slices([[1, 2, -1, 5]])
table = tf.contrib.lookup.HashTable(
initializer=tf.contrib.lookup.KeyValueTensorInitializer(
keys=keys,
values=list(reversed(keys))),
default_value=100)
dataset = dataset.map(table.lookup)
data = dataset_util.make_initializable_iterator(dataset).get_next()
init = tf.tables_initializer()
with self.test_session() as sess:
sess.run(init)
self.assertAllEqual(sess.run(data), [-1, 100, 1, 100])
def test_read_dataset(self):
config = input_reader_pb2.InputReader()
config.num_readers = 1
config.shuffle = False
data = self._get_dataset_next([self._path_template % '*'], config,
batch_size=20)
with self.test_session() as sess:
self.assertAllEqual(sess.run(data),
[[1, 10, 2, 20, 3, 30, 4, 40, 5, 50, 1, 10, 2, 20, 3,
30, 4, 40, 5, 50]])
def test_read_dataset_single_epoch(self):
config = input_reader_pb2.InputReader()
config.num_epochs = 1
config.num_readers = 1
config.shuffle = False
data = self._get_dataset_next([self._path_template % '0'], config,
batch_size=30)
with self.test_session() as sess:
# First batch will retrieve as much as it can, second batch will fail.
self.assertAllEqual(sess.run(data), [[1, 10]])
self.assertRaises(tf.errors.OutOfRangeError, sess.run, data)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
"""Utilities for dealing with writing json strings.
json_utils wraps json.dump and json.dumps so that they can be used to safely
control the precision of floats when writing to json strings or files.
"""
import json
from json import encoder
def Dump(obj, fid, float_digits=-1, **params):
"""Wrapper of json.dump that allows specifying the float precision used.
Args:
obj: The object to dump.
fid: The file id to write to.
float_digits: The number of digits of precision when writing floats out.
**params: Additional parameters to pass to json.dumps.
"""
original_encoder = encoder.FLOAT_REPR
if float_digits >= 0:
encoder.FLOAT_REPR = lambda o: format(o, '.%df' % float_digits)
try:
json.dump(obj, fid, **params)
finally:
encoder.FLOAT_REPR = original_encoder
def Dumps(obj, float_digits=-1, **params):
"""Wrapper of json.dumps that allows specifying the float precision used.
Args:
obj: The object to dump.
float_digits: The number of digits of precision when writing floats out.
**params: Additional parameters to pass to json.dumps.
Returns:
output: JSON string representation of obj.
"""
original_encoder = encoder.FLOAT_REPR
original_c_make_encoder = encoder.c_make_encoder
if float_digits >= 0:
encoder.FLOAT_REPR = lambda o: format(o, '.%df' % float_digits)
encoder.c_make_encoder = None
try:
output = json.dumps(obj, **params)
finally:
encoder.FLOAT_REPR = original_encoder
encoder.c_make_encoder = original_c_make_encoder
return output
def PrettyParams(**params):
"""Returns parameters for use with Dump and Dumps to output pretty json.
Example usage:
```json_str = json_utils.Dumps(obj, **json_utils.PrettyParams())```
```json_str = json_utils.Dumps(
obj, **json_utils.PrettyParams(allow_nans=False))```
Args:
**params: Additional params to pass to json.dump or json.dumps.
Returns:
params: Parameters that are compatible with json_utils.Dump and
json_utils.Dumps.
"""
params['float_digits'] = 4
params['sort_keys'] = True
params['indent'] = 2
params['separators'] = (',', ': ')
return params
"""Tests for google3.image.understanding.object_detection.utils.json_utils."""
import os
import tensorflow as tf
from object_detection.utils import json_utils
class JsonUtilsTest(tf.test.TestCase):
def testDumpReasonablePrecision(self):
output_path = os.path.join(tf.test.get_temp_dir(), 'test.json')
with tf.gfile.GFile(output_path, 'w') as f:
json_utils.Dump(1.0, f, float_digits=2)
with tf.gfile.GFile(output_path, 'r') as f:
self.assertEqual(f.read(), '1.00')
def testDumpPassExtraParams(self):
output_path = os.path.join(tf.test.get_temp_dir(), 'test.json')
with tf.gfile.GFile(output_path, 'w') as f:
json_utils.Dump([1.0], f, float_digits=2, indent=3)
with tf.gfile.GFile(output_path, 'r') as f:
self.assertEqual(f.read(), '[\n 1.00\n]')
def testDumpZeroPrecision(self):
output_path = os.path.join(tf.test.get_temp_dir(), 'test.json')
with tf.gfile.GFile(output_path, 'w') as f:
json_utils.Dump(1.0, f, float_digits=0, indent=3)
with tf.gfile.GFile(output_path, 'r') as f:
self.assertEqual(f.read(), '1')
def testDumpUnspecifiedPrecision(self):
output_path = os.path.join(tf.test.get_temp_dir(), 'test.json')
with tf.gfile.GFile(output_path, 'w') as f:
json_utils.Dump(1.012345, f)
with tf.gfile.GFile(output_path, 'r') as f:
self.assertEqual(f.read(), '1.012345')
def testDumpsReasonablePrecision(self):
s = json_utils.Dumps(1.0, float_digits=2)
self.assertEqual(s, '1.00')
def testDumpsPassExtraParams(self):
s = json_utils.Dumps([1.0], float_digits=2, indent=3)
self.assertEqual(s, '[\n 1.00\n]')
def testDumpsZeroPrecision(self):
s = json_utils.Dumps(1.0, float_digits=0)
self.assertEqual(s, '1')
def testDumpsUnspecifiedPrecision(self):
s = json_utils.Dumps(1.012345)
self.assertEqual(s, '1.012345')
def testPrettyParams(self):
s = json_utils.Dumps({'v': 1.012345, 'n': 2}, **json_utils.PrettyParams())
self.assertEqual(s, '{\n "n": 2,\n "v": 1.0123\n}')
def testPrettyParamsExtraParamsInside(self):
s = json_utils.Dumps(
{'v': 1.012345,
'n': float('nan')}, **json_utils.PrettyParams(allow_nan=True))
self.assertEqual(s, '{\n "n": NaN,\n "v": 1.0123\n}')
with self.assertRaises(ValueError):
s = json_utils.Dumps(
{'v': 1.012345,
'n': float('nan')}, **json_utils.PrettyParams(allow_nan=False))
def testPrettyParamsExtraParamsOutside(self):
s = json_utils.Dumps(
{'v': 1.012345,
'n': float('nan')}, allow_nan=True, **json_utils.PrettyParams())
self.assertEqual(s, '{\n "n": NaN,\n "v": 1.0123\n}')
with self.assertRaises(ValueError):
s = json_utils.Dumps(
{'v': 1.012345,
'n': float('nan')}, allow_nan=False, **json_utils.PrettyParams())
if __name__ == '__main__':
tf.test.main()
...@@ -55,6 +55,18 @@ def create_category_index(categories): ...@@ -55,6 +55,18 @@ def create_category_index(categories):
return category_index return category_index
def get_max_label_map_index(label_map):
"""Get maximum index in label map.
Args:
label_map: a StringIntLabelMapProto
Returns:
an integer
"""
return max([item.id for item in label_map.item])
def convert_label_map_to_categories(label_map, def convert_label_map_to_categories(label_map,
max_num_classes, max_num_classes,
use_display_name=True): use_display_name=True):
......
...@@ -170,6 +170,12 @@ class LabelMapUtilTest(tf.test.TestCase): ...@@ -170,6 +170,12 @@ class LabelMapUtilTest(tf.test.TestCase):
}] }]
self.assertListEqual(expected_categories_list, cat_no_offset) self.assertListEqual(expected_categories_list, cat_no_offset)
def test_get_max_label_map_index(self):
num_classes = 4
label_map_proto = self._generate_label_map(num_classes=num_classes)
max_index = label_map_util.get_max_label_map_index(label_map_proto)
self.assertEqual(num_classes, max_index)
def test_create_category_index(self): def test_create_category_index(self):
categories = [{'name': u'1', 'id': 1}, {'name': u'2', 'id': 2}] categories = [{'name': u'1', 'id': 1}, {'name': u'2', 'id': 2}]
category_index = label_map_util.create_category_index(categories) category_index = label_map_util.create_category_index(categories)
......
...@@ -94,9 +94,8 @@ def cosine_decay_with_warmup(global_step, ...@@ -94,9 +94,8 @@ def cosine_decay_with_warmup(global_step,
raise ValueError('total_steps must be larger or equal to ' raise ValueError('total_steps must be larger or equal to '
'warmup_steps.') 'warmup_steps.')
learning_rate = 0.5 * learning_rate_base * ( learning_rate = 0.5 * learning_rate_base * (
1 + tf.cos(np.pi * tf.cast( 1 + tf.cos(np.pi * (tf.cast(global_step, tf.float32) - warmup_steps
global_step - warmup_steps, tf.float32 ) / float(total_steps - warmup_steps)))
) / float(total_steps - warmup_steps)))
if warmup_steps > 0: if warmup_steps > 0:
slope = (learning_rate_base - warmup_learning_rate) / warmup_steps slope = (learning_rate_base - warmup_learning_rate) / warmup_steps
pre_cosine_learning_rate = slope * tf.cast( pre_cosine_learning_rate = slope * tf.cast(
......
...@@ -21,7 +21,6 @@ Example box operations that are supported: ...@@ -21,7 +21,6 @@ Example box operations that are supported:
""" """
import numpy as np import numpy as np
from six.moves import xrange
from object_detection.utils import np_box_list from object_detection.utils import np_box_list
from object_detection.utils import np_box_ops from object_detection.utils import np_box_ops
...@@ -97,7 +96,7 @@ def ioa(boxlist1, boxlist2): ...@@ -97,7 +96,7 @@ def ioa(boxlist1, boxlist2):
def gather(boxlist, indices, fields=None): def gather(boxlist, indices, fields=None):
"""Gather boxes from BoxList according to indices and return new BoxList. """Gather boxes from BoxList according to indices and return new BoxList.
By default, Gather returns boxes corresponding to the input index list, as By default, gather returns boxes corresponding to the input index list, as
well as all additional fields stored in the boxlist (indexing into the well as all additional fields stored in the boxlist (indexing into the
first dimension). However one can optionally only gather from a first dimension). However one can optionally only gather from a
subset of fields. subset of fields.
......
...@@ -100,16 +100,16 @@ class AddExtraFieldTest(tf.test.TestCase): ...@@ -100,16 +100,16 @@ class AddExtraFieldTest(tf.test.TestCase):
def test_get_extra_fields(self): def test_get_extra_fields(self):
boxlist = self.boxlist boxlist = self.boxlist
self.assertSameElements(boxlist.get_extra_fields(), []) self.assertItemsEqual(boxlist.get_extra_fields(), [])
scores = np.array([0.5, 0.7, 0.9], dtype=float) scores = np.array([0.5, 0.7, 0.9], dtype=float)
boxlist.add_field('scores', scores) boxlist.add_field('scores', scores)
self.assertSameElements(boxlist.get_extra_fields(), ['scores']) self.assertItemsEqual(boxlist.get_extra_fields(), ['scores'])
labels = np.array([[0, 0, 0, 1, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 1]], labels = np.array([[0, 0, 0, 1, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 1]],
dtype=int) dtype=int)
boxlist.add_field('labels', labels) boxlist.add_field('labels', labels)
self.assertSameElements(boxlist.get_extra_fields(), ['scores', 'labels']) self.assertItemsEqual(boxlist.get_extra_fields(), ['scores', 'labels'])
def test_get_coordinates(self): def test_get_coordinates(self):
y_min, x_min, y_max, x_max = self.boxlist.get_coordinates() y_min, x_min, y_max, x_max = self.boxlist.get_coordinates()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment