test_pdf.py 3.98 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import pytest
import torch

device = "cuda:0"


def _create_intervals(n_rays, n_samples, flat=False):
    from nerfacc.data_specs import RayIntervals

    torch.manual_seed(42)
    vals = torch.rand((n_rays, n_samples + 1), device=device)
    vals = torch.sort(vals, -1)[0]

    sample_masks = torch.rand((n_rays, n_samples), device=device) > 0.5
    is_lefts = torch.cat(
        [
            sample_masks,
            torch.zeros((n_rays, 1), device=device, dtype=torch.bool),
        ],
        dim=-1,
    )
    is_rights = torch.cat(
        [
            torch.zeros((n_rays, 1), device=device, dtype=torch.bool),
            sample_masks,
        ],
        dim=-1,
    )
    if not flat:
        return RayIntervals(vals=vals)
    else:
        interval_masks = is_lefts | is_rights
        vals = vals[interval_masks]
        is_lefts = is_lefts[interval_masks]
        is_rights = is_rights[interval_masks]
        chunk_cnts = (interval_masks).long().sum(-1)
        chunk_starts = torch.cumsum(chunk_cnts, 0) - chunk_cnts
        packed_info = torch.stack([chunk_starts, chunk_cnts], -1)

        return RayIntervals(
            vals, packed_info, is_left=is_lefts, is_right=is_rights
        )


@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_searchsorted():
    from nerfacc.data_specs import RayIntervals
    from nerfacc.pdf import searchsorted

    torch.manual_seed(42)
    query: RayIntervals = _create_intervals(10, 100, flat=False)
    key: RayIntervals = _create_intervals(10, 100, flat=False)

    ids_left, ids_right = searchsorted(key, query)
    y = key.vals.gather(-1, ids_right)

    _ids_right = torch.searchsorted(key.vals, query.vals, right=True)
    _ids_right = torch.clamp(_ids_right, 0, key.vals.shape[-1] - 1)
    _y = key.vals.gather(-1, _ids_right)

    assert torch.allclose(ids_right, _ids_right)
    assert torch.allclose(y, _y)


@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_importance_sampling():
    from nerfacc.data_specs import RayIntervals
    from nerfacc.pdf import _sample_from_weighted, importance_sampling

    torch.manual_seed(42)
    intervals: RayIntervals = _create_intervals(5, 100, flat=False)
    cdfs = torch.rand_like(intervals.vals)
    cdfs = torch.sort(cdfs, -1)[0]
    n_intervels_per_ray = 100
    stratified = False

    _intervals, _samples = importance_sampling(
        intervals,
        cdfs,
        n_intervels_per_ray,
        stratified,
    )

    for i in range(intervals.vals.shape[0]):
        _vals, _mids = _sample_from_weighted(
            intervals.vals[i : i + 1],
            cdfs[i : i + 1, 1:] - cdfs[i : i + 1, :-1],
            n_intervels_per_ray,
            stratified,
            intervals.vals[i].min(),
            intervals.vals[i].max(),
        )
        assert torch.allclose(_intervals.vals[i : i + 1], _vals, atol=1e-4)
        assert torch.allclose(_samples.vals[i : i + 1], _mids, atol=1e-4)


@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_pdf_loss():
    from nerfacc.data_specs import RayIntervals
    from nerfacc.estimators.prop_net import _lossfun_outer, _pdf_loss
    from nerfacc.pdf import _sample_from_weighted, importance_sampling

    torch.manual_seed(42)
    intervals: RayIntervals = _create_intervals(5, 100, flat=False)
    cdfs = torch.rand_like(intervals.vals)
    cdfs = torch.sort(cdfs, -1)[0]
    n_intervels_per_ray = 10
    stratified = False

    _intervals, _samples = importance_sampling(
        intervals,
        cdfs,
        n_intervels_per_ray,
        stratified,
    )
    _cdfs = torch.rand_like(_intervals.vals)
    _cdfs = torch.sort(_cdfs, -1)[0]

    loss = _pdf_loss(intervals, cdfs, _intervals, _cdfs)

    loss2 = _lossfun_outer(
        intervals.vals,
        cdfs[:, 1:] - cdfs[:, :-1],
        _intervals.vals,
        _cdfs[:, 1:] - _cdfs[:, :-1],
    )
    assert torch.allclose(loss, loss2, atol=1e-4)


if __name__ == "__main__":
    test_importance_sampling()
    test_searchsorted()
    test_pdf_loss()