Unverified Commit 9c26a104 authored by Jintao Lin's avatar Jintao Lin Committed by GitHub
Browse files

empty tensor inference backward compatible (#1131)

* empty tensor inference backward continity

* update

* add 3d
parent 59ed0ddd
...@@ -128,8 +128,8 @@ class ConvTranspose3d(nn.ConvTranspose3d): ...@@ -128,8 +128,8 @@ class ConvTranspose3d(nn.ConvTranspose3d):
class MaxPool2d(nn.MaxPool2d): class MaxPool2d(nn.MaxPool2d):
def forward(self, x): def forward(self, x):
# PyTorch 1.7 does not support empty tensor inference yet # PyTorch 1.9 does not support empty tensor inference yet
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 7)): if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
out_shape = list(x.shape[:2]) out_shape = list(x.shape[:2])
for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size), for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size),
_pair(self.padding), _pair(self.stride), _pair(self.padding), _pair(self.stride),
...@@ -146,8 +146,8 @@ class MaxPool2d(nn.MaxPool2d): ...@@ -146,8 +146,8 @@ class MaxPool2d(nn.MaxPool2d):
class MaxPool3d(nn.MaxPool3d): class MaxPool3d(nn.MaxPool3d):
def forward(self, x): def forward(self, x):
# PyTorch 1.7 does not support empty tensor inference yet # PyTorch 1.9 does not support empty tensor inference yet
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 7)): if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
out_shape = list(x.shape[:2]) out_shape = list(x.shape[:2])
for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size), for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size),
_triple(self.padding), _triple(self.padding),
......
...@@ -330,7 +330,7 @@ def test_linear(in_w, in_h, in_feature, out_feature): ...@@ -330,7 +330,7 @@ def test_linear(in_w, in_h, in_feature, out_feature):
wrapper(x_empty) wrapper(x_empty)
@patch('mmcv.cnn.bricks.wrappers.TORCH_VERSION', (1, 8)) @patch('mmcv.cnn.bricks.wrappers.TORCH_VERSION', (1, 10))
def test_nn_op_forward_called(): def test_nn_op_forward_called():
for m in ['Conv2d', 'ConvTranspose2d', 'MaxPool2d']: for m in ['Conv2d', 'ConvTranspose2d', 'MaxPool2d']:
...@@ -347,6 +347,20 @@ def test_nn_op_forward_called(): ...@@ -347,6 +347,20 @@ def test_nn_op_forward_called():
wrapper(x_normal) wrapper(x_normal)
nn_module_forward.assert_called_with(x_normal) nn_module_forward.assert_called_with(x_normal)
for m in ['Conv3d', 'ConvTranspose3d', 'MaxPool3d']:
with patch(f'torch.nn.{m}.forward') as nn_module_forward:
# randn input
x_empty = torch.randn(0, 3, 10, 10, 10)
wrapper = eval(m)(3, 2, 1)
wrapper(x_empty)
nn_module_forward.assert_called_with(x_empty)
# non-randn input
x_normal = torch.randn(1, 3, 10, 10, 10)
wrapper = eval(m)(3, 2, 1)
wrapper(x_normal)
nn_module_forward.assert_called_with(x_normal)
with patch('torch.nn.Linear.forward') as nn_module_forward: with patch('torch.nn.Linear.forward') as nn_module_forward:
# randn input # randn input
x_empty = torch.randn(0, 3) x_empty = torch.randn(0, 3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment