train_mlp_tnerf.py 6.16 KB
Newer Older
1
2
3
4
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""

Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
5
6
import argparse
import math
Jingchen Ye's avatar
Jingchen Ye committed
7
import pathlib
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
8
9
10
11
12
13
14
import time

import imageio
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
15
from datasets.dnerf_synthetic import SubjectLoader
16
from lpips import LPIPS
17
18
19
20
from radiance_fields.mlp import TNeRFRadianceField

from examples.utils import render_image_with_occgrid, set_random_seed
from nerfacc.estimators.occ_grid import OccGridEstimator
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
21

22
23
device = "cuda:0"
set_random_seed(42)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
24

25
26
27
28
parser = argparse.ArgumentParser()
parser.add_argument(
    "--data_root",
    type=str,
29
    default=str(pathlib.Path.cwd() / "data/dnerf"),
30
31
32
33
34
35
    help="the root dir of the dataset",
)
parser.add_argument(
    "--train_split",
    type=str,
    default="train",
36
    choices=["train"],
37
38
39
40
41
42
    help="which train split to use",
)
parser.add_argument(
    "--scene",
    type=str,
    default="lego",
43
44
45
46
47
48
49
50
51
52
53
    choices=[
        # dnerf
        "bouncingballs",
        "hellwarrior",
        "hook",
        "jumpingjacks",
        "lego",
        "mutant",
        "standup",
        "trex",
    ],
54
55
56
57
58
    help="which scene to use",
)
parser.add_argument(
    "--test_chunk_size",
    type=int,
59
    default=4096,
60
61
)
args = parser.parse_args()
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
62

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# training parameters
max_steps = 30000
init_batch_size = 1024
target_sample_batch_size = 1 << 16
# scene parameters
aabb = torch.tensor([-1.5, -1.5, -1.5, 1.5, 1.5, 1.5], device=device)
near_plane = 0.0
far_plane = 1.0e10
# model parameters
grid_resolution = 128
grid_nlvl = 1
# render parameters
render_step_size = 5e-3

# setup the dataset
78
79
80
81
82
83
84
85
86
87
88
89
90
91
train_dataset = SubjectLoader(
    subject_id=args.scene,
    root_fp=args.data_root,
    split=args.train_split,
    num_rays=init_batch_size,
    device=device,
)
test_dataset = SubjectLoader(
    subject_id=args.scene,
    root_fp=args.data_root,
    split="test",
    num_rays=None,
    device=device,
)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
92

93
94
95
estimator = OccGridEstimator(
    roi_aabb=aabb, resolution=grid_resolution, levels=grid_nlvl
).to(device)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
96

97
# setup the radiance field we want to train.
98
99
100
101
102
103
104
105
106
107
108
radiance_field = TNeRFRadianceField().to(device)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer,
    milestones=[
        max_steps // 2,
        max_steps * 3 // 4,
        max_steps * 5 // 6,
        max_steps * 9 // 10,
    ],
    gamma=0.33,
109
)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
110

111
112
113
lpips_net = LPIPS(net="vgg").to(device)
lpips_norm_fn = lambda x: x[None, ...].permute(0, 3, 1, 2) * 2 - 1
lpips_fn = lambda x, y: lpips_net(lpips_norm_fn(x), lpips_norm_fn(y)).mean()
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
114

115
116
117
118
# training
tic = time.time()
for step in range(max_steps + 1):
    radiance_field.train()
119
    estimator.train()
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
120

121
122
    i = torch.randint(0, len(train_dataset), (1,)).item()
    data = train_dataset[i]
123

124
125
126
    render_bkgd = data["color_bkgd"]
    rays = data["rays"]
    pixels = data["pixels"]
127
    timestamps = data["timestamps"]
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
128

129
    # update occupancy grid
130
    estimator.update_every_n_steps(
131
        step=step,
132
133
134
        occ_eval_fn=lambda x: radiance_field.query_opacity(
            x, timestamps, render_step_size
        ),
135
136
        occ_thre=1e-2,
    )
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
137

138
    # render
139
    rgb, acc, depth, n_rendering_samples = render_image_with_occgrid(
140
        radiance_field,
141
        estimator,
142
143
144
145
146
        rays,
        # rendering options
        near_plane=near_plane,
        render_step_size=render_step_size,
        render_bkgd=render_bkgd,
147
148
149
        alpha_thre=0.01 if step > 1000 else 0.00,
        # t-nerf options
        timestamps=timestamps,
150
151
152
    )
    if n_rendering_samples == 0:
        continue
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
153

154
155
156
157
158
159
160
    if target_sample_batch_size > 0:
        # dynamic batch size for rays to keep sample batch size constant.
        num_rays = len(pixels)
        num_rays = int(
            num_rays * (target_sample_batch_size / float(n_rendering_samples))
        )
        train_dataset.update_num_rays(num_rays)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
161

162
163
    # compute loss
    loss = F.smooth_l1_loss(rgb, pixels)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
164

165
    optimizer.zero_grad()
166
    loss.backward()
167
168
    optimizer.step()
    scheduler.step()
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
169

170
171
172
173
174
175
176
177
178
179
    if step % 5000 == 0:
        elapsed_time = time.time() - tic
        loss = F.mse_loss(rgb, pixels)
        psnr = -10.0 * torch.log(loss) / np.log(10.0)
        print(
            f"elapsed_time={elapsed_time:.2f}s | step={step} | "
            f"loss={loss:.5f} | psnr={psnr:.2f} | "
            f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} | "
            f"max_depth={depth.max():.3f} | "
        )
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
180

181
182
183
    if step > 0 and step % max_steps == 0:
        # evaluation
        radiance_field.eval()
184
        estimator.eval()
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
185

186
187
188
189
190
191
192
193
        psnrs = []
        lpips = []
        with torch.no_grad():
            for i in tqdm.tqdm(range(len(test_dataset))):
                data = test_dataset[i]
                render_bkgd = data["color_bkgd"]
                rays = data["rays"]
                pixels = data["pixels"]
194
                timestamps = data["timestamps"]
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
195

196
                # rendering
197
                rgb, acc, depth, _ = render_image_with_occgrid(
198
                    radiance_field,
199
                    estimator,
200
201
202
203
204
                    rays,
                    # rendering options
                    near_plane=near_plane,
                    render_step_size=render_step_size,
                    render_bkgd=render_bkgd,
205
                    alpha_thre=0.01,
206
207
                    # test options
                    test_chunk_size=args.test_chunk_size,
208
209
                    # t-nerf options
                    timestamps=timestamps,
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
                )
                mse = F.mse_loss(rgb, pixels)
                psnr = -10.0 * torch.log(mse) / np.log(10.0)
                psnrs.append(psnr.item())
                lpips.append(lpips_fn(rgb, pixels).item())
                # if i == 0:
                #     imageio.imwrite(
                #         "rgb_test.png",
                #         (rgb.cpu().numpy() * 255).astype(np.uint8),
                #     )
                #     imageio.imwrite(
                #         "rgb_error.png",
                #         (
                #             (rgb - pixels).norm(dim=-1).cpu().numpy() * 255
                #         ).astype(np.uint8),
                #     )
        psnr_avg = sum(psnrs) / len(psnrs)
        lpips_avg = sum(lpips) / len(lpips)
        print(f"evaluation: psnr_avg={psnr_avg}, lpips_avg={lpips_avg}")