test_resampling.py 917 Bytes
Newer Older
1
2
3
import pytest
import torch

Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
4
from nerfacc import pack_info, ray_marching, ray_resampling
5
6
7
8
9
10
11

device = "cuda:0"
batch_size = 128


@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_resampling():
12
13
    rays_o = torch.rand((batch_size, 3), device=device)
    rays_d = torch.randn((batch_size, 3), device=device)
Ruilong Li's avatar
Ruilong Li committed
14
15
    rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)

Ruilong Li's avatar
Ruilong Li committed
16
    ray_indices, t_starts, t_ends = ray_marching(
Ruilong Li's avatar
Ruilong Li committed
17
18
19
20
        rays_o,
        rays_d,
        near_plane=0.1,
        far_plane=1.0,
21
        render_step_size=1e-3,
Ruilong Li's avatar
Ruilong Li committed
22
    )
23
24
25
26
    packed_info = pack_info(ray_indices, n_rays=batch_size)
    weights = torch.rand((t_starts.shape[0],), device=device)
    packed_info, t_starts, t_ends = ray_resampling(
        packed_info, t_starts, t_ends, weights, n_samples=32
Ruilong Li's avatar
Ruilong Li committed
27
    )
28
    assert t_starts.shape == t_ends.shape == (batch_size * 32, 1)
Ruilong Li's avatar
Ruilong Li committed
29

Ruilong Li's avatar
Ruilong Li committed
30

31
32
if __name__ == "__main__":
    test_resampling()