Commit ef76912d authored by Ronny Votel's avatar Ronny Votel Committed by TF Object Detection Team
Browse files

Allowing for input resolutions other than 224x224 in CenterNet ResNet models.

PiperOrigin-RevId: 327521336
parent 6b2bc083
......@@ -46,10 +46,12 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
channel_means=channel_means, channel_stds=channel_stds,
bgr_ordering=bgr_ordering)
if resnet_type == 'resnet_v2_101':
self._base_model = tf.keras.applications.ResNet101V2(weights=None)
self._base_model = tf.keras.applications.ResNet101V2(weights=None,
include_top=False)
output_layer = 'conv5_block3_out'
elif resnet_type == 'resnet_v2_50':
self._base_model = tf.keras.applications.ResNet50V2(weights=None)
self._base_model = tf.keras.applications.ResNet50V2(weights=None,
include_top=False)
output_layer = 'conv5_block3_out'
else:
raise ValueError('Unknown Resnet Model {}'.format(resnet_type))
......
......@@ -31,11 +31,11 @@ class CenterNetResnetFeatureExtractorTest(test_case.TestCase):
model = center_net_resnet_feature_extractor.\
CenterNetResnetFeatureExtractor('resnet_v2_101')
def graph_fn():
img = np.zeros((8, 224, 224, 3), dtype=np.float32)
img = np.zeros((8, 512, 512, 3), dtype=np.float32)
processed_img = model.preprocess(img)
return model(processed_img)
outputs = self.execute(graph_fn, [])
self.assertEqual(outputs.shape, (8, 56, 56, 64))
self.assertEqual(outputs.shape, (8, 128, 128, 64))
def test_output_size_resnet50(self):
"""Verify that shape of features returned by the backbone is correct."""
......
......@@ -71,13 +71,15 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
channel_means=channel_means, channel_stds=channel_stds,
bgr_ordering=bgr_ordering)
if resnet_type == 'resnet_v1_50':
self._base_model = tf.keras.applications.ResNet50(weights=None)
self._base_model = tf.keras.applications.ResNet50(weights=None,
include_top=False)
elif resnet_type == 'resnet_v1_101':
self._base_model = tf.keras.applications.ResNet101(weights=None)
self._base_model = tf.keras.applications.ResNet101(weights=None,
include_top=False)
elif resnet_type == 'resnet_v1_18':
self._base_model = resnet_v1.resnet_v1_18(weights=None)
self._base_model = resnet_v1.resnet_v1_18(weights=None, include_top=False)
elif resnet_type == 'resnet_v1_34':
self._base_model = resnet_v1.resnet_v1_34(weights=None)
self._base_model = resnet_v1.resnet_v1_34(weights=None, include_top=False)
else:
raise ValueError('Unknown Resnet Model {}'.format(resnet_type))
output_layers = _RESNET_MODEL_OUTPUT_LAYERS[resnet_type]
......
......@@ -40,11 +40,11 @@ class CenterNetResnetV1FpnFeatureExtractorTest(test_case.TestCase,
model = center_net_resnet_v1_fpn_feature_extractor.\
CenterNetResnetV1FpnFeatureExtractor(resnet_type)
def graph_fn():
img = np.zeros((8, 224, 224, 3), dtype=np.float32)
img = np.zeros((8, 512, 512, 3), dtype=np.float32)
processed_img = model.preprocess(img)
return model(processed_img)
self.assertEqual(self.execute(graph_fn, []).shape, (8, 56, 56, 64))
self.assertEqual(self.execute(graph_fn, []).shape, (8, 128, 128, 64))
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment