test_non_local.py 2.92 KB
Newer Older
Jintao Lin's avatar
Jintao Lin committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import pytest
import torch
import torch.nn as nn

from mmcv.cnn import NonLocal1d, NonLocal2d, NonLocal3d
from mmcv.cnn.bricks.non_local import _NonLocalNd


def test_nonlocal():
    with pytest.raises(ValueError):
        # mode should be in ['embedded_gaussian', 'dot_product']
        _NonLocalNd(3, mode='unsupport_mode')

    # _NonLocalNd
    _NonLocalNd(3, norm_cfg=dict(type='BN'))
    # Not Zero initialization
    _NonLocalNd(3, norm_cfg=dict(type='BN'), zeros_init=True)

    # NonLocal3d
    imgs = torch.randn(2, 3, 10, 20, 20)
    nonlocal_3d = NonLocal3d(3)
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            # NonLocal is only implemented on gpu in parrots
            imgs = imgs.cuda()
            nonlocal_3d.cuda()
    out = nonlocal_3d(imgs)
    assert out.shape == imgs.shape

    nonlocal_3d = NonLocal3d(3, mode='dot_product')
    assert nonlocal_3d.mode == 'dot_product'
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            nonlocal_3d.cuda()
    out = nonlocal_3d(imgs)
    assert out.shape == imgs.shape

    nonlocal_3d = NonLocal3d(3, mode='dot_product', sub_sample=True)
    for m in [nonlocal_3d.g, nonlocal_3d.phi]:
        assert isinstance(m, nn.Sequential) and len(m) == 2
        assert isinstance(m[1], nn.MaxPool3d)
        assert m[1].kernel_size == (1, 2, 2)
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            nonlocal_3d.cuda()
    out = nonlocal_3d(imgs)
    assert out.shape == imgs.shape

    # NonLocal2d
    imgs = torch.randn(2, 3, 20, 20)
    nonlocal_2d = NonLocal2d(3)
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            imgs = imgs.cuda()
            nonlocal_2d.cuda()
    out = nonlocal_2d(imgs)
    assert out.shape == imgs.shape

    nonlocal_2d = NonLocal2d(3, mode='dot_product', sub_sample=True)
    for m in [nonlocal_2d.g, nonlocal_2d.phi]:
        assert isinstance(m, nn.Sequential) and len(m) == 2
        assert isinstance(m[1], nn.MaxPool2d)
        assert m[1].kernel_size == (2, 2)
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            nonlocal_2d.cuda()
    out = nonlocal_2d(imgs)
    assert out.shape == imgs.shape

    # NonLocal1d
    imgs = torch.randn(2, 3, 20)
    nonlocal_1d = NonLocal1d(3)
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            imgs = imgs.cuda()
            nonlocal_1d.cuda()
    out = nonlocal_1d(imgs)
    assert out.shape == imgs.shape

    nonlocal_1d = NonLocal1d(3, mode='dot_product', sub_sample=True)
    for m in [nonlocal_1d.g, nonlocal_1d.phi]:
        assert isinstance(m, nn.Sequential) and len(m) == 2
        assert isinstance(m[1], nn.MaxPool1d)
        assert m[1].kernel_size == 2
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            nonlocal_1d.cuda()
    out = nonlocal_1d(imgs)
    assert out.shape == imgs.shape