Commit c3578304 authored by Xianzhi Du's avatar Xianzhi Du Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 394064826
parent a5b1cfbe
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import dataclasses import dataclasses
import os import os
from typing import List, Optional from typing import List, Optional, Union
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
...@@ -201,7 +201,8 @@ class MaskRCNNTask(cfg.TaskConfig): ...@@ -201,7 +201,8 @@ class MaskRCNNTask(cfg.TaskConfig):
drop_remainder=False) drop_remainder=False)
losses: Losses = Losses() losses: Losses = Losses()
init_checkpoint: Optional[str] = None init_checkpoint: Optional[str] = None
init_checkpoint_modules: str = 'all' # all or backbone init_checkpoint_modules: Union[
str, List[str]] = 'all' # all, backbone, and/or decoder
annotation_file: Optional[str] = None annotation_file: Optional[str] = None
per_category_metrics: bool = False per_category_metrics: bool = False
# If set, we only use masks for the specified class IDs. # If set, we only use masks for the specified class IDs.
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import dataclasses import dataclasses
import os import os
from typing import List, Optional from typing import List, Optional, Union
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
...@@ -145,7 +145,8 @@ class RetinaNetTask(cfg.TaskConfig): ...@@ -145,7 +145,8 @@ class RetinaNetTask(cfg.TaskConfig):
validation_data: DataConfig = DataConfig(is_training=False) validation_data: DataConfig = DataConfig(is_training=False)
losses: Losses = Losses() losses: Losses = Losses()
init_checkpoint: Optional[str] = None init_checkpoint: Optional[str] = None
init_checkpoint_modules: str = 'all' # all or backbone init_checkpoint_modules: Union[
str, List[str]] = 'all' # all, backbone, and/or decoder
annotation_file: Optional[str] = None annotation_file: Optional[str] = None
per_category_metrics: bool = False per_category_metrics: bool = False
export_config: ExportConfig = ExportConfig() export_config: ExportConfig = ExportConfig()
......
...@@ -98,13 +98,16 @@ class MaskRCNNTask(base_task.Task): ...@@ -98,13 +98,16 @@ class MaskRCNNTask(base_task.Task):
ckpt = tf.train.Checkpoint(**model.checkpoint_items) ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.read(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'backbone': else:
ckpt = tf.train.Checkpoint(backbone=model.backbone) ckpt_items = {}
if 'backbone' in self.task_config.init_checkpoint_modules:
ckpt_items.update(backbone=model.backbone)
if 'decoder' in self.task_config.init_checkpoint_modules:
ckpt_items.update(decoder=model.decoder)
ckpt = tf.train.Checkpoint(**ckpt_items)
status = ckpt.read(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
else:
raise ValueError(
"Only 'all' or 'backbone' can be used to initialize the model.")
logging.info('Finished loading pretrained checkpoint from %s', logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file) ckpt_dir_or_file)
......
...@@ -73,13 +73,16 @@ class RetinaNetTask(base_task.Task): ...@@ -73,13 +73,16 @@ class RetinaNetTask(base_task.Task):
ckpt = tf.train.Checkpoint(**model.checkpoint_items) ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.read(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'backbone': else:
ckpt = tf.train.Checkpoint(backbone=model.backbone) ckpt_items = {}
if 'backbone' in self.task_config.init_checkpoint_modules:
ckpt_items.update(backbone=model.backbone)
if 'decoder' in self.task_config.init_checkpoint_modules:
ckpt_items.update(decoder=model.decoder)
ckpt = tf.train.Checkpoint(**ckpt_items)
status = ckpt.read(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
else:
raise ValueError(
"Only 'all' or 'backbone' can be used to initialize the model.")
logging.info('Finished loading pretrained checkpoint from %s', logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file) ckpt_dir_or_file)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment