Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3abc1d1c
Commit
3abc1d1c
authored
Oct 04, 2021
by
A. Unique TensorFlower
Browse files
Support None num_boxes and refactor serving.
PiperOrigin-RevId: 400773326
parent
84a65a31
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
62 additions
and
31 deletions
+62
-31
official/vision/beta/modeling/maskrcnn_model.py
official/vision/beta/modeling/maskrcnn_model.py
+28
-17
official/vision/beta/ops/spatial_transform_ops.py
official/vision/beta/ops/spatial_transform_ops.py
+12
-9
official/vision/beta/projects/deepmac_maskrcnn/modeling/maskrcnn_model.py
...beta/projects/deepmac_maskrcnn/modeling/maskrcnn_model.py
+1
-1
official/vision/beta/serving/detection.py
official/vision/beta/serving/detection.py
+21
-4
No files found.
official/vision/beta/modeling/maskrcnn_model.py
View file @
3abc1d1c
...
...
@@ -151,7 +151,7 @@ class MaskRCNNModel(tf.keras.Model):
model_mask_outputs
=
self
.
_call_mask_outputs
(
model_box_outputs
=
model_outputs
,
features
=
intermediate
_outputs
[
'features'
],
features
=
model
_outputs
[
'
decoder_
features'
],
current_rois
=
intermediate_outputs
[
'current_rois'
],
matched_gt_indices
=
intermediate_outputs
[
'matched_gt_indices'
],
matched_gt_boxes
=
intermediate_outputs
[
'matched_gt_boxes'
],
...
...
@@ -161,6 +161,15 @@ class MaskRCNNModel(tf.keras.Model):
model_outputs
.
update
(
model_mask_outputs
)
return
model_outputs
def
_get_backbone_and_decoder_features
(
self
,
images
):
backbone_features
=
self
.
backbone
(
images
)
if
self
.
decoder
:
features
=
self
.
decoder
(
backbone_features
)
else
:
features
=
backbone_features
return
backbone_features
,
features
def
_call_box_outputs
(
self
,
images
:
tf
.
Tensor
,
image_shape
:
tf
.
Tensor
,
...
...
@@ -173,18 +182,15 @@ class MaskRCNNModel(tf.keras.Model):
model_outputs
=
{}
# Feature extraction.
backbone_features
=
self
.
backbone
(
images
)
if
self
.
decoder
:
features
=
self
.
decoder
(
backbone_features
)
else
:
features
=
backbone_features
(
backbone_features
,
decoder_features
)
=
self
.
_get_backbone_and_decoder_features
(
images
)
# Region proposal network.
rpn_scores
,
rpn_boxes
=
self
.
rpn_head
(
features
)
rpn_scores
,
rpn_boxes
=
self
.
rpn_head
(
decoder_
features
)
model_outputs
.
update
({
'backbone_features'
:
backbone_features
,
'decoder_features'
:
features
,
'decoder_features'
:
decoder_
features
,
'rpn_boxes'
:
rpn_boxes
,
'rpn_scores'
:
rpn_scores
})
...
...
@@ -219,7 +225,7 @@ class MaskRCNNModel(tf.keras.Model):
(
class_outputs
,
box_outputs
,
model_outputs
,
matched_gt_boxes
,
matched_gt_classes
,
matched_gt_indices
,
current_rois
)
=
self
.
_run_frcnn_head
(
features
=
features
,
features
=
decoder_
features
,
rois
=
current_rois
,
gt_boxes
=
gt_boxes
,
gt_classes
=
gt_classes
,
...
...
@@ -270,7 +276,6 @@ class MaskRCNNModel(tf.keras.Model):
'matched_gt_boxes'
:
matched_gt_boxes
,
'matched_gt_indices'
:
matched_gt_indices
,
'matched_gt_classes'
:
matched_gt_classes
,
'features'
:
features
,
'current_rois'
:
current_rois
,
}
return
(
model_outputs
,
intermediate_outputs
)
...
...
@@ -302,19 +307,16 @@ class MaskRCNNModel(tf.keras.Model):
current_rois
=
model_outputs
[
'detection_boxes'
]
roi_classes
=
model_outputs
[
'detection_classes'
]
# Mask RoI align.
mask_roi_features
=
self
.
mask_roi_aligner
(
features
,
current_rois
)
# Mask head.
raw_masks
=
self
.
mask_head
([
mask_roi_features
,
roi_classes
])
mask_logits
,
mask_probs
=
self
.
_features_to_mask_outputs
(
features
,
current_rois
,
roi_classes
)
if
training
:
model_outputs
.
update
({
'mask_outputs'
:
raw_
masks
,
'mask_outputs'
:
mask
_logit
s
,
})
else
:
model_outputs
.
update
({
'detection_masks'
:
tf
.
math
.
sigmoid
(
raw_masks
)
,
'detection_masks'
:
mask_probs
,
})
return
model_outputs
...
...
@@ -395,6 +397,15 @@ class MaskRCNNModel(tf.keras.Model):
return
(
class_outputs
,
box_outputs
,
model_outputs
,
matched_gt_boxes
,
matched_gt_classes
,
matched_gt_indices
,
rois
)
def
_features_to_mask_outputs
(
self
,
features
,
rois
,
roi_classes
):
# Mask RoI align.
mask_roi_features
=
self
.
mask_roi_aligner
(
features
,
rois
)
# Mask head.
raw_masks
=
self
.
mask_head
([
mask_roi_features
,
roi_classes
])
return
raw_masks
,
tf
.
nn
.
sigmoid
(
raw_masks
)
@
property
def
checkpoint_items
(
self
)
->
Mapping
[
str
,
Union
[
tf
.
keras
.
Model
,
tf
.
keras
.
layers
.
Layer
]]:
...
...
official/vision/beta/ops/spatial_transform_ops.py
View file @
3abc1d1c
...
...
@@ -43,10 +43,11 @@ def _feature_bilinear_interpolation(features, kernel_y, kernel_x):
[batch_size, num_boxes, output_size, output_size, num_filters].
"""
batch_size
,
num_boxes
,
output_size
,
_
,
num_filters
=
(
features
.
get_shape
().
as_list
())
if
batch_size
is
None
:
batch_size
=
tf
.
shape
(
features
)[
0
]
features_shape
=
tf
.
shape
(
features
)
batch_size
,
num_boxes
,
output_size
,
num_filters
=
(
features_shape
[
0
],
features_shape
[
1
],
features_shape
[
2
],
features_shape
[
4
])
output_size
=
output_size
//
2
kernel_y
=
tf
.
reshape
(
kernel_y
,
[
batch_size
,
num_boxes
,
output_size
*
2
,
1
])
kernel_x
=
tf
.
reshape
(
kernel_x
,
[
batch_size
,
num_boxes
,
1
,
output_size
*
2
])
...
...
@@ -88,7 +89,8 @@ def _compute_grid_positions(boxes, boundaries, output_size, sample_offset):
box_grid_y0y1: Tensor of size [batch_size, boxes, output_size, 2]
box_grid_x0x1: Tensor of size [batch_size, boxes, output_size, 2]
"""
batch_size
,
num_boxes
,
_
=
boxes
.
get_shape
().
as_list
()
boxes_shape
=
tf
.
shape
(
boxes
)
batch_size
,
num_boxes
=
boxes_shape
[
0
],
boxes_shape
[
1
]
if
batch_size
is
None
:
batch_size
=
tf
.
shape
(
boxes
)[
0
]
box_grid_x
=
[]
...
...
@@ -161,11 +163,12 @@ def multilevel_crop_and_resize(features,
levels
=
list
(
features
.
keys
())
min_level
=
int
(
min
(
levels
))
max_level
=
int
(
max
(
levels
))
features_shape
=
tf
.
shape
(
features
[
str
(
min_level
)])
batch_size
,
max_feature_height
,
max_feature_width
,
num_filters
=
(
features
[
str
(
min_level
)].
get_shape
().
as_list
())
if
batch_size
is
None
:
batch_size
=
tf
.
shape
(
features
[
str
(
min_level
)])[
0
]
_
,
num_boxes
,
_
=
boxes
.
get_shape
().
as_list
()
features
_shape
[
0
],
features_shape
[
1
],
features_shape
[
2
],
features_shape
[
3
])
num_boxes
=
tf
.
shape
(
boxes
)[
1
]
# Stack feature pyramid into a features_all of shape
# [batch_size, levels, height, width, num_filters].
...
...
official/vision/beta/projects/deepmac_maskrcnn/modeling/maskrcnn_model.py
View file @
3abc1d1c
...
...
@@ -131,7 +131,7 @@ class DeepMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
model_mask_outputs
=
self
.
_call_mask_outputs
(
model_box_outputs
=
model_outputs
,
features
=
intermediate
_outputs
[
'features'
],
features
=
model
_outputs
[
'
decoder_
features'
],
current_rois
=
intermediate_outputs
[
'current_rois'
],
matched_gt_indices
=
intermediate_outputs
[
'matched_gt_indices'
],
matched_gt_boxes
=
intermediate_outputs
[
'matched_gt_boxes'
],
...
...
official/vision/beta/serving/detection.py
View file @
3abc1d1c
...
...
@@ -15,6 +15,7 @@
# Lint as: python3
"""Detection input and model functions for serving/inference."""
from
typing
import
Mapping
,
Text
import
tensorflow
as
tf
from
official.vision.beta
import
configs
...
...
@@ -78,13 +79,17 @@ class DetectionModule(export_base.ExportModule):
return
image
,
anchor_boxes
,
image_info
def
serve
(
self
,
images
:
tf
.
Tensor
):
"""Cast image to float and run inference.
def
preprocess
(
self
,
images
:
tf
.
Tensor
)
->
(
tf
.
Tensor
,
Mapping
[
Text
,
tf
.
Tensor
],
tf
.
Tensor
):
"""Preprocess inputs to be suitable for the model.
Args:
images:
uint8 Tensor of shape [batch_size, None, None, 3]
images:
The images tensor.
Returns:
Tensor holding detection output logits.
images: The images tensor cast to float.
anchor_boxes: Dict mapping anchor levels to anchor boxes.
image_info: Tensor containing the details of the image resizing.
"""
model_params
=
self
.
params
.
task
.
model
with
tf
.
device
(
'cpu:0'
):
...
...
@@ -117,6 +122,18 @@ class DetectionModule(export_base.ExportModule):
image_info_spec
),
parallel_iterations
=
32
))
return
images
,
anchor_boxes
,
image_info
def
serve
(
self
,
images
:
tf
.
Tensor
):
"""Cast image to float and run inference.
Args:
images: uint8 Tensor of shape [batch_size, None, None, 3]
Returns:
Tensor holding detection output logits.
"""
images
,
anchor_boxes
,
image_info
=
self
.
preprocess
(
images
)
input_image_shape
=
image_info
[:,
1
,
:]
# To overcome keras.Model extra limitation to save a model with layers that
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment