"tests/bnb/test_mixed_int8.py" did not exist on "700e0cd65f0cee59cb298b30932caf311e6c9519"
bipartite_matcher.py 2.47 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Bipartite matcher implementation."""

import tensorflow as tf

from tensorflow.contrib.image.python.ops import image_ops
from object_detection.core import matcher


class GreedyBipartiteMatcher(matcher.Matcher):
  """Wraps a Tensorflow greedy bipartite matcher."""

27
28
29
30
31
32
33
34
35
36
37
  def __init__(self, use_matmul_gather=False):
    """Constructs a Matcher.

    Args:
      use_matmul_gather: Force constructed match objects to use matrix
        multiplication based gather instead of standard tf.gather.
        (Default: False).
    """
    super(GreedyBipartiteMatcher, self).__init__(
        use_matmul_gather=use_matmul_gather)

38
39
40
  def _match(self, similarity_matrix, num_valid_rows=-1):
    """Bipartite matches a collection rows and columns. A greedy bi-partite.

41
42
    TODO: Add num_valid_columns options to match only that many columns
    with all the rows.
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

    Args:
      similarity_matrix: Float tensor of shape [N, M] with pairwise similarity
        where higher values mean more similar.
      num_valid_rows: A scalar or a 1-D tensor with one element describing the
        number of valid rows of similarity_matrix to consider for the bipartite
        matching. If set to be negative, then all rows from similarity_matrix
        are used.

    Returns:
      match_results: int32 tensor of shape [M] with match_results[i]=-1
        meaning that column i is not matched and otherwise that it is matched to
        row match_results[i].
    """
    # Convert similarity matrix to distance matrix as tf.image.bipartite tries
    # to find minimum distance matches.
    distance_matrix = -1 * similarity_matrix
    _, match_results = image_ops.bipartite_match(
        distance_matrix, num_valid_rows)
    match_results = tf.reshape(match_results, [-1])
    match_results = tf.cast(match_results, tf.int32)
    return match_results