"...git@developer.sourcefind.cn:modelzoo/gpt2_migraphx.git" did not exist on "816b3d5203b7d69f4cf1d457c78a82a4b69798ea"
Commit 6a2c9932 authored by syiming's avatar syiming
Browse files

add unit test for get_proposal_feature_extractor_model smaller input size

parent d26ef0e6
...@@ -53,6 +53,19 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractorTest(tf.test.TestCase): ...@@ -53,6 +53,19 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractorTest(tf.test.TestCase):
weight_decay=0.0) weight_decay=0.0)
def test_extract_proposal_features_returns_expected_size(self): def test_extract_proposal_features_returns_expected_size(self):
feature_extractor = self._build_feature_extractor()
preprocessed_inputs = tf.random_uniform(
[2, 320, 320, 3], maxval=255, dtype=tf.float32)
rpn_feature_maps = feature_extractor.get_proposal_feature_extractor_model(
name='TestScope')(preprocessed_inputs)
features_shapes = [tf.shape(rpn_feature_map)
for name, rpn_feature_map in rpn_feature_maps.items()]
self.assertAllEqual(features_shapes[0].numpy(), [2, 40, 40, 256])
self.assertAllEqual(features_shapes[1].numpy(), [2, 20, 20, 256])
self.assertAllEqual(features_shapes[2].numpy(), [2, 10, 10, 256])
def test_extract_proposal_features_half_size_input(self):
feature_extractor = self._build_feature_extractor() feature_extractor = self._build_feature_extractor()
preprocessed_inputs = tf.random_uniform( preprocessed_inputs = tf.random_uniform(
[2, 160, 160, 3], maxval=255, dtype=tf.float32) [2, 160, 160, 3], maxval=255, dtype=tf.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment