Commit 6093da8f authored by Ruilong Li's avatar Ruilong Li
Browse files

better with tips

parent b5a2af68
......@@ -16,5 +16,12 @@ Tested with the default settings on the Lego test set.
| - | - | - | - | - | - |
| instant-ngp (paper) | trainval? | 36.39 | - | - | 3090 |
| torch-ngp (`-O`) | train (30K steps) | 34.15 | 310 sec | 7.8 fps | V100 |
| ours | train (30K steps) | 33.31 | 298 sec | 6.4 fps | TITAN RTX |
| ours | trainval (30K steps) | 34.45 | 290 sec | 6.6 fps | TITAN RTX |
\ No newline at end of file
| ours | train (30K steps) | 34.40 | 296 sec | 6.2 fps | TITAN RTX |
| ours | trainval (30K steps) | 35.42 | 291 sec | 6.4 fps | TITAN RTX |
## Tips:
1. sample rays over all images per iteration (`batch_over_images=True`) is better: `PSNR 33.31 -> 33.75`.
2. make use of scheduler (`MultiStepLR(optimizer, milestones=[20000, 30000], gamma=0.1)`) to adjust learning rate gives: `PSNR 33.75 -> 34.40`.
3. increasing chunk size (`chunk: 8192 -> 81920`) during inference gives speedup: `FPS 4.x -> 6.2`
# Copyright (c) Meta Platforms, Inc. and affiliates.
import collections
import json
import os
import imageio.v2 as imageio
import numpy as np
import torch
import torch.nn.functional as F
from .utils import Cameras, generate_rays
Rays = collections.namedtuple("Rays", ("origins", "viewdirs"))
def namedtuple_map(fn, tup):
"""Apply `fn` to each element of `tup` and cast to `tup`'s namedtuple."""
return type(tup)(*(None if x is None else fn(x) for x in tup))
def _load_renderings(root_fp: str, subject_id: str, split: str):
......@@ -33,8 +39,8 @@ def _load_renderings(root_fp: str, subject_id: str, split: str):
camtoworlds.append(frame["transform_matrix"])
images.append(rgba)
images = np.stack(images, axis=0).astype(np.float32)
camtoworlds = np.stack(camtoworlds, axis=0).astype(np.float32)
images = np.stack(images, axis=0)
camtoworlds = np.stack(camtoworlds, axis=0)
h, w = images.shape[1:3]
camera_angle_x = float(meta["camera_angle_x"])
......@@ -69,6 +75,7 @@ class SubjectLoader(torch.utils.data.Dataset):
num_rays: int = None,
near: float = None,
far: float = None,
batch_over_images: bool = True,
):
super().__init__()
assert split in self.SPLITS, "%s" % split
......@@ -80,6 +87,7 @@ class SubjectLoader(torch.utils.data.Dataset):
self.far = self.FAR if far is None else far
self.training = (num_rays is not None) and (split in ["train", "trainval"])
self.color_bkgd_aug = color_bkgd_aug
self.batch_over_images = batch_over_images
if split == "trainval":
_images_train, _camtoworlds_train, _focal_train = _load_renderings(
root_fp, subject_id, "train"
......@@ -94,11 +102,22 @@ class SubjectLoader(torch.utils.data.Dataset):
self.images, self.camtoworlds, self.focal = _load_renderings(
root_fp, subject_id, split
)
self.images = torch.from_numpy(self.images).to(torch.uint8)
self.camtoworlds = torch.from_numpy(self.camtoworlds).to(torch.float32)
self.K = torch.tensor(
[
[self.focal, 0, self.WIDTH / 2.0],
[0, self.focal, self.HEIGHT / 2.0],
[0, 0, 1],
],
dtype=torch.float32,
) # (3, 3)
assert self.images.shape[1:3] == (self.HEIGHT, self.WIDTH)
def __len__(self):
return len(self.images)
@torch.no_grad()
def __getitem__(self, index):
data = self.fetch_data(index)
data = self.preprocess(data)
......@@ -111,14 +130,14 @@ class SubjectLoader(torch.utils.data.Dataset):
if self.training:
if self.color_bkgd_aug == "random":
color_bkgd = torch.rand(3)
color_bkgd = torch.rand(3, device=self.images.device)
elif self.color_bkgd_aug == "white":
color_bkgd = torch.ones(3)
color_bkgd = torch.ones(3, device=self.images.device)
elif self.color_bkgd_aug == "black":
color_bkgd = torch.zeros(3)
color_bkgd = torch.zeros(3, device=self.images.device)
else:
# just use white during inference
color_bkgd = torch.ones(3)
color_bkgd = torch.ones(3, device=self.images.device)
pixels = pixels * alpha + color_bkgd * (1.0 - alpha)
return {
......@@ -130,48 +149,65 @@ class SubjectLoader(torch.utils.data.Dataset):
def fetch_data(self, index):
"""Fetch the data (it maybe cached for multiple batches)."""
# load data
camera_id = index
K = np.array(
[
[self.focal, 0, self.WIDTH / 2.0],
[0, self.focal, self.HEIGHT / 2.0],
[0, 0, 1],
]
).astype(np.float32)
w2c = np.linalg.inv(self.camtoworlds[camera_id])
rgba = self.images[camera_id]
# create pixels
rgba = torch.from_numpy(rgba).float() / 255.0
# create rays from camera
cameras = Cameras(
intrins=torch.from_numpy(K).float(),
extrins=torch.from_numpy(w2c).float(),
distorts=None,
width=self.WIDTH,
height=self.HEIGHT,
if self.training:
if self.batch_over_images:
image_id = torch.randint(
0,
len(self.images),
size=(self.num_rays,),
device=self.images.device,
)
else:
image_id = [index]
x = torch.randint(
0, self.WIDTH, size=(self.num_rays,), device=self.images.device
)
y = torch.randint(
0, self.HEIGHT, size=(self.num_rays,), device=self.images.device
)
if self.num_rays is not None:
x = torch.randint(0, self.WIDTH, size=(self.num_rays,))
y = torch.randint(0, self.HEIGHT, size=(self.num_rays,))
pixels_xy = torch.stack([x, y], dim=-1)
rgba = rgba[y, x, :]
else:
pixels_xy = None # full image
# Be careful: This dataset's camera coordinate is not the same as
# opencv's camera coordinate! It is actually opengl.
rays = generate_rays(
cameras,
opencv_format=False,
pixels_xy=pixels_xy,
image_id = [index]
x, y = torch.meshgrid(
torch.arange(self.WIDTH, device=self.images.device),
torch.arange(self.HEIGHT, device=self.images.device),
indexing="xy",
)
x = x.flatten()
y = y.flatten()
# generate rays
rgba = self.images[image_id, y, x] / 255.0 # (num_rays, 4)
c2w = self.camtoworlds[image_id] # (num_rays, 3, 4)
camera_dirs = F.pad(
torch.stack(
[
(x - self.K[0, 2] + 0.5) / self.K[0, 0],
(y - self.K[1, 2] + 0.5) / self.K[1, 1],
],
dim=-1,
),
(0, 1),
value=1,
) # [num_rays, 3]
camera_dirs[..., [1, 2]] *= -1 # opengl format
# [n_cams, height, width, 3]
directions = (camera_dirs[:, None, :] * c2w[:, :3, :3]).sum(dim=-1)
origins = torch.broadcast_to(c2w[:, :3, -1], directions.shape)
viewdirs = directions / torch.linalg.norm(directions, dim=-1, keepdims=True)
if self.training:
origins = torch.reshape(origins, (self.num_rays, 3))
viewdirs = torch.reshape(viewdirs, (self.num_rays, 3))
rgba = torch.reshape(rgba, (self.num_rays, 4))
else:
origins = torch.reshape(origins, (self.HEIGHT, self.WIDTH, 3))
viewdirs = torch.reshape(viewdirs, (self.HEIGHT, self.WIDTH, 3))
rgba = torch.reshape(rgba, (self.HEIGHT, self.WIDTH, 4))
rays = Rays(origins=origins, viewdirs=viewdirs)
return {
"camera_id": camera_id,
"rgba": rgba, # [h, w, 4] or [num_rays, 4]
"rays": rays, # [h, w] or [num_rays, 4]
"rays": rays, # [h, w, 3] or [num_rays, 3]
}
# Copyright (c) Meta Platforms, Inc. and affiliates.
import collections
import math
import torch
import torch.nn.functional as F
Rays = collections.namedtuple("Rays", ("origins", "viewdirs"))
Cameras = collections.namedtuple(
"Cameras", ("intrins", "extrins", "distorts", "width", "height")
)
def namedtuple_map(fn, tup):
"""Apply `fn` to each element of `tup` and cast to `tup`'s namedtuple."""
return type(tup)(*(None if x is None else fn(x) for x in tup))
def homo(points: torch.Tensor) -> torch.Tensor:
"""Get the homogeneous coordinates."""
return F.pad(points, (0, 1), value=1)
def transform_cameras(cameras: Cameras, resize_factor: float) -> torch.Tensor:
intrins = cameras.intrins
intrins[..., :2, :] = intrins[..., :2, :] * resize_factor
width = int(cameras.width * resize_factor + 0.5)
height = int(cameras.height * resize_factor + 0.5)
return Cameras(
intrins=intrins,
extrins=cameras.extrins,
distorts=cameras.distorts,
width=width,
height=height,
)
def generate_rays(
cameras: Cameras,
opencv_format: bool = True,
pixels_xy: torch.Tensor = None,
) -> Rays:
"""Generating rays for a single or multiple cameras.
:params cameras [(n_cams,)]
:returns: Rays
[(n_cams,) height, width] if pixels_xy is None
[(n_cams,) num_pixels] if pixels_xy is given
"""
if pixels_xy is not None:
K = cameras.intrins[..., None, :, :]
c2w = cameras.extrins[..., None, :, :].inverse()
x, y = pixels_xy[..., 0], pixels_xy[..., 1]
else:
K = cameras.intrins[..., None, None, :, :]
c2w = cameras.extrins[..., None, None, :, :].inverse()
x, y = torch.meshgrid(
torch.arange(cameras.width, dtype=K.dtype),
torch.arange(cameras.height, dtype=K.dtype),
indexing="xy",
) # [height, width]
camera_dirs = homo(
torch.stack(
[
(x - K[..., 0, 2] + 0.5) / K[..., 0, 0],
(y - K[..., 1, 2] + 0.5) / K[..., 1, 1],
],
dim=-1,
)
) # [n_cams, height, width, 3]
if not opencv_format:
camera_dirs[..., [1, 2]] *= -1
# [n_cams, height, width, 3]
directions = (camera_dirs[..., None, :] * c2w[..., :3, :3]).sum(dim=-1)
origins = torch.broadcast_to(c2w[..., :3, -1], directions.shape)
viewdirs = directions / torch.linalg.norm(directions, dim=-1, keepdims=True)
rays = Rays(
origins=origins, # [n_cams, height, width, 3]
viewdirs=viewdirs, # [n_cams, height, width, 3]
)
return rays
......@@ -5,8 +5,7 @@ import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from datasets.nerf_synthetic import SubjectLoader
from datasets.utils import namedtuple_map
from datasets.nerf_synthetic import SubjectLoader, namedtuple_map
from radiance_fields.ngp import NGPradianceField
from nerfacc import OccupancyField, volumetric_rendering
......@@ -67,19 +66,26 @@ if __name__ == "__main__":
split="train",
num_rays=8192,
)
# train_dataset.images = train_dataset.images.to(device)
# train_dataset.camtoworlds = train_dataset.camtoworlds.to(device)
# train_dataset.K = train_dataset.K.to(device)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
num_workers=1,
num_workers=4,
batch_size=None,
persistent_workers=True,
shuffle=True,
)
test_dataset = SubjectLoader(
subject_id="lego",
root_fp="/home/ruilongli/data/nerf_synthetic/",
split="test",
num_rays=None,
)
# test_dataset.images = test_dataset.images.to(device)
# test_dataset.camtoworlds = test_dataset.camtoworlds.to(device)
# test_dataset.K = test_dataset.K.to(device)
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
num_workers=4,
......@@ -102,9 +108,9 @@ if __name__ == "__main__":
)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=3e-3, eps=1e-15)
# scheduler = torch.optim.lr_scheduler.MultiStepLR(
# optimizer, milestones=[10000, 20000, 30000], gamma=0.1
# )
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=[20000, 30000], gamma=0.1
)
# setup occupancy field with eval function
def occ_eval_fn(x: torch.Tensor) -> torch.Tensor:
......@@ -156,7 +162,7 @@ if __name__ == "__main__":
optimizer.zero_grad()
loss.backward()
optimizer.step()
# scheduler.step()
scheduler.step()
if step % 50 == 0:
elapsed_time = time.time() - tic
......@@ -189,6 +195,18 @@ if __name__ == "__main__":
# elapsed_time=298.27s (data=60.08s) | step=30000 | loss=0.00026
# evaluation: psnr_avg=33.305334663391115 (6.42 it/s)
# "train" batch_over_images=True
# elapsed_time=335.21s (data=68.99s) | step=30000 | loss=0.00028
# evaluation: psnr_avg=33.74970862388611 (6.23 it/s)
# "train" batch_over_images=True, schedule
# elapsed_time=296.30s (data=54.38s) | step=30000 | loss=0.00022
# evaluation: psnr_avg=34.3978275680542 (6.22 it/s)
# "trainval"
# elapsed_time=289.94s (data=51.99s) | step=30000 | loss=0.00021
# evaluation: psnr_avg=34.44980221748352 (6.61 it/s)
# "trainval" batch_over_images=True, schedule
# elapsed_time=291.42s (data=52.82s) | step=30000 | loss=0.00020
# evaluation: psnr_avg=35.41630497932434 (6.40 it/s)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment