"vscode:/vscode.git/clone" did not exist on "d5097f72dbefc7656e1a19a84d4fddf03f356074"
Unverified Commit b53f8a4e authored by Ruilong Li(李瑞龙)'s avatar Ruilong Li(李瑞龙) Committed by GitHub
Browse files

Support multi-res occ grid & prop net (#176)

* multres grid

* prop

* benchmark with prop and occ

* benchmark blender with weight_decay

* docs

* bump version
parent 82fd69c7
...@@ -119,3 +119,4 @@ venv.bak/ ...@@ -119,3 +119,4 @@ venv.bak/
benchmarks/ benchmarks/
outputs/ outputs/
data
\ No newline at end of file
nerfacc.ray\_resampling
=======================
.. currentmodule:: nerfacc
.. autofunction:: ray_resampling
\ No newline at end of file
...@@ -17,7 +17,6 @@ Utils ...@@ -17,7 +17,6 @@ Utils
render_weight_from_alpha render_weight_from_alpha
render_visibility render_visibility
ray_resampling
pack_data pack_data
unpack_data unpack_data
\ No newline at end of file
...@@ -7,7 +7,7 @@ See code `examples/train_ngp_nerf.py` at our `github repository`_ for details. ...@@ -7,7 +7,7 @@ See code `examples/train_ngp_nerf.py` at our `github repository`_ for details.
Benchmarks Benchmarks
------------ ------------
*updated on 2022-10-12* *updated on 2023-03-14*
Here we trained a `Instant-NGP Nerf`_ model on the `Nerf-Synthetic dataset`_. We follow the same Here we trained a `Instant-NGP Nerf`_ model on the `Nerf-Synthetic dataset`_. We follow the same
settings with the Instant-NGP paper, which uses train split for training and test split for settings with the Instant-NGP paper, which uses train split for training and test split for
...@@ -30,11 +30,15 @@ memory footprint is about 3GB. ...@@ -30,11 +30,15 @@ memory footprint is about 3GB.
+-----------------------+-------+-------+---------+-------+-------+-------+-------+-------+-------+ +-----------------------+-------+-------+---------+-------+-------+-------+-------+-------+-------+
|(training time) | 309s | 258s | 256s | 316s | 292s | 207s | 218s | 250s | 263s | |(training time) | 309s | 258s | 256s | 316s | 292s | 207s | 218s | 250s | 263s |
+-----------------------+-------+-------+---------+-------+-------+-------+-------+-------+-------+ +-----------------------+-------+-------+---------+-------+-------+-------+-------+-------+-------+
|Ours 20k steps | 35.50 | 36.16 | 29.14 | 35.23 | 37.15 | 31.71 | 24.88 | 29.91 | 32.46 | |Ours (occ) 20k steps | 35.81 | 36.87 | 29.59 | 35.70 | 37.45 | 33.63 | 24.98 | 30.64 | 33.08 |
+-----------------------+-------+-------+---------+-------+-------+-------+-------+-------+-------+ +-----------------------+-------+-------+---------+-------+-------+-------+-------+-------+-------+
|(training time) | 287s | 274s | 269s | 317s | 269s | 244s | 249s | 257s | 271s | |(training time) | 288s | 255s | 247s | 319s | 274s | 238s | 247s | 252s | 265s |
+-----------------------+-------+-------+---------+-------+-------+-------+-------+-------+-------+
|Ours (prop) 20k steps | 34.06 | 34.32 | 27.93 | 34.27 | 36.47 | 31.39 | 24.39 | 30.57 | 31.68 |
+-----------------------+-------+-------+---------+-------+-------+-------+-------+-------+-------+
|(training time) | 238s | 236s | 250s | 235s | 235s | 236s | 236s | 236s | 240s |
+-----------------------+-------+-------+---------+-------+-------+-------+-------+-------+-------+ +-----------------------+-------+-------+---------+-------+-------+-------+-------+-------+-------+
.. _`Instant-NGP Nerf`: https://github.com/NVlabs/instant-ngp/tree/51e4107edf48338e9ab0316d56a222e0adf87143 .. _`Instant-NGP Nerf`: https://github.com/NVlabs/instant-ngp/tree/51e4107edf48338e9ab0316d56a222e0adf87143
.. _`github repository`: https://github.com/KAIR-BAIR/nerfacc/tree/76c0f9817da4c9c8b5ccf827eb069ee2ce854b75 .. _`github repository`: https://github.com/KAIR-BAIR/nerfacc/
.. _`Nerf-Synthetic dataset`: https://drive.google.com/drive/folders/1JDdLGDruGNXWnM1eqY1FNL9PlStjaKWi .. _`Nerf-Synthetic dataset`: https://drive.google.com/drive/folders/1JDdLGDruGNXWnM1eqY1FNL9PlStjaKWi
...@@ -5,7 +5,7 @@ See code `examples/train_ngp_nerf.py` at our `github repository`_ for details. ...@@ -5,7 +5,7 @@ See code `examples/train_ngp_nerf.py` at our `github repository`_ for details.
Benchmarks Benchmarks
------------ ------------
*updated on 2022-11-07* *updated on 2023-03-14*
Here we trained a `Instant-NGP Nerf`_ on the `MipNerf360`_ dataset. We used train Here we trained a `Instant-NGP Nerf`_ on the `MipNerf360`_ dataset. We used train
split for training and test split for evaluation. Our experiments are conducted on a split for training and test split for evaluation. Our experiments are conducted on a
...@@ -32,12 +32,19 @@ that takes from `MipNerf360`_. ...@@ -32,12 +32,19 @@ that takes from `MipNerf360`_.
+----------------------+-------+-------+-------+-------+-------+-------+-------+-------+ +----------------------+-------+-------+-------+-------+-------+-------+-------+-------+
| MipNerf360 (~days) | 26.98 | 24.37 | 33.46 | 29.55 | 32.23 | 31.63 | 26.40 | 29.23 | | MipNerf360 (~days) | 26.98 | 24.37 | 33.46 | 29.55 | 32.23 | 31.63 | 26.40 | 29.23 |
+----------------------+-------+-------+-------+-------+-------+-------+-------+-------+ +----------------------+-------+-------+-------+-------+-------+-------+-------+-------+
| Ours (~20 mins) | 25.41 | 22.97 | 30.71 | 27.34 | 30.32 | 31.00 | 23.43 | 27.31 | | Ours (occ) | 24.76 | 22.38 | 29.72 | 26.80 | 28.02 | 30.67 | 22.39 | 26.39 |
+----------------------+-------+-------+-------+-------+-------+-------+-------+-------+ +----------------------+-------+-------+-------+-------+-------+-------+-------+-------+
| Ours (Training time) | 25min | 17min | 19min | 23min | 28min | 20min | 17min | 21min | | Ours (Training time) | 323s | 302s | 300s | 337s | 347s | 320s | 322s | 322s |
+----------------------+-------+-------+-------+-------+-------+-------+-------+-------+ +----------------------+-------+-------+-------+-------+-------+-------+-------+-------+
| Ours (prop) | 25.44 | 23.21 | 30.62 | 26.75 | 30.63 | 30.93 | 25.20 | 27.54 |
+----------------------+-------+-------+-------+-------+-------+-------+-------+-------+
| Ours (Training time) | 308s | 304s | 308s | 306s | 313s | 301s | 287s | 304s |
+----------------------+-------+-------+-------+-------+-------+-------+-------+-------+
Note `Ours (prop)` is basically a `Nerfacto_` model.
.. _`Instant-NGP Nerf`: https://arxiv.org/abs/2201.05989 .. _`Instant-NGP Nerf`: https://arxiv.org/abs/2201.05989
.. _`MipNerf360`: https://arxiv.org/abs/2111.12077 .. _`MipNerf360`: https://arxiv.org/abs/2111.12077
.. _`Nerf++`: https://arxiv.org/abs/2010.07492 .. _`Nerf++`: https://arxiv.org/abs/2010.07492
.. _`github repository`: https://github.com/KAIR-BAIR/nerfacc/tree/76c0f9817da4c9c8b5ccf827eb069ee2ce854b75 .. _`github repository`: https://github.com/KAIR-BAIR/nerfacc/
.. _`Nerfacto`: https://github.com/nerfstudio-project/nerfstudio/blob/main/nerfstudio/models/nerfacto.py
\ No newline at end of file
...@@ -123,7 +123,7 @@ Links: ...@@ -123,7 +123,7 @@ Links:
.. toctree:: .. toctree::
:glob: :glob:
:maxdepth: 1 :maxdepth: 1
:caption: Example Usages :caption: Example Usages and Benchmarks
examples/* examples/*
......
...@@ -86,7 +86,7 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -86,7 +86,7 @@ class SubjectLoader(torch.utils.data.Dataset):
near: float = None, near: float = None,
far: float = None, far: float = None,
batch_over_images: bool = True, batch_over_images: bool = True,
device: str = "cuda:0", device: str = "cpu",
): ):
super().__init__() super().__init__()
assert split in self.SPLITS, "%s" % split assert split in self.SPLITS, "%s" % split
......
...@@ -22,7 +22,7 @@ sys.path.insert( ...@@ -22,7 +22,7 @@ sys.path.insert(
from scene_manager import SceneManager from scene_manager import SceneManager
def _load_colmap(root_fp: str, subject_id: str, split: str, factor: int = 1): def _load_colmap(root_fp: str, subject_id: str, factor: int = 1):
assert factor in [1, 2, 4, 8] assert factor in [1, 2, 4, 8]
data_dir = os.path.join(root_fp, subject_id) data_dir = os.path.join(root_fp, subject_id)
...@@ -134,12 +134,66 @@ def _load_colmap(root_fp: str, subject_id: str, split: str, factor: int = 1): ...@@ -134,12 +134,66 @@ def _load_colmap(root_fp: str, subject_id: str, split: str, factor: int = 1):
"test": all_indices[all_indices % 8 == 0], "test": all_indices[all_indices % 8 == 0],
"train": all_indices[all_indices % 8 != 0], "train": all_indices[all_indices % 8 != 0],
} }
indices = split_indices[split] return images, camtoworlds, K, split_indices
# All per-image quantities must be re-indexed using the split indices.
images = images[indices]
camtoworlds = camtoworlds[indices] def similarity_from_cameras(c2w, strict_scaling):
"""
reference: nerf-factory
Get a similarity transform to normalize dataset
from c2w (OpenCV convention) cameras
:param c2w: (N, 4)
:return T (4,4) , scale (float)
"""
t = c2w[:, :3, 3]
R = c2w[:, :3, :3]
# (1) Rotate the world so that z+ is the up axis
# we estimate the up axis by averaging the camera up axes
ups = np.sum(R * np.array([0, -1.0, 0]), axis=-1)
world_up = np.mean(ups, axis=0)
world_up /= np.linalg.norm(world_up)
up_camspace = np.array([0.0, -1.0, 0.0])
c = (up_camspace * world_up).sum()
cross = np.cross(world_up, up_camspace)
skew = np.array(
[
[0.0, -cross[2], cross[1]],
[cross[2], 0.0, -cross[0]],
[-cross[1], cross[0], 0.0],
]
)
if c > -1:
R_align = np.eye(3) + skew + (skew @ skew) * 1 / (1 + c)
else:
# In the unlikely case the original data has y+ up axis,
# rotate 180-deg about x axis
R_align = np.array([[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
# R_align = np.eye(3) # DEBUG
R = R_align @ R
fwds = np.sum(R * np.array([0, 0.0, 1.0]), axis=-1)
t = (R_align @ t[..., None])[..., 0]
# (2) Recenter the scene using camera center rays
# find the closest point to the origin for each camera's center ray
nearest = t + (fwds * -t).sum(-1)[:, None] * fwds
return images, camtoworlds, K # median for more robustness
translate = -np.median(nearest, axis=0)
# translate = -np.mean(t, axis=0) # DEBUG
transform = np.eye(4)
transform[:3, 3] = translate
transform[:3, :3] = R_align
# (3) Rescale the scene using camera distances
scale_fn = np.max if strict_scaling else np.median
scale = 1.0 / scale_fn(np.linalg.norm(t + translate, axis=-1))
return transform, scale
class SubjectLoader(torch.utils.data.Dataset): class SubjectLoader(torch.utils.data.Dataset):
...@@ -169,7 +223,7 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -169,7 +223,7 @@ class SubjectLoader(torch.utils.data.Dataset):
far: float = None, far: float = None,
batch_over_images: bool = True, batch_over_images: bool = True,
factor: int = 1, factor: int = 1,
device: str = "cuda:0", device: str = "cpu",
): ):
super().__init__() super().__init__()
assert split in self.SPLITS, "%s" % split assert split in self.SPLITS, "%s" % split
...@@ -184,14 +238,25 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -184,14 +238,25 @@ class SubjectLoader(torch.utils.data.Dataset):
) )
self.color_bkgd_aug = color_bkgd_aug self.color_bkgd_aug = color_bkgd_aug
self.batch_over_images = batch_over_images self.batch_over_images = batch_over_images
self.images, self.camtoworlds, self.K = _load_colmap( self.images, self.camtoworlds, self.K, split_indices = _load_colmap(
root_fp, subject_id, split, factor root_fp, subject_id, factor
)
# normalize the scene
T, sscale = similarity_from_cameras(
self.camtoworlds, strict_scaling=False
) )
self.images = torch.from_numpy(self.images).to(device).to(torch.uint8) self.camtoworlds = np.einsum("nij, ki -> nkj", self.camtoworlds, T)
self.camtoworlds[:, :3, 3] *= sscale
# split
indices = split_indices[split]
self.images = self.images[indices]
self.camtoworlds = self.camtoworlds[indices]
# to tensor
self.images = torch.from_numpy(self.images).to(torch.uint8).to(device)
self.camtoworlds = ( self.camtoworlds = (
torch.from_numpy(self.camtoworlds).to(device).to(torch.float32) torch.from_numpy(self.camtoworlds).to(torch.float32).to(device)
) )
self.K = torch.tensor(self.K).to(device).to(torch.float32) self.K = torch.tensor(self.K).to(torch.float32).to(device)
self.height, self.width = self.images.shape[1:3] self.height, self.width = self.images.shape[1:3]
def __len__(self): def __len__(self):
...@@ -275,7 +340,7 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -275,7 +340,7 @@ class SubjectLoader(torch.utils.data.Dataset):
value=(-1.0 if self.OPENGL_CAMERA else 1.0), value=(-1.0 if self.OPENGL_CAMERA else 1.0),
) # [num_rays, 3] ) # [num_rays, 3]
# [n_cams, height, width, 3] # [num_rays, 3]
directions = (camera_dirs[:, None, :] * c2w[:, :3, :3]).sum(dim=-1) directions = (camera_dirs[:, None, :] * c2w[:, :3, :3]).sum(dim=-1)
origins = torch.broadcast_to(c2w[:, :3, -1], directions.shape) origins = torch.broadcast_to(c2w[:, :3, -1], directions.shape)
viewdirs = directions / torch.linalg.norm( viewdirs = directions / torch.linalg.norm(
......
...@@ -79,7 +79,7 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -79,7 +79,7 @@ class SubjectLoader(torch.utils.data.Dataset):
near: float = None, near: float = None,
far: float = None, far: float = None,
batch_over_images: bool = True, batch_over_images: bool = True,
device: str = "cuda:0", device: torch.device = torch.device("cpu"),
): ):
super().__init__() super().__init__()
assert split in self.SPLITS, "%s" % split assert split in self.SPLITS, "%s" % split
...@@ -110,10 +110,8 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -110,10 +110,8 @@ class SubjectLoader(torch.utils.data.Dataset):
self.images, self.camtoworlds, self.focal = _load_renderings( self.images, self.camtoworlds, self.focal = _load_renderings(
root_fp, subject_id, split root_fp, subject_id, split
) )
self.images = torch.from_numpy(self.images).to(device).to(torch.uint8) self.images = torch.from_numpy(self.images).to(torch.uint8)
self.camtoworlds = ( self.camtoworlds = torch.from_numpy(self.camtoworlds).to(torch.float32)
torch.from_numpy(self.camtoworlds).to(device).to(torch.float32)
)
self.K = torch.tensor( self.K = torch.tensor(
[ [
[self.focal, 0, self.WIDTH / 2.0], [self.focal, 0, self.WIDTH / 2.0],
...@@ -121,8 +119,10 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -121,8 +119,10 @@ class SubjectLoader(torch.utils.data.Dataset):
[0, 0, 1], [0, 0, 1],
], ],
dtype=torch.float32, dtype=torch.float32,
device=device,
) # (3, 3) ) # (3, 3)
self.images = self.images.to(device)
self.camtoworlds = self.camtoworlds.to(device)
self.K = self.K.to(device)
assert self.images.shape[1:3] == (self.HEIGHT, self.WIDTH) assert self.images.shape[1:3] == (self.HEIGHT, self.WIDTH)
def __len__(self): def __len__(self):
......
...@@ -4,6 +4,7 @@ Copyright (c) 2022 Ruilong Li, UC Berkeley. ...@@ -4,6 +4,7 @@ Copyright (c) 2022 Ruilong Li, UC Berkeley.
from typing import Callable, List, Union from typing import Callable, List, Union
import numpy as np
import torch import torch
from torch.autograd import Function from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
...@@ -41,13 +42,15 @@ trunc_exp = _TruncExp.apply ...@@ -41,13 +42,15 @@ trunc_exp = _TruncExp.apply
def contract_to_unisphere( def contract_to_unisphere(
x: torch.Tensor, x: torch.Tensor,
aabb: torch.Tensor, aabb: torch.Tensor,
ord: Union[str, int] = 2,
# ord: Union[float, int] = float("inf"),
eps: float = 1e-6, eps: float = 1e-6,
derivative: bool = False, derivative: bool = False,
): ):
aabb_min, aabb_max = torch.split(aabb, 3, dim=-1) aabb_min, aabb_max = torch.split(aabb, 3, dim=-1)
x = (x - aabb_min) / (aabb_max - aabb_min) x = (x - aabb_min) / (aabb_max - aabb_min)
x = x * 2 - 1 # aabb is at [-1, 1] x = x * 2 - 1 # aabb is at [-1, 1]
mag = x.norm(dim=-1, keepdim=True) mag = torch.linalg.norm(x, ord=ord, dim=-1, keepdim=True)
mask = mag.squeeze(-1) > 1 mask = mag.squeeze(-1) > 1
if derivative: if derivative:
...@@ -63,8 +66,8 @@ def contract_to_unisphere( ...@@ -63,8 +66,8 @@ def contract_to_unisphere(
return x return x
class NGPradianceField(torch.nn.Module): class NGPRadianceField(torch.nn.Module):
"""Instance-NGP radiance Field""" """Instance-NGP Radiance Field"""
def __init__( def __init__(
self, self,
...@@ -73,6 +76,8 @@ class NGPradianceField(torch.nn.Module): ...@@ -73,6 +76,8 @@ class NGPradianceField(torch.nn.Module):
use_viewdirs: bool = True, use_viewdirs: bool = True,
density_activation: Callable = lambda x: trunc_exp(x - 1), density_activation: Callable = lambda x: trunc_exp(x - 1),
unbounded: bool = False, unbounded: bool = False,
base_resolution: int = 16,
max_resolution: int = 4096,
geo_feat_dim: int = 15, geo_feat_dim: int = 15,
n_levels: int = 16, n_levels: int = 16,
log2_hashmap_size: int = 19, log2_hashmap_size: int = 19,
...@@ -85,9 +90,15 @@ class NGPradianceField(torch.nn.Module): ...@@ -85,9 +90,15 @@ class NGPradianceField(torch.nn.Module):
self.use_viewdirs = use_viewdirs self.use_viewdirs = use_viewdirs
self.density_activation = density_activation self.density_activation = density_activation
self.unbounded = unbounded self.unbounded = unbounded
self.base_resolution = base_resolution
self.max_resolution = max_resolution
self.geo_feat_dim = geo_feat_dim self.geo_feat_dim = geo_feat_dim
per_level_scale = 1.4472692012786865 self.n_levels = n_levels
self.log2_hashmap_size = log2_hashmap_size
per_level_scale = np.exp(
(np.log(max_resolution) - np.log(base_resolution)) / (n_levels - 1)
).tolist()
if self.use_viewdirs: if self.use_viewdirs:
self.direction_encoding = tcnn.Encoding( self.direction_encoding = tcnn.Encoding(
...@@ -113,7 +124,7 @@ class NGPradianceField(torch.nn.Module): ...@@ -113,7 +124,7 @@ class NGPradianceField(torch.nn.Module):
"n_levels": n_levels, "n_levels": n_levels,
"n_features_per_level": 2, "n_features_per_level": 2,
"log2_hashmap_size": log2_hashmap_size, "log2_hashmap_size": log2_hashmap_size,
"base_resolution": 16, "base_resolution": base_resolution,
"per_level_scale": per_level_scale, "per_level_scale": per_level_scale,
}, },
network_config={ network_config={
...@@ -138,7 +149,7 @@ class NGPradianceField(torch.nn.Module): ...@@ -138,7 +149,7 @@ class NGPradianceField(torch.nn.Module):
network_config={ network_config={
"otype": "FullyFusedMLP", "otype": "FullyFusedMLP",
"activation": "ReLU", "activation": "ReLU",
"output_activation": "Sigmoid", "output_activation": "None",
"n_neurons": 64, "n_neurons": 64,
"n_hidden_layers": 2, "n_hidden_layers": 2,
}, },
...@@ -168,19 +179,21 @@ class NGPradianceField(torch.nn.Module): ...@@ -168,19 +179,21 @@ class NGPradianceField(torch.nn.Module):
else: else:
return density return density
def _query_rgb(self, dir, embedding): def _query_rgb(self, dir, embedding, apply_act: bool = True):
# tcnn requires directions in the range [0, 1] # tcnn requires directions in the range [0, 1]
if self.use_viewdirs: if self.use_viewdirs:
dir = (dir + 1.0) / 2.0 dir = (dir + 1.0) / 2.0
d = self.direction_encoding(dir.view(-1, dir.shape[-1])) d = self.direction_encoding(dir.reshape(-1, dir.shape[-1]))
h = torch.cat([d, embedding.view(-1, self.geo_feat_dim)], dim=-1) h = torch.cat([d, embedding.reshape(-1, self.geo_feat_dim)], dim=-1)
else: else:
h = embedding.view(-1, self.geo_feat_dim) h = embedding.reshape(-1, self.geo_feat_dim)
rgb = ( rgb = (
self.mlp_head(h) self.mlp_head(h)
.view(list(embedding.shape[:-1]) + [3]) .reshape(list(embedding.shape[:-1]) + [3])
.to(embedding) .to(embedding)
) )
if apply_act:
rgb = torch.sigmoid(rgb)
return rgb return rgb
def forward( def forward(
...@@ -194,4 +207,73 @@ class NGPradianceField(torch.nn.Module): ...@@ -194,4 +207,73 @@ class NGPradianceField(torch.nn.Module):
), f"{positions.shape} v.s. {directions.shape}" ), f"{positions.shape} v.s. {directions.shape}"
density, embedding = self.query_density(positions, return_feat=True) density, embedding = self.query_density(positions, return_feat=True)
rgb = self._query_rgb(directions, embedding=embedding) rgb = self._query_rgb(directions, embedding=embedding)
return rgb, density return rgb, density # type: ignore
class NGPDensityField(torch.nn.Module):
"""Instance-NGP Density Field used for resampling"""
def __init__(
self,
aabb: Union[torch.Tensor, List[float]],
num_dim: int = 3,
density_activation: Callable = lambda x: trunc_exp(x - 1),
unbounded: bool = False,
base_resolution: int = 16,
max_resolution: int = 128,
n_levels: int = 5,
log2_hashmap_size: int = 17,
) -> None:
super().__init__()
if not isinstance(aabb, torch.Tensor):
aabb = torch.tensor(aabb, dtype=torch.float32)
self.register_buffer("aabb", aabb)
self.num_dim = num_dim
self.density_activation = density_activation
self.unbounded = unbounded
self.base_resolution = base_resolution
self.max_resolution = max_resolution
self.n_levels = n_levels
self.log2_hashmap_size = log2_hashmap_size
per_level_scale = np.exp(
(np.log(max_resolution) - np.log(base_resolution)) / (n_levels - 1)
).tolist()
self.mlp_base = tcnn.NetworkWithInputEncoding(
n_input_dims=num_dim,
n_output_dims=1,
encoding_config={
"otype": "HashGrid",
"n_levels": n_levels,
"n_features_per_level": 2,
"log2_hashmap_size": log2_hashmap_size,
"base_resolution": base_resolution,
"per_level_scale": per_level_scale,
},
network_config={
"otype": "FullyFusedMLP",
"activation": "ReLU",
"output_activation": "None",
"n_neurons": 64,
"n_hidden_layers": 1,
},
)
def forward(self, positions: torch.Tensor):
if self.unbounded:
positions = contract_to_unisphere(positions, self.aabb)
else:
aabb_min, aabb_max = torch.split(self.aabb, self.num_dim, dim=-1)
positions = (positions - aabb_min) / (aabb_max - aabb_min)
selector = ((positions > 0.0) & (positions < 1.0)).all(dim=-1)
density_before_activation = (
self.mlp_base(positions.view(-1, self.num_dim))
.view(list(positions.shape[:-1]) + [1])
.to(positions)
)
density = (
self.density_activation(density_before_activation)
* selector[..., None]
)
return density
...@@ -4,3 +4,4 @@ imageio ...@@ -4,3 +4,4 @@ imageio
numpy numpy
tqdm tqdm
scipy scipy
lpips
\ No newline at end of file
...@@ -16,28 +16,26 @@ from datasets.dnerf_synthetic import SubjectLoader ...@@ -16,28 +16,26 @@ from datasets.dnerf_synthetic import SubjectLoader
from radiance_fields.mlp import DNeRFRadianceField from radiance_fields.mlp import DNeRFRadianceField
from utils import render_image, set_random_seed from utils import render_image, set_random_seed
from nerfacc import ContractionType, OccupancyGrid from nerfacc import OccupancyGrid
if __name__ == "__main__": device = "cuda:0"
set_random_seed(42)
device = "cuda:0" parser = argparse.ArgumentParser()
set_random_seed(42) parser.add_argument(
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_root", "--data_root",
type=str, type=str,
default=str(pathlib.Path.cwd() / "data/dnerf"), default=str(pathlib.Path.cwd() / "data/dnerf"),
help="the root dir of the dataset", help="the root dir of the dataset",
) )
parser.add_argument( parser.add_argument(
"--train_split", "--train_split",
type=str, type=str,
default="train", default="train",
choices=["train"], choices=["train"],
help="which train split to use", help="which train split to use",
) )
parser.add_argument( parser.add_argument(
"--scene", "--scene",
type=str, type=str,
default="lego", default="lego",
...@@ -53,40 +51,37 @@ if __name__ == "__main__": ...@@ -53,40 +51,37 @@ if __name__ == "__main__":
"trex", "trex",
], ],
help="which scene to use", help="which scene to use",
) )
parser.add_argument( parser.add_argument(
"--aabb", "--aabb",
type=lambda s: [float(item) for item in s.split(",")], type=lambda s: [float(item) for item in s.split(",")],
default="-1.5,-1.5,-1.5,1.5,1.5,1.5", default="-1.5,-1.5,-1.5,1.5,1.5,1.5",
help="delimited list input", help="delimited list input",
) )
parser.add_argument( parser.add_argument(
"--test_chunk_size", "--test_chunk_size",
type=int, type=int,
default=8192, default=8192,
) )
parser.add_argument("--cone_angle", type=float, default=0.0) parser.add_argument("--cone_angle", type=float, default=0.0)
args = parser.parse_args() args = parser.parse_args()
render_n_samples = 1024 render_n_samples = 1024
# setup the scene bounding box. # setup the scene bounding box.
contraction_type = ContractionType.AABB scene_aabb = torch.tensor(args.aabb, dtype=torch.float32, device=device)
scene_aabb = torch.tensor(args.aabb, dtype=torch.float32, device=device) near_plane = None
near_plane = None far_plane = None
far_plane = None render_step_size = (
render_step_size = ( (scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples
(scene_aabb[3:] - scene_aabb[:3]).max() ).item()
* math.sqrt(3)
/ render_n_samples # setup the radiance field we want to train.
).item() max_steps = 30000
grad_scaler = torch.cuda.amp.GradScaler(1)
# setup the radiance field we want to train. radiance_field = DNeRFRadianceField().to(device)
max_steps = 30000 optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4)
grad_scaler = torch.cuda.amp.GradScaler(1) scheduler = torch.optim.lr_scheduler.MultiStepLR(
radiance_field = DNeRFRadianceField().to(device)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, optimizer,
milestones=[ milestones=[
max_steps // 2, max_steps // 2,
...@@ -95,35 +90,35 @@ if __name__ == "__main__": ...@@ -95,35 +90,35 @@ if __name__ == "__main__":
max_steps * 9 // 10, max_steps * 9 // 10,
], ],
gamma=0.33, gamma=0.33,
) )
# setup the dataset # setup the dataset
target_sample_batch_size = 1 << 16 target_sample_batch_size = 1 << 16
grid_resolution = 128 grid_resolution = 128
train_dataset = SubjectLoader( train_dataset = SubjectLoader(
subject_id=args.scene, subject_id=args.scene,
root_fp=args.data_root, root_fp=args.data_root,
split=args.train_split, split=args.train_split,
num_rays=target_sample_batch_size // render_n_samples, num_rays=target_sample_batch_size // render_n_samples,
) device=device,
)
test_dataset = SubjectLoader( test_dataset = SubjectLoader(
subject_id=args.scene, subject_id=args.scene,
root_fp=args.data_root, root_fp=args.data_root,
split="test", split="test",
num_rays=None, num_rays=None,
) device=device,
)
occupancy_grid = OccupancyGrid( occupancy_grid = OccupancyGrid(
roi_aabb=args.aabb, roi_aabb=args.aabb, resolution=grid_resolution
resolution=grid_resolution, ).to(device)
contraction_type=contraction_type,
).to(device)
# training # training
step = 0 step = 0
tic = time.time() tic = time.time()
for epoch in range(10000000): for epoch in range(10000000):
for i in range(len(train_dataset)): for i in range(len(train_dataset)):
radiance_field.train() radiance_field.train()
data = train_dataset[i] data = train_dataset[i]
...@@ -163,8 +158,7 @@ if __name__ == "__main__": ...@@ -163,8 +158,7 @@ if __name__ == "__main__":
# dynamic batch size for rays to keep sample batch size constant. # dynamic batch size for rays to keep sample batch size constant.
num_rays = len(pixels) num_rays = len(pixels)
num_rays = int( num_rays = int(
num_rays num_rays * (target_sample_batch_size / float(n_rendering_samples))
* (target_sample_batch_size / float(n_rendering_samples))
) )
train_dataset.update_num_rays(num_rays) train_dataset.update_num_rays(num_rays)
alive_ray_mask = acc.squeeze(-1) > 0 alive_ray_mask = acc.squeeze(-1) > 0
......
...@@ -17,26 +17,24 @@ from utils import render_image, set_random_seed ...@@ -17,26 +17,24 @@ from utils import render_image, set_random_seed
from nerfacc import ContractionType, OccupancyGrid from nerfacc import ContractionType, OccupancyGrid
if __name__ == "__main__": device = "cuda:0"
set_random_seed(42)
device = "cuda:0" parser = argparse.ArgumentParser()
set_random_seed(42) parser.add_argument(
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_root", "--data_root",
type=str, type=str,
default=str(pathlib.Path.cwd() / "data/nerf_synthetic"), default=str(pathlib.Path.cwd() / "data/nerf_synthetic"),
help="the root dir of the dataset", help="the root dir of the dataset",
) )
parser.add_argument( parser.add_argument(
"--train_split", "--train_split",
type=str, type=str,
default="trainval", default="trainval",
choices=["train", "trainval"], choices=["train", "trainval"],
help="which train split to use", help="which train split to use",
) )
parser.add_argument( parser.add_argument(
"--scene", "--scene",
type=str, type=str,
default="lego", default="lego",
...@@ -54,30 +52,30 @@ if __name__ == "__main__": ...@@ -54,30 +52,30 @@ if __name__ == "__main__":
"garden", "garden",
], ],
help="which scene to use", help="which scene to use",
) )
parser.add_argument( parser.add_argument(
"--aabb", "--aabb",
type=lambda s: [float(item) for item in s.split(",")], type=lambda s: [float(item) for item in s.split(",")],
default="-1.5,-1.5,-1.5,1.5,1.5,1.5", default="-1.5,-1.5,-1.5,1.5,1.5,1.5",
help="delimited list input", help="delimited list input",
) )
parser.add_argument( parser.add_argument(
"--test_chunk_size", "--test_chunk_size",
type=int, type=int,
default=8192, default=8192,
) )
parser.add_argument( parser.add_argument(
"--unbounded", "--unbounded",
action="store_true", action="store_true",
help="whether to use unbounded rendering", help="whether to use unbounded rendering",
) )
parser.add_argument("--cone_angle", type=float, default=0.0) parser.add_argument("--cone_angle", type=float, default=0.0)
args = parser.parse_args() args = parser.parse_args()
render_n_samples = 1024 render_n_samples = 1024
# setup the scene bounding box. # setup the scene bounding box.
if args.unbounded: if args.unbounded:
print("Using unbounded rendering") print("Using unbounded rendering")
contraction_type = ContractionType.UN_BOUNDED_SPHERE contraction_type = ContractionType.UN_BOUNDED_SPHERE
# contraction_type = ContractionType.UN_BOUNDED_TANH # contraction_type = ContractionType.UN_BOUNDED_TANH
...@@ -85,7 +83,7 @@ if __name__ == "__main__": ...@@ -85,7 +83,7 @@ if __name__ == "__main__":
near_plane = 0.2 near_plane = 0.2
far_plane = 1e4 far_plane = 1e4
render_step_size = 1e-2 render_step_size = 1e-2
else: else:
contraction_type = ContractionType.AABB contraction_type = ContractionType.AABB
scene_aabb = torch.tensor(args.aabb, dtype=torch.float32, device=device) scene_aabb = torch.tensor(args.aabb, dtype=torch.float32, device=device)
near_plane = None near_plane = None
...@@ -96,12 +94,12 @@ if __name__ == "__main__": ...@@ -96,12 +94,12 @@ if __name__ == "__main__":
/ render_n_samples / render_n_samples
).item() ).item()
# setup the radiance field we want to train. # setup the radiance field we want to train.
max_steps = 50000 max_steps = 50000
grad_scaler = torch.cuda.amp.GradScaler(1) grad_scaler = torch.cuda.amp.GradScaler(1)
radiance_field = VanillaNeRFRadianceField().to(device) radiance_field = VanillaNeRFRadianceField().to(device)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4) optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR( scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, optimizer,
milestones=[ milestones=[
max_steps // 2, max_steps // 2,
...@@ -110,50 +108,50 @@ if __name__ == "__main__": ...@@ -110,50 +108,50 @@ if __name__ == "__main__":
max_steps * 9 // 10, max_steps * 9 // 10,
], ],
gamma=0.33, gamma=0.33,
) )
# setup the dataset # setup the dataset
train_dataset_kwargs = {} train_dataset_kwargs = {}
test_dataset_kwargs = {} test_dataset_kwargs = {}
if args.scene == "garden": if args.scene == "garden":
from datasets.nerf_360_v2 import SubjectLoader from datasets.nerf_360_v2 import SubjectLoader
target_sample_batch_size = 1 << 16 target_sample_batch_size = 1 << 16
train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4} train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4}
test_dataset_kwargs = {"factor": 4} test_dataset_kwargs = {"factor": 4}
grid_resolution = 128 grid_resolution = 128
else: else:
from datasets.nerf_synthetic import SubjectLoader from datasets.nerf_synthetic import SubjectLoader
target_sample_batch_size = 1 << 16 target_sample_batch_size = 1 << 16
grid_resolution = 128 grid_resolution = 128
train_dataset = SubjectLoader( train_dataset = SubjectLoader(
subject_id=args.scene, subject_id=args.scene,
root_fp=args.data_root, root_fp=args.data_root,
split=args.train_split, split=args.train_split,
num_rays=target_sample_batch_size // render_n_samples, num_rays=target_sample_batch_size // render_n_samples,
**train_dataset_kwargs, **train_dataset_kwargs,
) )
test_dataset = SubjectLoader( test_dataset = SubjectLoader(
subject_id=args.scene, subject_id=args.scene,
root_fp=args.data_root, root_fp=args.data_root,
split="test", split="test",
num_rays=None, num_rays=None,
**test_dataset_kwargs, **test_dataset_kwargs,
) )
occupancy_grid = OccupancyGrid( occupancy_grid = OccupancyGrid(
roi_aabb=args.aabb, roi_aabb=args.aabb,
resolution=grid_resolution, resolution=grid_resolution,
contraction_type=contraction_type, contraction_type=contraction_type,
).to(device) ).to(device)
# training # training
step = 0 step = 0
tic = time.time() tic = time.time()
for epoch in range(10000000): for epoch in range(10000000):
for i in range(len(train_dataset)): for i in range(len(train_dataset)):
radiance_field.train() radiance_field.train()
data = train_dataset[i] data = train_dataset[i]
...@@ -189,8 +187,7 @@ if __name__ == "__main__": ...@@ -189,8 +187,7 @@ if __name__ == "__main__":
# dynamic batch size for rays to keep sample batch size constant. # dynamic batch size for rays to keep sample batch size constant.
num_rays = len(pixels) num_rays = len(pixels)
num_rays = int( num_rays = int(
num_rays num_rays * (target_sample_batch_size / float(n_rendering_samples))
* (target_sample_batch_size / float(n_rendering_samples))
) )
train_dataset.update_num_rays(num_rays) train_dataset.update_num_rays(num_rays)
alive_ray_mask = acc.squeeze(-1) > 0 alive_ray_mask = acc.squeeze(-1) > 0
......
...@@ -12,172 +12,155 @@ import numpy as np ...@@ -12,172 +12,155 @@ import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tqdm import tqdm
from radiance_fields.ngp import NGPradianceField from lpips import LPIPS
from utils import render_image, set_random_seed from radiance_fields.ngp import NGPRadianceField
from utils import (
MIPNERF360_UNBOUNDED_SCENES,
NERF_SYNTHETIC_SCENES,
enlarge_aabb,
render_image,
set_random_seed,
)
from nerfacc import ContractionType, OccupancyGrid from nerfacc import OccupancyGrid
if __name__ == "__main__": parser = argparse.ArgumentParser()
parser.add_argument(
device = "cuda:0"
set_random_seed(42)
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_root", "--data_root",
type=str, type=str,
default=str(pathlib.Path.cwd() / "data"), # default=str(pathlib.Path.cwd() / "data/360_v2"),
default=str(pathlib.Path.cwd() / "data/nerf_synthetic"),
help="the root dir of the dataset", help="the root dir of the dataset",
) )
parser.add_argument( parser.add_argument(
"--train_split", "--train_split",
type=str, type=str,
default="trainval", default="train",
choices=["train", "trainval"], choices=["train", "trainval"],
help="which train split to use", help="which train split to use",
) )
parser.add_argument( parser.add_argument(
"--scene", "--scene",
type=str, type=str,
default="lego", default="lego",
choices=[ choices=NERF_SYNTHETIC_SCENES + MIPNERF360_UNBOUNDED_SCENES,
# nerf synthetic
"chair",
"drums",
"ficus",
"hotdog",
"lego",
"materials",
"mic",
"ship",
# mipnerf360 unbounded
"garden",
"bicycle",
"bonsai",
"counter",
"kitchen",
"room",
"stump",
],
help="which scene to use", help="which scene to use",
) )
parser.add_argument( parser.add_argument(
"--aabb",
type=lambda s: [float(item) for item in s.split(",")],
default="-1.5,-1.5,-1.5,1.5,1.5,1.5",
help="delimited list input",
)
parser.add_argument(
"--test_chunk_size", "--test_chunk_size",
type=int, type=int,
default=8192, default=8192,
) )
parser.add_argument( args = parser.parse_args()
"--unbounded",
action="store_true",
help="whether to use unbounded rendering",
)
parser.add_argument(
"--auto_aabb",
action="store_true",
help="whether to automatically compute the aabb",
)
parser.add_argument("--cone_angle", type=float, default=0.0)
args = parser.parse_args()
render_n_samples = 1024 device = "cuda:0"
set_random_seed(42)
# setup the dataset if args.scene in MIPNERF360_UNBOUNDED_SCENES:
train_dataset_kwargs = {}
test_dataset_kwargs = {}
if args.unbounded:
from datasets.nerf_360_v2 import SubjectLoader from datasets.nerf_360_v2 import SubjectLoader
target_sample_batch_size = 1 << 20 # training parameters
max_steps = 100000
init_batch_size = 1024
target_sample_batch_size = 1 << 18
weight_decay = 0.0
# scene parameters
aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=device)
near_plane = 0.02
far_plane = None
# dataset parameters
train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4} train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4}
test_dataset_kwargs = {"factor": 4} test_dataset_kwargs = {"factor": 4}
grid_resolution = 256 # model parameters
else: grid_resolution = 128
grid_nlvl = 4
# render parameters
render_step_size = 1e-3
alpha_thre = 1e-2
cone_angle = 0.004
else:
from datasets.nerf_synthetic import SubjectLoader from datasets.nerf_synthetic import SubjectLoader
# training parameters
max_steps = 20000
init_batch_size = 1024
target_sample_batch_size = 1 << 18 target_sample_batch_size = 1 << 18
weight_decay = (
1e-5 if args.scene in ["materials", "ficus", "drums"] else 1e-6
)
# scene parameters
aabb = torch.tensor([-1.5, -1.5, -1.5, 1.5, 1.5, 1.5], device=device)
near_plane = None
far_plane = None
# dataset parameters
train_dataset_kwargs = {}
test_dataset_kwargs = {}
# model parameters
grid_resolution = 128 grid_resolution = 128
grid_nlvl = 1
# render parameters
render_step_size = 5e-3
alpha_thre = 0.0
cone_angle = 0.0
train_dataset = SubjectLoader( train_dataset = SubjectLoader(
subject_id=args.scene, subject_id=args.scene,
root_fp=args.data_root, root_fp=args.data_root,
split=args.train_split, split=args.train_split,
num_rays=target_sample_batch_size // render_n_samples, num_rays=init_batch_size,
device=device,
**train_dataset_kwargs, **train_dataset_kwargs,
) )
test_dataset = SubjectLoader( test_dataset = SubjectLoader(
subject_id=args.scene, subject_id=args.scene,
root_fp=args.data_root, root_fp=args.data_root,
split="test", split="test",
num_rays=None, num_rays=None,
device=device,
**test_dataset_kwargs, **test_dataset_kwargs,
) )
if args.auto_aabb: # setup scene aabb
camera_locs = torch.cat( scene_aabb = enlarge_aabb(aabb, 1 << (grid_nlvl - 1))
[train_dataset.camtoworlds, test_dataset.camtoworlds]
)[:, :3, -1]
args.aabb = torch.cat(
[camera_locs.min(dim=0).values, camera_locs.max(dim=0).values]
).tolist()
print("Using auto aabb", args.aabb)
# setup the scene bounding box.
if args.unbounded:
print("Using unbounded rendering")
contraction_type = ContractionType.UN_BOUNDED_SPHERE
# contraction_type = ContractionType.UN_BOUNDED_TANH
scene_aabb = None
near_plane = 0.2
far_plane = 1e4
render_step_size = 1e-2
alpha_thre = 1e-2
else:
contraction_type = ContractionType.AABB
scene_aabb = torch.tensor(args.aabb, dtype=torch.float32, device=device)
near_plane = None
far_plane = None
render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max()
* math.sqrt(3)
/ render_n_samples
).item()
alpha_thre = 0.0
# setup the radiance field we want to train. # setup the radiance field we want to train.
max_steps = 20000 grad_scaler = torch.cuda.amp.GradScaler(2**10)
grad_scaler = torch.cuda.amp.GradScaler(2**10) radiance_field = NGPRadianceField(aabb=scene_aabb).to(device)
radiance_field = NGPradianceField( optimizer = torch.optim.Adam(
aabb=args.aabb, radiance_field.parameters(), lr=1e-2, eps=1e-15, weight_decay=weight_decay
unbounded=args.unbounded, )
).to(device) scheduler = torch.optim.lr_scheduler.ChainedScheduler(
optimizer = torch.optim.Adam( [
radiance_field.parameters(), lr=1e-2, eps=1e-15 torch.optim.lr_scheduler.LinearLR(
) optimizer, start_factor=0.01, total_iters=100
scheduler = torch.optim.lr_scheduler.MultiStepLR( ),
torch.optim.lr_scheduler.MultiStepLR(
optimizer, optimizer,
milestones=[max_steps // 2, max_steps * 3 // 4, max_steps * 9 // 10], milestones=[
max_steps // 2,
max_steps * 3 // 4,
max_steps * 9 // 10,
],
gamma=0.33, gamma=0.33,
) ),
]
)
occupancy_grid = OccupancyGrid(
roi_aabb=aabb, resolution=grid_resolution, levels=grid_nlvl
).to(device)
occupancy_grid = OccupancyGrid( lpips_net = LPIPS(net="vgg").to(device)
roi_aabb=args.aabb, lpips_norm_fn = lambda x: x[None, ...].permute(0, 3, 1, 2) * 2 - 1
resolution=grid_resolution, lpips_fn = lambda x, y: lpips_net(lpips_norm_fn(x), lpips_norm_fn(y)).mean()
contraction_type=contraction_type,
).to(device)
# training # training
step = 0 tic = time.time()
tic = time.time() for step in range(max_steps + 1):
for epoch in range(10000000):
for i in range(len(train_dataset)):
radiance_field.train() radiance_field.train()
i = torch.randint(0, len(train_dataset), (1,)).item()
data = train_dataset[i] data = train_dataset[i]
render_bkgd = data["color_bkgd"] render_bkgd = data["color_bkgd"]
...@@ -185,61 +168,42 @@ if __name__ == "__main__": ...@@ -185,61 +168,42 @@ if __name__ == "__main__":
pixels = data["pixels"] pixels = data["pixels"]
def occ_eval_fn(x): def occ_eval_fn(x):
if args.cone_angle > 0.0:
# randomly sample a camera for computing step size.
camera_ids = torch.randint(
0, len(train_dataset), (x.shape[0],), device=device
)
origins = train_dataset.camtoworlds[camera_ids, :3, -1]
t = (origins - x).norm(dim=-1, keepdim=True)
# compute actual step size used in marching, based on the distance to the camera.
step_size = torch.clamp(
t * args.cone_angle, min=render_step_size
)
# filter out the points that are not in the near far plane.
if (near_plane is not None) and (far_plane is not None):
step_size = torch.where(
(t > near_plane) & (t < far_plane),
step_size,
torch.zeros_like(step_size),
)
else:
step_size = render_step_size
# compute occupancy
density = radiance_field.query_density(x) density = radiance_field.query_density(x)
return density * step_size return density * render_step_size
# update occupancy grid # update occupancy grid
occupancy_grid.every_n_step(step=step, occ_eval_fn=occ_eval_fn) occupancy_grid.every_n_step(
step=step,
occ_eval_fn=occ_eval_fn,
occ_thre=1e-2,
)
# render # render
rgb, acc, depth, n_rendering_samples = render_image( rgb, acc, depth, n_rendering_samples = render_image(
radiance_field, radiance_field,
occupancy_grid, occupancy_grid,
rays, rays,
scene_aabb, scene_aabb=scene_aabb,
# rendering options # rendering options
near_plane=near_plane, near_plane=near_plane,
far_plane=far_plane,
render_step_size=render_step_size, render_step_size=render_step_size,
render_bkgd=render_bkgd, render_bkgd=render_bkgd,
cone_angle=args.cone_angle, cone_angle=cone_angle,
alpha_thre=alpha_thre, alpha_thre=alpha_thre,
) )
if n_rendering_samples == 0: if n_rendering_samples == 0:
continue continue
if target_sample_batch_size > 0:
# dynamic batch size for rays to keep sample batch size constant. # dynamic batch size for rays to keep sample batch size constant.
num_rays = len(pixels) num_rays = len(pixels)
num_rays = int( num_rays = int(
num_rays num_rays * (target_sample_batch_size / float(n_rendering_samples))
* (target_sample_batch_size / float(n_rendering_samples))
) )
train_dataset.update_num_rays(num_rays) train_dataset.update_num_rays(num_rays)
alive_ray_mask = acc.squeeze(-1) > 0
# compute loss # compute loss
loss = F.smooth_l1_loss(rgb[alive_ray_mask], pixels[alive_ray_mask]) loss = F.smooth_l1_loss(rgb, pixels)
optimizer.zero_grad() optimizer.zero_grad()
# do not unscale it because we are using Adam. # do not unscale it because we are using Adam.
...@@ -247,14 +211,15 @@ if __name__ == "__main__": ...@@ -247,14 +211,15 @@ if __name__ == "__main__":
optimizer.step() optimizer.step()
scheduler.step() scheduler.step()
if step % 10000 == 0: if step % 5000 == 0:
elapsed_time = time.time() - tic elapsed_time = time.time() - tic
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask]) loss = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(loss) / np.log(10.0)
print( print(
f"elapsed_time={elapsed_time:.2f}s | step={step} | " f"elapsed_time={elapsed_time:.2f}s | step={step} | "
f"loss={loss:.5f} | " f"loss={loss:.5f} | psnr={psnr:.2f} | "
f"alive_ray_mask={alive_ray_mask.long().sum():d} | " f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} | "
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |" f"max_depth={depth.max():.3f} | "
) )
if step > 0 and step % max_steps == 0: if step > 0 and step % max_steps == 0:
...@@ -262,6 +227,7 @@ if __name__ == "__main__": ...@@ -262,6 +227,7 @@ if __name__ == "__main__":
radiance_field.eval() radiance_field.eval()
psnrs = [] psnrs = []
lpips = []
with torch.no_grad(): with torch.no_grad():
for i in tqdm.tqdm(range(len(test_dataset))): for i in tqdm.tqdm(range(len(test_dataset))):
data = test_dataset[i] data = test_dataset[i]
...@@ -274,13 +240,12 @@ if __name__ == "__main__": ...@@ -274,13 +240,12 @@ if __name__ == "__main__":
radiance_field, radiance_field,
occupancy_grid, occupancy_grid,
rays, rays,
scene_aabb, scene_aabb=scene_aabb,
# rendering options # rendering options
near_plane=near_plane, near_plane=near_plane,
far_plane=far_plane,
render_step_size=render_step_size, render_step_size=render_step_size,
render_bkgd=render_bkgd, render_bkgd=render_bkgd,
cone_angle=args.cone_angle, cone_angle=cone_angle,
alpha_thre=alpha_thre, alpha_thre=alpha_thre,
# test options # test options
test_chunk_size=args.test_chunk_size, test_chunk_size=args.test_chunk_size,
...@@ -288,21 +253,18 @@ if __name__ == "__main__": ...@@ -288,21 +253,18 @@ if __name__ == "__main__":
mse = F.mse_loss(rgb, pixels) mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0) psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item()) psnrs.append(psnr.item())
# imageio.imwrite( lpips.append(lpips_fn(rgb, pixels).item())
# "acc_binary_test.png", # if i == 0:
# ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8),
# )
# imageio.imwrite( # imageio.imwrite(
# "rgb_test.png", # "rgb_test.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8), # (rgb.cpu().numpy() * 255).astype(np.uint8),
# ) # )
# break # imageio.imwrite(
# "rgb_error.png",
# (
# (rgb - pixels).norm(dim=-1).cpu().numpy() * 255
# ).astype(np.uint8),
# )
psnr_avg = sum(psnrs) / len(psnrs) psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation: psnr_avg={psnr_avg}") lpips_avg = sum(lpips) / len(lpips)
train_dataset.training = True print(f"evaluation: psnr_avg={psnr_avg}, lpips_avg={lpips_avg}")
if step == max_steps:
print("training stops")
exit()
step += 1
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
import argparse
import itertools
import pathlib
import time
import imageio
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from lpips import LPIPS
from radiance_fields.ngp import NGPDensityField, NGPRadianceField
from utils import (
MIPNERF360_UNBOUNDED_SCENES,
NERF_SYNTHETIC_SCENES,
render_image_proposal,
set_random_seed,
)
from nerfacc.proposal import (
compute_prop_loss,
get_proposal_annealing_fn,
get_proposal_requires_grad_fn,
)
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_root",
type=str,
# default=str(pathlib.Path.cwd() / "data/360_v2"),
default=str(pathlib.Path.cwd() / "data/nerf_synthetic"),
help="the root dir of the dataset",
)
parser.add_argument(
"--train_split",
type=str,
default="train",
choices=["train", "trainval"],
help="which train split to use",
)
parser.add_argument(
"--scene",
type=str,
default="lego",
choices=NERF_SYNTHETIC_SCENES + MIPNERF360_UNBOUNDED_SCENES,
help="which scene to use",
)
parser.add_argument(
"--test_chunk_size",
type=int,
default=8192,
)
args = parser.parse_args()
device = "cuda:0"
set_random_seed(42)
if args.scene in MIPNERF360_UNBOUNDED_SCENES:
from datasets.nerf_360_v2 import SubjectLoader
# training parameters
max_steps = 100000
init_batch_size = 4096
weight_decay = 0.0
# scene parameters
unbounded = True
aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=device)
near_plane = 0.2 # TODO: Try 0.02
far_plane = 1e3
# dataset parameters
train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4}
test_dataset_kwargs = {"factor": 4}
# model parameters
proposal_networks = [
NGPDensityField(
aabb=aabb,
unbounded=unbounded,
n_levels=5,
max_resolution=128,
).to(device),
NGPDensityField(
aabb=aabb,
unbounded=unbounded,
n_levels=5,
max_resolution=256,
).to(device),
]
# render parameters
num_samples = 48
num_samples_per_prop = [256, 96]
sampling_type = "lindisp"
opaque_bkgd = True
else:
from datasets.nerf_synthetic import SubjectLoader
# training parameters
max_steps = 20000
init_batch_size = 4096
weight_decay = (
1e-5 if args.scene in ["materials", "ficus", "drums"] else 1e-6
)
# scene parameters
unbounded = False
aabb = torch.tensor([-1.5, -1.5, -1.5, 1.5, 1.5, 1.5], device=device)
near_plane = 2.0
far_plane = 6.0
# dataset parameters
train_dataset_kwargs = {}
test_dataset_kwargs = {}
# model parameters
proposal_networks = [
NGPDensityField(
aabb=aabb,
unbounded=unbounded,
n_levels=5,
max_resolution=128,
).to(device),
]
# render parameters
num_samples = 64
num_samples_per_prop = [128]
sampling_type = "uniform"
opaque_bkgd = False
train_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split=args.train_split,
num_rays=init_batch_size,
device=device,
**train_dataset_kwargs,
)
test_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split="test",
num_rays=None,
device=device,
**test_dataset_kwargs,
)
# setup the radiance field we want to train.
grad_scaler = torch.cuda.amp.GradScaler(2**10)
radiance_field = NGPRadianceField(aabb=aabb, unbounded=unbounded).to(device)
optimizer = torch.optim.Adam(
itertools.chain(
radiance_field.parameters(),
*[p.parameters() for p in proposal_networks],
),
lr=1e-2,
eps=1e-15,
weight_decay=weight_decay,
)
scheduler = torch.optim.lr_scheduler.ChainedScheduler(
[
torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=0.01, total_iters=100
),
torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
max_steps // 2,
max_steps * 3 // 4,
max_steps * 9 // 10,
],
gamma=0.33,
),
]
)
proposal_requires_grad_fn = get_proposal_requires_grad_fn()
proposal_annealing_fn = get_proposal_annealing_fn()
lpips_net = LPIPS(net="vgg").to(device)
lpips_norm_fn = lambda x: x[None, ...].permute(0, 3, 1, 2) * 2 - 1
lpips_fn = lambda x, y: lpips_net(lpips_norm_fn(x), lpips_norm_fn(y)).mean()
# training
tic = time.time()
for step in range(max_steps + 1):
radiance_field.train()
for p in proposal_networks:
p.train()
i = torch.randint(0, len(train_dataset), (1,)).item()
data = train_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
# render
(
rgb,
acc,
depth,
weights_per_level,
s_vals_per_level,
) = render_image_proposal(
radiance_field,
proposal_networks,
rays,
scene_aabb=None,
# rendering options
num_samples=num_samples,
num_samples_per_prop=num_samples_per_prop,
near_plane=near_plane,
far_plane=far_plane,
sampling_type=sampling_type,
opaque_bkgd=opaque_bkgd,
render_bkgd=render_bkgd,
# train options
proposal_requires_grad=proposal_requires_grad_fn(step),
proposal_annealing=proposal_annealing_fn(step),
)
# compute loss
loss = F.smooth_l1_loss(rgb, pixels)
loss_prop = compute_prop_loss(s_vals_per_level, weights_per_level)
loss = loss + loss_prop
optimizer.zero_grad()
# do not unscale it because we are using Adam.
grad_scaler.scale(loss).backward()
optimizer.step()
scheduler.step()
if step % 5000 == 0:
elapsed_time = time.time() - tic
loss = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(loss) / np.log(10.0)
print(
f"elapsed_time={elapsed_time:.2f}s | step={step} | "
f"loss={loss:.5f} | psnr={psnr:.2f} | "
f"num_rays={len(pixels):d} | "
f"max_depth={depth.max():.3f} | "
)
if step > 0 and step % max_steps == 0:
# evaluation
radiance_field.eval()
for p in proposal_networks:
p.eval()
psnrs = []
lpips = []
with torch.no_grad():
for i in tqdm.tqdm(range(len(test_dataset))):
data = test_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
# rendering
rgb, acc, depth, _, _, = render_image_proposal(
radiance_field,
proposal_networks,
rays,
scene_aabb=None,
# rendering options
num_samples=num_samples,
num_samples_per_prop=num_samples_per_prop,
near_plane=near_plane,
far_plane=far_plane,
sampling_type=sampling_type,
opaque_bkgd=opaque_bkgd,
render_bkgd=render_bkgd,
proposal_annealing=proposal_annealing_fn(step),
# test options
test_chunk_size=args.test_chunk_size,
)
mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item())
lpips.append(lpips_fn(rgb, pixels).item())
# if i == 0:
# imageio.imwrite(
# "rgb_test.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8),
# )
# imageio.imwrite(
# "rgb_error.png",
# (
# (rgb - pixels).norm(dim=-1).cpu().numpy() * 255
# ).astype(np.uint8),
# )
psnr_avg = sum(psnrs) / len(psnrs)
lpips_avg = sum(lpips) / len(lpips)
print(f"evaluation: psnr_avg={psnr_avg}, lpips_avg={lpips_avg}")
...@@ -3,13 +3,35 @@ Copyright (c) 2022 Ruilong Li, UC Berkeley. ...@@ -3,13 +3,35 @@ Copyright (c) 2022 Ruilong Li, UC Berkeley.
""" """
import random import random
from typing import Optional from typing import Literal, Optional, Sequence
import numpy as np import numpy as np
import torch import torch
from datasets.utils import Rays, namedtuple_map from datasets.utils import Rays, namedtuple_map
from torch.utils.data._utils.collate import collate, default_collate_fn_map
from nerfacc import OccupancyGrid, ray_marching, rendering from nerfacc import OccupancyGrid, ray_marching, rendering
from nerfacc.proposal import rendering as rendering_proposal
NERF_SYNTHETIC_SCENES = [
"chair",
"drums",
"ficus",
"hotdog",
"lego",
"materials",
"mic",
"ship",
]
MIPNERF360_UNBOUNDED_SCENES = [
"garden",
"bicycle",
"bonsai",
"counter",
"kitchen",
"room",
"stump",
]
def set_random_seed(seed): def set_random_seed(seed):
...@@ -18,6 +40,12 @@ def set_random_seed(seed): ...@@ -18,6 +40,12 @@ def set_random_seed(seed):
torch.manual_seed(seed) torch.manual_seed(seed)
def enlarge_aabb(aabb, factor: float) -> torch.Tensor:
center = (aabb[:3] + aabb[3:]) / 2
extent = (aabb[3:] - aabb[:3]) / 2
return torch.cat([center - extent * factor, center + extent * factor])
def render_image( def render_image(
# scene # scene
radiance_field: torch.nn.Module, radiance_field: torch.nn.Module,
...@@ -116,3 +144,99 @@ def render_image( ...@@ -116,3 +144,99 @@ def render_image(
depths.view((*rays_shape[:-1], -1)), depths.view((*rays_shape[:-1], -1)),
sum(n_rendering_samples), sum(n_rendering_samples),
) )
def render_image_proposal(
# scene
radiance_field: torch.nn.Module,
proposal_networks: Sequence[torch.nn.Module],
rays: Rays,
scene_aabb: torch.Tensor,
# rendering options
num_samples: int,
num_samples_per_prop: Sequence[int],
near_plane: Optional[float] = None,
far_plane: Optional[float] = None,
sampling_type: Literal["uniform", "lindisp"] = "lindisp",
opaque_bkgd: bool = True,
render_bkgd: Optional[torch.Tensor] = None,
# train options
proposal_requires_grad: bool = False,
proposal_annealing: float = 1.0,
# test options
test_chunk_size: int = 8192,
):
"""Render the pixels of an image."""
rays_shape = rays.origins.shape
if len(rays_shape) == 3:
height, width, _ = rays_shape
num_rays = height * width
rays = namedtuple_map(
lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays
)
else:
num_rays, _ = rays_shape
def prop_sigma_fn(t_starts, t_ends, proposal_network):
t_origins = chunk_rays.origins[..., None, :]
t_dirs = chunk_rays.viewdirs[..., None, :]
positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0
return proposal_network(positions)
def rgb_sigma_fn(t_starts, t_ends):
t_origins = chunk_rays.origins[..., None, :]
t_dirs = chunk_rays.viewdirs[..., None, :].repeat_interleave(
t_starts.shape[-2], dim=-2
)
positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0
return radiance_field(positions, t_dirs)
results = []
chunk = (
torch.iinfo(torch.int32).max
if radiance_field.training
else test_chunk_size
)
for i in range(0, num_rays, chunk):
chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
(
rgb,
opacity,
depth,
(weights_per_level, s_vals_per_level),
) = rendering_proposal(
rgb_sigma_fn=rgb_sigma_fn,
num_samples=num_samples,
prop_sigma_fns=[
lambda *args: prop_sigma_fn(*args, p) for p in proposal_networks
],
num_samples_per_prop=num_samples_per_prop,
rays_o=chunk_rays.origins,
rays_d=chunk_rays.viewdirs,
scene_aabb=scene_aabb,
near_plane=near_plane,
far_plane=far_plane,
stratified=radiance_field.training,
sampling_type=sampling_type,
opaque_bkgd=opaque_bkgd,
render_bkgd=render_bkgd,
proposal_requires_grad=proposal_requires_grad,
proposal_annealing=proposal_annealing,
)
chunk_results = [rgb, opacity, depth]
results.append(chunk_results)
colors, opacities, depths = collate(
results,
collate_fn_map={
**default_collate_fn_map,
torch.Tensor: lambda x, **_: torch.cat(x, 0),
},
)
return (
colors.view((*rays_shape[:-1], -1)),
opacities.view((*rays_shape[:-1], -1)),
depths.view((*rays_shape[:-1], -1)),
weights_per_level,
s_vals_per_level,
)
""" """
Copyright (c) 2022 Ruilong Li, UC Berkeley. Copyright (c) 2022 Ruilong Li, UC Berkeley.
""" """
import warnings
from .cdf import ray_resampling
from .contraction import ContractionType, contract, contract_inv from .contraction import ContractionType, contract, contract_inv
from .grid import Grid, OccupancyGrid, query_grid from .grid import Grid, OccupancyGrid, query_grid
from .intersection import ray_aabb_intersect from .intersection import ray_aabb_intersect
from .losses import distortion as loss_distortion
from .pack import pack_data, pack_info, unpack_data, unpack_info from .pack import pack_data, pack_info, unpack_data, unpack_info
from .ray_marching import ray_marching from .ray_marching import ray_marching
from .version import __version__ from .version import __version__
...@@ -21,39 +17,30 @@ from .vol_rendering import ( ...@@ -21,39 +17,30 @@ from .vol_rendering import (
rendering, rendering,
) )
# About to be deprecated
def unpack_to_ray_indices(*args, **kwargs):
warnings.warn(
"`unpack_to_ray_indices` will be deprecated. Please use `unpack_info` instead.",
DeprecationWarning,
stacklevel=2,
)
return unpack_info(*args, **kwargs)
__all__ = [ __all__ = [
"__version__", "__version__",
# occ grid
"Grid", "Grid",
"OccupancyGrid", "OccupancyGrid",
"query_grid", "query_grid",
"ContractionType", "ContractionType",
# contraction
"contract", "contract",
"contract_inv", "contract_inv",
# marching
"ray_aabb_intersect", "ray_aabb_intersect",
"ray_marching", "ray_marching",
# rendering
"accumulate_along_rays", "accumulate_along_rays",
"render_visibility", "render_visibility",
"render_weight_from_alpha", "render_weight_from_alpha",
"render_weight_from_density", "render_weight_from_density",
"render_transmittance_from_density",
"render_transmittance_from_alpha",
"rendering", "rendering",
# pack
"pack_data", "pack_data",
"unpack_data", "unpack_data",
"unpack_info", "unpack_info",
"pack_info", "pack_info",
"ray_resampling",
"loss_distortion",
"unpack_to_ray_indices",
"render_transmittance_from_density",
"render_transmittance_from_alpha",
] ]
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
from typing import Tuple
from torch import Tensor
import nerfacc.cuda as _C
def ray_resampling(
packed_info: Tensor,
t_starts: Tensor,
t_ends: Tensor,
weights: Tensor,
n_samples: int,
) -> Tuple[Tensor, Tensor, Tensor]:
"""Resample a set of rays based on the CDF of the weights.
Args:
packed_info (Tensor): Stores information on which samples belong to the same ray. \
See :func:`nerfacc.ray_marching` for details. Tensor with shape (n_rays, 2).
t_starts: Where the frustum-shape sample starts along a ray. Tensor with \
shape (n_samples, 1).
t_ends: Where the frustum-shape sample ends along a ray. Tensor with \
shape (n_samples, 1).
weights: Volumetric rendering weights for those samples. Tensor with shape \
(n_samples,).
n_samples (int): Number of samples per ray to resample.
Returns:
Resampled packed info (n_rays, 2), t_starts (n_samples, 1), and t_ends (n_samples, 1).
"""
(
resampled_packed_info,
resampled_t_starts,
resampled_t_ends,
) = _C.ray_resampling(
packed_info.contiguous(),
t_starts.contiguous(),
t_ends.contiguous(),
weights.contiguous(),
n_samples,
)
return resampled_packed_info, resampled_t_starts, resampled_t_ends
...@@ -23,7 +23,6 @@ grid_query = _make_lazy_cuda_func("grid_query") ...@@ -23,7 +23,6 @@ grid_query = _make_lazy_cuda_func("grid_query")
ray_aabb_intersect = _make_lazy_cuda_func("ray_aabb_intersect") ray_aabb_intersect = _make_lazy_cuda_func("ray_aabb_intersect")
ray_marching = _make_lazy_cuda_func("ray_marching") ray_marching = _make_lazy_cuda_func("ray_marching")
ray_resampling = _make_lazy_cuda_func("ray_resampling")
is_cub_available = _make_lazy_cuda_func("is_cub_available") is_cub_available = _make_lazy_cuda_func("is_cub_available")
transmittance_from_sigma_forward_cub = _make_lazy_cuda_func( transmittance_from_sigma_forward_cub = _make_lazy_cuda_func(
...@@ -68,3 +67,6 @@ weight_from_alpha_backward_naive = _make_lazy_cuda_func( ...@@ -68,3 +67,6 @@ weight_from_alpha_backward_naive = _make_lazy_cuda_func(
unpack_data = _make_lazy_cuda_func("unpack_data") unpack_data = _make_lazy_cuda_func("unpack_data")
unpack_info = _make_lazy_cuda_func("unpack_info") unpack_info = _make_lazy_cuda_func("unpack_info")
unpack_info_to_mask = _make_lazy_cuda_func("unpack_info_to_mask") unpack_info_to_mask = _make_lazy_cuda_func("unpack_info_to_mask")
pdf_readout = _make_lazy_cuda_func("pdf_readout")
pdf_sampling = _make_lazy_cuda_func("pdf_sampling")
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include "include/helpers_cuda.h"
template <typename scalar_t>
__global__ void cdf_resampling_kernel(
const uint32_t n_rays,
const int *packed_info, // input ray & point indices.
const scalar_t *starts, // input start t
const scalar_t *ends, // input end t
const scalar_t *weights, // transmittance weights
const int *resample_packed_info,
scalar_t *resample_starts,
scalar_t *resample_ends)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
const int resample_base = resample_packed_info[i * 2 + 0];
const int resample_steps = resample_packed_info[i * 2 + 1];
if (steps == 0)
return;
starts += base;
ends += base;
weights += base;
resample_starts += resample_base;
resample_ends += resample_base;
// normalize weights **per ray**
scalar_t weights_sum = 0.0f;
for (int j = 0; j < steps; j++)
weights_sum += weights[j];
scalar_t padding = fmaxf(1e-5f - weights_sum, 0.0f);
scalar_t padding_step = padding / steps;
weights_sum += padding;
int num_bins = resample_steps + 1;
scalar_t cdf_step_size = (1.0f - 1.0 / num_bins) / resample_steps;
int idx = 0, j = 0;
scalar_t cdf_prev = 0.0f, cdf_next = (weights[idx] + padding_step) / weights_sum;
scalar_t cdf_u = 1.0 / (2 * num_bins);
while (j < num_bins)
{
if (cdf_u < cdf_next)
{
// printf("cdf_u: %f, cdf_next: %f\n", cdf_u, cdf_next);
// resample in this interval
scalar_t scaling = (ends[idx] - starts[idx]) / (cdf_next - cdf_prev);
scalar_t t = (cdf_u - cdf_prev) * scaling + starts[idx];
if (j < num_bins - 1)
resample_starts[j] = t;
if (j > 0)
resample_ends[j - 1] = t;
// going further to next resample
cdf_u += cdf_step_size;
j += 1;
}
else
{
// going to next interval
idx += 1;
cdf_prev = cdf_next;
cdf_next += (weights[idx] + padding_step) / weights_sum;
}
}
if (j != num_bins)
{
printf("Error: %d %d %f\n", j, num_bins, weights_sum);
}
return;
}
// template <typename scalar_t>
// __global__ void cdf_resampling_kernel(
// const uint32_t n_rays,
// const int *packed_info, // input ray & point indices.
// const scalar_t *starts, // input start t
// const scalar_t *ends, // input end t
// const scalar_t *weights, // transmittance weights
// const int *resample_packed_info,
// scalar_t *resample_starts,
// scalar_t *resample_ends)
// {
// CUDA_GET_THREAD_ID(i, n_rays);
// // locate
// const int base = packed_info[i * 2 + 0]; // point idx start.
// const int steps = packed_info[i * 2 + 1]; // point idx shift.
// const int resample_base = resample_packed_info[i * 2 + 0];
// const int resample_steps = resample_packed_info[i * 2 + 1];
// if (steps == 0)
// return;
// starts += base;
// ends += base;
// weights += base;
// resample_starts += resample_base;
// resample_ends += resample_base;
// scalar_t cdf_step_size = 1.0f / resample_steps;
// // normalize weights **per ray**
// scalar_t weights_sum = 0.0f;
// for (int j = 0; j < steps; j++)
// weights_sum += weights[j];
// scalar_t padding = fmaxf(1e-5f - weights_sum, 0.0f);
// scalar_t padding_step = padding / steps;
// weights_sum += padding;
// int idx = 0, j = 0;
// scalar_t cdf_prev = 0.0f, cdf_next = (weights[idx] + padding_step) / weights_sum;
// scalar_t cdf_u = 0.5f * cdf_step_size;
// while (cdf_u < 1.0f)
// {
// if (cdf_u < cdf_next)
// {
// // resample in this interval
// scalar_t scaling = (ends[idx] - starts[idx]) / (cdf_next - cdf_prev);
// scalar_t resample_mid = (cdf_u - cdf_prev) * scaling + starts[idx];
// scalar_t resample_half_size = cdf_step_size * scaling * 0.5;
// resample_starts[j] = fmaxf(resample_mid - resample_half_size, starts[idx]);
// resample_ends[j] = fminf(resample_mid + resample_half_size, ends[idx]);
// // going further to next resample
// cdf_u += cdf_step_size;
// j += 1;
// }
// else
// {
// // go to next interval
// idx += 1;
// if (idx == steps)
// break;
// cdf_prev = cdf_next;
// cdf_next += (weights[idx] + padding_step) / weights_sum;
// }
// }
// if (j != resample_steps)
// {
// printf("Error: %d %d %f\n", j, resample_steps, weights_sum);
// }
// return;
// }
std::vector<torch::Tensor> ray_resampling(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor weights,
const int steps)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(starts);
CHECK_INPUT(ends);
CHECK_INPUT(weights);
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 2);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(weights.ndimension() == 1);
const uint32_t n_rays = packed_info.size(0);
const uint32_t n_samples = weights.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
torch::Tensor num_steps = torch::split(packed_info, 1, 1)[1];
torch::Tensor resample_num_steps = (num_steps > 0).to(num_steps.options()) * steps;
torch::Tensor resample_cum_steps = resample_num_steps.cumsum(0, torch::kInt32);
torch::Tensor resample_packed_info = torch::cat(
{resample_cum_steps - resample_num_steps, resample_num_steps}, 1);
int total_steps = resample_cum_steps[resample_cum_steps.size(0) - 1].item<int>();
torch::Tensor resample_starts = torch::zeros({total_steps, 1}, starts.options());
torch::Tensor resample_ends = torch::zeros({total_steps, 1}, ends.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
weights.scalar_type(),
"ray_resampling",
([&]
{ cdf_resampling_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
starts.data_ptr<scalar_t>(),
ends.data_ptr<scalar_t>(),
weights.data_ptr<scalar_t>(),
resample_packed_info.data_ptr<int>(),
// outputs
resample_starts.data_ptr<scalar_t>(),
resample_ends.data_ptr<scalar_t>()); }));
return {resample_packed_info, resample_starts, resample_ends};
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment