Unverified Commit 350a3e8e authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Use frozen BN only if pre-trained. (#5443)

parent b4cb352c
......@@ -383,15 +383,15 @@ def fasterrcnn_resnet50_fpn(
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
passed (the default) this value is set to 3.
"""
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3
)
is_trained = pretrained or pretrained_backbone
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
if pretrained:
# no need to download the backbone if pretrained is set
pretrained_backbone = False
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer)
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
model = FasterRCNN(backbone, num_classes, **kwargs)
if pretrained:
......@@ -410,16 +410,14 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
trainable_backbone_layers=None,
**kwargs,
):
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3
)
is_trained = pretrained or pretrained_backbone
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3)
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
if pretrained:
pretrained_backbone = False
backbone = mobilenet_v3_large(
pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d
)
backbone = mobilenet_v3_large(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer)
backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
anchor_sizes = (
......
......@@ -686,15 +686,15 @@ def fcos_resnet50_fpn(
from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
trainable. If ``None`` is passed (the default) this value is set to 3. Default: None
"""
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3
)
is_trained = pretrained or pretrained_backbone
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
if pretrained:
# no need to download the backbone if pretrained is set
pretrained_backbone = False
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer)
backbone = _resnet_fpn_extractor(
backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
)
......
......@@ -365,15 +365,15 @@ def keypointrcnn_resnet50_fpn(
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
passed (the default) this value is set to 3.
"""
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3
)
is_trained = pretrained or pretrained_backbone
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
if pretrained:
# no need to download the backbone if pretrained is set
pretrained_backbone = False
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer)
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
if pretrained:
......
......@@ -360,15 +360,15 @@ def maskrcnn_resnet50_fpn(
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
passed (the default) this value is set to 3.
"""
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3
)
is_trained = pretrained or pretrained_backbone
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
if pretrained:
# no need to download the backbone if pretrained is set
pretrained_backbone = False
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer)
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
model = MaskRCNN(backbone, num_classes, **kwargs)
if pretrained:
......
......@@ -626,15 +626,15 @@ def retinanet_resnet50_fpn(
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
passed (the default) this value is set to 3.
"""
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3
)
is_trained = pretrained or pretrained_backbone
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
if pretrained:
# no need to download the backbone if pretrained is set
pretrained_backbone = False
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer)
# skip P2 because it generates too many anchors (according to their paper)
backbone = _resnet_fpn_extractor(
backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
......
from typing import Any, Optional, Union
from torch import nn
from torchvision.prototype.transforms import CocoEval
from torchvision.transforms.functional import InterpolationMode
......@@ -103,11 +104,11 @@ def fasterrcnn_resnet50_fpn(
elif num_classes is None:
num_classes = 91
trainable_backbone_layers = _validate_trainable_layers(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3
)
is_trained = weights is not None or weights_backbone is not None
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)
......@@ -134,11 +135,11 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
elif num_classes is None:
num_classes = 91
trainable_backbone_layers = _validate_trainable_layers(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 6, 3
)
is_trained = weights is not None or weights_backbone is not None
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3)
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
anchor_sizes = (
(
......
from typing import Any, Optional
from torch import nn
from torchvision.prototype.transforms import CocoEval
from torchvision.transforms.functional import InterpolationMode
......@@ -63,11 +64,11 @@ def fcos_resnet50_fpn(
elif num_classes is None:
num_classes = 91
trainable_backbone_layers = _validate_trainable_layers(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3
)
is_trained = weights is not None or weights_backbone is not None
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
backbone = _resnet_fpn_extractor(
backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
)
......
from typing import Any, Optional
from torch import nn
from torchvision.prototype.transforms import CocoEval
from torchvision.transforms.functional import InterpolationMode
......@@ -91,11 +92,11 @@ def keypointrcnn_resnet50_fpn(
if num_keypoints is None:
num_keypoints = 17
trainable_backbone_layers = _validate_trainable_layers(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3
)
is_trained = weights is not None or weights_backbone is not None
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
......
from typing import Any, Optional
from torch import nn
from torchvision.prototype.transforms import CocoEval
from torchvision.transforms.functional import InterpolationMode
......@@ -64,11 +65,11 @@ def maskrcnn_resnet50_fpn(
elif num_classes is None:
num_classes = 91
trainable_backbone_layers = _validate_trainable_layers(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3
)
is_trained = weights is not None or weights_backbone is not None
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
model = MaskRCNN(backbone, num_classes=num_classes, **kwargs)
......
from typing import Any, Optional
from torch import nn
from torchvision.prototype.transforms import CocoEval
from torchvision.transforms.functional import InterpolationMode
......@@ -64,11 +65,11 @@ def retinanet_resnet50_fpn(
elif num_classes is None:
num_classes = 91
trainable_backbone_layers = _validate_trainable_layers(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3
)
is_trained = weights is not None or weights_backbone is not None
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
# skip P2 because it generates too many anchors (according to their paper)
backbone = _resnet_fpn_extractor(
backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment