test_merge_cells.py 2.44 KB
Newer Older
Cao Yuhang's avatar
Cao Yuhang committed
1
2
3
4
5
6
7
8
9
10
11
"""
CommandLine:
    pytest tests/test_merge_cells.py
"""
import torch
import torch.nn.functional as F

from mmcv.ops.merge_cells import (BaseMergeCell, ConcatCell, GlobalPoolingCell,
                                  SumCell)


limm's avatar
limm committed
12
13
14
def test_sum_cell():
    inputs_x = torch.randn([2, 256, 32, 32])
    inputs_y = torch.randn([2, 256, 16, 16])
Cao Yuhang's avatar
Cao Yuhang committed
15
16
17
18
19
20
    sum_cell = SumCell(256, 256)
    output = sum_cell(inputs_x, inputs_y, out_size=inputs_x.shape[-2:])
    assert output.size() == inputs_x.size()
    output = sum_cell(inputs_x, inputs_y, out_size=inputs_y.shape[-2:])
    assert output.size() == inputs_y.size()
    output = sum_cell(inputs_x, inputs_y)
limm's avatar
limm committed
21
    assert output.size() == inputs_x.size()
Cao Yuhang's avatar
Cao Yuhang committed
22
23


limm's avatar
limm committed
24
25
26
def test_concat_cell():
    inputs_x = torch.randn([2, 256, 32, 32])
    inputs_y = torch.randn([2, 256, 16, 16])
Cao Yuhang's avatar
Cao Yuhang committed
27
28
29
30
31
32
    concat_cell = ConcatCell(256, 256)
    output = concat_cell(inputs_x, inputs_y, out_size=inputs_x.shape[-2:])
    assert output.size() == inputs_x.size()
    output = concat_cell(inputs_x, inputs_y, out_size=inputs_y.shape[-2:])
    assert output.size() == inputs_y.size()
    output = concat_cell(inputs_x, inputs_y)
limm's avatar
limm committed
33
    assert output.size() == inputs_x.size()
Cao Yuhang's avatar
Cao Yuhang committed
34
35


limm's avatar
limm committed
36
37
38
def test_global_pool_cell():
    inputs_x = torch.randn([2, 256, 32, 32])
    inputs_y = torch.randn([2, 256, 32, 32])
Cao Yuhang's avatar
Cao Yuhang committed
39
40
41
42
43
44
45
46
    gp_cell = GlobalPoolingCell(with_out_conv=False)
    gp_cell_out = gp_cell(inputs_x, inputs_y, out_size=inputs_x.shape[-2:])
    assert (gp_cell_out.size() == inputs_x.size())
    gp_cell = GlobalPoolingCell(256, 256)
    gp_cell_out = gp_cell(inputs_x, inputs_y, out_size=inputs_x.shape[-2:])
    assert (gp_cell_out.size() == inputs_x.size())


limm's avatar
limm committed
47
def test_resize_methods():
Cao Yuhang's avatar
Cao Yuhang committed
48
    inputs_x = torch.randn([2, 256, 128, 128])
limm's avatar
limm committed
49
50
    target_resize_sizes = [(128, 128), (256, 256)]
    resize_methods_list = ['nearest', 'bilinear']
Cao Yuhang's avatar
Cao Yuhang committed
51

limm's avatar
limm committed
52
53
54
    for method in resize_methods_list:
        merge_cell = BaseMergeCell(upsample_mode=method)
        for target_size in target_resize_sizes:
Cao Yuhang's avatar
Cao Yuhang committed
55
56
57
            merge_cell_out = merge_cell._resize(inputs_x, target_size)
            gt_out = F.interpolate(inputs_x, size=target_size, mode=method)
            assert merge_cell_out.equal(gt_out)
limm's avatar
limm committed
58
59
60
61
62
63
64
65

    target_size = (64, 64)  # resize to a smaller size
    merge_cell = BaseMergeCell()
    merge_cell_out = merge_cell._resize(inputs_x, target_size)
    kernel_size = inputs_x.shape[-1] // target_size[-1]
    gt_out = F.max_pool2d(
        inputs_x, kernel_size=kernel_size, stride=kernel_size)
    assert (merge_cell_out == gt_out).all()