Commit ec2d5d8d authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Internal change

PiperOrigin-RevId: 321810710
parent 4c83bd7f
...@@ -140,6 +140,10 @@ if tf_version.is_tf2(): ...@@ -140,6 +140,10 @@ if tf_version.is_tf2():
CENTER_NET_EXTRACTOR_FUNCTION_MAP = { CENTER_NET_EXTRACTOR_FUNCTION_MAP = {
'resnet_v2_50': center_net_resnet_feature_extractor.resnet_v2_50, 'resnet_v2_50': center_net_resnet_feature_extractor.resnet_v2_50,
'resnet_v2_101': center_net_resnet_feature_extractor.resnet_v2_101, 'resnet_v2_101': center_net_resnet_feature_extractor.resnet_v2_101,
'resnet_v1_18_fpn':
center_net_resnet_v1_fpn_feature_extractor.resnet_v1_18_fpn,
'resnet_v1_34_fpn':
center_net_resnet_v1_fpn_feature_extractor.resnet_v1_34_fpn,
'resnet_v1_50_fpn': 'resnet_v1_50_fpn':
center_net_resnet_v1_fpn_feature_extractor.resnet_v1_50_fpn, center_net_resnet_v1_fpn_feature_extractor.resnet_v1_50_fpn,
'resnet_v1_101_fpn': 'resnet_v1_101_fpn':
......
...@@ -21,9 +21,14 @@ ...@@ -21,9 +21,14 @@
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from object_detection.meta_architectures.center_net_meta_arch import CenterNetFeatureExtractor from object_detection.meta_architectures.center_net_meta_arch import CenterNetFeatureExtractor
from object_detection.models.keras_models import resnet_v1
_RESNET_MODEL_OUTPUT_LAYERS = { _RESNET_MODEL_OUTPUT_LAYERS = {
'resnet_v1_18': ['conv2_block2_out', 'conv3_block2_out',
'conv4_block2_out', 'conv5_block2_out'],
'resnet_v1_34': ['conv2_block3_out', 'conv3_block4_out',
'conv4_block6_out', 'conv5_block3_out'],
'resnet_v1_50': ['conv2_block3_out', 'conv3_block4_out', 'resnet_v1_50': ['conv2_block3_out', 'conv3_block4_out',
'conv4_block6_out', 'conv5_block3_out'], 'conv4_block6_out', 'conv5_block3_out'],
'resnet_v1_101': ['conv2_block3_out', 'conv3_block4_out', 'resnet_v1_101': ['conv2_block3_out', 'conv3_block4_out',
...@@ -69,6 +74,10 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor): ...@@ -69,6 +74,10 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
self._base_model = tf.keras.applications.ResNet50(weights=None) self._base_model = tf.keras.applications.ResNet50(weights=None)
elif resnet_type == 'resnet_v1_101': elif resnet_type == 'resnet_v1_101':
self._base_model = tf.keras.applications.ResNet101(weights=None) self._base_model = tf.keras.applications.ResNet101(weights=None)
elif resnet_type == 'resnet_v1_18':
self._base_model = resnet_v1.resnet_v1_18(weights=None)
elif resnet_type == 'resnet_v1_34':
self._base_model = resnet_v1.resnet_v1_34(weights=None)
else: else:
raise ValueError('Unknown Resnet Model {}'.format(resnet_type)) raise ValueError('Unknown Resnet Model {}'.format(resnet_type))
output_layers = _RESNET_MODEL_OUTPUT_LAYERS[resnet_type] output_layers = _RESNET_MODEL_OUTPUT_LAYERS[resnet_type]
...@@ -174,3 +183,24 @@ def resnet_v1_50_fpn(channel_means, channel_stds, bgr_ordering): ...@@ -174,3 +183,24 @@ def resnet_v1_50_fpn(channel_means, channel_stds, bgr_ordering):
channel_means=channel_means, channel_means=channel_means,
channel_stds=channel_stds, channel_stds=channel_stds,
bgr_ordering=bgr_ordering) bgr_ordering=bgr_ordering)
def resnet_v1_34_fpn(channel_means, channel_stds, bgr_ordering):
"""The ResNet v1 34 FPN feature extractor."""
return CenterNetResnetV1FpnFeatureExtractor(
resnet_type='resnet_v1_34',
channel_means=channel_means,
channel_stds=channel_stds,
bgr_ordering=bgr_ordering
)
def resnet_v1_18_fpn(channel_means, channel_stds, bgr_ordering):
"""The ResNet v1 18 FPN feature extractor."""
return CenterNetResnetV1FpnFeatureExtractor(
resnet_type='resnet_v1_18',
channel_means=channel_means,
channel_stds=channel_stds,
bgr_ordering=bgr_ordering)
...@@ -31,6 +31,8 @@ class CenterNetResnetV1FpnFeatureExtractorTest(test_case.TestCase, ...@@ -31,6 +31,8 @@ class CenterNetResnetV1FpnFeatureExtractorTest(test_case.TestCase,
@parameterized.parameters( @parameterized.parameters(
{'resnet_type': 'resnet_v1_50'}, {'resnet_type': 'resnet_v1_50'},
{'resnet_type': 'resnet_v1_101'}, {'resnet_type': 'resnet_v1_101'},
{'resnet_type': 'resnet_v1_18'},
{'resnet_type': 'resnet_v1_34'},
) )
def test_correct_output_size(self, resnet_type): def test_correct_output_size(self, resnet_type):
"""Verify that shape of features returned by the backbone is correct.""" """Verify that shape of features returned by the backbone is correct."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment