Unverified Commit 3e69462f authored by Michael Kösel's avatar Michael Kösel Committed by GitHub
Browse files

Add support for other normalizations in MobileNetV2 (#2267)

* Add norm_layer to MobileNetV2

* Add simple test case

* Small fix
parent c8064cdb
......@@ -2,6 +2,7 @@ from common_utils import TestCase, map_nested_tensor_object, freeze_rng_state
from collections import OrderedDict
from itertools import product
import torch
import torch.nn as nn
import numpy as np
from torchvision import models
import unittest
......@@ -240,6 +241,17 @@ class ModelTester(TestCase):
out = model(x)
self.assertEqual(out.shape[-1], 1000)
def test_mobilenetv2_norm_layer(self):
model = models.__dict__["mobilenet_v2"]()
self.assertTrue(any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))
def get_gn(num_channels):
return nn.GroupNorm(32, num_channels)
model = models.__dict__["mobilenet_v2"](norm_layer=get_gn)
self.assertFalse(any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))
self.assertTrue(any(isinstance(x, nn.GroupNorm) for x in model.modules()))
def test_fasterrcnn_double(self):
model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False)
model.double()
......
......@@ -31,34 +31,39 @@ def _make_divisible(v, divisor, min_value=None):
class ConvBNReLU(nn.Sequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=None):
padding = (kernel_size - 1) // 2
if norm_layer is None:
norm_layer = nn.BatchNorm2d
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
nn.BatchNorm2d(out_planes),
norm_layer(out_planes),
nn.ReLU6(inplace=True)
)
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expand_ratio):
def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
if norm_layer is None:
norm_layer = nn.BatchNorm2d
hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = self.stride == 1 and inp == oup
layers = []
if expand_ratio != 1:
# pw
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
layers.extend([
# dw
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
norm_layer(oup),
])
self.conv = nn.Sequential(*layers)
......@@ -75,7 +80,8 @@ class MobileNetV2(nn.Module):
width_mult=1.0,
inverted_residual_setting=None,
round_nearest=8,
block=None):
block=None,
norm_layer=None):
"""
MobileNet V2 main class
......@@ -86,12 +92,17 @@ class MobileNetV2(nn.Module):
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
Set to 1 to turn off rounding
block: Module specifying inverted residual building block for mobilenet
norm_layer: Module specifying the normalization layer to use
"""
super(MobileNetV2, self).__init__()
if block is None:
block = InvertedResidual
if norm_layer is None:
norm_layer = nn.BatchNorm2d
input_channel = 32
last_channel = 1280
......@@ -115,16 +126,16 @@ class MobileNetV2(nn.Module):
# building first layer
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features = [ConvBNReLU(3, input_channel, stride=2)]
features = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
# building inverted residual blocks
for t, c, n, s in inverted_residual_setting:
output_channel = _make_divisible(c * width_mult, round_nearest)
for i in range(n):
stride = s if i == 0 else 1
features.append(block(input_channel, output_channel, stride, expand_ratio=t))
features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
input_channel = output_channel
# building last several layers
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer))
# make it nn.Sequential
self.features = nn.Sequential(*features)
......@@ -140,7 +151,7 @@ class MobileNetV2(nn.Module):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment