Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
e12bd6a5
Commit
e12bd6a5
authored
Nov 01, 2019
by
Pengchong Jin
Committed by
A. Unique TensorFlower
Nov 01, 2019
Browse files
Internal change
PiperOrigin-RevId: 277961522
parent
ada2ed77
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
26 additions
and
25 deletions
+26
-25
official/vision/detection/dataloader/maskrcnn_parser.py
official/vision/detection/dataloader/maskrcnn_parser.py
+5
-6
official/vision/detection/modeling/architecture/factory.py
official/vision/detection/modeling/architecture/factory.py
+1
-1
official/vision/detection/modeling/architecture/heads.py
official/vision/detection/modeling/architecture/heads.py
+4
-4
official/vision/detection/ops/sampling_ops.py
official/vision/detection/ops/sampling_ops.py
+16
-14
No files found.
official/vision/detection/dataloader/maskrcnn_parser.py
View file @
e12bd6a5
...
...
@@ -232,11 +232,6 @@ class Parser(object):
offset
=
image_info
[
3
,
:]
boxes
=
input_utils
.
resize_and_crop_boxes
(
boxes
,
image_scale
,
(
image_height
,
image_width
),
offset
)
if
self
.
_include_mask
:
masks
=
input_utils
.
resize_and_crop_masks
(
tf
.
expand_dims
(
masks
,
axis
=-
1
),
image_scale
,
(
image_height
,
image_width
),
offset
)
masks
=
tf
.
squeeze
(
masks
,
axis
=-
1
)
# Filters out ground truth boxes that are all zeros.
indices
=
input_utils
.
get_non_empty_box_indices
(
boxes
)
...
...
@@ -244,10 +239,14 @@ class Parser(object):
classes
=
tf
.
gather
(
classes
,
indices
)
if
self
.
_include_mask
:
masks
=
tf
.
gather
(
masks
,
indices
)
cropped_boxes
=
boxes
+
tf
.
cast
(
tf
.
tile
(
tf
.
expand_dims
(
offset
,
axis
=
0
),
[
1
,
2
]),
dtype
=
tf
.
float32
)
cropped_boxes
=
box_utils
.
normalize_boxes
(
cropped_boxes
,
image_info
[
1
,
:])
num_masks
=
tf
.
shape
(
masks
)[
0
]
masks
=
tf
.
image
.
crop_and_resize
(
tf
.
expand_dims
(
masks
,
axis
=-
1
),
box_utils
.
normalize_boxes
(
boxes
,
tf
.
shape
(
image
)[
0
:
2
])
,
cropped_boxes
,
box_indices
=
tf
.
range
(
num_masks
,
dtype
=
tf
.
int32
),
crop_size
=
[
self
.
_mask_crop_size
,
self
.
_mask_crop_size
],
method
=
'bilinear'
)
...
...
official/vision/detection/modeling/architecture/factory.py
View file @
e12bd6a5
...
...
@@ -104,7 +104,7 @@ def fast_rcnn_head_generator(params):
def
mask_rcnn_head_generator
(
params
):
"""Generator function for Mask R-CNN head architecture."""
return
heads
.
MaskrcnnHead
(
params
.
num_classes
,
params
.
m
rcnn_resolution
,
params
.
m
ask_target_size
,
batch_norm_relu
=
batch_norm_relu_generator
(
params
.
batch_norm
))
...
...
official/vision/detection/modeling/architecture/heads.py
View file @
e12bd6a5
...
...
@@ -177,18 +177,18 @@ class MaskrcnnHead(object):
def
__init__
(
self
,
num_classes
,
m
rcnn_resolution
,
m
ask_target_size
,
batch_norm_relu
=
nn_ops
.
BatchNormRelu
):
"""Initialize params to build Fast R-CNN head.
Args:
num_classes: a integer for the number of classes.
m
rcnn_resolution
: a integer that is the resolution of masks.
m
ask_target_size
: a integer that is the resolution of masks.
batch_norm_relu: an operation that includes a batch normalization layer
followed by a relu layer(optional).
"""
self
.
_num_classes
=
num_classes
self
.
_m
rcnn_resolution
=
mrcnn_resolution
self
.
_m
ask_target_size
=
mask_target_size
self
.
_batch_norm_relu
=
batch_norm_relu
def
__call__
(
self
,
roi_features
,
class_indices
,
is_training
=
None
):
...
...
@@ -272,7 +272,7 @@ class MaskrcnnHead(object):
name
=
'mask_fcn_logits'
)(
net
)
mask_outputs
=
tf
.
reshape
(
mask_outputs
,
[
-
1
,
num_rois
,
self
.
_m
rcnn_resolution
,
self
.
_mrcnn_resolution
,
-
1
,
num_rois
,
self
.
_m
ask_target_size
,
self
.
_mask_target_size
,
self
.
_num_classes
])
...
...
official/vision/detection/ops/sampling_ops.py
View file @
e12bd6a5
...
...
@@ -220,8 +220,8 @@ def sample_and_crop_foreground_masks(candidate_rois,
candidate_gt_classes
,
candidate_gt_indices
,
gt_masks
,
num_mask_samples_per_image
=
28
,
cropped_mask
_size
=
28
):
num_mask_samples_per_image
=
1
28
,
mask_target
_size
=
28
):
"""Samples and creates cropped foreground masks for training.
Args:
...
...
@@ -243,7 +243,7 @@ def sample_and_crop_foreground_masks(candidate_rois,
containing all the groundtruth masks which sample masks are drawn from.
num_mask_samples_per_image: an integer which specifies the number of masks
to sample.
cropped_mask
_size: an integer which specifies the final cropped mask size
mask_target
_size: an integer which specifies the final cropped mask size
after sampling. The output masks are resized w.r.t the sampled RoIs.
Returns:
...
...
@@ -253,7 +253,7 @@ def sample_and_crop_foreground_masks(candidate_rois,
foreground_classes: a tensor of shape of [batch_size, K] storing the classes
corresponding to the sampled foreground masks.
cropoped_foreground_masks: a tensor of shape of
[batch_size, K,
cropped_mask_size, cropped_mask
_size] storing the cropped
[batch_size, K,
mask_target_size, mask_target
_size] storing the cropped
foreground masks used for training.
"""
with
tf
.
name_scope
(
'sample_and_crop_foreground_masks'
):
...
...
@@ -268,23 +268,25 @@ def sample_and_crop_foreground_masks(candidate_rois,
gather_nd_instance_indices
=
tf
.
stack
(
[
batch_indices
,
fg_instance_indices
],
axis
=-
1
)
foreground_rois
=
tf
.
gather_nd
(
candidate_rois
,
gather_nd_instance_indices
)
foreground_rois
=
tf
.
gather_nd
(
candidate_rois
,
gather_nd_instance_indices
)
foreground_boxes
=
tf
.
gather_nd
(
candidate_gt_boxes
,
gather_nd_instance_indices
)
foreground_classes
=
tf
.
gather_nd
(
candidate_gt_classes
,
gather_nd_instance_indices
)
f
g
_gt_indices
=
tf
.
gather_nd
(
f
oreground
_gt_indices
=
tf
.
gather_nd
(
candidate_gt_indices
,
gather_nd_instance_indices
)
f
g
_gt_indices_shape
=
tf
.
shape
(
f
g
_gt_indices
)
f
oreground
_gt_indices_shape
=
tf
.
shape
(
f
oreground
_gt_indices
)
batch_indices
=
(
tf
.
expand_dims
(
tf
.
range
(
fg_gt_indices_shape
[
0
]),
axis
=-
1
)
*
tf
.
ones
([
1
,
fg_gt_indices_shape
[
-
1
]],
dtype
=
tf
.
int32
))
gather_nd_gt_indices
=
tf
.
stack
([
batch_indices
,
fg_gt_indices
],
axis
=-
1
)
tf
.
expand_dims
(
tf
.
range
(
foreground_gt_indices_shape
[
0
]),
axis
=-
1
)
*
tf
.
ones
([
1
,
foreground_gt_indices_shape
[
-
1
]],
dtype
=
tf
.
int32
))
gather_nd_gt_indices
=
tf
.
stack
(
[
batch_indices
,
foreground_gt_indices
],
axis
=-
1
)
foreground_masks
=
tf
.
gather_nd
(
gt_masks
,
gather_nd_gt_indices
)
cropped_foreground_masks
=
spatial_transform_ops
.
crop_mask_in_target_box
(
foreground_masks
,
foreground_boxes
,
foreground_rois
,
cropped_mask
_size
)
foreground_masks
,
foreground_boxes
,
foreground_rois
,
mask_target
_size
)
return
foreground_rois
,
foreground_classes
,
cropped_foreground_masks
...
...
@@ -345,7 +347,7 @@ class MaskSampler(object):
def
__init__
(
self
,
params
):
self
.
_num_mask_samples_per_image
=
params
.
num_mask_samples_per_image
self
.
_
cropped_mask
_size
=
params
.
cropped_mask
_size
self
.
_
mask_target
_size
=
params
.
mask_target
_size
def
__call__
(
self
,
candidate_rois
,
...
...
@@ -381,7 +383,7 @@ class MaskSampler(object):
foreground_classes: a tensor of shape of [batch_size, K] storing the
classes corresponding to the sampled foreground masks.
cropoped_foreground_masks: a tensor of shape of
[batch_size, K,
cropped_mask_size, cropped_mask
_size] storing the
[batch_size, K,
mask_target_size, mask_target
_size] storing the
cropped foreground masks used for training.
"""
foreground_rois
,
foreground_classes
,
cropped_foreground_masks
=
(
...
...
@@ -392,5 +394,5 @@ class MaskSampler(object):
candidate_gt_indices
,
gt_masks
,
self
.
_num_mask_samples_per_image
,
self
.
_
cropped_mask
_size
))
self
.
_
mask_target
_size
))
return
foreground_rois
,
foreground_classes
,
cropped_foreground_masks
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment