utils.py 7.3 KB
Newer Older
1
2
3
4
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""

Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
5
import random
6
from typing import Literal, Optional, Sequence
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
7
8
9
10

import numpy as np
import torch
from datasets.utils import Rays, namedtuple_map
11
from torch.utils.data._utils.collate import collate, default_collate_fn_map
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
12

13
14
15
from nerfacc.estimators.occ_grid import OccGridEstimator
from nerfacc.estimators.prop_net import PropNetEstimator
from nerfacc.volrend import rendering
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

NERF_SYNTHETIC_SCENES = [
    "chair",
    "drums",
    "ficus",
    "hotdog",
    "lego",
    "materials",
    "mic",
    "ship",
]
MIPNERF360_UNBOUNDED_SCENES = [
    "garden",
    "bicycle",
    "bonsai",
    "counter",
    "kitchen",
    "room",
    "stump",
]
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
36
37
38
39
40
41
42
43


def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


44
def render_image_with_occgrid(
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
45
46
    # scene
    radiance_field: torch.nn.Module,
47
    estimator: OccGridEstimator,
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
48
49
    rays: Rays,
    # rendering options
50
51
    near_plane: float = 0.0,
    far_plane: float = 1e10,
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
52
53
54
    render_step_size: float = 1e-3,
    render_bkgd: Optional[torch.Tensor] = None,
    cone_angle: float = 0.0,
55
    alpha_thre: float = 0.0,
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    # test options
    test_chunk_size: int = 8192,
    # only useful for dnerf
    timestamps: Optional[torch.Tensor] = None,
):
    """Render the pixels of an image."""
    rays_shape = rays.origins.shape
    if len(rays_shape) == 3:
        height, width, _ = rays_shape
        num_rays = height * width
        rays = namedtuple_map(
            lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays
        )
    else:
        num_rays, _ = rays_shape

    def sigma_fn(t_starts, t_ends, ray_indices):
        t_origins = chunk_rays.origins[ray_indices]
        t_dirs = chunk_rays.viewdirs[ray_indices]
75
        positions = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
76
77
78
79
80
81
82
        if timestamps is not None:
            # dnerf
            t = (
                timestamps[ray_indices]
                if radiance_field.training
                else timestamps.expand_as(positions[:, :1])
            )
83
84
85
86
            sigmas = radiance_field.query_density(positions, t)
        else:
            sigmas = radiance_field.query_density(positions)
        return sigmas.squeeze(-1)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
87
88
89
90

    def rgb_sigma_fn(t_starts, t_ends, ray_indices):
        t_origins = chunk_rays.origins[ray_indices]
        t_dirs = chunk_rays.viewdirs[ray_indices]
91
        positions = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
92
93
94
95
96
97
98
        if timestamps is not None:
            # dnerf
            t = (
                timestamps[ray_indices]
                if radiance_field.training
                else timestamps.expand_as(positions[:, :1])
            )
99
100
101
102
            rgbs, sigmas = radiance_field(positions, t, t_dirs)
        else:
            rgbs, sigmas = radiance_field(positions, t_dirs)
        return rgbs, sigmas.squeeze(-1)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
103
104
105
106
107
108
109
110
111

    results = []
    chunk = (
        torch.iinfo(torch.int32).max
        if radiance_field.training
        else test_chunk_size
    )
    for i in range(0, num_rays, chunk):
        chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
112
        ray_indices, t_starts, t_ends = estimator.sampling(
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
113
114
115
116
117
118
119
120
            chunk_rays.origins,
            chunk_rays.viewdirs,
            sigma_fn=sigma_fn,
            near_plane=near_plane,
            far_plane=far_plane,
            render_step_size=render_step_size,
            stratified=radiance_field.training,
            cone_angle=cone_angle,
121
            alpha_thre=alpha_thre,
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
122
        )
123
        rgb, opacity, depth, extras = rendering(
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
124
125
            t_starts,
            t_ends,
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
126
127
            ray_indices,
            n_rays=chunk_rays.origins.shape[0],
128
            rgb_sigma_fn=rgb_sigma_fn,
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
            render_bkgd=render_bkgd,
        )
        chunk_results = [rgb, opacity, depth, len(t_starts)]
        results.append(chunk_results)
    colors, opacities, depths, n_rendering_samples = [
        torch.cat(r, dim=0) if isinstance(r[0], torch.Tensor) else r
        for r in zip(*results)
    ]
    return (
        colors.view((*rays_shape[:-1], -1)),
        opacities.view((*rays_shape[:-1], -1)),
        depths.view((*rays_shape[:-1], -1)),
        sum(n_rendering_samples),
    )
143
144


145
def render_image_with_propnet(
146
147
148
    # scene
    radiance_field: torch.nn.Module,
    proposal_networks: Sequence[torch.nn.Module],
149
    estimator: PropNetEstimator,
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
    rays: Rays,
    # rendering options
    num_samples: int,
    num_samples_per_prop: Sequence[int],
    near_plane: Optional[float] = None,
    far_plane: Optional[float] = None,
    sampling_type: Literal["uniform", "lindisp"] = "lindisp",
    opaque_bkgd: bool = True,
    render_bkgd: Optional[torch.Tensor] = None,
    # train options
    proposal_requires_grad: bool = False,
    # test options
    test_chunk_size: int = 8192,
):
    """Render the pixels of an image."""
    rays_shape = rays.origins.shape
    if len(rays_shape) == 3:
        height, width, _ = rays_shape
        num_rays = height * width
        rays = namedtuple_map(
            lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays
        )
    else:
        num_rays, _ = rays_shape

    def prop_sigma_fn(t_starts, t_ends, proposal_network):
        t_origins = chunk_rays.origins[..., None, :]
        t_dirs = chunk_rays.viewdirs[..., None, :]
178
179
180
181
182
        positions = t_origins + t_dirs * (t_starts + t_ends)[..., None] / 2.0
        sigmas = proposal_network(positions)
        if opaque_bkgd:
            sigmas[..., -1, :] = torch.inf
        return sigmas.squeeze(-1)
183

184
    def rgb_sigma_fn(t_starts, t_ends, ray_indices):
185
186
        t_origins = chunk_rays.origins[..., None, :]
        t_dirs = chunk_rays.viewdirs[..., None, :].repeat_interleave(
187
            t_starts.shape[-1], dim=-2
188
        )
189
190
191
192
193
        positions = t_origins + t_dirs * (t_starts + t_ends)[..., None] / 2.0
        rgb, sigmas = radiance_field(positions, t_dirs)
        if opaque_bkgd:
            sigmas[..., -1, :] = torch.inf
        return rgb, sigmas.squeeze(-1)
194
195
196
197
198
199
200
201
202

    results = []
    chunk = (
        torch.iinfo(torch.int32).max
        if radiance_field.training
        else test_chunk_size
    )
    for i in range(0, num_rays, chunk):
        chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
203
        t_starts, t_ends = estimator.sampling(
204
205
206
            prop_sigma_fns=[
                lambda *args: prop_sigma_fn(*args, p) for p in proposal_networks
            ],
207
208
209
            prop_samples=num_samples_per_prop,
            num_samples=num_samples,
            n_rays=chunk_rays.origins.shape[0],
210
211
212
            near_plane=near_plane,
            far_plane=far_plane,
            sampling_type=sampling_type,
213
214
215
216
217
218
219
220
221
            stratified=radiance_field.training,
            requires_grad=proposal_requires_grad,
        )
        rgb, opacity, depth, extras = rendering(
            t_starts,
            t_ends,
            ray_indices=None,
            n_rays=None,
            rgb_sigma_fn=rgb_sigma_fn,
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
            render_bkgd=render_bkgd,
        )
        chunk_results = [rgb, opacity, depth]
        results.append(chunk_results)

    colors, opacities, depths = collate(
        results,
        collate_fn_map={
            **default_collate_fn_map,
            torch.Tensor: lambda x, **_: torch.cat(x, 0),
        },
    )
    return (
        colors.view((*rays_shape[:-1], -1)),
        opacities.view((*rays_shape[:-1], -1)),
        depths.view((*rays_shape[:-1], -1)),
238
        extras,
239
    )