Unverified Commit 005c4087 authored by Wenwei Zhang's avatar Wenwei Zhang Committed by GitHub
Browse files

Fix wrappers version comparison (#602)

* add version check in wrappers

* fix assersion

* use digital version for version comparison

* fix unit tests

* reformat

* fall back to compare the first two version

* fix unittest

* fix unittest

* fix unit test

* clean unnecessary change
parent fe83261b
...@@ -12,6 +12,10 @@ from torch.nn.modules.utils import _pair ...@@ -12,6 +12,10 @@ from torch.nn.modules.utils import _pair
from .registry import CONV_LAYERS, UPSAMPLE_LAYERS from .registry import CONV_LAYERS, UPSAMPLE_LAYERS
# torch.__version__ could be 1.3.1+cu92, we only need the first two
# for comparison
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
class NewEmptyTensorOp(torch.autograd.Function): class NewEmptyTensorOp(torch.autograd.Function):
...@@ -30,7 +34,7 @@ class NewEmptyTensorOp(torch.autograd.Function): ...@@ -30,7 +34,7 @@ class NewEmptyTensorOp(torch.autograd.Function):
class Conv2d(nn.Conv2d): class Conv2d(nn.Conv2d):
def forward(self, x): def forward(self, x):
if x.numel() == 0 and torch.__version__ <= '1.4.0': if x.numel() == 0 and TORCH_VERSION <= (1, 4):
out_shape = [x.shape[0], self.out_channels] out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size, for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size,
self.padding, self.stride, self.dilation): self.padding, self.stride, self.dilation):
...@@ -53,7 +57,7 @@ class Conv2d(nn.Conv2d): ...@@ -53,7 +57,7 @@ class Conv2d(nn.Conv2d):
class ConvTranspose2d(nn.ConvTranspose2d): class ConvTranspose2d(nn.ConvTranspose2d):
def forward(self, x): def forward(self, x):
if x.numel() == 0 and torch.__version__ <= '1.4.0': if x.numel() == 0 and TORCH_VERSION <= (1, 4):
out_shape = [x.shape[0], self.out_channels] out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size, for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size,
self.padding, self.stride, self.padding, self.stride,
...@@ -74,7 +78,7 @@ class MaxPool2d(nn.MaxPool2d): ...@@ -74,7 +78,7 @@ class MaxPool2d(nn.MaxPool2d):
def forward(self, x): def forward(self, x):
# PyTorch 1.6 does not support empty tensor inference yet # PyTorch 1.6 does not support empty tensor inference yet
if x.numel() == 0 and torch.__version__ <= '1.6.0': if x.numel() == 0 and TORCH_VERSION <= (1, 6):
out_shape = list(x.shape[:2]) out_shape = list(x.shape[:2])
for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size), for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size),
_pair(self.padding), _pair(self.stride), _pair(self.padding), _pair(self.stride),
...@@ -91,7 +95,8 @@ class MaxPool2d(nn.MaxPool2d): ...@@ -91,7 +95,8 @@ class MaxPool2d(nn.MaxPool2d):
class Linear(torch.nn.Linear): class Linear(torch.nn.Linear):
def forward(self, x): def forward(self, x):
if x.numel() == 0: # empty tensor forward of Linear layer is supported in Pytorch 1.6
if x.numel() == 0 and TORCH_VERSION <= (1, 5):
out_shape = [x.shape[0], self.out_features] out_shape = [x.shape[0], self.out_features]
empty = NewEmptyTensorOp.apply(x, out_shape) empty = NewEmptyTensorOp.apply(x, out_shape)
if self.training: if self.training:
......
...@@ -169,7 +169,7 @@ def test_linear(): ...@@ -169,7 +169,7 @@ def test_linear():
wrapper(x_empty) wrapper(x_empty)
@patch('torch.__version__', '1.6.1') @patch('mmcv.cnn.bricks.wrappers.TORCH_VERSION', (1, 7))
def test_nn_op_forward_called(): def test_nn_op_forward_called():
for m in ['Conv2d', 'ConvTranspose2d', 'MaxPool2d']: for m in ['Conv2d', 'ConvTranspose2d', 'MaxPool2d']:
...@@ -191,7 +191,7 @@ def test_nn_op_forward_called(): ...@@ -191,7 +191,7 @@ def test_nn_op_forward_called():
x_empty = torch.randn(0, 3) x_empty = torch.randn(0, 3)
wrapper = Linear(3, 3) wrapper = Linear(3, 3)
wrapper(x_empty) wrapper(x_empty)
nn_module_forward.assert_not_called() nn_module_forward.assert_called_with(x_empty)
# non-randn input # non-randn input
x_normal = torch.randn(1, 3) x_normal = torch.randn(1, 3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment