Commit d96d2e3e authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

finish pylint

parent fbe9b495
...@@ -26,7 +26,7 @@ if tf_version.is_tf2(): ...@@ -26,7 +26,7 @@ if tf_version.is_tf2():
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.') @unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class GreedyBipartiteMatcherTest(test_case.TestCase): class HungarianBipartiteMatcherTest(test_case.TestCase):
def test_get_expected_matches_when_all_rows_are_valid(self): def test_get_expected_matches_when_all_rows_are_valid(self):
similarity_matrix = np.array([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]], similarity_matrix = np.array([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]],
...@@ -37,8 +37,8 @@ class GreedyBipartiteMatcherTest(test_case.TestCase): ...@@ -37,8 +37,8 @@ class GreedyBipartiteMatcherTest(test_case.TestCase):
matcher = hungarian_matcher.HungarianBipartiteMatcher() matcher = hungarian_matcher.HungarianBipartiteMatcher()
match_results_out = matcher.match(similarity_matrix, valid_rows=valid_rows) match_results_out = matcher.match(similarity_matrix, valid_rows=valid_rows)
self.assertAllEqual(match_results_out._match_results.numpy(), expected_match_results) self.assertAllEqual(match_results_out._match_results.numpy(),
expected_match_results)
def test_get_expected_matches_with_all_rows_be_default(self): def test_get_expected_matches_with_all_rows_be_default(self):
similarity_matrix = np.array([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]], similarity_matrix = np.array([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]],
...@@ -48,7 +48,8 @@ class GreedyBipartiteMatcherTest(test_case.TestCase): ...@@ -48,7 +48,8 @@ class GreedyBipartiteMatcherTest(test_case.TestCase):
matcher = hungarian_matcher.HungarianBipartiteMatcher() matcher = hungarian_matcher.HungarianBipartiteMatcher()
match_results_out = matcher.match(similarity_matrix) match_results_out = matcher.match(similarity_matrix)
self.assertAllEqual(match_results_out._match_results.numpy(), expected_match_results) self.assertAllEqual(match_results_out._match_results.numpy(),
expected_match_results)
def test_get_no_matches_with_zero_valid_rows(self): def test_get_no_matches_with_zero_valid_rows(self):
similarity_matrix = np.array([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]], similarity_matrix = np.array([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]],
...@@ -59,7 +60,8 @@ class GreedyBipartiteMatcherTest(test_case.TestCase): ...@@ -59,7 +60,8 @@ class GreedyBipartiteMatcherTest(test_case.TestCase):
matcher = hungarian_matcher.HungarianBipartiteMatcher() matcher = hungarian_matcher.HungarianBipartiteMatcher()
match_results_out = matcher.match(similarity_matrix, valid_rows=valid_rows) match_results_out = matcher.match(similarity_matrix, valid_rows=valid_rows)
self.assertAllEqual(match_results_out._match_results.numpy(), expected_match_results) self.assertAllEqual(match_results_out._match_results.numpy(),
expected_match_results)
def test_get_expected_matches_with_only_one_valid_row(self): def test_get_expected_matches_with_only_one_valid_row(self):
similarity_matrix = np.array([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]], similarity_matrix = np.array([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]],
...@@ -70,7 +72,8 @@ class GreedyBipartiteMatcherTest(test_case.TestCase): ...@@ -70,7 +72,8 @@ class GreedyBipartiteMatcherTest(test_case.TestCase):
matcher = hungarian_matcher.HungarianBipartiteMatcher() matcher = hungarian_matcher.HungarianBipartiteMatcher()
match_results_out = matcher.match(similarity_matrix, valid_rows=valid_rows) match_results_out = matcher.match(similarity_matrix, valid_rows=valid_rows)
self.assertAllEqual(match_results_out._match_results.numpy(), expected_match_results) self.assertAllEqual(match_results_out._match_results.numpy(),
expected_match_results)
def test_get_expected_matches_with_only_one_valid_row_at_bottom(self): def test_get_expected_matches_with_only_one_valid_row_at_bottom(self):
similarity_matrix = np.array([[0.15, 0.2, 0.3], [0.50, 0.1, 0.8]], similarity_matrix = np.array([[0.15, 0.2, 0.3], [0.50, 0.1, 0.8]],
...@@ -81,7 +84,8 @@ class GreedyBipartiteMatcherTest(test_case.TestCase): ...@@ -81,7 +84,8 @@ class GreedyBipartiteMatcherTest(test_case.TestCase):
matcher = hungarian_matcher.HungarianBipartiteMatcher() matcher = hungarian_matcher.HungarianBipartiteMatcher()
match_results_out = matcher.match(similarity_matrix, valid_rows=valid_rows) match_results_out = matcher.match(similarity_matrix, valid_rows=valid_rows)
self.assertAllEqual(match_results_out._match_results.numpy(), expected_match_results) self.assertAllEqual(match_results_out._match_results.numpy(),
expected_match_results)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment