Unverified Commit ef812085 authored by Ruilong Li(李瑞龙)'s avatar Ruilong Li(李瑞龙) Committed by GitHub
Browse files

support dnerf. cleanup api. expose rendering pipeline (#22)

* dyn

* clean up for dyn

* update test

* ngp is up and running but slow

* benchmark on three examples

* fix and update ngp in readme

* clean up and expose volumetric_rendering_pipeline

* update doc and fix import

* fix import
parent 08761ab8
...@@ -118,3 +118,4 @@ venv.bak/ ...@@ -118,3 +118,4 @@ venv.bak/
.vsocde .vsocde
benchmarks/ benchmarks/
outputs/
\ No newline at end of file
...@@ -12,9 +12,9 @@ Performance on TITAN RTX : ...@@ -12,9 +12,9 @@ Performance on TITAN RTX :
| trainval | Lego | Mic | Materials | Chair | Hotdog | | trainval | Lego | Mic | Materials | Chair | Hotdog |
| - | - | - | - | - | - | | - | - | - | - | - | - |
| Time | 300s | 272s | 258s | 331s | 287s | | Time | 300s | 274s | 266s | 341s | 277s |
| PSNR | 36.61 | 37.45 | 30.15 | 36.06 | 38.17 | | PSNR | 36.61 | 37.62 | 30.11 | 36.09 | 38.09 |
| FPS | 11.49 | 21.48 | 8.86 | 15.61 | 7.38 | | FPS | 12.87 | 23.67 | 9.33 | 16.91 | 7.48 |
Instant-NGP paper (5 min) on 3090 (w/ mask): Instant-NGP paper (5 min) on 3090 (w/ mask):
...@@ -44,17 +44,23 @@ Note: We only use a single MLP with more samples (1024), instead of two MLPs wit ...@@ -44,17 +44,23 @@ Note: We only use a single MLP with more samples (1024), instead of two MLPs wit
*FPS for some scenes are tested under `--test_chunk_size=8192` (default is `81920`) to avoid OOM. *FPS for some scenes are tested under `--test_chunk_size=8192` (default is `81920`) to avoid OOM.
<!--
Tested with the default settings on the Lego test set. ## Examples: MLP NeRF on Dynamic objects
Here we trained something similar to D-NeRF on the dnerf dataset:
``` bash
python examples/trainval.py dnerf --train_split train --test_chunk_size=8192
```
Performance on test set:
| | Lego | Stand Up |
| - | - | - |
| DNeRF paper PSNR (train set) | 21.64 | 32.79 |
| Our PSNR (train set) | 24.66 | 33.98 |
| Our train time & test FPS | 43min; 0.15FPS | 41min; 0.4FPS |
| Model | Split | PSNR | Train Time | Test Speed | GPU | Train Memory |
| - | - | - | - | - | - | - |
| instant-ngp (paper) | trainval? | 36.39 | - | - | 3090 |
| instant-ngp (code) | train (35k steps) | 36.08 | 308 sec | 55.32 fps | TITAN RTX | 1734MB |
| instant-ngp (code) w/o rng bkgd| train (35k steps) | 34.17 | - | - | - | - |
| torch-ngp (`-O`) | train (30K steps) | 34.15 | 310 sec | 7.8 fps | V100 |
| ours | trainval (35K steps) | 36.22 | 378 sec | 12.08 fps | TITAN RTX | -->
## Tips: ## Tips:
......
...@@ -3,6 +3,8 @@ Volumetric Rendering ...@@ -3,6 +3,8 @@ Volumetric Rendering
.. currentmodule:: nerfacc .. currentmodule:: nerfacc
.. autofunction:: volumetric_rendering_pipeline
.. autofunction:: volumetric_rendering_steps .. autofunction:: volumetric_rendering_steps
.. autofunction:: volumetric_rendering_weights .. autofunction:: volumetric_rendering_weights
.. autofunction:: volumetric_rendering_accumulate .. autofunction:: volumetric_rendering_accumulate
.. autofunction:: unpack_to_ray_indices
\ No newline at end of file
import collections
import json
import os
import imageio.v2 as imageio
import numpy as np
import torch
import torch.nn.functional as F
Rays = collections.namedtuple("Rays", ("origins", "viewdirs"))
def namedtuple_map(fn, tup):
"""Apply `fn` to each element of `tup` and cast to `tup`'s namedtuple."""
return type(tup)(*(None if x is None else fn(x) for x in tup))
def _load_renderings(root_fp: str, subject_id: str, split: str):
"""Load images from disk."""
if not root_fp.startswith("/"):
# allow relative path. e.g., "./data/dnerf_synthetic/"
root_fp = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"..",
"..",
root_fp,
)
data_dir = os.path.join(root_fp, subject_id)
with open(os.path.join(data_dir, "transforms_{}.json".format(split)), "r") as fp:
meta = json.load(fp)
images = []
camtoworlds = []
timestamps = []
for i in range(len(meta["frames"])):
frame = meta["frames"][i]
fname = os.path.join(data_dir, frame["file_path"] + ".png")
rgba = imageio.imread(fname)
timestamp = (
frame["time"] if "time" in frame else float(i) / (len(meta["frames"]) - 1)
)
timestamps.append(timestamp)
camtoworlds.append(frame["transform_matrix"])
images.append(rgba)
images = np.stack(images, axis=0)
camtoworlds = np.stack(camtoworlds, axis=0)
timestamps = np.stack(timestamps, axis=0)
h, w = images.shape[1:3]
camera_angle_x = float(meta["camera_angle_x"])
focal = 0.5 * w / np.tan(0.5 * camera_angle_x)
return images, camtoworlds, focal, timestamps
class SubjectLoader(torch.utils.data.Dataset):
"""Single subject data loader for training and evaluation."""
SPLITS = ["train", "val", "test"]
SUBJECT_IDS = [
"bouncingballs",
"hellwarrior",
"hook",
"jumpingjacks",
"lego",
"mutant",
"standup",
"trex",
]
WIDTH, HEIGHT = 800, 800
NEAR, FAR = 2.0, 6.0
OPENGL_CAMERA = True
def __init__(
self,
subject_id: str,
root_fp: str,
split: str,
color_bkgd_aug: str = "white",
num_rays: int = None,
near: float = None,
far: float = None,
batch_over_images: bool = True,
):
super().__init__()
assert split in self.SPLITS, "%s" % split
assert subject_id in self.SUBJECT_IDS, "%s" % subject_id
assert color_bkgd_aug in ["white", "black", "random"]
self.split = split
self.num_rays = num_rays
self.near = self.NEAR if near is None else near
self.far = self.FAR if far is None else far
self.training = (num_rays is not None) and (split in ["train", "trainval"])
self.color_bkgd_aug = color_bkgd_aug
self.batch_over_images = batch_over_images
self.images, self.camtoworlds, self.focal, self.timestamps = _load_renderings(
root_fp, subject_id, split
)
self.images = torch.from_numpy(self.images).to(torch.uint8)
self.camtoworlds = torch.from_numpy(self.camtoworlds).to(torch.float32)
self.timestamps = torch.from_numpy(self.timestamps).to(torch.float32)[:, None]
self.K = torch.tensor(
[
[self.focal, 0, self.WIDTH / 2.0],
[0, self.focal, self.HEIGHT / 2.0],
[0, 0, 1],
],
dtype=torch.float32,
) # (3, 3)
assert self.images.shape[1:3] == (self.HEIGHT, self.WIDTH)
def __len__(self):
return len(self.images)
@torch.no_grad()
def __getitem__(self, index):
data = self.fetch_data(index)
data = self.preprocess(data)
return data
def preprocess(self, data):
"""Process the fetched / cached data with randomness."""
rgba, rays = data["rgba"], data["rays"]
pixels, alpha = torch.split(rgba, [3, 1], dim=-1)
if self.training:
if self.color_bkgd_aug == "random":
color_bkgd = torch.rand(3, device=self.images.device)
elif self.color_bkgd_aug == "white":
color_bkgd = torch.ones(3, device=self.images.device)
elif self.color_bkgd_aug == "black":
color_bkgd = torch.zeros(3, device=self.images.device)
else:
# just use white during inference
color_bkgd = torch.ones(3, device=self.images.device)
pixels = pixels * alpha + color_bkgd * (1.0 - alpha)
return {
"pixels": pixels, # [n_rays, 3] or [h, w, 3]
"rays": rays, # [n_rays,] or [h, w]
"color_bkgd": color_bkgd, # [3,]
**{k: v for k, v in data.items() if k not in ["rgba", "rays"]},
}
def update_num_rays(self, num_rays):
self.num_rays = num_rays
def fetch_data(self, index):
"""Fetch the data (it maybe cached for multiple batches)."""
num_rays = self.num_rays
if self.training:
if self.batch_over_images:
image_id = torch.randint(
0,
len(self.images),
size=(num_rays,),
device=self.images.device,
)
else:
image_id = [index]
x = torch.randint(
0, self.WIDTH, size=(num_rays,), device=self.images.device
)
y = torch.randint(
0, self.HEIGHT, size=(num_rays,), device=self.images.device
)
else:
image_id = [index]
x, y = torch.meshgrid(
torch.arange(self.WIDTH, device=self.images.device),
torch.arange(self.HEIGHT, device=self.images.device),
indexing="xy",
)
x = x.flatten()
y = y.flatten()
# generate rays
rgba = self.images[image_id, y, x] / 255.0 # (num_rays, 4)
c2w = self.camtoworlds[image_id] # (num_rays, 3, 4)
camera_dirs = F.pad(
torch.stack(
[
(x - self.K[0, 2] + 0.5) / self.K[0, 0],
(y - self.K[1, 2] + 0.5)
/ self.K[1, 1]
* (-1.0 if self.OPENGL_CAMERA else 1.0),
],
dim=-1,
),
(0, 1),
value=(-1.0 if self.OPENGL_CAMERA else 1.0),
) # [num_rays, 3]
# [n_cams, height, width, 3]
directions = (camera_dirs[:, None, :] * c2w[:, :3, :3]).sum(dim=-1)
origins = torch.broadcast_to(c2w[:, :3, -1], directions.shape)
viewdirs = directions / torch.linalg.norm(directions, dim=-1, keepdims=True)
if self.training:
origins = torch.reshape(origins, (num_rays, 3))
viewdirs = torch.reshape(viewdirs, (num_rays, 3))
rgba = torch.reshape(rgba, (num_rays, 4))
else:
origins = torch.reshape(origins, (self.HEIGHT, self.WIDTH, 3))
viewdirs = torch.reshape(viewdirs, (self.HEIGHT, self.WIDTH, 3))
rgba = torch.reshape(rgba, (self.HEIGHT, self.WIDTH, 4))
rays = Rays(origins=origins, viewdirs=viewdirs)
timestamps = self.timestamps[image_id]
return {
"rgba": rgba, # [h, w, 4] or [num_rays, 4]
"rays": rays, # [h, w, 3] or [num_rays, 3]
"timestamps": timestamps, # [num_rays, 1]
}
""" The MLPs and Voxels. """ """ The MLPs and Voxels. """
import functools
import math import math
from typing import Callable, Dict, Optional from typing import Callable, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -170,13 +171,15 @@ class SinusoidalEncoder(nn.Module): ...@@ -170,13 +171,15 @@ class SinusoidalEncoder(nn.Module):
def latent_dim(self) -> int: def latent_dim(self) -> int:
return (int(self.use_identity) + (self.max_deg - self.min_deg) * 2) * self.x_dim return (int(self.use_identity) + (self.max_deg - self.min_deg) * 2) * self.x_dim
def forward(self, x: torch.Tensor) -> Dict: def forward(self, x: torch.Tensor) -> torch.Tensor:
""" """
Args: Args:
x: [..., x_dim] x: [..., x_dim]
Returns: Returns:
latent: [..., latent_dim] latent: [..., latent_dim]
""" """
if self.max_deg == self.min_deg:
return x
xb = torch.reshape( xb = torch.reshape(
(x[Ellipsis, None, :] * self.scales[:, None]), (x[Ellipsis, None, :] * self.scales[:, None]),
list(x.shape[:-1]) + [(self.max_deg - self.min_deg) * self.x_dim], list(x.shape[:-1]) + [(self.max_deg - self.min_deg) * self.x_dim],
...@@ -220,3 +223,31 @@ class VanillaNeRFRadianceField(nn.Module): ...@@ -220,3 +223,31 @@ class VanillaNeRFRadianceField(nn.Module):
condition = self.view_encoder(condition) condition = self.view_encoder(condition)
rgb, sigma = self.mlp(x, condition=condition) rgb, sigma = self.mlp(x, condition=condition)
return torch.sigmoid(rgb), F.relu(sigma) return torch.sigmoid(rgb), F.relu(sigma)
class DNeRFRadianceField(nn.Module):
def __init__(self) -> None:
super().__init__()
self.posi_encoder = SinusoidalEncoder(3, 0, 0, True)
self.time_encoder = SinusoidalEncoder(1, 0, 0, True)
self.warp = MLP(
input_dim=self.posi_encoder.latent_dim + self.time_encoder.latent_dim,
output_dim=3,
net_depth=4,
net_width=64,
skip_layer=2,
output_init=functools.partial(torch.nn.init.uniform_, b=1e-4),
)
self.nerf = VanillaNeRFRadianceField()
def query_density(self, x, t):
x = x + self.warp(
torch.cat([self.posi_encoder(x), self.time_encoder(t)], dim=-1)
)
return self.nerf.query_density(x)
def forward(self, x, t, condition=None):
x = x + self.warp(
torch.cat([self.posi_encoder(x), self.time_encoder(t)], dim=-1)
)
return self.nerf(x, condition=condition)
import argparse import argparse
import math import math
import os
import random
import time import time
import imageio
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tqdm import tqdm
from datasets.nerf_synthetic import SubjectLoader, namedtuple_map
from radiance_fields.mlp import VanillaNeRFRadianceField
from radiance_fields.ngp import NGPradianceField
from nerfacc import OccupancyField, volumetric_rendering from nerfacc import OccupancyField, volumetric_rendering_pipeline
TARGET_SAMPLE_BATCH_SIZE = 1 << 16 device = "cuda:0"
def _set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def render_image( def render_image(
radiance_field, rays, render_bkgd, render_step_size, test_chunk_size=81920 radiance_field,
rays,
timestamps,
render_bkgd,
render_step_size,
test_chunk_size=81920,
): ):
"""Render the pixels of an image. """Render the pixels of an image.
...@@ -37,25 +48,45 @@ def render_image( ...@@ -37,25 +48,45 @@ def render_image(
else: else:
num_rays, _ = rays_shape num_rays, _ = rays_shape
def sigma_fn(frustum_origins, frustum_dirs, frustum_starts, frustum_ends): def sigma_fn(frustum_starts, frustum_ends, ray_indices):
ray_indices = ray_indices.long()
frustum_origins = chunk_rays.origins[ray_indices]
frustum_dirs = chunk_rays.viewdirs[ray_indices]
positions = ( positions = (
frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0 frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0
) )
if timestamps is None:
return radiance_field.query_density(positions) return radiance_field.query_density(positions)
else:
if radiance_field.training:
t = timestamps[ray_indices]
else:
t = timestamps.expand_as(positions[:, :1])
return radiance_field.query_density(positions, t)
def sigma_rgb_fn(frustum_origins, frustum_dirs, frustum_starts, frustum_ends): def rgb_sigma_fn(frustum_starts, frustum_ends, ray_indices):
ray_indices = ray_indices.long()
frustum_origins = chunk_rays.origins[ray_indices]
frustum_dirs = chunk_rays.viewdirs[ray_indices]
positions = ( positions = (
frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0 frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0
) )
if timestamps is None:
return radiance_field(positions, frustum_dirs) return radiance_field(positions, frustum_dirs)
else:
if radiance_field.training:
t = timestamps[ray_indices]
else:
t = timestamps.expand_as(positions[:, :1])
return radiance_field(positions, t, frustum_dirs)
results = [] results = []
chunk = torch.iinfo(torch.int32).max if radiance_field.training else test_chunk_size chunk = torch.iinfo(torch.int32).max if radiance_field.training else test_chunk_size
for i in range(0, num_rays, chunk): for i in range(0, num_rays, chunk):
chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays) chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
chunk_results = volumetric_rendering( chunk_results = volumetric_rendering_pipeline(
sigma_fn=sigma_fn, sigma_fn=sigma_fn,
sigma_rgb_fn=sigma_rgb_fn, rgb_sigma_fn=rgb_sigma_fn,
rays_o=chunk_rays.origins, rays_o=chunk_rays.origins,
rays_d=chunk_rays.viewdirs, rays_d=chunk_rays.viewdirs,
scene_aabb=occ_field.aabb, scene_aabb=occ_field.aabb,
...@@ -80,14 +111,14 @@ def render_image( ...@@ -80,14 +111,14 @@ def render_image(
if __name__ == "__main__": if __name__ == "__main__":
torch.manual_seed(42) _set_random_seed(42)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"method", "method",
type=str, type=str,
default="ngp", default="ngp",
choices=["ngp", "vanilla"], choices=["ngp", "vanilla", "dnerf"],
help="which nerf to use", help="which nerf to use",
) )
parser.add_argument( parser.add_argument(
...@@ -102,6 +133,7 @@ if __name__ == "__main__": ...@@ -102,6 +133,7 @@ if __name__ == "__main__":
type=str, type=str,
default="lego", default="lego",
choices=[ choices=[
# nerf synthetic
"chair", "chair",
"drums", "drums",
"ficus", "ficus",
...@@ -110,9 +142,23 @@ if __name__ == "__main__": ...@@ -110,9 +142,23 @@ if __name__ == "__main__":
"materials", "materials",
"mic", "mic",
"ship", "ship",
# dnerf
"bouncingballs",
"hellwarrior",
"hook",
"jumpingjacks",
"lego",
"mutant",
"standup",
"trex",
], ],
help="which scene to use", help="which scene to use",
) )
parser.add_argument(
"--aabb",
type=list,
default=[-1.5, -1.5, -1.5, 1.5, 1.5, 1.5],
)
parser.add_argument( parser.add_argument(
"--test_chunk_size", "--test_chunk_size",
type=int, type=int,
...@@ -120,11 +166,46 @@ if __name__ == "__main__": ...@@ -120,11 +166,46 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
device = "cuda:0" if args.method == "ngp":
from datasets.nerf_synthetic import SubjectLoader, namedtuple_map
from radiance_fields.ngp import NGPradianceField
radiance_field = NGPradianceField(aabb=args.aabb).to(device)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=1e-2, eps=1e-15)
max_steps = 20000
occ_field_warmup_steps = 256
grad_scaler = torch.cuda.amp.GradScaler(2**10)
data_root_fp = "/home/ruilongli/data/nerf_synthetic/"
target_sample_batch_size = 1 << 18
elif args.method == "vanilla":
from datasets.nerf_synthetic import SubjectLoader, namedtuple_map
from radiance_fields.mlp import VanillaNeRFRadianceField
radiance_field = VanillaNeRFRadianceField().to(device)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4)
max_steps = 40000
occ_field_warmup_steps = 2000
grad_scaler = torch.cuda.amp.GradScaler(1)
data_root_fp = "/home/ruilongli/data/nerf_synthetic/"
target_sample_batch_size = 1 << 16
elif args.method == "dnerf":
from datasets.dnerf_synthetic import SubjectLoader, namedtuple_map
from radiance_fields.mlp import DNeRFRadianceField
radiance_field = DNeRFRadianceField().to(device)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4)
max_steps = 40000
occ_field_warmup_steps = 2000
grad_scaler = torch.cuda.amp.GradScaler(1)
data_root_fp = "/home/ruilongli/data/dnerf/"
target_sample_batch_size = 1 << 16
scene = args.scene scene = args.scene
# setup the scene bounding box. # setup the scene bounding box.
scene_aabb = torch.tensor([-1.5, -1.5, -1.5, 1.5, 1.5, 1.5]) scene_aabb = torch.tensor(args.aabb)
# setup some rendering settings # setup some rendering settings
render_n_samples = 1024 render_n_samples = 1024
render_step_size = ( render_step_size = (
...@@ -134,55 +215,29 @@ if __name__ == "__main__": ...@@ -134,55 +215,29 @@ if __name__ == "__main__":
# setup dataset # setup dataset
train_dataset = SubjectLoader( train_dataset = SubjectLoader(
subject_id=scene, subject_id=scene,
root_fp="/home/ruilongli/data/nerf_synthetic/", root_fp=data_root_fp,
split=args.train_split, split=args.train_split,
num_rays=TARGET_SAMPLE_BATCH_SIZE // render_n_samples, num_rays=target_sample_batch_size // render_n_samples,
# color_bkgd_aug="random", # color_bkgd_aug="random",
) )
train_dataset.images = train_dataset.images.to(device) train_dataset.images = train_dataset.images.to(device)
train_dataset.camtoworlds = train_dataset.camtoworlds.to(device) train_dataset.camtoworlds = train_dataset.camtoworlds.to(device)
train_dataset.K = train_dataset.K.to(device) train_dataset.K = train_dataset.K.to(device)
train_dataloader = torch.utils.data.DataLoader( if hasattr(train_dataset, "timestamps"):
train_dataset, train_dataset.timestamps = train_dataset.timestamps.to(device)
num_workers=0,
batch_size=None,
# persistent_workers=True,
shuffle=True,
)
test_dataset = SubjectLoader( test_dataset = SubjectLoader(
subject_id=scene, subject_id=scene,
root_fp="/home/ruilongli/data/nerf_synthetic/", root_fp=data_root_fp,
split="test", split="test",
num_rays=None, num_rays=None,
) )
test_dataset.images = test_dataset.images.to(device) test_dataset.images = test_dataset.images.to(device)
test_dataset.camtoworlds = test_dataset.camtoworlds.to(device) test_dataset.camtoworlds = test_dataset.camtoworlds.to(device)
test_dataset.K = test_dataset.K.to(device) test_dataset.K = test_dataset.K.to(device)
test_dataloader = torch.utils.data.DataLoader( if hasattr(train_dataset, "timestamps"):
test_dataset, test_dataset.timestamps = test_dataset.timestamps.to(device)
num_workers=0,
batch_size=None,
)
# setup the scene radiance field. Assume you have a NeRF model and
# it has following functions:
# - query_density(): {x} -> {density}
# - forward(): {x, dirs} -> {rgb, density}
if args.method == "ngp":
radiance_field = NGPradianceField(aabb=scene_aabb).to(device)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=1e-2, eps=1e-15)
max_steps = 20000
occ_field_warmup_steps = 2000
grad_scaler = torch.cuda.amp.GradScaler(1)
elif args.method == "vanilla":
radiance_field = VanillaNeRFRadianceField().to(device)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4)
max_steps = 40000
occ_field_warmup_steps = 256
grad_scaler = torch.cuda.amp.GradScaler(2**10)
scheduler = torch.optim.lr_scheduler.MultiStepLR( scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, optimizer,
...@@ -199,6 +254,13 @@ if __name__ == "__main__": ...@@ -199,6 +254,13 @@ if __name__ == "__main__":
Returns: Returns:
occupancy values with shape (N, 1). occupancy values with shape (N, 1).
""" """
if args.method == "dnerf":
idxs = torch.randint(
0, len(train_dataset.timestamps), (x.shape[0],), device=x.device
)
t = train_dataset.timestamps[idxs]
density_after_activation = radiance_field.query_density(x, t)
else:
density_after_activation = radiance_field.query_density(x) density_after_activation = radiance_field.query_density(x)
# those two are similar when density is small. # those two are similar when density is small.
# occupancy = 1.0 - torch.exp(-density_after_activation * render_step_size) # occupancy = 1.0 - torch.exp(-density_after_activation * render_step_size)
...@@ -221,22 +283,20 @@ if __name__ == "__main__": ...@@ -221,22 +283,20 @@ if __name__ == "__main__":
data = train_dataset[i] data = train_dataset[i]
data_time += time.time() - tic_data data_time += time.time() - tic_data
# generate rays from data and the gt pixel color
# rays = namedtuple_map(lambda x: x.to(device), data["rays"])
# pixels = data["pixels"].to(device)
render_bkgd = data["color_bkgd"] render_bkgd = data["color_bkgd"]
rays = data["rays"] rays = data["rays"]
pixels = data["pixels"] pixels = data["pixels"]
timestamps = data.get("timestamps", None)
# update occupancy grid # update occupancy grid
occ_field.every_n_step(step, warmup_steps=occ_field_warmup_steps) occ_field.every_n_step(step, warmup_steps=occ_field_warmup_steps)
rgb, acc, counter, compact_counter = render_image( rgb, acc, counter, compact_counter = render_image(
radiance_field, rays, render_bkgd, render_step_size radiance_field, rays, timestamps, render_bkgd, render_step_size
) )
num_rays = len(pixels) num_rays = len(pixels)
num_rays = int( num_rays = int(
num_rays * (TARGET_SAMPLE_BATCH_SIZE / float(compact_counter)) num_rays * (target_sample_batch_size / float(compact_counter))
) )
train_dataset.update_num_rays(num_rays) train_dataset.update_num_rays(num_rays)
alive_ray_mask = acc.squeeze(-1) > 0 alive_ray_mask = acc.squeeze(-1) > 0
...@@ -261,21 +321,24 @@ if __name__ == "__main__": ...@@ -261,21 +321,24 @@ if __name__ == "__main__":
) )
# if time.time() - tic > 300: # if time.time() - tic > 300:
if step >= max_steps and step % max_steps == 0 and step > 0: if step >= 0 and step % max_steps == 0 and step > 0:
# evaluation # evaluation
radiance_field.eval() radiance_field.eval()
psnrs = [] psnrs = []
with torch.no_grad(): with torch.no_grad():
for data in tqdm.tqdm(test_dataloader): for i in tqdm.tqdm(range(len(test_dataset))):
# generate rays from data and the gt pixel color data = test_dataset[i]
rays = namedtuple_map(lambda x: x.to(device), data["rays"]) render_bkgd = data["color_bkgd"]
pixels = data["pixels"].to(device) rays = data["rays"]
render_bkgd = data["color_bkgd"].to(device) pixels = data["pixels"]
timestamps = data.get("timestamps", None)
# rendering # rendering
rgb, acc, _, _ = render_image( rgb, acc, _, _ = render_image(
radiance_field, radiance_field,
rays, rays,
timestamps,
render_bkgd, render_bkgd,
render_step_size, render_step_size,
test_chunk_size=args.test_chunk_size, test_chunk_size=args.test_chunk_size,
...@@ -283,38 +346,28 @@ if __name__ == "__main__": ...@@ -283,38 +346,28 @@ if __name__ == "__main__":
mse = F.mse_loss(rgb, pixels) mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0) psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item()) psnrs.append(psnr.item())
psnr_avg = sum(psnrs) / len(psnrs) # if step == max_steps:
print(f"evaluation: {psnr_avg=}") # output_dir = os.path.join("./outputs/nerfacc/", scene)
# os.makedirs(output_dir, exist_ok=True)
# save = torch.cat([pixels, rgb], dim=1)
# imageio.imwrite( # imageio.imwrite(
# "acc_binary_test.png", # os.path.join(output_dir, "%05d.png" % i),
# ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8), # (save.cpu().numpy() * 255).astype(np.uint8),
# )
# psnrs = []
# train_dataset.training = False
# with torch.no_grad():
# for data in tqdm.tqdm(train_dataloader):
# # generate rays from data and the gt pixel color
# rays = namedtuple_map(lambda x: x.to(device), data["rays"])
# pixels = data["pixels"].to(device)
# render_bkgd = data["color_bkgd"].to(device)
# # rendering
# rgb, acc, _, _ = render_image(
# radiance_field, rays, render_bkgd, render_step_size
# ) # )
# mse = F.mse_loss(rgb, pixels) # else:
# psnr = -10.0 * torch.log(mse) / np.log(10.0)
# psnrs.append(psnr.item())
# psnr_avg = sum(psnrs) / len(psnrs)
# print(f"evaluation on train: {psnr_avg=}")
# imageio.imwrite( # imageio.imwrite(
# "acc_binary_train.png", # "acc_binary_test.png",
# ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8), # ((acc > 0).float().cpu().numpy() * 255).astype(
# np.uint8
# ),
# ) # )
# imageio.imwrite( # imageio.imwrite(
# "rgb_train.png", # "rgb_test.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8), # (rgb.cpu().numpy() * 255).astype(np.uint8),
# ) # )
# break
psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation: {psnr_avg=}")
train_dataset.training = True train_dataset.training = True
if step == max_steps: if step == max_steps:
......
from .occupancy_field import OccupancyField from .occupancy_field import OccupancyField
from .utils import ( from .utils import (
ray_aabb_intersect, ray_aabb_intersect,
unpack_to_ray_indices,
volumetric_marching, volumetric_marching,
volumetric_rendering_accumulate, volumetric_rendering_accumulate,
volumetric_rendering_steps, volumetric_rendering_steps,
volumetric_rendering_weights, volumetric_rendering_weights,
) )
from .volumetric_rendering import volumetric_rendering from .volumetric_rendering import volumetric_rendering_pipeline
__all__ = [ __all__ = [
"OccupancyField", "OccupancyField",
...@@ -15,5 +16,6 @@ __all__ = [ ...@@ -15,5 +16,6 @@ __all__ = [
"volumetric_rendering_accumulate", "volumetric_rendering_accumulate",
"volumetric_rendering_steps", "volumetric_rendering_steps",
"volumetric_rendering_weights", "volumetric_rendering_weights",
"volumetric_rendering", "volumetric_rendering_pipeline",
"unpack_to_ray_indices",
] ]
from typing import Callable from typing import Callable
def _make_lazy_cuda(name: str) -> Callable: def _make_lazy_cuda(name: str) -> Callable:
def call_cuda(*args, **kwargs): def call_cuda(*args, **kwargs):
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
...@@ -9,8 +10,14 @@ def _make_lazy_cuda(name: str) -> Callable: ...@@ -9,8 +10,14 @@ def _make_lazy_cuda(name: str) -> Callable:
return call_cuda return call_cuda
ray_aabb_intersect = _make_lazy_cuda("ray_aabb_intersect") ray_aabb_intersect = _make_lazy_cuda("ray_aabb_intersect")
volumetric_marching = _make_lazy_cuda("volumetric_marching") volumetric_marching = _make_lazy_cuda("volumetric_marching")
volumetric_rendering_steps = _make_lazy_cuda("volumetric_rendering_steps") volumetric_rendering_steps = _make_lazy_cuda("volumetric_rendering_steps")
volumetric_rendering_weights_forward = _make_lazy_cuda("volumetric_rendering_weights_forward") volumetric_rendering_weights_forward = _make_lazy_cuda(
volumetric_rendering_weights_backward = _make_lazy_cuda("volumetric_rendering_weights_backward") "volumetric_rendering_weights_forward"
\ No newline at end of file )
volumetric_rendering_weights_backward = _make_lazy_cuda(
"volumetric_rendering_weights_backward"
)
unpack_to_ray_indices = _make_lazy_cuda("unpack_to_ray_indices")
...@@ -14,7 +14,7 @@ std::vector<torch::Tensor> volumetric_rendering_steps( ...@@ -14,7 +14,7 @@ std::vector<torch::Tensor> volumetric_rendering_steps(
torch::Tensor sigmas torch::Tensor sigmas
); );
std::vector<torch::Tensor> volumetric_rendering_weights_forward( torch::Tensor volumetric_rendering_weights_forward(
torch::Tensor packed_info, torch::Tensor packed_info,
torch::Tensor starts, torch::Tensor starts,
torch::Tensor ends, torch::Tensor ends,
...@@ -44,6 +44,8 @@ std::vector<torch::Tensor> volumetric_marching( ...@@ -44,6 +44,8 @@ std::vector<torch::Tensor> volumetric_marching(
const float dt const float dt
); );
torch::Tensor unpack_to_ray_indices(const torch::Tensor packed_info);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{ {
m.def("ray_aabb_intersect", &ray_aabb_intersect); m.def("ray_aabb_intersect", &ray_aabb_intersect);
...@@ -51,4 +53,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) ...@@ -51,4 +53,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("volumetric_rendering_steps", &volumetric_rendering_steps); m.def("volumetric_rendering_steps", &volumetric_rendering_steps);
m.def("volumetric_rendering_weights_forward", &volumetric_rendering_weights_forward); m.def("volumetric_rendering_weights_forward", &volumetric_rendering_weights_forward);
m.def("volumetric_rendering_weights_backward", &volumetric_rendering_weights_backward); m.def("volumetric_rendering_weights_backward", &volumetric_rendering_weights_backward);
m.def("unpack_to_ray_indices", &unpack_to_ray_indices);
} }
\ No newline at end of file
...@@ -144,8 +144,6 @@ __global__ void marching_forward_kernel( ...@@ -144,8 +144,6 @@ __global__ void marching_forward_kernel(
const float dt, const float dt,
const int* packed_info, const int* packed_info,
// frustrum outputs // frustrum outputs
float* frustum_origins,
float* frustum_dirs,
float* frustum_starts, float* frustum_starts,
float* frustum_ends float* frustum_ends
) { ) {
...@@ -165,8 +163,6 @@ __global__ void marching_forward_kernel( ...@@ -165,8 +163,6 @@ __global__ void marching_forward_kernel(
const float near = t_min[0], far = t_max[0]; const float near = t_min[0], far = t_max[0];
// locate // locate
frustum_origins += base * 3;
frustum_dirs += base * 3;
frustum_starts += base; frustum_starts += base;
frustum_ends += base; frustum_ends += base;
...@@ -182,12 +178,6 @@ __global__ void marching_forward_kernel( ...@@ -182,12 +178,6 @@ __global__ void marching_forward_kernel(
const float z = oz + t_mid * dz; const float z = oz + t_mid * dz;
if (grid_occupied_at(x, y, z, resx, resy, resz, aabb, occ_binary)) { if (grid_occupied_at(x, y, z, resx, resy, resz, aabb, occ_binary)) {
frustum_origins[j * 3 + 0] = ox;
frustum_origins[j * 3 + 1] = oy;
frustum_origins[j * 3 + 2] = oz;
frustum_dirs[j * 3 + 0] = dx;
frustum_dirs[j * 3 + 1] = dy;
frustum_dirs[j * 3 + 2] = dz;
frustum_starts[j] = t0; frustum_starts[j] = t0;
frustum_ends[j] = t1; frustum_ends[j] = t1;
++j; ++j;
...@@ -211,6 +201,27 @@ __global__ void marching_forward_kernel( ...@@ -211,6 +201,27 @@ __global__ void marching_forward_kernel(
return; return;
} }
__global__ void ray_indices_kernel(
// input
const int n_rays,
const int* packed_info,
// output
int* ray_indices
) {
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
if (steps == 0) return;
ray_indices += base;
for (int j = 0; j < steps; ++j) {
ray_indices[j] = i;
}
}
std::vector<torch::Tensor> volumetric_marching( std::vector<torch::Tensor> volumetric_marching(
// rays // rays
...@@ -271,8 +282,6 @@ std::vector<torch::Tensor> volumetric_marching( ...@@ -271,8 +282,6 @@ std::vector<torch::Tensor> volumetric_marching(
// output frustum samples // output frustum samples
int total_steps = cum_steps[cum_steps.size(0) - 1].item<int>(); int total_steps = cum_steps[cum_steps.size(0) - 1].item<int>();
torch::Tensor frustum_origins = torch::zeros({total_steps, 3}, rays_o.options());
torch::Tensor frustum_dirs = torch::zeros({total_steps, 3}, rays_o.options());
torch::Tensor frustum_starts = torch::zeros({total_steps, 1}, rays_o.options()); torch::Tensor frustum_starts = torch::zeros({total_steps, 1}, rays_o.options());
torch::Tensor frustum_ends = torch::zeros({total_steps, 1}, rays_o.options()); torch::Tensor frustum_ends = torch::zeros({total_steps, 1}, rays_o.options());
...@@ -293,12 +302,32 @@ std::vector<torch::Tensor> volumetric_marching( ...@@ -293,12 +302,32 @@ std::vector<torch::Tensor> volumetric_marching(
dt, dt,
packed_info.data_ptr<int>(), packed_info.data_ptr<int>(),
// outputs // outputs
frustum_origins.data_ptr<float>(),
frustum_dirs.data_ptr<float>(),
frustum_starts.data_ptr<float>(), frustum_starts.data_ptr<float>(),
frustum_ends.data_ptr<float>() frustum_ends.data_ptr<float>()
); );
return {packed_info, frustum_origins, frustum_dirs, frustum_starts, frustum_ends}; return {packed_info, frustum_starts, frustum_ends};
}
torch::Tensor unpack_to_ray_indices(const torch::Tensor packed_info) {
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
const int n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
int n_samples = packed_info[n_rays - 1].sum(0).item<int>();
torch::Tensor ray_indices = torch::zeros(
{n_samples}, packed_info.options().dtype(torch::kInt32));
ray_indices_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
packed_info.data_ptr<int>(),
ray_indices.data_ptr<int>()
);
return ray_indices;
} }
...@@ -52,8 +52,7 @@ __global__ void volumetric_rendering_weights_forward_kernel( ...@@ -52,8 +52,7 @@ __global__ void volumetric_rendering_weights_forward_kernel(
const scalar_t* ends, // input end t const scalar_t* ends, // input end t
const scalar_t* sigmas, // input density after activation const scalar_t* sigmas, // input density after activation
// should be all-zero initialized // should be all-zero initialized
scalar_t* weights, // output scalar_t* weights // output
int* samples_ray_ids // output
) { ) {
CUDA_GET_THREAD_ID(i, n_rays); CUDA_GET_THREAD_ID(i, n_rays);
...@@ -66,11 +65,6 @@ __global__ void volumetric_rendering_weights_forward_kernel( ...@@ -66,11 +65,6 @@ __global__ void volumetric_rendering_weights_forward_kernel(
ends += base; ends += base;
sigmas += base; sigmas += base;
weights += base; weights += base;
samples_ray_ids += base;
for (int j = 0; j < steps; ++j) {
samples_ray_ids[j] = i;
}
// accumulated rendering // accumulated rendering
scalar_t T = 1.f; scalar_t T = 1.f;
...@@ -184,7 +178,7 @@ std::vector<torch::Tensor> volumetric_rendering_steps( ...@@ -184,7 +178,7 @@ std::vector<torch::Tensor> volumetric_rendering_steps(
} }
std::vector<torch::Tensor> volumetric_rendering_weights_forward( torch::Tensor volumetric_rendering_weights_forward(
torch::Tensor packed_info, torch::Tensor packed_info,
torch::Tensor starts, torch::Tensor starts,
torch::Tensor ends, torch::Tensor ends,
...@@ -208,7 +202,6 @@ std::vector<torch::Tensor> volumetric_rendering_weights_forward( ...@@ -208,7 +202,6 @@ std::vector<torch::Tensor> volumetric_rendering_weights_forward(
// outputs // outputs
torch::Tensor weights = torch::zeros({n_samples}, sigmas.options()); torch::Tensor weights = torch::zeros({n_samples}, sigmas.options());
torch::Tensor ray_indices = torch::zeros({n_samples}, packed_info.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
sigmas.scalar_type(), sigmas.scalar_type(),
...@@ -220,12 +213,11 @@ std::vector<torch::Tensor> volumetric_rendering_weights_forward( ...@@ -220,12 +213,11 @@ std::vector<torch::Tensor> volumetric_rendering_weights_forward(
starts.data_ptr<scalar_t>(), starts.data_ptr<scalar_t>(),
ends.data_ptr<scalar_t>(), ends.data_ptr<scalar_t>(),
sigmas.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(),
weights.data_ptr<scalar_t>(), weights.data_ptr<scalar_t>()
ray_indices.data_ptr<int>()
); );
})); }));
return {weights, ray_indices}; return weights;
} }
......
...@@ -75,8 +75,6 @@ def volumetric_marching( ...@@ -75,8 +75,6 @@ def volumetric_marching(
It is a tensor with shape (n_rays, 2). For each ray, the two values \ It is a tensor with shape (n_rays, 2). For each ray, the two values \
indicate the start index and the number of samples for this ray, \ indicate the start index and the number of samples for this ray, \
respectively. respectively.
- **frustum_origins**: Sampled frustum origins. Tensor with shape (n_samples, 3).
- **frustum_dirs**: Sampled frustum directions. Tensor with shape (n_samples, 3).
- **frustum_starts**: Sampled frustum directions. Tensor with shape (n_samples, 3). - **frustum_starts**: Sampled frustum directions. Tensor with shape (n_samples, 3).
- **frustum_ends**: Sampled frustum directions. Tensor with shape (n_samples, 3). - **frustum_ends**: Sampled frustum directions. Tensor with shape (n_samples, 3).
...@@ -94,13 +92,7 @@ def volumetric_marching( ...@@ -94,13 +92,7 @@ def volumetric_marching(
if stratified: if stratified:
t_min = t_min + torch.rand_like(t_min) * render_step_size t_min = t_min + torch.rand_like(t_min) * render_step_size
( packed_info, frustum_starts, frustum_ends = nerfacc_cuda.volumetric_marching(
packed_info,
frustum_origins,
frustum_dirs,
frustum_starts,
frustum_ends,
) = nerfacc_cuda.volumetric_marching(
# rays # rays
rays_o.contiguous(), rays_o.contiguous(),
rays_d.contiguous(), rays_d.contiguous(),
...@@ -114,13 +106,7 @@ def volumetric_marching( ...@@ -114,13 +106,7 @@ def volumetric_marching(
render_step_size, render_step_size,
) )
return ( return packed_info, frustum_starts, frustum_ends
packed_info,
frustum_origins,
frustum_dirs,
frustum_starts,
frustum_ends,
)
@torch.no_grad() @torch.no_grad()
...@@ -206,7 +192,6 @@ def volumetric_rendering_weights( ...@@ -206,7 +192,6 @@ def volumetric_rendering_weights(
A tuple of tensors containing A tuple of tensors containing
- **weights**: Volumetric rendering weights for those samples. Tensor with shape (n_samples). - **weights**: Volumetric rendering weights for those samples. Tensor with shape (n_samples).
- **ray_indices**: Ray index of each sample. IntTensor with shape (n_sample).
""" """
if ( if (
...@@ -219,12 +204,12 @@ def volumetric_rendering_weights( ...@@ -219,12 +204,12 @@ def volumetric_rendering_weights(
frustum_starts = frustum_starts.contiguous() frustum_starts = frustum_starts.contiguous()
frustum_ends = frustum_ends.contiguous() frustum_ends = frustum_ends.contiguous()
sigmas = sigmas.contiguous() sigmas = sigmas.contiguous()
weights, ray_indices = _volumetric_rendering_weights.apply( weights = _volumetric_rendering_weights.apply(
packed_info, frustum_starts, frustum_ends, sigmas packed_info, frustum_starts, frustum_ends, sigmas
) )
else: else:
raise NotImplementedError("Only support cuda inputs.") raise NotImplementedError("Only support cuda inputs.")
return weights, ray_indices return weights
def volumetric_rendering_accumulate( def volumetric_rendering_accumulate(
...@@ -275,10 +260,32 @@ def volumetric_rendering_accumulate( ...@@ -275,10 +260,32 @@ def volumetric_rendering_accumulate(
return outputs return outputs
@torch.no_grad()
def unpack_to_ray_indices(packed_info: Tensor) -> Tensor:
"""Unpack `packed_info` to ray indices. Useful for converting per ray data to per sample data.
Note: this function is not differentiable to inputs.
Args:
packed_info: Stores infomation on which samples belong to the same ray. \
See ``volumetric_marching`` for details. Tensor with shape (n_rays, 2).
Returns:
Ray index of each sample. IntTensor with shape (n_sample).
"""
if packed_info.is_cuda:
packed_info = packed_info.contiguous()
ray_indices = nerfacc_cuda.unpack_to_ray_indices(packed_info)
else:
raise NotImplementedError("Only support cuda inputs.")
return ray_indices
class _volumetric_rendering_weights(torch.autograd.Function): class _volumetric_rendering_weights(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, packed_info, frustum_starts, frustum_ends, sigmas): def forward(ctx, packed_info, frustum_starts, frustum_ends, sigmas):
weights, ray_indices = nerfacc_cuda.volumetric_rendering_weights_forward( weights = nerfacc_cuda.volumetric_rendering_weights_forward(
packed_info, frustum_starts, frustum_ends, sigmas packed_info, frustum_starts, frustum_ends, sigmas
) )
ctx.save_for_backward( ctx.save_for_backward(
...@@ -288,10 +295,10 @@ class _volumetric_rendering_weights(torch.autograd.Function): ...@@ -288,10 +295,10 @@ class _volumetric_rendering_weights(torch.autograd.Function):
sigmas, sigmas,
weights, weights,
) )
return weights, ray_indices return weights
@staticmethod @staticmethod
def backward(ctx, grad_weights, _grad_ray_indices): def backward(ctx, grad_weights):
( (
packed_info, packed_info,
frustum_starts, frustum_starts,
......
...@@ -3,6 +3,7 @@ from typing import Callable, List, Optional, Tuple ...@@ -3,6 +3,7 @@ from typing import Callable, List, Optional, Tuple
import torch import torch
from .utils import ( from .utils import (
unpack_to_ray_indices,
volumetric_marching, volumetric_marching,
volumetric_rendering_accumulate, volumetric_rendering_accumulate,
volumetric_rendering_steps, volumetric_rendering_steps,
...@@ -10,9 +11,9 @@ from .utils import ( ...@@ -10,9 +11,9 @@ from .utils import (
) )
def volumetric_rendering( def volumetric_rendering_pipeline(
sigma_fn: Callable, sigma_fn: Callable,
sigma_rgb_fn: Callable, rgb_sigma_fn: Callable,
rays_o: torch.Tensor, rays_o: torch.Tensor,
rays_d: torch.Tensor, rays_d: torch.Tensor,
scene_aabb: torch.Tensor, scene_aabb: torch.Tensor,
...@@ -23,7 +24,34 @@ def volumetric_rendering( ...@@ -23,7 +24,34 @@ def volumetric_rendering(
near_plane: float = 0.0, near_plane: float = 0.0,
stratified: bool = False, stratified: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, int, int]: ) -> Tuple[torch.Tensor, torch.Tensor, int, int]:
"""Differentiable volumetric rendering.""" """Differentiable volumetric rendering pipeline.
This function is the integration of those individual functions:
- ray_aabb_intersect
- volumetric_marhcing
- volumetric_rendering_steps
- volumetric_rendering_weights
- volumetric_rendering_accumulate
Args:
sigma_fn: A function that takes in the {frustum starts (N, 1), frustum ends (N, 1), and
ray indices (N,)} and returns the post-activation sigma values (N, 1).
rgb_sigma_fn: A function that takes in the {frustum starts (N, 1), frustum ends (N, 1), and
ray indices (N,)} and returns the post-activation rgb values (N, 3) and sigma values (N, 1).
rays_o: The origin of the rays (n_rays, 3).
rays_d: The normalized direction of the rays (n_rays, 3).
scene_aabb: The scene axis-aligned bounding box {xmin, ymin, zmin, xmax, ymax, zmax}.
scene_resolution: The scene resolution (3,). Defaults to None.
scene_occ_binary: The scene occupancy binary tensor used to skip samples (n_cells,). Defaults to None.
render_bkgd: The background color (3,). Default: None.
render_step_size: The step size for the volumetric rendering. Default: 1e-3.
near_plane: The near plane for the volumetric rendering. Default: 0.0.
stratified: Whether to use stratified sampling. Default: False.
Returns:
Ray colors (n_rays, 3), and opacities (n_rays, 1), the number of marching steps, and the number of rendering steps.
"""
n_rays = rays_o.shape[0] n_rays = rays_o.shape[0]
if scene_occ_binary is None: if scene_occ_binary is None:
...@@ -45,13 +73,7 @@ def volumetric_rendering( ...@@ -45,13 +73,7 @@ def volumetric_rendering(
with torch.no_grad(): with torch.no_grad():
# Ray marching and occupancy check. # Ray marching and occupancy check.
( packed_info, frustum_starts, frustum_ends = volumetric_marching(
packed_info,
frustum_origins,
frustum_dirs,
frustum_starts,
frustum_ends,
) = volumetric_marching(
rays_o, rays_o,
rays_d, rays_d,
aabb=scene_aabb, aabb=scene_aabb,
...@@ -62,44 +84,30 @@ def volumetric_rendering( ...@@ -62,44 +84,30 @@ def volumetric_rendering(
stratified=stratified, stratified=stratified,
) )
n_marching_samples = frustum_starts.shape[0] n_marching_samples = frustum_starts.shape[0]
ray_indices = unpack_to_ray_indices(packed_info)
# Query sigma without gradients # Query sigma without gradients
sigmas = sigma_fn( sigmas = sigma_fn(frustum_starts, frustum_ends, ray_indices)
frustum_origins,
frustum_dirs,
frustum_starts,
frustum_ends,
)
# Ray marching and rendering check. # Ray marching and rendering check.
( packed_info, frustum_starts, frustum_ends = volumetric_rendering_steps(
packed_info,
frustum_starts,
frustum_ends,
frustum_origins,
frustum_dirs,
) = volumetric_rendering_steps(
packed_info, packed_info,
sigmas, sigmas,
frustum_starts, frustum_starts,
frustum_ends, frustum_ends,
frustum_origins,
frustum_dirs,
) )
n_rendering_samples = frustum_starts.shape[0] n_rendering_samples = frustum_starts.shape[0]
ray_indices = unpack_to_ray_indices(packed_info)
# Query sigma and color with gradients # Query sigma and color with gradients
rgbs, sigmas = sigma_rgb_fn( rgbs, sigmas = rgb_sigma_fn(frustum_starts, frustum_ends, ray_indices)
frustum_origins, assert rgbs.shape[-1] == 3, "rgbs must have 3 channels, got {}".format(rgbs.shape)
frustum_dirs, assert sigmas.shape[-1] == 1, "sigmas must have 1 channel, got {}".format(
frustum_starts, sigmas.shape
frustum_ends,
) )
assert rgbs.shape[-1] == 3, "rgbs must have 3 channels"
assert sigmas.shape[-1] == 1, "sigmas must have 1 channel"
# Rendering: compute weights and ray indices. # Rendering: compute weights and ray indices.
weights, ray_indices = volumetric_rendering_weights( weights = volumetric_rendering_weights(
packed_info, sigmas, frustum_starts, frustum_ends packed_info, sigmas, frustum_starts, frustum_ends
) )
......
...@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" ...@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "nerfacc" name = "nerfacc"
version = "0.0.7" version = "0.0.8"
authors = [{name = "Ruilong", email = "ruilongli94@gmail.com"}] authors = [{name = "Ruilong", email = "ruilongli94@gmail.com"}]
license = { text="MIT" } license = { text="MIT" }
requires-python = ">=3.8" requires-python = ">=3.8"
......
import torch import torch
import tqdm import tqdm
from nerfacc import volumetric_rendering from nerfacc import volumetric_rendering_pipeline
device = "cuda:0" device = "cuda:0"
def sigma_fn(frustum_origins, frustum_dirs, frustum_starts, frustum_ends): def sigma_fn(frustum_starts, frustum_ends, ray_indices):
return torch.rand_like(frustum_ends[:, :1]) return torch.rand_like(frustum_ends[:, :1])
def sigma_rgb_fn(frustum_origins, frustum_dirs, frustum_starts, frustum_ends): def rgb_sigma_fn(frustum_starts, frustum_ends, ray_indices):
return torch.rand_like(frustum_ends[:, :1]), torch.rand_like(frustum_ends[:, :3]) return torch.rand((frustum_ends.shape[0], 3), device=device), torch.rand_like(
frustum_ends
)
def test_rendering(): def test_rendering():
...@@ -24,9 +26,9 @@ def test_rendering(): ...@@ -24,9 +26,9 @@ def test_rendering():
render_bkgd = torch.ones(3, device=device) render_bkgd = torch.ones(3, device=device)
for step in tqdm.tqdm(range(1000)): for step in tqdm.tqdm(range(1000)):
volumetric_rendering( volumetric_rendering_pipeline(
sigma_fn, sigma_fn,
sigma_rgb_fn, rgb_sigma_fn,
rays_o, rays_o,
rays_d, rays_d,
scene_aabb, scene_aabb,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment