Commit 930ccd92 authored by Zhichao Lu's avatar Zhichao Lu Committed by lzc5123016
Browse files

Add encode_background_as_zeros option to the SSDMetaArch class --- now clients...

Add encode_background_as_zeros option to the SSDMetaArch class --- now clients have the option of encoding background targets as an all zeros vector or a one-hot vector with the 0th dimension corresponding to a background prediction.

PiperOrigin-RevId: 185228281
parent 8f932583
......@@ -152,6 +152,7 @@ def _build_ssd_model(ssd_config, is_training, add_summaries):
matcher = matcher_builder.build(ssd_config.matcher)
region_similarity_calculator = sim_calc.build(
ssd_config.similarity_calculator)
encode_background_as_zeros = ssd_config.encode_background_as_zeros
ssd_box_predictor = box_predictor_builder.build(hyperparams_builder.build,
ssd_config.box_predictor,
is_training, num_classes)
......@@ -173,6 +174,7 @@ def _build_ssd_model(ssd_config, is_training, add_summaries):
feature_extractor,
matcher,
region_similarity_calculator,
encode_background_as_zeros,
image_resizer_fn,
non_max_suppression_fn,
score_conversion_fn,
......
......@@ -121,6 +121,7 @@ class SSDMetaArch(model.DetectionModel):
feature_extractor,
matcher,
region_similarity_calculator,
encode_background_as_zeros,
image_resizer_fn,
non_max_suppression_fn,
score_conversion_fn,
......@@ -147,6 +148,9 @@ class SSDMetaArch(model.DetectionModel):
matcher: a matcher.Matcher object.
region_similarity_calculator: a
region_similarity_calculator.RegionSimilarityCalculator object.
encode_background_as_zeros: boolean determining whether background
targets are to be encoded as an all zeros vector or a one-hot
vector (where background is the 0th class).
image_resizer_fn: a callable for image resizing. This callable always
takes a rank-3 image tensor (corresponding to a single image) and
returns a rank-3 image tensor, possibly with new spatial dimensions and
......@@ -190,7 +194,12 @@ class SSDMetaArch(model.DetectionModel):
# TODO: handle agnostic mode and positive/negative class
# weights
unmatched_cls_target = None
unmatched_cls_target = tf.constant([1] + self.num_classes * [0], tf.float32)
unmatched_cls_target = tf.constant([1] + self.num_classes * [0],
tf.float32)
if encode_background_as_zeros:
unmatched_cls_target = tf.constant((self.num_classes + 1) * [0],
tf.float32)
self._target_assigner = target_assigner.TargetAssigner(
self._region_similarity_calculator,
self._matcher,
......
......@@ -84,7 +84,7 @@ class SsdMetaArchTest(test_case.TestCase):
fake_feature_extractor = FakeSSDFeatureExtractor()
mock_matcher = test_utils.MockMatcher()
region_similarity_calculator = sim_calc.IouSimilarity()
encode_background_as_zeros = False
def image_resizer_fn(image):
return [tf.identity(image), tf.shape(image)]
......@@ -111,10 +111,10 @@ class SsdMetaArchTest(test_case.TestCase):
model = ssd_meta_arch.SSDMetaArch(
is_training, mock_anchor_generator, mock_box_predictor, mock_box_coder,
fake_feature_extractor, mock_matcher, region_similarity_calculator,
image_resizer_fn, non_max_suppression_fn, tf.identity,
classification_loss, localization_loss, classification_loss_weight,
localization_loss_weight, normalize_loss_by_num_matches,
hard_example_miner, add_summaries=False)
encode_background_as_zeros, image_resizer_fn, non_max_suppression_fn,
tf.identity, classification_loss, localization_loss,
classification_loss_weight, localization_loss_weight,
normalize_loss_by_num_matches, hard_example_miner, add_summaries=False)
return model, num_classes, mock_anchor_generator.num_anchors(), code_size
def test_preprocess_preserves_shapes_with_dynamic_input_image(self):
......
......@@ -32,6 +32,10 @@ message Ssd {
// Region similarity calculator to compute similarity of boxes.
optional RegionSimilarityCalculator similarity_calculator = 6;
// Whether background targets are to be encoded as an all
// zeros vector or a one-hot vector (where background is the 0th class).
optional bool encode_background_as_zeros = 12 [default=false];
// Box predictor to attach to the features.
optional BoxPredictor box_predictor = 7;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment