Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
48693cad
Commit
48693cad
authored
Nov 21, 2019
by
Pengchong Jin
Committed by
A. Unique TensorFlower
Nov 21, 2019
Browse files
Move get_non_empty_box_indices to box_utils.
PiperOrigin-RevId: 281846940
parent
4c872f63
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
16 additions
and
16 deletions
+16
-16
official/vision/detection/dataloader/maskrcnn_parser.py
official/vision/detection/dataloader/maskrcnn_parser.py
+1
-1
official/vision/detection/dataloader/retinanet_parser.py
official/vision/detection/dataloader/retinanet_parser.py
+3
-3
official/vision/detection/dataloader/shapemask_parser.py
official/vision/detection/dataloader/shapemask_parser.py
+2
-2
official/vision/detection/utils/box_utils.py
official/vision/detection/utils/box_utils.py
+10
-0
official/vision/detection/utils/input_utils.py
official/vision/detection/utils/input_utils.py
+0
-10
No files found.
official/vision/detection/dataloader/maskrcnn_parser.py
View file @
48693cad
...
@@ -234,7 +234,7 @@ class Parser(object):
...
@@ -234,7 +234,7 @@ class Parser(object):
boxes
,
image_scale
,
(
image_height
,
image_width
),
offset
)
boxes
,
image_scale
,
(
image_height
,
image_width
),
offset
)
# Filters out ground truth boxes that are all zeros.
# Filters out ground truth boxes that are all zeros.
indices
=
input
_utils
.
get_non_empty_box_indices
(
boxes
)
indices
=
box
_utils
.
get_non_empty_box_indices
(
boxes
)
boxes
=
tf
.
gather
(
boxes
,
indices
)
boxes
=
tf
.
gather
(
boxes
,
indices
)
classes
=
tf
.
gather
(
classes
,
indices
)
classes
=
tf
.
gather
(
classes
,
indices
)
if
self
.
_include_mask
:
if
self
.
_include_mask
:
...
...
official/vision/detection/dataloader/retinanet_parser.py
View file @
48693cad
...
@@ -251,7 +251,7 @@ class Parser(object):
...
@@ -251,7 +251,7 @@ class Parser(object):
boxes
=
input_utils
.
resize_and_crop_boxes
(
boxes
=
input_utils
.
resize_and_crop_boxes
(
boxes
,
image_scale
,
(
image_height
,
image_width
),
offset
)
boxes
,
image_scale
,
(
image_height
,
image_width
),
offset
)
# Filters out ground truth boxes that are all zeros.
# Filters out ground truth boxes that are all zeros.
indices
=
input
_utils
.
get_non_empty_box_indices
(
boxes
)
indices
=
box
_utils
.
get_non_empty_box_indices
(
boxes
)
boxes
=
tf
.
gather
(
boxes
,
indices
)
boxes
=
tf
.
gather
(
boxes
,
indices
)
classes
=
tf
.
gather
(
classes
,
indices
)
classes
=
tf
.
gather
(
classes
,
indices
)
...
@@ -311,7 +311,7 @@ class Parser(object):
...
@@ -311,7 +311,7 @@ class Parser(object):
boxes
=
input_utils
.
resize_and_crop_boxes
(
boxes
=
input_utils
.
resize_and_crop_boxes
(
boxes
,
image_scale
,
(
image_height
,
image_width
),
offset
)
boxes
,
image_scale
,
(
image_height
,
image_width
),
offset
)
# Filters out ground truth boxes that are all zeros.
# Filters out ground truth boxes that are all zeros.
indices
=
input
_utils
.
get_non_empty_box_indices
(
boxes
)
indices
=
box
_utils
.
get_non_empty_box_indices
(
boxes
)
boxes
=
tf
.
gather
(
boxes
,
indices
)
boxes
=
tf
.
gather
(
boxes
,
indices
)
classes
=
tf
.
gather
(
classes
,
indices
)
classes
=
tf
.
gather
(
classes
,
indices
)
...
@@ -414,7 +414,7 @@ class Parser(object):
...
@@ -414,7 +414,7 @@ class Parser(object):
boxes
=
input_utils
.
resize_and_crop_boxes
(
boxes
=
input_utils
.
resize_and_crop_boxes
(
boxes
,
image_scale
,
(
image_height
,
image_width
),
offset
)
boxes
,
image_scale
,
(
image_height
,
image_width
),
offset
)
# Filters out ground truth boxes that are all zeros.
# Filters out ground truth boxes that are all zeros.
indices
=
input
_utils
.
get_non_empty_box_indices
(
boxes
)
indices
=
box
_utils
.
get_non_empty_box_indices
(
boxes
)
boxes
=
tf
.
gather
(
boxes
,
indices
)
boxes
=
tf
.
gather
(
boxes
,
indices
)
# Assigns anchors.
# Assigns anchors.
...
...
official/vision/detection/dataloader/shapemask_parser.py
View file @
48693cad
...
@@ -268,7 +268,7 @@ class Parser(object):
...
@@ -268,7 +268,7 @@ class Parser(object):
boxes
,
image_scale
,
self
.
_output_size
,
offset
)
boxes
,
image_scale
,
self
.
_output_size
,
offset
)
# Filters out ground truth boxes that are all zeros.
# Filters out ground truth boxes that are all zeros.
indices
=
input
_utils
.
get_non_empty_box_indices
(
boxes
)
indices
=
box
_utils
.
get_non_empty_box_indices
(
boxes
)
boxes
=
tf
.
gather
(
boxes
,
indices
)
boxes
=
tf
.
gather
(
boxes
,
indices
)
classes
=
tf
.
gather
(
classes
,
indices
)
classes
=
tf
.
gather
(
classes
,
indices
)
masks
=
tf
.
gather
(
masks
,
indices
)
masks
=
tf
.
gather
(
masks
,
indices
)
...
@@ -427,7 +427,7 @@ class Parser(object):
...
@@ -427,7 +427,7 @@ class Parser(object):
tf
.
expand_dims
(
masks
,
axis
=-
1
),
image_scale
,
self
.
_output_size
,
offset
)
tf
.
expand_dims
(
masks
,
axis
=-
1
),
image_scale
,
self
.
_output_size
,
offset
)
# Filters out ground truth boxes that are all zeros.
# Filters out ground truth boxes that are all zeros.
indices
=
input
_utils
.
get_non_empty_box_indices
(
boxes
)
indices
=
box
_utils
.
get_non_empty_box_indices
(
boxes
)
boxes
=
tf
.
gather
(
boxes
,
indices
)
boxes
=
tf
.
gather
(
boxes
,
indices
)
classes
=
tf
.
gather
(
classes
,
indices
)
classes
=
tf
.
gather
(
classes
,
indices
)
...
...
official/vision/detection/utils/box_utils.py
View file @
48693cad
...
@@ -523,3 +523,13 @@ def bbox_overlap(boxes, gt_boxes):
...
@@ -523,3 +523,13 @@ def bbox_overlap(boxes, gt_boxes):
iou
=
tf
.
where
(
padding_mask
,
-
tf
.
ones_like
(
iou
),
iou
)
iou
=
tf
.
where
(
padding_mask
,
-
tf
.
ones_like
(
iou
),
iou
)
return
iou
return
iou
def
get_non_empty_box_indices
(
boxes
):
"""Get indices for non-empty boxes."""
# Selects indices if box height or width is 0.
height
=
boxes
[:,
2
]
-
boxes
[:,
0
]
width
=
boxes
[:,
3
]
-
boxes
[:,
1
]
indices
=
tf
.
where
(
tf
.
logical_and
(
tf
.
greater
(
height
,
0
),
tf
.
greater
(
width
,
0
)))
return
indices
[:,
0
]
official/vision/detection/utils/input_utils.py
View file @
48693cad
...
@@ -362,13 +362,3 @@ def resize_and_crop_masks(masks,
...
@@ -362,13 +362,3 @@ def resize_and_crop_masks(masks,
def
random_horizontal_flip
(
image
,
boxes
=
None
,
masks
=
None
):
def
random_horizontal_flip
(
image
,
boxes
=
None
,
masks
=
None
):
"""Randomly flips input image and bounding boxes."""
"""Randomly flips input image and bounding boxes."""
return
preprocessor
.
random_horizontal_flip
(
image
,
boxes
,
masks
)
return
preprocessor
.
random_horizontal_flip
(
image
,
boxes
,
masks
)
def
get_non_empty_box_indices
(
boxes
):
"""Get indices for non-empty boxes."""
# Selects indices if box height or width is 0.
height
=
boxes
[:,
2
]
-
boxes
[:,
0
]
width
=
boxes
[:,
3
]
-
boxes
[:,
1
]
indices
=
tf
.
where
(
tf
.
logical_and
(
tf
.
greater
(
height
,
0
),
tf
.
greater
(
width
,
0
)))
return
indices
[:,
0
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment