Commit 5b6f0c61 authored by Ruilong Li's avatar Ruilong Li
Browse files

remove alive_mask

parent 224e217d
...@@ -73,14 +73,13 @@ def render_image(radiance_field, rays, render_bkgd, render_step_size): ...@@ -73,14 +73,13 @@ def render_image(radiance_field, rays, render_bkgd, render_step_size):
render_step_size=render_step_size, render_step_size=render_step_size,
) )
results.append(chunk_results) results.append(chunk_results)
rgb, depth, acc, alive_ray_mask, counter, compact_counter = [ rgb, depth, acc, counter, compact_counter = [
torch.cat(r, dim=0) for r in zip(*results) torch.cat(r, dim=0) for r in zip(*results)
] ]
return ( return (
rgb.view((*rays_shape[:-1], -1)), rgb.view((*rays_shape[:-1], -1)),
depth.view((*rays_shape[:-1], -1)), depth.view((*rays_shape[:-1], -1)),
acc.view((*rays_shape[:-1], -1)), acc.view((*rays_shape[:-1], -1)),
alive_ray_mask.view(*rays_shape[:-1]),
counter.sum(), counter.sum(),
compact_counter.sum(), compact_counter.sum(),
) )
...@@ -192,7 +191,7 @@ if __name__ == "__main__": ...@@ -192,7 +191,7 @@ if __name__ == "__main__":
# update occupancy grid # update occupancy grid
occ_field.every_n_step(step) occ_field.every_n_step(step)
rgb, depth, acc, alive_ray_mask, counter, compact_counter = render_image( rgb, depth, acc, counter, compact_counter = render_image(
radiance_field, rays, render_bkgd, render_step_size radiance_field, rays, render_bkgd, render_step_size
) )
num_rays = len(pixels) num_rays = len(pixels)
...@@ -200,6 +199,7 @@ if __name__ == "__main__": ...@@ -200,6 +199,7 @@ if __name__ == "__main__":
num_rays * (TARGET_SAMPLE_BATCH_SIZE / float(compact_counter.item())) num_rays * (TARGET_SAMPLE_BATCH_SIZE / float(compact_counter.item()))
) )
train_dataset.update_num_rays(num_rays) train_dataset.update_num_rays(num_rays)
alive_ray_mask = acc.squeeze(-1) > 0
# compute loss # compute loss
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask]) loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
...@@ -231,7 +231,7 @@ if __name__ == "__main__": ...@@ -231,7 +231,7 @@ if __name__ == "__main__":
pixels = data["pixels"].to(device) pixels = data["pixels"].to(device)
render_bkgd = data["color_bkgd"].to(device) render_bkgd = data["color_bkgd"].to(device)
# rendering # rendering
rgb, depth, acc, alive_ray_mask, _, _ = render_image( rgb, depth, acc, _, _ = render_image(
radiance_field, rays, render_bkgd, render_step_size radiance_field, rays, render_bkgd, render_step_size
) )
mse = F.mse_loss(rgb, pixels) mse = F.mse_loss(rgb, pixels)
......
from ._backend import _C from ._backend import _C
volumetric_marching = _C.volumetric_marching
volumetric_rendering_steps = _C.volumetric_rendering_steps
...@@ -53,8 +53,7 @@ __global__ void volumetric_rendering_weights_forward_kernel( ...@@ -53,8 +53,7 @@ __global__ void volumetric_rendering_weights_forward_kernel(
const scalar_t* sigmas, // input density after activation const scalar_t* sigmas, // input density after activation
// should be all-zero initialized // should be all-zero initialized
scalar_t* weights, // output scalar_t* weights, // output
int* samples_ray_ids, // output int* samples_ray_ids // output
bool* mask // output
) { ) {
CUDA_GET_THREAD_ID(i, n_rays); CUDA_GET_THREAD_ID(i, n_rays);
...@@ -68,7 +67,6 @@ __global__ void volumetric_rendering_weights_forward_kernel( ...@@ -68,7 +67,6 @@ __global__ void volumetric_rendering_weights_forward_kernel(
sigmas += base; sigmas += base;
weights += base; weights += base;
samples_ray_ids += base; samples_ray_ids += base;
mask += i;
for (int j = 0; j < steps; ++j) { for (int j = 0; j < steps; ++j) {
samples_ray_ids[j] = i; samples_ray_ids[j] = i;
...@@ -87,7 +85,6 @@ __global__ void volumetric_rendering_weights_forward_kernel( ...@@ -87,7 +85,6 @@ __global__ void volumetric_rendering_weights_forward_kernel(
weights[j] = weight; weights[j] = weight;
T *= (1.f - alpha); T *= (1.f - alpha);
} }
mask[0] = true;
} }
...@@ -167,7 +164,7 @@ std::vector<torch::Tensor> volumetric_rendering_steps( ...@@ -167,7 +164,7 @@ std::vector<torch::Tensor> volumetric_rendering_steps(
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
sigmas.scalar_type(), sigmas.scalar_type(),
"volumetric_rendering_inference", "volumetric_marching_steps",
([&] ([&]
{ volumetric_rendering_steps_kernel<scalar_t><<<blocks, threads>>>( { volumetric_rendering_steps_kernel<scalar_t><<<blocks, threads>>>(
n_rays, n_rays,
...@@ -212,8 +209,6 @@ std::vector<torch::Tensor> volumetric_rendering_weights_forward( ...@@ -212,8 +209,6 @@ std::vector<torch::Tensor> volumetric_rendering_weights_forward(
// outputs // outputs
torch::Tensor weights = torch::zeros({n_samples}, sigmas.options()); torch::Tensor weights = torch::zeros({n_samples}, sigmas.options());
torch::Tensor ray_indices = torch::zeros({n_samples}, packed_info.options()); torch::Tensor ray_indices = torch::zeros({n_samples}, packed_info.options());
// The rays that are not skipped during sampling.
torch::Tensor mask = torch::zeros({n_rays}, sigmas.options().dtype(torch::kBool));
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
sigmas.scalar_type(), sigmas.scalar_type(),
...@@ -226,12 +221,11 @@ std::vector<torch::Tensor> volumetric_rendering_weights_forward( ...@@ -226,12 +221,11 @@ std::vector<torch::Tensor> volumetric_rendering_weights_forward(
ends.data_ptr<scalar_t>(), ends.data_ptr<scalar_t>(),
sigmas.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(),
weights.data_ptr<scalar_t>(), weights.data_ptr<scalar_t>(),
ray_indices.data_ptr<int>(), ray_indices.data_ptr<int>()
mask.data_ptr<bool>()
); );
})); }));
return {weights, ray_indices, mask}; return {weights, ray_indices};
} }
......
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
from .cuda import _C from .cuda import _C
@torch.no_grad()
def ray_aabb_intersect( def ray_aabb_intersect(
rays_o: torch.Tensor, rays_d: torch.Tensor, aabb: torch.Tensor rays_o: torch.Tensor, rays_d: torch.Tensor, aabb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
...@@ -32,11 +33,141 @@ def ray_aabb_intersect( ...@@ -32,11 +33,141 @@ def ray_aabb_intersect(
return t_min, t_max return t_min, t_max
@torch.no_grad()
def volumetric_marching(
rays_o: torch.Tensor,
rays_d: torch.Tensor,
aabb: torch.Tensor,
scene_occ_binary: torch.Tensor,
t_min: torch.Tensor = None,
t_max: torch.Tensor = None,
render_step_size: float = 1e-3,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Volumetric marching with occupancy test.
Note: this function is not differentiable to inputs.
Args:
rays_o: Ray origins. Tensor with shape (n_rays, 3).
rays_d: Normalized ray directions. Tensor with shape (n_rays, 3).
aabb: Scene bounding box {xmin, ymin, zmin, xmax, ymax, zmax}.
Tensor with shape (6)
scene_occ_binary: Scene occupancy binary field. BoolTensor with shape
(resx, resy, resz)
t_min: Optional. Ray near planes. Tensor with shape (n_ray,).
If not given it will be calculated using aabb test. Default is None.
t_max: Optional. Ray far planes. Tensor with shape (n_ray,)
If not given it will be calculated using aabb test. Default is None.
render_step_size: Marching step size. Default is 1e-3.
Returns:
packed_info: Stores infomation on which samples belong to the same ray.
It is a tensor with shape (n_rays, 2). For each ray, the two values
indicate the start index and the number of samples for this ray,
respectively.
frustum_origins: Sampled frustum origins. Tensor with shape (n_samples, 3).
frustum_dirs: Sampled frustum directions. Tensor with shape (n_samples, 3).
frustum_starts: Sampled frustum starts. Tensor with shape (n_samples, 1).
frustum_ends: Sampled frustum ends. Tensor with shape (n_samples, 1).
"""
if not rays_o.is_cuda:
raise NotImplementedError("Only support cuda inputs.")
if t_min is None or t_max is None:
t_min, t_max = ray_aabb_intersect(rays_o, rays_d, aabb)
assert scene_occ_binary.dim() == 3, f"Shape {scene_occ_binary.shape} is not right!"
(
packed_info,
frustum_origins,
frustum_dirs,
frustum_starts,
frustum_ends,
) = _C.volumetric_marching(
# rays
rays_o.contiguous(),
rays_d.contiguous(),
t_min.contiguous(),
t_max.contiguous(),
# density grid
aabb.contiguous(),
list(scene_occ_binary.shape),
scene_occ_binary.contiguous(),
# sampling
render_step_size,
)
return (
packed_info,
frustum_origins,
frustum_dirs,
frustum_starts,
frustum_ends,
)
@torch.no_grad()
def volumetric_rendering_steps(
packed_info: torch.Tensor,
sigmas: torch.Tensor,
frustum_starts: torch.Tensor,
frustum_ends: torch.Tensor,
*args,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute rendering marching steps.
This function will compact the samples by terminate the marching once the
transmittance reaches to 0.9999. It is recommanded that before running your
network with gradients enabled, first run this function without gradients
(`torch.no_grad()`) to quickly filter out some samples.
Note: this function is not differentiable to inputs.
Args:
packed_info: Stores infomation on which samples belong to the same ray.
See `volumetric_marching` for details. Tensor with shape (n_rays, 3).
sigmas: Densities at those samples. Tensor with shape (n_samples, 1).
frustum_starts: Where the frustum-shape sample starts along a ray. Tensor with
shape (n_samples, 1).
frustum_ends: Where the frustum-shape sample ends along a ray. Tensor with
shape (n_samples, 1).
Returns:
compact_packed_info: Compacted version of input `packed_info`.
compact_frustum_starts: Compacted version of input `frustum_starts`.
compact_frustum_ends: Compacted version of input `frustum_ends`.
... all the things in *args
"""
if (
packed_info.is_cuda
and frustum_starts.is_cuda
and frustum_ends.is_cuda
and sigmas.is_cuda
):
packed_info = packed_info.contiguous()
frustum_starts = frustum_starts.contiguous()
frustum_ends = frustum_ends.contiguous()
sigmas = sigmas.contiguous()
compact_packed_info, compact_selector = _C.volumetric_rendering_steps(
packed_info, frustum_starts, frustum_ends, sigmas
)
compact_frustum_starts = frustum_starts[compact_selector]
compact_frustum_ends = frustum_ends[compact_selector]
extras = (arg[compact_selector] for arg in args)
else:
raise NotImplementedError("Only support cuda inputs.")
return (
compact_packed_info,
compact_frustum_starts,
compact_frustum_ends,
*extras,
)
def volumetric_rendering_weights( def volumetric_rendering_weights(
packed_info: torch.Tensor, packed_info: torch.Tensor,
t_starts: torch.Tensor,
t_ends: torch.Tensor,
sigmas: torch.Tensor, sigmas: torch.Tensor,
frustum_starts: torch.Tensor,
frustum_ends: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute weights for volumetric rendering. """Compute weights for volumetric rendering.
...@@ -44,35 +175,37 @@ def volumetric_rendering_weights( ...@@ -44,35 +175,37 @@ def volumetric_rendering_weights(
Args: Args:
packed_info: Stores infomation on which samples belong to the same ray. packed_info: Stores infomation on which samples belong to the same ray.
See `volumetric_sampling` for details. Tensor with shape (n_rays, 3). See `volumetric_marching` for details. Tensor with shape (n_rays, 3).
t_starts: Where the frustum-shape sample starts along a ray. Tensor with sigmas: Densities at those samples. Tensor with shape (n_samples, 1).
shape (n_samples, 1). frustum_starts: Where the frustum-shape sample starts along a ray. Tensor with
t_ends: Where the frustum-shape sample ends along a ray. Tensor with
shape (n_samples, 1). shape (n_samples, 1).
sigmas: Densities at those samples. Tensor with frustum_ends: Where the frustum-shape sample ends along a ray. Tensor with
shape (n_samples, 1). shape (n_samples, 1).
Returns: Returns:
weights: Volumetric rendering weights for those samples. Tensor with shape weights: Volumetric rendering weights for those samples. Tensor with shape
(n_samples). (n_samples).
ray_indices: Ray index of each sample. IntTensor with shape (n_sample). ray_indices: Ray index of each sample. IntTensor with shape (n_sample).
ray_alive_masks: Whether we skipped this ray during sampling. BoolTensor with
shape (n_rays)
""" """
if packed_info.is_cuda and t_starts.is_cuda and t_ends.is_cuda and sigmas.is_cuda: if (
packed_info.is_cuda
and frustum_starts.is_cuda
and frustum_ends.is_cuda
and sigmas.is_cuda
):
packed_info = packed_info.contiguous() packed_info = packed_info.contiguous()
t_starts = t_starts.contiguous() frustum_starts = frustum_starts.contiguous()
t_ends = t_ends.contiguous() frustum_ends = frustum_ends.contiguous()
sigmas = sigmas.contiguous() sigmas = sigmas.contiguous()
weights, ray_indices, ray_alive_masks = _volumetric_rendering_weights.apply( weights, ray_indices = _volumetric_rendering_weights.apply(
packed_info, t_starts, t_ends, sigmas packed_info, frustum_starts, frustum_ends, sigmas
) )
else: else:
raise NotImplementedError("Only support cuda inputs.") raise NotImplementedError("Only support cuda inputs.")
return weights, ray_indices, ray_alive_masks return weights, ray_indices
def volumetric_accumulate( def volumetric_rendering_accumulate(
weights: torch.Tensor, weights: torch.Tensor,
ray_indices: torch.Tensor, ray_indices: torch.Tensor,
values: torch.Tensor = None, values: torch.Tensor = None,
...@@ -122,29 +255,25 @@ def volumetric_accumulate( ...@@ -122,29 +255,25 @@ def volumetric_accumulate(
class _volumetric_rendering_weights(torch.autograd.Function): class _volumetric_rendering_weights(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, packed_info, t_starts, t_ends, sigmas): def forward(ctx, packed_info, frustum_starts, frustum_ends, sigmas):
( weights, ray_indices = _C.volumetric_rendering_weights_forward(
weights, packed_info, frustum_starts, frustum_ends, sigmas
ray_indices,
ray_alive_masks,
) = _C.volumetric_rendering_weights_forward(
packed_info, t_starts, t_ends, sigmas
) )
ctx.save_for_backward( ctx.save_for_backward(
packed_info, packed_info,
t_starts, frustum_starts,
t_ends, frustum_ends,
sigmas, sigmas,
weights, weights,
) )
return weights, ray_indices, ray_alive_masks return weights, ray_indices
@staticmethod @staticmethod
def backward(ctx, grad_weights, _grad_ray_indices, _grad_ray_alive_masks): def backward(ctx, grad_weights, _grad_ray_indices):
( (
packed_info, packed_info,
t_starts, frustum_starts,
t_ends, frustum_ends,
sigmas, sigmas,
weights, weights,
) = ctx.saved_tensors ) = ctx.saved_tensors
...@@ -152,8 +281,8 @@ class _volumetric_rendering_weights(torch.autograd.Function): ...@@ -152,8 +281,8 @@ class _volumetric_rendering_weights(torch.autograd.Function):
weights, weights,
grad_weights, grad_weights,
packed_info, packed_info,
t_starts, frustum_starts,
t_ends, frustum_ends,
sigmas, sigmas,
) )
return None, None, None, grad_sigmas return None, None, None, grad_sigmas
...@@ -3,13 +3,10 @@ from typing import Callable, Tuple ...@@ -3,13 +3,10 @@ from typing import Callable, Tuple
import torch import torch
from .cuda import ( # ComputeWeight,; VolumeRenderer,; ray_aabb_intersect, from .utils import (
volumetric_marching, volumetric_marching,
volumetric_rendering_accumulate,
volumetric_rendering_steps, volumetric_rendering_steps,
)
from .utils import (
ray_aabb_intersect,
volumetric_accumulate,
volumetric_rendering_weights, volumetric_rendering_weights,
) )
...@@ -47,8 +44,6 @@ def volumetric_rendering( ...@@ -47,8 +44,6 @@ def volumetric_rendering(
# get packed samples from ray marching & occupancy check. # get packed samples from ray marching & occupancy check.
with torch.no_grad(): with torch.no_grad():
t_min, t_max = ray_aabb_intersect(rays_o, rays_d, scene_aabb)
( (
packed_info, packed_info,
frustum_origins, frustum_origins,
...@@ -59,16 +54,12 @@ def volumetric_rendering( ...@@ -59,16 +54,12 @@ def volumetric_rendering(
# rays # rays
rays_o, rays_o,
rays_d, rays_d,
t_min,
t_max,
# density grid # density grid
scene_aabb, aabb=scene_aabb,
scene_resolution, scene_occ_binary=scene_occ_binary.reshape(scene_resolution),
scene_occ_binary,
# sampling # sampling
render_step_size, render_step_size=render_step_size,
) )
frustum_positions = ( frustum_positions = (
frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0 frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0
) )
...@@ -79,16 +70,26 @@ def volumetric_rendering( ...@@ -79,16 +70,26 @@ def volumetric_rendering(
densities = query_fn( densities = query_fn(
frustum_positions, frustum_dirs, only_density=True, **kwargs frustum_positions, frustum_dirs, only_density=True, **kwargs
) )
compact_packed_info, compact_selector = volumetric_rendering_steps( (
packed_info.contiguous(), compact_packed_info,
frustum_starts.contiguous(), compact_frustum_starts,
frustum_ends.contiguous(), compact_frustum_ends,
densities.contiguous(), compact_frustum_positions,
compact_frustum_dirs,
) = volumetric_rendering_steps(
packed_info,
densities,
frustum_starts,
frustum_ends,
frustum_positions,
frustum_dirs,
) )
compact_frustum_positions = frustum_positions[compact_selector] # compact_frustum_positions = (
compact_frustum_dirs = frustum_dirs[compact_selector] # compact_frustum_origins
compact_frustum_starts = frustum_starts[compact_selector] # + compact_frustum_dirs
compact_frustum_ends = frustum_ends[compact_selector] # * (compact_frustum_starts + compact_frustum_ends)
# / 2.0
# )
compact_steps_counter = compact_packed_info[:, -1].sum(0, keepdim=True) compact_steps_counter = compact_packed_info[:, -1].sum(0, keepdim=True)
# network # network
...@@ -98,33 +99,31 @@ def volumetric_rendering( ...@@ -98,33 +99,31 @@ def volumetric_rendering(
compact_rgbs, compact_densities = compact_query_results[0], compact_query_results[1] compact_rgbs, compact_densities = compact_query_results[0], compact_query_results[1]
# accumulation # accumulation
compact_weights, compact_ray_indices, alive_ray_mask = volumetric_rendering_weights( compact_weights, compact_ray_indices = volumetric_rendering_weights(
compact_packed_info, compact_packed_info,
compact_densities,
compact_frustum_starts, compact_frustum_starts,
compact_frustum_ends, compact_frustum_ends,
compact_densities,
) )
accumulated_color = volumetric_accumulate( accumulated_color = volumetric_rendering_accumulate(
compact_weights, compact_ray_indices, compact_rgbs, n_rays compact_weights, compact_ray_indices, compact_rgbs, n_rays
) )
accumulated_weight = volumetric_accumulate( accumulated_weight = volumetric_rendering_accumulate(
compact_weights, compact_ray_indices, None, n_rays compact_weights, compact_ray_indices, None, n_rays
) )
accumulated_depth = volumetric_accumulate( accumulated_depth = volumetric_rendering_accumulate(
compact_weights, compact_weights,
compact_ray_indices, compact_ray_indices,
(compact_frustum_starts + compact_frustum_ends) / 2.0, (compact_frustum_starts + compact_frustum_ends) / 2.0,
n_rays, n_rays,
) )
accumulated_depth = torch.clip(accumulated_depth, t_min[:, None], t_max[:, None])
accumulated_color = accumulated_color + render_bkgd * (1.0 - accumulated_weight) accumulated_color = accumulated_color + render_bkgd * (1.0 - accumulated_weight)
return ( return (
accumulated_color, accumulated_color,
accumulated_depth, accumulated_depth,
accumulated_weight, accumulated_weight,
alive_ray_mask,
steps_counter, steps_counter,
compact_steps_counter, compact_steps_counter,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment