"...resnet50_tensorflow.git" did not exist on "d24f4bbaea382ed1ef2a37a38e98a9560f22d791"
Unverified Commit 1e32b05e authored by Rafael Padilla's avatar Rafael Padilla Committed by GitHub
Browse files

improving TimmBackbone to support FrozenBatchNorm2d (#27160)



* supporting freeze_batch_norm_2d

* supporting freeze_batch_norm_2d

* including unfreeze + separate into methods

* fix typo

* calling unfreeze

* lint

* Update src/transformers/models/timm_backbone/modeling_timm_backbone.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avatarRafael Padilla <rafael.padilla@huggingface.co>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 21a2fbaf
...@@ -43,6 +43,8 @@ class TimmBackboneConfig(PretrainedConfig): ...@@ -43,6 +43,8 @@ class TimmBackboneConfig(PretrainedConfig):
out_indices (`List[int]`, *optional*): out_indices (`List[int]`, *optional*):
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
many stages the model has). Will default to the last stage if unset. many stages the model has). Will default to the last stage if unset.
freeze_batch_norm_2d (`bool`, *optional*, defaults to `False`):
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`.
Example: Example:
```python ```python
...@@ -67,6 +69,7 @@ class TimmBackboneConfig(PretrainedConfig): ...@@ -67,6 +69,7 @@ class TimmBackboneConfig(PretrainedConfig):
features_only=True, features_only=True,
use_pretrained_backbone=True, use_pretrained_backbone=True,
out_indices=None, out_indices=None,
freeze_batch_norm_2d=False,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -76,3 +79,4 @@ class TimmBackboneConfig(PretrainedConfig): ...@@ -76,3 +79,4 @@ class TimmBackboneConfig(PretrainedConfig):
self.use_pretrained_backbone = use_pretrained_backbone self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = True self.use_timm_backbone = True
self.out_indices = out_indices if out_indices is not None else (-1,) self.out_indices = out_indices if out_indices is not None else (-1,)
self.freeze_batch_norm_2d = freeze_batch_norm_2d
...@@ -72,6 +72,11 @@ class TimmBackbone(PreTrainedModel, BackboneMixin): ...@@ -72,6 +72,11 @@ class TimmBackbone(PreTrainedModel, BackboneMixin):
out_indices=out_indices, out_indices=out_indices,
**kwargs, **kwargs,
) )
# Converts all `BatchNorm2d` and `SyncBatchNorm` or `BatchNormAct2d` and `SyncBatchNormAct2d` layers of provided module into `FrozenBatchNorm2d` or `FrozenBatchNormAct2d` respectively
if getattr(config, "freeze_batch_norm_2d", False):
self.freeze_batch_norm_2d()
# These are used to control the output of the model when called. If output_hidden_states is True, then # These are used to control the output of the model when called. If output_hidden_states is True, then
# return_layers is modified to include all layers. # return_layers is modified to include all layers.
self._return_layers = self._backbone.return_layers self._return_layers = self._backbone.return_layers
...@@ -102,6 +107,12 @@ class TimmBackbone(PreTrainedModel, BackboneMixin): ...@@ -102,6 +107,12 @@ class TimmBackbone(PreTrainedModel, BackboneMixin):
) )
return super()._from_config(config, **kwargs) return super()._from_config(config, **kwargs)
def freeze_batch_norm_2d(self):
timm.layers.freeze_batch_norm_2d(self._backbone)
def unfreeze_batch_norm_2d(self):
timm.layers.unfreeze_batch_norm_2d(self._backbone)
def _init_weights(self, module): def _init_weights(self, module):
""" """
Empty init weights function to ensure compatibility of the class in the library. Empty init weights function to ensure compatibility of the class in the library.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment