target_gather.py 3.86 KB
Newer Older
Yeqing Li's avatar
Yeqing Li committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Zhenyu Tan's avatar
Zhenyu Tan committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Yeqing Li's avatar
Yeqing Li committed
14

Zhenyu Tan's avatar
Zhenyu Tan committed
15
"""Definition of target gather, which gathers targets from indices."""
Zhenyu Tan's avatar
Zhenyu Tan committed
16
17
18
19

import tensorflow as tf


Zhenyu Tan's avatar
Zhenyu Tan committed
20
21
class TargetGather:
  """Targer gather for dense object detector."""
Zhenyu Tan's avatar
Zhenyu Tan committed
22

Zhenyu Tan's avatar
Zhenyu Tan committed
23
  def __call__(self, labels, match_indices, mask=None, mask_val=0.0):
Zhenyu Tan's avatar
Zhenyu Tan committed
24
25
    """Labels anchors with ground truth inputs.

Zhenyu Tan's avatar
Zhenyu Tan committed
26
27
28
    B: batch_size
    N: number of groundtruth boxes.

Zhenyu Tan's avatar
Zhenyu Tan committed
29
    Args:
Zhenyu Tan's avatar
Zhenyu Tan committed
30
      labels: An integer tensor with shape [N, dims] or [B, N, ...] representing
Zhenyu Tan's avatar
Zhenyu Tan committed
31
        groundtruth labels.
Zhenyu Tan's avatar
Zhenyu Tan committed
32
      match_indices: An integer tensor with shape [M] or [B, M] representing
Zhenyu Tan's avatar
Zhenyu Tan committed
33
        match label index.
Zhenyu Tan's avatar
Zhenyu Tan committed
34
35
      mask: An boolean tensor with shape [M, dims] or [B, M,...] representing
        match labels.
Zhenyu Tan's avatar
Zhenyu Tan committed
36
      mask_val: An integer to fill in for mask.
Zhenyu Tan's avatar
Zhenyu Tan committed
37
38

    Returns:
Zhenyu Tan's avatar
Zhenyu Tan committed
39
40
41
      target: An integer Tensor with shape [M] or [B, M]
    Raises:
      ValueError: If `labels` is higher than rank 3.
Zhenyu Tan's avatar
Zhenyu Tan committed
42
    """
Zhenyu Tan's avatar
Zhenyu Tan committed
43
44
45
46
    if len(labels.shape) <= 2:
      return self._gather_unbatched(labels, match_indices, mask, mask_val)
    elif len(labels.shape) == 3:
      return self._gather_batched(labels, match_indices, mask, mask_val)
Zhenyu Tan's avatar
Zhenyu Tan committed
47
48
49
    else:
      raise ValueError("`TargetGather` does not support `labels` with rank "
                       "larger than 3, got {}".format(len(labels.shape)))
Zhenyu Tan's avatar
Zhenyu Tan committed
50
51
52
53
54
55

  def _gather_unbatched(self, labels, match_indices, mask, mask_val):
    """Gather based on unbatched labels and boxes."""
    num_gt_boxes = tf.shape(labels)[0]

    def _assign_when_rows_empty():
Zhenyu Tan's avatar
Zhenyu Tan committed
56
57
58
59
60
61
      if len(labels.shape) > 1:
        mask_shape = [match_indices.shape[0], labels.shape[-1]]
      else:
        mask_shape = [match_indices.shape[0]]
      return tf.cast(mask_val, labels.dtype) * tf.ones(
          mask_shape, dtype=labels.dtype)
Zhenyu Tan's avatar
Zhenyu Tan committed
62
63
64

    def _assign_when_rows_not_empty():
      targets = tf.gather(labels, match_indices)
Zhenyu Tan's avatar
Zhenyu Tan committed
65
66
67
68
69
70
      if mask is None:
        return targets
      else:
        masked_targets = tf.cast(mask_val, labels.dtype) * tf.ones_like(
            mask, dtype=labels.dtype)
        return tf.where(mask, masked_targets, targets)
Zhenyu Tan's avatar
Zhenyu Tan committed
71
72
73
74
75
76
77
78
79

    return tf.cond(tf.greater(num_gt_boxes, 0),
                   _assign_when_rows_not_empty,
                   _assign_when_rows_empty)

  def _gather_batched(self, labels, match_indices, mask, mask_val):
    """Gather based on batched labels."""
    batch_size = labels.shape[0]
    if batch_size == 1:
Zhenyu Tan's avatar
Zhenyu Tan committed
80
81
82
83
84
85
86
87
      if mask is not None:
        result = self._gather_unbatched(
            tf.squeeze(labels, axis=0), tf.squeeze(match_indices, axis=0),
            tf.squeeze(mask, axis=0), mask_val)
      else:
        result = self._gather_unbatched(
            tf.squeeze(labels, axis=0), tf.squeeze(match_indices, axis=0),
            None, mask_val)
Zhenyu Tan's avatar
Zhenyu Tan committed
88
89
90
91
92
93
94
95
96
97
      return tf.expand_dims(result, axis=0)
    else:
      indices_shape = tf.shape(match_indices)
      indices_dtype = match_indices.dtype
      batch_indices = (tf.expand_dims(
          tf.range(indices_shape[0], dtype=indices_dtype), axis=-1) *
                       tf.ones([1, indices_shape[-1]], dtype=indices_dtype))
      gather_nd_indices = tf.stack(
          [batch_indices, match_indices], axis=-1)
      targets = tf.gather_nd(labels, gather_nd_indices)
Zhenyu Tan's avatar
Zhenyu Tan committed
98
99
100
101
102
103
      if mask is None:
        return targets
      else:
        masked_targets = tf.cast(mask_val, labels.dtype) * tf.ones_like(
            mask, dtype=labels.dtype)
        return tf.where(mask, masked_targets, targets)