test_resampling.py 5.01 KB
Newer Older
1
2
import pytest
import torch
Ruilong Li's avatar
Ruilong Li committed
3
from functorch import vmap
4

Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
5
from nerfacc import pack_info, ray_marching, ray_resampling
Ruilong Li's avatar
Ruilong Li committed
6
from nerfacc.cuda import ray_pdf_query
7
8
9

device = "cuda:0"
batch_size = 128
Ruilong Li's avatar
Ruilong Li committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
eps = torch.finfo(torch.float32).eps


def _interp(x, xp, fp):
    """One-dimensional linear interpolation for monotonically increasing sample
    points.

    Returns the one-dimensional piecewise linear interpolant to a function with
    given discrete data points :math:`(xp, fp)`, evaluated at :math:`x`.

    Args:
        x: the :math:`x`-coordinates at which to evaluate the interpolated
            values.
        xp: the :math:`x`-coordinates of the data points, must be increasing.
        fp: the :math:`y`-coordinates of the data points, same length as `xp`.

    Returns:
        the interpolated values, same size as `x`.
    """
    xp = xp.contiguous()
    x = x.contiguous()
    m = (fp[1:] - fp[:-1]) / (xp[1:] - xp[:-1])
    b = fp[:-1] - (m * xp[:-1])
    indices = torch.searchsorted(xp, x, right=True) - 1
    indices = torch.clamp(indices, 0, len(m) - 1)
    return m[indices] * x + b[indices]


def _integrate_weights(w):
    """Compute the cumulative sum of w, assuming all weight vectors sum to 1.

    The output's size on the last dimension is one greater than that of the input,
    because we're computing the integral corresponding to the endpoints of a step
    function, not the integral of the interior/bin values.

    Args:
      w: Tensor, which will be integrated along the last axis. This is assumed to
        sum to 1 along the last axis, and this function will (silently) break if
        that is not the case.

    Returns:
      cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1
    """
    cw = torch.clamp(torch.cumsum(w[..., :-1], dim=-1), max=1)
    shape = cw.shape[:-1] + (1,)
    # Ensure that the CDF starts with exactly 0 and ends with exactly 1.
    zeros = torch.zeros(shape, device=w.device)
    ones = torch.ones(shape, device=w.device)
    cw0 = torch.cat([zeros, cw, ones], dim=-1)
    return cw0


def _invert_cdf(u, t, w_logits):
    """Invert the CDF defined by (t, w) at the points specified by u in [0, 1)."""
    # Compute the PDF and CDF for each weight vector.
    w = torch.softmax(w_logits, dim=-1)
    # w = torch.exp(w_logits)
    # w = w / torch.sum(w, dim=-1, keepdim=True)
    cw = _integrate_weights(w)
    # Interpolate into the inverse CDF.
    t_new = vmap(_interp)(u, cw, t)
    return t_new


def _resampling(t, w_logits, num_samples):
    """Piecewise-Constant PDF sampling from a step function.

    Args:
        t: [..., num_bins + 1], bin endpoint coordinates (must be sorted).
        w_logits: [..., num_bins], logits corresponding to bin weights.
        num_samples: int, the number of samples.

    returns:
        t_samples: [..., num_samples], the sampled t values
    """
    pad = 1 / (2 * num_samples)
    u = torch.linspace(pad, 1.0 - pad - eps, num_samples, device=device)
    u = torch.broadcast_to(u, t.shape[:-1] + (num_samples,))
    return _invert_cdf(u, t, w_logits)
89
90
91
92


@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_resampling():
Ruilong Li's avatar
Ruilong Li committed
93
94
95
    batch_size = 1024
    num_bins = 128
    num_samples = 128
96

Ruilong Li's avatar
Ruilong Li committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    t = torch.randn((batch_size, num_bins + 1), device=device)
    t = torch.sort(t, dim=-1).values
    w_logits = torch.randn((batch_size, num_bins), device=device) * 0.1
    w = torch.softmax(w_logits, dim=-1)
    masks = w_logits > 0
    w_logits[~masks] = -torch.inf

    t_samples = _resampling(t, w_logits, num_samples + 1)

    t_starts = t[:, :-1][masks].unsqueeze(-1)
    t_ends = t[:, 1:][masks].unsqueeze(-1)
    w_logits = w_logits[masks]
    w = w[masks]
    num_steps = masks.long().sum(dim=-1)
    cum_steps = torch.cumsum(num_steps, dim=0)
    packed_info = torch.stack([cum_steps - num_steps, num_steps], dim=-1).int()

    _, t_starts, t_ends = ray_resampling(
        packed_info, t_starts, t_ends, w, num_samples
116
    )
Ruilong Li's avatar
Ruilong Li committed
117
118
119
120
121
122
123
124
125
126
127
128

    # print(
    #     (t_starts.view(batch_size, num_samples) - t_samples[:, :-1])
    #     .abs()
    #     .max(),
    #     (t_ends.view(batch_size, num_samples) - t_samples[:, 1:]).abs().max(),
    # )
    assert torch.allclose(
        t_starts.view(batch_size, num_samples), t_samples[:, :-1], atol=1e-3
    )
    assert torch.allclose(
        t_ends.view(batch_size, num_samples), t_samples[:, 1:], atol=1e-3
129
130
131
    )


Ruilong Li's avatar
Ruilong Li committed
132
133
134
135
136
def test_pdf_query():
    rays_o = torch.rand((1, 3), device=device)
    rays_d = torch.randn((1, 3), device=device)
    rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)

Ruilong Li's avatar
Ruilong Li committed
137
    ray_indices, t_starts, t_ends = ray_marching(
Ruilong Li's avatar
Ruilong Li committed
138
139
140
141
142
143
        rays_o,
        rays_d,
        near_plane=0.1,
        far_plane=1.0,
        render_step_size=0.2,
    )
Ruilong Li's avatar
Ruilong Li committed
144
145
    packed_info = pack_info(ray_indices, rays_o.shape[0])

Ruilong Li's avatar
Ruilong Li committed
146
147
148
149
150
151
152
153
154
155
156
    weights = torch.rand((t_starts.shape[0],), device=device)
    weights_new = ray_pdf_query(
        packed_info,
        t_starts,
        t_ends,
        weights,
        packed_info,
        t_starts + 0.3,
        t_ends + 0.3,
    )

Ruilong Li's avatar
Ruilong Li committed
157

158
159
if __name__ == "__main__":
    test_resampling()
Ruilong Li's avatar
Ruilong Li committed
160
    test_pdf_query()