Commit a34a571d authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 393926742
parent 5e5cdec3
...@@ -384,7 +384,7 @@ class MaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -384,7 +384,7 @@ class MaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
ckpt.save(os.path.join(save_dir, 'ckpt')) ckpt.save(os.path.join(save_dir, 'ckpt'))
partial_ckpt = tf.train.Checkpoint(backbone=backbone) partial_ckpt = tf.train.Checkpoint(backbone=backbone)
partial_ckpt.restore(tf.train.latest_checkpoint( partial_ckpt.read(tf.train.latest_checkpoint(
save_dir)).expect_partial().assert_existing_objects_matched() save_dir)).expect_partial().assert_existing_objects_matched()
if include_mask: if include_mask:
......
...@@ -460,7 +460,7 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -460,7 +460,7 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
ckpt.save(os.path.join(save_dir, 'ckpt')) ckpt.save(os.path.join(save_dir, 'ckpt'))
partial_ckpt = tf.train.Checkpoint(backbone=backbone) partial_ckpt = tf.train.Checkpoint(backbone=backbone)
partial_ckpt.restore(tf.train.latest_checkpoint( partial_ckpt.read(tf.train.latest_checkpoint(
save_dir)).expect_partial().assert_existing_objects_matched() save_dir)).expect_partial().assert_existing_objects_matched()
partial_ckpt_mask = tf.train.Checkpoint( partial_ckpt_mask = tf.train.Checkpoint(
......
...@@ -77,14 +77,14 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -77,14 +77,14 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
checkpoint_path = _get_checkpoint_path( checkpoint_path = _get_checkpoint_path(
self.task_config.init_checkpoint) self.task_config.init_checkpoint)
ckpt = tf.train.Checkpoint(**model.checkpoint_items) ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(checkpoint_path) status = ckpt.read(checkpoint_path)
status.assert_consumed() status.expect_partial().assert_existing_objects_matched()
elif init_module == 'backbone': elif init_module == 'backbone':
checkpoint_path = _get_checkpoint_path( checkpoint_path = _get_checkpoint_path(
self.task_config.init_checkpoint) self.task_config.init_checkpoint)
ckpt = tf.train.Checkpoint(backbone=model.backbone) ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.restore(checkpoint_path) status = ckpt.read(checkpoint_path)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
elif init_module == 'segmentation_backbone': elif init_module == 'segmentation_backbone':
...@@ -92,7 +92,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -92,7 +92,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
self.task_config.segmentation_init_checkpoint) self.task_config.segmentation_init_checkpoint)
ckpt = tf.train.Checkpoint( ckpt = tf.train.Checkpoint(
segmentation_backbone=model.segmentation_backbone) segmentation_backbone=model.segmentation_backbone)
status = ckpt.restore(checkpoint_path) status = ckpt.read(checkpoint_path)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
elif init_module == 'segmentation_decoder': elif init_module == 'segmentation_decoder':
...@@ -100,7 +100,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -100,7 +100,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
self.task_config.segmentation_init_checkpoint) self.task_config.segmentation_init_checkpoint)
ckpt = tf.train.Checkpoint( ckpt = tf.train.Checkpoint(
segmentation_decoder=model.segmentation_decoder) segmentation_decoder=model.segmentation_decoder)
status = ckpt.restore(checkpoint_path) status = ckpt.read(checkpoint_path)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
else: else:
......
...@@ -150,11 +150,11 @@ class SimCLRPretrainTask(base_task.Task): ...@@ -150,11 +150,11 @@ class SimCLRPretrainTask(base_task.Task):
# Restoring checkpoint. # Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'all': if self.task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(**model.checkpoint_items) ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.assert_consumed() status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'backbone': elif self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone) ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.restore(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
else: else:
assert "Only 'all' or 'backbone' can be used to initialize the model." assert "Only 'all' or 'backbone' can be used to initialize the model."
...@@ -455,16 +455,16 @@ class SimCLRFinetuneTask(base_task.Task): ...@@ -455,16 +455,16 @@ class SimCLRFinetuneTask(base_task.Task):
# Restoring checkpoint. # Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'all': if self.task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(**model.checkpoint_items) ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.assert_consumed() status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'backbone_projection': elif self.task_config.init_checkpoint_modules == 'backbone_projection':
ckpt = tf.train.Checkpoint( ckpt = tf.train.Checkpoint(
backbone=model.backbone, projection_head=model.projection_head) backbone=model.backbone, projection_head=model.projection_head)
status = ckpt.restore(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'backbone': elif self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone) ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.restore(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
else: else:
assert "Only 'all' or 'backbone' can be used to initialize the model." assert "Only 'all' or 'backbone' can be used to initialize the model."
......
...@@ -41,7 +41,7 @@ class VideoSSLEvalTask(video_classification.VideoClassificationTask): ...@@ -41,7 +41,7 @@ class VideoSSLEvalTask(video_classification.VideoClassificationTask):
# Restoring checkpoint. # Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'backbone': if self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone) ckpt = tf.train.Checkpoint(backbone=model.backbone)
ckpt.restore(ckpt_dir_or_file) ckpt.read(ckpt_dir_or_file)
else: else:
raise NotImplementedError raise NotImplementedError
......
...@@ -79,8 +79,8 @@ class SemanticSegmentation3DTask(base_task.Task): ...@@ -79,8 +79,8 @@ class SemanticSegmentation3DTask(base_task.Task):
# Restoring checkpoint. # Restoring checkpoint.
if 'all' in self.task_config.init_checkpoint_modules: if 'all' in self.task_config.init_checkpoint_modules:
ckpt = tf.train.Checkpoint(**model.checkpoint_items) ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.assert_consumed() status.expect_partial().assert_existing_objects_matched()
else: else:
ckpt_items = {} ckpt_items = {}
if 'backbone' in self.task_config.init_checkpoint_modules: if 'backbone' in self.task_config.init_checkpoint_modules:
...@@ -89,7 +89,7 @@ class SemanticSegmentation3DTask(base_task.Task): ...@@ -89,7 +89,7 @@ class SemanticSegmentation3DTask(base_task.Task):
ckpt_items.update(decoder=model.decoder) ckpt_items.update(decoder=model.decoder)
ckpt = tf.train.Checkpoint(**ckpt_items) ckpt = tf.train.Checkpoint(**ckpt_items)
status = ckpt.restore(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s', logging.info('Finished loading pretrained checkpoint from %s',
......
...@@ -62,11 +62,11 @@ class ImageClassificationTask(base_task.Task): ...@@ -62,11 +62,11 @@ class ImageClassificationTask(base_task.Task):
# Restoring checkpoint. # Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'all': if self.task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(**model.checkpoint_items) ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.assert_consumed() status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'backbone': elif self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone) ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.restore(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
else: else:
raise ValueError( raise ValueError(
......
...@@ -96,11 +96,11 @@ class MaskRCNNTask(base_task.Task): ...@@ -96,11 +96,11 @@ class MaskRCNNTask(base_task.Task):
# Restoring checkpoint. # Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'all': if self.task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(**model.checkpoint_items) ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.assert_consumed() status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'backbone': elif self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone) ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.restore(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
else: else:
raise ValueError( raise ValueError(
......
...@@ -71,11 +71,11 @@ class RetinaNetTask(base_task.Task): ...@@ -71,11 +71,11 @@ class RetinaNetTask(base_task.Task):
# Restoring checkpoint. # Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'all': if self.task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(**model.checkpoint_items) ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.assert_consumed() status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'backbone': elif self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone) ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.restore(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
else: else:
raise ValueError( raise ValueError(
......
...@@ -63,8 +63,8 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -63,8 +63,8 @@ class SemanticSegmentationTask(base_task.Task):
# Restoring checkpoint. # Restoring checkpoint.
if 'all' in self.task_config.init_checkpoint_modules: if 'all' in self.task_config.init_checkpoint_modules:
ckpt = tf.train.Checkpoint(**model.checkpoint_items) ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.assert_consumed() status.expect_partial().assert_existing_objects_matched()
else: else:
ckpt_items = {} ckpt_items = {}
if 'backbone' in self.task_config.init_checkpoint_modules: if 'backbone' in self.task_config.init_checkpoint_modules:
...@@ -73,7 +73,7 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -73,7 +73,7 @@ class SemanticSegmentationTask(base_task.Task):
ckpt_items.update(decoder=model.decoder) ckpt_items.update(decoder=model.decoder)
ckpt = tf.train.Checkpoint(**ckpt_items) ckpt = tf.train.Checkpoint(**ckpt_items)
status = ckpt.restore(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s', logging.info('Finished loading pretrained checkpoint from %s',
......
...@@ -86,11 +86,11 @@ class VideoClassificationTask(base_task.Task): ...@@ -86,11 +86,11 @@ class VideoClassificationTask(base_task.Task):
# Restoring checkpoint. # Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'all': if self.task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(**model.checkpoint_items) ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.assert_consumed() status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'backbone': elif self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone) ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.restore(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
else: else:
raise ValueError( raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment