Unverified Commit b53f8a4e authored by Ruilong Li(李瑞龙)'s avatar Ruilong Li(李瑞龙) Committed by GitHub
Browse files

Support multi-res occ grid & prop net (#176)

* multres grid

* prop

* benchmark with prop and occ

* benchmark blender with weight_decay

* docs

* bump version
parent 82fd69c7
......@@ -118,4 +118,5 @@ venv.bak/
.vsocde
benchmarks/
outputs/
\ No newline at end of file
outputs/
data
\ No newline at end of file
nerfacc.ray\_resampling
=======================
.. currentmodule:: nerfacc
.. autofunction:: ray_resampling
\ No newline at end of file
......@@ -17,7 +17,6 @@ Utils
render_weight_from_alpha
render_visibility
ray_resampling
pack_data
unpack_data
\ No newline at end of file
......@@ -7,7 +7,7 @@ See code `examples/train_ngp_nerf.py` at our `github repository`_ for details.
Benchmarks
------------
*updated on 2022-10-12*
*updated on 2023-03-14*
Here we trained a `Instant-NGP Nerf`_ model on the `Nerf-Synthetic dataset`_. We follow the same
settings with the Instant-NGP paper, which uses train split for training and test split for
......@@ -30,11 +30,15 @@ memory footprint is about 3GB.
+-----------------------+-------+-------+---------+-------+-------+-------+-------+-------+-------+
|(training time) | 309s | 258s | 256s | 316s | 292s | 207s | 218s | 250s | 263s |
+-----------------------+-------+-------+---------+-------+-------+-------+-------+-------+-------+
|Ours 20k steps | 35.50 | 36.16 | 29.14 | 35.23 | 37.15 | 31.71 | 24.88 | 29.91 | 32.46 |
|Ours (occ) 20k steps | 35.81 | 36.87 | 29.59 | 35.70 | 37.45 | 33.63 | 24.98 | 30.64 | 33.08 |
+-----------------------+-------+-------+---------+-------+-------+-------+-------+-------+-------+
|(training time) | 287s | 274s | 269s | 317s | 269s | 244s | 249s | 257s | 271s |
|(training time) | 288s | 255s | 247s | 319s | 274s | 238s | 247s | 252s | 265s |
+-----------------------+-------+-------+---------+-------+-------+-------+-------+-------+-------+
|Ours (prop) 20k steps | 34.06 | 34.32 | 27.93 | 34.27 | 36.47 | 31.39 | 24.39 | 30.57 | 31.68 |
+-----------------------+-------+-------+---------+-------+-------+-------+-------+-------+-------+
|(training time) | 238s | 236s | 250s | 235s | 235s | 236s | 236s | 236s | 240s |
+-----------------------+-------+-------+---------+-------+-------+-------+-------+-------+-------+
.. _`Instant-NGP Nerf`: https://github.com/NVlabs/instant-ngp/tree/51e4107edf48338e9ab0316d56a222e0adf87143
.. _`github repository`: https://github.com/KAIR-BAIR/nerfacc/tree/76c0f9817da4c9c8b5ccf827eb069ee2ce854b75
.. _`github repository`: https://github.com/KAIR-BAIR/nerfacc/
.. _`Nerf-Synthetic dataset`: https://drive.google.com/drive/folders/1JDdLGDruGNXWnM1eqY1FNL9PlStjaKWi
......@@ -5,7 +5,7 @@ See code `examples/train_ngp_nerf.py` at our `github repository`_ for details.
Benchmarks
------------
*updated on 2022-11-07*
*updated on 2023-03-14*
Here we trained a `Instant-NGP Nerf`_ on the `MipNerf360`_ dataset. We used train
split for training and test split for evaluation. Our experiments are conducted on a
......@@ -32,12 +32,19 @@ that takes from `MipNerf360`_.
+----------------------+-------+-------+-------+-------+-------+-------+-------+-------+
| MipNerf360 (~days) | 26.98 | 24.37 | 33.46 | 29.55 | 32.23 | 31.63 | 26.40 | 29.23 |
+----------------------+-------+-------+-------+-------+-------+-------+-------+-------+
| Ours (~20 mins) | 25.41 | 22.97 | 30.71 | 27.34 | 30.32 | 31.00 | 23.43 | 27.31 |
| Ours (occ) | 24.76 | 22.38 | 29.72 | 26.80 | 28.02 | 30.67 | 22.39 | 26.39 |
+----------------------+-------+-------+-------+-------+-------+-------+-------+-------+
| Ours (Training time) | 25min | 17min | 19min | 23min | 28min | 20min | 17min | 21min |
| Ours (Training time) | 323s | 302s | 300s | 337s | 347s | 320s | 322s | 322s |
+----------------------+-------+-------+-------+-------+-------+-------+-------+-------+
| Ours (prop) | 25.44 | 23.21 | 30.62 | 26.75 | 30.63 | 30.93 | 25.20 | 27.54 |
+----------------------+-------+-------+-------+-------+-------+-------+-------+-------+
| Ours (Training time) | 308s | 304s | 308s | 306s | 313s | 301s | 287s | 304s |
+----------------------+-------+-------+-------+-------+-------+-------+-------+-------+
Note `Ours (prop)` is basically a `Nerfacto_` model.
.. _`Instant-NGP Nerf`: https://arxiv.org/abs/2201.05989
.. _`MipNerf360`: https://arxiv.org/abs/2111.12077
.. _`Nerf++`: https://arxiv.org/abs/2010.07492
.. _`github repository`: https://github.com/KAIR-BAIR/nerfacc/tree/76c0f9817da4c9c8b5ccf827eb069ee2ce854b75
.. _`github repository`: https://github.com/KAIR-BAIR/nerfacc/
.. _`Nerfacto`: https://github.com/nerfstudio-project/nerfstudio/blob/main/nerfstudio/models/nerfacto.py
\ No newline at end of file
......@@ -123,7 +123,7 @@ Links:
.. toctree::
:glob:
:maxdepth: 1
:caption: Example Usages
:caption: Example Usages and Benchmarks
examples/*
......
......@@ -86,7 +86,7 @@ class SubjectLoader(torch.utils.data.Dataset):
near: float = None,
far: float = None,
batch_over_images: bool = True,
device: str = "cuda:0",
device: str = "cpu",
):
super().__init__()
assert split in self.SPLITS, "%s" % split
......
......@@ -22,7 +22,7 @@ sys.path.insert(
from scene_manager import SceneManager
def _load_colmap(root_fp: str, subject_id: str, split: str, factor: int = 1):
def _load_colmap(root_fp: str, subject_id: str, factor: int = 1):
assert factor in [1, 2, 4, 8]
data_dir = os.path.join(root_fp, subject_id)
......@@ -134,12 +134,66 @@ def _load_colmap(root_fp: str, subject_id: str, split: str, factor: int = 1):
"test": all_indices[all_indices % 8 == 0],
"train": all_indices[all_indices % 8 != 0],
}
indices = split_indices[split]
# All per-image quantities must be re-indexed using the split indices.
images = images[indices]
camtoworlds = camtoworlds[indices]
return images, camtoworlds, K, split_indices
def similarity_from_cameras(c2w, strict_scaling):
"""
reference: nerf-factory
Get a similarity transform to normalize dataset
from c2w (OpenCV convention) cameras
:param c2w: (N, 4)
:return T (4,4) , scale (float)
"""
t = c2w[:, :3, 3]
R = c2w[:, :3, :3]
# (1) Rotate the world so that z+ is the up axis
# we estimate the up axis by averaging the camera up axes
ups = np.sum(R * np.array([0, -1.0, 0]), axis=-1)
world_up = np.mean(ups, axis=0)
world_up /= np.linalg.norm(world_up)
up_camspace = np.array([0.0, -1.0, 0.0])
c = (up_camspace * world_up).sum()
cross = np.cross(world_up, up_camspace)
skew = np.array(
[
[0.0, -cross[2], cross[1]],
[cross[2], 0.0, -cross[0]],
[-cross[1], cross[0], 0.0],
]
)
if c > -1:
R_align = np.eye(3) + skew + (skew @ skew) * 1 / (1 + c)
else:
# In the unlikely case the original data has y+ up axis,
# rotate 180-deg about x axis
R_align = np.array([[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
# R_align = np.eye(3) # DEBUG
R = R_align @ R
fwds = np.sum(R * np.array([0, 0.0, 1.0]), axis=-1)
t = (R_align @ t[..., None])[..., 0]
# (2) Recenter the scene using camera center rays
# find the closest point to the origin for each camera's center ray
nearest = t + (fwds * -t).sum(-1)[:, None] * fwds
# median for more robustness
translate = -np.median(nearest, axis=0)
return images, camtoworlds, K
# translate = -np.mean(t, axis=0) # DEBUG
transform = np.eye(4)
transform[:3, 3] = translate
transform[:3, :3] = R_align
# (3) Rescale the scene using camera distances
scale_fn = np.max if strict_scaling else np.median
scale = 1.0 / scale_fn(np.linalg.norm(t + translate, axis=-1))
return transform, scale
class SubjectLoader(torch.utils.data.Dataset):
......@@ -169,7 +223,7 @@ class SubjectLoader(torch.utils.data.Dataset):
far: float = None,
batch_over_images: bool = True,
factor: int = 1,
device: str = "cuda:0",
device: str = "cpu",
):
super().__init__()
assert split in self.SPLITS, "%s" % split
......@@ -184,14 +238,25 @@ class SubjectLoader(torch.utils.data.Dataset):
)
self.color_bkgd_aug = color_bkgd_aug
self.batch_over_images = batch_over_images
self.images, self.camtoworlds, self.K = _load_colmap(
root_fp, subject_id, split, factor
self.images, self.camtoworlds, self.K, split_indices = _load_colmap(
root_fp, subject_id, factor
)
# normalize the scene
T, sscale = similarity_from_cameras(
self.camtoworlds, strict_scaling=False
)
self.images = torch.from_numpy(self.images).to(device).to(torch.uint8)
self.camtoworlds = np.einsum("nij, ki -> nkj", self.camtoworlds, T)
self.camtoworlds[:, :3, 3] *= sscale
# split
indices = split_indices[split]
self.images = self.images[indices]
self.camtoworlds = self.camtoworlds[indices]
# to tensor
self.images = torch.from_numpy(self.images).to(torch.uint8).to(device)
self.camtoworlds = (
torch.from_numpy(self.camtoworlds).to(device).to(torch.float32)
torch.from_numpy(self.camtoworlds).to(torch.float32).to(device)
)
self.K = torch.tensor(self.K).to(device).to(torch.float32)
self.K = torch.tensor(self.K).to(torch.float32).to(device)
self.height, self.width = self.images.shape[1:3]
def __len__(self):
......@@ -275,7 +340,7 @@ class SubjectLoader(torch.utils.data.Dataset):
value=(-1.0 if self.OPENGL_CAMERA else 1.0),
) # [num_rays, 3]
# [n_cams, height, width, 3]
# [num_rays, 3]
directions = (camera_dirs[:, None, :] * c2w[:, :3, :3]).sum(dim=-1)
origins = torch.broadcast_to(c2w[:, :3, -1], directions.shape)
viewdirs = directions / torch.linalg.norm(
......
......@@ -79,7 +79,7 @@ class SubjectLoader(torch.utils.data.Dataset):
near: float = None,
far: float = None,
batch_over_images: bool = True,
device: str = "cuda:0",
device: torch.device = torch.device("cpu"),
):
super().__init__()
assert split in self.SPLITS, "%s" % split
......@@ -110,10 +110,8 @@ class SubjectLoader(torch.utils.data.Dataset):
self.images, self.camtoworlds, self.focal = _load_renderings(
root_fp, subject_id, split
)
self.images = torch.from_numpy(self.images).to(device).to(torch.uint8)
self.camtoworlds = (
torch.from_numpy(self.camtoworlds).to(device).to(torch.float32)
)
self.images = torch.from_numpy(self.images).to(torch.uint8)
self.camtoworlds = torch.from_numpy(self.camtoworlds).to(torch.float32)
self.K = torch.tensor(
[
[self.focal, 0, self.WIDTH / 2.0],
......@@ -121,8 +119,10 @@ class SubjectLoader(torch.utils.data.Dataset):
[0, 0, 1],
],
dtype=torch.float32,
device=device,
) # (3, 3)
self.images = self.images.to(device)
self.camtoworlds = self.camtoworlds.to(device)
self.K = self.K.to(device)
assert self.images.shape[1:3] == (self.HEIGHT, self.WIDTH)
def __len__(self):
......
......@@ -4,6 +4,7 @@ Copyright (c) 2022 Ruilong Li, UC Berkeley.
from typing import Callable, List, Union
import numpy as np
import torch
from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd
......@@ -41,13 +42,15 @@ trunc_exp = _TruncExp.apply
def contract_to_unisphere(
x: torch.Tensor,
aabb: torch.Tensor,
ord: Union[str, int] = 2,
# ord: Union[float, int] = float("inf"),
eps: float = 1e-6,
derivative: bool = False,
):
aabb_min, aabb_max = torch.split(aabb, 3, dim=-1)
x = (x - aabb_min) / (aabb_max - aabb_min)
x = x * 2 - 1 # aabb is at [-1, 1]
mag = x.norm(dim=-1, keepdim=True)
mag = torch.linalg.norm(x, ord=ord, dim=-1, keepdim=True)
mask = mag.squeeze(-1) > 1
if derivative:
......@@ -63,8 +66,8 @@ def contract_to_unisphere(
return x
class NGPradianceField(torch.nn.Module):
"""Instance-NGP radiance Field"""
class NGPRadianceField(torch.nn.Module):
"""Instance-NGP Radiance Field"""
def __init__(
self,
......@@ -73,6 +76,8 @@ class NGPradianceField(torch.nn.Module):
use_viewdirs: bool = True,
density_activation: Callable = lambda x: trunc_exp(x - 1),
unbounded: bool = False,
base_resolution: int = 16,
max_resolution: int = 4096,
geo_feat_dim: int = 15,
n_levels: int = 16,
log2_hashmap_size: int = 19,
......@@ -85,9 +90,15 @@ class NGPradianceField(torch.nn.Module):
self.use_viewdirs = use_viewdirs
self.density_activation = density_activation
self.unbounded = unbounded
self.base_resolution = base_resolution
self.max_resolution = max_resolution
self.geo_feat_dim = geo_feat_dim
per_level_scale = 1.4472692012786865
self.n_levels = n_levels
self.log2_hashmap_size = log2_hashmap_size
per_level_scale = np.exp(
(np.log(max_resolution) - np.log(base_resolution)) / (n_levels - 1)
).tolist()
if self.use_viewdirs:
self.direction_encoding = tcnn.Encoding(
......@@ -113,7 +124,7 @@ class NGPradianceField(torch.nn.Module):
"n_levels": n_levels,
"n_features_per_level": 2,
"log2_hashmap_size": log2_hashmap_size,
"base_resolution": 16,
"base_resolution": base_resolution,
"per_level_scale": per_level_scale,
},
network_config={
......@@ -138,7 +149,7 @@ class NGPradianceField(torch.nn.Module):
network_config={
"otype": "FullyFusedMLP",
"activation": "ReLU",
"output_activation": "Sigmoid",
"output_activation": "None",
"n_neurons": 64,
"n_hidden_layers": 2,
},
......@@ -168,19 +179,21 @@ class NGPradianceField(torch.nn.Module):
else:
return density
def _query_rgb(self, dir, embedding):
def _query_rgb(self, dir, embedding, apply_act: bool = True):
# tcnn requires directions in the range [0, 1]
if self.use_viewdirs:
dir = (dir + 1.0) / 2.0
d = self.direction_encoding(dir.view(-1, dir.shape[-1]))
h = torch.cat([d, embedding.view(-1, self.geo_feat_dim)], dim=-1)
d = self.direction_encoding(dir.reshape(-1, dir.shape[-1]))
h = torch.cat([d, embedding.reshape(-1, self.geo_feat_dim)], dim=-1)
else:
h = embedding.view(-1, self.geo_feat_dim)
h = embedding.reshape(-1, self.geo_feat_dim)
rgb = (
self.mlp_head(h)
.view(list(embedding.shape[:-1]) + [3])
.reshape(list(embedding.shape[:-1]) + [3])
.to(embedding)
)
if apply_act:
rgb = torch.sigmoid(rgb)
return rgb
def forward(
......@@ -194,4 +207,73 @@ class NGPradianceField(torch.nn.Module):
), f"{positions.shape} v.s. {directions.shape}"
density, embedding = self.query_density(positions, return_feat=True)
rgb = self._query_rgb(directions, embedding=embedding)
return rgb, density
return rgb, density # type: ignore
class NGPDensityField(torch.nn.Module):
"""Instance-NGP Density Field used for resampling"""
def __init__(
self,
aabb: Union[torch.Tensor, List[float]],
num_dim: int = 3,
density_activation: Callable = lambda x: trunc_exp(x - 1),
unbounded: bool = False,
base_resolution: int = 16,
max_resolution: int = 128,
n_levels: int = 5,
log2_hashmap_size: int = 17,
) -> None:
super().__init__()
if not isinstance(aabb, torch.Tensor):
aabb = torch.tensor(aabb, dtype=torch.float32)
self.register_buffer("aabb", aabb)
self.num_dim = num_dim
self.density_activation = density_activation
self.unbounded = unbounded
self.base_resolution = base_resolution
self.max_resolution = max_resolution
self.n_levels = n_levels
self.log2_hashmap_size = log2_hashmap_size
per_level_scale = np.exp(
(np.log(max_resolution) - np.log(base_resolution)) / (n_levels - 1)
).tolist()
self.mlp_base = tcnn.NetworkWithInputEncoding(
n_input_dims=num_dim,
n_output_dims=1,
encoding_config={
"otype": "HashGrid",
"n_levels": n_levels,
"n_features_per_level": 2,
"log2_hashmap_size": log2_hashmap_size,
"base_resolution": base_resolution,
"per_level_scale": per_level_scale,
},
network_config={
"otype": "FullyFusedMLP",
"activation": "ReLU",
"output_activation": "None",
"n_neurons": 64,
"n_hidden_layers": 1,
},
)
def forward(self, positions: torch.Tensor):
if self.unbounded:
positions = contract_to_unisphere(positions, self.aabb)
else:
aabb_min, aabb_max = torch.split(self.aabb, self.num_dim, dim=-1)
positions = (positions - aabb_min) / (aabb_max - aabb_min)
selector = ((positions > 0.0) & (positions < 1.0)).all(dim=-1)
density_before_activation = (
self.mlp_base(positions.view(-1, self.num_dim))
.view(list(positions.shape[:-1]) + [1])
.to(positions)
)
density = (
self.density_activation(density_before_activation)
* selector[..., None]
)
return density
......@@ -3,4 +3,5 @@ opencv-python
imageio
numpy
tqdm
scipy
\ No newline at end of file
scipy
lpips
\ No newline at end of file
......@@ -16,227 +16,221 @@ from datasets.dnerf_synthetic import SubjectLoader
from radiance_fields.mlp import DNeRFRadianceField
from utils import render_image, set_random_seed
from nerfacc import ContractionType, OccupancyGrid
if __name__ == "__main__":
device = "cuda:0"
set_random_seed(42)
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_root",
type=str,
default=str(pathlib.Path.cwd() / "data/dnerf"),
help="the root dir of the dataset",
)
parser.add_argument(
"--train_split",
type=str,
default="train",
choices=["train"],
help="which train split to use",
)
parser.add_argument(
"--scene",
type=str,
default="lego",
choices=[
# dnerf
"bouncingballs",
"hellwarrior",
"hook",
"jumpingjacks",
"lego",
"mutant",
"standup",
"trex",
],
help="which scene to use",
)
parser.add_argument(
"--aabb",
type=lambda s: [float(item) for item in s.split(",")],
default="-1.5,-1.5,-1.5,1.5,1.5,1.5",
help="delimited list input",
)
parser.add_argument(
"--test_chunk_size",
type=int,
default=8192,
)
parser.add_argument("--cone_angle", type=float, default=0.0)
args = parser.parse_args()
render_n_samples = 1024
# setup the scene bounding box.
contraction_type = ContractionType.AABB
scene_aabb = torch.tensor(args.aabb, dtype=torch.float32, device=device)
near_plane = None
far_plane = None
render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max()
* math.sqrt(3)
/ render_n_samples
).item()
# setup the radiance field we want to train.
max_steps = 30000
grad_scaler = torch.cuda.amp.GradScaler(1)
radiance_field = DNeRFRadianceField().to(device)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
max_steps // 2,
max_steps * 3 // 4,
max_steps * 5 // 6,
max_steps * 9 // 10,
],
gamma=0.33,
)
# setup the dataset
target_sample_batch_size = 1 << 16
grid_resolution = 128
train_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split=args.train_split,
num_rays=target_sample_batch_size // render_n_samples,
)
test_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split="test",
num_rays=None,
)
occupancy_grid = OccupancyGrid(
roi_aabb=args.aabb,
resolution=grid_resolution,
contraction_type=contraction_type,
).to(device)
# training
step = 0
tic = time.time()
for epoch in range(10000000):
for i in range(len(train_dataset)):
radiance_field.train()
data = train_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
timestamps = data["timestamps"]
# update occupancy grid
occupancy_grid.every_n_step(
step=step,
occ_eval_fn=lambda x: radiance_field.query_opacity(
x, timestamps, render_step_size
),
from nerfacc import OccupancyGrid
device = "cuda:0"
set_random_seed(42)
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_root",
type=str,
default=str(pathlib.Path.cwd() / "data/dnerf"),
help="the root dir of the dataset",
)
parser.add_argument(
"--train_split",
type=str,
default="train",
choices=["train"],
help="which train split to use",
)
parser.add_argument(
"--scene",
type=str,
default="lego",
choices=[
# dnerf
"bouncingballs",
"hellwarrior",
"hook",
"jumpingjacks",
"lego",
"mutant",
"standup",
"trex",
],
help="which scene to use",
)
parser.add_argument(
"--aabb",
type=lambda s: [float(item) for item in s.split(",")],
default="-1.5,-1.5,-1.5,1.5,1.5,1.5",
help="delimited list input",
)
parser.add_argument(
"--test_chunk_size",
type=int,
default=8192,
)
parser.add_argument("--cone_angle", type=float, default=0.0)
args = parser.parse_args()
render_n_samples = 1024
# setup the scene bounding box.
scene_aabb = torch.tensor(args.aabb, dtype=torch.float32, device=device)
near_plane = None
far_plane = None
render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples
).item()
# setup the radiance field we want to train.
max_steps = 30000
grad_scaler = torch.cuda.amp.GradScaler(1)
radiance_field = DNeRFRadianceField().to(device)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
max_steps // 2,
max_steps * 3 // 4,
max_steps * 5 // 6,
max_steps * 9 // 10,
],
gamma=0.33,
)
# setup the dataset
target_sample_batch_size = 1 << 16
grid_resolution = 128
train_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split=args.train_split,
num_rays=target_sample_batch_size // render_n_samples,
device=device,
)
test_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split="test",
num_rays=None,
device=device,
)
occupancy_grid = OccupancyGrid(
roi_aabb=args.aabb, resolution=grid_resolution
).to(device)
# training
step = 0
tic = time.time()
for epoch in range(10000000):
for i in range(len(train_dataset)):
radiance_field.train()
data = train_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
timestamps = data["timestamps"]
# update occupancy grid
occupancy_grid.every_n_step(
step=step,
occ_eval_fn=lambda x: radiance_field.query_opacity(
x, timestamps, render_step_size
),
)
# render
rgb, acc, depth, n_rendering_samples = render_image(
radiance_field,
occupancy_grid,
rays,
scene_aabb,
# rendering options
near_plane=near_plane,
far_plane=far_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
alpha_thre=0.01 if step > 1000 else 0.00,
# dnerf options
timestamps=timestamps,
)
if n_rendering_samples == 0:
continue
# dynamic batch size for rays to keep sample batch size constant.
num_rays = len(pixels)
num_rays = int(
num_rays * (target_sample_batch_size / float(n_rendering_samples))
)
train_dataset.update_num_rays(num_rays)
alive_ray_mask = acc.squeeze(-1) > 0
# compute loss
loss = F.smooth_l1_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
optimizer.zero_grad()
# do not unscale it because we are using Adam.
grad_scaler.scale(loss).backward()
optimizer.step()
scheduler.step()
if step % 5000 == 0:
elapsed_time = time.time() - tic
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
print(
f"elapsed_time={elapsed_time:.2f}s | step={step} | "
f"loss={loss:.5f} | "
f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |"
)
# render
rgb, acc, depth, n_rendering_samples = render_image(
radiance_field,
occupancy_grid,
rays,
scene_aabb,
# rendering options
near_plane=near_plane,
far_plane=far_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
alpha_thre=0.01 if step > 1000 else 0.00,
# dnerf options
timestamps=timestamps,
)
if n_rendering_samples == 0:
continue
# dynamic batch size for rays to keep sample batch size constant.
num_rays = len(pixels)
num_rays = int(
num_rays
* (target_sample_batch_size / float(n_rendering_samples))
)
train_dataset.update_num_rays(num_rays)
alive_ray_mask = acc.squeeze(-1) > 0
# compute loss
loss = F.smooth_l1_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
optimizer.zero_grad()
# do not unscale it because we are using Adam.
grad_scaler.scale(loss).backward()
optimizer.step()
scheduler.step()
if step % 5000 == 0:
elapsed_time = time.time() - tic
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
print(
f"elapsed_time={elapsed_time:.2f}s | step={step} | "
f"loss={loss:.5f} | "
f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |"
)
if step > 0 and step % max_steps == 0:
# evaluation
radiance_field.eval()
psnrs = []
with torch.no_grad():
for i in tqdm.tqdm(range(len(test_dataset))):
data = test_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
timestamps = data["timestamps"]
# rendering
rgb, acc, depth, _ = render_image(
radiance_field,
occupancy_grid,
rays,
scene_aabb,
# rendering options
near_plane=None,
far_plane=None,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
alpha_thre=0.01,
# test options
test_chunk_size=args.test_chunk_size,
# dnerf options
timestamps=timestamps,
)
mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item())
# imageio.imwrite(
# "acc_binary_test.png",
# ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8),
# )
# imageio.imwrite(
# "rgb_test.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8),
# )
# break
psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation: psnr_avg={psnr_avg}")
train_dataset.training = True
if step == max_steps:
print("training stops")
exit()
step += 1
if step > 0 and step % max_steps == 0:
# evaluation
radiance_field.eval()
psnrs = []
with torch.no_grad():
for i in tqdm.tqdm(range(len(test_dataset))):
data = test_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
timestamps = data["timestamps"]
# rendering
rgb, acc, depth, _ = render_image(
radiance_field,
occupancy_grid,
rays,
scene_aabb,
# rendering options
near_plane=None,
far_plane=None,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
alpha_thre=0.01,
# test options
test_chunk_size=args.test_chunk_size,
# dnerf options
timestamps=timestamps,
)
mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item())
# imageio.imwrite(
# "acc_binary_test.png",
# ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8),
# )
# imageio.imwrite(
# "rgb_test.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8),
# )
# break
psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation: psnr_avg={psnr_avg}")
train_dataset.training = True
if step == max_steps:
print("training stops")
exit()
step += 1
......@@ -17,248 +17,245 @@ from utils import render_image, set_random_seed
from nerfacc import ContractionType, OccupancyGrid
if __name__ == "__main__":
device = "cuda:0"
set_random_seed(42)
device = "cuda:0"
set_random_seed(42)
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_root",
type=str,
default=str(pathlib.Path.cwd() / "data/nerf_synthetic"),
help="the root dir of the dataset",
)
parser.add_argument(
"--train_split",
type=str,
default="trainval",
choices=["train", "trainval"],
help="which train split to use",
)
parser.add_argument(
"--scene",
type=str,
default="lego",
choices=[
# nerf synthetic
"chair",
"drums",
"ficus",
"hotdog",
"lego",
"materials",
"mic",
"ship",
# mipnerf360 unbounded
"garden",
],
help="which scene to use",
)
parser.add_argument(
"--aabb",
type=lambda s: [float(item) for item in s.split(",")],
default="-1.5,-1.5,-1.5,1.5,1.5,1.5",
help="delimited list input",
)
parser.add_argument(
"--test_chunk_size",
type=int,
default=8192,
)
parser.add_argument(
"--unbounded",
action="store_true",
help="whether to use unbounded rendering",
)
parser.add_argument("--cone_angle", type=float, default=0.0)
args = parser.parse_args()
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_root",
type=str,
default=str(pathlib.Path.cwd() / "data/nerf_synthetic"),
help="the root dir of the dataset",
)
parser.add_argument(
"--train_split",
type=str,
default="trainval",
choices=["train", "trainval"],
help="which train split to use",
)
parser.add_argument(
"--scene",
type=str,
default="lego",
choices=[
# nerf synthetic
"chair",
"drums",
"ficus",
"hotdog",
"lego",
"materials",
"mic",
"ship",
# mipnerf360 unbounded
"garden",
],
help="which scene to use",
)
parser.add_argument(
"--aabb",
type=lambda s: [float(item) for item in s.split(",")],
default="-1.5,-1.5,-1.5,1.5,1.5,1.5",
help="delimited list input",
)
parser.add_argument(
"--test_chunk_size",
type=int,
default=8192,
)
parser.add_argument(
"--unbounded",
action="store_true",
help="whether to use unbounded rendering",
)
parser.add_argument("--cone_angle", type=float, default=0.0)
args = parser.parse_args()
render_n_samples = 1024
render_n_samples = 1024
# setup the scene bounding box.
if args.unbounded:
print("Using unbounded rendering")
contraction_type = ContractionType.UN_BOUNDED_SPHERE
# contraction_type = ContractionType.UN_BOUNDED_TANH
scene_aabb = None
near_plane = 0.2
far_plane = 1e4
render_step_size = 1e-2
else:
contraction_type = ContractionType.AABB
scene_aabb = torch.tensor(args.aabb, dtype=torch.float32, device=device)
near_plane = None
far_plane = None
render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max()
* math.sqrt(3)
/ render_n_samples
).item()
# setup the scene bounding box.
if args.unbounded:
print("Using unbounded rendering")
contraction_type = ContractionType.UN_BOUNDED_SPHERE
# contraction_type = ContractionType.UN_BOUNDED_TANH
scene_aabb = None
near_plane = 0.2
far_plane = 1e4
render_step_size = 1e-2
else:
contraction_type = ContractionType.AABB
scene_aabb = torch.tensor(args.aabb, dtype=torch.float32, device=device)
near_plane = None
far_plane = None
render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max()
* math.sqrt(3)
/ render_n_samples
).item()
# setup the radiance field we want to train.
max_steps = 50000
grad_scaler = torch.cuda.amp.GradScaler(1)
radiance_field = VanillaNeRFRadianceField().to(device)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
max_steps // 2,
max_steps * 3 // 4,
max_steps * 5 // 6,
max_steps * 9 // 10,
],
gamma=0.33,
)
# setup the radiance field we want to train.
max_steps = 50000
grad_scaler = torch.cuda.amp.GradScaler(1)
radiance_field = VanillaNeRFRadianceField().to(device)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
max_steps // 2,
max_steps * 3 // 4,
max_steps * 5 // 6,
max_steps * 9 // 10,
],
gamma=0.33,
)
# setup the dataset
train_dataset_kwargs = {}
test_dataset_kwargs = {}
if args.scene == "garden":
from datasets.nerf_360_v2 import SubjectLoader
# setup the dataset
train_dataset_kwargs = {}
test_dataset_kwargs = {}
if args.scene == "garden":
from datasets.nerf_360_v2 import SubjectLoader
target_sample_batch_size = 1 << 16
train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4}
test_dataset_kwargs = {"factor": 4}
grid_resolution = 128
else:
from datasets.nerf_synthetic import SubjectLoader
target_sample_batch_size = 1 << 16
train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4}
test_dataset_kwargs = {"factor": 4}
grid_resolution = 128
else:
from datasets.nerf_synthetic import SubjectLoader
target_sample_batch_size = 1 << 16
grid_resolution = 128
target_sample_batch_size = 1 << 16
grid_resolution = 128
train_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split=args.train_split,
num_rays=target_sample_batch_size // render_n_samples,
**train_dataset_kwargs,
)
train_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split=args.train_split,
num_rays=target_sample_batch_size // render_n_samples,
**train_dataset_kwargs,
)
test_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split="test",
num_rays=None,
**test_dataset_kwargs,
)
test_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split="test",
num_rays=None,
**test_dataset_kwargs,
)
occupancy_grid = OccupancyGrid(
roi_aabb=args.aabb,
resolution=grid_resolution,
contraction_type=contraction_type,
).to(device)
occupancy_grid = OccupancyGrid(
roi_aabb=args.aabb,
resolution=grid_resolution,
contraction_type=contraction_type,
).to(device)
# training
step = 0
tic = time.time()
for epoch in range(10000000):
for i in range(len(train_dataset)):
radiance_field.train()
data = train_dataset[i]
# training
step = 0
tic = time.time()
for epoch in range(10000000):
for i in range(len(train_dataset)):
radiance_field.train()
data = train_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
# update occupancy grid
occupancy_grid.every_n_step(
step=step,
occ_eval_fn=lambda x: radiance_field.query_opacity(
x, render_step_size
),
)
# update occupancy grid
occupancy_grid.every_n_step(
step=step,
occ_eval_fn=lambda x: radiance_field.query_opacity(
x, render_step_size
),
)
# render
rgb, acc, depth, n_rendering_samples = render_image(
radiance_field,
occupancy_grid,
rays,
scene_aabb,
# rendering options
near_plane=near_plane,
far_plane=far_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
)
if n_rendering_samples == 0:
continue
# render
rgb, acc, depth, n_rendering_samples = render_image(
radiance_field,
occupancy_grid,
rays,
scene_aabb,
# rendering options
near_plane=near_plane,
far_plane=far_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
)
if n_rendering_samples == 0:
continue
# dynamic batch size for rays to keep sample batch size constant.
num_rays = len(pixels)
num_rays = int(
num_rays * (target_sample_batch_size / float(n_rendering_samples))
)
train_dataset.update_num_rays(num_rays)
alive_ray_mask = acc.squeeze(-1) > 0
# dynamic batch size for rays to keep sample batch size constant.
num_rays = len(pixels)
num_rays = int(
num_rays
* (target_sample_batch_size / float(n_rendering_samples))
)
train_dataset.update_num_rays(num_rays)
alive_ray_mask = acc.squeeze(-1) > 0
# compute loss
loss = F.smooth_l1_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
# compute loss
loss = F.smooth_l1_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
optimizer.zero_grad()
# do not unscale it because we are using Adam.
grad_scaler.scale(loss).backward()
optimizer.step()
scheduler.step()
optimizer.zero_grad()
# do not unscale it because we are using Adam.
grad_scaler.scale(loss).backward()
optimizer.step()
scheduler.step()
if step % 5000 == 0:
elapsed_time = time.time() - tic
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
print(
f"elapsed_time={elapsed_time:.2f}s | step={step} | "
f"loss={loss:.5f} | "
f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |"
)
if step % 5000 == 0:
elapsed_time = time.time() - tic
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
print(
f"elapsed_time={elapsed_time:.2f}s | step={step} | "
f"loss={loss:.5f} | "
f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |"
)
if step > 0 and step % max_steps == 0:
# evaluation
radiance_field.eval()
if step > 0 and step % max_steps == 0:
# evaluation
radiance_field.eval()
psnrs = []
with torch.no_grad():
for i in tqdm.tqdm(range(len(test_dataset))):
data = test_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
psnrs = []
with torch.no_grad():
for i in tqdm.tqdm(range(len(test_dataset))):
data = test_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
# rendering
rgb, acc, depth, _ = render_image(
radiance_field,
occupancy_grid,
rays,
scene_aabb,
# rendering options
near_plane=None,
far_plane=None,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
# test options
test_chunk_size=args.test_chunk_size,
)
mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item())
# imageio.imwrite(
# "acc_binary_test.png",
# ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8),
# )
# imageio.imwrite(
# "rgb_test.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8),
# )
# break
psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation: psnr_avg={psnr_avg}")
train_dataset.training = True
# rendering
rgb, acc, depth, _ = render_image(
radiance_field,
occupancy_grid,
rays,
scene_aabb,
# rendering options
near_plane=None,
far_plane=None,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
# test options
test_chunk_size=args.test_chunk_size,
)
mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item())
# imageio.imwrite(
# "acc_binary_test.png",
# ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8),
# )
# imageio.imwrite(
# "rgb_test.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8),
# )
# break
psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation: psnr_avg={psnr_avg}")
train_dataset.training = True
if step == max_steps:
print("training stops")
exit()
if step == max_steps:
print("training stops")
exit()
step += 1
step += 1
......@@ -12,297 +12,259 @@ import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from radiance_fields.ngp import NGPradianceField
from utils import render_image, set_random_seed
from lpips import LPIPS
from radiance_fields.ngp import NGPRadianceField
from utils import (
MIPNERF360_UNBOUNDED_SCENES,
NERF_SYNTHETIC_SCENES,
enlarge_aabb,
render_image,
set_random_seed,
)
from nerfacc import ContractionType, OccupancyGrid
from nerfacc import OccupancyGrid
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_root",
type=str,
# default=str(pathlib.Path.cwd() / "data/360_v2"),
default=str(pathlib.Path.cwd() / "data/nerf_synthetic"),
help="the root dir of the dataset",
)
parser.add_argument(
"--train_split",
type=str,
default="train",
choices=["train", "trainval"],
help="which train split to use",
)
parser.add_argument(
"--scene",
type=str,
default="lego",
choices=NERF_SYNTHETIC_SCENES + MIPNERF360_UNBOUNDED_SCENES,
help="which scene to use",
)
parser.add_argument(
"--test_chunk_size",
type=int,
default=8192,
)
args = parser.parse_args()
device = "cuda:0"
set_random_seed(42)
device = "cuda:0"
set_random_seed(42)
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_root",
type=str,
default=str(pathlib.Path.cwd() / "data"),
help="the root dir of the dataset",
)
parser.add_argument(
"--train_split",
type=str,
default="trainval",
choices=["train", "trainval"],
help="which train split to use",
)
parser.add_argument(
"--scene",
type=str,
default="lego",
choices=[
# nerf synthetic
"chair",
"drums",
"ficus",
"hotdog",
"lego",
"materials",
"mic",
"ship",
# mipnerf360 unbounded
"garden",
"bicycle",
"bonsai",
"counter",
"kitchen",
"room",
"stump",
],
help="which scene to use",
)
parser.add_argument(
"--aabb",
type=lambda s: [float(item) for item in s.split(",")],
default="-1.5,-1.5,-1.5,1.5,1.5,1.5",
help="delimited list input",
)
parser.add_argument(
"--test_chunk_size",
type=int,
default=8192,
)
parser.add_argument(
"--unbounded",
action="store_true",
help="whether to use unbounded rendering",
)
parser.add_argument(
"--auto_aabb",
action="store_true",
help="whether to automatically compute the aabb",
)
parser.add_argument("--cone_angle", type=float, default=0.0)
args = parser.parse_args()
render_n_samples = 1024
# setup the dataset
train_dataset_kwargs = {}
test_dataset_kwargs = {}
if args.unbounded:
from datasets.nerf_360_v2 import SubjectLoader
if args.scene in MIPNERF360_UNBOUNDED_SCENES:
from datasets.nerf_360_v2 import SubjectLoader
target_sample_batch_size = 1 << 20
train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4}
test_dataset_kwargs = {"factor": 4}
grid_resolution = 256
else:
from datasets.nerf_synthetic import SubjectLoader
# training parameters
max_steps = 100000
init_batch_size = 1024
target_sample_batch_size = 1 << 18
weight_decay = 0.0
# scene parameters
aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=device)
near_plane = 0.02
far_plane = None
# dataset parameters
train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4}
test_dataset_kwargs = {"factor": 4}
# model parameters
grid_resolution = 128
grid_nlvl = 4
# render parameters
render_step_size = 1e-3
alpha_thre = 1e-2
cone_angle = 0.004
target_sample_batch_size = 1 << 18
grid_resolution = 128
else:
from datasets.nerf_synthetic import SubjectLoader
train_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split=args.train_split,
num_rays=target_sample_batch_size // render_n_samples,
**train_dataset_kwargs,
)
test_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split="test",
num_rays=None,
**test_dataset_kwargs,
# training parameters
max_steps = 20000
init_batch_size = 1024
target_sample_batch_size = 1 << 18
weight_decay = (
1e-5 if args.scene in ["materials", "ficus", "drums"] else 1e-6
)
# scene parameters
aabb = torch.tensor([-1.5, -1.5, -1.5, 1.5, 1.5, 1.5], device=device)
near_plane = None
far_plane = None
# dataset parameters
train_dataset_kwargs = {}
test_dataset_kwargs = {}
# model parameters
grid_resolution = 128
grid_nlvl = 1
# render parameters
render_step_size = 5e-3
alpha_thre = 0.0
cone_angle = 0.0
if args.auto_aabb:
camera_locs = torch.cat(
[train_dataset.camtoworlds, test_dataset.camtoworlds]
)[:, :3, -1]
args.aabb = torch.cat(
[camera_locs.min(dim=0).values, camera_locs.max(dim=0).values]
).tolist()
print("Using auto aabb", args.aabb)
train_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split=args.train_split,
num_rays=init_batch_size,
device=device,
**train_dataset_kwargs,
)
# setup the scene bounding box.
if args.unbounded:
print("Using unbounded rendering")
contraction_type = ContractionType.UN_BOUNDED_SPHERE
# contraction_type = ContractionType.UN_BOUNDED_TANH
scene_aabb = None
near_plane = 0.2
far_plane = 1e4
render_step_size = 1e-2
alpha_thre = 1e-2
else:
contraction_type = ContractionType.AABB
scene_aabb = torch.tensor(args.aabb, dtype=torch.float32, device=device)
near_plane = None
far_plane = None
render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max()
* math.sqrt(3)
/ render_n_samples
).item()
alpha_thre = 0.0
test_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split="test",
num_rays=None,
device=device,
**test_dataset_kwargs,
)
# setup the radiance field we want to train.
max_steps = 20000
grad_scaler = torch.cuda.amp.GradScaler(2**10)
radiance_field = NGPradianceField(
aabb=args.aabb,
unbounded=args.unbounded,
).to(device)
optimizer = torch.optim.Adam(
radiance_field.parameters(), lr=1e-2, eps=1e-15
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[max_steps // 2, max_steps * 3 // 4, max_steps * 9 // 10],
gamma=0.33,
)
# setup scene aabb
scene_aabb = enlarge_aabb(aabb, 1 << (grid_nlvl - 1))
occupancy_grid = OccupancyGrid(
roi_aabb=args.aabb,
resolution=grid_resolution,
contraction_type=contraction_type,
).to(device)
# setup the radiance field we want to train.
grad_scaler = torch.cuda.amp.GradScaler(2**10)
radiance_field = NGPRadianceField(aabb=scene_aabb).to(device)
optimizer = torch.optim.Adam(
radiance_field.parameters(), lr=1e-2, eps=1e-15, weight_decay=weight_decay
)
scheduler = torch.optim.lr_scheduler.ChainedScheduler(
[
torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=0.01, total_iters=100
),
torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
max_steps // 2,
max_steps * 3 // 4,
max_steps * 9 // 10,
],
gamma=0.33,
),
]
)
occupancy_grid = OccupancyGrid(
roi_aabb=aabb, resolution=grid_resolution, levels=grid_nlvl
).to(device)
# training
step = 0
tic = time.time()
for epoch in range(10000000):
for i in range(len(train_dataset)):
radiance_field.train()
data = train_dataset[i]
lpips_net = LPIPS(net="vgg").to(device)
lpips_norm_fn = lambda x: x[None, ...].permute(0, 3, 1, 2) * 2 - 1
lpips_fn = lambda x, y: lpips_net(lpips_norm_fn(x), lpips_norm_fn(y)).mean()
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
# training
tic = time.time()
for step in range(max_steps + 1):
radiance_field.train()
def occ_eval_fn(x):
if args.cone_angle > 0.0:
# randomly sample a camera for computing step size.
camera_ids = torch.randint(
0, len(train_dataset), (x.shape[0],), device=device
)
origins = train_dataset.camtoworlds[camera_ids, :3, -1]
t = (origins - x).norm(dim=-1, keepdim=True)
# compute actual step size used in marching, based on the distance to the camera.
step_size = torch.clamp(
t * args.cone_angle, min=render_step_size
)
# filter out the points that are not in the near far plane.
if (near_plane is not None) and (far_plane is not None):
step_size = torch.where(
(t > near_plane) & (t < far_plane),
step_size,
torch.zeros_like(step_size),
)
else:
step_size = render_step_size
# compute occupancy
density = radiance_field.query_density(x)
return density * step_size
i = torch.randint(0, len(train_dataset), (1,)).item()
data = train_dataset[i]
# update occupancy grid
occupancy_grid.every_n_step(step=step, occ_eval_fn=occ_eval_fn)
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
# render
rgb, acc, depth, n_rendering_samples = render_image(
radiance_field,
occupancy_grid,
rays,
scene_aabb,
# rendering options
near_plane=near_plane,
far_plane=far_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
alpha_thre=alpha_thre,
)
if n_rendering_samples == 0:
continue
def occ_eval_fn(x):
density = radiance_field.query_density(x)
return density * render_step_size
# dynamic batch size for rays to keep sample batch size constant.
num_rays = len(pixels)
num_rays = int(
num_rays
* (target_sample_batch_size / float(n_rendering_samples))
)
train_dataset.update_num_rays(num_rays)
alive_ray_mask = acc.squeeze(-1) > 0
# update occupancy grid
occupancy_grid.every_n_step(
step=step,
occ_eval_fn=occ_eval_fn,
occ_thre=1e-2,
)
# compute loss
loss = F.smooth_l1_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
# render
rgb, acc, depth, n_rendering_samples = render_image(
radiance_field,
occupancy_grid,
rays,
scene_aabb=scene_aabb,
# rendering options
near_plane=near_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=cone_angle,
alpha_thre=alpha_thre,
)
if n_rendering_samples == 0:
continue
optimizer.zero_grad()
# do not unscale it because we are using Adam.
grad_scaler.scale(loss).backward()
optimizer.step()
scheduler.step()
if target_sample_batch_size > 0:
# dynamic batch size for rays to keep sample batch size constant.
num_rays = len(pixels)
num_rays = int(
num_rays * (target_sample_batch_size / float(n_rendering_samples))
)
train_dataset.update_num_rays(num_rays)
if step % 10000 == 0:
elapsed_time = time.time() - tic
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
print(
f"elapsed_time={elapsed_time:.2f}s | step={step} | "
f"loss={loss:.5f} | "
f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |"
)
# compute loss
loss = F.smooth_l1_loss(rgb, pixels)
if step > 0 and step % max_steps == 0:
# evaluation
radiance_field.eval()
optimizer.zero_grad()
# do not unscale it because we are using Adam.
grad_scaler.scale(loss).backward()
optimizer.step()
scheduler.step()
psnrs = []
with torch.no_grad():
for i in tqdm.tqdm(range(len(test_dataset))):
data = test_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
if step % 5000 == 0:
elapsed_time = time.time() - tic
loss = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(loss) / np.log(10.0)
print(
f"elapsed_time={elapsed_time:.2f}s | step={step} | "
f"loss={loss:.5f} | psnr={psnr:.2f} | "
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} | "
f"max_depth={depth.max():.3f} | "
)
# rendering
rgb, acc, depth, _ = render_image(
radiance_field,
occupancy_grid,
rays,
scene_aabb,
# rendering options
near_plane=near_plane,
far_plane=far_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
alpha_thre=alpha_thre,
# test options
test_chunk_size=args.test_chunk_size,
)
mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item())
# imageio.imwrite(
# "acc_binary_test.png",
# ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8),
# )
# imageio.imwrite(
# "rgb_test.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8),
# )
# break
psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation: psnr_avg={psnr_avg}")
train_dataset.training = True
if step > 0 and step % max_steps == 0:
# evaluation
radiance_field.eval()
if step == max_steps:
print("training stops")
exit()
psnrs = []
lpips = []
with torch.no_grad():
for i in tqdm.tqdm(range(len(test_dataset))):
data = test_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
step += 1
# rendering
rgb, acc, depth, _ = render_image(
radiance_field,
occupancy_grid,
rays,
scene_aabb=scene_aabb,
# rendering options
near_plane=near_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=cone_angle,
alpha_thre=alpha_thre,
# test options
test_chunk_size=args.test_chunk_size,
)
mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item())
lpips.append(lpips_fn(rgb, pixels).item())
# if i == 0:
# imageio.imwrite(
# "rgb_test.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8),
# )
# imageio.imwrite(
# "rgb_error.png",
# (
# (rgb - pixels).norm(dim=-1).cpu().numpy() * 255
# ).astype(np.uint8),
# )
psnr_avg = sum(psnrs) / len(psnrs)
lpips_avg = sum(lpips) / len(lpips)
print(f"evaluation: psnr_avg={psnr_avg}, lpips_avg={lpips_avg}")
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
import argparse
import itertools
import pathlib
import time
import imageio
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from lpips import LPIPS
from radiance_fields.ngp import NGPDensityField, NGPRadianceField
from utils import (
MIPNERF360_UNBOUNDED_SCENES,
NERF_SYNTHETIC_SCENES,
render_image_proposal,
set_random_seed,
)
from nerfacc.proposal import (
compute_prop_loss,
get_proposal_annealing_fn,
get_proposal_requires_grad_fn,
)
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_root",
type=str,
# default=str(pathlib.Path.cwd() / "data/360_v2"),
default=str(pathlib.Path.cwd() / "data/nerf_synthetic"),
help="the root dir of the dataset",
)
parser.add_argument(
"--train_split",
type=str,
default="train",
choices=["train", "trainval"],
help="which train split to use",
)
parser.add_argument(
"--scene",
type=str,
default="lego",
choices=NERF_SYNTHETIC_SCENES + MIPNERF360_UNBOUNDED_SCENES,
help="which scene to use",
)
parser.add_argument(
"--test_chunk_size",
type=int,
default=8192,
)
args = parser.parse_args()
device = "cuda:0"
set_random_seed(42)
if args.scene in MIPNERF360_UNBOUNDED_SCENES:
from datasets.nerf_360_v2 import SubjectLoader
# training parameters
max_steps = 100000
init_batch_size = 4096
weight_decay = 0.0
# scene parameters
unbounded = True
aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=device)
near_plane = 0.2 # TODO: Try 0.02
far_plane = 1e3
# dataset parameters
train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4}
test_dataset_kwargs = {"factor": 4}
# model parameters
proposal_networks = [
NGPDensityField(
aabb=aabb,
unbounded=unbounded,
n_levels=5,
max_resolution=128,
).to(device),
NGPDensityField(
aabb=aabb,
unbounded=unbounded,
n_levels=5,
max_resolution=256,
).to(device),
]
# render parameters
num_samples = 48
num_samples_per_prop = [256, 96]
sampling_type = "lindisp"
opaque_bkgd = True
else:
from datasets.nerf_synthetic import SubjectLoader
# training parameters
max_steps = 20000
init_batch_size = 4096
weight_decay = (
1e-5 if args.scene in ["materials", "ficus", "drums"] else 1e-6
)
# scene parameters
unbounded = False
aabb = torch.tensor([-1.5, -1.5, -1.5, 1.5, 1.5, 1.5], device=device)
near_plane = 2.0
far_plane = 6.0
# dataset parameters
train_dataset_kwargs = {}
test_dataset_kwargs = {}
# model parameters
proposal_networks = [
NGPDensityField(
aabb=aabb,
unbounded=unbounded,
n_levels=5,
max_resolution=128,
).to(device),
]
# render parameters
num_samples = 64
num_samples_per_prop = [128]
sampling_type = "uniform"
opaque_bkgd = False
train_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split=args.train_split,
num_rays=init_batch_size,
device=device,
**train_dataset_kwargs,
)
test_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split="test",
num_rays=None,
device=device,
**test_dataset_kwargs,
)
# setup the radiance field we want to train.
grad_scaler = torch.cuda.amp.GradScaler(2**10)
radiance_field = NGPRadianceField(aabb=aabb, unbounded=unbounded).to(device)
optimizer = torch.optim.Adam(
itertools.chain(
radiance_field.parameters(),
*[p.parameters() for p in proposal_networks],
),
lr=1e-2,
eps=1e-15,
weight_decay=weight_decay,
)
scheduler = torch.optim.lr_scheduler.ChainedScheduler(
[
torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=0.01, total_iters=100
),
torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
max_steps // 2,
max_steps * 3 // 4,
max_steps * 9 // 10,
],
gamma=0.33,
),
]
)
proposal_requires_grad_fn = get_proposal_requires_grad_fn()
proposal_annealing_fn = get_proposal_annealing_fn()
lpips_net = LPIPS(net="vgg").to(device)
lpips_norm_fn = lambda x: x[None, ...].permute(0, 3, 1, 2) * 2 - 1
lpips_fn = lambda x, y: lpips_net(lpips_norm_fn(x), lpips_norm_fn(y)).mean()
# training
tic = time.time()
for step in range(max_steps + 1):
radiance_field.train()
for p in proposal_networks:
p.train()
i = torch.randint(0, len(train_dataset), (1,)).item()
data = train_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
# render
(
rgb,
acc,
depth,
weights_per_level,
s_vals_per_level,
) = render_image_proposal(
radiance_field,
proposal_networks,
rays,
scene_aabb=None,
# rendering options
num_samples=num_samples,
num_samples_per_prop=num_samples_per_prop,
near_plane=near_plane,
far_plane=far_plane,
sampling_type=sampling_type,
opaque_bkgd=opaque_bkgd,
render_bkgd=render_bkgd,
# train options
proposal_requires_grad=proposal_requires_grad_fn(step),
proposal_annealing=proposal_annealing_fn(step),
)
# compute loss
loss = F.smooth_l1_loss(rgb, pixels)
loss_prop = compute_prop_loss(s_vals_per_level, weights_per_level)
loss = loss + loss_prop
optimizer.zero_grad()
# do not unscale it because we are using Adam.
grad_scaler.scale(loss).backward()
optimizer.step()
scheduler.step()
if step % 5000 == 0:
elapsed_time = time.time() - tic
loss = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(loss) / np.log(10.0)
print(
f"elapsed_time={elapsed_time:.2f}s | step={step} | "
f"loss={loss:.5f} | psnr={psnr:.2f} | "
f"num_rays={len(pixels):d} | "
f"max_depth={depth.max():.3f} | "
)
if step > 0 and step % max_steps == 0:
# evaluation
radiance_field.eval()
for p in proposal_networks:
p.eval()
psnrs = []
lpips = []
with torch.no_grad():
for i in tqdm.tqdm(range(len(test_dataset))):
data = test_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
# rendering
rgb, acc, depth, _, _, = render_image_proposal(
radiance_field,
proposal_networks,
rays,
scene_aabb=None,
# rendering options
num_samples=num_samples,
num_samples_per_prop=num_samples_per_prop,
near_plane=near_plane,
far_plane=far_plane,
sampling_type=sampling_type,
opaque_bkgd=opaque_bkgd,
render_bkgd=render_bkgd,
proposal_annealing=proposal_annealing_fn(step),
# test options
test_chunk_size=args.test_chunk_size,
)
mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item())
lpips.append(lpips_fn(rgb, pixels).item())
# if i == 0:
# imageio.imwrite(
# "rgb_test.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8),
# )
# imageio.imwrite(
# "rgb_error.png",
# (
# (rgb - pixels).norm(dim=-1).cpu().numpy() * 255
# ).astype(np.uint8),
# )
psnr_avg = sum(psnrs) / len(psnrs)
lpips_avg = sum(lpips) / len(lpips)
print(f"evaluation: psnr_avg={psnr_avg}, lpips_avg={lpips_avg}")
......@@ -3,13 +3,35 @@ Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
import random
from typing import Optional
from typing import Literal, Optional, Sequence
import numpy as np
import torch
from datasets.utils import Rays, namedtuple_map
from torch.utils.data._utils.collate import collate, default_collate_fn_map
from nerfacc import OccupancyGrid, ray_marching, rendering
from nerfacc.proposal import rendering as rendering_proposal
NERF_SYNTHETIC_SCENES = [
"chair",
"drums",
"ficus",
"hotdog",
"lego",
"materials",
"mic",
"ship",
]
MIPNERF360_UNBOUNDED_SCENES = [
"garden",
"bicycle",
"bonsai",
"counter",
"kitchen",
"room",
"stump",
]
def set_random_seed(seed):
......@@ -18,6 +40,12 @@ def set_random_seed(seed):
torch.manual_seed(seed)
def enlarge_aabb(aabb, factor: float) -> torch.Tensor:
center = (aabb[:3] + aabb[3:]) / 2
extent = (aabb[3:] - aabb[:3]) / 2
return torch.cat([center - extent * factor, center + extent * factor])
def render_image(
# scene
radiance_field: torch.nn.Module,
......@@ -116,3 +144,99 @@ def render_image(
depths.view((*rays_shape[:-1], -1)),
sum(n_rendering_samples),
)
def render_image_proposal(
# scene
radiance_field: torch.nn.Module,
proposal_networks: Sequence[torch.nn.Module],
rays: Rays,
scene_aabb: torch.Tensor,
# rendering options
num_samples: int,
num_samples_per_prop: Sequence[int],
near_plane: Optional[float] = None,
far_plane: Optional[float] = None,
sampling_type: Literal["uniform", "lindisp"] = "lindisp",
opaque_bkgd: bool = True,
render_bkgd: Optional[torch.Tensor] = None,
# train options
proposal_requires_grad: bool = False,
proposal_annealing: float = 1.0,
# test options
test_chunk_size: int = 8192,
):
"""Render the pixels of an image."""
rays_shape = rays.origins.shape
if len(rays_shape) == 3:
height, width, _ = rays_shape
num_rays = height * width
rays = namedtuple_map(
lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays
)
else:
num_rays, _ = rays_shape
def prop_sigma_fn(t_starts, t_ends, proposal_network):
t_origins = chunk_rays.origins[..., None, :]
t_dirs = chunk_rays.viewdirs[..., None, :]
positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0
return proposal_network(positions)
def rgb_sigma_fn(t_starts, t_ends):
t_origins = chunk_rays.origins[..., None, :]
t_dirs = chunk_rays.viewdirs[..., None, :].repeat_interleave(
t_starts.shape[-2], dim=-2
)
positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0
return radiance_field(positions, t_dirs)
results = []
chunk = (
torch.iinfo(torch.int32).max
if radiance_field.training
else test_chunk_size
)
for i in range(0, num_rays, chunk):
chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
(
rgb,
opacity,
depth,
(weights_per_level, s_vals_per_level),
) = rendering_proposal(
rgb_sigma_fn=rgb_sigma_fn,
num_samples=num_samples,
prop_sigma_fns=[
lambda *args: prop_sigma_fn(*args, p) for p in proposal_networks
],
num_samples_per_prop=num_samples_per_prop,
rays_o=chunk_rays.origins,
rays_d=chunk_rays.viewdirs,
scene_aabb=scene_aabb,
near_plane=near_plane,
far_plane=far_plane,
stratified=radiance_field.training,
sampling_type=sampling_type,
opaque_bkgd=opaque_bkgd,
render_bkgd=render_bkgd,
proposal_requires_grad=proposal_requires_grad,
proposal_annealing=proposal_annealing,
)
chunk_results = [rgb, opacity, depth]
results.append(chunk_results)
colors, opacities, depths = collate(
results,
collate_fn_map={
**default_collate_fn_map,
torch.Tensor: lambda x, **_: torch.cat(x, 0),
},
)
return (
colors.view((*rays_shape[:-1], -1)),
opacities.view((*rays_shape[:-1], -1)),
depths.view((*rays_shape[:-1], -1)),
weights_per_level,
s_vals_per_level,
)
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
import warnings
from .cdf import ray_resampling
from .contraction import ContractionType, contract, contract_inv
from .grid import Grid, OccupancyGrid, query_grid
from .intersection import ray_aabb_intersect
from .losses import distortion as loss_distortion
from .pack import pack_data, pack_info, unpack_data, unpack_info
from .ray_marching import ray_marching
from .version import __version__
......@@ -21,39 +17,30 @@ from .vol_rendering import (
rendering,
)
# About to be deprecated
def unpack_to_ray_indices(*args, **kwargs):
warnings.warn(
"`unpack_to_ray_indices` will be deprecated. Please use `unpack_info` instead.",
DeprecationWarning,
stacklevel=2,
)
return unpack_info(*args, **kwargs)
__all__ = [
"__version__",
# occ grid
"Grid",
"OccupancyGrid",
"query_grid",
"ContractionType",
# contraction
"contract",
"contract_inv",
# marching
"ray_aabb_intersect",
"ray_marching",
# rendering
"accumulate_along_rays",
"render_visibility",
"render_weight_from_alpha",
"render_weight_from_density",
"render_transmittance_from_density",
"render_transmittance_from_alpha",
"rendering",
# pack
"pack_data",
"unpack_data",
"unpack_info",
"pack_info",
"ray_resampling",
"loss_distortion",
"unpack_to_ray_indices",
"render_transmittance_from_density",
"render_transmittance_from_alpha",
]
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
from typing import Tuple
from torch import Tensor
import nerfacc.cuda as _C
def ray_resampling(
packed_info: Tensor,
t_starts: Tensor,
t_ends: Tensor,
weights: Tensor,
n_samples: int,
) -> Tuple[Tensor, Tensor, Tensor]:
"""Resample a set of rays based on the CDF of the weights.
Args:
packed_info (Tensor): Stores information on which samples belong to the same ray. \
See :func:`nerfacc.ray_marching` for details. Tensor with shape (n_rays, 2).
t_starts: Where the frustum-shape sample starts along a ray. Tensor with \
shape (n_samples, 1).
t_ends: Where the frustum-shape sample ends along a ray. Tensor with \
shape (n_samples, 1).
weights: Volumetric rendering weights for those samples. Tensor with shape \
(n_samples,).
n_samples (int): Number of samples per ray to resample.
Returns:
Resampled packed info (n_rays, 2), t_starts (n_samples, 1), and t_ends (n_samples, 1).
"""
(
resampled_packed_info,
resampled_t_starts,
resampled_t_ends,
) = _C.ray_resampling(
packed_info.contiguous(),
t_starts.contiguous(),
t_ends.contiguous(),
weights.contiguous(),
n_samples,
)
return resampled_packed_info, resampled_t_starts, resampled_t_ends
......@@ -23,7 +23,6 @@ grid_query = _make_lazy_cuda_func("grid_query")
ray_aabb_intersect = _make_lazy_cuda_func("ray_aabb_intersect")
ray_marching = _make_lazy_cuda_func("ray_marching")
ray_resampling = _make_lazy_cuda_func("ray_resampling")
is_cub_available = _make_lazy_cuda_func("is_cub_available")
transmittance_from_sigma_forward_cub = _make_lazy_cuda_func(
......@@ -68,3 +67,6 @@ weight_from_alpha_backward_naive = _make_lazy_cuda_func(
unpack_data = _make_lazy_cuda_func("unpack_data")
unpack_info = _make_lazy_cuda_func("unpack_info")
unpack_info_to_mask = _make_lazy_cuda_func("unpack_info_to_mask")
pdf_readout = _make_lazy_cuda_func("pdf_readout")
pdf_sampling = _make_lazy_cuda_func("pdf_sampling")
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include "include/helpers_cuda.h"
template <typename scalar_t>
__global__ void cdf_resampling_kernel(
const uint32_t n_rays,
const int *packed_info, // input ray & point indices.
const scalar_t *starts, // input start t
const scalar_t *ends, // input end t
const scalar_t *weights, // transmittance weights
const int *resample_packed_info,
scalar_t *resample_starts,
scalar_t *resample_ends)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
const int resample_base = resample_packed_info[i * 2 + 0];
const int resample_steps = resample_packed_info[i * 2 + 1];
if (steps == 0)
return;
starts += base;
ends += base;
weights += base;
resample_starts += resample_base;
resample_ends += resample_base;
// normalize weights **per ray**
scalar_t weights_sum = 0.0f;
for (int j = 0; j < steps; j++)
weights_sum += weights[j];
scalar_t padding = fmaxf(1e-5f - weights_sum, 0.0f);
scalar_t padding_step = padding / steps;
weights_sum += padding;
int num_bins = resample_steps + 1;
scalar_t cdf_step_size = (1.0f - 1.0 / num_bins) / resample_steps;
int idx = 0, j = 0;
scalar_t cdf_prev = 0.0f, cdf_next = (weights[idx] + padding_step) / weights_sum;
scalar_t cdf_u = 1.0 / (2 * num_bins);
while (j < num_bins)
{
if (cdf_u < cdf_next)
{
// printf("cdf_u: %f, cdf_next: %f\n", cdf_u, cdf_next);
// resample in this interval
scalar_t scaling = (ends[idx] - starts[idx]) / (cdf_next - cdf_prev);
scalar_t t = (cdf_u - cdf_prev) * scaling + starts[idx];
if (j < num_bins - 1)
resample_starts[j] = t;
if (j > 0)
resample_ends[j - 1] = t;
// going further to next resample
cdf_u += cdf_step_size;
j += 1;
}
else
{
// going to next interval
idx += 1;
cdf_prev = cdf_next;
cdf_next += (weights[idx] + padding_step) / weights_sum;
}
}
if (j != num_bins)
{
printf("Error: %d %d %f\n", j, num_bins, weights_sum);
}
return;
}
// template <typename scalar_t>
// __global__ void cdf_resampling_kernel(
// const uint32_t n_rays,
// const int *packed_info, // input ray & point indices.
// const scalar_t *starts, // input start t
// const scalar_t *ends, // input end t
// const scalar_t *weights, // transmittance weights
// const int *resample_packed_info,
// scalar_t *resample_starts,
// scalar_t *resample_ends)
// {
// CUDA_GET_THREAD_ID(i, n_rays);
// // locate
// const int base = packed_info[i * 2 + 0]; // point idx start.
// const int steps = packed_info[i * 2 + 1]; // point idx shift.
// const int resample_base = resample_packed_info[i * 2 + 0];
// const int resample_steps = resample_packed_info[i * 2 + 1];
// if (steps == 0)
// return;
// starts += base;
// ends += base;
// weights += base;
// resample_starts += resample_base;
// resample_ends += resample_base;
// scalar_t cdf_step_size = 1.0f / resample_steps;
// // normalize weights **per ray**
// scalar_t weights_sum = 0.0f;
// for (int j = 0; j < steps; j++)
// weights_sum += weights[j];
// scalar_t padding = fmaxf(1e-5f - weights_sum, 0.0f);
// scalar_t padding_step = padding / steps;
// weights_sum += padding;
// int idx = 0, j = 0;
// scalar_t cdf_prev = 0.0f, cdf_next = (weights[idx] + padding_step) / weights_sum;
// scalar_t cdf_u = 0.5f * cdf_step_size;
// while (cdf_u < 1.0f)
// {
// if (cdf_u < cdf_next)
// {
// // resample in this interval
// scalar_t scaling = (ends[idx] - starts[idx]) / (cdf_next - cdf_prev);
// scalar_t resample_mid = (cdf_u - cdf_prev) * scaling + starts[idx];
// scalar_t resample_half_size = cdf_step_size * scaling * 0.5;
// resample_starts[j] = fmaxf(resample_mid - resample_half_size, starts[idx]);
// resample_ends[j] = fminf(resample_mid + resample_half_size, ends[idx]);
// // going further to next resample
// cdf_u += cdf_step_size;
// j += 1;
// }
// else
// {
// // go to next interval
// idx += 1;
// if (idx == steps)
// break;
// cdf_prev = cdf_next;
// cdf_next += (weights[idx] + padding_step) / weights_sum;
// }
// }
// if (j != resample_steps)
// {
// printf("Error: %d %d %f\n", j, resample_steps, weights_sum);
// }
// return;
// }
std::vector<torch::Tensor> ray_resampling(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor weights,
const int steps)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(starts);
CHECK_INPUT(ends);
CHECK_INPUT(weights);
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 2);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(weights.ndimension() == 1);
const uint32_t n_rays = packed_info.size(0);
const uint32_t n_samples = weights.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
torch::Tensor num_steps = torch::split(packed_info, 1, 1)[1];
torch::Tensor resample_num_steps = (num_steps > 0).to(num_steps.options()) * steps;
torch::Tensor resample_cum_steps = resample_num_steps.cumsum(0, torch::kInt32);
torch::Tensor resample_packed_info = torch::cat(
{resample_cum_steps - resample_num_steps, resample_num_steps}, 1);
int total_steps = resample_cum_steps[resample_cum_steps.size(0) - 1].item<int>();
torch::Tensor resample_starts = torch::zeros({total_steps, 1}, starts.options());
torch::Tensor resample_ends = torch::zeros({total_steps, 1}, ends.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
weights.scalar_type(),
"ray_resampling",
([&]
{ cdf_resampling_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
starts.data_ptr<scalar_t>(),
ends.data_ptr<scalar_t>(),
weights.data_ptr<scalar_t>(),
resample_packed_info.data_ptr<int>(),
// outputs
resample_starts.data_ptr<scalar_t>(),
resample_ends.data_ptr<scalar_t>()); }));
return {resample_packed_info, resample_starts, resample_ends};
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment