Unverified Commit 8340e19d authored by Ruilong Li(李瑞龙)'s avatar Ruilong Li(李瑞龙) Committed by GitHub
Browse files

0.5.0: Rewrite all the underlying CUDA. Speedup and Benchmarking. (#182)

* importance_sampling with test

* package importance_sampling

* compute_intervals tested and packaged

* compute_intervals_v2

* bicycle is failing

* fix cut in compute_intervals_v2, test pass for rendering

* hacky way to get opaque_bkgd work

* reorg ING

* PackedRaySegmentsSpec

* chunk_ids -> ray_ids

* binary -> occupied

* test_traverse_grid_basic checked

* fix traverse_grid with step size, checked

* support max_step_size, not verified

* _cuda and cuda; upgrade ray_marching

* inclusive scan

* test_exclusive_sum but seems to have numeric error

* inclusive_sum_backward verified

* exclusive sum backward

* merge fwd and bwd for scan

* inclusive & exclusive prod verified

* support normal scan with torch funcs

* rendering and tests

* a bit clean up

* importance_sampling verified

* stratified for importance_sampling

* importance_sampling in pdf.py

* RaySegmentsSpec in data_specs; fix various bugs

* verified with _proposal_packed.py

* importance sampling support batch input/output. need to verify

* prop script with batch samples

* try to use cumsum  instead of cumprod

* searchsorted

* benchmarking prop

* ray_aabb_intersect untested

* update prop benchmark numbers

* minor fixes

* batched ray_aabb_intersect

* ray_aabb_intersect and traverse with grid(s)

* tiny optimize for traverse_grids kernels

* traverse_grids return intervals and samples

* cub not verified

* cleanup

* propnet and occgrid as estimators

* training print iters 10k

* prop is good now

* benchmark in google sheet.

* really cleanup: scan.py and test

* pack.py and test

* rendering and test

* data_specs.py and pdf.py docs

* data_specs.py and pdf.py docs

* init and headers

* grid.py and test for it

* occ grid docs

* generated docs

* example docs for pack and scan function.

* doc fix for volrend.py

* doc fix for pdf.py

* fix doc for rendering function

* docs

* propnet docs

* update scripts

* docs: index.rst

* methodology docs

* docs for examples

* mlp nerf script

* update t-nerf script

* rename dnerf to tnerf

* misc update

* bug fix: pdf_loss with test

* minor fix

* update readme with submodules

* fix format

* update gitingore file

* fix doc failure. teaser png to jpg

* docs in examples/
parent e547490c
.. _`Efficient Sampling`:
Efficient Sampling
===================================
Transmittance Estimator is All You Need.
----------------------------------------
Efficient sampling is a well-explored problem in Graphics, wherein the
emphasis is on identifying regions that make the most significant
contribution to the final rendering. This objective is generally accomplished
through importance sampling, which aims to distribute samples based on the
probability density function (PDF), denoted as :math:`p(t)`, between the range
of :math:`[t_n, t_f]`. By computing the cumulative distribution function (CDF)
through integration, *i.e.*, :math:`F(t) = \int_{t_n}^{t} p(v)\,dv`,
samples are generated using the inverse transform sampling method:
.. math::
t = F^{-1}(u) \quad \text{where} \quad u \sim \mathcal{U}[0,1].
In volumetric rendering, the contribution of each sample to the final
rendering is expressed by the accumulation weights :math:`T(t)\sigma(t)`:
.. math::
C(\mathbf{r}) = \int_{t_n}^{t_f} T(t)\,\sigma(t)\,c(t)\,dt
Hence, the PDF for volumetric rendering is :math:`p(t) = T(t)\sigma(t)`
and the CDF is:
.. math::
F(t) = \int_{t_n}^{t} T(v)\sigma(v)\,dv = 1 - T(t)
Therefore, inverse sampling the CDF :math:`F(t)` is equivalent to inverse
sampling the transmittance :math:`T(t)`. A transmittance estimator is sufficient
to determine the optimal samples. We refer readers to the
`SIGGRAPH 2017 Course: Production Volume Rendering`_ for more details about this
concept if within interests.
Occupancy Grid Estimator.
----------------------------
.. image:: ../_static/images/illustration_occgrid.png
:class: float-right
:width: 200px
The idea of Occupancy Grid is to cache the density in the scene with a binaraized voxel grid. When
sampling, the ray marches through the grid with a preset step sizes, and skip the empty regions by querying
the voxel grid. Intuitively, the binaraized voxel grid is an *estimator* of the radiance field, with much
faster readout. This technique is proposed in `Instant-NGP`_ with highly optimized CUDA implementations.
More formally, The estimator describes a binaraized density distribution :math:`\hat{\sigma}` along
the ray with a conservative threshold :math:`\tau`:
.. math::
\hat{\sigma}(t_i) = \mathbb{1}\big[\sigma(t_i) > \tau\big]
Consequently, the piece-wise constant PDF can be expressed as
.. math::
p(t_i) = \hat{\sigma}(t_i) / \sum_{j=1}^{n} \hat{\sigma}(t_j)
and the piece-wise linear transmittance estimator is
.. math::
T(t_i) = 1 - \sum_{j=1}^{i-1}\hat{\sigma}(t_j) / \sum_{j=1}^{n} \hat{\sigma}(t_j)
See the figure below for an illustration.
.. rst-class:: clear-both
.. image:: ../_static/images/plot_occgrid.png
:align: center
|
In `nerfacc`, this is implemented via the :class:`nerfacc.OccGridEstimator` class.
Proposal Network Estimator.
-----------------------------
.. image:: ../_static/images/illustration_propnet.png
:class: float-right
:width: 200px
Another type of approach is to directly estimate the PDF along the ray with discrete samples.
In `vanilla NeRF`_, the coarse MLP is trained using volumetric rendering loss to output a set of
densities :math:`{\sigma(t_i)}`. This allows for the creation of a piece-wise constant PDF:
.. math::
p(t_i) = \sigma(t_i)\exp(-\sigma(t_i)\,dt)
and a piece-wise linear transmittance estimator:
.. math::
T(t_i) = \exp(-\sum_{j=1}^{i-1}\sigma(t_i)\,dt)
This approach was further improved in `Mip-NeRF 360`_ with a PDF matching loss, which allows for
the use of a much smaller MLP in the coarse level, namely Proposal Network, to speedup the
PDF construction.
See the figure below for an illustration.
.. image:: ../_static/images/plot_propnet.png
:align: center
|
In `nerfacc`, this is implemented via the :class:`nerfacc.PropNetEstimator` class.
Which Estimator to use?
-----------------------
- :class:`nerfacc.OccGridEstimator` is a generally more efficient when most of the space in the scene is empty, such as in the case of `NeRF-Synthetic`_ dataset. But but it still places samples within occupied but occluded areas that contribute little to the final rendering (e.g., the last sample in the above illustration).
- :class:`nerfacc.PropNetEstimator` generally provide more accurate transmittance estimation, enabling samples to concentrate more on high-contribution areas (e.g., surfaces) and to be more spread out in both empty and occluded regions. Also this method works nicely on unbouned scenes as it does not require a preset bounding box of the scene. Thus datasets like `Mip-NeRF 360`_ are better suited with this estimator.
.. .. currentmodule:: nerfacc
.. .. autoclass:: OccGridEstimator
.. :members:
.. .. autoclass:: PropNetEstimator
.. :members:
.. _`SIGGRAPH 2017 Course: Production Volume Rendering`: https://graphics.pixar.com/library/ProductionVolumeRendering/paper.pdf
.. _`Instant-NGP`: https://arxiv.org/abs/2201.05989
.. _`Mip-NeRF 360`: https://arxiv.org/abs/2111.12077
.. _`vanilla NeRF`: https://arxiv.org/abs/2003.08934
.. _`NeRF-Synthetic`: https://arxiv.org/abs/2003.08934
\ No newline at end of file
...@@ -245,7 +245,7 @@ class VanillaNeRFRadianceField(nn.Module): ...@@ -245,7 +245,7 @@ class VanillaNeRFRadianceField(nn.Module):
return torch.sigmoid(rgb), F.relu(sigma) return torch.sigmoid(rgb), F.relu(sigma)
class DNeRFRadianceField(nn.Module): class TNeRFRadianceField(nn.Module):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.posi_encoder = SinusoidalEncoder(3, 0, 4, True) self.posi_encoder = SinusoidalEncoder(3, 0, 4, True)
......
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
import argparse
import math
import pathlib
import time
import imageio
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from datasets.dnerf_synthetic import SubjectLoader
from radiance_fields.mlp import DNeRFRadianceField
from utils import render_image, set_random_seed
from nerfacc import OccupancyGrid
device = "cuda:0"
set_random_seed(42)
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_root",
type=str,
default=str(pathlib.Path.cwd() / "data/dnerf"),
help="the root dir of the dataset",
)
parser.add_argument(
"--train_split",
type=str,
default="train",
choices=["train"],
help="which train split to use",
)
parser.add_argument(
"--scene",
type=str,
default="lego",
choices=[
# dnerf
"bouncingballs",
"hellwarrior",
"hook",
"jumpingjacks",
"lego",
"mutant",
"standup",
"trex",
],
help="which scene to use",
)
parser.add_argument(
"--aabb",
type=lambda s: [float(item) for item in s.split(",")],
default="-1.5,-1.5,-1.5,1.5,1.5,1.5",
help="delimited list input",
)
parser.add_argument(
"--test_chunk_size",
type=int,
default=8192,
)
parser.add_argument("--cone_angle", type=float, default=0.0)
args = parser.parse_args()
render_n_samples = 1024
# setup the scene bounding box.
scene_aabb = torch.tensor(args.aabb, dtype=torch.float32, device=device)
near_plane = None
far_plane = None
render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples
).item()
# setup the radiance field we want to train.
max_steps = 30000
grad_scaler = torch.cuda.amp.GradScaler(1)
radiance_field = DNeRFRadianceField().to(device)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
max_steps // 2,
max_steps * 3 // 4,
max_steps * 5 // 6,
max_steps * 9 // 10,
],
gamma=0.33,
)
# setup the dataset
target_sample_batch_size = 1 << 16
grid_resolution = 128
train_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split=args.train_split,
num_rays=target_sample_batch_size // render_n_samples,
device=device,
)
test_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split="test",
num_rays=None,
device=device,
)
occupancy_grid = OccupancyGrid(
roi_aabb=args.aabb, resolution=grid_resolution
).to(device)
# training
step = 0
tic = time.time()
for epoch in range(10000000):
for i in range(len(train_dataset)):
radiance_field.train()
data = train_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
timestamps = data["timestamps"]
# update occupancy grid
occupancy_grid.every_n_step(
step=step,
occ_eval_fn=lambda x: radiance_field.query_opacity(
x, timestamps, render_step_size
),
)
# render
rgb, acc, depth, n_rendering_samples = render_image(
radiance_field,
occupancy_grid,
rays,
scene_aabb,
# rendering options
near_plane=near_plane,
far_plane=far_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
alpha_thre=0.01 if step > 1000 else 0.00,
# dnerf options
timestamps=timestamps,
)
if n_rendering_samples == 0:
continue
# dynamic batch size for rays to keep sample batch size constant.
num_rays = len(pixels)
num_rays = int(
num_rays * (target_sample_batch_size / float(n_rendering_samples))
)
train_dataset.update_num_rays(num_rays)
alive_ray_mask = acc.squeeze(-1) > 0
# compute loss
loss = F.smooth_l1_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
optimizer.zero_grad()
# do not unscale it because we are using Adam.
grad_scaler.scale(loss).backward()
optimizer.step()
scheduler.step()
if step % 5000 == 0:
elapsed_time = time.time() - tic
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
print(
f"elapsed_time={elapsed_time:.2f}s | step={step} | "
f"loss={loss:.5f} | "
f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |"
)
if step > 0 and step % max_steps == 0:
# evaluation
radiance_field.eval()
psnrs = []
with torch.no_grad():
for i in tqdm.tqdm(range(len(test_dataset))):
data = test_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
timestamps = data["timestamps"]
# rendering
rgb, acc, depth, _ = render_image(
radiance_field,
occupancy_grid,
rays,
scene_aabb,
# rendering options
near_plane=None,
far_plane=None,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
alpha_thre=0.01,
# test options
test_chunk_size=args.test_chunk_size,
# dnerf options
timestamps=timestamps,
)
mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item())
# imageio.imwrite(
# "acc_binary_test.png",
# ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8),
# )
# imageio.imwrite(
# "rgb_test.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8),
# )
# break
psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation: psnr_avg={psnr_avg}")
train_dataset.training = True
if step == max_steps:
print("training stops")
exit()
step += 1
...@@ -3,7 +3,6 @@ Copyright (c) 2022 Ruilong Li, UC Berkeley. ...@@ -3,7 +3,6 @@ Copyright (c) 2022 Ruilong Li, UC Berkeley.
""" """
import argparse import argparse
import math
import pathlib import pathlib
import time import time
...@@ -12,15 +11,16 @@ import numpy as np ...@@ -12,15 +11,16 @@ import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tqdm import tqdm
from datasets.nerf_synthetic import SubjectLoader
from lpips import LPIPS
from radiance_fields.mlp import VanillaNeRFRadianceField from radiance_fields.mlp import VanillaNeRFRadianceField
from utils import (
MIPNERF360_UNBOUNDED_SCENES, from examples.utils import (
NERF_SYNTHETIC_SCENES, NERF_SYNTHETIC_SCENES,
render_image, render_image_with_occgrid,
set_random_seed, set_random_seed,
) )
from nerfacc.estimators.occ_grid import OccGridEstimator
from nerfacc import ContractionType, OccupancyGrid
device = "cuda:0" device = "cuda:0"
set_random_seed(42) set_random_seed(42)
...@@ -35,7 +35,7 @@ parser.add_argument( ...@@ -35,7 +35,7 @@ parser.add_argument(
parser.add_argument( parser.add_argument(
"--train_split", "--train_split",
type=str, type=str,
default="trainval", default="train",
choices=["train", "trainval"], choices=["train", "trainval"],
help="which train split to use", help="which train split to use",
) )
...@@ -49,89 +49,51 @@ parser.add_argument( ...@@ -49,89 +49,51 @@ parser.add_argument(
"--scene", "--scene",
type=str, type=str,
default="lego", default="lego",
choices=NERF_SYNTHETIC_SCENES + MIPNERF360_UNBOUNDED_SCENES, choices=NERF_SYNTHETIC_SCENES,
help="which scene to use", help="which scene to use",
) )
parser.add_argument(
"--aabb",
type=lambda s: [float(item) for item in s.split(",")],
default="-1.5,-1.5,-1.5,1.5,1.5,1.5",
help="delimited list input",
)
parser.add_argument( parser.add_argument(
"--test_chunk_size", "--test_chunk_size",
type=int, type=int,
default=8192, default=4096,
) )
parser.add_argument(
"--unbounded",
action="store_true",
help="whether to use unbounded rendering",
)
parser.add_argument("--cone_angle", type=float, default=0.0)
args = parser.parse_args() args = parser.parse_args()
render_n_samples = 1024 # training parameters
max_steps = 50000
init_batch_size = 1024
target_sample_batch_size = 1 << 16
# scene parameters
aabb = torch.tensor([-1.5, -1.5, -1.5, 1.5, 1.5, 1.5], device=device)
near_plane = 0.0
far_plane = 1.0e10
# model parameters
grid_resolution = 128
grid_nlvl = 1
# render parameters
render_step_size = 5e-3
# setup the dataset # setup the dataset
train_dataset_kwargs = {}
test_dataset_kwargs = {}
if args.scene in MIPNERF360_UNBOUNDED_SCENES:
from datasets.nerf_360_v2 import SubjectLoader
print("Using unbounded rendering")
target_sample_batch_size = 1 << 16
train_dataset_kwargs["color_bkgd_aug"] = "random"
train_dataset_kwargs["factor"] = 4
test_dataset_kwargs["factor"] = 4
grid_resolution = 128
elif args.scene in NERF_SYNTHETIC_SCENES:
from datasets.nerf_synthetic import SubjectLoader
target_sample_batch_size = 1 << 16
grid_resolution = 128
train_dataset = SubjectLoader( train_dataset = SubjectLoader(
subject_id=args.scene, subject_id=args.scene,
root_fp=args.data_root, root_fp=args.data_root,
split=args.train_split, split=args.train_split,
num_rays=target_sample_batch_size // render_n_samples, num_rays=init_batch_size,
device=device, device=device,
**train_dataset_kwargs,
) )
test_dataset = SubjectLoader( test_dataset = SubjectLoader(
subject_id=args.scene, subject_id=args.scene,
root_fp=args.data_root, root_fp=args.data_root,
split="test", split="test",
num_rays=None, num_rays=None,
device=device, device=device,
**test_dataset_kwargs,
) )
if args.unbounded: estimator = OccGridEstimator(
print("Using unbounded rendering") roi_aabb=aabb, resolution=grid_resolution, levels=grid_nlvl
contraction_type = ContractionType.UN_BOUNDED_SPHERE ).to(device)
scene_aabb = None
near_plane = 0.2
far_plane = 1e4
render_step_size = 1e-2
else:
contraction_type = ContractionType.AABB
scene_aabb = torch.tensor(args.aabb, dtype=torch.float32, device=device)
near_plane = None
far_plane = None
render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max()
* math.sqrt(3)
/ render_n_samples
).item()
# setup the radiance field we want to train. # setup the radiance field we want to train.
max_steps = 50000
grad_scaler = torch.cuda.amp.GradScaler(1)
radiance_field = VanillaNeRFRadianceField().to(device) radiance_field = VanillaNeRFRadianceField().to(device)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4) optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR( scheduler = torch.optim.lr_scheduler.MultiStepLR(
...@@ -145,144 +107,137 @@ scheduler = torch.optim.lr_scheduler.MultiStepLR( ...@@ -145,144 +107,137 @@ scheduler = torch.optim.lr_scheduler.MultiStepLR(
gamma=0.33, gamma=0.33,
) )
occupancy_grid = OccupancyGrid( lpips_net = LPIPS(net="vgg").to(device)
roi_aabb=args.aabb, lpips_norm_fn = lambda x: x[None, ...].permute(0, 3, 1, 2) * 2 - 1
resolution=grid_resolution, lpips_fn = lambda x, y: lpips_net(lpips_norm_fn(x), lpips_norm_fn(y)).mean()
contraction_type=contraction_type,
).to(device)
if args.model_path is not None: if args.model_path is not None:
checkpoint = torch.load(args.model_path) checkpoint = torch.load(args.model_path)
radiance_field.load_state_dict(checkpoint["radiance_field_state_dict"]) radiance_field.load_state_dict(checkpoint["radiance_field_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
occupancy_grid.load_state_dict(checkpoint["occupancy_grid_state_dict"]) estimator.load_state_dict(checkpoint["estimator_state_dict"])
step = checkpoint["step"] step = checkpoint["step"]
else: else:
step = 0 step = 0
# training # training
step = 0
tic = time.time() tic = time.time()
for epoch in range(10000000): for step in range(max_steps + 1):
for i in range(len(train_dataset)): radiance_field.train()
radiance_field.train() estimator.train()
data = train_dataset[i]
i = torch.randint(0, len(train_dataset), (1,)).item()
render_bkgd = data["color_bkgd"] data = train_dataset[i]
rays = data["rays"]
pixels = data["pixels"] render_bkgd = data["color_bkgd"]
rays = data["rays"]
# update occupancy grid pixels = data["pixels"]
occupancy_grid.every_n_step(
step=step, def occ_eval_fn(x):
occ_eval_fn=lambda x: radiance_field.query_opacity( density = radiance_field.query_density(x)
x, render_step_size return density * render_step_size
),
) # update occupancy grid
estimator.update_every_n_steps(
# render step=step,
rgb, acc, depth, n_rendering_samples = render_image( occ_eval_fn=occ_eval_fn,
radiance_field, occ_thre=1e-2,
occupancy_grid, )
rays,
scene_aabb, # render
# rendering options rgb, acc, depth, n_rendering_samples = render_image_with_occgrid(
near_plane=near_plane, radiance_field,
far_plane=far_plane, estimator,
render_step_size=render_step_size, rays,
render_bkgd=render_bkgd, # rendering options
cone_angle=args.cone_angle, near_plane=near_plane,
) render_step_size=render_step_size,
if n_rendering_samples == 0: render_bkgd=render_bkgd,
continue )
if n_rendering_samples == 0:
continue
if target_sample_batch_size > 0:
# dynamic batch size for rays to keep sample batch size constant. # dynamic batch size for rays to keep sample batch size constant.
num_rays = len(pixels) num_rays = len(pixels)
num_rays = int( num_rays = int(
num_rays * (target_sample_batch_size / float(n_rendering_samples)) num_rays * (target_sample_batch_size / float(n_rendering_samples))
) )
train_dataset.update_num_rays(num_rays) train_dataset.update_num_rays(num_rays)
alive_ray_mask = acc.squeeze(-1) > 0
# compute loss
loss = F.smooth_l1_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
optimizer.zero_grad() # compute loss
# do not unscale it because we are using Adam. loss = F.smooth_l1_loss(rgb, pixels)
grad_scaler.scale(loss).backward()
optimizer.step() optimizer.zero_grad()
scheduler.step() loss.backward()
optimizer.step()
if step % 5000 == 0: scheduler.step()
elapsed_time = time.time() - tic
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask]) if step % 5000 == 0:
psnr = -10.0 * torch.log(loss) / np.log(10.0) elapsed_time = time.time() - tic
print( loss = F.mse_loss(rgb, pixels)
f"elapsed_time={elapsed_time:.2f}s | step={step} | " psnr = -10.0 * torch.log(loss) / np.log(10.0)
f"loss={loss:.5f} | " print(
f"alive_ray_mask={alive_ray_mask.long().sum():d} | " f"elapsed_time={elapsed_time:.2f}s | step={step} | "
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} | " f"loss={loss:.5f} | psnr={psnr:.2f} | "
f"psnr={psnr:.2f}" f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} | "
) f"max_depth={depth.max():.3f} | "
)
if step > 0 and step % max_steps == 0:
model_save_path = str(pathlib.Path.cwd() / f"mlp_nerf_{step}")
torch.save(
{
"step": step,
"radiance_field_state_dict": radiance_field.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"occupancy_grid_state_dict": occupancy_grid.state_dict(),
},
model_save_path,
)
# evaluation
radiance_field.eval()
psnrs = []
with torch.no_grad():
for i in tqdm.tqdm(range(len(test_dataset))):
data = test_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
# rendering
rgb, acc, depth, _ = render_image(
radiance_field,
occupancy_grid,
rays,
scene_aabb,
# rendering options
near_plane=near_plane,
far_plane=far_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
# test options
test_chunk_size=args.test_chunk_size,
)
mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item())
# imageio.imwrite(
# "acc_binary_test.png",
# ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8),
# )
# imageio.imwrite(
# f"rgb_test_{i}.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8),
# )
# break
psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation: psnr_avg={psnr_avg}")
train_dataset.training = True
if step == max_steps: if step > 0 and step % max_steps == 0:
print("training stops") model_save_path = str(pathlib.Path.cwd() / f"mlp_nerf_{step}")
exit() torch.save(
{
"step": step,
"radiance_field_state_dict": radiance_field.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"estimator_state_dict": estimator.state_dict(),
},
model_save_path,
)
step += 1 # evaluation
radiance_field.eval()
estimator.eval()
psnrs = []
lpips = []
with torch.no_grad():
for i in tqdm.tqdm(range(len(test_dataset))):
data = test_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
# rendering
rgb, acc, depth, _ = render_image_with_occgrid(
radiance_field,
estimator,
rays,
# rendering options
near_plane=near_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
# test options
test_chunk_size=args.test_chunk_size,
)
mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item())
lpips.append(lpips_fn(rgb, pixels).item())
# if i == 0:
# imageio.imwrite(
# "rgb_test.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8),
# )
# imageio.imwrite(
# "rgb_error.png",
# (
# (rgb - pixels).norm(dim=-1).cpu().numpy() * 255
# ).astype(np.uint8),
# )
psnr_avg = sum(psnrs) / len(psnrs)
lpips_avg = sum(lpips) / len(lpips)
print(f"evaluation: psnr_avg={psnr_avg}, lpips_avg={lpips_avg}")
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
import argparse
import math
import pathlib
import time
import imageio
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from datasets.dnerf_synthetic import SubjectLoader
from lpips import LPIPS
from radiance_fields.mlp import TNeRFRadianceField
from examples.utils import render_image_with_occgrid, set_random_seed
from nerfacc.estimators.occ_grid import OccGridEstimator
device = "cuda:0"
set_random_seed(42)
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_root",
type=str,
default=str(pathlib.Path.cwd() / "data/dnerf"),
help="the root dir of the dataset",
)
parser.add_argument(
"--train_split",
type=str,
default="train",
choices=["train"],
help="which train split to use",
)
parser.add_argument(
"--scene",
type=str,
default="lego",
choices=[
# dnerf
"bouncingballs",
"hellwarrior",
"hook",
"jumpingjacks",
"lego",
"mutant",
"standup",
"trex",
],
help="which scene to use",
)
parser.add_argument(
"--test_chunk_size",
type=int,
default=4096,
)
args = parser.parse_args()
# training parameters
max_steps = 30000
init_batch_size = 1024
target_sample_batch_size = 1 << 16
# scene parameters
aabb = torch.tensor([-1.5, -1.5, -1.5, 1.5, 1.5, 1.5], device=device)
near_plane = 0.0
far_plane = 1.0e10
# model parameters
grid_resolution = 128
grid_nlvl = 1
# render parameters
render_step_size = 5e-3
# setup the dataset
train_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split=args.train_split,
num_rays=init_batch_size,
device=device,
)
test_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split="test",
num_rays=None,
device=device,
)
estimator = OccGridEstimator(
roi_aabb=aabb, resolution=grid_resolution, levels=grid_nlvl
).to(device)
# setup the radiance field we want to train.
radiance_field = TNeRFRadianceField().to(device)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
max_steps // 2,
max_steps * 3 // 4,
max_steps * 5 // 6,
max_steps * 9 // 10,
],
gamma=0.33,
)
lpips_net = LPIPS(net="vgg").to(device)
lpips_norm_fn = lambda x: x[None, ...].permute(0, 3, 1, 2) * 2 - 1
lpips_fn = lambda x, y: lpips_net(lpips_norm_fn(x), lpips_norm_fn(y)).mean()
# training
tic = time.time()
for step in range(max_steps + 1):
radiance_field.train()
estimator.train()
i = torch.randint(0, len(train_dataset), (1,)).item()
data = train_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
timestamps = data["timestamps"]
# update occupancy grid
estimator.update_every_n_steps(
step=step,
occ_eval_fn=lambda x: radiance_field.query_opacity(
x, timestamps, render_step_size
),
occ_thre=1e-2,
)
# render
rgb, acc, depth, n_rendering_samples = render_image_with_occgrid(
radiance_field,
estimator,
rays,
# rendering options
near_plane=near_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
alpha_thre=0.01 if step > 1000 else 0.00,
# t-nerf options
timestamps=timestamps,
)
if n_rendering_samples == 0:
continue
if target_sample_batch_size > 0:
# dynamic batch size for rays to keep sample batch size constant.
num_rays = len(pixels)
num_rays = int(
num_rays * (target_sample_batch_size / float(n_rendering_samples))
)
train_dataset.update_num_rays(num_rays)
# compute loss
loss = F.smooth_l1_loss(rgb, pixels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
if step % 5000 == 0:
elapsed_time = time.time() - tic
loss = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(loss) / np.log(10.0)
print(
f"elapsed_time={elapsed_time:.2f}s | step={step} | "
f"loss={loss:.5f} | psnr={psnr:.2f} | "
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} | "
f"max_depth={depth.max():.3f} | "
)
if step > 0 and step % max_steps == 0:
# evaluation
radiance_field.eval()
estimator.eval()
psnrs = []
lpips = []
with torch.no_grad():
for i in tqdm.tqdm(range(len(test_dataset))):
data = test_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
timestamps = data["timestamps"]
# rendering
rgb, acc, depth, _ = render_image_with_occgrid(
radiance_field,
estimator,
rays,
# rendering options
near_plane=near_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
alpha_thre=0.01,
# test options
test_chunk_size=args.test_chunk_size,
# t-nerf options
timestamps=timestamps,
)
mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item())
lpips.append(lpips_fn(rgb, pixels).item())
# if i == 0:
# imageio.imwrite(
# "rgb_test.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8),
# )
# imageio.imwrite(
# "rgb_error.png",
# (
# (rgb - pixels).norm(dim=-1).cpu().numpy() * 255
# ).astype(np.uint8),
# )
psnr_avg = sum(psnrs) / len(psnrs)
lpips_avg = sum(lpips) / len(lpips)
print(f"evaluation: psnr_avg={psnr_avg}, lpips_avg={lpips_avg}")
...@@ -14,15 +14,14 @@ import torch.nn.functional as F ...@@ -14,15 +14,14 @@ import torch.nn.functional as F
import tqdm import tqdm
from lpips import LPIPS from lpips import LPIPS
from radiance_fields.ngp import NGPRadianceField from radiance_fields.ngp import NGPRadianceField
from utils import (
from examples.utils import (
MIPNERF360_UNBOUNDED_SCENES, MIPNERF360_UNBOUNDED_SCENES,
NERF_SYNTHETIC_SCENES, NERF_SYNTHETIC_SCENES,
enlarge_aabb, render_image_with_occgrid,
render_image,
set_random_seed, set_random_seed,
) )
from nerfacc.estimators.occ_grid import OccGridEstimator
from nerfacc import OccupancyGrid
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
...@@ -66,8 +65,8 @@ if args.scene in MIPNERF360_UNBOUNDED_SCENES: ...@@ -66,8 +65,8 @@ if args.scene in MIPNERF360_UNBOUNDED_SCENES:
weight_decay = 0.0 weight_decay = 0.0
# scene parameters # scene parameters
aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=device) aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=device)
near_plane = 0.02 near_plane = 0.2
far_plane = None far_plane = 1.0e10
# dataset parameters # dataset parameters
train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4} train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4}
test_dataset_kwargs = {"factor": 4} test_dataset_kwargs = {"factor": 4}
...@@ -91,8 +90,8 @@ else: ...@@ -91,8 +90,8 @@ else:
) )
# scene parameters # scene parameters
aabb = torch.tensor([-1.5, -1.5, -1.5, 1.5, 1.5, 1.5], device=device) aabb = torch.tensor([-1.5, -1.5, -1.5, 1.5, 1.5, 1.5], device=device)
near_plane = None near_plane = 0.0
far_plane = None far_plane = 1.0e10
# dataset parameters # dataset parameters
train_dataset_kwargs = {} train_dataset_kwargs = {}
test_dataset_kwargs = {} test_dataset_kwargs = {}
...@@ -122,12 +121,13 @@ test_dataset = SubjectLoader( ...@@ -122,12 +121,13 @@ test_dataset = SubjectLoader(
**test_dataset_kwargs, **test_dataset_kwargs,
) )
# setup scene aabb estimator = OccGridEstimator(
scene_aabb = enlarge_aabb(aabb, 1 << (grid_nlvl - 1)) roi_aabb=aabb, resolution=grid_resolution, levels=grid_nlvl
).to(device)
# setup the radiance field we want to train. # setup the radiance field we want to train.
grad_scaler = torch.cuda.amp.GradScaler(2**10) grad_scaler = torch.cuda.amp.GradScaler(2**10)
radiance_field = NGPRadianceField(aabb=scene_aabb).to(device) radiance_field = NGPRadianceField(aabb=estimator.aabbs[-1]).to(device)
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
radiance_field.parameters(), lr=1e-2, eps=1e-15, weight_decay=weight_decay radiance_field.parameters(), lr=1e-2, eps=1e-15, weight_decay=weight_decay
) )
...@@ -147,10 +147,6 @@ scheduler = torch.optim.lr_scheduler.ChainedScheduler( ...@@ -147,10 +147,6 @@ scheduler = torch.optim.lr_scheduler.ChainedScheduler(
), ),
] ]
) )
occupancy_grid = OccupancyGrid(
roi_aabb=aabb, resolution=grid_resolution, levels=grid_nlvl
).to(device)
lpips_net = LPIPS(net="vgg").to(device) lpips_net = LPIPS(net="vgg").to(device)
lpips_norm_fn = lambda x: x[None, ...].permute(0, 3, 1, 2) * 2 - 1 lpips_norm_fn = lambda x: x[None, ...].permute(0, 3, 1, 2) * 2 - 1
lpips_fn = lambda x, y: lpips_net(lpips_norm_fn(x), lpips_norm_fn(y)).mean() lpips_fn = lambda x, y: lpips_net(lpips_norm_fn(x), lpips_norm_fn(y)).mean()
...@@ -159,6 +155,7 @@ lpips_fn = lambda x, y: lpips_net(lpips_norm_fn(x), lpips_norm_fn(y)).mean() ...@@ -159,6 +155,7 @@ lpips_fn = lambda x, y: lpips_net(lpips_norm_fn(x), lpips_norm_fn(y)).mean()
tic = time.time() tic = time.time()
for step in range(max_steps + 1): for step in range(max_steps + 1):
radiance_field.train() radiance_field.train()
estimator.train()
i = torch.randint(0, len(train_dataset), (1,)).item() i = torch.randint(0, len(train_dataset), (1,)).item()
data = train_dataset[i] data = train_dataset[i]
...@@ -172,18 +169,17 @@ for step in range(max_steps + 1): ...@@ -172,18 +169,17 @@ for step in range(max_steps + 1):
return density * render_step_size return density * render_step_size
# update occupancy grid # update occupancy grid
occupancy_grid.every_n_step( estimator.update_every_n_steps(
step=step, step=step,
occ_eval_fn=occ_eval_fn, occ_eval_fn=occ_eval_fn,
occ_thre=1e-2, occ_thre=1e-2,
) )
# render # render
rgb, acc, depth, n_rendering_samples = render_image( rgb, acc, depth, n_rendering_samples = render_image_with_occgrid(
radiance_field, radiance_field,
occupancy_grid, estimator,
rays, rays,
scene_aabb=scene_aabb,
# rendering options # rendering options
near_plane=near_plane, near_plane=near_plane,
render_step_size=render_step_size, render_step_size=render_step_size,
...@@ -211,7 +207,7 @@ for step in range(max_steps + 1): ...@@ -211,7 +207,7 @@ for step in range(max_steps + 1):
optimizer.step() optimizer.step()
scheduler.step() scheduler.step()
if step % 5000 == 0: if step % 10000 == 0:
elapsed_time = time.time() - tic elapsed_time = time.time() - tic
loss = F.mse_loss(rgb, pixels) loss = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(loss) / np.log(10.0) psnr = -10.0 * torch.log(loss) / np.log(10.0)
...@@ -225,6 +221,7 @@ for step in range(max_steps + 1): ...@@ -225,6 +221,7 @@ for step in range(max_steps + 1):
if step > 0 and step % max_steps == 0: if step > 0 and step % max_steps == 0:
# evaluation # evaluation
radiance_field.eval() radiance_field.eval()
estimator.eval()
psnrs = [] psnrs = []
lpips = [] lpips = []
...@@ -236,11 +233,10 @@ for step in range(max_steps + 1): ...@@ -236,11 +233,10 @@ for step in range(max_steps + 1):
pixels = data["pixels"] pixels = data["pixels"]
# rendering # rendering
rgb, acc, depth, _ = render_image( rgb, acc, depth, _ = render_image_with_occgrid(
radiance_field, radiance_field,
occupancy_grid, estimator,
rays, rays,
scene_aabb=scene_aabb,
# rendering options # rendering options
near_plane=near_plane, near_plane=near_plane,
render_step_size=render_step_size, render_step_size=render_step_size,
......
""" """
Copyright (c) 2022 Ruilong Li, UC Berkeley. Copyright (c) 2022 Ruilong Li, UC Berkeley.
""" """
import argparse import argparse
import itertools import itertools
import pathlib import pathlib
import time import time
from typing import Callable
import imageio import imageio
import numpy as np import numpy as np
...@@ -14,16 +14,15 @@ import torch.nn.functional as F ...@@ -14,16 +14,15 @@ import torch.nn.functional as F
import tqdm import tqdm
from lpips import LPIPS from lpips import LPIPS
from radiance_fields.ngp import NGPDensityField, NGPRadianceField from radiance_fields.ngp import NGPDensityField, NGPRadianceField
from utils import (
from examples.utils import (
MIPNERF360_UNBOUNDED_SCENES, MIPNERF360_UNBOUNDED_SCENES,
NERF_SYNTHETIC_SCENES, NERF_SYNTHETIC_SCENES,
render_image_proposal, render_image_with_propnet,
set_random_seed, set_random_seed,
) )
from nerfacc.estimators.prop_net import (
from nerfacc.proposal import ( PropNetEstimator,
compute_prop_loss,
get_proposal_annealing_fn,
get_proposal_requires_grad_fn, get_proposal_requires_grad_fn,
) )
...@@ -146,17 +145,40 @@ test_dataset = SubjectLoader( ...@@ -146,17 +145,40 @@ test_dataset = SubjectLoader(
) )
# setup the radiance field we want to train. # setup the radiance field we want to train.
grad_scaler = torch.cuda.amp.GradScaler(2**10) prop_optimizer = torch.optim.Adam(
radiance_field = NGPRadianceField(aabb=aabb, unbounded=unbounded).to(device)
optimizer = torch.optim.Adam(
itertools.chain( itertools.chain(
radiance_field.parameters(),
*[p.parameters() for p in proposal_networks], *[p.parameters() for p in proposal_networks],
), ),
lr=1e-2, lr=1e-2,
eps=1e-15, eps=1e-15,
weight_decay=weight_decay, weight_decay=weight_decay,
) )
prop_scheduler = torch.optim.lr_scheduler.ChainedScheduler(
[
torch.optim.lr_scheduler.LinearLR(
prop_optimizer, start_factor=0.01, total_iters=100
),
torch.optim.lr_scheduler.MultiStepLR(
prop_optimizer,
milestones=[
max_steps // 2,
max_steps * 3 // 4,
max_steps * 9 // 10,
],
gamma=0.33,
),
]
)
estimator = PropNetEstimator(prop_optimizer, prop_scheduler).to(device)
grad_scaler = torch.cuda.amp.GradScaler(2**10)
radiance_field = NGPRadianceField(aabb=aabb, unbounded=unbounded).to(device)
optimizer = torch.optim.Adam(
radiance_field.parameters(),
lr=1e-2,
eps=1e-15,
weight_decay=weight_decay,
)
scheduler = torch.optim.lr_scheduler.ChainedScheduler( scheduler = torch.optim.lr_scheduler.ChainedScheduler(
[ [
torch.optim.lr_scheduler.LinearLR( torch.optim.lr_scheduler.LinearLR(
...@@ -174,7 +196,7 @@ scheduler = torch.optim.lr_scheduler.ChainedScheduler( ...@@ -174,7 +196,7 @@ scheduler = torch.optim.lr_scheduler.ChainedScheduler(
] ]
) )
proposal_requires_grad_fn = get_proposal_requires_grad_fn() proposal_requires_grad_fn = get_proposal_requires_grad_fn()
proposal_annealing_fn = get_proposal_annealing_fn() # proposal_annealing_fn = get_proposal_annealing_fn()
lpips_net = LPIPS(net="vgg").to(device) lpips_net = LPIPS(net="vgg").to(device)
lpips_norm_fn = lambda x: x[None, ...].permute(0, 3, 1, 2) * 2 - 1 lpips_norm_fn = lambda x: x[None, ...].permute(0, 3, 1, 2) * 2 - 1
...@@ -186,6 +208,7 @@ for step in range(max_steps + 1): ...@@ -186,6 +208,7 @@ for step in range(max_steps + 1):
radiance_field.train() radiance_field.train()
for p in proposal_networks: for p in proposal_networks:
p.train() p.train()
estimator.train()
i = torch.randint(0, len(train_dataset), (1,)).item() i = torch.randint(0, len(train_dataset), (1,)).item()
data = train_dataset[i] data = train_dataset[i]
...@@ -194,18 +217,13 @@ for step in range(max_steps + 1): ...@@ -194,18 +217,13 @@ for step in range(max_steps + 1):
rays = data["rays"] rays = data["rays"]
pixels = data["pixels"] pixels = data["pixels"]
proposal_requires_grad = proposal_requires_grad_fn(step)
# render # render
( rgb, acc, depth, extras = render_image_with_propnet(
rgb,
acc,
depth,
weights_per_level,
s_vals_per_level,
) = render_image_proposal(
radiance_field, radiance_field,
proposal_networks, proposal_networks,
estimator,
rays, rays,
scene_aabb=None,
# rendering options # rendering options
num_samples=num_samples, num_samples=num_samples,
num_samples_per_prop=num_samples_per_prop, num_samples_per_prop=num_samples_per_prop,
...@@ -215,14 +233,14 @@ for step in range(max_steps + 1): ...@@ -215,14 +233,14 @@ for step in range(max_steps + 1):
opaque_bkgd=opaque_bkgd, opaque_bkgd=opaque_bkgd,
render_bkgd=render_bkgd, render_bkgd=render_bkgd,
# train options # train options
proposal_requires_grad=proposal_requires_grad_fn(step), proposal_requires_grad=proposal_requires_grad,
proposal_annealing=proposal_annealing_fn(step), )
estimator.update_every_n_steps(
extras["trans"], proposal_requires_grad, loss_scaler=1024
) )
# compute loss # compute loss
loss = F.smooth_l1_loss(rgb, pixels) loss = F.smooth_l1_loss(rgb, pixels)
loss_prop = compute_prop_loss(s_vals_per_level, weights_per_level)
loss = loss + loss_prop
optimizer.zero_grad() optimizer.zero_grad()
# do not unscale it because we are using Adam. # do not unscale it because we are using Adam.
...@@ -230,7 +248,7 @@ for step in range(max_steps + 1): ...@@ -230,7 +248,7 @@ for step in range(max_steps + 1):
optimizer.step() optimizer.step()
scheduler.step() scheduler.step()
if step % 5000 == 0: if step % 10000 == 0:
elapsed_time = time.time() - tic elapsed_time = time.time() - tic
loss = F.mse_loss(rgb, pixels) loss = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(loss) / np.log(10.0) psnr = -10.0 * torch.log(loss) / np.log(10.0)
...@@ -246,6 +264,7 @@ for step in range(max_steps + 1): ...@@ -246,6 +264,7 @@ for step in range(max_steps + 1):
radiance_field.eval() radiance_field.eval()
for p in proposal_networks: for p in proposal_networks:
p.eval() p.eval()
estimator.eval()
psnrs = [] psnrs = []
lpips = [] lpips = []
...@@ -257,11 +276,11 @@ for step in range(max_steps + 1): ...@@ -257,11 +276,11 @@ for step in range(max_steps + 1):
pixels = data["pixels"] pixels = data["pixels"]
# rendering # rendering
rgb, acc, depth, _, _, = render_image_proposal( rgb, acc, depth, _, = render_image_with_propnet(
radiance_field, radiance_field,
proposal_networks, proposal_networks,
estimator,
rays, rays,
scene_aabb=None,
# rendering options # rendering options
num_samples=num_samples, num_samples=num_samples,
num_samples_per_prop=num_samples_per_prop, num_samples_per_prop=num_samples_per_prop,
...@@ -270,7 +289,6 @@ for step in range(max_steps + 1): ...@@ -270,7 +289,6 @@ for step in range(max_steps + 1):
sampling_type=sampling_type, sampling_type=sampling_type,
opaque_bkgd=opaque_bkgd, opaque_bkgd=opaque_bkgd,
render_bkgd=render_bkgd, render_bkgd=render_bkgd,
proposal_annealing=proposal_annealing_fn(step),
# test options # test options
test_chunk_size=args.test_chunk_size, test_chunk_size=args.test_chunk_size,
) )
...@@ -289,6 +307,7 @@ for step in range(max_steps + 1): ...@@ -289,6 +307,7 @@ for step in range(max_steps + 1):
# (rgb - pixels).norm(dim=-1).cpu().numpy() * 255 # (rgb - pixels).norm(dim=-1).cpu().numpy() * 255
# ).astype(np.uint8), # ).astype(np.uint8),
# ) # )
# break
psnr_avg = sum(psnrs) / len(psnrs) psnr_avg = sum(psnrs) / len(psnrs)
lpips_avg = sum(lpips) / len(lpips) lpips_avg = sum(lpips) / len(lpips)
print(f"evaluation: psnr_avg={psnr_avg}, lpips_avg={lpips_avg}") print(f"evaluation: psnr_avg={psnr_avg}, lpips_avg={lpips_avg}")
...@@ -10,8 +10,9 @@ import torch ...@@ -10,8 +10,9 @@ import torch
from datasets.utils import Rays, namedtuple_map from datasets.utils import Rays, namedtuple_map
from torch.utils.data._utils.collate import collate, default_collate_fn_map from torch.utils.data._utils.collate import collate, default_collate_fn_map
from nerfacc import OccupancyGrid, ray_marching, rendering from nerfacc.estimators.occ_grid import OccGridEstimator
from nerfacc.proposal import rendering as rendering_proposal from nerfacc.estimators.prop_net import PropNetEstimator
from nerfacc.volrend import rendering
NERF_SYNTHETIC_SCENES = [ NERF_SYNTHETIC_SCENES = [
"chair", "chair",
...@@ -40,21 +41,14 @@ def set_random_seed(seed): ...@@ -40,21 +41,14 @@ def set_random_seed(seed):
torch.manual_seed(seed) torch.manual_seed(seed)
def enlarge_aabb(aabb, factor: float) -> torch.Tensor: def render_image_with_occgrid(
center = (aabb[:3] + aabb[3:]) / 2
extent = (aabb[3:] - aabb[:3]) / 2
return torch.cat([center - extent * factor, center + extent * factor])
def render_image(
# scene # scene
radiance_field: torch.nn.Module, radiance_field: torch.nn.Module,
occupancy_grid: OccupancyGrid, estimator: OccGridEstimator,
rays: Rays, rays: Rays,
scene_aabb: torch.Tensor,
# rendering options # rendering options
near_plane: Optional[float] = None, near_plane: float = 0.0,
far_plane: Optional[float] = None, far_plane: float = 1e10,
render_step_size: float = 1e-3, render_step_size: float = 1e-3,
render_bkgd: Optional[torch.Tensor] = None, render_bkgd: Optional[torch.Tensor] = None,
cone_angle: float = 0.0, cone_angle: float = 0.0,
...@@ -78,7 +72,7 @@ def render_image( ...@@ -78,7 +72,7 @@ def render_image(
def sigma_fn(t_starts, t_ends, ray_indices): def sigma_fn(t_starts, t_ends, ray_indices):
t_origins = chunk_rays.origins[ray_indices] t_origins = chunk_rays.origins[ray_indices]
t_dirs = chunk_rays.viewdirs[ray_indices] t_dirs = chunk_rays.viewdirs[ray_indices]
positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0 positions = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
if timestamps is not None: if timestamps is not None:
# dnerf # dnerf
t = ( t = (
...@@ -86,13 +80,15 @@ def render_image( ...@@ -86,13 +80,15 @@ def render_image(
if radiance_field.training if radiance_field.training
else timestamps.expand_as(positions[:, :1]) else timestamps.expand_as(positions[:, :1])
) )
return radiance_field.query_density(positions, t) sigmas = radiance_field.query_density(positions, t)
return radiance_field.query_density(positions) else:
sigmas = radiance_field.query_density(positions)
return sigmas.squeeze(-1)
def rgb_sigma_fn(t_starts, t_ends, ray_indices): def rgb_sigma_fn(t_starts, t_ends, ray_indices):
t_origins = chunk_rays.origins[ray_indices] t_origins = chunk_rays.origins[ray_indices]
t_dirs = chunk_rays.viewdirs[ray_indices] t_dirs = chunk_rays.viewdirs[ray_indices]
positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0 positions = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
if timestamps is not None: if timestamps is not None:
# dnerf # dnerf
t = ( t = (
...@@ -100,8 +96,10 @@ def render_image( ...@@ -100,8 +96,10 @@ def render_image(
if radiance_field.training if radiance_field.training
else timestamps.expand_as(positions[:, :1]) else timestamps.expand_as(positions[:, :1])
) )
return radiance_field(positions, t, t_dirs) rgbs, sigmas = radiance_field(positions, t, t_dirs)
return radiance_field(positions, t_dirs) else:
rgbs, sigmas = radiance_field(positions, t_dirs)
return rgbs, sigmas.squeeze(-1)
results = [] results = []
chunk = ( chunk = (
...@@ -111,11 +109,9 @@ def render_image( ...@@ -111,11 +109,9 @@ def render_image(
) )
for i in range(0, num_rays, chunk): for i in range(0, num_rays, chunk):
chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays) chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
ray_indices, t_starts, t_ends = ray_marching( ray_indices, t_starts, t_ends = estimator.sampling(
chunk_rays.origins, chunk_rays.origins,
chunk_rays.viewdirs, chunk_rays.viewdirs,
scene_aabb=scene_aabb,
grid=occupancy_grid,
sigma_fn=sigma_fn, sigma_fn=sigma_fn,
near_plane=near_plane, near_plane=near_plane,
far_plane=far_plane, far_plane=far_plane,
...@@ -124,7 +120,7 @@ def render_image( ...@@ -124,7 +120,7 @@ def render_image(
cone_angle=cone_angle, cone_angle=cone_angle,
alpha_thre=alpha_thre, alpha_thre=alpha_thre,
) )
rgb, opacity, depth = rendering( rgb, opacity, depth, extras = rendering(
t_starts, t_starts,
t_ends, t_ends,
ray_indices, ray_indices,
...@@ -146,12 +142,12 @@ def render_image( ...@@ -146,12 +142,12 @@ def render_image(
) )
def render_image_proposal( def render_image_with_propnet(
# scene # scene
radiance_field: torch.nn.Module, radiance_field: torch.nn.Module,
proposal_networks: Sequence[torch.nn.Module], proposal_networks: Sequence[torch.nn.Module],
estimator: PropNetEstimator,
rays: Rays, rays: Rays,
scene_aabb: torch.Tensor,
# rendering options # rendering options
num_samples: int, num_samples: int,
num_samples_per_prop: Sequence[int], num_samples_per_prop: Sequence[int],
...@@ -162,7 +158,6 @@ def render_image_proposal( ...@@ -162,7 +158,6 @@ def render_image_proposal(
render_bkgd: Optional[torch.Tensor] = None, render_bkgd: Optional[torch.Tensor] = None,
# train options # train options
proposal_requires_grad: bool = False, proposal_requires_grad: bool = False,
proposal_annealing: float = 1.0,
# test options # test options
test_chunk_size: int = 8192, test_chunk_size: int = 8192,
): ):
...@@ -180,16 +175,22 @@ def render_image_proposal( ...@@ -180,16 +175,22 @@ def render_image_proposal(
def prop_sigma_fn(t_starts, t_ends, proposal_network): def prop_sigma_fn(t_starts, t_ends, proposal_network):
t_origins = chunk_rays.origins[..., None, :] t_origins = chunk_rays.origins[..., None, :]
t_dirs = chunk_rays.viewdirs[..., None, :] t_dirs = chunk_rays.viewdirs[..., None, :]
positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0 positions = t_origins + t_dirs * (t_starts + t_ends)[..., None] / 2.0
return proposal_network(positions) sigmas = proposal_network(positions)
if opaque_bkgd:
sigmas[..., -1, :] = torch.inf
return sigmas.squeeze(-1)
def rgb_sigma_fn(t_starts, t_ends): def rgb_sigma_fn(t_starts, t_ends, ray_indices):
t_origins = chunk_rays.origins[..., None, :] t_origins = chunk_rays.origins[..., None, :]
t_dirs = chunk_rays.viewdirs[..., None, :].repeat_interleave( t_dirs = chunk_rays.viewdirs[..., None, :].repeat_interleave(
t_starts.shape[-2], dim=-2 t_starts.shape[-1], dim=-2
) )
positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0 positions = t_origins + t_dirs * (t_starts + t_ends)[..., None] / 2.0
return radiance_field(positions, t_dirs) rgb, sigmas = radiance_field(positions, t_dirs)
if opaque_bkgd:
sigmas[..., -1, :] = torch.inf
return rgb, sigmas.squeeze(-1)
results = [] results = []
chunk = ( chunk = (
...@@ -199,29 +200,26 @@ def render_image_proposal( ...@@ -199,29 +200,26 @@ def render_image_proposal(
) )
for i in range(0, num_rays, chunk): for i in range(0, num_rays, chunk):
chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays) chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
( t_starts, t_ends = estimator.sampling(
rgb,
opacity,
depth,
(weights_per_level, s_vals_per_level),
) = rendering_proposal(
rgb_sigma_fn=rgb_sigma_fn,
num_samples=num_samples,
prop_sigma_fns=[ prop_sigma_fns=[
lambda *args: prop_sigma_fn(*args, p) for p in proposal_networks lambda *args: prop_sigma_fn(*args, p) for p in proposal_networks
], ],
num_samples_per_prop=num_samples_per_prop, prop_samples=num_samples_per_prop,
rays_o=chunk_rays.origins, num_samples=num_samples,
rays_d=chunk_rays.viewdirs, n_rays=chunk_rays.origins.shape[0],
scene_aabb=scene_aabb,
near_plane=near_plane, near_plane=near_plane,
far_plane=far_plane, far_plane=far_plane,
stratified=radiance_field.training,
sampling_type=sampling_type, sampling_type=sampling_type,
opaque_bkgd=opaque_bkgd, stratified=radiance_field.training,
requires_grad=proposal_requires_grad,
)
rgb, opacity, depth, extras = rendering(
t_starts,
t_ends,
ray_indices=None,
n_rays=None,
rgb_sigma_fn=rgb_sigma_fn,
render_bkgd=render_bkgd, render_bkgd=render_bkgd,
proposal_requires_grad=proposal_requires_grad,
proposal_annealing=proposal_annealing,
) )
chunk_results = [rgb, opacity, depth] chunk_results = [rgb, opacity, depth]
results.append(chunk_results) results.append(chunk_results)
...@@ -237,6 +235,5 @@ def render_image_proposal( ...@@ -237,6 +235,5 @@ def render_image_proposal(
colors.view((*rays_shape[:-1], -1)), colors.view((*rays_shape[:-1], -1)),
opacities.view((*rays_shape[:-1], -1)), opacities.view((*rays_shape[:-1], -1)),
depths.view((*rays_shape[:-1], -1)), depths.view((*rays_shape[:-1], -1)),
weights_per_level, extras,
s_vals_per_level,
) )
""" """
Copyright (c) 2022 Ruilong Li, UC Berkeley. Copyright (c) 2022 Ruilong Li, UC Berkeley.
""" """
from .contraction import ContractionType, contract, contract_inv from .data_specs import RayIntervals, RaySamples
from .grid import Grid, OccupancyGrid, query_grid from .estimators.occ_grid import OccGridEstimator
from .intersection import ray_aabb_intersect from .estimators.prop_net import PropNetEstimator
from .pack import pack_data, pack_info, unpack_data, unpack_info from .grid import ray_aabb_intersect, traverse_grids
from .ray_marching import ray_marching from .pack import pack_info
from .pdf import importance_sampling, searchsorted
from .scan import exclusive_prod, exclusive_sum, inclusive_prod, inclusive_sum
from .version import __version__ from .version import __version__
from .vol_rendering import ( from .volrend import (
accumulate_along_rays, accumulate_along_rays,
render_transmittance_from_alpha, render_transmittance_from_alpha,
render_transmittance_from_density, render_transmittance_from_density,
render_visibility, render_visibility_from_alpha,
render_visibility_from_density,
render_weight_from_alpha, render_weight_from_alpha,
render_weight_from_density, render_weight_from_density,
rendering, rendering,
...@@ -19,28 +22,25 @@ from .vol_rendering import ( ...@@ -19,28 +22,25 @@ from .vol_rendering import (
__all__ = [ __all__ = [
"__version__", "__version__",
# occ grid "inclusive_prod",
"Grid", "exclusive_prod",
"OccupancyGrid", "inclusive_sum",
"query_grid", "exclusive_sum",
"ContractionType", "pack_info",
# contraction "render_visibility_from_alpha",
"contract", "render_visibility_from_density",
"contract_inv",
# marching
"ray_aabb_intersect",
"ray_marching",
# rendering
"accumulate_along_rays",
"render_visibility",
"render_weight_from_alpha", "render_weight_from_alpha",
"render_weight_from_density", "render_weight_from_density",
"render_transmittance_from_density",
"render_transmittance_from_alpha", "render_transmittance_from_alpha",
"render_transmittance_from_density",
"accumulate_along_rays",
"rendering", "rendering",
# pack "importance_sampling",
"pack_data", "searchsorted",
"unpack_data", "RayIntervals",
"unpack_info", "RaySamples",
"pack_info", "ray_aabb_intersect",
"traverse_grids",
"OccGridEstimator",
"PropNetEstimator",
] ]
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
from enum import Enum
import torch
import nerfacc.cuda as _C
class ContractionType(Enum):
"""Space contraction options.
This is an enum class that describes how a :class:`nerfacc.Grid` covers the 3D space.
It is also used by :func:`nerfacc.ray_marching` to determine how to perform ray marching
within the grid.
The options in this enum class are:
Attributes:
AABB: Linearly map the region of interest :math:`[x_0, x_1]` to a
unit cube in :math:`[0, 1]`.
.. math:: f(x) = \\frac{x - x_0}{x_1 - x_0}
UN_BOUNDED_TANH: Contract an unbounded space into a unit cube in :math:`[0, 1]`
using tanh. The region of interest :math:`[x_0, x_1]` is first
mapped into :math:`[-0.5, +0.5]` before applying tanh.
.. math:: f(x) = \\frac{1}{2}(tanh(\\frac{x - x_0}{x_1 - x_0} - \\frac{1}{2}) + 1)
UN_BOUNDED_SPHERE: Contract an unbounded space into a unit sphere. Used in
`Mip-Nerf 360: Unbounded Anti-Aliased Neural Radiance Fields`_.
.. math::
f(x) =
\\begin{cases}
z(x) & ||z(x)|| \\leq 1 \\\\
(2 - \\frac{1}{||z(x)||})(\\frac{z(x)}{||z(x)||}) & ||z(x)|| > 1
\\end{cases}
.. math::
z(x) = \\frac{x - x_0}{x_1 - x_0} * 2 - 1
.. _Mip-Nerf 360\: Unbounded Anti-Aliased Neural Radiance Fields:
https://arxiv.org/abs/2111.12077
"""
AABB = 0
UN_BOUNDED_TANH = 1
UN_BOUNDED_SPHERE = 2
def to_cpp_version(self):
"""Convert to the C++ version of the enum class.
Returns:
The C++ version of the enum class.
"""
return _C.ContractionTypeGetter(self.value)
@torch.no_grad()
def contract(
x: torch.Tensor,
roi: torch.Tensor,
type: ContractionType = ContractionType.AABB,
) -> torch.Tensor:
"""Contract the space into [0, 1]^3.
Args:
x (torch.Tensor): Un-contracted points.
roi (torch.Tensor): Region of interest.
type (ContractionType): Contraction type.
Returns:
torch.Tensor: Contracted points ([0, 1]^3).
"""
ctype = type.to_cpp_version()
return _C.contract(x.contiguous(), roi.contiguous(), ctype)
@torch.no_grad()
def contract_inv(
x: torch.Tensor,
roi: torch.Tensor,
type: ContractionType = ContractionType.AABB,
) -> torch.Tensor:
"""Recover the space from [0, 1]^3 by inverse contraction.
Args:
x (torch.Tensor): Contracted points ([0, 1]^3).
roi (torch.Tensor): Region of interest.
type (ContractionType): Contraction type.
Returns:
torch.Tensor: Un-contracted points.
"""
ctype = type.to_cpp_version()
return _C.contract_inv(x.contiguous(), roi.contiguous(), ctype)
...@@ -15,58 +15,26 @@ def _make_lazy_cuda_func(name: str) -> Callable: ...@@ -15,58 +15,26 @@ def _make_lazy_cuda_func(name: str) -> Callable:
return call_cuda return call_cuda
ContractionTypeGetter = _make_lazy_cuda_func("ContractionType")
contract = _make_lazy_cuda_func("contract")
contract_inv = _make_lazy_cuda_func("contract_inv")
grid_query = _make_lazy_cuda_func("grid_query")
ray_aabb_intersect = _make_lazy_cuda_func("ray_aabb_intersect")
ray_marching = _make_lazy_cuda_func("ray_marching")
is_cub_available = _make_lazy_cuda_func("is_cub_available") is_cub_available = _make_lazy_cuda_func("is_cub_available")
transmittance_from_sigma_forward_cub = _make_lazy_cuda_func(
"transmittance_from_sigma_forward_cub"
)
transmittance_from_sigma_backward_cub = _make_lazy_cuda_func(
"transmittance_from_sigma_backward_cub"
)
transmittance_from_alpha_forward_cub = _make_lazy_cuda_func(
"transmittance_from_alpha_forward_cub"
)
transmittance_from_alpha_backward_cub = _make_lazy_cuda_func(
"transmittance_from_alpha_backward_cub"
)
transmittance_from_sigma_forward_naive = _make_lazy_cuda_func( # data specs
"transmittance_from_sigma_forward_naive" MultiScaleGridSpec = _make_lazy_cuda_func("MultiScaleGridSpec")
) RaysSpec = _make_lazy_cuda_func("RaysSpec")
transmittance_from_sigma_backward_naive = _make_lazy_cuda_func( RaySegmentsSpec = _make_lazy_cuda_func("RaySegmentsSpec")
"transmittance_from_sigma_backward_naive"
)
transmittance_from_alpha_forward_naive = _make_lazy_cuda_func(
"transmittance_from_alpha_forward_naive"
)
transmittance_from_alpha_backward_naive = _make_lazy_cuda_func(
"transmittance_from_alpha_backward_naive"
)
weight_from_sigma_forward_naive = _make_lazy_cuda_func( # grid
"weight_from_sigma_forward_naive" ray_aabb_intersect = _make_lazy_cuda_func("ray_aabb_intersect")
) traverse_grids = _make_lazy_cuda_func("traverse_grids")
weight_from_sigma_backward_naive = _make_lazy_cuda_func(
"weight_from_sigma_backward_naive" # scan
) exclusive_sum_by_key = _make_lazy_cuda_func("exclusive_sum_by_key")
weight_from_alpha_forward_naive = _make_lazy_cuda_func( inclusive_sum = _make_lazy_cuda_func("inclusive_sum")
"weight_from_alpha_forward_naive" exclusive_sum = _make_lazy_cuda_func("exclusive_sum")
) inclusive_prod_forward = _make_lazy_cuda_func("inclusive_prod_forward")
weight_from_alpha_backward_naive = _make_lazy_cuda_func( inclusive_prod_backward = _make_lazy_cuda_func("inclusive_prod_backward")
"weight_from_alpha_backward_naive" exclusive_prod_forward = _make_lazy_cuda_func("exclusive_prod_forward")
) exclusive_prod_backward = _make_lazy_cuda_func("exclusive_prod_backward")
unpack_data = _make_lazy_cuda_func("unpack_data") # pdf
unpack_info = _make_lazy_cuda_func("unpack_info") importance_sampling = _make_lazy_cuda_func("importance_sampling")
unpack_info_to_mask = _make_lazy_cuda_func("unpack_info_to_mask") searchsorted = _make_lazy_cuda_func("searchsorted")
pdf_readout = _make_lazy_cuda_func("pdf_readout")
pdf_sampling = _make_lazy_cuda_func("pdf_sampling")
...@@ -44,6 +44,9 @@ extra_cflags = ["-O3"] ...@@ -44,6 +44,9 @@ extra_cflags = ["-O3"]
extra_cuda_cflags = ["-O3"] extra_cuda_cflags = ["-O3"]
_C = None _C = None
sources = list(glob.glob(os.path.join(PATH, "csrc/*.cu"))) + list(
glob.glob(os.path.join(PATH, "csrc/*.cpp"))
)
try: try:
# try to import the compiled module (via setup.py) # try to import the compiled module (via setup.py)
...@@ -57,7 +60,7 @@ except ImportError: ...@@ -57,7 +60,7 @@ except ImportError:
_C = load( _C = load(
name=name, name=name,
sources=glob.glob(os.path.join(PATH, "csrc/*.cu")), sources=sources,
extra_cflags=extra_cflags, extra_cflags=extra_cflags,
extra_cuda_cflags=extra_cuda_cflags, extra_cuda_cflags=extra_cuda_cflags,
extra_include_paths=extra_include_paths, extra_include_paths=extra_include_paths,
...@@ -72,7 +75,7 @@ except ImportError: ...@@ -72,7 +75,7 @@ except ImportError:
): ):
_C = load( _C = load(
name=name, name=name,
sources=glob.glob(os.path.join(PATH, "csrc/*.cu")), sources=sources,
extra_cflags=extra_cflags, extra_cflags=extra_cflags,
extra_cuda_cflags=extra_cuda_cflags, extra_cuda_cflags=extra_cuda_cflags,
extra_include_paths=extra_include_paths, extra_include_paths=extra_include_paths,
......
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include "include/helpers_cuda.h"
#include "include/helpers_math.h"
#include "include/helpers_contraction.h"
__global__ void contract_kernel(
// samples info
const uint32_t n_samples,
const float *samples, // (n_samples, 3)
// contraction
const float *roi,
const ContractionType type,
// outputs
float *out_samples)
{
CUDA_GET_THREAD_ID(i, n_samples);
// locate
samples += i * 3;
out_samples += i * 3;
const float3 roi_min = make_float3(roi[0], roi[1], roi[2]);
const float3 roi_max = make_float3(roi[3], roi[4], roi[5]);
const float3 xyz = make_float3(samples[0], samples[1], samples[2]);
float3 xyz_unit = apply_contraction(xyz, roi_min, roi_max, type);
out_samples[0] = xyz_unit.x;
out_samples[1] = xyz_unit.y;
out_samples[2] = xyz_unit.z;
return;
}
__global__ void contract_inv_kernel(
// samples info
const uint32_t n_samples,
const float *samples, // (n_samples, 3)
// contraction
const float *roi,
const ContractionType type,
// outputs
float *out_samples)
{
CUDA_GET_THREAD_ID(i, n_samples);
// locate
samples += i * 3;
out_samples += i * 3;
const float3 roi_min = make_float3(roi[0], roi[1], roi[2]);
const float3 roi_max = make_float3(roi[3], roi[4], roi[5]);
const float3 xyz_unit = make_float3(samples[0], samples[1], samples[2]);
float3 xyz = apply_contraction_inv(xyz_unit, roi_min, roi_max, type);
out_samples[0] = xyz.x;
out_samples[1] = xyz.y;
out_samples[2] = xyz.z;
return;
}
torch::Tensor contract(
const torch::Tensor samples,
// contraction
const torch::Tensor roi,
const ContractionType type)
{
DEVICE_GUARD(samples);
CHECK_INPUT(samples);
const int n_samples = samples.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_samples, threads);
torch::Tensor out_samples = torch::empty({n_samples, 3}, samples.options());
contract_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_samples,
samples.data_ptr<float>(),
// contraction
roi.data_ptr<float>(),
type,
// outputs
out_samples.data_ptr<float>());
return out_samples;
}
torch::Tensor contract_inv(
const torch::Tensor samples,
// contraction
const torch::Tensor roi,
const ContractionType type)
{
DEVICE_GUARD(samples);
CHECK_INPUT(samples);
const int n_samples = samples.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_samples, threads);
torch::Tensor out_samples = torch::empty({n_samples, 3}, samples.options());
contract_inv_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_samples,
samples.data_ptr<float>(),
// contraction
roi.data_ptr<float>(),
type,
// outputs
out_samples.data_ptr<float>());
return out_samples;
}
#include <torch/extension.h>
#include <ATen/NumericUtils.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <c10/util/MaybeOwned.h>
#include <curand.h>
#include <curand_kernel.h>
#include <curand_philox4x32_x.h>
#include "include/data_spec.hpp"
#include "include/data_spec_packed.cuh"
#include "include/utils_cuda.cuh"
#include "include/utils_grid.cuh"
#include "include/utils_math.cuh"
static constexpr uint32_t MAX_GRID_LEVELS = 8;
namespace {
namespace device {
inline __device__ float _calc_dt(
const float t, const float cone_angle,
const float dt_min, const float dt_max)
{
return clamp(t * cone_angle, dt_min, dt_max);
}
__global__ void traverse_grids_kernel(
// rays
int32_t n_rays,
float *rays_o, // [n_rays, 3]
float *rays_d, // [n_rays, 3]
// grids
int32_t n_grids,
int3 resolution,
bool *binaries, // [n_grids, resx, resy, resz]
float *aabbs, // [n_grids, 6]
// sorted intersections
bool *hits, // [n_rays, n_grids]
float *t_sorted, // [n_rays, n_grids * 2]
int64_t *t_indices, // [n_rays, n_grids * 2]
// options
float *near_planes,
float *far_planes,
float step_size,
float cone_angle,
// outputs
bool first_pass,
PackedRaySegmentsSpec intervals,
PackedRaySegmentsSpec samples)
{
float eps = 1e-6f;
// parallelize over rays
for (int32_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < n_rays; tid += blockDim.x * gridDim.x)
{
// skip rays that are empty.
if (intervals.chunk_cnts != nullptr)
if (!first_pass && intervals.chunk_cnts[tid] == 0) continue;
if (samples.chunk_cnts != nullptr)
if (!first_pass && samples.chunk_cnts[tid] == 0) continue;
int64_t chunk_start, chunk_start_bin;
if (!first_pass) {
if (intervals.chunk_cnts != nullptr)
chunk_start = intervals.chunk_starts[tid];
if (samples.chunk_cnts != nullptr)
chunk_start_bin = samples.chunk_starts[tid];
}
float near_plane = near_planes[tid];
float far_plane = far_planes[tid];
SingleRaySpec ray = SingleRaySpec(
rays_o + tid * 3, rays_d + tid * 3, near_plane, far_plane);
int32_t base_hits = tid * n_grids;
int32_t base_t_sorted = tid * n_grids * 2;
// loop over all intersections along the ray.
int64_t n_intervals = 0;
int64_t n_samples = 0;
float t_last = near_plane;
bool continuous = false;
for (int32_t i = base_t_sorted; i < base_t_sorted + n_grids * 2 - 1; i++) {
// whether this is the entering or leaving for this level of grid.
bool is_entering = t_indices[i] < n_grids;
int64_t level = t_indices[i] % n_grids;
// printf("i=%d, level=%lld, is_entering=%d, hits=%d\n", i, level, is_entering, hits[level]);
if (!hits[base_hits + level]) {
continue; // this grid is not hit.
}
if (!is_entering) {
// we are leaving this grid. Are we inside the next grid?
bool next_is_entering = t_indices[i + 1] < n_grids;
if (next_is_entering) continue; // we are outside next grid.
level = t_indices[i + 1] % n_grids;
if (!hits[base_hits + level]) {
continue; // this grid is not hit.
}
}
float this_tmin = fmaxf(t_sorted[i], near_plane);
float this_tmax = fminf(t_sorted[i + 1], far_plane);
if (this_tmin >= this_tmax) continue; // this interval is invalid. e.g. (0.0f, 0.0f)
// printf("i=%d, this_tmin=%f, this_tmax=%f, level=%lld\n", i, this_tmin, this_tmax, level);
if (!continuous) {
if (step_size <= 0.0f) { // march to this_tmin.
t_last = this_tmin;
} else {
float dt = _calc_dt(t_last, cone_angle, step_size, 1e10f);
while (true) { // march until t_mid is right after this_tmin.
if (t_last + dt * 0.5f >= this_tmin) break;
t_last += dt;
}
}
}
// printf(
// "[traverse segment] i=%d, this_mip=%d, this_tmin=%f, this_tmax=%f\n",
// i, this_mip, this_tmin, this_tmax);
AABBSpec aabb = AABBSpec(aabbs + level * 6);
// init: pre-compute variables needed for traversal
float3 tdist, delta;
int3 step_index, current_index, final_index;
setup_traversal(
ray, this_tmin, this_tmax, eps,
aabb, resolution,
// outputs
delta, tdist, step_index, current_index, final_index);
// printf(
// "[traverse init], delta=(%f, %f, %f), step_index=(%d, %d, %d)\n",
// delta.x, delta.y, delta.z, step_index.x, step_index.y, step_index.z
// );
const int3 overflow_index = final_index + step_index;
while (true) {
float t_traverse = min(tdist.x, min(tdist.y, tdist.z));
int64_t cell_id = (
current_index.x * resolution.y * resolution.z
+ current_index.y * resolution.z
+ current_index.z
+ level * resolution.x * resolution.y * resolution.z
);
if (!binaries[cell_id]) {
// skip the cell that is empty.
if (step_size <= 0.0f) { // march to t_traverse.
t_last = t_traverse;
} else {
float dt = _calc_dt(t_last, cone_angle, step_size, 1e10f);
while (true) { // march until t_mid is right after t_traverse.
if (t_last + dt * 0.5f >= t_traverse) break;
t_last += dt;
}
}
continuous = false;
} else {
// this cell is not empty, so we need to traverse it.
while (true) {
float t_next;
if (step_size <= 0.0f) {
t_next = t_traverse;
} else { // march until t_mid is right after t_traverse.
float dt = _calc_dt(t_last, cone_angle, step_size, 1e10f);
if (t_last + dt * 0.5f >= t_traverse) break;
t_next = t_last + dt;
}
// writeout the interval.
if (intervals.chunk_cnts != nullptr) {
if (!continuous) {
if (!first_pass) { // left side of the intervel
int64_t idx = chunk_start + n_intervals;
intervals.vals[idx] = t_last;
intervals.ray_indices[idx] = tid;
intervals.is_left[idx] = true;
}
n_intervals++;
if (!first_pass) { // right side of the intervel
int64_t idx = chunk_start + n_intervals;
intervals.vals[idx] = t_next;
intervals.ray_indices[idx] = tid;
intervals.is_right[idx] = true;
}
n_intervals++;
} else {
if (!first_pass) { // right side of the intervel
int64_t idx = chunk_start + n_intervals;
intervals.vals[idx] = t_next;
intervals.ray_indices[idx] = tid;
intervals.is_left[idx - 1] = true;
intervals.is_right[idx] = true;
}
n_intervals++;
}
}
// writeout the sample.
if (samples.chunk_cnts != nullptr) {
if (!first_pass) {
int64_t idx = chunk_start_bin + n_samples;
samples.vals[idx] = (t_next + t_last) * 0.5f;
samples.ray_indices[idx] = tid;
}
n_samples++;
}
continuous = true;
t_last = t_next;
if (t_next >= t_traverse) break;
}
}
// printf(
// "[traverse], t_last=%f, t_traverse=%f, cell_id=%d, current_index=(%d, %d, %d)\n",
// t_last, t_traverse, cell_id, current_index.x, current_index.y, current_index.z
// );
if (!single_traversal(tdist, current_index, overflow_index, step_index, delta)) {
break;
}
}
}
if (first_pass) {
if (intervals.chunk_cnts != nullptr)
intervals.chunk_cnts[tid] = n_intervals;
if (samples.chunk_cnts != nullptr)
samples.chunk_cnts[tid] = n_samples;
}
}
}
__global__ void ray_aabb_intersect_kernel(
const int32_t n_rays, float *rays_o, float *rays_d, float near, float far,
const int32_t n_aabbs, float *aabbs,
// outputs
const float miss_value,
float *t_mins, float *t_maxs, bool *hits)
{
int32_t numel = n_rays * n_aabbs;
// parallelize over rays
for (int32_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < numel; tid += blockDim.x * gridDim.x)
{
int32_t ray_id = tid / n_aabbs;
int32_t aabb_id = tid % n_aabbs;
float t_min, t_max;
bool hit = device::ray_aabb_intersect(
SingleRaySpec(rays_o + ray_id * 3, rays_d + ray_id * 3, near, far),
AABBSpec(aabbs + aabb_id * 6),
t_min, t_max
);
if (hit) {
t_mins[tid] = t_min;
t_maxs[tid] = t_max;
} else {
t_mins[tid] = miss_value;
t_maxs[tid] = miss_value;
}
hits[tid] = hit;
}
}
} // namespace device
} // namespace
std::vector<RaySegmentsSpec> traverse_grids(
// rays
const torch::Tensor rays_o, // [n_rays, 3]
const torch::Tensor rays_d, // [n_rays, 3]
// grids
const torch::Tensor binaries, // [n_grids, resx, resy, resz]
const torch::Tensor aabbs, // [n_grids, 6]
// intersections
const torch::Tensor t_mins, // [n_rays, n_grids]
const torch::Tensor t_maxs, // [n_rays, n_grids]
const torch::Tensor hits, // [n_rays, n_grids]
// options
const torch::Tensor near_planes,
const torch::Tensor far_planes,
const float step_size,
const float cone_angle,
const bool compute_intervals,
const bool compute_samples)
{
DEVICE_GUARD(rays_o);
int32_t n_rays = rays_o.size(0);
int32_t n_grids = binaries.size(0);
int3 resolution = make_int3(binaries.size(1), binaries.size(2), binaries.size(3));
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int32_t max_threads = 512;
int32_t max_blocks = 65535;
dim3 threads = dim3(min(max_threads, n_rays));
dim3 blocks = dim3(min(max_blocks, ceil_div<int32_t>(n_rays, threads.x)));
// Sort the intersections. [n_rays, n_grids * 2]
torch::Tensor t_sorted, t_indices;
if (n_grids > 1) {
std::tie(t_sorted, t_indices) = torch::sort(torch::cat({t_mins, t_maxs}, -1), -1);
}
else {
t_sorted = torch::cat({t_mins, t_maxs}, -1);
t_indices = torch::arange(
0, n_grids * 2, t_mins.options().dtype(torch::kLong)
).expand({n_rays, n_grids * 2}).contiguous();
}
// outputs
RaySegmentsSpec intervals, samples;
// first pass to count the number of segments along each ray.
if (compute_intervals)
intervals.memalloc_cnts(n_rays, rays_o.options(), false);
if (compute_samples)
samples.memalloc_cnts(n_rays, rays_o.options(), false);
device::traverse_grids_kernel<<<blocks, threads, 0, stream>>>(
// rays
n_rays,
rays_o.data_ptr<float>(), // [n_rays, 3]
rays_d.data_ptr<float>(), // [n_rays, 3]
// grids
n_grids,
resolution,
binaries.data_ptr<bool>(), // [n_grids, resx, resy, resz]
aabbs.data_ptr<float>(), // [n_grids, 6]
// sorted intersections
hits.data_ptr<bool>(), // [n_rays, n_grids]
t_sorted.data_ptr<float>(), // [n_rays, n_grids * 2]
t_indices.data_ptr<int64_t>(), // [n_rays, n_grids * 2]
// options
near_planes.data_ptr<float>(), // [n_rays]
far_planes.data_ptr<float>(), // [n_rays]
step_size,
cone_angle,
// outputs
true,
device::PackedRaySegmentsSpec(intervals),
device::PackedRaySegmentsSpec(samples));
// second pass to record the segments.
if (compute_intervals)
intervals.memalloc_data(true, true);
if (compute_samples)
samples.memalloc_data(false, false);
device::traverse_grids_kernel<<<blocks, threads, 0, stream>>>(
// rays
n_rays,
rays_o.data_ptr<float>(), // [n_rays, 3]
rays_d.data_ptr<float>(), // [n_rays, 3]
// grids
n_grids,
resolution,
binaries.data_ptr<bool>(), // [n_grids, resx, resy, resz]
aabbs.data_ptr<float>(), // [n_grids, 6]
// sorted intersections
hits.data_ptr<bool>(), // [n_rays, n_grids]
t_sorted.data_ptr<float>(), // [n_rays, n_grids * 2]
t_indices.data_ptr<int64_t>(), // [n_rays, n_grids * 2]
// options
near_planes.data_ptr<float>(), // [n_rays]
far_planes.data_ptr<float>(), // [n_rays]
step_size,
cone_angle,
// outputs
false,
device::PackedRaySegmentsSpec(intervals),
device::PackedRaySegmentsSpec(samples));
return {intervals, samples};
}
std::vector<torch::Tensor> ray_aabb_intersect(
const torch::Tensor rays_o, // [n_rays, 3]
const torch::Tensor rays_d, // [n_rays, 3]
const torch::Tensor aabbs, // [n_aabbs, 6]
const float near_plane,
const float far_plane,
const float miss_value)
{
DEVICE_GUARD(rays_o);
int32_t n_rays = rays_o.size(0);
int32_t n_aabbs = aabbs.size(0);
int32_t numel = n_rays * n_aabbs;
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int32_t max_threads = 512;
int32_t max_blocks = 65535;
dim3 threads = dim3(min(max_threads, numel));
dim3 blocks = dim3(min(max_blocks, ceil_div<int32_t>(numel, threads.x)));
// outputs
torch::Tensor t_mins = torch::empty({n_rays, n_aabbs}, rays_o.options());
torch::Tensor t_maxs = torch::empty({n_rays, n_aabbs}, rays_o.options());
torch::Tensor hits = torch::empty({n_rays, n_aabbs}, rays_d.options().dtype(torch::kBool));
device::ray_aabb_intersect_kernel<<<blocks, threads, 0, stream>>>(
// rays
n_rays,
rays_o.data_ptr<float>(), // [n_rays, 3]
rays_d.data_ptr<float>(), // [n_rays, 3]
near_plane,
far_plane,
// aabbs
n_aabbs,
aabbs.data_ptr<float>(), // [n_aabbs, 6]
// outputs
miss_value,
t_mins.data_ptr<float>(), // [n_rays, n_aabbs]
t_maxs.data_ptr<float>(), // [n_rays, n_aabbs]
hits.data_ptr<bool>()); // [n_rays, n_aabbs]
return {t_mins, t_maxs, hits};
}
#pragma once
#include <torch/extension.h>
#include "utils_cuda.cuh"
struct MultiScaleGridSpec {
torch::Tensor data; // [levels, resx, resy, resz]
torch::Tensor occupied; // [levels, resx, resy, resz]
torch::Tensor base_aabb; // [6,]
inline void check() {
CHECK_INPUT(data);
CHECK_INPUT(occupied);
CHECK_INPUT(base_aabb);
TORCH_CHECK(data.ndimension() == 4);
TORCH_CHECK(occupied.ndimension() == 4);
TORCH_CHECK(base_aabb.ndimension() == 1);
TORCH_CHECK(data.numel() == occupied.numel());
TORCH_CHECK(base_aabb.numel() == 6);
}
};
struct RaysSpec {
torch::Tensor origins; // [n_rays, 3]
torch::Tensor dirs; // [n_rays, 3]
inline void check() {
CHECK_INPUT(origins);
CHECK_INPUT(dirs);
TORCH_CHECK(origins.ndimension() == 2);
TORCH_CHECK(dirs.ndimension() == 2);
TORCH_CHECK(origins.numel() == dirs.numel());
TORCH_CHECK(origins.size(1) == 3);
TORCH_CHECK(dirs.size(1) == 3);
}
};
struct RaySegmentsSpec {
torch::Tensor vals; // [n_edges] or [n_rays, n_edges_per_ray]
// for flattened tensor
torch::Tensor chunk_starts; // [n_rays]
torch::Tensor chunk_cnts; // [n_rays]
torch::Tensor ray_indices; // [n_edges]
torch::Tensor is_left; // [n_edges] have n_bins true values
torch::Tensor is_right; // [n_edges] have n_bins true values
inline void check() {
CHECK_INPUT(vals);
TORCH_CHECK(vals.defined());
// batched tensor [..., n_edges_per_ray]
if (vals.ndimension() > 1) return;
// flattend tensor [n_edges]
CHECK_INPUT(chunk_starts);
CHECK_INPUT(chunk_cnts);
TORCH_CHECK(chunk_starts.defined());
TORCH_CHECK(chunk_cnts.defined());
TORCH_CHECK(chunk_starts.ndimension() == 1);
TORCH_CHECK(chunk_cnts.ndimension() == 1);
TORCH_CHECK(chunk_starts.numel() == chunk_cnts.numel());
if (ray_indices.defined()) {
CHECK_INPUT(ray_indices);
TORCH_CHECK(ray_indices.ndimension() == 1);
TORCH_CHECK(vals.numel() == ray_indices.numel());
}
if (is_left.defined()) {
CHECK_INPUT(is_left);
TORCH_CHECK(is_left.ndimension() == 1);
TORCH_CHECK(vals.numel() == is_left.numel());
}
if (is_right.defined()) {
CHECK_INPUT(is_right);
TORCH_CHECK(is_right.ndimension() == 1);
TORCH_CHECK(vals.numel() == is_right.numel());
}
}
inline void memalloc_cnts(int32_t n_rays, at::TensorOptions options, bool zero_init = true) {
TORCH_CHECK(!chunk_cnts.defined());
if (zero_init) {
chunk_cnts = torch::zeros({n_rays}, options.dtype(torch::kLong));
} else {
chunk_cnts = torch::empty({n_rays}, options.dtype(torch::kLong));
}
}
inline int64_t memalloc_data(bool alloc_masks = true, bool zero_init = true) {
TORCH_CHECK(chunk_cnts.defined());
TORCH_CHECK(!chunk_starts.defined());
TORCH_CHECK(!vals.defined());
torch::Tensor cumsum = torch::cumsum(chunk_cnts, 0, chunk_cnts.scalar_type());
int64_t n_edges = cumsum[-1].item<int64_t>();
chunk_starts = cumsum - chunk_cnts;
if (zero_init) {
vals = torch::zeros({n_edges}, chunk_cnts.options().dtype(torch::kFloat32));
ray_indices = torch::zeros({n_edges}, chunk_cnts.options().dtype(torch::kLong));
if (alloc_masks) {
is_left = torch::zeros({n_edges}, chunk_cnts.options().dtype(torch::kBool));
is_right = torch::zeros({n_edges}, chunk_cnts.options().dtype(torch::kBool));
}
} else {
vals = torch::empty({n_edges}, chunk_cnts.options().dtype(torch::kFloat32));
ray_indices = torch::empty({n_edges}, chunk_cnts.options().dtype(torch::kLong));
if (alloc_masks) {
is_left = torch::empty({n_edges}, chunk_cnts.options().dtype(torch::kBool));
is_right = torch::empty({n_edges}, chunk_cnts.options().dtype(torch::kBool));
}
}
return 1;
}
};
\ No newline at end of file
#pragma once
#include <torch/extension.h>
#include "data_spec.hpp"
namespace {
namespace device {
struct PackedRaySegmentsSpec {
PackedRaySegmentsSpec(RaySegmentsSpec& spec) :
vals(spec.vals.defined() ? spec.vals.data_ptr<float>() : nullptr),
is_batched(spec.vals.defined() ? spec.vals.dim() > 1 : false),
// for flattened tensor
chunk_starts(spec.chunk_starts.defined() ? spec.chunk_starts.data_ptr<int64_t>() : nullptr),
chunk_cnts(spec.chunk_cnts.defined() ? spec.chunk_cnts.data_ptr<int64_t>(): nullptr),
ray_indices(spec.ray_indices.defined() ? spec.ray_indices.data_ptr<int64_t>() : nullptr),
is_left(spec.is_left.defined() ? spec.is_left.data_ptr<bool>() : nullptr),
is_right(spec.is_right.defined() ? spec.is_right.data_ptr<bool>() : nullptr),
// for dimensions
n_edges(spec.vals.defined() ? spec.vals.numel() : 0),
n_rays(spec.chunk_cnts.defined() ? spec.chunk_cnts.size(0) : 0), // for flattened tensor
n_edges_per_ray(spec.vals.defined() ? spec.vals.size(-1) : 0) // for batched tensor
{ }
float* vals;
bool is_batched;
int64_t* chunk_starts;
int64_t* chunk_cnts;
int64_t* ray_indices;
bool* is_left;
bool* is_right;
int64_t n_edges;
int32_t n_rays;
int32_t n_edges_per_ray;
};
struct PackedMultiScaleGridSpec {
PackedMultiScaleGridSpec(MultiScaleGridSpec& spec) :
data(spec.data.data_ptr<float>()),
occupied(spec.occupied.data_ptr<bool>()),
base_aabb(spec.base_aabb.data_ptr<float>()),
levels(spec.data.size(0)),
resolution{
(int32_t)spec.data.size(1),
(int32_t)spec.data.size(2),
(int32_t)spec.data.size(3)}
{ }
float* data;
bool* occupied;
float* base_aabb;
int32_t levels;
int3 resolution;
};
struct PackedRaysSpec {
PackedRaysSpec(RaysSpec& spec) :
origins(spec.origins.data_ptr<float>()),
dirs(spec.dirs.data_ptr<float>()),
N(spec.origins.size(0))
{ }
float *origins;
float *dirs;
int32_t N;
};
struct SingleRaySpec {
// TODO: check inv_dir if dir is zero.
__device__ SingleRaySpec(
float *rays_o, float *rays_d, float tmin, float tmax) :
origin{rays_o[0], rays_o[1], rays_o[2]},
dir{rays_d[0], rays_d[1], rays_d[2]},
inv_dir{1.0f/rays_d[0], 1.0f/rays_d[1], 1.0f/rays_d[2]},
tmin{tmin},
tmax{tmax}
{ }
__device__ SingleRaySpec(
PackedRaysSpec& rays, int32_t id, float tmin, float tmax) :
origin{
rays.origins[id * 3],
rays.origins[id * 3 + 1],
rays.origins[id * 3 + 2]},
dir{
rays.dirs[id * 3],
rays.dirs[id * 3 + 1],
rays.dirs[id * 3 + 2]},
inv_dir{
1.0f / rays.dirs[id * 3],
1.0f / rays.dirs[id * 3 + 1],
1.0f / rays.dirs[id * 3 + 2]},
tmin{tmin},
tmax{tmax}
{ }
float3 origin;
float3 dir;
float3 inv_dir;
float tmin;
float tmax;
};
struct AABBSpec {
__device__ AABBSpec(float *aabb) :
min{aabb[0], aabb[1], aabb[2]},
max{aabb[3], aabb[4], aabb[5]}
{ }
__device__ AABBSpec(float3 min, float3 max) :
min{min.x, min.y, min.z},
max{max.x, max.y, max.z}
{ }
float3 min;
float3 max;
};
} // namespace device
} // namespace
\ No newline at end of file
...@@ -4,7 +4,10 @@ ...@@ -4,7 +4,10 @@
#pragma once #pragma once
#include "helpers_math.h" #include "utils_math.cuh"
namespace {
namespace device {
enum ContractionType enum ContractionType
{ {
...@@ -127,3 +130,6 @@ inline __device__ __host__ float3 apply_contraction_inv( ...@@ -127,3 +130,6 @@ inline __device__ __host__ float3 apply_contraction_inv(
return unit_sphere_to_inf(xyz, roi_min, roi_max); return unit_sphere_to_inf(xyz, roi_min, roi_max);
} }
} }
} // namespace device
} // namespace
\ No newline at end of file
...@@ -40,3 +40,9 @@ ...@@ -40,3 +40,9 @@
func(temp_storage.get(), temp_storage_bytes, __VA_ARGS__); \ func(temp_storage.get(), temp_storage_bytes, __VA_ARGS__); \
AT_CUDA_CHECK(cudaGetLastError()); \ AT_CUDA_CHECK(cudaGetLastError()); \
} while (false) } while (false)
template <typename scalar_t>
inline __device__ __host__ scalar_t ceil_div(scalar_t a, scalar_t b)
{
return (a + b - 1) / b;
}
\ No newline at end of file
#pragma once
#include "data_spec_packed.cuh"
#include "utils_contraction.cuh"
#include "utils_math.cuh"
namespace {
namespace device {
inline __device__ bool ray_aabb_intersect(
SingleRaySpec ray, AABBSpec aabb,
// outputs
float& tmin, float& tmax)
{
float tmin_temp{};
float tmax_temp{};
if (ray.inv_dir.x >= 0) {
tmin = (aabb.min.x - ray.origin.x) * ray.inv_dir.x;
tmax = (aabb.max.x - ray.origin.x) * ray.inv_dir.x;
} else {
tmin = (aabb.max.x - ray.origin.x) * ray.inv_dir.x;
tmax = (aabb.min.x - ray.origin.x) * ray.inv_dir.x;
}
if (ray.inv_dir.y >= 0) {
tmin_temp = (aabb.min.y - ray.origin.y) * ray.inv_dir.y;
tmax_temp = (aabb.max.y - ray.origin.y) * ray.inv_dir.y;
} else {
tmin_temp = (aabb.max.y - ray.origin.y) * ray.inv_dir.y;
tmax_temp = (aabb.min.y - ray.origin.y) * ray.inv_dir.y;
}
if (tmin > tmax_temp || tmin_temp > tmax) return false;
if (tmin_temp > tmin) tmin = tmin_temp;
if (tmax_temp < tmax) tmax = tmax_temp;
if (ray.inv_dir.z >= 0) {
tmin_temp = (aabb.min.z - ray.origin.z) * ray.inv_dir.z;
tmax_temp = (aabb.max.z - ray.origin.z) * ray.inv_dir.z;
} else {
tmin_temp = (aabb.max.z - ray.origin.z) * ray.inv_dir.z;
tmax_temp = (aabb.min.z - ray.origin.z) * ray.inv_dir.z;
}
if (tmin > tmax_temp || tmin_temp > tmax) return false;
if (tmin_temp > tmin) tmin = tmin_temp;
if (tmax_temp < tmax) tmax = tmax_temp;
if (tmax <= 0) return false;
tmin = fmaxf(tmin, ray.tmin);
tmax = fminf(tmax, ray.tmax);
return true;
}
inline __device__ void setup_traversal(
SingleRaySpec ray, float tmin, float tmax, float eps,
AABBSpec aabb, int3 resolution,
// outputs
float3 &delta, float3 &tdist,
int3 &step_index, int3 &current_index, int3 &final_index)
{
const float3 res = make_float3(resolution);
const float3 voxel_size = (aabb.max - aabb.min) / res;
const float3 ray_start = ray.origin + ray.dir * (tmin + eps);
const float3 ray_end = ray.origin + ray.dir * (tmax - eps);
// get voxel index of start and end within grid
// TODO: check float error here!
current_index = make_int3(
apply_contraction(ray_start, aabb.min, aabb.max, ContractionType::AABB)
* res
);
current_index = clamp(current_index, make_int3(0, 0, 0), resolution - 1);
final_index = make_int3(
apply_contraction(ray_end, aabb.min, aabb.max, ContractionType::AABB)
* res
);
final_index = clamp(final_index, make_int3(0, 0, 0), resolution - 1);
//
const int3 index_delta = make_int3(
ray.dir.x > 0 ? 1 : 0, ray.dir.y > 0 ? 1 : 0, ray.dir.z > 0 ? 1 : 0
);
const int3 start_index = current_index + index_delta;
const float3 tmax_xyz = ((aabb.min +
((make_float3(start_index) * voxel_size) - ray_start)) * ray.inv_dir) + tmin;
tdist = make_float3(
(ray.dir.x == 0.0f) ? tmax : tmax_xyz.x,
(ray.dir.y == 0.0f) ? tmax : tmax_xyz.y,
(ray.dir.z == 0.0f) ? tmax : tmax_xyz.z
);
// printf("tdist: %f %f %f\n", tdist.x, tdist.y, tdist.z);
const float3 step_float = make_float3(
(ray.dir.x == 0.0f) ? 0.0f : (ray.dir.x > 0.0f ? 1.0f : -1.0f),
(ray.dir.y == 0.0f) ? 0.0f : (ray.dir.y > 0.0f ? 1.0f : -1.0f),
(ray.dir.z == 0.0f) ? 0.0f : (ray.dir.z > 0.0f ? 1.0f : -1.0f)
);
step_index = make_int3(step_float);
// printf("step_index: %d %d %d\n", step_index.x, step_index.y, step_index.z);
const float3 delta_temp = voxel_size * ray.inv_dir * step_float;
delta = make_float3(
(ray.dir.x == 0.0f) ? tmax : delta_temp.x,
(ray.dir.y == 0.0f) ? tmax : delta_temp.y,
(ray.dir.z == 0.0f) ? tmax : delta_temp.z
);
// printf("delta: %f %f %f\n", delta.x, delta.y, delta.z);
}
inline __device__ bool single_traversal(
float3& tdist, int3& current_index,
const int3 overflow_index, const int3 step_index, const float3 delta) {
if ((tdist.x < tdist.y) && (tdist.x < tdist.z)) {
// X-axis traversal.
current_index.x += step_index.x;
tdist.x += delta.x;
if (current_index.x == overflow_index.x) {
return false;
}
} else if (tdist.y < tdist.z) {
// Y-axis traversal.
current_index.y += step_index.y;
tdist.y += delta.y;
if (current_index.y == overflow_index.y) {
return false;
}
} else {
// Z-axis traversal.
current_index.z += step_index.z;
tdist.z += delta.z;
if (current_index.z == overflow_index.z) {
return false;
}
}
return true;
}
} // namespace device
} // namespace
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment