Unverified Commit d24c2b7f authored by Ruilong Li(李瑞龙)'s avatar Ruilong Li(李瑞龙) Committed by GitHub
Browse files

support alpha; relax pytorch requirement (#94)

* support alpha for both marching and rendering; relax pytorch requirement

* bump version
parent e9aa8d3f
...@@ -99,10 +99,10 @@ def render_image( ...@@ -99,10 +99,10 @@ def render_image(
alpha_thre=alpha_thre, alpha_thre=alpha_thre,
) )
rgb, opacity, depth = rendering( rgb, opacity, depth = rendering(
rgb_sigma_fn,
packed_info, packed_info,
t_starts, t_starts,
t_ends, t_ends,
rgb_sigma_fn=rgb_sigma_fn,
render_bkgd=render_bkgd, render_bkgd=render_bkgd,
) )
chunk_results = [rgb, opacity, depth, len(t_starts)] chunk_results = [rgb, opacity, depth, len(t_starts)]
......
...@@ -22,8 +22,9 @@ def ray_marching( ...@@ -22,8 +22,9 @@ def ray_marching(
scene_aabb: Optional[torch.Tensor] = None, scene_aabb: Optional[torch.Tensor] = None,
# binarized grid for skipping empty space # binarized grid for skipping empty space
grid: Optional[Grid] = None, grid: Optional[Grid] = None,
# sigma function for skipping invisible space # sigma/alpha function for skipping invisible space
sigma_fn: Optional[Callable] = None, sigma_fn: Optional[Callable] = None,
alpha_fn: Optional[Callable] = None,
early_stop_eps: float = 1e-4, early_stop_eps: float = 1e-4,
alpha_thre: float = 0.0, alpha_thre: float = 0.0,
# rendering options # rendering options
...@@ -61,6 +62,12 @@ def ray_marching( ...@@ -61,6 +62,12 @@ def ray_marching(
by evaluating the density along the ray with `sigma_fn`. It should be a by evaluating the density along the ray with `sigma_fn`. It should be a
function that takes in samples {t_starts (N, 1), t_ends (N, 1), function that takes in samples {t_starts (N, 1), t_ends (N, 1),
ray indices (N,)} and returns the post-activation density values (N, 1). ray indices (N,)} and returns the post-activation density values (N, 1).
You should only provide either `sigma_fn` or `alpha_fn`.
alpha_fn: Optional. If provided, the marching will skip the invisible space
by evaluating the density along the ray with `alpha_fn`. It should be a
function that takes in samples {t_starts (N, 1), t_ends (N, 1),
ray indices (N,)} and returns the post-activation opacity values (N, 1).
You should only provide either `sigma_fn` or `alpha_fn`.
early_stop_eps: Early stop threshold for skipping invisible space. Default: 1e-4. early_stop_eps: Early stop threshold for skipping invisible space. Default: 1e-4.
alpha_thre: Alpha threshold for skipping empty space. Default: 0.0. alpha_thre: Alpha threshold for skipping empty space. Default: 0.0.
near_plane: Optional. Near plane distance. If provided, it will be used near_plane: Optional. Near plane distance. If provided, it will be used
...@@ -128,6 +135,10 @@ def ray_marching( ...@@ -128,6 +135,10 @@ def ray_marching(
""" """
if not rays_o.is_cuda: if not rays_o.is_cuda:
raise NotImplementedError("Only support cuda inputs.") raise NotImplementedError("Only support cuda inputs.")
if alpha_fn is not None and sigma_fn is not None:
raise ValueError(
"Only one of `alpha_fn` and `sigma_fn` should be provided."
)
# logic for t_min and t_max: # logic for t_min and t_max:
# 1. if t_min and t_max are given, use them with highest priority. # 1. if t_min and t_max are given, use them with highest priority.
...@@ -184,14 +195,20 @@ def ray_marching( ...@@ -184,14 +195,20 @@ def ray_marching(
) )
# skip invisible space # skip invisible space
if sigma_fn is not None: if sigma_fn is not None or alpha_fn is not None:
# Query sigma without gradients # Query sigma without gradients
ray_indices = unpack_info(packed_info) ray_indices = unpack_info(packed_info)
sigmas = sigma_fn(t_starts, t_ends, ray_indices.long()) if sigma_fn is not None:
assert ( sigmas = sigma_fn(t_starts, t_ends, ray_indices.long())
sigmas.shape == t_starts.shape assert (
), "sigmas must have shape of (N, 1)! Got {}".format(sigmas.shape) sigmas.shape == t_starts.shape
alphas = 1.0 - torch.exp(-sigmas * (t_ends - t_starts)) ), "sigmas must have shape of (N, 1)! Got {}".format(sigmas.shape)
alphas = 1.0 - torch.exp(-sigmas * (t_ends - t_starts))
elif alpha_fn is not None:
alphas = alpha_fn(t_starts, t_ends, ray_indices.long())
assert (
alphas.shape == t_starts.shape
), "alphas must have shape of (N, 1)! Got {}".format(alphas.shape)
# Compute visibility of the samples, and filter out invisible samples # Compute visibility of the samples, and filter out invisible samples
visibility, packed_info_visible = render_visibility( visibility, packed_info_visible = render_visibility(
......
...@@ -13,12 +13,13 @@ from .pack import unpack_info ...@@ -13,12 +13,13 @@ from .pack import unpack_info
def rendering( def rendering(
# radiance field
rgb_sigma_fn: Callable,
# ray marching results # ray marching results
packed_info: torch.Tensor, packed_info: torch.Tensor,
t_starts: torch.Tensor, t_starts: torch.Tensor,
t_ends: torch.Tensor, t_ends: torch.Tensor,
# radiance field
rgb_sigma_fn: Optional[Callable] = None,
rgb_alpha_fn: Optional[Callable] = None,
# rendering options # rendering options
early_stop_eps: float = 1e-4, early_stop_eps: float = 1e-4,
alpha_thre: float = 0.0, alpha_thre: float = 0.0,
...@@ -33,12 +34,17 @@ def rendering( ...@@ -33,12 +34,17 @@ def rendering(
This function is not differentiable to `t_starts`, `t_ends`. This function is not differentiable to `t_starts`, `t_ends`.
Args: Args:
rgb_sigma_fn: A function that takes in samples {t_starts (N, 1), t_ends (N, 1), \
ray indices (N,)} and returns the post-activation rgb (N, 3) and density \
values (N, 1).
packed_info: Packed ray marching info. See :func:`ray_marching` for details. packed_info: Packed ray marching info. See :func:`ray_marching` for details.
t_starts: Per-sample start distance. Tensor with shape (n_samples, 1). t_starts: Per-sample start distance. Tensor with shape (n_samples, 1).
t_ends: Per-sample end distance. Tensor with shape (n_samples, 1). t_ends: Per-sample end distance. Tensor with shape (n_samples, 1).
rgb_sigma_fn: A function that takes in samples {t_starts (N, 1), t_ends (N, 1), \
ray indices (N,)} and returns the post-activation rgb (N, 3) and density \
values (N, 1). At least one of `rgb_sigma_fn` and `rgb_alpha_fn` should be \
specified.
rgb_alpha_fn: A function that takes in samples {t_starts (N, 1), t_ends (N, 1), \
ray indices (N,)} and returns the post-activation rgb (N, 3) and opacity \
values (N, 1). At least one of `rgb_sigma_fn` and `rgb_alpha_fn` should be \
specified.
early_stop_eps: Early stop threshold during trasmittance accumulation. Default: 1e-4. early_stop_eps: Early stop threshold during trasmittance accumulation. Default: 1e-4.
alpha_thre: Alpha threshold for skipping empty space. Default: 0.0. alpha_thre: Alpha threshold for skipping empty space. Default: 0.0.
render_bkgd: Optional. Background color. Tensor with shape (3,). render_bkgd: Optional. Background color. Tensor with shape (3,).
...@@ -76,22 +82,47 @@ def rendering( ...@@ -76,22 +82,47 @@ def rendering(
print(colors.shape, opacities.shape, depths.shape) print(colors.shape, opacities.shape, depths.shape)
""" """
if callable(packed_info):
raise RuntimeError(
"You maybe want to use the nerfacc<=0.2.1 version. For nerfacc>0.2.1, "
"The first argument of `rendering` should be the packed ray packed info. "
"See the latest documentation for details: "
"https://www.nerfacc.com/en/latest/apis/rendering.html#nerfacc.rendering"
)
if rgb_sigma_fn is None and rgb_alpha_fn is None:
raise ValueError(
"At least one of `rgb_sigma_fn` and `rgb_alpha_fn` should be specified."
)
n_rays = packed_info.shape[0] n_rays = packed_info.shape[0]
ray_indices = unpack_info(packed_info) ray_indices = unpack_info(packed_info)
# Query sigma and color with gradients # Query sigma/alpha and color with gradients
rgbs, sigmas = rgb_sigma_fn(t_starts, t_ends, ray_indices.long()) if rgb_sigma_fn is not None:
assert rgbs.shape[-1] == 3, "rgbs must have 3 channels, got {}".format( rgbs, sigmas = rgb_sigma_fn(t_starts, t_ends, ray_indices.long())
rgbs.shape assert rgbs.shape[-1] == 3, "rgbs must have 3 channels, got {}".format(
) rgbs.shape
assert ( )
sigmas.shape == t_starts.shape assert (
), "sigmas must have shape of (N, 1)! Got {}".format(sigmas.shape) sigmas.shape == t_starts.shape
), "sigmas must have shape of (N, 1)! Got {}".format(sigmas.shape)
# Rendering: compute weights and ray indices. # Rendering: compute weights and ray indices.
weights = render_weight_from_density( weights = render_weight_from_density(
packed_info, t_starts, t_ends, sigmas, early_stop_eps, alpha_thre packed_info, t_starts, t_ends, sigmas, early_stop_eps, alpha_thre
) )
elif rgb_alpha_fn is not None:
rgbs, alphas = rgb_alpha_fn(t_starts, t_ends, ray_indices.long())
assert rgbs.shape[-1] == 3, "rgbs must have 3 channels, got {}".format(
rgbs.shape
)
assert (
alphas.shape == t_starts.shape
), "alphas must have shape of (N, 1)! Got {}".format(alphas.shape)
# Rendering: compute weights and ray indices.
weights = render_weight_from_alpha(
packed_info, alphas, early_stop_eps, alpha_thre
)
# Rendering: accumulate rgbs, opacities, and depths along the rays. # Rendering: accumulate rgbs, opacities, and depths along the rays.
colors = accumulate_along_rays( colors = accumulate_along_rays(
...@@ -244,7 +275,7 @@ def render_weight_from_alpha( ...@@ -244,7 +275,7 @@ def render_weight_from_alpha(
early_stop_eps: float = 1e-4, early_stop_eps: float = 1e-4,
alpha_thre: float = 0.0, alpha_thre: float = 0.0,
) -> torch.Tensor: ) -> torch.Tensor:
"""Compute transmittance weights from density. """Compute transmittance weights from opacity.
Args: Args:
packed_info: Stores information on which samples belong to the same ray. \ packed_info: Stores information on which samples belong to the same ray. \
......
...@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" ...@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "nerfacc" name = "nerfacc"
version = "0.2.1" version = "0.2.2"
description = "A General NeRF Acceleration Toolbox." description = "A General NeRF Acceleration Toolbox."
readme = "README.md" readme = "README.md"
authors = [{name = "Ruilong", email = "ruilongli94@gmail.com"}] authors = [{name = "Ruilong", email = "ruilongli94@gmail.com"}]
...@@ -14,7 +14,7 @@ dependencies = [ ...@@ -14,7 +14,7 @@ dependencies = [
"importlib_metadata>=5.0.0; python_version<'3.8'", "importlib_metadata>=5.0.0; python_version<'3.8'",
"ninja>=1.10.2.3", "ninja>=1.10.2.3",
"pybind11>=2.10.0", "pybind11>=2.10.0",
"torch>=1.12.0", "torch", # tested with 1.12.0
"rich>=12" "rich>=12"
] ]
......
...@@ -120,7 +120,9 @@ def test_rendering(): ...@@ -120,7 +120,9 @@ def test_rendering():
t_starts = torch.rand_like(sigmas) t_starts = torch.rand_like(sigmas)
t_ends = torch.rand_like(sigmas) + 1.0 t_ends = torch.rand_like(sigmas) + 1.0
_, _, _ = rendering(rgb_sigma_fn, packed_info, t_starts, t_ends) _, _, _ = rendering(
packed_info, t_starts, t_ends, rgb_sigma_fn=rgb_sigma_fn
)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment