Commit 2cae9509 authored by apache2046's avatar apache2046 Committed by Francisco Massa
Browse files

Fix the old flatten method which use the size(0) to caculate the batch size,...

Fix the old flatten method which use the size(0) to caculate the batch size, the old method will intruduce Gather opertion in the onnx output, which will faild parsed by tensorRT 5.0 (#1134)
parent bbd363ca
import torch
import torch.nn as nn
from .utils import load_state_dict_from_url
......@@ -43,7 +44,7 @@ class AlexNet(nn.Module):
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = x.view(x.size(0), 256 * 6 * 6)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
......
......@@ -154,7 +154,8 @@ class DenseNet(nn.Module):
def forward(self, x):
features = self.features(x)
out = F.relu(features, inplace=True)
out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1)
out = F.adaptive_avg_pool2d(out, (1, 1))
out = torch.flatten(out, 1)
out = self.classifier(out)
return out
......
......@@ -151,7 +151,7 @@ class GoogLeNet(nn.Module):
x = self.avgpool(x)
# N x 1024 x 1 x 1
x = x.view(x.size(0), -1)
x = torch.flatten(x, 1)
# N x 1024
x = self.dropout(x)
x = self.fc(x)
......@@ -208,7 +208,7 @@ class InceptionAux(nn.Module):
# aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
x = self.conv(x)
# N x 128 x 4 x 4
x = x.view(x.size(0), -1)
x = torch.flatten(x, 1)
# N x 2048
x = F.relu(self.fc1(x), inplace=True)
# N x 2048
......
......@@ -142,7 +142,7 @@ class Inception3(nn.Module):
# N x 2048 x 1 x 1
x = F.dropout(x, training=self.training)
# N x 2048 x 1 x 1
x = x.view(x.size(0), -1)
x = torch.flatten(x, 1)
# N x 2048
x = self.fc(x)
# N x 1000 (num_classes)
......@@ -334,7 +334,7 @@ class InceptionAux(nn.Module):
# Adaptive average pooling
x = F.adaptive_avg_pool2d(x, (1, 1))
# N x 768 x 1 x 1
x = x.view(x.size(0), -1)
x = torch.flatten(x, 1)
# N x 768
x = self.fc(x)
# N x 1000
......
import torch
import torch.nn as nn
from .utils import load_state_dict_from_url
......@@ -203,7 +204,7 @@ class ResNet(nn.Module):
x = self.layer4(x)
x = self.avgpool(x)
x = x.reshape(x.size(0), -1)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
......
......@@ -99,7 +99,7 @@ class SqueezeNet(nn.Module):
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x.view(x.size(0), -1)
return torch.flatten(x, 1)
def _squeezenet(version, pretrained, progress, **kwargs):
......
import torch
import torch.nn as nn
from .utils import load_state_dict_from_url
......@@ -41,7 +42,7 @@ class VGG(nn.Module):
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment