test_scan.py 5.82 KB
Newer Older
1
2
3
4
5
6
7
8
9
import pytest
import torch

device = "cuda:0"


@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_inclusive_sum():
    from nerfacc.scan import inclusive_sum
Ruilong Li's avatar
Ruilong Li committed
10
    from nerfacc.scan_cub import inclusive_sum_cub
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

    torch.manual_seed(42)

    data = torch.rand((5, 1000), device=device, requires_grad=True)
    outputs1 = inclusive_sum(data)
    outputs1 = outputs1.flatten()
    outputs1.sum().backward()
    grad1 = data.grad.clone()
    data.grad.zero_()

    chunk_starts = torch.arange(
        0, data.numel(), data.shape[1], device=device, dtype=torch.long
    )
    chunk_cnts = torch.full(
        (data.shape[0],), data.shape[1], dtype=torch.long, device=device
    )
    packed_info = torch.stack([chunk_starts, chunk_cnts], dim=-1)
    flatten_data = data.flatten()
    outputs2 = inclusive_sum(flatten_data, packed_info=packed_info)
    outputs2.sum().backward()
    grad2 = data.grad.clone()
Ruilong Li's avatar
Ruilong Li committed
32
33
34
35
36
37
38
39
40
    data.grad.zero_()

    indices = torch.arange(data.shape[0], device=device, dtype=torch.long)
    indices = indices.repeat_interleave(data.shape[1])
    indices = indices.flatten()
    outputs3 = inclusive_sum_cub(flatten_data, indices)
    outputs3.sum().backward()
    grad3 = data.grad.clone()
    data.grad.zero_()
41
42
43
44

    assert torch.allclose(outputs1, outputs2)
    assert torch.allclose(grad1, grad2)

Ruilong Li's avatar
Ruilong Li committed
45
46
47
    assert torch.allclose(outputs1, outputs3)
    assert torch.allclose(grad1, grad3)

48
49
50
51

@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_exclusive_sum():
    from nerfacc.scan import exclusive_sum
Ruilong Li's avatar
Ruilong Li committed
52
    from nerfacc.scan_cub import exclusive_sum_cub
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

    torch.manual_seed(42)

    data = torch.rand((5, 1000), device=device, requires_grad=True)
    outputs1 = exclusive_sum(data)
    outputs1 = outputs1.flatten()
    outputs1.sum().backward()
    grad1 = data.grad.clone()
    data.grad.zero_()

    chunk_starts = torch.arange(
        0, data.numel(), data.shape[1], device=device, dtype=torch.long
    )
    chunk_cnts = torch.full(
        (data.shape[0],), data.shape[1], dtype=torch.long, device=device
    )
    packed_info = torch.stack([chunk_starts, chunk_cnts], dim=-1)
    flatten_data = data.flatten()
    outputs2 = exclusive_sum(flatten_data, packed_info=packed_info)
    outputs2.sum().backward()
    grad2 = data.grad.clone()
Ruilong Li's avatar
Ruilong Li committed
74
75
76
77
78
79
80
81
82
    data.grad.zero_()

    indices = torch.arange(data.shape[0], device=device, dtype=torch.long)
    indices = indices.repeat_interleave(data.shape[1])
    indices = indices.flatten()
    outputs3 = exclusive_sum_cub(flatten_data, indices)
    outputs3.sum().backward()
    grad3 = data.grad.clone()
    data.grad.zero_()
83
84
85
86
87
88

    # TODO: check exclusive sum. numeric error?
    # print((outputs1 - outputs2).abs().max())  # 0.0002
    assert torch.allclose(outputs1, outputs2, atol=3e-4)
    assert torch.allclose(grad1, grad2)

Ruilong Li's avatar
Ruilong Li committed
89
90
91
    assert torch.allclose(outputs1, outputs3, atol=3e-4)
    assert torch.allclose(grad1, grad3)

92
93
94
95

@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_inclusive_prod():
    from nerfacc.scan import inclusive_prod
Ruilong Li's avatar
Ruilong Li committed
96
    from nerfacc.scan_cub import inclusive_prod_cub
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

    torch.manual_seed(42)

    data = torch.rand((5, 1000), device=device, requires_grad=True)
    outputs1 = inclusive_prod(data)
    outputs1 = outputs1.flatten()
    outputs1.sum().backward()
    grad1 = data.grad.clone()
    data.grad.zero_()

    chunk_starts = torch.arange(
        0, data.numel(), data.shape[1], device=device, dtype=torch.long
    )
    chunk_cnts = torch.full(
        (data.shape[0],), data.shape[1], dtype=torch.long, device=device
    )
    packed_info = torch.stack([chunk_starts, chunk_cnts], dim=-1)
    flatten_data = data.flatten()
    outputs2 = inclusive_prod(flatten_data, packed_info=packed_info)
    outputs2.sum().backward()
    grad2 = data.grad.clone()
Ruilong Li's avatar
Ruilong Li committed
118
119
120
121
122
123
124
125
126
    data.grad.zero_()

    indices = torch.arange(data.shape[0], device=device, dtype=torch.long)
    indices = indices.repeat_interleave(data.shape[1])
    indices = indices.flatten()
    outputs3 = inclusive_prod_cub(flatten_data, indices)
    outputs3.sum().backward()
    grad3 = data.grad.clone()
    data.grad.zero_()
127
128
129
130

    assert torch.allclose(outputs1, outputs2)
    assert torch.allclose(grad1, grad2)

Ruilong Li's avatar
Ruilong Li committed
131
132
133
    assert torch.allclose(outputs1, outputs3)
    assert torch.allclose(grad1, grad3)

134
135
136
137

@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_exclusive_prod():
    from nerfacc.scan import exclusive_prod
Ruilong Li's avatar
Ruilong Li committed
138
    from nerfacc.scan_cub import exclusive_prod_cub
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159

    torch.manual_seed(42)

    data = torch.rand((5, 1000), device=device, requires_grad=True)
    outputs1 = exclusive_prod(data)
    outputs1 = outputs1.flatten()
    outputs1.sum().backward()
    grad1 = data.grad.clone()
    data.grad.zero_()

    chunk_starts = torch.arange(
        0, data.numel(), data.shape[1], device=device, dtype=torch.long
    )
    chunk_cnts = torch.full(
        (data.shape[0],), data.shape[1], dtype=torch.long, device=device
    )
    packed_info = torch.stack([chunk_starts, chunk_cnts], dim=-1)
    flatten_data = data.flatten()
    outputs2 = exclusive_prod(flatten_data, packed_info=packed_info)
    outputs2.sum().backward()
    grad2 = data.grad.clone()
Ruilong Li's avatar
Ruilong Li committed
160
161
162
163
164
165
166
167
168
    data.grad.zero_()

    indices = torch.arange(data.shape[0], device=device, dtype=torch.long)
    indices = indices.repeat_interleave(data.shape[1])
    indices = indices.flatten()
    outputs3 = exclusive_prod_cub(flatten_data, indices)
    outputs3.sum().backward()
    grad3 = data.grad.clone()
    data.grad.zero_()
169
170
171
172
173
174

    # TODO: check exclusive sum. numeric error?
    # print((outputs1 - outputs2).abs().max())
    assert torch.allclose(outputs1, outputs2)
    assert torch.allclose(grad1, grad2)

Ruilong Li's avatar
Ruilong Li committed
175
176
177
    assert torch.allclose(outputs1, outputs3)
    assert torch.allclose(grad1, grad3)

178
179
180
181
182

if __name__ == "__main__":
    test_inclusive_sum()
    test_exclusive_sum()
    test_inclusive_prod()
Ruilong Li's avatar
Ruilong Li committed
183
    test_exclusive_prod()