Commit 3dcc078a authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 362993979
parent 3b6c3d10
......@@ -178,6 +178,7 @@ class MaskRCNNModel(tf.keras.Model):
# Mask head.
raw_masks = self.mask_head([mask_roi_features, roi_classes])
if training:
model_outputs.update({
'mask_outputs': raw_masks,
......@@ -188,6 +189,20 @@ class MaskRCNNModel(tf.keras.Model):
})
return model_outputs
@property
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
items = dict(
backbone=self.backbone,
rpn_head=self.rpn_head,
detection_head=self.detection_head)
if self.decoder is not None:
items.update(decoder=self.decoder)
if self._include_mask:
items.update(mask_head=self.mask_head)
return items
def get_config(self):
return self._config_dict
......
......@@ -15,6 +15,7 @@
# ==============================================================================
"""Tests for maskrcnn_model.py."""
import os
# Import libraries
from absl.testing import parameterized
import numpy as np
......@@ -274,6 +275,60 @@ class MaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(model.get_config(), new_model.get_config())
@parameterized.parameters(
(False,),
(True,),
)
def test_checkpoint(self, include_mask):
input_specs = tf.keras.layers.InputSpec(shape=[None, None, None, 3])
backbone = resnet.ResNet(model_id=50, input_specs=input_specs)
decoder = fpn.FPN(
min_level=3, max_level=7, input_specs=backbone.output_specs)
rpn_head = dense_prediction_heads.RPNHead(
min_level=3, max_level=7, num_anchors_per_location=3)
detection_head = instance_heads.DetectionHead(num_classes=2)
roi_generator_obj = roi_generator.MultilevelROIGenerator()
roi_sampler_obj = roi_sampler.ROISampler()
roi_aligner_obj = roi_aligner.MultilevelROIAligner()
detection_generator_obj = detection_generator.DetectionGenerator()
if include_mask:
mask_head = instance_heads.MaskHead(num_classes=2, upsample_factor=2)
mask_sampler_obj = mask_sampler.MaskSampler(
mask_target_size=28, num_sampled_masks=1)
mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(crop_size=14)
else:
mask_head = None
mask_sampler_obj = None
mask_roi_aligner_obj = None
model = maskrcnn_model.MaskRCNNModel(backbone, decoder, rpn_head,
detection_head, roi_generator_obj,
roi_sampler_obj, roi_aligner_obj,
detection_generator_obj, mask_head,
mask_sampler_obj, mask_roi_aligner_obj)
expect_checkpoint_items = dict(
backbone=backbone,
decoder=decoder,
rpn_head=rpn_head,
detection_head=detection_head)
if include_mask:
expect_checkpoint_items['mask_head'] = mask_head
self.assertAllEqual(expect_checkpoint_items, model.checkpoint_items)
# Test save and load checkpoints.
ckpt = tf.train.Checkpoint(model=model, **model.checkpoint_items)
save_dir = self.create_tempdir().full_path
ckpt.save(os.path.join(save_dir, 'ckpt'))
partial_ckpt = tf.train.Checkpoint(backbone=backbone)
partial_ckpt.restore(tf.train.latest_checkpoint(
save_dir)).expect_partial().assert_existing_objects_matched()
if include_mask:
partial_ckpt_mask = tf.train.Checkpoint(
backbone=backbone, mask_head=mask_head)
partial_ckpt_mask.restore(tf.train.latest_checkpoint(
save_dir)).expect_partial().assert_existing_objects_matched()
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment