Unverified Commit 9036241e authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[Enhancement] Change the order of condition to make fx wok (#2883)

parent f64d4858
......@@ -41,7 +41,7 @@ class NewEmptyTensorOp(torch.autograd.Function):
class Conv2d(nn.Conv2d):
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0:
out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size,
self.padding, self.stride, self.dilation):
......@@ -62,7 +62,7 @@ class Conv2d(nn.Conv2d):
class Conv3d(nn.Conv3d):
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0:
out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size,
self.padding, self.stride, self.dilation):
......@@ -84,7 +84,7 @@ class Conv3d(nn.Conv3d):
class ConvTranspose2d(nn.ConvTranspose2d):
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0:
out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size,
self.padding, self.stride,
......@@ -106,7 +106,7 @@ class ConvTranspose2d(nn.ConvTranspose2d):
class ConvTranspose3d(nn.ConvTranspose3d):
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0:
out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size,
self.padding, self.stride,
......@@ -127,7 +127,7 @@ class MaxPool2d(nn.MaxPool2d):
def forward(self, x: torch.Tensor) -> torch.Tensor:
# PyTorch 1.9 does not support empty tensor inference yet
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
if obsolete_torch_version(TORCH_VERSION, (1, 9)) and x.numel() == 0:
out_shape = list(x.shape[:2])
for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size),
_pair(self.padding), _pair(self.stride),
......@@ -145,7 +145,7 @@ class MaxPool3d(nn.MaxPool3d):
def forward(self, x: torch.Tensor) -> torch.Tensor:
# PyTorch 1.9 does not support empty tensor inference yet
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
if obsolete_torch_version(TORCH_VERSION, (1, 9)) and x.numel() == 0:
out_shape = list(x.shape[:2])
for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size),
_triple(self.padding),
......@@ -164,7 +164,7 @@ class Linear(torch.nn.Linear):
def forward(self, x: torch.Tensor) -> torch.Tensor:
# empty tensor forward of Linear layer is supported in Pytorch 1.6
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)):
if obsolete_torch_version(TORCH_VERSION, (1, 5)) and x.numel() == 0:
out_shape = [x.shape[0], self.out_features]
empty = NewEmptyTensorOp.apply(x, out_shape)
if self.training:
......
......@@ -4,6 +4,8 @@ from unittest.mock import patch
import pytest
import torch
import torch.nn as nn
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
from mmcv.cnn.bricks import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
Linear, MaxPool2d, MaxPool3d)
......@@ -374,3 +376,21 @@ def test_nn_op_forward_called():
wrapper = Linear(3, 3)
wrapper(x_normal)
nn_module_forward.assert_called_with(x_normal)
@pytest.mark.skipif(
digit_version(TORCH_VERSION) < digit_version('1.10'),
reason='MaxPool2d and MaxPool3d will fail fx for torch<=1.9')
def test_fx_compatibility():
from torch import fx
# ensure the fx trace can pass the network
for Net in (MaxPool2d, MaxPool3d):
net = Net(1)
gm_module = fx.symbolic_trace(net) # noqa: F841
for Net in (Linear, ):
net = Net(1, 1)
gm_module = fx.symbolic_trace(net) # noqa: F841
for Net in (Conv2d, ConvTranspose2d, Conv3d, ConvTranspose3d):
net = Net(1, 1, 1)
gm_module = fx.symbolic_trace(net) # noqa: F841
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment