"...resnet50_tensorflow.git" did not exist on "f428c400b875f415f26f5c93bf565ba3cb4057f1"
Commit 4fd1790e authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by TF Object Detection Team
Browse files

Made changes in `center_net_meta_arch.py` for the particular use case in Marky...

Made changes in `center_net_meta_arch.py` for the particular use case in Marky Landmark Detection, i.e.
1. Accessing an access-protected field in the model;
2. Traverse layers to set trainability.

PiperOrigin-RevId: 385136419
parent c705089f
...@@ -2455,6 +2455,24 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2455,6 +2455,24 @@ class CenterNetMetaArch(model.DetectionModel):
super(CenterNetMetaArch, self).__init__(num_classes) super(CenterNetMetaArch, self).__init__(num_classes)
def set_trainability_by_layer_traversal(self, trainable):
"""Sets trainability layer by layer.
The commonly-seen `model.trainable = False` method does not traverse
the children layer. For example, if the parent is not trainable, we won't
be able to set individual layers as trainable/non-trainable differentially.
Args:
trainable: (bool) Setting this for the model layer by layer except for
the parent itself.
"""
for layer in self._flatten_layers(include_self=False):
layer.trainable = trainable
@property
def prediction_head_dict(self):
return self._prediction_head_dict
@property @property
def batched_prediction_tensor_names(self): def batched_prediction_tensor_names(self):
if not self._batched_prediction_tensor_names: if not self._batched_prediction_tensor_names:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment