train_ngp_nerf_occ.py 7.35 KB
Newer Older
1
2
3
4
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""

Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
5
6
import argparse
import math
Jingchen Ye's avatar
Jingchen Ye committed
7
import pathlib
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
8
9
10
11
12
13
14
import time

import imageio
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
15
16
from lpips import LPIPS
from radiance_fields.ngp import NGPRadianceField
17
18

from examples.utils import (
19
20
    MIPNERF360_UNBOUNDED_SCENES,
    NERF_SYNTHETIC_SCENES,
21
    render_image_with_occgrid,
22
    render_image_with_occgrid_test,
23
24
    set_random_seed,
)
25
from nerfacc.estimators.occ_grid import OccGridEstimator
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
26

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
parser = argparse.ArgumentParser()
parser.add_argument(
    "--data_root",
    type=str,
    # default=str(pathlib.Path.cwd() / "data/360_v2"),
    default=str(pathlib.Path.cwd() / "data/nerf_synthetic"),
    help="the root dir of the dataset",
)
parser.add_argument(
    "--train_split",
    type=str,
    default="train",
    choices=["train", "trainval"],
    help="which train split to use",
)
parser.add_argument(
    "--scene",
    type=str,
    default="lego",
    choices=NERF_SYNTHETIC_SCENES + MIPNERF360_UNBOUNDED_SCENES,
    help="which scene to use",
)
args = parser.parse_args()
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
50

51
52
device = "cuda:0"
set_random_seed(42)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
53

54
55
if args.scene in MIPNERF360_UNBOUNDED_SCENES:
    from datasets.nerf_360_v2 import SubjectLoader
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
56

57
    # training parameters
Ruilong Li's avatar
Ruilong Li committed
58
    max_steps = 20000
59
60
61
62
63
    init_batch_size = 1024
    target_sample_batch_size = 1 << 18
    weight_decay = 0.0
    # scene parameters
    aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=device)
64
65
    near_plane = 0.2
    far_plane = 1.0e10
66
67
68
69
70
71
72
73
74
75
    # dataset parameters
    train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4}
    test_dataset_kwargs = {"factor": 4}
    # model parameters
    grid_resolution = 128
    grid_nlvl = 4
    # render parameters
    render_step_size = 1e-3
    alpha_thre = 1e-2
    cone_angle = 0.004
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
76

77
78
else:
    from datasets.nerf_synthetic import SubjectLoader
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
79

80
81
82
83
84
85
    # training parameters
    max_steps = 20000
    init_batch_size = 1024
    target_sample_batch_size = 1 << 18
    weight_decay = (
        1e-5 if args.scene in ["materials", "ficus", "drums"] else 1e-6
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
86
    )
87
88
    # scene parameters
    aabb = torch.tensor([-1.5, -1.5, -1.5, 1.5, 1.5, 1.5], device=device)
89
90
    near_plane = 0.0
    far_plane = 1.0e10
91
92
93
94
95
96
97
98
99
100
    # dataset parameters
    train_dataset_kwargs = {}
    test_dataset_kwargs = {}
    # model parameters
    grid_resolution = 128
    grid_nlvl = 1
    # render parameters
    render_step_size = 5e-3
    alpha_thre = 0.0
    cone_angle = 0.0
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
101

102
103
104
105
106
107
108
109
train_dataset = SubjectLoader(
    subject_id=args.scene,
    root_fp=args.data_root,
    split=args.train_split,
    num_rays=init_batch_size,
    device=device,
    **train_dataset_kwargs,
)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
110

111
112
113
114
115
116
117
118
test_dataset = SubjectLoader(
    subject_id=args.scene,
    root_fp=args.data_root,
    split="test",
    num_rays=None,
    device=device,
    **test_dataset_kwargs,
)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
119

120
121
122
estimator = OccGridEstimator(
    roi_aabb=aabb, resolution=grid_resolution, levels=grid_nlvl
).to(device)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
123

124
125
# setup the radiance field we want to train.
grad_scaler = torch.cuda.amp.GradScaler(2**10)
126
radiance_field = NGPRadianceField(aabb=estimator.aabbs[-1]).to(device)
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
optimizer = torch.optim.Adam(
    radiance_field.parameters(), lr=1e-2, eps=1e-15, weight_decay=weight_decay
)
scheduler = torch.optim.lr_scheduler.ChainedScheduler(
    [
        torch.optim.lr_scheduler.LinearLR(
            optimizer, start_factor=0.01, total_iters=100
        ),
        torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[
                max_steps // 2,
                max_steps * 3 // 4,
                max_steps * 9 // 10,
            ],
            gamma=0.33,
        ),
    ]
)
lpips_net = LPIPS(net="vgg").to(device)
lpips_norm_fn = lambda x: x[None, ...].permute(0, 3, 1, 2) * 2 - 1
lpips_fn = lambda x, y: lpips_net(lpips_norm_fn(x), lpips_norm_fn(y)).mean()
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
149

150
151
152
153
# training
tic = time.time()
for step in range(max_steps + 1):
    radiance_field.train()
154
    estimator.train()
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
155

156
157
    i = torch.randint(0, len(train_dataset), (1,)).item()
    data = train_dataset[i]
158

159
160
161
    render_bkgd = data["color_bkgd"]
    rays = data["rays"]
    pixels = data["pixels"]
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
162

163
164
165
    def occ_eval_fn(x):
        density = radiance_field.query_density(x)
        return density * render_step_size
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
166

167
    # update occupancy grid
168
    estimator.update_every_n_steps(
169
170
171
172
        step=step,
        occ_eval_fn=occ_eval_fn,
        occ_thre=1e-2,
    )
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
173

174
    # render
175
    rgb, acc, depth, n_rendering_samples = render_image_with_occgrid(
176
        radiance_field,
177
        estimator,
178
179
180
181
182
183
184
185
186
187
        rays,
        # rendering options
        near_plane=near_plane,
        render_step_size=render_step_size,
        render_bkgd=render_bkgd,
        cone_angle=cone_angle,
        alpha_thre=alpha_thre,
    )
    if n_rendering_samples == 0:
        continue
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
188

189
190
191
192
193
194
195
    if target_sample_batch_size > 0:
        # dynamic batch size for rays to keep sample batch size constant.
        num_rays = len(pixels)
        num_rays = int(
            num_rays * (target_sample_batch_size / float(n_rendering_samples))
        )
        train_dataset.update_num_rays(num_rays)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
196

197
198
    # compute loss
    loss = F.smooth_l1_loss(rgb, pixels)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
199

200
201
202
203
204
    optimizer.zero_grad()
    # do not unscale it because we are using Adam.
    grad_scaler.scale(loss).backward()
    optimizer.step()
    scheduler.step()
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
205

206
    if step % 10000 == 0:
207
208
209
210
211
212
213
214
215
        elapsed_time = time.time() - tic
        loss = F.mse_loss(rgb, pixels)
        psnr = -10.0 * torch.log(loss) / np.log(10.0)
        print(
            f"elapsed_time={elapsed_time:.2f}s | step={step} | "
            f"loss={loss:.5f} | psnr={psnr:.2f} | "
            f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} | "
            f"max_depth={depth.max():.3f} | "
        )
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
216

217
218
219
    if step > 0 and step % max_steps == 0:
        # evaluation
        radiance_field.eval()
220
        estimator.eval()
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
221

222
223
224
225
226
227
228
229
        psnrs = []
        lpips = []
        with torch.no_grad():
            for i in tqdm.tqdm(range(len(test_dataset))):
                data = test_dataset[i]
                render_bkgd = data["color_bkgd"]
                rays = data["rays"]
                pixels = data["pixels"]
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
230

231
                # rendering
232
233
234
                rgb, acc, depth, _ = render_image_with_occgrid_test(
                    1024,
                    # scene
235
                    radiance_field,
236
                    estimator,
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
                    rays,
                    # rendering options
                    near_plane=near_plane,
                    render_step_size=render_step_size,
                    render_bkgd=render_bkgd,
                    cone_angle=cone_angle,
                    alpha_thre=alpha_thre,
                )
                mse = F.mse_loss(rgb, pixels)
                psnr = -10.0 * torch.log(mse) / np.log(10.0)
                psnrs.append(psnr.item())
                lpips.append(lpips_fn(rgb, pixels).item())
                # if i == 0:
                #     imageio.imwrite(
                #         "rgb_test.png",
                #         (rgb.cpu().numpy() * 255).astype(np.uint8),
                #     )
                #     imageio.imwrite(
                #         "rgb_error.png",
                #         (
                #             (rgb - pixels).norm(dim=-1).cpu().numpy() * 255
                #         ).astype(np.uint8),
                #     )
        psnr_avg = sum(psnrs) / len(psnrs)
        lpips_avg = sum(lpips) / len(lpips)
        print(f"evaluation: psnr_avg={psnr_avg}, lpips_avg={lpips_avg}")