Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
69eb52cd
Commit
69eb52cd
authored
Mar 15, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 362993979
parent
a4b513b1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
70 additions
and
0 deletions
+70
-0
official/vision/beta/modeling/maskrcnn_model.py
official/vision/beta/modeling/maskrcnn_model.py
+15
-0
official/vision/beta/modeling/maskrcnn_model_test.py
official/vision/beta/modeling/maskrcnn_model_test.py
+55
-0
No files found.
official/vision/beta/modeling/maskrcnn_model.py
View file @
69eb52cd
...
...
@@ -178,6 +178,7 @@ class MaskRCNNModel(tf.keras.Model):
# Mask head.
raw_masks
=
self
.
mask_head
([
mask_roi_features
,
roi_classes
])
if
training
:
model_outputs
.
update
({
'mask_outputs'
:
raw_masks
,
...
...
@@ -188,6 +189,20 @@ class MaskRCNNModel(tf.keras.Model):
})
return
model_outputs
@
property
def
checkpoint_items
(
self
):
"""Returns a dictionary of items to be additionally checkpointed."""
items
=
dict
(
backbone
=
self
.
backbone
,
rpn_head
=
self
.
rpn_head
,
detection_head
=
self
.
detection_head
)
if
self
.
decoder
is
not
None
:
items
.
update
(
decoder
=
self
.
decoder
)
if
self
.
_include_mask
:
items
.
update
(
mask_head
=
self
.
mask_head
)
return
items
def
get_config
(
self
):
return
self
.
_config_dict
...
...
official/vision/beta/modeling/maskrcnn_model_test.py
View file @
69eb52cd
...
...
@@ -15,6 +15,7 @@
# ==============================================================================
"""Tests for maskrcnn_model.py."""
import
os
# Import libraries
from
absl.testing
import
parameterized
import
numpy
as
np
...
...
@@ -274,6 +275,60 @@ class MaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
model
.
get_config
(),
new_model
.
get_config
())
@
parameterized
.
parameters
(
(
False
,),
(
True
,),
)
def
test_checkpoint
(
self
,
include_mask
):
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
])
backbone
=
resnet
.
ResNet
(
model_id
=
50
,
input_specs
=
input_specs
)
decoder
=
fpn
.
FPN
(
min_level
=
3
,
max_level
=
7
,
input_specs
=
backbone
.
output_specs
)
rpn_head
=
dense_prediction_heads
.
RPNHead
(
min_level
=
3
,
max_level
=
7
,
num_anchors_per_location
=
3
)
detection_head
=
instance_heads
.
DetectionHead
(
num_classes
=
2
)
roi_generator_obj
=
roi_generator
.
MultilevelROIGenerator
()
roi_sampler_obj
=
roi_sampler
.
ROISampler
()
roi_aligner_obj
=
roi_aligner
.
MultilevelROIAligner
()
detection_generator_obj
=
detection_generator
.
DetectionGenerator
()
if
include_mask
:
mask_head
=
instance_heads
.
MaskHead
(
num_classes
=
2
,
upsample_factor
=
2
)
mask_sampler_obj
=
mask_sampler
.
MaskSampler
(
mask_target_size
=
28
,
num_sampled_masks
=
1
)
mask_roi_aligner_obj
=
roi_aligner
.
MultilevelROIAligner
(
crop_size
=
14
)
else
:
mask_head
=
None
mask_sampler_obj
=
None
mask_roi_aligner_obj
=
None
model
=
maskrcnn_model
.
MaskRCNNModel
(
backbone
,
decoder
,
rpn_head
,
detection_head
,
roi_generator_obj
,
roi_sampler_obj
,
roi_aligner_obj
,
detection_generator_obj
,
mask_head
,
mask_sampler_obj
,
mask_roi_aligner_obj
)
expect_checkpoint_items
=
dict
(
backbone
=
backbone
,
decoder
=
decoder
,
rpn_head
=
rpn_head
,
detection_head
=
detection_head
)
if
include_mask
:
expect_checkpoint_items
[
'mask_head'
]
=
mask_head
self
.
assertAllEqual
(
expect_checkpoint_items
,
model
.
checkpoint_items
)
# Test save and load checkpoints.
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
model
,
**
model
.
checkpoint_items
)
save_dir
=
self
.
create_tempdir
().
full_path
ckpt
.
save
(
os
.
path
.
join
(
save_dir
,
'ckpt'
))
partial_ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
backbone
)
partial_ckpt
.
restore
(
tf
.
train
.
latest_checkpoint
(
save_dir
)).
expect_partial
().
assert_existing_objects_matched
()
if
include_mask
:
partial_ckpt_mask
=
tf
.
train
.
Checkpoint
(
backbone
=
backbone
,
mask_head
=
mask_head
)
partial_ckpt_mask
.
restore
(
tf
.
train
.
latest_checkpoint
(
save_dir
)).
expect_partial
().
assert_existing_objects_matched
()
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment