Commit 144e7567 authored by dreamerlin's avatar dreamerlin
Browse files

use pytest.mark.parametrize

parent 86d9f468
from collections import OrderedDict
from itertools import product
from unittest.mock import patch from unittest.mock import patch
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -10,239 +9,248 @@ from mmcv.cnn.bricks import (Conv2d, ConvTranspose2d, ConvTranspose3d, Linear, ...@@ -10,239 +9,248 @@ from mmcv.cnn.bricks import (Conv2d, ConvTranspose2d, ConvTranspose3d, Linear,
@patch('torch.__version__', '1.1') @patch('torch.__version__', '1.1')
def test_conv2d(): @pytest.mark.parametrize(
'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation',
[(10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 3, 3, 5, 2, 1, 2)])
def test_conv2d(in_w, in_h, in_channel, out_channel, kernel_size, stride,
padding, dilation):
""" """
CommandLine: CommandLine:
xdoctest -m tests/test_wrappers.py test_conv2d xdoctest -m tests/test_wrappers.py test_conv2d
""" """
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]),
('in_channel', [1, 3]), ('out_channel', [1, 3]),
('kernel_size', [3, 5]), ('stride', [1, 2]),
('padding', [0, 1]), ('dilation', [1, 2])])
# train mode # train mode
for in_h, in_w, in_cha, out_cha, k, s, p, d in product( # wrapper op with 0-dim input
*list(test_cases.values())): x_empty = torch.randn(0, in_channel, in_h, in_w)
# wrapper op with 0-dim input torch.manual_seed(0)
x_empty = torch.randn(0, in_cha, in_h, in_w) wrapper = Conv2d(
torch.manual_seed(0) in_channel,
wrapper = Conv2d(in_cha, out_cha, k, stride=s, padding=p, dilation=d) out_channel,
wrapper_out = wrapper(x_empty) kernel_size,
stride=stride,
# torch op with 3-dim input as shape reference padding=padding,
x_normal = torch.randn(3, in_cha, in_h, in_w).requires_grad_(True) dilation=dilation)
torch.manual_seed(0) wrapper_out = wrapper(x_empty)
ref = nn.Conv2d(in_cha, out_cha, k, stride=s, padding=p, dilation=d)
ref_out = ref(x_normal) # torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_channel, in_h, in_w).requires_grad_(True)
assert wrapper_out.shape[0] == 0 torch.manual_seed(0)
assert wrapper_out.shape[1:] == ref_out.shape[1:] ref = nn.Conv2d(
in_channel,
wrapper_out.sum().backward() out_channel,
assert wrapper.weight.grad is not None kernel_size,
assert wrapper.weight.grad.shape == wrapper.weight.shape stride=stride,
padding=padding,
assert torch.equal(wrapper(x_normal), ref_out) dilation=dilation)
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0
assert wrapper_out.shape[1:] == ref_out.shape[1:]
wrapper_out.sum().backward()
assert wrapper.weight.grad is not None
assert wrapper.weight.grad.shape == wrapper.weight.shape
assert torch.equal(wrapper(x_normal), ref_out)
# eval mode # eval mode
x_empty = torch.randn(0, in_cha, in_h, in_w) x_empty = torch.randn(0, in_channel, in_h, in_w)
wrapper = Conv2d(in_cha, out_cha, k, stride=s, padding=p, dilation=d) wrapper = Conv2d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
wrapper.eval() wrapper.eval()
wrapper(x_empty) wrapper(x_empty)
@patch('torch.__version__', '1.1') @patch('torch.__version__', '1.1')
def test_conv_transposed_2d(): @pytest.mark.parametrize(
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]), 'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation',
('in_channel', [1, 3]), ('out_channel', [1, 3]), [(10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 3, 3, 5, 2, 1, 2)])
('kernel_size', [3, 5]), ('stride', [1, 2]), def test_conv_transposed_2d(in_w, in_h, in_channel, out_channel, kernel_size,
('padding', [0, 1]), ('dilation', [1, 2])]) stride, padding, dilation):
# wrapper op with 0-dim input
for in_h, in_w, in_cha, out_cha, k, s, p, d in product( x_empty = torch.randn(0, in_channel, in_h, in_w, requires_grad=True)
*list(test_cases.values())): # out padding must be smaller than either stride or dilation
# wrapper op with 0-dim input op = min(stride, dilation) - 1
x_empty = torch.randn(0, in_cha, in_h, in_w, requires_grad=True) torch.manual_seed(0)
# out padding must be smaller than either stride or dilation wrapper = ConvTranspose2d(
op = min(s, d) - 1 in_channel,
torch.manual_seed(0) out_channel,
wrapper = ConvTranspose2d( kernel_size,
in_cha, stride=stride,
out_cha, padding=padding,
k, dilation=dilation,
stride=s, output_padding=op)
padding=p, wrapper_out = wrapper(x_empty)
dilation=d,
output_padding=op) # torch op with 3-dim input as shape reference
wrapper_out = wrapper(x_empty) x_normal = torch.randn(3, in_channel, in_h, in_w)
torch.manual_seed(0)
# torch op with 3-dim input as shape reference ref = nn.ConvTranspose2d(
x_normal = torch.randn(3, in_cha, in_h, in_w) in_channel,
torch.manual_seed(0) out_channel,
ref = nn.ConvTranspose2d( kernel_size,
in_cha, stride=stride,
out_cha, padding=padding,
k, dilation=dilation,
stride=s, output_padding=op)
padding=p, ref_out = ref(x_normal)
dilation=d,
output_padding=op) assert wrapper_out.shape[0] == 0
ref_out = ref(x_normal) assert wrapper_out.shape[1:] == ref_out.shape[1:]
assert wrapper_out.shape[0] == 0 wrapper_out.sum().backward()
assert wrapper_out.shape[1:] == ref_out.shape[1:] assert wrapper.weight.grad is not None
assert wrapper.weight.grad.shape == wrapper.weight.shape
wrapper_out.sum().backward()
assert wrapper.weight.grad is not None assert torch.equal(wrapper(x_normal), ref_out)
assert wrapper.weight.grad.shape == wrapper.weight.shape
assert torch.equal(wrapper(x_normal), ref_out)
# eval mode # eval mode
x_empty = torch.randn(0, in_cha, in_h, in_w) x_empty = torch.randn(0, in_channel, in_h, in_w)
wrapper = ConvTranspose2d( wrapper = ConvTranspose2d(
in_cha, out_cha, k, stride=s, padding=p, dilation=d, output_padding=op) in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
output_padding=op)
wrapper.eval() wrapper.eval()
wrapper(x_empty) wrapper(x_empty)
@patch('torch.__version__', '1.1') @patch('torch.__version__', '1.1')
def test_conv_transposed_3d(): @pytest.mark.parametrize(
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]), 'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation', # noqa: E501
('in_t', [10, 20]), ('in_channel', [1, 3]), [(10, 10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 20, 3, 3, 5, 2, 1, 2)])
('out_channel', [1, 3]), ('kernel_size', [3, 5]), def test_conv_transposed_3d(in_w, in_h, in_t, in_channel, out_channel,
('stride', [1, 2]), ('padding', [0, 1]), kernel_size, stride, padding, dilation):
('dilation', [1, 2])]) # wrapper op with 0-dim input
x_empty = torch.randn(0, in_channel, in_t, in_h, in_w, requires_grad=True)
for in_h, in_w, in_t, in_cha, out_cha, k, s, p, d in product( # out padding must be smaller than either stride or dilation
*list(test_cases.values())): op = min(stride, dilation) - 1
# wrapper op with 0-dim input torch.manual_seed(0)
x_empty = torch.randn(0, in_cha, in_t, in_h, in_w, requires_grad=True) wrapper = ConvTranspose3d(
# out padding must be smaller than either stride or dilation in_channel,
op = min(s, d) - 1 out_channel,
torch.manual_seed(0) kernel_size,
wrapper = ConvTranspose3d( stride=stride,
in_cha, padding=padding,
out_cha, dilation=dilation,
k, output_padding=op)
stride=s, wrapper_out = wrapper(x_empty)
padding=p,
dilation=d, # torch op with 3-dim input as shape reference
output_padding=op) x_normal = torch.randn(3, in_channel, in_t, in_h, in_w)
wrapper_out = wrapper(x_empty) torch.manual_seed(0)
ref = nn.ConvTranspose3d(
# torch op with 3-dim input as shape reference in_channel,
x_normal = torch.randn(3, in_cha, in_t, in_h, in_w) out_channel,
torch.manual_seed(0) kernel_size,
ref = nn.ConvTranspose3d( stride=stride,
in_cha, padding=padding,
out_cha, dilation=dilation,
k, output_padding=op)
stride=s, ref_out = ref(x_normal)
padding=p,
dilation=d, assert wrapper_out.shape[0] == 0
output_padding=op) assert wrapper_out.shape[1:] == ref_out.shape[1:]
ref_out = ref(x_normal)
wrapper_out.sum().backward()
assert wrapper_out.shape[0] == 0 assert wrapper.weight.grad is not None
assert wrapper_out.shape[1:] == ref_out.shape[1:] assert wrapper.weight.grad.shape == wrapper.weight.shape
wrapper_out.sum().backward() assert torch.equal(wrapper(x_normal), ref_out)
assert wrapper.weight.grad is not None
assert wrapper.weight.grad.shape == wrapper.weight.shape
assert torch.equal(wrapper(x_normal), ref_out)
# eval mode # eval mode
x_empty = torch.randn(0, in_cha, in_t, in_h, in_w) x_empty = torch.randn(0, in_channel, in_t, in_h, in_w)
wrapper = ConvTranspose3d( wrapper = ConvTranspose3d(
in_cha, out_cha, k, stride=s, padding=p, dilation=d, output_padding=op) in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
output_padding=op)
wrapper.eval() wrapper.eval()
wrapper(x_empty) wrapper(x_empty)
@patch('torch.__version__', '1.1') @patch('torch.__version__', '1.1')
def test_max_pool_2d(): @pytest.mark.parametrize(
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]), 'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation',
('in_channel', [1, 3]), ('out_channel', [1, 3]), [(10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 3, 3, 5, 2, 1, 2)])
('kernel_size', [3, 5]), ('stride', [1, 2]), def test_max_pool_2d(in_w, in_h, in_channel, out_channel, kernel_size, stride,
('padding', [0, 1]), ('dilation', [1, 2])]) padding, dilation):
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_channel, in_h, in_w, requires_grad=True)
wrapper = MaxPool2d(
kernel_size, stride=stride, padding=padding, dilation=dilation)
wrapper_out = wrapper(x_empty)
for in_h, in_w, in_cha, out_cha, k, s, p, d in product( # torch op with 3-dim input as shape reference
*list(test_cases.values())): x_normal = torch.randn(3, in_channel, in_h, in_w)
# wrapper op with 0-dim input ref = nn.MaxPool2d(
x_empty = torch.randn(0, in_cha, in_h, in_w, requires_grad=True) kernel_size, stride=stride, padding=padding, dilation=dilation)
wrapper = MaxPool2d(k, stride=s, padding=p, dilation=d) ref_out = ref(x_normal)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference assert wrapper_out.shape[0] == 0
x_normal = torch.randn(3, in_cha, in_h, in_w) assert wrapper_out.shape[1:] == ref_out.shape[1:]
ref = nn.MaxPool2d(k, stride=s, padding=p, dilation=d)
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0 assert torch.equal(wrapper(x_normal), ref_out)
assert wrapper_out.shape[1:] == ref_out.shape[1:]
assert torch.equal(wrapper(x_normal), ref_out)
@patch('torch.__version__', '1.1') @patch('torch.__version__', '1.1')
def test_max_pool_3d(): @pytest.mark.parametrize(
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]), 'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation', # noqa: E501
('in_t', [10, 20]), ('in_channel', [1, 3]), [(10, 10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 20, 3, 3, 5, 2, 1, 2)])
('out_channel', [1, 3]), ('kernel_size', [3, 5]), def test_max_pool_3d(in_w, in_h, in_t, in_channel, out_channel, kernel_size,
('stride', [1, 2]), ('padding', [0, 1]), stride, padding, dilation):
('dilation', [1, 2])]) # wrapper op with 0-dim input
x_empty = torch.randn(0, in_channel, in_t, in_h, in_w, requires_grad=True)
for in_h, in_w, in_t, in_cha, out_cha, k, s, p, d in product( wrapper = MaxPool3d(
*list(test_cases.values())): kernel_size, stride=stride, padding=padding, dilation=dilation)
# wrapper op with 0-dim input wrapper_out = wrapper(x_empty)
x_empty = torch.randn(0, in_cha, in_t, in_h, in_w, requires_grad=True)
wrapper = MaxPool3d(k, stride=s, padding=p, dilation=d)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference # torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_cha, in_t, in_h, in_w) x_normal = torch.randn(3, in_channel, in_t, in_h, in_w)
ref = nn.MaxPool3d(k, stride=s, padding=p, dilation=d) ref = nn.MaxPool3d(
ref_out = ref(x_normal) kernel_size, stride=stride, padding=padding, dilation=dilation)
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0 assert wrapper_out.shape[0] == 0
assert wrapper_out.shape[1:] == ref_out.shape[1:] assert wrapper_out.shape[1:] == ref_out.shape[1:]
assert torch.equal(wrapper(x_normal), ref_out) assert torch.equal(wrapper(x_normal), ref_out)
@patch('torch.__version__', '1.1') @patch('torch.__version__', '1.1')
def test_linear(): @pytest.mark.parametrize('in_w,in_h,in_feature,out_feature', [(10, 10, 1, 1),
test_cases = OrderedDict([ (20, 20, 3, 3)])
('in_w', [10, 20]), def test_linear(in_w, in_h, in_feature, out_feature):
('in_h', [10, 20]), # wrapper op with 0-dim input
('in_feature', [1, 3]), x_empty = torch.randn(0, in_feature, requires_grad=True)
('out_feature', [1, 3]), torch.manual_seed(0)
]) wrapper = Linear(in_feature, out_feature)
wrapper_out = wrapper(x_empty)
for in_h, in_w, in_feature, out_feature in product(
*list(test_cases.values())): # torch op with 3-dim input as shape reference
# wrapper op with 0-dim input x_normal = torch.randn(3, in_feature)
x_empty = torch.randn(0, in_feature, requires_grad=True) torch.manual_seed(0)
torch.manual_seed(0) ref = nn.Linear(in_feature, out_feature)
wrapper = Linear(in_feature, out_feature) ref_out = ref(x_normal)
wrapper_out = wrapper(x_empty)
assert wrapper_out.shape[0] == 0
# torch op with 3-dim input as shape reference assert wrapper_out.shape[1:] == ref_out.shape[1:]
x_normal = torch.randn(3, in_feature)
torch.manual_seed(0) wrapper_out.sum().backward()
ref = nn.Linear(in_feature, out_feature) assert wrapper.weight.grad is not None
ref_out = ref(x_normal) assert wrapper.weight.grad.shape == wrapper.weight.shape
assert wrapper_out.shape[0] == 0 assert torch.equal(wrapper(x_normal), ref_out)
assert wrapper_out.shape[1:] == ref_out.shape[1:]
wrapper_out.sum().backward()
assert wrapper.weight.grad is not None
assert wrapper.weight.grad.shape == wrapper.weight.shape
assert torch.equal(wrapper(x_normal), ref_out)
# eval mode # eval mode
x_empty = torch.randn(0, in_feature) x_empty = torch.randn(0, in_feature)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment