"tests/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "32ff4773d4b6662ddbb35c4a75f7178eb2b70cf0"
Commit d96d2e3e authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

finish pylint

parent fbe9b495
...@@ -58,6 +58,6 @@ class HungarianBipartiteMatcher(matcher.Matcher): ...@@ -58,6 +58,6 @@ class HungarianBipartiteMatcher(matcher.Matcher):
return tf.numpy_function(numpy_matching, inputs, Tout=[tf.int32]) return tf.numpy_function(numpy_matching, inputs, Tout=[tf.int32])
matching_result = tf.autograph.experimental.do_not_convert( matching_result = tf.autograph.experimental.do_not_convert(
numpy_wrapper)([distance_matrix]) numpy_wrapper)([distance_matrix])
return tf.reshape(matching_result, [-1]) return tf.reshape(matching_result, [-1])
...@@ -26,7 +26,7 @@ if tf_version.is_tf2(): ...@@ -26,7 +26,7 @@ if tf_version.is_tf2():
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.') @unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class GreedyBipartiteMatcherTest(test_case.TestCase): class HungarianBipartiteMatcherTest(test_case.TestCase):
def test_get_expected_matches_when_all_rows_are_valid(self): def test_get_expected_matches_when_all_rows_are_valid(self):
similarity_matrix = np.array([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]], similarity_matrix = np.array([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]],
...@@ -37,9 +37,9 @@ class GreedyBipartiteMatcherTest(test_case.TestCase): ...@@ -37,9 +37,9 @@ class GreedyBipartiteMatcherTest(test_case.TestCase):
matcher = hungarian_matcher.HungarianBipartiteMatcher() matcher = hungarian_matcher.HungarianBipartiteMatcher()
match_results_out = matcher.match(similarity_matrix, valid_rows=valid_rows) match_results_out = matcher.match(similarity_matrix, valid_rows=valid_rows)
self.assertAllEqual(match_results_out._match_results.numpy(), expected_match_results) self.assertAllEqual(match_results_out._match_results.numpy(),
expected_match_results)
def test_get_expected_matches_with_all_rows_be_default(self): def test_get_expected_matches_with_all_rows_be_default(self):
similarity_matrix = np.array([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]], similarity_matrix = np.array([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]],
dtype=np.float32) dtype=np.float32)
...@@ -47,41 +47,45 @@ class GreedyBipartiteMatcherTest(test_case.TestCase): ...@@ -47,41 +47,45 @@ class GreedyBipartiteMatcherTest(test_case.TestCase):
matcher = hungarian_matcher.HungarianBipartiteMatcher() matcher = hungarian_matcher.HungarianBipartiteMatcher()
match_results_out = matcher.match(similarity_matrix) match_results_out = matcher.match(similarity_matrix)
self.assertAllEqual(match_results_out._match_results.numpy(), expected_match_results) self.assertAllEqual(match_results_out._match_results.numpy(),
expected_match_results)
def test_get_no_matches_with_zero_valid_rows(self): def test_get_no_matches_with_zero_valid_rows(self):
similarity_matrix = np.array([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]], similarity_matrix = np.array([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]],
dtype=np.float32) dtype=np.float32)
valid_rows = np.zeros([2], dtype=np.bool) valid_rows = np.zeros([2], dtype=np.bool)
expected_match_results = [-1, -1, -1] expected_match_results = [-1, -1, -1]
matcher = hungarian_matcher.HungarianBipartiteMatcher() matcher = hungarian_matcher.HungarianBipartiteMatcher()
match_results_out = matcher.match(similarity_matrix, valid_rows=valid_rows) match_results_out = matcher.match(similarity_matrix, valid_rows=valid_rows)
self.assertAllEqual(match_results_out._match_results.numpy(), expected_match_results) self.assertAllEqual(match_results_out._match_results.numpy(),
expected_match_results)
def test_get_expected_matches_with_only_one_valid_row(self): def test_get_expected_matches_with_only_one_valid_row(self):
similarity_matrix = np.array([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]], similarity_matrix = np.array([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]],
dtype=np.float32) dtype=np.float32)
valid_rows = np.array([True, False], dtype=np.bool) valid_rows = np.array([True, False], dtype=np.bool)
expected_match_results = [-1, -1, 0] expected_match_results = [-1, -1, 0]
matcher = hungarian_matcher.HungarianBipartiteMatcher() matcher = hungarian_matcher.HungarianBipartiteMatcher()
match_results_out = matcher.match(similarity_matrix, valid_rows=valid_rows) match_results_out = matcher.match(similarity_matrix, valid_rows=valid_rows)
self.assertAllEqual(match_results_out._match_results.numpy(), expected_match_results) self.assertAllEqual(match_results_out._match_results.numpy(),
expected_match_results)
def test_get_expected_matches_with_only_one_valid_row_at_bottom(self): def test_get_expected_matches_with_only_one_valid_row_at_bottom(self):
similarity_matrix = np.array([[0.15, 0.2, 0.3], [0.50, 0.1, 0.8]], similarity_matrix = np.array([[0.15, 0.2, 0.3], [0.50, 0.1, 0.8]],
dtype=np.float32) dtype=np.float32)
valid_rows = np.array([False, True], dtype=np.bool) valid_rows = np.array([False, True], dtype=np.bool)
expected_match_results = [-1, -1, 0] expected_match_results = [-1, -1, 0]
matcher = hungarian_matcher.HungarianBipartiteMatcher() matcher = hungarian_matcher.HungarianBipartiteMatcher()
match_results_out = matcher.match(similarity_matrix, valid_rows=valid_rows) match_results_out = matcher.match(similarity_matrix, valid_rows=valid_rows)
self.assertAllEqual(match_results_out._match_results.numpy(), expected_match_results) self.assertAllEqual(match_results_out._match_results.numpy(),
expected_match_results)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment