Unverified Commit 8dcfbad9 authored by Ruilong Li(李瑞龙)'s avatar Ruilong Li(李瑞龙) Committed by GitHub
Browse files

Reformat (#31)



* seems working

* contraction func in cuda

* Update type

* More type updates

* disable DDA for contraction

* update contraction perfom in readme

* 360 data: Garden

* eval at max_steps

* add perform of 360 to readme

* fix contraction scaling

* tiny hot fix

* new volrend

* cleanup ray_marching.cu

* cleanup backend

* tests

* cleaning up Grid

* fix doc for grid base class

* check and fix for contraction

* test grid

* rendering and marching

* transmittance_compress verified

* rendering is indeed faster

* pipeline is working

* lego example

* cleanup

* cuda folder is cleaned up! finally!

* cuda formatting

* contraction verify

* upgrade grid

* test for ray marching

* pipeline

* ngp with contraction

* train_ngp runs but slow

* trasmittance seperate to two. Now NGP is as fast as before

* verified faster than before

* bug fix for contraction

* ngp contraction fix

* tiny cleanup

* contraction works! yay!

* contraction with tanh seems working

* minor update

* support alpha rendering

* absorb visibility to ray marching

* tiny import update

* get rid of contraction temperture;

* doc for ContractionType

* doc for Grid

* doc for grid.py is done

* doc for ray marching

* rendering function

* fix doc for rendering

* doc for vol rend

* autosummary for utils

* fix autosummary line break

* utils docs

* api doc is done

* starting work on examples

* contraction for npg is in python now

* further clean up examples

* mlp nerf is running

* dnerf is in

* update readme command

* merge

* disable pylint error for now

* reformatting and skip tests without cuda

* fix the type issue for contractiontype

* fix cuda attribute issue

* bump to 0.1.0
Co-authored-by: default avatarMatt Tancik <tancik@berkeley.edu>
parent a7611603
import collections
Rays = collections.namedtuple("Rays", ("origins", "viewdirs"))
def namedtuple_map(fn, tup):
"""Apply `fn` to each element of `tup` and cast to `tup`'s namedtuple."""
return type(tup)(*(None if x is None else fn(x) for x in tup))
Subproject commit 18c64d8fe9f888fdcc50b9e4170fcf9f7dc11eaa
from abc import abstractmethod
from typing import Tuple
import torch
import torch.nn as nn
class BaseRadianceField(nn.Module):
"""An abstract RadianceField class (supports both 2D and 3D).
The key functions to be implemented are:
- forward(positions, directions, masks): returns rgb and density.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__()
@abstractmethod
def forward(
self,
positions: torch.Tensor,
directions: torch.Tensor = None,
masks: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Returns {rgb, density}."""
raise NotImplementedError()
@abstractmethod
def query_density(self, positions: torch.Tensor) -> torch.Tensor:
"""Returns {density}."""
raise NotImplementedError()
......@@ -44,7 +44,11 @@ class MLP(nn.Module):
self.hidden_layers.append(
nn.Linear(in_features, self.net_width, bias=bias_enabled)
)
if (self.skip_layer is not None) and (i % self.skip_layer == 0) and (i > 0):
if (
(self.skip_layer is not None)
and (i % self.skip_layer == 0)
and (i > 0)
):
in_features = self.net_width + self.input_dim
else:
in_features = self.net_width
......@@ -82,7 +86,11 @@ class MLP(nn.Module):
for i in range(self.net_depth):
x = self.hidden_layers[i](x)
x = self.hidden_activation(x)
if (self.skip_layer is not None) and (i % self.skip_layer == 0) and (i > 0):
if (
(self.skip_layer is not None)
and (i % self.skip_layer == 0)
and (i > 0)
):
x = torch.cat([x, inputs], dim=-1)
if self.output_enabled:
x = self.output_layer(x)
......@@ -169,7 +177,9 @@ class SinusoidalEncoder(nn.Module):
@property
def latent_dim(self) -> int:
return (int(self.use_identity) + (self.max_deg - self.min_deg) * 2) * self.x_dim
return (
int(self.use_identity) + (self.max_deg - self.min_deg) * 2
) * self.x_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
......@@ -212,6 +222,13 @@ class VanillaNeRFRadianceField(nn.Module):
net_width_condition=net_width_condition,
)
def query_opacity(self, x, step_size):
density = self.query_density(x)
# if the density is small enough those two are the same.
# opacity = 1.0 - torch.exp(-density * step_size)
opacity = density * step_size
return opacity
def query_density(self, x):
x = self.posi_encoder(x)
sigma = self.mlp.query_density(x)
......@@ -231,7 +248,8 @@ class DNeRFRadianceField(nn.Module):
self.posi_encoder = SinusoidalEncoder(3, 0, 0, True)
self.time_encoder = SinusoidalEncoder(1, 0, 0, True)
self.warp = MLP(
input_dim=self.posi_encoder.latent_dim + self.time_encoder.latent_dim,
input_dim=self.posi_encoder.latent_dim
+ self.time_encoder.latent_dim,
output_dim=3,
net_depth=4,
net_width=64,
......@@ -240,6 +258,15 @@ class DNeRFRadianceField(nn.Module):
)
self.nerf = VanillaNeRFRadianceField()
def query_opacity(self, x, timestamps, step_size):
idxs = torch.randint(0, len(timestamps), (x.shape[0],), device=x.device)
t = timestamps[idxs]
density = self.query_density(x, t)
# if the density is small enough those two are the same.
# opacity = 1.0 - torch.exp(-density * step_size)
opacity = density * step_size
return opacity
def query_density(self, x, t):
x = x + self.warp(
torch.cat([self.posi_encoder(x), self.time_encoder(t)], dim=-1)
......
......@@ -14,8 +14,6 @@ except ImportError as e:
)
exit()
from .base import BaseRadianceField
class _TruncExp(Function): # pylint: disable=abstract-method
# Implementation from torch-ngp:
......@@ -36,7 +34,7 @@ class _TruncExp(Function): # pylint: disable=abstract-method
trunc_exp = _TruncExp.apply
class NGPradianceField(BaseRadianceField):
class NGPradianceField(torch.nn.Module):
"""Instance-NGP radiance Field"""
def __init__(
......@@ -45,7 +43,7 @@ class NGPradianceField(BaseRadianceField):
num_dim: int = 3,
use_viewdirs: bool = True,
density_activation: Callable = lambda x: trunc_exp(x - 1),
# density_activation: Callable = lambda x: torch.nn.functional.softplus(x - 1),
unbounded: bool = False,
) -> None:
super().__init__()
if not isinstance(aabb, torch.Tensor):
......@@ -54,6 +52,7 @@ class NGPradianceField(BaseRadianceField):
self.num_dim = num_dim
self.use_viewdirs = use_viewdirs
self.density_activation = density_activation
self.unbounded = unbounded
self.geo_feat_dim = 15
per_level_scale = 1.4472692012786865
......@@ -96,7 +95,11 @@ class NGPradianceField(BaseRadianceField):
self.mlp_head = tcnn.Network(
n_input_dims=(
(self.direction_encoding.n_output_dims if self.use_viewdirs else 0)
(
self.direction_encoding.n_output_dims
if self.use_viewdirs
else 0
)
+ self.geo_feat_dim
),
n_output_dims=3,
......@@ -109,9 +112,36 @@ class NGPradianceField(BaseRadianceField):
},
)
def query_opacity(self, x, step_size):
density = self.query_density(x)
aabb_min, aabb_max = torch.split(self.aabb, self.num_dim, dim=-1)
if self.unbounded:
# TODO: [revisit] is this necessary?
# 1.0 / derivative of tanh contraction
x = (x - aabb_min) / (aabb_max - aabb_min)
x = x - 0.5
scaling = 1.0 / (
torch.clamp(1.0 - torch.tanh(x) ** 2, min=1e6) * 0.5
)
scaling = scaling * (aabb_max - aabb_min)
else:
scaling = aabb_max - aabb_min
step_size = step_size * scaling.norm(dim=-1, keepdim=True)
# if the density is small enough those two are the same.
# opacity = 1.0 - torch.exp(-density * step_size)
opacity = density * step_size
return opacity
def query_density(self, x, return_feat: bool = False):
bb_min, bb_max = torch.split(self.aabb, [self.num_dim, self.num_dim], dim=0)
x = (x - bb_min) / (bb_max - bb_min)
if self.unbounded:
# tanh contraction
aabb_min, aabb_max = torch.split(self.aabb, self.num_dim, dim=-1)
x = (x - aabb_min) / (aabb_max - aabb_min)
x = x - 0.5
x = (torch.tanh(x) + 1) * 0.5
else:
aabb_min, aabb_max = torch.split(self.aabb, self.num_dim, dim=-1)
x = (x - aabb_min) / (aabb_max - aabb_min)
selector = ((x > 0.0) & (x < 1.0)).all(dim=-1)
x = (
self.mlp_base(x.view(-1, self.num_dim))
......@@ -122,7 +152,8 @@ class NGPradianceField(BaseRadianceField):
x, [1, self.geo_feat_dim], dim=-1
)
density = (
self.density_activation(density_before_activation) * selector[..., None]
self.density_activation(density_before_activation)
* selector[..., None]
)
if return_feat:
return density, base_mlp_out
......@@ -137,36 +168,22 @@ class NGPradianceField(BaseRadianceField):
h = torch.cat([d, embedding.view(-1, self.geo_feat_dim)], dim=-1)
else:
h = embedding.view(-1, self.geo_feat_dim)
rgb = self.mlp_head(h).view(list(embedding.shape[:-1]) + [3]).to(embedding)
rgb = (
self.mlp_head(h)
.view(list(embedding.shape[:-1]) + [3])
.to(embedding)
)
return rgb
def forward(
self,
positions: torch.Tensor,
directions: torch.Tensor = None,
mask: torch.Tensor = None,
only_density: bool = False,
):
if self.use_viewdirs and (directions is not None):
assert (
positions.shape == directions.shape
), f"{positions.shape} v.s. {directions.shape}"
if mask is not None:
density = torch.zeros_like(positions[..., :1])
rgb = torch.zeros(list(positions.shape[:-1]) + [3], device=positions.device)
density[mask], embedding = self.query_density(positions[mask])
if only_density:
return density
rgb[mask] = self.query_rgb(
directions[mask] if directions is not None else None,
embedding=embedding,
)
else:
density, embedding = self.query_density(positions, return_feat=True)
if only_density:
return density
rgb = self._query_rgb(directions, embedding=embedding)
return rgb, density
import argparse
import math
import os
import time
import imageio
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from datasets.dnerf_synthetic import SubjectLoader
from radiance_fields.mlp import DNeRFRadianceField
from utils import render_image, set_random_seed
from nerfacc import ContractionType, OccupancyGrid
if __name__ == "__main__":
device = "cuda:0"
set_random_seed(42)
parser = argparse.ArgumentParser()
parser.add_argument(
"--train_split",
type=str,
default="train",
choices=["train"],
help="which train split to use",
)
parser.add_argument(
"--scene",
type=str,
default="lego",
choices=[
# dnerf
"bouncingballs",
"hellwarrior",
"hook",
"jumpingjacks",
"lego",
"mutant",
"standup",
"trex",
],
help="which scene to use",
)
parser.add_argument(
"--aabb",
type=lambda s: [float(item) for item in s.split(",")],
default="-1.5,-1.5,-1.5,1.5,1.5,1.5",
help="delimited list input",
)
parser.add_argument(
"--test_chunk_size",
type=int,
default=8192,
)
parser.add_argument("--cone_angle", type=float, default=0.0)
args = parser.parse_args()
render_n_samples = 1024
# setup the scene bounding box.
contraction_type = ContractionType.AABB
scene_aabb = torch.tensor(args.aabb, dtype=torch.float32, device=device)
near_plane = None
far_plane = None
render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max()
* math.sqrt(3)
/ render_n_samples
).item()
# setup the radiance field we want to train.
max_steps = 40000
grad_scaler = torch.cuda.amp.GradScaler(1)
radiance_field = DNeRFRadianceField().to(device)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[max_steps // 2, max_steps * 3 // 4, max_steps * 9 // 10],
gamma=0.33,
)
# setup the dataset
data_root_fp = "/home/ruilongli/data/dnerf/"
target_sample_batch_size = 1 << 16
grid_resolution = 128
train_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=data_root_fp,
split=args.train_split,
num_rays=target_sample_batch_size // render_n_samples,
)
train_dataset.images = train_dataset.images.to(device)
train_dataset.camtoworlds = train_dataset.camtoworlds.to(device)
train_dataset.K = train_dataset.K.to(device)
train_dataset.timestamps = train_dataset.timestamps.to(device)
test_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=data_root_fp,
split="test",
num_rays=None,
)
test_dataset.images = test_dataset.images.to(device)
test_dataset.camtoworlds = test_dataset.camtoworlds.to(device)
test_dataset.K = test_dataset.K.to(device)
test_dataset.timestamps = test_dataset.timestamps.to(device)
occupancy_grid = OccupancyGrid(
roi_aabb=args.aabb,
resolution=grid_resolution,
contraction_type=contraction_type,
).to(device)
# training
step = 0
tic = time.time()
for epoch in range(10000000):
for i in range(len(train_dataset)):
radiance_field.train()
data = train_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
timestamps = data["timestamps"]
# update occupancy grid
occupancy_grid.every_n_step(
step=step,
occ_eval_fn=lambda x: radiance_field.query_opacity(
x, timestamps, render_step_size
),
)
# render
rgb, acc, depth, n_rendering_samples = render_image(
radiance_field,
occupancy_grid,
rays,
scene_aabb,
# rendering options
near_plane=near_plane,
far_plane=far_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
# dnerf options
timestamps=timestamps,
)
# dynamic batch size for rays to keep sample batch size constant.
num_rays = len(pixels)
num_rays = int(
num_rays
* (target_sample_batch_size / float(n_rendering_samples))
)
train_dataset.update_num_rays(num_rays)
alive_ray_mask = acc.squeeze(-1) > 0
# compute loss
loss = F.smooth_l1_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
optimizer.zero_grad()
# do not unscale it because we are using Adam.
grad_scaler.scale(loss).backward()
optimizer.step()
scheduler.step()
if step % 100 == 0:
elapsed_time = time.time() - tic
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
print(
f"elapsed_time={elapsed_time:.2f}s | {step=} | "
f"loss={loss:.5f} | "
f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |"
)
if step >= 0 and step % max_steps == 0 and step > 0:
# evaluation
radiance_field.eval()
psnrs = []
with torch.no_grad():
for i in tqdm.tqdm(range(len(test_dataset))):
data = test_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
timestamps = data["timestamps"]
# rendering
rgb, acc, depth, _ = render_image(
radiance_field,
occupancy_grid,
rays,
scene_aabb,
# rendering options
near_plane=None,
far_plane=None,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
# test options
test_chunk_size=args.test_chunk_size,
# dnerf options
timestamps=timestamps,
)
mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item())
# imageio.imwrite(
# "acc_binary_test.png",
# ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8),
# )
# imageio.imwrite(
# "rgb_test.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8),
# )
# break
psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation: {psnr_avg=}")
train_dataset.training = True
if step == max_steps:
print("training stops")
exit()
step += 1
import argparse
import math
import os
import random
import time
import imageio
......@@ -9,118 +8,17 @@ import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from radiance_fields.mlp import VanillaNeRFRadianceField
from utils import render_image, set_random_seed
from nerfacc import OccupancyField, volumetric_rendering_pipeline
device = "cuda:0"
def _set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def render_image(
radiance_field,
rays,
timestamps,
render_bkgd,
render_step_size,
test_chunk_size=81920,
):
"""Render the pixels of an image.
Args:
radiance_field: the radiance field of nerf.
rays: a `Rays` namedtuple, the rays to be rendered.
Returns:
rgb: torch.tensor, rendered color image.
depth: torch.tensor, rendered depth image.
acc: torch.tensor, rendered accumulated weights per pixel.
"""
rays_shape = rays.origins.shape
if len(rays_shape) == 3:
height, width, _ = rays_shape
num_rays = height * width
rays = namedtuple_map(lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays)
else:
num_rays, _ = rays_shape
def sigma_fn(frustum_starts, frustum_ends, ray_indices):
ray_indices = ray_indices.long()
frustum_origins = chunk_rays.origins[ray_indices]
frustum_dirs = chunk_rays.viewdirs[ray_indices]
positions = (
frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0
)
if timestamps is None:
return radiance_field.query_density(positions)
else:
if radiance_field.training:
t = timestamps[ray_indices]
else:
t = timestamps.expand_as(positions[:, :1])
return radiance_field.query_density(positions, t)
def rgb_sigma_fn(frustum_starts, frustum_ends, ray_indices):
ray_indices = ray_indices.long()
frustum_origins = chunk_rays.origins[ray_indices]
frustum_dirs = chunk_rays.viewdirs[ray_indices]
positions = (
frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0
)
if timestamps is None:
return radiance_field(positions, frustum_dirs)
else:
if radiance_field.training:
t = timestamps[ray_indices]
else:
t = timestamps.expand_as(positions[:, :1])
return radiance_field(positions, t, frustum_dirs)
results = []
chunk = torch.iinfo(torch.int32).max if radiance_field.training else test_chunk_size
for i in range(0, num_rays, chunk):
chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
chunk_results = volumetric_rendering_pipeline(
sigma_fn=sigma_fn,
rgb_sigma_fn=rgb_sigma_fn,
rays_o=chunk_rays.origins,
rays_d=chunk_rays.viewdirs,
scene_aabb=occ_field.aabb,
scene_occ_binary=occ_field.occ_grid_binary,
scene_resolution=occ_field.resolution,
render_bkgd=render_bkgd,
render_step_size=render_step_size,
near_plane=0.0,
stratified=radiance_field.training,
)
results.append(chunk_results)
colors, opacities, n_marching_samples, n_rendering_samples = [
torch.cat(r, dim=0) if isinstance(r[0], torch.Tensor) else r
for r in zip(*results)
]
return (
colors.view((*rays_shape[:-1], -1)),
opacities.view((*rays_shape[:-1], -1)),
sum(n_marching_samples),
sum(n_rendering_samples),
)
from nerfacc import ContractionType, OccupancyGrid
if __name__ == "__main__":
_set_random_seed(42)
device = "cuda:0"
set_random_seed(42)
parser = argparse.ArgumentParser()
parser.add_argument(
"method",
type=str,
default="ngp",
choices=["ngp", "vanilla", "dnerf"],
help="which nerf to use",
)
parser.add_argument(
"--train_split",
type=str,
......@@ -142,161 +40,149 @@ if __name__ == "__main__":
"materials",
"mic",
"ship",
# dnerf
"bouncingballs",
"hellwarrior",
"hook",
"jumpingjacks",
"lego",
"mutant",
"standup",
"trex",
# mipnerf360 unbounded
"garden",
],
help="which scene to use",
)
parser.add_argument(
"--aabb",
type=list,
default=[-1.5, -1.5, -1.5, 1.5, 1.5, 1.5],
type=lambda s: [float(item) for item in s.split(",")],
default="-1.5,-1.5,-1.5,1.5,1.5,1.5",
help="delimited list input",
)
parser.add_argument(
"--test_chunk_size",
type=int,
default=81920,
default=8192,
)
parser.add_argument(
"--unbounded",
action="store_true",
help="whether to use unbounded rendering",
)
parser.add_argument("--cone_angle", type=float, default=0.0)
args = parser.parse_args()
if args.method == "ngp":
from datasets.nerf_synthetic import SubjectLoader, namedtuple_map
from radiance_fields.ngp import NGPradianceField
render_n_samples = 1024
radiance_field = NGPradianceField(aabb=args.aabb).to(device)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=1e-2, eps=1e-15)
max_steps = 20000
occ_field_warmup_steps = 256
grad_scaler = torch.cuda.amp.GradScaler(2**10)
data_root_fp = "/home/ruilongli/data/nerf_synthetic/"
target_sample_batch_size = 1 << 18
# setup the scene bounding box.
if args.unbounded:
print("Using unbounded rendering")
contraction_type = ContractionType.UN_BOUNDED_SPHERE
# contraction_type = ContractionType.UN_BOUNDED_TANH
scene_aabb = None
near_plane = 0.2
far_plane = 1e4
render_step_size = 1e-2
else:
contraction_type = ContractionType.AABB
scene_aabb = torch.tensor(args.aabb, dtype=torch.float32, device=device)
near_plane = None
far_plane = None
render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max()
* math.sqrt(3)
/ render_n_samples
).item()
# setup the radiance field we want to train.
max_steps = 40000
grad_scaler = torch.cuda.amp.GradScaler(1)
radiance_field = VanillaNeRFRadianceField().to(device)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[max_steps // 2, max_steps * 3 // 4, max_steps * 9 // 10],
gamma=0.33,
)
elif args.method == "vanilla":
from datasets.nerf_synthetic import SubjectLoader, namedtuple_map
from radiance_fields.mlp import VanillaNeRFRadianceField
# setup the dataset
train_dataset_kwargs = {}
test_dataset_kwargs = {}
if args.scene == "garden":
from datasets.nerf_360_v2 import SubjectLoader
radiance_field = VanillaNeRFRadianceField().to(device)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4)
max_steps = 40000
occ_field_warmup_steps = 2000
grad_scaler = torch.cuda.amp.GradScaler(1)
data_root_fp = "/home/ruilongli/data/nerf_synthetic/"
data_root_fp = "/home/ruilongli/data/360_v2/"
target_sample_batch_size = 1 << 16
train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4}
test_dataset_kwargs = {"factor": 4}
grid_resolution = 128
else:
from datasets.nerf_synthetic import SubjectLoader
elif args.method == "dnerf":
from datasets.dnerf_synthetic import SubjectLoader, namedtuple_map
from radiance_fields.mlp import DNeRFRadianceField
radiance_field = DNeRFRadianceField().to(device)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4)
max_steps = 40000
occ_field_warmup_steps = 2000
grad_scaler = torch.cuda.amp.GradScaler(1)
data_root_fp = "/home/ruilongli/data/dnerf/"
data_root_fp = "/home/ruilongli/data/nerf_synthetic/"
target_sample_batch_size = 1 << 16
grid_resolution = 128
scene = args.scene
# setup the scene bounding box.
scene_aabb = torch.tensor(args.aabb)
# setup some rendering settings
render_n_samples = 1024
render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples
).item()
# setup dataset
train_dataset = SubjectLoader(
subject_id=scene,
subject_id=args.scene,
root_fp=data_root_fp,
split=args.train_split,
num_rays=target_sample_batch_size // render_n_samples,
# color_bkgd_aug="random",
**train_dataset_kwargs,
)
train_dataset.images = train_dataset.images.to(device)
train_dataset.camtoworlds = train_dataset.camtoworlds.to(device)
train_dataset.K = train_dataset.K.to(device)
if hasattr(train_dataset, "timestamps"):
train_dataset.timestamps = train_dataset.timestamps.to(device)
test_dataset = SubjectLoader(
subject_id=scene,
subject_id=args.scene,
root_fp=data_root_fp,
split="test",
num_rays=None,
**test_dataset_kwargs,
)
test_dataset.images = test_dataset.images.to(device)
test_dataset.camtoworlds = test_dataset.camtoworlds.to(device)
test_dataset.K = test_dataset.K.to(device)
if hasattr(train_dataset, "timestamps"):
test_dataset.timestamps = test_dataset.timestamps.to(device)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[max_steps // 2, max_steps * 3 // 4, max_steps * 9 // 10],
gamma=0.33,
)
# setup occupancy field with eval function
def occ_eval_fn(x: torch.Tensor) -> torch.Tensor:
"""Evaluate occupancy given positions.
Args:
x: positions with shape (N, 3).
Returns:
occupancy values with shape (N, 1).
"""
if args.method == "dnerf":
idxs = torch.randint(
0, len(train_dataset.timestamps), (x.shape[0],), device=x.device
)
t = train_dataset.timestamps[idxs]
density_after_activation = radiance_field.query_density(x, t)
else:
density_after_activation = radiance_field.query_density(x)
# those two are similar when density is small.
# occupancy = 1.0 - torch.exp(-density_after_activation * render_step_size)
occupancy = density_after_activation * render_step_size
return occupancy
occ_field = OccupancyField(
occ_eval_fn=occ_eval_fn, aabb=scene_aabb, resolution=128
occupancy_grid = OccupancyGrid(
roi_aabb=args.aabb,
resolution=grid_resolution,
contraction_type=contraction_type,
).to(device)
# training
step = 0
tic = time.time()
data_time = 0
tic_data = time.time()
for epoch in range(10000000):
for i in range(len(train_dataset)):
radiance_field.train()
data = train_dataset[i]
data_time += time.time() - tic_data
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
timestamps = data.get("timestamps", None)
# update occupancy grid
occ_field.every_n_step(step, warmup_steps=occ_field_warmup_steps)
occupancy_grid.every_n_step(
step=step,
occ_eval_fn=lambda x: radiance_field.query_opacity(
x, render_step_size
),
)
rgb, acc, counter, compact_counter = render_image(
radiance_field, rays, timestamps, render_bkgd, render_step_size
# render
rgb, acc, depth, n_rendering_samples = render_image(
radiance_field,
occupancy_grid,
rays,
scene_aabb,
# rendering options
near_plane=near_plane,
far_plane=far_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
)
# dynamic batch size for rays to keep sample batch size constant.
num_rays = len(pixels)
num_rays = int(
num_rays * (target_sample_batch_size / float(compact_counter))
num_rays
* (target_sample_batch_size / float(n_rendering_samples))
)
train_dataset.update_num_rays(num_rays)
alive_ray_mask = acc.squeeze(-1) > 0
......@@ -314,13 +200,12 @@ if __name__ == "__main__":
elapsed_time = time.time() - tic
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
print(
f"elapsed_time={elapsed_time:.2f}s (data={data_time:.2f}s) | {step=} | "
f"elapsed_time={elapsed_time:.2f}s | {step=} | "
f"loss={loss:.5f} | "
f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
f"counter={counter:d} | compact_counter={compact_counter:d} | num_rays={len(pixels):d} |"
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |"
)
# if time.time() - tic > 300:
if step >= 0 and step % max_steps == 0 and step > 0:
# evaluation
radiance_field.eval()
......@@ -332,40 +217,34 @@ if __name__ == "__main__":
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
timestamps = data.get("timestamps", None)
# rendering
rgb, acc, _, _ = render_image(
rgb, acc, depth, _ = render_image(
radiance_field,
occupancy_grid,
rays,
timestamps,
render_bkgd,
render_step_size,
scene_aabb,
# rendering options
near_plane=None,
far_plane=None,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
# test options
test_chunk_size=args.test_chunk_size,
)
mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item())
# if step == max_steps:
# output_dir = os.path.join("./outputs/nerfacc/", scene)
# os.makedirs(output_dir, exist_ok=True)
# save = torch.cat([pixels, rgb], dim=1)
# imageio.imwrite(
# os.path.join(output_dir, "%05d.png" % i),
# (save.cpu().numpy() * 255).astype(np.uint8),
# )
# else:
# imageio.imwrite(
# "acc_binary_test.png",
# ((acc > 0).float().cpu().numpy() * 255).astype(
# np.uint8
# ),
# )
# imageio.imwrite(
# "rgb_test.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8),
# )
# break
# imageio.imwrite(
# "acc_binary_test.png",
# ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8),
# )
# imageio.imwrite(
# "rgb_test.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8),
# )
# break
psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation: {psnr_avg=}")
train_dataset.training = True
......@@ -373,6 +252,5 @@ if __name__ == "__main__":
if step == max_steps:
print("training stops")
exit()
tic_data = time.time()
step += 1
import argparse
import math
import os
import time
import imageio
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from radiance_fields.ngp import NGPradianceField
from utils import render_image, set_random_seed
from nerfacc import ContractionType, OccupancyGrid
if __name__ == "__main__":
device = "cuda:0"
set_random_seed(42)
parser = argparse.ArgumentParser()
parser.add_argument(
"--train_split",
type=str,
default="trainval",
choices=["train", "trainval"],
help="which train split to use",
)
parser.add_argument(
"--scene",
type=str,
default="lego",
choices=[
# nerf synthetic
"chair",
"drums",
"ficus",
"hotdog",
"lego",
"materials",
"mic",
"ship",
# mipnerf360 unbounded
"garden",
],
help="which scene to use",
)
parser.add_argument(
"--aabb",
type=lambda s: [float(item) for item in s.split(",")],
default="-1.5,-1.5,-1.5,1.5,1.5,1.5",
help="delimited list input",
)
parser.add_argument(
"--test_chunk_size",
type=int,
default=8192,
)
parser.add_argument(
"--unbounded",
action="store_true",
help="whether to use unbounded rendering",
)
parser.add_argument("--cone_angle", type=float, default=0.0)
args = parser.parse_args()
render_n_samples = 1024
# setup the scene bounding box.
if args.unbounded:
print("Using unbounded rendering")
contraction_type = ContractionType.UN_BOUNDED_SPHERE
# contraction_type = ContractionType.UN_BOUNDED_TANH
scene_aabb = None
near_plane = 0.2
far_plane = 1e4
render_step_size = 1e-2
else:
contraction_type = ContractionType.AABB
scene_aabb = torch.tensor(args.aabb, dtype=torch.float32, device=device)
near_plane = None
far_plane = None
render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max()
* math.sqrt(3)
/ render_n_samples
).item()
# setup the radiance field we want to train.
max_steps = 20000
grad_scaler = torch.cuda.amp.GradScaler(2**10)
radiance_field = NGPradianceField(
aabb=args.aabb,
unbounded=args.unbounded,
).to(device)
optimizer = torch.optim.Adam(
radiance_field.parameters(), lr=1e-2, eps=1e-15
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[max_steps // 2, max_steps * 3 // 4, max_steps * 9 // 10],
gamma=0.33,
)
# setup the dataset
train_dataset_kwargs = {}
test_dataset_kwargs = {}
if args.scene == "garden":
from datasets.nerf_360_v2 import SubjectLoader
data_root_fp = "/home/ruilongli/data/360_v2/"
target_sample_batch_size = 1 << 20
train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4}
test_dataset_kwargs = {"factor": 4}
grid_resolution = 128
else:
from datasets.nerf_synthetic import SubjectLoader
data_root_fp = "/home/ruilongli/data/nerf_synthetic/"
target_sample_batch_size = 1 << 18
grid_resolution = 128
train_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=data_root_fp,
split=args.train_split,
num_rays=target_sample_batch_size // render_n_samples,
**train_dataset_kwargs,
)
train_dataset.images = train_dataset.images.to(device)
train_dataset.camtoworlds = train_dataset.camtoworlds.to(device)
train_dataset.K = train_dataset.K.to(device)
test_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=data_root_fp,
split="test",
num_rays=None,
**test_dataset_kwargs,
)
test_dataset.images = test_dataset.images.to(device)
test_dataset.camtoworlds = test_dataset.camtoworlds.to(device)
test_dataset.K = test_dataset.K.to(device)
occupancy_grid = OccupancyGrid(
roi_aabb=args.aabb,
resolution=grid_resolution,
contraction_type=contraction_type,
).to(device)
# training
step = 0
tic = time.time()
for epoch in range(10000000):
for i in range(len(train_dataset)):
radiance_field.train()
data = train_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
# update occupancy grid
occupancy_grid.every_n_step(
step=step,
occ_eval_fn=lambda x: radiance_field.query_opacity(
x, render_step_size
),
)
# render
rgb, acc, depth, n_rendering_samples = render_image(
radiance_field,
occupancy_grid,
rays,
scene_aabb,
# rendering options
near_plane=near_plane,
far_plane=far_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
)
# dynamic batch size for rays to keep sample batch size constant.
num_rays = len(pixels)
num_rays = int(
num_rays
* (target_sample_batch_size / float(n_rendering_samples))
)
train_dataset.update_num_rays(num_rays)
alive_ray_mask = acc.squeeze(-1) > 0
# compute loss
loss = F.smooth_l1_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
optimizer.zero_grad()
# do not unscale it because we are using Adam.
grad_scaler.scale(loss).backward()
optimizer.step()
scheduler.step()
if step % 100 == 0:
elapsed_time = time.time() - tic
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
print(
f"elapsed_time={elapsed_time:.2f}s | {step=} | "
f"loss={loss:.5f} | "
f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |"
)
if step >= 0 and step % max_steps == 0 and step > 0:
# evaluation
radiance_field.eval()
psnrs = []
with torch.no_grad():
for i in tqdm.tqdm(range(len(test_dataset))):
data = test_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
# rendering
rgb, acc, depth, _ = render_image(
radiance_field,
occupancy_grid,
rays,
scene_aabb,
# rendering options
near_plane=None,
far_plane=None,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
# test options
test_chunk_size=args.test_chunk_size,
)
mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item())
# imageio.imwrite(
# "acc_binary_test.png",
# ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8),
# )
# imageio.imwrite(
# "rgb_test.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8),
# )
# break
psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation: {psnr_avg=}")
train_dataset.training = True
if step == max_steps:
print("training stops")
exit()
step += 1
import random
from typing import Optional
import numpy as np
import torch
from datasets.utils import Rays, namedtuple_map
from nerfacc import OccupancyGrid, ray_marching, rendering
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def render_image(
# scene
radiance_field: torch.nn.Module,
occupancy_grid: OccupancyGrid,
rays: Rays,
scene_aabb: torch.Tensor,
# rendering options
near_plane: Optional[float] = None,
far_plane: Optional[float] = None,
render_step_size: float = 1e-3,
render_bkgd: Optional[torch.Tensor] = None,
cone_angle: float = 0.0,
# test options
test_chunk_size: int = 8192,
# only useful for dnerf
timestamps: Optional[torch.Tensor] = None,
):
"""Render the pixels of an image."""
rays_shape = rays.origins.shape
if len(rays_shape) == 3:
height, width, _ = rays_shape
num_rays = height * width
rays = namedtuple_map(
lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays
)
else:
num_rays, _ = rays_shape
def sigma_fn(t_starts, t_ends, ray_indices):
ray_indices = ray_indices.long()
t_origins = chunk_rays.origins[ray_indices]
t_dirs = chunk_rays.viewdirs[ray_indices]
positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0
if timestamps is not None:
# dnerf
t = (
timestamps[ray_indices]
if radiance_field.training
else timestamps.expand_as(positions[:, :1])
)
return radiance_field.query_density(positions, t)
return radiance_field.query_density(positions)
def rgb_sigma_fn(t_starts, t_ends, ray_indices):
ray_indices = ray_indices.long()
t_origins = chunk_rays.origins[ray_indices]
t_dirs = chunk_rays.viewdirs[ray_indices]
positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0
if timestamps is not None:
# dnerf
t = (
timestamps[ray_indices]
if radiance_field.training
else timestamps.expand_as(positions[:, :1])
)
return radiance_field(positions, t, t_dirs)
return radiance_field(positions, t_dirs)
results = []
chunk = (
torch.iinfo(torch.int32).max
if radiance_field.training
else test_chunk_size
)
for i in range(0, num_rays, chunk):
chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
packed_info, t_starts, t_ends = ray_marching(
chunk_rays.origins,
chunk_rays.viewdirs,
scene_aabb=scene_aabb,
grid=occupancy_grid,
sigma_fn=sigma_fn,
near_plane=near_plane,
far_plane=far_plane,
render_step_size=render_step_size,
stratified=radiance_field.training,
cone_angle=cone_angle,
)
rgb, opacity, depth = rendering(
rgb_sigma_fn,
packed_info,
t_starts,
t_ends,
render_bkgd=render_bkgd,
)
chunk_results = [rgb, opacity, depth, len(t_starts)]
results.append(chunk_results)
colors, opacities, depths, n_rendering_samples = [
torch.cat(r, dim=0) if isinstance(r[0], torch.Tensor) else r
for r in zip(*results)
]
return (
colors.view((*rays_shape[:-1], -1)),
opacities.view((*rays_shape[:-1], -1)),
depths.view((*rays_shape[:-1], -1)),
sum(n_rendering_samples),
)
"""nerfacc - A Python package for the fast volumetric rendering."""
from .occupancy_field import OccupancyField
from .utils import (
from .contraction import ContractionType, contract, contract_inv
from .grid import Grid, OccupancyGrid
from .pipeline import rendering, volumetric_rendering
from .ray_marching import (
ray_aabb_intersect,
ray_marching,
unpack_to_ray_indices,
volumetric_marching,
volumetric_rendering_accumulate,
volumetric_rendering_steps,
volumetric_rendering_weights,
)
from .volumetric_rendering import volumetric_rendering_pipeline
from .vol_rendering import (
accumulate_along_rays,
render_visibility,
render_weight_from_alpha,
render_weight_from_density,
)
__all__ = [
"OccupancyField",
"Grid",
"OccupancyGrid",
"ContractionType",
"contract",
"contract_inv",
"ray_aabb_intersect",
"volumetric_marching",
"volumetric_rendering_accumulate",
"volumetric_rendering_steps",
"volumetric_rendering_weights",
"volumetric_rendering_pipeline",
"ray_marching",
"unpack_to_ray_indices",
"accumulate_along_rays",
"render_visibility",
"render_weight_from_alpha",
"render_weight_from_density",
"volumetric_rendering",
"rendering",
]
from enum import Enum
import torch
import nerfacc.cuda as _C
class ContractionType(Enum):
"""Space contraction options.
This is an enum class that describes how a :class:`nerfacc.Grid` covers the 3D space.
It is also used by :func:`nerfacc.ray_marching` to determine how to perform ray marching
within the grid.
The options in this enum class are:
Attributes:
AABB: Linearly map the region of interest :math:`[x_0, x_1]` to a
unit cube in :math:`[0, 1]`.
.. math:: f(x) = \\frac{x - x_0}{x_1 - x_0}
UN_BOUNDED_TANH: Contract an unbounded space into a unit cube in :math:`[0, 1]`
using tanh. The region of interest :math:`[x_0, x_1]` is first
mapped into :math:`[-0.5, +0.5]` before applying tanh.
.. math:: f(x) = \\frac{1}{2}(tanh(\\frac{x - x_0}{x_1 - x_0} - \\frac{1}{2}) + 1)
UN_BOUNDED_SPHERE: Contract an unbounded space into a unit sphere. Used in
`Mip-NeRF 360: Unbounded Anti-Aliased Neural Radiance Fields`_.
.. math::
f(x) =
\\begin{cases}
z(x) & ||z(x)|| \\leq 1 \\\\
(2 - \\frac{1}{||z(x)||})(\\frac{z(x)}{||z(x)||}) & ||z(x)|| > 1
\\end{cases}
.. math::
z(x) = \\frac{x - x_0}{x_1 - x_0} * 2 - 1
.. _Mip-NeRF 360\: Unbounded Anti-Aliased Neural Radiance Fields:
https://arxiv.org/abs/2111.12077
"""
AABB = 0
UN_BOUNDED_TANH = 1
UN_BOUNDED_SPHERE = 2
@torch.no_grad()
def contract(
x: torch.Tensor,
roi: torch.Tensor,
type: ContractionType = ContractionType.AABB,
) -> torch.Tensor:
"""Contract the space into [0, 1]^3.
Args:
x (torch.Tensor): Un-contracted points.
roi (torch.Tensor): Region of interest.
type (ContractionType): Contraction type.
Returns:
torch.Tensor: Contracted points ([0, 1]^3).
"""
ctype = _C.ContractionType(type.value)
return _C.contract(x.contiguous(), roi.contiguous(), ctype)
@torch.no_grad()
def contract_inv(
x: torch.Tensor,
roi: torch.Tensor,
type: ContractionType = ContractionType.AABB,
) -> torch.Tensor:
"""Recover the space from [0, 1]^3 by inverse contraction.
Args:
x (torch.Tensor): Contracted points ([0, 1]^3).
roi (torch.Tensor): Region of interest.
type (ContractionType): Contraction type.
Returns:
torch.Tensor: Un-contracted points.
"""
ctype = _C.ContractionType(type.value)
return _C.contract_inv(x.contiguous(), roi.contiguous(), ctype)
from typing import Callable
from typing import Any, Callable
def _make_lazy_cuda(name: str) -> Callable:
def _make_lazy_cuda_func(name: str) -> Callable:
def call_cuda(*args, **kwargs):
# pylint: disable=import-outside-toplevel
from ._backend import _C
......@@ -11,14 +11,27 @@ def _make_lazy_cuda(name: str) -> Callable:
return call_cuda
ray_aabb_intersect = _make_lazy_cuda("ray_aabb_intersect")
volumetric_marching = _make_lazy_cuda("volumetric_marching")
volumetric_rendering_steps = _make_lazy_cuda("volumetric_rendering_steps")
volumetric_rendering_weights_forward = _make_lazy_cuda(
"volumetric_rendering_weights_forward"
)
volumetric_rendering_weights_backward = _make_lazy_cuda(
"volumetric_rendering_weights_backward"
)
unpack_to_ray_indices = _make_lazy_cuda("unpack_to_ray_indices")
query_occ = _make_lazy_cuda("query_occ")
def _make_lazy_cuda_attribute(name: str) -> Any:
# pylint: disable=import-outside-toplevel
from ._backend import _C
if _C is None:
return None
else:
return getattr(_C, name)
ContractionType = _make_lazy_cuda_attribute("ContractionType")
contract = _make_lazy_cuda_func("contract")
contract_inv = _make_lazy_cuda_func("contract_inv")
query_occ = _make_lazy_cuda_func("query_occ")
ray_aabb_intersect = _make_lazy_cuda_func("ray_aabb_intersect")
ray_marching = _make_lazy_cuda_func("ray_marching")
unpack_to_ray_indices = _make_lazy_cuda_func("unpack_to_ray_indices")
rendering_forward = _make_lazy_cuda_func("rendering_forward")
rendering_backward = _make_lazy_cuda_func("rendering_backward")
rendering_alphas_forward = _make_lazy_cuda_func("rendering_alphas_forward")
rendering_alphas_backward = _make_lazy_cuda_func("rendering_alphas_backward")
......@@ -4,9 +4,6 @@ import os
from subprocess import DEVNULL, call
from rich.console import Console
console = Console()
from torch.utils.cpp_extension import load
PATH = os.path.dirname(os.path.abspath(__file__))
......@@ -14,7 +11,6 @@ PATH = os.path.dirname(os.path.abspath(__file__))
def cuda_toolkit_available():
"""Check if the nvcc is avaiable on the machine."""
# https://github.com/idiap/fast-transformers/blob/master/setup.py
try:
call(["nvcc"], stdout=DEVNULL, stderr=DEVNULL)
return True
......@@ -22,22 +18,18 @@ def cuda_toolkit_available():
return False
_C = None
if cuda_toolkit_available():
sources = glob.glob(os.path.join(PATH, "csrc/*.cu"))
else:
sources = glob.glob(os.path.join(PATH, "csrc/*.cpp"))
extra_cflags = ["-O3"]
extra_cuda_cflags = ["-O3"]
with console.status(
"[bold yellow]Setting up CUDA (This may take a few minutes the first time)",
spinner="bouncingBall",
):
_C = load(
name="nerfacc_cuda",
sources=sources,
extra_cflags=extra_cflags,
extra_cuda_cflags=extra_cuda_cflags,
)
console = Console()
with console.status(
"[bold yellow]Setting up CUDA (This may take a few minutes the first time)",
spinner="bouncingBall",
):
_C = load(
name="nerfacc_cuda",
sources=glob.glob(os.path.join(PATH, "csrc/*.cu")),
extra_cflags=["-O3"],
extra_cuda_cflags=["-O3"],
)
__all__ = ["_C"]
#include "include/helpers_cuda.h"
#include "include/helpers_math.h"
#include "include/helpers_contraction.h"
__global__ void contract_kernel(
// samples info
const uint32_t n_samples,
const float *samples, // (n_samples, 3)
// contraction
const float *roi,
const ContractionType type,
// outputs
float *out_samples)
{
CUDA_GET_THREAD_ID(i, n_samples);
// locate
samples += i * 3;
out_samples += i * 3;
const float3 roi_min = make_float3(roi[0], roi[1], roi[2]);
const float3 roi_max = make_float3(roi[3], roi[4], roi[5]);
const float3 xyz = make_float3(samples[0], samples[1], samples[2]);
float3 xyz_unit = apply_contraction(xyz, roi_min, roi_max, type);
out_samples[0] = xyz_unit.x;
out_samples[1] = xyz_unit.y;
out_samples[2] = xyz_unit.z;
return;
}
__global__ void contract_inv_kernel(
// samples info
const uint32_t n_samples,
const float *samples, // (n_samples, 3)
// contraction
const float *roi,
const ContractionType type,
// outputs
float *out_samples)
{
CUDA_GET_THREAD_ID(i, n_samples);
// locate
samples += i * 3;
out_samples += i * 3;
const float3 roi_min = make_float3(roi[0], roi[1], roi[2]);
const float3 roi_max = make_float3(roi[3], roi[4], roi[5]);
const float3 xyz_unit = make_float3(samples[0], samples[1], samples[2]);
float3 xyz = apply_contraction_inv(xyz_unit, roi_min, roi_max, type);
out_samples[0] = xyz.x;
out_samples[1] = xyz.y;
out_samples[2] = xyz.z;
return;
}
torch::Tensor contract(
const torch::Tensor samples,
// contraction
const torch::Tensor roi,
const ContractionType type)
{
DEVICE_GUARD(samples);
CHECK_INPUT(samples);
const int n_samples = samples.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_samples, threads);
torch::Tensor out_samples = torch::zeros({n_samples, 3}, samples.options());
contract_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_samples,
samples.data_ptr<float>(),
// contraction
roi.data_ptr<float>(),
type,
// outputs
out_samples.data_ptr<float>());
return out_samples;
}
torch::Tensor contract_inv(
const torch::Tensor samples,
// contraction
const torch::Tensor roi,
const ContractionType type)
{
DEVICE_GUARD(samples);
CHECK_INPUT(samples);
const int n_samples = samples.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_samples, threads);
torch::Tensor out_samples = torch::zeros({n_samples, 3}, samples.options());
contract_inv_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_samples,
samples.data_ptr<float>(),
// contraction
roi.data_ptr<float>(),
type,
// outputs
out_samples.data_ptr<float>());
return out_samples;
}
#pragma once
#ifdef __CUDACC__
#define CUDA_HOSTDEV __host__ __device__
#else
#define CUDA_HOSTDEV
#endif
#include <torch/extension.h>
inline constexpr CUDA_HOSTDEV float __SQRT3() { return 1.73205080757f; }
template <typename scalar_t>
inline CUDA_HOSTDEV void __swap(scalar_t &a, scalar_t &b)
{
scalar_t c = a;
a = b;
b = c;
}
inline CUDA_HOSTDEV float __clamp(float f, float a, float b) { return fmaxf(a, fminf(f, b)); }
inline CUDA_HOSTDEV int __clamp(int f, int a, int b) { return std::max(a, std::min(f, b)); }
inline CUDA_HOSTDEV float __sign(float x) { return copysignf(1.0, x); }
inline CUDA_HOSTDEV uint32_t __expand_bits(uint32_t v)
{
v = (v * 0x00010001u) & 0xFF0000FFu;
v = (v * 0x00000101u) & 0x0F00F00Fu;
v = (v * 0x00000011u) & 0xC30C30C3u;
v = (v * 0x00000005u) & 0x49249249u;
return v;
}
inline CUDA_HOSTDEV uint32_t __morton3D(uint32_t x, uint32_t y, uint32_t z)
{
uint32_t xx = __expand_bits(x);
uint32_t yy = __expand_bits(y);
uint32_t zz = __expand_bits(z);
return xx | (yy << 1) | (zz << 2);
}
inline CUDA_HOSTDEV uint32_t __morton3D_invert(uint32_t x)
{
x = x & 0x49249249;
x = (x | (x >> 2)) & 0xc30c30c3;
x = (x | (x >> 4)) & 0x0f00f00f;
x = (x | (x >> 8)) & 0xff0000ff;
x = (x | (x >> 16)) & 0x0000ffff;
return x;
}
\ No newline at end of file
#pragma once
#include "helpers_math.h"
enum ContractionType
{
AABB = 0,
UN_BOUNDED_TANH = 1,
UN_BOUNDED_SPHERE = 2,
};
inline __device__ __host__ float3 roi_to_unit(
const float3 xyz, const float3 roi_min, const float3 roi_max)
{
// roi -> [0, 1]^3
return (xyz - roi_min) / (roi_max - roi_min);
}
inline __device__ __host__ float3 unit_to_roi(
const float3 xyz, const float3 roi_min, const float3 roi_max)
{
// [0, 1]^3 -> roi
return xyz * (roi_max - roi_min) + roi_min;
}
inline __device__ __host__ float3 inf_to_unit_tanh(
const float3 xyz, float3 roi_min, const float3 roi_max)
{
/**
[-inf, inf]^3 -> [0, 1]^3
roi -> cube of [0.25, 0.75]^3
**/
float3 xyz_unit = roi_to_unit(xyz, roi_min, roi_max); // roi -> [0, 1]^3
xyz_unit = xyz_unit - 0.5f; // roi -> [-0.5, 0.5]^3
return make_float3(tanhf(xyz_unit.x), tanhf(xyz_unit.y), tanhf(xyz_unit.z)) * 0.5f + 0.5f;
}
inline __device__ __host__ float3 unit_to_inf_tanh(
const float3 xyz, float3 roi_min, const float3 roi_max)
{
/**
[0, 1]^3 -> [-inf, inf]^3
cube of [0.25, 0.75]^3 -> roi
**/
float3 xyz_unit = clamp(
make_float3(
atanhf(xyz.x * 2.0f - 1.0f),
atanhf(xyz.y * 2.0f - 1.0f),
atanhf(xyz.z * 2.0f - 1.0f)),
-1e10f,
1e10f);
xyz_unit = xyz_unit + 0.5f;
xyz_unit = unit_to_roi(xyz_unit, roi_min, roi_max);
return xyz_unit;
}
inline __device__ __host__ float3 inf_to_unit_sphere(
const float3 xyz, const float3 roi_min, const float3 roi_max)
{
/** From MipNeRF360
[-inf, inf]^3 -> sphere of [0, 1]^3;
roi -> sphere of [0.25, 0.75]^3
**/
float3 xyz_unit = roi_to_unit(xyz, roi_min, roi_max); // roi -> [0, 1]^3
xyz_unit = xyz_unit * 2.0f - 1.0f; // roi -> [-1, 1]^3
float norm_sq = dot(xyz_unit, xyz_unit);
float norm = sqrt(norm_sq);
if (norm > 1.0f)
{
xyz_unit = (2.0f - 1.0f / norm) * (xyz_unit / norm);
}
xyz_unit = xyz_unit * 0.25f + 0.5f; // [-1, 1]^3 -> [0.25, 0.75]^3
return xyz_unit;
}
inline __device__ __host__ float3 unit_sphere_to_inf(
const float3 xyz, const float3 roi_min, const float3 roi_max)
{
/** From MipNeRF360
sphere of [0, 1]^3 -> [-inf, inf]^3;
sphere of [0.25, 0.75]^3 -> roi
**/
float3 xyz_unit = (xyz - 0.5f) * 4.0f; // [0.25, 0.75]^3 -> [-1, 1]^3
float norm_sq = dot(xyz_unit, xyz_unit);
float norm = sqrt(norm_sq);
if (norm > 1.0f)
{
xyz_unit = xyz_unit / fmaxf((2.0f * norm - 1.0f * norm_sq), 1e-10f);
}
xyz_unit = xyz_unit * 0.5f + 0.5f; // [-1, 1]^3 -> [0, 1]^3
xyz_unit = unit_to_roi(xyz_unit, roi_min, roi_max); // [0, 1]^3 -> roi
return xyz_unit;
}
inline __device__ __host__ float3 apply_contraction(
const float3 xyz, const float3 roi_min, const float3 roi_max,
const ContractionType type)
{
switch (type)
{
case AABB:
return roi_to_unit(xyz, roi_min, roi_max);
case UN_BOUNDED_TANH:
return inf_to_unit_tanh(xyz, roi_min, roi_max);
case UN_BOUNDED_SPHERE:
return inf_to_unit_sphere(xyz, roi_min, roi_max);
}
}
inline __device__ __host__ float3 apply_contraction_inv(
const float3 xyz, const float3 roi_min, const float3 roi_max,
const ContractionType type)
{
switch (type)
{
case AABB:
return unit_to_roi(xyz, roi_min, roi_max);
case UN_BOUNDED_TANH:
return unit_to_inf_tanh(xyz, roi_min, roi_max);
case UN_BOUNDED_SPHERE:
return unit_sphere_to_inf(xyz, roi_min, roi_max);
}
}
#pragma once
#include "helpers.h"
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
......@@ -15,4 +15,4 @@
return
#define CUDA_N_BLOCKS_NEEDED(Q, CUDA_N_THREADS) ((Q - 1) / CUDA_N_THREADS + 1)
#define DEVICE_GUARD(_ten) \
const at::cuda::OptionalCUDAGuard device_guard(device_of(_ten));
\ No newline at end of file
const at::cuda::OptionalCUDAGuard device_guard(device_of(_ten));
......@@ -241,7 +241,6 @@ inline __host__ __device__ int4 make_int4(float4 a)
return make_int4(int(a.x), int(a.y), int(a.z), int(a.w));
}
inline __host__ __device__ uint4 make_uint4(uint s)
{
return make_uint4(s, s, s, s);
......@@ -361,7 +360,6 @@ inline __host__ __device__ void operator+=(uint2 &a, uint b)
a.y += b;
}
inline __host__ __device__ float3 operator+(float3 a, float3 b)
{
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
......@@ -440,7 +438,7 @@ inline __host__ __device__ float3 operator+(float b, float3 a)
inline __host__ __device__ float4 operator+(float4 a, float4 b)
{
return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
}
inline __host__ __device__ void operator+=(float4 &a, float4 b)
{
......@@ -467,7 +465,7 @@ inline __host__ __device__ void operator+=(float4 &a, float b)
inline __host__ __device__ int4 operator+(int4 a, int4 b)
{
return make_int4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
return make_int4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
}
inline __host__ __device__ void operator+=(int4 &a, int4 b)
{
......@@ -478,11 +476,11 @@ inline __host__ __device__ void operator+=(int4 &a, int4 b)
}
inline __host__ __device__ int4 operator+(int4 a, int b)
{
return make_int4(a.x + b, a.y + b, a.z + b, a.w + b);
return make_int4(a.x + b, a.y + b, a.z + b, a.w + b);
}
inline __host__ __device__ int4 operator+(int b, int4 a)
{
return make_int4(a.x + b, a.y + b, a.z + b, a.w + b);
return make_int4(a.x + b, a.y + b, a.z + b, a.w + b);
}
inline __host__ __device__ void operator+=(int4 &a, int b)
{
......@@ -494,7 +492,7 @@ inline __host__ __device__ void operator+=(int4 &a, int b)
inline __host__ __device__ uint4 operator+(uint4 a, uint4 b)
{
return make_uint4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
return make_uint4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
}
inline __host__ __device__ void operator+=(uint4 &a, uint4 b)
{
......@@ -505,11 +503,11 @@ inline __host__ __device__ void operator+=(uint4 &a, uint4 b)
}
inline __host__ __device__ uint4 operator+(uint4 a, uint b)
{
return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b);
return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b);
}
inline __host__ __device__ uint4 operator+(uint b, uint4 a)
{
return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b);
return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b);
}
inline __host__ __device__ void operator+=(uint4 &a, uint b)
{
......@@ -669,7 +667,7 @@ inline __host__ __device__ void operator-=(uint3 &a, uint b)
inline __host__ __device__ float4 operator-(float4 a, float4 b)
{
return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
}
inline __host__ __device__ void operator-=(float4 &a, float4 b)
{
......@@ -680,7 +678,7 @@ inline __host__ __device__ void operator-=(float4 &a, float4 b)
}
inline __host__ __device__ float4 operator-(float4 a, float b)
{
return make_float4(a.x - b, a.y - b, a.z - b, a.w - b);
return make_float4(a.x - b, a.y - b, a.z - b, a.w - b);
}
inline __host__ __device__ void operator-=(float4 &a, float b)
{
......@@ -692,7 +690,7 @@ inline __host__ __device__ void operator-=(float4 &a, float b)
inline __host__ __device__ int4 operator-(int4 a, int4 b)
{
return make_int4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
return make_int4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
}
inline __host__ __device__ void operator-=(int4 &a, int4 b)
{
......@@ -703,7 +701,7 @@ inline __host__ __device__ void operator-=(int4 &a, int4 b)
}
inline __host__ __device__ int4 operator-(int4 a, int b)
{
return make_int4(a.x - b, a.y - b, a.z - b, a.w - b);
return make_int4(a.x - b, a.y - b, a.z - b, a.w - b);
}
inline __host__ __device__ int4 operator-(int b, int4 a)
{
......@@ -719,7 +717,7 @@ inline __host__ __device__ void operator-=(int4 &a, int b)
inline __host__ __device__ uint4 operator-(uint4 a, uint4 b)
{
return make_uint4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
return make_uint4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
}
inline __host__ __device__ void operator-=(uint4 &a, uint4 b)
{
......@@ -730,7 +728,7 @@ inline __host__ __device__ void operator-=(uint4 &a, uint4 b)
}
inline __host__ __device__ uint4 operator-(uint4 a, uint b)
{
return make_uint4(a.x - b, a.y - b, a.z - b, a.w - b);
return make_uint4(a.x - b, a.y - b, a.z - b, a.w - b);
}
inline __host__ __device__ uint4 operator-(uint b, uint4 a)
{
......@@ -894,7 +892,7 @@ inline __host__ __device__ void operator*=(uint3 &a, uint b)
inline __host__ __device__ float4 operator*(float4 a, float4 b)
{
return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
}
inline __host__ __device__ void operator*=(float4 &a, float4 b)
{
......@@ -905,7 +903,7 @@ inline __host__ __device__ void operator*=(float4 &a, float4 b)
}
inline __host__ __device__ float4 operator*(float4 a, float b)
{
return make_float4(a.x * b, a.y * b, a.z * b, a.w * b);
return make_float4(a.x * b, a.y * b, a.z * b, a.w * b);
}
inline __host__ __device__ float4 operator*(float b, float4 a)
{
......@@ -921,7 +919,7 @@ inline __host__ __device__ void operator*=(float4 &a, float b)
inline __host__ __device__ int4 operator*(int4 a, int4 b)
{
return make_int4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
return make_int4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
}
inline __host__ __device__ void operator*=(int4 &a, int4 b)
{
......@@ -932,7 +930,7 @@ inline __host__ __device__ void operator*=(int4 &a, int4 b)
}
inline __host__ __device__ int4 operator*(int4 a, int b)
{
return make_int4(a.x * b, a.y * b, a.z * b, a.w * b);
return make_int4(a.x * b, a.y * b, a.z * b, a.w * b);
}
inline __host__ __device__ int4 operator*(int b, int4 a)
{
......@@ -948,7 +946,7 @@ inline __host__ __device__ void operator*=(int4 &a, int b)
inline __host__ __device__ uint4 operator*(uint4 a, uint4 b)
{
return make_uint4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
return make_uint4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
}
inline __host__ __device__ void operator*=(uint4 &a, uint4 b)
{
......@@ -959,7 +957,7 @@ inline __host__ __device__ void operator*=(uint4 &a, uint4 b)
}
inline __host__ __device__ uint4 operator*(uint4 a, uint b)
{
return make_uint4(a.x * b, a.y * b, a.z * b, a.w * b);
return make_uint4(a.x * b, a.y * b, a.z * b, a.w * b);
}
inline __host__ __device__ uint4 operator*(uint b, uint4 a)
{
......@@ -1027,7 +1025,7 @@ inline __host__ __device__ float3 operator/(float b, float3 a)
inline __host__ __device__ float4 operator/(float4 a, float4 b)
{
return make_float4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w);
return make_float4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w);
}
inline __host__ __device__ void operator/=(float4 &a, float4 b)
{
......@@ -1038,7 +1036,7 @@ inline __host__ __device__ void operator/=(float4 &a, float4 b)
}
inline __host__ __device__ float4 operator/(float4 a, float b)
{
return make_float4(a.x / b, a.y / b, a.z / b, a.w / b);
return make_float4(a.x / b, a.y / b, a.z / b, a.w / b);
}
inline __host__ __device__ void operator/=(float4 &a, float b)
{
......@@ -1056,43 +1054,43 @@ inline __host__ __device__ float4 operator/(float b, float4 a)
// min
////////////////////////////////////////////////////////////////////////////////
inline __host__ __device__ float2 fminf(float2 a, float2 b)
inline __host__ __device__ float2 fminf(float2 a, float2 b)
{
return make_float2(fminf(a.x,b.x), fminf(a.y,b.y));
return make_float2(fminf(a.x, b.x), fminf(a.y, b.y));
}
inline __host__ __device__ float3 fminf(float3 a, float3 b)
{
return make_float3(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z));
return make_float3(fminf(a.x, b.x), fminf(a.y, b.y), fminf(a.z, b.z));
}
inline __host__ __device__ float4 fminf(float4 a, float4 b)
inline __host__ __device__ float4 fminf(float4 a, float4 b)
{
return make_float4(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z), fminf(a.w,b.w));
return make_float4(fminf(a.x, b.x), fminf(a.y, b.y), fminf(a.z, b.z), fminf(a.w, b.w));
}
inline __host__ __device__ int2 min(int2 a, int2 b)
{
return make_int2(min(a.x,b.x), min(a.y,b.y));
return make_int2(min(a.x, b.x), min(a.y, b.y));
}
inline __host__ __device__ int3 min(int3 a, int3 b)
{
return make_int3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z));
return make_int3(min(a.x, b.x), min(a.y, b.y), min(a.z, b.z));
}
inline __host__ __device__ int4 min(int4 a, int4 b)
{
return make_int4(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z), min(a.w,b.w));
return make_int4(min(a.x, b.x), min(a.y, b.y), min(a.z, b.z), min(a.w, b.w));
}
inline __host__ __device__ uint2 min(uint2 a, uint2 b)
{
return make_uint2(min(a.x,b.x), min(a.y,b.y));
return make_uint2(min(a.x, b.x), min(a.y, b.y));
}
inline __host__ __device__ uint3 min(uint3 a, uint3 b)
{
return make_uint3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z));
return make_uint3(min(a.x, b.x), min(a.y, b.y), min(a.z, b.z));
}
inline __host__ __device__ uint4 min(uint4 a, uint4 b)
{
return make_uint4(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z), min(a.w,b.w));
return make_uint4(min(a.x, b.x), min(a.y, b.y), min(a.z, b.z), min(a.w, b.w));
}
////////////////////////////////////////////////////////////////////////////////
......@@ -1101,41 +1099,41 @@ inline __host__ __device__ uint4 min(uint4 a, uint4 b)
inline __host__ __device__ float2 fmaxf(float2 a, float2 b)
{
return make_float2(fmaxf(a.x,b.x), fmaxf(a.y,b.y));
return make_float2(fmaxf(a.x, b.x), fmaxf(a.y, b.y));
}
inline __host__ __device__ float3 fmaxf(float3 a, float3 b)
{
return make_float3(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z));
return make_float3(fmaxf(a.x, b.x), fmaxf(a.y, b.y), fmaxf(a.z, b.z));
}
inline __host__ __device__ float4 fmaxf(float4 a, float4 b)
{
return make_float4(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z), fmaxf(a.w,b.w));
return make_float4(fmaxf(a.x, b.x), fmaxf(a.y, b.y), fmaxf(a.z, b.z), fmaxf(a.w, b.w));
}
inline __host__ __device__ int2 max(int2 a, int2 b)
{
return make_int2(max(a.x,b.x), max(a.y,b.y));
return make_int2(max(a.x, b.x), max(a.y, b.y));
}
inline __host__ __device__ int3 max(int3 a, int3 b)
{
return make_int3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z));
return make_int3(max(a.x, b.x), max(a.y, b.y), max(a.z, b.z));
}
inline __host__ __device__ int4 max(int4 a, int4 b)
{
return make_int4(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z), max(a.w,b.w));
return make_int4(max(a.x, b.x), max(a.y, b.y), max(a.z, b.z), max(a.w, b.w));
}
inline __host__ __device__ uint2 max(uint2 a, uint2 b)
{
return make_uint2(max(a.x,b.x), max(a.y,b.y));
return make_uint2(max(a.x, b.x), max(a.y, b.y));
}
inline __host__ __device__ uint3 max(uint3 a, uint3 b)
{
return make_uint3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z));
return make_uint3(max(a.x, b.x), max(a.y, b.y), max(a.z, b.z));
}
inline __host__ __device__ uint4 max(uint4 a, uint4 b)
{
return make_uint4(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z), max(a.w,b.w));
return make_uint4(max(a.x, b.x), max(a.y, b.y), max(a.z, b.z), max(a.w, b.w));
}
////////////////////////////////////////////////////////////////////////////////
......@@ -1145,19 +1143,19 @@ inline __host__ __device__ uint4 max(uint4 a, uint4 b)
inline __device__ __host__ float lerp(float a, float b, float t)
{
return a + t*(b-a);
return a + t * (b - a);
}
inline __device__ __host__ float2 lerp(float2 a, float2 b, float t)
{
return a + t*(b-a);
return a + t * (b - a);
}
inline __device__ __host__ float3 lerp(float3 a, float3 b, float t)
{
return a + t*(b-a);
return a + t * (b - a);
}
inline __device__ __host__ float4 lerp(float4 a, float4 b, float t)
{
return a + t*(b-a);
return a + t * (b - a);
}
////////////////////////////////////////////////////////////////////////////////
......@@ -1426,7 +1424,7 @@ inline __host__ __device__ int4 abs(int4 v)
inline __host__ __device__ float3 reflect(float3 i, float3 n)
{
return i - 2.0f * n * dot(n,i);
return i - 2.0f * n * dot(n, i);
}
////////////////////////////////////////////////////////////////////////////////
......@@ -1435,7 +1433,7 @@ inline __host__ __device__ float3 reflect(float3 i, float3 n)
inline __host__ __device__ float3 cross(float3 a, float3 b)
{
return make_float3(a.y*b.z - a.z*b.y, a.z*b.x - a.x*b.z, a.x*b.y - a.y*b.x);
return make_float3(a.y * b.z - a.z * b.y, a.z * b.x - a.x * b.z, a.x * b.y - a.y * b.x);
}
////////////////////////////////////////////////////////////////////////////////
......@@ -1448,22 +1446,31 @@ inline __host__ __device__ float3 cross(float3 a, float3 b)
inline __device__ __host__ float smoothstep(float a, float b, float x)
{
float y = clamp((x - a) / (b - a), 0.0f, 1.0f);
return (y*y*(3.0f - (2.0f*y)));
return (y * y * (3.0f - (2.0f * y)));
}
inline __device__ __host__ float2 smoothstep(float2 a, float2 b, float2 x)
{
float2 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
return (y*y*(make_float2(3.0f) - (make_float2(2.0f)*y)));
return (y * y * (make_float2(3.0f) - (make_float2(2.0f) * y)));
}
inline __device__ __host__ float3 smoothstep(float3 a, float3 b, float3 x)
{
float3 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
return (y*y*(make_float3(3.0f) - (make_float3(2.0f)*y)));
return (y * y * (make_float3(3.0f) - (make_float3(2.0f) * y)));
}
inline __device__ __host__ float4 smoothstep(float4 a, float4 b, float4 x)
{
float4 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
return (y*y*(make_float4(3.0f) - (make_float4(2.0f)*y)));
return (y * y * (make_float4(3.0f) - (make_float4(2.0f) * y)));
}
////////////////////////////////////////////////////////////////////////////////
// sign
////////////////////////////////////////////////////////////////////////////////
inline __device__ __host__ float3 sign(float3 a)
{
return make_float3(
copysignf(1.0f, a.x), copysignf(1.0f, a.y), copysignf(1.0f, a.z));
}
#endif
\ No newline at end of file
#include "include/helpers_cuda.h"
template <typename scalar_t>
inline __host__ __device__ void _swap(scalar_t &a, scalar_t &b)
{
scalar_t c = a;
a = b;
b = c;
}
template <typename scalar_t>
inline __host__ __device__ void _ray_aabb_intersect(
const scalar_t* rays_o,
const scalar_t* rays_d,
const scalar_t* aabb,
scalar_t* near,
scalar_t* far
) {
const scalar_t *rays_o,
const scalar_t *rays_d,
const scalar_t *aabb,
scalar_t *near,
scalar_t *far)
{
// aabb is [xmin, ymin, zmin, xmax, ymax, zmax]
scalar_t tmin = (aabb[0] - rays_o[0]) / rays_d[0];
scalar_t tmax = (aabb[3] - rays_o[0]) / rays_d[0];
if (tmin > tmax) __swap(tmin, tmax);
if (tmin > tmax)
_swap(tmin, tmax);
scalar_t tymin = (aabb[1] - rays_o[1]) / rays_d[1];
scalar_t tymax = (aabb[4] - rays_o[1]) / rays_d[1];
if (tymin > tymax) __swap(tymin, tymax);
if (tymin > tymax)
_swap(tymin, tymax);
if (tmin > tymax || tymin > tmax){
if (tmin > tymax || tymin > tmax)
{
*near = 1e10;
*far = 1e10;
return;
}
if (tymin > tmin) tmin = tymin;
if (tymax < tmax) tmax = tymax;
if (tymin > tmin)
tmin = tymin;
if (tymax < tmax)
tmax = tymax;
scalar_t tzmin = (aabb[2] - rays_o[2]) / rays_d[2];
scalar_t tzmax = (aabb[5] - rays_o[2]) / rays_d[2];
if (tzmin > tzmax) __swap(tzmin, tzmax);
if (tzmin > tzmax)
_swap(tzmin, tzmax);
if (tmin > tzmax || tzmin > tmax){
if (tmin > tzmax || tzmin > tmax)
{
*near = 1e10;
*far = 1e10;
return;
}
if (tzmin > tmin) tmin = tzmin;
if (tzmax < tmax) tmax = tzmax;
if (tzmin > tmin)
tmin = tzmin;
if (tzmax < tmax)
tmax = tzmax;
*near = tmin;
*far = tmax;
return;
}
template <typename scalar_t>
__global__ void ray_aabb_intersect_kernel(
const int N,
const scalar_t* rays_o,
const scalar_t* rays_d,
const scalar_t* aabb,
scalar_t* t_min,
scalar_t* t_max
){
const scalar_t *rays_o,
const scalar_t *rays_d,
const scalar_t *aabb,
scalar_t *t_min,
scalar_t *t_max)
{
// aabb is [xmin, ymin, zmin, xmax, ymax, zmax]
CUDA_GET_THREAD_ID(thread_id, N);
......@@ -63,9 +78,9 @@ __global__ void ray_aabb_intersect_kernel(
rays_d += thread_id * 3;
t_min += thread_id;
t_max += thread_id;
_ray_aabb_intersect<scalar_t>(rays_o, rays_d, aabb, t_min, t_max);
scalar_t zero = static_cast<scalar_t>(0.f);
*t_min = *t_min > zero ? *t_min : zero;
return;
......@@ -73,17 +88,17 @@ __global__ void ray_aabb_intersect_kernel(
/**
* @brief Ray AABB Test
*
*
* @param rays_o Ray origins. Tensor with shape [N, 3].
* @param rays_d Normalized ray directions. Tensor with shape [N, 3].
* @param aabb Scene AABB [xmin, ymin, zmin, xmax, ymax, zmax]. Tensor with shape [6].
* @return std::vector<torch::Tensor>
* Ray AABB intersection {t_min, t_max} with shape [N] respectively. Note the t_min is
* clipped to minimum zero.
* @return std::vector<torch::Tensor>
* Ray AABB intersection {t_min, t_max} with shape [N] respectively. Note the t_min is
* clipped to minimum zero. 1e10 is returned if no intersection.
*/
std::vector<torch::Tensor> ray_aabb_intersect(
const torch::Tensor rays_o, const torch::Tensor rays_d, const torch::Tensor aabb
) {
const torch::Tensor rays_o, const torch::Tensor rays_d, const torch::Tensor aabb)
{
DEVICE_GUARD(rays_o);
CHECK_INPUT(rays_o);
CHECK_INPUT(rays_d);
......@@ -91,7 +106,7 @@ std::vector<torch::Tensor> ray_aabb_intersect(
TORCH_CHECK(rays_o.ndimension() == 2 & rays_o.size(1) == 3)
TORCH_CHECK(rays_d.ndimension() == 2 & rays_d.size(1) == 3)
TORCH_CHECK(aabb.ndimension() == 1 & aabb.size(0) == 6)
const int N = rays_o.size(0);
const int threads = 256;
......@@ -101,18 +116,15 @@ std::vector<torch::Tensor> ray_aabb_intersect(
torch::Tensor t_max = torch::empty({N}, rays_o.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
rays_o.scalar_type(), "ray_aabb_intersect",
([&] {
ray_aabb_intersect_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
N,
rays_o.data_ptr<scalar_t>(),
rays_d.data_ptr<scalar_t>(),
aabb.data_ptr<scalar_t>(),
t_min.data_ptr<scalar_t>(),
t_max.data_ptr<scalar_t>()
);
})
);
rays_o.scalar_type(), "ray_aabb_intersect",
([&]
{ ray_aabb_intersect_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
N,
rays_o.data_ptr<scalar_t>(),
rays_d.data_ptr<scalar_t>(),
aabb.data_ptr<scalar_t>(),
t_min.data_ptr<scalar_t>(),
t_max.data_ptr<scalar_t>()); }));
return {t_min, t_max};
}
\ No newline at end of file
#include "include/helpers_cuda.h"
#include "include/helpers_math.h"
#include "include/helpers_contraction.h"
std::vector<torch::Tensor> rendering_forward(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas,
float early_stop_eps,
bool compression);
torch::Tensor rendering_backward(
torch::Tensor weights,
torch::Tensor grad_weights,
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas,
float early_stop_eps);
std::vector<torch::Tensor> ray_aabb_intersect(
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor aabb
);
std::vector<torch::Tensor> volumetric_rendering_steps(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas
);
torch::Tensor volumetric_rendering_weights_forward(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas
);
torch::Tensor volumetric_rendering_weights_backward(
torch::Tensor weights,
torch::Tensor grad_weights,
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas
);
const torch::Tensor aabb);
std::vector<torch::Tensor> volumetric_marching(
std::vector<torch::Tensor> ray_marching(
// rays
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor t_min,
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor t_min,
const torch::Tensor t_max,
// density grid
const torch::Tensor aabb,
const pybind11::list resolution,
const torch::Tensor occ_binary,
// occupancy grid & contraction
const torch::Tensor roi,
const torch::Tensor grid_binary,
const ContractionType type,
// sampling
const float dt
);
const float step_size,
const float cone_angle);
torch::Tensor unpack_to_ray_indices(
const torch::Tensor packed_info);
torch::Tensor query_occ(
const torch::Tensor samples,
// density grid
const torch::Tensor aabb,
const pybind11::list resolution,
const torch::Tensor occ_binary
);
// occupancy grid & contraction
const torch::Tensor roi,
const torch::Tensor grid_binary,
const ContractionType type);
torch::Tensor unpack_to_ray_indices(const torch::Tensor packed_info);
torch::Tensor contract(
const torch::Tensor samples,
// contraction
const torch::Tensor roi,
const ContractionType type);
torch::Tensor contract_inv(
const torch::Tensor samples,
// contraction
const torch::Tensor roi,
const ContractionType type);
torch::Tensor rendering_alphas_backward(
torch::Tensor weights,
torch::Tensor grad_weights,
torch::Tensor packed_info,
torch::Tensor alphas,
float early_stop_eps);
std::vector<torch::Tensor> rendering_alphas_forward(
torch::Tensor packed_info,
torch::Tensor alphas,
float early_stop_eps,
bool compression);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
// contraction
py::enum_<ContractionType>(m, "ContractionType")
.value("AABB", ContractionType::AABB)
.value("UN_BOUNDED_TANH", ContractionType::UN_BOUNDED_TANH)
.value("UN_BOUNDED_SPHERE", ContractionType::UN_BOUNDED_SPHERE);
m.def("contract", &contract);
m.def("contract_inv", &contract_inv);
// grid
m.def("query_occ", &query_occ);
// marching
m.def("ray_aabb_intersect", &ray_aabb_intersect);
m.def("volumetric_marching", &volumetric_marching);
m.def("volumetric_rendering_steps", &volumetric_rendering_steps);
m.def("volumetric_rendering_weights_forward", &volumetric_rendering_weights_forward);
m.def("volumetric_rendering_weights_backward", &volumetric_rendering_weights_backward);
m.def("unpack_to_ray_indices", &unpack_to_ray_indices);
m.def("query_occ", &query_occ);
m.def("ray_marching", &ray_marching);
m.def("unpack_to_ray_indices", &unpack_to_ray_indices);
// rendering
m.def("rendering_forward", &rendering_forward);
m.def("rendering_backward", &rendering_backward);
m.def("rendering_alphas_forward", &rendering_alphas_forward);
m.def("rendering_alphas_backward", &rendering_alphas_backward);
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment