Commit 48693cad authored by Pengchong Jin's avatar Pengchong Jin Committed by A. Unique TensorFlower
Browse files

Move get_non_empty_box_indices to box_utils.

PiperOrigin-RevId: 281846940
parent 4c872f63
...@@ -234,7 +234,7 @@ class Parser(object): ...@@ -234,7 +234,7 @@ class Parser(object):
boxes, image_scale, (image_height, image_width), offset) boxes, image_scale, (image_height, image_width), offset)
# Filters out ground truth boxes that are all zeros. # Filters out ground truth boxes that are all zeros.
indices = input_utils.get_non_empty_box_indices(boxes) indices = box_utils.get_non_empty_box_indices(boxes)
boxes = tf.gather(boxes, indices) boxes = tf.gather(boxes, indices)
classes = tf.gather(classes, indices) classes = tf.gather(classes, indices)
if self._include_mask: if self._include_mask:
......
...@@ -251,7 +251,7 @@ class Parser(object): ...@@ -251,7 +251,7 @@ class Parser(object):
boxes = input_utils.resize_and_crop_boxes( boxes = input_utils.resize_and_crop_boxes(
boxes, image_scale, (image_height, image_width), offset) boxes, image_scale, (image_height, image_width), offset)
# Filters out ground truth boxes that are all zeros. # Filters out ground truth boxes that are all zeros.
indices = input_utils.get_non_empty_box_indices(boxes) indices = box_utils.get_non_empty_box_indices(boxes)
boxes = tf.gather(boxes, indices) boxes = tf.gather(boxes, indices)
classes = tf.gather(classes, indices) classes = tf.gather(classes, indices)
...@@ -311,7 +311,7 @@ class Parser(object): ...@@ -311,7 +311,7 @@ class Parser(object):
boxes = input_utils.resize_and_crop_boxes( boxes = input_utils.resize_and_crop_boxes(
boxes, image_scale, (image_height, image_width), offset) boxes, image_scale, (image_height, image_width), offset)
# Filters out ground truth boxes that are all zeros. # Filters out ground truth boxes that are all zeros.
indices = input_utils.get_non_empty_box_indices(boxes) indices = box_utils.get_non_empty_box_indices(boxes)
boxes = tf.gather(boxes, indices) boxes = tf.gather(boxes, indices)
classes = tf.gather(classes, indices) classes = tf.gather(classes, indices)
...@@ -414,7 +414,7 @@ class Parser(object): ...@@ -414,7 +414,7 @@ class Parser(object):
boxes = input_utils.resize_and_crop_boxes( boxes = input_utils.resize_and_crop_boxes(
boxes, image_scale, (image_height, image_width), offset) boxes, image_scale, (image_height, image_width), offset)
# Filters out ground truth boxes that are all zeros. # Filters out ground truth boxes that are all zeros.
indices = input_utils.get_non_empty_box_indices(boxes) indices = box_utils.get_non_empty_box_indices(boxes)
boxes = tf.gather(boxes, indices) boxes = tf.gather(boxes, indices)
# Assigns anchors. # Assigns anchors.
......
...@@ -268,7 +268,7 @@ class Parser(object): ...@@ -268,7 +268,7 @@ class Parser(object):
boxes, image_scale, self._output_size, offset) boxes, image_scale, self._output_size, offset)
# Filters out ground truth boxes that are all zeros. # Filters out ground truth boxes that are all zeros.
indices = input_utils.get_non_empty_box_indices(boxes) indices = box_utils.get_non_empty_box_indices(boxes)
boxes = tf.gather(boxes, indices) boxes = tf.gather(boxes, indices)
classes = tf.gather(classes, indices) classes = tf.gather(classes, indices)
masks = tf.gather(masks, indices) masks = tf.gather(masks, indices)
...@@ -427,7 +427,7 @@ class Parser(object): ...@@ -427,7 +427,7 @@ class Parser(object):
tf.expand_dims(masks, axis=-1), image_scale, self._output_size, offset) tf.expand_dims(masks, axis=-1), image_scale, self._output_size, offset)
# Filters out ground truth boxes that are all zeros. # Filters out ground truth boxes that are all zeros.
indices = input_utils.get_non_empty_box_indices(boxes) indices = box_utils.get_non_empty_box_indices(boxes)
boxes = tf.gather(boxes, indices) boxes = tf.gather(boxes, indices)
classes = tf.gather(classes, indices) classes = tf.gather(classes, indices)
......
...@@ -523,3 +523,13 @@ def bbox_overlap(boxes, gt_boxes): ...@@ -523,3 +523,13 @@ def bbox_overlap(boxes, gt_boxes):
iou = tf.where(padding_mask, -tf.ones_like(iou), iou) iou = tf.where(padding_mask, -tf.ones_like(iou), iou)
return iou return iou
def get_non_empty_box_indices(boxes):
"""Get indices for non-empty boxes."""
# Selects indices if box height or width is 0.
height = boxes[:, 2] - boxes[:, 0]
width = boxes[:, 3] - boxes[:, 1]
indices = tf.where(tf.logical_and(tf.greater(height, 0),
tf.greater(width, 0)))
return indices[:, 0]
...@@ -362,13 +362,3 @@ def resize_and_crop_masks(masks, ...@@ -362,13 +362,3 @@ def resize_and_crop_masks(masks,
def random_horizontal_flip(image, boxes=None, masks=None): def random_horizontal_flip(image, boxes=None, masks=None):
"""Randomly flips input image and bounding boxes.""" """Randomly flips input image and bounding boxes."""
return preprocessor.random_horizontal_flip(image, boxes, masks) return preprocessor.random_horizontal_flip(image, boxes, masks)
def get_non_empty_box_indices(boxes):
"""Get indices for non-empty boxes."""
# Selects indices if box height or width is 0.
height = boxes[:, 2] - boxes[:, 0]
width = boxes[:, 3] - boxes[:, 1]
indices = tf.where(tf.logical_and(tf.greater(height, 0),
tf.greater(width, 0)))
return indices[:, 0]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment