Commit d96d2e3e authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

finish pylint

parent fbe9b495
......@@ -26,7 +26,7 @@ if tf_version.is_tf2():
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class GreedyBipartiteMatcherTest(test_case.TestCase):
class HungarianBipartiteMatcherTest(test_case.TestCase):
def test_get_expected_matches_when_all_rows_are_valid(self):
similarity_matrix = np.array([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]],
......@@ -37,8 +37,8 @@ class GreedyBipartiteMatcherTest(test_case.TestCase):
matcher = hungarian_matcher.HungarianBipartiteMatcher()
match_results_out = matcher.match(similarity_matrix, valid_rows=valid_rows)
self.assertAllEqual(match_results_out._match_results.numpy(), expected_match_results)
self.assertAllEqual(match_results_out._match_results.numpy(),
expected_match_results)
def test_get_expected_matches_with_all_rows_be_default(self):
similarity_matrix = np.array([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]],
......@@ -48,7 +48,8 @@ class GreedyBipartiteMatcherTest(test_case.TestCase):
matcher = hungarian_matcher.HungarianBipartiteMatcher()
match_results_out = matcher.match(similarity_matrix)
self.assertAllEqual(match_results_out._match_results.numpy(), expected_match_results)
self.assertAllEqual(match_results_out._match_results.numpy(),
expected_match_results)
def test_get_no_matches_with_zero_valid_rows(self):
similarity_matrix = np.array([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]],
......@@ -59,7 +60,8 @@ class GreedyBipartiteMatcherTest(test_case.TestCase):
matcher = hungarian_matcher.HungarianBipartiteMatcher()
match_results_out = matcher.match(similarity_matrix, valid_rows=valid_rows)
self.assertAllEqual(match_results_out._match_results.numpy(), expected_match_results)
self.assertAllEqual(match_results_out._match_results.numpy(),
expected_match_results)
def test_get_expected_matches_with_only_one_valid_row(self):
similarity_matrix = np.array([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]],
......@@ -70,7 +72,8 @@ class GreedyBipartiteMatcherTest(test_case.TestCase):
matcher = hungarian_matcher.HungarianBipartiteMatcher()
match_results_out = matcher.match(similarity_matrix, valid_rows=valid_rows)
self.assertAllEqual(match_results_out._match_results.numpy(), expected_match_results)
self.assertAllEqual(match_results_out._match_results.numpy(),
expected_match_results)
def test_get_expected_matches_with_only_one_valid_row_at_bottom(self):
similarity_matrix = np.array([[0.15, 0.2, 0.3], [0.50, 0.1, 0.8]],
......@@ -81,7 +84,8 @@ class GreedyBipartiteMatcherTest(test_case.TestCase):
matcher = hungarian_matcher.HungarianBipartiteMatcher()
match_results_out = matcher.match(similarity_matrix, valid_rows=valid_rows)
self.assertAllEqual(match_results_out._match_results.numpy(), expected_match_results)
self.assertAllEqual(match_results_out._match_results.numpy(),
expected_match_results)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment