Unverified Commit 1e27b533 authored by talcs's avatar talcs Committed by GitHub
Browse files

replaced mean on dimensions 2,3 by adaptive_avg_pooling2d (#1838)

* replaced mean on dimensions 2,3 by adaptive_avg_pooling2d with destination of 1, to remove hardcoded dimension ordering

* replaced reshape command by torch.squeeze after global_avg_pool2d, which is cleaner

* reshape rather than squeeze for BS=1

* remove import torch
parent 0156d58e
...@@ -151,7 +151,8 @@ class MobileNetV2(nn.Module): ...@@ -151,7 +151,8 @@ class MobileNetV2(nn.Module):
# This exists since TorchScript doesn't support inheritance, so the superclass method # This exists since TorchScript doesn't support inheritance, so the superclass method
# (this one) needs to have a name other than `forward` that can be accessed in a subclass # (this one) needs to have a name other than `forward` that can be accessed in a subclass
x = self.features(x) x = self.features(x)
x = x.mean([2, 3]) # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0]
x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1)
x = self.classifier(x) x = self.classifier(x)
return x return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment