"git@developer.sourcefind.cn:OpenDAS/torch-cluster.git" did not exist on "0a0383342251ef9e3e3327ddd250fdff1714a032"
Unverified Commit b53f8a4e authored by Ruilong Li(李瑞龙)'s avatar Ruilong Li(李瑞龙) Committed by GitHub
Browse files

Support multi-res occ grid & prop net (#176)

* multres grid

* prop

* benchmark with prop and occ

* benchmark blender with weight_decay

* docs

* bump version
parent 82fd69c7
...@@ -118,4 +118,5 @@ venv.bak/ ...@@ -118,4 +118,5 @@ venv.bak/
.vsocde .vsocde
benchmarks/ benchmarks/
outputs/ outputs/
\ No newline at end of file data
\ No newline at end of file
nerfacc.ray\_resampling
=======================
.. currentmodule:: nerfacc
.. autofunction:: ray_resampling
\ No newline at end of file
...@@ -17,7 +17,6 @@ Utils ...@@ -17,7 +17,6 @@ Utils
render_weight_from_alpha render_weight_from_alpha
render_visibility render_visibility
ray_resampling
pack_data pack_data
unpack_data unpack_data
\ No newline at end of file
...@@ -7,7 +7,7 @@ See code `examples/train_ngp_nerf.py` at our `github repository`_ for details. ...@@ -7,7 +7,7 @@ See code `examples/train_ngp_nerf.py` at our `github repository`_ for details.
Benchmarks Benchmarks
------------ ------------
*updated on 2022-10-12* *updated on 2023-03-14*
Here we trained a `Instant-NGP Nerf`_ model on the `Nerf-Synthetic dataset`_. We follow the same Here we trained a `Instant-NGP Nerf`_ model on the `Nerf-Synthetic dataset`_. We follow the same
settings with the Instant-NGP paper, which uses train split for training and test split for settings with the Instant-NGP paper, which uses train split for training and test split for
...@@ -30,11 +30,15 @@ memory footprint is about 3GB. ...@@ -30,11 +30,15 @@ memory footprint is about 3GB.
+-----------------------+-------+-------+---------+-------+-------+-------+-------+-------+-------+ +-----------------------+-------+-------+---------+-------+-------+-------+-------+-------+-------+
|(training time) | 309s | 258s | 256s | 316s | 292s | 207s | 218s | 250s | 263s | |(training time) | 309s | 258s | 256s | 316s | 292s | 207s | 218s | 250s | 263s |
+-----------------------+-------+-------+---------+-------+-------+-------+-------+-------+-------+ +-----------------------+-------+-------+---------+-------+-------+-------+-------+-------+-------+
|Ours 20k steps | 35.50 | 36.16 | 29.14 | 35.23 | 37.15 | 31.71 | 24.88 | 29.91 | 32.46 | |Ours (occ) 20k steps | 35.81 | 36.87 | 29.59 | 35.70 | 37.45 | 33.63 | 24.98 | 30.64 | 33.08 |
+-----------------------+-------+-------+---------+-------+-------+-------+-------+-------+-------+ +-----------------------+-------+-------+---------+-------+-------+-------+-------+-------+-------+
|(training time) | 287s | 274s | 269s | 317s | 269s | 244s | 249s | 257s | 271s | |(training time) | 288s | 255s | 247s | 319s | 274s | 238s | 247s | 252s | 265s |
+-----------------------+-------+-------+---------+-------+-------+-------+-------+-------+-------+
|Ours (prop) 20k steps | 34.06 | 34.32 | 27.93 | 34.27 | 36.47 | 31.39 | 24.39 | 30.57 | 31.68 |
+-----------------------+-------+-------+---------+-------+-------+-------+-------+-------+-------+
|(training time) | 238s | 236s | 250s | 235s | 235s | 236s | 236s | 236s | 240s |
+-----------------------+-------+-------+---------+-------+-------+-------+-------+-------+-------+ +-----------------------+-------+-------+---------+-------+-------+-------+-------+-------+-------+
.. _`Instant-NGP Nerf`: https://github.com/NVlabs/instant-ngp/tree/51e4107edf48338e9ab0316d56a222e0adf87143 .. _`Instant-NGP Nerf`: https://github.com/NVlabs/instant-ngp/tree/51e4107edf48338e9ab0316d56a222e0adf87143
.. _`github repository`: https://github.com/KAIR-BAIR/nerfacc/tree/76c0f9817da4c9c8b5ccf827eb069ee2ce854b75 .. _`github repository`: https://github.com/KAIR-BAIR/nerfacc/
.. _`Nerf-Synthetic dataset`: https://drive.google.com/drive/folders/1JDdLGDruGNXWnM1eqY1FNL9PlStjaKWi .. _`Nerf-Synthetic dataset`: https://drive.google.com/drive/folders/1JDdLGDruGNXWnM1eqY1FNL9PlStjaKWi
...@@ -5,7 +5,7 @@ See code `examples/train_ngp_nerf.py` at our `github repository`_ for details. ...@@ -5,7 +5,7 @@ See code `examples/train_ngp_nerf.py` at our `github repository`_ for details.
Benchmarks Benchmarks
------------ ------------
*updated on 2022-11-07* *updated on 2023-03-14*
Here we trained a `Instant-NGP Nerf`_ on the `MipNerf360`_ dataset. We used train Here we trained a `Instant-NGP Nerf`_ on the `MipNerf360`_ dataset. We used train
split for training and test split for evaluation. Our experiments are conducted on a split for training and test split for evaluation. Our experiments are conducted on a
...@@ -32,12 +32,19 @@ that takes from `MipNerf360`_. ...@@ -32,12 +32,19 @@ that takes from `MipNerf360`_.
+----------------------+-------+-------+-------+-------+-------+-------+-------+-------+ +----------------------+-------+-------+-------+-------+-------+-------+-------+-------+
| MipNerf360 (~days) | 26.98 | 24.37 | 33.46 | 29.55 | 32.23 | 31.63 | 26.40 | 29.23 | | MipNerf360 (~days) | 26.98 | 24.37 | 33.46 | 29.55 | 32.23 | 31.63 | 26.40 | 29.23 |
+----------------------+-------+-------+-------+-------+-------+-------+-------+-------+ +----------------------+-------+-------+-------+-------+-------+-------+-------+-------+
| Ours (~20 mins) | 25.41 | 22.97 | 30.71 | 27.34 | 30.32 | 31.00 | 23.43 | 27.31 | | Ours (occ) | 24.76 | 22.38 | 29.72 | 26.80 | 28.02 | 30.67 | 22.39 | 26.39 |
+----------------------+-------+-------+-------+-------+-------+-------+-------+-------+ +----------------------+-------+-------+-------+-------+-------+-------+-------+-------+
| Ours (Training time) | 25min | 17min | 19min | 23min | 28min | 20min | 17min | 21min | | Ours (Training time) | 323s | 302s | 300s | 337s | 347s | 320s | 322s | 322s |
+----------------------+-------+-------+-------+-------+-------+-------+-------+-------+ +----------------------+-------+-------+-------+-------+-------+-------+-------+-------+
| Ours (prop) | 25.44 | 23.21 | 30.62 | 26.75 | 30.63 | 30.93 | 25.20 | 27.54 |
+----------------------+-------+-------+-------+-------+-------+-------+-------+-------+
| Ours (Training time) | 308s | 304s | 308s | 306s | 313s | 301s | 287s | 304s |
+----------------------+-------+-------+-------+-------+-------+-------+-------+-------+
Note `Ours (prop)` is basically a `Nerfacto_` model.
.. _`Instant-NGP Nerf`: https://arxiv.org/abs/2201.05989 .. _`Instant-NGP Nerf`: https://arxiv.org/abs/2201.05989
.. _`MipNerf360`: https://arxiv.org/abs/2111.12077 .. _`MipNerf360`: https://arxiv.org/abs/2111.12077
.. _`Nerf++`: https://arxiv.org/abs/2010.07492 .. _`Nerf++`: https://arxiv.org/abs/2010.07492
.. _`github repository`: https://github.com/KAIR-BAIR/nerfacc/tree/76c0f9817da4c9c8b5ccf827eb069ee2ce854b75 .. _`github repository`: https://github.com/KAIR-BAIR/nerfacc/
.. _`Nerfacto`: https://github.com/nerfstudio-project/nerfstudio/blob/main/nerfstudio/models/nerfacto.py
\ No newline at end of file
...@@ -123,7 +123,7 @@ Links: ...@@ -123,7 +123,7 @@ Links:
.. toctree:: .. toctree::
:glob: :glob:
:maxdepth: 1 :maxdepth: 1
:caption: Example Usages :caption: Example Usages and Benchmarks
examples/* examples/*
......
...@@ -86,7 +86,7 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -86,7 +86,7 @@ class SubjectLoader(torch.utils.data.Dataset):
near: float = None, near: float = None,
far: float = None, far: float = None,
batch_over_images: bool = True, batch_over_images: bool = True,
device: str = "cuda:0", device: str = "cpu",
): ):
super().__init__() super().__init__()
assert split in self.SPLITS, "%s" % split assert split in self.SPLITS, "%s" % split
......
...@@ -22,7 +22,7 @@ sys.path.insert( ...@@ -22,7 +22,7 @@ sys.path.insert(
from scene_manager import SceneManager from scene_manager import SceneManager
def _load_colmap(root_fp: str, subject_id: str, split: str, factor: int = 1): def _load_colmap(root_fp: str, subject_id: str, factor: int = 1):
assert factor in [1, 2, 4, 8] assert factor in [1, 2, 4, 8]
data_dir = os.path.join(root_fp, subject_id) data_dir = os.path.join(root_fp, subject_id)
...@@ -134,12 +134,66 @@ def _load_colmap(root_fp: str, subject_id: str, split: str, factor: int = 1): ...@@ -134,12 +134,66 @@ def _load_colmap(root_fp: str, subject_id: str, split: str, factor: int = 1):
"test": all_indices[all_indices % 8 == 0], "test": all_indices[all_indices % 8 == 0],
"train": all_indices[all_indices % 8 != 0], "train": all_indices[all_indices % 8 != 0],
} }
indices = split_indices[split] return images, camtoworlds, K, split_indices
# All per-image quantities must be re-indexed using the split indices.
images = images[indices]
camtoworlds = camtoworlds[indices] def similarity_from_cameras(c2w, strict_scaling):
"""
reference: nerf-factory
Get a similarity transform to normalize dataset
from c2w (OpenCV convention) cameras
:param c2w: (N, 4)
:return T (4,4) , scale (float)
"""
t = c2w[:, :3, 3]
R = c2w[:, :3, :3]
# (1) Rotate the world so that z+ is the up axis
# we estimate the up axis by averaging the camera up axes
ups = np.sum(R * np.array([0, -1.0, 0]), axis=-1)
world_up = np.mean(ups, axis=0)
world_up /= np.linalg.norm(world_up)
up_camspace = np.array([0.0, -1.0, 0.0])
c = (up_camspace * world_up).sum()
cross = np.cross(world_up, up_camspace)
skew = np.array(
[
[0.0, -cross[2], cross[1]],
[cross[2], 0.0, -cross[0]],
[-cross[1], cross[0], 0.0],
]
)
if c > -1:
R_align = np.eye(3) + skew + (skew @ skew) * 1 / (1 + c)
else:
# In the unlikely case the original data has y+ up axis,
# rotate 180-deg about x axis
R_align = np.array([[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
# R_align = np.eye(3) # DEBUG
R = R_align @ R
fwds = np.sum(R * np.array([0, 0.0, 1.0]), axis=-1)
t = (R_align @ t[..., None])[..., 0]
# (2) Recenter the scene using camera center rays
# find the closest point to the origin for each camera's center ray
nearest = t + (fwds * -t).sum(-1)[:, None] * fwds
# median for more robustness
translate = -np.median(nearest, axis=0)
return images, camtoworlds, K # translate = -np.mean(t, axis=0) # DEBUG
transform = np.eye(4)
transform[:3, 3] = translate
transform[:3, :3] = R_align
# (3) Rescale the scene using camera distances
scale_fn = np.max if strict_scaling else np.median
scale = 1.0 / scale_fn(np.linalg.norm(t + translate, axis=-1))
return transform, scale
class SubjectLoader(torch.utils.data.Dataset): class SubjectLoader(torch.utils.data.Dataset):
...@@ -169,7 +223,7 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -169,7 +223,7 @@ class SubjectLoader(torch.utils.data.Dataset):
far: float = None, far: float = None,
batch_over_images: bool = True, batch_over_images: bool = True,
factor: int = 1, factor: int = 1,
device: str = "cuda:0", device: str = "cpu",
): ):
super().__init__() super().__init__()
assert split in self.SPLITS, "%s" % split assert split in self.SPLITS, "%s" % split
...@@ -184,14 +238,25 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -184,14 +238,25 @@ class SubjectLoader(torch.utils.data.Dataset):
) )
self.color_bkgd_aug = color_bkgd_aug self.color_bkgd_aug = color_bkgd_aug
self.batch_over_images = batch_over_images self.batch_over_images = batch_over_images
self.images, self.camtoworlds, self.K = _load_colmap( self.images, self.camtoworlds, self.K, split_indices = _load_colmap(
root_fp, subject_id, split, factor root_fp, subject_id, factor
)
# normalize the scene
T, sscale = similarity_from_cameras(
self.camtoworlds, strict_scaling=False
) )
self.images = torch.from_numpy(self.images).to(device).to(torch.uint8) self.camtoworlds = np.einsum("nij, ki -> nkj", self.camtoworlds, T)
self.camtoworlds[:, :3, 3] *= sscale
# split
indices = split_indices[split]
self.images = self.images[indices]
self.camtoworlds = self.camtoworlds[indices]
# to tensor
self.images = torch.from_numpy(self.images).to(torch.uint8).to(device)
self.camtoworlds = ( self.camtoworlds = (
torch.from_numpy(self.camtoworlds).to(device).to(torch.float32) torch.from_numpy(self.camtoworlds).to(torch.float32).to(device)
) )
self.K = torch.tensor(self.K).to(device).to(torch.float32) self.K = torch.tensor(self.K).to(torch.float32).to(device)
self.height, self.width = self.images.shape[1:3] self.height, self.width = self.images.shape[1:3]
def __len__(self): def __len__(self):
...@@ -275,7 +340,7 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -275,7 +340,7 @@ class SubjectLoader(torch.utils.data.Dataset):
value=(-1.0 if self.OPENGL_CAMERA else 1.0), value=(-1.0 if self.OPENGL_CAMERA else 1.0),
) # [num_rays, 3] ) # [num_rays, 3]
# [n_cams, height, width, 3] # [num_rays, 3]
directions = (camera_dirs[:, None, :] * c2w[:, :3, :3]).sum(dim=-1) directions = (camera_dirs[:, None, :] * c2w[:, :3, :3]).sum(dim=-1)
origins = torch.broadcast_to(c2w[:, :3, -1], directions.shape) origins = torch.broadcast_to(c2w[:, :3, -1], directions.shape)
viewdirs = directions / torch.linalg.norm( viewdirs = directions / torch.linalg.norm(
......
...@@ -79,7 +79,7 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -79,7 +79,7 @@ class SubjectLoader(torch.utils.data.Dataset):
near: float = None, near: float = None,
far: float = None, far: float = None,
batch_over_images: bool = True, batch_over_images: bool = True,
device: str = "cuda:0", device: torch.device = torch.device("cpu"),
): ):
super().__init__() super().__init__()
assert split in self.SPLITS, "%s" % split assert split in self.SPLITS, "%s" % split
...@@ -110,10 +110,8 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -110,10 +110,8 @@ class SubjectLoader(torch.utils.data.Dataset):
self.images, self.camtoworlds, self.focal = _load_renderings( self.images, self.camtoworlds, self.focal = _load_renderings(
root_fp, subject_id, split root_fp, subject_id, split
) )
self.images = torch.from_numpy(self.images).to(device).to(torch.uint8) self.images = torch.from_numpy(self.images).to(torch.uint8)
self.camtoworlds = ( self.camtoworlds = torch.from_numpy(self.camtoworlds).to(torch.float32)
torch.from_numpy(self.camtoworlds).to(device).to(torch.float32)
)
self.K = torch.tensor( self.K = torch.tensor(
[ [
[self.focal, 0, self.WIDTH / 2.0], [self.focal, 0, self.WIDTH / 2.0],
...@@ -121,8 +119,10 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -121,8 +119,10 @@ class SubjectLoader(torch.utils.data.Dataset):
[0, 0, 1], [0, 0, 1],
], ],
dtype=torch.float32, dtype=torch.float32,
device=device,
) # (3, 3) ) # (3, 3)
self.images = self.images.to(device)
self.camtoworlds = self.camtoworlds.to(device)
self.K = self.K.to(device)
assert self.images.shape[1:3] == (self.HEIGHT, self.WIDTH) assert self.images.shape[1:3] == (self.HEIGHT, self.WIDTH)
def __len__(self): def __len__(self):
......
...@@ -4,6 +4,7 @@ Copyright (c) 2022 Ruilong Li, UC Berkeley. ...@@ -4,6 +4,7 @@ Copyright (c) 2022 Ruilong Li, UC Berkeley.
from typing import Callable, List, Union from typing import Callable, List, Union
import numpy as np
import torch import torch
from torch.autograd import Function from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
...@@ -41,13 +42,15 @@ trunc_exp = _TruncExp.apply ...@@ -41,13 +42,15 @@ trunc_exp = _TruncExp.apply
def contract_to_unisphere( def contract_to_unisphere(
x: torch.Tensor, x: torch.Tensor,
aabb: torch.Tensor, aabb: torch.Tensor,
ord: Union[str, int] = 2,
# ord: Union[float, int] = float("inf"),
eps: float = 1e-6, eps: float = 1e-6,
derivative: bool = False, derivative: bool = False,
): ):
aabb_min, aabb_max = torch.split(aabb, 3, dim=-1) aabb_min, aabb_max = torch.split(aabb, 3, dim=-1)
x = (x - aabb_min) / (aabb_max - aabb_min) x = (x - aabb_min) / (aabb_max - aabb_min)
x = x * 2 - 1 # aabb is at [-1, 1] x = x * 2 - 1 # aabb is at [-1, 1]
mag = x.norm(dim=-1, keepdim=True) mag = torch.linalg.norm(x, ord=ord, dim=-1, keepdim=True)
mask = mag.squeeze(-1) > 1 mask = mag.squeeze(-1) > 1
if derivative: if derivative:
...@@ -63,8 +66,8 @@ def contract_to_unisphere( ...@@ -63,8 +66,8 @@ def contract_to_unisphere(
return x return x
class NGPradianceField(torch.nn.Module): class NGPRadianceField(torch.nn.Module):
"""Instance-NGP radiance Field""" """Instance-NGP Radiance Field"""
def __init__( def __init__(
self, self,
...@@ -73,6 +76,8 @@ class NGPradianceField(torch.nn.Module): ...@@ -73,6 +76,8 @@ class NGPradianceField(torch.nn.Module):
use_viewdirs: bool = True, use_viewdirs: bool = True,
density_activation: Callable = lambda x: trunc_exp(x - 1), density_activation: Callable = lambda x: trunc_exp(x - 1),
unbounded: bool = False, unbounded: bool = False,
base_resolution: int = 16,
max_resolution: int = 4096,
geo_feat_dim: int = 15, geo_feat_dim: int = 15,
n_levels: int = 16, n_levels: int = 16,
log2_hashmap_size: int = 19, log2_hashmap_size: int = 19,
...@@ -85,9 +90,15 @@ class NGPradianceField(torch.nn.Module): ...@@ -85,9 +90,15 @@ class NGPradianceField(torch.nn.Module):
self.use_viewdirs = use_viewdirs self.use_viewdirs = use_viewdirs
self.density_activation = density_activation self.density_activation = density_activation
self.unbounded = unbounded self.unbounded = unbounded
self.base_resolution = base_resolution
self.max_resolution = max_resolution
self.geo_feat_dim = geo_feat_dim self.geo_feat_dim = geo_feat_dim
per_level_scale = 1.4472692012786865 self.n_levels = n_levels
self.log2_hashmap_size = log2_hashmap_size
per_level_scale = np.exp(
(np.log(max_resolution) - np.log(base_resolution)) / (n_levels - 1)
).tolist()
if self.use_viewdirs: if self.use_viewdirs:
self.direction_encoding = tcnn.Encoding( self.direction_encoding = tcnn.Encoding(
...@@ -113,7 +124,7 @@ class NGPradianceField(torch.nn.Module): ...@@ -113,7 +124,7 @@ class NGPradianceField(torch.nn.Module):
"n_levels": n_levels, "n_levels": n_levels,
"n_features_per_level": 2, "n_features_per_level": 2,
"log2_hashmap_size": log2_hashmap_size, "log2_hashmap_size": log2_hashmap_size,
"base_resolution": 16, "base_resolution": base_resolution,
"per_level_scale": per_level_scale, "per_level_scale": per_level_scale,
}, },
network_config={ network_config={
...@@ -138,7 +149,7 @@ class NGPradianceField(torch.nn.Module): ...@@ -138,7 +149,7 @@ class NGPradianceField(torch.nn.Module):
network_config={ network_config={
"otype": "FullyFusedMLP", "otype": "FullyFusedMLP",
"activation": "ReLU", "activation": "ReLU",
"output_activation": "Sigmoid", "output_activation": "None",
"n_neurons": 64, "n_neurons": 64,
"n_hidden_layers": 2, "n_hidden_layers": 2,
}, },
...@@ -168,19 +179,21 @@ class NGPradianceField(torch.nn.Module): ...@@ -168,19 +179,21 @@ class NGPradianceField(torch.nn.Module):
else: else:
return density return density
def _query_rgb(self, dir, embedding): def _query_rgb(self, dir, embedding, apply_act: bool = True):
# tcnn requires directions in the range [0, 1] # tcnn requires directions in the range [0, 1]
if self.use_viewdirs: if self.use_viewdirs:
dir = (dir + 1.0) / 2.0 dir = (dir + 1.0) / 2.0
d = self.direction_encoding(dir.view(-1, dir.shape[-1])) d = self.direction_encoding(dir.reshape(-1, dir.shape[-1]))
h = torch.cat([d, embedding.view(-1, self.geo_feat_dim)], dim=-1) h = torch.cat([d, embedding.reshape(-1, self.geo_feat_dim)], dim=-1)
else: else:
h = embedding.view(-1, self.geo_feat_dim) h = embedding.reshape(-1, self.geo_feat_dim)
rgb = ( rgb = (
self.mlp_head(h) self.mlp_head(h)
.view(list(embedding.shape[:-1]) + [3]) .reshape(list(embedding.shape[:-1]) + [3])
.to(embedding) .to(embedding)
) )
if apply_act:
rgb = torch.sigmoid(rgb)
return rgb return rgb
def forward( def forward(
...@@ -194,4 +207,73 @@ class NGPradianceField(torch.nn.Module): ...@@ -194,4 +207,73 @@ class NGPradianceField(torch.nn.Module):
), f"{positions.shape} v.s. {directions.shape}" ), f"{positions.shape} v.s. {directions.shape}"
density, embedding = self.query_density(positions, return_feat=True) density, embedding = self.query_density(positions, return_feat=True)
rgb = self._query_rgb(directions, embedding=embedding) rgb = self._query_rgb(directions, embedding=embedding)
return rgb, density return rgb, density # type: ignore
class NGPDensityField(torch.nn.Module):
"""Instance-NGP Density Field used for resampling"""
def __init__(
self,
aabb: Union[torch.Tensor, List[float]],
num_dim: int = 3,
density_activation: Callable = lambda x: trunc_exp(x - 1),
unbounded: bool = False,
base_resolution: int = 16,
max_resolution: int = 128,
n_levels: int = 5,
log2_hashmap_size: int = 17,
) -> None:
super().__init__()
if not isinstance(aabb, torch.Tensor):
aabb = torch.tensor(aabb, dtype=torch.float32)
self.register_buffer("aabb", aabb)
self.num_dim = num_dim
self.density_activation = density_activation
self.unbounded = unbounded
self.base_resolution = base_resolution
self.max_resolution = max_resolution
self.n_levels = n_levels
self.log2_hashmap_size = log2_hashmap_size
per_level_scale = np.exp(
(np.log(max_resolution) - np.log(base_resolution)) / (n_levels - 1)
).tolist()
self.mlp_base = tcnn.NetworkWithInputEncoding(
n_input_dims=num_dim,
n_output_dims=1,
encoding_config={
"otype": "HashGrid",
"n_levels": n_levels,
"n_features_per_level": 2,
"log2_hashmap_size": log2_hashmap_size,
"base_resolution": base_resolution,
"per_level_scale": per_level_scale,
},
network_config={
"otype": "FullyFusedMLP",
"activation": "ReLU",
"output_activation": "None",
"n_neurons": 64,
"n_hidden_layers": 1,
},
)
def forward(self, positions: torch.Tensor):
if self.unbounded:
positions = contract_to_unisphere(positions, self.aabb)
else:
aabb_min, aabb_max = torch.split(self.aabb, self.num_dim, dim=-1)
positions = (positions - aabb_min) / (aabb_max - aabb_min)
selector = ((positions > 0.0) & (positions < 1.0)).all(dim=-1)
density_before_activation = (
self.mlp_base(positions.view(-1, self.num_dim))
.view(list(positions.shape[:-1]) + [1])
.to(positions)
)
density = (
self.density_activation(density_before_activation)
* selector[..., None]
)
return density
...@@ -3,4 +3,5 @@ opencv-python ...@@ -3,4 +3,5 @@ opencv-python
imageio imageio
numpy numpy
tqdm tqdm
scipy scipy
\ No newline at end of file lpips
\ No newline at end of file
...@@ -16,227 +16,221 @@ from datasets.dnerf_synthetic import SubjectLoader ...@@ -16,227 +16,221 @@ from datasets.dnerf_synthetic import SubjectLoader
from radiance_fields.mlp import DNeRFRadianceField from radiance_fields.mlp import DNeRFRadianceField
from utils import render_image, set_random_seed from utils import render_image, set_random_seed
from nerfacc import ContractionType, OccupancyGrid from nerfacc import OccupancyGrid
if __name__ == "__main__": device = "cuda:0"
set_random_seed(42)
device = "cuda:0"
set_random_seed(42) parser = argparse.ArgumentParser()
parser.add_argument(
parser = argparse.ArgumentParser() "--data_root",
parser.add_argument( type=str,
"--data_root", default=str(pathlib.Path.cwd() / "data/dnerf"),
type=str, help="the root dir of the dataset",
default=str(pathlib.Path.cwd() / "data/dnerf"), )
help="the root dir of the dataset", parser.add_argument(
) "--train_split",
parser.add_argument( type=str,
"--train_split", default="train",
type=str, choices=["train"],
default="train", help="which train split to use",
choices=["train"], )
help="which train split to use", parser.add_argument(
) "--scene",
parser.add_argument( type=str,
"--scene", default="lego",
type=str, choices=[
default="lego", # dnerf
choices=[ "bouncingballs",
# dnerf "hellwarrior",
"bouncingballs", "hook",
"hellwarrior", "jumpingjacks",
"hook", "lego",
"jumpingjacks", "mutant",
"lego", "standup",
"mutant", "trex",
"standup", ],
"trex", help="which scene to use",
], )
help="which scene to use", parser.add_argument(
) "--aabb",
parser.add_argument( type=lambda s: [float(item) for item in s.split(",")],
"--aabb", default="-1.5,-1.5,-1.5,1.5,1.5,1.5",
type=lambda s: [float(item) for item in s.split(",")], help="delimited list input",
default="-1.5,-1.5,-1.5,1.5,1.5,1.5", )
help="delimited list input", parser.add_argument(
) "--test_chunk_size",
parser.add_argument( type=int,
"--test_chunk_size", default=8192,
type=int, )
default=8192, parser.add_argument("--cone_angle", type=float, default=0.0)
) args = parser.parse_args()
parser.add_argument("--cone_angle", type=float, default=0.0)
args = parser.parse_args() render_n_samples = 1024
render_n_samples = 1024 # setup the scene bounding box.
scene_aabb = torch.tensor(args.aabb, dtype=torch.float32, device=device)
# setup the scene bounding box. near_plane = None
contraction_type = ContractionType.AABB far_plane = None
scene_aabb = torch.tensor(args.aabb, dtype=torch.float32, device=device) render_step_size = (
near_plane = None (scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples
far_plane = None ).item()
render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max() # setup the radiance field we want to train.
* math.sqrt(3) max_steps = 30000
/ render_n_samples grad_scaler = torch.cuda.amp.GradScaler(1)
).item() radiance_field = DNeRFRadianceField().to(device)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4)
# setup the radiance field we want to train. scheduler = torch.optim.lr_scheduler.MultiStepLR(
max_steps = 30000 optimizer,
grad_scaler = torch.cuda.amp.GradScaler(1) milestones=[
radiance_field = DNeRFRadianceField().to(device) max_steps // 2,
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4) max_steps * 3 // 4,
scheduler = torch.optim.lr_scheduler.MultiStepLR( max_steps * 5 // 6,
optimizer, max_steps * 9 // 10,
milestones=[ ],
max_steps // 2, gamma=0.33,
max_steps * 3 // 4, )
max_steps * 5 // 6, # setup the dataset
max_steps * 9 // 10, target_sample_batch_size = 1 << 16
], grid_resolution = 128
gamma=0.33,
) train_dataset = SubjectLoader(
# setup the dataset subject_id=args.scene,
target_sample_batch_size = 1 << 16 root_fp=args.data_root,
grid_resolution = 128 split=args.train_split,
num_rays=target_sample_batch_size // render_n_samples,
train_dataset = SubjectLoader( device=device,
subject_id=args.scene, )
root_fp=args.data_root,
split=args.train_split, test_dataset = SubjectLoader(
num_rays=target_sample_batch_size // render_n_samples, subject_id=args.scene,
) root_fp=args.data_root,
split="test",
test_dataset = SubjectLoader( num_rays=None,
subject_id=args.scene, device=device,
root_fp=args.data_root, )
split="test",
num_rays=None, occupancy_grid = OccupancyGrid(
) roi_aabb=args.aabb, resolution=grid_resolution
).to(device)
occupancy_grid = OccupancyGrid(
roi_aabb=args.aabb, # training
resolution=grid_resolution, step = 0
contraction_type=contraction_type, tic = time.time()
).to(device) for epoch in range(10000000):
for i in range(len(train_dataset)):
# training radiance_field.train()
step = 0 data = train_dataset[i]
tic = time.time()
for epoch in range(10000000): render_bkgd = data["color_bkgd"]
for i in range(len(train_dataset)): rays = data["rays"]
radiance_field.train() pixels = data["pixels"]
data = train_dataset[i] timestamps = data["timestamps"]
render_bkgd = data["color_bkgd"] # update occupancy grid
rays = data["rays"] occupancy_grid.every_n_step(
pixels = data["pixels"] step=step,
timestamps = data["timestamps"] occ_eval_fn=lambda x: radiance_field.query_opacity(
x, timestamps, render_step_size
# update occupancy grid ),
occupancy_grid.every_n_step( )
step=step,
occ_eval_fn=lambda x: radiance_field.query_opacity( # render
x, timestamps, render_step_size rgb, acc, depth, n_rendering_samples = render_image(
), radiance_field,
occupancy_grid,
rays,
scene_aabb,
# rendering options
near_plane=near_plane,
far_plane=far_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
alpha_thre=0.01 if step > 1000 else 0.00,
# dnerf options
timestamps=timestamps,
)
if n_rendering_samples == 0:
continue
# dynamic batch size for rays to keep sample batch size constant.
num_rays = len(pixels)
num_rays = int(
num_rays * (target_sample_batch_size / float(n_rendering_samples))
)
train_dataset.update_num_rays(num_rays)
alive_ray_mask = acc.squeeze(-1) > 0
# compute loss
loss = F.smooth_l1_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
optimizer.zero_grad()
# do not unscale it because we are using Adam.
grad_scaler.scale(loss).backward()
optimizer.step()
scheduler.step()
if step % 5000 == 0:
elapsed_time = time.time() - tic
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
print(
f"elapsed_time={elapsed_time:.2f}s | step={step} | "
f"loss={loss:.5f} | "
f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |"
) )
# render if step > 0 and step % max_steps == 0:
rgb, acc, depth, n_rendering_samples = render_image( # evaluation
radiance_field, radiance_field.eval()
occupancy_grid,
rays, psnrs = []
scene_aabb, with torch.no_grad():
# rendering options for i in tqdm.tqdm(range(len(test_dataset))):
near_plane=near_plane, data = test_dataset[i]
far_plane=far_plane, render_bkgd = data["color_bkgd"]
render_step_size=render_step_size, rays = data["rays"]
render_bkgd=render_bkgd, pixels = data["pixels"]
cone_angle=args.cone_angle, timestamps = data["timestamps"]
alpha_thre=0.01 if step > 1000 else 0.00,
# dnerf options # rendering
timestamps=timestamps, rgb, acc, depth, _ = render_image(
) radiance_field,
if n_rendering_samples == 0: occupancy_grid,
continue rays,
scene_aabb,
# dynamic batch size for rays to keep sample batch size constant. # rendering options
num_rays = len(pixels) near_plane=None,
num_rays = int( far_plane=None,
num_rays render_step_size=render_step_size,
* (target_sample_batch_size / float(n_rendering_samples)) render_bkgd=render_bkgd,
) cone_angle=args.cone_angle,
train_dataset.update_num_rays(num_rays) alpha_thre=0.01,
alive_ray_mask = acc.squeeze(-1) > 0 # test options
test_chunk_size=args.test_chunk_size,
# compute loss # dnerf options
loss = F.smooth_l1_loss(rgb[alive_ray_mask], pixels[alive_ray_mask]) timestamps=timestamps,
)
optimizer.zero_grad() mse = F.mse_loss(rgb, pixels)
# do not unscale it because we are using Adam. psnr = -10.0 * torch.log(mse) / np.log(10.0)
grad_scaler.scale(loss).backward() psnrs.append(psnr.item())
optimizer.step() # imageio.imwrite(
scheduler.step() # "acc_binary_test.png",
# ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8),
if step % 5000 == 0: # )
elapsed_time = time.time() - tic # imageio.imwrite(
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask]) # "rgb_test.png",
print( # (rgb.cpu().numpy() * 255).astype(np.uint8),
f"elapsed_time={elapsed_time:.2f}s | step={step} | " # )
f"loss={loss:.5f} | " # break
f"alive_ray_mask={alive_ray_mask.long().sum():d} | " psnr_avg = sum(psnrs) / len(psnrs)
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |" print(f"evaluation: psnr_avg={psnr_avg}")
) train_dataset.training = True
if step > 0 and step % max_steps == 0: if step == max_steps:
# evaluation print("training stops")
radiance_field.eval() exit()
psnrs = [] step += 1
with torch.no_grad():
for i in tqdm.tqdm(range(len(test_dataset))):
data = test_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
timestamps = data["timestamps"]
# rendering
rgb, acc, depth, _ = render_image(
radiance_field,
occupancy_grid,
rays,
scene_aabb,
# rendering options
near_plane=None,
far_plane=None,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
alpha_thre=0.01,
# test options
test_chunk_size=args.test_chunk_size,
# dnerf options
timestamps=timestamps,
)
mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item())
# imageio.imwrite(
# "acc_binary_test.png",
# ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8),
# )
# imageio.imwrite(
# "rgb_test.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8),
# )
# break
psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation: psnr_avg={psnr_avg}")
train_dataset.training = True
if step == max_steps:
print("training stops")
exit()
step += 1
...@@ -17,248 +17,245 @@ from utils import render_image, set_random_seed ...@@ -17,248 +17,245 @@ from utils import render_image, set_random_seed
from nerfacc import ContractionType, OccupancyGrid from nerfacc import ContractionType, OccupancyGrid
if __name__ == "__main__": device = "cuda:0"
set_random_seed(42)
device = "cuda:0" parser = argparse.ArgumentParser()
set_random_seed(42) parser.add_argument(
"--data_root",
type=str,
default=str(pathlib.Path.cwd() / "data/nerf_synthetic"),
help="the root dir of the dataset",
)
parser.add_argument(
"--train_split",
type=str,
default="trainval",
choices=["train", "trainval"],
help="which train split to use",
)
parser.add_argument(
"--scene",
type=str,
default="lego",
choices=[
# nerf synthetic
"chair",
"drums",
"ficus",
"hotdog",
"lego",
"materials",
"mic",
"ship",
# mipnerf360 unbounded
"garden",
],
help="which scene to use",
)
parser.add_argument(
"--aabb",
type=lambda s: [float(item) for item in s.split(",")],
default="-1.5,-1.5,-1.5,1.5,1.5,1.5",
help="delimited list input",
)
parser.add_argument(
"--test_chunk_size",
type=int,
default=8192,
)
parser.add_argument(
"--unbounded",
action="store_true",
help="whether to use unbounded rendering",
)
parser.add_argument("--cone_angle", type=float, default=0.0)
args = parser.parse_args()
parser = argparse.ArgumentParser() render_n_samples = 1024
parser.add_argument(
"--data_root",
type=str,
default=str(pathlib.Path.cwd() / "data/nerf_synthetic"),
help="the root dir of the dataset",
)
parser.add_argument(
"--train_split",
type=str,
default="trainval",
choices=["train", "trainval"],
help="which train split to use",
)
parser.add_argument(
"--scene",
type=str,
default="lego",
choices=[
# nerf synthetic
"chair",
"drums",
"ficus",
"hotdog",
"lego",
"materials",
"mic",
"ship",
# mipnerf360 unbounded
"garden",
],
help="which scene to use",
)
parser.add_argument(
"--aabb",
type=lambda s: [float(item) for item in s.split(",")],
default="-1.5,-1.5,-1.5,1.5,1.5,1.5",
help="delimited list input",
)
parser.add_argument(
"--test_chunk_size",
type=int,
default=8192,
)
parser.add_argument(
"--unbounded",
action="store_true",
help="whether to use unbounded rendering",
)
parser.add_argument("--cone_angle", type=float, default=0.0)
args = parser.parse_args()
render_n_samples = 1024 # setup the scene bounding box.
if args.unbounded:
print("Using unbounded rendering")
contraction_type = ContractionType.UN_BOUNDED_SPHERE
# contraction_type = ContractionType.UN_BOUNDED_TANH
scene_aabb = None
near_plane = 0.2
far_plane = 1e4
render_step_size = 1e-2
else:
contraction_type = ContractionType.AABB
scene_aabb = torch.tensor(args.aabb, dtype=torch.float32, device=device)
near_plane = None
far_plane = None
render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max()
* math.sqrt(3)
/ render_n_samples
).item()
# setup the scene bounding box. # setup the radiance field we want to train.
if args.unbounded: max_steps = 50000
print("Using unbounded rendering") grad_scaler = torch.cuda.amp.GradScaler(1)
contraction_type = ContractionType.UN_BOUNDED_SPHERE radiance_field = VanillaNeRFRadianceField().to(device)
# contraction_type = ContractionType.UN_BOUNDED_TANH optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4)
scene_aabb = None scheduler = torch.optim.lr_scheduler.MultiStepLR(
near_plane = 0.2 optimizer,
far_plane = 1e4 milestones=[
render_step_size = 1e-2 max_steps // 2,
else: max_steps * 3 // 4,
contraction_type = ContractionType.AABB max_steps * 5 // 6,
scene_aabb = torch.tensor(args.aabb, dtype=torch.float32, device=device) max_steps * 9 // 10,
near_plane = None ],
far_plane = None gamma=0.33,
render_step_size = ( )
(scene_aabb[3:] - scene_aabb[:3]).max()
* math.sqrt(3)
/ render_n_samples
).item()
# setup the radiance field we want to train. # setup the dataset
max_steps = 50000 train_dataset_kwargs = {}
grad_scaler = torch.cuda.amp.GradScaler(1) test_dataset_kwargs = {}
radiance_field = VanillaNeRFRadianceField().to(device) if args.scene == "garden":
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4) from datasets.nerf_360_v2 import SubjectLoader
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
max_steps // 2,
max_steps * 3 // 4,
max_steps * 5 // 6,
max_steps * 9 // 10,
],
gamma=0.33,
)
# setup the dataset target_sample_batch_size = 1 << 16
train_dataset_kwargs = {} train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4}
test_dataset_kwargs = {} test_dataset_kwargs = {"factor": 4}
if args.scene == "garden": grid_resolution = 128
from datasets.nerf_360_v2 import SubjectLoader else:
from datasets.nerf_synthetic import SubjectLoader
target_sample_batch_size = 1 << 16 target_sample_batch_size = 1 << 16
train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4} grid_resolution = 128
test_dataset_kwargs = {"factor": 4}
grid_resolution = 128
else:
from datasets.nerf_synthetic import SubjectLoader
target_sample_batch_size = 1 << 16 train_dataset = SubjectLoader(
grid_resolution = 128 subject_id=args.scene,
root_fp=args.data_root,
split=args.train_split,
num_rays=target_sample_batch_size // render_n_samples,
**train_dataset_kwargs,
)
train_dataset = SubjectLoader( test_dataset = SubjectLoader(
subject_id=args.scene, subject_id=args.scene,
root_fp=args.data_root, root_fp=args.data_root,
split=args.train_split, split="test",
num_rays=target_sample_batch_size // render_n_samples, num_rays=None,
**train_dataset_kwargs, **test_dataset_kwargs,
) )
test_dataset = SubjectLoader( occupancy_grid = OccupancyGrid(
subject_id=args.scene, roi_aabb=args.aabb,
root_fp=args.data_root, resolution=grid_resolution,
split="test", contraction_type=contraction_type,
num_rays=None, ).to(device)
**test_dataset_kwargs,
)
occupancy_grid = OccupancyGrid( # training
roi_aabb=args.aabb, step = 0
resolution=grid_resolution, tic = time.time()
contraction_type=contraction_type, for epoch in range(10000000):
).to(device) for i in range(len(train_dataset)):
radiance_field.train()
data = train_dataset[i]
# training render_bkgd = data["color_bkgd"]
step = 0 rays = data["rays"]
tic = time.time() pixels = data["pixels"]
for epoch in range(10000000):
for i in range(len(train_dataset)):
radiance_field.train()
data = train_dataset[i]
render_bkgd = data["color_bkgd"] # update occupancy grid
rays = data["rays"] occupancy_grid.every_n_step(
pixels = data["pixels"] step=step,
occ_eval_fn=lambda x: radiance_field.query_opacity(
x, render_step_size
),
)
# update occupancy grid # render
occupancy_grid.every_n_step( rgb, acc, depth, n_rendering_samples = render_image(
step=step, radiance_field,
occ_eval_fn=lambda x: radiance_field.query_opacity( occupancy_grid,
x, render_step_size rays,
), scene_aabb,
) # rendering options
near_plane=near_plane,
far_plane=far_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
)
if n_rendering_samples == 0:
continue
# render # dynamic batch size for rays to keep sample batch size constant.
rgb, acc, depth, n_rendering_samples = render_image( num_rays = len(pixels)
radiance_field, num_rays = int(
occupancy_grid, num_rays * (target_sample_batch_size / float(n_rendering_samples))
rays, )
scene_aabb, train_dataset.update_num_rays(num_rays)
# rendering options alive_ray_mask = acc.squeeze(-1) > 0
near_plane=near_plane,
far_plane=far_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
)
if n_rendering_samples == 0:
continue
# dynamic batch size for rays to keep sample batch size constant. # compute loss
num_rays = len(pixels) loss = F.smooth_l1_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
num_rays = int(
num_rays
* (target_sample_batch_size / float(n_rendering_samples))
)
train_dataset.update_num_rays(num_rays)
alive_ray_mask = acc.squeeze(-1) > 0
# compute loss optimizer.zero_grad()
loss = F.smooth_l1_loss(rgb[alive_ray_mask], pixels[alive_ray_mask]) # do not unscale it because we are using Adam.
grad_scaler.scale(loss).backward()
optimizer.step()
scheduler.step()
optimizer.zero_grad() if step % 5000 == 0:
# do not unscale it because we are using Adam. elapsed_time = time.time() - tic
grad_scaler.scale(loss).backward() loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
optimizer.step() print(
scheduler.step() f"elapsed_time={elapsed_time:.2f}s | step={step} | "
f"loss={loss:.5f} | "
if step % 5000 == 0: f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
elapsed_time = time.time() - tic f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |"
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask]) )
print(
f"elapsed_time={elapsed_time:.2f}s | step={step} | "
f"loss={loss:.5f} | "
f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |"
)
if step > 0 and step % max_steps == 0: if step > 0 and step % max_steps == 0:
# evaluation # evaluation
radiance_field.eval() radiance_field.eval()
psnrs = [] psnrs = []
with torch.no_grad(): with torch.no_grad():
for i in tqdm.tqdm(range(len(test_dataset))): for i in tqdm.tqdm(range(len(test_dataset))):
data = test_dataset[i] data = test_dataset[i]
render_bkgd = data["color_bkgd"] render_bkgd = data["color_bkgd"]
rays = data["rays"] rays = data["rays"]
pixels = data["pixels"] pixels = data["pixels"]
# rendering # rendering
rgb, acc, depth, _ = render_image( rgb, acc, depth, _ = render_image(
radiance_field, radiance_field,
occupancy_grid, occupancy_grid,
rays, rays,
scene_aabb, scene_aabb,
# rendering options # rendering options
near_plane=None, near_plane=None,
far_plane=None, far_plane=None,
render_step_size=render_step_size, render_step_size=render_step_size,
render_bkgd=render_bkgd, render_bkgd=render_bkgd,
cone_angle=args.cone_angle, cone_angle=args.cone_angle,
# test options # test options
test_chunk_size=args.test_chunk_size, test_chunk_size=args.test_chunk_size,
) )
mse = F.mse_loss(rgb, pixels) mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0) psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item()) psnrs.append(psnr.item())
# imageio.imwrite( # imageio.imwrite(
# "acc_binary_test.png", # "acc_binary_test.png",
# ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8), # ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8),
# ) # )
# imageio.imwrite( # imageio.imwrite(
# "rgb_test.png", # "rgb_test.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8), # (rgb.cpu().numpy() * 255).astype(np.uint8),
# ) # )
# break # break
psnr_avg = sum(psnrs) / len(psnrs) psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation: psnr_avg={psnr_avg}") print(f"evaluation: psnr_avg={psnr_avg}")
train_dataset.training = True train_dataset.training = True
if step == max_steps: if step == max_steps:
print("training stops") print("training stops")
exit() exit()
step += 1 step += 1
...@@ -12,297 +12,259 @@ import numpy as np ...@@ -12,297 +12,259 @@ import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tqdm import tqdm
from radiance_fields.ngp import NGPradianceField from lpips import LPIPS
from utils import render_image, set_random_seed from radiance_fields.ngp import NGPRadianceField
from utils import (
MIPNERF360_UNBOUNDED_SCENES,
NERF_SYNTHETIC_SCENES,
enlarge_aabb,
render_image,
set_random_seed,
)
from nerfacc import ContractionType, OccupancyGrid from nerfacc import OccupancyGrid
if __name__ == "__main__": parser = argparse.ArgumentParser()
parser.add_argument(
"--data_root",
type=str,
# default=str(pathlib.Path.cwd() / "data/360_v2"),
default=str(pathlib.Path.cwd() / "data/nerf_synthetic"),
help="the root dir of the dataset",
)
parser.add_argument(
"--train_split",
type=str,
default="train",
choices=["train", "trainval"],
help="which train split to use",
)
parser.add_argument(
"--scene",
type=str,
default="lego",
choices=NERF_SYNTHETIC_SCENES + MIPNERF360_UNBOUNDED_SCENES,
help="which scene to use",
)
parser.add_argument(
"--test_chunk_size",
type=int,
default=8192,
)
args = parser.parse_args()
device = "cuda:0" device = "cuda:0"
set_random_seed(42) set_random_seed(42)
parser = argparse.ArgumentParser() if args.scene in MIPNERF360_UNBOUNDED_SCENES:
parser.add_argument( from datasets.nerf_360_v2 import SubjectLoader
"--data_root",
type=str,
default=str(pathlib.Path.cwd() / "data"),
help="the root dir of the dataset",
)
parser.add_argument(
"--train_split",
type=str,
default="trainval",
choices=["train", "trainval"],
help="which train split to use",
)
parser.add_argument(
"--scene",
type=str,
default="lego",
choices=[
# nerf synthetic
"chair",
"drums",
"ficus",
"hotdog",
"lego",
"materials",
"mic",
"ship",
# mipnerf360 unbounded
"garden",
"bicycle",
"bonsai",
"counter",
"kitchen",
"room",
"stump",
],
help="which scene to use",
)
parser.add_argument(
"--aabb",
type=lambda s: [float(item) for item in s.split(",")],
default="-1.5,-1.5,-1.5,1.5,1.5,1.5",
help="delimited list input",
)
parser.add_argument(
"--test_chunk_size",
type=int,
default=8192,
)
parser.add_argument(
"--unbounded",
action="store_true",
help="whether to use unbounded rendering",
)
parser.add_argument(
"--auto_aabb",
action="store_true",
help="whether to automatically compute the aabb",
)
parser.add_argument("--cone_angle", type=float, default=0.0)
args = parser.parse_args()
render_n_samples = 1024
# setup the dataset
train_dataset_kwargs = {}
test_dataset_kwargs = {}
if args.unbounded:
from datasets.nerf_360_v2 import SubjectLoader
target_sample_batch_size = 1 << 20 # training parameters
train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4} max_steps = 100000
test_dataset_kwargs = {"factor": 4} init_batch_size = 1024
grid_resolution = 256 target_sample_batch_size = 1 << 18
else: weight_decay = 0.0
from datasets.nerf_synthetic import SubjectLoader # scene parameters
aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=device)
near_plane = 0.02
far_plane = None
# dataset parameters
train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4}
test_dataset_kwargs = {"factor": 4}
# model parameters
grid_resolution = 128
grid_nlvl = 4
# render parameters
render_step_size = 1e-3
alpha_thre = 1e-2
cone_angle = 0.004
target_sample_batch_size = 1 << 18 else:
grid_resolution = 128 from datasets.nerf_synthetic import SubjectLoader
train_dataset = SubjectLoader( # training parameters
subject_id=args.scene, max_steps = 20000
root_fp=args.data_root, init_batch_size = 1024
split=args.train_split, target_sample_batch_size = 1 << 18
num_rays=target_sample_batch_size // render_n_samples, weight_decay = (
**train_dataset_kwargs, 1e-5 if args.scene in ["materials", "ficus", "drums"] else 1e-6
)
test_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split="test",
num_rays=None,
**test_dataset_kwargs,
) )
# scene parameters
aabb = torch.tensor([-1.5, -1.5, -1.5, 1.5, 1.5, 1.5], device=device)
near_plane = None
far_plane = None
# dataset parameters
train_dataset_kwargs = {}
test_dataset_kwargs = {}
# model parameters
grid_resolution = 128
grid_nlvl = 1
# render parameters
render_step_size = 5e-3
alpha_thre = 0.0
cone_angle = 0.0
if args.auto_aabb: train_dataset = SubjectLoader(
camera_locs = torch.cat( subject_id=args.scene,
[train_dataset.camtoworlds, test_dataset.camtoworlds] root_fp=args.data_root,
)[:, :3, -1] split=args.train_split,
args.aabb = torch.cat( num_rays=init_batch_size,
[camera_locs.min(dim=0).values, camera_locs.max(dim=0).values] device=device,
).tolist() **train_dataset_kwargs,
print("Using auto aabb", args.aabb) )
# setup the scene bounding box. test_dataset = SubjectLoader(
if args.unbounded: subject_id=args.scene,
print("Using unbounded rendering") root_fp=args.data_root,
contraction_type = ContractionType.UN_BOUNDED_SPHERE split="test",
# contraction_type = ContractionType.UN_BOUNDED_TANH num_rays=None,
scene_aabb = None device=device,
near_plane = 0.2 **test_dataset_kwargs,
far_plane = 1e4 )
render_step_size = 1e-2
alpha_thre = 1e-2
else:
contraction_type = ContractionType.AABB
scene_aabb = torch.tensor(args.aabb, dtype=torch.float32, device=device)
near_plane = None
far_plane = None
render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max()
* math.sqrt(3)
/ render_n_samples
).item()
alpha_thre = 0.0
# setup the radiance field we want to train. # setup scene aabb
max_steps = 20000 scene_aabb = enlarge_aabb(aabb, 1 << (grid_nlvl - 1))
grad_scaler = torch.cuda.amp.GradScaler(2**10)
radiance_field = NGPradianceField(
aabb=args.aabb,
unbounded=args.unbounded,
).to(device)
optimizer = torch.optim.Adam(
radiance_field.parameters(), lr=1e-2, eps=1e-15
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[max_steps // 2, max_steps * 3 // 4, max_steps * 9 // 10],
gamma=0.33,
)
occupancy_grid = OccupancyGrid( # setup the radiance field we want to train.
roi_aabb=args.aabb, grad_scaler = torch.cuda.amp.GradScaler(2**10)
resolution=grid_resolution, radiance_field = NGPRadianceField(aabb=scene_aabb).to(device)
contraction_type=contraction_type, optimizer = torch.optim.Adam(
).to(device) radiance_field.parameters(), lr=1e-2, eps=1e-15, weight_decay=weight_decay
)
scheduler = torch.optim.lr_scheduler.ChainedScheduler(
[
torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=0.01, total_iters=100
),
torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
max_steps // 2,
max_steps * 3 // 4,
max_steps * 9 // 10,
],
gamma=0.33,
),
]
)
occupancy_grid = OccupancyGrid(
roi_aabb=aabb, resolution=grid_resolution, levels=grid_nlvl
).to(device)
# training lpips_net = LPIPS(net="vgg").to(device)
step = 0 lpips_norm_fn = lambda x: x[None, ...].permute(0, 3, 1, 2) * 2 - 1
tic = time.time() lpips_fn = lambda x, y: lpips_net(lpips_norm_fn(x), lpips_norm_fn(y)).mean()
for epoch in range(10000000):
for i in range(len(train_dataset)):
radiance_field.train()
data = train_dataset[i]
render_bkgd = data["color_bkgd"] # training
rays = data["rays"] tic = time.time()
pixels = data["pixels"] for step in range(max_steps + 1):
radiance_field.train()
def occ_eval_fn(x): i = torch.randint(0, len(train_dataset), (1,)).item()
if args.cone_angle > 0.0: data = train_dataset[i]
# randomly sample a camera for computing step size.
camera_ids = torch.randint(
0, len(train_dataset), (x.shape[0],), device=device
)
origins = train_dataset.camtoworlds[camera_ids, :3, -1]
t = (origins - x).norm(dim=-1, keepdim=True)
# compute actual step size used in marching, based on the distance to the camera.
step_size = torch.clamp(
t * args.cone_angle, min=render_step_size
)
# filter out the points that are not in the near far plane.
if (near_plane is not None) and (far_plane is not None):
step_size = torch.where(
(t > near_plane) & (t < far_plane),
step_size,
torch.zeros_like(step_size),
)
else:
step_size = render_step_size
# compute occupancy
density = radiance_field.query_density(x)
return density * step_size
# update occupancy grid render_bkgd = data["color_bkgd"]
occupancy_grid.every_n_step(step=step, occ_eval_fn=occ_eval_fn) rays = data["rays"]
pixels = data["pixels"]
# render def occ_eval_fn(x):
rgb, acc, depth, n_rendering_samples = render_image( density = radiance_field.query_density(x)
radiance_field, return density * render_step_size
occupancy_grid,
rays,
scene_aabb,
# rendering options
near_plane=near_plane,
far_plane=far_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
alpha_thre=alpha_thre,
)
if n_rendering_samples == 0:
continue
# dynamic batch size for rays to keep sample batch size constant. # update occupancy grid
num_rays = len(pixels) occupancy_grid.every_n_step(
num_rays = int( step=step,
num_rays occ_eval_fn=occ_eval_fn,
* (target_sample_batch_size / float(n_rendering_samples)) occ_thre=1e-2,
) )
train_dataset.update_num_rays(num_rays)
alive_ray_mask = acc.squeeze(-1) > 0
# compute loss # render
loss = F.smooth_l1_loss(rgb[alive_ray_mask], pixels[alive_ray_mask]) rgb, acc, depth, n_rendering_samples = render_image(
radiance_field,
occupancy_grid,
rays,
scene_aabb=scene_aabb,
# rendering options
near_plane=near_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=cone_angle,
alpha_thre=alpha_thre,
)
if n_rendering_samples == 0:
continue
optimizer.zero_grad() if target_sample_batch_size > 0:
# do not unscale it because we are using Adam. # dynamic batch size for rays to keep sample batch size constant.
grad_scaler.scale(loss).backward() num_rays = len(pixels)
optimizer.step() num_rays = int(
scheduler.step() num_rays * (target_sample_batch_size / float(n_rendering_samples))
)
train_dataset.update_num_rays(num_rays)
if step % 10000 == 0: # compute loss
elapsed_time = time.time() - tic loss = F.smooth_l1_loss(rgb, pixels)
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
print(
f"elapsed_time={elapsed_time:.2f}s | step={step} | "
f"loss={loss:.5f} | "
f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |"
)
if step > 0 and step % max_steps == 0: optimizer.zero_grad()
# evaluation # do not unscale it because we are using Adam.
radiance_field.eval() grad_scaler.scale(loss).backward()
optimizer.step()
scheduler.step()
psnrs = [] if step % 5000 == 0:
with torch.no_grad(): elapsed_time = time.time() - tic
for i in tqdm.tqdm(range(len(test_dataset))): loss = F.mse_loss(rgb, pixels)
data = test_dataset[i] psnr = -10.0 * torch.log(loss) / np.log(10.0)
render_bkgd = data["color_bkgd"] print(
rays = data["rays"] f"elapsed_time={elapsed_time:.2f}s | step={step} | "
pixels = data["pixels"] f"loss={loss:.5f} | psnr={psnr:.2f} | "
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} | "
f"max_depth={depth.max():.3f} | "
)
# rendering if step > 0 and step % max_steps == 0:
rgb, acc, depth, _ = render_image( # evaluation
radiance_field, radiance_field.eval()
occupancy_grid,
rays,
scene_aabb,
# rendering options
near_plane=near_plane,
far_plane=far_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
alpha_thre=alpha_thre,
# test options
test_chunk_size=args.test_chunk_size,
)
mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item())
# imageio.imwrite(
# "acc_binary_test.png",
# ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8),
# )
# imageio.imwrite(
# "rgb_test.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8),
# )
# break
psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation: psnr_avg={psnr_avg}")
train_dataset.training = True
if step == max_steps: psnrs = []
print("training stops") lpips = []
exit() with torch.no_grad():
for i in tqdm.tqdm(range(len(test_dataset))):
data = test_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
step += 1 # rendering
rgb, acc, depth, _ = render_image(
radiance_field,
occupancy_grid,
rays,
scene_aabb=scene_aabb,
# rendering options
near_plane=near_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=cone_angle,
alpha_thre=alpha_thre,
# test options
test_chunk_size=args.test_chunk_size,
)
mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item())
lpips.append(lpips_fn(rgb, pixels).item())
# if i == 0:
# imageio.imwrite(
# "rgb_test.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8),
# )
# imageio.imwrite(
# "rgb_error.png",
# (
# (rgb - pixels).norm(dim=-1).cpu().numpy() * 255
# ).astype(np.uint8),
# )
psnr_avg = sum(psnrs) / len(psnrs)
lpips_avg = sum(lpips) / len(lpips)
print(f"evaluation: psnr_avg={psnr_avg}, lpips_avg={lpips_avg}")
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
import argparse
import itertools
import pathlib
import time
import imageio
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from lpips import LPIPS
from radiance_fields.ngp import NGPDensityField, NGPRadianceField
from utils import (
MIPNERF360_UNBOUNDED_SCENES,
NERF_SYNTHETIC_SCENES,
render_image_proposal,
set_random_seed,
)
from nerfacc.proposal import (
compute_prop_loss,
get_proposal_annealing_fn,
get_proposal_requires_grad_fn,
)
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_root",
type=str,
# default=str(pathlib.Path.cwd() / "data/360_v2"),
default=str(pathlib.Path.cwd() / "data/nerf_synthetic"),
help="the root dir of the dataset",
)
parser.add_argument(
"--train_split",
type=str,
default="train",
choices=["train", "trainval"],
help="which train split to use",
)
parser.add_argument(
"--scene",
type=str,
default="lego",
choices=NERF_SYNTHETIC_SCENES + MIPNERF360_UNBOUNDED_SCENES,
help="which scene to use",
)
parser.add_argument(
"--test_chunk_size",
type=int,
default=8192,
)
args = parser.parse_args()
device = "cuda:0"
set_random_seed(42)
if args.scene in MIPNERF360_UNBOUNDED_SCENES:
from datasets.nerf_360_v2 import SubjectLoader
# training parameters
max_steps = 100000
init_batch_size = 4096
weight_decay = 0.0
# scene parameters
unbounded = True
aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=device)
near_plane = 0.2 # TODO: Try 0.02
far_plane = 1e3
# dataset parameters
train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4}
test_dataset_kwargs = {"factor": 4}
# model parameters
proposal_networks = [
NGPDensityField(
aabb=aabb,
unbounded=unbounded,
n_levels=5,
max_resolution=128,
).to(device),
NGPDensityField(
aabb=aabb,
unbounded=unbounded,
n_levels=5,
max_resolution=256,
).to(device),
]
# render parameters
num_samples = 48
num_samples_per_prop = [256, 96]
sampling_type = "lindisp"
opaque_bkgd = True
else:
from datasets.nerf_synthetic import SubjectLoader
# training parameters
max_steps = 20000
init_batch_size = 4096
weight_decay = (
1e-5 if args.scene in ["materials", "ficus", "drums"] else 1e-6
)
# scene parameters
unbounded = False
aabb = torch.tensor([-1.5, -1.5, -1.5, 1.5, 1.5, 1.5], device=device)
near_plane = 2.0
far_plane = 6.0
# dataset parameters
train_dataset_kwargs = {}
test_dataset_kwargs = {}
# model parameters
proposal_networks = [
NGPDensityField(
aabb=aabb,
unbounded=unbounded,
n_levels=5,
max_resolution=128,
).to(device),
]
# render parameters
num_samples = 64
num_samples_per_prop = [128]
sampling_type = "uniform"
opaque_bkgd = False
train_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split=args.train_split,
num_rays=init_batch_size,
device=device,
**train_dataset_kwargs,
)
test_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split="test",
num_rays=None,
device=device,
**test_dataset_kwargs,
)
# setup the radiance field we want to train.
grad_scaler = torch.cuda.amp.GradScaler(2**10)
radiance_field = NGPRadianceField(aabb=aabb, unbounded=unbounded).to(device)
optimizer = torch.optim.Adam(
itertools.chain(
radiance_field.parameters(),
*[p.parameters() for p in proposal_networks],
),
lr=1e-2,
eps=1e-15,
weight_decay=weight_decay,
)
scheduler = torch.optim.lr_scheduler.ChainedScheduler(
[
torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=0.01, total_iters=100
),
torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
max_steps // 2,
max_steps * 3 // 4,
max_steps * 9 // 10,
],
gamma=0.33,
),
]
)
proposal_requires_grad_fn = get_proposal_requires_grad_fn()
proposal_annealing_fn = get_proposal_annealing_fn()
lpips_net = LPIPS(net="vgg").to(device)
lpips_norm_fn = lambda x: x[None, ...].permute(0, 3, 1, 2) * 2 - 1
lpips_fn = lambda x, y: lpips_net(lpips_norm_fn(x), lpips_norm_fn(y)).mean()
# training
tic = time.time()
for step in range(max_steps + 1):
radiance_field.train()
for p in proposal_networks:
p.train()
i = torch.randint(0, len(train_dataset), (1,)).item()
data = train_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
# render
(
rgb,
acc,
depth,
weights_per_level,
s_vals_per_level,
) = render_image_proposal(
radiance_field,
proposal_networks,
rays,
scene_aabb=None,
# rendering options
num_samples=num_samples,
num_samples_per_prop=num_samples_per_prop,
near_plane=near_plane,
far_plane=far_plane,
sampling_type=sampling_type,
opaque_bkgd=opaque_bkgd,
render_bkgd=render_bkgd,
# train options
proposal_requires_grad=proposal_requires_grad_fn(step),
proposal_annealing=proposal_annealing_fn(step),
)
# compute loss
loss = F.smooth_l1_loss(rgb, pixels)
loss_prop = compute_prop_loss(s_vals_per_level, weights_per_level)
loss = loss + loss_prop
optimizer.zero_grad()
# do not unscale it because we are using Adam.
grad_scaler.scale(loss).backward()
optimizer.step()
scheduler.step()
if step % 5000 == 0:
elapsed_time = time.time() - tic
loss = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(loss) / np.log(10.0)
print(
f"elapsed_time={elapsed_time:.2f}s | step={step} | "
f"loss={loss:.5f} | psnr={psnr:.2f} | "
f"num_rays={len(pixels):d} | "
f"max_depth={depth.max():.3f} | "
)
if step > 0 and step % max_steps == 0:
# evaluation
radiance_field.eval()
for p in proposal_networks:
p.eval()
psnrs = []
lpips = []
with torch.no_grad():
for i in tqdm.tqdm(range(len(test_dataset))):
data = test_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
# rendering
rgb, acc, depth, _, _, = render_image_proposal(
radiance_field,
proposal_networks,
rays,
scene_aabb=None,
# rendering options
num_samples=num_samples,
num_samples_per_prop=num_samples_per_prop,
near_plane=near_plane,
far_plane=far_plane,
sampling_type=sampling_type,
opaque_bkgd=opaque_bkgd,
render_bkgd=render_bkgd,
proposal_annealing=proposal_annealing_fn(step),
# test options
test_chunk_size=args.test_chunk_size,
)
mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item())
lpips.append(lpips_fn(rgb, pixels).item())
# if i == 0:
# imageio.imwrite(
# "rgb_test.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8),
# )
# imageio.imwrite(
# "rgb_error.png",
# (
# (rgb - pixels).norm(dim=-1).cpu().numpy() * 255
# ).astype(np.uint8),
# )
psnr_avg = sum(psnrs) / len(psnrs)
lpips_avg = sum(lpips) / len(lpips)
print(f"evaluation: psnr_avg={psnr_avg}, lpips_avg={lpips_avg}")
...@@ -3,13 +3,35 @@ Copyright (c) 2022 Ruilong Li, UC Berkeley. ...@@ -3,13 +3,35 @@ Copyright (c) 2022 Ruilong Li, UC Berkeley.
""" """
import random import random
from typing import Optional from typing import Literal, Optional, Sequence
import numpy as np import numpy as np
import torch import torch
from datasets.utils import Rays, namedtuple_map from datasets.utils import Rays, namedtuple_map
from torch.utils.data._utils.collate import collate, default_collate_fn_map
from nerfacc import OccupancyGrid, ray_marching, rendering from nerfacc import OccupancyGrid, ray_marching, rendering
from nerfacc.proposal import rendering as rendering_proposal
NERF_SYNTHETIC_SCENES = [
"chair",
"drums",
"ficus",
"hotdog",
"lego",
"materials",
"mic",
"ship",
]
MIPNERF360_UNBOUNDED_SCENES = [
"garden",
"bicycle",
"bonsai",
"counter",
"kitchen",
"room",
"stump",
]
def set_random_seed(seed): def set_random_seed(seed):
...@@ -18,6 +40,12 @@ def set_random_seed(seed): ...@@ -18,6 +40,12 @@ def set_random_seed(seed):
torch.manual_seed(seed) torch.manual_seed(seed)
def enlarge_aabb(aabb, factor: float) -> torch.Tensor:
center = (aabb[:3] + aabb[3:]) / 2
extent = (aabb[3:] - aabb[:3]) / 2
return torch.cat([center - extent * factor, center + extent * factor])
def render_image( def render_image(
# scene # scene
radiance_field: torch.nn.Module, radiance_field: torch.nn.Module,
...@@ -116,3 +144,99 @@ def render_image( ...@@ -116,3 +144,99 @@ def render_image(
depths.view((*rays_shape[:-1], -1)), depths.view((*rays_shape[:-1], -1)),
sum(n_rendering_samples), sum(n_rendering_samples),
) )
def render_image_proposal(
# scene
radiance_field: torch.nn.Module,
proposal_networks: Sequence[torch.nn.Module],
rays: Rays,
scene_aabb: torch.Tensor,
# rendering options
num_samples: int,
num_samples_per_prop: Sequence[int],
near_plane: Optional[float] = None,
far_plane: Optional[float] = None,
sampling_type: Literal["uniform", "lindisp"] = "lindisp",
opaque_bkgd: bool = True,
render_bkgd: Optional[torch.Tensor] = None,
# train options
proposal_requires_grad: bool = False,
proposal_annealing: float = 1.0,
# test options
test_chunk_size: int = 8192,
):
"""Render the pixels of an image."""
rays_shape = rays.origins.shape
if len(rays_shape) == 3:
height, width, _ = rays_shape
num_rays = height * width
rays = namedtuple_map(
lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays
)
else:
num_rays, _ = rays_shape
def prop_sigma_fn(t_starts, t_ends, proposal_network):
t_origins = chunk_rays.origins[..., None, :]
t_dirs = chunk_rays.viewdirs[..., None, :]
positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0
return proposal_network(positions)
def rgb_sigma_fn(t_starts, t_ends):
t_origins = chunk_rays.origins[..., None, :]
t_dirs = chunk_rays.viewdirs[..., None, :].repeat_interleave(
t_starts.shape[-2], dim=-2
)
positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0
return radiance_field(positions, t_dirs)
results = []
chunk = (
torch.iinfo(torch.int32).max
if radiance_field.training
else test_chunk_size
)
for i in range(0, num_rays, chunk):
chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
(
rgb,
opacity,
depth,
(weights_per_level, s_vals_per_level),
) = rendering_proposal(
rgb_sigma_fn=rgb_sigma_fn,
num_samples=num_samples,
prop_sigma_fns=[
lambda *args: prop_sigma_fn(*args, p) for p in proposal_networks
],
num_samples_per_prop=num_samples_per_prop,
rays_o=chunk_rays.origins,
rays_d=chunk_rays.viewdirs,
scene_aabb=scene_aabb,
near_plane=near_plane,
far_plane=far_plane,
stratified=radiance_field.training,
sampling_type=sampling_type,
opaque_bkgd=opaque_bkgd,
render_bkgd=render_bkgd,
proposal_requires_grad=proposal_requires_grad,
proposal_annealing=proposal_annealing,
)
chunk_results = [rgb, opacity, depth]
results.append(chunk_results)
colors, opacities, depths = collate(
results,
collate_fn_map={
**default_collate_fn_map,
torch.Tensor: lambda x, **_: torch.cat(x, 0),
},
)
return (
colors.view((*rays_shape[:-1], -1)),
opacities.view((*rays_shape[:-1], -1)),
depths.view((*rays_shape[:-1], -1)),
weights_per_level,
s_vals_per_level,
)
""" """
Copyright (c) 2022 Ruilong Li, UC Berkeley. Copyright (c) 2022 Ruilong Li, UC Berkeley.
""" """
import warnings
from .cdf import ray_resampling
from .contraction import ContractionType, contract, contract_inv from .contraction import ContractionType, contract, contract_inv
from .grid import Grid, OccupancyGrid, query_grid from .grid import Grid, OccupancyGrid, query_grid
from .intersection import ray_aabb_intersect from .intersection import ray_aabb_intersect
from .losses import distortion as loss_distortion
from .pack import pack_data, pack_info, unpack_data, unpack_info from .pack import pack_data, pack_info, unpack_data, unpack_info
from .ray_marching import ray_marching from .ray_marching import ray_marching
from .version import __version__ from .version import __version__
...@@ -21,39 +17,30 @@ from .vol_rendering import ( ...@@ -21,39 +17,30 @@ from .vol_rendering import (
rendering, rendering,
) )
# About to be deprecated
def unpack_to_ray_indices(*args, **kwargs):
warnings.warn(
"`unpack_to_ray_indices` will be deprecated. Please use `unpack_info` instead.",
DeprecationWarning,
stacklevel=2,
)
return unpack_info(*args, **kwargs)
__all__ = [ __all__ = [
"__version__", "__version__",
# occ grid
"Grid", "Grid",
"OccupancyGrid", "OccupancyGrid",
"query_grid", "query_grid",
"ContractionType", "ContractionType",
# contraction
"contract", "contract",
"contract_inv", "contract_inv",
# marching
"ray_aabb_intersect", "ray_aabb_intersect",
"ray_marching", "ray_marching",
# rendering
"accumulate_along_rays", "accumulate_along_rays",
"render_visibility", "render_visibility",
"render_weight_from_alpha", "render_weight_from_alpha",
"render_weight_from_density", "render_weight_from_density",
"render_transmittance_from_density",
"render_transmittance_from_alpha",
"rendering", "rendering",
# pack
"pack_data", "pack_data",
"unpack_data", "unpack_data",
"unpack_info", "unpack_info",
"pack_info", "pack_info",
"ray_resampling",
"loss_distortion",
"unpack_to_ray_indices",
"render_transmittance_from_density",
"render_transmittance_from_alpha",
] ]
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
from typing import Tuple
from torch import Tensor
import nerfacc.cuda as _C
def ray_resampling(
packed_info: Tensor,
t_starts: Tensor,
t_ends: Tensor,
weights: Tensor,
n_samples: int,
) -> Tuple[Tensor, Tensor, Tensor]:
"""Resample a set of rays based on the CDF of the weights.
Args:
packed_info (Tensor): Stores information on which samples belong to the same ray. \
See :func:`nerfacc.ray_marching` for details. Tensor with shape (n_rays, 2).
t_starts: Where the frustum-shape sample starts along a ray. Tensor with \
shape (n_samples, 1).
t_ends: Where the frustum-shape sample ends along a ray. Tensor with \
shape (n_samples, 1).
weights: Volumetric rendering weights for those samples. Tensor with shape \
(n_samples,).
n_samples (int): Number of samples per ray to resample.
Returns:
Resampled packed info (n_rays, 2), t_starts (n_samples, 1), and t_ends (n_samples, 1).
"""
(
resampled_packed_info,
resampled_t_starts,
resampled_t_ends,
) = _C.ray_resampling(
packed_info.contiguous(),
t_starts.contiguous(),
t_ends.contiguous(),
weights.contiguous(),
n_samples,
)
return resampled_packed_info, resampled_t_starts, resampled_t_ends
...@@ -23,7 +23,6 @@ grid_query = _make_lazy_cuda_func("grid_query") ...@@ -23,7 +23,6 @@ grid_query = _make_lazy_cuda_func("grid_query")
ray_aabb_intersect = _make_lazy_cuda_func("ray_aabb_intersect") ray_aabb_intersect = _make_lazy_cuda_func("ray_aabb_intersect")
ray_marching = _make_lazy_cuda_func("ray_marching") ray_marching = _make_lazy_cuda_func("ray_marching")
ray_resampling = _make_lazy_cuda_func("ray_resampling")
is_cub_available = _make_lazy_cuda_func("is_cub_available") is_cub_available = _make_lazy_cuda_func("is_cub_available")
transmittance_from_sigma_forward_cub = _make_lazy_cuda_func( transmittance_from_sigma_forward_cub = _make_lazy_cuda_func(
...@@ -68,3 +67,6 @@ weight_from_alpha_backward_naive = _make_lazy_cuda_func( ...@@ -68,3 +67,6 @@ weight_from_alpha_backward_naive = _make_lazy_cuda_func(
unpack_data = _make_lazy_cuda_func("unpack_data") unpack_data = _make_lazy_cuda_func("unpack_data")
unpack_info = _make_lazy_cuda_func("unpack_info") unpack_info = _make_lazy_cuda_func("unpack_info")
unpack_info_to_mask = _make_lazy_cuda_func("unpack_info_to_mask") unpack_info_to_mask = _make_lazy_cuda_func("unpack_info_to_mask")
pdf_readout = _make_lazy_cuda_func("pdf_readout")
pdf_sampling = _make_lazy_cuda_func("pdf_sampling")
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include "include/helpers_cuda.h"
template <typename scalar_t>
__global__ void cdf_resampling_kernel(
const uint32_t n_rays,
const int *packed_info, // input ray & point indices.
const scalar_t *starts, // input start t
const scalar_t *ends, // input end t
const scalar_t *weights, // transmittance weights
const int *resample_packed_info,
scalar_t *resample_starts,
scalar_t *resample_ends)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
const int resample_base = resample_packed_info[i * 2 + 0];
const int resample_steps = resample_packed_info[i * 2 + 1];
if (steps == 0)
return;
starts += base;
ends += base;
weights += base;
resample_starts += resample_base;
resample_ends += resample_base;
// normalize weights **per ray**
scalar_t weights_sum = 0.0f;
for (int j = 0; j < steps; j++)
weights_sum += weights[j];
scalar_t padding = fmaxf(1e-5f - weights_sum, 0.0f);
scalar_t padding_step = padding / steps;
weights_sum += padding;
int num_bins = resample_steps + 1;
scalar_t cdf_step_size = (1.0f - 1.0 / num_bins) / resample_steps;
int idx = 0, j = 0;
scalar_t cdf_prev = 0.0f, cdf_next = (weights[idx] + padding_step) / weights_sum;
scalar_t cdf_u = 1.0 / (2 * num_bins);
while (j < num_bins)
{
if (cdf_u < cdf_next)
{
// printf("cdf_u: %f, cdf_next: %f\n", cdf_u, cdf_next);
// resample in this interval
scalar_t scaling = (ends[idx] - starts[idx]) / (cdf_next - cdf_prev);
scalar_t t = (cdf_u - cdf_prev) * scaling + starts[idx];
if (j < num_bins - 1)
resample_starts[j] = t;
if (j > 0)
resample_ends[j - 1] = t;
// going further to next resample
cdf_u += cdf_step_size;
j += 1;
}
else
{
// going to next interval
idx += 1;
cdf_prev = cdf_next;
cdf_next += (weights[idx] + padding_step) / weights_sum;
}
}
if (j != num_bins)
{
printf("Error: %d %d %f\n", j, num_bins, weights_sum);
}
return;
}
// template <typename scalar_t>
// __global__ void cdf_resampling_kernel(
// const uint32_t n_rays,
// const int *packed_info, // input ray & point indices.
// const scalar_t *starts, // input start t
// const scalar_t *ends, // input end t
// const scalar_t *weights, // transmittance weights
// const int *resample_packed_info,
// scalar_t *resample_starts,
// scalar_t *resample_ends)
// {
// CUDA_GET_THREAD_ID(i, n_rays);
// // locate
// const int base = packed_info[i * 2 + 0]; // point idx start.
// const int steps = packed_info[i * 2 + 1]; // point idx shift.
// const int resample_base = resample_packed_info[i * 2 + 0];
// const int resample_steps = resample_packed_info[i * 2 + 1];
// if (steps == 0)
// return;
// starts += base;
// ends += base;
// weights += base;
// resample_starts += resample_base;
// resample_ends += resample_base;
// scalar_t cdf_step_size = 1.0f / resample_steps;
// // normalize weights **per ray**
// scalar_t weights_sum = 0.0f;
// for (int j = 0; j < steps; j++)
// weights_sum += weights[j];
// scalar_t padding = fmaxf(1e-5f - weights_sum, 0.0f);
// scalar_t padding_step = padding / steps;
// weights_sum += padding;
// int idx = 0, j = 0;
// scalar_t cdf_prev = 0.0f, cdf_next = (weights[idx] + padding_step) / weights_sum;
// scalar_t cdf_u = 0.5f * cdf_step_size;
// while (cdf_u < 1.0f)
// {
// if (cdf_u < cdf_next)
// {
// // resample in this interval
// scalar_t scaling = (ends[idx] - starts[idx]) / (cdf_next - cdf_prev);
// scalar_t resample_mid = (cdf_u - cdf_prev) * scaling + starts[idx];
// scalar_t resample_half_size = cdf_step_size * scaling * 0.5;
// resample_starts[j] = fmaxf(resample_mid - resample_half_size, starts[idx]);
// resample_ends[j] = fminf(resample_mid + resample_half_size, ends[idx]);
// // going further to next resample
// cdf_u += cdf_step_size;
// j += 1;
// }
// else
// {
// // go to next interval
// idx += 1;
// if (idx == steps)
// break;
// cdf_prev = cdf_next;
// cdf_next += (weights[idx] + padding_step) / weights_sum;
// }
// }
// if (j != resample_steps)
// {
// printf("Error: %d %d %f\n", j, resample_steps, weights_sum);
// }
// return;
// }
std::vector<torch::Tensor> ray_resampling(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor weights,
const int steps)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(starts);
CHECK_INPUT(ends);
CHECK_INPUT(weights);
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 2);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(weights.ndimension() == 1);
const uint32_t n_rays = packed_info.size(0);
const uint32_t n_samples = weights.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
torch::Tensor num_steps = torch::split(packed_info, 1, 1)[1];
torch::Tensor resample_num_steps = (num_steps > 0).to(num_steps.options()) * steps;
torch::Tensor resample_cum_steps = resample_num_steps.cumsum(0, torch::kInt32);
torch::Tensor resample_packed_info = torch::cat(
{resample_cum_steps - resample_num_steps, resample_num_steps}, 1);
int total_steps = resample_cum_steps[resample_cum_steps.size(0) - 1].item<int>();
torch::Tensor resample_starts = torch::zeros({total_steps, 1}, starts.options());
torch::Tensor resample_ends = torch::zeros({total_steps, 1}, ends.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
weights.scalar_type(),
"ray_resampling",
([&]
{ cdf_resampling_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
starts.data_ptr<scalar_t>(),
ends.data_ptr<scalar_t>(),
weights.data_ptr<scalar_t>(),
resample_packed_info.data_ptr<int>(),
// outputs
resample_starts.data_ptr<scalar_t>(),
resample_ends.data_ptr<scalar_t>()); }));
return {resample_packed_info, resample_starts, resample_ends};
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment