Commit 83b2dfb2 authored by ekka's avatar ekka Committed by Francisco Massa
Browse files

Changing to AdaptiveAvgPool2d on VGG (#747)

The update allows VGG to process images larger or smaller than prescribed imagenet size using adaptive average pooling. Will be useful while finetuning or testing on different resolution images. Similar to https://github.com/pytorch/vision/pull/643 and https://github.com/pytorch/vision/pull/672. I did not include adaptive avg pool in features or classifier block so that these predefined blocks can be used as it is.
parent 6434dead
...@@ -25,6 +25,7 @@ class VGG(nn.Module): ...@@ -25,6 +25,7 @@ class VGG(nn.Module):
def __init__(self, features, num_classes=1000, init_weights=True): def __init__(self, features, num_classes=1000, init_weights=True):
super(VGG, self).__init__() super(VGG, self).__init__()
self.features = features self.features = features
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
self.classifier = nn.Sequential( self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096), nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True), nn.ReLU(True),
...@@ -39,6 +40,7 @@ class VGG(nn.Module): ...@@ -39,6 +40,7 @@ class VGG(nn.Module):
def forward(self, x): def forward(self, x):
x = self.features(x) x = self.features(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1) x = x.view(x.size(0), -1)
x = self.classifier(x) x = self.classifier(x)
return x return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment