"encoding/vscode:/vscode.git/clone" did not exist on "331ecdd5306104614cb414b16fbcd9d1a8d40e1e"
Commit 20685639 authored by Liangzhe Yuan's avatar Liangzhe Yuan Committed by A. Unique TensorFlower
Browse files

Add freeze_backbone flag for image_classification.

PiperOrigin-RevId: 460570094
parent ad83b2db
...@@ -63,6 +63,7 @@ class ImageClassificationTask(cfg.TaskConfig): ...@@ -63,6 +63,7 @@ class ImageClassificationTask(cfg.TaskConfig):
evaluation: Evaluation = Evaluation() evaluation: Evaluation = Evaluation()
init_checkpoint: Optional[str] = None init_checkpoint: Optional[str] = None
init_checkpoint_modules: str = 'all' # all or backbone init_checkpoint_modules: str = 'all' # all or backbone
freeze_backbone: bool = False
IMAGENET_TRAIN_EXAMPLES = 1281167 IMAGENET_TRAIN_EXAMPLES = 1281167
......
...@@ -97,6 +97,7 @@ class ImageClassificationTask(cfg.TaskConfig): ...@@ -97,6 +97,7 @@ class ImageClassificationTask(cfg.TaskConfig):
init_checkpoint_modules: str = 'all' # all or backbone init_checkpoint_modules: str = 'all' # all or backbone
model_output_keys: Optional[List[int]] = dataclasses.field( model_output_keys: Optional[List[int]] = dataclasses.field(
default_factory=list) default_factory=list)
freeze_backbone: bool = False
@exp_factory.register_config_factory('image_classification') @exp_factory.register_config_factory('image_classification')
......
...@@ -49,6 +49,9 @@ class ImageClassificationTask(base_task.Task): ...@@ -49,6 +49,9 @@ class ImageClassificationTask(base_task.Task):
input_specs=input_specs, input_specs=input_specs,
model_config=self.task_config.model, model_config=self.task_config.model,
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
if self.task_config.freeze_backbone:
model.backbone.trainable = False
return model return model
def initialize(self, model: tf.keras.Model): def initialize(self, model: tf.keras.Model):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment