test_resampling.py 917 Bytes
Newer Older
1
2
3
import pytest
import torch

Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
4
from nerfacc import pack_info, ray_marching, ray_resampling
5
6
7
8
9
10
11
12
13
14
15

device = "cuda:0"
batch_size = 128


@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_resampling():
    rays_o = torch.rand((batch_size, 3), device=device)
    rays_d = torch.randn((batch_size, 3), device=device)
    rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)

Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
16
    ray_indices, t_starts, t_ends = ray_marching(
17
18
19
20
21
22
        rays_o,
        rays_d,
        near_plane=0.1,
        far_plane=1.0,
        render_step_size=1e-3,
    )
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
23
    packed_info = pack_info(ray_indices, n_rays=batch_size)
24
25
26
27
28
29
30
31
32
    weights = torch.rand((t_starts.shape[0],), device=device)
    packed_info, t_starts, t_ends = ray_resampling(
        packed_info, t_starts, t_ends, weights, n_samples=32
    )
    assert t_starts.shape == t_ends.shape == (batch_size * 32, 1)


if __name__ == "__main__":
    test_resampling()