train_mlp_nerf.py 7.24 KB
Newer Older
1
2
3
4
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""

Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
5
6
import argparse
import math
Jingchen Ye's avatar
Jingchen Ye committed
7
import pathlib
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
8
9
10
11
12
13
14
15
16
17
18
19
import time

import imageio
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from radiance_fields.mlp import VanillaNeRFRadianceField
from utils import render_image, set_random_seed

from nerfacc import ContractionType, OccupancyGrid

20
21
device = "cuda:0"
set_random_seed(42)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
22

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
parser = argparse.ArgumentParser()
parser.add_argument(
    "--data_root",
    type=str,
    default=str(pathlib.Path.cwd() / "data/nerf_synthetic"),
    help="the root dir of the dataset",
)
parser.add_argument(
    "--train_split",
    type=str,
    default="trainval",
    choices=["train", "trainval"],
    help="which train split to use",
)
parser.add_argument(
    "--scene",
    type=str,
    default="lego",
    choices=[
        # nerf synthetic
        "chair",
        "drums",
        "ficus",
        "hotdog",
        "lego",
        "materials",
        "mic",
        "ship",
        # mipnerf360 unbounded
        "garden",
    ],
    help="which scene to use",
)
parser.add_argument(
    "--aabb",
    type=lambda s: [float(item) for item in s.split(",")],
    default="-1.5,-1.5,-1.5,1.5,1.5,1.5",
    help="delimited list input",
)
parser.add_argument(
    "--test_chunk_size",
    type=int,
    default=8192,
)
parser.add_argument(
    "--unbounded",
    action="store_true",
    help="whether to use unbounded rendering",
)
parser.add_argument("--cone_angle", type=float, default=0.0)
args = parser.parse_args()
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
74

75
render_n_samples = 1024
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
76

77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# setup the scene bounding box.
if args.unbounded:
    print("Using unbounded rendering")
    contraction_type = ContractionType.UN_BOUNDED_SPHERE
    # contraction_type = ContractionType.UN_BOUNDED_TANH
    scene_aabb = None
    near_plane = 0.2
    far_plane = 1e4
    render_step_size = 1e-2
else:
    contraction_type = ContractionType.AABB
    scene_aabb = torch.tensor(args.aabb, dtype=torch.float32, device=device)
    near_plane = None
    far_plane = None
    render_step_size = (
        (scene_aabb[3:] - scene_aabb[:3]).max()
        * math.sqrt(3)
        / render_n_samples
    ).item()
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
96

97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# setup the radiance field we want to train.
max_steps = 50000
grad_scaler = torch.cuda.amp.GradScaler(1)
radiance_field = VanillaNeRFRadianceField().to(device)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer,
    milestones=[
        max_steps // 2,
        max_steps * 3 // 4,
        max_steps * 5 // 6,
        max_steps * 9 // 10,
    ],
    gamma=0.33,
)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
112

113
114
115
116
117
# setup the dataset
train_dataset_kwargs = {}
test_dataset_kwargs = {}
if args.scene == "garden":
    from datasets.nerf_360_v2 import SubjectLoader
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
118

119
120
121
122
123
124
    target_sample_batch_size = 1 << 16
    train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4}
    test_dataset_kwargs = {"factor": 4}
    grid_resolution = 128
else:
    from datasets.nerf_synthetic import SubjectLoader
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
125

126
127
    target_sample_batch_size = 1 << 16
    grid_resolution = 128
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
128

129
130
131
132
133
134
135
train_dataset = SubjectLoader(
    subject_id=args.scene,
    root_fp=args.data_root,
    split=args.train_split,
    num_rays=target_sample_batch_size // render_n_samples,
    **train_dataset_kwargs,
)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
136

137
138
139
140
141
142
143
test_dataset = SubjectLoader(
    subject_id=args.scene,
    root_fp=args.data_root,
    split="test",
    num_rays=None,
    **test_dataset_kwargs,
)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
144

145
146
147
148
149
occupancy_grid = OccupancyGrid(
    roi_aabb=args.aabb,
    resolution=grid_resolution,
    contraction_type=contraction_type,
).to(device)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
150

151
152
153
154
155
156
157
# training
step = 0
tic = time.time()
for epoch in range(10000000):
    for i in range(len(train_dataset)):
        radiance_field.train()
        data = train_dataset[i]
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
158

159
160
161
        render_bkgd = data["color_bkgd"]
        rays = data["rays"]
        pixels = data["pixels"]
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
162

163
164
165
166
167
168
169
        # update occupancy grid
        occupancy_grid.every_n_step(
            step=step,
            occ_eval_fn=lambda x: radiance_field.query_opacity(
                x, render_step_size
            ),
        )
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
170

171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        # render
        rgb, acc, depth, n_rendering_samples = render_image(
            radiance_field,
            occupancy_grid,
            rays,
            scene_aabb,
            # rendering options
            near_plane=near_plane,
            far_plane=far_plane,
            render_step_size=render_step_size,
            render_bkgd=render_bkgd,
            cone_angle=args.cone_angle,
        )
        if n_rendering_samples == 0:
            continue
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
186

187
188
189
190
191
192
193
        # dynamic batch size for rays to keep sample batch size constant.
        num_rays = len(pixels)
        num_rays = int(
            num_rays * (target_sample_batch_size / float(n_rendering_samples))
        )
        train_dataset.update_num_rays(num_rays)
        alive_ray_mask = acc.squeeze(-1) > 0
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
194

195
196
        # compute loss
        loss = F.smooth_l1_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
197

198
199
200
201
202
        optimizer.zero_grad()
        # do not unscale it because we are using Adam.
        grad_scaler.scale(loss).backward()
        optimizer.step()
        scheduler.step()
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
203

204
205
206
207
208
209
210
211
212
        if step % 5000 == 0:
            elapsed_time = time.time() - tic
            loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
            print(
                f"elapsed_time={elapsed_time:.2f}s | step={step} | "
                f"loss={loss:.5f} | "
                f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
                f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |"
            )
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
213

214
215
216
        if step > 0 and step % max_steps == 0:
            # evaluation
            radiance_field.eval()
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
217

218
219
220
221
222
223
224
            psnrs = []
            with torch.no_grad():
                for i in tqdm.tqdm(range(len(test_dataset))):
                    data = test_dataset[i]
                    render_bkgd = data["color_bkgd"]
                    rays = data["rays"]
                    pixels = data["pixels"]
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
225

226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
                    # rendering
                    rgb, acc, depth, _ = render_image(
                        radiance_field,
                        occupancy_grid,
                        rays,
                        scene_aabb,
                        # rendering options
                        near_plane=None,
                        far_plane=None,
                        render_step_size=render_step_size,
                        render_bkgd=render_bkgd,
                        cone_angle=args.cone_angle,
                        # test options
                        test_chunk_size=args.test_chunk_size,
                    )
                    mse = F.mse_loss(rgb, pixels)
                    psnr = -10.0 * torch.log(mse) / np.log(10.0)
                    psnrs.append(psnr.item())
                    # imageio.imwrite(
                    #     "acc_binary_test.png",
                    #     ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8),
                    # )
                    # imageio.imwrite(
                    #     "rgb_test.png",
                    #     (rgb.cpu().numpy() * 255).astype(np.uint8),
                    # )
                    # break
            psnr_avg = sum(psnrs) / len(psnrs)
            print(f"evaluation: psnr_avg={psnr_avg}")
            train_dataset.training = True
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
256

257
258
259
        if step == max_steps:
            print("training stops")
            exit()
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
260

261
        step += 1