test_raysampler.py 3.63 KB
Newer Older
David Novotny's avatar
David Novotny committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

import unittest

import torch
from nerf.raysampler import NeRFRaysampler, ProbabilisticRaysampler
from pytorch3d.renderer import PerspectiveCameras
from pytorch3d.transforms.rotation_conversions import random_rotations


class TestRaysampler(unittest.TestCase):
    def setUp(self) -> None:
        torch.manual_seed(42)

    def test_raysampler_caching(self, batch_size=10):
        """
        Tests the consistency of the NeRF raysampler caching.
        """

        raysampler = NeRFRaysampler(
            min_x=0.0,
            max_x=10.0,
            min_y=0.0,
            max_y=10.0,
            n_pts_per_ray=10,
            min_depth=0.1,
            max_depth=10.0,
            n_rays_per_image=12,
            image_width=10,
            image_height=10,
            stratified=False,
            stratified_test=False,
            invert_directions=True,
        )

        raysampler.eval()

        cameras, rays = [], []

        for _ in range(batch_size):

            R = random_rotations(1)
            T = torch.randn(1, 3)
            focal_length = torch.rand(1, 2) + 0.5
            principal_point = torch.randn(1, 2)

            camera = PerspectiveCameras(
                focal_length=focal_length,
                principal_point=principal_point,
                R=R,
                T=T,
            )

            cameras.append(camera)
            rays.append(raysampler(camera))

        raysampler.precache_rays(cameras, list(range(batch_size)))

        for cam_index, rays_ in enumerate(rays):
            rays_cached_ = raysampler(
                cameras=cameras[cam_index],
                chunksize=None,
                chunk_idx=0,
                camera_hash=cam_index,
                caching=False,
            )

            for v, v_cached in zip(rays_, rays_cached_):
                self.assertTrue(torch.allclose(v, v_cached))

    def test_probabilistic_raysampler(self, batch_size=1, n_pts_per_ray=60):
        """
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
73
        Check that the probabilistic ray sampler does not crash for various
David Novotny's avatar
David Novotny committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        settings.
        """

        raysampler_grid = NeRFRaysampler(
            min_x=0.0,
            max_x=10.0,
            min_y=0.0,
            max_y=10.0,
            n_pts_per_ray=n_pts_per_ray,
            min_depth=1.0,
            max_depth=10.0,
            n_rays_per_image=12,
            image_width=10,
            image_height=10,
            stratified=False,
            stratified_test=False,
            invert_directions=True,
        )

        R = random_rotations(batch_size)
        T = torch.randn(batch_size, 3)
        focal_length = torch.rand(batch_size, 2) + 0.5
        principal_point = torch.randn(batch_size, 2)
        camera = PerspectiveCameras(
            focal_length=focal_length,
            principal_point=principal_point,
            R=R,
            T=T,
        )

        raysampler_grid.eval()

        ray_bundle = raysampler_grid(cameras=camera)

        ray_weights = torch.rand_like(ray_bundle.lengths)

        # Just check that we dont crash for all possible settings.
        for stratified_test in (True, False):
            for stratified in (True, False):
                raysampler_prob = ProbabilisticRaysampler(
                    n_pts_per_ray=n_pts_per_ray,
                    stratified=stratified,
                    stratified_test=stratified_test,
                    add_input_samples=True,
                )
                for mode in ("train", "eval"):
                    getattr(raysampler_prob, mode)()
                    for _ in range(10):
                        raysampler_prob(ray_bundle, ray_weights)