Unverified Commit b266c2f1 authored by Jithendra Paruchuri's avatar Jithendra Paruchuri Committed by GitHub
Browse files

Replace reshape with flatten (#3462)



Current implementation is generating bad graph after onnx conversion. So replacing with flatten like in mobilenetv3 code.
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 74d5e71d
import torch
from torch import nn
from torch import Tensor
from .utils import load_state_dict_from_url
......@@ -189,8 +190,9 @@ class MobileNetV2(nn.Module):
# This exists since TorchScript doesn't support inheritance, so the superclass method
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
x = self.features(x)
# Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0]
x = nn.functional.adaptive_avg_pool2d(x, (1, 1)).reshape(x.shape[0], -1)
# Cannot use "squeeze" as batch-size can be 1
x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment