Commit 441f14a6 authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Add test for DeepMAC builder.

PiperOrigin-RevId: 373010443
parent fd67924c
......@@ -34,6 +34,7 @@ from object_detection.core import post_processing
from object_detection.core import target_assigner
from object_detection.meta_architectures import center_net_meta_arch
from object_detection.meta_architectures import context_rcnn_meta_arch
from object_detection.meta_architectures import deepmac_meta_arch
from object_detection.meta_architectures import faster_rcnn_meta_arch
from object_detection.meta_architectures import rfcn_meta_arch
from object_detection.meta_architectures import ssd_meta_arch
......
......@@ -25,6 +25,7 @@ from google.protobuf import text_format
from object_detection.builders import model_builder
from object_detection.builders import model_builder_test
from object_detection.core import losses
from object_detection.meta_architectures import deepmac_meta_arch
from object_detection.models import center_net_hourglass_feature_extractor
from object_detection.models.keras_models import hourglass_network
from object_detection.protos import center_net_pb2
......@@ -466,6 +467,57 @@ class ModelBuilderTF2Test(
# Verify that there are up_sampling2d layers.
self.assertGreater(num_up_sampling2d_layers, 0)
def test_create_center_net_deepmac(self):
"""Test building a CenterNet DeepMAC model."""
proto_txt = """
center_net {
num_classes: 90
feature_extractor {
type: "hourglass_52"
}
image_resizer {
keep_aspect_ratio_resizer {
min_dimension: 512
max_dimension: 512
pad_to_max_dimension: true
}
}
object_detection_task {
task_loss_weight: 1.0
offset_loss_weight: 1.0
scale_loss_weight: 0.1
localization_loss {
l1_localization_loss {
}
}
}
object_center_params {
object_center_loss_weight: 1.0
min_box_overlap_iou: 0.7
max_box_predictions: 100
classification_loss {
penalty_reduced_logistic_focal_loss {
alpha: 2.0
beta: 4.0
}
}
}
deepmac_mask_estimation {
classification_loss {
weighted_sigmoid {}
}
}
}
"""
# Set up the configuration proto.
config = text_format.Parse(proto_txt, model_pb2.DetectionModel())
# Build the model from the configuration.
model = model_builder.build(config, is_training=True)
self.assertIsInstance(model, deepmac_meta_arch.DeepMACMetaArch)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment