Unverified Commit 3e06bc6f authored by Shunta Saito's avatar Shunta Saito Committed by GitHub
Browse files

Use Module objects instead of functions for some layers of Inception3 (#2287)

parent 11a39aaa
...@@ -90,8 +90,10 @@ class Inception3(nn.Module): ...@@ -90,8 +90,10 @@ class Inception3(nn.Module):
self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2) self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3) self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1) self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1) self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3) self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
self.Mixed_5b = inception_a(192, pool_features=32) self.Mixed_5b = inception_a(192, pool_features=32)
self.Mixed_5c = inception_a(256, pool_features=64) self.Mixed_5c = inception_a(256, pool_features=64)
self.Mixed_5d = inception_a(288, pool_features=64) self.Mixed_5d = inception_a(288, pool_features=64)
...@@ -105,6 +107,8 @@ class Inception3(nn.Module): ...@@ -105,6 +107,8 @@ class Inception3(nn.Module):
self.Mixed_7a = inception_d(768) self.Mixed_7a = inception_d(768)
self.Mixed_7b = inception_e(1280) self.Mixed_7b = inception_e(1280)
self.Mixed_7c = inception_e(2048) self.Mixed_7c = inception_e(2048)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.dropout = nn.Dropout()
self.fc = nn.Linear(2048, num_classes) self.fc = nn.Linear(2048, num_classes)
if init_weights: if init_weights:
for m in self.modules(): for m in self.modules():
...@@ -136,13 +140,13 @@ class Inception3(nn.Module): ...@@ -136,13 +140,13 @@ class Inception3(nn.Module):
# N x 32 x 147 x 147 # N x 32 x 147 x 147
x = self.Conv2d_2b_3x3(x) x = self.Conv2d_2b_3x3(x)
# N x 64 x 147 x 147 # N x 64 x 147 x 147
x = F.max_pool2d(x, kernel_size=3, stride=2) x = self.maxpool1(x)
# N x 64 x 73 x 73 # N x 64 x 73 x 73
x = self.Conv2d_3b_1x1(x) x = self.Conv2d_3b_1x1(x)
# N x 80 x 73 x 73 # N x 80 x 73 x 73
x = self.Conv2d_4a_3x3(x) x = self.Conv2d_4a_3x3(x)
# N x 192 x 71 x 71 # N x 192 x 71 x 71
x = F.max_pool2d(x, kernel_size=3, stride=2) x = self.maxpool2(x)
# N x 192 x 35 x 35 # N x 192 x 35 x 35
x = self.Mixed_5b(x) x = self.Mixed_5b(x)
# N x 256 x 35 x 35 # N x 256 x 35 x 35
...@@ -173,9 +177,9 @@ class Inception3(nn.Module): ...@@ -173,9 +177,9 @@ class Inception3(nn.Module):
x = self.Mixed_7c(x) x = self.Mixed_7c(x)
# N x 2048 x 8 x 8 # N x 2048 x 8 x 8
# Adaptive average pooling # Adaptive average pooling
x = F.adaptive_avg_pool2d(x, (1, 1)) x = self.avgpool(x)
# N x 2048 x 1 x 1 # N x 2048 x 1 x 1
x = F.dropout(x, training=self.training) x = self.dropout(x)
# N x 2048 x 1 x 1 # N x 2048 x 1 x 1
x = torch.flatten(x, 1) x = torch.flatten(x, 1)
# N x 2048 # N x 2048
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment