Commit 14c32065 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Freeze backbone support for RetinaNet

PiperOrigin-RevId: 443733011
parent e9dbffec
...@@ -167,6 +167,10 @@ class RetinaNetTask(cfg.TaskConfig): ...@@ -167,6 +167,10 @@ class RetinaNetTask(cfg.TaskConfig):
# If set, the Waymo Open Dataset evaluator would be used. # If set, the Waymo Open Dataset evaluator would be used.
use_wod_metrics: bool = False use_wod_metrics: bool = False
# If set, freezes the backbone during training.
# TODO(crisnv) Add paper link when available.
freeze_backbone: bool = False
@exp_factory.register_config_factory('retinanet') @exp_factory.register_config_factory('retinanet')
def retinanet() -> cfg.ExperimentConfig: def retinanet() -> cfg.ExperimentConfig:
......
...@@ -59,6 +59,10 @@ class RetinaNetTask(base_task.Task): ...@@ -59,6 +59,10 @@ class RetinaNetTask(base_task.Task):
input_specs=input_specs, input_specs=input_specs,
model_config=self.task_config.model, model_config=self.task_config.model,
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
if self.task_config.freeze_backbone:
model.backbone.trainable = False
return model return model
def initialize(self, model: tf.keras.Model): def initialize(self, model: tf.keras.Model):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment