Unverified Commit 542f4310 authored by Ruilong Li(李瑞龙)'s avatar Ruilong Li(李瑞龙) Committed by GitHub
Browse files

Support setting alpha threshold for marching and rendering (#42)

* support alpha_thre for rendering and ray marching. default to zero

* bump version
parent daf3559a
......@@ -8,8 +8,8 @@ project = "nerfacc"
copyright = "2022, Ruilong"
author = "Ruilong"
release = "0.1.2"
version = "0.1.2"
release = "0.1.4"
version = "0.1.4"
# -- General configuration
......
......@@ -8,6 +8,7 @@ std::vector<torch::Tensor> rendering_forward(
torch::Tensor ends,
torch::Tensor sigmas,
float early_stop_eps,
float alpha_thre,
bool compression);
torch::Tensor rendering_backward(
......@@ -17,7 +18,8 @@ torch::Tensor rendering_backward(
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas,
float early_stop_eps);
float early_stop_eps,
float alpha_thre);
std::vector<torch::Tensor> ray_aabb_intersect(
const torch::Tensor rays_o,
......@@ -65,12 +67,14 @@ torch::Tensor rendering_alphas_backward(
torch::Tensor grad_weights,
torch::Tensor packed_info,
torch::Tensor alphas,
float early_stop_eps);
float early_stop_eps,
float alpha_thre);
std::vector<torch::Tensor> rendering_alphas_forward(
torch::Tensor packed_info,
torch::Tensor alphas,
float early_stop_eps,
float alpha_thre,
bool compression);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
......
......@@ -9,6 +9,7 @@ __global__ void rendering_forward_kernel(
const scalar_t *sigmas, // input density after activation
const scalar_t *alphas, // input alpha (opacity) values.
const scalar_t early_stop_eps, // transmittance threshold for early stop
const scalar_t alpha_thre, // alpha threshold for emtpy space
// outputs: should be all-zero initialized
int *num_steps, // the number of valid steps for each ray
scalar_t *weights, // the number rendering weights for each sample
......@@ -70,6 +71,11 @@ __global__ void rendering_forward_kernel(
scalar_t delta = ends[j] - starts[j];
alpha = 1.f - __expf(-sigmas[j] * delta);
}
if (alpha < alpha_thre)
{
// empty space
continue;
}
const scalar_t weight = alpha * T;
T *= (1.f - alpha);
if (weights != nullptr)
......@@ -97,6 +103,7 @@ __global__ void rendering_backward_kernel(
const scalar_t *sigmas, // input density after activation
const scalar_t *alphas, // input alpha (opacity) values.
const scalar_t early_stop_eps, // transmittance threshold for early stop
const scalar_t alpha_thre, // alpha threshold for emtpy space
const scalar_t *weights, // forward output
const scalar_t *grad_weights, // input gradients
// if alphas was given, we compute the gradients for alphas.
......@@ -150,6 +157,11 @@ __global__ void rendering_backward_kernel(
{
// rendering with alpha
alpha = alphas[j];
if (alpha < alpha_thre)
{
// empty space
continue;
}
grad_alphas[j] = (grad_weights[j] * T - accum) / fmaxf(1.f - alpha, 1e-10f);
}
else
......@@ -157,6 +169,11 @@ __global__ void rendering_backward_kernel(
// rendering with density
scalar_t delta = ends[j] - starts[j];
alpha = 1.f - __expf(-sigmas[j] * delta);
if (alpha < alpha_thre)
{
// empty space
continue;
}
grad_sigmas[j] = (grad_weights[j] * T - accum) * delta;
}
......@@ -171,6 +188,7 @@ std::vector<torch::Tensor> rendering_forward(
torch::Tensor ends,
torch::Tensor sigmas,
float early_stop_eps,
float alpha_thre,
bool compression)
{
DEVICE_GUARD(packed_info);
......@@ -211,6 +229,7 @@ std::vector<torch::Tensor> rendering_forward(
sigmas.data_ptr<scalar_t>(),
nullptr, // alphas
early_stop_eps,
alpha_thre,
// outputs
num_steps.data_ptr<int>(),
nullptr,
......@@ -238,6 +257,7 @@ std::vector<torch::Tensor> rendering_forward(
sigmas.data_ptr<scalar_t>(),
nullptr, // alphas
early_stop_eps,
alpha_thre,
// outputs
nullptr,
weights.data_ptr<scalar_t>(),
......@@ -254,7 +274,8 @@ torch::Tensor rendering_backward(
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas,
float early_stop_eps)
float early_stop_eps,
float alpha_thre)
{
DEVICE_GUARD(packed_info);
const uint32_t n_rays = packed_info.size(0);
......@@ -279,6 +300,7 @@ torch::Tensor rendering_backward(
sigmas.data_ptr<scalar_t>(),
nullptr, // alphas
early_stop_eps,
alpha_thre,
weights.data_ptr<scalar_t>(),
grad_weights.data_ptr<scalar_t>(),
// outputs
......@@ -295,6 +317,7 @@ std::vector<torch::Tensor> rendering_alphas_forward(
torch::Tensor packed_info,
torch::Tensor alphas,
float early_stop_eps,
float alpha_thre,
bool compression)
{
DEVICE_GUARD(packed_info);
......@@ -331,6 +354,7 @@ std::vector<torch::Tensor> rendering_alphas_forward(
nullptr, // sigmas
alphas.data_ptr<scalar_t>(),
early_stop_eps,
alpha_thre,
// outputs
num_steps.data_ptr<int>(),
nullptr,
......@@ -358,6 +382,7 @@ std::vector<torch::Tensor> rendering_alphas_forward(
nullptr, // sigmas
alphas.data_ptr<scalar_t>(),
early_stop_eps,
alpha_thre,
// outputs
nullptr,
weights.data_ptr<scalar_t>(),
......@@ -372,7 +397,8 @@ torch::Tensor rendering_alphas_backward(
torch::Tensor grad_weights,
torch::Tensor packed_info,
torch::Tensor alphas,
float early_stop_eps)
float early_stop_eps,
float alpha_thre)
{
DEVICE_GUARD(packed_info);
const uint32_t n_rays = packed_info.size(0);
......@@ -397,6 +423,7 @@ torch::Tensor rendering_alphas_backward(
nullptr, // sigmas
alphas.data_ptr<scalar_t>(),
early_stop_eps,
alpha_thre,
weights.data_ptr<scalar_t>(),
grad_weights.data_ptr<scalar_t>(),
// outputs
......
......@@ -15,6 +15,7 @@ def rendering(
t_ends: torch.Tensor,
# rendering options
early_stop_eps: float = 1e-4,
alpha_thre: float = 1e-2,
render_bkgd: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Render the rays through the radience field defined by `rgb_sigma_fn`.
......@@ -33,6 +34,7 @@ def rendering(
t_starts: Per-sample start distance. Tensor with shape (n_samples, 1).
t_ends: Per-sample end distance. Tensor with shape (n_samples, 1).
early_stop_eps: Early stop threshold during trasmittance accumulation. Default: 1e-4.
alpha_thre: Alpha threshold for skipping empty space. Default: 0.0.
render_bkgd: Optional. Background color. Tensor with shape (3,).
Returns:
......@@ -82,7 +84,7 @@ def rendering(
# Rendering: compute weights and ray indices.
weights = render_weight_from_density(
packed_info, t_starts, t_ends, sigmas, early_stop_eps
packed_info, t_starts, t_ends, sigmas, early_stop_eps, alpha_thre
)
# Rendering: accumulate rgbs, opacities, and depths along the rays.
......
......@@ -104,6 +104,7 @@ def ray_marching(
# sigma function for skipping invisible space
sigma_fn: Optional[Callable] = None,
early_stop_eps: float = 1e-4,
alpha_thre: float = 0.0,
# rendering options
near_plane: Optional[float] = None,
far_plane: Optional[float] = None,
......@@ -140,6 +141,7 @@ def ray_marching(
function that takes in samples {t_starts (N, 1), t_ends (N, 1),
ray indices (N,)} and returns the post-activation density values (N, 1).
early_stop_eps: Early stop threshold for skipping invisible space. Default: 1e-4.
alpha_thre: Alpha threshold for skipping empty space. Default: 0.0.
near_plane: Optional. Near plane distance. If provided, it will be used
to clip t_min.
far_plane: Optional. Far plane distance. If provided, it will be used
......@@ -272,7 +274,7 @@ def ray_marching(
# Compute visibility of the samples, and filter out invisible samples
visibility, packed_info_visible = render_visibility(
packed_info, alphas, early_stop_eps
packed_info, alphas, early_stop_eps, alpha_thre
)
t_starts, t_ends = t_starts[visibility], t_ends[visibility]
packed_info = packed_info_visible
......
......@@ -82,6 +82,7 @@ def render_weight_from_density(
t_ends,
sigmas,
early_stop_eps: float = 1e-4,
alpha_thre: float = 0.0,
) -> torch.Tensor:
"""Compute transmittance weights from density.
......@@ -94,6 +95,7 @@ def render_weight_from_density(
shape (n_samples, 1).
sigmas: The density values of the samples. Tensor with shape (n_samples, 1).
early_stop_eps: The epsilon value for early stopping. Default is 1e-4.
alpha_thre: Alpha threshold for skipping empty space. Default: 0.0.
Returns:
transmittance weights with shape (n_samples,).
......@@ -123,7 +125,7 @@ def render_weight_from_density(
if not sigmas.is_cuda:
raise NotImplementedError("Only support cuda inputs.")
weights = _RenderingDensity.apply(
packed_info, t_starts, t_ends, sigmas, early_stop_eps
packed_info, t_starts, t_ends, sigmas, early_stop_eps, alpha_thre
)
return weights
......@@ -132,6 +134,7 @@ def render_weight_from_alpha(
packed_info,
alphas,
early_stop_eps: float = 1e-4,
alpha_thre: float = 0.0,
) -> Tuple[torch.Tensor, ...]:
"""Compute transmittance weights from density.
......@@ -140,7 +143,8 @@ def render_weight_from_alpha(
See :func:`nerfacc.ray_marching` for details. Tensor with shape (n_rays, 2).
alphas: The opacity values of the samples. Tensor with shape (n_samples, 1).
early_stop_eps: The epsilon value for early stopping. Default is 1e-4.
alpha_thre: Alpha threshold for skipping empty space. Default: 0.0.
Returns:
transmittance weights with shape (n_samples,).
......@@ -168,7 +172,9 @@ def render_weight_from_alpha(
"""
if not alphas.is_cuda:
raise NotImplementedError("Only support cuda inputs.")
weights = _RenderingAlpha.apply(packed_info, alphas, early_stop_eps)
weights = _RenderingAlpha.apply(
packed_info, alphas, early_stop_eps, alpha_thre
)
return weights
......@@ -177,6 +183,7 @@ def render_visibility(
packed_info: torch.Tensor,
alphas: torch.Tensor,
early_stop_eps: float = 1e-4,
alpha_thre: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Filter out invisible samples given alpha (opacity).
......@@ -185,6 +192,7 @@ def render_visibility(
See :func:`nerfacc.ray_marching` for details. Tensor with shape (n_rays, 2).
alphas: The opacity values of the samples. Tensor with shape (n_samples, 1).
early_stop_eps: The epsilon value for early stopping. Default is 1e-4.
alpha_thre: Alpha threshold for skipping empty space. Default: 0.0.
Returns:
A tuple of tensors.
......@@ -223,6 +231,7 @@ def render_visibility(
packed_info.contiguous(),
alphas.contiguous(),
early_stop_eps,
alpha_thre,
True, # compute visibility instead of weights
)
return visibility, packed_info_visible
......@@ -239,6 +248,7 @@ class _RenderingDensity(torch.autograd.Function):
t_ends,
sigmas,
early_stop_eps: float = 1e-4,
alpha_thre: float = 0.0,
):
packed_info = packed_info.contiguous()
t_starts = t_starts.contiguous()
......@@ -250,6 +260,7 @@ class _RenderingDensity(torch.autograd.Function):
t_ends,
sigmas,
early_stop_eps,
alpha_thre,
False, # not doing filtering
)[0]
if ctx.needs_input_grad[3]: # sigmas
......@@ -261,12 +272,14 @@ class _RenderingDensity(torch.autograd.Function):
weights,
)
ctx.early_stop_eps = early_stop_eps
ctx.alpha_thre = alpha_thre
return weights
@staticmethod
def backward(ctx, grad_weights):
grad_weights = grad_weights.contiguous()
early_stop_eps = ctx.early_stop_eps
alpha_thre = ctx.alpha_thre
(
packed_info,
t_starts,
......@@ -282,8 +295,9 @@ class _RenderingDensity(torch.autograd.Function):
t_ends,
sigmas,
early_stop_eps,
alpha_thre,
)
return None, None, None, grad_sigmas, None
return None, None, None, grad_sigmas, None, None
class _RenderingAlpha(torch.autograd.Function):
......@@ -295,6 +309,7 @@ class _RenderingAlpha(torch.autograd.Function):
packed_info,
alphas,
early_stop_eps: float = 1e-4,
alpha_thre: float = 0.0,
):
packed_info = packed_info.contiguous()
alphas = alphas.contiguous()
......@@ -302,6 +317,7 @@ class _RenderingAlpha(torch.autograd.Function):
packed_info,
alphas,
early_stop_eps,
alpha_thre,
False, # not doing filtering
)[0]
if ctx.needs_input_grad[1]: # alphas
......@@ -311,12 +327,14 @@ class _RenderingAlpha(torch.autograd.Function):
weights,
)
ctx.early_stop_eps = early_stop_eps
ctx.alpha_thre = alpha_thre
return weights
@staticmethod
def backward(ctx, grad_weights):
grad_weights = grad_weights.contiguous()
early_stop_eps = ctx.early_stop_eps
alpha_thre = ctx.alpha_thre
(
packed_info,
alphas,
......@@ -328,5 +346,6 @@ class _RenderingAlpha(torch.autograd.Function):
packed_info,
alphas,
early_stop_eps,
alpha_thre,
)
return None, grad_sigmas, None
return None, grad_sigmas, None, None
......@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "nerfacc"
version = "0.1.3"
version = "0.1.4"
authors = [{name = "Ruilong", email = "ruilongli94@gmail.com"}]
license = { text="MIT" }
requires-python = ">=3.8"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment