"vscode:/vscode.git/clone" did not exist on "e5352297d97b96101a7bd6944de420ed17ae62d3"
Unverified Commit fd7b6887 authored by Jonathan Huang's avatar Jonathan Huang Committed by GitHub
Browse files

Merge pull request #3293 from pkulzc/master

Internal changes of object_detection 
parents f98ec55e 1efe98bb
......@@ -116,18 +116,17 @@ def build_faster_rcnn_classification_loss(loss_config):
loss_type = loss_config.WhichOneof('classification_loss')
if loss_type == 'weighted_sigmoid':
config = loss_config.weighted_sigmoid
return losses.WeightedSigmoidClassificationLoss(
anchorwise_output=config.anchorwise_output)
return losses.WeightedSigmoidClassificationLoss()
if loss_type == 'weighted_softmax':
config = loss_config.weighted_softmax
return losses.WeightedSoftmaxClassificationLoss(
anchorwise_output=config.anchorwise_output)
logit_scale=config.logit_scale)
# By default, Faster RCNN second stage classifier uses Softmax loss
# with anchor-wise outputs.
config = loss_config.weighted_softmax
return losses.WeightedSoftmaxClassificationLoss(
anchorwise_output=True)
logit_scale=config.logit_scale)
def _build_localization_loss(loss_config):
......@@ -148,14 +147,10 @@ def _build_localization_loss(loss_config):
loss_type = loss_config.WhichOneof('localization_loss')
if loss_type == 'weighted_l2':
config = loss_config.weighted_l2
return losses.WeightedL2LocalizationLoss(
anchorwise_output=config.anchorwise_output)
return losses.WeightedL2LocalizationLoss()
if loss_type == 'weighted_smooth_l1':
config = loss_config.weighted_smooth_l1
return losses.WeightedSmoothL1LocalizationLoss(
anchorwise_output=config.anchorwise_output)
return losses.WeightedSmoothL1LocalizationLoss()
if loss_type == 'weighted_iou':
return losses.WeightedIOULocalizationLoss()
......@@ -181,9 +176,7 @@ def _build_classification_loss(loss_config):
loss_type = loss_config.WhichOneof('classification_loss')
if loss_type == 'weighted_sigmoid':
config = loss_config.weighted_sigmoid
return losses.WeightedSigmoidClassificationLoss(
anchorwise_output=config.anchorwise_output)
return losses.WeightedSigmoidClassificationLoss()
if loss_type == 'weighted_sigmoid_focal':
config = loss_config.weighted_sigmoid_focal
......@@ -191,21 +184,18 @@ def _build_classification_loss(loss_config):
if config.HasField('alpha'):
alpha = config.alpha
return losses.SigmoidFocalClassificationLoss(
anchorwise_output=config.anchorwise_output,
gamma=config.gamma,
alpha=alpha)
if loss_type == 'weighted_softmax':
config = loss_config.weighted_softmax
return losses.WeightedSoftmaxClassificationLoss(
anchorwise_output=config.anchorwise_output,
logit_scale=config.logit_scale)
if loss_type == 'bootstrapped_sigmoid':
config = loss_config.bootstrapped_sigmoid
return losses.BootstrappedSigmoidClassificationLoss(
alpha=config.alpha,
bootstrap_type=('hard' if config.hard_bootstrap else 'soft'),
anchorwise_output=config.anchorwise_output)
bootstrap_type=('hard' if config.hard_bootstrap else 'soft'))
raise ValueError('Empty loss config.')
......@@ -80,7 +80,6 @@ class LocalizationLossBuilderTest(tf.test.TestCase):
losses_text_proto = """
localization_loss {
weighted_smooth_l1 {
anchorwise_output: true
}
}
classification_loss {
......@@ -245,7 +244,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
targets = tf.constant([[[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]])
weights = tf.constant([[1.0, 1.0]])
loss = classification_loss(predictions, targets, weights=weights)
self.assertEqual(loss.shape, [1, 2])
self.assertEqual(loss.shape, [1, 2, 3])
def test_raise_error_on_empty_config(self):
losses_text_proto = """
......
......@@ -45,7 +45,9 @@ def build(matcher_config):
matched_threshold=matched_threshold,
unmatched_threshold=unmatched_threshold,
negatives_lower_than_unmatched=matcher.negatives_lower_than_unmatched,
force_match_for_each_row=matcher.force_match_for_each_row)
force_match_for_each_row=matcher.force_match_for_each_row,
use_matmul_gather=matcher.use_matmul_gather)
if matcher_config.WhichOneof('matcher_oneof') == 'bipartite_matcher':
return bipartite_matcher.GreedyBipartiteMatcher()
matcher = matcher_config.bipartite_matcher
return bipartite_matcher.GreedyBipartiteMatcher(matcher.use_matmul_gather)
raise ValueError('Empty matcher.')
......@@ -62,6 +62,7 @@ class MatcherBuilderTest(tf.test.TestCase):
unmatched_threshold: 0.3
negatives_lower_than_unmatched: false
force_match_for_each_row: true
use_matmul_gather: true
}
"""
matcher_proto = matcher_pb2.Matcher()
......@@ -72,6 +73,7 @@ class MatcherBuilderTest(tf.test.TestCase):
self.assertAlmostEqual(matcher_object._unmatched_threshold, 0.3)
self.assertFalse(matcher_object._negatives_lower_than_unmatched)
self.assertTrue(matcher_object._force_match_for_each_row)
self.assertTrue(matcher_object._use_matmul_gather)
def test_build_bipartite_matcher(self):
matcher_text_proto = """
......
......@@ -31,6 +31,7 @@ from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extr
from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2
from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas
from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1
from object_detection.models import ssd_resnet_v1_fpn_feature_extractor as ssd_resnet_v1_fpn
from object_detection.models.embedded_ssd_mobilenet_v1_feature_extractor import EmbeddedSSDMobileNetV1FeatureExtractor
from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor
from object_detection.models.ssd_inception_v3_feature_extractor import SSDInceptionV3FeatureExtractor
......@@ -42,6 +43,9 @@ SSD_FEATURE_EXTRACTOR_CLASS_MAP = {
'ssd_inception_v2': SSDInceptionV2FeatureExtractor,
'ssd_inception_v3': SSDInceptionV3FeatureExtractor,
'ssd_mobilenet_v1': SSDMobileNetV1FeatureExtractor,
'ssd_resnet50_v1_fpn': ssd_resnet_v1_fpn.SSDResnet50V1FpnFeatureExtractor,
'ssd_resnet101_v1_fpn': ssd_resnet_v1_fpn.SSDResnet101V1FpnFeatureExtractor,
'ssd_resnet152_v1_fpn': ssd_resnet_v1_fpn.SSDResnet152V1FpnFeatureExtractor,
'embedded_ssd_mobilenet_v1': EmbeddedSSDMobileNetV1FeatureExtractor,
}
......@@ -62,13 +66,14 @@ FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP = {
}
def build(model_config, is_training):
def build(model_config, is_training, add_summaries=True):
"""Builds a DetectionModel based on the model config.
Args:
model_config: A model.proto object containing the config for the desired
DetectionModel.
is_training: True if this model is being built for training purposes.
add_summaries: Whether to add tensorflow summaries in the model graph.
Returns:
DetectionModel based on the config.
......@@ -80,9 +85,10 @@ def build(model_config, is_training):
raise ValueError('model_config not of type model_pb2.DetectionModel.')
meta_architecture = model_config.WhichOneof('model')
if meta_architecture == 'ssd':
return _build_ssd_model(model_config.ssd, is_training)
return _build_ssd_model(model_config.ssd, is_training, add_summaries)
if meta_architecture == 'faster_rcnn':
return _build_faster_rcnn_model(model_config.faster_rcnn, is_training)
return _build_faster_rcnn_model(model_config.faster_rcnn, is_training,
add_summaries)
raise ValueError('Unknown meta architecture: {}'.format(meta_architecture))
......@@ -106,6 +112,8 @@ def _build_ssd_feature_extractor(feature_extractor_config, is_training,
min_depth = feature_extractor_config.min_depth
pad_to_multiple = feature_extractor_config.pad_to_multiple
batch_norm_trainable = feature_extractor_config.batch_norm_trainable
use_explicit_padding = feature_extractor_config.use_explicit_padding
use_depthwise = feature_extractor_config.use_depthwise
conv_hyperparams = hyperparams_builder.build(
feature_extractor_config.conv_hyperparams, is_training)
......@@ -115,16 +123,18 @@ def _build_ssd_feature_extractor(feature_extractor_config, is_training,
feature_extractor_class = SSD_FEATURE_EXTRACTOR_CLASS_MAP[feature_type]
return feature_extractor_class(is_training, depth_multiplier, min_depth,
pad_to_multiple, conv_hyperparams,
batch_norm_trainable, reuse_weights)
batch_norm_trainable, reuse_weights,
use_explicit_padding, use_depthwise)
def _build_ssd_model(ssd_config, is_training):
def _build_ssd_model(ssd_config, is_training, add_summaries):
"""Builds an SSD detection model based on the model config.
Args:
ssd_config: A ssd.proto object containing the config for the desired
SSDMetaArch.
is_training: True if this model is being built for training purposes.
add_summaries: Whether to add tf summaries in the model.
Returns:
SSDMetaArch based on the config.
......@@ -171,7 +181,8 @@ def _build_ssd_model(ssd_config, is_training):
classification_weight,
localization_weight,
normalize_loss_by_num_matches,
hard_example_miner)
hard_example_miner,
add_summaries=add_summaries)
def _build_faster_rcnn_feature_extractor(
......@@ -205,7 +216,7 @@ def _build_faster_rcnn_feature_extractor(
batch_norm_trainable, reuse_weights)
def _build_faster_rcnn_model(frcnn_config, is_training):
def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries):
"""Builds a Faster R-CNN or R-FCN detection model based on the model config.
Builds R-FCN model if the second_stage_box_predictor in the config is of type
......@@ -213,8 +224,9 @@ def _build_faster_rcnn_model(frcnn_config, is_training):
Args:
frcnn_config: A faster_rcnn.proto object containing the config for the
desired FasterRCNNMetaArch or RFCNMetaArch.
desired FasterRCNNMetaArch or RFCNMetaArch.
is_training: True if this model is being built for training purposes.
add_summaries: Whether to add tf summaries in the model.
Returns:
FasterRCNNMetaArch based on the config.
......@@ -228,7 +240,7 @@ def _build_faster_rcnn_model(frcnn_config, is_training):
feature_extractor = _build_faster_rcnn_feature_extractor(
frcnn_config.feature_extractor, is_training)
first_stage_only = frcnn_config.first_stage_only
number_of_stages = frcnn_config.number_of_stages
first_stage_anchor_generator = anchor_generator_builder.build(
frcnn_config.first_stage_anchor_generator)
......@@ -283,7 +295,7 @@ def _build_faster_rcnn_model(frcnn_config, is_training):
'num_classes': num_classes,
'image_resizer_fn': image_resizer_fn,
'feature_extractor': feature_extractor,
'first_stage_only': first_stage_only,
'number_of_stages': number_of_stages,
'first_stage_anchor_generator': first_stage_anchor_generator,
'first_stage_atrous_rate': first_stage_atrous_rate,
'first_stage_box_predictor_arg_scope':
......@@ -310,7 +322,8 @@ def _build_faster_rcnn_model(frcnn_config, is_training):
second_stage_classification_loss,
'second_stage_classification_loss_weight':
second_stage_classification_loss_weight,
'hard_example_miner': hard_example_miner}
'hard_example_miner': hard_example_miner,
'add_summaries': add_summaries}
if isinstance(second_stage_box_predictor, box_predictor.RfcnBoxPredictor):
return rfcn_meta_arch.RFCNMetaArch(
......
......@@ -26,12 +26,14 @@ from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extr
from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2
from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas
from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1
from object_detection.models import ssd_resnet_v1_fpn_feature_extractor as ssd_resnet_v1_fpn
from object_detection.models.embedded_ssd_mobilenet_v1_feature_extractor import EmbeddedSSDMobileNetV1FeatureExtractor
from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor
from object_detection.models.ssd_inception_v3_feature_extractor import SSDInceptionV3FeatureExtractor
from object_detection.models.ssd_mobilenet_v1_feature_extractor import SSDMobileNetV1FeatureExtractor
from object_detection.protos import model_pb2
FEATURE_EXTRACTOR_MAPS = {
FRCNN_RESNET_FEAT_MAPS = {
'faster_rcnn_resnet50':
frcnn_resnet_v1.FasterRCNNResnet50FeatureExtractor,
'faster_rcnn_resnet101':
......@@ -40,6 +42,15 @@ FEATURE_EXTRACTOR_MAPS = {
frcnn_resnet_v1.FasterRCNNResnet152FeatureExtractor
}
SSD_RESNET_V1_FPN_FEAT_MAPS = {
'ssd_resnet50_v1_fpn':
ssd_resnet_v1_fpn.SSDResnet50V1FpnFeatureExtractor,
'ssd_resnet101_v1_fpn':
ssd_resnet_v1_fpn.SSDResnet101V1FpnFeatureExtractor,
'ssd_resnet152_v1_fpn':
ssd_resnet_v1_fpn.SSDResnet152V1FpnFeatureExtractor
}
class ModelBuilderTest(tf.test.TestCase):
......@@ -197,6 +208,87 @@ class ModelBuilderTest(tf.test.TestCase):
self.assertIsInstance(model._feature_extractor,
SSDInceptionV3FeatureExtractor)
def test_create_ssd_resnet_v1_fpn_model_from_config(self):
model_text_proto = """
ssd {
feature_extractor {
type: 'ssd_resnet50_v1_fpn'
conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
batch_norm_trainable: true
}
box_coder {
faster_rcnn_box_coder {
}
}
matcher {
argmax_matcher {
}
}
similarity_calculator {
iou_similarity {
}
}
anchor_generator {
multiscale_anchor_generator {
aspect_ratios: [1.0, 2.0, 0.5]
scales_per_octave: 2
}
}
image_resizer {
fixed_shape_resizer {
height: 320
width: 320
}
}
box_predictor {
weight_shared_convolutional_box_predictor {
depth: 32
conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
num_layers_before_predictor: 1
}
}
loss {
classification_loss {
weighted_sigmoid_focal {
alpha: 0.25
gamma: 2.0
}
}
localization_loss {
weighted_smooth_l1 {
}
}
classification_weight: 1.0
localization_weight: 1.0
}
}"""
model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto)
for extractor_type, extractor_class in SSD_RESNET_V1_FPN_FEAT_MAPS.items():
model_proto.ssd.feature_extractor.type = extractor_type
model = model_builder.build(model_proto, is_training=True)
self.assertIsInstance(model, ssd_meta_arch.SSDMetaArch)
self.assertIsInstance(model._feature_extractor, extractor_class)
def test_create_ssd_mobilenet_v1_model_from_config(self):
model_text_proto = """
ssd {
......@@ -270,6 +362,78 @@ class ModelBuilderTest(tf.test.TestCase):
SSDMobileNetV1FeatureExtractor)
self.assertTrue(model._feature_extractor._batch_norm_trainable)
def test_create_embedded_ssd_mobilenet_v1_model_from_config(self):
model_text_proto = """
ssd {
feature_extractor {
type: 'embedded_ssd_mobilenet_v1'
conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
batch_norm_trainable: true
}
box_coder {
faster_rcnn_box_coder {
}
}
matcher {
argmax_matcher {
}
}
similarity_calculator {
iou_similarity {
}
}
anchor_generator {
ssd_anchor_generator {
aspect_ratios: 1.0
}
}
image_resizer {
fixed_shape_resizer {
height: 256
width: 256
}
}
box_predictor {
convolutional_box_predictor {
conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
}
}
loss {
classification_loss {
weighted_softmax {
}
}
localization_loss {
weighted_smooth_l1 {
}
}
}
}"""
model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto)
model = self.create_model(model_proto)
self.assertIsInstance(model, ssd_meta_arch.SSDMetaArch)
self.assertIsInstance(model._feature_extractor,
EmbeddedSSDMobileNetV1FeatureExtractor)
def test_create_faster_rcnn_resnet_v1_models_from_config(self):
model_text_proto = """
faster_rcnn {
......@@ -331,7 +495,7 @@ class ModelBuilderTest(tf.test.TestCase):
}"""
model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto)
for extractor_type, extractor_class in FEATURE_EXTRACTOR_MAPS.items():
for extractor_type, extractor_class in FRCNN_RESNET_FEAT_MAPS.items():
model_proto.faster_rcnn.feature_extractor.type = extractor_type
model = model_builder.build(model_proto, is_training=True)
self.assertIsInstance(model, faster_rcnn_meta_arch.FasterRCNNMetaArch)
......@@ -730,7 +894,7 @@ class ModelBuilderTest(tf.test.TestCase):
}"""
model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto)
for extractor_type, extractor_class in FEATURE_EXTRACTOR_MAPS.items():
for extractor_type, extractor_class in FRCNN_RESNET_FEAT_MAPS.items():
model_proto.faster_rcnn.feature_extractor.type = extractor_type
model = model_builder.build(model_proto, is_training=True)
self.assertIsInstance(model, rfcn_meta_arch.RFCNMetaArch)
......
......@@ -19,15 +19,14 @@ import tensorflow as tf
from object_detection.utils import learning_schedules
def build(optimizer_config, global_summaries):
def build(optimizer_config):
"""Create optimizer based on config.
Args:
optimizer_config: A Optimizer proto message.
global_summaries: A set to attach learning rate summary to.
Returns:
An optimizer.
An optimizer and a list of variables for summary.
Raises:
ValueError: when using an unsupported input data type.
......@@ -35,24 +34,30 @@ def build(optimizer_config, global_summaries):
optimizer_type = optimizer_config.WhichOneof('optimizer')
optimizer = None
summary_vars = []
if optimizer_type == 'rms_prop_optimizer':
config = optimizer_config.rms_prop_optimizer
learning_rate = _create_learning_rate(config.learning_rate)
summary_vars.append(learning_rate)
optimizer = tf.train.RMSPropOptimizer(
_create_learning_rate(config.learning_rate, global_summaries),
learning_rate,
decay=config.decay,
momentum=config.momentum_optimizer_value,
epsilon=config.epsilon)
if optimizer_type == 'momentum_optimizer':
config = optimizer_config.momentum_optimizer
learning_rate = _create_learning_rate(config.learning_rate)
summary_vars.append(learning_rate)
optimizer = tf.train.MomentumOptimizer(
_create_learning_rate(config.learning_rate, global_summaries),
learning_rate,
momentum=config.momentum_optimizer_value)
if optimizer_type == 'adam_optimizer':
config = optimizer_config.adam_optimizer
optimizer = tf.train.AdamOptimizer(
_create_learning_rate(config.learning_rate, global_summaries))
learning_rate = _create_learning_rate(config.learning_rate)
summary_vars.append(learning_rate)
optimizer = tf.train.AdamOptimizer(learning_rate)
if optimizer is None:
raise ValueError('Optimizer %s not supported.' % optimizer_type)
......@@ -61,15 +66,14 @@ def build(optimizer_config, global_summaries):
optimizer = tf.contrib.opt.MovingAverageOptimizer(
optimizer, average_decay=optimizer_config.moving_average_decay)
return optimizer
return optimizer, summary_vars
def _create_learning_rate(learning_rate_config, global_summaries):
def _create_learning_rate(learning_rate_config):
"""Create optimizer learning rate based on config.
Args:
learning_rate_config: A LearningRate proto message.
global_summaries: A set to attach learning rate summary to.
Returns:
A learning rate.
......@@ -81,7 +85,7 @@ def _create_learning_rate(learning_rate_config, global_summaries):
learning_rate_type = learning_rate_config.WhichOneof('learning_rate')
if learning_rate_type == 'constant_learning_rate':
config = learning_rate_config.constant_learning_rate
learning_rate = config.learning_rate
learning_rate = tf.constant(config.learning_rate, dtype=tf.float32)
if learning_rate_type == 'exponential_decay_learning_rate':
config = learning_rate_config.exponential_decay_learning_rate
......@@ -115,5 +119,4 @@ def _create_learning_rate(learning_rate_config, global_summaries):
if learning_rate is None:
raise ValueError('Learning_rate %s not supported.' % learning_rate_type)
global_summaries.add(tf.summary.scalar('Learning_Rate', learning_rate))
return learning_rate
......@@ -31,12 +31,13 @@ class LearningRateBuilderTest(tf.test.TestCase):
learning_rate: 0.004
}
"""
global_summaries = set([])
learning_rate_proto = optimizer_pb2.LearningRate()
text_format.Merge(learning_rate_text_proto, learning_rate_proto)
learning_rate = optimizer_builder._create_learning_rate(
learning_rate_proto, global_summaries)
self.assertAlmostEqual(learning_rate, 0.004)
learning_rate_proto)
with self.test_session():
learning_rate_out = learning_rate.eval()
self.assertAlmostEqual(learning_rate_out, 0.004)
def testBuildExponentialDecayLearningRate(self):
learning_rate_text_proto = """
......@@ -47,11 +48,10 @@ class LearningRateBuilderTest(tf.test.TestCase):
staircase: false
}
"""
global_summaries = set([])
learning_rate_proto = optimizer_pb2.LearningRate()
text_format.Merge(learning_rate_text_proto, learning_rate_proto)
learning_rate = optimizer_builder._create_learning_rate(
learning_rate_proto, global_summaries)
learning_rate_proto)
self.assertTrue(isinstance(learning_rate, tf.Tensor))
def testBuildManualStepLearningRate(self):
......@@ -67,11 +67,10 @@ class LearningRateBuilderTest(tf.test.TestCase):
}
}
"""
global_summaries = set([])
learning_rate_proto = optimizer_pb2.LearningRate()
text_format.Merge(learning_rate_text_proto, learning_rate_proto)
learning_rate = optimizer_builder._create_learning_rate(
learning_rate_proto, global_summaries)
learning_rate_proto)
self.assertTrue(isinstance(learning_rate, tf.Tensor))
def testBuildCosineDecayLearningRate(self):
......@@ -83,22 +82,19 @@ class LearningRateBuilderTest(tf.test.TestCase):
warmup_steps: 1000
}
"""
global_summaries = set([])
learning_rate_proto = optimizer_pb2.LearningRate()
text_format.Merge(learning_rate_text_proto, learning_rate_proto)
learning_rate = optimizer_builder._create_learning_rate(
learning_rate_proto, global_summaries)
learning_rate_proto)
self.assertTrue(isinstance(learning_rate, tf.Tensor))
def testRaiseErrorOnEmptyLearningRate(self):
learning_rate_text_proto = """
"""
global_summaries = set([])
learning_rate_proto = optimizer_pb2.LearningRate()
text_format.Merge(learning_rate_text_proto, learning_rate_proto)
with self.assertRaises(ValueError):
optimizer_builder._create_learning_rate(
learning_rate_proto, global_summaries)
optimizer_builder._create_learning_rate(learning_rate_proto)
class OptimizerBuilderTest(tf.test.TestCase):
......@@ -119,10 +115,9 @@ class OptimizerBuilderTest(tf.test.TestCase):
}
use_moving_average: false
"""
global_summaries = set([])
optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer = optimizer_builder.build(optimizer_proto, global_summaries)
optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertTrue(isinstance(optimizer, tf.train.RMSPropOptimizer))
def testBuildMomentumOptimizer(self):
......@@ -137,10 +132,9 @@ class OptimizerBuilderTest(tf.test.TestCase):
}
use_moving_average: false
"""
global_summaries = set([])
optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer = optimizer_builder.build(optimizer_proto, global_summaries)
optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertTrue(isinstance(optimizer, tf.train.MomentumOptimizer))
def testBuildAdamOptimizer(self):
......@@ -154,10 +148,9 @@ class OptimizerBuilderTest(tf.test.TestCase):
}
use_moving_average: false
"""
global_summaries = set([])
optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer = optimizer_builder.build(optimizer_proto, global_summaries)
optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertTrue(isinstance(optimizer, tf.train.AdamOptimizer))
def testBuildMovingAverageOptimizer(self):
......@@ -171,10 +164,9 @@ class OptimizerBuilderTest(tf.test.TestCase):
}
use_moving_average: True
"""
global_summaries = set([])
optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer = optimizer_builder.build(optimizer_proto, global_summaries)
optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertTrue(
isinstance(optimizer, tf.contrib.opt.MovingAverageOptimizer))
......@@ -190,23 +182,21 @@ class OptimizerBuilderTest(tf.test.TestCase):
use_moving_average: True
moving_average_decay: 0.2
"""
global_summaries = set([])
optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer = optimizer_builder.build(optimizer_proto, global_summaries)
optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertTrue(
isinstance(optimizer, tf.contrib.opt.MovingAverageOptimizer))
# TODO(rathodv): Find a way to not depend on the private members.
# TODO: Find a way to not depend on the private members.
self.assertAlmostEqual(optimizer._ema._decay, 0.2)
def testBuildEmptyOptimizer(self):
optimizer_text_proto = """
"""
global_summaries = set([])
optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto)
with self.assertRaises(ValueError):
optimizer_builder.build(optimizer_proto, global_summaries)
optimizer_builder.build(optimizer_proto)
if __name__ == '__main__':
......
......@@ -83,6 +83,7 @@ PREPROCESSING_FUNCTION_MAP = {
'random_jitter_boxes': preprocessor.random_jitter_boxes,
'random_crop_to_aspect_ratio': preprocessor.random_crop_to_aspect_ratio,
'random_black_patches': preprocessor.random_black_patches,
'rgb_to_gray': preprocessor.rgb_to_gray,
'scale_boxes_to_pixel_coordinates': (
preprocessor.scale_boxes_to_pixel_coordinates),
'subtract_channel_mean': preprocessor.subtract_channel_mean,
......
......@@ -379,6 +379,16 @@ class PreprocessorBuilderTest(tf.test.TestCase):
'new_width': 100,
'method': tf.image.ResizeMethod.BICUBIC})
def test_build_rgb_to_gray(self):
preprocessor_text_proto = """
rgb_to_gray {}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.rgb_to_gray)
self.assertEqual(args, {})
def test_build_subtract_channel_mean(self):
preprocessor_text_proto = """
subtract_channel_mean {
......
......@@ -53,7 +53,7 @@ py_library(
deps = [
":box_list",
"//tensorflow",
"//tensorflow_models/object_detection/utils:shape_utils",
"//tensorflow/models/research/object_detection/utils:shape_utils",
],
)
......@@ -113,7 +113,7 @@ py_library(
":box_list",
":box_list_ops",
"//tensorflow",
"//tensorflow_models/object_detection/utils:ops",
"//tensorflow/models/research/object_detection/utils:ops",
],
)
......@@ -123,6 +123,7 @@ py_library(
"matcher.py",
],
deps = [
"//tensorflow/models/research/object_detection/utils:ops",
],
)
......@@ -160,8 +161,17 @@ py_library(
":box_list",
":box_list_ops",
":keypoint_ops",
":preprocessor_cache",
":standard_fields",
"//tensorflow",
"//tensorflow/models/research/object_detection/utils:shape_utils",
],
)
py_library(
name = "preprocessor_cache",
srcs = [
"preprocessor_cache.py",
],
)
......@@ -172,6 +182,7 @@ py_test(
],
deps = [
":preprocessor",
":preprocessor_cache",
"//tensorflow",
],
)
......@@ -211,6 +222,7 @@ py_library(
":box_list_ops",
":standard_fields",
"//tensorflow",
"//tensorflow/models/research/object_detection/utils:shape_utils",
],
)
......@@ -232,15 +244,16 @@ py_library(
],
deps = [
":box_list",
":box_list_ops",
":matcher",
":region_similarity_calculator",
":standard_fields",
"//tensorflow",
"//tensorflow_models/object_detection/box_coders:faster_rcnn_box_coder",
"//tensorflow_models/object_detection/box_coders:mean_stddev_box_coder",
"//tensorflow_models/object_detection/core:box_coder",
"//tensorflow_models/object_detection/matchers:argmax_matcher",
"//tensorflow_models/object_detection/matchers:bipartite_matcher",
"//tensorflow/models/research/object_detection/box_coders:faster_rcnn_box_coder",
"//tensorflow/models/research/object_detection/box_coders:mean_stddev_box_coder",
"//tensorflow/models/research/object_detection/core:box_coder",
"//tensorflow/models/research/object_detection/matchers:argmax_matcher",
"//tensorflow/models/research/object_detection/matchers:bipartite_matcher",
"//tensorflow/models/research/object_detection/utils:shape_utils",
],
)
......@@ -254,8 +267,10 @@ py_test(
":region_similarity_calculator",
":target_assigner",
"//tensorflow",
"//tensorflow_models/object_detection/box_coders:mean_stddev_box_coder",
"//tensorflow_models/object_detection/matchers:bipartite_matcher",
"//tensorflow/models/research/object_detection/box_coders:keypoint_box_coder",
"//tensorflow/models/research/object_detection/box_coders:mean_stddev_box_coder",
"//tensorflow/models/research/object_detection/matchers:bipartite_matcher",
"//tensorflow/models/research/object_detection/utils:test_case",
],
)
......@@ -274,9 +289,9 @@ py_library(
srcs = ["box_predictor.py"],
deps = [
"//tensorflow",
"//tensorflow_models/object_detection/utils:ops",
"//tensorflow_models/object_detection/utils:shape_utils",
"//tensorflow_models/object_detection/utils:static_shape",
"//tensorflow/models/research/object_detection/utils:ops",
"//tensorflow/models/research/object_detection/utils:shape_utils",
"//tensorflow/models/research/object_detection/utils:static_shape",
],
)
......@@ -286,8 +301,9 @@ py_test(
deps = [
":box_predictor",
"//tensorflow",
"//tensorflow_models/object_detection/builders:hyperparams_builder",
"//tensorflow_models/object_detection/protos:hyperparams_py_pb2",
"//tensorflow/models/research/object_detection/builders:hyperparams_builder",
"//tensorflow/models/research/object_detection/protos:hyperparams_py_pb2",
"//tensorflow/models/research/object_detection/utils:test_case",
],
)
......@@ -298,7 +314,7 @@ py_library(
],
deps = [
"//tensorflow",
"//tensorflow_models/object_detection/core:box_list_ops",
"//tensorflow/models/research/object_detection/core:box_list_ops",
],
)
......@@ -309,7 +325,7 @@ py_test(
],
deps = [
":region_similarity_calculator",
"//tensorflow_models/object_detection/core:box_list",
"//tensorflow/models/research/object_detection/core:box_list",
],
)
......@@ -330,7 +346,7 @@ py_library(
],
deps = [
"//tensorflow",
"//tensorflow_models/object_detection/utils:ops",
"//tensorflow/models/research/object_detection/utils:ops",
],
)
......
......@@ -77,8 +77,8 @@ class AnchorGenerator(object):
def generate(self, feature_map_shape_list, **params):
"""Generates a collection of bounding boxes to be used as anchors.
TODO: remove **params from argument list and make stride and offsets (for
multiple_grid_anchor_generator) constructor arguments.
TODO: remove **params from argument list and make stride and
offsets (for multiple_grid_anchor_generator) constructor arguments.
Args:
feature_map_shape_list: list of (height, width) pairs in the format
......@@ -140,3 +140,4 @@ class AnchorGenerator(object):
* feature_map_shape[0]
* feature_map_shape[1])
return tf.assert_equal(expected_num_anchors, anchors.num_boxes())
......@@ -183,7 +183,8 @@ def prune_completely_outside_window(boxlist, window, scope=None):
scope: name scope.
Returns:
pruned_corners: a tensor with shape [M_out, 4] where M_out <= M_in
pruned_boxlist: a new BoxList with all bounding boxes partially or fully in
the window.
valid_indices: a tensor with shape [M_out] indexing the valid bounding boxes
in the input tensor.
"""
......@@ -982,3 +983,79 @@ def pad_or_clip_box_list(boxlist, num_boxes, scope=None):
boxlist.get_field(field), num_boxes)
subboxlist.add_field(field, subfield)
return subboxlist
def select_random_box(boxlist,
default_box=None,
seed=None,
scope=None):
"""Selects a random bounding box from a `BoxList`.
Args:
boxlist: A BoxList.
default_box: A [1, 4] float32 tensor. If no boxes are present in `boxlist`,
this default box will be returned. If None, will use a default box of
[[-1., -1., -1., -1.]].
seed: Random seed.
scope: Name scope.
Returns:
bbox: A [1, 4] tensor with a random bounding box.
valid: A bool tensor indicating whether a valid bounding box is returned
(True) or whether the default box is returned (False).
"""
with tf.name_scope(scope, 'SelectRandomBox'):
bboxes = boxlist.get()
combined_shape = shape_utils.combined_static_and_dynamic_shape(bboxes)
number_of_boxes = combined_shape[0]
default_box = default_box or tf.constant([[-1., -1., -1., -1.]])
def select_box():
random_index = tf.random_uniform([],
maxval=number_of_boxes,
dtype=tf.int32,
seed=seed)
return tf.expand_dims(bboxes[random_index], axis=0), tf.constant(True)
return tf.cond(
tf.greater_equal(number_of_boxes, 1),
true_fn=select_box,
false_fn=lambda: (default_box, tf.constant(False)))
def get_minimal_coverage_box(boxlist,
default_box=None,
scope=None):
"""Creates a single bounding box which covers all boxes in the boxlist.
Args:
boxlist: A Boxlist.
default_box: A [1, 4] float32 tensor. If no boxes are present in `boxlist`,
this default box will be returned. If None, will use a default box of
[[0., 0., 1., 1.]].
scope: Name scope.
Returns:
A [1, 4] float32 tensor with a bounding box that tightly covers all the
boxes in the box list. If the boxlist does not contain any boxes, the
default box is returned.
"""
with tf.name_scope(scope, 'CreateCoverageBox'):
num_boxes = boxlist.num_boxes()
def coverage_box(bboxes):
y_min, x_min, y_max, x_max = tf.split(
value=bboxes, num_or_size_splits=4, axis=1)
y_min_coverage = tf.reduce_min(y_min, axis=0)
x_min_coverage = tf.reduce_min(x_min, axis=0)
y_max_coverage = tf.reduce_max(y_max, axis=0)
x_max_coverage = tf.reduce_max(x_max, axis=0)
return tf.stack(
[y_min_coverage, x_min_coverage, y_max_coverage, x_max_coverage],
axis=1)
default_box = default_box or tf.constant([[0., 0., 1., 1.]])
return tf.cond(
tf.greater_equal(num_boxes, 1),
true_fn=lambda: coverage_box(boxlist.get()),
false_fn=lambda: default_box)
......@@ -153,6 +153,25 @@ class BoxListOpsTest(tf.test.TestCase):
extra_data_out = sess.run(pruned.get_field('extra_data'))
self.assertAllEqual(extra_data_out, [[1], [2], [3], [4], [6]])
def test_prune_completely_outside_window_with_empty_boxlist(self):
window = tf.constant([0, 0, 9, 14], tf.float32)
corners = tf.zeros(shape=[0, 4], dtype=tf.float32)
boxes = box_list.BoxList(corners)
boxes.add_field('extra_data', tf.zeros(shape=[0], dtype=tf.int32))
pruned, keep_indices = box_list_ops.prune_completely_outside_window(boxes,
window)
pruned_boxes = pruned.get()
extra = pruned.get_field('extra_data')
exp_pruned_boxes = np.zeros(shape=[0, 4], dtype=np.float32)
exp_extra = np.zeros(shape=[0], dtype=np.int32)
with self.test_session() as sess:
pruned_boxes_out, keep_indices_out, extra_out = sess.run(
[pruned_boxes, keep_indices, extra])
self.assertAllClose(exp_pruned_boxes, pruned_boxes_out)
self.assertAllEqual([], keep_indices_out)
self.assertAllEqual(exp_extra, extra_out)
def test_intersection(self):
corners1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]])
corners2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0],
......@@ -593,6 +612,58 @@ class BoxListOpsTest(tf.test.TestCase):
self.assertAllEqual(expected_classes, classes_out)
self.assertAllClose(expected_scores, scores_out)
def test_select_random_box(self):
boxes = [[0., 0., 1., 1.],
[0., 1., 2., 3.],
[0., 2., 3., 4.]]
corners = tf.constant(boxes, dtype=tf.float32)
boxlist = box_list.BoxList(corners)
random_bbox, valid = box_list_ops.select_random_box(boxlist)
with self.test_session() as sess:
random_bbox_out, valid_out = sess.run([random_bbox, valid])
norm_small = any(
[np.linalg.norm(random_bbox_out - box) < 1e-6 for box in boxes])
self.assertTrue(norm_small)
self.assertTrue(valid_out)
def test_select_random_box_with_empty_boxlist(self):
corners = tf.constant([], shape=[0, 4], dtype=tf.float32)
boxlist = box_list.BoxList(corners)
random_bbox, valid = box_list_ops.select_random_box(boxlist)
with self.test_session() as sess:
random_bbox_out, valid_out = sess.run([random_bbox, valid])
expected_bbox_out = np.array([[-1., -1., -1., -1.]], dtype=np.float32)
self.assertAllEqual(expected_bbox_out, random_bbox_out)
self.assertFalse(valid_out)
def test_get_minimal_coverage_box(self):
boxes = [[0., 0., 1., 1.],
[-1., 1., 2., 3.],
[0., 2., 3., 4.]]
expected_coverage_box = [[-1., 0., 3., 4.]]
corners = tf.constant(boxes, dtype=tf.float32)
boxlist = box_list.BoxList(corners)
coverage_box = box_list_ops.get_minimal_coverage_box(boxlist)
with self.test_session() as sess:
coverage_box_out = sess.run(coverage_box)
self.assertAllClose(expected_coverage_box, coverage_box_out)
def test_get_minimal_coverage_box_with_empty_boxlist(self):
corners = tf.constant([], shape=[0, 4], dtype=tf.float32)
boxlist = box_list.BoxList(corners)
coverage_box = box_list_ops.get_minimal_coverage_box(boxlist)
with self.test_session() as sess:
coverage_box_out = sess.run(coverage_box)
self.assertAllClose([[0.0, 0.0, 1.0, 1.0]], coverage_box_out)
class ConcatenateTest(tf.test.TestCase):
......@@ -958,5 +1029,6 @@ class BoxRefinementTest(tf.test.TestCase):
self.assertAllClose(expected_scores, scores_out)
self.assertAllEqual(extra_field_out, [0, 1, 1])
if __name__ == '__main__':
tf.test.main()
......@@ -50,8 +50,10 @@ class Loss(object):
"""Call the loss function.
Args:
prediction_tensor: a tensor representing predicted quantities.
target_tensor: a tensor representing regression or classification targets.
prediction_tensor: an N-d tensor of shape [batch, anchors, ...]
representing predicted quantities.
target_tensor: an N-d tensor of shape [batch, anchors, ...] representing
regression or classification targets.
ignore_nan_targets: whether to ignore nan targets in the loss computation.
E.g. can be used if the target tensor is missing groundtruth data that
shouldn't be factored into the loss.
......@@ -81,7 +83,8 @@ class Loss(object):
the Loss.
Returns:
loss: a tensor representing the value of the loss function
loss: an N-d tensor of shape [batch, anchors, ...] containing the loss per
anchor
"""
pass
......@@ -92,15 +95,6 @@ class WeightedL2LocalizationLoss(Loss):
Loss[b,a] = .5 * ||weights[b,a] * (prediction[b,a,:] - target[b,a,:])||^2
"""
def __init__(self, anchorwise_output=False):
"""Constructor.
Args:
anchorwise_output: Outputs loss per anchor. (default False)
"""
self._anchorwise_output = anchorwise_output
def _compute_loss(self, prediction_tensor, target_tensor, weights):
"""Compute loss function.
......@@ -112,15 +106,13 @@ class WeightedL2LocalizationLoss(Loss):
weights: a float tensor of shape [batch_size, num_anchors]
Returns:
loss: a (scalar) tensor representing the value of the loss function
or a float tensor of shape [batch_size, num_anchors]
loss: a float tensor of shape [batch_size, num_anchors] tensor
representing the value of the loss function.
"""
weighted_diff = (prediction_tensor - target_tensor) * tf.expand_dims(
weights, 2)
square_diff = 0.5 * tf.square(weighted_diff)
if self._anchorwise_output:
return tf.reduce_sum(square_diff, 2)
return tf.reduce_sum(square_diff)
return tf.reduce_sum(square_diff, 2)
class WeightedSmoothL1LocalizationLoss(Loss):
......@@ -132,15 +124,6 @@ class WeightedSmoothL1LocalizationLoss(Loss):
See also Equation (3) in the Fast R-CNN paper by Ross Girshick (ICCV 2015)
"""
def __init__(self, anchorwise_output=False):
"""Constructor.
Args:
anchorwise_output: Outputs loss per anchor. (default False)
"""
self._anchorwise_output = anchorwise_output
def _compute_loss(self, prediction_tensor, target_tensor, weights):
"""Compute loss function.
......@@ -152,7 +135,8 @@ class WeightedSmoothL1LocalizationLoss(Loss):
weights: a float tensor of shape [batch_size, num_anchors]
Returns:
loss: a (scalar) tensor representing the value of the loss function
loss: a float tensor of shape [batch_size, num_anchors] tensor
representing the value of the loss function.
"""
diff = prediction_tensor - target_tensor
abs_diff = tf.abs(diff)
......@@ -160,9 +144,7 @@ class WeightedSmoothL1LocalizationLoss(Loss):
anchorwise_smooth_l1norm = tf.reduce_sum(
tf.where(abs_diff_lt_1, 0.5 * tf.square(abs_diff), abs_diff - 0.5),
2) * weights
if self._anchorwise_output:
return anchorwise_smooth_l1norm
return tf.reduce_sum(anchorwise_smooth_l1norm)
return anchorwise_smooth_l1norm
class WeightedIOULocalizationLoss(Loss):
......@@ -184,27 +166,19 @@ class WeightedIOULocalizationLoss(Loss):
weights: a float tensor of shape [batch_size, num_anchors]
Returns:
loss: a (scalar) tensor representing the value of the loss function
loss: a float tensor of shape [batch_size, num_anchors] tensor
representing the value of the loss function.
"""
predicted_boxes = box_list.BoxList(tf.reshape(prediction_tensor, [-1, 4]))
target_boxes = box_list.BoxList(tf.reshape(target_tensor, [-1, 4]))
per_anchor_iou_loss = 1.0 - box_list_ops.matched_iou(predicted_boxes,
target_boxes)
return tf.reduce_sum(tf.reshape(weights, [-1]) * per_anchor_iou_loss)
return tf.reshape(weights, [-1]) * per_anchor_iou_loss
class WeightedSigmoidClassificationLoss(Loss):
"""Sigmoid cross entropy classification loss function."""
def __init__(self, anchorwise_output=False):
"""Constructor.
Args:
anchorwise_output: Outputs loss per anchor. (default False)
"""
self._anchorwise_output = anchorwise_output
def _compute_loss(self,
prediction_tensor,
target_tensor,
......@@ -222,8 +196,8 @@ class WeightedSigmoidClassificationLoss(Loss):
If provided, computes loss only for the specified class indices.
Returns:
loss: a (scalar) tensor representing the value of the loss function
or a float tensor of shape [batch_size, num_anchors]
loss: a float tensor of shape [batch_size, num_anchors, num_classes]
representing the value of the loss function.
"""
weights = tf.expand_dims(weights, 2)
if class_indices is not None:
......@@ -233,9 +207,7 @@ class WeightedSigmoidClassificationLoss(Loss):
[1, 1, -1])
per_entry_cross_ent = (tf.nn.sigmoid_cross_entropy_with_logits(
labels=target_tensor, logits=prediction_tensor))
if self._anchorwise_output:
return tf.reduce_sum(per_entry_cross_ent * weights, 2)
return tf.reduce_sum(per_entry_cross_ent * weights)
return per_entry_cross_ent * weights
class SigmoidFocalClassificationLoss(Loss):
......@@ -245,15 +217,13 @@ class SigmoidFocalClassificationLoss(Loss):
examples. See https://arxiv.org/pdf/1708.02002.pdf for the loss definition.
"""
def __init__(self, anchorwise_output=False, gamma=2.0, alpha=0.25):
def __init__(self, gamma=2.0, alpha=0.25):
"""Constructor.
Args:
anchorwise_output: Outputs loss per anchor. (default False)
gamma: exponent of the modulating factor (1 - p_t) ^ gamma.
alpha: optional alpha weighting factor to balance positives vs negatives.
"""
self._anchorwise_output = anchorwise_output
self._alpha = alpha
self._gamma = gamma
......@@ -274,8 +244,8 @@ class SigmoidFocalClassificationLoss(Loss):
If provided, computes loss only for the specified class indices.
Returns:
loss: a (scalar) tensor representing the value of the loss function
or a float tensor of shape [batch_size, num_anchors]
loss: a float tensor of shape [batch_size, num_anchors, num_classes]
representing the value of the loss function.
"""
weights = tf.expand_dims(weights, 2)
if class_indices is not None:
......@@ -297,25 +267,21 @@ class SigmoidFocalClassificationLoss(Loss):
(1 - target_tensor) * (1 - self._alpha))
focal_cross_entropy_loss = (modulating_factor * alpha_weight_factor *
per_entry_cross_ent)
if self._anchorwise_output:
return tf.reduce_sum(focal_cross_entropy_loss * weights, 2)
return tf.reduce_sum(focal_cross_entropy_loss * weights)
return focal_cross_entropy_loss * weights
class WeightedSoftmaxClassificationLoss(Loss):
"""Softmax loss function."""
def __init__(self, anchorwise_output=False, logit_scale=1.0):
def __init__(self, logit_scale=1.0):
"""Constructor.
Args:
anchorwise_output: Whether to output loss per anchor (default False)
logit_scale: When this value is high, the prediction is "diffused" and
when this value is low, the prediction is made peakier.
(default 1.0)
"""
self._anchorwise_output = anchorwise_output
self._logit_scale = logit_scale
def _compute_loss(self, prediction_tensor, target_tensor, weights):
......@@ -329,7 +295,8 @@ class WeightedSoftmaxClassificationLoss(Loss):
weights: a float tensor of shape [batch_size, num_anchors]
Returns:
loss: a (scalar) tensor representing the value of the loss function
loss: a float tensor of shape [batch_size, num_anchors]
representing the value of the loss function.
"""
num_classes = prediction_tensor.get_shape().as_list()[-1]
prediction_tensor = tf.divide(
......@@ -337,9 +304,7 @@ class WeightedSoftmaxClassificationLoss(Loss):
per_row_cross_ent = (tf.nn.softmax_cross_entropy_with_logits(
labels=tf.reshape(target_tensor, [-1, num_classes]),
logits=tf.reshape(prediction_tensor, [-1, num_classes])))
if self._anchorwise_output:
return tf.reshape(per_row_cross_ent, tf.shape(weights)) * weights
return tf.reduce_sum(per_row_cross_ent * tf.reshape(weights, [-1]))
return tf.reshape(per_row_cross_ent, tf.shape(weights)) * weights
class BootstrappedSigmoidClassificationLoss(Loss):
......@@ -359,14 +324,13 @@ class BootstrappedSigmoidClassificationLoss(Loss):
Reed et al. (ICLR 2015).
"""
def __init__(self, alpha, bootstrap_type='soft', anchorwise_output=False):
def __init__(self, alpha, bootstrap_type='soft'):
"""Constructor.
Args:
alpha: a float32 scalar tensor between 0 and 1 representing interpolation
weight
bootstrap_type: set to either 'hard' or 'soft' (default)
anchorwise_output: Outputs loss per anchor. (default False)
Raises:
ValueError: if bootstrap_type is not either 'hard' or 'soft'
......@@ -376,7 +340,6 @@ class BootstrappedSigmoidClassificationLoss(Loss):
'\'hard\' or \'soft.\'')
self._alpha = alpha
self._bootstrap_type = bootstrap_type
self._anchorwise_output = anchorwise_output
def _compute_loss(self, prediction_tensor, target_tensor, weights):
"""Compute loss function.
......@@ -389,8 +352,8 @@ class BootstrappedSigmoidClassificationLoss(Loss):
weights: a float tensor of shape [batch_size, num_anchors]
Returns:
loss: a (scalar) tensor representing the value of the loss function
or a float tensor of shape [batch_size, num_anchors]
loss: a float tensor of shape [batch_size, num_anchors, num_classes]
representing the value of the loss function.
"""
if self._bootstrap_type == 'soft':
bootstrap_target_tensor = self._alpha * target_tensor + (
......@@ -401,9 +364,7 @@ class BootstrappedSigmoidClassificationLoss(Loss):
tf.sigmoid(prediction_tensor) > 0.5, tf.float32)
per_entry_cross_ent = (tf.nn.sigmoid_cross_entropy_with_logits(
labels=bootstrap_target_tensor, logits=prediction_tensor))
if self._anchorwise_output:
return tf.reduce_sum(per_entry_cross_ent * tf.expand_dims(weights, 2), 2)
return tf.reduce_sum(per_entry_cross_ent * tf.expand_dims(weights, 2))
return per_entry_cross_ent * tf.expand_dims(weights, 2)
class HardExampleMiner(object):
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment