Commit e6abe821 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

fix pr

parent d96d2e3e
...@@ -21,8 +21,9 @@ import numpy as np ...@@ -21,8 +21,9 @@ import numpy as np
from object_detection.core import matcher from object_detection.core import matcher
from scipy.optimize import linear_sum_assignment from scipy.optimize import linear_sum_assignment
class HungarianBipartiteMatcher(matcher.Matcher): class HungarianBipartiteMatcher(matcher.Matcher):
"""Wraps a Tensorflow greedy bipartite matcher.""" """Wraps a Hungarian bipartite matcher into TensorFlow."""
def __init__(self): def __init__(self):
"""Constructs a Matcher.""" """Constructs a Matcher."""
...@@ -51,8 +52,7 @@ class HungarianBipartiteMatcher(matcher.Matcher): ...@@ -51,8 +52,7 @@ class HungarianBipartiteMatcher(matcher.Matcher):
def numpy_matching(input_matrix): def numpy_matching(input_matrix):
row_indices, col_indices = linear_sum_assignment(input_matrix) row_indices, col_indices = linear_sum_assignment(input_matrix)
match_results = np.full(input_matrix.shape[1], -1) match_results = np.full(input_matrix.shape[1], -1)
for i in range(len(col_indices)): match_results[col_indices] = row_indices
match_results[col_indices[i]] = row_indices[i]
return match_results.astype(np.int32) return match_results.astype(np.int32)
return tf.numpy_function(numpy_matching, inputs, Tout=[tf.int32]) return tf.numpy_function(numpy_matching, inputs, Tout=[tf.int32])
......
...@@ -87,5 +87,18 @@ class HungarianBipartiteMatcherTest(test_case.TestCase): ...@@ -87,5 +87,18 @@ class HungarianBipartiteMatcherTest(test_case.TestCase):
self.assertAllEqual(match_results_out._match_results.numpy(), self.assertAllEqual(match_results_out._match_results.numpy(),
expected_match_results) expected_match_results)
def test_get_expected_matches_with_two_valid_rows(self):
similarity_matrix = np.array([[0.15, 0.2, 0.3], [0.50, 0.1, 0.8],
[0.84, 0.32, 0.2]],
dtype=np.float32)
valid_rows = np.array([True, False, True], dtype=np.bool)
expected_match_results = [1, -1, 0]
matcher = hungarian_matcher.HungarianBipartiteMatcher()
match_results_out = matcher.match(similarity_matrix, valid_rows=valid_rows)
self.assertAllEqual(match_results_out._match_results.numpy(),
expected_match_results)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment