Unverified Commit 350a3e8e authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Use frozen BN only if pre-trained. (#5443)

parent b4cb352c
...@@ -383,15 +383,15 @@ def fasterrcnn_resnet50_fpn( ...@@ -383,15 +383,15 @@ def fasterrcnn_resnet50_fpn(
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
passed (the default) this value is set to 3. passed (the default) this value is set to 3.
""" """
trainable_backbone_layers = _validate_trainable_layers( is_trained = pretrained or pretrained_backbone
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3 trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
if pretrained: if pretrained:
# no need to download the backbone if pretrained is set # no need to download the backbone if pretrained is set
pretrained_backbone = False pretrained_backbone = False
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer)
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
model = FasterRCNN(backbone, num_classes, **kwargs) model = FasterRCNN(backbone, num_classes, **kwargs)
if pretrained: if pretrained:
...@@ -410,16 +410,14 @@ def _fasterrcnn_mobilenet_v3_large_fpn( ...@@ -410,16 +410,14 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
trainable_backbone_layers=None, trainable_backbone_layers=None,
**kwargs, **kwargs,
): ):
trainable_backbone_layers = _validate_trainable_layers( is_trained = pretrained or pretrained_backbone
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3 trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3)
) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
if pretrained: if pretrained:
pretrained_backbone = False pretrained_backbone = False
backbone = mobilenet_v3_large( backbone = mobilenet_v3_large(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer)
pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d
)
backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers) backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
anchor_sizes = ( anchor_sizes = (
......
...@@ -686,15 +686,15 @@ def fcos_resnet50_fpn( ...@@ -686,15 +686,15 @@ def fcos_resnet50_fpn(
from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
trainable. If ``None`` is passed (the default) this value is set to 3. Default: None trainable. If ``None`` is passed (the default) this value is set to 3. Default: None
""" """
trainable_backbone_layers = _validate_trainable_layers( is_trained = pretrained or pretrained_backbone
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3 trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
if pretrained: if pretrained:
# no need to download the backbone if pretrained is set # no need to download the backbone if pretrained is set
pretrained_backbone = False pretrained_backbone = False
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer)
backbone = _resnet_fpn_extractor( backbone = _resnet_fpn_extractor(
backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
) )
......
...@@ -365,15 +365,15 @@ def keypointrcnn_resnet50_fpn( ...@@ -365,15 +365,15 @@ def keypointrcnn_resnet50_fpn(
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
passed (the default) this value is set to 3. passed (the default) this value is set to 3.
""" """
trainable_backbone_layers = _validate_trainable_layers( is_trained = pretrained or pretrained_backbone
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3 trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
if pretrained: if pretrained:
# no need to download the backbone if pretrained is set # no need to download the backbone if pretrained is set
pretrained_backbone = False pretrained_backbone = False
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer)
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs) model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
if pretrained: if pretrained:
......
...@@ -360,15 +360,15 @@ def maskrcnn_resnet50_fpn( ...@@ -360,15 +360,15 @@ def maskrcnn_resnet50_fpn(
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
passed (the default) this value is set to 3. passed (the default) this value is set to 3.
""" """
trainable_backbone_layers = _validate_trainable_layers( is_trained = pretrained or pretrained_backbone
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3 trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
if pretrained: if pretrained:
# no need to download the backbone if pretrained is set # no need to download the backbone if pretrained is set
pretrained_backbone = False pretrained_backbone = False
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer)
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
model = MaskRCNN(backbone, num_classes, **kwargs) model = MaskRCNN(backbone, num_classes, **kwargs)
if pretrained: if pretrained:
......
...@@ -626,15 +626,15 @@ def retinanet_resnet50_fpn( ...@@ -626,15 +626,15 @@ def retinanet_resnet50_fpn(
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
passed (the default) this value is set to 3. passed (the default) this value is set to 3.
""" """
trainable_backbone_layers = _validate_trainable_layers( is_trained = pretrained or pretrained_backbone
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3 trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
if pretrained: if pretrained:
# no need to download the backbone if pretrained is set # no need to download the backbone if pretrained is set
pretrained_backbone = False pretrained_backbone = False
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer)
# skip P2 because it generates too many anchors (according to their paper) # skip P2 because it generates too many anchors (according to their paper)
backbone = _resnet_fpn_extractor( backbone = _resnet_fpn_extractor(
backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
......
from typing import Any, Optional, Union from typing import Any, Optional, Union
from torch import nn
from torchvision.prototype.transforms import CocoEval from torchvision.prototype.transforms import CocoEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
...@@ -103,11 +104,11 @@ def fasterrcnn_resnet50_fpn( ...@@ -103,11 +104,11 @@ def fasterrcnn_resnet50_fpn(
elif num_classes is None: elif num_classes is None:
num_classes = 91 num_classes = 91
trainable_backbone_layers = _validate_trainable_layers( is_trained = weights is not None or weights_backbone is not None
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3 trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
model = FasterRCNN(backbone, num_classes=num_classes, **kwargs) model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)
...@@ -134,11 +135,11 @@ def _fasterrcnn_mobilenet_v3_large_fpn( ...@@ -134,11 +135,11 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
elif num_classes is None: elif num_classes is None:
num_classes = 91 num_classes = 91
trainable_backbone_layers = _validate_trainable_layers( is_trained = weights is not None or weights_backbone is not None
weights is not None or weights_backbone is not None, trainable_backbone_layers, 6, 3 trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3)
) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers) backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
anchor_sizes = ( anchor_sizes = (
( (
......
from typing import Any, Optional from typing import Any, Optional
from torch import nn
from torchvision.prototype.transforms import CocoEval from torchvision.prototype.transforms import CocoEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
...@@ -63,11 +64,11 @@ def fcos_resnet50_fpn( ...@@ -63,11 +64,11 @@ def fcos_resnet50_fpn(
elif num_classes is None: elif num_classes is None:
num_classes = 91 num_classes = 91
trainable_backbone_layers = _validate_trainable_layers( is_trained = weights is not None or weights_backbone is not None
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3 trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
backbone = _resnet_fpn_extractor( backbone = _resnet_fpn_extractor(
backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
) )
......
from typing import Any, Optional from typing import Any, Optional
from torch import nn
from torchvision.prototype.transforms import CocoEval from torchvision.prototype.transforms import CocoEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
...@@ -91,11 +92,11 @@ def keypointrcnn_resnet50_fpn( ...@@ -91,11 +92,11 @@ def keypointrcnn_resnet50_fpn(
if num_keypoints is None: if num_keypoints is None:
num_keypoints = 17 num_keypoints = 17
trainable_backbone_layers = _validate_trainable_layers( is_trained = weights is not None or weights_backbone is not None
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3 trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs) model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
......
from typing import Any, Optional from typing import Any, Optional
from torch import nn
from torchvision.prototype.transforms import CocoEval from torchvision.prototype.transforms import CocoEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
...@@ -64,11 +65,11 @@ def maskrcnn_resnet50_fpn( ...@@ -64,11 +65,11 @@ def maskrcnn_resnet50_fpn(
elif num_classes is None: elif num_classes is None:
num_classes = 91 num_classes = 91
trainable_backbone_layers = _validate_trainable_layers( is_trained = weights is not None or weights_backbone is not None
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3 trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
model = MaskRCNN(backbone, num_classes=num_classes, **kwargs) model = MaskRCNN(backbone, num_classes=num_classes, **kwargs)
......
from typing import Any, Optional from typing import Any, Optional
from torch import nn
from torchvision.prototype.transforms import CocoEval from torchvision.prototype.transforms import CocoEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
...@@ -64,11 +65,11 @@ def retinanet_resnet50_fpn( ...@@ -64,11 +65,11 @@ def retinanet_resnet50_fpn(
elif num_classes is None: elif num_classes is None:
num_classes = 91 num_classes = 91
trainable_backbone_layers = _validate_trainable_layers( is_trained = weights is not None or weights_backbone is not None
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3 trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
# skip P2 because it generates too many anchors (according to their paper) # skip P2 because it generates too many anchors (according to their paper)
backbone = _resnet_fpn_extractor( backbone = _resnet_fpn_extractor(
backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment