".github/vscode:/vscode.git/clone" did not exist on "b5592bd67a4a3836921d1a93ab90efd3b3e436f7"
eval.py 5.54 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

r"""Evaluation executable for detection models.

This executable is used to evaluate DetectionModels. There are two ways of
configuring the eval job.

1) A single pipeline_pb2.TrainEvalPipelineConfig file maybe specified instead.
In this mode, the --eval_training_data flag may be given to force the pipeline
to evaluate on training data instead.

Example usage:
    ./eval \
        --logtostderr \
        --checkpoint_dir=path/to/checkpoint_dir \
        --eval_dir=path/to/eval_dir \
        --pipeline_config_path=pipeline_config.pbtxt

2) Three configuration files may be provided: a model_pb2.DetectionModel
JongYoon Lim's avatar
JongYoon Lim committed
33
configuration file to define what type of DetectionModel is being evaluated, an
34
35
36
37
38
39
40
41
42
43
44
45
46
input_reader_pb2.InputReader file to specify what data the model is evaluating
and an eval_pb2.EvalConfig file to configure evaluation parameters.

Example usage:
    ./eval \
        --logtostderr \
        --checkpoint_dir=path/to/checkpoint_dir \
        --eval_dir=path/to/eval_dir \
        --eval_config_path=eval_config.pbtxt \
        --model_config_path=model_config.pbtxt \
        --input_config_path=eval_input_config.pbtxt
"""
import functools
47
import os
48
49
50
import tensorflow as tf

from object_detection import evaluator
51
from object_detection.builders import dataset_builder
52
from object_detection.builders import graph_rewriter_builder
53
from object_detection.builders import model_builder
54
from object_detection.utils import config_util
55
from object_detection.utils import dataset_util
56
57
from object_detection.utils import label_map_util

58

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
tf.logging.set_verbosity(tf.logging.INFO)

flags = tf.app.flags
flags.DEFINE_boolean('eval_training_data', False,
                     'If training data should be evaluated for this job.')
flags.DEFINE_string('checkpoint_dir', '',
                    'Directory containing checkpoints to evaluate, typically '
                    'set to `train_dir` used in the training job.')
flags.DEFINE_string('eval_dir', '',
                    'Directory to write eval summaries to.')
flags.DEFINE_string('pipeline_config_path', '',
                    'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
                    'file. If provided, other configs are ignored')
flags.DEFINE_string('eval_config_path', '',
                    'Path to an eval_pb2.EvalConfig config file.')
flags.DEFINE_string('input_config_path', '',
                    'Path to an input_reader_pb2.InputReader config file.')
flags.DEFINE_string('model_config_path', '',
                    'Path to a model_pb2.DetectionModel config file.')
78
79
80
flags.DEFINE_boolean('run_once', False, 'Option to only run a single pass of '
                     'evaluation. Overrides the `max_evals` parameter in the '
                     'provided config.')
81
82
83
84
85
86
FLAGS = flags.FLAGS


def main(unused_argv):
  assert FLAGS.checkpoint_dir, '`checkpoint_dir` is missing.'
  assert FLAGS.eval_dir, '`eval_dir` is missing.'
87
  tf.gfile.MakeDirs(FLAGS.eval_dir)
88
  if FLAGS.pipeline_config_path:
89
90
91
92
93
    configs = config_util.get_configs_from_pipeline_file(
        FLAGS.pipeline_config_path)
    tf.gfile.Copy(FLAGS.pipeline_config_path,
                  os.path.join(FLAGS.eval_dir, 'pipeline.config'),
                  overwrite=True)
94
  else:
95
96
97
98
99
100
101
102
103
104
105
106
107
    configs = config_util.get_configs_from_multiple_files(
        model_config_path=FLAGS.model_config_path,
        eval_config_path=FLAGS.eval_config_path,
        eval_input_config_path=FLAGS.input_config_path)
    for name, config in [('model.config', FLAGS.model_config_path),
                         ('eval.config', FLAGS.eval_config_path),
                         ('input.config', FLAGS.input_config_path)]:
      tf.gfile.Copy(config,
                    os.path.join(FLAGS.eval_dir, name),
                    overwrite=True)

  model_config = configs['model']
  eval_config = configs['eval_config']
108
  input_config = configs['eval_input_config']
jzhugithub's avatar
jzhugithub committed
109
110
  if FLAGS.eval_training_data:
    input_config = configs['train_input_config']
111
112
113
114
115
116

  model_fn = functools.partial(
      model_builder.build,
      model_config=model_config,
      is_training=False)

117
118
119
120
121
  def get_next(config):
    return dataset_util.make_initializable_iterator(
        dataset_builder.build(config)).get_next()

  create_input_dict_fn = functools.partial(get_next, input_config)
122
123
124
125
126
127

  label_map = label_map_util.load_labelmap(input_config.label_map_path)
  max_num_classes = max([item.id for item in label_map.item])
  categories = label_map_util.convert_label_map_to_categories(
      label_map, max_num_classes)

128
129
130
  if FLAGS.run_once:
    eval_config.max_evals = 1

131
132
133
134
135
136
137
138
139
140
141
142
143
  graph_rewriter_fn = None
  if 'graph_rewriter_config' in configs:
    graph_rewriter_fn = graph_rewriter_builder.build(
        configs['graph_rewriter_config'], is_training=False)

  evaluator.evaluate(
      create_input_dict_fn,
      model_fn,
      eval_config,
      categories,
      FLAGS.checkpoint_dir,
      FLAGS.eval_dir,
      graph_hook_fn=graph_rewriter_fn)
144
145
146
147


if __name__ == '__main__':
  tf.app.run()