Commit 41ee590e authored by yl-1993's avatar yl-1993
Browse files

keep name rule consistent with official model

parent 3333bab6
......@@ -6,28 +6,27 @@ from .weight_init import constant_init, normal_init, kaiming_init
from ..runner import load_checkpoint
def conv3x3(in_planes, out_planes, dilation=1, bias=False):
def conv3x3(in_planes, out_planes, dilation=1):
"3x3 convolution with padding"
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
padding=dilation,
dilation=dilation,
bias=bias)
dilation=dilation)
def make_vgg_layer(inplanes, planes, num_blocks, dilation=1, with_bn=False):
layers = []
for _ in range(num_blocks):
layers.append(conv3x3(inplanes, planes, dilation, not with_bn))
layers.append(conv3x3(inplanes, planes, dilation))
if with_bn:
layers.append(nn.BatchNorm2d(planes))
layers.append(nn.ReLU(inplace=True))
inplanes = planes
layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
return nn.Sequential(*layers)
return layers
class VGG(nn.Module):
......@@ -69,9 +68,9 @@ class VGG(nn.Module):
raise KeyError('invalid depth {} for vgg'.format(depth))
assert num_stages >= 1 and num_stages <= 5
stage_blocks = self.arch_settings[depth]
stage_blocks = stage_blocks[:num_stages]
self.stage_blocks = stage_blocks[:num_stages]
assert len(dilations) == num_stages
assert max(out_indices) < num_stages
assert max(out_indices) <= num_stages
self.num_classes = num_classes
self.out_indices = out_indices
......@@ -80,8 +79,12 @@ class VGG(nn.Module):
self.bn_frozen = bn_frozen
self.inplanes = 3
self.vgg_layers = []
for i, num_blocks in enumerate(stage_blocks):
start_idx = 0
vgg_layers = []
self.range_sub_modules = []
for i, num_blocks in enumerate(self.stage_blocks):
num_modules = num_blocks * (2 + with_bn) + 1
end_idx = start_idx + num_modules
dilation = dilations[i]
planes = 64 * 2**i if i < 4 else 512
vgg_layer = make_vgg_layer(
......@@ -90,10 +93,12 @@ class VGG(nn.Module):
num_blocks,
dilation=dilation,
with_bn=with_bn)
vgg_layers.extend(vgg_layer)
self.inplanes = planes
layer_name = 'layer{}'.format(i + 1)
self.add_module(layer_name, vgg_layer)
self.vgg_layers.append(layer_name)
self.range_sub_modules.append([start_idx, end_idx])
start_idx = end_idx
self.module_name = 'features'
self.add_module(self.module_name, nn.Sequential(*vgg_layers))
if self.num_classes > 0:
self.classifier = nn.Sequential(
......@@ -123,9 +128,11 @@ class VGG(nn.Module):
def forward(self, x):
outs = []
for i, layer_name in enumerate(self.vgg_layers):
vgg_layer = getattr(self, layer_name)
x = vgg_layer(x)
vgg_layers = getattr(self, self.module_name)
for i, num_blocks in enumerate(self.stage_blocks):
for j in range(*self.range_sub_modules[i]):
vgg_layer = vgg_layers[j]
x = vgg_layer(x)
if i in self.out_indices:
outs.append(x)
if self.num_classes > 0:
......@@ -146,9 +153,11 @@ class VGG(nn.Module):
if self.bn_frozen:
for params in m.parameters():
params.requires_grad = False
vgg_layers = getattr(self, self.module_name)
if mode and self.frozen_stages >= 0:
for i in range(1, self.frozen_stages + 1):
mod = getattr(self, 'layer{}'.format(i))
mod.eval()
for param in mod.parameters():
param.requires_grad = False
for i in range(self.frozen_stages):
for j in range(*self.range_sub_modules[i]):
mod = vgg_layers[j]
mod.eval()
for param in mod.parameters():
param.requires_grad = False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment