Commit 2fe71495 authored by Liangzhe Yuan's avatar Liangzhe Yuan Committed by A. Unique TensorFlower
Browse files

Add freeze backbone flag for video classification task.

PiperOrigin-RevId: 474361376
parent 3f7b5405
......@@ -150,6 +150,7 @@ class VideoClassificationTask(cfg.TaskConfig):
metrics: Metrics = Metrics()
init_checkpoint: Optional[str] = None
init_checkpoint_modules: str = 'all' # all or backbone
freeze_backbone: bool = False
# Spatial Partitioning fields.
train_input_partition_dims: Optional[Tuple[int, ...]] = None
eval_input_partition_dims: Optional[Tuple[int, ...]] = None
......
......@@ -60,7 +60,7 @@ class VideoClassificationTask(base_task.Task):
input_specs = tf.keras.layers.InputSpec(shape=[None] + common_input_shape)
logging.info('Build model input %r', common_input_shape)
l2_weight_decay = self.task_config.losses.l2_weight_decay
l2_weight_decay = float(self.task_config.losses.l2_weight_decay)
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
......@@ -73,6 +73,10 @@ class VideoClassificationTask(base_task.Task):
model_config=self.task_config.model,
num_classes=self._get_num_classes(),
l2_regularizer=l2_regularizer)
if self.task_config.freeze_backbone:
logging.info('Freezing model backbone.')
model.backbone.trainable = False
return model
def initialize(self, model: tf.keras.Model):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment