test_scatter_points.py 5.85 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
4
5
6
import pytest
import torch
from torch.autograd import gradcheck

from mmcv.ops import DynamicScatter
7
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
8

pc's avatar
pc committed
9
10
11
if torch.__version__ == 'parrots':
    pytest.skip('not supported in parrots now', allow_module_level=True)

12

13
14
15
16
17
18
19
20
21
22
23
@pytest.mark.parametrize('device', [
    pytest.param(
        'cuda',
        marks=pytest.mark.skipif(
            not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
    pytest.param(
        'mlu',
        marks=pytest.mark.skipif(
            not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
def test_dynamic_scatter(device):
24
25
26
27
28
29
    dsmean = DynamicScatter([0.32, 0.32, 6],
                            [-74.88, -74.88, -2, 74.88, 74.88, 4], True)
    dsmax = DynamicScatter([0.32, 0.32, 6],
                           [-74.88, -74.88, -2, 74.88, 74.88, 4], False)

    # test empty input
30
31
    empty_feats = torch.empty(size=(0, 3), dtype=torch.float32, device=device)
    empty_coors = torch.empty(size=(0, 3), dtype=torch.int32, device=device)
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

    empty_feats.requires_grad_()
    empty_feats_out_mean, empty_coors_out_mean = dsmean(
        empty_feats, empty_coors)
    empty_feats_out_mean.sum().backward()
    empty_feats_out_max, empty_coors_out_max = dsmax(empty_feats, empty_coors)
    empty_feats_out_max.sum().backward()

    assert empty_feats_out_mean.shape == empty_feats.shape
    assert empty_feats_out_max.shape == empty_feats.shape
    assert empty_coors_out_mean.shape == empty_coors.shape
    assert empty_coors_out_max.shape == empty_coors.shape

    # test empty reduced output
    empty_o_feats = torch.rand(
47
        size=(200000, 3), dtype=torch.float32, device=device) * 100 - 50
48
    empty_o_coors = torch.randint(
49
        low=-1, high=0, size=(200000, 3), dtype=torch.int32, device=device)
50
51
52
53
54
55
56
57
58
59
60
61
62

    empty_o_feats.requires_grad_()
    empty_o_feats_out_mean, empty_o_coors_out_mean = dsmean(
        empty_o_feats, empty_o_coors)
    empty_o_feats_out_mean.sum().backward()
    assert (empty_o_feats.grad == 0).all()

    empty_o_feats_out_max, empty_o_coors_out_max = dsmax(
        empty_o_feats, empty_o_coors)
    empty_o_feats_out_max.sum().backward()
    assert (empty_o_feats.grad == 0).all()

    # test non-empty input
63
    feats = torch.rand(
64
        size=(200000, 3), dtype=torch.float32, device=device) * 100 - 50
65
    coors = torch.randint(
66
        low=-1, high=20, size=(200000, 3), dtype=torch.int32, device=device)
67

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    ref_voxel_coors = coors.unique(dim=0, sorted=True)
    ref_voxel_coors = ref_voxel_coors[ref_voxel_coors.min(dim=-1).values >= 0]
    ref_voxel_feats_mean = []
    ref_voxel_feats_max = []
    for ref_voxel_coor in ref_voxel_coors:
        voxel_mask = (coors == ref_voxel_coor).all(dim=-1)
        ref_voxel_feats_mean.append(feats[voxel_mask].mean(dim=0))
        ref_voxel_feats_max.append(feats[voxel_mask].max(dim=0).values)
    ref_voxel_feats_mean = torch.stack(ref_voxel_feats_mean)
    ref_voxel_feats_max = torch.stack(ref_voxel_feats_max)

    feats_out_mean, coors_out_mean = dsmean(feats, coors)
    seq_mean = (coors_out_mean[:, 0] * 400 + coors_out_mean[:, 1] * 20 +
                coors_out_mean[:, 2]).argsort()
    feats_out_mean = feats_out_mean[seq_mean]
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    coors_out_mean = coors_out_mean[seq_mean]

    feats_out_max, coors_out_max = dsmax(feats, coors)
    seq_max = (coors_out_max[:, 0] * 400 + coors_out_max[:, 1] * 20 +
               coors_out_max[:, 2]).argsort()
    feats_out_max = feats_out_max[seq_max]
    coors_cout_max = coors_out_max[seq_max]

    assert (coors_out_mean == ref_voxel_coors).all()
    assert torch.allclose(
        feats_out_mean, ref_voxel_feats_mean, atol=1e-2, rtol=1e-5)
    assert (coors_cout_max == ref_voxel_coors).all()
    assert torch.allclose(
        feats_out_max, ref_voxel_feats_max, atol=1e-2, rtol=1e-5)

    # test non-empty input without any point out of bound
    feats = torch.rand(
100
        size=(200000, 3), dtype=torch.float32, device=device) * 100 - 50
101
    coors = torch.randint(
102
        low=0, high=20, size=(200000, 3), dtype=torch.int32, device=device)
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

    ref_voxel_coors = coors.unique(dim=0, sorted=True)
    ref_voxel_coors = ref_voxel_coors[ref_voxel_coors.min(dim=-1).values >= 0]
    ref_voxel_feats_mean = []
    ref_voxel_feats_max = []
    for ref_voxel_coor in ref_voxel_coors:
        voxel_mask = (coors == ref_voxel_coor).all(dim=-1)
        ref_voxel_feats_mean.append(feats[voxel_mask].mean(dim=0))
        ref_voxel_feats_max.append(feats[voxel_mask].max(dim=0).values)
    ref_voxel_feats_mean = torch.stack(ref_voxel_feats_mean)
    ref_voxel_feats_max = torch.stack(ref_voxel_feats_max)

    feats_out_mean, coors_out_mean = dsmean(feats, coors)
    seq_mean = (coors_out_mean[:, 0] * 400 + coors_out_mean[:, 1] * 20 +
                coors_out_mean[:, 2]).argsort()
    feats_out_mean = feats_out_mean[seq_mean]
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    coors_out_mean = coors_out_mean[seq_mean]

    feats_out_max, coors_out_max = dsmax(feats, coors)
    seq_max = (coors_out_max[:, 0] * 400 + coors_out_max[:, 1] * 20 +
               coors_out_max[:, 2]).argsort()
    feats_out_max = feats_out_max[seq_max]
    coors_cout_max = coors_out_max[seq_max]

    assert (coors_out_mean == ref_voxel_coors).all()
    assert torch.allclose(
        feats_out_mean, ref_voxel_feats_mean, atol=1e-2, rtol=1e-5)
    assert (coors_cout_max == ref_voxel_coors).all()
    assert torch.allclose(
        feats_out_max, ref_voxel_feats_max, atol=1e-2, rtol=1e-5)

    # test grad #
    feats = torch.rand(
136
        size=(100, 4), dtype=torch.float32, device=device) * 100 - 50
137
    coors = torch.randint(
138
        low=-1, high=3, size=(100, 3), dtype=torch.int32, device=device)
139
    feats.requires_grad_()
140
141
142
    # TODO(Cambricon): mlu only support max reduce in current version.
    if not IS_MLU_AVAILABLE:
        gradcheck(dsmean, (feats, coors), eps=1e-2, atol=1e-2, rtol=1e-5)
143
    gradcheck(dsmax, (feats, coors), eps=1e-2, atol=1e-2, rtol=1e-5)