Unverified Commit 5b61a5c8 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Overwriting FrozenBN eps=0.0 if pretrained=True for detection models. (#2940)

* Overwriting FrozenBN eps=0.0 if pretrained=True for detection models.

* Moving the method to detection utils and adding comments.
parent 6b071be9
...@@ -8,7 +8,7 @@ from torchvision import models ...@@ -8,7 +8,7 @@ from torchvision import models
import unittest import unittest
import random import random
from torchvision.ops.misc import FrozenBatchNorm2d from torchvision.models.detection._utils import overwrite_eps
def set_rng_seed(seed): def set_rng_seed(seed):
...@@ -151,9 +151,7 @@ class ModelTester(TestCase): ...@@ -151,9 +151,7 @@ class ModelTester(TestCase):
kwargs["score_thresh"] = 0.013 kwargs["score_thresh"] = 0.013
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False, **kwargs) model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False, **kwargs)
if "keypointrcnn" in name or "retinanet" in name: if "keypointrcnn" in name or "retinanet" in name:
for module in model.modules(): overwrite_eps(model, 0.0)
if isinstance(module, FrozenBatchNorm2d):
module.eps = 0
model.eval().to(device=dev) model.eval().to(device=dev)
input_shape = (3, 300, 300) input_shape = (3, 300, 300)
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
......
...@@ -3,7 +3,8 @@ import math ...@@ -3,7 +3,8 @@ import math
import torch import torch
from torch.jit.annotations import List, Tuple from torch.jit.annotations import List, Tuple
from torch import Tensor from torch import Tensor
import torchvision
from torchvision.ops.misc import FrozenBatchNorm2d
class BalancedPositiveNegativeSampler(object): class BalancedPositiveNegativeSampler(object):
...@@ -349,3 +350,21 @@ def smooth_l1_loss(input, target, beta: float = 1. / 9, size_average: bool = Tru ...@@ -349,3 +350,21 @@ def smooth_l1_loss(input, target, beta: float = 1. / 9, size_average: bool = Tru
if size_average: if size_average:
return loss.mean() return loss.mean()
return loss.sum() return loss.sum()
def overwrite_eps(model, eps):
"""
This method overwrites the default eps values of all the
FrozenBatchNorm2d layers of the model with the provided value.
This is necessary to address the BC-breaking change introduced
by the bug-fix at pytorch/vision#2933. The overwrite is applied
only when the pretrained weights are loaded to maintain compatibility
with previous versions.
Arguments:
model (nn.Module): The model on which we perform the overwrite.
eps (float): The new value of eps.
"""
for module in model.modules():
if isinstance(module, FrozenBatchNorm2d):
module.eps = eps
...@@ -7,6 +7,7 @@ import torch.nn.functional as F ...@@ -7,6 +7,7 @@ import torch.nn.functional as F
from torchvision.ops import misc as misc_nn_ops from torchvision.ops import misc as misc_nn_ops
from torchvision.ops import MultiScaleRoIAlign from torchvision.ops import MultiScaleRoIAlign
from ._utils import overwrite_eps
from ..utils import load_state_dict_from_url from ..utils import load_state_dict_from_url
from .anchor_utils import AnchorGenerator from .anchor_utils import AnchorGenerator
...@@ -361,4 +362,5 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True, ...@@ -361,4 +362,5 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
state_dict = load_state_dict_from_url(model_urls['fasterrcnn_resnet50_fpn_coco'], state_dict = load_state_dict_from_url(model_urls['fasterrcnn_resnet50_fpn_coco'],
progress=progress) progress=progress)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
overwrite_eps(model, 0.0)
return model return model
...@@ -3,6 +3,7 @@ from torch import nn ...@@ -3,6 +3,7 @@ from torch import nn
from torchvision.ops import MultiScaleRoIAlign from torchvision.ops import MultiScaleRoIAlign
from ._utils import overwrite_eps
from ..utils import load_state_dict_from_url from ..utils import load_state_dict_from_url
from .faster_rcnn import FasterRCNN from .faster_rcnn import FasterRCNN
...@@ -332,4 +333,5 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True, ...@@ -332,4 +333,5 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,
state_dict = load_state_dict_from_url(model_urls[key], state_dict = load_state_dict_from_url(model_urls[key],
progress=progress) progress=progress)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
overwrite_eps(model, 0.0)
return model return model
...@@ -7,6 +7,7 @@ import torch.nn.functional as F ...@@ -7,6 +7,7 @@ import torch.nn.functional as F
from torchvision.ops import misc as misc_nn_ops from torchvision.ops import misc as misc_nn_ops
from torchvision.ops import MultiScaleRoIAlign from torchvision.ops import MultiScaleRoIAlign
from ._utils import overwrite_eps
from ..utils import load_state_dict_from_url from ..utils import load_state_dict_from_url
from .faster_rcnn import FasterRCNN from .faster_rcnn import FasterRCNN
...@@ -328,4 +329,5 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True, ...@@ -328,4 +329,5 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True,
state_dict = load_state_dict_from_url(model_urls['maskrcnn_resnet50_fpn_coco'], state_dict = load_state_dict_from_url(model_urls['maskrcnn_resnet50_fpn_coco'],
progress=progress) progress=progress)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
overwrite_eps(model, 0.0)
return model return model
...@@ -7,6 +7,7 @@ import torch.nn as nn ...@@ -7,6 +7,7 @@ import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.jit.annotations import Dict, List, Tuple from torch.jit.annotations import Dict, List, Tuple
from ._utils import overwrite_eps
from ..utils import load_state_dict_from_url from ..utils import load_state_dict_from_url
from . import _utils as det_utils from . import _utils as det_utils
...@@ -628,4 +629,5 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True, ...@@ -628,4 +629,5 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True,
state_dict = load_state_dict_from_url(model_urls['retinanet_resnet50_fpn_coco'], state_dict = load_state_dict_from_url(model_urls['retinanet_resnet50_fpn_coco'],
progress=progress) progress=progress)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
overwrite_eps(model, 0.0)
return model return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment