Unverified Commit 655f3c3f authored by BigBigDream's avatar BigBigDream Committed by GitHub
Browse files

fix mmcv_ci test_wrappers.py for parrots (#758)

parent 86e0d62a
...@@ -7,8 +7,13 @@ import torch.nn as nn ...@@ -7,8 +7,13 @@ import torch.nn as nn
from mmcv.cnn.bricks import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d, from mmcv.cnn.bricks import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
Linear, MaxPool2d, MaxPool3d) Linear, MaxPool2d, MaxPool3d)
if torch.__version__ != 'parrots':
torch_version = '1.1'
else:
torch_version = 'parrots'
@patch('torch.__version__', '1.1')
@patch('torch.__version__', torch_version)
@pytest.mark.parametrize( @pytest.mark.parametrize(
'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation', 'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation',
[(10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 3, 3, 5, 2, 1, 2)]) [(10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 3, 3, 5, 2, 1, 2)])
...@@ -65,7 +70,7 @@ def test_conv2d(in_w, in_h, in_channel, out_channel, kernel_size, stride, ...@@ -65,7 +70,7 @@ def test_conv2d(in_w, in_h, in_channel, out_channel, kernel_size, stride,
wrapper(x_empty) wrapper(x_empty)
@patch('torch.__version__', '1.1') @patch('torch.__version__', torch_version)
@pytest.mark.parametrize( @pytest.mark.parametrize(
'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation', # noqa: E501 'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation', # noqa: E501
[(10, 10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 20, 3, 3, 5, 2, 1, 2)]) [(10, 10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 20, 3, 3, 5, 2, 1, 2)])
...@@ -123,7 +128,7 @@ def test_conv3d(in_w, in_h, in_t, in_channel, out_channel, kernel_size, stride, ...@@ -123,7 +128,7 @@ def test_conv3d(in_w, in_h, in_t, in_channel, out_channel, kernel_size, stride,
wrapper(x_empty) wrapper(x_empty)
@patch('torch.__version__', '1.1') @patch('torch.__version__', torch_version)
@pytest.mark.parametrize( @pytest.mark.parametrize(
'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation', 'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation',
[(10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 3, 3, 5, 2, 1, 2)]) [(10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 3, 3, 5, 2, 1, 2)])
...@@ -133,6 +138,8 @@ def test_conv_transposed_2d(in_w, in_h, in_channel, out_channel, kernel_size, ...@@ -133,6 +138,8 @@ def test_conv_transposed_2d(in_w, in_h, in_channel, out_channel, kernel_size,
x_empty = torch.randn(0, in_channel, in_h, in_w, requires_grad=True) x_empty = torch.randn(0, in_channel, in_h, in_w, requires_grad=True)
# out padding must be smaller than either stride or dilation # out padding must be smaller than either stride or dilation
op = min(stride, dilation) - 1 op = min(stride, dilation) - 1
if torch.__version__ == 'parrots':
op = 0
torch.manual_seed(0) torch.manual_seed(0)
wrapper = ConvTranspose2d( wrapper = ConvTranspose2d(
in_channel, in_channel,
...@@ -180,7 +187,7 @@ def test_conv_transposed_2d(in_w, in_h, in_channel, out_channel, kernel_size, ...@@ -180,7 +187,7 @@ def test_conv_transposed_2d(in_w, in_h, in_channel, out_channel, kernel_size,
wrapper(x_empty) wrapper(x_empty)
@patch('torch.__version__', '1.1') @patch('torch.__version__', torch_version)
@pytest.mark.parametrize( @pytest.mark.parametrize(
'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation', # noqa: E501 'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation', # noqa: E501
[(10, 10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 20, 3, 3, 5, 2, 1, 2)]) [(10, 10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 20, 3, 3, 5, 2, 1, 2)])
...@@ -237,7 +244,7 @@ def test_conv_transposed_3d(in_w, in_h, in_t, in_channel, out_channel, ...@@ -237,7 +244,7 @@ def test_conv_transposed_3d(in_w, in_h, in_t, in_channel, out_channel,
wrapper(x_empty) wrapper(x_empty)
@patch('torch.__version__', '1.1') @patch('torch.__version__', torch_version)
@pytest.mark.parametrize( @pytest.mark.parametrize(
'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation', 'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation',
[(10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 3, 3, 5, 2, 1, 2)]) [(10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 3, 3, 5, 2, 1, 2)])
...@@ -261,22 +268,28 @@ def test_max_pool_2d(in_w, in_h, in_channel, out_channel, kernel_size, stride, ...@@ -261,22 +268,28 @@ def test_max_pool_2d(in_w, in_h, in_channel, out_channel, kernel_size, stride,
assert torch.equal(wrapper(x_normal), ref_out) assert torch.equal(wrapper(x_normal), ref_out)
@patch('torch.__version__', '1.1') @patch('torch.__version__', torch_version)
@pytest.mark.parametrize( @pytest.mark.parametrize(
'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation', # noqa: E501 'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation', # noqa: E501
[(10, 10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 20, 3, 3, 5, 2, 1, 2)]) [(10, 10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 20, 3, 3, 5, 2, 1, 2)])
@pytest.mark.skipif(
torch.__version__ == 'parrots' and not torch.cuda.is_available(),
reason='parrots requires CUDA support')
def test_max_pool_3d(in_w, in_h, in_t, in_channel, out_channel, kernel_size, def test_max_pool_3d(in_w, in_h, in_t, in_channel, out_channel, kernel_size,
stride, padding, dilation): stride, padding, dilation):
# wrapper op with 0-dim input # wrapper op with 0-dim input
x_empty = torch.randn(0, in_channel, in_t, in_h, in_w, requires_grad=True) x_empty = torch.randn(0, in_channel, in_t, in_h, in_w, requires_grad=True)
wrapper = MaxPool3d( wrapper = MaxPool3d(
kernel_size, stride=stride, padding=padding, dilation=dilation) kernel_size, stride=stride, padding=padding, dilation=dilation)
if torch.__version__ == 'parrots':
x_empty = x_empty.cuda()
wrapper_out = wrapper(x_empty) wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference # torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_channel, in_t, in_h, in_w) x_normal = torch.randn(3, in_channel, in_t, in_h, in_w)
ref = nn.MaxPool3d( ref = nn.MaxPool3d(
kernel_size, stride=stride, padding=padding, dilation=dilation) kernel_size, stride=stride, padding=padding, dilation=dilation)
if torch.__version__ == 'parrots':
x_normal = x_normal.cuda()
ref_out = ref(x_normal) ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0 assert wrapper_out.shape[0] == 0
...@@ -285,7 +298,7 @@ def test_max_pool_3d(in_w, in_h, in_t, in_channel, out_channel, kernel_size, ...@@ -285,7 +298,7 @@ def test_max_pool_3d(in_w, in_h, in_t, in_channel, out_channel, kernel_size,
assert torch.equal(wrapper(x_normal), ref_out) assert torch.equal(wrapper(x_normal), ref_out)
@patch('torch.__version__', '1.1') @patch('torch.__version__', torch_version)
@pytest.mark.parametrize('in_w,in_h,in_feature,out_feature', [(10, 10, 1, 1), @pytest.mark.parametrize('in_w,in_h,in_feature,out_feature', [(10, 10, 1, 1),
(20, 20, 3, 3)]) (20, 20, 3, 3)])
def test_linear(in_w, in_h, in_feature, out_feature): def test_linear(in_w, in_h, in_feature, out_feature):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment