Unverified Commit 9036241e authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[Enhancement] Change the order of condition to make fx wok (#2883)

parent f64d4858
...@@ -41,7 +41,7 @@ class NewEmptyTensorOp(torch.autograd.Function): ...@@ -41,7 +41,7 @@ class NewEmptyTensorOp(torch.autograd.Function):
class Conv2d(nn.Conv2d): class Conv2d(nn.Conv2d):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0:
out_shape = [x.shape[0], self.out_channels] out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size, for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size,
self.padding, self.stride, self.dilation): self.padding, self.stride, self.dilation):
...@@ -62,7 +62,7 @@ class Conv2d(nn.Conv2d): ...@@ -62,7 +62,7 @@ class Conv2d(nn.Conv2d):
class Conv3d(nn.Conv3d): class Conv3d(nn.Conv3d):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0:
out_shape = [x.shape[0], self.out_channels] out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size, for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size,
self.padding, self.stride, self.dilation): self.padding, self.stride, self.dilation):
...@@ -84,7 +84,7 @@ class Conv3d(nn.Conv3d): ...@@ -84,7 +84,7 @@ class Conv3d(nn.Conv3d):
class ConvTranspose2d(nn.ConvTranspose2d): class ConvTranspose2d(nn.ConvTranspose2d):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0:
out_shape = [x.shape[0], self.out_channels] out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size, for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size,
self.padding, self.stride, self.padding, self.stride,
...@@ -106,7 +106,7 @@ class ConvTranspose2d(nn.ConvTranspose2d): ...@@ -106,7 +106,7 @@ class ConvTranspose2d(nn.ConvTranspose2d):
class ConvTranspose3d(nn.ConvTranspose3d): class ConvTranspose3d(nn.ConvTranspose3d):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0:
out_shape = [x.shape[0], self.out_channels] out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size, for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size,
self.padding, self.stride, self.padding, self.stride,
...@@ -127,7 +127,7 @@ class MaxPool2d(nn.MaxPool2d): ...@@ -127,7 +127,7 @@ class MaxPool2d(nn.MaxPool2d):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# PyTorch 1.9 does not support empty tensor inference yet # PyTorch 1.9 does not support empty tensor inference yet
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): if obsolete_torch_version(TORCH_VERSION, (1, 9)) and x.numel() == 0:
out_shape = list(x.shape[:2]) out_shape = list(x.shape[:2])
for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size), for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size),
_pair(self.padding), _pair(self.stride), _pair(self.padding), _pair(self.stride),
...@@ -145,7 +145,7 @@ class MaxPool3d(nn.MaxPool3d): ...@@ -145,7 +145,7 @@ class MaxPool3d(nn.MaxPool3d):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# PyTorch 1.9 does not support empty tensor inference yet # PyTorch 1.9 does not support empty tensor inference yet
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): if obsolete_torch_version(TORCH_VERSION, (1, 9)) and x.numel() == 0:
out_shape = list(x.shape[:2]) out_shape = list(x.shape[:2])
for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size), for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size),
_triple(self.padding), _triple(self.padding),
...@@ -164,7 +164,7 @@ class Linear(torch.nn.Linear): ...@@ -164,7 +164,7 @@ class Linear(torch.nn.Linear):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# empty tensor forward of Linear layer is supported in Pytorch 1.6 # empty tensor forward of Linear layer is supported in Pytorch 1.6
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)): if obsolete_torch_version(TORCH_VERSION, (1, 5)) and x.numel() == 0:
out_shape = [x.shape[0], self.out_features] out_shape = [x.shape[0], self.out_features]
empty = NewEmptyTensorOp.apply(x, out_shape) empty = NewEmptyTensorOp.apply(x, out_shape)
if self.training: if self.training:
......
...@@ -4,6 +4,8 @@ from unittest.mock import patch ...@@ -4,6 +4,8 @@ from unittest.mock import patch
import pytest import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
from mmcv.cnn.bricks import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d, from mmcv.cnn.bricks import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
Linear, MaxPool2d, MaxPool3d) Linear, MaxPool2d, MaxPool3d)
...@@ -374,3 +376,21 @@ def test_nn_op_forward_called(): ...@@ -374,3 +376,21 @@ def test_nn_op_forward_called():
wrapper = Linear(3, 3) wrapper = Linear(3, 3)
wrapper(x_normal) wrapper(x_normal)
nn_module_forward.assert_called_with(x_normal) nn_module_forward.assert_called_with(x_normal)
@pytest.mark.skipif(
digit_version(TORCH_VERSION) < digit_version('1.10'),
reason='MaxPool2d and MaxPool3d will fail fx for torch<=1.9')
def test_fx_compatibility():
from torch import fx
# ensure the fx trace can pass the network
for Net in (MaxPool2d, MaxPool3d):
net = Net(1)
gm_module = fx.symbolic_trace(net) # noqa: F841
for Net in (Linear, ):
net = Net(1, 1)
gm_module = fx.symbolic_trace(net) # noqa: F841
for Net in (Conv2d, ConvTranspose2d, Conv3d, ConvTranspose3d):
net = Net(1, 1, 1)
gm_module = fx.symbolic_trace(net) # noqa: F841
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment