"test/srt/cpu/test_norm.py" did not exist on "3ded6235c9e423d39cb12e7fd3453c714e7a33ef"
anchor.py 18.9 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Vishnu Banna's avatar
Vishnu Banna committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Yolo Anchor labler."""
Vishnu Banna's avatar
Vishnu Banna committed
16
import numpy as np
Vishnu Banna's avatar
Vishnu Banna committed
17
import tensorflow as tf
Vishnu Banna's avatar
Vishnu Banna committed
18

Abdullah Rashwan's avatar
Abdullah Rashwan committed
19
20
21
from official.projects.yolo.ops import box_ops
from official.projects.yolo.ops import loss_utils
from official.projects.yolo.ops import preprocessing_ops
Vishnu Banna's avatar
Vishnu Banna committed
22

Vishnu Banna's avatar
Vishnu Banna committed
23
INF = 10000000
Vishnu Banna's avatar
Vishnu Banna committed
24

25

Vishnu Banna's avatar
Vishnu Banna committed
26
27
28
29
30
31
def get_best_anchor(y_true,
                    anchors,
                    stride,
                    width=1,
                    height=1,
                    iou_thresh=0.25,
32
                    best_match_only=False,
Vishnu Banna's avatar
Vishnu Banna committed
33
                    use_tie_breaker=True):
Vishnu Banna's avatar
Vishnu Banna committed
34
  """Get the correct anchor that is assoiciated with each box using IOU.
35

Vishnu Banna's avatar
Vishnu Banna committed
36
  Args:
Vishnu Banna's avatar
Vishnu Banna committed
37
    y_true: tf.Tensor[] for the list of bounding boxes in the yolo format.
38
39
40
    anchors: list or tensor for the anchor boxes to be used in prediction found
      via Kmeans.
    stride: `int` stride for the anchors.
Vishnu Banna's avatar
Vishnu Banna committed
41
42
    width: int for the image width.
    height: int for the image height.
43
44
45
46
47
48
49
50
51
52
53
    iou_thresh: `float` the minimum iou threshold to use for selecting boxes for
      each level.
    best_match_only: `bool` if the box only has one match and it is less than
      the iou threshold, when set to True, this match will be dropped as no
      anchors can be linked to it.
    use_tie_breaker: `bool` if there is many anchors for a given box, then
      attempt to use all of them, if False, only the first matching box will be
      used.
  Returns:
    tf.Tensor: y_true with the anchor associated with each ground truth box
      known
Vishnu Banna's avatar
Vishnu Banna committed
54
55
56
57
58
59
  """
  with tf.name_scope('get_best_anchor'):
    width = tf.cast(width, dtype=tf.float32)
    height = tf.cast(height, dtype=tf.float32)
    scaler = tf.convert_to_tensor([width, height])

Vishnu Banna's avatar
Vishnu Banna committed
60
    # scale to levels houts width and height
Vishnu Banna's avatar
Vishnu Banna committed
61
    true_wh = tf.cast(y_true[..., 2:4], dtype=tf.float32) * scaler
Vishnu Banna's avatar
Vishnu Banna committed
62
63

    # scale down from large anchor to small anchor type
64
    anchors = tf.cast(anchors, dtype=tf.float32) / stride
Vishnu Banna's avatar
Vishnu Banna committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

    k = tf.shape(anchors)[0]

    anchors = tf.concat([tf.zeros_like(anchors), anchors], axis=-1)
    truth_comp = tf.concat([tf.zeros_like(true_wh), true_wh], axis=-1)

    if iou_thresh >= 1.0:
      anchors = tf.expand_dims(anchors, axis=-2)
      truth_comp = tf.expand_dims(truth_comp, axis=-3)

      aspect = truth_comp[..., 2:4] / anchors[..., 2:4]
      aspect = tf.where(tf.math.is_nan(aspect), tf.zeros_like(aspect), aspect)
      aspect = tf.maximum(aspect, 1 / aspect)
      aspect = tf.where(tf.math.is_nan(aspect), tf.zeros_like(aspect), aspect)
      aspect = tf.reduce_max(aspect, axis=-1)

      values, indexes = tf.math.top_k(
          tf.transpose(-aspect, perm=[1, 0]),
          k=tf.cast(k, dtype=tf.int32),
          sorted=True)
      values = -values
      ind_mask = tf.cast(values < iou_thresh, dtype=indexes.dtype)
    else:
      truth_comp = box_ops.xcycwh_to_yxyx(truth_comp)
      anchors = box_ops.xcycwh_to_yxyx(anchors)
      iou_raw = box_ops.aggregated_comparitive_iou(
          truth_comp,
          anchors,
          iou_type=3,
      )
      values, indexes = tf.math.top_k(
96
          iou_raw, k=tf.cast(k, dtype=tf.int32), sorted=True)
Vishnu Banna's avatar
Vishnu Banna committed
97
98
99
100
101
102
103
104
105
106
      ind_mask = tf.cast(values >= iou_thresh, dtype=indexes.dtype)

    # pad the indexs such that all values less than the thresh are -1
    # add one, multiply the mask to zeros all the bad locations
    # subtract 1 makeing all the bad locations 0.
    if best_match_only:
      iou_index = ((indexes[..., 0:] + 1) * ind_mask[..., 0:]) - 1
    elif use_tie_breaker:
      iou_index = tf.concat([
          tf.expand_dims(indexes[..., 0], axis=-1),
107
108
109
          ((indexes[..., 1:] + 1) * ind_mask[..., 1:]) - 1
      ],
                            axis=-1)
Vishnu Banna's avatar
Vishnu Banna committed
110
111
112
    else:
      iou_index = tf.concat([
          tf.expand_dims(indexes[..., 0], axis=-1),
113
114
115
          tf.zeros_like(indexes[..., 1:]) - 1
      ],
                            axis=-1)
Vishnu Banna's avatar
Vishnu Banna committed
116
117
118

  return tf.cast(iou_index, dtype=tf.float32), tf.cast(values, dtype=tf.float32)

119

Vishnu Banna's avatar
Vishnu Banna committed
120
class YoloAnchorLabeler:
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
  """Anchor labeler for the Yolo Models."""

  def __init__(self,
               anchors=None,
               anchor_free_level_limits=None,
               level_strides=None,
               center_radius=None,
               max_num_instances=200,
               match_threshold=0.25,
               best_matches_only=False,
               use_tie_breaker=True,
               darknet=False,
               dtype='float32'):
    """Initialization for anchor labler.

    Args:
Vishnu Banna's avatar
Vishnu Banna committed
137
      anchors: `Dict[List[Union[int, float]]]` values for each anchor box.
138
139
      anchor_free_level_limits: `List` the box sizes that will be allowed at
        each FPN level as is done in the FCOS and YOLOX paper for anchor free
Vishnu Banna's avatar
Vishnu Banna committed
140
        box assignment.
141
142
143
      level_strides: `Dict[int]` for how much the model scales down the images
        at the each level.
      center_radius: `Dict[float]` for radius around each box center to search
Vishnu Banna's avatar
Vishnu Banna committed
144
145
        for extra centers in each level.
      max_num_instances: `int` for the number of boxes to compute loss on.
146
147
148
149
150
151
152
153
      match_threshold: `float` indicating the threshold over which an anchor
        will be considered for prediction, at zero, all the anchors will be used
        and at 1.0 only the best will be used. for anchor thresholds larger than
        1.0 we stop using the IOU for anchor comparison and resort directly to
        comparing the width and height, this is used for the scaled models.
      best_matches_only: `boolean` indicating how boxes are selected for
        optimization.
      use_tie_breaker: `boolean` indicating whether to use the anchor threshold
Vishnu Banna's avatar
Vishnu Banna committed
154
        value.
155
      darknet: `boolean` indicating which data pipeline to use. Setting to True
Vishnu Banna's avatar
Vishnu Banna committed
156
        swaps the pipeline to output images realtive to Yolov4 and older.
157
      dtype: `str` indicating the output datatype of the datapipeline selecting
Vishnu Banna's avatar
Vishnu Banna committed
158
159
        from {"float32", "float16", "bfloat16"}.
    """
Vishnu Banna's avatar
Vishnu Banna committed
160
161
    self.anchors = anchors
    self.masks = self._get_mask()
Vishnu Banna's avatar
Vishnu Banna committed
162
    self.anchor_free_level_limits = self._get_level_limits(
163
        anchor_free_level_limits)
Vishnu Banna's avatar
Vishnu Banna committed
164
165
166

    if darknet and self.anchor_free_level_limits is None:
      center_radius = None
167

Vishnu Banna's avatar
Vishnu Banna committed
168
169
170
171
172
173
174
    self.keys = self.anchors.keys()
    if self.anchor_free_level_limits is not None:
      maxim = 2000
      match_threshold = -0.01
      self.num_instances = {key: maxim for key in self.keys}
    elif not darknet:
      self.num_instances = {
175
176
          key: (6 - i) * max_num_instances for i, key in enumerate(self.keys)
      }
Vishnu Banna's avatar
Vishnu Banna committed
177
178
179
180
181
    else:
      self.num_instances = {key: max_num_instances for key in self.keys}

    self.center_radius = center_radius
    self.level_strides = level_strides
Vishnu Banna's avatar
Vishnu Banna committed
182
183
184
    self.match_threshold = match_threshold
    self.best_matches_only = best_matches_only
    self.use_tie_breaker = use_tie_breaker
Vishnu Banna's avatar
Vishnu Banna committed
185
    self.dtype = dtype
Vishnu Banna's avatar
Vishnu Banna committed
186
187

  def _get_mask(self):
Vishnu Banna's avatar
Vishnu Banna committed
188
    """For each level get indexs of each anchor for box search across levels."""
Vishnu Banna's avatar
Vishnu Banna committed
189
190
191
192
193
194
195
196
197
198
    masks = {}
    start = 0

    minimum = int(min(self.anchors.keys()))
    maximum = int(max(self.anchors.keys()))
    for i in range(minimum, maximum + 1):
      per_scale = len(self.anchors[str(i)])
      masks[str(i)] = list(range(start, per_scale + start))
      start += per_scale
    return masks
199

Vishnu Banna's avatar
Vishnu Banna committed
200
201
202
203
204
205
206
207
208
209
210
  def _get_level_limits(self, level_limits):
    """For each level receptive feild range for anchor free box placement."""
    if level_limits is not None:
      level_limits_dict = {}
      level_limits = [0.0] + level_limits + [np.inf]

      for i, key in enumerate(self.anchors.keys()):
        level_limits_dict[key] = level_limits[i:i + 2]
    else:
      level_limits_dict = None
    return level_limits_dict
Vishnu Banna's avatar
Vishnu Banna committed
211
212

  def _tie_breaking_search(self, anchors, mask, boxes, classes):
Vishnu Banna's avatar
Vishnu Banna committed
213
    """After search, link each anchor ind to the correct map in ground truth."""
Vishnu Banna's avatar
Vishnu Banna committed
214
215
    mask = tf.cast(tf.reshape(mask, [1, 1, 1, -1]), anchors.dtype)
    anchors = tf.expand_dims(anchors, axis=-1)
216
    viable = tf.where(tf.squeeze(anchors == mask, axis=0))
Vishnu Banna's avatar
Vishnu Banna committed
217

218
    gather_id, _, anchor_id = tf.split(viable, 3, axis=-1)
Vishnu Banna's avatar
Vishnu Banna committed
219
220
221

    boxes = tf.gather_nd(boxes, gather_id)
    classes = tf.gather_nd(classes, gather_id)
222
223

    classes = tf.expand_dims(classes, axis=-1)
Vishnu Banna's avatar
Vishnu Banna committed
224
225
226
227
    classes = tf.cast(classes, boxes.dtype)
    anchor_id = tf.cast(anchor_id, boxes.dtype)
    return boxes, classes, anchor_id

228
229
230
231
232
233
234
235
236
237
  def _get_anchor_id(self,
                     key,
                     boxes,
                     classes,
                     width,
                     height,
                     stride,
                     iou_index=None):
    """Find the object anchor assignments in an anchor based paradigm."""

Vishnu Banna's avatar
Vishnu Banna committed
238
    # find the best anchor
Vishnu Banna's avatar
Vishnu Banna committed
239
    anchors = self.anchors[key]
Vishnu Banna's avatar
Vishnu Banna committed
240
241
242
    num_anchors = len(anchors)
    if self.best_matches_only:
      # get the best anchor for each box
243
244
245
246
247
248
249
250
      iou_index, _ = get_best_anchor(
          boxes,
          anchors,
          stride,
          width=width,
          height=height,
          best_match_only=True,
          iou_thresh=self.match_threshold)
Vishnu Banna's avatar
Vishnu Banna committed
251
      mask = range(num_anchors)
252
    else:
Vishnu Banna's avatar
Vishnu Banna committed
253
      # search is done across FPN levels, get the mask of anchor indexes
254
      # corralated to this level.
Vishnu Banna's avatar
Vishnu Banna committed
255
256
257
      mask = self.masks[key]

    # search for the correct box to use
258
259
    (boxes, classes,
     anchors) = self._tie_breaking_search(iou_index, mask, boxes, classes)
Vishnu Banna's avatar
Vishnu Banna committed
260
261
    return boxes, classes, anchors, num_anchors

Vishnu Banna's avatar
Vishnu Banna committed
262
  def _get_centers(self, boxes, classes, anchors, width, height, scale_xy):
263
    """Find the object center assignments in an anchor based paradigm."""
Vishnu Banna's avatar
Vishnu Banna committed
264
265
    offset = tf.cast(0.5 * (scale_xy - 1), boxes.dtype)

266
    grid_xy, _ = tf.split(boxes, 2, axis=-1)
Vishnu Banna's avatar
Vishnu Banna committed
267
268
269
270
271
    wh_scale = tf.cast(tf.convert_to_tensor([width, height]), boxes.dtype)

    grid_xy = grid_xy * wh_scale
    centers = tf.math.floor(grid_xy)

272
273
    if offset != 0.0:
      clamp = lambda x, ma: tf.maximum(  # pylint:disable=g-long-lambda
Vishnu Banna's avatar
Vishnu Banna committed
274
275
276
277
          tf.minimum(x, tf.cast(ma, x.dtype)), tf.zeros_like(x))

      grid_xy_index = grid_xy - centers
      positive_shift = ((grid_xy_index < offset) & (grid_xy > 1.))
278
279
      negative_shift = ((grid_xy_index > (1 - offset)) & (grid_xy <
                                                          (wh_scale - 1.)))
Vishnu Banna's avatar
Vishnu Banna committed
280

281
282
283
284
      zero, _ = tf.split(tf.ones_like(positive_shift), 2, axis=-1)
      shift_mask = tf.concat([zero, positive_shift, negative_shift], axis=-1)
      offset = tf.cast([[0, 0], [1, 0], [0, 1], [-1, 0], [0, -1]],
                       offset.dtype) * offset
Vishnu Banna's avatar
Vishnu Banna committed
285
286
287

      num_shifts = tf.shape(shift_mask)
      num_shifts = num_shifts[-1]
288
289
290
      boxes = tf.tile(tf.expand_dims(boxes, axis=-2), [1, num_shifts, 1])
      classes = tf.tile(tf.expand_dims(classes, axis=-2), [1, num_shifts, 1])
      anchors = tf.tile(tf.expand_dims(anchors, axis=-2), [1, num_shifts, 1])
Vishnu Banna's avatar
Vishnu Banna committed
291
292

      shift_mask = tf.cast(shift_mask, boxes.dtype)
293
      shift_ind = shift_mask * tf.range(0, num_shifts, dtype=boxes.dtype)
Vishnu Banna's avatar
Vishnu Banna committed
294
      shift_ind = shift_ind - (1 - shift_mask)
295
      shift_ind = tf.expand_dims(shift_ind, axis=-1)
Vishnu Banna's avatar
Vishnu Banna committed
296

297
298
      boxes_and_centers = tf.concat([boxes, classes, anchors, shift_ind],
                                    axis=-1)
Vishnu Banna's avatar
Vishnu Banna committed
299
      boxes_and_centers = tf.reshape(boxes_and_centers, [-1, 7])
300
      _, center_ids = tf.split(boxes_and_centers, [6, 1], axis=-1)
Vishnu Banna's avatar
Vishnu Banna committed
301
302

      select = tf.where(center_ids >= 0)
303
      select, _ = tf.split(select, 2, axis=-1)
Vishnu Banna's avatar
Vishnu Banna committed
304
305
306
307
308
309
310

      boxes_and_centers = tf.gather_nd(boxes_and_centers, select)

      center_ids = tf.gather_nd(center_ids, select)
      center_ids = tf.cast(center_ids, tf.int32)
      shifts = tf.gather_nd(offset, center_ids)

311
312
313
      boxes, classes, anchors, _ = tf.split(
          boxes_and_centers, [4, 1, 1, 1], axis=-1)
      grid_xy, _ = tf.split(boxes, 2, axis=-1)
Vishnu Banna's avatar
Vishnu Banna committed
314
315
      centers = tf.math.floor(grid_xy * wh_scale - shifts)
      centers = clamp(centers, wh_scale - 1)
316
317
318

    x, y = tf.split(centers, 2, axis=-1)
    centers = tf.cast(tf.concat([y, x, anchors], axis=-1), tf.int32)
Vishnu Banna's avatar
Vishnu Banna committed
319
320
    return boxes, classes, centers

321
  def _get_anchor_free(self, key, boxes, classes, height, width, stride,
Vishnu Banna's avatar
Vishnu Banna committed
322
323
324
325
                       center_radius):
    """Find the box assignements in an anchor free paradigm."""
    level_limits = self.anchor_free_level_limits[key]
    gen = loss_utils.GridGenerator(anchors=[[1, 1]], scale_anchors=stride)
Vishnu Banna's avatar
Vishnu Banna committed
326
327
328
329
330
331
332
333
334
335
336
337
338
339
    grid_points = gen(width, height, 1, boxes.dtype)[0]
    grid_points = tf.squeeze(grid_points, axis=0)
    box_list = boxes
    class_list = classes

    grid_points = (grid_points + 0.5) * stride
    x_centers, y_centers = grid_points[..., 0], grid_points[..., 1]
    boxes *= (tf.convert_to_tensor([width, height, width, height]) * stride)

    tlbr_boxes = box_ops.xcycwh_to_yxyx(boxes)

    boxes = tf.reshape(boxes, [1, 1, -1, 4])
    tlbr_boxes = tf.reshape(tlbr_boxes, [1, 1, -1, 4])
    if self.use_tie_breaker:
340
      area = tf.reduce_prod(boxes[..., 2:], axis=-1)
Vishnu Banna's avatar
Vishnu Banna committed
341
342
343
344
345
346
347

    # check if the box is in the receptive feild of the this fpn level
    b_t = y_centers - tlbr_boxes[..., 0]
    b_l = x_centers - tlbr_boxes[..., 1]
    b_b = tlbr_boxes[..., 2] - y_centers
    b_r = tlbr_boxes[..., 3] - x_centers
    box_delta = tf.stack([b_t, b_l, b_b, b_r], axis=-1)
Vishnu Banna's avatar
Vishnu Banna committed
348
    if level_limits is not None:
Vishnu Banna's avatar
Vishnu Banna committed
349
      max_reg_targets_per_im = tf.reduce_max(box_delta, axis=-1)
Vishnu Banna's avatar
Vishnu Banna committed
350
351
      gt_min = max_reg_targets_per_im >= level_limits[0]
      gt_max = max_reg_targets_per_im <= level_limits[1]
Vishnu Banna's avatar
Vishnu Banna committed
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
      is_in_boxes = tf.logical_and(gt_min, gt_max)
    else:
      is_in_boxes = tf.reduce_min(box_delta, axis=-1) > 0.0
    is_in_boxes_all = tf.reduce_any(is_in_boxes, axis=(0, 1), keepdims=True)

    # check if the center is in the receptive feild of the this fpn level
    c_t = y_centers - (boxes[..., 1] - center_radius * stride)
    c_l = x_centers - (boxes[..., 0] - center_radius * stride)
    c_b = (boxes[..., 1] + center_radius * stride) - y_centers
    c_r = (boxes[..., 0] + center_radius * stride) - x_centers
    centers_delta = tf.stack([c_t, c_l, c_b, c_r], axis=-1)
    is_in_centers = tf.reduce_min(centers_delta, axis=-1) > 0.0
    is_in_centers_all = tf.reduce_any(is_in_centers, axis=(0, 1), keepdims=True)

    # colate all masks to get the final locations
    is_in_index = tf.logical_or(is_in_boxes_all, is_in_centers_all)
    is_in_boxes_and_center = tf.logical_and(is_in_boxes, is_in_centers)
    is_in_boxes_and_center = tf.logical_and(is_in_index, is_in_boxes_and_center)

    if self.use_tie_breaker:
372
      boxes_all = tf.cast(is_in_boxes_and_center, area.dtype)
Vishnu Banna's avatar
Vishnu Banna committed
373
      boxes_all = ((boxes_all * area) + ((1 - boxes_all) * INF))
374
      boxes_min = tf.reduce_min(boxes_all, axis=-1, keepdims=True)
Vishnu Banna's avatar
Vishnu Banna committed
375
      boxes_min = tf.where(boxes_min == INF, -1.0, boxes_min)
Vishnu Banna's avatar
Vishnu Banna committed
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
      is_in_boxes_and_center = boxes_all == boxes_min

    # construct the index update grid
    reps = tf.reduce_sum(tf.cast(is_in_boxes_and_center, tf.int16), axis=-1)
    indexes = tf.cast(tf.where(is_in_boxes_and_center), tf.int32)
    y, x, t = tf.split(indexes, 3, axis=-1)

    boxes = tf.gather_nd(box_list, t)
    classes = tf.cast(tf.gather_nd(class_list, t), boxes.dtype)
    reps = tf.gather_nd(reps, tf.concat([y, x], axis=-1))
    reps = tf.cast(tf.expand_dims(reps, axis=-1), boxes.dtype)
    classes = tf.cast(tf.expand_dims(classes, axis=-1), boxes.dtype)
    conf = tf.ones_like(classes)

    # return the samples and the indexes
    samples = tf.concat([boxes, conf, classes], axis=-1)
    indexes = tf.concat([y, x, tf.zeros_like(t)], axis=-1)
    return indexes, samples

395
396
397
398
399
400
401
  def build_label_per_path(self,
                           key,
                           boxes,
                           classes,
                           width,
                           height,
                           iou_index=None):
Vishnu Banna's avatar
Vishnu Banna committed
402
403
404
405
    """Builds the labels for one path."""
    stride = self.level_strides[key]
    scale_xy = self.center_radius[key] if self.center_radius is not None else 1

406
407
408
    width = tf.cast(width // stride, boxes.dtype)
    height = tf.cast(height // stride, boxes.dtype)

Vishnu Banna's avatar
Vishnu Banna committed
409
    if self.anchor_free_level_limits is None:
410
411
412
      (boxes, classes, anchors, num_anchors) = self._get_anchor_id(
          key, boxes, classes, width, height, stride, iou_index=iou_index)
      boxes, classes, centers = self._get_centers(boxes, classes, anchors,
Vishnu Banna's avatar
Vishnu Banna committed
413
414
                                                  width, height, scale_xy)
      ind_mask = tf.ones_like(classes)
415
      updates = tf.concat([boxes, ind_mask, classes], axis=-1)
Vishnu Banna's avatar
Vishnu Banna committed
416
417
    else:
      num_anchors = 1
418
      (centers, updates) = self._get_anchor_free(key, boxes, classes, height,
Vishnu Banna's avatar
Vishnu Banna committed
419
                                                 width, stride, scale_xy)
420
421
      boxes, ind_mask, classes = tf.split(updates, [4, 1, 1], axis=-1)

Vishnu Banna's avatar
Vishnu Banna committed
422
423
424
425
426
427
428
    width = tf.cast(width, tf.int32)
    height = tf.cast(height, tf.int32)
    full = tf.zeros([height, width, num_anchors, 1], dtype=classes.dtype)
    full = tf.tensor_scatter_nd_add(full, centers, ind_mask)

    num_instances = int(self.num_instances[key])
    centers = preprocessing_ops.pad_max_instances(
429
        centers, num_instances, pad_value=0, pad_axis=0)
Vishnu Banna's avatar
Vishnu Banna committed
430
    updates = preprocessing_ops.pad_max_instances(
431
        updates, num_instances, pad_value=0, pad_axis=0)
Vishnu Banna's avatar
Vishnu Banna committed
432
433
434
435
436
437

    updates = tf.cast(updates, self.dtype)
    full = tf.cast(full, self.dtype)
    return centers, updates, full

  def __call__(self, boxes, classes, width, height):
438
439
440
441
442
    """Builds the labels for a single image, not functional in batch mode.

    Args:
      boxes: `Tensor` of shape [None, 4] indicating the object locations in an
        image.
Vishnu Banna's avatar
Vishnu Banna committed
443
      classes: `Tensor` of shape [None] indicating the each objects classes.
444
      width: `int` for the images width.
Vishnu Banna's avatar
Vishnu Banna committed
445
446
      height: `int` for the images height.

447
448
449
    Returns:
      centers: `Tensor` of shape [None, 3] of indexes in the final grid where
        boxes are located.
Vishnu Banna's avatar
Vishnu Banna committed
450
      updates: `Tensor` of shape [None, 8] the value to place in the final grid.
451
452
      full: `Tensor` of [width/stride, height/stride, num_anchors, 1] holding
        a mask of where boxes are locates for confidence losses.
Vishnu Banna's avatar
Vishnu Banna committed
453
    """
Vishnu Banna's avatar
Vishnu Banna committed
454
455
456
457
    indexes = {}
    updates = {}
    true_grids = {}
    iou_index = None
Vishnu Banna's avatar
Vishnu Banna committed
458

Vishnu Banna's avatar
Vishnu Banna committed
459
    boxes = box_ops.yxyx_to_xcycwh(boxes)
460
    if not self.best_matches_only and self.anchor_free_level_limits is None:
Vishnu Banna's avatar
Vishnu Banna committed
461
462
463
464
      # stitch and search boxes across fpn levels
      anchorsvec = []
      for stitch in self.anchors:
        anchorsvec.extend(self.anchors[stitch])
Vishnu Banna's avatar
Vishnu Banna committed
465

Vishnu Banna's avatar
Vishnu Banna committed
466
467
      stride = tf.cast([width, height], boxes.dtype)
      # get the best anchor for each box
468
469
470
471
472
473
474
475
476
      iou_index, _ = get_best_anchor(
          boxes,
          anchorsvec,
          stride,
          width=1.0,
          height=1.0,
          best_match_only=False,
          use_tie_breaker=self.use_tie_breaker,
          iou_thresh=self.match_threshold)
Vishnu Banna's avatar
Vishnu Banna committed
477

Vishnu Banna's avatar
Vishnu Banna committed
478
479
    for key in self.keys:
      indexes[key], updates[key], true_grids[key] = self.build_label_per_path(
480
481
          key, boxes, classes, width, height, iou_index=iou_index)
    return indexes, updates, true_grids