test_raysampler.py 3.76 KB
Newer Older
1
# Copyright (c) Meta Platforms, Inc. and affiliates.
Patrick Labatut's avatar
Patrick Labatut committed
2
3
4
5
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
David Novotny's avatar
David Novotny committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76

import unittest

import torch
from nerf.raysampler import NeRFRaysampler, ProbabilisticRaysampler
from pytorch3d.renderer import PerspectiveCameras
from pytorch3d.transforms.rotation_conversions import random_rotations


class TestRaysampler(unittest.TestCase):
    def setUp(self) -> None:
        torch.manual_seed(42)

    def test_raysampler_caching(self, batch_size=10):
        """
        Tests the consistency of the NeRF raysampler caching.
        """

        raysampler = NeRFRaysampler(
            min_x=0.0,
            max_x=10.0,
            min_y=0.0,
            max_y=10.0,
            n_pts_per_ray=10,
            min_depth=0.1,
            max_depth=10.0,
            n_rays_per_image=12,
            image_width=10,
            image_height=10,
            stratified=False,
            stratified_test=False,
            invert_directions=True,
        )

        raysampler.eval()

        cameras, rays = [], []

        for _ in range(batch_size):

            R = random_rotations(1)
            T = torch.randn(1, 3)
            focal_length = torch.rand(1, 2) + 0.5
            principal_point = torch.randn(1, 2)

            camera = PerspectiveCameras(
                focal_length=focal_length,
                principal_point=principal_point,
                R=R,
                T=T,
            )

            cameras.append(camera)
            rays.append(raysampler(camera))

        raysampler.precache_rays(cameras, list(range(batch_size)))

        for cam_index, rays_ in enumerate(rays):
            rays_cached_ = raysampler(
                cameras=cameras[cam_index],
                chunksize=None,
                chunk_idx=0,
                camera_hash=cam_index,
                caching=False,
            )

            for v, v_cached in zip(rays_, rays_cached_):
                self.assertTrue(torch.allclose(v, v_cached))

    def test_probabilistic_raysampler(self, batch_size=1, n_pts_per_ray=60):
        """
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
77
        Check that the probabilistic ray sampler does not crash for various
David Novotny's avatar
David Novotny committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        settings.
        """

        raysampler_grid = NeRFRaysampler(
            min_x=0.0,
            max_x=10.0,
            min_y=0.0,
            max_y=10.0,
            n_pts_per_ray=n_pts_per_ray,
            min_depth=1.0,
            max_depth=10.0,
            n_rays_per_image=12,
            image_width=10,
            image_height=10,
            stratified=False,
            stratified_test=False,
            invert_directions=True,
        )

        R = random_rotations(batch_size)
        T = torch.randn(batch_size, 3)
        focal_length = torch.rand(batch_size, 2) + 0.5
        principal_point = torch.randn(batch_size, 2)
        camera = PerspectiveCameras(
            focal_length=focal_length,
            principal_point=principal_point,
            R=R,
            T=T,
        )

        raysampler_grid.eval()

        ray_bundle = raysampler_grid(cameras=camera)

        ray_weights = torch.rand_like(ray_bundle.lengths)

        # Just check that we dont crash for all possible settings.
        for stratified_test in (True, False):
            for stratified in (True, False):
                raysampler_prob = ProbabilisticRaysampler(
                    n_pts_per_ray=n_pts_per_ray,
                    stratified=stratified,
                    stratified_test=stratified_test,
                    add_input_samples=True,
                )
                for mode in ("train", "eval"):
                    getattr(raysampler_prob, mode)()
                    for _ in range(10):
                        raysampler_prob(ray_bundle, ray_weights)