"src/vscode:/vscode.git/clone" did not exist on "3231712b7da66f9c1b160f2d5363b186b03daafe"
Commit 41ee590e authored by yl-1993's avatar yl-1993
Browse files

keep name rule consistent with official model

parent 3333bab6
...@@ -6,28 +6,27 @@ from .weight_init import constant_init, normal_init, kaiming_init ...@@ -6,28 +6,27 @@ from .weight_init import constant_init, normal_init, kaiming_init
from ..runner import load_checkpoint from ..runner import load_checkpoint
def conv3x3(in_planes, out_planes, dilation=1, bias=False): def conv3x3(in_planes, out_planes, dilation=1):
"3x3 convolution with padding" "3x3 convolution with padding"
return nn.Conv2d( return nn.Conv2d(
in_planes, in_planes,
out_planes, out_planes,
kernel_size=3, kernel_size=3,
padding=dilation, padding=dilation,
dilation=dilation, dilation=dilation)
bias=bias)
def make_vgg_layer(inplanes, planes, num_blocks, dilation=1, with_bn=False): def make_vgg_layer(inplanes, planes, num_blocks, dilation=1, with_bn=False):
layers = [] layers = []
for _ in range(num_blocks): for _ in range(num_blocks):
layers.append(conv3x3(inplanes, planes, dilation, not with_bn)) layers.append(conv3x3(inplanes, planes, dilation))
if with_bn: if with_bn:
layers.append(nn.BatchNorm2d(planes)) layers.append(nn.BatchNorm2d(planes))
layers.append(nn.ReLU(inplace=True)) layers.append(nn.ReLU(inplace=True))
inplanes = planes inplanes = planes
layers.append(nn.MaxPool2d(kernel_size=2, stride=2)) layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
return nn.Sequential(*layers) return layers
class VGG(nn.Module): class VGG(nn.Module):
...@@ -69,9 +68,9 @@ class VGG(nn.Module): ...@@ -69,9 +68,9 @@ class VGG(nn.Module):
raise KeyError('invalid depth {} for vgg'.format(depth)) raise KeyError('invalid depth {} for vgg'.format(depth))
assert num_stages >= 1 and num_stages <= 5 assert num_stages >= 1 and num_stages <= 5
stage_blocks = self.arch_settings[depth] stage_blocks = self.arch_settings[depth]
stage_blocks = stage_blocks[:num_stages] self.stage_blocks = stage_blocks[:num_stages]
assert len(dilations) == num_stages assert len(dilations) == num_stages
assert max(out_indices) < num_stages assert max(out_indices) <= num_stages
self.num_classes = num_classes self.num_classes = num_classes
self.out_indices = out_indices self.out_indices = out_indices
...@@ -80,8 +79,12 @@ class VGG(nn.Module): ...@@ -80,8 +79,12 @@ class VGG(nn.Module):
self.bn_frozen = bn_frozen self.bn_frozen = bn_frozen
self.inplanes = 3 self.inplanes = 3
self.vgg_layers = [] start_idx = 0
for i, num_blocks in enumerate(stage_blocks): vgg_layers = []
self.range_sub_modules = []
for i, num_blocks in enumerate(self.stage_blocks):
num_modules = num_blocks * (2 + with_bn) + 1
end_idx = start_idx + num_modules
dilation = dilations[i] dilation = dilations[i]
planes = 64 * 2**i if i < 4 else 512 planes = 64 * 2**i if i < 4 else 512
vgg_layer = make_vgg_layer( vgg_layer = make_vgg_layer(
...@@ -90,10 +93,12 @@ class VGG(nn.Module): ...@@ -90,10 +93,12 @@ class VGG(nn.Module):
num_blocks, num_blocks,
dilation=dilation, dilation=dilation,
with_bn=with_bn) with_bn=with_bn)
vgg_layers.extend(vgg_layer)
self.inplanes = planes self.inplanes = planes
layer_name = 'layer{}'.format(i + 1) self.range_sub_modules.append([start_idx, end_idx])
self.add_module(layer_name, vgg_layer) start_idx = end_idx
self.vgg_layers.append(layer_name) self.module_name = 'features'
self.add_module(self.module_name, nn.Sequential(*vgg_layers))
if self.num_classes > 0: if self.num_classes > 0:
self.classifier = nn.Sequential( self.classifier = nn.Sequential(
...@@ -123,8 +128,10 @@ class VGG(nn.Module): ...@@ -123,8 +128,10 @@ class VGG(nn.Module):
def forward(self, x): def forward(self, x):
outs = [] outs = []
for i, layer_name in enumerate(self.vgg_layers): vgg_layers = getattr(self, self.module_name)
vgg_layer = getattr(self, layer_name) for i, num_blocks in enumerate(self.stage_blocks):
for j in range(*self.range_sub_modules[i]):
vgg_layer = vgg_layers[j]
x = vgg_layer(x) x = vgg_layer(x)
if i in self.out_indices: if i in self.out_indices:
outs.append(x) outs.append(x)
...@@ -146,9 +153,11 @@ class VGG(nn.Module): ...@@ -146,9 +153,11 @@ class VGG(nn.Module):
if self.bn_frozen: if self.bn_frozen:
for params in m.parameters(): for params in m.parameters():
params.requires_grad = False params.requires_grad = False
vgg_layers = getattr(self, self.module_name)
if mode and self.frozen_stages >= 0: if mode and self.frozen_stages >= 0:
for i in range(1, self.frozen_stages + 1): for i in range(self.frozen_stages):
mod = getattr(self, 'layer{}'.format(i)) for j in range(*self.range_sub_modules[i]):
mod = vgg_layers[j]
mod.eval() mod.eval()
for param in mod.parameters(): for param in mod.parameters():
param.requires_grad = False param.requires_grad = False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment