test_pdf_query.py 2.09 KB
Newer Older
Ruilong Li's avatar
Ruilong Li committed
1
2
3
import torch
from torch import Tensor

Ruilong Li's avatar
Ruilong Li committed
4
from nerfacc import pack_info, ray_marching, ray_resampling
Ruilong Li's avatar
Ruilong Li committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from nerfacc.cuda import ray_pdf_query

device = "cuda:0"


def outer(
    t0_starts: Tensor,
    t0_ends: Tensor,
    t1_starts: Tensor,
    t1_ends: Tensor,
    y1: Tensor,
) -> Tensor:
    cy1 = torch.cat(
        [torch.zeros_like(y1[..., :1]), torch.cumsum(y1, dim=-1)], dim=-1
    )

    idx_lo = (
        torch.searchsorted(
            t1_starts.contiguous(), t0_starts.contiguous(), side="right"
        )
        - 1
    )
    idx_lo = torch.clamp(idx_lo, min=0, max=y1.shape[-1] - 1)
    idx_hi = torch.searchsorted(
        t1_ends.contiguous(), t0_ends.contiguous(), side="right"
    )
    idx_hi = torch.clamp(idx_hi, min=0, max=y1.shape[-1] - 1)
    cy1_lo = torch.take_along_dim(cy1[..., :-1], idx_lo, dim=-1)
    cy1_hi = torch.take_along_dim(cy1[..., 1:], idx_hi, dim=-1)
    y0_outer = cy1_hi - cy1_lo

    return y0_outer


def test_pdf_query():
    n_rays = 1
    rays_o = torch.rand((n_rays, 3), device=device)
    rays_d = torch.randn((n_rays, 3), device=device)
    rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)

Ruilong Li's avatar
Ruilong Li committed
45
    ray_indices, t_starts, t_ends = ray_marching(
Ruilong Li's avatar
Ruilong Li committed
46
47
48
49
50
51
        rays_o,
        rays_d,
        near_plane=0.1,
        far_plane=1.0,
        render_step_size=0.2,
    )
Ruilong Li's avatar
Ruilong Li committed
52
    packed_info = pack_info(ray_indices, n_rays)
Ruilong Li's avatar
Ruilong Li committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    weights = torch.rand((t_starts.shape[0],), device=device)

    packed_info_new = packed_info
    t_starts_new = t_starts - 0.3
    t_ends_new = t_ends - 0.3

    weights_new_ref = outer(
        t_starts_new.reshape(n_rays, -1),
        t_ends_new.reshape(n_rays, -1),
        t_starts.reshape(n_rays, -1),
        t_ends.reshape(n_rays, -1),
        weights.reshape(n_rays, -1),
    )
    weights_new_ref = weights_new_ref.flatten()

    weights_new = ray_pdf_query(
        packed_info,
        t_starts,
        t_ends,
        weights,
        packed_info_new,
        t_starts_new,
        t_ends_new,
    )
    weights_new = weights_new.flatten()
Ruilong Li's avatar
Ruilong Li committed
78
    # print(weights)
Ruilong Li's avatar
Ruilong Li committed
79

Ruilong Li's avatar
Ruilong Li committed
80
81
    # print(weights_new_ref)
    # print(weights_new)
Ruilong Li's avatar
Ruilong Li committed
82
83
84
85


if __name__ == "__main__":
    test_pdf_query()