Commit 7140eede authored by syiming's avatar syiming
Browse files

fix coding style

parent 8db480c9
...@@ -43,22 +43,22 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractorTest(tf.test.TestCase): ...@@ -43,22 +43,22 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractorTest(tf.test.TestCase):
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams)
return hyperparams_builder.KerasLayerHyperparams(conv_hyperparams) return hyperparams_builder.KerasLayerHyperparams(conv_hyperparams)
def _build_feature_extractor(self, architecture='resnet_v1_50'): def _build_feature_extractor(self):
return frcnn_res_fpn.FasterRCNNResnet50FPNKerasFeatureExtractor( return frcnn_res_fpn.FasterRCNNResnet50FPNKerasFeatureExtractor(
is_training=False, is_training=False,
conv_hyperparams=self._build_conv_hyperparams(), conv_hyperparams=self._build_conv_hyperparams(),
first_stage_features_stride=16, first_stage_features_stride=16,
batch_norm_trainable=False, batch_norm_trainable=False,
weight_decay=0.0) weight_decay=0.0)
def test_extract_proposal_features_returns_expected_size(self): def test_extract_proposal_features_returns_expected_size(self):
feature_extractor = self._build_feature_extractor() feature_extractor = self._build_feature_extractor()
preprocessed_inputs = tf.random_uniform( preprocessed_inputs = tf.random_uniform(
[2, 448, 448, 3], maxval=255, dtype=tf.float32) [2, 448, 448, 3], maxval=255, dtype=tf.float32)
rpn_feature_maps = feature_extractor.get_proposal_feature_extractor_model( rpn_feature_maps = feature_extractor.get_proposal_feature_extractor_model(
name='TestScope')(preprocessed_inputs) name='TestScope')(preprocessed_inputs)
features_shapes = [tf.shape(rpn_feature_map) features_shapes = [tf.shape(rpn_feature_map)
for rpn_feature_map in rpn_feature_maps] for rpn_feature_map in rpn_feature_maps]
self.assertAllEqual(features_shapes[0].numpy(), [2, 112, 112, 256]) self.assertAllEqual(features_shapes[0].numpy(), [2, 112, 112, 256])
self.assertAllEqual(features_shapes[1].numpy(), [2, 56, 56, 256]) self.assertAllEqual(features_shapes[1].numpy(), [2, 56, 56, 256])
...@@ -71,9 +71,9 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractorTest(tf.test.TestCase): ...@@ -71,9 +71,9 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractorTest(tf.test.TestCase):
[2, 224, 224, 3], maxval=255, dtype=tf.float32) [2, 224, 224, 3], maxval=255, dtype=tf.float32)
rpn_feature_maps = feature_extractor.get_proposal_feature_extractor_model( rpn_feature_maps = feature_extractor.get_proposal_feature_extractor_model(
name='TestScope')(preprocessed_inputs) name='TestScope')(preprocessed_inputs)
features_shapes = [tf.shape(rpn_feature_map) features_shapes = [tf.shape(rpn_feature_map)
for rpn_feature_map in rpn_feature_maps] for rpn_feature_map in rpn_feature_maps]
self.assertAllEqual(features_shapes[0].numpy(), [2, 56, 56, 256]) self.assertAllEqual(features_shapes[0].numpy(), [2, 56, 56, 256])
self.assertAllEqual(features_shapes[1].numpy(), [2, 28, 28, 256]) self.assertAllEqual(features_shapes[1].numpy(), [2, 28, 28, 256])
self.assertAllEqual(features_shapes[2].numpy(), [2, 14, 14, 256]) self.assertAllEqual(features_shapes[2].numpy(), [2, 14, 14, 256])
...@@ -102,4 +102,4 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractorTest(tf.test.TestCase): ...@@ -102,4 +102,4 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractorTest(tf.test.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
tf.enable_v2_behavior() tf.enable_v2_behavior()
tf.test.main() tf.test.main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment