"git@developer.sourcefind.cn:OpenDAS/torch-scatter.git" did not exist on "bf36f19a5dee39a8e340e8cff1b1edae1ae7dfb0"
Commit 930ccd92 authored by Zhichao Lu's avatar Zhichao Lu Committed by lzc5123016
Browse files

Add encode_background_as_zeros option to the SSDMetaArch class --- now clients...

Add encode_background_as_zeros option to the SSDMetaArch class --- now clients have the option of encoding background targets as an all zeros vector or a one-hot vector with the 0th dimension corresponding to a background prediction.

PiperOrigin-RevId: 185228281
parent 8f932583
...@@ -152,6 +152,7 @@ def _build_ssd_model(ssd_config, is_training, add_summaries): ...@@ -152,6 +152,7 @@ def _build_ssd_model(ssd_config, is_training, add_summaries):
matcher = matcher_builder.build(ssd_config.matcher) matcher = matcher_builder.build(ssd_config.matcher)
region_similarity_calculator = sim_calc.build( region_similarity_calculator = sim_calc.build(
ssd_config.similarity_calculator) ssd_config.similarity_calculator)
encode_background_as_zeros = ssd_config.encode_background_as_zeros
ssd_box_predictor = box_predictor_builder.build(hyperparams_builder.build, ssd_box_predictor = box_predictor_builder.build(hyperparams_builder.build,
ssd_config.box_predictor, ssd_config.box_predictor,
is_training, num_classes) is_training, num_classes)
...@@ -173,6 +174,7 @@ def _build_ssd_model(ssd_config, is_training, add_summaries): ...@@ -173,6 +174,7 @@ def _build_ssd_model(ssd_config, is_training, add_summaries):
feature_extractor, feature_extractor,
matcher, matcher,
region_similarity_calculator, region_similarity_calculator,
encode_background_as_zeros,
image_resizer_fn, image_resizer_fn,
non_max_suppression_fn, non_max_suppression_fn,
score_conversion_fn, score_conversion_fn,
......
...@@ -121,6 +121,7 @@ class SSDMetaArch(model.DetectionModel): ...@@ -121,6 +121,7 @@ class SSDMetaArch(model.DetectionModel):
feature_extractor, feature_extractor,
matcher, matcher,
region_similarity_calculator, region_similarity_calculator,
encode_background_as_zeros,
image_resizer_fn, image_resizer_fn,
non_max_suppression_fn, non_max_suppression_fn,
score_conversion_fn, score_conversion_fn,
...@@ -147,6 +148,9 @@ class SSDMetaArch(model.DetectionModel): ...@@ -147,6 +148,9 @@ class SSDMetaArch(model.DetectionModel):
matcher: a matcher.Matcher object. matcher: a matcher.Matcher object.
region_similarity_calculator: a region_similarity_calculator: a
region_similarity_calculator.RegionSimilarityCalculator object. region_similarity_calculator.RegionSimilarityCalculator object.
encode_background_as_zeros: boolean determining whether background
targets are to be encoded as an all zeros vector or a one-hot
vector (where background is the 0th class).
image_resizer_fn: a callable for image resizing. This callable always image_resizer_fn: a callable for image resizing. This callable always
takes a rank-3 image tensor (corresponding to a single image) and takes a rank-3 image tensor (corresponding to a single image) and
returns a rank-3 image tensor, possibly with new spatial dimensions and returns a rank-3 image tensor, possibly with new spatial dimensions and
...@@ -190,7 +194,12 @@ class SSDMetaArch(model.DetectionModel): ...@@ -190,7 +194,12 @@ class SSDMetaArch(model.DetectionModel):
# TODO: handle agnostic mode and positive/negative class # TODO: handle agnostic mode and positive/negative class
# weights # weights
unmatched_cls_target = None unmatched_cls_target = None
unmatched_cls_target = tf.constant([1] + self.num_classes * [0], tf.float32) unmatched_cls_target = tf.constant([1] + self.num_classes * [0],
tf.float32)
if encode_background_as_zeros:
unmatched_cls_target = tf.constant((self.num_classes + 1) * [0],
tf.float32)
self._target_assigner = target_assigner.TargetAssigner( self._target_assigner = target_assigner.TargetAssigner(
self._region_similarity_calculator, self._region_similarity_calculator,
self._matcher, self._matcher,
......
...@@ -84,7 +84,7 @@ class SsdMetaArchTest(test_case.TestCase): ...@@ -84,7 +84,7 @@ class SsdMetaArchTest(test_case.TestCase):
fake_feature_extractor = FakeSSDFeatureExtractor() fake_feature_extractor = FakeSSDFeatureExtractor()
mock_matcher = test_utils.MockMatcher() mock_matcher = test_utils.MockMatcher()
region_similarity_calculator = sim_calc.IouSimilarity() region_similarity_calculator = sim_calc.IouSimilarity()
encode_background_as_zeros = False
def image_resizer_fn(image): def image_resizer_fn(image):
return [tf.identity(image), tf.shape(image)] return [tf.identity(image), tf.shape(image)]
...@@ -111,10 +111,10 @@ class SsdMetaArchTest(test_case.TestCase): ...@@ -111,10 +111,10 @@ class SsdMetaArchTest(test_case.TestCase):
model = ssd_meta_arch.SSDMetaArch( model = ssd_meta_arch.SSDMetaArch(
is_training, mock_anchor_generator, mock_box_predictor, mock_box_coder, is_training, mock_anchor_generator, mock_box_predictor, mock_box_coder,
fake_feature_extractor, mock_matcher, region_similarity_calculator, fake_feature_extractor, mock_matcher, region_similarity_calculator,
image_resizer_fn, non_max_suppression_fn, tf.identity, encode_background_as_zeros, image_resizer_fn, non_max_suppression_fn,
classification_loss, localization_loss, classification_loss_weight, tf.identity, classification_loss, localization_loss,
localization_loss_weight, normalize_loss_by_num_matches, classification_loss_weight, localization_loss_weight,
hard_example_miner, add_summaries=False) normalize_loss_by_num_matches, hard_example_miner, add_summaries=False)
return model, num_classes, mock_anchor_generator.num_anchors(), code_size return model, num_classes, mock_anchor_generator.num_anchors(), code_size
def test_preprocess_preserves_shapes_with_dynamic_input_image(self): def test_preprocess_preserves_shapes_with_dynamic_input_image(self):
......
...@@ -32,6 +32,10 @@ message Ssd { ...@@ -32,6 +32,10 @@ message Ssd {
// Region similarity calculator to compute similarity of boxes. // Region similarity calculator to compute similarity of boxes.
optional RegionSimilarityCalculator similarity_calculator = 6; optional RegionSimilarityCalculator similarity_calculator = 6;
// Whether background targets are to be encoded as an all
// zeros vector or a one-hot vector (where background is the 0th class).
optional bool encode_background_as_zeros = 12 [default=false];
// Box predictor to attach to the features. // Box predictor to attach to the features.
optional BoxPredictor box_predictor = 7; optional BoxPredictor box_predictor = 7;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment