Unverified Commit ef812085 authored by Ruilong Li(李瑞龙)'s avatar Ruilong Li(李瑞龙) Committed by GitHub
Browse files

support dnerf. cleanup api. expose rendering pipeline (#22)

* dyn

* clean up for dyn

* update test

* ngp is up and running but slow

* benchmark on three examples

* fix and update ngp in readme

* clean up and expose volumetric_rendering_pipeline

* update doc and fix import

* fix import
parent 08761ab8
......@@ -118,3 +118,4 @@ venv.bak/
.vsocde
benchmarks/
outputs/
\ No newline at end of file
......@@ -12,9 +12,9 @@ Performance on TITAN RTX :
| trainval | Lego | Mic | Materials | Chair | Hotdog |
| - | - | - | - | - | - |
| Time | 300s | 272s | 258s | 331s | 287s |
| PSNR | 36.61 | 37.45 | 30.15 | 36.06 | 38.17 |
| FPS | 11.49 | 21.48 | 8.86 | 15.61 | 7.38 |
| Time | 300s | 274s | 266s | 341s | 277s |
| PSNR | 36.61 | 37.62 | 30.11 | 36.09 | 38.09 |
| FPS | 12.87 | 23.67 | 9.33 | 16.91 | 7.48 |
Instant-NGP paper (5 min) on 3090 (w/ mask):
......@@ -44,17 +44,23 @@ Note: We only use a single MLP with more samples (1024), instead of two MLPs wit
*FPS for some scenes are tested under `--test_chunk_size=8192` (default is `81920`) to avoid OOM.
<!--
Tested with the default settings on the Lego test set.
## Examples: MLP NeRF on Dynamic objects
Here we trained something similar to D-NeRF on the dnerf dataset:
``` bash
python examples/trainval.py dnerf --train_split train --test_chunk_size=8192
```
Performance on test set:
| | Lego | Stand Up |
| - | - | - |
| DNeRF paper PSNR (train set) | 21.64 | 32.79 |
| Our PSNR (train set) | 24.66 | 33.98 |
| Our train time & test FPS | 43min; 0.15FPS | 41min; 0.4FPS |
| Model | Split | PSNR | Train Time | Test Speed | GPU | Train Memory |
| - | - | - | - | - | - | - |
| instant-ngp (paper) | trainval? | 36.39 | - | - | 3090 |
| instant-ngp (code) | train (35k steps) | 36.08 | 308 sec | 55.32 fps | TITAN RTX | 1734MB |
| instant-ngp (code) w/o rng bkgd| train (35k steps) | 34.17 | - | - | - | - |
| torch-ngp (`-O`) | train (30K steps) | 34.15 | 310 sec | 7.8 fps | V100 |
| ours | trainval (35K steps) | 36.22 | 378 sec | 12.08 fps | TITAN RTX | -->
## Tips:
......
......@@ -3,6 +3,8 @@ Volumetric Rendering
.. currentmodule:: nerfacc
.. autofunction:: volumetric_rendering_pipeline
.. autofunction:: volumetric_rendering_steps
.. autofunction:: volumetric_rendering_weights
.. autofunction:: volumetric_rendering_accumulate
.. autofunction:: unpack_to_ray_indices
\ No newline at end of file
import collections
import json
import os
import imageio.v2 as imageio
import numpy as np
import torch
import torch.nn.functional as F
Rays = collections.namedtuple("Rays", ("origins", "viewdirs"))
def namedtuple_map(fn, tup):
"""Apply `fn` to each element of `tup` and cast to `tup`'s namedtuple."""
return type(tup)(*(None if x is None else fn(x) for x in tup))
def _load_renderings(root_fp: str, subject_id: str, split: str):
"""Load images from disk."""
if not root_fp.startswith("/"):
# allow relative path. e.g., "./data/dnerf_synthetic/"
root_fp = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"..",
"..",
root_fp,
)
data_dir = os.path.join(root_fp, subject_id)
with open(os.path.join(data_dir, "transforms_{}.json".format(split)), "r") as fp:
meta = json.load(fp)
images = []
camtoworlds = []
timestamps = []
for i in range(len(meta["frames"])):
frame = meta["frames"][i]
fname = os.path.join(data_dir, frame["file_path"] + ".png")
rgba = imageio.imread(fname)
timestamp = (
frame["time"] if "time" in frame else float(i) / (len(meta["frames"]) - 1)
)
timestamps.append(timestamp)
camtoworlds.append(frame["transform_matrix"])
images.append(rgba)
images = np.stack(images, axis=0)
camtoworlds = np.stack(camtoworlds, axis=0)
timestamps = np.stack(timestamps, axis=0)
h, w = images.shape[1:3]
camera_angle_x = float(meta["camera_angle_x"])
focal = 0.5 * w / np.tan(0.5 * camera_angle_x)
return images, camtoworlds, focal, timestamps
class SubjectLoader(torch.utils.data.Dataset):
"""Single subject data loader for training and evaluation."""
SPLITS = ["train", "val", "test"]
SUBJECT_IDS = [
"bouncingballs",
"hellwarrior",
"hook",
"jumpingjacks",
"lego",
"mutant",
"standup",
"trex",
]
WIDTH, HEIGHT = 800, 800
NEAR, FAR = 2.0, 6.0
OPENGL_CAMERA = True
def __init__(
self,
subject_id: str,
root_fp: str,
split: str,
color_bkgd_aug: str = "white",
num_rays: int = None,
near: float = None,
far: float = None,
batch_over_images: bool = True,
):
super().__init__()
assert split in self.SPLITS, "%s" % split
assert subject_id in self.SUBJECT_IDS, "%s" % subject_id
assert color_bkgd_aug in ["white", "black", "random"]
self.split = split
self.num_rays = num_rays
self.near = self.NEAR if near is None else near
self.far = self.FAR if far is None else far
self.training = (num_rays is not None) and (split in ["train", "trainval"])
self.color_bkgd_aug = color_bkgd_aug
self.batch_over_images = batch_over_images
self.images, self.camtoworlds, self.focal, self.timestamps = _load_renderings(
root_fp, subject_id, split
)
self.images = torch.from_numpy(self.images).to(torch.uint8)
self.camtoworlds = torch.from_numpy(self.camtoworlds).to(torch.float32)
self.timestamps = torch.from_numpy(self.timestamps).to(torch.float32)[:, None]
self.K = torch.tensor(
[
[self.focal, 0, self.WIDTH / 2.0],
[0, self.focal, self.HEIGHT / 2.0],
[0, 0, 1],
],
dtype=torch.float32,
) # (3, 3)
assert self.images.shape[1:3] == (self.HEIGHT, self.WIDTH)
def __len__(self):
return len(self.images)
@torch.no_grad()
def __getitem__(self, index):
data = self.fetch_data(index)
data = self.preprocess(data)
return data
def preprocess(self, data):
"""Process the fetched / cached data with randomness."""
rgba, rays = data["rgba"], data["rays"]
pixels, alpha = torch.split(rgba, [3, 1], dim=-1)
if self.training:
if self.color_bkgd_aug == "random":
color_bkgd = torch.rand(3, device=self.images.device)
elif self.color_bkgd_aug == "white":
color_bkgd = torch.ones(3, device=self.images.device)
elif self.color_bkgd_aug == "black":
color_bkgd = torch.zeros(3, device=self.images.device)
else:
# just use white during inference
color_bkgd = torch.ones(3, device=self.images.device)
pixels = pixels * alpha + color_bkgd * (1.0 - alpha)
return {
"pixels": pixels, # [n_rays, 3] or [h, w, 3]
"rays": rays, # [n_rays,] or [h, w]
"color_bkgd": color_bkgd, # [3,]
**{k: v for k, v in data.items() if k not in ["rgba", "rays"]},
}
def update_num_rays(self, num_rays):
self.num_rays = num_rays
def fetch_data(self, index):
"""Fetch the data (it maybe cached for multiple batches)."""
num_rays = self.num_rays
if self.training:
if self.batch_over_images:
image_id = torch.randint(
0,
len(self.images),
size=(num_rays,),
device=self.images.device,
)
else:
image_id = [index]
x = torch.randint(
0, self.WIDTH, size=(num_rays,), device=self.images.device
)
y = torch.randint(
0, self.HEIGHT, size=(num_rays,), device=self.images.device
)
else:
image_id = [index]
x, y = torch.meshgrid(
torch.arange(self.WIDTH, device=self.images.device),
torch.arange(self.HEIGHT, device=self.images.device),
indexing="xy",
)
x = x.flatten()
y = y.flatten()
# generate rays
rgba = self.images[image_id, y, x] / 255.0 # (num_rays, 4)
c2w = self.camtoworlds[image_id] # (num_rays, 3, 4)
camera_dirs = F.pad(
torch.stack(
[
(x - self.K[0, 2] + 0.5) / self.K[0, 0],
(y - self.K[1, 2] + 0.5)
/ self.K[1, 1]
* (-1.0 if self.OPENGL_CAMERA else 1.0),
],
dim=-1,
),
(0, 1),
value=(-1.0 if self.OPENGL_CAMERA else 1.0),
) # [num_rays, 3]
# [n_cams, height, width, 3]
directions = (camera_dirs[:, None, :] * c2w[:, :3, :3]).sum(dim=-1)
origins = torch.broadcast_to(c2w[:, :3, -1], directions.shape)
viewdirs = directions / torch.linalg.norm(directions, dim=-1, keepdims=True)
if self.training:
origins = torch.reshape(origins, (num_rays, 3))
viewdirs = torch.reshape(viewdirs, (num_rays, 3))
rgba = torch.reshape(rgba, (num_rays, 4))
else:
origins = torch.reshape(origins, (self.HEIGHT, self.WIDTH, 3))
viewdirs = torch.reshape(viewdirs, (self.HEIGHT, self.WIDTH, 3))
rgba = torch.reshape(rgba, (self.HEIGHT, self.WIDTH, 4))
rays = Rays(origins=origins, viewdirs=viewdirs)
timestamps = self.timestamps[image_id]
return {
"rgba": rgba, # [h, w, 4] or [num_rays, 4]
"rays": rays, # [h, w, 3] or [num_rays, 3]
"timestamps": timestamps, # [num_rays, 1]
}
""" The MLPs and Voxels. """
import functools
import math
from typing import Callable, Dict, Optional
from typing import Callable, Optional
import torch
import torch.nn as nn
......@@ -170,13 +171,15 @@ class SinusoidalEncoder(nn.Module):
def latent_dim(self) -> int:
return (int(self.use_identity) + (self.max_deg - self.min_deg) * 2) * self.x_dim
def forward(self, x: torch.Tensor) -> Dict:
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [..., x_dim]
Returns:
latent: [..., latent_dim]
"""
if self.max_deg == self.min_deg:
return x
xb = torch.reshape(
(x[Ellipsis, None, :] * self.scales[:, None]),
list(x.shape[:-1]) + [(self.max_deg - self.min_deg) * self.x_dim],
......@@ -220,3 +223,31 @@ class VanillaNeRFRadianceField(nn.Module):
condition = self.view_encoder(condition)
rgb, sigma = self.mlp(x, condition=condition)
return torch.sigmoid(rgb), F.relu(sigma)
class DNeRFRadianceField(nn.Module):
def __init__(self) -> None:
super().__init__()
self.posi_encoder = SinusoidalEncoder(3, 0, 0, True)
self.time_encoder = SinusoidalEncoder(1, 0, 0, True)
self.warp = MLP(
input_dim=self.posi_encoder.latent_dim + self.time_encoder.latent_dim,
output_dim=3,
net_depth=4,
net_width=64,
skip_layer=2,
output_init=functools.partial(torch.nn.init.uniform_, b=1e-4),
)
self.nerf = VanillaNeRFRadianceField()
def query_density(self, x, t):
x = x + self.warp(
torch.cat([self.posi_encoder(x), self.time_encoder(t)], dim=-1)
)
return self.nerf.query_density(x)
def forward(self, x, t, condition=None):
x = x + self.warp(
torch.cat([self.posi_encoder(x), self.time_encoder(t)], dim=-1)
)
return self.nerf(x, condition=condition)
import argparse
import math
import os
import random
import time
import imageio
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from datasets.nerf_synthetic import SubjectLoader, namedtuple_map
from radiance_fields.mlp import VanillaNeRFRadianceField
from radiance_fields.ngp import NGPradianceField
from nerfacc import OccupancyField, volumetric_rendering
from nerfacc import OccupancyField, volumetric_rendering_pipeline
TARGET_SAMPLE_BATCH_SIZE = 1 << 16
device = "cuda:0"
def _set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def render_image(
radiance_field, rays, render_bkgd, render_step_size, test_chunk_size=81920
radiance_field,
rays,
timestamps,
render_bkgd,
render_step_size,
test_chunk_size=81920,
):
"""Render the pixels of an image.
......@@ -37,25 +48,45 @@ def render_image(
else:
num_rays, _ = rays_shape
def sigma_fn(frustum_origins, frustum_dirs, frustum_starts, frustum_ends):
def sigma_fn(frustum_starts, frustum_ends, ray_indices):
ray_indices = ray_indices.long()
frustum_origins = chunk_rays.origins[ray_indices]
frustum_dirs = chunk_rays.viewdirs[ray_indices]
positions = (
frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0
)
if timestamps is None:
return radiance_field.query_density(positions)
else:
if radiance_field.training:
t = timestamps[ray_indices]
else:
t = timestamps.expand_as(positions[:, :1])
return radiance_field.query_density(positions, t)
def sigma_rgb_fn(frustum_origins, frustum_dirs, frustum_starts, frustum_ends):
def rgb_sigma_fn(frustum_starts, frustum_ends, ray_indices):
ray_indices = ray_indices.long()
frustum_origins = chunk_rays.origins[ray_indices]
frustum_dirs = chunk_rays.viewdirs[ray_indices]
positions = (
frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0
)
if timestamps is None:
return radiance_field(positions, frustum_dirs)
else:
if radiance_field.training:
t = timestamps[ray_indices]
else:
t = timestamps.expand_as(positions[:, :1])
return radiance_field(positions, t, frustum_dirs)
results = []
chunk = torch.iinfo(torch.int32).max if radiance_field.training else test_chunk_size
for i in range(0, num_rays, chunk):
chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
chunk_results = volumetric_rendering(
chunk_results = volumetric_rendering_pipeline(
sigma_fn=sigma_fn,
sigma_rgb_fn=sigma_rgb_fn,
rgb_sigma_fn=rgb_sigma_fn,
rays_o=chunk_rays.origins,
rays_d=chunk_rays.viewdirs,
scene_aabb=occ_field.aabb,
......@@ -80,14 +111,14 @@ def render_image(
if __name__ == "__main__":
torch.manual_seed(42)
_set_random_seed(42)
parser = argparse.ArgumentParser()
parser.add_argument(
"method",
type=str,
default="ngp",
choices=["ngp", "vanilla"],
choices=["ngp", "vanilla", "dnerf"],
help="which nerf to use",
)
parser.add_argument(
......@@ -102,6 +133,7 @@ if __name__ == "__main__":
type=str,
default="lego",
choices=[
# nerf synthetic
"chair",
"drums",
"ficus",
......@@ -110,9 +142,23 @@ if __name__ == "__main__":
"materials",
"mic",
"ship",
# dnerf
"bouncingballs",
"hellwarrior",
"hook",
"jumpingjacks",
"lego",
"mutant",
"standup",
"trex",
],
help="which scene to use",
)
parser.add_argument(
"--aabb",
type=list,
default=[-1.5, -1.5, -1.5, 1.5, 1.5, 1.5],
)
parser.add_argument(
"--test_chunk_size",
type=int,
......@@ -120,11 +166,46 @@ if __name__ == "__main__":
)
args = parser.parse_args()
device = "cuda:0"
if args.method == "ngp":
from datasets.nerf_synthetic import SubjectLoader, namedtuple_map
from radiance_fields.ngp import NGPradianceField
radiance_field = NGPradianceField(aabb=args.aabb).to(device)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=1e-2, eps=1e-15)
max_steps = 20000
occ_field_warmup_steps = 256
grad_scaler = torch.cuda.amp.GradScaler(2**10)
data_root_fp = "/home/ruilongli/data/nerf_synthetic/"
target_sample_batch_size = 1 << 18
elif args.method == "vanilla":
from datasets.nerf_synthetic import SubjectLoader, namedtuple_map
from radiance_fields.mlp import VanillaNeRFRadianceField
radiance_field = VanillaNeRFRadianceField().to(device)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4)
max_steps = 40000
occ_field_warmup_steps = 2000
grad_scaler = torch.cuda.amp.GradScaler(1)
data_root_fp = "/home/ruilongli/data/nerf_synthetic/"
target_sample_batch_size = 1 << 16
elif args.method == "dnerf":
from datasets.dnerf_synthetic import SubjectLoader, namedtuple_map
from radiance_fields.mlp import DNeRFRadianceField
radiance_field = DNeRFRadianceField().to(device)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4)
max_steps = 40000
occ_field_warmup_steps = 2000
grad_scaler = torch.cuda.amp.GradScaler(1)
data_root_fp = "/home/ruilongli/data/dnerf/"
target_sample_batch_size = 1 << 16
scene = args.scene
# setup the scene bounding box.
scene_aabb = torch.tensor([-1.5, -1.5, -1.5, 1.5, 1.5, 1.5])
scene_aabb = torch.tensor(args.aabb)
# setup some rendering settings
render_n_samples = 1024
render_step_size = (
......@@ -134,55 +215,29 @@ if __name__ == "__main__":
# setup dataset
train_dataset = SubjectLoader(
subject_id=scene,
root_fp="/home/ruilongli/data/nerf_synthetic/",
root_fp=data_root_fp,
split=args.train_split,
num_rays=TARGET_SAMPLE_BATCH_SIZE // render_n_samples,
num_rays=target_sample_batch_size // render_n_samples,
# color_bkgd_aug="random",
)
train_dataset.images = train_dataset.images.to(device)
train_dataset.camtoworlds = train_dataset.camtoworlds.to(device)
train_dataset.K = train_dataset.K.to(device)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
num_workers=0,
batch_size=None,
# persistent_workers=True,
shuffle=True,
)
if hasattr(train_dataset, "timestamps"):
train_dataset.timestamps = train_dataset.timestamps.to(device)
test_dataset = SubjectLoader(
subject_id=scene,
root_fp="/home/ruilongli/data/nerf_synthetic/",
root_fp=data_root_fp,
split="test",
num_rays=None,
)
test_dataset.images = test_dataset.images.to(device)
test_dataset.camtoworlds = test_dataset.camtoworlds.to(device)
test_dataset.K = test_dataset.K.to(device)
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
num_workers=0,
batch_size=None,
)
# setup the scene radiance field. Assume you have a NeRF model and
# it has following functions:
# - query_density(): {x} -> {density}
# - forward(): {x, dirs} -> {rgb, density}
if args.method == "ngp":
radiance_field = NGPradianceField(aabb=scene_aabb).to(device)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=1e-2, eps=1e-15)
max_steps = 20000
occ_field_warmup_steps = 2000
grad_scaler = torch.cuda.amp.GradScaler(1)
elif args.method == "vanilla":
radiance_field = VanillaNeRFRadianceField().to(device)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4)
max_steps = 40000
occ_field_warmup_steps = 256
grad_scaler = torch.cuda.amp.GradScaler(2**10)
if hasattr(train_dataset, "timestamps"):
test_dataset.timestamps = test_dataset.timestamps.to(device)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
......@@ -199,6 +254,13 @@ if __name__ == "__main__":
Returns:
occupancy values with shape (N, 1).
"""
if args.method == "dnerf":
idxs = torch.randint(
0, len(train_dataset.timestamps), (x.shape[0],), device=x.device
)
t = train_dataset.timestamps[idxs]
density_after_activation = radiance_field.query_density(x, t)
else:
density_after_activation = radiance_field.query_density(x)
# those two are similar when density is small.
# occupancy = 1.0 - torch.exp(-density_after_activation * render_step_size)
......@@ -221,22 +283,20 @@ if __name__ == "__main__":
data = train_dataset[i]
data_time += time.time() - tic_data
# generate rays from data and the gt pixel color
# rays = namedtuple_map(lambda x: x.to(device), data["rays"])
# pixels = data["pixels"].to(device)
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
timestamps = data.get("timestamps", None)
# update occupancy grid
occ_field.every_n_step(step, warmup_steps=occ_field_warmup_steps)
rgb, acc, counter, compact_counter = render_image(
radiance_field, rays, render_bkgd, render_step_size
radiance_field, rays, timestamps, render_bkgd, render_step_size
)
num_rays = len(pixels)
num_rays = int(
num_rays * (TARGET_SAMPLE_BATCH_SIZE / float(compact_counter))
num_rays * (target_sample_batch_size / float(compact_counter))
)
train_dataset.update_num_rays(num_rays)
alive_ray_mask = acc.squeeze(-1) > 0
......@@ -261,21 +321,24 @@ if __name__ == "__main__":
)
# if time.time() - tic > 300:
if step >= max_steps and step % max_steps == 0 and step > 0:
if step >= 0 and step % max_steps == 0 and step > 0:
# evaluation
radiance_field.eval()
psnrs = []
with torch.no_grad():
for data in tqdm.tqdm(test_dataloader):
# generate rays from data and the gt pixel color
rays = namedtuple_map(lambda x: x.to(device), data["rays"])
pixels = data["pixels"].to(device)
render_bkgd = data["color_bkgd"].to(device)
for i in tqdm.tqdm(range(len(test_dataset))):
data = test_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
timestamps = data.get("timestamps", None)
# rendering
rgb, acc, _, _ = render_image(
radiance_field,
rays,
timestamps,
render_bkgd,
render_step_size,
test_chunk_size=args.test_chunk_size,
......@@ -283,38 +346,28 @@ if __name__ == "__main__":
mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item())
psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation: {psnr_avg=}")
# if step == max_steps:
# output_dir = os.path.join("./outputs/nerfacc/", scene)
# os.makedirs(output_dir, exist_ok=True)
# save = torch.cat([pixels, rgb], dim=1)
# imageio.imwrite(
# "acc_binary_test.png",
# ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8),
# )
# psnrs = []
# train_dataset.training = False
# with torch.no_grad():
# for data in tqdm.tqdm(train_dataloader):
# # generate rays from data and the gt pixel color
# rays = namedtuple_map(lambda x: x.to(device), data["rays"])
# pixels = data["pixels"].to(device)
# render_bkgd = data["color_bkgd"].to(device)
# # rendering
# rgb, acc, _, _ = render_image(
# radiance_field, rays, render_bkgd, render_step_size
# os.path.join(output_dir, "%05d.png" % i),
# (save.cpu().numpy() * 255).astype(np.uint8),
# )
# mse = F.mse_loss(rgb, pixels)
# psnr = -10.0 * torch.log(mse) / np.log(10.0)
# psnrs.append(psnr.item())
# psnr_avg = sum(psnrs) / len(psnrs)
# print(f"evaluation on train: {psnr_avg=}")
# else:
# imageio.imwrite(
# "acc_binary_train.png",
# ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8),
# "acc_binary_test.png",
# ((acc > 0).float().cpu().numpy() * 255).astype(
# np.uint8
# ),
# )
# imageio.imwrite(
# "rgb_train.png",
# "rgb_test.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8),
# )
# break
psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation: {psnr_avg=}")
train_dataset.training = True
if step == max_steps:
......
from .occupancy_field import OccupancyField
from .utils import (
ray_aabb_intersect,
unpack_to_ray_indices,
volumetric_marching,
volumetric_rendering_accumulate,
volumetric_rendering_steps,
volumetric_rendering_weights,
)
from .volumetric_rendering import volumetric_rendering
from .volumetric_rendering import volumetric_rendering_pipeline
__all__ = [
"OccupancyField",
......@@ -15,5 +16,6 @@ __all__ = [
"volumetric_rendering_accumulate",
"volumetric_rendering_steps",
"volumetric_rendering_weights",
"volumetric_rendering",
"volumetric_rendering_pipeline",
"unpack_to_ray_indices",
]
from typing import Callable
def _make_lazy_cuda(name: str) -> Callable:
def call_cuda(*args, **kwargs):
# pylint: disable=import-outside-toplevel
......@@ -9,8 +10,14 @@ def _make_lazy_cuda(name: str) -> Callable:
return call_cuda
ray_aabb_intersect = _make_lazy_cuda("ray_aabb_intersect")
volumetric_marching = _make_lazy_cuda("volumetric_marching")
volumetric_rendering_steps = _make_lazy_cuda("volumetric_rendering_steps")
volumetric_rendering_weights_forward = _make_lazy_cuda("volumetric_rendering_weights_forward")
volumetric_rendering_weights_backward = _make_lazy_cuda("volumetric_rendering_weights_backward")
\ No newline at end of file
volumetric_rendering_weights_forward = _make_lazy_cuda(
"volumetric_rendering_weights_forward"
)
volumetric_rendering_weights_backward = _make_lazy_cuda(
"volumetric_rendering_weights_backward"
)
unpack_to_ray_indices = _make_lazy_cuda("unpack_to_ray_indices")
......@@ -14,7 +14,7 @@ std::vector<torch::Tensor> volumetric_rendering_steps(
torch::Tensor sigmas
);
std::vector<torch::Tensor> volumetric_rendering_weights_forward(
torch::Tensor volumetric_rendering_weights_forward(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
......@@ -44,6 +44,8 @@ std::vector<torch::Tensor> volumetric_marching(
const float dt
);
torch::Tensor unpack_to_ray_indices(const torch::Tensor packed_info);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("ray_aabb_intersect", &ray_aabb_intersect);
......@@ -51,4 +53,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("volumetric_rendering_steps", &volumetric_rendering_steps);
m.def("volumetric_rendering_weights_forward", &volumetric_rendering_weights_forward);
m.def("volumetric_rendering_weights_backward", &volumetric_rendering_weights_backward);
m.def("unpack_to_ray_indices", &unpack_to_ray_indices);
}
\ No newline at end of file
......@@ -144,8 +144,6 @@ __global__ void marching_forward_kernel(
const float dt,
const int* packed_info,
// frustrum outputs
float* frustum_origins,
float* frustum_dirs,
float* frustum_starts,
float* frustum_ends
) {
......@@ -165,8 +163,6 @@ __global__ void marching_forward_kernel(
const float near = t_min[0], far = t_max[0];
// locate
frustum_origins += base * 3;
frustum_dirs += base * 3;
frustum_starts += base;
frustum_ends += base;
......@@ -182,12 +178,6 @@ __global__ void marching_forward_kernel(
const float z = oz + t_mid * dz;
if (grid_occupied_at(x, y, z, resx, resy, resz, aabb, occ_binary)) {
frustum_origins[j * 3 + 0] = ox;
frustum_origins[j * 3 + 1] = oy;
frustum_origins[j * 3 + 2] = oz;
frustum_dirs[j * 3 + 0] = dx;
frustum_dirs[j * 3 + 1] = dy;
frustum_dirs[j * 3 + 2] = dz;
frustum_starts[j] = t0;
frustum_ends[j] = t1;
++j;
......@@ -211,6 +201,27 @@ __global__ void marching_forward_kernel(
return;
}
__global__ void ray_indices_kernel(
// input
const int n_rays,
const int* packed_info,
// output
int* ray_indices
) {
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
if (steps == 0) return;
ray_indices += base;
for (int j = 0; j < steps; ++j) {
ray_indices[j] = i;
}
}
std::vector<torch::Tensor> volumetric_marching(
// rays
......@@ -271,8 +282,6 @@ std::vector<torch::Tensor> volumetric_marching(
// output frustum samples
int total_steps = cum_steps[cum_steps.size(0) - 1].item<int>();
torch::Tensor frustum_origins = torch::zeros({total_steps, 3}, rays_o.options());
torch::Tensor frustum_dirs = torch::zeros({total_steps, 3}, rays_o.options());
torch::Tensor frustum_starts = torch::zeros({total_steps, 1}, rays_o.options());
torch::Tensor frustum_ends = torch::zeros({total_steps, 1}, rays_o.options());
......@@ -293,12 +302,32 @@ std::vector<torch::Tensor> volumetric_marching(
dt,
packed_info.data_ptr<int>(),
// outputs
frustum_origins.data_ptr<float>(),
frustum_dirs.data_ptr<float>(),
frustum_starts.data_ptr<float>(),
frustum_ends.data_ptr<float>()
);
return {packed_info, frustum_origins, frustum_dirs, frustum_starts, frustum_ends};
return {packed_info, frustum_starts, frustum_ends};
}
torch::Tensor unpack_to_ray_indices(const torch::Tensor packed_info) {
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
const int n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
int n_samples = packed_info[n_rays - 1].sum(0).item<int>();
torch::Tensor ray_indices = torch::zeros(
{n_samples}, packed_info.options().dtype(torch::kInt32));
ray_indices_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
packed_info.data_ptr<int>(),
ray_indices.data_ptr<int>()
);
return ray_indices;
}
......@@ -52,8 +52,7 @@ __global__ void volumetric_rendering_weights_forward_kernel(
const scalar_t* ends, // input end t
const scalar_t* sigmas, // input density after activation
// should be all-zero initialized
scalar_t* weights, // output
int* samples_ray_ids // output
scalar_t* weights // output
) {
CUDA_GET_THREAD_ID(i, n_rays);
......@@ -66,11 +65,6 @@ __global__ void volumetric_rendering_weights_forward_kernel(
ends += base;
sigmas += base;
weights += base;
samples_ray_ids += base;
for (int j = 0; j < steps; ++j) {
samples_ray_ids[j] = i;
}
// accumulated rendering
scalar_t T = 1.f;
......@@ -184,7 +178,7 @@ std::vector<torch::Tensor> volumetric_rendering_steps(
}
std::vector<torch::Tensor> volumetric_rendering_weights_forward(
torch::Tensor volumetric_rendering_weights_forward(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
......@@ -208,7 +202,6 @@ std::vector<torch::Tensor> volumetric_rendering_weights_forward(
// outputs
torch::Tensor weights = torch::zeros({n_samples}, sigmas.options());
torch::Tensor ray_indices = torch::zeros({n_samples}, packed_info.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
sigmas.scalar_type(),
......@@ -220,12 +213,11 @@ std::vector<torch::Tensor> volumetric_rendering_weights_forward(
starts.data_ptr<scalar_t>(),
ends.data_ptr<scalar_t>(),
sigmas.data_ptr<scalar_t>(),
weights.data_ptr<scalar_t>(),
ray_indices.data_ptr<int>()
weights.data_ptr<scalar_t>()
);
}));
return {weights, ray_indices};
return weights;
}
......
......@@ -75,8 +75,6 @@ def volumetric_marching(
It is a tensor with shape (n_rays, 2). For each ray, the two values \
indicate the start index and the number of samples for this ray, \
respectively.
- **frustum_origins**: Sampled frustum origins. Tensor with shape (n_samples, 3).
- **frustum_dirs**: Sampled frustum directions. Tensor with shape (n_samples, 3).
- **frustum_starts**: Sampled frustum directions. Tensor with shape (n_samples, 3).
- **frustum_ends**: Sampled frustum directions. Tensor with shape (n_samples, 3).
......@@ -94,13 +92,7 @@ def volumetric_marching(
if stratified:
t_min = t_min + torch.rand_like(t_min) * render_step_size
(
packed_info,
frustum_origins,
frustum_dirs,
frustum_starts,
frustum_ends,
) = nerfacc_cuda.volumetric_marching(
packed_info, frustum_starts, frustum_ends = nerfacc_cuda.volumetric_marching(
# rays
rays_o.contiguous(),
rays_d.contiguous(),
......@@ -114,13 +106,7 @@ def volumetric_marching(
render_step_size,
)
return (
packed_info,
frustum_origins,
frustum_dirs,
frustum_starts,
frustum_ends,
)
return packed_info, frustum_starts, frustum_ends
@torch.no_grad()
......@@ -206,7 +192,6 @@ def volumetric_rendering_weights(
A tuple of tensors containing
- **weights**: Volumetric rendering weights for those samples. Tensor with shape (n_samples).
- **ray_indices**: Ray index of each sample. IntTensor with shape (n_sample).
"""
if (
......@@ -219,12 +204,12 @@ def volumetric_rendering_weights(
frustum_starts = frustum_starts.contiguous()
frustum_ends = frustum_ends.contiguous()
sigmas = sigmas.contiguous()
weights, ray_indices = _volumetric_rendering_weights.apply(
weights = _volumetric_rendering_weights.apply(
packed_info, frustum_starts, frustum_ends, sigmas
)
else:
raise NotImplementedError("Only support cuda inputs.")
return weights, ray_indices
return weights
def volumetric_rendering_accumulate(
......@@ -275,10 +260,32 @@ def volumetric_rendering_accumulate(
return outputs
@torch.no_grad()
def unpack_to_ray_indices(packed_info: Tensor) -> Tensor:
"""Unpack `packed_info` to ray indices. Useful for converting per ray data to per sample data.
Note: this function is not differentiable to inputs.
Args:
packed_info: Stores infomation on which samples belong to the same ray. \
See ``volumetric_marching`` for details. Tensor with shape (n_rays, 2).
Returns:
Ray index of each sample. IntTensor with shape (n_sample).
"""
if packed_info.is_cuda:
packed_info = packed_info.contiguous()
ray_indices = nerfacc_cuda.unpack_to_ray_indices(packed_info)
else:
raise NotImplementedError("Only support cuda inputs.")
return ray_indices
class _volumetric_rendering_weights(torch.autograd.Function):
@staticmethod
def forward(ctx, packed_info, frustum_starts, frustum_ends, sigmas):
weights, ray_indices = nerfacc_cuda.volumetric_rendering_weights_forward(
weights = nerfacc_cuda.volumetric_rendering_weights_forward(
packed_info, frustum_starts, frustum_ends, sigmas
)
ctx.save_for_backward(
......@@ -288,10 +295,10 @@ class _volumetric_rendering_weights(torch.autograd.Function):
sigmas,
weights,
)
return weights, ray_indices
return weights
@staticmethod
def backward(ctx, grad_weights, _grad_ray_indices):
def backward(ctx, grad_weights):
(
packed_info,
frustum_starts,
......
......@@ -3,6 +3,7 @@ from typing import Callable, List, Optional, Tuple
import torch
from .utils import (
unpack_to_ray_indices,
volumetric_marching,
volumetric_rendering_accumulate,
volumetric_rendering_steps,
......@@ -10,9 +11,9 @@ from .utils import (
)
def volumetric_rendering(
def volumetric_rendering_pipeline(
sigma_fn: Callable,
sigma_rgb_fn: Callable,
rgb_sigma_fn: Callable,
rays_o: torch.Tensor,
rays_d: torch.Tensor,
scene_aabb: torch.Tensor,
......@@ -23,7 +24,34 @@ def volumetric_rendering(
near_plane: float = 0.0,
stratified: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, int, int]:
"""Differentiable volumetric rendering."""
"""Differentiable volumetric rendering pipeline.
This function is the integration of those individual functions:
- ray_aabb_intersect
- volumetric_marhcing
- volumetric_rendering_steps
- volumetric_rendering_weights
- volumetric_rendering_accumulate
Args:
sigma_fn: A function that takes in the {frustum starts (N, 1), frustum ends (N, 1), and
ray indices (N,)} and returns the post-activation sigma values (N, 1).
rgb_sigma_fn: A function that takes in the {frustum starts (N, 1), frustum ends (N, 1), and
ray indices (N,)} and returns the post-activation rgb values (N, 3) and sigma values (N, 1).
rays_o: The origin of the rays (n_rays, 3).
rays_d: The normalized direction of the rays (n_rays, 3).
scene_aabb: The scene axis-aligned bounding box {xmin, ymin, zmin, xmax, ymax, zmax}.
scene_resolution: The scene resolution (3,). Defaults to None.
scene_occ_binary: The scene occupancy binary tensor used to skip samples (n_cells,). Defaults to None.
render_bkgd: The background color (3,). Default: None.
render_step_size: The step size for the volumetric rendering. Default: 1e-3.
near_plane: The near plane for the volumetric rendering. Default: 0.0.
stratified: Whether to use stratified sampling. Default: False.
Returns:
Ray colors (n_rays, 3), and opacities (n_rays, 1), the number of marching steps, and the number of rendering steps.
"""
n_rays = rays_o.shape[0]
if scene_occ_binary is None:
......@@ -45,13 +73,7 @@ def volumetric_rendering(
with torch.no_grad():
# Ray marching and occupancy check.
(
packed_info,
frustum_origins,
frustum_dirs,
frustum_starts,
frustum_ends,
) = volumetric_marching(
packed_info, frustum_starts, frustum_ends = volumetric_marching(
rays_o,
rays_d,
aabb=scene_aabb,
......@@ -62,44 +84,30 @@ def volumetric_rendering(
stratified=stratified,
)
n_marching_samples = frustum_starts.shape[0]
ray_indices = unpack_to_ray_indices(packed_info)
# Query sigma without gradients
sigmas = sigma_fn(
frustum_origins,
frustum_dirs,
frustum_starts,
frustum_ends,
)
sigmas = sigma_fn(frustum_starts, frustum_ends, ray_indices)
# Ray marching and rendering check.
(
packed_info,
frustum_starts,
frustum_ends,
frustum_origins,
frustum_dirs,
) = volumetric_rendering_steps(
packed_info, frustum_starts, frustum_ends = volumetric_rendering_steps(
packed_info,
sigmas,
frustum_starts,
frustum_ends,
frustum_origins,
frustum_dirs,
)
n_rendering_samples = frustum_starts.shape[0]
ray_indices = unpack_to_ray_indices(packed_info)
# Query sigma and color with gradients
rgbs, sigmas = sigma_rgb_fn(
frustum_origins,
frustum_dirs,
frustum_starts,
frustum_ends,
rgbs, sigmas = rgb_sigma_fn(frustum_starts, frustum_ends, ray_indices)
assert rgbs.shape[-1] == 3, "rgbs must have 3 channels, got {}".format(rgbs.shape)
assert sigmas.shape[-1] == 1, "sigmas must have 1 channel, got {}".format(
sigmas.shape
)
assert rgbs.shape[-1] == 3, "rgbs must have 3 channels"
assert sigmas.shape[-1] == 1, "sigmas must have 1 channel"
# Rendering: compute weights and ray indices.
weights, ray_indices = volumetric_rendering_weights(
weights = volumetric_rendering_weights(
packed_info, sigmas, frustum_starts, frustum_ends
)
......
......@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "nerfacc"
version = "0.0.7"
version = "0.0.8"
authors = [{name = "Ruilong", email = "ruilongli94@gmail.com"}]
license = { text="MIT" }
requires-python = ">=3.8"
......
import torch
import tqdm
from nerfacc import volumetric_rendering
from nerfacc import volumetric_rendering_pipeline
device = "cuda:0"
def sigma_fn(frustum_origins, frustum_dirs, frustum_starts, frustum_ends):
def sigma_fn(frustum_starts, frustum_ends, ray_indices):
return torch.rand_like(frustum_ends[:, :1])
def sigma_rgb_fn(frustum_origins, frustum_dirs, frustum_starts, frustum_ends):
return torch.rand_like(frustum_ends[:, :1]), torch.rand_like(frustum_ends[:, :3])
def rgb_sigma_fn(frustum_starts, frustum_ends, ray_indices):
return torch.rand((frustum_ends.shape[0], 3), device=device), torch.rand_like(
frustum_ends
)
def test_rendering():
......@@ -24,9 +26,9 @@ def test_rendering():
render_bkgd = torch.ones(3, device=device)
for step in tqdm.tqdm(range(1000)):
volumetric_rendering(
volumetric_rendering_pipeline(
sigma_fn,
sigma_rgb_fn,
rgb_sigma_fn,
rays_o,
rays_d,
scene_aabb,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment