Commit 5b6f0c61 authored by Ruilong Li's avatar Ruilong Li
Browse files

remove alive_mask

parent 224e217d
......@@ -73,14 +73,13 @@ def render_image(radiance_field, rays, render_bkgd, render_step_size):
render_step_size=render_step_size,
)
results.append(chunk_results)
rgb, depth, acc, alive_ray_mask, counter, compact_counter = [
rgb, depth, acc, counter, compact_counter = [
torch.cat(r, dim=0) for r in zip(*results)
]
return (
rgb.view((*rays_shape[:-1], -1)),
depth.view((*rays_shape[:-1], -1)),
acc.view((*rays_shape[:-1], -1)),
alive_ray_mask.view(*rays_shape[:-1]),
counter.sum(),
compact_counter.sum(),
)
......@@ -192,7 +191,7 @@ if __name__ == "__main__":
# update occupancy grid
occ_field.every_n_step(step)
rgb, depth, acc, alive_ray_mask, counter, compact_counter = render_image(
rgb, depth, acc, counter, compact_counter = render_image(
radiance_field, rays, render_bkgd, render_step_size
)
num_rays = len(pixels)
......@@ -200,6 +199,7 @@ if __name__ == "__main__":
num_rays * (TARGET_SAMPLE_BATCH_SIZE / float(compact_counter.item()))
)
train_dataset.update_num_rays(num_rays)
alive_ray_mask = acc.squeeze(-1) > 0
# compute loss
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
......@@ -231,7 +231,7 @@ if __name__ == "__main__":
pixels = data["pixels"].to(device)
render_bkgd = data["color_bkgd"].to(device)
# rendering
rgb, depth, acc, alive_ray_mask, _, _ = render_image(
rgb, depth, acc, _, _ = render_image(
radiance_field, rays, render_bkgd, render_step_size
)
mse = F.mse_loss(rgb, pixels)
......
from ._backend import _C
volumetric_marching = _C.volumetric_marching
volumetric_rendering_steps = _C.volumetric_rendering_steps
......@@ -53,8 +53,7 @@ __global__ void volumetric_rendering_weights_forward_kernel(
const scalar_t* sigmas, // input density after activation
// should be all-zero initialized
scalar_t* weights, // output
int* samples_ray_ids, // output
bool* mask // output
int* samples_ray_ids // output
) {
CUDA_GET_THREAD_ID(i, n_rays);
......@@ -68,7 +67,6 @@ __global__ void volumetric_rendering_weights_forward_kernel(
sigmas += base;
weights += base;
samples_ray_ids += base;
mask += i;
for (int j = 0; j < steps; ++j) {
samples_ray_ids[j] = i;
......@@ -87,7 +85,6 @@ __global__ void volumetric_rendering_weights_forward_kernel(
weights[j] = weight;
T *= (1.f - alpha);
}
mask[0] = true;
}
......@@ -167,7 +164,7 @@ std::vector<torch::Tensor> volumetric_rendering_steps(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
sigmas.scalar_type(),
"volumetric_rendering_inference",
"volumetric_marching_steps",
([&]
{ volumetric_rendering_steps_kernel<scalar_t><<<blocks, threads>>>(
n_rays,
......@@ -212,8 +209,6 @@ std::vector<torch::Tensor> volumetric_rendering_weights_forward(
// outputs
torch::Tensor weights = torch::zeros({n_samples}, sigmas.options());
torch::Tensor ray_indices = torch::zeros({n_samples}, packed_info.options());
// The rays that are not skipped during sampling.
torch::Tensor mask = torch::zeros({n_rays}, sigmas.options().dtype(torch::kBool));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
sigmas.scalar_type(),
......@@ -226,12 +221,11 @@ std::vector<torch::Tensor> volumetric_rendering_weights_forward(
ends.data_ptr<scalar_t>(),
sigmas.data_ptr<scalar_t>(),
weights.data_ptr<scalar_t>(),
ray_indices.data_ptr<int>(),
mask.data_ptr<bool>()
ray_indices.data_ptr<int>()
);
}));
return {weights, ray_indices, mask};
return {weights, ray_indices};
}
......
......@@ -5,6 +5,7 @@ import torch
from .cuda import _C
@torch.no_grad()
def ray_aabb_intersect(
rays_o: torch.Tensor, rays_d: torch.Tensor, aabb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
......@@ -32,11 +33,141 @@ def ray_aabb_intersect(
return t_min, t_max
@torch.no_grad()
def volumetric_marching(
rays_o: torch.Tensor,
rays_d: torch.Tensor,
aabb: torch.Tensor,
scene_occ_binary: torch.Tensor,
t_min: torch.Tensor = None,
t_max: torch.Tensor = None,
render_step_size: float = 1e-3,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Volumetric marching with occupancy test.
Note: this function is not differentiable to inputs.
Args:
rays_o: Ray origins. Tensor with shape (n_rays, 3).
rays_d: Normalized ray directions. Tensor with shape (n_rays, 3).
aabb: Scene bounding box {xmin, ymin, zmin, xmax, ymax, zmax}.
Tensor with shape (6)
scene_occ_binary: Scene occupancy binary field. BoolTensor with shape
(resx, resy, resz)
t_min: Optional. Ray near planes. Tensor with shape (n_ray,).
If not given it will be calculated using aabb test. Default is None.
t_max: Optional. Ray far planes. Tensor with shape (n_ray,)
If not given it will be calculated using aabb test. Default is None.
render_step_size: Marching step size. Default is 1e-3.
Returns:
packed_info: Stores infomation on which samples belong to the same ray.
It is a tensor with shape (n_rays, 2). For each ray, the two values
indicate the start index and the number of samples for this ray,
respectively.
frustum_origins: Sampled frustum origins. Tensor with shape (n_samples, 3).
frustum_dirs: Sampled frustum directions. Tensor with shape (n_samples, 3).
frustum_starts: Sampled frustum starts. Tensor with shape (n_samples, 1).
frustum_ends: Sampled frustum ends. Tensor with shape (n_samples, 1).
"""
if not rays_o.is_cuda:
raise NotImplementedError("Only support cuda inputs.")
if t_min is None or t_max is None:
t_min, t_max = ray_aabb_intersect(rays_o, rays_d, aabb)
assert scene_occ_binary.dim() == 3, f"Shape {scene_occ_binary.shape} is not right!"
(
packed_info,
frustum_origins,
frustum_dirs,
frustum_starts,
frustum_ends,
) = _C.volumetric_marching(
# rays
rays_o.contiguous(),
rays_d.contiguous(),
t_min.contiguous(),
t_max.contiguous(),
# density grid
aabb.contiguous(),
list(scene_occ_binary.shape),
scene_occ_binary.contiguous(),
# sampling
render_step_size,
)
return (
packed_info,
frustum_origins,
frustum_dirs,
frustum_starts,
frustum_ends,
)
@torch.no_grad()
def volumetric_rendering_steps(
packed_info: torch.Tensor,
sigmas: torch.Tensor,
frustum_starts: torch.Tensor,
frustum_ends: torch.Tensor,
*args,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute rendering marching steps.
This function will compact the samples by terminate the marching once the
transmittance reaches to 0.9999. It is recommanded that before running your
network with gradients enabled, first run this function without gradients
(`torch.no_grad()`) to quickly filter out some samples.
Note: this function is not differentiable to inputs.
Args:
packed_info: Stores infomation on which samples belong to the same ray.
See `volumetric_marching` for details. Tensor with shape (n_rays, 3).
sigmas: Densities at those samples. Tensor with shape (n_samples, 1).
frustum_starts: Where the frustum-shape sample starts along a ray. Tensor with
shape (n_samples, 1).
frustum_ends: Where the frustum-shape sample ends along a ray. Tensor with
shape (n_samples, 1).
Returns:
compact_packed_info: Compacted version of input `packed_info`.
compact_frustum_starts: Compacted version of input `frustum_starts`.
compact_frustum_ends: Compacted version of input `frustum_ends`.
... all the things in *args
"""
if (
packed_info.is_cuda
and frustum_starts.is_cuda
and frustum_ends.is_cuda
and sigmas.is_cuda
):
packed_info = packed_info.contiguous()
frustum_starts = frustum_starts.contiguous()
frustum_ends = frustum_ends.contiguous()
sigmas = sigmas.contiguous()
compact_packed_info, compact_selector = _C.volumetric_rendering_steps(
packed_info, frustum_starts, frustum_ends, sigmas
)
compact_frustum_starts = frustum_starts[compact_selector]
compact_frustum_ends = frustum_ends[compact_selector]
extras = (arg[compact_selector] for arg in args)
else:
raise NotImplementedError("Only support cuda inputs.")
return (
compact_packed_info,
compact_frustum_starts,
compact_frustum_ends,
*extras,
)
def volumetric_rendering_weights(
packed_info: torch.Tensor,
t_starts: torch.Tensor,
t_ends: torch.Tensor,
sigmas: torch.Tensor,
frustum_starts: torch.Tensor,
frustum_ends: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute weights for volumetric rendering.
......@@ -44,35 +175,37 @@ def volumetric_rendering_weights(
Args:
packed_info: Stores infomation on which samples belong to the same ray.
See `volumetric_sampling` for details. Tensor with shape (n_rays, 3).
t_starts: Where the frustum-shape sample starts along a ray. Tensor with
shape (n_samples, 1).
t_ends: Where the frustum-shape sample ends along a ray. Tensor with
See `volumetric_marching` for details. Tensor with shape (n_rays, 3).
sigmas: Densities at those samples. Tensor with shape (n_samples, 1).
frustum_starts: Where the frustum-shape sample starts along a ray. Tensor with
shape (n_samples, 1).
sigmas: Densities at those samples. Tensor with
frustum_ends: Where the frustum-shape sample ends along a ray. Tensor with
shape (n_samples, 1).
Returns:
weights: Volumetric rendering weights for those samples. Tensor with shape
(n_samples).
ray_indices: Ray index of each sample. IntTensor with shape (n_sample).
ray_alive_masks: Whether we skipped this ray during sampling. BoolTensor with
shape (n_rays)
"""
if packed_info.is_cuda and t_starts.is_cuda and t_ends.is_cuda and sigmas.is_cuda:
if (
packed_info.is_cuda
and frustum_starts.is_cuda
and frustum_ends.is_cuda
and sigmas.is_cuda
):
packed_info = packed_info.contiguous()
t_starts = t_starts.contiguous()
t_ends = t_ends.contiguous()
frustum_starts = frustum_starts.contiguous()
frustum_ends = frustum_ends.contiguous()
sigmas = sigmas.contiguous()
weights, ray_indices, ray_alive_masks = _volumetric_rendering_weights.apply(
packed_info, t_starts, t_ends, sigmas
weights, ray_indices = _volumetric_rendering_weights.apply(
packed_info, frustum_starts, frustum_ends, sigmas
)
else:
raise NotImplementedError("Only support cuda inputs.")
return weights, ray_indices, ray_alive_masks
return weights, ray_indices
def volumetric_accumulate(
def volumetric_rendering_accumulate(
weights: torch.Tensor,
ray_indices: torch.Tensor,
values: torch.Tensor = None,
......@@ -122,29 +255,25 @@ def volumetric_accumulate(
class _volumetric_rendering_weights(torch.autograd.Function):
@staticmethod
def forward(ctx, packed_info, t_starts, t_ends, sigmas):
(
weights,
ray_indices,
ray_alive_masks,
) = _C.volumetric_rendering_weights_forward(
packed_info, t_starts, t_ends, sigmas
def forward(ctx, packed_info, frustum_starts, frustum_ends, sigmas):
weights, ray_indices = _C.volumetric_rendering_weights_forward(
packed_info, frustum_starts, frustum_ends, sigmas
)
ctx.save_for_backward(
packed_info,
t_starts,
t_ends,
frustum_starts,
frustum_ends,
sigmas,
weights,
)
return weights, ray_indices, ray_alive_masks
return weights, ray_indices
@staticmethod
def backward(ctx, grad_weights, _grad_ray_indices, _grad_ray_alive_masks):
def backward(ctx, grad_weights, _grad_ray_indices):
(
packed_info,
t_starts,
t_ends,
frustum_starts,
frustum_ends,
sigmas,
weights,
) = ctx.saved_tensors
......@@ -152,8 +281,8 @@ class _volumetric_rendering_weights(torch.autograd.Function):
weights,
grad_weights,
packed_info,
t_starts,
t_ends,
frustum_starts,
frustum_ends,
sigmas,
)
return None, None, None, grad_sigmas
......@@ -3,13 +3,10 @@ from typing import Callable, Tuple
import torch
from .cuda import ( # ComputeWeight,; VolumeRenderer,; ray_aabb_intersect,
from .utils import (
volumetric_marching,
volumetric_rendering_accumulate,
volumetric_rendering_steps,
)
from .utils import (
ray_aabb_intersect,
volumetric_accumulate,
volumetric_rendering_weights,
)
......@@ -47,8 +44,6 @@ def volumetric_rendering(
# get packed samples from ray marching & occupancy check.
with torch.no_grad():
t_min, t_max = ray_aabb_intersect(rays_o, rays_d, scene_aabb)
(
packed_info,
frustum_origins,
......@@ -59,16 +54,12 @@ def volumetric_rendering(
# rays
rays_o,
rays_d,
t_min,
t_max,
# density grid
scene_aabb,
scene_resolution,
scene_occ_binary,
aabb=scene_aabb,
scene_occ_binary=scene_occ_binary.reshape(scene_resolution),
# sampling
render_step_size,
render_step_size=render_step_size,
)
frustum_positions = (
frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0
)
......@@ -79,16 +70,26 @@ def volumetric_rendering(
densities = query_fn(
frustum_positions, frustum_dirs, only_density=True, **kwargs
)
compact_packed_info, compact_selector = volumetric_rendering_steps(
packed_info.contiguous(),
frustum_starts.contiguous(),
frustum_ends.contiguous(),
densities.contiguous(),
(
compact_packed_info,
compact_frustum_starts,
compact_frustum_ends,
compact_frustum_positions,
compact_frustum_dirs,
) = volumetric_rendering_steps(
packed_info,
densities,
frustum_starts,
frustum_ends,
frustum_positions,
frustum_dirs,
)
compact_frustum_positions = frustum_positions[compact_selector]
compact_frustum_dirs = frustum_dirs[compact_selector]
compact_frustum_starts = frustum_starts[compact_selector]
compact_frustum_ends = frustum_ends[compact_selector]
# compact_frustum_positions = (
# compact_frustum_origins
# + compact_frustum_dirs
# * (compact_frustum_starts + compact_frustum_ends)
# / 2.0
# )
compact_steps_counter = compact_packed_info[:, -1].sum(0, keepdim=True)
# network
......@@ -98,33 +99,31 @@ def volumetric_rendering(
compact_rgbs, compact_densities = compact_query_results[0], compact_query_results[1]
# accumulation
compact_weights, compact_ray_indices, alive_ray_mask = volumetric_rendering_weights(
compact_weights, compact_ray_indices = volumetric_rendering_weights(
compact_packed_info,
compact_densities,
compact_frustum_starts,
compact_frustum_ends,
compact_densities,
)
accumulated_color = volumetric_accumulate(
accumulated_color = volumetric_rendering_accumulate(
compact_weights, compact_ray_indices, compact_rgbs, n_rays
)
accumulated_weight = volumetric_accumulate(
accumulated_weight = volumetric_rendering_accumulate(
compact_weights, compact_ray_indices, None, n_rays
)
accumulated_depth = volumetric_accumulate(
accumulated_depth = volumetric_rendering_accumulate(
compact_weights,
compact_ray_indices,
(compact_frustum_starts + compact_frustum_ends) / 2.0,
n_rays,
)
accumulated_depth = torch.clip(accumulated_depth, t_min[:, None], t_max[:, None])
accumulated_color = accumulated_color + render_bkgd * (1.0 - accumulated_weight)
return (
accumulated_color,
accumulated_depth,
accumulated_weight,
alive_ray_mask,
steps_counter,
compact_steps_counter,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment