test_pdf_query.py 2.05 KB
Newer Older
Ruilong Li's avatar
Ruilong Li committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from struct import pack

import torch
from torch import Tensor

from nerfacc import ray_marching, ray_resampling
from nerfacc.cuda import ray_pdf_query

device = "cuda:0"


def outer(
    t0_starts: Tensor,
    t0_ends: Tensor,
    t1_starts: Tensor,
    t1_ends: Tensor,
    y1: Tensor,
) -> Tensor:
    cy1 = torch.cat(
        [torch.zeros_like(y1[..., :1]), torch.cumsum(y1, dim=-1)], dim=-1
    )

    idx_lo = (
        torch.searchsorted(
            t1_starts.contiguous(), t0_starts.contiguous(), side="right"
        )
        - 1
    )
    idx_lo = torch.clamp(idx_lo, min=0, max=y1.shape[-1] - 1)
    idx_hi = torch.searchsorted(
        t1_ends.contiguous(), t0_ends.contiguous(), side="right"
    )
    idx_hi = torch.clamp(idx_hi, min=0, max=y1.shape[-1] - 1)
    cy1_lo = torch.take_along_dim(cy1[..., :-1], idx_lo, dim=-1)
    cy1_hi = torch.take_along_dim(cy1[..., 1:], idx_hi, dim=-1)
    y0_outer = cy1_hi - cy1_lo

    return y0_outer


def test_pdf_query():
    n_rays = 1
    rays_o = torch.rand((n_rays, 3), device=device)
    rays_d = torch.randn((n_rays, 3), device=device)
    rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)

    packed_info, t_starts, t_ends = ray_marching(
        rays_o,
        rays_d,
        near_plane=0.1,
        far_plane=1.0,
        render_step_size=0.2,
    )
    weights = torch.rand((t_starts.shape[0],), device=device)

    packed_info_new = packed_info
    t_starts_new = t_starts - 0.3
    t_ends_new = t_ends - 0.3

    weights_new_ref = outer(
        t_starts_new.reshape(n_rays, -1),
        t_ends_new.reshape(n_rays, -1),
        t_starts.reshape(n_rays, -1),
        t_ends.reshape(n_rays, -1),
        weights.reshape(n_rays, -1),
    )
    weights_new_ref = weights_new_ref.flatten()

    weights_new = ray_pdf_query(
        packed_info,
        t_starts,
        t_ends,
        weights,
        packed_info_new,
        t_starts_new,
        t_ends_new,
    )
    weights_new = weights_new.flatten()
    print(weights)

    print(weights_new_ref)
    print(weights_new)


if __name__ == "__main__":
    test_pdf_query()