Commit 6093da8f authored by Ruilong Li's avatar Ruilong Li
Browse files

better with tips

parent b5a2af68
...@@ -16,5 +16,12 @@ Tested with the default settings on the Lego test set. ...@@ -16,5 +16,12 @@ Tested with the default settings on the Lego test set.
| - | - | - | - | - | - | | - | - | - | - | - | - |
| instant-ngp (paper) | trainval? | 36.39 | - | - | 3090 | | instant-ngp (paper) | trainval? | 36.39 | - | - | 3090 |
| torch-ngp (`-O`) | train (30K steps) | 34.15 | 310 sec | 7.8 fps | V100 | | torch-ngp (`-O`) | train (30K steps) | 34.15 | 310 sec | 7.8 fps | V100 |
| ours | train (30K steps) | 33.31 | 298 sec | 6.4 fps | TITAN RTX | | ours | train (30K steps) | 34.40 | 296 sec | 6.2 fps | TITAN RTX |
| ours | trainval (30K steps) | 34.45 | 290 sec | 6.6 fps | TITAN RTX | | ours | trainval (30K steps) | 35.42 | 291 sec | 6.4 fps | TITAN RTX |
\ No newline at end of file
## Tips:
1. sample rays over all images per iteration (`batch_over_images=True`) is better: `PSNR 33.31 -> 33.75`.
2. make use of scheduler (`MultiStepLR(optimizer, milestones=[20000, 30000], gamma=0.1)`) to adjust learning rate gives: `PSNR 33.75 -> 34.40`.
3. increasing chunk size (`chunk: 8192 -> 81920`) during inference gives speedup: `FPS 4.x -> 6.2`
# Copyright (c) Meta Platforms, Inc. and affiliates. import collections
import json import json
import os import os
import imageio.v2 as imageio import imageio.v2 as imageio
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
from .utils import Cameras, generate_rays Rays = collections.namedtuple("Rays", ("origins", "viewdirs"))
def namedtuple_map(fn, tup):
"""Apply `fn` to each element of `tup` and cast to `tup`'s namedtuple."""
return type(tup)(*(None if x is None else fn(x) for x in tup))
def _load_renderings(root_fp: str, subject_id: str, split: str): def _load_renderings(root_fp: str, subject_id: str, split: str):
...@@ -33,8 +39,8 @@ def _load_renderings(root_fp: str, subject_id: str, split: str): ...@@ -33,8 +39,8 @@ def _load_renderings(root_fp: str, subject_id: str, split: str):
camtoworlds.append(frame["transform_matrix"]) camtoworlds.append(frame["transform_matrix"])
images.append(rgba) images.append(rgba)
images = np.stack(images, axis=0).astype(np.float32) images = np.stack(images, axis=0)
camtoworlds = np.stack(camtoworlds, axis=0).astype(np.float32) camtoworlds = np.stack(camtoworlds, axis=0)
h, w = images.shape[1:3] h, w = images.shape[1:3]
camera_angle_x = float(meta["camera_angle_x"]) camera_angle_x = float(meta["camera_angle_x"])
...@@ -69,6 +75,7 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -69,6 +75,7 @@ class SubjectLoader(torch.utils.data.Dataset):
num_rays: int = None, num_rays: int = None,
near: float = None, near: float = None,
far: float = None, far: float = None,
batch_over_images: bool = True,
): ):
super().__init__() super().__init__()
assert split in self.SPLITS, "%s" % split assert split in self.SPLITS, "%s" % split
...@@ -80,6 +87,7 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -80,6 +87,7 @@ class SubjectLoader(torch.utils.data.Dataset):
self.far = self.FAR if far is None else far self.far = self.FAR if far is None else far
self.training = (num_rays is not None) and (split in ["train", "trainval"]) self.training = (num_rays is not None) and (split in ["train", "trainval"])
self.color_bkgd_aug = color_bkgd_aug self.color_bkgd_aug = color_bkgd_aug
self.batch_over_images = batch_over_images
if split == "trainval": if split == "trainval":
_images_train, _camtoworlds_train, _focal_train = _load_renderings( _images_train, _camtoworlds_train, _focal_train = _load_renderings(
root_fp, subject_id, "train" root_fp, subject_id, "train"
...@@ -94,11 +102,22 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -94,11 +102,22 @@ class SubjectLoader(torch.utils.data.Dataset):
self.images, self.camtoworlds, self.focal = _load_renderings( self.images, self.camtoworlds, self.focal = _load_renderings(
root_fp, subject_id, split root_fp, subject_id, split
) )
self.images = torch.from_numpy(self.images).to(torch.uint8)
self.camtoworlds = torch.from_numpy(self.camtoworlds).to(torch.float32)
self.K = torch.tensor(
[
[self.focal, 0, self.WIDTH / 2.0],
[0, self.focal, self.HEIGHT / 2.0],
[0, 0, 1],
],
dtype=torch.float32,
) # (3, 3)
assert self.images.shape[1:3] == (self.HEIGHT, self.WIDTH) assert self.images.shape[1:3] == (self.HEIGHT, self.WIDTH)
def __len__(self): def __len__(self):
return len(self.images) return len(self.images)
@torch.no_grad()
def __getitem__(self, index): def __getitem__(self, index):
data = self.fetch_data(index) data = self.fetch_data(index)
data = self.preprocess(data) data = self.preprocess(data)
...@@ -111,14 +130,14 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -111,14 +130,14 @@ class SubjectLoader(torch.utils.data.Dataset):
if self.training: if self.training:
if self.color_bkgd_aug == "random": if self.color_bkgd_aug == "random":
color_bkgd = torch.rand(3) color_bkgd = torch.rand(3, device=self.images.device)
elif self.color_bkgd_aug == "white": elif self.color_bkgd_aug == "white":
color_bkgd = torch.ones(3) color_bkgd = torch.ones(3, device=self.images.device)
elif self.color_bkgd_aug == "black": elif self.color_bkgd_aug == "black":
color_bkgd = torch.zeros(3) color_bkgd = torch.zeros(3, device=self.images.device)
else: else:
# just use white during inference # just use white during inference
color_bkgd = torch.ones(3) color_bkgd = torch.ones(3, device=self.images.device)
pixels = pixels * alpha + color_bkgd * (1.0 - alpha) pixels = pixels * alpha + color_bkgd * (1.0 - alpha)
return { return {
...@@ -130,48 +149,65 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -130,48 +149,65 @@ class SubjectLoader(torch.utils.data.Dataset):
def fetch_data(self, index): def fetch_data(self, index):
"""Fetch the data (it maybe cached for multiple batches).""" """Fetch the data (it maybe cached for multiple batches)."""
# load data if self.training:
camera_id = index if self.batch_over_images:
K = np.array( image_id = torch.randint(
[ 0,
[self.focal, 0, self.WIDTH / 2.0], len(self.images),
[0, self.focal, self.HEIGHT / 2.0], size=(self.num_rays,),
[0, 0, 1], device=self.images.device,
] )
).astype(np.float32) else:
w2c = np.linalg.inv(self.camtoworlds[camera_id]) image_id = [index]
rgba = self.images[camera_id] x = torch.randint(
0, self.WIDTH, size=(self.num_rays,), device=self.images.device
# create pixels )
rgba = torch.from_numpy(rgba).float() / 255.0 y = torch.randint(
0, self.HEIGHT, size=(self.num_rays,), device=self.images.device
# create rays from camera )
cameras = Cameras( else:
intrins=torch.from_numpy(K).float(), image_id = [index]
extrins=torch.from_numpy(w2c).float(), x, y = torch.meshgrid(
distorts=None, torch.arange(self.WIDTH, device=self.images.device),
width=self.WIDTH, torch.arange(self.HEIGHT, device=self.images.device),
height=self.HEIGHT, indexing="xy",
) )
x = x.flatten()
y = y.flatten()
# generate rays
rgba = self.images[image_id, y, x] / 255.0 # (num_rays, 4)
c2w = self.camtoworlds[image_id] # (num_rays, 3, 4)
camera_dirs = F.pad(
torch.stack(
[
(x - self.K[0, 2] + 0.5) / self.K[0, 0],
(y - self.K[1, 2] + 0.5) / self.K[1, 1],
],
dim=-1,
),
(0, 1),
value=1,
) # [num_rays, 3]
camera_dirs[..., [1, 2]] *= -1 # opengl format
# [n_cams, height, width, 3]
directions = (camera_dirs[:, None, :] * c2w[:, :3, :3]).sum(dim=-1)
origins = torch.broadcast_to(c2w[:, :3, -1], directions.shape)
viewdirs = directions / torch.linalg.norm(directions, dim=-1, keepdims=True)
if self.num_rays is not None: if self.training:
x = torch.randint(0, self.WIDTH, size=(self.num_rays,)) origins = torch.reshape(origins, (self.num_rays, 3))
y = torch.randint(0, self.HEIGHT, size=(self.num_rays,)) viewdirs = torch.reshape(viewdirs, (self.num_rays, 3))
pixels_xy = torch.stack([x, y], dim=-1) rgba = torch.reshape(rgba, (self.num_rays, 4))
rgba = rgba[y, x, :]
else: else:
pixels_xy = None # full image origins = torch.reshape(origins, (self.HEIGHT, self.WIDTH, 3))
viewdirs = torch.reshape(viewdirs, (self.HEIGHT, self.WIDTH, 3))
# Be careful: This dataset's camera coordinate is not the same as rgba = torch.reshape(rgba, (self.HEIGHT, self.WIDTH, 4))
# opencv's camera coordinate! It is actually opengl.
rays = generate_rays( rays = Rays(origins=origins, viewdirs=viewdirs)
cameras,
opencv_format=False,
pixels_xy=pixels_xy,
)
return { return {
"camera_id": camera_id,
"rgba": rgba, # [h, w, 4] or [num_rays, 4] "rgba": rgba, # [h, w, 4] or [num_rays, 4]
"rays": rays, # [h, w] or [num_rays, 4] "rays": rays, # [h, w, 3] or [num_rays, 3]
} }
# Copyright (c) Meta Platforms, Inc. and affiliates.
import collections
import math
import torch
import torch.nn.functional as F
Rays = collections.namedtuple("Rays", ("origins", "viewdirs"))
Cameras = collections.namedtuple(
"Cameras", ("intrins", "extrins", "distorts", "width", "height")
)
def namedtuple_map(fn, tup):
"""Apply `fn` to each element of `tup` and cast to `tup`'s namedtuple."""
return type(tup)(*(None if x is None else fn(x) for x in tup))
def homo(points: torch.Tensor) -> torch.Tensor:
"""Get the homogeneous coordinates."""
return F.pad(points, (0, 1), value=1)
def transform_cameras(cameras: Cameras, resize_factor: float) -> torch.Tensor:
intrins = cameras.intrins
intrins[..., :2, :] = intrins[..., :2, :] * resize_factor
width = int(cameras.width * resize_factor + 0.5)
height = int(cameras.height * resize_factor + 0.5)
return Cameras(
intrins=intrins,
extrins=cameras.extrins,
distorts=cameras.distorts,
width=width,
height=height,
)
def generate_rays(
cameras: Cameras,
opencv_format: bool = True,
pixels_xy: torch.Tensor = None,
) -> Rays:
"""Generating rays for a single or multiple cameras.
:params cameras [(n_cams,)]
:returns: Rays
[(n_cams,) height, width] if pixels_xy is None
[(n_cams,) num_pixels] if pixels_xy is given
"""
if pixels_xy is not None:
K = cameras.intrins[..., None, :, :]
c2w = cameras.extrins[..., None, :, :].inverse()
x, y = pixels_xy[..., 0], pixels_xy[..., 1]
else:
K = cameras.intrins[..., None, None, :, :]
c2w = cameras.extrins[..., None, None, :, :].inverse()
x, y = torch.meshgrid(
torch.arange(cameras.width, dtype=K.dtype),
torch.arange(cameras.height, dtype=K.dtype),
indexing="xy",
) # [height, width]
camera_dirs = homo(
torch.stack(
[
(x - K[..., 0, 2] + 0.5) / K[..., 0, 0],
(y - K[..., 1, 2] + 0.5) / K[..., 1, 1],
],
dim=-1,
)
) # [n_cams, height, width, 3]
if not opencv_format:
camera_dirs[..., [1, 2]] *= -1
# [n_cams, height, width, 3]
directions = (camera_dirs[..., None, :] * c2w[..., :3, :3]).sum(dim=-1)
origins = torch.broadcast_to(c2w[..., :3, -1], directions.shape)
viewdirs = directions / torch.linalg.norm(directions, dim=-1, keepdims=True)
rays = Rays(
origins=origins, # [n_cams, height, width, 3]
viewdirs=viewdirs, # [n_cams, height, width, 3]
)
return rays
...@@ -5,8 +5,7 @@ import numpy as np ...@@ -5,8 +5,7 @@ import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tqdm import tqdm
from datasets.nerf_synthetic import SubjectLoader from datasets.nerf_synthetic import SubjectLoader, namedtuple_map
from datasets.utils import namedtuple_map
from radiance_fields.ngp import NGPradianceField from radiance_fields.ngp import NGPradianceField
from nerfacc import OccupancyField, volumetric_rendering from nerfacc import OccupancyField, volumetric_rendering
...@@ -67,19 +66,26 @@ if __name__ == "__main__": ...@@ -67,19 +66,26 @@ if __name__ == "__main__":
split="train", split="train",
num_rays=8192, num_rays=8192,
) )
# train_dataset.images = train_dataset.images.to(device)
# train_dataset.camtoworlds = train_dataset.camtoworlds.to(device)
# train_dataset.K = train_dataset.K.to(device)
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_dataset, train_dataset,
num_workers=1, num_workers=4,
batch_size=None, batch_size=None,
persistent_workers=True, persistent_workers=True,
shuffle=True, shuffle=True,
) )
test_dataset = SubjectLoader( test_dataset = SubjectLoader(
subject_id="lego", subject_id="lego",
root_fp="/home/ruilongli/data/nerf_synthetic/", root_fp="/home/ruilongli/data/nerf_synthetic/",
split="test", split="test",
num_rays=None, num_rays=None,
) )
# test_dataset.images = test_dataset.images.to(device)
# test_dataset.camtoworlds = test_dataset.camtoworlds.to(device)
# test_dataset.K = test_dataset.K.to(device)
test_dataloader = torch.utils.data.DataLoader( test_dataloader = torch.utils.data.DataLoader(
test_dataset, test_dataset,
num_workers=4, num_workers=4,
...@@ -102,9 +108,9 @@ if __name__ == "__main__": ...@@ -102,9 +108,9 @@ if __name__ == "__main__":
) )
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=3e-3, eps=1e-15) optimizer = torch.optim.Adam(radiance_field.parameters(), lr=3e-3, eps=1e-15)
# scheduler = torch.optim.lr_scheduler.MultiStepLR( scheduler = torch.optim.lr_scheduler.MultiStepLR(
# optimizer, milestones=[10000, 20000, 30000], gamma=0.1 optimizer, milestones=[20000, 30000], gamma=0.1
# ) )
# setup occupancy field with eval function # setup occupancy field with eval function
def occ_eval_fn(x: torch.Tensor) -> torch.Tensor: def occ_eval_fn(x: torch.Tensor) -> torch.Tensor:
...@@ -156,7 +162,7 @@ if __name__ == "__main__": ...@@ -156,7 +162,7 @@ if __name__ == "__main__":
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
# scheduler.step() scheduler.step()
if step % 50 == 0: if step % 50 == 0:
elapsed_time = time.time() - tic elapsed_time = time.time() - tic
...@@ -189,6 +195,18 @@ if __name__ == "__main__": ...@@ -189,6 +195,18 @@ if __name__ == "__main__":
# elapsed_time=298.27s (data=60.08s) | step=30000 | loss=0.00026 # elapsed_time=298.27s (data=60.08s) | step=30000 | loss=0.00026
# evaluation: psnr_avg=33.305334663391115 (6.42 it/s) # evaluation: psnr_avg=33.305334663391115 (6.42 it/s)
# "train" batch_over_images=True
# elapsed_time=335.21s (data=68.99s) | step=30000 | loss=0.00028
# evaluation: psnr_avg=33.74970862388611 (6.23 it/s)
# "train" batch_over_images=True, schedule
# elapsed_time=296.30s (data=54.38s) | step=30000 | loss=0.00022
# evaluation: psnr_avg=34.3978275680542 (6.22 it/s)
# "trainval" # "trainval"
# elapsed_time=289.94s (data=51.99s) | step=30000 | loss=0.00021 # elapsed_time=289.94s (data=51.99s) | step=30000 | loss=0.00021
# evaluation: psnr_avg=34.44980221748352 (6.61 it/s) # evaluation: psnr_avg=34.44980221748352 (6.61 it/s)
# "trainval" batch_over_images=True, schedule
# elapsed_time=291.42s (data=52.82s) | step=30000 | loss=0.00020
# evaluation: psnr_avg=35.41630497932434 (6.40 it/s)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment