"docs/basic_usage/qwen3.md" did not exist on "1ee11df8acccd9c59d0936f8faf69afac5bbb146"
Commit bf9805c5 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 393926742
parent 164bab98
......@@ -384,7 +384,7 @@ class MaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
ckpt.save(os.path.join(save_dir, 'ckpt'))
partial_ckpt = tf.train.Checkpoint(backbone=backbone)
partial_ckpt.restore(tf.train.latest_checkpoint(
partial_ckpt.read(tf.train.latest_checkpoint(
save_dir)).expect_partial().assert_existing_objects_matched()
if include_mask:
......
......@@ -460,7 +460,7 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
ckpt.save(os.path.join(save_dir, 'ckpt'))
partial_ckpt = tf.train.Checkpoint(backbone=backbone)
partial_ckpt.restore(tf.train.latest_checkpoint(
partial_ckpt.read(tf.train.latest_checkpoint(
save_dir)).expect_partial().assert_existing_objects_matched()
partial_ckpt_mask = tf.train.Checkpoint(
......
......@@ -77,14 +77,14 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
checkpoint_path = _get_checkpoint_path(
self.task_config.init_checkpoint)
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(checkpoint_path)
status.assert_consumed()
status = ckpt.read(checkpoint_path)
status.expect_partial().assert_existing_objects_matched()
elif init_module == 'backbone':
checkpoint_path = _get_checkpoint_path(
self.task_config.init_checkpoint)
ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.restore(checkpoint_path)
status = ckpt.read(checkpoint_path)
status.expect_partial().assert_existing_objects_matched()
elif init_module == 'segmentation_backbone':
......@@ -92,7 +92,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
self.task_config.segmentation_init_checkpoint)
ckpt = tf.train.Checkpoint(
segmentation_backbone=model.segmentation_backbone)
status = ckpt.restore(checkpoint_path)
status = ckpt.read(checkpoint_path)
status.expect_partial().assert_existing_objects_matched()
elif init_module == 'segmentation_decoder':
......@@ -100,7 +100,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
self.task_config.segmentation_init_checkpoint)
ckpt = tf.train.Checkpoint(
segmentation_decoder=model.segmentation_decoder)
status = ckpt.restore(checkpoint_path)
status = ckpt.read(checkpoint_path)
status.expect_partial().assert_existing_objects_matched()
else:
......
......@@ -150,11 +150,11 @@ class SimCLRPretrainTask(base_task.Task):
# Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.assert_consumed()
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.restore(ckpt_dir_or_file)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
assert "Only 'all' or 'backbone' can be used to initialize the model."
......@@ -455,16 +455,16 @@ class SimCLRFinetuneTask(base_task.Task):
# Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.assert_consumed()
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'backbone_projection':
ckpt = tf.train.Checkpoint(
backbone=model.backbone, projection_head=model.projection_head)
status = ckpt.restore(ckpt_dir_or_file)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.restore(ckpt_dir_or_file)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
assert "Only 'all' or 'backbone' can be used to initialize the model."
......
......@@ -41,7 +41,7 @@ class VideoSSLEvalTask(video_classification.VideoClassificationTask):
# Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone)
ckpt.restore(ckpt_dir_or_file)
ckpt.read(ckpt_dir_or_file)
else:
raise NotImplementedError
......
......@@ -79,8 +79,8 @@ class SemanticSegmentation3DTask(base_task.Task):
# Restoring checkpoint.
if 'all' in self.task_config.init_checkpoint_modules:
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.assert_consumed()
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
ckpt_items = {}
if 'backbone' in self.task_config.init_checkpoint_modules:
......@@ -89,7 +89,7 @@ class SemanticSegmentation3DTask(base_task.Task):
ckpt_items.update(decoder=model.decoder)
ckpt = tf.train.Checkpoint(**ckpt_items)
status = ckpt.restore(ckpt_dir_or_file)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
......
......@@ -62,11 +62,11 @@ class ImageClassificationTask(base_task.Task):
# Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.assert_consumed()
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.restore(ckpt_dir_or_file)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
raise ValueError(
......
......@@ -96,11 +96,11 @@ class MaskRCNNTask(base_task.Task):
# Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.assert_consumed()
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.restore(ckpt_dir_or_file)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
raise ValueError(
......
......@@ -71,11 +71,11 @@ class RetinaNetTask(base_task.Task):
# Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.assert_consumed()
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.restore(ckpt_dir_or_file)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
raise ValueError(
......
......@@ -63,8 +63,8 @@ class SemanticSegmentationTask(base_task.Task):
# Restoring checkpoint.
if 'all' in self.task_config.init_checkpoint_modules:
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.assert_consumed()
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
ckpt_items = {}
if 'backbone' in self.task_config.init_checkpoint_modules:
......@@ -73,7 +73,7 @@ class SemanticSegmentationTask(base_task.Task):
ckpt_items.update(decoder=model.decoder)
ckpt = tf.train.Checkpoint(**ckpt_items)
status = ckpt.restore(ckpt_dir_or_file)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
......
......@@ -86,11 +86,11 @@ class VideoClassificationTask(base_task.Task):
# Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.assert_consumed()
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.restore(ckpt_dir_or_file)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment