utils.py 7.39 KB
Newer Older
1
2
3
4
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""

Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
5
import random
6
7
8
9
10
11
from typing import Optional, Sequence

try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
12
13
14
15

import numpy as np
import torch
from datasets.utils import Rays, namedtuple_map
16
from torch.utils.data._utils.collate import collate, default_collate_fn_map
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
17

18
19
20
from nerfacc.estimators.occ_grid import OccGridEstimator
from nerfacc.estimators.prop_net import PropNetEstimator
from nerfacc.volrend import rendering
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

NERF_SYNTHETIC_SCENES = [
    "chair",
    "drums",
    "ficus",
    "hotdog",
    "lego",
    "materials",
    "mic",
    "ship",
]
MIPNERF360_UNBOUNDED_SCENES = [
    "garden",
    "bicycle",
    "bonsai",
    "counter",
    "kitchen",
    "room",
    "stump",
]
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
41
42
43
44
45
46
47
48


def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


49
def render_image_with_occgrid(
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
50
51
    # scene
    radiance_field: torch.nn.Module,
52
    estimator: OccGridEstimator,
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
53
54
    rays: Rays,
    # rendering options
55
56
    near_plane: float = 0.0,
    far_plane: float = 1e10,
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
57
58
59
    render_step_size: float = 1e-3,
    render_bkgd: Optional[torch.Tensor] = None,
    cone_angle: float = 0.0,
60
    alpha_thre: float = 0.0,
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    # test options
    test_chunk_size: int = 8192,
    # only useful for dnerf
    timestamps: Optional[torch.Tensor] = None,
):
    """Render the pixels of an image."""
    rays_shape = rays.origins.shape
    if len(rays_shape) == 3:
        height, width, _ = rays_shape
        num_rays = height * width
        rays = namedtuple_map(
            lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays
        )
    else:
        num_rays, _ = rays_shape

    def sigma_fn(t_starts, t_ends, ray_indices):
        t_origins = chunk_rays.origins[ray_indices]
        t_dirs = chunk_rays.viewdirs[ray_indices]
80
        positions = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
81
82
83
84
85
86
87
        if timestamps is not None:
            # dnerf
            t = (
                timestamps[ray_indices]
                if radiance_field.training
                else timestamps.expand_as(positions[:, :1])
            )
88
89
90
91
            sigmas = radiance_field.query_density(positions, t)
        else:
            sigmas = radiance_field.query_density(positions)
        return sigmas.squeeze(-1)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
92
93
94
95

    def rgb_sigma_fn(t_starts, t_ends, ray_indices):
        t_origins = chunk_rays.origins[ray_indices]
        t_dirs = chunk_rays.viewdirs[ray_indices]
96
        positions = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
97
98
99
100
101
102
103
        if timestamps is not None:
            # dnerf
            t = (
                timestamps[ray_indices]
                if radiance_field.training
                else timestamps.expand_as(positions[:, :1])
            )
104
105
106
107
            rgbs, sigmas = radiance_field(positions, t, t_dirs)
        else:
            rgbs, sigmas = radiance_field(positions, t_dirs)
        return rgbs, sigmas.squeeze(-1)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
108
109
110
111
112
113
114
115
116

    results = []
    chunk = (
        torch.iinfo(torch.int32).max
        if radiance_field.training
        else test_chunk_size
    )
    for i in range(0, num_rays, chunk):
        chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
117
        ray_indices, t_starts, t_ends = estimator.sampling(
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
118
119
120
121
122
123
124
125
            chunk_rays.origins,
            chunk_rays.viewdirs,
            sigma_fn=sigma_fn,
            near_plane=near_plane,
            far_plane=far_plane,
            render_step_size=render_step_size,
            stratified=radiance_field.training,
            cone_angle=cone_angle,
126
            alpha_thre=alpha_thre,
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
127
        )
128
        rgb, opacity, depth, extras = rendering(
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
129
130
            t_starts,
            t_ends,
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
131
132
            ray_indices,
            n_rays=chunk_rays.origins.shape[0],
133
            rgb_sigma_fn=rgb_sigma_fn,
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
            render_bkgd=render_bkgd,
        )
        chunk_results = [rgb, opacity, depth, len(t_starts)]
        results.append(chunk_results)
    colors, opacities, depths, n_rendering_samples = [
        torch.cat(r, dim=0) if isinstance(r[0], torch.Tensor) else r
        for r in zip(*results)
    ]
    return (
        colors.view((*rays_shape[:-1], -1)),
        opacities.view((*rays_shape[:-1], -1)),
        depths.view((*rays_shape[:-1], -1)),
        sum(n_rendering_samples),
    )
148
149


150
def render_image_with_propnet(
151
152
153
    # scene
    radiance_field: torch.nn.Module,
    proposal_networks: Sequence[torch.nn.Module],
154
    estimator: PropNetEstimator,
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    rays: Rays,
    # rendering options
    num_samples: int,
    num_samples_per_prop: Sequence[int],
    near_plane: Optional[float] = None,
    far_plane: Optional[float] = None,
    sampling_type: Literal["uniform", "lindisp"] = "lindisp",
    opaque_bkgd: bool = True,
    render_bkgd: Optional[torch.Tensor] = None,
    # train options
    proposal_requires_grad: bool = False,
    # test options
    test_chunk_size: int = 8192,
):
    """Render the pixels of an image."""
    rays_shape = rays.origins.shape
    if len(rays_shape) == 3:
        height, width, _ = rays_shape
        num_rays = height * width
        rays = namedtuple_map(
            lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays
        )
    else:
        num_rays, _ = rays_shape

    def prop_sigma_fn(t_starts, t_ends, proposal_network):
        t_origins = chunk_rays.origins[..., None, :]
        t_dirs = chunk_rays.viewdirs[..., None, :]
183
184
185
186
187
        positions = t_origins + t_dirs * (t_starts + t_ends)[..., None] / 2.0
        sigmas = proposal_network(positions)
        if opaque_bkgd:
            sigmas[..., -1, :] = torch.inf
        return sigmas.squeeze(-1)
188

189
    def rgb_sigma_fn(t_starts, t_ends, ray_indices):
190
191
        t_origins = chunk_rays.origins[..., None, :]
        t_dirs = chunk_rays.viewdirs[..., None, :].repeat_interleave(
192
            t_starts.shape[-1], dim=-2
193
        )
194
195
196
197
198
        positions = t_origins + t_dirs * (t_starts + t_ends)[..., None] / 2.0
        rgb, sigmas = radiance_field(positions, t_dirs)
        if opaque_bkgd:
            sigmas[..., -1, :] = torch.inf
        return rgb, sigmas.squeeze(-1)
199
200
201
202
203
204
205
206
207

    results = []
    chunk = (
        torch.iinfo(torch.int32).max
        if radiance_field.training
        else test_chunk_size
    )
    for i in range(0, num_rays, chunk):
        chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
208
        t_starts, t_ends = estimator.sampling(
209
210
211
            prop_sigma_fns=[
                lambda *args: prop_sigma_fn(*args, p) for p in proposal_networks
            ],
212
213
214
            prop_samples=num_samples_per_prop,
            num_samples=num_samples,
            n_rays=chunk_rays.origins.shape[0],
215
216
217
            near_plane=near_plane,
            far_plane=far_plane,
            sampling_type=sampling_type,
218
219
220
221
222
223
224
225
226
            stratified=radiance_field.training,
            requires_grad=proposal_requires_grad,
        )
        rgb, opacity, depth, extras = rendering(
            t_starts,
            t_ends,
            ray_indices=None,
            n_rays=None,
            rgb_sigma_fn=rgb_sigma_fn,
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
            render_bkgd=render_bkgd,
        )
        chunk_results = [rgb, opacity, depth]
        results.append(chunk_results)

    colors, opacities, depths = collate(
        results,
        collate_fn_map={
            **default_collate_fn_map,
            torch.Tensor: lambda x, **_: torch.cat(x, 0),
        },
    )
    return (
        colors.view((*rays_shape[:-1], -1)),
        opacities.view((*rays_shape[:-1], -1)),
        depths.view((*rays_shape[:-1], -1)),
243
        extras,
244
    )