Commit 8a72df2d authored by Vivek Rathod's avatar Vivek Rathod
Browse files

* Change evalutor and eval_util.py to use new eval

interface defined in utils/object_detection_evaluation.py.
* Update eval.py to use routines from utils/config_utils.py
to parse config files.
parent a3c7d7e8
...@@ -57,8 +57,13 @@ py_library( ...@@ -57,8 +57,13 @@ py_library(
], ],
deps = [ deps = [
"//tensorflow", "//tensorflow",
"//tensorflow_models/object_detection/core:box_list",
"//tensorflow_models/object_detection/core:box_list_ops",
"//tensorflow_models/object_detection/core:keypoint_ops",
"//tensorflow_models/object_detection/core:standard_fields",
"//tensorflow_models/object_detection/utils:label_map_util", "//tensorflow_models/object_detection/utils:label_map_util",
"//tensorflow_models/object_detection/utils:object_detection_evaluation", "//tensorflow_models/object_detection/utils:object_detection_evaluation",
"//tensorflow_models/object_detection/utils:ops",
"//tensorflow_models/object_detection/utils:visualization_utils", "//tensorflow_models/object_detection/utils:visualization_utils",
], ],
) )
...@@ -69,11 +74,10 @@ py_library( ...@@ -69,11 +74,10 @@ py_library(
deps = [ deps = [
"//tensorflow", "//tensorflow",
"//tensorflow_models/object_detection:eval_util", "//tensorflow_models/object_detection:eval_util",
"//tensorflow_models/object_detection/core:box_list",
"//tensorflow_models/object_detection/core:box_list_ops",
"//tensorflow_models/object_detection/core:prefetcher", "//tensorflow_models/object_detection/core:prefetcher",
"//tensorflow_models/object_detection/core:standard_fields", "//tensorflow_models/object_detection/core:standard_fields",
"//tensorflow_models/object_detection/protos:eval_py_pb2", "//tensorflow_models/object_detection/protos:eval_py_pb2",
"//tensorflow_models/object_detection/utils:object_detection_evaluation",
], ],
) )
...@@ -87,10 +91,7 @@ py_binary( ...@@ -87,10 +91,7 @@ py_binary(
"//tensorflow", "//tensorflow",
"//tensorflow_models/object_detection/builders:input_reader_builder", "//tensorflow_models/object_detection/builders:input_reader_builder",
"//tensorflow_models/object_detection/builders:model_builder", "//tensorflow_models/object_detection/builders:model_builder",
"//tensorflow_models/object_detection/protos:eval_py_pb2", "//tensorflow_models/object_detection/utils:config_util",
"//tensorflow_models/object_detection/protos:input_reader_py_pb2",
"//tensorflow_models/object_detection/protos:model_py_pb2",
"//tensorflow_models/object_detection/protos:pipeline_py_pb2",
"//tensorflow_models/object_detection/utils:label_map_util", "//tensorflow_models/object_detection/utils:label_map_util",
], ],
) )
......
...@@ -44,18 +44,16 @@ Example usage: ...@@ -44,18 +44,16 @@ Example usage:
--input_config_path=eval_input_config.pbtxt --input_config_path=eval_input_config.pbtxt
""" """
import functools import functools
import os
import tensorflow as tf import tensorflow as tf
from google.protobuf import text_format
from object_detection import evaluator from object_detection import evaluator
from object_detection.builders import input_reader_builder from object_detection.builders import input_reader_builder
from object_detection.builders import model_builder from object_detection.builders import model_builder
from object_detection.protos import eval_pb2 from object_detection.utils import config_util
from object_detection.protos import input_reader_pb2
from object_detection.protos import model_pb2
from object_detection.protos import pipeline_pb2
from object_detection.utils import label_map_util from object_detection.utils import label_map_util
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
flags = tf.app.flags flags = tf.app.flags
...@@ -75,69 +73,37 @@ flags.DEFINE_string('input_config_path', '', ...@@ -75,69 +73,37 @@ flags.DEFINE_string('input_config_path', '',
'Path to an input_reader_pb2.InputReader config file.') 'Path to an input_reader_pb2.InputReader config file.')
flags.DEFINE_string('model_config_path', '', flags.DEFINE_string('model_config_path', '',
'Path to a model_pb2.DetectionModel config file.') 'Path to a model_pb2.DetectionModel config file.')
flags.DEFINE_boolean('run_once', False, 'Option to only run a single pass of '
'evaluation. Overrides the `max_evals` parameter in the '
'provided config.')
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def get_configs_from_pipeline_file():
"""Reads evaluation configuration from a pipeline_pb2.TrainEvalPipelineConfig.
Reads evaluation config from file specified by pipeline_config_path flag.
Returns:
model_config: a model_pb2.DetectionModel
eval_config: a eval_pb2.EvalConfig
input_config: a input_reader_pb2.InputReader
"""
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
text_format.Merge(f.read(), pipeline_config)
model_config = pipeline_config.model
if FLAGS.eval_training_data:
eval_config = pipeline_config.train_config
else:
eval_config = pipeline_config.eval_config
input_config = pipeline_config.eval_input_reader
return model_config, eval_config, input_config
def get_configs_from_multiple_files():
"""Reads evaluation configuration from multiple config files.
Reads the evaluation config from the following files:
model_config: Read from --model_config_path
eval_config: Read from --eval_config_path
input_config: Read from --input_config_path
Returns:
model_config: a model_pb2.DetectionModel
eval_config: a eval_pb2.EvalConfig
input_config: a input_reader_pb2.InputReader
"""
eval_config = eval_pb2.EvalConfig()
with tf.gfile.GFile(FLAGS.eval_config_path, 'r') as f:
text_format.Merge(f.read(), eval_config)
model_config = model_pb2.DetectionModel()
with tf.gfile.GFile(FLAGS.model_config_path, 'r') as f:
text_format.Merge(f.read(), model_config)
input_config = input_reader_pb2.InputReader()
with tf.gfile.GFile(FLAGS.input_config_path, 'r') as f:
text_format.Merge(f.read(), input_config)
return model_config, eval_config, input_config
def main(unused_argv): def main(unused_argv):
assert FLAGS.checkpoint_dir, '`checkpoint_dir` is missing.' assert FLAGS.checkpoint_dir, '`checkpoint_dir` is missing.'
assert FLAGS.eval_dir, '`eval_dir` is missing.' assert FLAGS.eval_dir, '`eval_dir` is missing.'
tf.gfile.MakeDirs(FLAGS.eval_dir)
if FLAGS.pipeline_config_path: if FLAGS.pipeline_config_path:
model_config, eval_config, input_config = get_configs_from_pipeline_file() configs = config_util.get_configs_from_pipeline_file(
FLAGS.pipeline_config_path)
tf.gfile.Copy(FLAGS.pipeline_config_path,
os.path.join(FLAGS.eval_dir, 'pipeline.config'),
overwrite=True)
else: else:
model_config, eval_config, input_config = get_configs_from_multiple_files() configs = config_util.get_configs_from_multiple_files(
model_config_path=FLAGS.model_config_path,
eval_config_path=FLAGS.eval_config_path,
eval_input_config_path=FLAGS.input_config_path)
for name, config in [('model.config', FLAGS.model_config_path),
('eval.config', FLAGS.eval_config_path),
('input.config', FLAGS.input_config_path)]:
tf.gfile.Copy(config,
os.path.join(FLAGS.eval_dir, name),
overwrite=True)
model_config = configs['model']
eval_config = configs['eval_config']
input_config = configs['eval_input_config']
model_fn = functools.partial( model_fn = functools.partial(
model_builder.build, model_builder.build,
...@@ -153,6 +119,9 @@ def main(unused_argv): ...@@ -153,6 +119,9 @@ def main(unused_argv):
categories = label_map_util.convert_label_map_to_categories( categories = label_map_util.convert_label_map_to_categories(
label_map, max_num_classes) label_map, max_num_classes)
if FLAGS.run_once:
eval_config.max_evals = 1
evaluator.evaluate(create_input_dict_fn, model_fn, eval_config, categories, evaluator.evaluate(create_input_dict_fn, model_fn, eval_config, categories,
FLAGS.checkpoint_dir, FLAGS.eval_dir) FLAGS.checkpoint_dir, FLAGS.eval_dir)
......
This diff is collapsed.
...@@ -12,26 +12,30 @@ ...@@ -12,26 +12,30 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Detection model evaluator. """Detection model evaluator.
This file provides a generic evaluation method that can be used to evaluate a This file provides a generic evaluation method that can be used to evaluate a
DetectionModel. DetectionModel.
""" """
import logging import logging
import tensorflow as tf import tensorflow as tf
from object_detection import eval_util from object_detection import eval_util
from object_detection.core import box_list
from object_detection.core import box_list_ops
from object_detection.core import prefetcher from object_detection.core import prefetcher
from object_detection.core import standard_fields as fields from object_detection.core import standard_fields as fields
from object_detection.utils import ops from object_detection.utils import object_detection_evaluation
slim = tf.contrib.slim # A dictionary of metric names to classes that implement the metric. The classes
# in the dictionary must implement
EVAL_METRICS_FN_DICT = { # utils.object_detection_evaluation.DetectionEvaluator interface.
'pascal_voc_metrics': eval_util.evaluate_detection_results_pascal_voc EVAL_METRICS_CLASS_DICT = {
'pascal_voc_metrics':
object_detection_evaluation.PascalDetectionEvaluator,
'weighted_pascal_voc_metrics':
object_detection_evaluation.WeightedPascalDetectionEvaluator,
'open_images_metrics':
object_detection_evaluation.OpenImagesDetectionEvaluator
} }
...@@ -56,54 +60,56 @@ def _extract_prediction_tensors(model, ...@@ -56,54 +60,56 @@ def _extract_prediction_tensors(model,
prediction_dict = model.predict(preprocessed_image) prediction_dict = model.predict(preprocessed_image)
detections = model.postprocess(prediction_dict) detections = model.postprocess(prediction_dict)
original_image_shape = tf.shape(original_image) groundtruth = None
absolute_detection_boxlist = box_list_ops.to_absolute_coordinates(
box_list.BoxList(tf.squeeze(detections['detection_boxes'], axis=0)),
original_image_shape[1], original_image_shape[2])
label_id_offset = 1
tensor_dict = {
'original_image': original_image,
'image_id': input_dict[fields.InputDataFields.source_id],
'detection_boxes': absolute_detection_boxlist.get(),
'detection_scores': tf.squeeze(detections['detection_scores'], axis=0),
'detection_classes': (
tf.squeeze(detections['detection_classes'], axis=0) +
label_id_offset),
}
if 'detection_masks' in detections:
detection_masks = tf.squeeze(detections['detection_masks'],
axis=0)
detection_boxes = tf.squeeze(detections['detection_boxes'],
axis=0)
# TODO: This should be done in model's postprocess function ideally.
detection_masks_reframed = ops.reframe_box_masks_to_image_masks(
detection_masks,
detection_boxes,
original_image_shape[1], original_image_shape[2])
detection_masks_reframed = tf.to_float(tf.greater(detection_masks_reframed,
0.5))
tensor_dict['detection_masks'] = detection_masks_reframed
# load groundtruth fields into tensor_dict
if not ignore_groundtruth: if not ignore_groundtruth:
normalized_gt_boxlist = box_list.BoxList( groundtruth = {
input_dict[fields.InputDataFields.groundtruth_boxes]) fields.InputDataFields.groundtruth_boxes:
gt_boxlist = box_list_ops.scale(normalized_gt_boxlist, input_dict[fields.InputDataFields.groundtruth_boxes],
tf.shape(original_image)[1], fields.InputDataFields.groundtruth_classes:
tf.shape(original_image)[2]) input_dict[fields.InputDataFields.groundtruth_classes],
groundtruth_boxes = gt_boxlist.get() fields.InputDataFields.groundtruth_area:
groundtruth_classes = input_dict[fields.InputDataFields.groundtruth_classes] input_dict[fields.InputDataFields.groundtruth_area],
tensor_dict['groundtruth_boxes'] = groundtruth_boxes fields.InputDataFields.groundtruth_is_crowd:
tensor_dict['groundtruth_classes'] = groundtruth_classes input_dict[fields.InputDataFields.groundtruth_is_crowd],
tensor_dict['area'] = input_dict[fields.InputDataFields.groundtruth_area] fields.InputDataFields.groundtruth_difficult:
tensor_dict['is_crowd'] = input_dict[ input_dict[fields.InputDataFields.groundtruth_difficult]
fields.InputDataFields.groundtruth_is_crowd] }
tensor_dict['difficult'] = input_dict[ if fields.InputDataFields.groundtruth_group_of in input_dict:
fields.InputDataFields.groundtruth_difficult] groundtruth[fields.InputDataFields.groundtruth_group_of] = (
if 'detection_masks' in tensor_dict: input_dict[fields.InputDataFields.groundtruth_group_of])
tensor_dict['groundtruth_instance_masks'] = input_dict[ if fields.DetectionResultFields.detection_masks in detections:
fields.InputDataFields.groundtruth_instance_masks] groundtruth[fields.InputDataFields.groundtruth_instance_masks] = (
return tensor_dict input_dict[fields.InputDataFields.groundtruth_instance_masks])
return eval_util.result_dict_for_single_example(
original_image,
input_dict[fields.InputDataFields.source_id],
detections,
groundtruth,
class_agnostic=(
fields.DetectionResultFields.detection_classes not in detections),
scale_to_absolute=True)
def get_evaluators(eval_config, categories):
"""Returns the evaluator class according to eval_config, valid for categories.
Args:
eval_config: evaluation configurations.
categories: a list of categories to evaluate.
Returns:
An list of instances of DetectionEvaluator.
Raises:
ValueError: if metric is not in the metric class dictionary.
"""
eval_metric_fn_key = eval_config.metrics_set
if eval_metric_fn_key not in EVAL_METRICS_CLASS_DICT:
raise ValueError('Metric not found: {}'.format(eval_metric_fn_key))
return [
EVAL_METRICS_CLASS_DICT[eval_metric_fn_key](
categories=categories)
]
def evaluate(create_input_dict_fn, create_model_fn, eval_config, categories, def evaluate(create_input_dict_fn, create_model_fn, eval_config, categories,
...@@ -118,6 +124,10 @@ def evaluate(create_input_dict_fn, create_model_fn, eval_config, categories, ...@@ -118,6 +124,10 @@ def evaluate(create_input_dict_fn, create_model_fn, eval_config, categories,
have an integer 'id' field and string 'name' field. have an integer 'id' field and string 'name' field.
checkpoint_dir: directory to load the checkpoints to evaluate from. checkpoint_dir: directory to load the checkpoints to evaluate from.
eval_dir: directory to write evaluation metrics summary to. eval_dir: directory to write evaluation metrics summary to.
Returns:
metrics: A dictionary containing metric names and values from the latest
run.
""" """
model = create_model_fn() model = create_model_fn()
...@@ -131,7 +141,7 @@ def evaluate(create_input_dict_fn, create_model_fn, eval_config, categories, ...@@ -131,7 +141,7 @@ def evaluate(create_input_dict_fn, create_model_fn, eval_config, categories,
create_input_dict_fn=create_input_dict_fn, create_input_dict_fn=create_input_dict_fn,
ignore_groundtruth=eval_config.ignore_groundtruth) ignore_groundtruth=eval_config.ignore_groundtruth)
def _process_batch(tensor_dict, sess, batch_index, counters, update_op): def _process_batch(tensor_dict, sess, batch_index, counters):
"""Evaluates tensors in tensor_dict, visualizing the first K examples. """Evaluates tensors in tensor_dict, visualizing the first K examples.
This function calls sess.run on tensor_dict, evaluating the original_image This function calls sess.run on tensor_dict, evaluating the original_image
...@@ -146,66 +156,57 @@ def evaluate(create_input_dict_fn, create_model_fn, eval_config, categories, ...@@ -146,66 +156,57 @@ def evaluate(create_input_dict_fn, create_model_fn, eval_config, categories,
be updated to keep track of number of successful and failed runs, be updated to keep track of number of successful and failed runs,
respectively. If these fields are not updated, then the success/skipped respectively. If these fields are not updated, then the success/skipped
counter values shown at the end of evaluation will be incorrect. counter values shown at the end of evaluation will be incorrect.
update_op: An update op that has to be run along with output tensors. For
example this could be an op to compute statistics for slim metrics.
Returns: Returns:
result_dict: a dictionary of numpy arrays result_dict: a dictionary of numpy arrays
""" """
if batch_index >= eval_config.num_visualizations:
if 'original_image' in tensor_dict:
tensor_dict = {k: v for (k, v) in tensor_dict.items()
if k != 'original_image'}
try: try:
(result_dict, _) = sess.run([tensor_dict, update_op]) result_dict = sess.run(tensor_dict)
counters['success'] += 1 counters['success'] += 1
except tf.errors.InvalidArgumentError: except tf.errors.InvalidArgumentError:
logging.info('Skipping image') logging.info('Skipping image')
counters['skipped'] += 1 counters['skipped'] += 1
return {} return {}
global_step = tf.train.global_step(sess, slim.get_global_step()) global_step = tf.train.global_step(sess, tf.train.get_global_step())
if batch_index < eval_config.num_visualizations: if batch_index < eval_config.num_visualizations:
tag = 'image-{}'.format(batch_index) tag = 'image-{}'.format(batch_index)
eval_util.visualize_detection_results( eval_util.visualize_detection_results(
result_dict, tag, global_step, categories=categories, result_dict,
tag,
global_step,
categories=categories,
summary_dir=eval_dir, summary_dir=eval_dir,
export_dir=eval_config.visualization_export_dir, export_dir=eval_config.visualization_export_dir,
show_groundtruth=eval_config.visualization_export_dir) show_groundtruth=eval_config.visualization_export_dir)
return result_dict return result_dict
def _process_aggregated_results(result_lists):
eval_metric_fn_key = eval_config.metrics_set
if eval_metric_fn_key not in EVAL_METRICS_FN_DICT:
raise ValueError('Metric not found: {}'.format(eval_metric_fn_key))
return EVAL_METRICS_FN_DICT[eval_metric_fn_key](result_lists,
categories=categories)
variables_to_restore = tf.global_variables() variables_to_restore = tf.global_variables()
global_step = slim.get_or_create_global_step() global_step = tf.train.get_or_create_global_step()
variables_to_restore.append(global_step) variables_to_restore.append(global_step)
if eval_config.use_moving_averages: if eval_config.use_moving_averages:
variable_averages = tf.train.ExponentialMovingAverage(0.0) variable_averages = tf.train.ExponentialMovingAverage(0.0)
variables_to_restore = variable_averages.variables_to_restore() variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore) saver = tf.train.Saver(variables_to_restore)
def _restore_latest_checkpoint(sess): def _restore_latest_checkpoint(sess):
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir) latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
saver.restore(sess, latest_checkpoint) saver.restore(sess, latest_checkpoint)
eval_util.repeated_checkpoint_run( metrics = eval_util.repeated_checkpoint_run(
tensor_dict=tensor_dict, tensor_dict=tensor_dict,
update_op=tf.no_op(),
summary_dir=eval_dir, summary_dir=eval_dir,
aggregated_result_processor=_process_aggregated_results, evaluators=get_evaluators(eval_config, categories),
batch_processor=_process_batch, batch_processor=_process_batch,
checkpoint_dirs=[checkpoint_dir], checkpoint_dirs=[checkpoint_dir],
variables_to_restore=None, variables_to_restore=None,
restore_fn=_restore_latest_checkpoint, restore_fn=_restore_latest_checkpoint,
num_batches=eval_config.num_examples, num_batches=eval_config.num_examples,
eval_interval_secs=eval_config.eval_interval_secs, eval_interval_secs=eval_config.eval_interval_secs,
max_number_of_evaluations=( max_number_of_evaluations=(1 if eval_config.ignore_groundtruth else
1 if eval_config.ignore_groundtruth else eval_config.max_evals
eval_config.max_evals if eval_config.max_evals else if eval_config.max_evals else None),
None),
master=eval_config.eval_master, master=eval_config.eval_master,
save_graph=eval_config.save_graph, save_graph=eval_config.save_graph,
save_graph_dir=(eval_dir if eval_config.save_graph else '')) save_graph_dir=(eval_dir if eval_config.save_graph else ''))
return metrics
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment