"vscode:/vscode.git/clone" did not exist on "45c726c2c43f44081817ae2e05edc2ff1ed86e3a"
Commit ef76912d authored by Ronny Votel's avatar Ronny Votel Committed by TF Object Detection Team
Browse files

Allowing for input resolutions other than 224x224 in CenterNet ResNet models.

PiperOrigin-RevId: 327521336
parent 6b2bc083
...@@ -46,10 +46,12 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor): ...@@ -46,10 +46,12 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
channel_means=channel_means, channel_stds=channel_stds, channel_means=channel_means, channel_stds=channel_stds,
bgr_ordering=bgr_ordering) bgr_ordering=bgr_ordering)
if resnet_type == 'resnet_v2_101': if resnet_type == 'resnet_v2_101':
self._base_model = tf.keras.applications.ResNet101V2(weights=None) self._base_model = tf.keras.applications.ResNet101V2(weights=None,
include_top=False)
output_layer = 'conv5_block3_out' output_layer = 'conv5_block3_out'
elif resnet_type == 'resnet_v2_50': elif resnet_type == 'resnet_v2_50':
self._base_model = tf.keras.applications.ResNet50V2(weights=None) self._base_model = tf.keras.applications.ResNet50V2(weights=None,
include_top=False)
output_layer = 'conv5_block3_out' output_layer = 'conv5_block3_out'
else: else:
raise ValueError('Unknown Resnet Model {}'.format(resnet_type)) raise ValueError('Unknown Resnet Model {}'.format(resnet_type))
......
...@@ -31,11 +31,11 @@ class CenterNetResnetFeatureExtractorTest(test_case.TestCase): ...@@ -31,11 +31,11 @@ class CenterNetResnetFeatureExtractorTest(test_case.TestCase):
model = center_net_resnet_feature_extractor.\ model = center_net_resnet_feature_extractor.\
CenterNetResnetFeatureExtractor('resnet_v2_101') CenterNetResnetFeatureExtractor('resnet_v2_101')
def graph_fn(): def graph_fn():
img = np.zeros((8, 224, 224, 3), dtype=np.float32) img = np.zeros((8, 512, 512, 3), dtype=np.float32)
processed_img = model.preprocess(img) processed_img = model.preprocess(img)
return model(processed_img) return model(processed_img)
outputs = self.execute(graph_fn, []) outputs = self.execute(graph_fn, [])
self.assertEqual(outputs.shape, (8, 56, 56, 64)) self.assertEqual(outputs.shape, (8, 128, 128, 64))
def test_output_size_resnet50(self): def test_output_size_resnet50(self):
"""Verify that shape of features returned by the backbone is correct.""" """Verify that shape of features returned by the backbone is correct."""
......
...@@ -71,13 +71,15 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor): ...@@ -71,13 +71,15 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
channel_means=channel_means, channel_stds=channel_stds, channel_means=channel_means, channel_stds=channel_stds,
bgr_ordering=bgr_ordering) bgr_ordering=bgr_ordering)
if resnet_type == 'resnet_v1_50': if resnet_type == 'resnet_v1_50':
self._base_model = tf.keras.applications.ResNet50(weights=None) self._base_model = tf.keras.applications.ResNet50(weights=None,
include_top=False)
elif resnet_type == 'resnet_v1_101': elif resnet_type == 'resnet_v1_101':
self._base_model = tf.keras.applications.ResNet101(weights=None) self._base_model = tf.keras.applications.ResNet101(weights=None,
include_top=False)
elif resnet_type == 'resnet_v1_18': elif resnet_type == 'resnet_v1_18':
self._base_model = resnet_v1.resnet_v1_18(weights=None) self._base_model = resnet_v1.resnet_v1_18(weights=None, include_top=False)
elif resnet_type == 'resnet_v1_34': elif resnet_type == 'resnet_v1_34':
self._base_model = resnet_v1.resnet_v1_34(weights=None) self._base_model = resnet_v1.resnet_v1_34(weights=None, include_top=False)
else: else:
raise ValueError('Unknown Resnet Model {}'.format(resnet_type)) raise ValueError('Unknown Resnet Model {}'.format(resnet_type))
output_layers = _RESNET_MODEL_OUTPUT_LAYERS[resnet_type] output_layers = _RESNET_MODEL_OUTPUT_LAYERS[resnet_type]
......
...@@ -40,11 +40,11 @@ class CenterNetResnetV1FpnFeatureExtractorTest(test_case.TestCase, ...@@ -40,11 +40,11 @@ class CenterNetResnetV1FpnFeatureExtractorTest(test_case.TestCase,
model = center_net_resnet_v1_fpn_feature_extractor.\ model = center_net_resnet_v1_fpn_feature_extractor.\
CenterNetResnetV1FpnFeatureExtractor(resnet_type) CenterNetResnetV1FpnFeatureExtractor(resnet_type)
def graph_fn(): def graph_fn():
img = np.zeros((8, 224, 224, 3), dtype=np.float32) img = np.zeros((8, 512, 512, 3), dtype=np.float32)
processed_img = model.preprocess(img) processed_img = model.preprocess(img)
return model(processed_img) return model(processed_img)
self.assertEqual(self.execute(graph_fn, []).shape, (8, 56, 56, 64)) self.assertEqual(self.execute(graph_fn, []).shape, (8, 128, 128, 64))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment