Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
cf5732d4
Commit
cf5732d4
authored
Jun 16, 2021
by
A. Unique TensorFlower
Browse files
Merge pull request #10052 from tensorflow:panoptic-segmentation
PiperOrigin-RevId: 379834961
parents
25eaaf7f
35c3a79f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
724 additions
and
3 deletions
+724
-3
official/vision/beta/modeling/maskrcnn_model.py
official/vision/beta/modeling/maskrcnn_model.py
+6
-3
official/vision/beta/projects/panoptic_maskrcnn/README.md
official/vision/beta/projects/panoptic_maskrcnn/README.md
+20
-0
official/vision/beta/projects/panoptic_maskrcnn/__init__.py
official/vision/beta/projects/panoptic_maskrcnn/__init__.py
+27
-0
official/vision/beta/projects/panoptic_maskrcnn/modeling/panoptic_maskrcnn_model.py
...cts/panoptic_maskrcnn/modeling/panoptic_maskrcnn_model.py
+182
-0
official/vision/beta/projects/panoptic_maskrcnn/modeling/panoptic_maskrcnn_model_test.py
...anoptic_maskrcnn/modeling/panoptic_maskrcnn_model_test.py
+489
-0
No files found.
official/vision/beta/modeling/maskrcnn_model.py
View file @
cf5732d4
...
@@ -16,7 +16,6 @@
...
@@ -16,7 +16,6 @@
from
typing
import
Any
,
List
,
Mapping
,
Optional
,
Union
from
typing
import
Any
,
List
,
Mapping
,
Optional
,
Union
# Import libraries
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.vision.beta.ops
import
anchor
from
official.vision.beta.ops
import
anchor
...
@@ -147,14 +146,18 @@ class MaskRCNNModel(tf.keras.Model):
...
@@ -147,14 +146,18 @@ class MaskRCNNModel(tf.keras.Model):
model_outputs
=
{}
model_outputs
=
{}
# Feature extraction.
# Feature extraction.
features
=
self
.
backbone
(
images
)
backbone_
features
=
self
.
backbone
(
images
)
if
self
.
decoder
:
if
self
.
decoder
:
features
=
self
.
decoder
(
features
)
features
=
self
.
decoder
(
backbone_features
)
else
:
features
=
backbone_features
# Region proposal network.
# Region proposal network.
rpn_scores
,
rpn_boxes
=
self
.
rpn_head
(
features
)
rpn_scores
,
rpn_boxes
=
self
.
rpn_head
(
features
)
model_outputs
.
update
({
model_outputs
.
update
({
'backbone_features'
:
backbone_features
,
'decoder_features'
:
features
,
'rpn_boxes'
:
rpn_boxes
,
'rpn_boxes'
:
rpn_boxes
,
'rpn_scores'
:
rpn_scores
'rpn_scores'
:
rpn_scores
})
})
...
...
official/vision/beta/projects/panoptic_maskrcnn/README.md
0 → 100644
View file @
cf5732d4
# Panoptic Segmentation
## Description
Panoptic Segmentation combines the two distinct vision tasks - semantic
segmentation and instance segmentation. These tasks are unified such that, each
pixel in the image is assigned the label of the class it belongs to, and also
the instance identifier of the object it a part of.
## Environment setup
The code can be run on multiple GPUs or TPUs with different distribution
strategies. See the TensorFlow distributed training
[
guide
](
https://www.tensorflow.org/guide/distributed_training
)
for an overview
of
`tf.distribute`
.
The code is compatible with TensorFlow 2.4+. See requirements.txt for all
prerequisites, and you can also install them using the following command.
`pip
install -r ./official/requirements.txt`
**DISCLAIMER**
: Panoptic MaskRCNN is still under active development, stay tuned!
official/vision/beta/projects/panoptic_maskrcnn/__init__.py
0 → 100644
View file @
cf5732d4
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/vision/beta/projects/panoptic_maskrcnn/modeling/panoptic_maskrcnn_model.py
0 → 100644
View file @
cf5732d4
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Panoptic Segmentation model."""
from
typing
import
List
,
Mapping
,
Optional
,
Union
import
tensorflow
as
tf
from
official.vision.beta.modeling
import
maskrcnn_model
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
PanopticMaskRCNNModel
(
maskrcnn_model
.
MaskRCNNModel
):
"""The Panoptic Segmentation model."""
def
__init__
(
self
,
backbone
:
tf
.
keras
.
Model
,
decoder
:
tf
.
keras
.
Model
,
rpn_head
:
tf
.
keras
.
layers
.
Layer
,
detection_head
:
Union
[
tf
.
keras
.
layers
.
Layer
,
List
[
tf
.
keras
.
layers
.
Layer
]],
roi_generator
:
tf
.
keras
.
layers
.
Layer
,
roi_sampler
:
Union
[
tf
.
keras
.
layers
.
Layer
,
List
[
tf
.
keras
.
layers
.
Layer
]],
roi_aligner
:
tf
.
keras
.
layers
.
Layer
,
detection_generator
:
tf
.
keras
.
layers
.
Layer
,
mask_head
:
Optional
[
tf
.
keras
.
layers
.
Layer
]
=
None
,
mask_sampler
:
Optional
[
tf
.
keras
.
layers
.
Layer
]
=
None
,
mask_roi_aligner
:
Optional
[
tf
.
keras
.
layers
.
Layer
]
=
None
,
segmentation_backbone
:
Optional
[
tf
.
keras
.
Model
]
=
None
,
segmentation_decoder
:
Optional
[
tf
.
keras
.
Model
]
=
None
,
segmentation_head
:
tf
.
keras
.
layers
.
Layer
=
None
,
class_agnostic_bbox_pred
:
bool
=
False
,
cascade_class_ensemble
:
bool
=
False
,
min_level
:
Optional
[
int
]
=
None
,
max_level
:
Optional
[
int
]
=
None
,
num_scales
:
Optional
[
int
]
=
None
,
aspect_ratios
:
Optional
[
List
[
float
]]
=
None
,
anchor_size
:
Optional
[
float
]
=
None
,
**
kwargs
):
"""Initializes the Panoptic Mask R-CNN model.
Args:
backbone: `tf.keras.Model`, the backbone network.
decoder: `tf.keras.Model`, the decoder network.
rpn_head: the RPN head.
detection_head: the detection head or a list of heads.
roi_generator: the ROI generator.
roi_sampler: a single ROI sampler or a list of ROI samplers for cascade
detection heads.
roi_aligner: the ROI aligner.
detection_generator: the detection generator.
mask_head: the mask head.
mask_sampler: the mask sampler.
mask_roi_aligner: the ROI alginer for mask prediction.
segmentation_backbone: `tf.keras.Model`, the backbone network for the
segmentation head for panoptic task. Providing `segmentation_backbone`
will allow the segmentation head to use a standlone backbone. Setting
`segmentation_backbone=None` would enable backbone sharing between the
MaskRCNN model and segmentation head.
segmentation_decoder: `tf.keras.Model`, the decoder network for the
segmentation head for panoptic task. Providing `segmentation_decoder`
will allow the segmentation head to use a standlone decoder. Setting
`segmentation_decoder=None` would enable decoder sharing between the
MaskRCNN model and segmentation head. Decoders can only be shared when
`segmentation_backbone` is shared as well.
segmentation_head: segmentatation head for panoptic task.
class_agnostic_bbox_pred: if True, perform class agnostic bounding box
prediction. Needs to be `True` for Cascade RCNN models.
cascade_class_ensemble: if True, ensemble classification scores over all
detection heads.
min_level: Minimum level in output feature maps.
max_level: Maximum level in output feature maps.
num_scales: A number representing intermediate scales added on each level.
For instances, num_scales=2 adds one additional intermediate anchor
scales [2^0, 2^0.5] on each level.
aspect_ratios: A list representing the aspect raito anchors added on each
level. The number indicates the ratio of width to height. For instances,
aspect_ratios=[1.0, 2.0, 0.5] adds three anchors on each scale level.
anchor_size: A number representing the scale of size of the base anchor to
the feature stride 2^level.
**kwargs: keyword arguments to be passed.
"""
super
(
PanopticMaskRCNNModel
,
self
).
__init__
(
backbone
=
backbone
,
decoder
=
decoder
,
rpn_head
=
rpn_head
,
detection_head
=
detection_head
,
roi_generator
=
roi_generator
,
roi_sampler
=
roi_sampler
,
roi_aligner
=
roi_aligner
,
detection_generator
=
detection_generator
,
mask_head
=
mask_head
,
mask_sampler
=
mask_sampler
,
mask_roi_aligner
=
mask_roi_aligner
,
class_agnostic_bbox_pred
=
class_agnostic_bbox_pred
,
cascade_class_ensemble
=
cascade_class_ensemble
,
min_level
=
min_level
,
max_level
=
max_level
,
num_scales
=
num_scales
,
aspect_ratios
=
aspect_ratios
,
anchor_size
=
anchor_size
,
**
kwargs
)
self
.
_config_dict
.
update
({
'segmentation_backbone'
:
segmentation_backbone
,
'segmentation_decoder'
:
segmentation_decoder
,
'segmentation_head'
:
segmentation_head
})
if
not
self
.
_include_mask
:
raise
ValueError
(
'`mask_head` needs to be provided for Panoptic Mask R-CNN.'
)
if
segmentation_backbone
is
not
None
and
segmentation_decoder
is
None
:
raise
ValueError
(
'`segmentation_decoder` needs to be provided for Panoptic Mask R-CNN'
'if `backbone` is not shared.'
)
self
.
segmentation_backbone
=
segmentation_backbone
self
.
segmentation_decoder
=
segmentation_decoder
self
.
segmentation_head
=
segmentation_head
def
call
(
self
,
images
:
tf
.
Tensor
,
image_shape
:
tf
.
Tensor
,
anchor_boxes
:
Optional
[
Mapping
[
str
,
tf
.
Tensor
]]
=
None
,
gt_boxes
:
Optional
[
tf
.
Tensor
]
=
None
,
gt_classes
:
Optional
[
tf
.
Tensor
]
=
None
,
gt_masks
:
Optional
[
tf
.
Tensor
]
=
None
,
training
:
Optional
[
bool
]
=
None
)
->
Mapping
[
str
,
tf
.
Tensor
]:
model_outputs
=
super
(
PanopticMaskRCNNModel
,
self
).
call
(
images
=
images
,
image_shape
=
image_shape
,
anchor_boxes
=
anchor_boxes
,
gt_boxes
=
gt_boxes
,
gt_classes
=
gt_classes
,
gt_masks
=
gt_masks
,
training
=
training
)
if
self
.
segmentation_backbone
is
not
None
:
backbone_features
=
self
.
segmentation_backbone
(
images
,
training
=
training
)
else
:
backbone_features
=
model_outputs
[
'backbone_features'
]
if
self
.
segmentation_decoder
is
not
None
:
decoder_features
=
self
.
segmentation_decoder
(
backbone_features
,
training
=
training
)
else
:
decoder_features
=
model_outputs
[
'decoder_features'
]
segmentation_outputs
=
self
.
segmentation_head
(
backbone_features
,
decoder_features
,
training
=
training
)
model_outputs
.
update
({
'segmentation_outputs'
:
segmentation_outputs
,
})
return
model_outputs
@
property
def
checkpoint_items
(
self
)
->
Mapping
[
str
,
Union
[
tf
.
keras
.
Model
,
tf
.
keras
.
layers
.
Layer
]]:
"""Returns a dictionary of items to be additionally checkpointed."""
items
=
super
(
PanopticMaskRCNNModel
,
self
).
checkpoint_items
if
self
.
segmentation_backbone
is
not
None
:
items
.
update
(
segmentation_backbone
=
self
.
segmentation_backbone
)
if
self
.
segmentation_decoder
is
not
None
:
items
.
update
(
segmentation_decoder
=
self
.
segmentation_decoder
)
items
.
update
(
segmentation_head
=
self
.
segmentation_head
)
return
items
official/vision/beta/projects/panoptic_maskrcnn/modeling/panoptic_maskrcnn_model_test.py
0 → 100644
View file @
cf5732d4
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for panoptic_maskrcnn_model.py."""
import
os
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.vision.beta.modeling.backbones
import
resnet
from
official.vision.beta.modeling.decoders
import
aspp
from
official.vision.beta.modeling.decoders
import
fpn
from
official.vision.beta.modeling.heads
import
dense_prediction_heads
from
official.vision.beta.modeling.heads
import
instance_heads
from
official.vision.beta.modeling.heads
import
segmentation_heads
from
official.vision.beta.modeling.layers
import
detection_generator
from
official.vision.beta.modeling.layers
import
mask_sampler
from
official.vision.beta.modeling.layers
import
roi_aligner
from
official.vision.beta.modeling.layers
import
roi_generator
from
official.vision.beta.modeling.layers
import
roi_sampler
from
official.vision.beta.ops
import
anchor
from
official.vision.beta.projects.panoptic_maskrcnn.modeling
import
panoptic_maskrcnn_model
class
PanopticMaskRCNNModelTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
combinations
.
generate
(
combinations
.
combine
(
use_separable_conv
=
[
True
,
False
],
build_anchor_boxes
=
[
True
,
False
],
shared_backbone
=
[
True
,
False
],
shared_decoder
=
[
True
,
False
],
is_training
=
[
True
,
False
]))
def
test_build_model
(
self
,
use_separable_conv
,
build_anchor_boxes
,
shared_backbone
,
shared_decoder
,
is_training
=
True
):
num_classes
=
3
min_level
=
3
max_level
=
7
num_scales
=
3
aspect_ratios
=
[
1.0
]
anchor_size
=
3
resnet_model_id
=
50
segmentation_resnet_model_id
=
50
segmentation_output_stride
=
16
aspp_dilation_rates
=
[
6
,
12
,
18
]
aspp_decoder_level
=
int
(
np
.
math
.
log2
(
segmentation_output_stride
))
fpn_decoder_level
=
3
num_anchors_per_location
=
num_scales
*
len
(
aspect_ratios
)
image_size
=
128
images
=
np
.
random
.
rand
(
2
,
image_size
,
image_size
,
3
)
image_shape
=
np
.
array
([[
image_size
,
image_size
],
[
image_size
,
image_size
]])
shared_decoder
=
shared_decoder
and
shared_backbone
if
build_anchor_boxes
:
anchor_boxes
=
anchor
.
Anchor
(
min_level
=
min_level
,
max_level
=
max_level
,
num_scales
=
num_scales
,
aspect_ratios
=
aspect_ratios
,
anchor_size
=
3
,
image_size
=
(
image_size
,
image_size
)).
multilevel_boxes
for
l
in
anchor_boxes
:
anchor_boxes
[
l
]
=
tf
.
tile
(
tf
.
expand_dims
(
anchor_boxes
[
l
],
axis
=
0
),
[
2
,
1
,
1
,
1
])
else
:
anchor_boxes
=
None
backbone
=
resnet
.
ResNet
(
model_id
=
resnet_model_id
)
decoder
=
fpn
.
FPN
(
input_specs
=
backbone
.
output_specs
,
min_level
=
min_level
,
max_level
=
max_level
,
use_separable_conv
=
use_separable_conv
)
rpn_head
=
dense_prediction_heads
.
RPNHead
(
min_level
=
min_level
,
max_level
=
max_level
,
num_anchors_per_location
=
num_anchors_per_location
,
num_convs
=
1
)
detection_head
=
instance_heads
.
DetectionHead
(
num_classes
=
num_classes
)
roi_generator_obj
=
roi_generator
.
MultilevelROIGenerator
()
roi_sampler_obj
=
roi_sampler
.
ROISampler
()
roi_aligner_obj
=
roi_aligner
.
MultilevelROIAligner
()
detection_generator_obj
=
detection_generator
.
DetectionGenerator
()
mask_head
=
instance_heads
.
MaskHead
(
num_classes
=
num_classes
,
upsample_factor
=
2
)
mask_sampler_obj
=
mask_sampler
.
MaskSampler
(
mask_target_size
=
28
,
num_sampled_masks
=
1
)
mask_roi_aligner_obj
=
roi_aligner
.
MultilevelROIAligner
(
crop_size
=
14
)
if
shared_backbone
:
segmentation_backbone
=
None
else
:
segmentation_backbone
=
resnet
.
ResNet
(
model_id
=
segmentation_resnet_model_id
)
if
not
shared_decoder
:
level
=
aspp_decoder_level
segmentation_decoder
=
aspp
.
ASPP
(
level
=
level
,
dilation_rates
=
aspp_dilation_rates
)
else
:
level
=
fpn_decoder_level
segmentation_decoder
=
None
segmentation_head
=
segmentation_heads
.
SegmentationHead
(
num_classes
=
2
,
# stuff and common class for things,
level
=
level
,
num_convs
=
2
)
model
=
panoptic_maskrcnn_model
.
PanopticMaskRCNNModel
(
backbone
,
decoder
,
rpn_head
,
detection_head
,
roi_generator_obj
,
roi_sampler_obj
,
roi_aligner_obj
,
detection_generator_obj
,
mask_head
,
mask_sampler_obj
,
mask_roi_aligner_obj
,
segmentation_backbone
=
segmentation_backbone
,
segmentation_decoder
=
segmentation_decoder
,
segmentation_head
=
segmentation_head
,
min_level
=
min_level
,
max_level
=
max_level
,
num_scales
=
num_scales
,
aspect_ratios
=
aspect_ratios
,
anchor_size
=
anchor_size
)
gt_boxes
=
np
.
array
(
[[[
10
,
10
,
15
,
15
],
[
2.5
,
2.5
,
7.5
,
7.5
],
[
-
1
,
-
1
,
-
1
,
-
1
]],
[[
100
,
100
,
150
,
150
],
[
-
1
,
-
1
,
-
1
,
-
1
],
[
-
1
,
-
1
,
-
1
,
-
1
]]],
dtype
=
np
.
float32
)
gt_classes
=
np
.
array
([[
2
,
1
,
-
1
],
[
1
,
-
1
,
-
1
]],
dtype
=
np
.
int32
)
gt_masks
=
np
.
ones
((
2
,
3
,
100
,
100
))
# Results will be checked in test_forward.
_
=
model
(
images
,
image_shape
,
anchor_boxes
,
gt_boxes
,
gt_classes
,
gt_masks
,
training
=
is_training
)
@
combinations
.
generate
(
combinations
.
combine
(
strategy
=
[
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
shared_backbone
=
[
True
,
False
],
shared_decoder
=
[
True
,
False
],
training
=
[
True
,
False
],
))
def
test_forward
(
self
,
strategy
,
training
,
shared_backbone
,
shared_decoder
):
num_classes
=
3
min_level
=
3
max_level
=
4
num_scales
=
3
aspect_ratios
=
[
1.0
]
anchor_size
=
3
segmentation_resnet_model_id
=
101
segmentation_output_stride
=
16
aspp_dilation_rates
=
[
6
,
12
,
18
]
aspp_decoder_level
=
int
(
np
.
math
.
log2
(
segmentation_output_stride
))
fpn_decoder_level
=
3
class_agnostic_bbox_pred
=
False
cascade_class_ensemble
=
False
image_size
=
(
256
,
256
)
images
=
np
.
random
.
rand
(
2
,
image_size
[
0
],
image_size
[
1
],
3
)
image_shape
=
np
.
array
([[
224
,
100
],
[
100
,
224
]])
shared_decoder
=
shared_decoder
and
shared_backbone
with
strategy
.
scope
():
anchor_boxes
=
anchor
.
Anchor
(
min_level
=
min_level
,
max_level
=
max_level
,
num_scales
=
num_scales
,
aspect_ratios
=
aspect_ratios
,
anchor_size
=
anchor_size
,
image_size
=
image_size
).
multilevel_boxes
num_anchors_per_location
=
len
(
aspect_ratios
)
*
num_scales
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
])
backbone
=
resnet
.
ResNet
(
model_id
=
50
,
input_specs
=
input_specs
)
decoder
=
fpn
.
FPN
(
min_level
=
min_level
,
max_level
=
max_level
,
input_specs
=
backbone
.
output_specs
)
rpn_head
=
dense_prediction_heads
.
RPNHead
(
min_level
=
min_level
,
max_level
=
max_level
,
num_anchors_per_location
=
num_anchors_per_location
)
detection_head
=
instance_heads
.
DetectionHead
(
num_classes
=
num_classes
,
class_agnostic_bbox_pred
=
class_agnostic_bbox_pred
)
roi_generator_obj
=
roi_generator
.
MultilevelROIGenerator
()
roi_sampler_cascade
=
[]
roi_sampler_obj
=
roi_sampler
.
ROISampler
()
roi_sampler_cascade
.
append
(
roi_sampler_obj
)
roi_aligner_obj
=
roi_aligner
.
MultilevelROIAligner
()
detection_generator_obj
=
detection_generator
.
DetectionGenerator
()
mask_head
=
instance_heads
.
MaskHead
(
num_classes
=
num_classes
,
upsample_factor
=
2
)
mask_sampler_obj
=
mask_sampler
.
MaskSampler
(
mask_target_size
=
28
,
num_sampled_masks
=
1
)
mask_roi_aligner_obj
=
roi_aligner
.
MultilevelROIAligner
(
crop_size
=
14
)
if
shared_backbone
:
segmentation_backbone
=
None
else
:
segmentation_backbone
=
resnet
.
ResNet
(
model_id
=
segmentation_resnet_model_id
)
if
not
shared_decoder
:
level
=
aspp_decoder_level
segmentation_decoder
=
aspp
.
ASPP
(
level
=
level
,
dilation_rates
=
aspp_dilation_rates
)
else
:
level
=
fpn_decoder_level
segmentation_decoder
=
None
segmentation_head
=
segmentation_heads
.
SegmentationHead
(
num_classes
=
2
,
# stuff and common class for things,
level
=
level
,
num_convs
=
2
)
model
=
panoptic_maskrcnn_model
.
PanopticMaskRCNNModel
(
backbone
,
decoder
,
rpn_head
,
detection_head
,
roi_generator_obj
,
roi_sampler_obj
,
roi_aligner_obj
,
detection_generator_obj
,
mask_head
,
mask_sampler_obj
,
mask_roi_aligner_obj
,
segmentation_backbone
=
segmentation_backbone
,
segmentation_decoder
=
segmentation_decoder
,
segmentation_head
=
segmentation_head
,
class_agnostic_bbox_pred
=
class_agnostic_bbox_pred
,
cascade_class_ensemble
=
cascade_class_ensemble
,
min_level
=
min_level
,
max_level
=
max_level
,
num_scales
=
num_scales
,
aspect_ratios
=
aspect_ratios
,
anchor_size
=
anchor_size
)
gt_boxes
=
np
.
array
(
[[[
10
,
10
,
15
,
15
],
[
2.5
,
2.5
,
7.5
,
7.5
],
[
-
1
,
-
1
,
-
1
,
-
1
]],
[[
100
,
100
,
150
,
150
],
[
-
1
,
-
1
,
-
1
,
-
1
],
[
-
1
,
-
1
,
-
1
,
-
1
]]],
dtype
=
np
.
float32
)
gt_classes
=
np
.
array
([[
2
,
1
,
-
1
],
[
1
,
-
1
,
-
1
]],
dtype
=
np
.
int32
)
gt_masks
=
np
.
ones
((
2
,
3
,
100
,
100
))
results
=
model
(
images
,
image_shape
,
anchor_boxes
,
gt_boxes
,
gt_classes
,
gt_masks
,
training
=
training
)
self
.
assertIn
(
'rpn_boxes'
,
results
)
self
.
assertIn
(
'rpn_scores'
,
results
)
if
training
:
self
.
assertIn
(
'class_targets'
,
results
)
self
.
assertIn
(
'box_targets'
,
results
)
self
.
assertIn
(
'class_outputs'
,
results
)
self
.
assertIn
(
'box_outputs'
,
results
)
self
.
assertIn
(
'mask_outputs'
,
results
)
else
:
self
.
assertIn
(
'detection_boxes'
,
results
)
self
.
assertIn
(
'detection_scores'
,
results
)
self
.
assertIn
(
'detection_classes'
,
results
)
self
.
assertIn
(
'num_detections'
,
results
)
self
.
assertIn
(
'detection_masks'
,
results
)
self
.
assertIn
(
'segmentation_outputs'
,
results
)
self
.
assertAllEqual
(
[
2
,
image_size
[
0
]
//
(
2
**
level
),
image_size
[
1
]
//
(
2
**
level
),
2
],
results
[
'segmentation_outputs'
].
numpy
().
shape
)
@
combinations
.
generate
(
combinations
.
combine
(
shared_backbone
=
[
True
,
False
],
shared_decoder
=
[
True
,
False
]))
def
test_serialize_deserialize
(
self
,
shared_backbone
,
shared_decoder
):
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
])
backbone
=
resnet
.
ResNet
(
model_id
=
50
,
input_specs
=
input_specs
)
decoder
=
fpn
.
FPN
(
min_level
=
3
,
max_level
=
7
,
input_specs
=
backbone
.
output_specs
)
rpn_head
=
dense_prediction_heads
.
RPNHead
(
min_level
=
3
,
max_level
=
7
,
num_anchors_per_location
=
3
)
detection_head
=
instance_heads
.
DetectionHead
(
num_classes
=
2
)
roi_generator_obj
=
roi_generator
.
MultilevelROIGenerator
()
roi_sampler_obj
=
roi_sampler
.
ROISampler
()
roi_aligner_obj
=
roi_aligner
.
MultilevelROIAligner
()
detection_generator_obj
=
detection_generator
.
DetectionGenerator
()
segmentation_resnet_model_id
=
101
segmentation_output_stride
=
16
aspp_dilation_rates
=
[
6
,
12
,
18
]
aspp_decoder_level
=
int
(
np
.
math
.
log2
(
segmentation_output_stride
))
fpn_decoder_level
=
3
shared_decoder
=
shared_decoder
and
shared_backbone
mask_head
=
instance_heads
.
MaskHead
(
num_classes
=
2
,
upsample_factor
=
2
)
mask_sampler_obj
=
mask_sampler
.
MaskSampler
(
mask_target_size
=
28
,
num_sampled_masks
=
1
)
mask_roi_aligner_obj
=
roi_aligner
.
MultilevelROIAligner
(
crop_size
=
14
)
if
shared_backbone
:
segmentation_backbone
=
None
else
:
segmentation_backbone
=
resnet
.
ResNet
(
model_id
=
segmentation_resnet_model_id
)
if
not
shared_decoder
:
level
=
aspp_decoder_level
segmentation_decoder
=
aspp
.
ASPP
(
level
=
level
,
dilation_rates
=
aspp_dilation_rates
)
else
:
level
=
fpn_decoder_level
segmentation_decoder
=
None
segmentation_head
=
segmentation_heads
.
SegmentationHead
(
num_classes
=
2
,
# stuff and common class for things,
level
=
level
,
num_convs
=
2
)
model
=
panoptic_maskrcnn_model
.
PanopticMaskRCNNModel
(
backbone
,
decoder
,
rpn_head
,
detection_head
,
roi_generator_obj
,
roi_sampler_obj
,
roi_aligner_obj
,
detection_generator_obj
,
mask_head
,
mask_sampler_obj
,
mask_roi_aligner_obj
,
segmentation_backbone
=
segmentation_backbone
,
segmentation_decoder
=
segmentation_decoder
,
segmentation_head
=
segmentation_head
,
min_level
=
3
,
max_level
=
7
,
num_scales
=
3
,
aspect_ratios
=
[
1.0
],
anchor_size
=
3
)
config
=
model
.
get_config
()
new_model
=
panoptic_maskrcnn_model
.
PanopticMaskRCNNModel
.
from_config
(
config
)
# Validate that the config can be forced to JSON.
_
=
new_model
.
to_json
()
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
model
.
get_config
(),
new_model
.
get_config
())
@
combinations
.
generate
(
combinations
.
combine
(
shared_backbone
=
[
True
,
False
],
shared_decoder
=
[
True
,
False
]))
def
test_checkpoint
(
self
,
shared_backbone
,
shared_decoder
):
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
])
backbone
=
resnet
.
ResNet
(
model_id
=
50
,
input_specs
=
input_specs
)
decoder
=
fpn
.
FPN
(
min_level
=
3
,
max_level
=
7
,
input_specs
=
backbone
.
output_specs
)
rpn_head
=
dense_prediction_heads
.
RPNHead
(
min_level
=
3
,
max_level
=
7
,
num_anchors_per_location
=
3
)
detection_head
=
instance_heads
.
DetectionHead
(
num_classes
=
2
)
roi_generator_obj
=
roi_generator
.
MultilevelROIGenerator
()
roi_sampler_obj
=
roi_sampler
.
ROISampler
()
roi_aligner_obj
=
roi_aligner
.
MultilevelROIAligner
()
detection_generator_obj
=
detection_generator
.
DetectionGenerator
()
segmentation_resnet_model_id
=
101
segmentation_output_stride
=
16
aspp_dilation_rates
=
[
6
,
12
,
18
]
aspp_decoder_level
=
int
(
np
.
math
.
log2
(
segmentation_output_stride
))
fpn_decoder_level
=
3
shared_decoder
=
shared_decoder
and
shared_backbone
mask_head
=
instance_heads
.
MaskHead
(
num_classes
=
2
,
upsample_factor
=
2
)
mask_sampler_obj
=
mask_sampler
.
MaskSampler
(
mask_target_size
=
28
,
num_sampled_masks
=
1
)
mask_roi_aligner_obj
=
roi_aligner
.
MultilevelROIAligner
(
crop_size
=
14
)
if
shared_backbone
:
segmentation_backbone
=
None
else
:
segmentation_backbone
=
resnet
.
ResNet
(
model_id
=
segmentation_resnet_model_id
)
if
not
shared_decoder
:
level
=
aspp_decoder_level
segmentation_decoder
=
aspp
.
ASPP
(
level
=
level
,
dilation_rates
=
aspp_dilation_rates
)
else
:
level
=
fpn_decoder_level
segmentation_decoder
=
None
segmentation_head
=
segmentation_heads
.
SegmentationHead
(
num_classes
=
2
,
# stuff and common class for things,
level
=
level
,
num_convs
=
2
)
model
=
panoptic_maskrcnn_model
.
PanopticMaskRCNNModel
(
backbone
,
decoder
,
rpn_head
,
detection_head
,
roi_generator_obj
,
roi_sampler_obj
,
roi_aligner_obj
,
detection_generator_obj
,
mask_head
,
mask_sampler_obj
,
mask_roi_aligner_obj
,
segmentation_backbone
=
segmentation_backbone
,
segmentation_decoder
=
segmentation_decoder
,
segmentation_head
=
segmentation_head
,
min_level
=
3
,
max_level
=
7
,
num_scales
=
3
,
aspect_ratios
=
[
1.0
],
anchor_size
=
3
)
expect_checkpoint_items
=
dict
(
backbone
=
backbone
,
decoder
=
decoder
,
rpn_head
=
rpn_head
,
detection_head
=
[
detection_head
])
expect_checkpoint_items
[
'mask_head'
]
=
mask_head
if
not
shared_backbone
:
expect_checkpoint_items
[
'segmentation_backbone'
]
=
segmentation_backbone
if
not
shared_decoder
:
expect_checkpoint_items
[
'segmentation_decoder'
]
=
segmentation_decoder
expect_checkpoint_items
[
'segmentation_head'
]
=
segmentation_head
self
.
assertAllEqual
(
expect_checkpoint_items
,
model
.
checkpoint_items
)
# Test save and load checkpoints.
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
model
,
**
model
.
checkpoint_items
)
save_dir
=
self
.
create_tempdir
().
full_path
ckpt
.
save
(
os
.
path
.
join
(
save_dir
,
'ckpt'
))
partial_ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
backbone
)
partial_ckpt
.
restore
(
tf
.
train
.
latest_checkpoint
(
save_dir
)).
expect_partial
().
assert_existing_objects_matched
()
partial_ckpt_mask
=
tf
.
train
.
Checkpoint
(
backbone
=
backbone
,
mask_head
=
mask_head
)
partial_ckpt_mask
.
restore
(
tf
.
train
.
latest_checkpoint
(
save_dir
)).
expect_partial
().
assert_existing_objects_matched
()
if
not
shared_backbone
:
partial_ckpt_segmentation
=
tf
.
train
.
Checkpoint
(
segmentation_backbone
=
segmentation_backbone
,
segmentation_decoder
=
segmentation_decoder
,
segmentation_head
=
segmentation_head
)
elif
not
shared_decoder
:
partial_ckpt_segmentation
=
tf
.
train
.
Checkpoint
(
segmentation_decoder
=
segmentation_decoder
,
segmentation_head
=
segmentation_head
)
else
:
partial_ckpt_segmentation
=
tf
.
train
.
Checkpoint
(
segmentation_head
=
segmentation_head
)
partial_ckpt_segmentation
.
restore
(
tf
.
train
.
latest_checkpoint
(
save_dir
)).
expect_partial
().
assert_existing_objects_matched
()
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment