Unverified Commit c5fee834 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #17 from yl-1993/cnn

support initializing with official pretrained model
parents 3333bab6 0164a046
import logging import logging
import math
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint as cp import torch.utils.checkpoint as cp
......
...@@ -6,28 +6,27 @@ from .weight_init import constant_init, normal_init, kaiming_init ...@@ -6,28 +6,27 @@ from .weight_init import constant_init, normal_init, kaiming_init
from ..runner import load_checkpoint from ..runner import load_checkpoint
def conv3x3(in_planes, out_planes, dilation=1, bias=False): def conv3x3(in_planes, out_planes, dilation=1):
"3x3 convolution with padding" "3x3 convolution with padding"
return nn.Conv2d( return nn.Conv2d(
in_planes, in_planes,
out_planes, out_planes,
kernel_size=3, kernel_size=3,
padding=dilation, padding=dilation,
dilation=dilation, dilation=dilation)
bias=bias)
def make_vgg_layer(inplanes, planes, num_blocks, dilation=1, with_bn=False): def make_vgg_layer(inplanes, planes, num_blocks, dilation=1, with_bn=False):
layers = [] layers = []
for _ in range(num_blocks): for _ in range(num_blocks):
layers.append(conv3x3(inplanes, planes, dilation, not with_bn)) layers.append(conv3x3(inplanes, planes, dilation))
if with_bn: if with_bn:
layers.append(nn.BatchNorm2d(planes)) layers.append(nn.BatchNorm2d(planes))
layers.append(nn.ReLU(inplace=True)) layers.append(nn.ReLU(inplace=True))
inplanes = planes inplanes = planes
layers.append(nn.MaxPool2d(kernel_size=2, stride=2)) layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
return nn.Sequential(*layers) return layers
class VGG(nn.Module): class VGG(nn.Module):
...@@ -69,9 +68,9 @@ class VGG(nn.Module): ...@@ -69,9 +68,9 @@ class VGG(nn.Module):
raise KeyError('invalid depth {} for vgg'.format(depth)) raise KeyError('invalid depth {} for vgg'.format(depth))
assert num_stages >= 1 and num_stages <= 5 assert num_stages >= 1 and num_stages <= 5
stage_blocks = self.arch_settings[depth] stage_blocks = self.arch_settings[depth]
stage_blocks = stage_blocks[:num_stages] self.stage_blocks = stage_blocks[:num_stages]
assert len(dilations) == num_stages assert len(dilations) == num_stages
assert max(out_indices) < num_stages assert max(out_indices) <= num_stages
self.num_classes = num_classes self.num_classes = num_classes
self.out_indices = out_indices self.out_indices = out_indices
...@@ -80,8 +79,12 @@ class VGG(nn.Module): ...@@ -80,8 +79,12 @@ class VGG(nn.Module):
self.bn_frozen = bn_frozen self.bn_frozen = bn_frozen
self.inplanes = 3 self.inplanes = 3
self.vgg_layers = [] start_idx = 0
for i, num_blocks in enumerate(stage_blocks): vgg_layers = []
self.range_sub_modules = []
for i, num_blocks in enumerate(self.stage_blocks):
num_modules = num_blocks * (2 + with_bn) + 1
end_idx = start_idx + num_modules
dilation = dilations[i] dilation = dilations[i]
planes = 64 * 2**i if i < 4 else 512 planes = 64 * 2**i if i < 4 else 512
vgg_layer = make_vgg_layer( vgg_layer = make_vgg_layer(
...@@ -90,10 +93,12 @@ class VGG(nn.Module): ...@@ -90,10 +93,12 @@ class VGG(nn.Module):
num_blocks, num_blocks,
dilation=dilation, dilation=dilation,
with_bn=with_bn) with_bn=with_bn)
vgg_layers.extend(vgg_layer)
self.inplanes = planes self.inplanes = planes
layer_name = 'layer{}'.format(i + 1) self.range_sub_modules.append([start_idx, end_idx])
self.add_module(layer_name, vgg_layer) start_idx = end_idx
self.vgg_layers.append(layer_name) self.module_name = 'features'
self.add_module(self.module_name, nn.Sequential(*vgg_layers))
if self.num_classes > 0: if self.num_classes > 0:
self.classifier = nn.Sequential( self.classifier = nn.Sequential(
...@@ -123,8 +128,10 @@ class VGG(nn.Module): ...@@ -123,8 +128,10 @@ class VGG(nn.Module):
def forward(self, x): def forward(self, x):
outs = [] outs = []
for i, layer_name in enumerate(self.vgg_layers): vgg_layers = getattr(self, self.module_name)
vgg_layer = getattr(self, layer_name) for i, num_blocks in enumerate(self.stage_blocks):
for j in range(*self.range_sub_modules[i]):
vgg_layer = vgg_layers[j]
x = vgg_layer(x) x = vgg_layer(x)
if i in self.out_indices: if i in self.out_indices:
outs.append(x) outs.append(x)
...@@ -146,9 +153,11 @@ class VGG(nn.Module): ...@@ -146,9 +153,11 @@ class VGG(nn.Module):
if self.bn_frozen: if self.bn_frozen:
for params in m.parameters(): for params in m.parameters():
params.requires_grad = False params.requires_grad = False
vgg_layers = getattr(self, self.module_name)
if mode and self.frozen_stages >= 0: if mode and self.frozen_stages >= 0:
for i in range(1, self.frozen_stages + 1): for i in range(self.frozen_stages):
mod = getattr(self, 'layer{}'.format(i)) for j in range(*self.range_sub_modules[i]):
mod = vgg_layers[j]
mod.eval() mod.eval()
for param in mod.parameters(): for param in mod.parameters():
param.requires_grad = False param.requires_grad = False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment