trainval.py 12.8 KB
Newer Older
Ruilong Li's avatar
Ruilong Li committed
1
import argparse
Ruilong Li's avatar
Ruilong Li committed
2
import math
3
4
import os
import random
Ruilong Li's avatar
Ruilong Li committed
5
6
import time

7
import imageio
Ruilong Li's avatar
Ruilong Li committed
8
9
10
11
12
import numpy as np
import torch
import torch.nn.functional as F
import tqdm

13
from nerfacc import OccupancyField, volumetric_rendering_pipeline
Ruilong Li's avatar
Ruilong Li committed
14

15
16
17
18
19
20
21
device = "cuda:0"


def _set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
Ruilong Li's avatar
Ruilong Li committed
22

Ruilong Li's avatar
Ruilong Li committed
23

Ruilong Li's avatar
Ruilong Li committed
24
def render_image(
25
26
27
28
29
30
    radiance_field,
    rays,
    timestamps,
    render_bkgd,
    render_step_size,
    test_chunk_size=81920,
Ruilong Li's avatar
Ruilong Li committed
31
):
Ruilong Li's avatar
Ruilong Li committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    """Render the pixels of an image.

    Args:
      radiance_field: the radiance field of nerf.
      rays: a `Rays` namedtuple, the rays to be rendered.

    Returns:
      rgb: torch.tensor, rendered color image.
      depth: torch.tensor, rendered depth image.
      acc: torch.tensor, rendered accumulated weights per pixel.
    """
    rays_shape = rays.origins.shape
    if len(rays_shape) == 3:
        height, width, _ = rays_shape
        num_rays = height * width
        rays = namedtuple_map(lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays)
    else:
        num_rays, _ = rays_shape
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
50

51
52
53
54
    def sigma_fn(frustum_starts, frustum_ends, ray_indices):
        ray_indices = ray_indices.long()
        frustum_origins = chunk_rays.origins[ray_indices]
        frustum_dirs = chunk_rays.viewdirs[ray_indices]
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
55
56
57
        positions = (
            frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0
        )
58
59
60
61
62
63
64
65
66
67
68
69
70
        if timestamps is None:
            return radiance_field.query_density(positions)
        else:
            if radiance_field.training:
                t = timestamps[ray_indices]
            else:
                t = timestamps.expand_as(positions[:, :1])
            return radiance_field.query_density(positions, t)

    def rgb_sigma_fn(frustum_starts, frustum_ends, ray_indices):
        ray_indices = ray_indices.long()
        frustum_origins = chunk_rays.origins[ray_indices]
        frustum_dirs = chunk_rays.viewdirs[ray_indices]
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
71
72
73
        positions = (
            frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0
        )
74
75
76
77
78
79
80
81
        if timestamps is None:
            return radiance_field(positions, frustum_dirs)
        else:
            if radiance_field.training:
                t = timestamps[ray_indices]
            else:
                t = timestamps.expand_as(positions[:, :1])
            return radiance_field(positions, t, frustum_dirs)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
82

Ruilong Li's avatar
Ruilong Li committed
83
    results = []
Ruilong Li's avatar
Ruilong Li committed
84
    chunk = torch.iinfo(torch.int32).max if radiance_field.training else test_chunk_size
Ruilong Li's avatar
Ruilong Li committed
85
86
    for i in range(0, num_rays, chunk):
        chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
87
        chunk_results = volumetric_rendering_pipeline(
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
88
            sigma_fn=sigma_fn,
89
            rgb_sigma_fn=rgb_sigma_fn,
Ruilong Li's avatar
Ruilong Li committed
90
91
92
93
94
95
            rays_o=chunk_rays.origins,
            rays_d=chunk_rays.viewdirs,
            scene_aabb=occ_field.aabb,
            scene_occ_binary=occ_field.occ_grid_binary,
            scene_resolution=occ_field.resolution,
            render_bkgd=render_bkgd,
Ruilong Li's avatar
Ruilong Li committed
96
            render_step_size=render_step_size,
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
97
            near_plane=0.0,
98
            stratified=radiance_field.training,
Ruilong Li's avatar
Ruilong Li committed
99
        )
Ruilong Li's avatar
Ruilong Li committed
100
        results.append(chunk_results)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
101
    colors, opacities, n_marching_samples, n_rendering_samples = [
Ruilong Li's avatar
Ruilong Li committed
102
103
        torch.cat(r, dim=0) if isinstance(r[0], torch.Tensor) else r
        for r in zip(*results)
Ruilong Li's avatar
Ruilong Li committed
104
    ]
Ruilong Li's avatar
Ruilong Li committed
105
    return (
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
106
107
108
109
        colors.view((*rays_shape[:-1], -1)),
        opacities.view((*rays_shape[:-1], -1)),
        sum(n_marching_samples),
        sum(n_rendering_samples),
Ruilong Li's avatar
Ruilong Li committed
110
111
112
113
    )


if __name__ == "__main__":
114
    _set_random_seed(42)
Ruilong Li's avatar
Ruilong Li committed
115

Ruilong Li's avatar
Ruilong Li committed
116
117
118
119
120
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "method",
        type=str,
        default="ngp",
121
        choices=["ngp", "vanilla", "dnerf"],
Ruilong Li's avatar
Ruilong Li committed
122
123
124
125
126
127
128
129
130
        help="which nerf to use",
    )
    parser.add_argument(
        "--train_split",
        type=str,
        default="trainval",
        choices=["train", "trainval"],
        help="which train split to use",
    )
Ruilong Li's avatar
Ruilong Li committed
131
132
133
134
135
    parser.add_argument(
        "--scene",
        type=str,
        default="lego",
        choices=[
136
            # nerf synthetic
Ruilong Li's avatar
Ruilong Li committed
137
138
139
140
141
142
143
144
            "chair",
            "drums",
            "ficus",
            "hotdog",
            "lego",
            "materials",
            "mic",
            "ship",
145
146
147
148
149
150
151
152
153
            # dnerf
            "bouncingballs",
            "hellwarrior",
            "hook",
            "jumpingjacks",
            "lego",
            "mutant",
            "standup",
            "trex",
Ruilong Li's avatar
Ruilong Li committed
154
155
156
        ],
        help="which scene to use",
    )
157
158
159
160
161
    parser.add_argument(
        "--aabb",
        type=list,
        default=[-1.5, -1.5, -1.5, 1.5, 1.5, 1.5],
    )
Ruilong Li's avatar
Ruilong Li committed
162
163
164
165
166
    parser.add_argument(
        "--test_chunk_size",
        type=int,
        default=81920,
    )
Ruilong Li's avatar
Ruilong Li committed
167
168
    args = parser.parse_args()

169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    if args.method == "ngp":
        from datasets.nerf_synthetic import SubjectLoader, namedtuple_map
        from radiance_fields.ngp import NGPradianceField

        radiance_field = NGPradianceField(aabb=args.aabb).to(device)
        optimizer = torch.optim.Adam(radiance_field.parameters(), lr=1e-2, eps=1e-15)
        max_steps = 20000
        occ_field_warmup_steps = 256
        grad_scaler = torch.cuda.amp.GradScaler(2**10)
        data_root_fp = "/home/ruilongli/data/nerf_synthetic/"
        target_sample_batch_size = 1 << 18

    elif args.method == "vanilla":
        from datasets.nerf_synthetic import SubjectLoader, namedtuple_map
        from radiance_fields.mlp import VanillaNeRFRadianceField

        radiance_field = VanillaNeRFRadianceField().to(device)
        optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4)
        max_steps = 40000
        occ_field_warmup_steps = 2000
        grad_scaler = torch.cuda.amp.GradScaler(1)
        data_root_fp = "/home/ruilongli/data/nerf_synthetic/"
        target_sample_batch_size = 1 << 16

    elif args.method == "dnerf":
        from datasets.dnerf_synthetic import SubjectLoader, namedtuple_map
        from radiance_fields.mlp import DNeRFRadianceField

        radiance_field = DNeRFRadianceField().to(device)
        optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4)
        max_steps = 40000
        occ_field_warmup_steps = 2000
        grad_scaler = torch.cuda.amp.GradScaler(1)
        data_root_fp = "/home/ruilongli/data/dnerf/"
        target_sample_batch_size = 1 << 16

Ruilong Li's avatar
Ruilong Li committed
205
    scene = args.scene
Ruilong Li's avatar
Ruilong Li committed
206

Ruilong Li's avatar
Ruilong Li committed
207
    # setup the scene bounding box.
208
    scene_aabb = torch.tensor(args.aabb)
Ruilong Li's avatar
Ruilong Li committed
209
210
211
212
213
214
    # setup some rendering settings
    render_n_samples = 1024
    render_step_size = (
        (scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples
    ).item()

Ruilong Li's avatar
Ruilong Li committed
215
216
    # setup dataset
    train_dataset = SubjectLoader(
Ruilong Li's avatar
Ruilong Li committed
217
        subject_id=scene,
218
        root_fp=data_root_fp,
Ruilong Li's avatar
Ruilong Li committed
219
        split=args.train_split,
220
        num_rays=target_sample_batch_size // render_n_samples,
Ruilong Li's avatar
Ruilong Li committed
221
        # color_bkgd_aug="random",
Ruilong Li's avatar
Ruilong Li committed
222
    )
Ruilong Li's avatar
Ruilong Li committed
223
224
225
226

    train_dataset.images = train_dataset.images.to(device)
    train_dataset.camtoworlds = train_dataset.camtoworlds.to(device)
    train_dataset.K = train_dataset.K.to(device)
227
228
    if hasattr(train_dataset, "timestamps"):
        train_dataset.timestamps = train_dataset.timestamps.to(device)
Ruilong Li's avatar
Ruilong Li committed
229

Ruilong Li's avatar
Ruilong Li committed
230
    test_dataset = SubjectLoader(
Ruilong Li's avatar
Ruilong Li committed
231
        subject_id=scene,
232
        root_fp=data_root_fp,
Ruilong Li's avatar
Ruilong Li committed
233
        split="test",
Ruilong Li's avatar
Ruilong Li committed
234
235
        num_rays=None,
    )
Ruilong Li's avatar
Ruilong Li committed
236
237
238
    test_dataset.images = test_dataset.images.to(device)
    test_dataset.camtoworlds = test_dataset.camtoworlds.to(device)
    test_dataset.K = test_dataset.K.to(device)
239
240
    if hasattr(train_dataset, "timestamps"):
        test_dataset.timestamps = test_dataset.timestamps.to(device)
Ruilong Li's avatar
Ruilong Li committed
241

Ruilong Li's avatar
Ruilong Li committed
242
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
Ruilong Li's avatar
Ruilong Li committed
243
244
245
        optimizer,
        milestones=[max_steps // 2, max_steps * 3 // 4, max_steps * 9 // 10],
        gamma=0.33,
Ruilong Li's avatar
Ruilong Li committed
246
    )
Ruilong Li's avatar
Ruilong Li committed
247
248
249
250
251
252
253
254
255
256

    # setup occupancy field with eval function
    def occ_eval_fn(x: torch.Tensor) -> torch.Tensor:
        """Evaluate occupancy given positions.

        Args:
            x: positions with shape (N, 3).
        Returns:
            occupancy values with shape (N, 1).
        """
257
258
259
260
261
262
263
264
        if args.method == "dnerf":
            idxs = torch.randint(
                0, len(train_dataset.timestamps), (x.shape[0],), device=x.device
            )
            t = train_dataset.timestamps[idxs]
            density_after_activation = radiance_field.query_density(x, t)
        else:
            density_after_activation = radiance_field.query_density(x)
Ruilong Li's avatar
Ruilong Li committed
265
        # those two are similar when density is small.
266
        # occupancy = 1.0 - torch.exp(-density_after_activation * render_step_size)
Ruilong Li's avatar
Ruilong Li committed
267
268
269
270
271
272
273
274
275
276
        occupancy = density_after_activation * render_step_size
        return occupancy

    occ_field = OccupancyField(
        occ_eval_fn=occ_eval_fn, aabb=scene_aabb, resolution=128
    ).to(device)

    # training
    step = 0
    tic = time.time()
Ruilong Li's avatar
Ruilong Li committed
277
278
    data_time = 0
    tic_data = time.time()
Ruilong Li's avatar
wtf  
Ruilong Li committed
279

Ruilong Li's avatar
Ruilong Li committed
280
    for epoch in range(10000000):
Ruilong Li's avatar
Ruilong Li committed
281
        for i in range(len(train_dataset)):
Ruilong Li's avatar
Ruilong Li committed
282
            radiance_field.train()
Ruilong Li's avatar
Ruilong Li committed
283
            data = train_dataset[i]
Ruilong Li's avatar
Ruilong Li committed
284
            data_time += time.time() - tic_data
Ruilong Li's avatar
Ruilong Li committed
285

Ruilong Li's avatar
Ruilong Li committed
286
287
288
            render_bkgd = data["color_bkgd"]
            rays = data["rays"]
            pixels = data["pixels"]
289
            timestamps = data.get("timestamps", None)
Ruilong Li's avatar
Ruilong Li committed
290

Ruilong Li's avatar
Ruilong Li committed
291
            # update occupancy grid
Ruilong Li's avatar
Ruilong Li committed
292
            occ_field.every_n_step(step, warmup_steps=occ_field_warmup_steps)
Ruilong Li's avatar
wtf  
Ruilong Li committed
293

Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
294
            rgb, acc, counter, compact_counter = render_image(
295
                radiance_field, rays, timestamps, render_bkgd, render_step_size
Ruilong Li's avatar
readme  
Ruilong Li committed
296
            )
Ruilong Li's avatar
Ruilong Li committed
297
298
            num_rays = len(pixels)
            num_rays = int(
299
                num_rays * (target_sample_batch_size / float(compact_counter))
Ruilong Li's avatar
Ruilong Li committed
300
301
            )
            train_dataset.update_num_rays(num_rays)
Ruilong Li's avatar
Ruilong Li committed
302
            alive_ray_mask = acc.squeeze(-1) > 0
Ruilong Li's avatar
Ruilong Li committed
303

Ruilong Li's avatar
Ruilong Li committed
304
            # compute loss
Ruilong Li's avatar
Ruilong Li committed
305
            loss = F.smooth_l1_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
Ruilong Li's avatar
Ruilong Li committed
306

Ruilong Li's avatar
Ruilong Li committed
307
            optimizer.zero_grad()
Ruilong Li's avatar
Ruilong Li committed
308
309
            # do not unscale it because we are using Adam.
            grad_scaler.scale(loss).backward()
Ruilong Li's avatar
Ruilong Li committed
310
311
            optimizer.step()
            scheduler.step()
Ruilong Li's avatar
Ruilong Li committed
312

Ruilong Li's avatar
Ruilong Li committed
313
            if step % 100 == 0:
Ruilong Li's avatar
Ruilong Li committed
314
                elapsed_time = time.time() - tic
Ruilong Li's avatar
Ruilong Li committed
315
                loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
Ruilong Li's avatar
Ruilong Li committed
316
                print(
Ruilong Li's avatar
Ruilong Li committed
317
                    f"elapsed_time={elapsed_time:.2f}s (data={data_time:.2f}s) | {step=} | "
Ruilong Li's avatar
Ruilong Li committed
318
319
                    f"loss={loss:.5f} | "
                    f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
Ruilong Li's avatar
Ruilong Li committed
320
                    f"counter={counter:d} | compact_counter={compact_counter:d} | num_rays={len(pixels):d} |"
Ruilong Li's avatar
Ruilong Li committed
321
322
                )

Ruilong Li's avatar
Ruilong Li committed
323
            # if time.time() - tic > 300:
324
            if step >= 0 and step % max_steps == 0 and step > 0:
Ruilong Li's avatar
Ruilong Li committed
325
326
                # evaluation
                radiance_field.eval()
Ruilong Li's avatar
Ruilong Li committed
327

Ruilong Li's avatar
Ruilong Li committed
328
329
                psnrs = []
                with torch.no_grad():
330
331
332
333
334
335
336
                    for i in tqdm.tqdm(range(len(test_dataset))):
                        data = test_dataset[i]
                        render_bkgd = data["color_bkgd"]
                        rays = data["rays"]
                        pixels = data["pixels"]
                        timestamps = data.get("timestamps", None)

Ruilong Li's avatar
Ruilong Li committed
337
                        # rendering
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
338
                        rgb, acc, _, _ = render_image(
Ruilong Li's avatar
Ruilong Li committed
339
340
                            radiance_field,
                            rays,
341
                            timestamps,
Ruilong Li's avatar
Ruilong Li committed
342
343
344
                            render_bkgd,
                            render_step_size,
                            test_chunk_size=args.test_chunk_size,
Ruilong Li's avatar
Ruilong Li committed
345
346
347
348
                        )
                        mse = F.mse_loss(rgb, pixels)
                        psnr = -10.0 * torch.log(mse) / np.log(10.0)
                        psnrs.append(psnr.item())
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
                        # if step == max_steps:
                        #     output_dir = os.path.join("./outputs/nerfacc/", scene)
                        #     os.makedirs(output_dir, exist_ok=True)
                        #     save = torch.cat([pixels, rgb], dim=1)
                        #     imageio.imwrite(
                        #         os.path.join(output_dir, "%05d.png" % i),
                        #         (save.cpu().numpy() * 255).astype(np.uint8),
                        #     )
                        # else:
                        #     imageio.imwrite(
                        #         "acc_binary_test.png",
                        #         ((acc > 0).float().cpu().numpy() * 255).astype(
                        #             np.uint8
                        #         ),
                        #     )
                        #     imageio.imwrite(
                        #         "rgb_test.png",
                        #         (rgb.cpu().numpy() * 255).astype(np.uint8),
                        #     )
                        #     break
Ruilong Li's avatar
Ruilong Li committed
369
370
                psnr_avg = sum(psnrs) / len(psnrs)
                print(f"evaluation: {psnr_avg=}")
Ruilong Li's avatar
Ruilong Li committed
371
372
                train_dataset.training = True

Ruilong Li's avatar
Ruilong Li committed
373
            if step == max_steps:
Ruilong Li's avatar
Ruilong Li committed
374
                print("training stops")
Ruilong Li's avatar
Ruilong Li committed
375
                exit()
Ruilong Li's avatar
Ruilong Li committed
376
            tic_data = time.time()
Ruilong Li's avatar
Ruilong Li committed
377

Ruilong Li's avatar
Ruilong Li committed
378
            step += 1