"git@developer.sourcefind.cn:OpenDAS/torch-scatter.git" did not exist on "4dcae88c9f6dac12ef0c69cf686b2f620143d2eb"
Commit fc3edc9e authored by Liangzhe Yuan's avatar Liangzhe Yuan Committed by A. Unique TensorFlower
Browse files

Add freeze backbone flag for video classification task.

PiperOrigin-RevId: 474361376
parent ad886737
...@@ -150,6 +150,7 @@ class VideoClassificationTask(cfg.TaskConfig): ...@@ -150,6 +150,7 @@ class VideoClassificationTask(cfg.TaskConfig):
metrics: Metrics = Metrics() metrics: Metrics = Metrics()
init_checkpoint: Optional[str] = None init_checkpoint: Optional[str] = None
init_checkpoint_modules: str = 'all' # all or backbone init_checkpoint_modules: str = 'all' # all or backbone
freeze_backbone: bool = False
# Spatial Partitioning fields. # Spatial Partitioning fields.
train_input_partition_dims: Optional[Tuple[int, ...]] = None train_input_partition_dims: Optional[Tuple[int, ...]] = None
eval_input_partition_dims: Optional[Tuple[int, ...]] = None eval_input_partition_dims: Optional[Tuple[int, ...]] = None
......
...@@ -60,7 +60,7 @@ class VideoClassificationTask(base_task.Task): ...@@ -60,7 +60,7 @@ class VideoClassificationTask(base_task.Task):
input_specs = tf.keras.layers.InputSpec(shape=[None] + common_input_shape) input_specs = tf.keras.layers.InputSpec(shape=[None] + common_input_shape)
logging.info('Build model input %r', common_input_shape) logging.info('Build model input %r', common_input_shape)
l2_weight_decay = self.task_config.losses.l2_weight_decay l2_weight_decay = float(self.task_config.losses.l2_weight_decay)
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss. # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2) # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss) # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
...@@ -73,6 +73,10 @@ class VideoClassificationTask(base_task.Task): ...@@ -73,6 +73,10 @@ class VideoClassificationTask(base_task.Task):
model_config=self.task_config.model, model_config=self.task_config.model,
num_classes=self._get_num_classes(), num_classes=self._get_num_classes(),
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
if self.task_config.freeze_backbone:
logging.info('Freezing model backbone.')
model.backbone.trainable = False
return model return model
def initialize(self, model: tf.keras.Model): def initialize(self, model: tf.keras.Model):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment