Unverified Commit 3f78f4cf authored by derekjchow's avatar derekjchow Committed by GitHub
Browse files

Merge pull request #3494 from pkulzc/master

Update object detection with internal changes and remove unused BUILD files.
parents 73748d01 0319908c
...@@ -256,7 +256,7 @@ def create_tf_record(output_filename, ...@@ -256,7 +256,7 @@ def create_tf_record(output_filename,
writer.close() writer.close()
# TODO: Add test for pet/PASCAL main files. # TODO(derekjchow): Add test for pet/PASCAL main files.
def main(_): def main(_):
data_dir = FLAGS.data_dir data_dir = FLAGS.data_dir
label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path) label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)
......
...@@ -50,7 +50,7 @@ def write_metrics(metrics, global_step, summary_dir): ...@@ -50,7 +50,7 @@ def write_metrics(metrics, global_step, summary_dir):
logging.info('Metrics written to tf summary.') logging.info('Metrics written to tf summary.')
# TODO: Add tests. # TODO(rathodv): Add tests.
def visualize_detection_results(result_dict, def visualize_detection_results(result_dict,
tag, tag,
global_step, global_step,
...@@ -289,7 +289,7 @@ def _run_checkpoint_once(tensor_dict, ...@@ -289,7 +289,7 @@ def _run_checkpoint_once(tensor_dict,
for evaluator in evaluators: for evaluator in evaluators:
# TODO(b/65130867): Use image_id tensor once we fix the input data # TODO(b/65130867): Use image_id tensor once we fix the input data
# decoders to return correct image_id. # decoders to return correct image_id.
# TODO: result_dict contains batches of images, while # TODO(akuznetsa): result_dict contains batches of images, while
# add_single_ground_truth_image_info expects a single image. Fix # add_single_ground_truth_image_info expects a single image. Fix
evaluator.add_single_ground_truth_image_info( evaluator.add_single_ground_truth_image_info(
image_id=batch, groundtruth_dict=result_dict) image_id=batch, groundtruth_dict=result_dict)
...@@ -314,7 +314,7 @@ def _run_checkpoint_once(tensor_dict, ...@@ -314,7 +314,7 @@ def _run_checkpoint_once(tensor_dict,
return (global_step, all_evaluator_metrics) return (global_step, all_evaluator_metrics)
# TODO: Add tests. # TODO(rathodv): Add tests.
def repeated_checkpoint_run(tensor_dict, def repeated_checkpoint_run(tensor_dict,
summary_dir, summary_dir,
evaluators, evaluators,
...@@ -487,15 +487,12 @@ def result_dict_for_single_example(image, ...@@ -487,15 +487,12 @@ def result_dict_for_single_example(image,
detection_fields = fields.DetectionResultFields detection_fields = fields.DetectionResultFields
detection_boxes = detections[detection_fields.detection_boxes][0] detection_boxes = detections[detection_fields.detection_boxes][0]
output_dict[detection_fields.detection_boxes] = detection_boxes
image_shape = tf.shape(image) image_shape = tf.shape(image)
if scale_to_absolute: if scale_to_absolute:
absolute_detection_boxlist = box_list_ops.to_absolute_coordinates( absolute_detection_boxlist = box_list_ops.to_absolute_coordinates(
box_list.BoxList(detection_boxes), image_shape[1], image_shape[2]) box_list.BoxList(detection_boxes), image_shape[1], image_shape[2])
output_dict[detection_fields.detection_boxes] = ( detection_boxes = absolute_detection_boxlist.get()
absolute_detection_boxlist.get())
detection_scores = detections[detection_fields.detection_scores][0] detection_scores = detections[detection_fields.detection_scores][0]
output_dict[detection_fields.detection_scores] = detection_scores
if class_agnostic: if class_agnostic:
detection_classes = tf.ones_like(detection_scores, dtype=tf.int64) detection_classes = tf.ones_like(detection_scores, dtype=tf.int64)
...@@ -503,15 +500,22 @@ def result_dict_for_single_example(image, ...@@ -503,15 +500,22 @@ def result_dict_for_single_example(image,
detection_classes = ( detection_classes = (
tf.to_int64(detections[detection_fields.detection_classes][0]) + tf.to_int64(detections[detection_fields.detection_classes][0]) +
label_id_offset) label_id_offset)
num_detections = tf.to_int32(detections[detection_fields.num_detections][0])
detection_boxes = tf.slice(
detection_boxes, begin=[0, 0], size=[num_detections, -1])
detection_classes = tf.slice(
detection_classes, begin=[0], size=[num_detections])
detection_scores = tf.slice(
detection_scores, begin=[0], size=[num_detections])
output_dict[detection_fields.detection_boxes] = detection_boxes
output_dict[detection_fields.detection_classes] = detection_classes output_dict[detection_fields.detection_classes] = detection_classes
output_dict[detection_fields.detection_scores] = detection_scores
if detection_fields.detection_masks in detections: if detection_fields.detection_masks in detections:
detection_masks = detections[detection_fields.detection_masks][0] detection_masks = detections[detection_fields.detection_masks][0]
# TODO: This should be done in model's postprocess # TODO(rathodv): This should be done in model's postprocess
# function ideally. # function ideally.
num_detections = tf.to_int32(detections[detection_fields.num_detections][0])
detection_boxes = tf.slice(
detection_boxes, begin=[0, 0], size=[num_detections, -1])
detection_masks = tf.slice( detection_masks = tf.slice(
detection_masks, begin=[0, 0, 0], size=[num_detections, -1, -1]) detection_masks, begin=[0, 0, 0], size=[num_detections, -1, -1])
detection_masks_reframed = ops.reframe_box_masks_to_image_masks( detection_masks_reframed = ops.reframe_box_masks_to_image_masks(
......
...@@ -128,7 +128,7 @@ def get_evaluators(eval_config, categories): ...@@ -128,7 +128,7 @@ def get_evaluators(eval_config, categories):
def evaluate(create_input_dict_fn, create_model_fn, eval_config, categories, def evaluate(create_input_dict_fn, create_model_fn, eval_config, categories,
checkpoint_dir, eval_dir, graph_hook_fn=None): checkpoint_dir, eval_dir, graph_hook_fn=None, evaluator_list=None):
"""Evaluation function for detection models. """Evaluation function for detection models.
Args: Args:
...@@ -143,6 +143,8 @@ def evaluate(create_input_dict_fn, create_model_fn, eval_config, categories, ...@@ -143,6 +143,8 @@ def evaluate(create_input_dict_fn, create_model_fn, eval_config, categories,
completely built. This is helpful to perform additional changes to the completely built. This is helpful to perform additional changes to the
training graph such as optimizing batchnorm. The function should modify training graph such as optimizing batchnorm. The function should modify
the default graph. the default graph.
evaluator_list: Optional list of instances of DetectionEvaluator. If not
given, this list of metrics is created according to the eval_config.
Returns: Returns:
metrics: A dictionary containing metric names and values from the latest metrics: A dictionary containing metric names and values from the latest
...@@ -222,10 +224,13 @@ def evaluate(create_input_dict_fn, create_model_fn, eval_config, categories, ...@@ -222,10 +224,13 @@ def evaluate(create_input_dict_fn, create_model_fn, eval_config, categories,
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir) latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
saver.restore(sess, latest_checkpoint) saver.restore(sess, latest_checkpoint)
if not evaluator_list:
evaluator_list = get_evaluators(eval_config, categories)
metrics = eval_util.repeated_checkpoint_run( metrics = eval_util.repeated_checkpoint_run(
tensor_dict=tensor_dict, tensor_dict=tensor_dict,
summary_dir=eval_dir, summary_dir=eval_dir,
evaluators=get_evaluators(eval_config, categories), evaluators=evaluator_list,
batch_processor=_process_batch, batch_processor=_process_batch,
checkpoint_dirs=[checkpoint_dir], checkpoint_dirs=[checkpoint_dir],
variables_to_restore=None, variables_to_restore=None,
......
...@@ -33,7 +33,7 @@ from object_detection.data_decoders import tf_example_decoder ...@@ -33,7 +33,7 @@ from object_detection.data_decoders import tf_example_decoder
slim = tf.contrib.slim slim = tf.contrib.slim
# TODO: Replace with freeze_graph.freeze_graph_with_def_protos when # TODO(derekjchow): Replace with freeze_graph.freeze_graph_with_def_protos when
# newer version of Tensorflow becomes more common. # newer version of Tensorflow becomes more common.
def freeze_graph_with_def_protos( def freeze_graph_with_def_protos(
input_graph_def, input_graph_def,
...@@ -242,7 +242,7 @@ def _add_output_tensor_nodes(postprocessed_tensors, ...@@ -242,7 +242,7 @@ def _add_output_tensor_nodes(postprocessed_tensors,
return outputs return outputs
def _write_frozen_graph(frozen_graph_path, frozen_graph_def): def write_frozen_graph(frozen_graph_path, frozen_graph_def):
"""Writes frozen graph to disk. """Writes frozen graph to disk.
Args: Args:
...@@ -254,10 +254,10 @@ def _write_frozen_graph(frozen_graph_path, frozen_graph_def): ...@@ -254,10 +254,10 @@ def _write_frozen_graph(frozen_graph_path, frozen_graph_def):
logging.info('%d ops in the final graph.', len(frozen_graph_def.node)) logging.info('%d ops in the final graph.', len(frozen_graph_def.node))
def _write_saved_model(saved_model_path, def write_saved_model(saved_model_path,
frozen_graph_def, frozen_graph_def,
inputs, inputs,
outputs): outputs):
"""Writes SavedModel to disk. """Writes SavedModel to disk.
If checkpoint_path is not None bakes the weights into the graph thereby If checkpoint_path is not None bakes the weights into the graph thereby
...@@ -301,10 +301,11 @@ def _write_saved_model(saved_model_path, ...@@ -301,10 +301,11 @@ def _write_saved_model(saved_model_path,
builder.save() builder.save()
def _write_graph_and_checkpoint(inference_graph_def, def write_graph_and_checkpoint(inference_graph_def,
model_path, model_path,
input_saver_def, input_saver_def,
trained_checkpoint_prefix): trained_checkpoint_prefix):
"""Writes the graph and the checkpoint into disk."""
for node in inference_graph_def.node: for node in inference_graph_def.node:
node.device = '' node.device = ''
with tf.Graph().as_default(): with tf.Graph().as_default():
...@@ -316,6 +317,44 @@ def _write_graph_and_checkpoint(inference_graph_def, ...@@ -316,6 +317,44 @@ def _write_graph_and_checkpoint(inference_graph_def,
saver.save(sess, model_path) saver.save(sess, model_path)
def _get_outputs_from_inputs(input_tensors, detection_model,
output_collection_name):
inputs = tf.to_float(input_tensors)
preprocessed_inputs, true_image_shapes = detection_model.preprocess(inputs)
output_tensors = detection_model.predict(
preprocessed_inputs, true_image_shapes)
postprocessed_tensors = detection_model.postprocess(
output_tensors, true_image_shapes)
return _add_output_tensor_nodes(postprocessed_tensors,
output_collection_name)
def _build_detection_graph(input_type, detection_model, input_shape,
output_collection_name, graph_hook_fn):
"""Build the detection graph."""
if input_type not in input_placeholder_fn_map:
raise ValueError('Unknown input type: {}'.format(input_type))
placeholder_args = {}
if input_shape is not None:
if input_type != 'image_tensor':
raise ValueError('Can only specify input shape for `image_tensor` '
'inputs.')
placeholder_args['input_shape'] = input_shape
placeholder_tensor, input_tensors = input_placeholder_fn_map[input_type](
**placeholder_args)
outputs = _get_outputs_from_inputs(
input_tensors=input_tensors,
detection_model=detection_model,
output_collection_name=output_collection_name)
# Add global step to the graph.
slim.get_or_create_global_step()
if graph_hook_fn: graph_hook_fn()
return outputs, placeholder_tensor
def _export_inference_graph(input_type, def _export_inference_graph(input_type,
detection_model, detection_model,
use_moving_averages, use_moving_averages,
...@@ -332,28 +371,12 @@ def _export_inference_graph(input_type, ...@@ -332,28 +371,12 @@ def _export_inference_graph(input_type,
saved_model_path = os.path.join(output_directory, 'saved_model') saved_model_path = os.path.join(output_directory, 'saved_model')
model_path = os.path.join(output_directory, 'model.ckpt') model_path = os.path.join(output_directory, 'model.ckpt')
if input_type not in input_placeholder_fn_map: outputs, placeholder_tensor = _build_detection_graph(
raise ValueError('Unknown input type: {}'.format(input_type)) input_type=input_type,
placeholder_args = {} detection_model=detection_model,
if input_shape is not None: input_shape=input_shape,
if input_type != 'image_tensor': output_collection_name=output_collection_name,
raise ValueError('Can only specify input shape for `image_tensor` ' graph_hook_fn=graph_hook_fn)
'inputs.')
placeholder_args['input_shape'] = input_shape
placeholder_tensor, input_tensors = input_placeholder_fn_map[input_type](
**placeholder_args)
inputs = tf.to_float(input_tensors)
preprocessed_inputs, true_image_shapes = detection_model.preprocess(inputs)
output_tensors = detection_model.predict(
preprocessed_inputs, true_image_shapes)
postprocessed_tensors = detection_model.postprocess(
output_tensors, true_image_shapes)
outputs = _add_output_tensor_nodes(postprocessed_tensors,
output_collection_name)
# Add global step to the graph.
slim.get_or_create_global_step()
if graph_hook_fn: graph_hook_fn()
saver_kwargs = {} saver_kwargs = {}
if use_moving_averages: if use_moving_averages:
...@@ -373,7 +396,7 @@ def _export_inference_graph(input_type, ...@@ -373,7 +396,7 @@ def _export_inference_graph(input_type,
saver = tf.train.Saver(**saver_kwargs) saver = tf.train.Saver(**saver_kwargs)
input_saver_def = saver.as_saver_def() input_saver_def = saver.as_saver_def()
_write_graph_and_checkpoint( write_graph_and_checkpoint(
inference_graph_def=tf.get_default_graph().as_graph_def(), inference_graph_def=tf.get_default_graph().as_graph_def(),
model_path=model_path, model_path=model_path,
input_saver_def=input_saver_def, input_saver_def=input_saver_def,
...@@ -393,9 +416,9 @@ def _export_inference_graph(input_type, ...@@ -393,9 +416,9 @@ def _export_inference_graph(input_type,
filename_tensor_name='save/Const:0', filename_tensor_name='save/Const:0',
clear_devices=True, clear_devices=True,
initializer_nodes='') initializer_nodes='')
_write_frozen_graph(frozen_graph_path, frozen_graph_def) write_frozen_graph(frozen_graph_path, frozen_graph_def)
_write_saved_model(saved_model_path, frozen_graph_def, write_saved_model(saved_model_path, frozen_graph_def,
placeholder_tensor, outputs) placeholder_tensor, outputs)
def export_inference_graph(input_type, def export_inference_graph(input_type,
......
...@@ -497,6 +497,66 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -497,6 +497,66 @@ class ExportInferenceGraphTest(tf.test.TestCase):
self.assertAllClose(masks_np, np.arange(64).reshape([2, 2, 4, 4])) self.assertAllClose(masks_np, np.arange(64).reshape([2, 2, 4, 4]))
self.assertAllClose(num_detections_np, [2, 1]) self.assertAllClose(num_detections_np, [2, 1])
def test_write_frozen_graph(self):
tmp_dir = self.get_temp_dir()
trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt')
self._save_checkpoint_from_mock_model(trained_checkpoint_prefix,
use_moving_averages=True)
output_directory = os.path.join(tmp_dir, 'output')
inference_graph_path = os.path.join(output_directory,
'frozen_inference_graph.pb')
tf.gfile.MakeDirs(output_directory)
with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel(add_detection_masks=True)
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.eval_config.use_moving_averages = False
detection_model = model_builder.build(pipeline_config.model,
is_training=False)
outputs, _ = exporter._build_detection_graph(
input_type='tf_example',
detection_model=detection_model,
input_shape=None,
output_collection_name='inference_op',
graph_hook_fn=None)
output_node_names = ','.join(outputs.keys())
saver = tf.train.Saver()
input_saver_def = saver.as_saver_def()
frozen_graph_def = exporter.freeze_graph_with_def_protos(
input_graph_def=tf.get_default_graph().as_graph_def(),
input_saver_def=input_saver_def,
input_checkpoint=trained_checkpoint_prefix,
output_node_names=output_node_names,
restore_op_name='save/restore_all',
filename_tensor_name='save/Const:0',
clear_devices=True,
initializer_nodes='')
exporter.write_frozen_graph(inference_graph_path, frozen_graph_def)
inference_graph = self._load_inference_graph(inference_graph_path)
tf_example_np = np.expand_dims(self._create_tf_example(
np.ones((4, 4, 3)).astype(np.uint8)), axis=0)
with self.test_session(graph=inference_graph) as sess:
tf_example = inference_graph.get_tensor_by_name('tf_example:0')
boxes = inference_graph.get_tensor_by_name('detection_boxes:0')
scores = inference_graph.get_tensor_by_name('detection_scores:0')
classes = inference_graph.get_tensor_by_name('detection_classes:0')
masks = inference_graph.get_tensor_by_name('detection_masks:0')
num_detections = inference_graph.get_tensor_by_name('num_detections:0')
(boxes_np, scores_np, classes_np, masks_np, num_detections_np) = sess.run(
[boxes, scores, classes, masks, num_detections],
feed_dict={tf_example: tf_example_np})
self.assertAllClose(boxes_np, [[[0.0, 0.0, 0.5, 0.5],
[0.5, 0.5, 0.8, 0.8]],
[[0.5, 0.5, 1.0, 1.0],
[0.0, 0.0, 0.0, 0.0]]])
self.assertAllClose(scores_np, [[0.7, 0.6],
[0.9, 0.0]])
self.assertAllClose(classes_np, [[1, 2],
[2, 1]])
self.assertAllClose(masks_np, np.arange(64).reshape([2, 2, 4, 4]))
self.assertAllClose(num_detections_np, [2, 1])
def test_export_graph_saves_pipeline_file(self): def test_export_graph_saves_pipeline_file(self):
tmp_dir = self.get_temp_dir() tmp_dir = self.get_temp_dir()
trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt') trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt')
...@@ -578,6 +638,82 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -578,6 +638,82 @@ class ExportInferenceGraphTest(tf.test.TestCase):
self.assertAllClose(masks_np, np.arange(64).reshape([2, 2, 4, 4])) self.assertAllClose(masks_np, np.arange(64).reshape([2, 2, 4, 4]))
self.assertAllClose(num_detections_np, [2, 1]) self.assertAllClose(num_detections_np, [2, 1])
def test_write_saved_model(self):
tmp_dir = self.get_temp_dir()
trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt')
self._save_checkpoint_from_mock_model(trained_checkpoint_prefix,
use_moving_averages=False)
output_directory = os.path.join(tmp_dir, 'output')
saved_model_path = os.path.join(output_directory, 'saved_model')
tf.gfile.MakeDirs(output_directory)
with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel(add_detection_masks=True)
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.eval_config.use_moving_averages = False
detection_model = model_builder.build(pipeline_config.model,
is_training=False)
outputs, placeholder_tensor = exporter._build_detection_graph(
input_type='tf_example',
detection_model=detection_model,
input_shape=None,
output_collection_name='inference_op',
graph_hook_fn=None)
output_node_names = ','.join(outputs.keys())
saver = tf.train.Saver()
input_saver_def = saver.as_saver_def()
frozen_graph_def = exporter.freeze_graph_with_def_protos(
input_graph_def=tf.get_default_graph().as_graph_def(),
input_saver_def=input_saver_def,
input_checkpoint=trained_checkpoint_prefix,
output_node_names=output_node_names,
restore_op_name='save/restore_all',
filename_tensor_name='save/Const:0',
clear_devices=True,
initializer_nodes='')
exporter.write_saved_model(
saved_model_path=saved_model_path,
frozen_graph_def=frozen_graph_def,
inputs=placeholder_tensor,
outputs=outputs)
tf_example_np = np.hstack([self._create_tf_example(
np.ones((4, 4, 3)).astype(np.uint8))] * 2)
with tf.Graph().as_default() as od_graph:
with self.test_session(graph=od_graph) as sess:
meta_graph = tf.saved_model.loader.load(
sess, [tf.saved_model.tag_constants.SERVING], saved_model_path)
signature = meta_graph.signature_def['serving_default']
input_tensor_name = signature.inputs['inputs'].name
tf_example = od_graph.get_tensor_by_name(input_tensor_name)
boxes = od_graph.get_tensor_by_name(
signature.outputs['detection_boxes'].name)
scores = od_graph.get_tensor_by_name(
signature.outputs['detection_scores'].name)
classes = od_graph.get_tensor_by_name(
signature.outputs['detection_classes'].name)
masks = od_graph.get_tensor_by_name(
signature.outputs['detection_masks'].name)
num_detections = od_graph.get_tensor_by_name(
signature.outputs['num_detections'].name)
(boxes_np, scores_np, classes_np, masks_np,
num_detections_np) = sess.run(
[boxes, scores, classes, masks, num_detections],
feed_dict={tf_example: tf_example_np})
self.assertAllClose(boxes_np, [[[0.0, 0.0, 0.5, 0.5],
[0.5, 0.5, 0.8, 0.8]],
[[0.5, 0.5, 1.0, 1.0],
[0.0, 0.0, 0.0, 0.0]]])
self.assertAllClose(scores_np, [[0.7, 0.6],
[0.9, 0.0]])
self.assertAllClose(classes_np, [[1, 2],
[2, 1]])
self.assertAllClose(masks_np, np.arange(64).reshape([2, 2, 4, 4]))
self.assertAllClose(num_detections_np, [2, 1])
def test_export_checkpoint_and_run_inference(self): def test_export_checkpoint_and_run_inference(self):
tmp_dir = self.get_temp_dir() tmp_dir = self.get_temp_dir()
trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt') trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt')
...@@ -626,6 +762,64 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -626,6 +762,64 @@ class ExportInferenceGraphTest(tf.test.TestCase):
self.assertAllClose(masks_np, np.arange(64).reshape([2, 2, 4, 4])) self.assertAllClose(masks_np, np.arange(64).reshape([2, 2, 4, 4]))
self.assertAllClose(num_detections_np, [2, 1]) self.assertAllClose(num_detections_np, [2, 1])
def test_write_graph_and_checkpoint(self):
tmp_dir = self.get_temp_dir()
trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt')
self._save_checkpoint_from_mock_model(trained_checkpoint_prefix,
use_moving_averages=False)
output_directory = os.path.join(tmp_dir, 'output')
model_path = os.path.join(output_directory, 'model.ckpt')
meta_graph_path = model_path + '.meta'
tf.gfile.MakeDirs(output_directory)
with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel(add_detection_masks=True)
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.eval_config.use_moving_averages = False
detection_model = model_builder.build(pipeline_config.model,
is_training=False)
exporter._build_detection_graph(
input_type='tf_example',
detection_model=detection_model,
input_shape=None,
output_collection_name='inference_op',
graph_hook_fn=None)
saver = tf.train.Saver()
input_saver_def = saver.as_saver_def()
exporter.write_graph_and_checkpoint(
inference_graph_def=tf.get_default_graph().as_graph_def(),
model_path=model_path,
input_saver_def=input_saver_def,
trained_checkpoint_prefix=trained_checkpoint_prefix)
tf_example_np = np.hstack([self._create_tf_example(
np.ones((4, 4, 3)).astype(np.uint8))] * 2)
with tf.Graph().as_default() as od_graph:
with self.test_session(graph=od_graph) as sess:
new_saver = tf.train.import_meta_graph(meta_graph_path)
new_saver.restore(sess, model_path)
tf_example = od_graph.get_tensor_by_name('tf_example:0')
boxes = od_graph.get_tensor_by_name('detection_boxes:0')
scores = od_graph.get_tensor_by_name('detection_scores:0')
classes = od_graph.get_tensor_by_name('detection_classes:0')
masks = od_graph.get_tensor_by_name('detection_masks:0')
num_detections = od_graph.get_tensor_by_name('num_detections:0')
(boxes_np, scores_np, classes_np, masks_np,
num_detections_np) = sess.run(
[boxes, scores, classes, masks, num_detections],
feed_dict={tf_example: tf_example_np})
self.assertAllClose(boxes_np, [[[0.0, 0.0, 0.5, 0.5],
[0.5, 0.5, 0.8, 0.8]],
[[0.5, 0.5, 1.0, 1.0],
[0.0, 0.0, 0.0, 0.0]]])
self.assertAllClose(scores_np, [[0.7, 0.6],
[0.9, 0.0]])
self.assertAllClose(classes_np, [[1, 2],
[2, 1]])
self.assertAllClose(masks_np, np.arange(64).reshape([2, 2, 4, 4]))
self.assertAllClose(num_detections_np, [2, 1])
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -53,7 +53,7 @@ recommended. Read [our paper](https://arxiv.org/abs/1611.10012) for a more ...@@ -53,7 +53,7 @@ recommended. Read [our paper](https://arxiv.org/abs/1611.10012) for a more
detailed discussion on the speed vs accuracy tradeoff. detailed discussion on the speed vs accuracy tradeoff.
To help new users get started, sample model configurations have been provided To help new users get started, sample model configurations have been provided
in the object_detection/samples/model_configs folder. The contents of these in the object_detection/samples/configs folder. The contents of these
configuration files can be pasted into `model` field of the skeleton configuration files can be pasted into `model` field of the skeleton
configuration. Users should note that the `num_classes` field should be changed configuration. Users should note that the `num_classes` field should be changed
to a value suited for the dataset the user is training on. to a value suited for the dataset the user is training on.
......
# Frequently Asked Questions
## Q: AttributeError: 'module' object has no attribute 'BackupHandler'
A: This BackupHandler (tf.contrib.slim.tfexample_decoder.BackupHandler) was
introduced in tensorflow 1.5.0 so runing with earlier versions may cause this
issue. It now has been replaced by
object_detection.data_decoders.tf_example_decoder.BackupHandler. Whoever sees
this issue should be able to resolve it by syncing your fork to HEAD.
## Q: Why can't I get the inference time as reported in model zoo?
A: The inference time reported in model zoo is mean time of testing hundreds of
images with a internal machine. As mentioned in
[Tensorflow detection model zoo](detection_model_zoo.md), this speed depends
highly on one's specific hardware configuration and should be treated more as
relative timing.
...@@ -4,15 +4,15 @@ ...@@ -4,15 +4,15 @@
Tensorflow Object Detection API depends on the following libraries: Tensorflow Object Detection API depends on the following libraries:
* Protobuf 2.6 * Protobuf 2.6
* Python-tk * Python-tk
* Pillow 1.0 * Pillow 1.0
* lxml * lxml
* tf Slim (which is included in the "tensorflow/models/research/" checkout) * tf Slim (which is included in the "tensorflow/models/research/" checkout)
* Jupyter notebook * Jupyter notebook
* Matplotlib * Matplotlib
* Tensorflow * Tensorflow
* cocoapi * cocoapi
For detailed steps to install Tensorflow, follow the [Tensorflow installation For detailed steps to install Tensorflow, follow the [Tensorflow installation
instructions](https://www.tensorflow.org/install/). A typical user can install instructions](https://www.tensorflow.org/install/). A typical user can install
......
...@@ -103,7 +103,7 @@ FLAGS = flags.FLAGS ...@@ -103,7 +103,7 @@ FLAGS = flags.FLAGS
def create_tf_example(example): def create_tf_example(example):
# TODO: Populate the following variables from your example. # TODO(user): Populate the following variables from your example.
height = None # Image height height = None # Image height
width = None # Image width width = None # Image width
filename = None # Filename of the image. Empty if image is not from file filename = None # Filename of the image. Empty if image is not from file
...@@ -139,7 +139,7 @@ def create_tf_example(example): ...@@ -139,7 +139,7 @@ def create_tf_example(example):
def main(_): def main(_):
writer = tf.python_io.TFRecordWriter(FLAGS.output_path) writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
# TODO: Write code to read in your dataset to examples variable # TODO(user): Write code to read in your dataset to examples variable
for example in examples: for example in examples:
tf_example = create_tf_example(example) tf_example = create_tf_example(example)
......
# Tensorflow Object Detection API: main runnables.
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"])
# Apache 2.0
py_library(
name = "detection_inference",
srcs = ["detection_inference.py"],
deps = [
"//tensorflow",
"//tensorflow/models/research/object_detection/core:standard_fields",
],
)
py_test(
name = "detection_inference_test",
srcs = ["detection_inference_test.py"],
deps = [
":detection_inference",
"//PIL:pil",
"//numpy",
"//tensorflow",
"//tensorflow/models/research/object_detection/core:standard_fields",
"//tensorflow/models/research/object_detection/utils:dataset_util",
],
)
py_binary(
name = "infer_detections",
srcs = ["infer_detections.py"],
deps = [
":detection_inference",
"//tensorflow",
],
)
...@@ -17,7 +17,6 @@ r"""Tests for detection_inference.py.""" ...@@ -17,7 +17,6 @@ r"""Tests for detection_inference.py."""
import os import os
import StringIO import StringIO
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import tensorflow as tf import tensorflow as tf
......
...@@ -200,8 +200,8 @@ def create_train_input_fn(train_config, train_input_config, ...@@ -200,8 +200,8 @@ def create_train_input_fn(train_config, train_input_config,
keypoints for each box. keypoints for each box.
Raises: Raises:
TypeError: if the `train_config` or `train_input_config` are not of the TypeError: if the `train_config`, `train_input_config` or `model_config`
correct type. are not of the correct type.
""" """
if not isinstance(train_config, train_pb2.TrainConfig): if not isinstance(train_config, train_pb2.TrainConfig):
raise TypeError('For training mode, the `train_config` must be a ' raise TypeError('For training mode, the `train_config` must be a '
...@@ -316,8 +316,8 @@ def create_eval_input_fn(eval_config, eval_input_config, model_config): ...@@ -316,8 +316,8 @@ def create_eval_input_fn(eval_config, eval_input_config, model_config):
which represent instance masks for objects. which represent instance masks for objects.
Raises: Raises:
TypeError: if the `eval_config` or `eval_input_config` are not of the TypeError: if the `eval_config`, `eval_input_config` or `model_config`
correct type. are not of the correct type.
""" """
del params del params
if not isinstance(eval_config, eval_pb2.EvalConfig): if not isinstance(eval_config, eval_pb2.EvalConfig):
......
...@@ -34,7 +34,6 @@ FLAGS = tf.flags.FLAGS ...@@ -34,7 +34,6 @@ FLAGS = tf.flags.FLAGS
def _get_configs_for_model(model_name): def _get_configs_for_model(model_name):
"""Returns configurations for model.""" """Returns configurations for model."""
# TODO: Make sure these tests work fine outside google3.
fname = os.path.join( fname = os.path.join(
FLAGS.test_srcdir, FLAGS.test_srcdir,
('google3/third_party/tensorflow_models/' ('google3/third_party/tensorflow_models/'
......
# Tensorflow Object Detection API: Matcher implementations.
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"])
# Apache 2.0
py_library(
name = "argmax_matcher",
srcs = [
"argmax_matcher.py",
],
deps = [
"//tensorflow",
"//tensorflow/models/research/object_detection/core:matcher",
"//tensorflow/models/research/object_detection/utils:shape_utils",
],
)
py_test(
name = "argmax_matcher_test",
srcs = ["argmax_matcher_test.py"],
deps = [
":argmax_matcher",
"//tensorflow",
"//tensorflow/models/research/object_detection/utils:test_case",
],
)
py_library(
name = "bipartite_matcher",
srcs = [
"bipartite_matcher.py",
],
deps = [
"//tensorflow",
"//tensorflow/contrib/image:image_py",
"//tensorflow/models/research/object_detection/core:matcher",
],
)
py_test(
name = "bipartite_matcher_test",
srcs = [
"bipartite_matcher_test.py",
],
deps = [
":bipartite_matcher",
"//tensorflow",
],
)
...@@ -38,7 +38,7 @@ class GreedyBipartiteMatcher(matcher.Matcher): ...@@ -38,7 +38,7 @@ class GreedyBipartiteMatcher(matcher.Matcher):
def _match(self, similarity_matrix, num_valid_rows=-1): def _match(self, similarity_matrix, num_valid_rows=-1):
"""Bipartite matches a collection rows and columns. A greedy bi-partite. """Bipartite matches a collection rows and columns. A greedy bi-partite.
TODO: Add num_valid_columns options to match only that many columns TODO(rathodv): Add num_valid_columns options to match only that many columns
with all the rows. with all the rows.
Args: Args:
......
# Tensorflow Object Detection API: Meta-architectures.
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"])
# Apache 2.0
py_library(
name = "ssd_meta_arch",
srcs = ["ssd_meta_arch.py"],
deps = [
"//tensorflow",
"//tensorflow/models/research/object_detection/core:box_list",
"//tensorflow/models/research/object_detection/core:box_predictor",
"//tensorflow/models/research/object_detection/core:model",
"//tensorflow/models/research/object_detection/core:target_assigner",
"//tensorflow/models/research/object_detection/utils:ops",
"//tensorflow/models/research/object_detection/utils:shape_utils",
"//tensorflow/models/research/object_detection/utils:test_case",
"//tensorflow/models/research/object_detection/utils:visualization_utils",
],
)
py_test(
name = "ssd_meta_arch_test",
srcs = ["ssd_meta_arch_test.py"],
deps = [
":ssd_meta_arch",
"//tensorflow",
"//tensorflow/models/research/object_detection/core:anchor_generator",
"//tensorflow/models/research/object_detection/core:box_list",
"//tensorflow/models/research/object_detection/core:losses",
"//tensorflow/models/research/object_detection/core:post_processing",
"//tensorflow/models/research/object_detection/core:region_similarity_calculator",
"//tensorflow/models/research/object_detection/utils:test_utils",
],
)
py_library(
name = "faster_rcnn_meta_arch",
srcs = [
"faster_rcnn_meta_arch.py",
],
deps = [
"//tensorflow",
"//tensorflow/models/research/object_detection/anchor_generators:grid_anchor_generator",
"//tensorflow/models/research/object_detection/core:balanced_positive_negative_sampler",
"//tensorflow/models/research/object_detection/core:box_list",
"//tensorflow/models/research/object_detection/core:box_list_ops",
"//tensorflow/models/research/object_detection/core:box_predictor",
"//tensorflow/models/research/object_detection/core:losses",
"//tensorflow/models/research/object_detection/core:model",
"//tensorflow/models/research/object_detection/core:post_processing",
"//tensorflow/models/research/object_detection/core:standard_fields",
"//tensorflow/models/research/object_detection/core:target_assigner",
"//tensorflow/models/research/object_detection/utils:ops",
"//tensorflow/models/research/object_detection/utils:shape_utils",
],
)
py_library(
name = "faster_rcnn_meta_arch_test_lib",
srcs = [
"faster_rcnn_meta_arch_test_lib.py",
],
deps = [
":faster_rcnn_meta_arch",
"//tensorflow",
"//tensorflow/models/research/object_detection/anchor_generators:grid_anchor_generator",
"//tensorflow/models/research/object_detection/builders:box_predictor_builder",
"//tensorflow/models/research/object_detection/builders:hyperparams_builder",
"//tensorflow/models/research/object_detection/builders:post_processing_builder",
"//tensorflow/models/research/object_detection/core:losses",
"//tensorflow/models/research/object_detection/protos:box_predictor_py_pb2",
"//tensorflow/models/research/object_detection/protos:hyperparams_py_pb2",
"//tensorflow/models/research/object_detection/protos:post_processing_py_pb2",
],
)
py_test(
name = "faster_rcnn_meta_arch_test",
srcs = ["faster_rcnn_meta_arch_test.py"],
deps = [
":faster_rcnn_meta_arch_test_lib",
],
)
py_library(
name = "rfcn_meta_arch",
srcs = ["rfcn_meta_arch.py"],
deps = [
":faster_rcnn_meta_arch",
"//tensorflow",
"//tensorflow/models/research/object_detection/core:box_predictor",
"//tensorflow/models/research/object_detection/utils:ops",
],
)
py_test(
name = "rfcn_meta_arch_test",
srcs = ["rfcn_meta_arch_test.py"],
deps = [
":faster_rcnn_meta_arch_test_lib",
":rfcn_meta_arch",
"//tensorflow",
],
)
...@@ -365,7 +365,7 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -365,7 +365,7 @@ class FasterRCNNMetaArch(model.DetectionModel):
ValueError: If first_stage_anchor_generator is not of type ValueError: If first_stage_anchor_generator is not of type
grid_anchor_generator.GridAnchorGenerator. grid_anchor_generator.GridAnchorGenerator.
""" """
# TODO: add_summaries is currently unused. Respect that directive # TODO(rathodv): add_summaries is currently unused. Respect that directive
# in the future. # in the future.
super(FasterRCNNMetaArch, self).__init__(num_classes=num_classes) super(FasterRCNNMetaArch, self).__init__(num_classes=num_classes)
...@@ -597,7 +597,7 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -597,7 +597,7 @@ class FasterRCNNMetaArch(model.DetectionModel):
`num_anchors` can differ depending on whether the model is created in `num_anchors` can differ depending on whether the model is created in
training or inference mode. training or inference mode.
(and if number_of_stages=1): (and if number_of_stages > 1):
7) refined_box_encodings: a 3-D tensor with shape 7) refined_box_encodings: a 3-D tensor with shape
[total_num_proposals, num_classes, 4] representing predicted [total_num_proposals, num_classes, 4] representing predicted
(final) refined box encodings, where (final) refined box encodings, where
...@@ -910,8 +910,9 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -910,8 +910,9 @@ class FasterRCNNMetaArch(model.DetectionModel):
preprocessed_inputs, scope=self.first_stage_feature_extractor_scope) preprocessed_inputs, scope=self.first_stage_feature_extractor_scope)
feature_map_shape = tf.shape(rpn_features_to_crop) feature_map_shape = tf.shape(rpn_features_to_crop)
anchors = self._first_stage_anchor_generator.generate( anchors = box_list_ops.concatenate(
[(feature_map_shape[1], feature_map_shape[2])]) self._first_stage_anchor_generator.generate([(feature_map_shape[1],
feature_map_shape[2])]))
with slim.arg_scope(self._first_stage_box_predictor_arg_scope): with slim.arg_scope(self._first_stage_box_predictor_arg_scope):
kernel_size = self._first_stage_box_predictor_kernel_size kernel_size = self._first_stage_box_predictor_kernel_size
rpn_box_predictor_features = slim.conv2d( rpn_box_predictor_features = slim.conv2d(
...@@ -957,9 +958,11 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -957,9 +958,11 @@ class FasterRCNNMetaArch(model.DetectionModel):
num_anchors_per_location, num_anchors_per_location,
scope=self.first_stage_box_predictor_scope) scope=self.first_stage_box_predictor_scope)
box_encodings = box_predictions[box_predictor.BOX_ENCODINGS] box_encodings = tf.concat(
objectness_predictions_with_background = box_predictions[ box_predictions[box_predictor.BOX_ENCODINGS], axis=1)
box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND] objectness_predictions_with_background = tf.concat(
box_predictions[box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND],
axis=1)
return (tf.squeeze(box_encodings, axis=2), return (tf.squeeze(box_encodings, axis=2),
objectness_predictions_with_background) objectness_predictions_with_background)
...@@ -1796,7 +1799,7 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -1796,7 +1799,7 @@ class FasterRCNNMetaArch(model.DetectionModel):
# Create a new target assigner that matches the proposals to groundtruth # Create a new target assigner that matches the proposals to groundtruth
# and returns the mask targets. # and returns the mask targets.
# TODO: Move `unmatched_cls_target` from constructor to assign # TODO(rathodv): Move `unmatched_cls_target` from constructor to assign
# function. This will enable reuse of a single target assigner for both # function. This will enable reuse of a single target assigner for both
# class targets and mask targets. # class targets and mask targets.
mask_target_assigner = target_assigner.create_target_assigner( mask_target_assigner = target_assigner.create_target_assigner(
......
...@@ -745,7 +745,7 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase): ...@@ -745,7 +745,7 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
preprocessed_inputs, _ = model.preprocess(image_placeholder) preprocessed_inputs, _ = model.preprocess(image_placeholder)
self.assertAllEqual(preprocessed_inputs.shape.as_list(), image_shape) self.assertAllEqual(preprocessed_inputs.shape.as_list(), image_shape)
# TODO: Split test into two - with and without masks. # TODO(rathodv): Split test into two - with and without masks.
def test_loss_first_stage_only_mode(self): def test_loss_first_stage_only_mode(self):
model = self._build_model( model = self._build_model(
is_training=True, number_of_stages=1, second_stage_batch_size=6) is_training=True, number_of_stages=1, second_stage_batch_size=6)
...@@ -797,7 +797,7 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase): ...@@ -797,7 +797,7 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
self.assertTrue('second_stage_localization_loss' not in loss_dict_out) self.assertTrue('second_stage_localization_loss' not in loss_dict_out)
self.assertTrue('second_stage_classification_loss' not in loss_dict_out) self.assertTrue('second_stage_classification_loss' not in loss_dict_out)
# TODO: Split test into two - with and without masks. # TODO(rathodv): Split test into two - with and without masks.
def test_loss_full(self): def test_loss_full(self):
model = self._build_model( model = self._build_model(
is_training=True, number_of_stages=2, second_stage_batch_size=6) is_training=True, number_of_stages=2, second_stage_batch_size=6)
......
...@@ -164,7 +164,7 @@ class RFCNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch): ...@@ -164,7 +164,7 @@ class RFCNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
ValueError: If first_stage_anchor_generator is not of type ValueError: If first_stage_anchor_generator is not of type
grid_anchor_generator.GridAnchorGenerator. grid_anchor_generator.GridAnchorGenerator.
""" """
# TODO: add_summaries is currently unused. Respect that directive # TODO(rathodv): add_summaries is currently unused. Respect that directive
# in the future. # in the future.
super(RFCNMetaArch, self).__init__( super(RFCNMetaArch, self).__init__(
is_training, is_training,
...@@ -275,9 +275,11 @@ class RFCNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch): ...@@ -275,9 +275,11 @@ class RFCNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
scope=self.second_stage_box_predictor_scope, scope=self.second_stage_box_predictor_scope,
proposal_boxes=proposal_boxes_normalized) proposal_boxes=proposal_boxes_normalized)
refined_box_encodings = tf.squeeze( refined_box_encodings = tf.squeeze(
box_predictions[box_predictor.BOX_ENCODINGS], axis=1) tf.concat(box_predictions[box_predictor.BOX_ENCODINGS], axis=1), axis=1)
class_predictions_with_background = tf.squeeze( class_predictions_with_background = tf.squeeze(
box_predictions[box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND], tf.concat(
box_predictions[box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND],
axis=1),
axis=1) axis=1)
absolute_proposal_boxes = ops.normalized_to_image_coordinates( absolute_proposal_boxes = ops.normalized_to_image_coordinates(
......
...@@ -23,6 +23,7 @@ import re ...@@ -23,6 +23,7 @@ import re
import tensorflow as tf import tensorflow as tf
from object_detection.core import box_list from object_detection.core import box_list
from object_detection.core import box_list_ops
from object_detection.core import model from object_detection.core import model
from object_detection.core import standard_fields as fields from object_detection.core import standard_fields as fields
from object_detection.core import target_assigner from object_detection.core import target_assigner
...@@ -122,6 +123,7 @@ class SSDMetaArch(model.DetectionModel): ...@@ -122,6 +123,7 @@ class SSDMetaArch(model.DetectionModel):
matcher, matcher,
region_similarity_calculator, region_similarity_calculator,
encode_background_as_zeros, encode_background_as_zeros,
negative_class_weight,
image_resizer_fn, image_resizer_fn,
non_max_suppression_fn, non_max_suppression_fn,
score_conversion_fn, score_conversion_fn,
...@@ -131,7 +133,8 @@ class SSDMetaArch(model.DetectionModel): ...@@ -131,7 +133,8 @@ class SSDMetaArch(model.DetectionModel):
localization_loss_weight, localization_loss_weight,
normalize_loss_by_num_matches, normalize_loss_by_num_matches,
hard_example_miner, hard_example_miner,
add_summaries=True): add_summaries=True,
normalize_loc_loss_by_codesize=False):
"""SSDMetaArch Constructor. """SSDMetaArch Constructor.
TODO(rathodv,jonathanhuang): group NMS parameters + score converter into TODO(rathodv,jonathanhuang): group NMS parameters + score converter into
...@@ -151,6 +154,7 @@ class SSDMetaArch(model.DetectionModel): ...@@ -151,6 +154,7 @@ class SSDMetaArch(model.DetectionModel):
encode_background_as_zeros: boolean determining whether background encode_background_as_zeros: boolean determining whether background
targets are to be encoded as an all zeros vector or a one-hot targets are to be encoded as an all zeros vector or a one-hot
vector (where background is the 0th class). vector (where background is the 0th class).
negative_class_weight: Weight for confidence loss of negative anchors.
image_resizer_fn: a callable for image resizing. This callable always image_resizer_fn: a callable for image resizing. This callable always
takes a rank-3 image tensor (corresponding to a single image) and takes a rank-3 image tensor (corresponding to a single image) and
returns a rank-3 image tensor, possibly with new spatial dimensions and returns a rank-3 image tensor, possibly with new spatial dimensions and
...@@ -175,6 +179,8 @@ class SSDMetaArch(model.DetectionModel): ...@@ -175,6 +179,8 @@ class SSDMetaArch(model.DetectionModel):
hard_example_miner: a losses.HardExampleMiner object (can be None) hard_example_miner: a losses.HardExampleMiner object (can be None)
add_summaries: boolean (default: True) controlling whether summary ops add_summaries: boolean (default: True) controlling whether summary ops
should be added to tensorflow graph. should be added to tensorflow graph.
normalize_loc_loss_by_codesize: whether to normalize localization loss
by code size of the box encoder.
""" """
super(SSDMetaArch, self).__init__(num_classes=box_predictor.num_classes) super(SSDMetaArch, self).__init__(num_classes=box_predictor.num_classes)
self._is_training = is_training self._is_training = is_training
...@@ -191,7 +197,7 @@ class SSDMetaArch(model.DetectionModel): ...@@ -191,7 +197,7 @@ class SSDMetaArch(model.DetectionModel):
self._matcher = matcher self._matcher = matcher
self._region_similarity_calculator = region_similarity_calculator self._region_similarity_calculator = region_similarity_calculator
# TODO: handle agnostic mode and positive/negative class # TODO(jonathanhuang): handle agnostic mode
# weights # weights
unmatched_cls_target = None unmatched_cls_target = None
unmatched_cls_target = tf.constant([1] + self.num_classes * [0], unmatched_cls_target = tf.constant([1] + self.num_classes * [0],
...@@ -204,7 +210,7 @@ class SSDMetaArch(model.DetectionModel): ...@@ -204,7 +210,7 @@ class SSDMetaArch(model.DetectionModel):
self._region_similarity_calculator, self._region_similarity_calculator,
self._matcher, self._matcher,
self._box_coder, self._box_coder,
negative_class_weight=1.0, negative_class_weight=negative_class_weight,
unmatched_cls_target=unmatched_cls_target) unmatched_cls_target=unmatched_cls_target)
self._classification_loss = classification_loss self._classification_loss = classification_loss
...@@ -212,6 +218,7 @@ class SSDMetaArch(model.DetectionModel): ...@@ -212,6 +218,7 @@ class SSDMetaArch(model.DetectionModel):
self._classification_loss_weight = classification_loss_weight self._classification_loss_weight = classification_loss_weight
self._localization_loss_weight = localization_loss_weight self._localization_loss_weight = localization_loss_weight
self._normalize_loss_by_num_matches = normalize_loss_by_num_matches self._normalize_loss_by_num_matches = normalize_loss_by_num_matches
self._normalize_loc_loss_by_codesize = normalize_loc_loss_by_codesize
self._hard_example_miner = hard_example_miner self._hard_example_miner = hard_example_miner
self._image_resizer_fn = image_resizer_fn self._image_resizer_fn = image_resizer_fn
...@@ -254,7 +261,7 @@ class SSDMetaArch(model.DetectionModel): ...@@ -254,7 +261,7 @@ class SSDMetaArch(model.DetectionModel):
if inputs.dtype is not tf.float32: if inputs.dtype is not tf.float32:
raise ValueError('`preprocess` expects a tf.float32 tensor') raise ValueError('`preprocess` expects a tf.float32 tensor')
with tf.name_scope('Preprocessor'): with tf.name_scope('Preprocessor'):
# TODO: revisit whether to always use batch size as # TODO(jonathanhuang): revisit whether to always use batch size as
# the number of parallel iterations vs allow for dynamic batching. # the number of parallel iterations vs allow for dynamic batching.
outputs = shape_utils.static_or_dynamic_map_fn( outputs = shape_utils.static_or_dynamic_map_fn(
self._image_resizer_fn, self._image_resizer_fn,
...@@ -344,15 +351,17 @@ class SSDMetaArch(model.DetectionModel): ...@@ -344,15 +351,17 @@ class SSDMetaArch(model.DetectionModel):
feature_map_spatial_dims = self._get_feature_map_spatial_dims(feature_maps) feature_map_spatial_dims = self._get_feature_map_spatial_dims(feature_maps)
image_shape = shape_utils.combined_static_and_dynamic_shape( image_shape = shape_utils.combined_static_and_dynamic_shape(
preprocessed_inputs) preprocessed_inputs)
self._anchors = self._anchor_generator.generate( self._anchors = box_list_ops.concatenate(
feature_map_spatial_dims, self._anchor_generator.generate(
im_height=image_shape[1], feature_map_spatial_dims,
im_width=image_shape[2]) im_height=image_shape[1],
im_width=image_shape[2]))
prediction_dict = self._box_predictor.predict( prediction_dict = self._box_predictor.predict(
feature_maps, self._anchor_generator.num_anchors_per_location()) feature_maps, self._anchor_generator.num_anchors_per_location())
box_encodings = tf.squeeze(prediction_dict['box_encodings'], axis=2) box_encodings = tf.squeeze(
class_predictions_with_background = prediction_dict[ tf.concat(prediction_dict['box_encodings'], axis=1), axis=2)
'class_predictions_with_background'] class_predictions_with_background = tf.concat(
prediction_dict['class_predictions_with_background'], axis=1)
predictions_dict = { predictions_dict = {
'preprocessed_inputs': preprocessed_inputs, 'preprocessed_inputs': preprocessed_inputs,
'box_encodings': box_encodings, 'box_encodings': box_encodings,
...@@ -530,8 +539,11 @@ class SSDMetaArch(model.DetectionModel): ...@@ -530,8 +539,11 @@ class SSDMetaArch(model.DetectionModel):
1.0) 1.0)
with tf.name_scope('localization_loss'): with tf.name_scope('localization_loss'):
localization_loss = ((self._localization_loss_weight / normalizer) * localization_loss_normalizer = normalizer
localization_loss) if self._normalize_loc_loss_by_codesize:
localization_loss_normalizer *= self._box_coder.code_size
localization_loss = ((self._localization_loss_weight / (
localization_loss_normalizer)) * localization_loss)
with tf.name_scope('classification_loss'): with tf.name_scope('classification_loss'):
classification_loss = ((self._classification_loss_weight / normalizer) * classification_loss = ((self._classification_loss_weight / normalizer) *
classification_loss) classification_loss)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment