"vscode:/vscode.git/clone" did not exist on "553334f37ffb331fdf44aed0f78684a5e4a16514"
Commit a9d416f1 authored by Xianzhi Du's avatar Xianzhi Du Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 394064826
parent a91f3779
......@@ -17,7 +17,7 @@
import dataclasses
import os
from typing import List, Optional
from typing import List, Optional, Union
from official.core import config_definitions as cfg
from official.core import exp_factory
......@@ -201,7 +201,8 @@ class MaskRCNNTask(cfg.TaskConfig):
drop_remainder=False)
losses: Losses = Losses()
init_checkpoint: Optional[str] = None
init_checkpoint_modules: str = 'all' # all or backbone
init_checkpoint_modules: Union[
str, List[str]] = 'all' # all, backbone, and/or decoder
annotation_file: Optional[str] = None
per_category_metrics: bool = False
# If set, we only use masks for the specified class IDs.
......
......@@ -17,7 +17,7 @@
import dataclasses
import os
from typing import List, Optional
from typing import List, Optional, Union
from official.core import config_definitions as cfg
from official.core import exp_factory
......@@ -145,7 +145,8 @@ class RetinaNetTask(cfg.TaskConfig):
validation_data: DataConfig = DataConfig(is_training=False)
losses: Losses = Losses()
init_checkpoint: Optional[str] = None
init_checkpoint_modules: str = 'all' # all or backbone
init_checkpoint_modules: Union[
str, List[str]] = 'all' # all, backbone, and/or decoder
annotation_file: Optional[str] = None
per_category_metrics: bool = False
export_config: ExportConfig = ExportConfig()
......
......@@ -98,13 +98,16 @@ class MaskRCNNTask(base_task.Task):
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone)
else:
ckpt_items = {}
if 'backbone' in self.task_config.init_checkpoint_modules:
ckpt_items.update(backbone=model.backbone)
if 'decoder' in self.task_config.init_checkpoint_modules:
ckpt_items.update(decoder=model.decoder)
ckpt = tf.train.Checkpoint(**ckpt_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
raise ValueError(
"Only 'all' or 'backbone' can be used to initialize the model.")
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
......
......@@ -73,13 +73,16 @@ class RetinaNetTask(base_task.Task):
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone)
else:
ckpt_items = {}
if 'backbone' in self.task_config.init_checkpoint_modules:
ckpt_items.update(backbone=model.backbone)
if 'decoder' in self.task_config.init_checkpoint_modules:
ckpt_items.update(decoder=model.decoder)
ckpt = tf.train.Checkpoint(**ckpt_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
raise ValueError(
"Only 'all' or 'backbone' can be used to initialize the model.")
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment