Commit 324d6dc3 authored by Zhichao Lu's avatar Zhichao Lu Committed by pkulzc
Browse files

Merged commit includes the following changes:

196161788  by Zhichao Lu:

    Add eval_on_train_steps parameter.

    Since the number of samples in train dataset is usually different to the number of samples in the eval dataset.

--
196151742  by Zhichao Lu:

    Add an optional random sampling process for SSD meta arch and update mean stddev coder to use default std dev when corresponding tensor is not added to boxlist field.

--
196148940  by Zhichao Lu:

    Release ssdlite mobilenet v2 coco trained model.

--
196058528  by Zhichao Lu:

    Apply FPN feature map generation before we add additional layers on top of resnet feature extractor.

--
195818367  by Zhichao Lu:

    Add support for exporting detection keypoints.

--
195745420  by Zhichao Lu:

    Introduce include_metrics_per_category option to Object Detection eval_config.

--
195734733  by Zhichao Lu:

    Rename SSDLite config to be more explicit.

--
195717383  by Zhichao Lu:

    Add quantized training to object_detection.

--
195683542  by...
parent 63054210
...@@ -25,6 +25,14 @@ from object_detection.core import box_list ...@@ -25,6 +25,14 @@ from object_detection.core import box_list
class MeanStddevBoxCoder(box_coder.BoxCoder): class MeanStddevBoxCoder(box_coder.BoxCoder):
"""Mean stddev box coder.""" """Mean stddev box coder."""
def __init__(self, stddev=0.01):
"""Constructor for MeanStddevBoxCoder.
Args:
stddev: The standard deviation used to encode and decode boxes.
"""
self._stddev = stddev
@property @property
def code_size(self): def code_size(self):
return 4 return 4
...@@ -34,37 +42,38 @@ class MeanStddevBoxCoder(box_coder.BoxCoder): ...@@ -34,37 +42,38 @@ class MeanStddevBoxCoder(box_coder.BoxCoder):
Args: Args:
boxes: BoxList holding N boxes to be encoded. boxes: BoxList holding N boxes to be encoded.
anchors: BoxList of N anchors. We assume that anchors has an associated anchors: BoxList of N anchors.
stddev field.
Returns: Returns:
a tensor representing N anchor-encoded boxes a tensor representing N anchor-encoded boxes
Raises: Raises:
ValueError: if the anchors BoxList does not have a stddev field ValueError: if the anchors still have deprecated stddev field.
""" """
if not anchors.has_field('stddev'):
raise ValueError('anchors must have a stddev field')
box_corners = boxes.get() box_corners = boxes.get()
if anchors.has_field('stddev'):
raise ValueError("'stddev' is a parameter of MeanStddevBoxCoder and "
"should not be specified in the box list.")
means = anchors.get() means = anchors.get()
stddev = anchors.get_field('stddev') return (box_corners - means) / self._stddev
return (box_corners - means) / stddev
def _decode(self, rel_codes, anchors): def _decode(self, rel_codes, anchors):
"""Decode. """Decode.
Args: Args:
rel_codes: a tensor representing N anchor-encoded boxes. rel_codes: a tensor representing N anchor-encoded boxes.
anchors: BoxList of anchors. We assume that anchors has an associated anchors: BoxList of anchors.
stddev field.
Returns: Returns:
boxes: BoxList holding N bounding boxes boxes: BoxList holding N bounding boxes
Raises: Raises:
ValueError: if the anchors BoxList does not have a stddev field ValueError: if the anchors still have deprecated stddev field and expects
the decode method to use stddev value from that field.
""" """
if not anchors.has_field('stddev'):
raise ValueError('anchors must have a stddev field')
means = anchors.get() means = anchors.get()
stddevs = anchors.get_field('stddev') if anchors.has_field('stddev'):
box_corners = rel_codes * stddevs + means raise ValueError("'stddev' is a parameter of MeanStddevBoxCoder and "
"should not be specified in the box list.")
box_corners = rel_codes * self._stddev + means
return box_list.BoxList(box_corners) return box_list.BoxList(box_corners)
...@@ -28,11 +28,9 @@ class MeanStddevBoxCoderTest(tf.test.TestCase): ...@@ -28,11 +28,9 @@ class MeanStddevBoxCoderTest(tf.test.TestCase):
boxes = box_list.BoxList(tf.constant(box_corners)) boxes = box_list.BoxList(tf.constant(box_corners))
expected_rel_codes = [[0.0, 0.0, 0.0, 0.0], [-5.0, -5.0, -5.0, -3.0]] expected_rel_codes = [[0.0, 0.0, 0.0, 0.0], [-5.0, -5.0, -5.0, -3.0]]
prior_means = tf.constant([[0.0, 0.0, 0.5, 0.5], [0.5, 0.5, 1.0, 0.8]]) prior_means = tf.constant([[0.0, 0.0, 0.5, 0.5], [0.5, 0.5, 1.0, 0.8]])
prior_stddevs = tf.constant(2 * [4 * [.1]])
priors = box_list.BoxList(prior_means) priors = box_list.BoxList(prior_means)
priors.add_field('stddev', prior_stddevs)
coder = mean_stddev_box_coder.MeanStddevBoxCoder() coder = mean_stddev_box_coder.MeanStddevBoxCoder(stddev=0.1)
rel_codes = coder.encode(boxes, priors) rel_codes = coder.encode(boxes, priors)
with self.test_session() as sess: with self.test_session() as sess:
rel_codes_out = sess.run(rel_codes) rel_codes_out = sess.run(rel_codes)
...@@ -42,11 +40,9 @@ class MeanStddevBoxCoderTest(tf.test.TestCase): ...@@ -42,11 +40,9 @@ class MeanStddevBoxCoderTest(tf.test.TestCase):
rel_codes = tf.constant([[0.0, 0.0, 0.0, 0.0], [-5.0, -5.0, -5.0, -3.0]]) rel_codes = tf.constant([[0.0, 0.0, 0.0, 0.0], [-5.0, -5.0, -5.0, -3.0]])
expected_box_corners = [[0.0, 0.0, 0.5, 0.5], [0.0, 0.0, 0.5, 0.5]] expected_box_corners = [[0.0, 0.0, 0.5, 0.5], [0.0, 0.0, 0.5, 0.5]]
prior_means = tf.constant([[0.0, 0.0, 0.5, 0.5], [0.5, 0.5, 1.0, 0.8]]) prior_means = tf.constant([[0.0, 0.0, 0.5, 0.5], [0.5, 0.5, 1.0, 0.8]])
prior_stddevs = tf.constant(2 * [4 * [.1]])
priors = box_list.BoxList(prior_means) priors = box_list.BoxList(prior_means)
priors.add_field('stddev', prior_stddevs)
coder = mean_stddev_box_coder.MeanStddevBoxCoder() coder = mean_stddev_box_coder.MeanStddevBoxCoder(stddev=0.1)
decoded_boxes = coder.decode(rel_codes, priors) decoded_boxes = coder.decode(rel_codes, priors)
decoded_box_corners = decoded_boxes.get() decoded_box_corners = decoded_boxes.get()
with self.test_session() as sess: with self.test_session() as sess:
......
...@@ -55,7 +55,8 @@ def build(box_coder_config): ...@@ -55,7 +55,8 @@ def build(box_coder_config):
]) ])
if (box_coder_config.WhichOneof('box_coder_oneof') == if (box_coder_config.WhichOneof('box_coder_oneof') ==
'mean_stddev_box_coder'): 'mean_stddev_box_coder'):
return mean_stddev_box_coder.MeanStddevBoxCoder() return mean_stddev_box_coder.MeanStddevBoxCoder(
stddev=box_coder_config.mean_stddev_box_coder.stddev)
if box_coder_config.WhichOneof('box_coder_oneof') == 'square_box_coder': if box_coder_config.WhichOneof('box_coder_oneof') == 'square_box_coder':
return square_box_coder.SquareBoxCoder(scale_factors=[ return square_box_coder.SquareBoxCoder(scale_factors=[
box_coder_config.square_box_coder.y_scale, box_coder_config.square_box_coder.y_scale,
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for quantized training and evaluation."""
import tensorflow as tf
def build(graph_rewriter_config, is_training):
"""Returns a function that modifies default graph based on options.
Args:
graph_rewriter_config: graph_rewriter_pb2.GraphRewriter proto.
is_training: whether in training of eval mode.
"""
def graph_rewrite_fn():
"""Function to quantize weights and activation of the default graph."""
if (graph_rewriter_config.quantization.weight_bits != 8 or
graph_rewriter_config.quantization.activation_bits != 8):
raise ValueError('Only 8bit quantization is supported')
# Quantize the graph by inserting quantize ops for weights and activations
if is_training:
tf.contrib.quantize.create_training_graph(
input_graph=tf.get_default_graph(),
quant_delay=graph_rewriter_config.quantization.delay)
else:
tf.contrib.quantize.create_eval_graph(input_graph=tf.get_default_graph())
tf.contrib.layers.summarize_collection('quant_vars')
return graph_rewrite_fn
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for graph_rewriter_builder."""
import mock
import tensorflow as tf
from object_detection.builders import graph_rewriter_builder
from object_detection.protos import graph_rewriter_pb2
class QuantizationBuilderTest(tf.test.TestCase):
def testQuantizationBuilderSetsUpCorrectTrainArguments(self):
with mock.patch.object(
tf.contrib.quantize, 'create_training_graph') as mock_quant_fn:
with mock.patch.object(tf.contrib.layers,
'summarize_collection') as mock_summarize_col:
graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter()
graph_rewriter_proto.quantization.delay = 10
graph_rewriter_proto.quantization.weight_bits = 8
graph_rewriter_proto.quantization.activation_bits = 8
graph_rewrite_fn = graph_rewriter_builder.build(
graph_rewriter_proto, is_training=True)
graph_rewrite_fn()
_, kwargs = mock_quant_fn.call_args
self.assertEqual(kwargs['input_graph'], tf.get_default_graph())
self.assertEqual(kwargs['quant_delay'], 10)
mock_summarize_col.assert_called_with('quant_vars')
def testQuantizationBuilderSetsUpCorrectEvalArguments(self):
with mock.patch.object(tf.contrib.quantize,
'create_eval_graph') as mock_quant_fn:
with mock.patch.object(tf.contrib.layers,
'summarize_collection') as mock_summarize_col:
graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter()
graph_rewriter_proto.quantization.delay = 10
graph_rewrite_fn = graph_rewriter_builder.build(
graph_rewriter_proto, is_training=False)
graph_rewrite_fn()
_, kwargs = mock_quant_fn.call_args
self.assertEqual(kwargs['input_graph'], tf.get_default_graph())
mock_summarize_col.assert_called_with('quant_vars')
if __name__ == '__main__':
tf.test.main()
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""A function to build localization and classification losses from config.""" """A function to build localization and classification losses from config."""
from object_detection.core import balanced_positive_negative_sampler as sampler
from object_detection.core import losses from object_detection.core import losses
from object_detection.protos import losses_pb2 from object_detection.protos import losses_pb2
...@@ -34,9 +35,12 @@ def build(loss_config): ...@@ -34,9 +35,12 @@ def build(loss_config):
classification_weight: Classification loss weight. classification_weight: Classification loss weight.
localization_weight: Localization loss weight. localization_weight: Localization loss weight.
hard_example_miner: Hard example miner object. hard_example_miner: Hard example miner object.
random_example_sampler: BalancedPositiveNegativeSampler object.
Raises: Raises:
ValueError: If hard_example_miner is used with sigmoid_focal_loss. ValueError: If hard_example_miner is used with sigmoid_focal_loss.
ValueError: If random_example_sampler is getting non-positive value as
desired positive example fraction.
""" """
classification_loss = _build_classification_loss( classification_loss = _build_classification_loss(
loss_config.classification_loss) loss_config.classification_loss)
...@@ -54,9 +58,16 @@ def build(loss_config): ...@@ -54,9 +58,16 @@ def build(loss_config):
loss_config.hard_example_miner, loss_config.hard_example_miner,
classification_weight, classification_weight,
localization_weight) localization_weight)
return (classification_loss, localization_loss, random_example_sampler = None
classification_weight, if loss_config.HasField('random_example_sampler'):
localization_weight, hard_example_miner) if loss_config.random_example_sampler.positive_sample_fraction <= 0:
raise ValueError('RandomExampleSampler should not use non-positive'
'value as positive sample fraction.')
random_example_sampler = sampler.BalancedPositiveNegativeSampler(
positive_fraction=loss_config.random_example_sampler.
positive_sample_fraction)
return (classification_loss, localization_loss, classification_weight,
localization_weight, hard_example_miner, random_example_sampler)
def build_hard_example_miner(config, def build_hard_example_miner(config,
......
...@@ -38,7 +38,7 @@ class LocalizationLossBuilderTest(tf.test.TestCase): ...@@ -38,7 +38,7 @@ class LocalizationLossBuilderTest(tf.test.TestCase):
""" """
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
_, localization_loss, _, _, _ = losses_builder.build(losses_proto) _, localization_loss, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(localization_loss, self.assertTrue(isinstance(localization_loss,
losses.WeightedL2LocalizationLoss)) losses.WeightedL2LocalizationLoss))
...@@ -55,7 +55,7 @@ class LocalizationLossBuilderTest(tf.test.TestCase): ...@@ -55,7 +55,7 @@ class LocalizationLossBuilderTest(tf.test.TestCase):
""" """
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
_, localization_loss, _, _, _ = losses_builder.build(losses_proto) _, localization_loss, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(localization_loss, self.assertTrue(isinstance(localization_loss,
losses.WeightedSmoothL1LocalizationLoss)) losses.WeightedSmoothL1LocalizationLoss))
self.assertAlmostEqual(localization_loss._delta, 1.0) self.assertAlmostEqual(localization_loss._delta, 1.0)
...@@ -74,7 +74,7 @@ class LocalizationLossBuilderTest(tf.test.TestCase): ...@@ -74,7 +74,7 @@ class LocalizationLossBuilderTest(tf.test.TestCase):
""" """
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
_, localization_loss, _, _, _ = losses_builder.build(losses_proto) _, localization_loss, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(localization_loss, self.assertTrue(isinstance(localization_loss,
losses.WeightedSmoothL1LocalizationLoss)) losses.WeightedSmoothL1LocalizationLoss))
self.assertAlmostEqual(localization_loss._delta, 0.1) self.assertAlmostEqual(localization_loss._delta, 0.1)
...@@ -92,7 +92,7 @@ class LocalizationLossBuilderTest(tf.test.TestCase): ...@@ -92,7 +92,7 @@ class LocalizationLossBuilderTest(tf.test.TestCase):
""" """
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
_, localization_loss, _, _, _ = losses_builder.build(losses_proto) _, localization_loss, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(localization_loss, self.assertTrue(isinstance(localization_loss,
losses.WeightedIOULocalizationLoss)) losses.WeightedIOULocalizationLoss))
...@@ -109,7 +109,7 @@ class LocalizationLossBuilderTest(tf.test.TestCase): ...@@ -109,7 +109,7 @@ class LocalizationLossBuilderTest(tf.test.TestCase):
""" """
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
_, localization_loss, _, _, _ = losses_builder.build(losses_proto) _, localization_loss, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(localization_loss, self.assertTrue(isinstance(localization_loss,
losses.WeightedSmoothL1LocalizationLoss)) losses.WeightedSmoothL1LocalizationLoss))
predictions = tf.constant([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]]) predictions = tf.constant([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]])
...@@ -146,7 +146,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase): ...@@ -146,7 +146,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
""" """
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _ = losses_builder.build(losses_proto) classification_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss, self.assertTrue(isinstance(classification_loss,
losses.WeightedSigmoidClassificationLoss)) losses.WeightedSigmoidClassificationLoss))
...@@ -163,7 +163,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase): ...@@ -163,7 +163,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
""" """
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _ = losses_builder.build(losses_proto) classification_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss, self.assertTrue(isinstance(classification_loss,
losses.SigmoidFocalClassificationLoss)) losses.SigmoidFocalClassificationLoss))
self.assertAlmostEqual(classification_loss._alpha, None) self.assertAlmostEqual(classification_loss._alpha, None)
...@@ -184,7 +184,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase): ...@@ -184,7 +184,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
""" """
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _ = losses_builder.build(losses_proto) classification_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss, self.assertTrue(isinstance(classification_loss,
losses.SigmoidFocalClassificationLoss)) losses.SigmoidFocalClassificationLoss))
self.assertAlmostEqual(classification_loss._alpha, 0.25) self.assertAlmostEqual(classification_loss._alpha, 0.25)
...@@ -203,7 +203,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase): ...@@ -203,7 +203,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
""" """
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _ = losses_builder.build(losses_proto) classification_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss, self.assertTrue(isinstance(classification_loss,
losses.WeightedSoftmaxClassificationLoss)) losses.WeightedSoftmaxClassificationLoss))
...@@ -220,7 +220,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase): ...@@ -220,7 +220,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
""" """
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _ = losses_builder.build(losses_proto) classification_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue( self.assertTrue(
isinstance(classification_loss, isinstance(classification_loss,
losses.WeightedSoftmaxClassificationAgainstLogitsLoss)) losses.WeightedSoftmaxClassificationAgainstLogitsLoss))
...@@ -239,7 +239,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase): ...@@ -239,7 +239,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
""" """
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _ = losses_builder.build(losses_proto) classification_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss, self.assertTrue(isinstance(classification_loss,
losses.WeightedSoftmaxClassificationLoss)) losses.WeightedSoftmaxClassificationLoss))
...@@ -257,7 +257,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase): ...@@ -257,7 +257,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
""" """
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _ = losses_builder.build(losses_proto) classification_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss, self.assertTrue(isinstance(classification_loss,
losses.BootstrappedSigmoidClassificationLoss)) losses.BootstrappedSigmoidClassificationLoss))
...@@ -275,7 +275,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase): ...@@ -275,7 +275,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
""" """
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _ = losses_builder.build(losses_proto) classification_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss, self.assertTrue(isinstance(classification_loss,
losses.WeightedSigmoidClassificationLoss)) losses.WeightedSigmoidClassificationLoss))
predictions = tf.constant([[[0.0, 1.0, 0.0], [0.0, 0.5, 0.5]]]) predictions = tf.constant([[[0.0, 1.0, 0.0], [0.0, 0.5, 0.5]]])
...@@ -312,7 +312,7 @@ class HardExampleMinerBuilderTest(tf.test.TestCase): ...@@ -312,7 +312,7 @@ class HardExampleMinerBuilderTest(tf.test.TestCase):
""" """
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
_, _, _, _, hard_example_miner = losses_builder.build(losses_proto) _, _, _, _, hard_example_miner, _ = losses_builder.build(losses_proto)
self.assertEqual(hard_example_miner, None) self.assertEqual(hard_example_miner, None)
def test_build_hard_example_miner_for_classification_loss(self): def test_build_hard_example_miner_for_classification_loss(self):
...@@ -331,7 +331,7 @@ class HardExampleMinerBuilderTest(tf.test.TestCase): ...@@ -331,7 +331,7 @@ class HardExampleMinerBuilderTest(tf.test.TestCase):
""" """
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
_, _, _, _, hard_example_miner = losses_builder.build(losses_proto) _, _, _, _, hard_example_miner, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner)) self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner))
self.assertEqual(hard_example_miner._loss_type, 'cls') self.assertEqual(hard_example_miner._loss_type, 'cls')
...@@ -351,7 +351,7 @@ class HardExampleMinerBuilderTest(tf.test.TestCase): ...@@ -351,7 +351,7 @@ class HardExampleMinerBuilderTest(tf.test.TestCase):
""" """
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
_, _, _, _, hard_example_miner = losses_builder.build(losses_proto) _, _, _, _, hard_example_miner, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner)) self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner))
self.assertEqual(hard_example_miner._loss_type, 'loc') self.assertEqual(hard_example_miner._loss_type, 'loc')
...@@ -375,7 +375,7 @@ class HardExampleMinerBuilderTest(tf.test.TestCase): ...@@ -375,7 +375,7 @@ class HardExampleMinerBuilderTest(tf.test.TestCase):
""" """
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
_, _, _, _, hard_example_miner = losses_builder.build(losses_proto) _, _, _, _, hard_example_miner, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner)) self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner))
self.assertEqual(hard_example_miner._num_hard_examples, 32) self.assertEqual(hard_example_miner._num_hard_examples, 32)
self.assertAlmostEqual(hard_example_miner._iou_threshold, 0.5) self.assertAlmostEqual(hard_example_miner._iou_threshold, 0.5)
...@@ -404,7 +404,7 @@ class LossBuilderTest(tf.test.TestCase): ...@@ -404,7 +404,7 @@ class LossBuilderTest(tf.test.TestCase):
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
(classification_loss, localization_loss, (classification_loss, localization_loss,
classification_weight, localization_weight, classification_weight, localization_weight,
hard_example_miner) = losses_builder.build(losses_proto) hard_example_miner, _) = losses_builder.build(losses_proto)
self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner)) self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner))
self.assertTrue(isinstance(classification_loss, self.assertTrue(isinstance(classification_loss,
losses.WeightedSoftmaxClassificationLoss)) losses.WeightedSoftmaxClassificationLoss))
......
...@@ -180,8 +180,8 @@ def _build_ssd_model(ssd_config, is_training, add_summaries, ...@@ -180,8 +180,8 @@ def _build_ssd_model(ssd_config, is_training, add_summaries,
non_max_suppression_fn, score_conversion_fn = post_processing_builder.build( non_max_suppression_fn, score_conversion_fn = post_processing_builder.build(
ssd_config.post_processing) ssd_config.post_processing)
(classification_loss, localization_loss, classification_weight, (classification_loss, localization_loss, classification_weight,
localization_weight, localization_weight, hard_example_miner,
hard_example_miner) = losses_builder.build(ssd_config.loss) random_example_sampler) = losses_builder.build(ssd_config.loss)
normalize_loss_by_num_matches = ssd_config.normalize_loss_by_num_matches normalize_loss_by_num_matches = ssd_config.normalize_loss_by_num_matches
normalize_loc_loss_by_codesize = ssd_config.normalize_loc_loss_by_codesize normalize_loc_loss_by_codesize = ssd_config.normalize_loc_loss_by_codesize
...@@ -208,7 +208,8 @@ def _build_ssd_model(ssd_config, is_training, add_summaries, ...@@ -208,7 +208,8 @@ def _build_ssd_model(ssd_config, is_training, add_summaries,
normalize_loc_loss_by_codesize=normalize_loc_loss_by_codesize, normalize_loc_loss_by_codesize=normalize_loc_loss_by_codesize,
freeze_batchnorm=ssd_config.freeze_batchnorm, freeze_batchnorm=ssd_config.freeze_batchnorm,
inplace_batchnorm_update=ssd_config.inplace_batchnorm_update, inplace_batchnorm_update=ssd_config.inplace_batchnorm_update,
add_background_class=add_background_class) add_background_class=add_background_class,
random_example_sampler=random_example_sampler)
def _build_faster_rcnn_feature_extractor( def _build_faster_rcnn_feature_extractor(
......
...@@ -39,6 +39,7 @@ class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler): ...@@ -39,6 +39,7 @@ class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler):
Args: Args:
positive_fraction: desired fraction of positive examples (scalar in [0,1]) positive_fraction: desired fraction of positive examples (scalar in [0,1])
in the batch.
Raises: Raises:
ValueError: if positive_fraction < 0, or positive_fraction > 1 ValueError: if positive_fraction < 0, or positive_fraction > 1
...@@ -53,7 +54,9 @@ class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler): ...@@ -53,7 +54,9 @@ class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler):
Args: Args:
indicator: boolean tensor of shape [N] whose True entries can be sampled. indicator: boolean tensor of shape [N] whose True entries can be sampled.
batch_size: desired batch size. batch_size: desired batch size. If None, keeps all positive samples and
randomly selects negative samples so that the positive sample fraction
matches self._positive_fraction.
labels: boolean tensor of shape [N] denoting positive(=True) and negative labels: boolean tensor of shape [N] denoting positive(=True) and negative
(=False) examples. (=False) examples.
...@@ -83,9 +86,19 @@ class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler): ...@@ -83,9 +86,19 @@ class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler):
negative_idx = tf.logical_and(negative_idx, indicator) negative_idx = tf.logical_and(negative_idx, indicator)
# Sample positive and negative samples separately # Sample positive and negative samples separately
max_num_pos = int(self._positive_fraction * batch_size) if batch_size is None:
max_num_pos = tf.reduce_sum(tf.to_int32(positive_idx))
else:
max_num_pos = int(self._positive_fraction * batch_size)
sampled_pos_idx = self.subsample_indicator(positive_idx, max_num_pos) sampled_pos_idx = self.subsample_indicator(positive_idx, max_num_pos)
max_num_neg = batch_size - tf.reduce_sum(tf.cast(sampled_pos_idx, tf.int32)) num_sampled_pos = tf.reduce_sum(tf.cast(sampled_pos_idx, tf.int32))
if batch_size is None:
negative_positive_ratio = (
1 - self._positive_fraction) / self._positive_fraction
max_num_neg = tf.to_int32(
negative_positive_ratio * tf.to_float(num_sampled_pos))
else:
max_num_neg = batch_size - num_sampled_pos
sampled_neg_idx = self.subsample_indicator(negative_idx, max_num_neg) sampled_neg_idx = self.subsample_indicator(negative_idx, max_num_neg)
sampled_idx = tf.logical_or(sampled_pos_idx, sampled_neg_idx) sampled_idx = tf.logical_or(sampled_pos_idx, sampled_neg_idx)
......
...@@ -19,9 +19,10 @@ import numpy as np ...@@ -19,9 +19,10 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from object_detection.core import balanced_positive_negative_sampler from object_detection.core import balanced_positive_negative_sampler
from object_detection.utils import test_case
class BalancedPositiveNegativeSamplerTest(tf.test.TestCase): class BalancedPositiveNegativeSamplerTest(test_case.TestCase):
def test_subsample_all_examples(self): def test_subsample_all_examples(self):
numpy_labels = np.random.permutation(300) numpy_labels = np.random.permutation(300)
...@@ -62,6 +63,28 @@ class BalancedPositiveNegativeSamplerTest(tf.test.TestCase): ...@@ -62,6 +63,28 @@ class BalancedPositiveNegativeSamplerTest(tf.test.TestCase):
self.assertAllEqual(is_sampled, np.logical_and(is_sampled, self.assertAllEqual(is_sampled, np.logical_and(is_sampled,
numpy_indicator)) numpy_indicator))
def test_subsample_selection_no_batch_size(self):
# Test random sampling when only some examples can be sampled:
# 1000 samples, 6 positives (5 can be sampled).
numpy_labels = np.arange(1000)
numpy_indicator = numpy_labels < 999
indicator = tf.constant(numpy_indicator)
numpy_labels = (numpy_labels - 994) >= 0
labels = tf.constant(numpy_labels)
sampler = (balanced_positive_negative_sampler.
BalancedPositiveNegativeSampler(0.01))
is_sampled = sampler.subsample(indicator, None, labels)
with self.test_session() as sess:
is_sampled = sess.run(is_sampled)
self.assertTrue(sum(is_sampled) == 500)
self.assertTrue(sum(np.logical_and(numpy_labels, is_sampled)) == 5)
self.assertTrue(sum(np.logical_and(
np.logical_not(numpy_labels), is_sampled)) == 495)
self.assertAllEqual(is_sampled, np.logical_and(is_sampled,
numpy_indicator))
def test_raises_error_with_incorrect_label_shape(self): def test_raises_error_with_incorrect_label_shape(self):
labels = tf.constant([[True, False, False]]) labels = tf.constant([[True, False, False]])
indicator = tf.constant([True, False, True]) indicator = tf.constant([True, False, True])
......
...@@ -237,7 +237,8 @@ class DetectionModel(object): ...@@ -237,7 +237,8 @@ class DetectionModel(object):
groundtruth_classes_list, groundtruth_classes_list,
groundtruth_masks_list=None, groundtruth_masks_list=None,
groundtruth_keypoints_list=None, groundtruth_keypoints_list=None,
groundtruth_weights_list=None): groundtruth_weights_list=None,
groundtruth_is_crowd_list=None):
"""Provide groundtruth tensors. """Provide groundtruth tensors.
Args: Args:
...@@ -260,6 +261,8 @@ class DetectionModel(object): ...@@ -260,6 +261,8 @@ class DetectionModel(object):
missing keypoints should be encoded as NaN. missing keypoints should be encoded as NaN.
groundtruth_weights_list: A list of 1-D tf.float32 tensors of shape groundtruth_weights_list: A list of 1-D tf.float32 tensors of shape
[num_boxes] containing weights for groundtruth boxes. [num_boxes] containing weights for groundtruth boxes.
groundtruth_is_crowd_list: A list of 1-D tf.bool tensors of shape
[num_boxes] containing is_crowd annotations
""" """
self._groundtruth_lists[fields.BoxListFields.boxes] = groundtruth_boxes_list self._groundtruth_lists[fields.BoxListFields.boxes] = groundtruth_boxes_list
self._groundtruth_lists[ self._groundtruth_lists[
...@@ -273,6 +276,9 @@ class DetectionModel(object): ...@@ -273,6 +276,9 @@ class DetectionModel(object):
if groundtruth_keypoints_list: if groundtruth_keypoints_list:
self._groundtruth_lists[ self._groundtruth_lists[
fields.BoxListFields.keypoints] = groundtruth_keypoints_list fields.BoxListFields.keypoints] = groundtruth_keypoints_list
if groundtruth_is_crowd_list:
self._groundtruth_lists[
fields.BoxListFields.is_crowd] = groundtruth_is_crowd_list
@abstractmethod @abstractmethod
def restore_map(self, fine_tune_checkpoint_type='detection'): def restore_map(self, fine_tune_checkpoint_type='detection'):
......
...@@ -132,6 +132,7 @@ class BoxListFields(object): ...@@ -132,6 +132,7 @@ class BoxListFields(object):
boundaries: boundaries per bounding box. boundaries: boundaries per bounding box.
keypoints: keypoints per bounding box. keypoints: keypoints per bounding box.
keypoint_heatmaps: keypoint heatmaps per bounding box. keypoint_heatmaps: keypoint heatmaps per bounding box.
is_crowd: is_crowd annotation per bounding box.
""" """
boxes = 'boxes' boxes = 'boxes'
classes = 'classes' classes = 'classes'
...@@ -142,6 +143,7 @@ class BoxListFields(object): ...@@ -142,6 +143,7 @@ class BoxListFields(object):
boundaries = 'boundaries' boundaries = 'boundaries'
keypoints = 'keypoints' keypoints = 'keypoints'
keypoint_heatmaps = 'keypoint_heatmaps' keypoint_heatmaps = 'keypoint_heatmaps'
is_crowd = 'is_crowd'
class TfExampleFields(object): class TfExampleFields(object):
......
...@@ -49,6 +49,7 @@ import tensorflow as tf ...@@ -49,6 +49,7 @@ import tensorflow as tf
from object_detection import evaluator from object_detection import evaluator
from object_detection.builders import dataset_builder from object_detection.builders import dataset_builder
from object_detection.builders import graph_rewriter_builder
from object_detection.builders import model_builder from object_detection.builders import model_builder
from object_detection.utils import config_util from object_detection.utils import config_util
from object_detection.utils import dataset_util from object_detection.utils import dataset_util
...@@ -127,8 +128,19 @@ def main(unused_argv): ...@@ -127,8 +128,19 @@ def main(unused_argv):
if FLAGS.run_once: if FLAGS.run_once:
eval_config.max_evals = 1 eval_config.max_evals = 1
evaluator.evaluate(create_input_dict_fn, model_fn, eval_config, categories, graph_rewriter_fn = None
FLAGS.checkpoint_dir, FLAGS.eval_dir) if 'graph_rewriter_config' in configs:
graph_rewriter_fn = graph_rewriter_builder.build(
configs['graph_rewriter_config'], is_training=False)
evaluator.evaluate(
create_input_dict_fn,
model_fn,
eval_config,
categories,
FLAGS.checkpoint_dir,
FLAGS.eval_dir,
graph_hook_fn=graph_rewriter_fn)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -588,7 +588,8 @@ def get_eval_metric_ops_for_evaluators(evaluation_metrics, ...@@ -588,7 +588,8 @@ def get_eval_metric_ops_for_evaluators(evaluation_metrics,
'name': (required) string representing category name e.g., 'cat', 'dog'. 'name': (required) string representing category name e.g., 'cat', 'dog'.
eval_dict: An evaluation dictionary, returned from eval_dict: An evaluation dictionary, returned from
result_dict_for_single_example(). result_dict_for_single_example().
include_metrics_per_category: If True, include metrics for each category. include_metrics_per_category: If True, additionally include per-category
metrics.
Returns: Returns:
A dictionary of metric names to tuple of value_op and update_op that can be A dictionary of metric names to tuple of value_op and update_op that can be
...@@ -615,7 +616,9 @@ def get_eval_metric_ops_for_evaluators(evaluation_metrics, ...@@ -615,7 +616,9 @@ def get_eval_metric_ops_for_evaluators(evaluation_metrics,
input_data_fields.groundtruth_classes], input_data_fields.groundtruth_classes],
detection_boxes=eval_dict[detection_fields.detection_boxes], detection_boxes=eval_dict[detection_fields.detection_boxes],
detection_scores=eval_dict[detection_fields.detection_scores], detection_scores=eval_dict[detection_fields.detection_scores],
detection_classes=eval_dict[detection_fields.detection_classes])) detection_classes=eval_dict[detection_fields.detection_classes],
groundtruth_is_crowd=eval_dict.get(
input_data_fields.groundtruth_is_crowd)))
elif metric == 'coco_mask_metrics': elif metric == 'coco_mask_metrics':
coco_mask_evaluator = coco_evaluation.CocoMaskEvaluator( coco_mask_evaluator = coco_evaluation.CocoMaskEvaluator(
categories, include_metrics_per_category=include_metrics_per_category) categories, include_metrics_per_category=include_metrics_per_category)
...@@ -629,7 +632,9 @@ def get_eval_metric_ops_for_evaluators(evaluation_metrics, ...@@ -629,7 +632,9 @@ def get_eval_metric_ops_for_evaluators(evaluation_metrics,
input_data_fields.groundtruth_instance_masks], input_data_fields.groundtruth_instance_masks],
detection_scores=eval_dict[detection_fields.detection_scores], detection_scores=eval_dict[detection_fields.detection_scores],
detection_classes=eval_dict[detection_fields.detection_classes], detection_classes=eval_dict[detection_fields.detection_classes],
detection_masks=eval_dict[detection_fields.detection_masks])) detection_masks=eval_dict[detection_fields.detection_masks],
groundtruth_is_crowd=eval_dict.get(
input_data_fields.groundtruth_is_crowd),))
else: else:
raise ValueError('The only evaluation metrics supported are ' raise ValueError('The only evaluation metrics supported are '
'"coco_detection_metrics" and "coco_mask_metrics". ' '"coco_detection_metrics" and "coco_mask_metrics". '
......
...@@ -197,6 +197,9 @@ def _add_output_tensor_nodes(postprocessed_tensors, ...@@ -197,6 +197,9 @@ def _add_output_tensor_nodes(postprocessed_tensors,
containing scores for the detected boxes. containing scores for the detected boxes.
* detection_classes: float32 tensor of shape [batch_size, num_boxes] * detection_classes: float32 tensor of shape [batch_size, num_boxes]
containing class predictions for the detected boxes. containing class predictions for the detected boxes.
* detection_keypoints: (Optional) float32 tensor of shape
[batch_size, num_boxes, num_keypoints, 2] containing keypoints for each
detection box.
* detection_masks: (Optional) float32 tensor of shape * detection_masks: (Optional) float32 tensor of shape
[batch_size, num_boxes, mask_height, mask_width] containing masks for each [batch_size, num_boxes, mask_height, mask_width] containing masks for each
detection box. detection box.
...@@ -220,6 +223,7 @@ def _add_output_tensor_nodes(postprocessed_tensors, ...@@ -220,6 +223,7 @@ def _add_output_tensor_nodes(postprocessed_tensors,
scores = postprocessed_tensors.get(detection_fields.detection_scores) scores = postprocessed_tensors.get(detection_fields.detection_scores)
classes = postprocessed_tensors.get( classes = postprocessed_tensors.get(
detection_fields.detection_classes) + label_id_offset detection_fields.detection_classes) + label_id_offset
keypoints = postprocessed_tensors.get(detection_fields.detection_keypoints)
masks = postprocessed_tensors.get(detection_fields.detection_masks) masks = postprocessed_tensors.get(detection_fields.detection_masks)
num_detections = postprocessed_tensors.get(detection_fields.num_detections) num_detections = postprocessed_tensors.get(detection_fields.num_detections)
outputs = {} outputs = {}
...@@ -231,6 +235,9 @@ def _add_output_tensor_nodes(postprocessed_tensors, ...@@ -231,6 +235,9 @@ def _add_output_tensor_nodes(postprocessed_tensors,
classes, name=detection_fields.detection_classes) classes, name=detection_fields.detection_classes)
outputs[detection_fields.num_detections] = tf.identity( outputs[detection_fields.num_detections] = tf.identity(
num_detections, name=detection_fields.num_detections) num_detections, name=detection_fields.num_detections)
if keypoints is not None:
outputs[detection_fields.detection_keypoints] = tf.identity(
keypoints, name=detection_fields.detection_keypoints)
if masks is not None: if masks is not None:
outputs[detection_fields.detection_masks] = tf.identity( outputs[detection_fields.detection_masks] = tf.identity(
masks, name=detection_fields.detection_masks) masks, name=detection_fields.detection_masks)
......
This diff is collapsed.
...@@ -71,6 +71,7 @@ Some remarks on frozen inference graphs: ...@@ -71,6 +71,7 @@ Some remarks on frozen inference graphs:
| ------------ | :--------------: | :--------------: | :-------------: | | ------------ | :--------------: | :--------------: | :-------------: |
| [ssd_mobilenet_v1_coco](http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_2017_11_17.tar.gz) | 30 | 21 | Boxes | | [ssd_mobilenet_v1_coco](http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_2017_11_17.tar.gz) | 30 | 21 | Boxes |
| [ssd_mobilenet_v2_coco](http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v2_coco_2018_03_29.tar.gz) | 31 | 22 | Boxes | | [ssd_mobilenet_v2_coco](http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v2_coco_2018_03_29.tar.gz) | 31 | 22 | Boxes |
| [ssdlite_mobilenet_v2_coco](http://download.tensorflow.org/models/object_detection/ssdlite_mobilenet_v2_coco_2018_05_09.tar.gz) | 27 | 22 | Boxes |
| [ssd_inception_v2_coco](http://download.tensorflow.org/models/object_detection/ssd_inception_v2_coco_2017_11_17.tar.gz) | 42 | 24 | Boxes | | [ssd_inception_v2_coco](http://download.tensorflow.org/models/object_detection/ssd_inception_v2_coco_2017_11_17.tar.gz) | 42 | 24 | Boxes |
| [faster_rcnn_inception_v2_coco](http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_v2_coco_2018_01_28.tar.gz) | 58 | 28 | Boxes | | [faster_rcnn_inception_v2_coco](http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_v2_coco_2018_01_28.tar.gz) | 58 | 28 | Boxes |
| [faster_rcnn_resnet50_coco](http://download.tensorflow.org/models/object_detection/faster_rcnn_resnet50_coco_2018_01_28.tar.gz) | 89 | 30 | Boxes | | [faster_rcnn_resnet50_coco](http://download.tensorflow.org/models/object_detection/faster_rcnn_resnet50_coco_2018_01_28.tar.gz) | 89 | 30 | Boxes |
......
...@@ -71,11 +71,10 @@ def transform_input_data(tensor_dict, ...@@ -71,11 +71,10 @@ def transform_input_data(tensor_dict,
model_preprocess_fn: model's preprocess function to apply on image tensor. model_preprocess_fn: model's preprocess function to apply on image tensor.
This function must take in a 4-D float tensor and return a 4-D preprocess This function must take in a 4-D float tensor and return a 4-D preprocess
float tensor and a tensor containing the true image shape. float tensor and a tensor containing the true image shape.
image_resizer_fn: image resizer function to apply on original image (if image_resizer_fn: image resizer function to apply on groundtruth instance
`retain_original_image` is True) and groundtruth instance masks. This `masks. This function must take a 3-D float tensor of an image and a 3-D
function must take a 3-D float tensor of an image and a 3-D tensor of tensor of instance masks and return a resized version of these along with
instance masks and return a resized version of these along with the true the true shapes.
shapes.
num_classes: number of max classes to one-hot (or k-hot) encode the class num_classes: number of max classes to one-hot (or k-hot) encode the class
labels. labels.
data_augmentation_fn: (optional) data augmentation function to apply on data_augmentation_fn: (optional) data augmentation function to apply on
...@@ -90,10 +89,8 @@ def transform_input_data(tensor_dict, ...@@ -90,10 +89,8 @@ def transform_input_data(tensor_dict,
after applying all the transformations. after applying all the transformations.
""" """
if retain_original_image: if retain_original_image:
original_image_resized, _ = image_resizer_fn(
tensor_dict[fields.InputDataFields.image])
tensor_dict[fields.InputDataFields.original_image] = tf.cast( tensor_dict[fields.InputDataFields.original_image] = tf.cast(
original_image_resized, tf.uint8) tensor_dict[fields.InputDataFields.image], tf.uint8)
# Apply data augmentation ops. # Apply data augmentation ops.
if data_augmentation_fn is not None: if data_augmentation_fn is not None:
...@@ -350,7 +347,7 @@ def create_eval_input_fn(eval_config, eval_input_config, model_config): ...@@ -350,7 +347,7 @@ def create_eval_input_fn(eval_config, eval_input_config, model_config):
TypeError: if the `eval_config`, `eval_input_config` or `model_config` TypeError: if the `eval_config`, `eval_input_config` or `model_config`
are not of the correct type. are not of the correct type.
""" """
del params params = params or {}
if not isinstance(eval_config, eval_pb2.EvalConfig): if not isinstance(eval_config, eval_pb2.EvalConfig):
raise TypeError('For eval mode, the `eval_config` must be a ' raise TypeError('For eval mode, the `eval_config` must be a '
'train_pb2.EvalConfig.') 'train_pb2.EvalConfig.')
...@@ -375,7 +372,7 @@ def create_eval_input_fn(eval_config, eval_input_config, model_config): ...@@ -375,7 +372,7 @@ def create_eval_input_fn(eval_config, eval_input_config, model_config):
dataset = INPUT_BUILDER_UTIL_MAP['dataset_build']( dataset = INPUT_BUILDER_UTIL_MAP['dataset_build'](
eval_input_config, eval_input_config,
transform_input_data_fn=transform_data_fn, transform_input_data_fn=transform_data_fn,
batch_size=1, batch_size=params.get('batch_size', 1),
num_classes=config_util.get_number_of_classes(model_config), num_classes=config_util.get_number_of_classes(model_config),
spatial_image_shape=config_util.get_spatial_image_size( spatial_image_shape=config_util.get_spatial_image_size(
image_resizer_config)) image_resizer_config))
......
...@@ -482,7 +482,7 @@ class DataTransformationFnTest(tf.test.TestCase): ...@@ -482,7 +482,7 @@ class DataTransformationFnTest(tf.test.TestCase):
self.assertAllEqual(transformed_inputs[ self.assertAllEqual(transformed_inputs[
fields.InputDataFields.original_image].dtype, tf.uint8) fields.InputDataFields.original_image].dtype, tf.uint8)
self.assertAllEqual(transformed_inputs[ self.assertAllEqual(transformed_inputs[
fields.InputDataFields.original_image].shape, [8, 8, 3]) fields.InputDataFields.original_image].shape, [4, 4, 3])
self.assertAllEqual(transformed_inputs[ self.assertAllEqual(transformed_inputs[
fields.InputDataFields.groundtruth_instance_masks].shape, [2, 8, 8]) fields.InputDataFields.groundtruth_instance_masks].shape, [2, 8, 8])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment