Unverified Commit 5b61a5c8 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Overwriting FrozenBN eps=0.0 if pretrained=True for detection models. (#2940)

* Overwriting FrozenBN eps=0.0 if pretrained=True for detection models.

* Moving the method to detection utils and adding comments.
parent 6b071be9
......@@ -8,7 +8,7 @@ from torchvision import models
import unittest
import random
from torchvision.ops.misc import FrozenBatchNorm2d
from torchvision.models.detection._utils import overwrite_eps
def set_rng_seed(seed):
......@@ -151,9 +151,7 @@ class ModelTester(TestCase):
kwargs["score_thresh"] = 0.013
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False, **kwargs)
if "keypointrcnn" in name or "retinanet" in name:
for module in model.modules():
if isinstance(module, FrozenBatchNorm2d):
module.eps = 0
overwrite_eps(model, 0.0)
model.eval().to(device=dev)
input_shape = (3, 300, 300)
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
......
......@@ -3,7 +3,8 @@ import math
import torch
from torch.jit.annotations import List, Tuple
from torch import Tensor
import torchvision
from torchvision.ops.misc import FrozenBatchNorm2d
class BalancedPositiveNegativeSampler(object):
......@@ -349,3 +350,21 @@ def smooth_l1_loss(input, target, beta: float = 1. / 9, size_average: bool = Tru
if size_average:
return loss.mean()
return loss.sum()
def overwrite_eps(model, eps):
"""
This method overwrites the default eps values of all the
FrozenBatchNorm2d layers of the model with the provided value.
This is necessary to address the BC-breaking change introduced
by the bug-fix at pytorch/vision#2933. The overwrite is applied
only when the pretrained weights are loaded to maintain compatibility
with previous versions.
Arguments:
model (nn.Module): The model on which we perform the overwrite.
eps (float): The new value of eps.
"""
for module in model.modules():
if isinstance(module, FrozenBatchNorm2d):
module.eps = eps
......@@ -7,6 +7,7 @@ import torch.nn.functional as F
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops import MultiScaleRoIAlign
from ._utils import overwrite_eps
from ..utils import load_state_dict_from_url
from .anchor_utils import AnchorGenerator
......@@ -361,4 +362,5 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
state_dict = load_state_dict_from_url(model_urls['fasterrcnn_resnet50_fpn_coco'],
progress=progress)
model.load_state_dict(state_dict)
overwrite_eps(model, 0.0)
return model
......@@ -3,6 +3,7 @@ from torch import nn
from torchvision.ops import MultiScaleRoIAlign
from ._utils import overwrite_eps
from ..utils import load_state_dict_from_url
from .faster_rcnn import FasterRCNN
......@@ -332,4 +333,5 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,
state_dict = load_state_dict_from_url(model_urls[key],
progress=progress)
model.load_state_dict(state_dict)
overwrite_eps(model, 0.0)
return model
......@@ -7,6 +7,7 @@ import torch.nn.functional as F
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops import MultiScaleRoIAlign
from ._utils import overwrite_eps
from ..utils import load_state_dict_from_url
from .faster_rcnn import FasterRCNN
......@@ -328,4 +329,5 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True,
state_dict = load_state_dict_from_url(model_urls['maskrcnn_resnet50_fpn_coco'],
progress=progress)
model.load_state_dict(state_dict)
overwrite_eps(model, 0.0)
return model
......@@ -7,6 +7,7 @@ import torch.nn as nn
from torch import Tensor
from torch.jit.annotations import Dict, List, Tuple
from ._utils import overwrite_eps
from ..utils import load_state_dict_from_url
from . import _utils as det_utils
......@@ -628,4 +629,5 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True,
state_dict = load_state_dict_from_url(model_urls['retinanet_resnet50_fpn_coco'],
progress=progress)
model.load_state_dict(state_dict)
overwrite_eps(model, 0.0)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment