Commit a74f7e7b authored by Ruilong Li's avatar Ruilong Li
Browse files

proposal_sampling_with_filter: 7k; 229s; loss 64; 35.25db; 63 rays

parent 1aeee0a9
...@@ -87,7 +87,15 @@ def render_image( ...@@ -87,7 +87,15 @@ def render_image(
chunk_rays.viewdirs, chunk_rays.viewdirs,
scene_aabb=scene_aabb, scene_aabb=scene_aabb,
grid=None, grid=None,
proposal_nets=proposal_nets, # proposal density fns: {t_starts, t_ends, ray_indices} -> density
proposal_sigma_fns=[
lambda t_starts, t_ends, ray_indices: sigma_fn(
t_starts, t_ends, ray_indices, proposal_net
)
for proposal_net in proposal_nets
],
proposal_n_samples=[32],
proposal_require_grads=proposal_nets_require_grads,
sigma_fn=sigma_fn, sigma_fn=sigma_fn,
near_plane=near_plane, near_plane=near_plane,
far_plane=far_plane, far_plane=far_plane,
...@@ -95,7 +103,6 @@ def render_image( ...@@ -95,7 +103,6 @@ def render_image(
stratified=radiance_field.training, stratified=radiance_field.training,
cone_angle=cone_angle, cone_angle=cone_angle,
alpha_thre=alpha_thre, alpha_thre=alpha_thre,
proposal_nets_require_grads=proposal_nets_require_grads,
) )
rgb, opacity, depth, weights = rendering( rgb, opacity, depth, weights = rendering(
t_starts, t_starts,
......
...@@ -2,58 +2,9 @@ from typing import Callable, Optional, Tuple ...@@ -2,58 +2,9 @@ from typing import Callable, Optional, Tuple
import torch import torch
import nerfacc.cuda as _C
from .cdf import ray_resampling
from .grid import Grid from .grid import Grid
from .intersection import ray_aabb_intersect from .intersection import ray_aabb_intersect
from .pack import pack_info, unpack_info from .sampling import proposal_sampling_with_filter, sample_along_rays
from .vol_rendering import (
render_visibility,
render_weight_from_alpha,
render_weight_from_density,
)
@torch.no_grad()
def maybe_filter(
t_starts: torch.Tensor,
t_ends: torch.Tensor,
ray_indices: torch.Tensor,
n_rays: int,
# sigma/alpha function for skipping invisible space
sigma_fn: Optional[Callable] = None,
alpha_fn: Optional[Callable] = None,
net: Optional[torch.nn.Module] = None,
early_stop_eps: float = 1e-4,
alpha_thre: float = 0.0,
):
alphas = None
if sigma_fn is not None:
alpha_fn = lambda *args: 1.0 - torch.exp(
-sigma_fn(*args) * (t_ends - t_starts)
)
if alpha_fn is not None:
alphas = alpha_fn(t_starts, t_ends, ray_indices.long(), net)
assert (
alphas.shape == t_starts.shape
), "alphas must have shape of (N, 1)! Got {}".format(alphas.shape)
# Compute visibility of the samples, and filter out invisible samples
masks = render_visibility(
alphas,
ray_indices=ray_indices,
early_stop_eps=early_stop_eps,
alpha_thre=alpha_thre,
n_rays=n_rays,
)
ray_indices, t_starts, t_ends, alphas = (
ray_indices[masks],
t_starts[masks],
t_ends[masks],
alphas[masks],
)
return ray_indices, t_starts, t_ends, alphas
@torch.no_grad() @torch.no_grad()
...@@ -71,10 +22,12 @@ def ray_marching( ...@@ -71,10 +22,12 @@ def ray_marching(
# sigma/alpha function for skipping invisible space # sigma/alpha function for skipping invisible space
sigma_fn: Optional[Callable] = None, sigma_fn: Optional[Callable] = None,
alpha_fn: Optional[Callable] = None, alpha_fn: Optional[Callable] = None,
proposal_nets: Optional[torch.nn.Module] = None, # proposal density fns: {t_starts, t_ends, ray_indices} -> density
proposal_sigma_fns: Tuple[Callable, ...] = [],
proposal_n_samples: Tuple[int, ...] = [],
proposal_require_grads: bool = False,
early_stop_eps: float = 1e-4, early_stop_eps: float = 1e-4,
alpha_thre: float = 0.0, alpha_thre: float = 0.0,
proposal_nets_require_grads: bool = True,
# rendering options # rendering options
near_plane: Optional[float] = None, near_plane: Optional[float] = None,
far_plane: Optional[float] = None, far_plane: Optional[float] = None,
...@@ -177,7 +130,6 @@ def ray_marching( ...@@ -177,7 +130,6 @@ def ray_marching(
sample_locs = rays_o[ray_indices] + t_mid * rays_d[ray_indices] sample_locs = rays_o[ray_indices] + t_mid * rays_d[ray_indices]
""" """
torch.cuda.synchronize()
n_rays = rays_o.shape[0] n_rays = rays_o.shape[0]
if not rays_o.is_cuda: if not rays_o.is_cuda:
...@@ -209,85 +161,32 @@ def ray_marching( ...@@ -209,85 +161,32 @@ def ray_marching(
if stratified: if stratified:
t_min = t_min + torch.rand_like(t_min) * render_step_size t_min = t_min + torch.rand_like(t_min) * render_step_size
# use grid for skipping if given ray_indices, t_starts, t_ends = sample_along_rays(
if grid is not None: rays_o=rays_o,
# marching with grid-based skipping rays_d=rays_d,
packed_info, ray_indices, t_starts, t_ends = _C.ray_marching_with_grid( t_min=t_min,
# rays t_max=t_max,
rays_o.contiguous(), step_size=render_step_size,
rays_d.contiguous(), cone_angle=cone_angle,
t_min.contiguous(), grid=grid,
t_max.contiguous(), )
# coontraction and grid
grid.roi_aabb.contiguous(),
grid.binary.contiguous(),
grid.contraction_type.to_cpp_version(),
# sampling
render_step_size,
cone_angle,
)
else:
# marching
packed_info, ray_indices, t_starts, t_ends = _C.ray_marching(
# rays
t_min.contiguous(),
t_max.contiguous(),
# sampling
render_step_size,
cone_angle,
)
proposal_sample_list = []
if proposal_nets is not None:
# resample with proposal nets
for net, num_samples in zip(proposal_nets, [32]):
ray_indices, t_starts, t_ends, alphas = maybe_filter(
t_starts=t_starts,
t_ends=t_ends,
ray_indices=ray_indices,
n_rays=n_rays,
sigma_fn=sigma_fn,
alpha_fn=alpha_fn,
net=net,
early_stop_eps=early_stop_eps,
alpha_thre=alpha_thre,
)
packed_info = pack_info(ray_indices, n_rays=n_rays)
if proposal_nets_require_grads:
with torch.enable_grad():
sigmas = sigma_fn(
t_starts, t_ends, ray_indices.long(), net=net
)
weights = render_weight_from_density(
t_starts, t_ends, sigmas, ray_indices=ray_indices
)
proposal_sample_list.append(
(packed_info, t_starts, t_ends, weights)
)
else:
weights = render_weight_from_alpha(
alphas, ray_indices=ray_indices
)
packed_info, t_starts, t_ends = ray_resampling(
packed_info, t_starts, t_ends, weights, n_samples=num_samples
)
ray_indices = unpack_info(packed_info, n_samples=t_starts.shape[0])
ray_indices, t_starts, t_ends, _ = maybe_filter( (
ray_indices,
t_starts,
t_ends,
proposal_samples,
) = proposal_sampling_with_filter(
t_starts=t_starts, t_starts=t_starts,
t_ends=t_ends, t_ends=t_ends,
ray_indices=ray_indices, ray_indices=ray_indices,
n_rays=n_rays, n_rays=n_rays,
sigma_fn=sigma_fn, sigma_fn=sigma_fn,
alpha_fn=alpha_fn, proposal_sigma_fns=proposal_sigma_fns,
net=None, proposal_n_samples=proposal_n_samples,
proposal_require_grads=proposal_require_grads,
early_stop_eps=early_stop_eps, early_stop_eps=early_stop_eps,
alpha_thre=alpha_thre, alpha_thre=alpha_thre,
) )
if proposal_nets is not None: return ray_indices, t_starts, t_ends, proposal_samples
return ray_indices, t_starts, t_ends, proposal_sample_list
else:
return ray_indices, t_starts, t_ends
import math
from typing import Callable, Optional, Tuple, Union, overload
import torch
import nerfacc.cuda as _C
from .cdf import ray_resampling
from .grid import Grid
from .pack import pack_info, unpack_info
from .vol_rendering import (
render_transmittance_from_alpha,
render_weight_from_density,
)
@overload
def sample_along_rays(
rays_o: torch.Tensor, # [n_rays, 3]
rays_d: torch.Tensor, # [n_rays, 3]
t_min: torch.Tensor, # [n_rays,]
t_max: torch.Tensor, # [n_rays,]
step_size: float,
cone_angle: float = 0.0,
grid: Optional[Grid] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Sample along rays with per-ray min max."""
...
@overload
def sample_along_rays(
rays_o: torch.Tensor, # [n_rays, 3]
rays_d: torch.Tensor, # [n_rays, 3]
t_min: float,
t_max: float,
step_size: float,
cone_angle: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Sample along rays with near far plane."""
...
@torch.no_grad()
def sample_along_rays(
rays_o: torch.Tensor, # [n_rays, 3]
rays_d: torch.Tensor, # [n_rays, 3]
t_min: Union[float, torch.Tensor], # [n_rays,]
t_max: Union[float, torch.Tensor], # [n_rays,]
step_size: float,
cone_angle: float = 0.0,
grid: Optional[Grid] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Sample intervals along rays."""
if isinstance(t_min, float) and isinstance(t_max, float):
n_rays = rays_o.shape[0]
device = rays_o.device
num_steps = math.floor((t_max - t_min) / step_size)
t_starts = (
(t_min + torch.arange(0, num_steps, device=device) * step_size)
.expand(n_rays, -1)
.reshape(-1, 1)
)
t_ends = t_starts + step_size
ray_indices = torch.arange(0, n_rays, device=device).repeat_interleave(
num_steps, dim=0
)
else:
if grid is None:
packed_info, ray_indices, t_starts, t_ends = _C.ray_marching(
# rays
t_min.contiguous(),
t_max.contiguous(),
# sampling
step_size,
cone_angle,
)
else:
(
packed_info,
ray_indices,
t_starts,
t_ends,
) = _C.ray_marching_with_grid(
# rays
rays_o.contiguous(),
rays_d.contiguous(),
t_min.contiguous(),
t_max.contiguous(),
# coontraction and grid
grid.roi_aabb.contiguous(),
grid.binary.contiguous(),
grid.contraction_type.to_cpp_version(),
# sampling
step_size,
cone_angle,
)
return ray_indices, t_starts, t_ends
@torch.no_grad()
def proposal_sampling_with_filter(
t_starts: torch.Tensor, # [n_samples, 1]
t_ends: torch.Tensor, # [n_samples, 1]
ray_indices: torch.Tensor, # [n_samples,]
n_rays: Optional[int] = None,
# compute density of samples: {t_starts, t_ends, ray_indices} -> density
sigma_fn: Optional[Callable] = None,
# proposal density fns: {t_starts, t_ends, ray_indices} -> density
proposal_sigma_fns: Tuple[Callable, ...] = [],
proposal_n_samples: Tuple[int, ...] = [],
proposal_require_grads: bool = False,
# acceleration options
early_stop_eps: float = 1e-4,
alpha_thre: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Hueristic marching with proposal fns."""
assert len(proposal_sigma_fns) == len(proposal_n_samples), (
"proposal_sigma_fns and proposal_n_samples must have the same length, "
f"but got {len(proposal_sigma_fns)} and {len(proposal_n_samples)}."
)
if n_rays is None:
n_rays = ray_indices.max() + 1
# compute density from proposal fns
proposal_samples = []
for proposal_fn, n_samples in zip(proposal_sigma_fns, proposal_n_samples):
# compute weights for resampling
sigmas = proposal_fn(t_starts, t_ends, ray_indices.long())
assert (
sigmas.shape == t_starts.shape
), "sigmas must have shape of (N, 1)! Got {}".format(sigmas.shape)
alphas = 1.0 - torch.exp(-sigmas * (t_ends - t_starts))
transmittance = render_transmittance_from_alpha(
alphas, ray_indices=ray_indices, n_rays=n_rays
)
weights = alphas * transmittance
# Compute visibility for filtering
if alpha_thre > 0 or early_stop_eps > 0:
vis = (alphas >= alpha_thre) & (transmittance >= early_stop_eps)
vis = vis.squeeze(-1)
ray_indices, t_starts, t_ends, weights = (
ray_indices[vis],
t_starts[vis],
t_ends[vis],
weights[vis],
)
packed_info = pack_info(ray_indices, n_rays=n_rays)
# Rerun the proposal function **with** gradients on filtered samples.
if proposal_require_grads:
with torch.enable_grad():
sigmas = proposal_fn(t_starts, t_ends, ray_indices.long())
weights = render_weight_from_density(
t_starts, t_ends, sigmas, ray_indices=ray_indices
)
proposal_samples.append(
(packed_info, t_starts, t_ends, weights)
)
# resampling on filtered samples
packed_info, t_starts, t_ends = ray_resampling(
packed_info, t_starts, t_ends, weights, n_samples=n_samples
)
ray_indices = unpack_info(packed_info, t_starts.shape[0])
# last round filtering with sigma_fn
if (alpha_thre > 0 or early_stop_eps > 0) and (sigma_fn is not None):
sigmas = sigma_fn(t_starts, t_ends, ray_indices.long())
assert (
sigmas.shape == t_starts.shape
), "sigmas must have shape of (N, 1)! Got {}".format(sigmas.shape)
alphas = 1.0 - torch.exp(-sigmas * (t_ends - t_starts))
transmittance = render_transmittance_from_alpha(
alphas, ray_indices=ray_indices, n_rays=n_rays
)
vis = (alphas >= alpha_thre) & (transmittance >= early_stop_eps)
vis = vis.squeeze(-1)
ray_indices, t_starts, t_ends = (
ray_indices[vis],
t_starts[vis],
t_ends[vis],
)
return ray_indices, t_starts, t_ends, proposal_samples
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment