Commit 20685639 authored by Liangzhe Yuan's avatar Liangzhe Yuan Committed by A. Unique TensorFlower
Browse files

Add freeze_backbone flag for image_classification.

PiperOrigin-RevId: 460570094
parent ad83b2db
......@@ -63,6 +63,7 @@ class ImageClassificationTask(cfg.TaskConfig):
evaluation: Evaluation = Evaluation()
init_checkpoint: Optional[str] = None
init_checkpoint_modules: str = 'all' # all or backbone
freeze_backbone: bool = False
IMAGENET_TRAIN_EXAMPLES = 1281167
......
......@@ -97,6 +97,7 @@ class ImageClassificationTask(cfg.TaskConfig):
init_checkpoint_modules: str = 'all' # all or backbone
model_output_keys: Optional[List[int]] = dataclasses.field(
default_factory=list)
freeze_backbone: bool = False
@exp_factory.register_config_factory('image_classification')
......
......@@ -49,6 +49,9 @@ class ImageClassificationTask(base_task.Task):
input_specs=input_specs,
model_config=self.task_config.model,
l2_regularizer=l2_regularizer)
if self.task_config.freeze_backbone:
model.backbone.trainable = False
return model
def initialize(self, model: tf.keras.Model):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment