"vscode:/vscode.git/clone" did not exist on "8bc5a1a5aa9068820d1dfeb26b4887e0740833a2"
Unverified Commit fd7b6887 authored by Jonathan Huang's avatar Jonathan Huang Committed by GitHub
Browse files

Merge pull request #3293 from pkulzc/master

Internal changes of object_detection 
parents f98ec55e 1efe98bb
...@@ -116,18 +116,17 @@ def build_faster_rcnn_classification_loss(loss_config): ...@@ -116,18 +116,17 @@ def build_faster_rcnn_classification_loss(loss_config):
loss_type = loss_config.WhichOneof('classification_loss') loss_type = loss_config.WhichOneof('classification_loss')
if loss_type == 'weighted_sigmoid': if loss_type == 'weighted_sigmoid':
config = loss_config.weighted_sigmoid return losses.WeightedSigmoidClassificationLoss()
return losses.WeightedSigmoidClassificationLoss(
anchorwise_output=config.anchorwise_output)
if loss_type == 'weighted_softmax': if loss_type == 'weighted_softmax':
config = loss_config.weighted_softmax config = loss_config.weighted_softmax
return losses.WeightedSoftmaxClassificationLoss( return losses.WeightedSoftmaxClassificationLoss(
anchorwise_output=config.anchorwise_output) logit_scale=config.logit_scale)
# By default, Faster RCNN second stage classifier uses Softmax loss # By default, Faster RCNN second stage classifier uses Softmax loss
# with anchor-wise outputs. # with anchor-wise outputs.
config = loss_config.weighted_softmax
return losses.WeightedSoftmaxClassificationLoss( return losses.WeightedSoftmaxClassificationLoss(
anchorwise_output=True) logit_scale=config.logit_scale)
def _build_localization_loss(loss_config): def _build_localization_loss(loss_config):
...@@ -148,14 +147,10 @@ def _build_localization_loss(loss_config): ...@@ -148,14 +147,10 @@ def _build_localization_loss(loss_config):
loss_type = loss_config.WhichOneof('localization_loss') loss_type = loss_config.WhichOneof('localization_loss')
if loss_type == 'weighted_l2': if loss_type == 'weighted_l2':
config = loss_config.weighted_l2 return losses.WeightedL2LocalizationLoss()
return losses.WeightedL2LocalizationLoss(
anchorwise_output=config.anchorwise_output)
if loss_type == 'weighted_smooth_l1': if loss_type == 'weighted_smooth_l1':
config = loss_config.weighted_smooth_l1 return losses.WeightedSmoothL1LocalizationLoss()
return losses.WeightedSmoothL1LocalizationLoss(
anchorwise_output=config.anchorwise_output)
if loss_type == 'weighted_iou': if loss_type == 'weighted_iou':
return losses.WeightedIOULocalizationLoss() return losses.WeightedIOULocalizationLoss()
...@@ -181,9 +176,7 @@ def _build_classification_loss(loss_config): ...@@ -181,9 +176,7 @@ def _build_classification_loss(loss_config):
loss_type = loss_config.WhichOneof('classification_loss') loss_type = loss_config.WhichOneof('classification_loss')
if loss_type == 'weighted_sigmoid': if loss_type == 'weighted_sigmoid':
config = loss_config.weighted_sigmoid return losses.WeightedSigmoidClassificationLoss()
return losses.WeightedSigmoidClassificationLoss(
anchorwise_output=config.anchorwise_output)
if loss_type == 'weighted_sigmoid_focal': if loss_type == 'weighted_sigmoid_focal':
config = loss_config.weighted_sigmoid_focal config = loss_config.weighted_sigmoid_focal
...@@ -191,21 +184,18 @@ def _build_classification_loss(loss_config): ...@@ -191,21 +184,18 @@ def _build_classification_loss(loss_config):
if config.HasField('alpha'): if config.HasField('alpha'):
alpha = config.alpha alpha = config.alpha
return losses.SigmoidFocalClassificationLoss( return losses.SigmoidFocalClassificationLoss(
anchorwise_output=config.anchorwise_output,
gamma=config.gamma, gamma=config.gamma,
alpha=alpha) alpha=alpha)
if loss_type == 'weighted_softmax': if loss_type == 'weighted_softmax':
config = loss_config.weighted_softmax config = loss_config.weighted_softmax
return losses.WeightedSoftmaxClassificationLoss( return losses.WeightedSoftmaxClassificationLoss(
anchorwise_output=config.anchorwise_output,
logit_scale=config.logit_scale) logit_scale=config.logit_scale)
if loss_type == 'bootstrapped_sigmoid': if loss_type == 'bootstrapped_sigmoid':
config = loss_config.bootstrapped_sigmoid config = loss_config.bootstrapped_sigmoid
return losses.BootstrappedSigmoidClassificationLoss( return losses.BootstrappedSigmoidClassificationLoss(
alpha=config.alpha, alpha=config.alpha,
bootstrap_type=('hard' if config.hard_bootstrap else 'soft'), bootstrap_type=('hard' if config.hard_bootstrap else 'soft'))
anchorwise_output=config.anchorwise_output)
raise ValueError('Empty loss config.') raise ValueError('Empty loss config.')
...@@ -80,7 +80,6 @@ class LocalizationLossBuilderTest(tf.test.TestCase): ...@@ -80,7 +80,6 @@ class LocalizationLossBuilderTest(tf.test.TestCase):
losses_text_proto = """ losses_text_proto = """
localization_loss { localization_loss {
weighted_smooth_l1 { weighted_smooth_l1 {
anchorwise_output: true
} }
} }
classification_loss { classification_loss {
...@@ -245,7 +244,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase): ...@@ -245,7 +244,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
targets = tf.constant([[[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]]) targets = tf.constant([[[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]])
weights = tf.constant([[1.0, 1.0]]) weights = tf.constant([[1.0, 1.0]])
loss = classification_loss(predictions, targets, weights=weights) loss = classification_loss(predictions, targets, weights=weights)
self.assertEqual(loss.shape, [1, 2]) self.assertEqual(loss.shape, [1, 2, 3])
def test_raise_error_on_empty_config(self): def test_raise_error_on_empty_config(self):
losses_text_proto = """ losses_text_proto = """
......
...@@ -45,7 +45,9 @@ def build(matcher_config): ...@@ -45,7 +45,9 @@ def build(matcher_config):
matched_threshold=matched_threshold, matched_threshold=matched_threshold,
unmatched_threshold=unmatched_threshold, unmatched_threshold=unmatched_threshold,
negatives_lower_than_unmatched=matcher.negatives_lower_than_unmatched, negatives_lower_than_unmatched=matcher.negatives_lower_than_unmatched,
force_match_for_each_row=matcher.force_match_for_each_row) force_match_for_each_row=matcher.force_match_for_each_row,
use_matmul_gather=matcher.use_matmul_gather)
if matcher_config.WhichOneof('matcher_oneof') == 'bipartite_matcher': if matcher_config.WhichOneof('matcher_oneof') == 'bipartite_matcher':
return bipartite_matcher.GreedyBipartiteMatcher() matcher = matcher_config.bipartite_matcher
return bipartite_matcher.GreedyBipartiteMatcher(matcher.use_matmul_gather)
raise ValueError('Empty matcher.') raise ValueError('Empty matcher.')
...@@ -62,6 +62,7 @@ class MatcherBuilderTest(tf.test.TestCase): ...@@ -62,6 +62,7 @@ class MatcherBuilderTest(tf.test.TestCase):
unmatched_threshold: 0.3 unmatched_threshold: 0.3
negatives_lower_than_unmatched: false negatives_lower_than_unmatched: false
force_match_for_each_row: true force_match_for_each_row: true
use_matmul_gather: true
} }
""" """
matcher_proto = matcher_pb2.Matcher() matcher_proto = matcher_pb2.Matcher()
...@@ -72,6 +73,7 @@ class MatcherBuilderTest(tf.test.TestCase): ...@@ -72,6 +73,7 @@ class MatcherBuilderTest(tf.test.TestCase):
self.assertAlmostEqual(matcher_object._unmatched_threshold, 0.3) self.assertAlmostEqual(matcher_object._unmatched_threshold, 0.3)
self.assertFalse(matcher_object._negatives_lower_than_unmatched) self.assertFalse(matcher_object._negatives_lower_than_unmatched)
self.assertTrue(matcher_object._force_match_for_each_row) self.assertTrue(matcher_object._force_match_for_each_row)
self.assertTrue(matcher_object._use_matmul_gather)
def test_build_bipartite_matcher(self): def test_build_bipartite_matcher(self):
matcher_text_proto = """ matcher_text_proto = """
......
...@@ -31,6 +31,7 @@ from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extr ...@@ -31,6 +31,7 @@ from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extr
from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2 from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2
from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas
from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1 from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1
from object_detection.models import ssd_resnet_v1_fpn_feature_extractor as ssd_resnet_v1_fpn
from object_detection.models.embedded_ssd_mobilenet_v1_feature_extractor import EmbeddedSSDMobileNetV1FeatureExtractor from object_detection.models.embedded_ssd_mobilenet_v1_feature_extractor import EmbeddedSSDMobileNetV1FeatureExtractor
from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor
from object_detection.models.ssd_inception_v3_feature_extractor import SSDInceptionV3FeatureExtractor from object_detection.models.ssd_inception_v3_feature_extractor import SSDInceptionV3FeatureExtractor
...@@ -42,6 +43,9 @@ SSD_FEATURE_EXTRACTOR_CLASS_MAP = { ...@@ -42,6 +43,9 @@ SSD_FEATURE_EXTRACTOR_CLASS_MAP = {
'ssd_inception_v2': SSDInceptionV2FeatureExtractor, 'ssd_inception_v2': SSDInceptionV2FeatureExtractor,
'ssd_inception_v3': SSDInceptionV3FeatureExtractor, 'ssd_inception_v3': SSDInceptionV3FeatureExtractor,
'ssd_mobilenet_v1': SSDMobileNetV1FeatureExtractor, 'ssd_mobilenet_v1': SSDMobileNetV1FeatureExtractor,
'ssd_resnet50_v1_fpn': ssd_resnet_v1_fpn.SSDResnet50V1FpnFeatureExtractor,
'ssd_resnet101_v1_fpn': ssd_resnet_v1_fpn.SSDResnet101V1FpnFeatureExtractor,
'ssd_resnet152_v1_fpn': ssd_resnet_v1_fpn.SSDResnet152V1FpnFeatureExtractor,
'embedded_ssd_mobilenet_v1': EmbeddedSSDMobileNetV1FeatureExtractor, 'embedded_ssd_mobilenet_v1': EmbeddedSSDMobileNetV1FeatureExtractor,
} }
...@@ -62,13 +66,14 @@ FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP = { ...@@ -62,13 +66,14 @@ FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP = {
} }
def build(model_config, is_training): def build(model_config, is_training, add_summaries=True):
"""Builds a DetectionModel based on the model config. """Builds a DetectionModel based on the model config.
Args: Args:
model_config: A model.proto object containing the config for the desired model_config: A model.proto object containing the config for the desired
DetectionModel. DetectionModel.
is_training: True if this model is being built for training purposes. is_training: True if this model is being built for training purposes.
add_summaries: Whether to add tensorflow summaries in the model graph.
Returns: Returns:
DetectionModel based on the config. DetectionModel based on the config.
...@@ -80,9 +85,10 @@ def build(model_config, is_training): ...@@ -80,9 +85,10 @@ def build(model_config, is_training):
raise ValueError('model_config not of type model_pb2.DetectionModel.') raise ValueError('model_config not of type model_pb2.DetectionModel.')
meta_architecture = model_config.WhichOneof('model') meta_architecture = model_config.WhichOneof('model')
if meta_architecture == 'ssd': if meta_architecture == 'ssd':
return _build_ssd_model(model_config.ssd, is_training) return _build_ssd_model(model_config.ssd, is_training, add_summaries)
if meta_architecture == 'faster_rcnn': if meta_architecture == 'faster_rcnn':
return _build_faster_rcnn_model(model_config.faster_rcnn, is_training) return _build_faster_rcnn_model(model_config.faster_rcnn, is_training,
add_summaries)
raise ValueError('Unknown meta architecture: {}'.format(meta_architecture)) raise ValueError('Unknown meta architecture: {}'.format(meta_architecture))
...@@ -106,6 +112,8 @@ def _build_ssd_feature_extractor(feature_extractor_config, is_training, ...@@ -106,6 +112,8 @@ def _build_ssd_feature_extractor(feature_extractor_config, is_training,
min_depth = feature_extractor_config.min_depth min_depth = feature_extractor_config.min_depth
pad_to_multiple = feature_extractor_config.pad_to_multiple pad_to_multiple = feature_extractor_config.pad_to_multiple
batch_norm_trainable = feature_extractor_config.batch_norm_trainable batch_norm_trainable = feature_extractor_config.batch_norm_trainable
use_explicit_padding = feature_extractor_config.use_explicit_padding
use_depthwise = feature_extractor_config.use_depthwise
conv_hyperparams = hyperparams_builder.build( conv_hyperparams = hyperparams_builder.build(
feature_extractor_config.conv_hyperparams, is_training) feature_extractor_config.conv_hyperparams, is_training)
...@@ -115,16 +123,18 @@ def _build_ssd_feature_extractor(feature_extractor_config, is_training, ...@@ -115,16 +123,18 @@ def _build_ssd_feature_extractor(feature_extractor_config, is_training,
feature_extractor_class = SSD_FEATURE_EXTRACTOR_CLASS_MAP[feature_type] feature_extractor_class = SSD_FEATURE_EXTRACTOR_CLASS_MAP[feature_type]
return feature_extractor_class(is_training, depth_multiplier, min_depth, return feature_extractor_class(is_training, depth_multiplier, min_depth,
pad_to_multiple, conv_hyperparams, pad_to_multiple, conv_hyperparams,
batch_norm_trainable, reuse_weights) batch_norm_trainable, reuse_weights,
use_explicit_padding, use_depthwise)
def _build_ssd_model(ssd_config, is_training): def _build_ssd_model(ssd_config, is_training, add_summaries):
"""Builds an SSD detection model based on the model config. """Builds an SSD detection model based on the model config.
Args: Args:
ssd_config: A ssd.proto object containing the config for the desired ssd_config: A ssd.proto object containing the config for the desired
SSDMetaArch. SSDMetaArch.
is_training: True if this model is being built for training purposes. is_training: True if this model is being built for training purposes.
add_summaries: Whether to add tf summaries in the model.
Returns: Returns:
SSDMetaArch based on the config. SSDMetaArch based on the config.
...@@ -171,7 +181,8 @@ def _build_ssd_model(ssd_config, is_training): ...@@ -171,7 +181,8 @@ def _build_ssd_model(ssd_config, is_training):
classification_weight, classification_weight,
localization_weight, localization_weight,
normalize_loss_by_num_matches, normalize_loss_by_num_matches,
hard_example_miner) hard_example_miner,
add_summaries=add_summaries)
def _build_faster_rcnn_feature_extractor( def _build_faster_rcnn_feature_extractor(
...@@ -205,7 +216,7 @@ def _build_faster_rcnn_feature_extractor( ...@@ -205,7 +216,7 @@ def _build_faster_rcnn_feature_extractor(
batch_norm_trainable, reuse_weights) batch_norm_trainable, reuse_weights)
def _build_faster_rcnn_model(frcnn_config, is_training): def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries):
"""Builds a Faster R-CNN or R-FCN detection model based on the model config. """Builds a Faster R-CNN or R-FCN detection model based on the model config.
Builds R-FCN model if the second_stage_box_predictor in the config is of type Builds R-FCN model if the second_stage_box_predictor in the config is of type
...@@ -215,6 +226,7 @@ def _build_faster_rcnn_model(frcnn_config, is_training): ...@@ -215,6 +226,7 @@ def _build_faster_rcnn_model(frcnn_config, is_training):
frcnn_config: A faster_rcnn.proto object containing the config for the frcnn_config: A faster_rcnn.proto object containing the config for the
desired FasterRCNNMetaArch or RFCNMetaArch. desired FasterRCNNMetaArch or RFCNMetaArch.
is_training: True if this model is being built for training purposes. is_training: True if this model is being built for training purposes.
add_summaries: Whether to add tf summaries in the model.
Returns: Returns:
FasterRCNNMetaArch based on the config. FasterRCNNMetaArch based on the config.
...@@ -228,7 +240,7 @@ def _build_faster_rcnn_model(frcnn_config, is_training): ...@@ -228,7 +240,7 @@ def _build_faster_rcnn_model(frcnn_config, is_training):
feature_extractor = _build_faster_rcnn_feature_extractor( feature_extractor = _build_faster_rcnn_feature_extractor(
frcnn_config.feature_extractor, is_training) frcnn_config.feature_extractor, is_training)
first_stage_only = frcnn_config.first_stage_only number_of_stages = frcnn_config.number_of_stages
first_stage_anchor_generator = anchor_generator_builder.build( first_stage_anchor_generator = anchor_generator_builder.build(
frcnn_config.first_stage_anchor_generator) frcnn_config.first_stage_anchor_generator)
...@@ -283,7 +295,7 @@ def _build_faster_rcnn_model(frcnn_config, is_training): ...@@ -283,7 +295,7 @@ def _build_faster_rcnn_model(frcnn_config, is_training):
'num_classes': num_classes, 'num_classes': num_classes,
'image_resizer_fn': image_resizer_fn, 'image_resizer_fn': image_resizer_fn,
'feature_extractor': feature_extractor, 'feature_extractor': feature_extractor,
'first_stage_only': first_stage_only, 'number_of_stages': number_of_stages,
'first_stage_anchor_generator': first_stage_anchor_generator, 'first_stage_anchor_generator': first_stage_anchor_generator,
'first_stage_atrous_rate': first_stage_atrous_rate, 'first_stage_atrous_rate': first_stage_atrous_rate,
'first_stage_box_predictor_arg_scope': 'first_stage_box_predictor_arg_scope':
...@@ -310,7 +322,8 @@ def _build_faster_rcnn_model(frcnn_config, is_training): ...@@ -310,7 +322,8 @@ def _build_faster_rcnn_model(frcnn_config, is_training):
second_stage_classification_loss, second_stage_classification_loss,
'second_stage_classification_loss_weight': 'second_stage_classification_loss_weight':
second_stage_classification_loss_weight, second_stage_classification_loss_weight,
'hard_example_miner': hard_example_miner} 'hard_example_miner': hard_example_miner,
'add_summaries': add_summaries}
if isinstance(second_stage_box_predictor, box_predictor.RfcnBoxPredictor): if isinstance(second_stage_box_predictor, box_predictor.RfcnBoxPredictor):
return rfcn_meta_arch.RFCNMetaArch( return rfcn_meta_arch.RFCNMetaArch(
......
...@@ -26,12 +26,14 @@ from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extr ...@@ -26,12 +26,14 @@ from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extr
from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2 from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2
from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas
from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1 from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1
from object_detection.models import ssd_resnet_v1_fpn_feature_extractor as ssd_resnet_v1_fpn
from object_detection.models.embedded_ssd_mobilenet_v1_feature_extractor import EmbeddedSSDMobileNetV1FeatureExtractor
from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor
from object_detection.models.ssd_inception_v3_feature_extractor import SSDInceptionV3FeatureExtractor from object_detection.models.ssd_inception_v3_feature_extractor import SSDInceptionV3FeatureExtractor
from object_detection.models.ssd_mobilenet_v1_feature_extractor import SSDMobileNetV1FeatureExtractor from object_detection.models.ssd_mobilenet_v1_feature_extractor import SSDMobileNetV1FeatureExtractor
from object_detection.protos import model_pb2 from object_detection.protos import model_pb2
FEATURE_EXTRACTOR_MAPS = { FRCNN_RESNET_FEAT_MAPS = {
'faster_rcnn_resnet50': 'faster_rcnn_resnet50':
frcnn_resnet_v1.FasterRCNNResnet50FeatureExtractor, frcnn_resnet_v1.FasterRCNNResnet50FeatureExtractor,
'faster_rcnn_resnet101': 'faster_rcnn_resnet101':
...@@ -40,6 +42,15 @@ FEATURE_EXTRACTOR_MAPS = { ...@@ -40,6 +42,15 @@ FEATURE_EXTRACTOR_MAPS = {
frcnn_resnet_v1.FasterRCNNResnet152FeatureExtractor frcnn_resnet_v1.FasterRCNNResnet152FeatureExtractor
} }
SSD_RESNET_V1_FPN_FEAT_MAPS = {
'ssd_resnet50_v1_fpn':
ssd_resnet_v1_fpn.SSDResnet50V1FpnFeatureExtractor,
'ssd_resnet101_v1_fpn':
ssd_resnet_v1_fpn.SSDResnet101V1FpnFeatureExtractor,
'ssd_resnet152_v1_fpn':
ssd_resnet_v1_fpn.SSDResnet152V1FpnFeatureExtractor
}
class ModelBuilderTest(tf.test.TestCase): class ModelBuilderTest(tf.test.TestCase):
...@@ -197,6 +208,87 @@ class ModelBuilderTest(tf.test.TestCase): ...@@ -197,6 +208,87 @@ class ModelBuilderTest(tf.test.TestCase):
self.assertIsInstance(model._feature_extractor, self.assertIsInstance(model._feature_extractor,
SSDInceptionV3FeatureExtractor) SSDInceptionV3FeatureExtractor)
def test_create_ssd_resnet_v1_fpn_model_from_config(self):
model_text_proto = """
ssd {
feature_extractor {
type: 'ssd_resnet50_v1_fpn'
conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
batch_norm_trainable: true
}
box_coder {
faster_rcnn_box_coder {
}
}
matcher {
argmax_matcher {
}
}
similarity_calculator {
iou_similarity {
}
}
anchor_generator {
multiscale_anchor_generator {
aspect_ratios: [1.0, 2.0, 0.5]
scales_per_octave: 2
}
}
image_resizer {
fixed_shape_resizer {
height: 320
width: 320
}
}
box_predictor {
weight_shared_convolutional_box_predictor {
depth: 32
conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
num_layers_before_predictor: 1
}
}
loss {
classification_loss {
weighted_sigmoid_focal {
alpha: 0.25
gamma: 2.0
}
}
localization_loss {
weighted_smooth_l1 {
}
}
classification_weight: 1.0
localization_weight: 1.0
}
}"""
model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto)
for extractor_type, extractor_class in SSD_RESNET_V1_FPN_FEAT_MAPS.items():
model_proto.ssd.feature_extractor.type = extractor_type
model = model_builder.build(model_proto, is_training=True)
self.assertIsInstance(model, ssd_meta_arch.SSDMetaArch)
self.assertIsInstance(model._feature_extractor, extractor_class)
def test_create_ssd_mobilenet_v1_model_from_config(self): def test_create_ssd_mobilenet_v1_model_from_config(self):
model_text_proto = """ model_text_proto = """
ssd { ssd {
...@@ -270,6 +362,78 @@ class ModelBuilderTest(tf.test.TestCase): ...@@ -270,6 +362,78 @@ class ModelBuilderTest(tf.test.TestCase):
SSDMobileNetV1FeatureExtractor) SSDMobileNetV1FeatureExtractor)
self.assertTrue(model._feature_extractor._batch_norm_trainable) self.assertTrue(model._feature_extractor._batch_norm_trainable)
def test_create_embedded_ssd_mobilenet_v1_model_from_config(self):
model_text_proto = """
ssd {
feature_extractor {
type: 'embedded_ssd_mobilenet_v1'
conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
batch_norm_trainable: true
}
box_coder {
faster_rcnn_box_coder {
}
}
matcher {
argmax_matcher {
}
}
similarity_calculator {
iou_similarity {
}
}
anchor_generator {
ssd_anchor_generator {
aspect_ratios: 1.0
}
}
image_resizer {
fixed_shape_resizer {
height: 256
width: 256
}
}
box_predictor {
convolutional_box_predictor {
conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
}
}
loss {
classification_loss {
weighted_softmax {
}
}
localization_loss {
weighted_smooth_l1 {
}
}
}
}"""
model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto)
model = self.create_model(model_proto)
self.assertIsInstance(model, ssd_meta_arch.SSDMetaArch)
self.assertIsInstance(model._feature_extractor,
EmbeddedSSDMobileNetV1FeatureExtractor)
def test_create_faster_rcnn_resnet_v1_models_from_config(self): def test_create_faster_rcnn_resnet_v1_models_from_config(self):
model_text_proto = """ model_text_proto = """
faster_rcnn { faster_rcnn {
...@@ -331,7 +495,7 @@ class ModelBuilderTest(tf.test.TestCase): ...@@ -331,7 +495,7 @@ class ModelBuilderTest(tf.test.TestCase):
}""" }"""
model_proto = model_pb2.DetectionModel() model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto) text_format.Merge(model_text_proto, model_proto)
for extractor_type, extractor_class in FEATURE_EXTRACTOR_MAPS.items(): for extractor_type, extractor_class in FRCNN_RESNET_FEAT_MAPS.items():
model_proto.faster_rcnn.feature_extractor.type = extractor_type model_proto.faster_rcnn.feature_extractor.type = extractor_type
model = model_builder.build(model_proto, is_training=True) model = model_builder.build(model_proto, is_training=True)
self.assertIsInstance(model, faster_rcnn_meta_arch.FasterRCNNMetaArch) self.assertIsInstance(model, faster_rcnn_meta_arch.FasterRCNNMetaArch)
...@@ -730,7 +894,7 @@ class ModelBuilderTest(tf.test.TestCase): ...@@ -730,7 +894,7 @@ class ModelBuilderTest(tf.test.TestCase):
}""" }"""
model_proto = model_pb2.DetectionModel() model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto) text_format.Merge(model_text_proto, model_proto)
for extractor_type, extractor_class in FEATURE_EXTRACTOR_MAPS.items(): for extractor_type, extractor_class in FRCNN_RESNET_FEAT_MAPS.items():
model_proto.faster_rcnn.feature_extractor.type = extractor_type model_proto.faster_rcnn.feature_extractor.type = extractor_type
model = model_builder.build(model_proto, is_training=True) model = model_builder.build(model_proto, is_training=True)
self.assertIsInstance(model, rfcn_meta_arch.RFCNMetaArch) self.assertIsInstance(model, rfcn_meta_arch.RFCNMetaArch)
......
...@@ -19,15 +19,14 @@ import tensorflow as tf ...@@ -19,15 +19,14 @@ import tensorflow as tf
from object_detection.utils import learning_schedules from object_detection.utils import learning_schedules
def build(optimizer_config, global_summaries): def build(optimizer_config):
"""Create optimizer based on config. """Create optimizer based on config.
Args: Args:
optimizer_config: A Optimizer proto message. optimizer_config: A Optimizer proto message.
global_summaries: A set to attach learning rate summary to.
Returns: Returns:
An optimizer. An optimizer and a list of variables for summary.
Raises: Raises:
ValueError: when using an unsupported input data type. ValueError: when using an unsupported input data type.
...@@ -35,24 +34,30 @@ def build(optimizer_config, global_summaries): ...@@ -35,24 +34,30 @@ def build(optimizer_config, global_summaries):
optimizer_type = optimizer_config.WhichOneof('optimizer') optimizer_type = optimizer_config.WhichOneof('optimizer')
optimizer = None optimizer = None
summary_vars = []
if optimizer_type == 'rms_prop_optimizer': if optimizer_type == 'rms_prop_optimizer':
config = optimizer_config.rms_prop_optimizer config = optimizer_config.rms_prop_optimizer
learning_rate = _create_learning_rate(config.learning_rate)
summary_vars.append(learning_rate)
optimizer = tf.train.RMSPropOptimizer( optimizer = tf.train.RMSPropOptimizer(
_create_learning_rate(config.learning_rate, global_summaries), learning_rate,
decay=config.decay, decay=config.decay,
momentum=config.momentum_optimizer_value, momentum=config.momentum_optimizer_value,
epsilon=config.epsilon) epsilon=config.epsilon)
if optimizer_type == 'momentum_optimizer': if optimizer_type == 'momentum_optimizer':
config = optimizer_config.momentum_optimizer config = optimizer_config.momentum_optimizer
learning_rate = _create_learning_rate(config.learning_rate)
summary_vars.append(learning_rate)
optimizer = tf.train.MomentumOptimizer( optimizer = tf.train.MomentumOptimizer(
_create_learning_rate(config.learning_rate, global_summaries), learning_rate,
momentum=config.momentum_optimizer_value) momentum=config.momentum_optimizer_value)
if optimizer_type == 'adam_optimizer': if optimizer_type == 'adam_optimizer':
config = optimizer_config.adam_optimizer config = optimizer_config.adam_optimizer
optimizer = tf.train.AdamOptimizer( learning_rate = _create_learning_rate(config.learning_rate)
_create_learning_rate(config.learning_rate, global_summaries)) summary_vars.append(learning_rate)
optimizer = tf.train.AdamOptimizer(learning_rate)
if optimizer is None: if optimizer is None:
raise ValueError('Optimizer %s not supported.' % optimizer_type) raise ValueError('Optimizer %s not supported.' % optimizer_type)
...@@ -61,15 +66,14 @@ def build(optimizer_config, global_summaries): ...@@ -61,15 +66,14 @@ def build(optimizer_config, global_summaries):
optimizer = tf.contrib.opt.MovingAverageOptimizer( optimizer = tf.contrib.opt.MovingAverageOptimizer(
optimizer, average_decay=optimizer_config.moving_average_decay) optimizer, average_decay=optimizer_config.moving_average_decay)
return optimizer return optimizer, summary_vars
def _create_learning_rate(learning_rate_config, global_summaries): def _create_learning_rate(learning_rate_config):
"""Create optimizer learning rate based on config. """Create optimizer learning rate based on config.
Args: Args:
learning_rate_config: A LearningRate proto message. learning_rate_config: A LearningRate proto message.
global_summaries: A set to attach learning rate summary to.
Returns: Returns:
A learning rate. A learning rate.
...@@ -81,7 +85,7 @@ def _create_learning_rate(learning_rate_config, global_summaries): ...@@ -81,7 +85,7 @@ def _create_learning_rate(learning_rate_config, global_summaries):
learning_rate_type = learning_rate_config.WhichOneof('learning_rate') learning_rate_type = learning_rate_config.WhichOneof('learning_rate')
if learning_rate_type == 'constant_learning_rate': if learning_rate_type == 'constant_learning_rate':
config = learning_rate_config.constant_learning_rate config = learning_rate_config.constant_learning_rate
learning_rate = config.learning_rate learning_rate = tf.constant(config.learning_rate, dtype=tf.float32)
if learning_rate_type == 'exponential_decay_learning_rate': if learning_rate_type == 'exponential_decay_learning_rate':
config = learning_rate_config.exponential_decay_learning_rate config = learning_rate_config.exponential_decay_learning_rate
...@@ -115,5 +119,4 @@ def _create_learning_rate(learning_rate_config, global_summaries): ...@@ -115,5 +119,4 @@ def _create_learning_rate(learning_rate_config, global_summaries):
if learning_rate is None: if learning_rate is None:
raise ValueError('Learning_rate %s not supported.' % learning_rate_type) raise ValueError('Learning_rate %s not supported.' % learning_rate_type)
global_summaries.add(tf.summary.scalar('Learning_Rate', learning_rate))
return learning_rate return learning_rate
...@@ -31,12 +31,13 @@ class LearningRateBuilderTest(tf.test.TestCase): ...@@ -31,12 +31,13 @@ class LearningRateBuilderTest(tf.test.TestCase):
learning_rate: 0.004 learning_rate: 0.004
} }
""" """
global_summaries = set([])
learning_rate_proto = optimizer_pb2.LearningRate() learning_rate_proto = optimizer_pb2.LearningRate()
text_format.Merge(learning_rate_text_proto, learning_rate_proto) text_format.Merge(learning_rate_text_proto, learning_rate_proto)
learning_rate = optimizer_builder._create_learning_rate( learning_rate = optimizer_builder._create_learning_rate(
learning_rate_proto, global_summaries) learning_rate_proto)
self.assertAlmostEqual(learning_rate, 0.004) with self.test_session():
learning_rate_out = learning_rate.eval()
self.assertAlmostEqual(learning_rate_out, 0.004)
def testBuildExponentialDecayLearningRate(self): def testBuildExponentialDecayLearningRate(self):
learning_rate_text_proto = """ learning_rate_text_proto = """
...@@ -47,11 +48,10 @@ class LearningRateBuilderTest(tf.test.TestCase): ...@@ -47,11 +48,10 @@ class LearningRateBuilderTest(tf.test.TestCase):
staircase: false staircase: false
} }
""" """
global_summaries = set([])
learning_rate_proto = optimizer_pb2.LearningRate() learning_rate_proto = optimizer_pb2.LearningRate()
text_format.Merge(learning_rate_text_proto, learning_rate_proto) text_format.Merge(learning_rate_text_proto, learning_rate_proto)
learning_rate = optimizer_builder._create_learning_rate( learning_rate = optimizer_builder._create_learning_rate(
learning_rate_proto, global_summaries) learning_rate_proto)
self.assertTrue(isinstance(learning_rate, tf.Tensor)) self.assertTrue(isinstance(learning_rate, tf.Tensor))
def testBuildManualStepLearningRate(self): def testBuildManualStepLearningRate(self):
...@@ -67,11 +67,10 @@ class LearningRateBuilderTest(tf.test.TestCase): ...@@ -67,11 +67,10 @@ class LearningRateBuilderTest(tf.test.TestCase):
} }
} }
""" """
global_summaries = set([])
learning_rate_proto = optimizer_pb2.LearningRate() learning_rate_proto = optimizer_pb2.LearningRate()
text_format.Merge(learning_rate_text_proto, learning_rate_proto) text_format.Merge(learning_rate_text_proto, learning_rate_proto)
learning_rate = optimizer_builder._create_learning_rate( learning_rate = optimizer_builder._create_learning_rate(
learning_rate_proto, global_summaries) learning_rate_proto)
self.assertTrue(isinstance(learning_rate, tf.Tensor)) self.assertTrue(isinstance(learning_rate, tf.Tensor))
def testBuildCosineDecayLearningRate(self): def testBuildCosineDecayLearningRate(self):
...@@ -83,22 +82,19 @@ class LearningRateBuilderTest(tf.test.TestCase): ...@@ -83,22 +82,19 @@ class LearningRateBuilderTest(tf.test.TestCase):
warmup_steps: 1000 warmup_steps: 1000
} }
""" """
global_summaries = set([])
learning_rate_proto = optimizer_pb2.LearningRate() learning_rate_proto = optimizer_pb2.LearningRate()
text_format.Merge(learning_rate_text_proto, learning_rate_proto) text_format.Merge(learning_rate_text_proto, learning_rate_proto)
learning_rate = optimizer_builder._create_learning_rate( learning_rate = optimizer_builder._create_learning_rate(
learning_rate_proto, global_summaries) learning_rate_proto)
self.assertTrue(isinstance(learning_rate, tf.Tensor)) self.assertTrue(isinstance(learning_rate, tf.Tensor))
def testRaiseErrorOnEmptyLearningRate(self): def testRaiseErrorOnEmptyLearningRate(self):
learning_rate_text_proto = """ learning_rate_text_proto = """
""" """
global_summaries = set([])
learning_rate_proto = optimizer_pb2.LearningRate() learning_rate_proto = optimizer_pb2.LearningRate()
text_format.Merge(learning_rate_text_proto, learning_rate_proto) text_format.Merge(learning_rate_text_proto, learning_rate_proto)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
optimizer_builder._create_learning_rate( optimizer_builder._create_learning_rate(learning_rate_proto)
learning_rate_proto, global_summaries)
class OptimizerBuilderTest(tf.test.TestCase): class OptimizerBuilderTest(tf.test.TestCase):
...@@ -119,10 +115,9 @@ class OptimizerBuilderTest(tf.test.TestCase): ...@@ -119,10 +115,9 @@ class OptimizerBuilderTest(tf.test.TestCase):
} }
use_moving_average: false use_moving_average: false
""" """
global_summaries = set([])
optimizer_proto = optimizer_pb2.Optimizer() optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto) text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer = optimizer_builder.build(optimizer_proto, global_summaries) optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertTrue(isinstance(optimizer, tf.train.RMSPropOptimizer)) self.assertTrue(isinstance(optimizer, tf.train.RMSPropOptimizer))
def testBuildMomentumOptimizer(self): def testBuildMomentumOptimizer(self):
...@@ -137,10 +132,9 @@ class OptimizerBuilderTest(tf.test.TestCase): ...@@ -137,10 +132,9 @@ class OptimizerBuilderTest(tf.test.TestCase):
} }
use_moving_average: false use_moving_average: false
""" """
global_summaries = set([])
optimizer_proto = optimizer_pb2.Optimizer() optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto) text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer = optimizer_builder.build(optimizer_proto, global_summaries) optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertTrue(isinstance(optimizer, tf.train.MomentumOptimizer)) self.assertTrue(isinstance(optimizer, tf.train.MomentumOptimizer))
def testBuildAdamOptimizer(self): def testBuildAdamOptimizer(self):
...@@ -154,10 +148,9 @@ class OptimizerBuilderTest(tf.test.TestCase): ...@@ -154,10 +148,9 @@ class OptimizerBuilderTest(tf.test.TestCase):
} }
use_moving_average: false use_moving_average: false
""" """
global_summaries = set([])
optimizer_proto = optimizer_pb2.Optimizer() optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto) text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer = optimizer_builder.build(optimizer_proto, global_summaries) optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertTrue(isinstance(optimizer, tf.train.AdamOptimizer)) self.assertTrue(isinstance(optimizer, tf.train.AdamOptimizer))
def testBuildMovingAverageOptimizer(self): def testBuildMovingAverageOptimizer(self):
...@@ -171,10 +164,9 @@ class OptimizerBuilderTest(tf.test.TestCase): ...@@ -171,10 +164,9 @@ class OptimizerBuilderTest(tf.test.TestCase):
} }
use_moving_average: True use_moving_average: True
""" """
global_summaries = set([])
optimizer_proto = optimizer_pb2.Optimizer() optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto) text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer = optimizer_builder.build(optimizer_proto, global_summaries) optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertTrue( self.assertTrue(
isinstance(optimizer, tf.contrib.opt.MovingAverageOptimizer)) isinstance(optimizer, tf.contrib.opt.MovingAverageOptimizer))
...@@ -190,23 +182,21 @@ class OptimizerBuilderTest(tf.test.TestCase): ...@@ -190,23 +182,21 @@ class OptimizerBuilderTest(tf.test.TestCase):
use_moving_average: True use_moving_average: True
moving_average_decay: 0.2 moving_average_decay: 0.2
""" """
global_summaries = set([])
optimizer_proto = optimizer_pb2.Optimizer() optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto) text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer = optimizer_builder.build(optimizer_proto, global_summaries) optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertTrue( self.assertTrue(
isinstance(optimizer, tf.contrib.opt.MovingAverageOptimizer)) isinstance(optimizer, tf.contrib.opt.MovingAverageOptimizer))
# TODO(rathodv): Find a way to not depend on the private members. # TODO: Find a way to not depend on the private members.
self.assertAlmostEqual(optimizer._ema._decay, 0.2) self.assertAlmostEqual(optimizer._ema._decay, 0.2)
def testBuildEmptyOptimizer(self): def testBuildEmptyOptimizer(self):
optimizer_text_proto = """ optimizer_text_proto = """
""" """
global_summaries = set([])
optimizer_proto = optimizer_pb2.Optimizer() optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto) text_format.Merge(optimizer_text_proto, optimizer_proto)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
optimizer_builder.build(optimizer_proto, global_summaries) optimizer_builder.build(optimizer_proto)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -83,6 +83,7 @@ PREPROCESSING_FUNCTION_MAP = { ...@@ -83,6 +83,7 @@ PREPROCESSING_FUNCTION_MAP = {
'random_jitter_boxes': preprocessor.random_jitter_boxes, 'random_jitter_boxes': preprocessor.random_jitter_boxes,
'random_crop_to_aspect_ratio': preprocessor.random_crop_to_aspect_ratio, 'random_crop_to_aspect_ratio': preprocessor.random_crop_to_aspect_ratio,
'random_black_patches': preprocessor.random_black_patches, 'random_black_patches': preprocessor.random_black_patches,
'rgb_to_gray': preprocessor.rgb_to_gray,
'scale_boxes_to_pixel_coordinates': ( 'scale_boxes_to_pixel_coordinates': (
preprocessor.scale_boxes_to_pixel_coordinates), preprocessor.scale_boxes_to_pixel_coordinates),
'subtract_channel_mean': preprocessor.subtract_channel_mean, 'subtract_channel_mean': preprocessor.subtract_channel_mean,
......
...@@ -379,6 +379,16 @@ class PreprocessorBuilderTest(tf.test.TestCase): ...@@ -379,6 +379,16 @@ class PreprocessorBuilderTest(tf.test.TestCase):
'new_width': 100, 'new_width': 100,
'method': tf.image.ResizeMethod.BICUBIC}) 'method': tf.image.ResizeMethod.BICUBIC})
def test_build_rgb_to_gray(self):
preprocessor_text_proto = """
rgb_to_gray {}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.rgb_to_gray)
self.assertEqual(args, {})
def test_build_subtract_channel_mean(self): def test_build_subtract_channel_mean(self):
preprocessor_text_proto = """ preprocessor_text_proto = """
subtract_channel_mean { subtract_channel_mean {
......
...@@ -53,7 +53,7 @@ py_library( ...@@ -53,7 +53,7 @@ py_library(
deps = [ deps = [
":box_list", ":box_list",
"//tensorflow", "//tensorflow",
"//tensorflow_models/object_detection/utils:shape_utils", "//tensorflow/models/research/object_detection/utils:shape_utils",
], ],
) )
...@@ -113,7 +113,7 @@ py_library( ...@@ -113,7 +113,7 @@ py_library(
":box_list", ":box_list",
":box_list_ops", ":box_list_ops",
"//tensorflow", "//tensorflow",
"//tensorflow_models/object_detection/utils:ops", "//tensorflow/models/research/object_detection/utils:ops",
], ],
) )
...@@ -123,6 +123,7 @@ py_library( ...@@ -123,6 +123,7 @@ py_library(
"matcher.py", "matcher.py",
], ],
deps = [ deps = [
"//tensorflow/models/research/object_detection/utils:ops",
], ],
) )
...@@ -160,8 +161,17 @@ py_library( ...@@ -160,8 +161,17 @@ py_library(
":box_list", ":box_list",
":box_list_ops", ":box_list_ops",
":keypoint_ops", ":keypoint_ops",
":preprocessor_cache",
":standard_fields", ":standard_fields",
"//tensorflow", "//tensorflow",
"//tensorflow/models/research/object_detection/utils:shape_utils",
],
)
py_library(
name = "preprocessor_cache",
srcs = [
"preprocessor_cache.py",
], ],
) )
...@@ -172,6 +182,7 @@ py_test( ...@@ -172,6 +182,7 @@ py_test(
], ],
deps = [ deps = [
":preprocessor", ":preprocessor",
":preprocessor_cache",
"//tensorflow", "//tensorflow",
], ],
) )
...@@ -211,6 +222,7 @@ py_library( ...@@ -211,6 +222,7 @@ py_library(
":box_list_ops", ":box_list_ops",
":standard_fields", ":standard_fields",
"//tensorflow", "//tensorflow",
"//tensorflow/models/research/object_detection/utils:shape_utils",
], ],
) )
...@@ -232,15 +244,16 @@ py_library( ...@@ -232,15 +244,16 @@ py_library(
], ],
deps = [ deps = [
":box_list", ":box_list",
":box_list_ops",
":matcher", ":matcher",
":region_similarity_calculator", ":region_similarity_calculator",
":standard_fields",
"//tensorflow", "//tensorflow",
"//tensorflow_models/object_detection/box_coders:faster_rcnn_box_coder", "//tensorflow/models/research/object_detection/box_coders:faster_rcnn_box_coder",
"//tensorflow_models/object_detection/box_coders:mean_stddev_box_coder", "//tensorflow/models/research/object_detection/box_coders:mean_stddev_box_coder",
"//tensorflow_models/object_detection/core:box_coder", "//tensorflow/models/research/object_detection/core:box_coder",
"//tensorflow_models/object_detection/matchers:argmax_matcher", "//tensorflow/models/research/object_detection/matchers:argmax_matcher",
"//tensorflow_models/object_detection/matchers:bipartite_matcher", "//tensorflow/models/research/object_detection/matchers:bipartite_matcher",
"//tensorflow/models/research/object_detection/utils:shape_utils",
], ],
) )
...@@ -254,8 +267,10 @@ py_test( ...@@ -254,8 +267,10 @@ py_test(
":region_similarity_calculator", ":region_similarity_calculator",
":target_assigner", ":target_assigner",
"//tensorflow", "//tensorflow",
"//tensorflow_models/object_detection/box_coders:mean_stddev_box_coder", "//tensorflow/models/research/object_detection/box_coders:keypoint_box_coder",
"//tensorflow_models/object_detection/matchers:bipartite_matcher", "//tensorflow/models/research/object_detection/box_coders:mean_stddev_box_coder",
"//tensorflow/models/research/object_detection/matchers:bipartite_matcher",
"//tensorflow/models/research/object_detection/utils:test_case",
], ],
) )
...@@ -274,9 +289,9 @@ py_library( ...@@ -274,9 +289,9 @@ py_library(
srcs = ["box_predictor.py"], srcs = ["box_predictor.py"],
deps = [ deps = [
"//tensorflow", "//tensorflow",
"//tensorflow_models/object_detection/utils:ops", "//tensorflow/models/research/object_detection/utils:ops",
"//tensorflow_models/object_detection/utils:shape_utils", "//tensorflow/models/research/object_detection/utils:shape_utils",
"//tensorflow_models/object_detection/utils:static_shape", "//tensorflow/models/research/object_detection/utils:static_shape",
], ],
) )
...@@ -286,8 +301,9 @@ py_test( ...@@ -286,8 +301,9 @@ py_test(
deps = [ deps = [
":box_predictor", ":box_predictor",
"//tensorflow", "//tensorflow",
"//tensorflow_models/object_detection/builders:hyperparams_builder", "//tensorflow/models/research/object_detection/builders:hyperparams_builder",
"//tensorflow_models/object_detection/protos:hyperparams_py_pb2", "//tensorflow/models/research/object_detection/protos:hyperparams_py_pb2",
"//tensorflow/models/research/object_detection/utils:test_case",
], ],
) )
...@@ -298,7 +314,7 @@ py_library( ...@@ -298,7 +314,7 @@ py_library(
], ],
deps = [ deps = [
"//tensorflow", "//tensorflow",
"//tensorflow_models/object_detection/core:box_list_ops", "//tensorflow/models/research/object_detection/core:box_list_ops",
], ],
) )
...@@ -309,7 +325,7 @@ py_test( ...@@ -309,7 +325,7 @@ py_test(
], ],
deps = [ deps = [
":region_similarity_calculator", ":region_similarity_calculator",
"//tensorflow_models/object_detection/core:box_list", "//tensorflow/models/research/object_detection/core:box_list",
], ],
) )
...@@ -330,7 +346,7 @@ py_library( ...@@ -330,7 +346,7 @@ py_library(
], ],
deps = [ deps = [
"//tensorflow", "//tensorflow",
"//tensorflow_models/object_detection/utils:ops", "//tensorflow/models/research/object_detection/utils:ops",
], ],
) )
......
...@@ -77,8 +77,8 @@ class AnchorGenerator(object): ...@@ -77,8 +77,8 @@ class AnchorGenerator(object):
def generate(self, feature_map_shape_list, **params): def generate(self, feature_map_shape_list, **params):
"""Generates a collection of bounding boxes to be used as anchors. """Generates a collection of bounding boxes to be used as anchors.
TODO: remove **params from argument list and make stride and offsets (for TODO: remove **params from argument list and make stride and
multiple_grid_anchor_generator) constructor arguments. offsets (for multiple_grid_anchor_generator) constructor arguments.
Args: Args:
feature_map_shape_list: list of (height, width) pairs in the format feature_map_shape_list: list of (height, width) pairs in the format
...@@ -140,3 +140,4 @@ class AnchorGenerator(object): ...@@ -140,3 +140,4 @@ class AnchorGenerator(object):
* feature_map_shape[0] * feature_map_shape[0]
* feature_map_shape[1]) * feature_map_shape[1])
return tf.assert_equal(expected_num_anchors, anchors.num_boxes()) return tf.assert_equal(expected_num_anchors, anchors.num_boxes())
...@@ -183,7 +183,8 @@ def prune_completely_outside_window(boxlist, window, scope=None): ...@@ -183,7 +183,8 @@ def prune_completely_outside_window(boxlist, window, scope=None):
scope: name scope. scope: name scope.
Returns: Returns:
pruned_corners: a tensor with shape [M_out, 4] where M_out <= M_in pruned_boxlist: a new BoxList with all bounding boxes partially or fully in
the window.
valid_indices: a tensor with shape [M_out] indexing the valid bounding boxes valid_indices: a tensor with shape [M_out] indexing the valid bounding boxes
in the input tensor. in the input tensor.
""" """
...@@ -982,3 +983,79 @@ def pad_or_clip_box_list(boxlist, num_boxes, scope=None): ...@@ -982,3 +983,79 @@ def pad_or_clip_box_list(boxlist, num_boxes, scope=None):
boxlist.get_field(field), num_boxes) boxlist.get_field(field), num_boxes)
subboxlist.add_field(field, subfield) subboxlist.add_field(field, subfield)
return subboxlist return subboxlist
def select_random_box(boxlist,
default_box=None,
seed=None,
scope=None):
"""Selects a random bounding box from a `BoxList`.
Args:
boxlist: A BoxList.
default_box: A [1, 4] float32 tensor. If no boxes are present in `boxlist`,
this default box will be returned. If None, will use a default box of
[[-1., -1., -1., -1.]].
seed: Random seed.
scope: Name scope.
Returns:
bbox: A [1, 4] tensor with a random bounding box.
valid: A bool tensor indicating whether a valid bounding box is returned
(True) or whether the default box is returned (False).
"""
with tf.name_scope(scope, 'SelectRandomBox'):
bboxes = boxlist.get()
combined_shape = shape_utils.combined_static_and_dynamic_shape(bboxes)
number_of_boxes = combined_shape[0]
default_box = default_box or tf.constant([[-1., -1., -1., -1.]])
def select_box():
random_index = tf.random_uniform([],
maxval=number_of_boxes,
dtype=tf.int32,
seed=seed)
return tf.expand_dims(bboxes[random_index], axis=0), tf.constant(True)
return tf.cond(
tf.greater_equal(number_of_boxes, 1),
true_fn=select_box,
false_fn=lambda: (default_box, tf.constant(False)))
def get_minimal_coverage_box(boxlist,
default_box=None,
scope=None):
"""Creates a single bounding box which covers all boxes in the boxlist.
Args:
boxlist: A Boxlist.
default_box: A [1, 4] float32 tensor. If no boxes are present in `boxlist`,
this default box will be returned. If None, will use a default box of
[[0., 0., 1., 1.]].
scope: Name scope.
Returns:
A [1, 4] float32 tensor with a bounding box that tightly covers all the
boxes in the box list. If the boxlist does not contain any boxes, the
default box is returned.
"""
with tf.name_scope(scope, 'CreateCoverageBox'):
num_boxes = boxlist.num_boxes()
def coverage_box(bboxes):
y_min, x_min, y_max, x_max = tf.split(
value=bboxes, num_or_size_splits=4, axis=1)
y_min_coverage = tf.reduce_min(y_min, axis=0)
x_min_coverage = tf.reduce_min(x_min, axis=0)
y_max_coverage = tf.reduce_max(y_max, axis=0)
x_max_coverage = tf.reduce_max(x_max, axis=0)
return tf.stack(
[y_min_coverage, x_min_coverage, y_max_coverage, x_max_coverage],
axis=1)
default_box = default_box or tf.constant([[0., 0., 1., 1.]])
return tf.cond(
tf.greater_equal(num_boxes, 1),
true_fn=lambda: coverage_box(boxlist.get()),
false_fn=lambda: default_box)
...@@ -153,6 +153,25 @@ class BoxListOpsTest(tf.test.TestCase): ...@@ -153,6 +153,25 @@ class BoxListOpsTest(tf.test.TestCase):
extra_data_out = sess.run(pruned.get_field('extra_data')) extra_data_out = sess.run(pruned.get_field('extra_data'))
self.assertAllEqual(extra_data_out, [[1], [2], [3], [4], [6]]) self.assertAllEqual(extra_data_out, [[1], [2], [3], [4], [6]])
def test_prune_completely_outside_window_with_empty_boxlist(self):
window = tf.constant([0, 0, 9, 14], tf.float32)
corners = tf.zeros(shape=[0, 4], dtype=tf.float32)
boxes = box_list.BoxList(corners)
boxes.add_field('extra_data', tf.zeros(shape=[0], dtype=tf.int32))
pruned, keep_indices = box_list_ops.prune_completely_outside_window(boxes,
window)
pruned_boxes = pruned.get()
extra = pruned.get_field('extra_data')
exp_pruned_boxes = np.zeros(shape=[0, 4], dtype=np.float32)
exp_extra = np.zeros(shape=[0], dtype=np.int32)
with self.test_session() as sess:
pruned_boxes_out, keep_indices_out, extra_out = sess.run(
[pruned_boxes, keep_indices, extra])
self.assertAllClose(exp_pruned_boxes, pruned_boxes_out)
self.assertAllEqual([], keep_indices_out)
self.assertAllEqual(exp_extra, extra_out)
def test_intersection(self): def test_intersection(self):
corners1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]]) corners1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]])
corners2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0], corners2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0],
...@@ -593,6 +612,58 @@ class BoxListOpsTest(tf.test.TestCase): ...@@ -593,6 +612,58 @@ class BoxListOpsTest(tf.test.TestCase):
self.assertAllEqual(expected_classes, classes_out) self.assertAllEqual(expected_classes, classes_out)
self.assertAllClose(expected_scores, scores_out) self.assertAllClose(expected_scores, scores_out)
def test_select_random_box(self):
boxes = [[0., 0., 1., 1.],
[0., 1., 2., 3.],
[0., 2., 3., 4.]]
corners = tf.constant(boxes, dtype=tf.float32)
boxlist = box_list.BoxList(corners)
random_bbox, valid = box_list_ops.select_random_box(boxlist)
with self.test_session() as sess:
random_bbox_out, valid_out = sess.run([random_bbox, valid])
norm_small = any(
[np.linalg.norm(random_bbox_out - box) < 1e-6 for box in boxes])
self.assertTrue(norm_small)
self.assertTrue(valid_out)
def test_select_random_box_with_empty_boxlist(self):
corners = tf.constant([], shape=[0, 4], dtype=tf.float32)
boxlist = box_list.BoxList(corners)
random_bbox, valid = box_list_ops.select_random_box(boxlist)
with self.test_session() as sess:
random_bbox_out, valid_out = sess.run([random_bbox, valid])
expected_bbox_out = np.array([[-1., -1., -1., -1.]], dtype=np.float32)
self.assertAllEqual(expected_bbox_out, random_bbox_out)
self.assertFalse(valid_out)
def test_get_minimal_coverage_box(self):
boxes = [[0., 0., 1., 1.],
[-1., 1., 2., 3.],
[0., 2., 3., 4.]]
expected_coverage_box = [[-1., 0., 3., 4.]]
corners = tf.constant(boxes, dtype=tf.float32)
boxlist = box_list.BoxList(corners)
coverage_box = box_list_ops.get_minimal_coverage_box(boxlist)
with self.test_session() as sess:
coverage_box_out = sess.run(coverage_box)
self.assertAllClose(expected_coverage_box, coverage_box_out)
def test_get_minimal_coverage_box_with_empty_boxlist(self):
corners = tf.constant([], shape=[0, 4], dtype=tf.float32)
boxlist = box_list.BoxList(corners)
coverage_box = box_list_ops.get_minimal_coverage_box(boxlist)
with self.test_session() as sess:
coverage_box_out = sess.run(coverage_box)
self.assertAllClose([[0.0, 0.0, 1.0, 1.0]], coverage_box_out)
class ConcatenateTest(tf.test.TestCase): class ConcatenateTest(tf.test.TestCase):
...@@ -958,5 +1029,6 @@ class BoxRefinementTest(tf.test.TestCase): ...@@ -958,5 +1029,6 @@ class BoxRefinementTest(tf.test.TestCase):
self.assertAllClose(expected_scores, scores_out) self.assertAllClose(expected_scores, scores_out)
self.assertAllEqual(extra_field_out, [0, 1, 1]) self.assertAllEqual(extra_field_out, [0, 1, 1])
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -50,8 +50,10 @@ class Loss(object): ...@@ -50,8 +50,10 @@ class Loss(object):
"""Call the loss function. """Call the loss function.
Args: Args:
prediction_tensor: a tensor representing predicted quantities. prediction_tensor: an N-d tensor of shape [batch, anchors, ...]
target_tensor: a tensor representing regression or classification targets. representing predicted quantities.
target_tensor: an N-d tensor of shape [batch, anchors, ...] representing
regression or classification targets.
ignore_nan_targets: whether to ignore nan targets in the loss computation. ignore_nan_targets: whether to ignore nan targets in the loss computation.
E.g. can be used if the target tensor is missing groundtruth data that E.g. can be used if the target tensor is missing groundtruth data that
shouldn't be factored into the loss. shouldn't be factored into the loss.
...@@ -81,7 +83,8 @@ class Loss(object): ...@@ -81,7 +83,8 @@ class Loss(object):
the Loss. the Loss.
Returns: Returns:
loss: a tensor representing the value of the loss function loss: an N-d tensor of shape [batch, anchors, ...] containing the loss per
anchor
""" """
pass pass
...@@ -92,15 +95,6 @@ class WeightedL2LocalizationLoss(Loss): ...@@ -92,15 +95,6 @@ class WeightedL2LocalizationLoss(Loss):
Loss[b,a] = .5 * ||weights[b,a] * (prediction[b,a,:] - target[b,a,:])||^2 Loss[b,a] = .5 * ||weights[b,a] * (prediction[b,a,:] - target[b,a,:])||^2
""" """
def __init__(self, anchorwise_output=False):
"""Constructor.
Args:
anchorwise_output: Outputs loss per anchor. (default False)
"""
self._anchorwise_output = anchorwise_output
def _compute_loss(self, prediction_tensor, target_tensor, weights): def _compute_loss(self, prediction_tensor, target_tensor, weights):
"""Compute loss function. """Compute loss function.
...@@ -112,15 +106,13 @@ class WeightedL2LocalizationLoss(Loss): ...@@ -112,15 +106,13 @@ class WeightedL2LocalizationLoss(Loss):
weights: a float tensor of shape [batch_size, num_anchors] weights: a float tensor of shape [batch_size, num_anchors]
Returns: Returns:
loss: a (scalar) tensor representing the value of the loss function loss: a float tensor of shape [batch_size, num_anchors] tensor
or a float tensor of shape [batch_size, num_anchors] representing the value of the loss function.
""" """
weighted_diff = (prediction_tensor - target_tensor) * tf.expand_dims( weighted_diff = (prediction_tensor - target_tensor) * tf.expand_dims(
weights, 2) weights, 2)
square_diff = 0.5 * tf.square(weighted_diff) square_diff = 0.5 * tf.square(weighted_diff)
if self._anchorwise_output:
return tf.reduce_sum(square_diff, 2) return tf.reduce_sum(square_diff, 2)
return tf.reduce_sum(square_diff)
class WeightedSmoothL1LocalizationLoss(Loss): class WeightedSmoothL1LocalizationLoss(Loss):
...@@ -132,15 +124,6 @@ class WeightedSmoothL1LocalizationLoss(Loss): ...@@ -132,15 +124,6 @@ class WeightedSmoothL1LocalizationLoss(Loss):
See also Equation (3) in the Fast R-CNN paper by Ross Girshick (ICCV 2015) See also Equation (3) in the Fast R-CNN paper by Ross Girshick (ICCV 2015)
""" """
def __init__(self, anchorwise_output=False):
"""Constructor.
Args:
anchorwise_output: Outputs loss per anchor. (default False)
"""
self._anchorwise_output = anchorwise_output
def _compute_loss(self, prediction_tensor, target_tensor, weights): def _compute_loss(self, prediction_tensor, target_tensor, weights):
"""Compute loss function. """Compute loss function.
...@@ -152,7 +135,8 @@ class WeightedSmoothL1LocalizationLoss(Loss): ...@@ -152,7 +135,8 @@ class WeightedSmoothL1LocalizationLoss(Loss):
weights: a float tensor of shape [batch_size, num_anchors] weights: a float tensor of shape [batch_size, num_anchors]
Returns: Returns:
loss: a (scalar) tensor representing the value of the loss function loss: a float tensor of shape [batch_size, num_anchors] tensor
representing the value of the loss function.
""" """
diff = prediction_tensor - target_tensor diff = prediction_tensor - target_tensor
abs_diff = tf.abs(diff) abs_diff = tf.abs(diff)
...@@ -160,9 +144,7 @@ class WeightedSmoothL1LocalizationLoss(Loss): ...@@ -160,9 +144,7 @@ class WeightedSmoothL1LocalizationLoss(Loss):
anchorwise_smooth_l1norm = tf.reduce_sum( anchorwise_smooth_l1norm = tf.reduce_sum(
tf.where(abs_diff_lt_1, 0.5 * tf.square(abs_diff), abs_diff - 0.5), tf.where(abs_diff_lt_1, 0.5 * tf.square(abs_diff), abs_diff - 0.5),
2) * weights 2) * weights
if self._anchorwise_output:
return anchorwise_smooth_l1norm return anchorwise_smooth_l1norm
return tf.reduce_sum(anchorwise_smooth_l1norm)
class WeightedIOULocalizationLoss(Loss): class WeightedIOULocalizationLoss(Loss):
...@@ -184,27 +166,19 @@ class WeightedIOULocalizationLoss(Loss): ...@@ -184,27 +166,19 @@ class WeightedIOULocalizationLoss(Loss):
weights: a float tensor of shape [batch_size, num_anchors] weights: a float tensor of shape [batch_size, num_anchors]
Returns: Returns:
loss: a (scalar) tensor representing the value of the loss function loss: a float tensor of shape [batch_size, num_anchors] tensor
representing the value of the loss function.
""" """
predicted_boxes = box_list.BoxList(tf.reshape(prediction_tensor, [-1, 4])) predicted_boxes = box_list.BoxList(tf.reshape(prediction_tensor, [-1, 4]))
target_boxes = box_list.BoxList(tf.reshape(target_tensor, [-1, 4])) target_boxes = box_list.BoxList(tf.reshape(target_tensor, [-1, 4]))
per_anchor_iou_loss = 1.0 - box_list_ops.matched_iou(predicted_boxes, per_anchor_iou_loss = 1.0 - box_list_ops.matched_iou(predicted_boxes,
target_boxes) target_boxes)
return tf.reduce_sum(tf.reshape(weights, [-1]) * per_anchor_iou_loss) return tf.reshape(weights, [-1]) * per_anchor_iou_loss
class WeightedSigmoidClassificationLoss(Loss): class WeightedSigmoidClassificationLoss(Loss):
"""Sigmoid cross entropy classification loss function.""" """Sigmoid cross entropy classification loss function."""
def __init__(self, anchorwise_output=False):
"""Constructor.
Args:
anchorwise_output: Outputs loss per anchor. (default False)
"""
self._anchorwise_output = anchorwise_output
def _compute_loss(self, def _compute_loss(self,
prediction_tensor, prediction_tensor,
target_tensor, target_tensor,
...@@ -222,8 +196,8 @@ class WeightedSigmoidClassificationLoss(Loss): ...@@ -222,8 +196,8 @@ class WeightedSigmoidClassificationLoss(Loss):
If provided, computes loss only for the specified class indices. If provided, computes loss only for the specified class indices.
Returns: Returns:
loss: a (scalar) tensor representing the value of the loss function loss: a float tensor of shape [batch_size, num_anchors, num_classes]
or a float tensor of shape [batch_size, num_anchors] representing the value of the loss function.
""" """
weights = tf.expand_dims(weights, 2) weights = tf.expand_dims(weights, 2)
if class_indices is not None: if class_indices is not None:
...@@ -233,9 +207,7 @@ class WeightedSigmoidClassificationLoss(Loss): ...@@ -233,9 +207,7 @@ class WeightedSigmoidClassificationLoss(Loss):
[1, 1, -1]) [1, 1, -1])
per_entry_cross_ent = (tf.nn.sigmoid_cross_entropy_with_logits( per_entry_cross_ent = (tf.nn.sigmoid_cross_entropy_with_logits(
labels=target_tensor, logits=prediction_tensor)) labels=target_tensor, logits=prediction_tensor))
if self._anchorwise_output: return per_entry_cross_ent * weights
return tf.reduce_sum(per_entry_cross_ent * weights, 2)
return tf.reduce_sum(per_entry_cross_ent * weights)
class SigmoidFocalClassificationLoss(Loss): class SigmoidFocalClassificationLoss(Loss):
...@@ -245,15 +217,13 @@ class SigmoidFocalClassificationLoss(Loss): ...@@ -245,15 +217,13 @@ class SigmoidFocalClassificationLoss(Loss):
examples. See https://arxiv.org/pdf/1708.02002.pdf for the loss definition. examples. See https://arxiv.org/pdf/1708.02002.pdf for the loss definition.
""" """
def __init__(self, anchorwise_output=False, gamma=2.0, alpha=0.25): def __init__(self, gamma=2.0, alpha=0.25):
"""Constructor. """Constructor.
Args: Args:
anchorwise_output: Outputs loss per anchor. (default False)
gamma: exponent of the modulating factor (1 - p_t) ^ gamma. gamma: exponent of the modulating factor (1 - p_t) ^ gamma.
alpha: optional alpha weighting factor to balance positives vs negatives. alpha: optional alpha weighting factor to balance positives vs negatives.
""" """
self._anchorwise_output = anchorwise_output
self._alpha = alpha self._alpha = alpha
self._gamma = gamma self._gamma = gamma
...@@ -274,8 +244,8 @@ class SigmoidFocalClassificationLoss(Loss): ...@@ -274,8 +244,8 @@ class SigmoidFocalClassificationLoss(Loss):
If provided, computes loss only for the specified class indices. If provided, computes loss only for the specified class indices.
Returns: Returns:
loss: a (scalar) tensor representing the value of the loss function loss: a float tensor of shape [batch_size, num_anchors, num_classes]
or a float tensor of shape [batch_size, num_anchors] representing the value of the loss function.
""" """
weights = tf.expand_dims(weights, 2) weights = tf.expand_dims(weights, 2)
if class_indices is not None: if class_indices is not None:
...@@ -297,25 +267,21 @@ class SigmoidFocalClassificationLoss(Loss): ...@@ -297,25 +267,21 @@ class SigmoidFocalClassificationLoss(Loss):
(1 - target_tensor) * (1 - self._alpha)) (1 - target_tensor) * (1 - self._alpha))
focal_cross_entropy_loss = (modulating_factor * alpha_weight_factor * focal_cross_entropy_loss = (modulating_factor * alpha_weight_factor *
per_entry_cross_ent) per_entry_cross_ent)
if self._anchorwise_output: return focal_cross_entropy_loss * weights
return tf.reduce_sum(focal_cross_entropy_loss * weights, 2)
return tf.reduce_sum(focal_cross_entropy_loss * weights)
class WeightedSoftmaxClassificationLoss(Loss): class WeightedSoftmaxClassificationLoss(Loss):
"""Softmax loss function.""" """Softmax loss function."""
def __init__(self, anchorwise_output=False, logit_scale=1.0): def __init__(self, logit_scale=1.0):
"""Constructor. """Constructor.
Args: Args:
anchorwise_output: Whether to output loss per anchor (default False)
logit_scale: When this value is high, the prediction is "diffused" and logit_scale: When this value is high, the prediction is "diffused" and
when this value is low, the prediction is made peakier. when this value is low, the prediction is made peakier.
(default 1.0) (default 1.0)
""" """
self._anchorwise_output = anchorwise_output
self._logit_scale = logit_scale self._logit_scale = logit_scale
def _compute_loss(self, prediction_tensor, target_tensor, weights): def _compute_loss(self, prediction_tensor, target_tensor, weights):
...@@ -329,7 +295,8 @@ class WeightedSoftmaxClassificationLoss(Loss): ...@@ -329,7 +295,8 @@ class WeightedSoftmaxClassificationLoss(Loss):
weights: a float tensor of shape [batch_size, num_anchors] weights: a float tensor of shape [batch_size, num_anchors]
Returns: Returns:
loss: a (scalar) tensor representing the value of the loss function loss: a float tensor of shape [batch_size, num_anchors]
representing the value of the loss function.
""" """
num_classes = prediction_tensor.get_shape().as_list()[-1] num_classes = prediction_tensor.get_shape().as_list()[-1]
prediction_tensor = tf.divide( prediction_tensor = tf.divide(
...@@ -337,9 +304,7 @@ class WeightedSoftmaxClassificationLoss(Loss): ...@@ -337,9 +304,7 @@ class WeightedSoftmaxClassificationLoss(Loss):
per_row_cross_ent = (tf.nn.softmax_cross_entropy_with_logits( per_row_cross_ent = (tf.nn.softmax_cross_entropy_with_logits(
labels=tf.reshape(target_tensor, [-1, num_classes]), labels=tf.reshape(target_tensor, [-1, num_classes]),
logits=tf.reshape(prediction_tensor, [-1, num_classes]))) logits=tf.reshape(prediction_tensor, [-1, num_classes])))
if self._anchorwise_output:
return tf.reshape(per_row_cross_ent, tf.shape(weights)) * weights return tf.reshape(per_row_cross_ent, tf.shape(weights)) * weights
return tf.reduce_sum(per_row_cross_ent * tf.reshape(weights, [-1]))
class BootstrappedSigmoidClassificationLoss(Loss): class BootstrappedSigmoidClassificationLoss(Loss):
...@@ -359,14 +324,13 @@ class BootstrappedSigmoidClassificationLoss(Loss): ...@@ -359,14 +324,13 @@ class BootstrappedSigmoidClassificationLoss(Loss):
Reed et al. (ICLR 2015). Reed et al. (ICLR 2015).
""" """
def __init__(self, alpha, bootstrap_type='soft', anchorwise_output=False): def __init__(self, alpha, bootstrap_type='soft'):
"""Constructor. """Constructor.
Args: Args:
alpha: a float32 scalar tensor between 0 and 1 representing interpolation alpha: a float32 scalar tensor between 0 and 1 representing interpolation
weight weight
bootstrap_type: set to either 'hard' or 'soft' (default) bootstrap_type: set to either 'hard' or 'soft' (default)
anchorwise_output: Outputs loss per anchor. (default False)
Raises: Raises:
ValueError: if bootstrap_type is not either 'hard' or 'soft' ValueError: if bootstrap_type is not either 'hard' or 'soft'
...@@ -376,7 +340,6 @@ class BootstrappedSigmoidClassificationLoss(Loss): ...@@ -376,7 +340,6 @@ class BootstrappedSigmoidClassificationLoss(Loss):
'\'hard\' or \'soft.\'') '\'hard\' or \'soft.\'')
self._alpha = alpha self._alpha = alpha
self._bootstrap_type = bootstrap_type self._bootstrap_type = bootstrap_type
self._anchorwise_output = anchorwise_output
def _compute_loss(self, prediction_tensor, target_tensor, weights): def _compute_loss(self, prediction_tensor, target_tensor, weights):
"""Compute loss function. """Compute loss function.
...@@ -389,8 +352,8 @@ class BootstrappedSigmoidClassificationLoss(Loss): ...@@ -389,8 +352,8 @@ class BootstrappedSigmoidClassificationLoss(Loss):
weights: a float tensor of shape [batch_size, num_anchors] weights: a float tensor of shape [batch_size, num_anchors]
Returns: Returns:
loss: a (scalar) tensor representing the value of the loss function loss: a float tensor of shape [batch_size, num_anchors, num_classes]
or a float tensor of shape [batch_size, num_anchors] representing the value of the loss function.
""" """
if self._bootstrap_type == 'soft': if self._bootstrap_type == 'soft':
bootstrap_target_tensor = self._alpha * target_tensor + ( bootstrap_target_tensor = self._alpha * target_tensor + (
...@@ -401,9 +364,7 @@ class BootstrappedSigmoidClassificationLoss(Loss): ...@@ -401,9 +364,7 @@ class BootstrappedSigmoidClassificationLoss(Loss):
tf.sigmoid(prediction_tensor) > 0.5, tf.float32) tf.sigmoid(prediction_tensor) > 0.5, tf.float32)
per_entry_cross_ent = (tf.nn.sigmoid_cross_entropy_with_logits( per_entry_cross_ent = (tf.nn.sigmoid_cross_entropy_with_logits(
labels=bootstrap_target_tensor, logits=prediction_tensor)) labels=bootstrap_target_tensor, logits=prediction_tensor))
if self._anchorwise_output: return per_entry_cross_ent * tf.expand_dims(weights, 2)
return tf.reduce_sum(per_entry_cross_ent * tf.expand_dims(weights, 2), 2)
return tf.reduce_sum(per_entry_cross_ent * tf.expand_dims(weights, 2))
class HardExampleMiner(object): class HardExampleMiner(object):
......
...@@ -36,6 +36,8 @@ from abc import abstractmethod ...@@ -36,6 +36,8 @@ from abc import abstractmethod
import tensorflow as tf import tensorflow as tf
from object_detection.utils import ops
class Match(object): class Match(object):
"""Class to store results from the matcher. """Class to store results from the matcher.
...@@ -44,7 +46,7 @@ class Match(object): ...@@ -44,7 +46,7 @@ class Match(object):
convenient methods to query the matching results. convenient methods to query the matching results.
""" """
def __init__(self, match_results): def __init__(self, match_results, use_matmul_gather=False):
"""Constructs a Match object. """Constructs a Match object.
Args: Args:
...@@ -52,6 +54,8 @@ class Match(object): ...@@ -52,6 +54,8 @@ class Match(object):
meaning that column i is matched with row match_results[i]. meaning that column i is matched with row match_results[i].
(2) match_results[i]=-1, meaning that column i is not matched. (2) match_results[i]=-1, meaning that column i is not matched.
(3) match_results[i]=-2, meaning that column i is ignored. (3) match_results[i]=-2, meaning that column i is ignored.
use_matmul_gather: Use matrix multiplication based gather instead of
standard tf.gather. (Default: False).
Raises: Raises:
ValueError: if match_results does not have rank 1 or is not an ValueError: if match_results does not have rank 1 or is not an
...@@ -63,6 +67,9 @@ class Match(object): ...@@ -63,6 +67,9 @@ class Match(object):
raise ValueError('match_results should be an int32 or int64 scalar ' raise ValueError('match_results should be an int32 or int64 scalar '
'tensor') 'tensor')
self._match_results = match_results self._match_results = match_results
self._gather_op = tf.gather
if use_matmul_gather:
self._gather_op = ops.matmul_gather_on_zeroth_axis
@property @property
def match_results(self): def match_results(self):
...@@ -163,17 +170,55 @@ class Match(object): ...@@ -163,17 +170,55 @@ class Match(object):
row_indices: int32 tensor of shape [K] with row indices. row_indices: int32 tensor of shape [K] with row indices.
""" """
return self._reshape_and_cast( return self._reshape_and_cast(
tf.gather(self._match_results, self.matched_column_indices())) self._gather_op(self._match_results, self.matched_column_indices()))
def _reshape_and_cast(self, t): def _reshape_and_cast(self, t):
return tf.cast(tf.reshape(t, [-1]), tf.int32) return tf.cast(tf.reshape(t, [-1]), tf.int32)
def gather_based_on_match(self, input_tensor, unmatched_value,
ignored_value):
"""Gathers elements from `input_tensor` based on match results.
For columns that are matched to a row, gathered_tensor[col] is set to
input_tensor[match_results[col]]. For columns that are unmatched,
gathered_tensor[col] is set to unmatched_value. Finally, for columns that
are ignored gathered_tensor[col] is set to ignored_value.
Note that the input_tensor.shape[1:] must match with unmatched_value.shape
and ignored_value.shape
Args:
input_tensor: Tensor to gather values from.
unmatched_value: Constant tensor value for unmatched columns.
ignored_value: Constant tensor value for ignored columns.
Returns:
gathered_tensor: A tensor containing values gathered from input_tensor.
The shape of the gathered tensor is [match_results.shape[0]] +
input_tensor.shape[1:].
"""
input_tensor = tf.concat([tf.stack([ignored_value, unmatched_value]),
input_tensor], axis=0)
gather_indices = tf.maximum(self.match_results + 2, 0)
gathered_tensor = self._gather_op(input_tensor, gather_indices)
return gathered_tensor
class Matcher(object): class Matcher(object):
"""Abstract base class for matcher. """Abstract base class for matcher.
""" """
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
def __init__(self, use_matmul_gather=False):
"""Constructs a Matcher.
Args:
use_matmul_gather: Force constructed match objects to use matrix
multiplication based gather instead of standard tf.gather.
(Default: False).
"""
self._use_matmul_gather = use_matmul_gather
def match(self, similarity_matrix, scope=None, **params): def match(self, similarity_matrix, scope=None, **params):
"""Computes matches among row and column indices and returns the result. """Computes matches among row and column indices and returns the result.
...@@ -191,11 +236,12 @@ class Matcher(object): ...@@ -191,11 +236,12 @@ class Matcher(object):
A Match object with the results of matching. A Match object with the results of matching.
""" """
with tf.name_scope(scope, 'Match', [similarity_matrix, params]) as scope: with tf.name_scope(scope, 'Match', [similarity_matrix, params]) as scope:
return Match(self._match(similarity_matrix, **params)) return Match(self._match(similarity_matrix, **params),
self._use_matmul_gather)
@abstractmethod @abstractmethod
def _match(self, similarity_matrix, **params): def _match(self, similarity_matrix, **params):
"""Method to be overriden by implementations. """Method to be overridden by implementations.
Args: Args:
similarity_matrix: Float tensor of shape [N, M] with pairwise similarity similarity_matrix: Float tensor of shape [N, M] with pairwise similarity
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment