Commit 2cae9509 authored by apache2046's avatar apache2046 Committed by Francisco Massa
Browse files

Fix the old flatten method which use the size(0) to caculate the batch size,...

Fix the old flatten method which use the size(0) to caculate the batch size, the old method will intruduce Gather opertion in the onnx output, which will faild parsed by tensorRT 5.0 (#1134)
parent bbd363ca
import torch
import torch.nn as nn import torch.nn as nn
from .utils import load_state_dict_from_url from .utils import load_state_dict_from_url
...@@ -43,7 +44,7 @@ class AlexNet(nn.Module): ...@@ -43,7 +44,7 @@ class AlexNet(nn.Module):
def forward(self, x): def forward(self, x):
x = self.features(x) x = self.features(x)
x = self.avgpool(x) x = self.avgpool(x)
x = x.view(x.size(0), 256 * 6 * 6) x = torch.flatten(x, 1)
x = self.classifier(x) x = self.classifier(x)
return x return x
......
...@@ -154,7 +154,8 @@ class DenseNet(nn.Module): ...@@ -154,7 +154,8 @@ class DenseNet(nn.Module):
def forward(self, x): def forward(self, x):
features = self.features(x) features = self.features(x)
out = F.relu(features, inplace=True) out = F.relu(features, inplace=True)
out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1) out = F.adaptive_avg_pool2d(out, (1, 1))
out = torch.flatten(out, 1)
out = self.classifier(out) out = self.classifier(out)
return out return out
......
...@@ -151,7 +151,7 @@ class GoogLeNet(nn.Module): ...@@ -151,7 +151,7 @@ class GoogLeNet(nn.Module):
x = self.avgpool(x) x = self.avgpool(x)
# N x 1024 x 1 x 1 # N x 1024 x 1 x 1
x = x.view(x.size(0), -1) x = torch.flatten(x, 1)
# N x 1024 # N x 1024
x = self.dropout(x) x = self.dropout(x)
x = self.fc(x) x = self.fc(x)
...@@ -208,7 +208,7 @@ class InceptionAux(nn.Module): ...@@ -208,7 +208,7 @@ class InceptionAux(nn.Module):
# aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4 # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
x = self.conv(x) x = self.conv(x)
# N x 128 x 4 x 4 # N x 128 x 4 x 4
x = x.view(x.size(0), -1) x = torch.flatten(x, 1)
# N x 2048 # N x 2048
x = F.relu(self.fc1(x), inplace=True) x = F.relu(self.fc1(x), inplace=True)
# N x 2048 # N x 2048
......
...@@ -142,7 +142,7 @@ class Inception3(nn.Module): ...@@ -142,7 +142,7 @@ class Inception3(nn.Module):
# N x 2048 x 1 x 1 # N x 2048 x 1 x 1
x = F.dropout(x, training=self.training) x = F.dropout(x, training=self.training)
# N x 2048 x 1 x 1 # N x 2048 x 1 x 1
x = x.view(x.size(0), -1) x = torch.flatten(x, 1)
# N x 2048 # N x 2048
x = self.fc(x) x = self.fc(x)
# N x 1000 (num_classes) # N x 1000 (num_classes)
...@@ -334,7 +334,7 @@ class InceptionAux(nn.Module): ...@@ -334,7 +334,7 @@ class InceptionAux(nn.Module):
# Adaptive average pooling # Adaptive average pooling
x = F.adaptive_avg_pool2d(x, (1, 1)) x = F.adaptive_avg_pool2d(x, (1, 1))
# N x 768 x 1 x 1 # N x 768 x 1 x 1
x = x.view(x.size(0), -1) x = torch.flatten(x, 1)
# N x 768 # N x 768
x = self.fc(x) x = self.fc(x)
# N x 1000 # N x 1000
......
import torch
import torch.nn as nn import torch.nn as nn
from .utils import load_state_dict_from_url from .utils import load_state_dict_from_url
...@@ -203,7 +204,7 @@ class ResNet(nn.Module): ...@@ -203,7 +204,7 @@ class ResNet(nn.Module):
x = self.layer4(x) x = self.layer4(x)
x = self.avgpool(x) x = self.avgpool(x)
x = x.reshape(x.size(0), -1) x = torch.flatten(x, 1)
x = self.fc(x) x = self.fc(x)
return x return x
......
...@@ -99,7 +99,7 @@ class SqueezeNet(nn.Module): ...@@ -99,7 +99,7 @@ class SqueezeNet(nn.Module):
def forward(self, x): def forward(self, x):
x = self.features(x) x = self.features(x)
x = self.classifier(x) x = self.classifier(x)
return x.view(x.size(0), -1) return torch.flatten(x, 1)
def _squeezenet(version, pretrained, progress, **kwargs): def _squeezenet(version, pretrained, progress, **kwargs):
......
import torch
import torch.nn as nn import torch.nn as nn
from .utils import load_state_dict_from_url from .utils import load_state_dict_from_url
...@@ -41,7 +42,7 @@ class VGG(nn.Module): ...@@ -41,7 +42,7 @@ class VGG(nn.Module):
def forward(self, x): def forward(self, x):
x = self.features(x) x = self.features(x)
x = self.avgpool(x) x = self.avgpool(x)
x = x.view(x.size(0), -1) x = torch.flatten(x, 1)
x = self.classifier(x) x = self.classifier(x)
return x return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment