Commit f66bacfa authored by Ruilong Li's avatar Ruilong Li
Browse files

vanilla nerf examples

parent 255352aa
......@@ -2,15 +2,13 @@
This is a **tiny** tootlbox for **accelerating** NeRF training & rendering using PyTorch CUDA extensions. Plug-and-play for most of the NeRFs!
## Instant-NGP example
## Examples: Instant-NGP NeRF
``` bash
python examples/trainval.py ngp --train_split trainval
```
python examples/trainval.py
```
## Performance Reference
Ours on TITAN RTX :
Performance on TITAN RTX :
| trainval | Lego | Mic | Materials | Chair | Hotdog |
| - | - | - | - | - | - |
......@@ -25,6 +23,25 @@ Instant-NGP paper (5 min) on 3090 (w/ mask):
| PSNR | 36.39 | 36.22 | 29.78 | 35.00 | 37.40 |
## Examples: Vanilla MLP NeRF
``` bash
python examples/trainval.py vanilla --train_split train
```
Performance on test set:
| | Lego |
| - | - |
| Paper PSNR (train set) | 32.54 |
| Our PSNR (train set) | 33.21 |
| Our PSNR (trainval set) | 33.66 |
| Our train time & test FPS | 45min; 0.43FPS |
For reference, vanilla NeRF paper trains on V100 GPU for 1-2 days per scene. Test time rendering takes about 30 secs to render a 800x800 image. Our model is trained on a TITAN X.
Note: We only use a single MLP with more samples (1024), instead of two MLPs with coarse-to-fine sampling as in the paper. Both ways share the same spirit to do dense sampling around the surface. Our fast rendering inheritly skip samples away from the surface so we can simplly increase the number of samples with a single MLP, to achieve the same goal with coarse-to-fine sampling, without runtime or memory issue.
<!--
Tested with the default settings on the Lego test set.
......
""" The MLPs and Voxels. """
import math
from typing import Callable, Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module):
def __init__(
self,
input_dim: int, # The number of input tensor channels.
output_dim: int = None, # The number of output tensor channels.
net_depth: int = 8, # The depth of the MLP.
net_width: int = 256, # The width of the MLP.
skip_layer: int = 4, # The layer to add skip layers to.
hidden_init: Callable = nn.init.xavier_uniform_,
hidden_activation: Callable = nn.ReLU(),
output_enabled: bool = True,
output_init: Optional[Callable] = nn.init.xavier_uniform_,
output_activation: Optional[Callable] = nn.Identity(),
bias_enabled: bool = True,
bias_init: Callable = nn.init.zeros_,
):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.net_depth = net_depth
self.net_width = net_width
self.skip_layer = skip_layer
self.hidden_init = hidden_init
self.hidden_activation = hidden_activation
self.output_enabled = output_enabled
self.output_init = output_init
self.output_activation = output_activation
self.bias_enabled = bias_enabled
self.bias_init = bias_init
self.hidden_layers = nn.ModuleList()
in_features = self.input_dim
for i in range(self.net_depth):
self.hidden_layers.append(
nn.Linear(in_features, self.net_width, bias=bias_enabled)
)
if (self.skip_layer is not None) and (i % self.skip_layer == 0) and (i > 0):
in_features = self.net_width + self.input_dim
else:
in_features = self.net_width
if self.output_enabled:
self.output_layer = nn.Linear(
in_features, self.output_dim, bias=bias_enabled
)
else:
self.output_dim = in_features
self.initialize()
def initialize(self):
def init_func_hidden(m):
if isinstance(m, nn.Linear):
if self.hidden_init is not None:
self.hidden_init(m.weight)
if self.bias_enabled and self.bias_init is not None:
self.bias_init(m.bias)
self.hidden_layers.apply(init_func_hidden)
if self.output_enabled:
def init_func_output(m):
if isinstance(m, nn.Linear):
if self.output_init is not None:
self.output_init(m.weight)
if self.bias_enabled and self.bias_init is not None:
self.bias_init(m.bias)
self.output_layer.apply(init_func_output)
def forward(self, x):
inputs = x
for i in range(self.net_depth):
x = self.hidden_layers[i](x)
x = self.hidden_activation(x)
if (self.skip_layer is not None) and (i % self.skip_layer == 0) and (i > 0):
x = torch.cat([x, inputs], dim=-1)
if self.output_enabled:
x = self.output_layer(x)
x = self.output_activation(x)
return x
class DenseLayer(MLP):
def __init__(self, input_dim, output_dim, **kwargs):
super().__init__(
input_dim=input_dim,
output_dim=output_dim,
net_depth=0, # no hidden layers
**kwargs,
)
class NerfMLP(nn.Module):
def __init__(
self,
input_dim: int, # The number of input tensor channels.
condition_dim: int, # The number of condition tensor channels.
net_depth: int = 8, # The depth of the MLP.
net_width: int = 256, # The width of the MLP.
skip_layer: int = 4, # The layer to add skip layers to.
net_depth_condition: int = 1, # The depth of the second part of MLP.
net_width_condition: int = 128, # The width of the second part of MLP.
):
super().__init__()
self.base = MLP(
input_dim=input_dim,
net_depth=net_depth,
net_width=net_width,
skip_layer=skip_layer,
output_enabled=False,
)
hidden_features = self.base.output_dim
self.sigma_layer = DenseLayer(hidden_features, 1)
if condition_dim > 0:
self.bottleneck_layer = DenseLayer(hidden_features, net_width)
self.rgb_layer = MLP(
input_dim=net_width + condition_dim,
output_dim=3,
net_depth=net_depth_condition,
net_width=net_width_condition,
skip_layer=None,
)
else:
self.rgb_layer = DenseLayer(hidden_features, 3)
def query_density(self, x):
x = self.base(x)
raw_sigma = self.sigma_layer(x)
return raw_sigma
def forward(self, x, condition=None):
x = self.base(x)
raw_sigma = self.sigma_layer(x)
if condition is not None:
if condition.shape[:-1] != x.shape[:-1]:
num_rays, n_dim = condition.shape
condition = condition.view(
[num_rays] + [1] * (x.dim() - condition.dim()) + [n_dim]
).expand(list(x.shape[:-1]) + [n_dim])
bottleneck = self.bottleneck_layer(x)
x = torch.cat([bottleneck, condition], dim=-1)
raw_rgb = self.rgb_layer(x)
return raw_rgb, raw_sigma
class SinusoidalEncoder(nn.Module):
"""Sinusoidal Positional Encoder used in NeRF."""
def __init__(self, x_dim, min_deg, max_deg, use_identity: bool = True):
super().__init__()
self.x_dim = x_dim
self.min_deg = min_deg
self.max_deg = max_deg
self.use_identity = use_identity
self.register_buffer(
"scales", torch.tensor([2**i for i in range(min_deg, max_deg)])
)
@property
def latent_dim(self) -> int:
return (int(self.use_identity) + (self.max_deg - self.min_deg) * 2) * self.x_dim
def forward(self, x: torch.Tensor) -> Dict:
"""
Args:
x: [..., x_dim]
Returns:
latent: [..., latent_dim]
"""
xb = torch.reshape(
(x[Ellipsis, None, :] * self.scales[:, None]),
list(x.shape[:-1]) + [(self.max_deg - self.min_deg) * self.x_dim],
)
latent = torch.sin(torch.cat([xb, xb + 0.5 * math.pi], dim=-1))
if self.use_identity:
latent = torch.cat([x] + [latent], dim=-1)
return latent
class VanillaNeRFRadianceField(nn.Module):
def __init__(
self,
net_depth: int = 8, # The depth of the MLP.
net_width: int = 256, # The width of the MLP.
skip_layer: int = 4, # The layer to add skip layers to.
net_depth_condition: int = 1, # The depth of the second part of MLP.
net_width_condition: int = 128, # The width of the second part of MLP.
) -> None:
super().__init__()
self.posi_encoder = SinusoidalEncoder(3, 0, 10, True)
self.view_encoder = SinusoidalEncoder(3, 0, 4, True)
self.mlp = NerfMLP(
input_dim=self.posi_encoder.latent_dim,
condition_dim=self.view_encoder.latent_dim,
net_depth=net_depth,
net_width=net_width,
skip_layer=skip_layer,
net_depth_condition=net_depth_condition,
net_width_condition=net_width_condition,
)
def query_density(self, x):
x = self.posi_encoder(x)
sigma = self.mlp.query_density(x)
return F.relu(sigma)
def forward(self, x, condition=None):
x = self.posi_encoder(x)
if condition is not None:
condition = self.view_encoder(condition)
rgb, sigma = self.mlp(x, condition=condition)
return torch.sigmoid(rgb), F.relu(sigma)
import argparse
import math
import time
......@@ -6,11 +7,12 @@ import torch
import torch.nn.functional as F
import tqdm
from datasets.nerf_synthetic import SubjectLoader, namedtuple_map
from radiance_fields.mlp import VanillaNeRFRadianceField
from radiance_fields.ngp import NGPradianceField
from nerfacc import OccupancyField, volumetric_rendering
TARGET_SAMPLE_BATCH_SIZE = 1 << 18
TARGET_SAMPLE_BATCH_SIZE = 1 << 16
def render_image(radiance_field, rays, render_bkgd, render_step_size):
......@@ -78,15 +80,40 @@ def render_image(radiance_field, rays, render_bkgd, render_step_size):
if __name__ == "__main__":
torch.manual_seed(42)
parser = argparse.ArgumentParser()
parser.add_argument(
"method",
type=str,
default="ngp",
choices=["ngp", "vanilla"],
help="which nerf to use",
)
parser.add_argument(
"--train_split",
type=str,
default="trainval",
choices=["train", "trainval"],
help="which train split to use",
)
args = parser.parse_args()
device = "cuda:0"
scene = "lego"
# setup the scene bounding box.
scene_aabb = torch.tensor([-1.5, -1.5, -1.5, 1.5, 1.5, 1.5])
# setup some rendering settings
render_n_samples = 1024
render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples
).item()
# setup dataset
train_dataset = SubjectLoader(
subject_id=scene,
root_fp="/home/ruilongli/data/nerf_synthetic/",
split="trainval",
num_rays=1024,
split=args.train_split,
num_rays=TARGET_SAMPLE_BATCH_SIZE // render_n_samples,
# color_bkgd_aug="random",
)
......@@ -116,30 +143,28 @@ if __name__ == "__main__":
batch_size=None,
)
# setup the scene bounding box.
scene_aabb = torch.tensor([-1.5, -1.5, -1.5, 1.5, 1.5, 1.5])
# setup the scene radiance field. Assume you have a NeRF model and
# it has following functions:
# - query_density(): {x} -> {density}
# - forward(): {x, dirs} -> {rgb, density}
radiance_field = NGPradianceField(aabb=scene_aabb).to(device)
if args.method == "ngp":
radiance_field = NGPradianceField(aabb=scene_aabb).to(device)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=1e-2, eps=1e-15)
max_steps = 20000
occ_field_warmup_steps = 2000
grad_scaler = torch.cuda.amp.GradScaler(1)
elif args.method == "vanilla":
radiance_field = VanillaNeRFRadianceField().to(device)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4)
max_steps = 40000
occ_field_warmup_steps = 256
grad_scaler = torch.cuda.amp.GradScaler(2**10)
# setup some rendering settings
render_n_samples = 1024
render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples
).item()
optimizer = torch.optim.Adam(
radiance_field.parameters(),
lr=1e-2,
# betas=(0.9, 0.99),
eps=1e-15,
# weight_decay=1e-6,
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=[10000, 15000, 18000], gamma=0.33
optimizer,
milestones=[max_steps // 2, max_steps * 3 // 4, max_steps * 9 // 10],
gamma=0.33,
)
# setup occupancy field with eval function
......@@ -167,8 +192,6 @@ if __name__ == "__main__":
data_time = 0
tic_data = time.time()
# Scaling up the gradients for Adam
grad_scaler = torch.cuda.amp.GradScaler(2**10)
for epoch in range(10000000):
for i in range(len(train_dataset)):
radiance_field.train()
......@@ -183,7 +206,7 @@ if __name__ == "__main__":
pixels = data["pixels"]
# update occupancy grid
occ_field.every_n_step(step)
occ_field.every_n_step(step, warmup_steps=occ_field_warmup_steps)
rgb, acc, counter, compact_counter = render_image(
radiance_field, rays, render_bkgd, render_step_size
......@@ -215,7 +238,7 @@ if __name__ == "__main__":
)
# if time.time() - tic > 300:
if step >= 20_000 and step % 5000 == 0 and step > 0:
if step >= max_steps and step % max_steps == 0 and step > 0:
# evaluation
radiance_field.eval()
......@@ -240,23 +263,23 @@ if __name__ == "__main__":
# ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8),
# )
psnrs = []
train_dataset.training = False
with torch.no_grad():
for data in tqdm.tqdm(train_dataloader):
# generate rays from data and the gt pixel color
rays = namedtuple_map(lambda x: x.to(device), data["rays"])
pixels = data["pixels"].to(device)
render_bkgd = data["color_bkgd"].to(device)
# rendering
rgb, acc, _, _ = render_image(
radiance_field, rays, render_bkgd, render_step_size
)
mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item())
psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation on train: {psnr_avg=}")
# psnrs = []
# train_dataset.training = False
# with torch.no_grad():
# for data in tqdm.tqdm(train_dataloader):
# # generate rays from data and the gt pixel color
# rays = namedtuple_map(lambda x: x.to(device), data["rays"])
# pixels = data["pixels"].to(device)
# render_bkgd = data["color_bkgd"].to(device)
# # rendering
# rgb, acc, _, _ = render_image(
# radiance_field, rays, render_bkgd, render_step_size
# )
# mse = F.mse_loss(rgb, pixels)
# psnr = -10.0 * torch.log(mse) / np.log(10.0)
# psnrs.append(psnr.item())
# psnr_avg = sum(psnrs) / len(psnrs)
# print(f"evaluation on train: {psnr_avg=}")
# imageio.imwrite(
# "acc_binary_train.png",
# ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8),
......@@ -267,7 +290,7 @@ if __name__ == "__main__":
# )
train_dataset.training = True
if step == 20_000:
if step == max_steps:
print("training stops")
exit()
tic_data = time.time()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment