test_scan.py 6.61 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import pytest
import torch

device = "cuda:0"


@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_inclusive_sum():
    from nerfacc.scan import inclusive_sum

    torch.manual_seed(42)

    data = torch.rand((5, 1000), device=device, requires_grad=True)
    outputs1 = inclusive_sum(data)
    outputs1 = outputs1.flatten()
    outputs1.sum().backward()
    grad1 = data.grad.clone()
    data.grad.zero_()

    chunk_starts = torch.arange(
        0, data.numel(), data.shape[1], device=device, dtype=torch.long
    )
    chunk_cnts = torch.full(
        (data.shape[0],), data.shape[1], dtype=torch.long, device=device
    )
    packed_info = torch.stack([chunk_starts, chunk_cnts], dim=-1)
    flatten_data = data.flatten()
    outputs2 = inclusive_sum(flatten_data, packed_info=packed_info)
    outputs2.sum().backward()
    grad2 = data.grad.clone()
Ruilong Li's avatar
Ruilong Li committed
31
32
33
34
35
    data.grad.zero_()

    indices = torch.arange(data.shape[0], device=device, dtype=torch.long)
    indices = indices.repeat_interleave(data.shape[1])
    indices = indices.flatten()
Ruilong Li's avatar
Ruilong Li committed
36
    outputs3 = inclusive_sum(flatten_data, indices=indices)
Ruilong Li's avatar
Ruilong Li committed
37
38
39
    outputs3.sum().backward()
    grad3 = data.grad.clone()
    data.grad.zero_()
40
41
42
43

    assert torch.allclose(outputs1, outputs2)
    assert torch.allclose(grad1, grad2)

Ruilong Li's avatar
Ruilong Li committed
44
45
46
    assert torch.allclose(outputs1, outputs3)
    assert torch.allclose(grad1, grad3)

47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_exclusive_sum():
    from nerfacc.scan import exclusive_sum

    torch.manual_seed(42)

    data = torch.rand((5, 1000), device=device, requires_grad=True)
    outputs1 = exclusive_sum(data)
    outputs1 = outputs1.flatten()
    outputs1.sum().backward()
    grad1 = data.grad.clone()
    data.grad.zero_()

    chunk_starts = torch.arange(
        0, data.numel(), data.shape[1], device=device, dtype=torch.long
    )
    chunk_cnts = torch.full(
        (data.shape[0],), data.shape[1], dtype=torch.long, device=device
    )
    packed_info = torch.stack([chunk_starts, chunk_cnts], dim=-1)
    flatten_data = data.flatten()
    outputs2 = exclusive_sum(flatten_data, packed_info=packed_info)
    outputs2.sum().backward()
    grad2 = data.grad.clone()
Ruilong Li's avatar
Ruilong Li committed
72
73
74
75
76
    data.grad.zero_()

    indices = torch.arange(data.shape[0], device=device, dtype=torch.long)
    indices = indices.repeat_interleave(data.shape[1])
    indices = indices.flatten()
Ruilong Li's avatar
Ruilong Li committed
77
    outputs3 = exclusive_sum(flatten_data, indices=indices)
Ruilong Li's avatar
Ruilong Li committed
78
79
80
    outputs3.sum().backward()
    grad3 = data.grad.clone()
    data.grad.zero_()
81
82
83
84
85
86

    # TODO: check exclusive sum. numeric error?
    # print((outputs1 - outputs2).abs().max())  # 0.0002
    assert torch.allclose(outputs1, outputs2, atol=3e-4)
    assert torch.allclose(grad1, grad2)

Ruilong Li's avatar
Ruilong Li committed
87
88
89
    assert torch.allclose(outputs1, outputs3, atol=3e-4)
    assert torch.allclose(grad1, grad3)

90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_inclusive_prod():
    from nerfacc.scan import inclusive_prod

    torch.manual_seed(42)

    data = torch.rand((5, 1000), device=device, requires_grad=True)
    outputs1 = inclusive_prod(data)
    outputs1 = outputs1.flatten()
    outputs1.sum().backward()
    grad1 = data.grad.clone()
    data.grad.zero_()

    chunk_starts = torch.arange(
        0, data.numel(), data.shape[1], device=device, dtype=torch.long
    )
    chunk_cnts = torch.full(
        (data.shape[0],), data.shape[1], dtype=torch.long, device=device
    )
    packed_info = torch.stack([chunk_starts, chunk_cnts], dim=-1)
    flatten_data = data.flatten()
    outputs2 = inclusive_prod(flatten_data, packed_info=packed_info)
    outputs2.sum().backward()
    grad2 = data.grad.clone()
Ruilong Li's avatar
Ruilong Li committed
115
116
117
118
119
    data.grad.zero_()

    indices = torch.arange(data.shape[0], device=device, dtype=torch.long)
    indices = indices.repeat_interleave(data.shape[1])
    indices = indices.flatten()
Ruilong Li's avatar
Ruilong Li committed
120
    outputs3 = inclusive_prod(flatten_data, indices=indices)
Ruilong Li's avatar
Ruilong Li committed
121
122
123
    outputs3.sum().backward()
    grad3 = data.grad.clone()
    data.grad.zero_()
124
125
126
127

    assert torch.allclose(outputs1, outputs2)
    assert torch.allclose(grad1, grad2)

Ruilong Li's avatar
Ruilong Li committed
128
129
130
    assert torch.allclose(outputs1, outputs3)
    assert torch.allclose(grad1, grad3)

131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155

@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_exclusive_prod():
    from nerfacc.scan import exclusive_prod

    torch.manual_seed(42)

    data = torch.rand((5, 1000), device=device, requires_grad=True)
    outputs1 = exclusive_prod(data)
    outputs1 = outputs1.flatten()
    outputs1.sum().backward()
    grad1 = data.grad.clone()
    data.grad.zero_()

    chunk_starts = torch.arange(
        0, data.numel(), data.shape[1], device=device, dtype=torch.long
    )
    chunk_cnts = torch.full(
        (data.shape[0],), data.shape[1], dtype=torch.long, device=device
    )
    packed_info = torch.stack([chunk_starts, chunk_cnts], dim=-1)
    flatten_data = data.flatten()
    outputs2 = exclusive_prod(flatten_data, packed_info=packed_info)
    outputs2.sum().backward()
    grad2 = data.grad.clone()
Ruilong Li's avatar
Ruilong Li committed
156
157
158
159
160
    data.grad.zero_()

    indices = torch.arange(data.shape[0], device=device, dtype=torch.long)
    indices = indices.repeat_interleave(data.shape[1])
    indices = indices.flatten()
Ruilong Li's avatar
Ruilong Li committed
161
    outputs3 = exclusive_prod(flatten_data, indices=indices)
Ruilong Li's avatar
Ruilong Li committed
162
163
164
    outputs3.sum().backward()
    grad3 = data.grad.clone()
    data.grad.zero_()
165
166
167
168
169
170

    # TODO: check exclusive sum. numeric error?
    # print((outputs1 - outputs2).abs().max())
    assert torch.allclose(outputs1, outputs2)
    assert torch.allclose(grad1, grad2)

Ruilong Li's avatar
Ruilong Li committed
171
172
173
    assert torch.allclose(outputs1, outputs3)
    assert torch.allclose(grad1, grad3)

Ruilong Li's avatar
Ruilong Li committed
174

Ruilong Li's avatar
Ruilong Li committed
175
176
def profile():
    import tqdm
Ruilong Li's avatar
Ruilong Li committed
177

Ruilong Li's avatar
Ruilong Li committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    from nerfacc.scan import inclusive_sum

    torch.manual_seed(42)

    data = torch.rand((8192, 8192), device=device, requires_grad=True)

    chunk_starts = torch.arange(
        0, data.numel(), data.shape[1], device=device, dtype=torch.long
    )
    chunk_cnts = torch.full(
        (data.shape[0],), data.shape[1], dtype=torch.long, device=device
    )
    packed_info = torch.stack([chunk_starts, chunk_cnts], dim=-1)
    flatten_data = data.flatten()
    torch.cuda.synchronize()
    for _ in tqdm.trange(2000):
        outputs2 = inclusive_sum(flatten_data, packed_info=packed_info)
        outputs2.sum().backward()

    indices = torch.arange(data.shape[0], device=device, dtype=torch.long)
    indices = indices.repeat_interleave(data.shape[1])
    indices = indices.flatten()
    torch.cuda.synchronize()
    for _ in tqdm.trange(2000):
Ruilong Li's avatar
Ruilong Li committed
202
        outputs3 = inclusive_sum(flatten_data, indices=indices)
Ruilong Li's avatar
Ruilong Li committed
203
204
        outputs3.sum().backward()

205
206
207
208
209

if __name__ == "__main__":
    test_inclusive_sum()
    test_exclusive_sum()
    test_inclusive_prod()
Ruilong Li's avatar
Ruilong Li committed
210
    test_exclusive_prod()
Ruilong Li's avatar
Ruilong Li committed
211
    profile()